import typing
import matplotlib.pyplot as plt
import matplotlib.font_manager as font_manager
from matplotlib import rcParams
import numpy as np
import scanpy as sc
import umap.umap_ as umap
import seaborn as sns
from matplotlib import cm
from matplotlib.lines import Line2D
import torch
from configparser import ConfigParser
from factory import get_factory, parse_list
from sc_dataset import get_loader
font_dir = ["Atkinson_Hyperlegible/Web Fonts/TTF/"]
for font in font_manager.findSystemFonts(font_dir):
    font_manager.fontManager.addfont(font)
# Set font family globally
rcParams["font.family"] = "Atkinson Hyperlegible"
rcParams.update({"font.size": 13})
UMAP = umap.UMAP(random_state=60, n_neighbors=15)
[docs]def plot_UMAP(
    real: np.ndarray,
    fake: np.ndarray,
    real_labels: typing.Optional[typing.Union[list[str], np.ndarray]] = None,
    fit: bool = False,
    fake_title: str = "Fake",
    case_ctr: str = "ctr",
    save_path: typing.Optional[str] = None,
) -> None:
    """
    Plot UMAP projections and density plots for real and generated (fake) cell data.
    Parameters
    ----------
    real : np.ndarray
        Real cell expression data (cells x genes).
    fake : np.ndarray
        Fake/generated cell expression data (cells x genes).
    real_labels : Optional[Union[list[str], np.ndarray]], optional
        List or array of cell type labels for real cells (used for color-coded scatter plots), by default None
    fit : bool, optional
        Whether to fit a new UMAP model on the concatenated data (`True`),
        or transform using an existing fitted UMAP model (`False`), by default False
    fake_title : str, optional
        Title used for fake cells in the plots (e.g., "Generated", "Simulated"), by default "Fake"
    case_ctr : str, optional
        Identifier used in the saved filenames (before of after pert), by default "ctr"
    save_path : Optional[str], optional
        If provided, saves the scatter and density plots as PNGs, by default None
    """
    if real_labels is not None:
        real_labels = np.array(real_labels)
        celltypes = set(real_labels)
        n_classes = len(celltypes)
    if fit:
        embedded_cells = UMAP.fit_transform(
            np.concatenate((real, fake), axis=0),
        )
    else:
        embedded_cells = UMAP.transform(np.concatenate((real, fake), axis=0))
    real_embedding = embedded_cells[0 : real.shape[0], :]
    fake_embedding = embedded_cells[real.shape[0] :, :]
    # Figure 1: Scatter Plots
    plt.clf()
    fig1 = plt.figure(figsize=(20, 6))
    ax1 = fig1.add_subplot(1, 3, 1)
    if real_labels is not None:
        # Get the original tab20 colormap
        colormap = cm.get_cmap("tab20")
        # Get the colors from tab20 excluding the red color
        colors = [
            colormap(i) for i in range(colormap.N) if i != 6
        ]  # Remove red color at index 2
        # Create a new colormap without the red color
        colormap = cm.colors.ListedColormap(colors)
        colormap = [colormap(i) for i in np.linspace(0, 1, n_classes)]
        colors = {celltype: colormap[i] for i, celltype in enumerate(set(real_labels))}
        for i in set(real_labels):
            mask = real_labels[:] == i
            ax1.scatter(
                real_embedding[mask, 0],
                real_embedding[mask, 1],
                c=np.array([colors[i]]),
                label=str(i) + " (real)",
                alpha=0.7,
            )
    else:
        ax1.scatter(
            real_embedding[:, 0],
            real_embedding[:, 1],
            c="blue",
            label="real",
            alpha=0.7,
        )
    ax1.set_title("Real Cells")
    ax1.grid(True)
    ax1.set_xlabel("UMAP 1")
    ax1.set_ylabel("UMAP 2")
    # Subplot 2: Fake Cells Scatter Plot
    ax2 = fig1.add_subplot(1, 3, 2)
    ax2.scatter(
        fake_embedding[:, 0],
        fake_embedding[:, 1],
        c="red",
        label="fake",
        alpha=0.7,
    )
    ax2.set_title(fake_title)
    ax2.grid(True)
    ax2.set_xlabel("UMAP 1")
    ax2.set_ylabel("UMAP 2")
    # Subplot 3: Real and Fake Cells Combined Scatter Plot
    ax3 = fig1.add_subplot(1, 3, 3)
    if real_labels is not None:
        for i in set(real_labels):
            mask = real_labels[:] == i
            ax3.scatter(
                real_embedding[mask, 0],
                real_embedding[mask, 1],
                c=np.array([colors[i]]),
                label=str(i) + " (real)",
                alpha=0.7,
            )
    else:
        ax3.scatter(
            real_embedding[:, 0],
            real_embedding[:, 1],
            c="blue",
            label="real",
            alpha=0.7,
        )
    ax3.scatter(
        fake_embedding[:, 0],
        fake_embedding[:, 1],
        c="red",
        label="fake",
        alpha=0.7,
    )
    ax3.set_xlabel("UMAP 1")
    ax3.set_ylabel("UMAP 2")
    ax3.set_title(f"Real and {fake_title} Cells")
    ax3.grid(True)
    # Get handles and labels from ax1
    handles, labels = ax1.get_legend_handles_labels()
    if real_labels is not None:
        # Place shared legend to the far left
        fig1.legend(
            handles,
            labels,
            title="Cell Types",
            loc="center left",
            bbox_to_anchor=(-0.17, 0.5),
            borderaxespad=0.0,
        )
        plt.tight_layout(rect=[0.01, 0, 1, 1])  # Leave space for legend on the left
    # Figure 2: Density Plots
    fig2 = plt.figure(figsize=(20, 6))
    # Subplot 1: Real Cells Density Plot
    ax4 = fig2.add_subplot(1, 3, 1)
    sns.kdeplot(
        real_embedding[:, 0],
        real_embedding[:, 1],
        cmap="Blues",
        shade=True,
        shade_lowest=False,
        ax=ax4,
        cbar=True,
    )
    ax4.set_title("Real Cells Density")
    ax4.grid(True)
    # Subplot 2: Fake Cells Density Plot
    ax5 = fig2.add_subplot(1, 3, 2)
    sns.kdeplot(
        fake_embedding[:, 0],
        fake_embedding[:, 1],
        cmap="Reds",
        shade=True,
        shade_lowest=False,
        ax=ax5,
        cbar=True,
        cbar_kws={"format": "%.3g"},  # Set format to limit significant figures
    )
    ax5.set_title(f"{fake_title} Cells Density")
    ax5.grid(True)
    # Subplot 3: Real and Fake Cells Combined Density Plot
    ax6 = fig2.add_subplot(1, 3, 3)
    sns.kdeplot(
        np.hstack((real_embedding[:, 0], fake_embedding[:, 0])),
        np.hstack((real_embedding[:, 1], fake_embedding[:, 1])),
        cmap="Greys",
        shade=True,
        shade_lowest=False,
        ax=ax6,
        cbar=True,
        cbar_kws={"format": "%.3g"},  # Set format to limit significant figures
    )
    ax6.set_title(f"Real and {fake_title} Cells Density")
    ax6.grid(True)
    plt.tight_layout()
    if save_path is not None:
        fig1.savefig(
            save_path + f"scatter_plot_{case_ctr}.png", dpi=300, bbox_inches="tight"
        )
        fig2.savefig(
            save_path + f"density_plot_{case_ctr}.png", dpi=300, bbox_inches="tight"
        ) 
[docs]def perturb(cfg: ConfigParser) -> None:
    """
    Performs perturbation experiment defined in the configuration file.
    Saves cells and UMAP plots before and after perturbation.
    Parameters
    ----------
    cfg : ConfigParser
        Parser for config file containing program params.
    """
    # use the same number of cels as the test set
    cells_no = cfg.getint("Preprocessing", "test set size")
    test_set = sc.read_h5ad(cfg.get("Data", "test"))
    # Read the GAN
    gan = get_factory(cfg).get_gan()
    print("Loaded GAN")
    checkpoint = cfg.get("EXPERIMENT", "checkpoint")
    print("Using checkpoint at", checkpoint)
    # get real cells
    loader = get_loader(cfg.get("Data", "test"), cells_no)
    real_cells, real_labels = next(iter(loader))
    real_cells = real_cells.cpu().numpy()
    real_cells[:cells_no], real_labels[:cells_no]
    # get fake cells without perturbation
    fake_cells = gan.generate_cells(cells_no, checkpoint)
    #### LET USER DEFINE PATH IN CFG
    if "celltype" in test_set.obs:
        plot_UMAP(
            test_set.X,
            fake_cells,
            real_labels=test_set.obs.celltype.to_numpy(),
            fit=True,
            case_ctr="before_perturbation",
            save_path=cfg.get("Perturbation", "save dir"),
        )
    else:
        plot_UMAP(
            test_set.X,
            fake_cells,
            fit=True,
            case_ctr="before_perturbation",
            save_path=cfg.get("Perturbation", "save dir"),
        )
    gan.gen.tf_expressions = None
    gan.gen.pert_mode = True
    fake_cells = gan.generate_cells(cells_no, checkpoint)
    fake_cells_new = gan.generate_cells(cells_no, checkpoint)
    assert (
        fake_cells == fake_cells_new
    ).all(), "perturbation mode should be deterministic"
    tfs_to_perturb = parse_list(cfg.get("Perturbation", "tfs to perturb"), str)
    pert_values = parse_list(cfg.get("Perturbation", "perturbation values"), float)
    gene_names = list(test_set.var_names)
    tfs_idx = [gene_names.index(tf) for tf in tfs_to_perturb]
    tf_idx = [gan.gen.tfs.index(tf_idx) for tf_idx in tfs_idx]
    print(tf_idx)
    unperturbed_tfs = gan.gen.tf_expressions.clone()
    pert_tensor = torch.tensor(
        pert_values,
        device=gan.gen.tf_expressions.device,
        dtype=gan.gen.tf_expressions.dtype,
    )
    gan.gen.tf_expressions[:, tf_idx] = pert_tensor.unsqueeze(0)
    fake_cells_perturbed = gan.generate_cells(cells_no, checkpoint)
    gan.gen.tf_expressions = unperturbed_tfs
    if "celltype" in test_set.obs:
        plot_UMAP(
            test_set.X,
            fake_cells_perturbed,
            real_labels=test_set.obs.celltype.to_numpy(),
            fit=False,
            case_ctr="after_perturbation",
            save_path=cfg.get("Perturbation", "save dir"),
        )
    else:
        plot_UMAP(
            test_set.X,
            fake_cells_perturbed,
            fit=False,
            case_ctr="after_perturbation",
            save_path=cfg.get("Perturbation", "save dir"),
        )
    fake_cells = sc.AnnData(fake_cells)
    fake_cells.obs_names = np.repeat("ctr_fake", fake_cells.shape[0])
    fake_cells.obs_names_make_unique()
    fake_cells.write(cfg.get("Perturbation", "save dir") + "before_perturbation.h5ad")
    fake_cells_perturbed = sc.AnnData(fake_cells_perturbed)
    fake_cells_perturbed.obs_names = np.repeat(
        "pert_fake", fake_cells_perturbed.shape[0]
    )
    fake_cells_perturbed.obs_names_make_unique()
    fake_cells_perturbed.write(
        cfg.get("Perturbation", "save dir") + "after_perturbation.h5ad"
    )
    print(
        "Saved cells before and after perturbation to",
        cfg.get("Perturbation", "save dir"),
    )