Source code for sc_dataset

import os
import typing

import scanpy as sc
import torch
from torch.utils.data import DataLoader, Dataset


[docs] class SCDataset(Dataset):
[docs] def __init__(self, path: typing.Union[str, bytes, os.PathLike]) -> None: """ Create a dataset from the h5ad processed data. Use the preprocessing/preprocess.py script to create the h5ad train, test, and validation files. Parameters ---------- path : typing.Union[str, bytes, os.PathLike] Path to the h5ad file. """ self.data = sc.read_h5ad(path) self.cells = torch.from_numpy(self.data.X) self.clusters = torch.from_numpy( self.data.obs.cluster.to_numpy(dtype=int) )
[docs] def __getitem__(self, index: int) -> typing.Tuple[torch.Tensor, torch.Tensor]: """ Parameters ---------- index : int Returns ------- typing.Tuple[torch.Tensor, torch.Tensor] Gene expression, Cluster label Tensor tuple. """ return self.cells[index], self.clusters[index]
[docs] def __len__(self) -> int: """ Returns ------- int Number of samples (cells). """ return self.cells.shape[0]
[docs] def get_loader( file_path: typing.Union[str, bytes, os.PathLike], batch_size: typing.Optional[int] = None, ) -> DataLoader: """ Provides an IterableLoader over a scRNA-seq Dataset read from given h5ad file. Parameters ---------- file_path : typing.Union[str, bytes, os.PathLike] Path to the h5ad file. batch_size : typing.Optional[int] Training batch size. If not specified, the entire dataset is returned at each load. Returns ------- DataLoader Iterable data loader over the dataset. """ dataset = SCDataset(file_path) # return the whole dataset if batch size if not specified if batch_size is None: batch_size = len(dataset) return DataLoader(dataset, batch_size, shuffle=True, drop_last=True)