Source code for preprocessing.grn_creation

import pickle
from configparser import ConfigParser
from itertools import chain

import pandas as pd
import scanpy as sc
from arboreto.algo import grnboost2
from tabulate import tabulate
from scipy import sparse


[docs]def create_GRN(cfg: ConfigParser) -> None: """ Infers a GRN using GRNBoost2 and uses it to construct a causal graph to impose onto GRouNdGAN. Parameters ---------- cfg : ConfigParser Parser for config file containing GRN creation params. """ real_cells = sc.read_h5ad(cfg.get("Data", "train")) real_cells_val = sc.read_h5ad(cfg.get("Data", "validation")) real_cells_test = sc.read_h5ad(cfg.get("Data", "test")) # find TFs that are in highly variable genes gene_names = real_cells.var_names.tolist() TFs = pd.read_csv(cfg.get("GRN Preparation", "TFs"), sep="\t")["Symbol"] TFs = list(set(TFs).intersection(gene_names)) if sparse.issparse(real_cells.X): real_cells.X = real_cells.X.todense() # preparing GRNBoost2's input real_cells_df = pd.DataFrame(real_cells.X, columns=real_cells.var_names) # we can optionally pass a list of TFs to GRNBoost2 print(f"Using {len(TFs)} TFs for GRN inference.") real_grn = grnboost2(real_cells_df, tf_names=TFs, verbose=True, seed=1) real_grn.to_csv(cfg.get("GRN Preparation", "Inferred GRN")) print("Successfully saved GRN inferred by GRNBoost2 GRN to", cfg.get("GRN Preparation", "Inferred GRN")) # read GRN csv output, group TFs regulating genes, sort by importance real_grn = ( pd.read_csv(cfg.get("GRN Preparation", "Inferred GRN")) .sort_values("importance", ascending=False) .astype(str) ) causal_graph = dict(real_grn.groupby("target")["TF"].apply(list)) k = int(cfg.get("GRN Preparation", "k")) if cfg.get("GRN Preparation", "strategy") == "top": causal_graph = { gene: set(tfs[:k]) # to sample the top k edges for (gene, tfs) in causal_graph.items() } elif cfg.get("GRN Preparation", "strategy") == "pos ctr": print("Creating positive control GRN from even indexed top TFs (top 1, 3, 5, ...)") causal_graph = { gene: set(tfs[0:k:2]) # sample even indices for (gene, tfs) in causal_graph.items() } elif cfg.get("GRN Preparation", "strategy") == "neg ctr": print("Creating negative control GRN from odd indexed top TFs (top 2, 4, 6, ...)") causal_graph = { gene: set(tfs[1:k:2]) # sample odd indices for (gene, tfs) in causal_graph.items() } else: print("GRN preparation strategy not valid") # get gene, TF names regulators = list(chain.from_iterable(causal_graph.values())) tfs = set(regulators) # delete targets that are also regulators causal_graph = {k: v for (k, v) in causal_graph.items() if k not in tfs} # get gene, TF names regulators = list(chain.from_iterable(causal_graph.values())) tfs = set(regulators) targets = set(causal_graph.keys()) genes = list(tfs | targets) # overwrite train, validation, and test datasets in case there some genes were excluded from the dataset real_cells = real_cells[:, genes] real_cells.write_h5ad(cfg.get("Data", "train")) real_cells_val[:, genes].write_h5ad(cfg.get("Data", "validation")) real_cells_test[:, genes].write_h5ad(cfg.get("Data", "test")) # print causal graph info print( "", "Causal Graph", tabulate( [ ("``TFs``", len(tfs)), ("``Targets``", len(targets)), ("Genes", len(genes)), ("Possible Edges", len(tfs) * len(targets)), ("Imposed Edges", k * len(targets)), ("GRN density Edges", k * len(targets) / (len(tfs) * len(targets))), ] ), sep="\n", ) gene_idx = real_cells.to_df().columns # convert gene names to numerical indices causal_graph = { gene_idx.get_loc(gene): {gene_idx.get_loc(tf) for tf in tfs} for (gene, tfs) in causal_graph.items() } # save causal graph with open(cfg.get("Data", "causal graph"), "wb") as fp: pickle.dump(causal_graph, fp, protocol=pickle.HIGHEST_PROTOCOL) print("Successfully saved GRouNdGAN causal graph to", cfg.get("Data", "causal graph"))