Source code for gans.gan

import os
import typing

import matplotlib.pyplot as plt
import numpy as np
import torch
from networks.critic import Critic
from networks.generator import Generator
from sc_dataset import get_loader
from sklearn.manifold import TSNE
from torch.optim.lr_scheduler import ExponentialLR
from torch.utils.data.dataloader import DataLoader
from torch.utils.tensorboard import SummaryWriter


[docs] class GAN:
[docs] def __init__( self, genes_no: int, batch_size: int, latent_dim: int, gen_layers: typing.List[int], crit_layers: typing.List[int], device: typing.Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", library_size: typing.Optional[int] = 20000, ) -> None: """ Non-conditional single-cell RNA-seq GAN. Parameters ---------- genes_no : int Number of genes in the dataset. batch_size : int Training batch size. latent_dim : int Dimension of the latent space from which the noise vector is sampled. gen_layers : typing.List[int] List of integers corresponding to the number of neurons of each generator layer. crit_layers : typing.List[int] List of integers corresponding to the number of neurons of each critic layer. device : typing.Optional[str], optional Specifies to train on 'cpu' or 'cuda'. Only 'cuda' is supported for training the GAN but 'cpu' can be used for inference, by default "cuda" if torch.cuda.is_available() else"cpu". library_size : typing.Optional[int], optional Total number of counts per generated cell, by default 20000. """ torch.cuda.empty_cache() self.genes_no = genes_no self.batch_size = batch_size self.latent_dim = latent_dim self.gen_layers = gen_layers self.critic_layers = crit_layers self.device = device self.library_size = library_size self.gen = None self.crit = None self._build_model() self.step = 0 self.gen_opt = None self.crit_opt = None self.gen_lr_scheduler = None self.crit_lr_scheduler = None
[docs] @staticmethod def _generate_noise(batch_size: int, latent_dim: int, device: str) -> torch.Tensor: """ Function for creating noise vectors: Given the dimensions (batch_size, latent_dim). Parameters ---------- batch_size : int The number of samples to generate (normally equal to training batch size). latent_dim : int Dimension of the latent space to sample from. device : str The device type. Returns ------- torch.Tensor A tensor filled with random numbers from the standard normal distribution. """ return torch.randn(batch_size, latent_dim, device=device)
[docs] @staticmethod def _set_exponential_lr( optimizer: torch.optim.Optimizer, alpha_0: float, alpha_final: float, max_steps: int, ) -> ExponentialLR: """ Sets up exponentially decaying learning rate scheduler to be used with the optimizer. Parameters ---------- optimizer : torch.optim.Optimizer Optimizer for which to create an exponential learning rate scheduler. alpha_0 : float Initial learning rate. alpha_final : float Final learning rate. max_steps : int Total number of training steps. When current_step=max_steps, alpha_final will be set as the learning rate. Returns ------- ExponentialLR Exponential learning rate scheduler. Call the step() function on this scheduler in the training loop. """ # Find the decay rate of the exponential learning rate decay_rate = (alpha_final / alpha_0) ** (1 / max_steps) return ExponentialLR(optimizer=optimizer, gamma=decay_rate)
[docs] @staticmethod def _critic_loss( crit_fake_pred: torch.Tensor, crit_real_pred: torch.Tensor, gp: torch.Tensor, c_lambda: float, ) -> torch.Tensor: """ Compute critic's loss given the its scores on real and fake cells, the gradient penalty, and gradient penalty regularization hyper-parameter. Parameters ---------- crit_fake_pred : torch.Tensor Critic's score on fake cells. crit_real_pred : torch.Tensor Critic's score on real cells. gp : torch.Tensor Unweighted gradient penalty c_lambda : float Regularization hyper-parameter to be used with the gradient penalty in the WGAN loss. Returns ------- torch.Tensor Critic's loss for the current batch. """ return torch.mean(crit_fake_pred) - torch.mean(crit_real_pred) + c_lambda * gp
[docs] @staticmethod def _generator_loss(crit_fake_pred: torch.Tensor) -> torch.Tensor: """ Compute the generator loss from the critic's score of the generated cells. Parameters ---------- crit_fake_pred : torch.Tensor The critic's score on fake generated cells. Returns ------- torch.Tensor Generator's loss value for the current batch. """ return -1.0 * torch.mean(crit_fake_pred)
[docs] def _get_gradient( self, real: torch.Tensor, fake: torch.Tensor, epsilon: torch.Tensor, *args, **kwargs, ) -> torch.Tensor: """ Compute the gradient of the critic's scores with respect to interpolations of real and fake cells. Parameters ---------- real : torch.Tensor A batch of real cells. fake : torch.Tensor A batch of fake cells. epsilon : torch.Tensor A vector of the uniformly random proportions of real/fake per interpolated cells. Returns ------- torch.Tensor Gradient of the critic's score with respect to interpolated data. """ # Mix real and fake cells together interpolates = real * epsilon + fake * (1 - epsilon) # Calculate the critic's scores on the mixed data critic_interpolates = self.crit(interpolates) # Take the gradient of the scores with respect to the data gradient = torch.autograd.grad( inputs=interpolates, outputs=critic_interpolates, grad_outputs=torch.ones_like(critic_interpolates), create_graph=True, retain_graph=True, )[0] return gradient
[docs] @staticmethod def _gradient_penalty(gradient: torch.Tensor) -> torch.Tensor: """ Compute the gradient penalty given a gradient. Parameters ---------- gradient : torch.Tensor The gradient of the critic's score with respect to the interpolated data. Returns ------- torch.Tensor Gradient penalty of the given gradient. """ gradient = gradient.view(len(gradient), -1) gradient_norm = gradient.norm(2, dim=1) return torch.mean((gradient_norm - 1) ** 2)
[docs] def generate_cells( self, cells_no: int, checkpoint: typing.Optional[typing.Union[str, bytes, os.PathLike, None]] = None, ) -> np.ndarray: """ Generate cells from the GAN model. Parameters ---------- cells_no : int Number of cells to generate. checkpoint : typing.Optional[typing.Union[str, bytes, os.PathLike, None]], optional Path to the saved trained model, by default None. Returns ------- np.ndarray Gene expression matrix of generated cells. """ if checkpoint is not None: self._load(checkpoint) # find how many batches to generate batch_no = int(np.ceil(cells_no / self.batch_size)) fake_cells = [] for _ in range(batch_no): noise = self._generate_noise(self.batch_size, self.latent_dim, self.device) fake_cells.append(self.gen(noise).cpu().detach().numpy()) return np.concatenate(fake_cells)[:cells_no]
[docs] def _save(self, path: typing.Union[str, bytes, os.PathLike]) -> None: """ Saves the model. Parameters ---------- path : typing.Union[str, bytes, os.PathLike] Directory to save the model. """ output_dir = path + "/checkpoints" if not os.path.isdir(output_dir): os.makedirs(output_dir) torch.save( { "step": self.step, "generator_state_dict": self.gen.module.state_dict(), "critic_state_dict": self.crit.module.state_dict(), "generator_optimizer_state_dict": self.gen_opt.state_dict(), "critic_optimizer_state_dict": self.crit_opt.state_dict(), "generator_lr_scheduler": self.gen_lr_scheduler.state_dict(), "critic_lr_scheduler": self.crit_lr_scheduler.state_dict(), }, f"{path}/checkpoints/step_{self.step}.pth", )
[docs] def _load( self, path: typing.Union[str, bytes, os.PathLike], mode: typing.Optional[str] = "inference", ) -> None: """ Loads a saved model (.pth file). Parameters ---------- path : typing.Union[str, bytes, os.PathLike] Path to the saved model. mode : typing.Optional[str], optional Specify if the loaded model is used for 'inference' or 'training', by default "inference". Raises ------ ValueError If a mode other than 'inference' or 'training' is specified. """ checkpoint = torch.load(path, map_location=torch.device(self.device)) self.gen.load_state_dict(checkpoint["generator_state_dict"]) self.crit.load_state_dict(checkpoint["critic_state_dict"]) if mode == "inference": self.gen.eval() self.crit.eval() elif mode == "training": self.gen.train() self.crit.train() self.step = checkpoint["step"] + 1 self.gen_opt.load_state_dict(checkpoint["generator_optimizer_state_dict"]) self.crit_opt.load_state_dict(checkpoint["critic_optimizer_state_dict"]) self.gen_lr_scheduler.load_state_dict(checkpoint["generator_lr_scheduler"]) self.crit_lr_scheduler.load_state_dict(checkpoint["critic_lr_scheduler"]) else: raise ValueError("mode should be 'inference' or 'training'")
[docs] def _build_model(self) -> None: """Instantiates the Generator and Critic.""" self.gen = Generator( self.latent_dim, self.genes_no, self.gen_layers, self.library_size ).to(self.device) self.crit = Critic(self.genes_no, self.critic_layers).to(self.device)
[docs] def _get_loaders( self, train_file: typing.Union[str, bytes, os.PathLike], validation_file: typing.Union[str, bytes, os.PathLike], ) -> typing.Tuple[DataLoader, DataLoader]: """ Gets training and validation DataLoaders for training. Parameters ---------- train_file : typing.Union[str, bytes, os.PathLike] Path to training files. validation_file : typing.Union[str, bytes, os.PathLike] Path to validation files. Returns ------- typing.Tuple[DataLoader, DataLoader] Train and Validation Dataloaders. """ return get_loader(train_file, self.batch_size), get_loader(validation_file)
[docs] def _add_tensorboard_graph( self, output_dir: typing.Union[str, bytes, os.PathLike], gen_data: typing.Union[torch.Tensor, typing.Tuple[torch.Tensor]], crit_data: typing.Union[torch.Tensor, typing.Tuple[torch.Tensor]], ) -> None: """ Adds the model graph to TensorBoard. Parameters ---------- output_dir : typing.Union[str, bytes, os.PathLike] Directory to save the tfevents. gen_data : typing.Union[torch.Tensor, typing.Tuple[torch.Tensor]] Input to the generator. crit_data : typing.Union[torch.Tensor, typing.Tuple[torch.Tensor]] Input to the critic. """ with SummaryWriter(f"{output_dir}/TensorBoard/model/generator") as w: w.add_graph(self.gen.module, gen_data) with SummaryWriter(f"{output_dir}/TensorBoard/model/critic") as w: w.add_graph(self.crit.module, crit_data)
[docs] def _update_tensorboard( self, gen_loss: float, crit_loss: float, gp: torch.Tensor, gen_lr: float, crit_lr: float, output_dir: typing.Union[str, bytes, os.PathLike], ) -> None: """ Updates the TensorBoard summary logs. Parameters ---------- gen_loss : float Generator loss. crit_loss : float Critic loss. gp : torch.Tensor Gradient penalty. gen_lr : float Generator's optimizer learning rate. crit_lr : float Critic's optimizer learning rate. output_dir : typing.Union[str, bytes, os.PathLike] Directory to save the tfevents. """ with SummaryWriter(f"{output_dir}/TensorBoard/generator") as w: w.add_scalar("loss", gen_loss, self.step) with SummaryWriter(f"{output_dir}/TensorBoard/critic") as w: w.add_scalar("loss", crit_loss, self.step) with SummaryWriter(f"{output_dir}/TensorBoard/gp") as w: w.add_scalar("gradient penalty", gp, self.step) with SummaryWriter(f"{output_dir}/TensorBoard/generator_lr") as w: w.add_scalar("learning rate", gen_lr, self.step) with SummaryWriter(f"{output_dir}/TensorBoard/critic_lr") as w: w.add_scalar("learning rate", crit_lr, self.step)
[docs] def _generate_tsne_plot( self, valid_loader: DataLoader, output_dir: typing.Union[str, bytes, os.PathLike], ) -> None: """ Generates t-SNE plots during training. Parameters ---------- valid_loader : DataLoader Validation set DataLoader. output_dir : typing.Union[str, bytes, os.PathLike] Directory to save the t-SNE plots. """ tsne_path = output_dir + "/TSNE" if not os.path.isdir(tsne_path): os.makedirs(tsne_path) fake_cells = self.generate_cells(len(valid_loader.dataset)) valid_cells, _ = next(iter(valid_loader)) embedded_cells = TSNE().fit_transform( np.concatenate((valid_cells, fake_cells), axis=0) ) real_embedding = embedded_cells[0 : valid_cells.shape[0], :] fake_embedding = embedded_cells[valid_cells.shape[0] :, :] plt.clf() fig = plt.figure() plt.scatter( real_embedding[:, 0], real_embedding[:, 1], c="blue", label="real", alpha=0.5, ) plt.scatter( fake_embedding[:, 0], fake_embedding[:, 1], c="red", label="fake", alpha=0.5, ) plt.grid(True) plt.legend( loc="lower left", numpoints=1, ncol=2, fontsize=8, bbox_to_anchor=(0, 0) ) plt.savefig(tsne_path + "/step_" + str(self.step) + ".jpg") with SummaryWriter(f"{output_dir}/TensorBoard/TSNE") as w: w.add_figure("t-SNE plot", fig, self.step) plt.close()
[docs] def _train_critic( self, real_cells: torch.Tensor, real_labels: torch.Tensor, c_lambda: float ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """ Trains the critic for one iteration. Parameters ---------- real_cells : torch.Tensor Tensor containing a batch of real cells. real_labels : torch.Tensor Tensor containing a batch of real labels (corresponding to real_cells). c_lambda : float Regularization hyper-parameter for gradient penalty. Returns ------- typing.Tuple[torch.Tensor, torch.Tensor] The computed critic loss and gradient penalty. """ self.crit_opt.zero_grad() fake_noise = self._generate_noise(self.batch_size, self.latent_dim, self.device) fake = self.gen(fake_noise) crit_fake_pred = self.crit(fake.detach()) crit_real_pred = self.crit(real_cells) epsilon = torch.rand(len(real_cells), 1, device=self.device, requires_grad=True) gradient = self._get_gradient(real_cells, fake.detach(), epsilon) gp = self._gradient_penalty(gradient) crit_loss = self._critic_loss(crit_fake_pred, crit_real_pred, gp, c_lambda) # Update gradients crit_loss.backward(retain_graph=True) # Update optimizer self.crit_opt.step() return crit_loss, gp
[docs] def _train_generator(self) -> torch.Tensor: """ Trains the generator for one iteration. Returns ------- torch.Tensor Tensor containing only 1 item, the generator loss. """ self.gen_opt.zero_grad() fake_noise = self._generate_noise( self.batch_size, self.latent_dim, device=self.device ) fake = self.gen(fake_noise) crit_fake_pred = self.crit(fake) gen_loss = self._generator_loss(crit_fake_pred) gen_loss.backward() # Update weights self.gen_opt.step() return gen_loss
[docs] def train( self, train_files: str, valid_files: str, critic_iter: int, max_steps: int, c_lambda: float, beta1: float, beta2: float, gen_alpha_0: float, gen_alpha_final: float, crit_alpha_0: float, crit_alpha_final: float, checkpoint: typing.Optional[typing.Union[str, bytes, os.PathLike, None]] = None, output_dir: typing.Optional[str] = "output", summary_freq: typing.Optional[int] = 5000, plt_freq: typing.Optional[int] = 10000, save_feq: typing.Optional[int] = 10000, ) -> None: """ Method for training the GAN. Parameters ---------- train_files : str Path to training set files (TFrecords supported for now). valid_files : str Path to validation set files (TFrecords supported for now). critic_iter : int Number of training iterations of the critic for each iteration on the generator. max_steps : int Maximum number of steps to train the GAN. c_lambda : float Regularization hyper-parameter for gradient penalty. beta1 : float Coefficients used for computing running averages of gradient in the optimizer. beta2 : float Coefficient used for computing running averages of gradient squares in the optimizer. gen_alpha_0 : float Generator's initial learning rate value. gen_alpha_final : float Generator's final learning rate value. crit_alpha_0 : float Critic's initial learning rate value. crit_alpha_final : float Critic's final learning rate value. checkpoint : typing.Optional[typing.Union[str, bytes, os.PathLike, None]], optional Path to a trained model; if specified, the checkpoint is be used to resume training, by default None. output_dir : typing.Optional[str], optional Directory to which plots, tfevents, and checkpoints will be saved, by default "output". summary_freq : typing.Optional[int], optional Period between summary logs to TensorBoard, by default 5000. plt_freq : typing.Optional[int], optional Period between t-SNE plots, by default 10000. save_feq : typing.Optional[int], optional Period between saves of the model, by default 10000. """ def should_run(freq): return freq > 0 and self.step % freq == 0 and self.step > 0 loader, valid_loader = self._get_loaders(train_files, valid_files) loader_gen = iter(loader) # Instantiate optimizers self.gen_opt = torch.optim.AdamW( filter(lambda p: p.requires_grad, self.gen.parameters()), lr=gen_alpha_0, betas=(beta1, beta2), amsgrad=True, ) self.crit_opt = torch.optim.AdamW( self.crit.parameters(), lr=crit_alpha_0, betas=(beta1, beta2), amsgrad=True, ) # Exponential Learning Rate self.gen_lr_scheduler = self._set_exponential_lr( self.gen_opt, gen_alpha_0, gen_alpha_final, max_steps ) self.crit_lr_scheduler = self._set_exponential_lr( self.crit_opt, crit_alpha_0, crit_alpha_final, max_steps ) if checkpoint is not None: self._load(checkpoint, mode="training") self.gen.train() self.crit.train() # We only accept training on GPU since training on CPU is impractical. self.device = "cuda" self.gen = torch.nn.DataParallel(self.gen) self.crit = torch.nn.DataParallel(self.crit) # Main training loop generator_losses, critic_losses = [], [] while self.step <= max_steps: try: real_cells, real_labels = next(loader_gen) except StopIteration: loader_gen = iter(loader) real_cells, real_labels = next(loader_gen) real_cells = real_cells.to(self.device) real_labels = real_labels.flatten().to(self.device) if self.step != 0: mean_iter_crit_loss = 0 for _ in range(critic_iter): crit_loss, gp = self._train_critic( real_cells, real_labels, c_lambda ) mean_iter_crit_loss += crit_loss.item() / critic_iter critic_losses += [mean_iter_crit_loss] # Update learning rate self.crit_lr_scheduler.step() gen_loss = self._train_generator() self.gen_lr_scheduler.step() generator_losses += [gen_loss.item()] # Log and visualize progress if should_run(summary_freq): gen_mean = sum(generator_losses[-summary_freq:]) / summary_freq crit_mean = sum(critic_losses[-summary_freq:]) / summary_freq # if self.step == summary_freq: # self._add_tensorboard_graph(output_dir, fake_noise, fake) self._update_tensorboard( gen_mean, crit_mean, gp, self.gen_lr_scheduler.get_last_lr()[0], self.crit_lr_scheduler.get_last_lr()[0], output_dir, ) if should_run(plt_freq): self._generate_tsne_plot(valid_loader, output_dir) if should_run(save_feq): self._save(output_dir) print("done training step", self.step, flush=True) self.step += 1