import os
import keras
from keras import ops
import typing as T
from tsgm.backend import get_backend
import tsgm.utils
[docs]class BetaVAE(keras.Model):
"""
beta-VAE implementation for unlabeled time series.
"""
def __init__(self, encoder: keras.Model, decoder: keras.Model, beta: float = 1.0, **kwargs) -> None:
"""
:param encoder: An encoder model which takes a time series as input.
:type encoder: keras.Model
:param decoder: Takes as input a random noise vector and returns a simulated time-series.
:type decoder: keras.Model
:param beta: The weight of the KL divergence term. Default is 1.0.
:type beta: float
"""
super(BetaVAE, self).__init__(**kwargs)
self.beta = beta
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
self._seq_len = self.decoder.output_shape[1]
self.latent_dim = self.decoder.input_shape[1]
@property
def metrics(self) -> T.List:
"""
:returns: A list of metrics trackers (total loss, reconstruction loss, and KL loss).
"""
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
[docs] def call(self, X: tsgm.types.Tensor) -> tsgm.types.Tensor:
"""
Encodes and decodes time series dataset X.
:param X: The input time series tensor.
:type X: tsgm.types.Tensor
:returns: Generated samples
:rtype: tsgm.types.Tensor
"""
z_mean, _, _ = self.encoder(X)
x_decoded = self.decoder(z_mean)
if len(x_decoded.shape) == 1:
x_decoded = x_decoded.reshape((1, -1))
return x_decoded
def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float:
reconst_loss = tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=0) +\
tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=1) +\
tsgm.utils.reconstruction_loss_by_axis(X, Xr, axis=2)
return reconst_loss
[docs] def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict:
with tf.GradientTape() as tape:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
# I am not sure if this should be self.optimizer.apply(grads, model.trainable_weights)
# see https://keras.io/guides/writing_a_custom_training_loop_in_tensorflow/
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict:
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
# Ensure total_loss is a scalar for PyTorch backward()
if hasattr(total_loss, 'shape') and len(total_loss.shape) > 0:
total_loss = ops.mean(total_loss)
self.zero_grad()
total_loss.backward()
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
with torch.no_grad():
self.optimizer.apply(gradients, trainable_weights)
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step_jax(self, jax, data: tsgm.types.Tensor) -> T.Dict:
# JAX backend uses Keras 3.0 automatic differentiation
z_mean, z_log_var, z = self.encoder(data)
reconstruction = self.decoder(z)
reconstruction_loss = self._get_reconstruction_loss(data, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step(self, data: tsgm.types.Tensor) -> T.Dict:
"""
Performs a training step using a batch of data, stored in data.
:param data: A batch of data in a format batch_size x seq_len x feat_dim
:type data: tsgm.types.Tensor
:returns: A dict with losses
:rtype: T.Dict
"""
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
return self.train_step_tf(backend, data)
elif os.environ.get("KERAS_BACKEND") == "torch":
return self.train_step_torch(backend, data)
elif os.environ.get("KERAS_BACKEND") == "jax":
return self.train_step_jax(backend, data)
[docs] def generate(self, n: int) -> tsgm.types.Tensor:
"""
Generates new data from the model.
:param n: the number of samples to be generated.
:type n: int
:returns: A tensor with generated samples.
:rtype: tsgm.types.Tensor
"""
# keras 3.0 support
z = keras.random.normal((n, self.latent_dim))
return self.decoder(z)
[docs]class cBetaVAE(keras.Model):
def __init__(self, encoder: keras.Model, decoder: keras.Model, latent_dim: int, temporal: bool, beta: float = 1.0, **kwargs) -> None:
super(cBetaVAE, self).__init__(**kwargs)
self.beta = beta
self.encoder = encoder
self.decoder = decoder
self.total_loss_tracker = keras.metrics.Mean(name="total_loss")
self.reconstruction_loss_tracker = keras.metrics.Mean(
name="reconstruction_loss"
)
self.kl_loss_tracker = keras.metrics.Mean(name="kl_loss")
self._temporal = temporal
self._seq_len = self.decoder.output_shape[1]
self.latent_dim = latent_dim
@property
def metrics(self) -> T.List:
"""
Returns the list of loss tracker: `[loss, reconstruction_loss, kl_loss]`.
"""
return [
self.total_loss_tracker,
self.reconstruction_loss_tracker,
self.kl_loss_tracker,
]
[docs] def generate(self, labels: tsgm.types.Tensor) -> T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]:
"""
Generates new data from the model.
:param labels: The labels for which to generate conditional samples.
:type labels: tsgm.types.Tensor
:returns: A tuple of synthetically generated data and labels.
:rtype: T.Tuple[tsgm.types.Tensor, tsgm.types.Tensor]
"""
# keras 3.0 support
batch_size = ops.shape(labels)[0]
dtype = 'float32' if os.environ.get("KERAS_BACKEND") == "torch" else labels.dtype
z = keras.random.normal((batch_size, self._seq_len, self.latent_dim), dtype=dtype)
decoder_input = self._get_decoder_input(z, labels)
return (self.decoder(decoder_input), labels)
[docs] def call(self, data: tsgm.types.Tensor) -> tsgm.types.Tensor:
"""
Encodes and decodes time series dataset.
:param data: The input data, either a tensor or a tuple of (X, labels).
:type data: tsgm.types.Tensor
:returns: Generated samples.
:rtype: tsgm.types.Tensor
"""
# Handle both single tensor and tuple of (X, labels)
if isinstance(data, (list, tuple)) and len(data) == 2:
X, labels = data
else:
# During model building, just return the input
return data
encoder_input = self._get_encoder_input(X, labels)
z_mean, _, _ = self.encoder(encoder_input)
decoder_input = self._get_decoder_input(z_mean, labels)
x_decoded = self.decoder(decoder_input)
if len(x_decoded.shape) == 1:
x_decoded = x_decoded.reshape((1, -1))
return x_decoded
def _get_reconstruction_loss(self, X: tsgm.types.Tensor, Xr: tsgm.types.Tensor) -> float:
# keras 3.0 support
reconst_loss = ops.sum(ops.square(X - Xr)) +\
ops.sum(ops.square(ops.mean(X, axis=1) - ops.mean(Xr, axis=1))) +\
ops.sum(ops.square(ops.mean(X, axis=2) - ops.mean(Xr, axis=2)))
return reconst_loss
def _get_encoder_input(self, X: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor:
# keras 3.0 support
if os.environ.get("KERAS_BACKEND") == "torch" and hasattr(labels, 'dtype'):
labels = ops.cast(labels, 'float32')
if self._temporal:
return ops.concatenate([X, labels[:, :, None]], axis=2)
else:
rep_labels = ops.repeat(labels[:, None, :], [self._seq_len], axis=1)
return ops.concatenate([X, rep_labels], axis=2)
def _get_decoder_input(self, z: tsgm.types.Tensor, labels: tsgm.types.Tensor) -> tsgm.types.Tensor:
# keras 3.0 support
if os.environ.get("KERAS_BACKEND") == "torch" and hasattr(labels, 'dtype'):
labels = ops.cast(labels, 'float32')
if self._temporal:
rep_labels = labels[:, :, None]
else:
rep_labels = ops.repeat(labels[:, None, :], [self._seq_len], axis=1)
z = ops.reshape(z, [-1, self._seq_len, self.latent_dim])
return ops.concatenate([z, rep_labels], axis=2)
[docs] def train_step_tf(self, tf, data: tsgm.types.Tensor) -> T.Dict[str, float]:
X, labels = data
with tf.GradientTape() as tape:
encoder_input = self._get_encoder_input(X, labels)
z_mean, z_log_var, z = self.encoder(encoder_input)
decoder_input = self._get_decoder_input(z_mean, labels)
reconstruction = self.decoder(decoder_input)
reconstruction_loss = self._get_reconstruction_loss(X, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
grads = tape.gradient(total_loss, self.trainable_weights)
self.optimizer.apply_gradients(zip(grads, self.trainable_weights))
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step_torch(self, torch, data: tsgm.types.Tensor) -> T.Dict[str, float]:
X, labels = data
encoder_input = self._get_encoder_input(X, labels)
z_mean, z_log_var, z = self.encoder(encoder_input)
decoder_input = self._get_decoder_input(z_mean, labels)
reconstruction = self.decoder(decoder_input)
reconstruction_loss = self._get_reconstruction_loss(X, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
# Ensure total_loss is a scalar for PyTorch backward()
if hasattr(total_loss, 'shape') and len(total_loss.shape) > 0:
total_loss = ops.mean(total_loss)
self.zero_grad()
total_loss.backward()
trainable_weights = [v for v in self.trainable_weights]
gradients = [v.value.grad for v in trainable_weights]
with torch.no_grad():
# Keras 3 expects (gradient, variable) pairs
grads_and_vars = list(zip(gradients, trainable_weights))
self.optimizer.apply_gradients(grads_and_vars)
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step_jax(self, jax, data: tsgm.types.Tensor) -> T.Dict[str, float]:
X, labels = data
encoder_input = self._get_encoder_input(X, labels)
z_mean, z_log_var, z = self.encoder(encoder_input)
decoder_input = self._get_decoder_input(z, labels)
reconstruction = self.decoder(decoder_input)
reconstruction_loss = self._get_reconstruction_loss(X, reconstruction)
kl_loss = -0.5 * (1 + z_log_var - ops.square(z_mean) - ops.exp(z_log_var))
kl_loss = ops.mean(ops.sum(kl_loss, axis=1))
total_loss = reconstruction_loss + self.beta * kl_loss
self.total_loss_tracker.update_state(total_loss)
self.reconstruction_loss_tracker.update_state(reconstruction_loss)
self.kl_loss_tracker.update_state(kl_loss)
return {
"loss": self.total_loss_tracker.result(),
"reconstruction_loss": self.reconstruction_loss_tracker.result(),
"kl_loss": self.kl_loss_tracker.result(),
}
[docs] def train_step(self, data: tsgm.types.Tensor) -> T.Dict[str, float]:
"""
Performs a training step using a batch of data, stored in data.
:param data: A batch of data in a format batch_size x seq_len x feat_dim
:type data: tsgm.types.Tensor
:returns: A dict with losses
:rtype: T.Dict[str, float]
"""
backend = get_backend()
if os.environ.get("KERAS_BACKEND") == "tensorflow":
return self.train_step_tf(backend, data)
elif os.environ.get("KERAS_BACKEND") == "torch":
return self.train_step_torch(backend, data)
elif os.environ.get("KERAS_BACKEND") == "jax":
return self.train_step_jax(backend, data)