Flows and Flow Matching

An introduction to flow matching.

Flows and Flow Matching

This blog post will provide a brief introduction to flow-based models and flow matching.

Flow-based models are an example of a probabilistic generative model. The goal of probablisitic modeling is to model the distribution of a random variable \(X\). This is typically done in a supervised fashion using examples \(\{x^{(i)}\}_{i=1}^N\) collected from the data distribution. We learn to approximate the probability density function of the data distribution with a model \(p(x;\theta)\) where \(\theta\) represents the parameters of a neural network. Specifically, flow-based models learn to model the data distribution through normalizing flows. The main benefit of flow-based models is that they are effecient for sample generation and likelihood evaluation.

Normalizing Flows

The framework for normalizing flows is based on a rather simple fact from probability theory. Suppose \(\mathbf{x_0} \in \mathbb{R}^d\) is distributed according to \(p\) i.e. \(\mathbf{x_0} \sim p\). Let \(f: \mathbb{R}^d \to \mathbb{R}^d\) be an invertible and differentiable function. Now, let’s do a change of variables, \(\mathbf{x_1} = f(\mathbf{x_0})\). We can write \(q\), the distribution of the transformed variable \(\mathbf{x_1}\) in terms of \(p\). Namely,

\(\begin{align} q(\mathbf{x_1}) &= p(\mathbf{x_0})\left|\det \frac{\partial f^{-1}}{\partial \mathbf{x_1}}(\mathbf{x_1})\right| \notag \\ &= p\left(f^{-1}(\mathbf{x_1})\right)\left|\det \frac{\partial f^{-1}}{\partial \mathbf{x_1}}(\mathbf{x_1})\right|. \end{align}\) The notation \(\frac{\partial f^{-1}}{\partial \mathbf{x_1}}\) denotes the Jacobian of \(f^{-1}\). Also, because the transformation is invertible, we can obtain \(p\) from \(q\) too:

\[\begin{align*} p(\mathbf{x_0}) &= q(\mathbf{x_1})\left|\det \frac{\partial f}{\partial \mathbf{x_0}}(\mathbf{x_0}) \right| \\ &= q(f(\mathbf{x_0}))\left|\det \frac{\partial f}{\partial \mathbf{x_0}}(\mathbf{x_0}) \right|. \end{align*}\]

Example 1 . Scaling and shifting a Gaussian. Suppose \(\mathbf{x_0} \in \mathbb{R}\) and \(\mathbf{x_0} \sim \mathcal{N}(0,1)\). Let \(\mathbf{x_1} = f(\mathbf{x_0}) = \sigma \mathbf{x_0} + \mathbf{\mu}\). Then \(\mathbf{x_0} = f^{-1}(\mathbf{x_1}) = \frac{\mathbf{x_1} - \mathbf{\mu}}{\sigma}\) so \(\frac{df^{-1}}{d\mathbf{x_1}} = \frac{1}{\sigma}\). In this case, the Jacobian is a positive scalar function so the determinant is itself. Recall the pdf of a canonical Gaussian:

\[p(\mathbf{x_0}) = \frac{1}{\sqrt{2\pi}}e^{-\frac{1}{2}\mathbf{x_0}^2}.\]

Applying the formula we obtain a Gaussian with mean \(\mu\) and variance \(\sigma^2\),

\[\begin{align*} q(\mathbf{x_1}) &= p\left(f^{-1}(\mathbf{x_1})\right)\left|\det \frac{\partial f^{-1}}{\partial \mathbf{x_1}}(\mathbf{x_1})\right| \\ &= \frac{1}{\sqrt{2\pi}}e^{-\frac{1}{2}(\frac{x - \mathbf{\mu}}{\sigma})^2}\frac{1}{\sigma} \\ &= \frac{1}{\sqrt{2\pi\sigma}}e^\frac{-(x-\mathbf{\mu})^2}{2\sigma^2}. \end{align*}\]

Intuitively, multiplying \(\mathbf{x_0}\) by \(\sigma\) stretches the domain which changes the variance of the Gaussian. Adding \(\mu\) applies a shift to this stretched Gaussian.

Example 2 . Non-linear transformation of Gaussian. Suppose \(\begin{bmatrix} x \\ y\end{bmatrix} \sim \mathcal{N}(\mathbf{0}, \mathbf{I})\). The pdf of a canonical Gaussian in 2D is:

\[p(x,y) = \frac{1}{\sqrt{2\pi}}e^\frac{-(x^2 + y^2)}{2}.\]

Let’s apply a cubic transformation to each coordinate, \(u = x^3\) and \(v = y^3\). The inverse is \(x = u^\frac{1}{3}\) and \(y = v^\frac{1}{3}\). The Jacobian of this transformation is the following:

\[\begin{bmatrix} \frac{\partial x}{\partial u} & \frac{\partial v}{\partial v} \\ \frac{\partial y}{\partial u} & \frac{\partial v}{\partial v} \\ \end{bmatrix} = \begin{bmatrix} \frac{1}{3}u^{-\frac{2}{3}} & 0 \\ 0 & \frac{1}{3}v^{-\frac{2}{3}}\\ \end{bmatrix}.\]

The absolute value of the determinant of this matrix is \(\frac{1}{9}\lvert uv\rvert ^{-\frac{2}{3}}\). Therefore,

\[\begin{align*} q(u, v) &= \frac{1}{9}\lvert uv\rvert ^{-\frac{2}{3}} p(x,y) \\ &= \frac{1}{9}\lvert uv\rvert ^{-\frac{2}{3}}p(u^\frac{1}{3}, v^\frac{1}{3}) \\ &= \frac{\lvert uv\rvert ^{-\frac{2}{3}}}{9\sqrt{2\pi}}e^\frac{-(u^\frac{2}{3} + v^\frac{2}{3})}{2} \\ \end{align*}\]
By applying a cubic transformation to a Gaussian, we obtained a slightly more complex distribution.

In the context of normalizing flows, \(p\) is simple distribution, typically a canonical Gaussian. Our goal with this setup is to learn how to transform samples from \(p\) to samples from a complex data distribution \(q\). We can do this by learning the invertible transformation \(f\). The function \(f\) will involve the use a neural network with parameters \(\theta\), so from now on we will denote the transformation as \(f_\theta\). Once we have learned \(f_\theta\) we will have access to \(\hat{q}\) which hopefully will be a good approximation of \(q\).

Given that we learned \(f_\theta\), how do we do density estimation and generate samples from \(q\)? This is quite simple for flow models. If you have a data sample \(\mathbf{x}^{(i)}\), you can compute \(f^{-1}(\mathbf{x}^{(i)})\) and the deterimant of the Jacobian. Then plug those into eq. (1) to obtain \(\hat{q}(\mathbf{x}^{(i)})\). If you want to sample from \(q\), first obtain a sample \(\mathbf{x_0} \sim p\) which we know how to do because \(p\) is a simple distribution. Then, we can compute \({\mathbf{x_1} = f^{-1}_\theta(\mathbf{x_0})}\) and so \(\mathbf{x_1}\) will be a sample from \(\hat{q}\).

Normalizing flow methods impose a specific structure on \(f_\theta\). We want to learn the transformation from \(p\) to \(q\) as a sequence of simpler transformations. Define functions \(f_1 \cdots f_k\) to be invertible and differentiable. Note these functions are still parameterized by \(\theta\) but we omit making this explicit for sake of notation. Invertible and differentiable functions are closed under composition. We can use this fact to define \(f_\theta\) in the following manner:

\[f_\theta = f_k \circ f_{k-1} \cdots f_2 \circ f_1.\]

The intituion behind this formulation is somewhat analagous to the justification of stacking many layers in a deep learning model instead of using one wide layer. Learning the transformation from \(p\) to \(q\) in one step might be too difficult. Instead, we can learn a sequence of functions where each function is responsible for transforming its input distribution into a slightly more complex distribution. Eventually, we are able to model the complexity of the data distribution.

alt text Figure 1: Each function transforms an input distrubtion into a slightly more complex distribution. The overall transformation maps the simple distribution to the complex data distribution.

Let’s reformulate the process of normalizing flows. Since we are performing multiple steps, \(\mathbf{x_1}\) is no longer a sample from \(q\) but a sample from a distribution slightly more complex than \(p_0 = p\). After applying \(K\) transformations we will have that \(\mathbf{x_K} \sim \hat{q}\):

\[\begin{align*} &\phantom{\Rightarrow} \ \ \mathbf{x_0} \sim p_0, \quad \mathbf{x_1} = f_1(\mathbf{x_0}) \\ &\Rightarrow \mathbf{x_1} \sim p_1, \quad \mathbf{x_2} = f_2(\mathbf{x_1}) \\ \phantom{\Rightarrow x_1} &\cdots \\ &\Rightarrow \mathbf{x}_{K-1} \sim p_{K-1}, \quad \mathbf{x}_K = f_K(\mathbf{x}_{K-1}) \\ &\Rightarrow \mathbf{x}_K \sim p_K = \hat{q} \approx q. \end{align*}\]

The sequence of transformations from \(p\) to the distribution \(q\) is called a flow. The term normalizing in normalizing flow refers to the fact that after a transformation is applied, the resulting pdf is valid i.e. it integrates to one over its support and is greater than zero.

So how do we actually train normalizing flows? The objective function is simply the maximum log-likelihood of the data:

\[\begin{align*} \theta^* &= \max_{\theta} \sum_{i=1}^{N} \log(\hat{q}(\mathbf{x}^{(i)})) \\ &= \max_{\theta} \sum_{i=1}^{N} \log\left(p\left(f^{-1}_\theta(\mathbf{x}^{(i)})\right)\left|\det \frac{\partial f^{-1}_\theta}{\partial \mathbf{x}_K}(\mathbf{x}^{(i)})\right|\right) \\ &= \max_{\theta} \sum_{i=1}^{N} \log p\left(f^{-1}_\theta(\mathbf{x}^{(i)})\right) + \log\left|\det \frac{\partial f^{-1}_\theta}{\partial \mathbf{x}_K}(\mathbf{x}^{(i)})\right| \end{align*}.\]

Remember that \(f_\theta\) is actually the composition of a sequence of functions. We can simplify the determinant of the Jacobian of \(f\) by decomposing it as a product of the individual determinants. Specifically,

\[\left| \det \frac{f^{-1}_\theta}{\partial \mathbf{x}_K} \right| = \left| \det \prod_{k=1}^K \frac{f^{-1}_k}{\partial \mathbf{x}_k} \right| = \prod_{k=1}^K \left| \det \frac{f^{-1}_k}{\partial \mathbf{x}_k} \right|.\]

Substituting this back into the objective function we obtain:

\[\max_{\theta} \sum_{i=1}^{N} \left[ \log p\left(f^{-1}_\theta(\mathbf{x}^{(i)})\right) + \sum_{k=1}^{K} \log\left|\det \frac{f^{-1}_k}{\partial \mathbf{x}_k} (\mathbf{x}^{(i)}) \right|\right]\]

We can intepret the sum of log determinants in the objective as each “layer” of the flow receiving additional gradient information about the objective.

Research in normalizing flow methods typically consists of constructing transformations that are easily invertible and have simple and computable log determinants. There are many normalizing flow methods such as RealNVP, NICE, Glow etc. Here we will cover residual flows which will provide a good context for continuous normalizing flows.

Continuous Normalizing Flows

In the normalizing flows setup, the transformation from the simple distribution to the data distribution is expressed as a finite composition of functions. We can intepret this as a discrete time process with \(K\) time steps. Each time step, there is a corresponding intermediary distribution. But how can we obtain a transformation from \(p\) to \(q\) in continuous time rather than discrete time? Imagine this as taking an infinite composition of functions. We can express this idea using Ordinary Differential Equations (ODE), the fundamental component of Continuous Normalizing Flows (CNF).

To gain some intuition for flows and ODEs, consider a two dimensional vector field \(v(x,y)\) that describes the movement of water along a river. For simplicity, assume it’s time-independent. The velocity of the water at point \((x,y)\) is the vector \(v(x,y)\). The path of a pebble thrown into the water at time \(t=0\) is a curve we can parameterize as a function of time:

\[\mathbf{r}(t) = \langle x(t), y(t) \rangle, \qquad \mathbf{r}(0) = \langle x(0), y(0) \rangle.\]

We can solve for the position of the pebble at time \(t\) by making the following observation. At time \(t\), the velocity of the pebble, \(\frac{d\mathbf{r}(t)}{dt}\), is the same as the velocity of the water at the position of the pebble, \(\mathbf{r}(t)\). We can model this with the following ODE:

\[\frac{d\mathbf{r}(t)}{dt} = v(\mathbf{r}(t)) = v(x(t), y(t)), \qquad \mathbf{r}(0) = \langle x(0), y(0) \rangle.\]

This example demonstrate how we can describe the movement of a particle induced by a vector field given some initial position. Specifically, we can construct a function \(\mathbf{r}(t)\) that describes the path taken by a single particle starting at a specific point in space at \(t=0\). As we will see, a flow in the context of CNFs is a more general object that represents the motion of all particles through time.

Example

Put example here with vector field.

Let’s provide a more rigorous definition of a flow. Suppose we have a vector field \(u: \mathbb{R}^d \times [0, 1] \to \mathbb{R}^d\). Unlike the example above, this is a time-dependent vector field and we will denote the time parameter as a subscript, \(u_t(x)\). In this setup, \(d\) is the dimension of our data space.

A flow, which is induced by the vector field \(v\), is a mapping \(\phi: \mathbb{R}^d \times [0,1] \to \mathbb{R}^d\) which satisfies the following ODE:

\[\frac{d\phi_t(x)}{dt} = u_t(\phi_t(x)),\]

with initial condition \(\phi_0(x) = x\).

To gain a better intiution of what \(\phi\) represents we can compare it to \(\mathbf{r}(t)\). Given some initial point \(\mathbf{x_0}\), \(\mathbf{r}(t)\) is the position of that point at time \(t\) induced by the movement of water. Similarly, when we provide \(\mathbf{x_0}\) as input to \(\phi\), we will get the function \(\phi(t, \mathbf{x_0}): [0, 1] \to \mathbb{R}^d\) which is only a function of time. It parameterizes a curve in \(\mathbb{R}^d\) that represents the position of the point \(\mathbf{x_0}\) with time induced by the vector field \(u_t\). We can view \(\phi\) from another perspective. Given a specific point in time \(t_0 \in [0,1]\) as input to \(\phi\), we will obtain a function \(\phi(t_0, \mathbf{x}): \mathbb{R}^d \to \mathbb{R}^d\). This function maps all points at time \(t=0\) to the position they would be at time \(t=t_0\). Overall, the mapping \(\phi\) describes the movement of all points starting from time \(t=0\) to time \(t = 1\).For consistent notation, we will denote the time parameter as a subscript \(\phi_t\).

Another important object in CNFs is the probability density path \({p_t: \mathbb{R}^d \times [0,1] \to \mathbb{R}_{>0}}\). It is a time-dependent probability density function i.e. \(\int p_t(\mathbf{x})d\mathbf{x} = 1\). Similar to normalizing flows, we let \(p_0 = p\) be a simple distribution such as a canonical Gaussian. Then \(p_t\) is defined by a change of variables from \(p_0\) using mapping \(\phi_t\):

\[\begin{equation} p_t(\mathbf{x}) = p_0(\phi_t^{-1}(\mathbf{x}))\det \left| \frac{\partial \phi_t^{-1}}{\partial \mathbf{x}}(\mathbf{x}) \right|. \end{equation}\]

Note: with some regularity conditions on \(u_t\) we can gaurauntee that \(\phi_t\) is invertible. The details is out of scope for this blog post.

In the setting of CNFs, we let \(p_1\) be the data distibution. The goal is to learn a vector field \(v_t\) which induces a flow \(\phi_t\). This flow is responsible for transforming the simple distribution \(p_0 = p\) at time \(t=0\) to the data distribution \(p_1 = q\) at time \(t=1\).

The training objective is the same as in normalizing flows. We maximize the log-likelihood of the data. Given a data point \(\mathbf{x_1} \in \mathbb{R}^d\), to compute \(\log p_1(\mathbf{x_1})\) we could use eq (2). However, as in normalizing flows, that would require computing the Jacobian which is an \(O(d^3)\) operation. A benefit of CNFs is that once we are in the continuous setting, there is an alternative method available so we don’t have to do this computation. The alternative method involves the continuity equation:

\[\frac{d}{dt}p_t(\mathbf{x}) + \nabla \cdot (p_t(\mathbf{x})u_t(\mathbf{x})) = 0.\]

The continuity equation is a Partial Differential Equation (PDE) where \(\nabla \cdot\) represents the divergence operator. The divergence is computed with respect to the spatial dimensions \(\frac{\partial}{\partial x_i}\). The continuity equation provides a necassary and sufficient condition to ensure that a vector field \(u_t\) generates the probability density path \(p_t\). The continiuty equation can be derived using some basic vector calculus and it has a clean physical intepretation. Suppose you have some arbitrary volume \(V\) in \(\mathbb{R}^3\). The main observation is that probability density, \(p_t\) has to integrate to \(1\) over \(\mathbb{R}^3\) by definition. So analagous to mass, it is a conserved quantity. It cannot dissapear or appear out of thin air. Therefore, the change in probability density in the volume must equal the difference in probablity density that has entered the volume and exited the volume. This is the same ideas as if \(u_t\) represented the flow of water and \(p_t\) is the mass of the water. The change in mass must be the difference in the mass of water entering and leaving the volume. The change in probability density in the volume can be written as follows:

\[\frac{d}{dt}\iiint_V p_t dV.\]

Let \(S\) be the surface (boundary) of the volume. Let \(n: \mathbb{R}^3 \to \mathbb{R}^3\) represent the normal vector to the surface at point \((x,y,z)\). Then the probability density entering and leaving the volume is:

\[- \iint_S p_t (u_t \cdot n) dS = - \iiint_V \nabla \cdot (p_tu_t) dV.\]

The equality is an application of Gauss’s divergence theorem. Therefore,

\[\frac{d}{dt}\iiint_V p_t dV = - \iiint_V \nabla \cdot (p_tu_t) dV\]

Moving everything to one side and simplfying we get,

\[\iiint_V \frac{d}{dt}p_t + \nabla \cdot (p_tu_t)dV = 0.\]

Since this is true for every volume \(V\) it must be that the quantity inside the integral is equal to \(0\) and we arrive at the continuity equation. This reasoning can be applied in arbitrary dimension \(\mathbb{R}^d\).

Using the continuity equation and the ODE describing the flow \(\phi_t\) we get the instantaneous change of variable equation:

\[\frac{d}{dt}\log p_t(\phi_t(\mathbf{x})) + \nabla \cdot u_t(\phi_t(\mathbf{x})) = 0.\]

Now we have an ODE that describes the change of the log-probability along the flow trajectory. So we can use an ODE solver to obtain a solution to this ODE which will give use \(\log p_1(\mathbf{x_1})\). However, the downside to this approach is that we have to simulate the flow trajectory and compute a divergence which may still be expensive to compute. This results in continuous normalizing flows not being a very scalable method. Flow matching aims to solve this issue.

Flow Matching

Flow matching builds on the same framework as CNFs but uses a different loss function. Mainly because we would like to train with a more scalable loss function.

Notice that by the continuity equation, there is a direct correspondence between the probability density path \(p_t\) and the vector field \(u_t\). Namely, if we knew the vector field already then we know it would generate a unique probability density path. Therefore, instead of directly optimizing the probability density path and computing \(\log p_1(\mathbf{x_1})\) we can optimize the vector field instead. So the flow matching loss looks like this:

\[\mathcal{L}_{FM}(\theta) = \mathbb{E}_{t,p_t(\mathbf{x})}\left\lVert v_t(\mathbf{x}) - u_t(\mathbf{x})\right\rVert^2,\]

where \(v_t(x)\) is a learnable vector field parameterized by \(\theta\). We let \(p_t\) describe the probability path such that \(p_0 = p\) is a simple distribution e.g. canonical Gaussian and \(p_1\) is approximately the data distribution \(q\). We regress the learnable vector field \(v_t\) onto the true vector field, \(u_t\), that generates the probability density path \(p_t\). However, in practice, we cannot compute this loss because we don’t know \(p_t\) or \(u_t\). If we did then obviously there would be no point in learning the vector field \(v_t\). To overcome this obstacle, we are going to create another loss that will be computable. This is the conditional flow matching loss:

\[\mathcal{L}_{CFM}(\theta) = \mathbb{E}_{t,q(\mathbf{x_1}), p_t(\mathbf{x}\vert\mathbf{x_1})}\left\lVert v_t(x) - u_t(\mathbf{x}\vert\mathbf{x_1})\right\rVert^2.\]

We can prove that \(\nabla_\theta \mathcal{L}_{FM}\) and \(\nabla_\theta \mathcal{L}_{CFM}\) are equal upto a constant. So they are equivalent in a sense since they have the same optima. Thus the conditional flow matching loss is a reasonable replacement. Now, we will describe all of the objects in the conditional flow matching loss.

The basic idea is that we can construct the marginal probability path using conditional probability paths. These conditional paths are conditioned on data samples. Suppose we have a particular sample \(\mathbf{x_1}\) from the data distribution. We design the conditional probability path, \(p_t(x \vert \mathbf{x_1})\) so that \(p_0(\mathbf{x}\vert\mathbf{x_1}) = p\) and \(p_1(\mathbf{x} \vert \mathbf{x_1}) = q\). To satisfy the boundary conditions, we can set \(p_0(\mathbf{x}\vert\mathbf{x_1}) = p(x)\) and \(p_1(\mathbf{x} \vert \mathbf{x_1}) = \delta_{\mathbf{x_1}}\) where \(\delta_{\mathbf{x_1}}\) is the Dirac distribution centered at \(\mathbf{x_1}\). To obtain the marginal probability path, we can marginalize over the data distribution:

\[p_t(x) = \int p_t(x|\mathbf{x_1})q(\mathbf{x_1})d\mathbf{x_1}.\]

We can see that setting \(p_0(\mathbf{x} \vert\mathbf{x_1}) = p_0(x)\) and \(p_1(\mathbf{x} \vert\mathbf{x_1}) = \delta_{\mathbf{x_1}}\) results in \(p_0(\mathbf{x}) = q(\mathbf{x})\) and \(p_1(\mathbf{x}) = q(\mathbf{x})\). For numerical reasons, instead of using the actual Diract distibution, we can approximate it by setting \(p_1(x\vert\mathbf{x_1}) = \mathcal{N}(\mathbf{x_1}, \sigma^2_{min}\mathbf{I})\) with \(\sigma_{min}\) sufficiently small.

For conditional probability path, there exists a conditional vector field which we denote as \(u_t(x\vert\mathbf{x_1})\) which generates it. We can express the marginal vector field as

\[\begin{align*} u_t(x) = \int u_t(x \vert \mathbf{x_1})\frac{p_t(x \vert \mathbf{x_1})q(\mathbf{x_1})}{p_t(x)}d\mathbf{x_1}. \end{align*}\]

To prove that the marginal vector field \(u_t\) generates the probability path \(p_t\) we show that they satisfy the continuity equation. The conditional flow, \(\phi_t(\mathbf{x}\vert \mathbf{x_1})\) satisfies the following ODE based on the conditional vector field:

\[\begin{equation} \frac{d}{dt}\phi_t(\mathbf{x} \vert \mathbf{x_1}) = u_t\left(\phi_t(\mathbf{x} \vert \mathbf{x_1}) \vert \mathbf{x_1}\right), \end{equation}\]

with initial condition \(\phi_0(\mathbf{x} \vert \mathbf{x_1}) = \mathbf{x}\). Therefore, we can integrate over the conditional vector field to obtain the conditional flow. Although, the flow matching loss requires computation of \(u_t\). Even if we define \(u_t\) based on conditional vector fields, computing the marginalization integral is still intractable. So we do need to use the conditional flow matching loss.

Returning back to the conditional flow matching loss, the main idea is that we take an expectation over the data distribution and the conditional probability path. As a result, we can replace the marginal vector field in the flow matching loss with the conditional vector field. In practice, we sample a point from the dataset and then sample from the conditional probability path instead of the marginal probability path. Of course, computing the loss also involves sampling a time \(t \in [0,1]\).

Furthermore, we can make the following observation to reparameterize the conditional flow matching loss. The reparameterization avoids having to sample \(\mathbf{x} \sim p_t(\mathbf{x} \vert \mathbf{x_1})\). Instead, we can sample \(\mathbf{x_0} \sim p\) from the simple distribution. Then \(\mathbf{x_t} = \phi_t(\mathbf{x_0} \vert \mathbf{x_1})\) is a sample from \(p_t(\mathbf{x} \vert \mathbf{x_1})\) since the conditional flow is a transformation from \(p\) to \(p_t(\mathbf{x} \vert \mathbf{x_1})\). Therefore, \(\mathbf{x_t}\) is the solution to the ODE in equation (3) with \(\mathbf{x_0}\) substituted into the flow:

\[\frac{d\phi_t(\mathbf{x_0}|\mathbf{x_1})}{dt} = \mu_t(\phi_t(\mathbf{x_0}|\mathbf{x_1}) \vert \mathbf{x_1}),\]

with initial condition \(\phi_0(\mathbf{x_0}\vert\mathbf{x_1}) = \mathbf{x_0}\). Therefore, we can rewrite the conditional flow matching objective as:

\[\begin{align} \mathcal{L}_{CFM}(\theta) &= \mathbb{E}_{t, q(\mathbf{x_1}), p(\mathbf{x_0})}\left\lVert v_t(\phi_t(\mathbf{x_0}|\mathbf{x_1})) - \mu_t(\phi_t(\mathbf{x_0}|\mathbf{x_1}) | \mathbf{x_1})\right\rVert^2 \notag \\ &= \mathbb{E}_{t, q(\mathbf{x_1}), p(\mathbf{x_0})}\left\lVert v_t(\phi_t(\mathbf{x_0}|\mathbf{x_1})) - \frac{d\phi_t(\mathbf{x_0}|\mathbf{x_1})}{dt}\right\rVert^2. \end{align}\]

To summarize, we have a way of training CNFs by using conditional probability paths and flows. The conditional flow matching loss has the same optima and doesn’t require access to the marginal probability path or vector field. We can compute the conditional flow matching loss effeciently as long as \(p_t(x\vert\mathbf{x_1})\) is defined and can be sampled from effeciently. Furthermore, we are able to easily compute \(u_t(x\vert\mathbf{x_1})\) because it is defined on a per-sample basis.

Now we have covered the conditional flow matching framework, we have to choose how to define \(p_t(\mathbf{x} \vert \mathbf{x_1})\) and \(\phi_t(\mathbf{x} \vert \mathbf{x_1})\). The definitions for these objects are motivated primarily by simplicity and tractability. Flow matching was introduced as a vector field \(u_t\) inducing a flow \(\phi_t\) that results in a probability density path \(p_t\). Although this is the natural way to understand the framework, we are going to define these objects in the opposite order but everything still works out.

We start off by defining the conditional probability path. A natural and simple choice for this is a Gaussian distribution,

\[p_t(\mathbf{x}\vert\mathbf{x_1}) = \mathcal{N}(u_t(\mathbf{x_1}), \sigma^2_t(\mathbf{x_1})\mathbf{I}),\]

where \(u_t: \mathbb{R}^d \times [0,1] \to \mathbb{R}^d\) and \(\sigma: \mathbb{R}^d \times [0,1] \to \mathbb{R}_{>0}\) are the time-dependent mean and standard deviation of the Gaussian. We choose \(p\) (the simple distribution) to be the canonical Gaussian. In order to satisfy the boundary conditions, we must have that \(u_0(\mathbf{x_1}) = 0\), \(\sigma_0(\mathbf{x_1}) = 1\), \(u_1(\mathbf{x_1}) = \mathbf{x_1}\) and \(\sigma_1(\mathbf{x_1}) = \sigma_{min}\).

The simplest conditional flow that will generate \(p_t(\mathbf{x_1} \vert \mathbf{x_1})\) given that \(p\) is a canonical Gaussian is the following:

\[\phi_t(\mathbf{x} \vert\mathbf{x_1}) = \sigma_t(\mathbf{x_1})\mathbf{x} + u_t(\mathbf{x_1}),\]

where \(\mathbf{x} \sim p\). Indeed by example 1, this is true.

The conditional vector field that generates this flow is given by the following:

\[u_t(\mathbf{x}\vert\mathbf{x_1}) = \frac{\sigma_t'(\mathbf{x_1})}{\sigma_t(\mathbf{x_1})}(\mathbf{x} - \mu_t(\mathbf{x_1})) + \mu'_t(\mathbf{x_1}).\]

In this setup, \(\mu_t\) is an arbitrary function that we can choose. Essentially, this allows us to select any arbitrary path from \(0\) to \(\mathbf{x_1}\). A natural choice for this is a straight line which is called the optimal transport solution.

Optimal Transport

The optimal transport solution is the path that requires the least amount of work done to transform the canonical Gaussian to the mean \(u_t\) and std. \(\sigma_t\) Gaussian. Specifically, the mean and standard deviation change linearly with time:

\[u_t(\mathbf{x}) = t\mathbf{x_1}, \quad \text{and} \quad \sigma_t(\mathbf{x}) = 1 - (1 - \sigma_{min})t.\]

This straight line path is generated by the vector field:

\[u_t(\mathbf{x} \vert \mathbf{x_1}) = \frac{\mathbf{x_1} - (1 - \sigma_{min})\mathbf{x}}{1 - (1 - \sigma_{min})t}.\]

By substituting \(u_t\) and \(\sigma_t\), we get that the conditional flow in optimal transport case is:

\[\phi_t(\mathbf{x}|\mathbf{x_1}) = (1- (1 - \sigma_{min})t)\mathbf{x} + t\mathbf{x_1}.\]

Therefore, the reparameterized conditional flow matching loss is the following,

\[\mathbb{E}_{t, q(\mathbf{x_1}), p(\mathbf{x_0})}\left\lVert v_t(\phi_t(\mathbf{x_0}|\mathbf{x_1})) - (\mathbf{x_1} - (1 - \sigma_{min})\mathbf{x_0})\right\rVert^2.\]

The conditional flow is the optimal transport displacement map between two Gaussians. Although, the conditional flow is optimal it doesn’t imply that the marginal vector field is optimal.

Riemannian Flow Matching (RFM)

In the previous section, we discussed how to do flow matching in \(\mathbb{R}^d\). Another interesting question is how do we do flow matching on non-Euclidean geometries? This is relevant if you already know that your data lies on a manifold.

alt-text Figure 2: Consider a simple case where your data lies on a simple manifold in \(\mathbb{R}^2\) - the circle. Of course, on the left-hand side, you can use flow matching on Euclidean spaces to try to model this data. But it may be beneficial to specify as much prior knowledge you have about the data to obtain the best model. So performing flow matching on the manifold domain, the circle represented on the right, may lead to better performance.

There are many real-world applications where we would want to model data that resides on a manifold. Examples include protein modelling, molecule modelling, robotics, medical imaging and geological sciences.

In this section, we introduce Riemannian flow matching - a generalization of flow matching. Specifically, we consider complete, connected and smooth Riemannian manifolds, \(\mathcal{M}\) endowed with metric \(g\). Formally, we have a set of data samples \(\{x_i\}_{i=1}^N\) with \(x_i \in \mathcal{M}\) that arise from a probability distribution, \(q\) on \(\mathcal{M}\). We aim to learn a flow that transforms a simple noise distribution \(p\) on \(\mathcal{M}\) to the data distribution.

The tangent space at \(x \in \mathcal{M}\) is denoted as \(T_x\mathcal{M}\). Also, \(g\) induces many key quantities. It defines an inner product over \(T_x\mathcal{M}\) denoted as \(\langle u,v \rangle _g\). We have the expontential map \(\exp_x: T_x\mathcal{M} \to \mathcal{M}\) and extensions of the gradient, divergence and Laplacian. For all \(x \in \mathcal{M}\), \(\text{div}_g{x}\) denotes the divergence with respect to the spatial (\(x\)) argument. The integration of the function \(f: \mathcal{M} \to \mathbb{R}\) is denotes as \(\int f(x) d\text{vol}_x\).

Fortunately, there is not too many changes required to make flow matching work on manifolds. The objects used in RFM are the same as in FM. The space of probability densities over \(\mathcal{M}\) is defined as \(\mathcal{P}\). We have a probability path \(p_t: \mathcal{M} \times [0,1] \to \mathcal{P}\) such that \(\int p_t(x)d\text{vol}_x = 1\). The time dependent vector field is represented as \(u_t: \mathcal{M} \times [0,1] \to \mathcal{M}\). The flow \(\phi_t: \mathcal{M} \times [0,1] \to \mathcal{M}\) satisifies the following ODE defined on \(\mathcal{M}\):

\[\frac{d\phi_t(\mathbf{x})}{dt} = u_t(\phi_t(\mathbf{x})),\]

with initial condition \(\phi_0(\mathbf{x}) = \mathbf{x}\). The vector field \(u_t\) and probability path \(p_t\) also satisify the continuity equation on manifolds:

\[\frac{dp_t(\mathbf{x})}{dt} + \text{div}_g u_t(\phi_t(\mathbf{x})) = 0.\]

The vector field \(u_t\) generates the probability path \(p_t\) such that \(p_0 = p\) is the simple distribution and \(p_1 = q\) is the data distribution. The Riemannian flow matching objective is almost the same except we use \(g\) as the metric for the norm:

\[\mathcal{L}_{RFM}(\theta) = \mathbb{E}_{t, p_t(\mathbf{x})} \left\lVert v_t(\mathbf{x}) - u_t(\mathbf{x})\right\rVert^2_g.\]

Again, \(v_t\) is a learnable time-dependent vector field parameterized by \(\theta\). However, as before we don’t know the probability path \(p_t\) nor the vector field that generates this probability path. Since we cannot compute this loss, we use the Riemannian conditional flow matching loss instead.

We condition on data samples to construct the conditional probability path and conditional vector field. Given \(\mathbf{x_1} \sim q\) we define the conditional path as \(p_t(\mathbf{x}\vert\mathbf{x_1})\) to satisfy the boundary conditions. As a note, we are keeping it general and not specifying the form of the conditional distribution. It does not have to be a Gaussian as in the Euclidean flow matching. Also, we can write the marginal probability path as

\[p_t(\mathbf{x}) = \int_{\mathcal{M}} p_t(\mathbf{x}\vert\mathbf{x_1})q(\mathbf{x_1})d\text{vol}_{\mathbf{x_1}}.\]

We define the conditional vector field \(u_t(\mathbf{x}\vert\mathbf{x_1})\) that generates this probability path. The marginal vector field can be obtained in a similar fashion as before:

\[u_t(x) = \int_{\mathcal{M}} u_t(x|\mathbf{x_1}) \frac{p_t(x|\mathbf{x_1})q(\mathbf{x_1})}{p_t(x)} d\text{vol}_{\mathbf{x_1}}.\]

Once again computing this integral is intractable which motivates us to define the Riemannian conditional flow matching loss:

\[\mathcal{L}_{RCFM}(\theta) = \mathbb{E}_{t, q(\mathbf{x_1}), p_t(\mathbf{x}\vert\mathbf{x_1})} ||v_t(\mathbf{x}) - u_t(\mathbf{x}\vert\mathbf{x_1})||^2_g.\]

We can reparameterize the loss as follows:

\[\mathcal{L}_{RCFM}(\theta) = \mathbb{E}_{t, q(\mathbf{x_1}), r(\mathbf{x_0})} \left\lVert v_t(\phi_t(\mathbf{x}\vert\mathbf{x_0})) - u_t(\phi(\mathbf{x} \vert \mathbf{x_0})\vert\mathbf{x_1})\right\rVert^2_g.\]

Now we need a way to construct the conditional flow. The conditional flow will map all points to \(\mathbf{x_1}\) at time \(t=1\) regardless of the choice of \(p\). So the flow satisfies:

\[\phi_1(\mathbf{x}\vert\mathbf{x_1}) = \mathbf{x_1}, \quad \forall \mathbf{x} \in \mathcal{M}.\]

Also, in the same manner in which we parameterized the loss function, we can sample \(\mathbf{x_0} \sim p\) and then compute \(\mathbf{x_t} = \phi_t(\mathbf{x_0} \vert \mathbf{x_1})\). Now, in order to construct the conditional flow, we consider two different cases. The first case is when we are on simple manifolds i.e. we have a closed form for the geodesics. Let \(d_g(\mathbf{x}, \mathbf{y})\) represent the geodesic distance between two points on the manifold. Let \(\kappa(t)\) be a monotonically decreasing function s.t. \(\kappa(0) = 1\) and \(\kappa(1) = 0\). We want to find a conditional flow \(\phi_t(\mathbf{x} \vert \mathbf{x_1})\) that will satisfy the following equation according to the scheduler \(\kappa\):

\[d_g(\phi_t(\mathbf{x_0} \vert \mathbf{x_1}), \mathbf{x_1}) = \kappa(t)d_g(\mathbf{x_0}, \mathbf{x_1}).\]

This will gaurantee that \(\phi_1(\mathbf{x} \vert \mathbf{x_1}) = \mathbf{x_1}\). A simple choice for this scheduler is \(\kappa(t) = 1 - t\). In fact, the conditional flow, \(\phi_t(\mathbf{x_0} \vert \mathbf{x_1})\) is a geodesic connecting \(x_0\) and \(x_1\). Additionally, the geodesic can be expressed as,

\[\phi_t(\mathbf{x_0} \vert \mathbf{x_1}) = \exp_{\mathbf{x_1}}(\kappa(t)\log_{\mathbf{x_1}}(\mathbf{x_0})),\]

which is simple to compute and results in a highly-scalable training objective. This conditional flow can be thought of as the analouge of interpolating between \(\mathbf{x_0}\) and \(\mathbf{x_1}\) in Euclidean space:

\[(1 - \kappa(t))\mathbf{x_1} + \kappa(t)\mathbf{x_0}.\]

When we are not on simple manifolds and don’t have access to the geodesic in closed form, we have to work with a pre-metric. A pre-metric is a function \(d: \mathcal{M} \times \mathcal{M} \to \mathbb{R}\) which satisfies the following properties:

Note that a geodesic satisfies the definition for a premetric. Then we want a flow \(\phi_t(\mathbf{x_0} \vert \mathbf{x_1})\) to satisfy,

\[d(\phi_t(\mathbf{x_0} \vert \mathbf{x_1}), \mathbf{x_1}) = \kappa(t)d(\mathbf{x_0}, \mathbf{x_1}).\]

Once again, this will gaurantee that \(\phi_1(\mathbf{x_0} \vert \mathbf{x_1}) = \mathbf{x_1}\). Furthermore, the conditional vector field that generates this flow can be shown to be:

\[\mu_t(\mathbf{x} \vert \mathbf{x_1}) = \frac{d \log \kappa(t)}{dt} d(\mathbf{x}, \mathbf{x_1})\frac{\nabla_x d(\mathbf{x}, \mathbf{x_1})}{\lVert \nabla_x d(\mathbf{x}, \mathbf{x_1}) \rVert _g^2}.\]

Although this formula seems complicated, the basic component is the gradient of the distance, \(\nabla_x d(\mathbf{x}, \mathbf{x_1})\). This ensures we are going in the direction of \(\mathbf{x_1}\). The other terms control for the speed and make sure that the flow hits \(\mathbf{x_1}\) at time \(t=1\).

If we don’t have access to the geodesic then there is no simple closed form interpolation like formula to compute \(\mathbf{x_t}\). Therefore, we must simulate/use an ODE solver to obtain \(\mathbf{x_t}\) which may computationally expensive.

An example of a pre-metric is the spectral distance:

\[d_w(\mathbf{x},\mathbf{y})^2 = \sum_{i=1}^{\infty} w(\lambda_i) (\varphi_i(\mathbf{x}) - \varphi_i(\mathbf{y})),\]

where \(\varphi_i: \mathcal{M} \to \mathbb{R}\) are the eigenfunctions of the Laplace-Beltrami operator \(\Delta_g\) over \(\mathcal{M}\) with eigenvalues \(\lambda_i\), \(\Delta_g \varphi_i = \lambda_i \varphi_i\) and \(w: \mathbb{R} \to \mathbb{R}_{>0}\) is some monotonically decreasing weighting function. Using the spectral distance can be more beneficial than geodesics because they are more robust to topological noise such as holes and shortcuts and are more geometry aware. An example of a spectral distance is the biharmonic distance which is helpful in avoiding boundaries of manifolds as show in the following figure.