DPM-Solver

DPM-Solver is a fast ordinary differential equation (ODE) solver specifically designed for [[Diffusion Model|Diffusion Probabilistic Models]] (DPMs). It exploits the semi-linear structure of the [[Probability Flow ODE]] to achieve high-order convergence with significantly fewer function evaluations, enabling fast sampling with only 10-20 steps.


1. Core Concept

1.1 Motivation

Standard [[Diffusion Model|diffusion models]] require 1000+ steps for high-quality sampling due to:

  • Discretization error in Euler-Maruyama method
  • Stiff dynamics near t=0
  • Random noise accumulation in [[Stochastic Differential Equation (SDE)|SDE]]-based sampling

DPM-Solver addresses this by:

  1. Using the deterministic [[Probability Flow ODE]] formulation
  2. Exploiting the semi-linear structure for analytical solutions
  3. Designing high-order solvers specifically for diffusion models

1.2 Key Innovation

The [[Probability Flow ODE]] has a semi-linear structure:

dxdt=f(t)x+g(t)sθ(x,t)

where:

  • f(t)x : Linear term (can be solved analytically)
  • g(t)sθ(x,t) : Nonlinear term (score function, requires numerical integration)

[!NOTE] Semi-linear Advantage
By solving the linear part exactly and only approximating the nonlinear part, DPM-Solver achieves much higher accuracy than general-purpose ODE solvers with the same number of function evaluations.


2. Mathematical Foundation

2.1 [[Probability Flow ODE]] Recap

The forward [[Stochastic Differential Equation (SDE)|SDE]]:

dx=f(t)xdt+g(t)dWt

has an equivalent [[Probability Flow ODE]]:

dx=[f(t)x12g(t)2xlogpt(x)]dt

2.2 Change of Variables

Define signal-to-noise ratio parameters:

αt=exp(120tβ(s)ds),σt=1αt2

The ODE can be rewritten as:

dxdt=dlogαtdtxdlogαtdtσt2αtϵθ(xt,t)

where ϵθ(xt,t) is the noise prediction network.

2.3 Semi-linear Form

Rearranging terms:

dxdt=dlogαtdtxLineardlogαtdtσt2αtϵθ(xt,t)Nonlinear

This is a semi-linear ODE where the linear part dominates.


3. DPM-Solver Algorithm

3.1 Exact Solution of Linear Part

The linear ODE dxdt=dlogαtdtx has exact solution:

xts=αtsαtxt

3.2 Variation of Constants Formula

Using variation of constants, the full solution is:

xts=αtsαtxtαtsttsdlogατdτστατϵθ(xτ,τ)dτ

Define:

h=logαtsαt(step size in log-SNR space)

Then:

xts=αtsαtxtσts0heτϵθ(xtτ,tτ)dτ

3.3 First-Order DPM-Solver (DPM-Solver-1)

Approximate ϵθ(xτ,τ)ϵθ(xt,t) (constant):

xts=αtsαtxtσts(1eh)ϵθ(xt,t)

This is equivalent to the DDIM update rule.

3.4 Second-Order DPM-Solver (DPM-Solver-2)

Use linear approximation for ϵθ :

  1. Predictor step: Compute intermediate point xtm at tm=t+ts2
xtm=αtmαtxtσtm(1eh/2)ϵθ(xt,t)
  1. Corrector step: Use midpoint rule
xts=αtsαtxtσts(1eh)ϵθ(xtm,tm)

Function evaluations: 2 per step (at t and tm )

3.5 Third-Order DPM-Solver (DPM-Solver-3)

Use quadratic approximation with two intermediate points:

  1. First intermediate: tm1=t+13h
xtm1=αtm1αtxtσtm1(1eh/3)ϵθ(xt,t)
  1. Second intermediate: tm2=t+23h
xtm2=αtm2αtxtσtm2[(1e2h/3)ϵθ(xt,t)+34(1e2h/3)2(ϵθ(xtm1,tm1)ϵθ(xt,t))]
  1. Final step: Use Simpson’s rule
xts=αtsαtxtσts[c1ϵθ(xt,t)+c2ϵθ(xtm1,tm1)+c3ϵθ(xtm2,tm2)]

where c1,c2,c3 are quadrature coefficients.

Function evaluations: 3 per step


4. Algorithm Summary

4.1 DPM-Solver Pseudocode

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
# DPM-Solver-2 Sampling Algorithm
def dpm_solver_2_sample(x_T, model, timesteps):
"""
x_T: Initial noise ~ N(0, I)
model: Noise prediction network epsilon_theta(x, t)
timesteps: Array of time points [T, T-1, ..., 0]
"""
x_t = x_T

for i in range(len(timesteps) - 1):
t = timesteps[i]
t_next = timesteps[i + 1]

# Compute step size in log-SNR space
h = log(alpha[t_next] / alpha[t])

# First evaluation at current point
eps_t = model(x_t, t)

# Midpoint prediction
t_mid = (t + t_next) / 2
h_mid = log(alpha[t_mid] / alpha[t])

x_mid = (alpha[t_mid] / alpha[t]) * x_t - sigma[t_mid] * (1 - exp(-h_mid)) * eps_t

# Second evaluation at midpoint
eps_mid = model(x_mid, t_mid)

# Final update
x_next = (alpha[t_next] / alpha[t]) * x_t - sigma[t_next] * (1 - exp(-h)) * eps_mid

x_t = x_next

return x_0

4.2 Order Comparison

Solver Order Function Evaluations Steps Needed Total NFE
Euler 1st 1 per step 50-100 50-100
DPM-Solver-1 1st 1 per step 50-100 50-100
DPM-Solver-2 2nd 2 per step 10-20 20-40
DPM-Solver-3 3rd 3 per step 10-15 30-45
RK4 4th 4 per step 20-30 80-120

[!TIP] Efficiency Insight
DPM-Solver-2 achieves high quality with only 20-40 function evaluations, compared to 1000+ for [[Diffusion Model|DDPM]] or 100+ for general-purpose ODE solvers.


5. Advanced Variants

5.1 DPM-Solver++

Improvements over DPM-Solver:

  • Better numerical stability
  • Unified framework for different parameterizations
  • Adaptive step size control

Key Innovation: Use x0 -prediction instead of ϵ -prediction for better stability near t=0 .

x0=1αtxtσtαtϵθ(xt,t)

5.2 DPM-Solver-Adaptive

Adaptive Step Size Control:

  1. Estimate local truncation error using embedded methods
  2. Adjust step size h based on error tolerance
  3. Accept/reject steps dynamically

Error Estimation (for DPM-Solver-2):

Errorxts(2)xts(1)

where xts(2) is DPM-Solver-2 result and xts(1) is DPM-Solver-1 result.

5.3 DPM-Solver with Correctors

Predictor-Corrector Framework:

  1. Predictor: Take one DPM-Solver step
  2. Corrector: Apply few steps of Langevin dynamics
  3. Repeat: For enhanced sample quality
1
2
3
4
5
6
7
8
9
10
11
12
# Predictor-Corrector with DPM-Solver
for t, t_next in timesteps:
# Predictor: DPM-Solver-2 step
x_pred = dpm_solver_2_step(x_t, t, t_next, model)

# Corrector: Langevin dynamics (optional)
for _ in range(num_corrector_steps):
score = -model(x_pred, t_next) / sigma[t_next]
noise = random_normal()
x_pred = x_pred + step_size * score + noise_scale * noise

x_t = x_pred

6. Theoretical Analysis

6.1 Convergence Order

Theorem: DPM-Solver- k has convergence order O(hk) where h is the step size.

Proof Sketch:

  1. Expand ϵθ(xτ,τ) in Taylor series
  2. Match terms up to order k
  3. Show truncation error is O(hk+1)

6.2 Stability Analysis

Linear Stability: For linear ODE dxdt=λx , DPM-Solver is stable if:

|R(hλ)|1

where R(z) is the stability function.

DPM-Solver Advantage: The exact solution of linear part ensures better stability than explicit methods.

6.3 Error Decomposition

Total error consists of:

  1. Discretization error: From numerical integration ( O(hk) )
  2. Score approximation error: From imperfect score model ( O(ϵscore) )
  3. Accumulated error: Propagated through multiple steps

Key Insight: Higher-order methods reduce discretization error, making score approximation error dominant.


7. Practical Implementation

7.1 Time Schedule Design

Uniform vs Non-uniform Schedules:

Schedule Steps Distribution Best For
Uniform Equal spacing Simple implementation
Log-SNR More steps near t=0 Better accuracy
Adaptive Dynamic based on error Optimal efficiency

Recommended Schedule (for 20 steps):

1
2
3
4
5
6
def get_time_schedule(N=20):
# Log-linear spacing
t = np.linspace(1e-5, 1.0, N)
# Transform to time steps
timesteps = np.flip(t) # Reverse: T -> 0
return timesteps

7.2 Numerical Stability Tips

1. Avoid Division by Small Values:

1
2
3
4
5
# Unstable
x_0 = x_t / alpha_t - (sigma_t / alpha_t) * eps

# Stable
x_0 = (x_t - sigma_t * eps) / alpha_t

2. Clamp Time Values:

1
t = torch.clamp(t, min=1e-5, max=1.0)

3. Use Log-SNR Parameterization:

λt=logαtσt

This provides better numerical conditioning.

7.3 Batch Processing

Parallel Sampling:

1
2
3
4
# Sample multiple images in parallel
batch_size = 64
x_T = torch.randn(batch_size, 3, 64, 64) # Batch of noise
x_0 = dpm_solver_2_sample(x_T, model, timesteps)

Memory Optimization:

  • Process in batches to fit GPU memory
  • Use gradient checkpointing if needed
  • Precompute αt , σt values

8. Performance Comparison

8.1 Sampling Speed vs Quality

Method FID (CIFAR-10) Steps Time (s) NFE
[[Diffusion Model|DDPM]] 3.17 1000 21.7 1000
DDIM 4.16 100 2.2 100
DPM-Solver-2 3.28 20 0.5 40
DPM-Solver-3 3.19 15 0.4 45
DPM-Solver++ 3.15 20 0.5 20

8.2 Comparison with Other Fast Samplers

Method Type Steps Quality Training Required
DPM-Solver ODE solver 10-20 High No (plug-and-play)
DDIM ODE solver 50-100 Medium-High No
Consistency Models Distillation 1-8 High Yes (retraining)
Progressive Distillation Distillation 2-8 High Yes (retraining)
Rectified Flows Retraining 1-10 High Yes (retraining)

[!NOTE] Key Advantage
DPM-Solver is a plug-and-play solver that works with any pre-trained [[Diffusion Model]] without retraining, unlike distillation-based methods.


9. Applications

9.1 Text-to-Image Generation

Stable Diffusion + DPM-Solver:

  • Original: 50 DDIM steps (~10 seconds)
  • With DPM-Solver-2: 20 steps (~4 seconds)
  • Quality: Comparable or better FID scores

9.2 High-Resolution Synthesis

Benefits for Large Images:

  • Fewer steps = less memory accumulation
  • Better stability for high-dimensional data
  • Enables real-time generation (1-2 seconds for 1024×1024)

9.3 Video Generation

Temporal Consistency:

  • Deterministic ODE trajectories ensure smooth transitions
  • Fewer steps reduce temporal artifacts
  • Suitable for frame interpolation tasks

9.4 3D Generation

Score Distillation Sampling (SDS):

  • DPM-Solver provides stable gradients
  • Faster convergence in optimization-based generation
  • Used in DreamFusion, Magic3D, etc.

10. Core Formula Cards

[!QUOTE] [[Probability Flow ODE]]

dxdt=f(t)x12g(t)2xlogpt(x)

[!QUOTE] Semi-linear Form

dxdt=dlogαtdtxdlogαtdtσt2αtϵθ(xt,t)

[!QUOTE] Exact Linear Solution

xts=αtsαtxtσts0heτϵθ(xtτ,tτ)dτ

[!QUOTE] DPM-Solver-1 (First-Order)

xts=αtsαtxtσts(1eh)ϵθ(xt,t)

[!QUOTE] DPM-Solver-2 (Second-Order)

xtm=αtmαtxtσtm(1eh/2)ϵθ(xt,t) xts=αtsαtxtσts(1eh)ϵθ(xtm,tm)

[!QUOTE] Step Size in Log-SNR Space

h=logαtsαt

11. Debugging and Troubleshooting

11.1 Common Issues

Problem 1: Sample quality degrades with fewer steps

Causes:

  • Step size too large
  • Low-order solver (DPM-Solver-1)
  • Stiff dynamics near t=0

Solutions:

  • Increase number of steps (try 20-30)
  • Use DPM-Solver-2 or DPM-Solver-3
  • Use non-uniform time schedule (more steps near t=0 )

Problem 2: Numerical instability (NaN values)

Causes:

  • Division by αt when t0
  • Score function explosion
  • Accumulated rounding errors

Solutions:

  • Clamp t[105,1.0]
  • Clip score values: sθ2<threshold
  • Use x0 -prediction instead of ϵ -prediction

Problem 3: Slow sampling despite DPM-Solver

Causes:

  • Too many function evaluations
  • Inefficient implementation
  • Large batch size causing memory bottleneck

Solutions:

  • Use DPM-Solver-2 with 15-20 steps
  • Precompute αt , σt values
  • Optimize model inference (TensorRT, ONNX)

11.2 Quality Checklist

Before deploying DPM-Solver:

  • [ ] Test with different step counts (10, 15, 20, 30)
  • [ ] Compare FID/IS scores with baseline (DDIM-50)
  • [ ] Check for artifacts in generated samples
  • [ ] Verify time schedule is appropriate
  • [ ] Monitor numerical stability (no NaN/Inf)
  • [ ] Profile inference time and memory usage

12. Extensions and Variants

12.1 Unified Framework

DPM-Solver can handle different model parameterizations:

Parameterization Network Predicts Best For
ϵ -prediction Noise ϵ Standard training
x0 -prediction Clean data x0 Better stability
v -prediction Velocity v=αtϵσtx0 Balanced performance

12.2 Multistep Methods

DPM-Solver-Multistep: Use information from previous steps (like Adams-Bashforth):

xts=αtsαtxtσtsi=0k1ciϵθ(xti,ti)

Advantage: Fewer function evaluations (1 per step after initialization)

12.3 Integration with Other Methods

DPM-Solver + Consistency Models:

  • Use DPM-Solver for high-quality sampling
  • Distill to consistency model for fast deployment

DPM-Solver + Rectified Flows:

  • Straighter ODE trajectories
  • Even fewer steps needed (5-10)

  • [[Diffusion Model]]
  • [[Probability Flow ODE]]
  • [[Stochastic Differential Equation (SDE)]]
  • [[Score Function]]
  • [[DDIM]]
  • [[Numerical ODE Methods]]
  • [[Runge-Kutta Methods]]
  • [[Consistency Models]]
  • [[Rectified Flows]]
  • [[Flow Matching]]
  • [[Wiener Process|Wiener Process]]
  • [[Markov Process]]
  • [[Neural ODE]]
  • [[Fast Sampling Methods]]
  • [[Langevin Dynamics]]

Dataview Query

1
2
3
LIST
FROM #dpm_solver OR #ode_solver OR #fast_sampling
SORT file.ctime DESC

References

  • Paper: DPM-Solver: A Fast ODE Solver for [[Diffusion Model|Diffusion Probabilistic Model]] Sampling (Lu et al., 2022)
  • Paper: DPM-Solver++: Fast Solver for Guided Sampling of Diffusion Probabilistic Models (Lu et al., 2022)
  • Paper: DPM-Solver-3: Third-Order Fast ODE Solver for Diffusion Models (2023)
  • GitHub: https://github.com/LuChengTHU/dpm-solver
  • Blog: Understanding DPM-Solver - Lilian Weng
  • Course: CS236 Deep Generative Models (Stanford) - Lecture on Fast Sampling