ZHH-015
[Paper] Simplifying, Stabilizing and Scaling Continuous-Time Consistency Models
这个主要就是把continous-time CM做work了,甚至效果上超过了discrete time(不管是distill还是直接训练)。CIFAR上面1步FID是2.97(没超过ICM),但2步是2.06。在ImageNet上distillation更是完全吊打
主要做的事情:
- Trigflow, 统一了EDM和FM;
- 修改time-conditioning,在embedding和adaGN两个方面;
- 修改了CM的训练目标。
这里我们先关注training了,因为细节其实就少。之后如果要搞distill再说。
Training
Discrete-time CM的训练目标是:
\[L=\mathbb{E}_{t,x,\epsilon}\left[w(t)\left\|f_{\theta}(x_t,t)-f_{\theta^-}(x_{t-\Delta t},t-\Delta t)\right\|^2\right]\](这里假设使用L2 distance作为d)其中
\[\theta^- := \text{SG}(\theta).\]我们可以把它的 梯度 改写为连续的形式(即取 $\Delta t\to 0$):
\[\frac{1}{\Delta t} \nabla_{\theta} L\approx \mathbb{E}_{t,x,\epsilon}\left[2w(t) \frac{df_{\theta^-}(x_t,t)}{dt}\cdot \nabla_{\theta} f_{\theta}(x_t,t)\right]\]因此我们可以定义Continous-time CM的训练目标为:
\[L_{\text{cont}}=\mathbb{E}_{t,x,\epsilon}\left[\tilde{w}(t) \left\|f_{\theta}(x_t,t) + k(t)\frac{df_{\theta^-}(x_t,t)}{dt} - f_{\theta^-}(x_t,t)\right\|^2\right]\]注意这个是fake loss,只有梯度有意义。只要 $\tilde{w}(t)k(t)=w(t)$,这两个函数就可以任意选择,给了很少的调参空间。活着。
TrigFlow: Improved diffusion schedule
CM原来用的是EDM的diffusion schedule,现在它改成用 TrigFlow:
\[x_t = \cos t\cdot x_0 + \sin t \cdot \epsilon\]注意 $t\in [0, \frac{\pi}{2}]$. 它采用的pre-conditioner形如:
\[f_{\theta}(x_t,t) = \cos t\cdot x_t - \sigma_d \sin t \cdot F_{\theta}\left(\frac{x_t}{\sigma_d},c_{\text{noise}}(t)\right)\]其中 $F_{\theta}$ 就是UNet, $c_{\text{noise}}(t)$ 是noise embedding。注意实际上只要 $f_{\theta}(x,0)\equiv x$,那么就是一个valid的CM;但是这里选择这个sin和cos好像是有深意的,可以参见appendix。我想看数学啊!
实验上经过感觉很少的调参,作者发现 $k(t)=\sigma_d \sin t\cos t$ 比较好。这样原来的loss就成为了
\[L_{\text{cont}}=\mathbb{E}_{t,x,\epsilon}\left[w_1(t) \left\|F_{\theta}\left(\frac{x_t}{\sigma_d},c_{\text{noise}}(t)\right)- F_{\theta^-}\left(\frac{x_t}{\sigma_d},c_{\text{noise}}(t)\right) - \cos t\cdot\frac{df_{\theta^-}(x_t,t)}{dt}\right\|^2\right]\]Stablizing training
为啥continous time CM的训练不稳定?作者认为核心是训练目标里的 $\frac{df_{\theta^-}(x_t,t)}{dt}$ 不稳定。具体而言可以再拆出来,因为UNet的输出和对image $x_t$ 的梯度可以认为是比较稳定的,所以最后发现实际上是
\[\frac{\partial}{\partial t} F_{\theta}\left(\frac{x_t}{\sigma_d},c_{\text{noise}}(t)\right)\]这一项的问题。这又可以进一步拆成三个部分:
- $c_{\text{noise}}(t)$ 对 $t$ 的导数不稳定(比如EDM使用的log condition):我们通过设置 $c_{\text{noise}}(t)\equiv t$来解决。
- Embedding 对 $c_{\text{noise}}(t)$ 的导数不稳定:我们不用fourier embedding了,改成用positional embedding。
- 神经网络输出对embedding不稳定:我们可以把AdaGN
改写为所谓 adaptive double normalization:
\[h_{\text{out}} = \text{pnorm}(s(t)) * \text{GN}(h_{\text{in}}) + \text{pnorm}(b(t))\]其中,$\text{pnorm}$ 是pixel-normalization:
\[\text{pnorm}(v) = \frac{v}{\sqrt{\frac{1}{f}\sum_i v_i^2+\epsilon}}\]其中 $i=1,…,f$, $f$ 是feature的数量。
- 最后,这些都做了之后还是不够稳定,于是做了一个clipping:
其中 $c=0.1$。
- 最后,还不够,我们观察到
作者敏锐的观察到, $\cos t$ 乘的部分都还算稳定,但 $\sin t$ 乘的部分都不太稳定。因此作者做了一个改动:
\[\frac{df_{\theta^-}(x_t,t)}{dt} \leftarrow \cos t \cdot \left(\frac{d x_t}{dt}- \sigma_d F_{\theta} \right) + r \sin t \left(-x_t - \sigma_d \cdot \frac{d}{dt} F_{\theta}\left(\frac{x_t}{\sigma_d},c_{\text{noise}}(t)\right) \right)\]其中 $r$ 在前10k steps从0线性增加到1。(???这有道理吗)
加上这一套东西,FID降的其实少。弱
More tricks
Adaptive Weighting:这个来自EDM2,就是把最终目标里面的 $w_1(t)$ 变成learnable的,从而不需要manual tune。
具体方法就是,我们把不同时间步骤的学习当成一种 multi-task learning:
\[L(\theta) = \mathbb{E}_{t}\left[e^{u(t)} L_t(\theta)-u(t)\right]\]设想现在我们同时让 $\theta$ 和 $u$ 可学习,会怎么样?我们发现 $u$ 会学到
\[u(t) = - \log \text{SG}(L_t(\theta))\]因此,对 $\theta$ 再进行优化的时候,就相当于优化
\[\mathbb{E}_{t}\left[\frac{L_t(\theta)}{\text{SG}(L_t(\theta))}\right]\]也就是说,不同的loss根据大小进行re-weight。用这个思路,我们最终给出了 sCM loss:
\[L_{\text{sCM}}=\mathbb{E}_{t,x,\epsilon}\left[e^{w_{\phi}(t)}\cdot \frac{1}{d}\left\|F_{\theta}\left(\frac{x_t}{\sigma_d},t\right)- F_{\theta^-}\left(\frac{x_t}{\sigma_d},t\right) - \cos t\cdot\frac{df_{\theta^-}(x_t,t)}{dt}\right\|^2 - w_{\phi}(t)\right]\]最后,取 $t$ 的分布为
\[t\sim \arctan \left(\frac{1}{\sigma_d} \exp \mathcal{N}(-1.0,1.4) \right)\]JVP computation:我们计算 $F_{\theta}$ 对 $t$ 的导数的时候,可以写为JVP
\[\frac{d}{dt} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right) = \left(\frac{dx_t}{dt}\right) \cdot \nabla_{x_t} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right) + (1)\cdot \frac{\partial}{\partial t} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right)\]但是这样发现不够end-to-end,可能导致中间有numerical issue。注意最后落实到loss里我们要的实际上是
\[\sin t \cos t\cdot \frac{d}{dt} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right) = \left(\frac{\sin t \cos t}{\sigma_d} \frac{dx_t}{dt}\right) \cdot \nabla_{\frac{x_t}{\sigma_d}} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right) + (\sin t \cos t)\cdot \frac{\partial}{\partial t} F_{\theta}\left(\frac{x_t}{\sigma_d},t\right)\]如果使用的库的JVP实现的比较好的话,这样计算的JVP就可以大幅减少numerical error。