From b4392a28034305ee3264fa501ac882ffb1e8f8e4 Mon Sep 17 00:00:00 2001 From: geoffroyO Date: Tue, 7 Oct 2025 16:06:05 +0200 Subject: [PATCH] DDCM integrator --- diffuse/integrator/stochastic.py | 83 +++++++++++++++++++++++++++++++- 1 file changed, 81 insertions(+), 2 deletions(-) diff --git a/diffuse/integrator/stochastic.py b/diffuse/integrator/stochastic.py index e15c545..57e531e 100644 --- a/diffuse/integrator/stochastic.py +++ b/diffuse/integrator/stochastic.py @@ -1,14 +1,15 @@ from dataclasses import dataclass -from typing import Callable +from typing import Tuple import jax import jax.numpy as jnp +from jaxtyping import Array from diffuse.integrator.base import IntegratorState, Integrator from diffuse.diffusion.sde import DiffusionModel from diffuse.predictor import Predictor -__all__ = ["EulerMaruyamaIntegrator"] +__all__ = ["EulerMaruyamaIntegrator", "DDCMIntegrator"] @dataclass @@ -70,3 +71,81 @@ def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> I dx = drift * dt + diffusion * noise _, rng_key_next = jax.random.split(rng_key) return IntegratorState(position + dx, rng_key_next, step + 1) + + +@dataclass +class DDCMIntegrator(Integrator): + """Discrete Diffusion with Codebook Matching (DDCM) integrator. + + Implements a variant of Euler-Maruyama where stochastic noise is sampled from + a discrete codebook rather than a continuous Gaussian distribution. This enables + learning structured noise patterns that may better capture the data distribution. + https://arxiv.org/pdf/2502.01189 + + The integrator solves SDEs of the form: + dX(t) = μ(X,t)dt + σ(X,t)dW_codebook(t) + + where: + - μ(X,t) is the drift term: g(t)² * (0.5 * X + score(X,t)) + - σ(X,t) is the diffusion term: g(t) + - dW_codebook(t) is sampled uniformly from a learned codebook + - g(t) = sqrt(β(t)) where β(t) is the noise schedule + + Discretization: + X(t + dt) = X(t) + μ(X,t)dt + σ(X,t)√dt * codebook[i], i ~ Uniform(0, |codebook|) + + Attributes: + model: Diffusion model providing SDE coefficients + codebook: Array of shape (size_codebook, *x0_shape) containing learned noise vectors + + Initialization: + codebook = jax.random.normal(rng_key, (size_codebook, *x0_shape)) + integrator = DDCMIntegrator(model=model, timer=timer, codebook=codebook) + """ + + model: DiffusionModel + codebook: Array # Shape: (size_codebook, *x0_shape) + + def __call__(self, integrator_state: IntegratorState, predictor: Predictor) -> IntegratorState: + """Perform one DDCM integration step. + + Args: + integrator_state: Current state containing: + - position: Current position X(t) with shape (*x0_shape) + - rng_key: JAX random number generator key + - step: Current integration step index + predictor: Predictor providing the score function ∇ₓ log p(x|t) + + Returns: + Updated IntegratorState containing: + - New position X(t + dt) + - Updated RNG key (split for next iteration) + - Incremented step count + + Notes: + The integration step implements: + dx = drift*dt + diffusion*√dt*codebook[i] + where: + - drift = g(t)² * (0.5 * position + score(position, t)) + - diffusion = g(t) = sqrt(β(t)) + - i ~ Uniform(0, codebook_size) sampled independently per step + """ + + position, rng_key, step = integrator_state + t, t_next = self.timer(step), self.timer(step + 1) + dt = t - t_next + f_t, g_t = self.model.sde_coefficients(t) + # For reverse-time: drift = f(t)*x - g(t)^2*score, but rearranged as: g(t)^2 * (0.5*x + score) + # Since f(t) = -0.5*beta(t) and g(t) = sqrt(beta(t)), we have beta(t) = g(t)^2 + drift = g_t * g_t * (0.5 * position + predictor.score(position, t)) + diffusion = g_t + + rng_key, rng_noise = jax.random.split(rng_key) + # Sample a single random index from the codebook + rdx_index = jax.random.randint(rng_noise, shape=(), minval=0, maxval=self.codebook.shape[0]) + # Index the codebook to get noise with the same shape as position + noise = self.codebook[rdx_index] * jnp.sqrt(dt) + + dx = drift * dt + diffusion * noise + _, rng_key_next = jax.random.split(rng_key) + return IntegratorState(position + dx, rng_key_next, step + 1)