import typing
import torch
from tfrecord.torch.dataset import MultiTFRecordDataset, TFRecordDataset
from torch.utils.data.dataloader import DataLoader
[docs]
def get_loader(
genes_no: int,
file_path: typing.Union[str, typing.List[str]],
batch_size: int,
splits: typing.Optional[typing.Dict[str, float]] = None,
description: typing.Union[typing.List[str], typing.Dict[str, str], None] = None,
compression_type: typing.Optional[str] = "gzip",
multi_read: typing.Optional[bool] = False,
get_clusters: typing.Optional[bool] = False,
) -> DataLoader:
"""
Provides an IterableLoader over a Dataset read from given tfrecord files for PyTorch.
Currently used to create data loaders from the PBMC preprocessed dataset in tfrecord
from scGAN (Marouf et al.,2020). description parameter and post_process function
can be modified to accommodate more tfrecord datasets.
Parameters
----------
genes_no : int
Number of genes in the expression matrix.
file_path : typing.Union[str, typing.List[str]]
Tfrecord file path for reading a single tfrecord (multi_read=False)
or file pattern for reading multiple tfrecords (ex: /path/{}.tfrecord).
batch_size : int
Training batch size.
splits : typing.Optional[typing.Dict[str, float]], optional
Dictionary of (key, value) pairs, where the key is used to construct
the data and index path(s) and the value determines the contribution
of each split to the batch. Provide when reading from multiple tfrecords
(multi_read=True), by default None.
description : typing.Union[typing.List[str], typing.Dict[str, str], None], optional
List of keys or dict of (key, value) pairs to extract from each record.
The keys represent the name of the features and the values ("byte", "float", or "int"),
by default { "indices": None, "values": None, }.
compression_type : typing.Optional[str], optional
The type of compression used for the tfrecord. Either 'gzip' or None, by default "gzip".
multi_read : typing.Optional[bool], optional
Specifies whether to construct the dataset from multiple tfrecords.
If True, a file pattern should be passed to file_path, by default False.
get_clusters : typing.Optional[bool], optional
If True, the returned data loader will contain the cluster label of cells in
addition to their gene expression values, by default False.
Returns
-------
DataLoader
Iterable data loader over the dataset.
"""
def post_process(
records: typing.Dict,
) -> typing.Union[typing.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]:
"""
Transform function to preprocess gene expression. Builds the dense
gene expression tensor from a sparse representation based on a
list of gene indices and corresponding expression values.
Parameters
----------
records : typing.Dict
Parsed tfrecord.
Returns
-------
typing.Union[typing.Tuple[torch.Tensor, torch.Tensor], torch.Tensor]
A cell's vector of expression levels with or without associated cluster label.
"""
indices = torch.from_numpy(records["indices"])
values = torch.from_numpy(records["values"])
# create dense vector of zeros
empty = torch.zeros([genes_no])
# insert expression values in respective indices in the zeroes vector
indices = indices.reshape([indices.shape[0], 1])
expression = empty.index_put_(tuple(indices.t()), values)
# If the number of clusters is not requested, only return expression values
try:
cluster = torch.from_numpy(records["cluster_int"])
except KeyError:
return expression
return expression, cluster
if description is None:
if get_clusters:
description = {"indices": None, "values": None, "cluster_int": None}
else:
description = {"indices": None, "values": None}
if multi_read:
dataset = MultiTFRecordDataset(
file_path,
None,
splits,
description,
compression_type=compression_type,
transform=post_process,
)
else:
dataset = TFRecordDataset(
file_path,
None,
description,
compression_type=compression_type,
transform=post_process,
)
return DataLoader(dataset, batch_size=batch_size)