import typing
import torch
from torch import nn
[docs]class Critic(nn.Module):
[docs]    def __init__(self, x_input: int, critic_layers: typing.List[int]) -> None:
        """
        Non-conditional Critic's constructor.
        Parameters
        ----------
        x_input : int
            The dimension of the input tensor.
        critic_layers : typing.List[int]
            List of integers corresponding to the number of neurons
            at each hidden layer of the critic.
        """
        super(Critic, self).__init__()
        self.x_input = x_input
        self.critic_layers = critic_layers
        self._create_critic() 
[docs]    def forward(self, data: torch.Tensor, *args, **kwargs) -> torch.Tensor:
        """
        Function for completing a forward pass of the critic.
        Parameters
        ----------
        data : torch.Tensor
            Tensor containing gene expression of (fake/real) cells.
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.
        Returns
        -------
        torch.Tensor
            1-dimensional tensor representing fake/real cells.
        """
        return self._critic(data) 
[docs]    def _create_critic(self) -> None:
        """Method for creating the Critic's network."""
        layers = []
        input_size = self.x_input
        for output_size in self.critic_layers:
            layers.append(self._create_critic_block(input_size, output_size))
            input_size = output_size  # update input size for the next layer
        # outermost layer
        layers.append(self._create_critic_block(input_size, 1, final_layer=True))
        self._critic = nn.Sequential(*layers) 
[docs]    @staticmethod
    def _create_critic_block(
        input_dim: int, output_dim: int, final_layer: typing.Optional[bool] = False
    ) -> nn.Sequential:
        """
        Function for creating a sequence of operations corresponding to
        a Critic block; a linear layer, and ReLU (except in the final block).
        Parameters
        ----------
        input_dim : int
            The block's input dimensions.
        output_dim : int
            The block's output dimensions.
        final_layer : typing.Optional[bool], optional
            Indicates if the block contains the final layer, by default False.
        Returns
        -------
        nn.Sequential
            Sequential container containing the modules.
        """
        linear_layer = nn.Linear(input_dim, output_dim)
        torch.nn.init.zeros_(linear_layer.bias)
        if not final_layer:
            torch.nn.init.kaiming_normal_(
                linear_layer.weight, mode="fan_in", nonlinearity="relu"
            )
            return nn.Sequential(linear_layer, nn.ReLU(inplace=True))
        # don't use an activation function at the
        # outermost layer of the critic's network
        else:
            torch.nn.init.xavier_uniform_(linear_layer.weight)
            return nn.Sequential(linear_layer)  
[docs]class ConditionalCritic(Critic):
[docs]    def __init__(
        self, x_input: int, critic_layers: typing.List[int], num_classes: int
    ) -> None:
        """
        Conditional Critic's constructor - Projection Discriminator (Miyato et al.,2018).
        Parameters
        ----------
        x_input : int
            The dimension of the input tensor.
        critic_layers : typing.List[int]
            List of integers corresponding to the number of neurons
            at each hidden layer of the critic.
        num_classes : int
            Number of clusters.
        """
        self.num_classes = num_classes
        super(ConditionalCritic, self).__init__(x_input, critic_layers) 
[docs]    def forward(
        self, data: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs
    ) -> torch.Tensor:
        """
        Function for completing a forward pass of the conditional critic.
        Parameters
        ----------
        data : torch.Tensor
            Tensor containing gene expression of (fake/real) cells.
        labels : torch.Tensor
            Tensor containing labels corresponding to cells (data parameter).
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.
        Returns
        -------
        torch.Tensor
            1-dimensional tensor representing fake/real cells.
        """
        y = data
        for layer in self._critic[:-2]:
            y = layer(y)
        output = self._critic[-2](y)
        proj = self._critic[-1](labels)
        output += torch.sum(proj * y, dim=1, keepdim=True)
        return output 
[docs]    def _create_critic(self) -> None:
        """Method for creating the Conditional Critic's network."""
        self._critic = nn.ModuleList()
        input_size = self.x_input
        for output_size in self.critic_layers:
            self._critic.append(self._create_critic_block(input_size, output_size))
            input_size = output_size  # update input size for the next layer
        # outermost layer
        self._critic.append(self._create_critic_block(input_size, 1, final_layer=True))
        # projection layer
        proj_layer = nn.Embedding(self.num_classes, input_size)
        nn.init.xavier_uniform_(proj_layer.weight)
        self._critic.append(proj_layer)  
[docs]class ConditionalCriticProj(Critic):
[docs]    def __init__(
        self, x_input: int, critic_layers: typing.List[int], num_classes: int
    ) -> None:
        """
        Conditional Critic's constructor using a modified implementation of
        Projection Discriminator (Marouf et al, 2020).
        Parameters
        ----------
        x_input : int
            The dimension of the input tensor.
        critic_layers : typing.List[int]
            List of integers corresponding to the number of neurons
            at each hidden layer of the critic.
        num_classes : int
            Number of clusters.
        """
        self.num_classes = num_classes
        super(ConditionalCriticProj, self).__init__(x_input, critic_layers) 
[docs]    def forward(
        self, data: torch.Tensor, labels: torch.Tensor = None, *args, **kwargs
    ) -> torch.Tensor:
        """
        Function for completing a forward pass of the conditional critic.
        Parameters
        ----------
        data : torch.Tensor
            Tensor containing gene expression of (fake/real) cells.
        labels : torch.Tensor
            Tensor containing labels corresponding to cells (data parameter).
        *args
            Variable length argument list.
        **kwargs
            Arbitrary keyword arguments.
        Returns
        -------
        torch.Tensor
            1-dimensional tensor representing fake/real cells.
        """
        y = data
        for layer in self._critic[:-2]:
            y = layer(y)
        output = self._critic[-2](labels)
        proj = self._critic[-1](labels)
        output += torch.sum(proj * y, dim=1, keepdim=True)
        return output 
[docs]    def _create_critic(self) -> None:
        """Method for creating the Conditional Critic's network."""
        self._critic = nn.ModuleList()
        input_size = self.x_input
        for output_size in self.critic_layers:
            self._critic.append(self._create_critic_block(input_size, output_size))
            input_size = output_size  # update input size for the next layer
        # bias layer (replaces the output linear layer)
        proj_bias = nn.Embedding(self.num_classes, 1)
        torch.nn.init.zeros_(proj_bias.weight)
        self._critic.append(proj_bias)
        # projection layer
        proj_layer = nn.Embedding(self.num_classes, input_size)
        nn.init.xavier_uniform_(proj_layer.weight)
        self._critic.append(proj_layer)