Source code for tsgm.models.cvae

import os
import keras
from keras import ops
import typing as T

from tsgm.backend import get_backend

import tsgm.utils


[docs]class BetaVAE(keras.Model): """ beta-VAE implementation for unlabeled time series. """ def __init__(self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs) -> None: """ :param encoder: An encoder model which takes a time series as input. :type encoder: keras.Model :param decoder: Takes as input a random noise vector and returns a simulated time-series. :type decoder: keras.Model :param beta: The weight of the KL divergence term. Default is 1.0. :type beta: float """ super(BetaVAE, self).__init__(**kwargs) self.beta = beta self.encoder = encoder self.decoder = decoder self.total_loss_tracker = keras.metrics.Mean(name="total_loss") self.reconstruction_loss_tracker = keras.metrics.Mean( name="reconstruction_loss" ) self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss") self._seq_len = self.decoder.output_shape[1] self.latent_dim = self.decoder.input_shape[1] @property def metrics(self) -> T.List: """ :returns: A list of metrics trackers (total loss, reconstruction loss, and KL loss). """ return [ self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker, ]
[docs] def call(self, X: tsgm.types.Tensor) -> tsgm.types.Tensor: """ Encodes and decodes time series dataset X. :param X: The input time series tensor. :type X: tsgm.types.Tensor :returns: Generated samples :rtype: tsgm.types.Tensor """ z_mean, _, _ = self.encoder(X) x_decoded = self.decoder(z_mean) if len(x_decoded.shape) == 1: x_decoded = x_decoded.reshape((1, -1)) return x_decoded
def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float: reconst_loss = tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0) +\ tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1) +\ tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2) return reconst_loss
[docs] def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict: with tf.GradientTape() as tape: z_mean, z_log_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss = self._get_reconstruction_loss(data, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss grads = tape.gradient(total_loss, self.trainable_weights) # I am not sure if this should be self.optimizer.apply(grads, model.trainable_weights) # see https://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/ self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict: z_mean, z_log_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss = self._get_reconstruction_loss(data, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss # Ensure total_loss is a scalar for PyTorch backward() if hasattr(total_loss, 'shape') and len(total_loss.shape) > 0: total_loss = ops.mean(total_loss) self.zero_grad() total_loss.backward() trainable_weights = [v for v in self.trainable_weights] gradients = [v.value.grad for v in trainable_weights] with torch.no_grad(): self.optimizer.apply(gradients, trainable_weights) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step_jax(self, jax, data: tsgm.types.Tensor) -> T.Dict: # JAX backend uses Keras 3.0 automatic differentiation z_mean, z_log_var, z = self.encoder(data) reconstruction = self.decoder(z) reconstruction_loss = self._get_reconstruction_loss(data, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step(self, data: tsgm.types.Tensor) -> T.Dict: """ Performs a training step using a batch of data, stored in data. :param data: A batch of data in a format batch_size x seq_len x feat_dim :type data: tsgm.types.Tensor :returns: A dict with losses :rtype: T.Dict """ backend = get_backend() if os.environ.get("KERAS_BACKEND") == "tensorflow": return self.train_step_tf(backend, data) elif os.environ.get("KERAS_BACKEND") == "torch": return self.train_step_torch(backend, data) elif os.environ.get("KERAS_BACKEND") == "jax": return self.train_step_jax(backend, data)
[docs] def generate(self, n: int) -> tsgm.types.Tensor: """ Generates new data from the model. :param n: the number of samples to be generated. :type n: int :returns: A tensor with generated samples. :rtype: tsgm.types.Tensor """ # keras 3.0 support z = keras.random.normal((n, self.latent_dim)) return self.decoder(z)
[docs]class cBetaVAE(keras.Model): def __init__(self, encoder: keras.Model, decoder: keras.Model, latent_dim: int, temporal: bool, beta: float = 1.0, **kwargs) -> None: super(cBetaVAE, self).__init__(**kwargs) self.beta = beta self.encoder = encoder self.decoder = decoder self.total_loss_tracker = keras.metrics.Mean(name="total_loss") self.reconstruction_loss_tracker = keras.metrics.Mean( name="reconstruction_loss" ) self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss") self._temporal = temporal self._seq_len = self.decoder.output_shape[1] self.latent_dim = latent_dim @property def metrics(self) -> T.List: """ Returns the list of loss tracker: `[loss, reconstruction_loss, kl_loss]`. """ return [ self.total_loss_tracker, self.reconstruction_loss_tracker, self.kl_loss_tracker, ]
[docs] def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]: """ Generates new data from the model. :param labels: The labels for which to generate conditional samples. :type labels: tsgm.types.Tensor :returns: A tuple of synthetically generated data and labels. :rtype: T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor] """ # keras 3.0 support batch_size = ops.shape(labels)[0] dtype = 'float32' if os.environ.get("KERAS_BACKEND") == "torch" else labels.dtype z = keras.random.normal((batch_size, self._seq_len, self.latent_dim), dtype=dtype) decoder_input = self._get_decoder_input(z, labels) return (self.decoder(decoder_input), labels)
[docs] def call(self, data: tsgm.types.Tensor) -> tsgm.types.Tensor: """ Encodes and decodes time series dataset. :param data: The input data, either a tensor or a tuple of (X, labels). :type data: tsgm.types.Tensor :returns: Generated samples. :rtype: tsgm.types.Tensor """ # Handle both single tensor and tuple of (X, labels) if isinstance(data, (list, tuple)) and len(data) == 2: X, labels = data else: # During model building, just return the input return data encoder_input = self._get_encoder_input(X, labels) z_mean, _, _ = self.encoder(encoder_input) decoder_input = self._get_decoder_input(z_mean, labels) x_decoded = self.decoder(decoder_input) if len(x_decoded.shape) == 1: x_decoded = x_decoded.reshape((1, -1)) return x_decoded
def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float: # keras 3.0 support reconst_loss = ops.sum(ops.square(X - Xr)) +\ ops.sum(ops.square(ops.mean(X, axis=1) - ops.mean(Xr, axis=1))) +\ ops.sum(ops.square(ops.mean(X, axis=2) - ops.mean(Xr, axis=2))) return reconst_loss def _get_encoder_input(self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor: # keras 3.0 support if os.environ.get("KERAS_BACKEND") == "torch" and hasattr(labels, 'dtype'): labels = ops.cast(labels, 'float32') if self._temporal: return ops.concatenate([X, labels[:, :, None]], axis=2) else: rep_labels = ops.repeat(labels[:, None, :], [self._seq_len], axis=1) return ops.concatenate([X, rep_labels], axis=2) def _get_decoder_input(self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor: # keras 3.0 support if os.environ.get("KERAS_BACKEND") == "torch" and hasattr(labels, 'dtype'): labels = ops.cast(labels, 'float32') if self._temporal: rep_labels = labels[:, :, None] else: rep_labels = ops.repeat(labels[:, None, :], [self._seq_len], axis=1) z = ops.reshape(z, [-1, self._seq_len, self.latent_dim]) return ops.concatenate([z, rep_labels], axis=2)
[docs] def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict[str, float]: X, labels = data with tf.GradientTape() as tape: encoder_input = self._get_encoder_input(X, labels) z_mean, z_log_var, z = self.encoder(encoder_input) decoder_input = self._get_decoder_input(z_mean, labels) reconstruction = self.decoder(decoder_input) reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss grads = tape.gradient(total_loss, self.trainable_weights) self.optimizer.apply_gradients(zip(grads, self.trainable_weights)) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict[str, float]: X, labels = data encoder_input = self._get_encoder_input(X, labels) z_mean, z_log_var, z = self.encoder(encoder_input) decoder_input = self._get_decoder_input(z_mean, labels) reconstruction = self.decoder(decoder_input) reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss # Ensure total_loss is a scalar for PyTorch backward() if hasattr(total_loss, 'shape') and len(total_loss.shape) > 0: total_loss = ops.mean(total_loss) self.zero_grad() total_loss.backward() trainable_weights = [v for v in self.trainable_weights] gradients = [v.value.grad for v in trainable_weights] with torch.no_grad(): # Keras 3 expects (gradient, variable) pairs grads_and_vars = list(zip(gradients, trainable_weights)) self.optimizer.apply_gradients(grads_and_vars) self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step_jax(self, jax, data: tsgm.types.Tensor) -> T.Dict[str, float]: X, labels = data encoder_input = self._get_encoder_input(X, labels) z_mean, z_log_var, z = self.encoder(encoder_input) decoder_input = self._get_decoder_input(z, labels) reconstruction = self.decoder(decoder_input) reconstruction_loss = self._get_reconstruction_loss(X, reconstruction) kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var)) kl_loss = ops.mean(ops.sum(kl_loss, axis=1)) total_loss = reconstruction_loss + self.beta * kl_loss self.total_loss_tracker.update_state(total_loss) self.reconstruction_loss_tracker.update_state(reconstruction_loss) self.kl_loss_tracker.update_state(kl_loss) return { "loss": self.total_loss_tracker.result(), "reconstruction_loss": self.reconstruction_loss_tracker.result(), "kl_loss": self.kl_loss_tracker.result(), }
[docs] def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]: """ Performs a training step using a batch of data, stored in data. :param data: A batch of data in a format batch_size x seq_len x feat_dim :type data: tsgm.types.Tensor :returns: A dict with losses :rtype: T.Dict[str, float] """ backend = get_backend() if os.environ.get("KERAS_BACKEND") == "tensorflow": return self.train_step_tf(backend, data) elif os.environ.get("KERAS_BACKEND") == "torch": return self.train_step_torch(backend, data) elif os.environ.get("KERAS_BACKEND") == "jax": return self.train_step_jax(backend, data)