Source code for tsgm.models.monitors

import os
import logging
import tensorflow as tf
import numpy as np
from tensorflow import keras
import typing as T

import seaborn as sns
import matplotlib.pyplot as plt

import tsgm.types
import tsgm.utils


logger = logging.getLogger('monitors')
logger.setLevel(logging.DEBUG)


[docs]class GANMonitor(keras.callbacks.Callback): """ GANMonitor is a Keras callback for monitoring and visualizing generated samples during training. :param num_samples: The number of samples to generate and visualize. :type num_samples: int :param latent_dim: The dimensionality of the latent space. Defaults to 128. :type latent_dim: int :param output_dim: The dimensionality of the output space. Defaults to 2. :type output_dim: int :param save: Whether to save the generated samples. Defaults to True. :type save: bool :param save_path: The path to save the generated samples. Defaults to None. :type save_path: str :raises ValueError: If the mode is not one of ['clf', 'reg'] :note: If `save` is True and `save_path` is not specified, the default save path is "/tmp/". :warning: If `save_path` is specified but `save` is False, a warning is issued. """ def __init__(self, num_samples: int, latent_dim: int, labels: tsgm.types.Tensor, save: bool = True, save_path: T.Optional[str] = None, mode: str = "clf") -> None: self._num_samples = num_samples self._latent_dim = latent_dim self._save = save self._save_path = save_path self._mode = mode if self._mode not in ["clf", "reg"]: raise ValueError("The mode should be in ['clf', 'reg']") self._labels = labels if self._save and self._save_path is None: self._save_path = "/tmp/" logger.warning("save_path is not specified. Using `/tmp` as the default save_path") if self._save_path is not None: if self._save is False: logger.warning("save_path is specified, but save is False.") os.makedirs(self._save_path, exist_ok=True)
[docs] def on_epoch_end(self, epoch: int, logs: T.Optional[T.Dict] = None) -> None: """ Callback function called at the end of each training epoch. :param epoch: Current epoch number. :type epoch: int :param logs: Dictionary containing the training loss values. :type logs: dict """ if self._mode in ["clf", "reg"]: random_latent_vectors = tf.random.normal(shape=(self._num_samples, self._latent_dim)) elif self._mode == "temporal": raise NotImplementedError # random_latent_vectors = tf.random.normal(shape=(self._output_dim * self._num_samples, self._latent_dim)) else: raise ValueError("Invalid `mode` in GANMonitor: ", self._mode) labels = self._labels[:self._num_samples] generator_input = tf.concat([random_latent_vectors, labels], 1) generated_samples = self.model.generator(generator_input) for i in range(generated_samples.shape[0]): label = np.argmax(labels[i][None, :], axis=1) tsgm.utils.visualize_ts_lineplot( generated_samples[i][None, :], label, 1) # TODO: update visualize_ts API if self._save: plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i))) else: plt.show()
[docs]class VAEMonitor(keras.callbacks.Callback): """ VAEMonitor is a Keras callback for monitoring and visualizing generated samples from a Variational Autoencoder (VAE) during training. :param num_samples: The number of samples to generate and visualize. Defaults to 6. :type num_samples: int :param latent_dim: The dimensionality of the latent space. Defaults to 128. :type latent_dim: int :param output_dim: The dimensionality of the output space. Defaults to 2. :type output_dim: int :param save: Whether to save the generated samples. Defaults to True. :type save: bool :param save_path: The path to save the generated samples. Defaults to None. :type save_path: str :raises ValueError: If `output_dim` is less than or equal to 0. :note: If `save` is True and `save_path` is not specified, the default save path is "/tmp/". :warning: If `save_path` is specified but `save` is False, a warning is issued. """ def __init__(self, num_samples: int = 6, latent_dim: int = 128, output_dim: int = 2, save: bool = True, save_path: T.Optional[str] = None) -> None: self._num_samples = num_samples self._latent_dim = latent_dim self._output_dim = output_dim self._save = save self._save_path = save_path if self._save and self._save_path is None: self._save_path = "/tmp/" logger.warning("save_path is not specified. Using `/tmp` as the default save_path") if self._save_path is not None: if self._save is False: logger.warning("save_path is specified, but save is False.") os.makedirs(self._save_path, exist_ok=True)
[docs] def on_epoch_end(self, epoch: int, logs: T.Optional[T.Dict] = None) -> None: """ Callback function called at the end of each training epoch. :param epoch: The current epoch number. :type epoch: int :param logs: Dictionary containing the training loss values. :type logs: dict """ labels = [] for i in range(self._output_dim): if not len(labels): labels = keras.utils.to_categorical([i], self._output_dim) else: labels = tf.concat((labels, keras.utils.to_categorical([i], self._output_dim)), 0) labels = tf.repeat(labels, self._num_samples, axis=0) generated_images, _ = self.model.generate(labels) for i in range(self._output_dim * self._num_samples): sns.lineplot( x=range(0, generated_images[i].shape[0]), y=tf.squeeze(generated_images[i]).numpy() ) if self._save: plt.savefig(os.path.join(self._save_path, "epoch_{}_sample_{}".format(epoch, i))) else: plt.show()