PROGRESSIVE DISTILLATION FOR FAST SAMPLING OF DIFFUSION MODELS

present new parameterizations of diffusion models that provide increased stability when using few sampling steps; and present a method to distill a trained deterministic diffusion sampler, using many steps, into a new diffusion model that takes half as many sampling steps

Problem

With no condition (no label or no text prompt), diffusion models’ sampling speed is slow.

Method

Progressive Distillation

  • distillation 的时候模型架构不变
  • 要求 sampling process deterministic (no noise)
  • 只需要对 discrete time steps 进行 distillation
  • distill 速度和训练一个差不多

Results: 在 distill 到 4-8 步的时候仍然能取得很好的效果!

Model parameterization

之前直接预测 $\epsilon$, 但对于 distillation 不适合 (因为步数少的时候前几步很重要)

提出了三种预测目标:

  1. predict $x$
  2. predict both $\tilde{x}{\theta}(z_t)$ and $\tilde{\epsilon}{\theta}(z_t)$ , and set
\[\hat{x}=\sigma_t^2\tilde{x}_{\theta}(z_t)+\alpha_t(z_t-\sigma_t \tilde{\epsilon}_{\theta}(z_t))\]
  1. predict $v=\alpha_t\epsilon-\sigma_t x$

此外, 还基于目标提出了三种权重函数:

Recall loss function $L_{\theta}=\mathbb{E}{\epsilon, t}[w(\lambda_t|\hat{x}{\theta}(z_t)-x|_2^2)]$

其中 $\lambda_t=\log[\alpha_t^2/\sigma_t^2]$ is log signal-to-noise ratio

  1. SNR: (classical) $w(\lambda_t)=\exp(-\lambda_t)$
  2. truncated SNR: \(L_{\theta}=\max(\|x-\hat{x}_t\|_2^2, \|\epsilon-\hat{\epsilon}_t\|_2^2)=\max\big(\frac{\alpha_t^2}{\sigma_t^2}, 1\big)\|x-\hat{x}_t\|_2^2\)
  3. SNR+1 weighting: \(\|v_t-\hat{v}_t\|_2^2=(1+\frac{\alpha_t^2}{\sigma_t^2})\|x-\hat{x}_t\|_2^2\)

Results:

Technical Details

The paper distill on a DDIM model, halving the number of sampling steps.

  • DDIM 的 sampling 是 deterministic 的

In distillation: we sample this discrete time such that the highest time index corresponds to a signal-to-noise ratio of zero, i.e. $\alpha_1 = 0$, which exactly matches the distribution of input noise $z_1 \sim N (0, I)$ that is used at test time. We found this to work slightly better than starting from a non-zero signal-to-noise ratio as used by e.g. Ho et al. (2020), both for training the original model as well as when performing progressive distillation.

repo