DiT (Diffusion Transformer)

DiT is a Transformer-based backbone for [[Diffusion Model|diffusion models]] that replaces the traditional [[U-Net]] architecture. Inspired by Vision Transformers (ViT), DiT patchifies the input image and processes it through a series of Transformer blocks with adaptive layer normalization (adaLN) for conditioning, achieving superior scalability — as model size and compute increase, DiT consistently outperforms U-Net baselines.


1. Core Concept

1.1 From U-Net to Transformer

Traditional diffusion models (DDPM, Stable Diffusion) use a convolution-based [[U-Net]] as the denoising network ϵθ(xt,t) . DiT asks a fundamental question:

Can we replace the inductive bias of convolutions with the scalability of Transformers?

Answer: Yes — at sufficient scale, Transformers outperform U-Net, leading to architectures like DiT (Peebles & Xie, 2023), U-ViT (Bao et al., 2022), and the backbones behind SORA and Stable Diffusion 3.

1.2 ViT-Inspired Design

DiT inherits the Vision Transformer paradigm:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
DiT Architecture Overview
═══════════════════════════════════════════════════════
Noisy Image x_t (e.g., 256×256×4 in latent space)

├── Patchify: split into N patches (e.g., 32×32 grid)
│ → Each patch = token of dimension d

├── Positional Embedding (sinusoidal, learned, or RoPE)


┌─────────────────────────────────────────────────────┐
│ DiT Block × N (e.g., 28 blocks for DiT-XL) │
│ ┌───────────────────────────────────────────────┐ │
│ │ adaLN (time + condition) → scale, shift, gate│ │
│ │ ↓ │ │
│ │ Multi-Head Self-Attention │ │
│ │ ↓ │ │
│ │ adaLN → Pointwise FFN (MLP) │ │
│ └───────────────────────────────────────────────┘ │
│ ... repeat N times ... │
└─────────────────────────────────────────────────────┘

├── Final LayerNorm + Linear Projection


Predicted Noise ε_θ(x_t, t, c) or Predicted v (velocity)
═══════════════════════════════════════════════════════

1.3 Key Design Principles

Principle U-Net Approach DiT Approach
Spatial processing Hierarchical convolutions Global self-attention on patches
Multi-scale Encoder-decoder with skip connections Single-scale (all tokens at same resolution)
Conditioning Scale-shift in ResBlocks, cross-attention adaLN in every block
Inductive bias Strong (locality, translation equivariance) Weak (learned from data)
Scalability Plateaus at ~500M params Continues improving with scale

2. Architecture in Detail

2.1 Patch Embedding

Images are split into non-overlapping patches, analogous to ViT:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
class PatchEmbed(nn.Module):
"""Convert image to patch tokens for DiT."""

def __init__(self, patch_size=2, in_channels=4, embed_dim=1152):
super().__init__()
self.patch_size = patch_size
self.proj = nn.Conv2d(
in_channels, embed_dim,
kernel_size=patch_size, stride=patch_size
)

def forward(self, x):
# x: (B, C, H, W) → (B, N, D) where N = (H/p)*(W/p)
x = self.proj(x) # (B, D, H/p, W/p)
x = x.flatten(2) # (B, D, N)
x = x.transpose(1, 2) # (B, N, D)
return x

Typical configuration for DiT-XL/2:

  • Input: latent space 32×32×4 (VAE-compressed 256×256 image)
  • Patch size: p=2 N=(32/2)2=256 tokens
  • Embedding dimension: d=1152

2.2 DiT Block

The core DiT Block uses adaptive layer normalization (adaLN) for conditioning:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class DiTBlock(nn.Module):
"""A single DiT transformer block with adaLN conditioning."""

def __init__(self, hidden_size, num_heads, mlp_ratio=4.0):
super().__init__()
self.norm1 = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.attn = nn.MultiheadAttention(
hidden_size, num_heads, batch_first=True
)
self.norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False)

mlp_hidden = int(hidden_size * mlp_ratio)
self.mlp = nn.Sequential(
nn.Linear(hidden_size, mlp_hidden),
nn.GELU(approximate='tanh'),
nn.Linear(mlp_hidden, hidden_size),
)

# adaLN: regress scale, shift, gate from conditioning vector
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size), # 6 params per block
)

def forward(self, x, c):
# c: conditioning vector (time + class/text embedding)

# Compute 6 modulation parameters
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = \
self.adaLN_modulation(c).chunk(6, dim=1)

# Self-attention with adaLN
x_norm = self.norm1(x)
x_modulated = x_norm * (1 + scale_msa.unsqueeze(1)) + shift_msa.unsqueeze(1)
attn_out, _ = self.attn(x_modulated, x_modulated, x_modulated)
x = x + gate_msa.unsqueeze(1) * attn_out

# MLP with adaLN
x_norm = self.norm2(x)
x_modulated = x_norm * (1 + scale_mlp.unsqueeze(1)) + shift_mlp.unsqueeze(1)
mlp_out = self.mlp(x_modulated)
x = x + gate_mlp.unsqueeze(1) * mlp_out

return x

The 6 modulation parameters per block are:

[γ1,β1,α1,γ2,β2,α2]=adaLN(c)

Applied as:

x=x+α1Attn(γ1LN(x)+β1)x=x+α2MLP(γ2LN(x)+β2)

where γ scales, β shifts, and α gates the residual path.

2.3 Conditioning Vector Construction

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
class ConditioningEmbedder(nn.Module):
"""Combine time and class/text embeddings into adaLN conditioning."""

def __init__(self, hidden_size, num_classes=None):
super().__init__()
self.hidden_size = hidden_size

# Time embedding (sinusoidal → MLP)
self.time_embed = nn.Sequential(
SinusoidalEmbedding(hidden_size),
nn.Linear(hidden_size, hidden_size),
nn.SiLU(),
nn.Linear(hidden_size, hidden_size),
)

# Class embedding (optional)
if num_classes is not None:
self.class_embed = nn.Embedding(num_classes, hidden_size)
else:
self.class_embed = None

def forward(self, t, class_labels=None):
# Time conditioning
t_emb = self.time_embed(t)

# Class conditioning (if available)
if self.class_embed is not None and class_labels is not None:
c_emb = self.class_embed(class_labels)
return t_emb + c_emb

return t_emb

2.4 Full DiT Model

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
class DiT(nn.Module):
"""Full Diffusion Transformer."""

def __init__(self, input_size=32, patch_size=2, in_channels=4,
hidden_size=1152, depth=28, num_heads=16,
mlp_ratio=4.0, num_classes=1000):
super().__init__()

self.input_size = input_size
self.patch_size = patch_size
self.hidden_size = hidden_size
self.num_patches = (input_size // patch_size) ** 2

# Patch embedding
self.patch_embed = PatchEmbed(patch_size, in_channels, hidden_size)

# Positional embedding (sinusoidal)
self.pos_embed = nn.Parameter(
torch.zeros(1, self.num_patches, hidden_size)
)

# Conditioning embedder
self.cond_embed = ConditioningEmbedder(hidden_size, num_classes)

# Transformer blocks
self.blocks = nn.ModuleList([
DiTBlock(hidden_size, num_heads, mlp_ratio)
for _ in range(depth)
])

# Final layer
self.final_norm = nn.LayerNorm(hidden_size, elementwise_affine=False)
self.final_adaLN = nn.Sequential(
nn.SiLU(),
nn.Linear(hidden_size, 2 * hidden_size),
)
self.final_proj = nn.Linear(
hidden_size, patch_size * patch_size * in_channels
)

self.initialize_weights()

def initialize_weights(self):
# Standard DiT initialization
nn.init.normal_(self.pos_embed, std=0.02)

# Zero-initialize final adaLN and projection
nn.init.constant_(self.final_adaLN[-1].weight, 0)
nn.init.constant_(self.final_adaLN[-1].bias, 0)
nn.init.constant_(self.final_proj.weight, 0)
nn.init.constant_(self.final_proj.bias, 0)

# Standard initialization for transformer blocks
for block in self.blocks:
nn.init.constant_(block.adaLN_modulation[-1].weight, 0)
nn.init.constant_(block.adaLN_modulation[-1].bias, 0)

def forward(self, x, t, class_labels=None):
# x: (B, C, H, W)
B, C, H, W = x.shape

# Patchify
x = self.patch_embed(x) # (B, N, D)
x = x + self.pos_embed[:, :x.shape[1]]

# Conditioning
c = self.cond_embed(t, class_labels) # (B, D)

# Transformer blocks
for block in self.blocks:
x = block(x, c)

# Final layer with adaLN
shift, scale = self.final_adaLN(c).chunk(2, dim=1)
x = self.final_norm(x) * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
x = self.final_proj(x) # (B, N, p²·C)

# Unpatchify back to image
x = self.unpatchify(x, H, W)
return x

def unpatchify(self, x, H, W):
p = self.patch_size
h, w = H // p, W // p
x = x.reshape(x.shape[0], h, w, p, p, -1)
x = x.permute(0, 5, 1, 3, 2, 4)
x = x.reshape(x.shape[0], -1, h * p, w * p)
return x

3. Conditioning Mechanisms

3.1 Adaptive Layer Normalization (adaLN)

DiT’s key innovation is adaLN, which replaces standard conditioning approaches:

Conditioning Method Mechanism Pros Cons
adaLN (DiT) Regress scale/shift/gate from conditioning vector Unified, parameter-efficient, per-block Fixed modulation per token
In-context Append condition tokens to sequence Simple, flexible Longer sequences
Cross-attention Condition tokens attend to image tokens Separate condition path, expressive More parameters
Add/Concat Add or concat condition to features Simple Less expressive

3.2 adaLN with Zero-Initialization

DiT uses zero-initialization for all adaLN output layers:

1
2
3
4
# Zero-init ensures identity function at initialization
# scale=0, shift=0, gate=0 → x_out = x (no modulation)
nn.init.constant_(adaLN_modulation[-1].weight, 0)
nn.init.constant_(adaLN_modulation[-1].bias, 0)

This ensures the model starts as the identity function, which stabilizes early training — the model gradually learns to modulate features rather than starting from random perturbations.

3.3 Conditioning Flow

1
2
3
4
5
6
7
8
9
10
11
12
13
Timestep t (int) ──→ Sinusoidal Embedding ──→ MLP ──→ t_emb (D-dim)

Class label c ─────→ Embedding ─────────────────→ c_emb (D-dim)

t_emb + c_emb


Conditioning Vector c

┌───────────────┼───────────────┐
▼ ▼ ▼
DiT Block 1 DiT Block 2 DiT Block N
(adaLN params) (adaLN params) (adaLN params)

4. DiT Model Variants

4.1 DiT Family

Model Hidden Dim Depth Heads Params FID (ImageNet 256², CFG)
DiT-S/2 384 12 6 33M 68.40
DiT-B/2 768 12 12 130M 8.25
DiT-L/2 1024 24 16 459M 3.95
DiT-XL/2 1152 28 16 675M 2.27

Naming convention: DiT-{Size}/{Patch} e.g., DiT-XL/2 = XL model, patch size 2.

4.2 Patch Size Trade-off

Patch Size Tokens (32² latent) FLOPs Detail preservation Best for
p=1 1024 Very high Maximum Maximum quality (at cost)
p=2 256 Moderate Good Default (best FLOP/quality trade-off)
p=4 64 Low Coarse Fast prototyping
p=8 16 Very low Minimal Ablation studies

Smaller patch size = more tokens = quadratic increase in attention cost, but better detail.

4.3 U-ViT

An alternative Transformer backbone (Bao et al., 2022) that incorporates long skip connections between shallow and deep layers:

1
2
3
4
5
U-ViT Architecture:
Input Tokens ──→ Block 1 ──→ Block 2 ──→ ... ──→ Block N ──→ Output
│ │
└────────── Long Skip ────────────────┘
(concatenate all intermediate tokens)

Comparison with DiT:

Aspect DiT U-ViT
Skip connections None Long skip (shallow → deep)
Implementation Simpler (pure Transformer) Extra concatenation layer
Performance Better at large scale Competitive at small scale
3D extension Straightforward Needs adaptation

5. Scaling Properties

5.1 DiT Scaling Law

DiT exhibits power-law scaling: performance improves predictably with model size and training compute.

FIDCα,α0.4

where C is training compute (FLOPs).

Key findings from the DiT paper:

Scaling Factor Observation
Model depth Deeper → better, saturates slowly
Model width Wider → better, saturates faster than depth
Training steps More steps → better, no plateau at 7M steps
Data DiT benefits more from data than U-Net

5.2 DiT vs. U-Net Scaling

Metric U-Net DiT
Small scale (<100M) ✅ Better (inductive bias helps) ❌ Underperforms
Medium scale (100-500M) ≈ Comparable ≈ Comparable
Large scale (>500M) ❌ Plateaus ✅ Keeps improving
GFLOPs (forward) Lower Higher (quadratic attention)
Training stability Good Good (with zero-init)

5.3 Why DiT Scales Better

  1. No architectural bottleneck: U-Net’s downsampling discards information; DiT preserves all tokens
  2. Global receptive field: Every token attends to every other token from the first block
  3. Homogeneous design: Same operation at every layer, easier to optimize
  4. Flexible conditioning: adaLN injects condition information uniformly through all blocks

6. DiT for Video: The SORA Architecture

6.1 From 2D to Spacetime Patches

SORA (OpenAI, 2024) extends DiT to video generation by treating video as a spacetime volume:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
class SpaceTimePatchEmbed(nn.Module):
"""3D patch embedding for video DiT (SORA-style)."""

def __init__(self, patch_size_t=1, patch_size_h=2, patch_size_w=2,
in_channels=4, embed_dim=1152):
super().__init__()
self.patch_size = (patch_size_t, patch_size_h, patch_size_w)
self.proj = nn.Conv3d(
in_channels, embed_dim,
kernel_size=self.patch_size,
stride=self.patch_size,
)

def forward(self, x):
# x: (B, C, T, H, W)
x = self.proj(x) # (B, D, T/pt, H/ph, W/pw)
x = x.flatten(2).transpose(1, 2) # (B, N, D)
return x

Key SORA design choices:

  • Spacetime patches (e.g., 1×2×2 ) treat time and space jointly
  • Native variable-resolution and variable-duration training
  • Scalable to minute-long video generation

6.2 SORA vs. Image DiT

Aspect Image DiT Video DiT (SORA)
Patch dimension 2D (h × w) 3D (t × h × w)
Sequence length 256 (32² / 2²) Up to 10K+ tokens
Attention Full self-attention Efficient attention (flash, sparse)
Position encoding 2D sine/cos 3D RoPE or learned
Conditioning Class label Text (T5 encoder)

7. Comparison with U-Net

7.1 Architectural Comparison

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
U-Net Backbone                          DiT Backbone
════════════════════ ════════════
Input (image grid) Input (image grid)
↓ ↓
Patchify + Embedding Patchify + Positional Embed
↓ ↓
ResBlock × 2 ──────────────────┐ DiT Block 1 (adaLN)
↓ │ ↓
Downsample │ DiT Block 2 (adaLN)
↓ │ ↓
ResBlock × 2 ──────────┐ │ DiT Block 3 (adaLN)
↓ │ │ ↓
Downsample │ │ ...
↓ │ │ ↓
Bottleneck (Attn) │ │ DiT Block N (adaLN)
↓ │ │ ↓
Upsample + Concat ←─────┘ │ Final LayerNorm + Proj
↓ │ ↓
ResBlock × 2 │ Unpatchify
↓ │ ↓
Upsample + Concat ←─────────────┘ Output (image grid)

ResBlock × 2

Output (image grid)
════════════════════ ════════════════

7.2 When to Use Each

Scenario Recommendation Rationale
Small budget (<100M params) [[U-Net]] Convolutional inductive bias more data-efficient
Large budget (>500M params) DiT Transformer scaling law kicks in
High-resolution images [[U-Net]] (with cascaded) O(N2) attention cost in DiT
Text-to-image Both viable SD3 uses DiT, SDXL uses U-Net
Video generation DiT / SORA-style Spacetime patches handle temporal naturally
Multi-modal DiT Transformer’s flexibility with modalities

7.3 Stable Diffusion 3 Architecture

Stable Diffusion 3 (2024) adopts a Multimodal Diffusion Transformer (MMDiT) that extends DiT:

  • Dual-stream: Separate weights for text and image tokens
  • Shared attention: Text and image tokens attend to each other
  • Rectified flow: Uses flow matching instead of DDPM noise prediction
vθ(xt,t,ctext,cimage)=MMDiT(xt,t,ctext,cimage)

8. Practical Implementation

8.1 Training Configuration

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
# DiT-XL/2 training configuration (ImageNet 256²)
training_config = {
"model": "DiT-XL/2",
"params": 675_000_000,
"input_size": 32, # Latent space (VAE)
"patch_size": 2,
"batch_size": 256,
"learning_rate": 1e-4,
"weight_decay": 0.0,
"optimizer": "AdamW",
"beta1": 0.9,
"beta2": 0.999,
"training_steps": 7_000_000,
"ema_decay": 0.9999,
"mixed_precision": "fp16",
"gradient_clip": 1.0,
"vae": "sd-vae-ft-mse", # Pre-trained VAE for latent encoding
}

8.2 Inference Optimizations

Technique Speedup Memory Saving Quality Impact
Flash Attention 2-3× 5-10× (smaller memory) None (exact)
v-prediction 1.2× (fewer steps) Slight improvement
Classifier-free guidance 2× (two forward passes) Better quality
Token merging (ToMe) 1.5× 1.3× Minor degradation
INT8 quantization 1.3× Slight degradation

8.3 Common Pitfalls

Pitfall Symptom Fix
Instability without zero-init Training loss diverges early Zero-initialize adaLN output layers
OOM with large patches CUDA out of memory Use smaller patch size or gradient checkpointing
Slow convergence High FID after many steps Check learning rate warmup, try lr=2×104
NaN in attention Loss becomes NaN Use fp32 for softmax, reduce learning rate
Patch boundary artifacts Grid-like patterns in output Ensure patch size divides input evenly

9. Theoretical Properties

9.1 Expressiveness

DiT’s self-attention provides global receptive field from the first block:

Attention(Q,K,V)=softmax(QKTdk)V

Every patch token can directly attend to every other token, unlike U-Net’s hierarchical approach where global context only emerges at the bottleneck.

9.2 Computational Complexity

For N tokens and dimension d :

Operation U-Net DiT
Per-block cost O(Nk2d2) (conv) O(N2d+Nd2) (attention)
Total blocks ~200 (across all resolutions) 28 (single resolution)
Dominant term Convolution at high resolutions Attention at all resolutions

For typical configurations ( N=256 , d=1152 ):

O(N2d)=2562115275MvsO(Nk2d2)=25691.3M3M

DiT’s attention cost is significantly higher, which is why it only becomes competitive at scale.

9.3 adaLN as HyperNetwork

adaLN can be viewed as a hypernetwork that generates layer-specific parameters:

θmodulation(l)=MLPadaLN(l)(c)

This is more expressive than simple concatenation because it allows the conditioning signal to dynamically control the importance of each feature dimension per block.


10. Core Formula Cards

# Formula Meaning
1 x=PatchEmbed(xt)+Epos Tokenization + position encoding
2 [γ1,β1,α1,γ2,β2,α2]=MLP(c) adaLN modulation parameters from conditioning c
3 x=x+α1Attn(γ1LN(x)+β1) adaLN-modulated self-attention
4 x=x+α2MLP(γ2LN(x)+β2) adaLN-modulated feed-forward
5 FIDC0.4 DiT scaling law ( C = compute)
6 Loss=|ϵϵθ(xt,t,c)|2 Standard diffusion training objective

11. Summary

DiT represents a paradigm shift in diffusion model architecture — moving from the convolution-dominated U-Net to a pure Transformer design. Its key contributions:

  • adaLN conditioning: A unified, parameter-efficient mechanism that injects time and class/text information into every Transformer block via learned scale/shift/gate parameters.
  • Scalability-first design: By removing convolutional inductive biases, DiT trades data efficiency at small scales for superior scaling behavior at large scales.
  • Architecture unification: DiT aligns diffusion models with the broader Transformer ecosystem (ViT, LLMs), enabling cross-domain techniques and infrastructure sharing.

DiT powers Stable Diffusion 3, SORA, and is the foundation of next-generation generative models — proving that in the era of large-scale training, Transformers are the ultimate backbone.


  • [[Diffusion Model]]
  • [[U-Net]]
  • [[Flow Matching]]
  • [[Score Function]]
  • [[Neural ODE]]
  • [[ResNet]]
  • [[Stable Diffusion]]
  • [[ControlNet]]
  • [[DDIM]]
  • [[DPM-Solver]]
  • [[Vision Transformer (ViT)]]