Source code for networks.labeler
import typing
import torch
from torch import nn
[docs]
class Labeler(nn.Module):
[docs]
def __init__(
self, num_genes: int, num_tfs: int, labeler_layers: typing.List[int]
) -> None:
"""
Labeler network's constructor.
Parameters
----------
num_genes : int
Number of target genes (all genes excluding TFs) in the dataset.
num_tfs : int
Number of transcription factors in the dataset.
labeler_layers : typing.List[int]
List of integers corresponding to the number of neurons
at each deep layer of the labeler.
"""
super(Labeler, self).__init__()
self.num_genes = num_genes
self.num_tfs = num_tfs
self.labeler_layers = labeler_layers
self._create_labeler()
[docs]
def forward(self, target_genes: torch.Tensor) -> torch.Tensor:
"""
Function for completing a forward pass of the labeler.
This network performs a regression by predicting TF expression
from target gene expression.
Parameters
----------
target_genes : torch.Tensor
Tensor containing target gene expression of (fake/real) cells.
Returns
-------
torch.Tensor
Tensor containing regulatory TFs.
"""
return self._labeler(target_genes)
[docs]
def _create_labeler(self) -> None:
"""Method for creating a labeler network."""
layers = []
input_dim = self.num_genes
for output_dim in self.labeler_layers:
layers.append(
nn.Sequential(
nn.Linear(input_dim, output_dim),
nn.BatchNorm1d(output_dim),
nn.ReLU(inplace=True),
)
)
input_dim = output_dim
layers.append(nn.Sequential(nn.Linear(input_dim, self.num_tfs)))
self._labeler = nn.Sequential(*layers)