Flow Matching and Diffusion Models 3

2 minute read

Published:

Training the Generative Model

Flow Matching

Recall that the flow matching loss is defined as

\[\begin{align} \mathcal{L}_\text{FM}(\theta) &= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)-u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x)\|^2] \end{align}\]

The transformation is conducted by law of total expectation. However, in real-world scenario, it is hard to compute the \(u_t^{target}(x)\) term. Notice that the conditional velocity field \(u_t^{target}(x \mid z)\) is tractable. Let us define the conditional flow matching loss

\[\mathcal{L}_{\text{CFM}} = \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x\mid z)\|^2]\]

Now we prove that \(\mathcal{L}_{\text{FM}} = \mathcal{L}_{\text{CFM}} + C\)

Proof. \(\begin{align} \mathcal{L}_\text{FM}&= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)-u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},x\sim p_t}[\|u_t^\theta(x)\|^2-2u_t^\theta(x)^Tu_t^{target}(x)+\|u_t^{target}(x)\|^2]\\ &= \mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)\|^2]-2\mathbb{E}_{t\sim \text{Unif},z\sim p_{target}, x\sim p_t(\cdot\mid z)}[u_t^\theta(x)^Tu_t^{target}(x|z)]+C_1\\ &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|u_t^\theta(x)-u_t^{target}(x\mid z)\|^2]+C_2+C_1\\ &= \mathcal{L}_{\text{CFM}} + C \end{align}\)

Once \(u_t^\theta\) is trained, we can simulate the flow model

\[\text{d}X_t = u_t^\theta(X_t)\text{d}t,\quad X_0\sim p_{init}\]

Flow Matching Training Procedure (here for Gaussian CondOT path \(p_t(x|z) = \mathcal{N}(tz, (1-t^2)I_d)\)) Require: A dataset of samples \(z\sim p_{target}\), neural network \(u_t^\theta\) For each mini-batch of data do

  • Sample a data example \(z\) from the dataset
  • Sample a random time \(t\sim \text{Unif}_{[0,1]}\)
  • Sample noise \(\epsilon\sim \mathcal{N}(0, I_d)\)
  • Set \(x = tz+(1-t)\epsilon\)
  • Compute loss \(\mathcal{L}(\theta) = \|u_t^\theta(x)-(z-\epsilon)\|^2\)
  • Update model parameter \(\theta\) via gradient descent on \(\mathcal{L}(\theta)\)

Score Matching

Recall the SDE with the same marginal distribution

\[\text{d}X_t = \left[u_t^{target}(X_t)+\frac{\sigma^2_t}{2}\nabla\text{log}p_t(X_t)\right]\text{d}t+\sigma_t\text{d}W_t\]

In the similar way, we define score matching loss and conditional score matching loss:

\[\begin{align} \mathcal{L}_{\text{SM}} &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|s_t^\theta(x)-\nabla\text{log}p_t(x)\|^2]\\ \mathcal{L}_{\text{CSM}} &= \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\|s_t^\theta(x)-\nabla\text{log}p_t(x\mid z)\|^2] \end{align}\]

Consider the toy example provided last post \(p_t(x\mid z) = \mathcal{N}(\mu_t(z), \sigma_t^2I_d)\), the conditional score \(\nabla\text{log}p_t(x\mid z) = -\frac{x-\mu_t(z)}{\sigma^2_t}\)

\[\mathcal{L}_{\text{CSM}} = \mathbb{E}_{t\sim \text{Unif}, z\sim p_{target}, x\sim p_t(\cdot\mid z)}[\frac{1}{\sigma^2_t}\|\sigma_ts_t^\theta(\mu_t(z)+\sigma_t\epsilon)+\epsilon\|^2]\]

Notice that the score network \(s_t^\theta\) essentially learns to predict the noise that was used to corrupt a data sample z. Therefore, the above training loss is also called denoising score matching in early stage work. In Denoising Diffusion Probabilistic Models, constant \(\frac{1}{\sigma_t^2}\) is dropped and reparameterize \(s_t^\theta\) into a noise predictor network \(\epsilon_t^\theta: \mathbb{R}\times[0,1]\rightarrow\mathbb{R}^d\) via:

\[-\sigma_ts_t^\theta(x) = \epsilon_t^\theta(x) \Rightarrow \mathcal{L}_\text{DDPM}=\mathbb{E}_{t\sim\text{Unif}, z\sim p_{target},\epsilon\sim\mathcal{N}(0,I_d)}[\|\epsilon_t^\theta(\mu_t(z)+\sigma_t\epsilon)+\epsilon\|^2]\]

Score Matching Training Procedure for Gaussian probability path Require: A dataset of samples \(z\sim p_{target}\), neural network \(s_t^\theta\) For each mini-batch of data do

  • Sample a data \(z\) from the dataset
  • Sample a random time \(t\sim \text{Unif}_{[1,0]}\)
  • Sample noise \(\epsilon \sim \mathcal{N}(0,I_d)\)
  • Set \(x_t = \mu_t(z)+\sigma_t\epsilon\)
  • Compute loss \(\mathcal{L}(\theta) = \|s_t^\theta(x) + \frac{\epsilon}{\sigma_t}\|^2\)
  • Updata \(\theta\) via gradient descent on \(\mathcal{L}(\theta)\)