Source code for layers.lsn

import typing

import torch
from torch import nn


[docs] class LSN(nn.Module):
[docs] def __init__( self, library_size: int, device: typing.Optional[str] = "cuda" if torch.cuda.is_available() else "cpu", ) -> None: """ Library size normalization (LSN) layer. Parameters ---------- library_size : int Total number of counts per generated cell. device : typing.Optional[str], optional Specifies to train on 'cpu' or 'cuda'. Only 'cuda' is supported for training the GAN but 'cpu' can be used for inference, by default "cuda" if torch.cuda.is_available() else"cpu". """ super().__init__() self.library_size = library_size self.device = device self.scale = None
[docs] def forward( self, in_: torch.Tensor, reuse_scale: typing.Optional[bool] = False ) -> torch.Tensor: """ Function for completing a forward pass of the LSN layer. Parameters ---------- in_ : torch.Tensor Tensor containing gene expression of cells. reuse_scale : typing.Optional[bool], optional If set to true, the LSN layer will scale the cells by the same scale as the previous batch. Useful for performing perturbation studies. By default False Returns ------- torch.Tensor Gene expression of cells after library size normalization. """ gammas = torch.ones(in_.shape[0]).to(self.device) * self.library_size sigmas = torch.sum(in_, 1) scale = torch.div(gammas, sigmas) if reuse_scale: if self.scale is not None: scale = self.scale # use previously set scale if not first pass through the frozen LSN layer else: self.scale = scale # if first pass through the frozen LSN layer else: self.scale = None # unfreeze LSN scale if set return torch.nan_to_num( torch.transpose(torch.transpose(in_, 0, 1) * scale, 0, 1), nan=0.0 ) # possible NaN if all genes are zero-expressed - NaNs are thus replaced with zeros