gans.causal_gan module#

class gans.causal_gan.CausalGAN(genes_no: int, batch_size: int, latent_dim: int, noise_per_gene: int, depth_per_gene: int, width_per_gene: int, cc_latent_dim: int, cc_layers: List[int], cc_pretrained_checkpoint: str, crit_layers: List[int], causal_graph: Dict[int, Set[int]], labeler_layers: List[int], device: str | None = 'cpu', library_size: int | None = 20000)[source]#

Bases: GAN

__init__(genes_no: int, batch_size: int, latent_dim: int, noise_per_gene: int, depth_per_gene: int, width_per_gene: int, cc_latent_dim: int, cc_layers: List[int], cc_pretrained_checkpoint: str, crit_layers: List[int], causal_graph: Dict[int, Set[int]], labeler_layers: List[int], device: str | None = 'cpu', library_size: int | None = 20000) None[source]#

Causal single-cell RNA-seq GAN (TODO: find a unique name).

  • genes_no (int) – Number of genes in the dataset.

  • batch_size (int) – Training batch size.

  • latent_dim (int) – Dimension of the latent space from which the noise vector used by the causal controller is sampled.

  • noise_per_gene (int) – Dimension of the latent space from which the noise vectors used by target generators is sampled.

  • depth_per_gene (int) – Depth of the target generator networks.

  • width_per_gene (int) – The width scale used for the target generator networks.

  • cc_latent_dim (int) – Dimension of the latent space from which the noise vector to the causal controller is sampled.

  • cc_layers (List[int]) – List of integers corresponding to the number of neurons of each causal controller layer.

  • cc_pretrained_checkpoint (str) – Path to the pretrained causal controller.

  • crit_layers (List[int]) – List of integers corresponding to the number of neurons of each critic layer.

  • causal_graph (Dict[int, Set[int]]) –

    The causal graph is a dictionary representing the TRN to impose. It has the following format: {target gene index: {TF1 index, TF2 index, …}}. This causal graph has to be acyclic and bipartite. A TF cannot be regulated by another TF. Invalid: {1: {2, 3, {4, 6}}, …} - a regulator (TF) is regulated by another regulator (TF) Invalid: {1: {2, 3, 4}, 2: {4, 3, 5}, …} - a regulator (TF) is also regulated Invalid: {4: {2, 3}, 2: {4, 3}} - contains a cycle

    Valid causal graph example: {1: {2, 3, 4}, 6: {5, 4, 2}, …}

  • labeler_layers (List[int]) – List of integers corresponding to the width of each labeler layer.

  • device (Optional[str], optional) – Specifies to train on ‘cpu’ or ‘cuda’. Only ‘cuda’ is supported for training the GAN but ‘cpu’ can be used for inference, by default “cuda” if torch.cuda.is_available() else”cpu”.

  • library_size (Optional[int], optional) – Total number of counts per generated cell, by default 20000.

_build_model() None[source]#

Instantiates the Generator and Critic.

_save(path: str | bytes | PathLike) None[source]#

Saves the model.


path (Union[str, bytes, os.PathLike]) – Directory to save the model.

_load(path: str | bytes | PathLike, mode: str | None = 'inference') None[source]#

Loads a saved causal GAN model (.pth file).

  • path (Union[str, bytes, os.PathLike]) – Path to the saved model.

  • mode (Optional[str], optional) – Specify if the loaded model is used for ‘inference’ or ‘training’, by default “inference”.


ValueError – If a mode other than ‘inference’ or ‘training’ is specified.

_train_labelers(real_cells: Tensor) None[source]#

Trains the labeler (on real and fake) and anti-labeler (on fake only).


real_cells (torch.Tensor) – Tensor containing a batch of real cells.

_train_generator() Tensor[source]#

Trains the causal generator for one iteration. :returns: Tensor containing only 1 item, the generator loss. :rtype: torch.Tensor

train(train_files: str, valid_files: str, critic_iter: int, max_steps: int, c_lambda: float, beta1: float, beta2: float, gen_alpha_0: float, gen_alpha_final: float, crit_alpha_0: float, crit_alpha_final: float, labeler_alpha: float, antilabeler_alpha: float, labeler_training_interval: int, checkpoint: str | bytes | PathLike | None = None, output_dir: str | None = 'output', summary_freq: int | None = 5000, plt_freq: int | None = 10000, save_feq: int | None = 10000) None[source]#

Method for training the causal GAN.

  • train_files (str) – Path to training set files (TFrecords supported for now).

  • valid_files (str) – Path to validation set files (TFrecords supported for now).

  • critic_iter (int) – Number of training iterations of the critic for each iteration on the generator.

  • max_steps (int) – Maximum number of steps to train the GAN.

  • c_lambda (float) – Regularization hyper-parameter for gradient penalty.

  • beta1 (float) – Coefficients used for computing running averages of gradient in the optimizer.

  • beta2 (float) – Coefficient used for computing running averages of gradient squares in the optimizer.

  • gen_alpha_0 (float) – Generator’s initial learning rate value.

  • gen_alpha_final (float) – Generator’s final learning rate value.

  • crit_alpha_0 (float) – Critic’s initial learning rate value.

  • crit_alpha_final (float) – Critic’s final learning rate value.

  • labeler_alpha (float) – Labeler’s learning rate value.

  • antilabeler_alpha (float) – Anti-labeler’s learning rate value.

  • labeler_training_interval (int) – The number of steps after which the labeler and anti-labeler are trained. If 20, the labeler and anti-labeler will be trained every 20 steps.

  • checkpoint (Optional[Union[str, bytes, os.PathLike, None]], optional) – Path to a trained model; if specified, the checkpoint is be used to resume training, by default None.

  • output_dir (Optional[str], optional) – Directory to which plots, tfevents, and checkpoints will be saved, by default “output”.

  • summary_freq (Optional[int], optional) – Period between summary logs to TensorBoard, by default 5000.

  • plt_freq (Optional[int], optional) – Period between t-SNE plots, by default 10000.

  • save_feq (Optional[int], optional) – Period between saves of the model, by default 10000.

gans.conditional_gan module#

class gans.conditional_gan.ConditionalGAN(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], device: str | None = 'cpu', library_size: int | None = 20000)[source]#

Bases: GAN, ABC

static _sample_pseudo_labels(batch_size: int, cluster_ratios: Tensor) Tensor[source]#

Randomly samples cluster labels following a multinomial distribution.

  • batch_size (int) – The number of samples to generate (normally equal to training batch size).

  • cluster_ratios (torch.Tensor) – Tensor containing the parameters of the multinomial distribution (ex: torch.Tensor([0.5, 0.3, 0.2]) for 3 clusters with occurence probabilities of 0.5, 0.3, and 0.2 for clusters 0, 1, and 2, respectively).


Tensor containing a batch of samples cluster labels.

Return type:


_generate_tsne_plot(valid_loader: DataLoader, output_dir: str | bytes | PathLike) None[source]#

Generate t-SNE plot during training.

  • valid_loader (DataLoader) – Validation set DataLoader.

  • output_dir (Union[str, bytes, os.PathLike]) – Directory to save the t-SNE plots.

gans.conditional_gan_cat module#

class gans.conditional_gan_cat.ConditionalCatGAN(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], num_classes: int, label_ratios: Tensor, device: str | None = 'cpu', library_size: int | None = 20000)[source]#

Bases: ConditionalGAN

__init__(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], num_classes: int, label_ratios: Tensor, device: str | None = 'cpu', library_size: int | None = 20000) None[source]#

Conditional single-cell RNA-seq GAN using the conditioning method by concatenation.

  • genes_no (int) – Number of genes in the dataset.

  • batch_size (int) – Training batch size.

  • latent_dim (int) – Dimension of the latent space from which the noise vector is sampled.

  • gen_layers (List[int]) – List of integers corresponding to the number of neurons of each generator layer.

  • crit_layers (List[int]) – List of integers corresponding to the number of neurons of each critic layer.

  • num_classes (int) – Number of classes in the dataset.

  • label_ratios (torch.Tensor) – Tensor containing the ratio of each class in the dataset.

  • device (Optional[str], optional) – Specifies to train on ‘cpu’ or ‘cuda’. Only ‘cuda’ is supported for training the GAN but ‘cpu’ can be used for inference, by default “cuda” if torch.cuda.is_available() else”cpu”.

  • library_size (Optional[int], optional) – Total number of counts per generated cell, by default 20000.

_get_gradient(real: Tensor, fake: Tensor, epsilon: Tensor, labels: Tensor | None = None, *args, **kwargs) Tensor[source]#

Compute the gradient of the critic’s scores with respect to interpolations of real and fake cells.

  • real (torch.Tensor) – A batch of real cells.

  • fake (torch.Tensor) – A batch of fake cells.

  • epsilon (torch.Tensor) – A vector of the uniformly random proportions of real/fake per interpolated cells.

  • labels (torch.Tensor) – A batch of real class labels.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.


Gradient of the critic’s score with respect to interpolated data.

Return type:


_cat_one_hot_labels(cells: Tensor, labels: Tensor) Tensor[source]#

Concatenates one-hot encoded labels to a tensor.

  • cells (torch.Tensor) – Tensor to which to concatenate one-hot encoded class labels.

  • labels (torch.Tensor) – Class labels to concatenate.


Tensor with one-hot encoded labels concatenated at the tail.

Return type:


generate_cells(cells_no: int, checkpoint: str | bytes | PathLike | None = None, class_: int | None = None) Tuple[ndarray, ndarray][source]#

Generate cells from the Conditional GAN model.

  • cells_no (int) – Number of cells to generate.

  • checkpoint (Optional[Union[str, bytes, os.PathLike, None]], optional) – Path to the saved trained model, by default None.

  • class (Optional[Union[int, None]] = None) – Class of the cells to generate. If None, cells with the same ratio per class will be generated.


Gene expression matrix of generated cells and their corresponding class labels.

Return type:

Tuple[np.ndarray, np.ndarray]

_build_model() None[source]#

Initializes the Generator and Critic.

_train_critic(real_cells, real_labels, c_lambda) Tuple[Tensor, Tensor][source]#

Trains the critic for one iteration.

  • real_cells (torch.Tensor) – Tensor containing a batch of real cells.

  • real_labels (torch.Tensor) – Tensor containing a batch of real labels (corresponding to real_cells).

  • c_lambda (float) – Regularization hyper-parameter for gradient penalty.


The computed critic loss and gradient penalty.

Return type:

Tuple[torch.Tensor, torch.Tensor]

_train_generator() Tensor[source]#

Trains the generator for one iteration.


Tensor containing only 1 item, the generator loss.

Return type:


gans.conditional_gan_proj module#

class gans.conditional_gan_proj.ConditionalProjGAN(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], num_classes: int, label_ratios: Tensor, device: str | None = 'cpu', library_size: int | None = 20000)[source]#

Bases: ConditionalGAN

__init__(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], num_classes: int, label_ratios: Tensor, device: str | None = 'cpu', library_size: int | None = 20000) None[source]#

Conditional single-cell RNA-seq GAN using the projection conditioning method.

  • genes_no (int) – Number of genes in the dataset.

  • batch_size (int) – Training batch size.

  • latent_dim (int) – Dimension of the latent space from which the noise vector is sampled.

  • gen_layers (List[int]) – List of integers corresponding to the number of neurons of each generator layer.

  • crit_layers (List[int]) – List of integers corresponding to the number of neurons of each critic layer.

  • num_classes (int) – Number of classes in the dataset.

  • label_ratios (torch.Tensor) – Tensor containing the ratio of each class in the dataset.

  • device (Optional[str], optional) – Specifies to train on ‘cpu’ or ‘cuda’. Only ‘cuda’ is supported for training the GAN but ‘cpu’ can be used for inference, by default “cuda” if torch.cuda.is_available() else”cpu”.

  • library_size (Optional[int], optional) – Total number of counts per generated cell, by default 20000.

_get_gradient(real: Tensor, fake: Tensor, epsilon: Tensor, labels: Tensor | None = None, *args, **kwargs) Tensor[source]#

Compute the gradient of the critic’s scores with respect to interpolations of real and fake cells.

  • real (torch.Tensor) – A batch of real cells.

  • fake (torch.Tensor) – A batch of fake cells.

  • epsilon (torch.Tensor) – A vector of the uniformly random proportions of real/fake per interpolated cells.

  • labels (torch.Tensor) – A batch of real class labels.

  • *args – Variable length argument list.

  • **kwargs – Arbitrary keyword arguments.


Gradient of the critic’s score with respect to interpolated data.

Return type:


generate_cells(cells_no: int, checkpoint: str | bytes | PathLike | None = None, class_: int | None = None) Tuple[ndarray, ndarray][source]#

Generate cells from the Conditional GAN model.

  • cells_no (int) – Number of cells to generate.

  • checkpoint (Optional[Union[str, bytes, os.PathLike, None]], optional) – Path to the saved trained model, by default None.

  • class (Optional[Union[int, None]] = None) – Class of the cells to generate. If None, cells with the same ratio per class will be generated.


Gene expression matrix of generated cells and their corresponding class labels.

Return type:

Tuple[np.ndarray, np.ndarray]

_build_model() None[source]#

Initializes the Generator and Critic.

_train_critic(real_cells, real_labels, c_lambda) Tuple[Tensor, Tensor][source]#

Trains the critic for one iteration.

  • real_cells (torch.Tensor) – Tensor containing a batch of real cells.

  • real_labels (torch.Tensor) – Tensor containing a batch of real labels (corresponding to real_cells).

  • c_lambda (float) – Regularization hyper-parameter for gradient penalty.


The computed critic loss and gradient penalty.

Return type:

Tuple[torch.Tensor, torch.Tensor]

_train_generator() Tensor[source]#

Trains the generator for one iteration.


Tensor containing only 1 item, the generator loss.

Return type:


gans.gan module#

class gans.gan.GAN(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], device: str | None = 'cpu', library_size: int | None = 20000)[source]#

Bases: object

__init__(genes_no: int, batch_size: int, latent_dim: int, gen_layers: List[int], crit_layers: List[int], device: str | None = 'cpu', library_size: int | None = 20000) None[source]#

Non-conditional single-cell RNA-seq GAN.

  • genes_no (int) – Number of genes in the dataset.

  • batch_size (int) – Training batch size.

  • latent_dim (int) – Dimension of the latent space from which the noise vector is sampled.

  • gen_layers (List[int]) – List of integers corresponding to the number of neurons of each generator layer.

  • crit_layers (List[int]) – List of integers corresponding to the number of neurons of each critic layer.

  • device (Optional[str], optional) – Specifies to train on ‘cpu’ or ‘cuda’. Only ‘cuda’ is supported for training the GAN but ‘cpu’ can be used for inference, by default “cuda” if torch.cuda.is_available() else”cpu”.

  • library_size (Optional[int], optional) – Total number of counts per generated cell, by default 20000.

static _generate_noise(batch_size: int, latent_dim: int, device: str) Tensor[source]#

Function for creating noise vectors: Given the dimensions (batch_size, latent_dim).

  • batch_size (int) – The number of samples to generate (normally equal to training batch size).

  • latent_dim (int) – Dimension of the latent space to sample from.

  • device (str) – The device type.


A tensor filled with random numbers from the standard normal distribution.

Return type:


static _set_exponential_lr(optimizer: Optimizer, alpha_0: float, alpha_final: float, max_steps: int) ExponentialLR[source]#

Sets up exponentially decaying learning rate scheduler to be used with the optimizer.

  • optimizer (torch.optim.Optimizer) – Optimizer for which to create an exponential learning rate scheduler.

  • alpha_0 (float) – Initial learning rate.

  • alpha_final (float) – Final learning rate.

  • max_steps (int) – Total number of training steps. When current_step=max_steps, alpha_final will be set as the learning rate.


Exponential learning rate scheduler. Call the step() function on this scheduler in the training loop.

Return type:


static _critic_loss(crit_fake_pred: Tensor, crit_real_pred: Tensor, gp: Tensor, c_lambda: float) Tensor[source]#

Compute critic’s loss given the its scores on real and fake cells, the gradient penalty, and gradient penalty regularization hyper-parameter.

  • crit_fake_pred (torch.Tensor) – Critic’s score on fake cells.

  • crit_real_pred (torch.Tensor) – Critic’s score on real cells.

  • gp (torch.Tensor) – Unweighted gradient penalty

  • c_lambda (float) – Regularization hyper-parameter to be used with the gradient penalty in the WGAN loss.


Critic’s loss for the current batch.

Return type:


static _generator_loss(crit_fake_pred: Tensor) Tensor[source]#

Compute the generator loss from the critic’s score of the generated cells.


crit_fake_pred (torch.Tensor) – The critic’s score on fake generated cells.


Generator’s loss value for the current batch.

Return type:


_get_gradient(real: Tensor, fake: Tensor, epsilon: Tensor, *args, **kwargs) Tensor[source]#

Compute the gradient of the critic’s scores with respect to interpolations of real and fake cells.

  • real (torch.Tensor) – A batch of real cells.

  • fake (torch.Tensor) – A batch of fake cells.

  • epsilon (torch.Tensor) – A vector of the uniformly random proportions of real/fake per interpolated cells.


Gradient of the critic’s score with respect to interpolated data.

Return type:


static _gradient_penalty(gradient: Tensor) Tensor[source]#

Compute the gradient penalty given a gradient.


gradient (torch.Tensor) – The gradient of the critic’s score with respect to the interpolated data.


Gradient penalty of the given gradient.

Return type:


generate_cells(cells_no: int, checkpoint: str | bytes | PathLike | None = None) ndarray[source]#

Generate cells from the GAN model.

  • cells_no (int) – Number of cells to generate.

  • checkpoint (Optional[Union[str, bytes, os.PathLike, None]], optional) – Path to the saved trained model, by default None.


Gene expression matrix of generated cells.

Return type:


_save(path: str | bytes | PathLike) None[source]#

Saves the model.


path (Union[str, bytes, os.PathLike]) – Directory to save the model.

_load(path: str | bytes | PathLike, mode: str | None = 'inference') None[source]#

Loads a saved model (.pth file).

  • path (Union[str, bytes, os.PathLike]) – Path to the saved model.

  • mode (Optional[str], optional) – Specify if the loaded model is used for ‘inference’ or ‘training’, by default “inference”.


ValueError – If a mode other than ‘inference’ or ‘training’ is specified.

_build_model() None[source]#

Instantiates the Generator and Critic.

_get_loaders(train_file: str | bytes | PathLike, validation_file: str | bytes | PathLike) Tuple[DataLoader, DataLoader][source]#

Gets training and validation DataLoaders for training.

  • train_file (Union[str, bytes, os.PathLike]) – Path to training files.

  • validation_file (Union[str, bytes, os.PathLike]) – Path to validation files.


Train and Validation Dataloaders.

Return type:

Tuple[DataLoader, DataLoader]

_add_tensorboard_graph(output_dir: str | bytes | PathLike, gen_data: Tensor | Tuple[Tensor], crit_data: Tensor | Tuple[Tensor]) None[source]#

Adds the model graph to TensorBoard.

  • output_dir (Union[str, bytes, os.PathLike]) – Directory to save the tfevents.

  • gen_data (Union[torch.Tensor, Tuple[torch.Tensor]]) – Input to the generator.

  • crit_data (Union[torch.Tensor, Tuple[torch.Tensor]]) – Input to the critic.

_update_tensorboard(gen_loss: float, crit_loss: float, gp: Tensor, gen_lr: float, crit_lr: float, output_dir: str | bytes | PathLike) None[source]#

Updates the TensorBoard summary logs.

  • gen_loss (float) – Generator loss.

  • crit_loss (float) – Critic loss.

  • gp (torch.Tensor) – Gradient penalty.

  • gen_lr (float) – Generator’s optimizer learning rate.

  • crit_lr (float) – Critic’s optimizer learning rate.

  • output_dir (Union[str, bytes, os.PathLike]) – Directory to save the tfevents.

_generate_tsne_plot(valid_loader: DataLoader, output_dir: str | bytes | PathLike) None[source]#

Generates t-SNE plots during training.

  • valid_loader (DataLoader) – Validation set DataLoader.

  • output_dir (Union[str, bytes, os.PathLike]) – Directory to save the t-SNE plots.

_train_critic(real_cells: Tensor, real_labels: Tensor, c_lambda: float) Tuple[Tensor, Tensor][source]#

Trains the critic for one iteration.

  • real_cells (torch.Tensor) – Tensor containing a batch of real cells.

  • real_labels (torch.Tensor) – Tensor containing a batch of real labels (corresponding to real_cells).

  • c_lambda (float) – Regularization hyper-parameter for gradient penalty.


The computed critic loss and gradient penalty.

Return type:

Tuple[torch.Tensor, torch.Tensor]

_train_generator() Tensor[source]#

Trains the generator for one iteration.


Tensor containing only 1 item, the generator loss.

Return type:


train(train_files: str, valid_files: str, critic_iter: int, max_steps: int, c_lambda: float, beta1: float, beta2: float, gen_alpha_0: float, gen_alpha_final: float, crit_alpha_0: float, crit_alpha_final: float, checkpoint: str | bytes | PathLike | None = None, output_dir: str | None = 'output', summary_freq: int | None = 5000, plt_freq: int | None = 10000, save_feq: int | None = 10000) None[source]#

Method for training the GAN.

  • train_files (str) – Path to training set files (TFrecords supported for now).

  • valid_files (str) – Path to validation set files (TFrecords supported for now).

  • critic_iter (int) – Number of training iterations of the critic for each iteration on the generator.

  • max_steps (int) – Maximum number of steps to train the GAN.

  • c_lambda (float) – Regularization hyper-parameter for gradient penalty.

  • beta1 (float) – Coefficients used for computing running averages of gradient in the optimizer.

  • beta2 (float) – Coefficient used for computing running averages of gradient squares in the optimizer.

  • gen_alpha_0 (float) – Generator’s initial learning rate value.

  • gen_alpha_final (float) – Generator’s final learning rate value.

  • crit_alpha_0 (float) – Critic’s initial learning rate value.

  • crit_alpha_final (float) – Critic’s final learning rate value.

  • checkpoint (Optional[Union[str, bytes, os.PathLike, None]], optional) – Path to a trained model; if specified, the checkpoint is be used to resume training, by default None.

  • output_dir (Optional[str], optional) – Directory to which plots, tfevents, and checkpoints will be saved, by default “output”.

  • summary_freq (Optional[int], optional) – Period between summary logs to TensorBoard, by default 5000.

  • plt_freq (Optional[int], optional) – Period between t-SNE plots, by default 10000.

  • save_feq (Optional[int], optional) – Period between saves of the model, by default 10000.

