Source code for networks.critic

import typing

import torch
from torch import nn


[docs] class Critic(nn.Module):
[docs] def __init__(self, x_input: int, critic_layers: typing.List[int]) -> None: """ Non-conditional Critic's constructor. Parameters ---------- x_input : int The dimension of the input tensor. critic_layers : typing.List[int] List of integers corresponding to the number of neurons at each hidden layer of the critic. """ super(Critic, self).__init__() self.x_input = x_input self.critic_layers = critic_layers self._create_critic()
[docs] def forward(self, data: torch.Tensor, *args, **kwargs) -> torch.Tensor: """ Function for completing a forward pass of the critic. Parameters ---------- data : torch.Tensor Tensor containing gene expression of (fake/real) cells. *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor 1-dimensional tensor representing fake/real cells. """ return self._critic(data)
[docs] def _create_critic(self) -> None: """Method for creating the Critic's network.""" layers = [] input_size = self.x_input for output_size in self.critic_layers: layers.append(self._create_critic_block(input_size, output_size)) input_size = output_size # update input size for the next layer # outermost layer layers.append(self._create_critic_block(input_size, 1, final_layer=True)) self._critic = nn.Sequential(*layers)
[docs] @staticmethod def _create_critic_block( input_dim: int, output_dim: int, final_layer: typing.Optional[bool] = False ) -> nn.Sequential: """ Function for creating a sequence of operations corresponding to a Critic block; a linear layer, and ReLU (except in the final block). Parameters ---------- input_dim : int The block's input dimensions. output_dim : int The block's output dimensions. final_layer : typing.Optional[bool], optional Indicates if the block contains the final layer, by default False. Returns ------- nn.Sequential Sequential container containing the modules. """ linear_layer = nn.Linear(input_dim, output_dim) torch.nn.init.zeros_(linear_layer.bias) if not final_layer: torch.nn.init.kaiming_normal_( linear_layer.weight, mode="fan_in", nonlinearity="relu" ) return nn.Sequential(linear_layer, nn.ReLU(inplace=True)) # don't use an activation function at the # outermost layer of the critic's network else: torch.nn.init.xavier_uniform_(linear_layer.weight) return nn.Sequential(linear_layer)
[docs] class ConditionalCritic(Critic):
[docs] def __init__( self, x_input: int, critic_layers: typing.List[int], num_classes: int ) -> None: """ Conditional Critic's constructor - Projection Discriminator (Miyato et al.,2018). Parameters ---------- x_input : int The dimension of the input tensor. critic_layers : typing.List[int] List of integers corresponding to the number of neurons at each hidden layer of the critic. num_classes : int Number of clusters. """ self.num_classes = num_classes super(ConditionalCritic, self).__init__(x_input, critic_layers)
[docs] def forward( self, data: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs ) -> torch.Tensor: """ Function for completing a forward pass of the conditional critic. Parameters ---------- data : torch.Tensor Tensor containing gene expression of (fake/real) cells. labels : torch.Tensor Tensor containing labels corresponding to cells (data parameter). *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor 1-dimensional tensor representing fake/real cells. """ y = data for layer in self._critic[:-2]: y = layer(y) output = self._critic[-2](y) proj = self._critic[-1](labels) output += torch.sum(proj * y, dim=1, keepdim=True) return output
[docs] def _create_critic(self) -> None: """Method for creating the Conditional Critic's network.""" self._critic = nn.ModuleList() input_size = self.x_input for output_size in self.critic_layers: self._critic.append(self._create_critic_block(input_size, output_size)) input_size = output_size # update input size for the next layer # outermost layer self._critic.append(self._create_critic_block(input_size, 1, final_layer=True)) # projection layer proj_layer = nn.Embedding(self.num_classes, input_size) nn.init.xavier_uniform_(proj_layer.weight) self._critic.append(proj_layer)
[docs] class ConditionalCriticProj(Critic):
[docs] def __init__( self, x_input: int, critic_layers: typing.List[int], num_classes: int ) -> None: """ Conditional Critic's constructor using a modified implementation of Projection Discriminator (Marouf et al, 2020). Parameters ---------- x_input : int The dimension of the input tensor. critic_layers : typing.List[int] List of integers corresponding to the number of neurons at each hidden layer of the critic. num_classes : int Number of clusters. """ self.num_classes = num_classes super(ConditionalCriticProj, self).__init__(x_input, critic_layers)
[docs] def forward( self, data: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs ) -> torch.Tensor: """ Function for completing a forward pass of the conditional critic. Parameters ---------- data : torch.Tensor Tensor containing gene expression of (fake/real) cells. labels : torch.Tensor Tensor containing labels corresponding to cells (data parameter). *args Variable length argument list. **kwargs Arbitrary keyword arguments. Returns ------- torch.Tensor 1-dimensional tensor representing fake/real cells. """ y = data for layer in self._critic[:-2]: y = layer(y) output = self._critic[-2](labels) proj = self._critic[-1](labels) output += torch.sum(proj * y, dim=1, keepdim=True) return output
[docs] def _create_critic(self) -> None: """Method for creating the Conditional Critic's network.""" self._critic = nn.ModuleList() input_size = self.x_input for output_size in self.critic_layers: self._critic.append(self._create_critic_block(input_size, output_size)) input_size = output_size # update input size for the next layer # bias layer (replaces the output linear layer) proj_bias = nn.Embedding(self.num_classes, 1) torch.nn.init.zeros_(proj_bias.weight) self._critic.append(proj_bias) # projection layer proj_layer = nn.Embedding(self.num_classes, input_size) nn.init.xavier_uniform_(proj_layer.weight) self._critic.append(proj_layer)