Source code for networks.generator

import typing

import torch
from layers.cbn import ConditionalBatchNorm
from layers.lsn import LSN
from torch import nn
from torch.nn.modules.activation import ReLU


[docs] class Generator(nn.Module):
[docs] def __init__( self, z_input: int, output_cells_dim: int, gen_layers: typing.List[int], library_size: typing.Optional[typing.Union[int, None]] = None, ) -> None: """ Non-conditional Generator's constructor. Parameters ---------- z_input : int The dimension of the noise tensor. output_cells_dim : int The dimension of the output cells (number of genes). gen_layers : typing.List[int] List of integers corresponding to the number of neurons at each hidden layer of the generator. library_size : typing.Optional[typing.Union[int, None]] Total number of counts per generated cell. """ super(Generator, self).__init__() self.z_input = z_input self.output_cells_dim = output_cells_dim self.gen_layers = gen_layers self.library_size = library_size self._create_generator()
[docs] def forward(self, noise: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Function for completing a forward pass of the generator. Parameters ---------- noise : torch.Tensor The noise used as input by the generator. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor The output of the generator (genes of the generated cell). """ return self._generator(noise)
[docs] def _create_generator(self) -> None: """Method for creating the Generator's network.""" layers = [] input_size = self.z_input for output_size in self.gen_layers: layers.append(self._create_generator_block(input_size, output_size)) input_size = output_size # update input size for the next layer # outermost layer layers.append( self._create_generator_block( input_size, self.output_cells_dim, self.library_size, final_layer=True ) ) self._generator = nn.Sequential(*layers)
[docs] @staticmethod def _create_generator_block( input_dim: int, output_dim: int, library_size: typing.Optional[typing.Union[int, None]] = None, final_layer: typing.Optional[bool] = False, *args, **kwargs ) -> nn.Sequential: """ Function for creating a sequence of operations corresponding to a Generator block; a linear layer, a batchnorm (except in the final block), a ReLU, and LSN in the final layer. Parameters ---------- input_dim : int The block's input dimensions. output_dim : int The block's output dimensions. library_size : typing.Optional[typing.Union[int, None]], optional Total number of counts per generated cell, by default None. final_layer : typing.Optional[bool], optional Indicates if the block contains the final layer, by default False. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- nn.Sequential Sequential container containing the modules. """ linear_layer = nn.Linear(input_dim, output_dim) if not final_layer: nn.init.xavier_uniform_(linear_layer.weight) return nn.Sequential( linear_layer, nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True), ) else: # * Unable to find variance_scaling_initializer() with FAN_AVG mode nn.init.kaiming_normal_( linear_layer.weight, mode="fan_in", nonlinearity="relu" ) torch.nn.init.zeros_(linear_layer.bias) if library_size is not None: return nn.Sequential(linear_layer, ReLU(), LSN(library_size)) else: return nn.Sequential(linear_layer, ReLU())
[docs] class ConditionalGenerator(Generator):
[docs] def __init__( self, z_input: int, output_cells_dim: int, num_classes: int, gen_layers: typing.List[int], library_size: typing.Optional[typing.Union[int, None]] = None, ) -> None: """ Conditional Generator's constructor. Parameters ---------- z_input : int The dimension of the noise tensor. output_cells_dim : int The dimension of the output cells (number of genes). num_classes : int Number of clusters. gen_layers : typing.List[int] List of integers corresponding to the number of neurons at each hidden layer of the generator. library_size : typing.Optional[typing.Union[int, None]], optional Total number of counts per generated cell, by default None. """ self.num_classes = num_classes super(ConditionalGenerator, self).__init__( z_input, output_cells_dim, gen_layers, library_size )
[docs] def forward( self, noise: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs ) -> torch.Tensor: """ Function for completing a forward pass of the generator. Parameters ---------- noise : torch.Tensor The noise used as input by the generator. labels : torch.Tensor Tensor containing labels corresponding to cells to generate. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor The output of the generator (genes of the generated cell). """ y = noise for layer in self._generator: if isinstance(layer, ConditionalBatchNorm): y = layer(y, labels) else: y = layer(y) return y
[docs] def _create_generator(self) -> None: """Method for creating the Generator's network.""" self._generator = nn.ModuleList() input_size = self.z_input for output_size in self.gen_layers: layers = self._create_generator_block( input_size, output_size, num_classes=self.num_classes ) for layer in layers: self._generator.append(layer) input_size = output_size # update input size for the next layer # outermost layer self._generator.append( self._create_generator_block( input_size, self.output_cells_dim, self.library_size, final_layer=True, num_classes=self.num_classes, ) )
[docs] @staticmethod def _create_generator_block( input_dim: int, output_dim: int, library_size: typing.Optional[typing.Union[int, None]] = None, final_layer: typing.Optional[bool] = False, num_classes: int = None, *args, **kwargs ) -> typing.Union[nn.Sequential, tuple]: """ Function for creating a sequence of operations corresponding to a Conditional Generator block; a linear layer, a conditional batchnorm (except in the final block), a ReLU, and LSN in the final layer. Parameters ---------- input_dim : int The block's input dimensions. output_dim : int The block's output dimensions. library_size : typing.Optional[typing.Union[int, None]], optional Total number of counts per generated cell, by default None. final_layer : typing.Optional[bool], optional Indicates if the block contains the final layer, by default False. num_classes : int Number of clusters. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- typing.Union[nn.Sequential, tuple] Sequential container or tuple containing modules. """ linear_layer = nn.Linear(input_dim, output_dim) if not final_layer: nn.init.xavier_uniform_(linear_layer.weight) return ( linear_layer, ConditionalBatchNorm(output_dim, num_classes), nn.ReLU(inplace=True), ) else: nn.init.kaiming_normal_( linear_layer.weight, mode="fan_in", nonlinearity="relu" ) torch.nn.init.zeros_(linear_layer.bias) if library_size is not None: return nn.Sequential(linear_layer, ReLU(), LSN(library_size)) else: return nn.Sequential(linear_layer, ReLU())