Source code for perturbation.perturbation

import typing

import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from matplotlib import rcParams
import numpy as np
import scanpy as sc
import umap.umap_ as umap
import seaborn as sns
from matplotlib import cm
from matplotlib.lines import Line2D
import torch

from configparser import ConfigParser
from factory import get_factory, parse_list
from sc_dataset import get_loader

font_dir = ["Atkinson_Hyperlegible/Web Fonts/TTF/"]
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)

# Set font family globally
rcParams["font.family"] = "Atkinson Hyperlegible"
rcParams.update({"font.size": 13})
UMAP = umap.UMAP(random_state=60, n_neighbors=15)


[docs]def plot_UMAP( real: np.ndarray, fake: np.ndarray, real_labels: typing.Optional[typing.Union[list[str], np.ndarray]] = None, fit: bool = False, fake_title: str = "Fake", case_ctr: str = "ctr", save_path: typing.Optional[str] = None, ) -> None: """ Plot UMAP projections and density plots for real and generated (fake) cell data. Parameters ---------- real : np.ndarray Real cell expression data (cells x genes). fake : np.ndarray Fake/generated cell expression data (cells x genes). real_labels : Optional[Union[list[str], np.ndarray]], optional List or array of cell type labels for real cells (used for color-coded scatter plots), by default None fit : bool, optional Whether to fit a new UMAP model on the concatenated data (`True`), or transform using an existing fitted UMAP model (`False`), by default False fake_title : str, optional Title used for fake cells in the plots (e.g., "Generated", "Simulated"), by default "Fake" case_ctr : str, optional Identifier used in the saved filenames (before of after pert), by default "ctr" save_path : Optional[str], optional If provided, saves the scatter and density plots as PNGs, by default None """ if real_labels is not None: real_labels = np.array(real_labels) celltypes = set(real_labels) n_classes = len(celltypes) if fit: embedded_cells = UMAP.fit_transform( np.concatenate((real, fake), axis=0), ) else: embedded_cells = UMAP.transform(np.concatenate((real, fake), axis=0)) real_embedding = embedded_cells[0 : real.shape[0], :] fake_embedding = embedded_cells[real.shape[0] :, :] # Figure 1: Scatter Plots plt.clf() fig1 = plt.figure(figsize=(20, 6)) ax1 = fig1.add_subplot(1, 3, 1) if real_labels is not None: # Get the original tab20 colormap colormap = cm.get_cmap("tab20") # Get the colors from tab20 excluding the red color colors = [ colormap(i) for i in range(colormap.N) if i != 6 ] # Remove red color at index 2 # Create a new colormap without the red color colormap = cm.colors.ListedColormap(colors) colormap = [colormap(i) for i in np.linspace(0, 1, n_classes)] colors = {celltype: colormap[i] for i, celltype in enumerate(set(real_labels))} for i in set(real_labels): mask = real_labels[:] == i ax1.scatter( real_embedding[mask, 0], real_embedding[mask, 1], c=np.array([colors[i]]), label=str(i) + " (real)", alpha=0.7, ) else: ax1.scatter( real_embedding[:, 0], real_embedding[:, 1], c="blue", label="real", alpha=0.7, ) ax1.set_title("Real Cells") ax1.grid(True) ax1.set_xlabel("UMAP 1") ax1.set_ylabel("UMAP 2") # Subplot 2: Fake Cells Scatter Plot ax2 = fig1.add_subplot(1, 3, 2) ax2.scatter( fake_embedding[:, 0], fake_embedding[:, 1], c="red", label="fake", alpha=0.7, ) ax2.set_title(fake_title) ax2.grid(True) ax2.set_xlabel("UMAP 1") ax2.set_ylabel("UMAP 2") # Subplot 3: Real and Fake Cells Combined Scatter Plot ax3 = fig1.add_subplot(1, 3, 3) if real_labels is not None: for i in set(real_labels): mask = real_labels[:] == i ax3.scatter( real_embedding[mask, 0], real_embedding[mask, 1], c=np.array([colors[i]]), label=str(i) + " (real)", alpha=0.7, ) else: ax3.scatter( real_embedding[:, 0], real_embedding[:, 1], c="blue", label="real", alpha=0.7, ) ax3.scatter( fake_embedding[:, 0], fake_embedding[:, 1], c="red", label="fake", alpha=0.7, ) ax3.set_xlabel("UMAP 1") ax3.set_ylabel("UMAP 2") ax3.set_title(f"Real and {fake_title} Cells") ax3.grid(True) # Get handles and labels from ax1 handles, labels = ax1.get_legend_handles_labels() if real_labels is not None: # Place shared legend to the far left fig1.legend( handles, labels, title="Cell Types", loc="center left", bbox_to_anchor=(-0.17, 0.5), borderaxespad=0.0, ) plt.tight_layout(rect=[0.01, 0, 1, 1]) # Leave space for legend on the left # Figure 2: Density Plots fig2 = plt.figure(figsize=(20, 6)) # Subplot 1: Real Cells Density Plot ax4 = fig2.add_subplot(1, 3, 1) sns.kdeplot( real_embedding[:, 0], real_embedding[:, 1], cmap="Blues", shade=True, shade_lowest=False, ax=ax4, cbar=True, ) ax4.set_title("Real Cells Density") ax4.grid(True) # Subplot 2: Fake Cells Density Plot ax5 = fig2.add_subplot(1, 3, 2) sns.kdeplot( fake_embedding[:, 0], fake_embedding[:, 1], cmap="Reds", shade=True, shade_lowest=False, ax=ax5, cbar=True, cbar_kws={"format": "%.3g"}, # Set format to limit significant figures ) ax5.set_title(f"{fake_title} Cells Density") ax5.grid(True) # Subplot 3: Real and Fake Cells Combined Density Plot ax6 = fig2.add_subplot(1, 3, 3) sns.kdeplot( np.hstack((real_embedding[:, 0], fake_embedding[:, 0])), np.hstack((real_embedding[:, 1], fake_embedding[:, 1])), cmap="Greys", shade=True, shade_lowest=False, ax=ax6, cbar=True, cbar_kws={"format": "%.3g"}, # Set format to limit significant figures ) ax6.set_title(f"Real and {fake_title} Cells Density") ax6.grid(True) plt.tight_layout() if save_path is not None: fig1.savefig( save_path + f"scatter_plot_{case_ctr}.png", dpi=300, bbox_inches="tight" ) fig2.savefig( save_path + f"density_plot_{case_ctr}.png", dpi=300, bbox_inches="tight" )
[docs]def perturb(cfg: ConfigParser) -> None: """ Performs perturbation experiment defined in the configuration file. Saves cells and UMAP plots before and after perturbation. Parameters ---------- cfg : ConfigParser Parser for config file containing program params. """ # use the same number of cels as the test set cells_no = cfg.getint("Preprocessing", "test set size") test_set = sc.read_h5ad(cfg.get("Data", "test")) # Read the GAN gan = get_factory(cfg).get_gan() print("Loaded GAN") checkpoint = cfg.get("EXPERIMENT", "checkpoint") print("Using checkpoint at", checkpoint) # get real cells loader = get_loader(cfg.get("Data", "test"), cells_no) real_cells, real_labels = next(iter(loader)) real_cells = real_cells.cpu().numpy() real_cells[:cells_no], real_labels[:cells_no] # get fake cells without perturbation fake_cells = gan.generate_cells(cells_no, checkpoint) #### LET USER DEFINE PATH IN CFG if "celltype" in test_set.obs: plot_UMAP( test_set.X, fake_cells, real_labels=test_set.obs.celltype.to_numpy(), fit=True, case_ctr="before_perturbation", save_path=cfg.get("Perturbation", "save dir"), ) else: plot_UMAP( test_set.X, fake_cells, fit=True, case_ctr="before_perturbation", save_path=cfg.get("Perturbation", "save dir"), ) gan.gen.tf_expressions = None gan.gen.pert_mode = True fake_cells = gan.generate_cells(cells_no, checkpoint) fake_cells_new = gan.generate_cells(cells_no, checkpoint) assert ( fake_cells == fake_cells_new ).all(), "perturbation mode should be deterministic" tfs_to_perturb = parse_list(cfg.get("Perturbation", "tfs to perturb"), str) pert_values = parse_list(cfg.get("Perturbation", "perturbation values"), float) gene_names = list(test_set.var_names) tfs_idx = [gene_names.index(tf) for tf in tfs_to_perturb] tf_idx = [gan.gen.tfs.index(tf_idx) for tf_idx in tfs_idx] print(tf_idx) unperturbed_tfs = gan.gen.tf_expressions.clone() pert_tensor = torch.tensor( pert_values, device=gan.gen.tf_expressions.device, dtype=gan.gen.tf_expressions.dtype, ) gan.gen.tf_expressions[:, tf_idx] = pert_tensor.unsqueeze(0) fake_cells_perturbed = gan.generate_cells(cells_no, checkpoint) gan.gen.tf_expressions = unperturbed_tfs if "celltype" in test_set.obs: plot_UMAP( test_set.X, fake_cells_perturbed, real_labels=test_set.obs.celltype.to_numpy(), fit=False, case_ctr="after_perturbation", save_path=cfg.get("Perturbation", "save dir"), ) else: plot_UMAP( test_set.X, fake_cells_perturbed, fit=False, case_ctr="after_perturbation", save_path=cfg.get("Perturbation", "save dir"), ) fake_cells = sc.AnnData(fake_cells) fake_cells.obs_names = np.repeat("ctr_fake", fake_cells.shape[0]) fake_cells.obs_names_make_unique() fake_cells.write(cfg.get("Perturbation", "save dir") + "before_perturbation.h5ad") fake_cells_perturbed = sc.AnnData(fake_cells_perturbed) fake_cells_perturbed.obs_names = np.repeat( "pert_fake", fake_cells_perturbed.shape[0] ) fake_cells_perturbed.obs_names_make_unique() fake_cells_perturbed.write( cfg.get("Perturbation", "save dir") + "after_perturbation.h5ad" ) print( "Saved cells before and after perturbation to", cfg.get("Perturbation", "save dir"), )