Source code for networks.labeler

import typing

import torch
from torch import nn


[docs] class Labeler(nn.Module):
[docs] def __init__( self, num_genes: int, num_tfs: int, labeler_layers: typing.List[int] ) -> None: """ Labeler network's constructor. Parameters ---------- num_genes : int Number of target genes (all genes excluding TFs) in the dataset. num_tfs : int Number of transcription factors in the dataset. labeler_layers : typing.List[int] List of integers corresponding to the number of neurons at each deep layer of the labeler. """ super(Labeler, self).__init__() self.num_genes = num_genes self.num_tfs = num_tfs self.labeler_layers = labeler_layers self._create_labeler()
[docs] def forward(self, target_genes: torch.Tensor) -> torch.Tensor: """ Function for completing a forward pass of the labeler. This network performs a regression by predicting TF expression from target gene expression. Parameters ---------- target_genes : torch.Tensor Tensor containing target gene expression of (fake/real) cells. Returns ------- torch.Tensor Tensor containing regulatory TFs. """ return self._labeler(target_genes)
[docs] def _create_labeler(self) -> None: """Method for creating a labeler network.""" layers = [] input_dim = self.num_genes for output_dim in self.labeler_layers: layers.append( nn.Sequential( nn.Linear(input_dim, output_dim), nn.BatchNorm1d(output_dim), nn.ReLU(inplace=True), ) ) input_dim = output_dim layers.append(nn.Sequential(nn.Linear(input_dim, self.num_tfs))) self._labeler = nn.Sequential(*layers)