Flow Matching and Diffusion Models 3
Published:
Training the Generative Model
Flow Matching
Recall that the flow matching loss is defined as
\[\begin{align} \mathcal{L}_\text{FM}(\theta) &= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)-u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x)\|^2] \end{align}\]The transformation is conducted by law of total expectation. However, in real-world scenario, it is hard to compute the \(u_t^{target}(x)\) term. Notice that the conditional velocity field \(u_t^{target}(x \mid z)\) is tractable. Let us define the conditional flow matching loss
\[\mathcal{L}_{\text{CFM}} = \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x\mid z)\|^2]\]Now we prove that \(\mathcal{L}_{\text{FM}} = \mathcal{L}_{\text{CFM}} + C\)
Proof. \(\begin{align} \mathcal{L}_\text{FM}&= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)-u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)\|^2-2u_t^\theta(x)^Tu_t^{target}(x)+\|u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)\|^2]-2\mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[u_t^\theta(x)^Tu_t^{target}(x|z)]+C_1\\ &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x\mid z)\|^2]+C_2+C_1\\ &= \mathcal{L}_{\text{CFM}} + C \end{align}\)
Once \(u_t^\theta\) is trained, we can simulate the flow model
\[\text{d}X_t = u_t^\theta(X_t)\text{d}t,\quad X_0\sim p_{init}\]Flow Matching Training Procedure (here for Gaussian CondOT path \(p_t(x|z) = \mathcal{N}(tz, (1-t^2)I_d)\)) Require: A dataset of samples \(z\sim p_{target}\), neural network \(u_t^\theta\) For each mini-batch of data do
- Sample a data example \(z\) from the dataset
- Sample a random time \(t\sim \text{Unif}_{[0,1]}\)
- Sample noise \(\epsilon\sim \mathcal{N}(0, I_d)\)
- Set \(x = tz+(1-t)\epsilon\)
- Compute loss \(\mathcal{L}(\theta) = \|u_t^\theta(x)-(z-\epsilon)\|^2\)
- Update model parameter \(\theta\) via gradient descent on \(\mathcal{L}(\theta)\)
Score Matching
Recall the SDE with the same marginal distribution
\[\text{d}X_t = \left[u_t^{target}(X_t)+\frac{\sigma^2_t}{2}\nabla\text{log}p_t(X_t)\right]\text{d}t+\sigma_t\text{d}W_t\]In the similar way, we define score matching loss and conditional score matching loss:
\[\begin{align} \mathcal{L}_{\text{SM}} &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|s_t^\theta(x)-\nabla\text{log}p_t(x)\|^2]\\ \mathcal{L}_{\text{CSM}} &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|s_t^\theta(x)-\nabla\text{log}p_t(x\mid z)\|^2] \end{align}\]Consider the toy example provided last post \(p_t(x\mid z) = \mathcal{N}(\mu_t(z), \sigma_t^2I_d)\), the conditional score \(\nabla\text{log}p_t(x\mid z) = -\frac{x-\mu_t(z)}{\sigma^2_t}\)
\[\mathcal{L}_{\text{CSM}} = \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\frac{1}{\sigma^2_t}\|\sigma_ts_t^\theta(\mu_t(z)+\sigma_t\epsilon)+\epsilon\|^2]\]Notice that the score network \(s_t^\theta\) essentially learns to predict the noise that was used to corrupt a data sample z. Therefore, the above training loss is also called denoising score matching in early stage work. In Denoising Diffusion Probabilistic Models, constant \(\frac{1}{\sigma_t^2}\) is dropped and reparameterize \(s_t^\theta\) into a noise predictor network \(\epsilon_t^\theta: \mathbb{R}\times[0,1]\rightarrow\mathbb{R}^d\) via:
\[-\sigma_ts_t^\theta(x) = \epsilon_t^\theta(x) \Rightarrow \mathcal{L}_\text{DDPM}=\mathbb{E}_{t\sim\text{Unif}, z\sim p_{target},\epsilon\sim\mathcal{N}(0,I_d)}[\|\epsilon_t^\theta(\mu_t(z)+\sigma_t\epsilon)+\epsilon\|^2]\]Score Matching Training Procedure for Gaussian probability path Require: A dataset of samples \(z\sim p_{target}\), neural network \(s_t^\theta\) For each mini-batch of data do
- Sample a data \(z\) from the dataset
- Sample a random time \(t\sim \text{Unif}_{[1,0]}\)
- Sample noise \(\epsilon \sim \mathcal{N}(0,I_d)\)
- Set \(x_t = \mu_t(z)+\sigma_t\epsilon\)
- Compute loss \(\mathcal{L}(\theta) = \|s_t^\theta(x) + \frac{\epsilon}{\sigma_t}\|^2\)
- Updata \(\theta\) via gradient descent on \(\mathcal{L}(\theta)\)
