import pickle
from configparser import ConfigParser
from itertools import chain
import pandas as pd
import scanpy as sc
from arboreto.algo import grnboost2
from tabulate import tabulate
[docs]
def create_GRN(cfg: ConfigParser) -> None:
"""
Infers a GRN using GRNBoost2 and uses it to construct a causal graph to impose onto GRouNdGAN.
Parameters
----------
cfg : ConfigParser
Parser for config file containing GRN creation params.
"""
real_cells = sc.read_h5ad(cfg.get("Data", "train"))
real_cells_val = sc.read_h5ad(cfg.get("Data", "validation"))
real_cells_test = sc.read_h5ad(cfg.get("Data", "test"))
# find TFs that are in highly variable genes
gene_names = real_cells.var_names.tolist()
TFs = pd.read_csv(cfg.get("GRN Preparation", "TFs"), sep="\t")["Symbol"]
TFs = list(set(TFs).intersection(gene_names))
# preparing GRNBoost2's input
real_cells_df = pd.DataFrame(real_cells.X, columns=real_cells.var_names)
# we can optionally pass a list of TFs to GRNBoost2
print(f"Using {len(TFs)} TFs for GRN inference.")
real_grn = grnboost2(real_cells_df, tf_names=TFs, verbose=True, seed=1)
real_grn.to_csv(cfg.get("GRN Preparation", "Inferred GRN"))
# read GRN csv output, group TFs regulating genes, sort by importance
real_grn = (
pd.read_csv(cfg.get("GRN Preparation", "Inferred GRN"))
.sort_values("importance", ascending=False)
.astype(str)
)
causal_graph = dict(real_grn.groupby("target")["TF"].apply(list))
k = int(cfg.get("GRN Preparation", "k"))
causal_graph = {
gene: set(tfs[:k]) # to sample the top k edges
# gene: set(tfs[0:10:2]) # sample even indices
# gene: set(tfs[1:10:2]) # sample odd indices
for (gene, tfs) in causal_graph.items()
}
# get gene, TF names
regulators = list(chain.from_iterable(causal_graph.values()))
tfs = set(regulators)
# delete targets that are also regulators
causal_graph = {k: v for (k, v) in causal_graph.items() if k not in tfs}
# get gene, TF names
regulators = list(chain.from_iterable(causal_graph.values()))
tfs = set(regulators)
targets = set(causal_graph.keys())
genes = list(tfs | targets)
# overwrite train, validation, and test datasets in case there some genes were excluded from the dataset
real_cells = real_cells[:, genes]
real_cells.write_h5ad(cfg.get("Data", "train"))
real_cells_val[:, genes].write_h5ad(cfg.get("Data", "validation"))
real_cells_test[:, genes].write_h5ad(cfg.get("Data", "test"))
# print causal graph info
print(
"",
"Causal Graph",
tabulate(
[
("TFs", len(tfs)),
("Targets", len(targets)),
("Genes", len(genes)),
("Possible Edges", len(tfs) * len(targets)),
("Imposed Edges", k * len(targets)),
("GRN density Edges", k * len(targets) / (len(tfs) * len(targets))),
]
),
sep="\n",
)
gene_idx = real_cells.to_df().columns
# convert gene names to numerical indices
causal_graph = {
gene_idx.get_loc(gene): {gene_idx.get_loc(tf) for tf in tfs}
for (gene, tfs) in causal_graph.items()
}
# save causal graph
with open(cfg.get("Data", "causal graph"), "wb") as fp:
pickle.dump(causal_graph, fp, protocol=pickle.HIGHEST_PROTOCOL)