Unlocking the Power of Diffusion Models: A Dive into Relative Trajectory Balance

In the ever-evolving landscape of artificial intelligence, diffusion models have emerged as a powerful tool for generating high-quality data across various domains such as images, text, and even action spaces. But what happens when we need these models to not just generate, but also adapt to specific constraints? We tackle this very problem in our paper “Amortizing Intractable Inference in Diffusion Models for Vision, Language, and Control“, where we propose Relative Trajectory Balance (RTB), an asymptotically unbiased training objective for training diffusion models that sample from posterior distributions under a diffusion model prior.

Tackling the Intractable: Why Posterior Inference Matters

Diffusion models are excellent at generating data by transforming noise through a series of steps. However, when these models are used in real-world applications, we often need them to conform to certain constraints—think of generating an image with specific features or producing text that adheres to a given context. This requirement leads us to the problem of posterior inference, where we want to sample data that fits both the original prior distribution while adhering to the additional constraints. The catch? This posterior inference is typically intractable.

Relative Trajectory Balance (RTB)

In the paper, we introduce RTB, a novel framework designed to tackle the challenge of intractable posterior inference. But what exactly is RTB, and why is it a game-changer?

Mathematical Formulation of RTB

At its core, RTB is about ensuring that we can finetune a diffusion model prior into a posterior model, and sample according to the posterior distribution of interest. Ideally, the posterior model should not drift too far from the prior (since we want the samples to be still plausible with respect to the original prior distribution), but it should nonetheless change appropriately to capture the posterior of interest.

When tackling the posterior sampling problem, the goal is to sample from a posterior distribution p_{\text{post}}(\mathbf{x}) \propto p_\theta(\mathbf{x}) r(\mathbf{x}) , where p_\theta(\mathbf{x}) is our diffusion model prior and r(\mathbf{x}) is the constraint. Importantly, r(\mathbf{x}) does not need to be a differentiable function, and it can be a measurement, a reward from another model, human preference, or anything really.

Our RTB objective introduces a constraint that ensures the correct relationship between the prior and the posterior processes, mathematically expressed as:

\mathcal{L}_{RTB}(\mathbf{x}_0 \rightarrow \mathbf{x}_{\Delta t} \rightarrow \dots \rightarrow \mathbf{x}_1; \phi) := \log\left(\frac{Z_{\phi} \cdot p_\phi^{post}(\mathbf{x}_0, \mathbf{x}_{\Delta t}, \dots, \mathbf{x}_1)}{r(\mathbf{x}_1) p_{\theta}( \mathbf{x}_0, \mathbf{x}_{\Delta t}, \dots, \mathbf{x}_1)}\right)^2

This ensures that the likelihood of the denoising trajectory according to the posterior model (p_\phi^{post}(\mathbf{x}_0, \mathbf{x}_{\Delta t}, \dots, \mathbf{x}_1)) remains close to the likelihood of the same trajectory according to the prior (p_{\theta}( \mathbf{x}_0, \mathbf{x}_{\Delta t}, \dots, \mathbf{x}_1)), while learning the scalar Z_{\phi} (in log scale for numerical stability), and sampling proportionally to p_\theta(\mathbf{x}) r(\mathbf{x}) . Thanks to its design, the advantage of RTB lies in its ability to achieve unbiased posterior sampling (proof in the paper).

Highlight Results

To show the effectiveness of RTB we provide a diverse set of empirical results across multiple domains. Here are some highlights:

In Vision, RTB applied to image generation tasks achieves state-of-the art results, improving on classifier guidance and RL-based baselines in both the diversity of the generated posterior samples and closeness to the true posterior distribution. Below is an example of finetuning an unconditional diffusion prior.

In stable diffusion, we can finetune a large text-to-image model to achieve high scores according to human preference, on par with state-of-the-art methods

In Language, we show we can finetune a posterior over a discrete language diffusion model prior to perform coherent posterior sampling for text infilling.

In reinforcement learning Control tasks, we report state-of-the-art results when training a posterior behavior policy from a diffusion behavioral prior, achieving comparable or better results than other recent methods.

Conclusion

Relative Trajectory Balance (RTB) is a flexible objective that allows for efficient and unbiased sampling from the posterior distribution, given a diffusion prior and a reward/likelihood. The impressive empirical results across vision, language, and control domains underscore the potential of RTB to set new benchmarks in various AI applications. Whether you’re diving into research or developing practical solutions, the advancements presented here are poised to transform how we generate and refine data. Stay tuned for more exciting developments in this field!

Check out these links for more information:

Paper, Code

Leave a Reply