Source code for gans.conditional_gan_cat

import os
import typing

import numpy as np
import torch
from networks.critic import Critic
from networks.generator import Generator

from gans.conditional_gan import ConditionalGAN


[docs] class ConditionalCatGAN(ConditionalGAN):
[docs] def __init__( self, genes_no: int, batch_size: int, latent_dim: int, gen_layers: typing.List[int], crit_layers: typing.List[int], num_classes: int, label_ratios: torch.Tensor, device: typing.Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", library_size: typing.Optional[int] = 20000, ) -> None: """ Conditional single-cell RNA-seq GAN using the conditioning method by concatenation. Parameters ---------- genes_no : int Number of genes in the dataset. batch_size : int Training batch size. latent_dim : int Dimension of the latent space from which the noise vector is sampled. gen_layers : typing.List[int] List of integers corresponding to the number of neurons of each generator layer. crit_layers : typing.List[int] List of integers corresponding to the number of neurons of each critic layer. num_classes : int Number of classes in the dataset. label_ratios : torch.Tensor Tensor containing the ratio of each class in the dataset. device : typing.Optional[str], optional Specifies to train on 'cpu' or 'cuda'. Only 'cuda' is supported for training the GAN but 'cpu' can be used for inference, by default "cuda" if torch.cuda.is_available() else"cpu". library_size : typing.Optional[int], optional Total number of counts per generated cell, by default 20000. """ self.num_classes = num_classes self.label_ratios = label_ratios super(ConditionalCatGAN, self).__init__( genes_no, batch_size, latent_dim, gen_layers, crit_layers, device, library_size, )
[docs] def _get_gradient( self, real: torch.Tensor, fake: torch.Tensor, epsilon: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs ) -> torch.Tensor: """ Compute the gradient of the critic's scores with respect to interpolations of real and fake cells. Parameters ---------- real : torch.Tensor A batch of real cells. fake : torch.Tensor A batch of fake cells. epsilon : torch.Tensor A vector of the uniformly random proportions of real/fake per interpolated cells. labels : torch.Tensor A batch of real class labels. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor Gradient of the critic's score with respect to interpolated data. """ # Mix real and fake cells together interpolates = real * epsilon + fake * (1 - epsilon) # Calculate the critic's scores on the mixed data critic_interpolates = self.crit(self._cat_one_hot_labels(interpolates, labels)) # Take the gradient of the scores with respect to the data gradient = torch.autograd.grad( inputs=interpolates, outputs=critic_interpolates, grad_outputs=torch.ones_like(critic_interpolates), create_graph=True, retain_graph=True, )[0] return gradient
[docs] def _cat_one_hot_labels( self, cells: torch.Tensor, labels: torch.Tensor ) -> torch.Tensor: """ Concatenates one-hot encoded labels to a tensor. Parameters ---------- cells : torch.Tensor Tensor to which to concatenate one-hot encoded class labels. labels : torch.Tensor Class labels to concatenate. Returns ------- torch.Tensor Tensor with one-hot encoded labels concatenated at the tail. """ one_hot = torch.nn.functional.one_hot(labels, self.num_classes) return torch.cat((cells.float(), one_hot.float()), 1)
[docs] def generate_cells( self, cells_no: int, checkpoint: typing.Optional[typing.Union[str, bytes, os.PathLike, None]] = None, class_: typing.Optional[typing.Union[int, None]] = None, ) -> typing.Tuple[np.ndarray, np.ndarray]: """ Generate cells from the Conditional GAN model. Parameters ---------- cells_no : int Number of cells to generate. checkpoint : typing.Optional[typing.Union[str, bytes, os.PathLike, None]], optional Path to the saved trained model, by default None. class_: typing.Optional[typing.Union[int, None]] = None Class of the cells to generate. If None, cells with the same ratio per class will be generated. Returns ------- typing.Tuple[np.ndarray, np.ndarray] Gene expression matrix of generated cells and their corresponding class labels. """ if checkpoint is not None: self._load(checkpoint) batch_no = int(np.ceil(cells_no / self.batch_size)) fake_cells = [] fake_labels = [] for _ in range(batch_no): noise = self._generate_noise(self.batch_size, self.latent_dim, self.device) if class_ is None: labels = self._sample_pseudo_labels( self.batch_size, self.label_ratios ).to(self.device) else: label_ratios = torch.zeros(self.num_classes).to(self.device) label_ratios[class_] = 0.99 labels = self._sample_pseudo_labels(self.batch_size, label_ratios).to( self.device ) fake_cells.append( self.gen(self._cat_one_hot_labels(noise, labels)).cpu().detach().numpy() ) fake_labels.append(labels.cpu().detach().numpy()) return ( np.concatenate(fake_cells)[:cells_no], np.concatenate(fake_labels)[:cells_no], )
[docs] def _build_model(self) -> None: """Initializes the Generator and Critic.""" self.gen = Generator( self.latent_dim + self.num_classes, self.genes_no, self.gen_layers, self.library_size, ).to(self.device) self.crit = Critic(self.genes_no + self.num_classes, self.critic_layers).to( self.device )
[docs] def _train_critic( self, real_cells, real_labels, c_lambda ) -> typing.Tuple[torch.Tensor, torch.Tensor]: """ Trains the critic for one iteration. Parameters ---------- real_cells : torch.Tensor Tensor containing a batch of real cells. real_labels : torch.Tensor Tensor containing a batch of real labels (corresponding to real_cells). c_lambda : float Regularization hyper-parameter for gradient penalty. Returns ------- typing.Tuple[torch.Tensor, torch.Tensor] The computed critic loss and gradient penalty. """ self.crit_opt.zero_grad() fake_noise = self._generate_noise(self.batch_size, self.latent_dim, self.device) fake = self.gen(self._cat_one_hot_labels(fake_noise, real_labels)) crit_fake_pred = self.crit(self._cat_one_hot_labels(fake, real_labels).detach()) crit_real_pred = self.crit(self._cat_one_hot_labels(real_cells, real_labels)) epsilon = torch.rand(len(real_cells), 1, device=self.device, requires_grad=True) gradient = self._get_gradient(real_cells, fake.detach(), epsilon, real_labels) gp = self._gradient_penalty(gradient) crit_loss = self._critic_loss(crit_fake_pred, crit_real_pred, gp, c_lambda) # Update gradients crit_loss.backward(retain_graph=True) # Update optimizer self.crit_opt.step() return crit_loss, gp
[docs] def _train_generator(self) -> torch.Tensor: """ Trains the generator for one iteration. Returns ------- torch.Tensor Tensor containing only 1 item, the generator loss. """ self.gen_opt.zero_grad() fake_noise = self._generate_noise( self.batch_size, self.latent_dim, device=self.device ) fake_labels = self._sample_pseudo_labels(self.batch_size, self.label_ratios).to( self.device ) fake = self.gen(self._cat_one_hot_labels(fake_noise, fake_labels)) crit_fake_pred = self.crit(self._cat_one_hot_labels(fake, fake_labels)) gen_loss = self._generator_loss(crit_fake_pred) gen_loss.backward() # Update weights self.gen_opt.step() return gen_loss