Source code for tsgm.models.cgan

import typing as T
import keras
from keras import ops
try:
    import tensorflow_privacy as tf_privacy
    __tf_privacy_available = True
except (ModuleNotFoundError, ImportError):
    __tf_privacy_available = False

import logging

import tsgm

import os
from tsgm.backend import get_backend


logger = logging.getLogger('models')
logger.setLevel(logging.DEBUG)


def _is_dp_optimizer(optimizer: keras.optimizers.Optimizer) -> bool:
    return __tf_privacy_available \
        and (isinstance(optimizer, tf_privacy.DPKerasAdagradOptimizer)
             or isinstance(optimizer, tf_privacy.DPKerasAdamOptimizer)
             or isinstance(optimizer, tf_privacy.DPKerasSGDOptimizer))


[docs]class GAN(keras.Model): """ GAN implementation for unlabeled time series. """ def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, use_wgan: bool = False) -> None: """ :param discriminator: A discriminator model which takes a time series as input and check whether the sample is real or fake. :type discriminator: keras.Model :param generator: Takes as input a random noise vector of `latent_dim` length and returns a simulated time-series. :type generator: keras.Model :param latent_dim: The size of the noise vector. :type latent_dim: int :param use_wgan: Use Wasserstein GAN with gradient penalty. :type use_wgan: bool """ super(GAN, self).__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim self._seq_len = self.generator.output_shape[1] self.use_wgan = use_wgan self.gp_weight = 10.0 self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss") self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss")
[docs] def call(self, inputs): """ Forward pass for the GAN model. This method is required for Keras 3 compatibility with PyTorch backend. """ # For GAN training, we don't typically call the model directly. # This is just a placeholder to satisfy Keras 3 requirements. # Return the inputs as-is since this is mainly used for building the model. return inputs
[docs] def wgan_discriminator_loss(self, real_sample, fake_sample): real_loss = ops.mean(real_sample) fake_loss = ops.mean(fake_sample) return fake_loss - real_loss
# Define the loss functions to be used for generator
[docs] def wgan_generator_loss(self, fake_sample): return -ops.mean(fake_sample)
[docs] def gradient_penalty_tf(self, tf, interpolated): with tf.GradientTape() as gp_tape: gp_tape.watch(interpolated) # 1. Get the discriminator output for this interpolated sample. pred = self.discriminator(interpolated, training=True) # 2. Calculate the gradients w.r.t to this interpolated sample. grads = gp_tape.gradient(pred, [interpolated])[0] return grads
[docs] def gradient_penalty_torch(self, torch, interpolated): # Create a new tensor that requires grad instead of modifying existing one interpolated = interpolated.detach().requires_grad_(True) pred = self.discriminator(interpolated, training=True) grads = torch.autograd.grad(outputs=pred, inputs=interpolated, grad_outputs=ops.ones_like(pred), create_graph=True, retain_graph=True, only_inputs=True)[0] return grads
[docs] def gradient_penalty(self, batch_size, real_samples, fake_samples): # get the interpolated samples alpha_shape = [batch_size, 1, 1] # Create alpha on the same device as real_samples alpha = ops.ones_like(real_samples[:, :1, :1]) * keras.random.normal(alpha_shape, 0.0, 1.0) diff = fake_samples - real_samples interpolated = real_samples + alpha * diff backend = get_backend() if os.environ.get("KERAS_BACKEND") == "tensorflow": grads = self.gradient_penalty_tf(backend, interpolated) elif os.environ.get("KERAS_BACKEND") == "torch": grads = self.gradient_penalty_torch(backend, interpolated) # 3. Calcuate the norm of the gradients norm = ops.sqrt(ops.sum(ops.square(grads), axis=[1, 2])) gp = ops.mean((norm - 1.0) ** 2) return gp
@property def metrics(self) -> T.List: """ :returns: A list of metrics trackers (e.g., generator's loss and discriminator's loss). """ return [self.gen_loss_tracker, self.disc_loss_tracker]
[docs] def compile(self, d_optimizer: keras.optimizers.Optimizer, g_optimizer: keras.optimizers.Optimizer, loss_fn: keras.losses.Loss) -> None: """ Compiles the generator and discriminator models. :param d_optimizer: An optimizer for the GAN's discriminator. :type d_optimizer: keras.optimizers.Optimizer :param g_optimizer: An optimizer for the GAN's generator. :type g_optimizer: keras.optimizers.Optimizer :param loss_fn: Loss function. :type loss_fn: keras.losses.Loss """ super(GAN, self).compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn generator_dp = _is_dp_optimizer(d_optimizer) discriminator_dp = _is_dp_optimizer(g_optimizer) if generator_dp != discriminator_dp: logger.warning(f"One of the optimizers is DP and another one is not. generator_dp={generator_dp}, discriminator_dp={discriminator_dp}") self.dp = generator_dp and discriminator_dp
def _get_random_vector_labels(self, batch_size: int, labels=None) -> tsgm.types.Tensor: return keras.random.normal(shape=(batch_size, self.latent_dim))
[docs] def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict[str, float]: real_data = data batch_size = ops.shape(real_data)[0] # Generate ts random_vector = self._get_random_vector_labels(batch_size) fake_data = self.generator(random_vector) combined_data = ops.concatenate( [fake_data, real_data], axis=0 ) # Labels for descriminator # 1 == real data # 0 == fake data desc_labels = ops.concatenate( [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0 ) with tf.GradientTape() as tape: predictions = self.discriminator(combined_data) if self.use_wgan: fake_logits = self.discriminator(fake_data, training=True) # Get the logits for the real samples real_logits = self.discriminator(real_data, training=True) # Calculate the discriminator loss using the fake and real sample logits d_cost = self.wgan_discriminator_loss(real_logits, fake_logits) # Calculate the gradient penalty gp = self.gradient_penalty(batch_size, real_data, fake_data) # Add the gradient penalty to the original discriminator loss d_loss = d_cost + gp * self.gp_weight else: d_loss = self.loss_fn(desc_labels, predictions) grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) random_vector = self._get_random_vector_labels(batch_size=batch_size) # Pretend that all samples are real misleading_labels = ops.zeros((batch_size, 1)) # Train generator (with updating the discriminator) with tf.GradientTape() as tape: fake_data = self.generator(random_vector) predictions = self.discriminator(fake_data) if self.use_wgan: # uses logits g_loss = self.wgan_generator_loss(predictions) else: g_loss = self.loss_fn(misleading_labels, predictions) grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) self.gen_loss_tracker.update_state(g_loss) self.disc_loss_tracker.update_state(d_loss) return { "g_loss": self.gen_loss_tracker.result(), "d_loss": self.disc_loss_tracker.result(), }
[docs] def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict[str, float]: # Handle PyTorch DataLoader format - extract tensor from list if isinstance(data, (list, tuple)) and len(data) == 1: real_data = data[0] else: real_data = data # Ensure real_data is on the same device as model parameters (MPS) if hasattr(real_data, 'device'): # Get device from first model parameter model_device = next(self.generator.parameters()).device if real_data.device != model_device: real_data = real_data.to(model_device) batch_size = ops.shape(real_data)[0] # Generate ts random_vector = self._get_random_vector_labels(batch_size) fake_data = self.generator(random_vector) combined_data = ops.concatenate( [fake_data, real_data], axis=0 ) # Labels for descriminator # 1 == real data # 0 == fake data desc_labels = ops.concatenate( [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0 ) predictions = self.discriminator(combined_data) if self.use_wgan: fake_logits = self.discriminator(fake_data, training=True) # Get the logits for the real samples real_logits = self.discriminator(real_data, training=True) # Calculate the discriminator loss using the fake and real sample logits d_cost = self.wgan_discriminator_loss(real_logits, fake_logits) # Calculate the gradient penalty gp = self.gradient_penalty(batch_size, real_data, fake_data) # Add the gradient penalty to the original discriminator loss d_loss = d_cost + gp * self.gp_weight else: d_loss = self.loss_fn(desc_labels, predictions) self.discriminator.zero_grad() d_loss.backward() d_trainable_weights = [v for v in self.discriminator.trainable_weights] d_gradients = [v.value.grad for v in d_trainable_weights] with torch.no_grad(): # Keras 3 expects (gradient, variable) pairs grads_and_vars = list(zip(d_gradients, d_trainable_weights)) self.d_optimizer.apply_gradients(grads_and_vars) random_vector = self._get_random_vector_labels(batch_size=batch_size) misleading_labels = ops.zeros((batch_size, 1)) fake_data = self.generator(random_vector) predictions = self.discriminator(fake_data) if self.use_wgan: # uses logits g_loss = self.wgan_generator_loss(predictions) else: g_loss = self.loss_fn(misleading_labels, predictions) self.generator.zero_grad() g_loss.backward() g_trainable_weights = [v for v in self.generator.trainable_weights] g_gradients = [v.value.grad for v in g_trainable_weights] with torch.no_grad(): # Keras 3 expects (gradient, variable) pairs grads_and_vars = list(zip(g_gradients, g_trainable_weights)) self.g_optimizer.apply_gradients(grads_and_vars) self.gen_loss_tracker.update_state(g_loss) self.disc_loss_tracker.update_state(d_loss) return { "g_loss": self.gen_loss_tracker.result(), "d_loss": self.disc_loss_tracker.result(), }
[docs] def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: """ Performs a training step using a batch of data, stored in data. :param data: A batch of data in a format batch_size x seq_len x feat_dim :type data: tsgm.types.Tensor :returns: A dictionary with generator (key "g_loss") and discriminator (key "d_loss") losses :rtype: T.Dict[str, float] """ backend = get_backend() if os.environ.get("KERAS_BACKEND") == "tensorflow": return self.train_step_tf(backend, data) elif os.environ.get("KERAS_BACKEND") == "torch": return self.train_step_torch(backend, data)
[docs] def generate(self, num: int) -> tsgm.types.Tensor: """ Generates new data from the model. :param num: the number of samples to be generated. :type num: int :returns: Generated samples :rtype: tsgm.types.Tensor """ random_vector_labels = self._get_random_vector_labels(batch_size=num) return self.generator(random_vector_labels)
[docs] def clone(self) -> "GAN": """ Clones GAN object :returns: The exact copy of the object :rtype: "GAN" """ copy_model = GAN(self.discriminator, self.generator, latent_dim=self.latent_dim) copy_model = copy_model.set_weights(self.get_weights()) return copy_model
[docs]class ConditionalGAN(keras.Model): """ Conditional GAN implementation for labeled and temporally labeled time series. """ def __init__(self, discriminator: keras.Model, generator: keras.Model, latent_dim: int, temporal=False, use_wgan=False) -> None: """ :param discriminator: A discriminator model which takes a time series as input and check whether the sample is real or fake. :type discriminator: keras.Model :param generator: Takes as input a random noise vector of `latent_dim` length and return a simulated time-series. :type generator: keras.Model :param latent_dim: The size of the noise vector. :type latent_dim: int :param temporal: Indicates whether the time series temporally labeled or not. :type temporal: bool :param use_wgan: Use Wasserstein GAN with gradient penalty. Default is False. :type use_wgan: bool """ super(ConditionalGAN, self).__init__() self.discriminator = discriminator self.generator = generator self.latent_dim = latent_dim self._seq_len = self.generator.output_shape[1] self.gen_loss_tracker = keras.metrics.Mean(name="generator_loss") self.disc_loss_tracker = keras.metrics.Mean(name="discriminator_loss") self._temporal = temporal
[docs] def call(self, inputs): """ Forward pass for the ConditionalGAN model. This method is required for Keras 3 compatibility with PyTorch backend. """ # For Conditional GAN training, we don't typically call the model directly. # This is just a placeholder to satisfy Keras 3 requirements. # Return the inputs as-is since this is mainly used for building the model. return inputs
@property def metrics(self) -> T.List: """ :returns: A list of metrics trackers (e.g., generator's loss and discriminator's loss). :rtype: T.List """ return [self.gen_loss_tracker, self.disc_loss_tracker]
[docs] def compile(self, d_optimizer: keras.optimizers.Optimizer, g_optimizer: keras.optimizers.Optimizer, loss_fn: T.Callable) -> None: """ Compiles the generator and discriminator models. :param d_optimizer: An optimizer for the GAN's discriminator. :type d_optimizer: keras.optimizers.Optimizer :param g_optimizer: An optimizer for the GAN's generator. :type g_optimizer: keras.optimizers.Optimizer :param loss_fn: Loss function. :type loss_fn: keras.losses.Loss """ # TODO: move `.compile logic to a base GAN class super(ConditionalGAN, self).compile() self.d_optimizer = d_optimizer self.g_optimizer = g_optimizer self.loss_fn = loss_fn generator_dp = _is_dp_optimizer(d_optimizer) discriminator_dp = _is_dp_optimizer(g_optimizer) if generator_dp != discriminator_dp: logger.warning(f"One of the optimizers is DP and another one is not. generator_dp={generator_dp}, discriminator_dp={discriminator_dp}") self.dp = generator_dp and discriminator_dp
def _get_random_vector_labels(self, batch_size: int, labels: tsgm.types.Tensor) -> None: if self._temporal: random_latent_vectors = keras.random.normal(shape=(batch_size, self._seq_len, self.latent_dim)) random_vector_labels = ops.concatenate( [random_latent_vectors, labels[:, :, None]], axis=2 ) else: random_latent_vectors = keras.random.normal(shape=(batch_size, self.latent_dim)) random_vector_labels = ops.concatenate( [random_latent_vectors, labels], axis=1 ) return random_vector_labels def _get_output_shape(self, labels: tsgm.types.Tensor) -> int: if self._temporal: if len(labels.shape) == 2: return 1 else: return labels.shape[2] else: return labels.shape[1]
[docs] def train_step_tf(self, tf, data: T.Tuple) -> T.Dict[str, float]: real_ts = data[0] labels = data[1] output_dim = self._get_output_shape(labels) batch_size = ops.shape(real_ts)[0] if not self._temporal: rep_labels = labels[:, :, None] rep_labels = ops.repeat( rep_labels, repeats=[self._seq_len] ) else: rep_labels = labels rep_labels = ops.reshape( rep_labels, (-1, self._seq_len, output_dim) ) # Generate ts random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels) generated_ts = self.generator(random_vector_labels) fake_data = ops.concatenate([generated_ts, rep_labels], -1) real_data = ops.concatenate([real_ts, rep_labels], -1) combined_data = ops.concatenate( [fake_data, real_data], axis=0 ) # Labels for descriminator # 1 == real data # 0 == fake data desc_labels = ops.concatenate( [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0 ) with tf.GradientTape() as tape: predictions = self.discriminator(combined_data) d_loss = self.loss_fn(desc_labels, predictions) if self.dp: # For DP optimizers from `tensorflow.privacy` self.d_optimizer.minimize(d_loss, self.discriminator.trainable_weights, tape=tape) else: grads = tape.gradient(d_loss, self.discriminator.trainable_weights) self.d_optimizer.apply_gradients( zip(grads, self.discriminator.trainable_weights) ) random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels) # Pretend that all samples are real misleading_labels = ops.zeros((batch_size, 1)) # Train generator (with updating the discriminator) with tf.GradientTape() as tape: fake_samples = self.generator(random_vector_labels) fake_data = ops.concatenate([fake_samples, rep_labels], -1) predictions = self.discriminator(fake_data) g_loss = self.loss_fn(misleading_labels, predictions) if self.dp: # For DP optimizers from `tensorflow.privacy` self.g_optimizer.minimize(g_loss, self.generator.trainable_weights, tape=tape) else: grads = tape.gradient(g_loss, self.generator.trainable_weights) self.g_optimizer.apply_gradients(zip(grads, self.generator.trainable_weights)) self.gen_loss_tracker.update_state(g_loss) self.disc_loss_tracker.update_state(d_loss) return { "g_loss": self.gen_loss_tracker.result(), "d_loss": self.disc_loss_tracker.result(), }
[docs] def train_step_torch(self, torch, data: T.Tuple) -> T.Dict[str, float]: # Handle PyTorch DataLoader format if isinstance(data, (list, tuple)) and len(data) == 2: real_ts, labels = data else: # Fallback for single input real_ts, labels = data, None output_dim = self._get_output_shape(labels) batch_size = ops.shape(real_ts)[0] if not self._temporal: rep_labels = labels[:, :, None] rep_labels = ops.repeat( rep_labels, repeats=[self._seq_len] ) else: rep_labels = labels rep_labels = ops.reshape( rep_labels, (-1, self._seq_len, output_dim) ) # Generate ts random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels) generated_ts = self.generator(random_vector_labels) fake_data = ops.concatenate([generated_ts, rep_labels], -1) real_data = ops.concatenate([real_ts, rep_labels], -1) combined_data = ops.concatenate( [fake_data, real_data], axis=0 ) # Labels for descriminator # 1 == real data # 0 == fake data desc_labels = ops.concatenate( [ops.ones((batch_size, 1)), ops.zeros((batch_size, 1))], axis=0 ) predictions = self.discriminator(combined_data) d_loss = self.loss_fn(desc_labels, predictions) self.discriminator.zero_grad() d_loss.backward() d_trainable_weights = [v for v in self.discriminator.trainable_weights] d_gradients = [v.value.grad for v in d_trainable_weights] with torch.no_grad(): # Keras 3 expects (gradient, variable) pairs grads_and_vars = list(zip(d_gradients, d_trainable_weights)) self.d_optimizer.apply_gradients(grads_and_vars) random_vector_labels = self._get_random_vector_labels(batch_size=batch_size, labels=labels) # Pretend that all samples are real misleading_labels = ops.zeros((batch_size, 1)) # Train generator (with updating the discriminator) fake_samples = self.generator(random_vector_labels) fake_data = ops.concatenate([fake_samples, rep_labels], -1) predictions = self.discriminator(fake_data) g_loss = self.loss_fn(misleading_labels, predictions) self.generator.zero_grad() g_loss.backward() g_trainable_weights = [v for v in self.generator.trainable_weights] g_gradients = [v.value.grad for v in g_trainable_weights] with torch.no_grad(): # Keras 3 expects (gradient, variable) pairs grads_and_vars = list(zip(g_gradients, g_trainable_weights)) self.g_optimizer.apply_gradients(grads_and_vars) self.gen_loss_tracker.update_state(g_loss) self.disc_loss_tracker.update_state(d_loss) return { "g_loss": self.gen_loss_tracker.result(), "d_loss": self.disc_loss_tracker.result(), }
[docs] def train_step(self, data: T.Tuple) -> T.Dict[str, float]: """ Performs a training step using a batch of data, stored in data. :param data: A batch of data in a format batch_size x seq_len x feat_dim :type data: tsgm.types.Tensor :returns: A dictionary with generator (key "g_loss") and discriminator (key "d_loss") losses :rtype: T.Dict[str, float] """ backend = get_backend() if os.environ.get("KERAS_BACKEND") == "tensorflow": return self.train_step_tf(backend, data) elif os.environ.get("KERAS_BACKEND") == "torch": return self.train_step_torch(backend, data) else: raise ValueError(f"Unsupported backend: {os.environ.get('KERAS_BACKEND')}")
[docs] def generate(self, labels: tsgm.types.Tensor) -> tsgm.types.Tensor: """ Generates new data from the model. :param labels: The labels for which to generate samples. :type labels: tsgm.types.Tensor :returns: Generated samples. :rtype: tsgm.types.Tensor """ batch_size = labels.shape[0] random_vector_labels = self._get_random_vector_labels( batch_size=batch_size, labels=labels) return self.generator(random_vector_labels)