Source code for networks.masked_causal_generator

import itertools
import typing

import torch
from layers.lsn import LSN
from layers.masked_linear import MaskedLinear
from torch import nn
from torch.nn.modules.activation import ReLU


[docs] class CausalGenerator(nn.Module):
[docs] def __init__( self, z_input: int, noise_per_gene: int, depth_per_gene: int, width_scale_per_gene: int, causal_controller: nn.Module, causal_graph: typing.Dict[int, typing.Set[int]], library_size: typing.Optional[typing.Union[int, None]] = None, device: typing.Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", ) -> None: """ Causal Generator's constructor. Parameters ---------- z_input : int The dimension of the noise tensor. noise_per_gene : int Dimension of the latent space from which the noise vectors used by target generators is sampled. depth_per_gene : int Depth of the target generator networks. width_scale_per_gene : int The width scale used for the target generator networks. if width_scale_per_gene = 2 and a gene is regulated by 10 TFs and 1 noise vector, the width of the target gene generator will be 2 * (10 + 1) = 22. Assuming 1000 target genes, each regulated by 10 TFs and 1 noise, the total width of the sparse target generator will be 22000. causal_controller : nn.Module Causal controller module (retrieved from checkpoint if pretrained). It is a GAN trained on genes and TFs with the LSN layer removed after training. It cannot be trained on TFs only since the library size has to be enforced. However, during causal generator training, only TFs are used. causal_graph : typing.Dict[int, typing.Set[int]] The causal graph is a dictionary representing the TRN to impose. It has the following format: {target gene index: {TF1 index, TF2 index, ...}}. This causal graph has to be acyclic and bipartite. A TF cannot be regulated by another TF. Invalid: {1: {2, 3, {4, 6}}, ...} - a regulator (TF) is regulated by another regulator (TF) Invalid: {1: {2, 3, 4}, 2: {4, 3, 5}, ...} - a regulator (TF) is also regulated Invalid: {4: {2, 3}, 2: {4, 3}} - contains a cycle Valid causal graph example: {1: {2, 3, 4}, 6: {5, 4, 2}, ...} library_size : typing.Optional[typing.Union[int, None]], optional Total number of counts per generated cell, by default None device : typing.Optional[str], optional Specifies to train on 'cpu' or 'cuda'. Only 'cuda' is supported for training the GAN but 'cpu' can be used for inference, by default "cuda" if torch.cuda.is_available() else"cpu". """ super().__init__() self.z_input = z_input self.noise_per_gene = noise_per_gene self.depth_per_gene = depth_per_gene self.width_scale_per_gene = width_scale_per_gene self.causal_graph = causal_graph self.library_size = library_size self.device = device self._causal_controller = causal_controller self._generator = None self.genes = list(self.causal_graph.keys()) self.regulators = list( # all gene regulating TFs (can contain duplicate TFs) itertools.chain.from_iterable(self.causal_graph.values()) ) self.tfs = list(set(self.regulators)) # if a gene has X number of regulators (TFs + noises), it will have a # hidden layer with the width of (hidden_width * num_regulators) self.num_genes = len(self.genes) self.num_noises = self.num_genes * self.noise_per_gene # 1 noise vector / gene self.num_tfs = len(self.tfs) # number of TFs # For performing perturbation studies. # In perturbation mode, TF expressions, noise vectors, and LSN layer are frozen. self.pert_mode = False self.tf_expressions = None self.noise = None self._lsn = LSN(self.library_size) self._create_generator() self._create_labeler()
[docs] def forward(self, noise: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Function for completing a forward pass of the generator. This includes a forward pass of the causal controller to generate TFs. TFs and generated noise are then used to complete a forward pass of the causal generator. Parameters ---------- noise : torch.Tensor The noise used as input by the causal controller. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor The output of the causal generator (gene expression matrix). """ tf_expressions = self._causal_controller(noise) tf_expressions = tf_expressions[:, self.tfs] tf_expressions = tf_expressions.detach() # use the same tf expressions as the previous forward pass in perturbation mode if self.pert_mode: if self.tf_expressions is not None: tf_expressions = self.tf_expressions else: self.tf_expressions = tf_expressions batch_size = tf_expressions.shape[0] # create placeholder for cells cells = torch.zeros(batch_size, self.num_tfs + self.num_genes).to(self.device) cells = cells.index_add_( 1, torch.tensor(self.tfs).to(self.device), tf_expressions ) # lazy way of avoiding a circular dependency # FIXME: circular dependency from gans.gan import GAN noise = GAN._generate_noise(batch_size, self.num_noises, self.device) if self.pert_mode: if self.noise is not None: noise = self.noise else: self.noise = noise regulators = torch.cat([tf_expressions, noise], dim=1) gene_expression = self._generator(regulators) cells = cells.index_add_( 1, torch.tensor(self.genes).to(self.device), gene_expression ) if self.library_size is not None: # reuse previous LSN scale in perturbation mode cells = self._lsn(cells, reuse_scale=self.pert_mode) return cells
[docs] def _create_generator(self) -> None: """ Method for creating the Causal Generator's network. An independent generator can be created for each gene. In that case, a pass of the causal generator would require a pass of generator networks individually in a loop (since all gene expressions are needed before being passed to the LSN layer), which is very inefficient. Instead, we create a single large Causal Generator containing sparse connections to logically create independent generators for each gene. This is done by creating 3 masks: input mask: contains connections between genes and their regulating TFs/noise hidden mask: contains connections between hidden layers such that there is no connection between hidden layers of two genes' generators output mask: contains connections between hidden layers of each gene's generator and its expression (before LSN) The MaskedLinear module is used to mask weights and gradients in linear layers. """ hidden_dims = ( len(self.regulators) + self.num_noises ) * self.width_scale_per_gene # noise mask will be added to TF mask input_mask = torch.zeros(self.num_tfs, hidden_dims).to(self.device) hidden_mask = torch.zeros(hidden_dims, hidden_dims).to(self.device) output_mask = torch.zeros(hidden_dims, self.num_genes).to(self.device) prev_gene_hidden_dims = 0 for gene, gene_regulators in self.causal_graph.items(): gene_idx = self.genes.index(gene) curr_gene_hidden_dims = self.width_scale_per_gene * ( len(gene_regulators) + self.noise_per_gene ) for gene_regulator in gene_regulators: gene_regulator_idx = self.tfs.index(gene_regulator) # mask for the tfs input_mask[ gene_regulator_idx, prev_gene_hidden_dims : prev_gene_hidden_dims + curr_gene_hidden_dims, ] = 1 # mask for the noises noise_mask = torch.zeros(self.noise_per_gene, hidden_dims).to(self.device) noise_mask[ :, prev_gene_hidden_dims : prev_gene_hidden_dims + curr_gene_hidden_dims ] = 1 input_mask = torch.cat([input_mask, noise_mask]) # mask for hidden layer hidden_mask[ prev_gene_hidden_dims : prev_gene_hidden_dims + curr_gene_hidden_dims, prev_gene_hidden_dims : prev_gene_hidden_dims + curr_gene_hidden_dims, ] = 1 # mask for final layer output_mask[ prev_gene_hidden_dims : prev_gene_hidden_dims + curr_gene_hidden_dims, gene_idx, ] = 1 prev_gene_hidden_dims += curr_gene_hidden_dims generator_layers = nn.ModuleList() # input block generator_layers.append(self._create_generator_block(input_mask)) # hidden block for _ in range(self.depth_per_gene): generator_layers.append(self._create_generator_block(hidden_mask)) # output block generator_layers.append( self._create_generator_block(output_mask, final_layer=True) ) self._generator = nn.Sequential(*generator_layers)
[docs] def _create_labeler(self): self._labeler = nn.Sequential( nn.Linear(self.num_genes, self.num_genes * 2), nn.BatchNorm1d(self.num_genes * 2), nn.ReLU(inplace=True), nn.Linear(self.num_genes * 2, self.num_genes * 2), nn.BatchNorm1d(self.num_genes * 2), nn.ReLU(inplace=True), nn.Linear(self.num_genes * 2, self.num_genes * 2), nn.BatchNorm1d(self.num_genes * 2), nn.ReLU(inplace=True), nn.Linear(self.num_genes * 2, self.num_tfs), )
[docs] def _create_generator_block( self, mask: torch.Tensor, library_size: typing.Optional[typing.Union[int, None]] = None, final_layer: typing.Optional[bool] = False, ) -> nn.Sequential: """ Method for creating a sequence of operations corresponding to a masked causal generator block; a masked linear layer, a batchnorm (except in the final block), and ReLU. Parameters ---------- mask : torch.Tensor Mask Tensor with shape (n_input_feature, n_output_feature). library_size : typing.Optional[typing.Union[int, None]], optional Total number of counts per generated cell, by default None. final_layer : typing.Optional[bool], optional Indicates if the block contains the final layer, by default False. Returns ------- nn.Sequential Sequential container containing the modules. """ masked_linear = MaskedLinear(mask, device=self.device) if not final_layer: nn.init.xavier_uniform_(masked_linear.weight) masked_linear.reapply_mask() return nn.Sequential( masked_linear, nn.BatchNorm1d(mask.shape[1]), nn.ReLU(inplace=True), ) else: nn.init.kaiming_normal_( masked_linear.weight, mode="fan_in", nonlinearity="relu" ) masked_linear.reapply_mask() torch.nn.init.zeros_(masked_linear.bias) if library_size is not None: return nn.Sequential(masked_linear, ReLU(), LSN(library_size)) else: return nn.Sequential(masked_linear, ReLU())
[docs] def freeze_causal_controller(self): """Freezes the pretrained causal controller and disallows any further updates.""" for param in self._causal_controller.parameters(): param.requires_grad = False