Source code for tsgm.models.monitors

import os
import logging
import numpy as np
import keras
from keras import ops
import typing as T

import seaborn as sns
import matplotlib.pyplot as plt

import tsgm.types
import tsgm.utils


logger = logging.getLogger('monitors')
logger.setLevel(logging.DEBUG)


def _to_numpy(tensor):
    """Convert tensor to numpy array safely across backends."""
    if os.environ.get("KERAS_BACKEND") == "torch":
        try:
            import torch
            if isinstance(tensor, torch.Tensor):
                return tensor.detach().cpu().numpy()
        except ImportError:
            pass
    elif hasattr(tensor, 'numpy'):
        try:
            return tensor.numpy()
        except TypeError:
            # Handle cases where .numpy() might fail (e.g., MPS tensors)
            if hasattr(tensor, 'cpu'):
                return tensor.cpu().numpy()
    return np.asarray(tensor)


[docs]class GANMonitor(keras.callbacks.Callback): """ GANMonitor is a Keras callback for monitoring and visualizing generated samples during training. :param num_samples: The number of samples to generate and visualize. :type num_samples: int :param latent_dim: The dimensionality of the latent space. Defaults to 128. :type latent_dim: int :param labels: The labels for conditional generation. :type labels: tsgm.types.Tensor :param save: Whether to save the generated samples. Defaults to True. :type save: bool :param save_path: The path to save the generated samples. Defaults to None. :type save_path: str :param mode: The generation mode, one of 'clf' or 'reg'. Defaults to 'clf'. :type mode: str :raises ValueError: If the mode is not one of ['clf', 'reg'] :note: If `save` is True and `save_path` is not specified, the default save path is "/tmp/". :warning: If `save_path` is specified but `save` is False, a warning is issued. """ def __init__(self, num_samples: int, latent_dim: int, labels: tsgm.types.Tensor, save: bool = True, save_path: T.Optional[str] = None, mode: str = "clf") -> None: self._num_samples = num_samples self._latent_dim = latent_dim self._save = save self._save_path = save_path self._mode = mode if self._mode not in ["clf", "reg"]: raise ValueError("The mode should be in ['clf', 'reg']") self._labels = labels if self._save and self._save_path is None: self._save_path = "/tmp/" logger.warning("save_path is not specified. Using `/tmp` as the default save_path") if self._save_path is not None: if self._save is False: logger.warning("save_path is specified, but save is False.") os.makedirs(self._save_path, exist_ok=True)
[docs] def on_epoch_end(self, epoch: int, logs: T.Optional[T.Dict] = None) -> None: """ Callback function called at the end of each training epoch. :param epoch: Current epoch number. :type epoch: int :param logs: Dictionary containing the training loss values. :type logs: dict """ if self._mode in ["clf", "reg"]: random_latent_vectors = keras.random.normal(shape=(self._num_samples, self._latent_dim)) elif self._mode == "temporal": raise NotImplementedError # random_latent_vectors = tf.random.normal(shape=(self._output_dim * self._num_samples, self._latent_dim)) else: raise ValueError("Invalid `mode` in GANMonitor: ", self._mode) labels = self._labels[:self._num_samples] generator_input = ops.concatenate([random_latent_vectors, labels], 1) generated_samples = self.model.generator(generator_input) for i in range(generated_samples.shape[0]): labels_np = _to_numpy(labels) label = np.argmax(labels_np[i][None, :], axis=1) tsgm.utils.visualize_ts_lineplot( generated_samples[i][None, :], label, 1) # TODO: update visualize_ts API if self._save: plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i))) else: plt.show()
[docs]class VAEMonitor(keras.callbacks.Callback): """ VAEMonitor is a Keras callback for monitoring and visualizing generated samples from a Variational Autoencoder (VAE) during training. :param num_samples: The number of samples to generate and visualize. Defaults to 6. :type num_samples: int :param latent_dim: The dimensionality of the latent space. Defaults to 128. :type latent_dim: int :param output_dim: The dimensionality of the output space. Defaults to 2. :type output_dim: int :param save: Whether to save the generated samples. Defaults to True. :type save: bool :param save_path: The path to save the generated samples. Defaults to None. :type save_path: str :raises ValueError: If `output_dim` is less than or equal to 0. :note: If `save` is True and `save_path` is not specified, the default save path is "/tmp/". :warning: If `save_path` is specified but `save` is False, a warning is issued. """ def __init__(self, num_samples: int = 6, latent_dim: int = 128, output_dim: int = 2, save: bool = True, save_path: T.Optional[str] = None) -> None: self._num_samples = num_samples self._latent_dim = latent_dim self._output_dim = output_dim self._save = save self._save_path = save_path if self._save and self._save_path is None: self._save_path = "/tmp/" logger.warning("save_path is not specified. Using `/tmp` as the default save_path") if self._save_path is not None: if self._save is False: logger.warning("save_path is specified, but save is False.") os.makedirs(self._save_path, exist_ok=True)
[docs] def on_epoch_end(self, epoch: int, logs: T.Optional[T.Dict] = None) -> None: """ Callback function called at the end of each training epoch. :param epoch: The current epoch number. :type epoch: int :param logs: Dictionary containing the training loss values. :type logs: dict """ labels = [] for i in range(self._output_dim): if not len(labels): # Use float32 for MPS compatibility labels = keras.utils.to_categorical([i], self._output_dim).astype('float32') else: # Use float32 for MPS compatibility new_label = keras.utils.to_categorical([i], self._output_dim).astype('float32') labels = ops.concatenate((labels, new_label), 0) labels = ops.repeat(labels, self._num_samples, axis=0) generated_images, _ = self.model.generate(labels) for i in range(self._output_dim * self._num_samples): sns.lineplot( x=range(0, generated_images[i].shape[0]), y=_to_numpy(ops.squeeze(generated_images[i])) ) if self._save: plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i))) else: plt.show()