Source code for tsgm.models.ddpm

"""
The implementation is based on Keras DDPM implementation: https://keras.io/examples/generative/ddpm/
"""
import numpy as np

from tensorflow import keras
import tensorflow as tf
from tensorflow.python.types.core import TensorLike

import typing as T


[docs]class GaussianDiffusion: """Gaussian diffusion utility for generating samples using a diffusion process. This class implements a Gaussian diffusion process, where a sample is gradually perturbed by adding Gaussian noise over a series of timesteps. It also includes methods to reverse the diffusion process, predicting the original data from the noisy samples. Args: beta_start (float): Start value of the scheduled variance for the diffusion process. beta_end (float): End value of the scheduled variance for the diffusion process. timesteps (int): Number of timesteps in the forward process. """ def __init__( self, beta_start: float = 1e-4, beta_end: float = 0.02, timesteps: int = 1000, ) -> None: self.beta_start = beta_start self.beta_end = beta_end self.timesteps = timesteps # Define the linear variance schedule self.betas = betas = np.linspace( beta_start, beta_end, timesteps, dtype=np.float64, # Using float64 for better precision ) self.num_timesteps = int(timesteps) alphas = 1.0 - betas alphas_cumprod = np.cumprod(alphas, axis=0) alphas_cumprod_prev = np.append(1.0, alphas_cumprod[:-1]) self.betas = tf.constant(betas, dtype=tf.float32) self.alphas_cumprod = tf.constant(alphas_cumprod, dtype=tf.float32) self.alphas_cumprod_prev = tf.constant(alphas_cumprod_prev, dtype=tf.float32) # Calculations for diffusion q(x_t | x_{t-1}) and others self.sqrt_alphas_cumprod = tf.constant( np.sqrt(alphas_cumprod), dtype=tf.float32 ) self.sqrt_one_minus_alphas_cumprod = tf.constant( np.sqrt(1.0 - alphas_cumprod), dtype=tf.float32 ) self.log_one_minus_alphas_cumprod = tf.constant( np.log(1.0 - alphas_cumprod), dtype=tf.float32 ) self.sqrt_recip_alphas_cumprod = tf.constant( np.sqrt(1.0 / alphas_cumprod), dtype=tf.float32 ) self.sqrt_recipm1_alphas_cumprod = tf.constant( np.sqrt(1.0 / alphas_cumprod - 1), dtype=tf.float32 ) # Calculations for posterior q(x_{t-1} | x_t, x_0) posterior_variance = ( betas * (1.0 - alphas_cumprod_prev) / (1.0 - alphas_cumprod) ) self.posterior_variance = tf.constant(posterior_variance, dtype=tf.float32) # Log calculation clipped because the posterior variance is 0 at the beginning # of the diffusion chain self.posterior_log_variance_clipped = tf.constant( np.log(np.maximum(posterior_variance, 1e-20)), dtype=tf.float32 ) self.posterior_mean_coef1 = tf.constant( betas * np.sqrt(alphas_cumprod_prev) / (1.0 - alphas_cumprod), dtype=tf.float32, ) self.posterior_mean_coef2 = tf.constant( (1.0 - alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - alphas_cumprod), dtype=tf.float32, )
[docs] def _extract(self, a: TensorLike, t: int, x_shape: tf.TensorShape) -> TensorLike: """ Extracts coefficients for a specific timestep and reshapes them for broadcasting. Args: a: Tensor to extract from. t: Timestep for which the coefficients are to be extracted. x_shape: Shape of the current batched samples. Returns: Tensor reshaped to [batch_size, 1, 1] for broadcasting. """ batch_size = x_shape[0] out = tf.gather(a, t) return tf.reshape(out, [batch_size, 1, 1])
[docs] def q_mean_variance(self, x_start: TensorLike, t: float) -> T.Tuple: """Extracts the mean and variance at a specific timestep in the forward diffusion process. Args: x_start: Initial sample (before the first diffusion step). t: A timestep. Returns: mean, variance, log_variance: Tensors representing the mean, variance, and log variance of the distribution at `t`. """ x_start_shape = tf.shape(x_start) mean = self._extract(self.sqrt_alphas_cumprod, t, x_start_shape) * x_start variance = self._extract(1.0 - self.alphas_cumprod, t, x_start_shape) log_variance = self._extract( self.log_one_minus_alphas_cumprod, t, x_start_shape ) return mean, variance, log_variance
[docs] def q_sample(self, x_start: TensorLike, t: float, noise: float) -> T.Tuple: """Performs the forward diffusion step by adding Gaussian noise to the sample. Args: x_start: Initial sample (before the first diffusion step) t: Current timestep noise: Gaussian noise to be added at timestep `t` Returns: Diffused samples at timestep `t` """ x_start_shape = tf.shape(x_start) return ( self._extract(self.sqrt_alphas_cumprod, t, tf.shape(x_start)) * x_start + self._extract(self.sqrt_one_minus_alphas_cumprod, t, x_start_shape) * noise )
[docs] def predict_start_from_noise(self, x_t: TensorLike, t, noise): """Predicts the initial sample from the noisy sample at timestep `t`. Args: x_t: Noisy sample at timestep `t`. t: Current timestep. noise: Gaussian noise added at timestep `t`. Returns: Predicted initial sample. """ x_t_shape = tf.shape(x_t) return ( self._extract(self.sqrt_recip_alphas_cumprod, t, x_t_shape) * x_t - self._extract(self.sqrt_recipm1_alphas_cumprod, t, x_t_shape) * noise )
[docs] def q_posterior(self, x_start, x_t, t): """Computes the mean and variance of the posterior distribution q(x_{t-1} | x_t, x_0). Args: x_start: Initial sample (x_0) for the posterior computation. x_t: Sample at timestep `t`. t: Current timestep. Returns: Posterior mean, variance, and clipped log variance at the current timestep. """ x_t_shape = tf.shape(x_t) posterior_mean = ( self._extract(self.posterior_mean_coef1, t, x_t_shape) * x_start + self._extract(self.posterior_mean_coef2, t, x_t_shape) * x_t ) posterior_variance = self._extract(self.posterior_variance, t, x_t_shape) posterior_log_variance_clipped = self._extract( self.posterior_log_variance_clipped, t, x_t_shape ) return posterior_mean, posterior_variance, posterior_log_variance_clipped
[docs] def p_mean_variance(self, pred_noise, x, t): """Predicts the mean and variance for the reverse diffusion step. Args: pred_noise: Noise predicted by the diffusion model. x: Samples at a given timestep for which the noise was predicted. t: Current timestep. Returns: model_mean, posterior_variance, posterior_log_variance: Tensors representing the mean and variance of the model at the current timestep. """ x_recon = self.predict_start_from_noise(x, t=t, noise=pred_noise) model_mean, posterior_variance, posterior_log_variance = self.q_posterior( x_start=x_recon, x_t=x, t=t ) return model_mean, posterior_variance, posterior_log_variance
[docs] def p_sample(self, pred_noise, x, t): """Generates a sample from the diffusion model by reversing the diffusion process. Args: pred_noise: Noise predicted by the diffusion model. x: Samples at a given timestep for which the noise was predicted. t: Current timestep. Returns: Sample generated by reversing the diffusion process at timestep `t`. """ model_mean, _, model_log_variance = self.p_mean_variance( pred_noise, x=x, t=t ) noise = tf.random.normal(shape=x.shape, dtype=x.dtype) # No noise when t == 0 nonzero_mask = tf.reshape( 1 - tf.cast(tf.equal(t, 0), tf.float32), [tf.shape(x)[0], 1, 1] ) return model_mean + nonzero_mask * tf.exp(0.5 * model_log_variance) * noise
[docs]class DDPM(keras.Model): """ Denoising Diffusion Probabilistic Model Args: network (keras.Model): A Keras model that predicts the noise added to the images. ema_network (keras.Model): EMA model, a clone of `network` timesteps (int): The number of timesteps in the diffusion process. ema (float): The decay factor for the EMA, default is 0.999. """ def __init__(self, network: keras.Model, ema_network: keras.Model, timesteps: int, ema: float = 0.999) -> None: super().__init__() self.network = network self.ema_network = ema_network self.timesteps = timesteps self.gdf_util = GaussianDiffusion(timesteps=timesteps) self.ema = ema self.ema_network.set_weights(network.get_weights()) # Initially the weights are the same # Filled in during training self.seq_len = None self.feat_dim = None
[docs] def train_step(self, images: TensorLike) -> T.Dict: """ Performs a single training step on a batch of images. Args: images: A batch of images to train on. Returns: A dictionary containing the loss value for the training step. """ self.seq_len, self.feat_dim = images.shape[1], images.shape[2] # 1. Get the batch size batch_size = tf.shape(images)[0] # 2. Sample timesteps uniformly t = tf.random.uniform( minval=0, maxval=self.timesteps, shape=(batch_size,), dtype=tf.int64 ) with tf.GradientTape() as tape: # 3. Sample random noise to be added to the images in the batch noise = tf.random.normal(shape=tf.shape(images), dtype=images.dtype) # 4. Diffuse the images with noise images_t = self.gdf_util.q_sample(images, t, noise) # 5. Pass the diffused images and time steps to the network pred_noise = self.network([images_t, t], training=True) # 6. Calculate the loss loss = self.loss(noise, pred_noise) # 7. Get the gradients gradients = tape.gradient(loss, self.network.trainable_weights) # 8. Update the weights of the network self.optimizer.apply_gradients(zip(gradients, self.network.trainable_weights)) # 9. Updates the weight values for the network with EMA weights for weight, ema_weight in zip(self.network.weights, self.ema_network.weights): ema_weight.assign(self.ema * ema_weight + (1 - self.ema) * weight) # 10. Return loss values return {"loss": loss}
[docs] def generate(self, n_samples: int = 16) -> TensorLike: """ Generates new samples by running the reverse diffusion process. Args: n_samples: The number of samples to generate. Returns: Generated samples after running the reverse diffusion process. """ if self.seq_len is None or self.feat_dim is None: raise ValueError("DDPM is not trained") # 1. Randomly sample noise (starting point for reverse process) samples = tf.random.normal( shape=(n_samples, self.seq_len, self.feat_dim), dtype=tf.float32 ) # 2. Sample from the model iteratively for t in reversed(range(0, self.timesteps)): tt = tf.cast(tf.fill(n_samples, t), dtype=tf.int64) pred_noise = self.ema_network.predict( [samples, tt], verbose=0, batch_size=n_samples ) samples = self.gdf_util.p_sample( pred_noise, samples, tt ) # 3. Return generated samples return samples
[docs] def call(self, n_samples: int) -> TensorLike: """ Calls the generate method to produce samples. Args: n_samples: The number of samples to generate. Returns: Generated samples. """ return self.generate(n_samples)