Flow Matching and Diffusion Models 2

6 minute read

Published:

Constructing the Training Target

A quick review

In last blog, we constructed flow and diffusion models where we obtain trajectories \((X_t)_{0\leq t \leq 1}\) by simulating the ODE/SDE

\[\begin{align} X_0 \sim p_{init}, \quad \text{d}X_t &= u^\theta_t(X_t)\text{d}t, &\quad \triangleright \text{ Flow model} \\ X_0 \sim p_{init}, \quad \text{d}X_t &= u^\theta_t(X_t)\text{d}t+\sigma_t\text{d}W_t, &\quad \triangleright \text{ Diffusion Model} \end{align}\]

where \(u_t^\theta\) is a neural network and \(\sigma_t\) is a fixed diffusion coefficienty. Naturally, if we randomly choose \(\theta\), output will be nonsense. We train the network by minimizing a loss function \(\mathcal{L}(\theta)\), such as the mean-squared error

\[\mathcal{L}(\theta) = \left\| u_t^\theta(x)-u_t^{target}(x)\right\| ^2,\]

where \(u_t^{target}\) is the training target we need to approximate. In this blog, we will introduce how to find an equation for the training target.

Conditional and Marginal Probability Path

A conditional probability path is a set of distribution \(p_t(x\mid z)\) over \(\mathbb{R}^d\) that gradually converts a single data point into the distribution \(p_{target}\)

\[p_0(\cdot\mid z) = p_{init}, \quad p_1(\cdot\mid z) = p_{target} \quad \text{for all } z\in\mathbb{R}^d.\]

The marginal probability path \(p_t(x)\) is defined from the conditional probability path

\[\begin{align} z \sim p_{target}, x\sim p_t(\cdot\mid z) \Rightarrow x\sim p_t &\triangleright \text{sampling from marginal path},\\ p_t(x) = \int p_t(x\mid z)p_{target}(z)\text{d}z &\triangleright \text{density of marginal path}. \end{align}\]

Conditional probability path is the “path” that specifies a gradual interpolation, converting the distribution of a sigle data point \(z\) between \(p_{init}\) and \(p_{target}\) as \(t\) varies. Marginal probability path is the overall distribution of \(x\) at time \(t\). Eq. 4 shows the connection of conditional and marginal probability path.

Here we provide an toy example. First we give the two distribution

\[p_{init}(z) = \mathcal{N}(0,\mathbf{I}),\quad p_{target}(z)=\delta_z.\]

Assuming that the conditional probability path is a gaussian distribution whose variance and mean value linearly varies with time \(t\), which is to say

\[p_t(x\mid z) = \mathcal{N}(x;\mu_t(z),\sigma_t^2\mathbf{I}).\]

The mean value and variance satisfies

\[\mu_0(z) = 0, \mu_1(z) = z, \sigma^2_0 = \mathbf{I}, \sigma^2_1 = \delta_z.\]

Therefore, an intuitive choice is

\[\mu_t(z) = t\cdot z, \quad \sigma^2_t = (1-t).\]

Finally we obtain the marginal probability path

\[\begin{align} p_t(x) &= \int p_t(x\mid z)p_{target}(z)\text{d}z, \\ p_t(x) &= \int \mathcal{N}(x;\mu_t(z^\prime),\sigma_t^2\mathbf{I})\delta_z\text{d}z^\prime\\ &=\mathcal{N}(x;t\cdot z,(1-t)\mathbf{I}). \end{align}\]

How to understand \(z\) as the condition in conditional probability path?
In brief, \(z\) determines the start and end of the path. First, the conversion from \(p_{init}\) to \(p_{target}\) is achieved by many probability path. By setting \(z\) as the condition, we actually defines many paths. Second, it serves as the signal in denoising. \(p_t(x\mid z)\) tells the model what the sample \(x\) looks like at time \(t\). Third, by integrating all the conditional probability path and their probability, we obtain the marginal probability path, which describes the distribution of all data points at time \(t\).

Conditional and Marginal Vector Fields

For every data point \(z\in \mathbb{R}^d\), let \(u_t^{target}(\cdot\mid z)\) denote a conditional vector field, which satisfies

\[X_0\sim p_{init},\quad \frac{\text{d}}{\text{d}z}X_t = u_t^{target}(X_t\mid z) \Rightarrow X_t\sim p_t(\cdot\mid z)\quad (0\leq t\leq 1)\]

Then the marginal vector field is defined by

\[u_t^{target}(x) = \int u_t^{target}(x\mid z)\frac{p_t(x\mid z)p_{target}(z)}{p_t(x)}\text{d}z \quad \triangleright\text{marginalization trick},\]

follows the marginal probability path, i.e.

\[X_0\sim p_{init},\quad \frac{\text{d}}{\text{d}z}X_t = u_t^{target}(X_t) \Rightarrow X_t\sim p_t \quad (0\leq t\leq 1)\]

How to understand the marginalization trick? First, for each data point \(z\), we can define a conditional vector field \(u_t^{target}(\cdot\mid z)\) whose ODE simualtes the probability path \(p_t(\cdot \mid z)\). And the marginal vector field is the weighted averaged sum of all conditional vector field. Here \(\frac{p_t(x\mid z)p_{target}(z)}{p_t(x)} = p_t(z\mid x)\) is the weight. It represents the probability \(z\) is the origin of \(x\) at time \(t\).

This can also be proved by the Continuity Equation Let us consider an flow model with vector field \(u_t^{target}\) with \(X_0\sim o_{init}\). Then \(X_t\sim p_t\) for all \(0\leq t \leq 1\) if and only if

\[\frac{\partial p_t(x)}{\partial t} = -\nabla_\text{x}\cdot(p_t(x)u_t^{target}(x)) \quad \text{for all } x\in\mathbb{R}^d, 0\leq t\leq 1.\]

We will provide a proof in the appendix at end of the blog. Here we just provide a intuivive explaination. LHS denotes the change rate of the probability mass. RHS denotes the net inflow of probability mass. Since the probability mass should be conserved, LHS should equal to RHS. Now we proceed to proof the marginalization trick

\[\begin{align} \frac{\partial p_t(x)}{\partial t} &= \frac{\partial}{\partial t}\int p_t(x\mid z)p_{target}(z)\text{d}z \\ &= \int \frac{\partial p_t(x\mid z)}{\partial t} p_{target}(z)\text{d}z \\ &= \int -\nabla_\text{x}\cdot\left[ (p_t(x\mid z)u_t^{target}(x\mid z))\right]p_{target}(z)\text{d}z\\ &= -\nabla_\text{x} \cdot \left[\int(p_t(x\mid z)u_t^{target}(x\mid z))p_{target}(z)\text{d}z\right]\\ &= -\nabla_\text{x}\cdot(p_t(x)u_t^{target}(x)), \end{align}\]

which implies

\[u_t^{target}(x) = \int u_t^{target}(x\mid z)\frac{p_t(x\mid z)p_{target}(z)}{p_t(x)}\text{d}z\]

Consider the toy example mentioned before, let us compute the Gaussian Conditional Vector Field. We have defined the conditional probability path by

\[p_t(x\mid z) = \mathcal{N}(x;\mu_t(z),\sigma_t^2\mathbf{I}).\]

Hence we can represent \(x\) as the linear combination of initial data \(z\) and noise \(x_0\)

\[x_t = \mu_t(z)+\sigma_t x_0, \\ \mu_t(z) = t\cdot z,\quad \sigma^2_t=(1-t),\quad x_0\sim \mathcal{N}(0, \mathbf{I}).\]

For simplicity and generalizability, we reserve \(\mu_t(z)\) and \(\sigma_t\) in the following computation.

We assert that the conditional Gaussian vector field is given by

\[u^{target}_t(x_t\mid z) = \dot {\mu_t(z)}-\frac{\dot {\sigma_t}}{\sigma_t}\mu_t(z)+ \frac{\dot{\sigma_t}}{\sigma_t}x_t.\]

Proof. First we define the conditional flow model

\[\begin{align} \phi_t^{target}(x\mid z) &= \mu_t(z)+\sigma_t x_0 \\ X_t = \phi_t^{target}(X_0\mid t) &= \mu_t(z)+\sigma_tX_0. \end{align}\]

Then let us extract the vector field.

\[\begin{align} &\frac{\text{d}}{\text{d}t}\phi_t^{target}(x\mid z) = u_t^{target}(\phi_t^{target}(x\mid z)\mid z)\\ &\Leftrightarrow \dot {\mu_t(z)}+\dot {\sigma_t}x_0= u_t^{target}(\mu_t(z)+\sigma_tx_0\mid z)\\ &\Leftrightarrow \dot {\mu_t(z)}+\dot {\sigma_t}(\frac{x_t-\mu_t(z)}{\sigma_t}) = u_t^{target}(x_t\mid z)\\ &\Leftrightarrow u_t^{target}(x_t\mid z) = \dot {\mu_t(z)}-\frac{\dot {\sigma_t}}{\sigma_t}\mu_t(z)+ \frac{\dot{\sigma_t}}{\sigma_t}x_t \quad \blacksquare \end{align}\]

Conditional and Marginal Score Functions

Fokker-Planck Equation

Let \(p_t\) be a probability path and let us consider the SDE

\[X_0\sim p_{init},\quad \text{d}X_t = u_t(X_t)\text{d}t+\sigma_t\text{d}W_t.\]

Then \(X_t\) has distribution \(p_t\) for all \(0\leq t\leq 1\) if and only if the Fokker-Planck Equation holds

\[\frac{\partial p_t(x)}{\partial t} = -\nabla_x(p_t(x)u_t^{target}(x))+\frac{\sigma^2_t}{2}\Delta p_t(x)\quad \text{for all } x\in\mathbb{R}^d, 0\leq t\leq 1.\]

SDE Extension Trick

For diffusion coefficient \(\sigma_t\geq 0\), we can now construct an SDE which folows the same probability path

\[X_0\sim p_{\text{init}}, \quad \text{d}X_t = \underbrace{\left[u_t^{\text{target}}(X_t)+\frac{\sigma^2_t}{2}\nabla\log p_t(X_t)\right]}_{\text{drift item}}\text{d}t+\underbrace{\sigma_t}_{\text{diffusion item}}\text{d}W_t\\ \Rightarrow X_t \sim p_t\]

Proof.

\[\begin{align} \frac{\partial p_t(x)}{\partial t} &= -\nabla_x\cdot(p_t(x)u_t^{target}(x))\\ &= -\nabla_x\cdot(p_t(x)u_t^{target}(x)) - \frac{\sigma_t^2}{2}\Delta p_t(x) + \frac{\sigma_t^2}{2}\Delta p_t(x)\\ &=-\nabla_x\cdot(p_t(x)u_t^{target}(x)) - \text{div}\left(\frac{\sigma_t^2}{2} \nabla_x p_t(x) + \frac{\sigma_t^2}{2}\Delta p_t(x)\right)\\ &=-\nabla_x\cdot\left(p_t(x)u_t^{target}(x)\right) - \text{div}\left(\frac{\sigma_t^2}{2}p_t(x) \nabla_x\log p_t(x)\right) + \frac{\sigma_t^2}{2}\Delta p_t(x)\\ &=-\nabla_x\cdot\left(p_t(x)\left[u_t^{target}(x)+\frac{\sigma_t^2}{2}\nabla_x\log p_t(x)\right]\right) + \frac{\sigma_t^2}{2}\Delta p_t(x)\quad \blacksquare \end{align}\]

Consider the case that \(p_t=p\) for a fixed distribution. In this case, we set \(u_t^{target}=0\) to obtain the SDE

\[\text{d}X_t = \frac{\sigma_t^2}{2}p_t(X_0)\nabla_x\log p_t(X_0)+\sigma_t\text{d}W_t\]

This is commonly known as Langevin dynamics. As a Markov chian, if \(X_0\sim p^\prime\neq p\), in mild conditions, the distribution still converges to target distribution \(p\).