simple diffusion: End-to-end diffusion for high resolution images

概括: 加了很多技巧让 diffusion 在 high resolution images 上 work

Problem

Previous diffusion models are not good at high resolution images. Some available approaches use external way to achieve this, which add complexity to the model.

Key methods

Adjusting noise schedules

对于 high resolution images, 采用 cosine schedule 会导致 global structure 很早已经被确定.

提出了 shifted SNR schedule:

  • 进行了一些理论分析

Multi-scale training loss

定义 \(L_{\theta}^{d\times d}(x)=\frac{1}{d^2}\mathbb{E}_{\epsilon, t}\|D^{d\times d}[\epsilon]-D^{d\times d}[\hat{\epsilon}_{\theta}(\alpha_t x+\sigma_t \epsilon, t)]\|_2^2\) 其中 $D^{d\times d}$ 指的是把 resolution downsample 到 $d\times d$ 的操作.

训练中使用 loss function \(\sum_{s\in \{32, 64, \cdots, d\}}\frac{1}{s} L_{\theta}^{s\times s}(x)\)

  • 解释: loss for higher resolution is noisier

Scaling the architecture

  • 经过仔细的硬件层面的分析, 选择在一个 low resolution (16x16) 的层进行 scaling 操作

Interesting discovery: To avoid doing computations on the highest resolutions, we downsample images immediately as a the first step of the neural network, and up-sample as the last step. Surprisingly, even though the neural networks are cheaper computationally and in terms of memory, we find empirically that they also achieve better performance

2 solutions for downsampling:

  1. Use the invertible and linear 5/3 wavelet (as used in JPEG2000) to transform the image to lower resolution frequency responses
    • In the network, the responses are concatenated over the channel axis
    • The inverse wavelet transform is used to upsample the responses
  2. (simpler but performance panelty) Directly use a $d\times d$ conv layer with stride $d$, and a transposed conv layer when upsample

Dropout

  • regularization is important even for high resolution images!
  • choose to only add dropout to low resolution layers (因为只在这里 scale). 比如只在$\le$ 32x32 的地方加入 dropout

The U-ViT architecture

  • 加入之前的技巧后: one can replace convolutional layers with MLP blocks if the architecture already uses self-attention at that resolution
  • 在 scale 的时候使用一个很大的 transformer; 其余的使用一个小的 U-Net 卷积架构

Technical Details

The parametrization adopts $v$ prediction, which is proposed by Salimans & Ho, 2022. See SQA-002.

  • $v$ prediction is more suitable for high resolution images.

Weight: they use unewighted loss $w(t)=1$ instead of SNR weight.

Experiments and conclusion

  • 同时也训练了 condition on text data
  • up to 512x512 resolution, with only small modifications to the model