import pickle
import typing
from abc import ABC, abstractmethod
from configparser import ConfigParser
import torch
from gans.causal_gan import CausalGAN
from gans.conditional_gan_cat import ConditionalCatGAN
from gans.conditional_gan_proj import ConditionalProjGAN
from gans.gan import GAN
[docs]def parse_list(str_list: str, type_: type) -> list:
    return list(map(type_, str.split(str_list))) 
[docs]class IGANFactory(ABC):
    """
    Factory that represents a GAN.
    This factory does not keep of created references.
    """
[docs]    def __init__(self, parser: ConfigParser) -> None:
        """
        Initialize the factory.
        Parameters
        ----------
        parser : ConfigParser
            Parser for config file containing GAN model and training params.
        """
        self.parser = parser 
[docs]    @abstractmethod
    def get_gan(self) -> GAN:
        """
        Returns a GAN instance
        Returns
        -------
        GAN
            GAN instance.
        """
        pass 
[docs]    @abstractmethod
    def get_trainer(self) -> typing.Callable:
        """
        Returns the GAN train function.
        Returns
        -------
        typing.Callable
            GAN train() function.
        """
        pass  
[docs]class GANFactory(IGANFactory):
[docs]    def get_gan(self) -> GAN:
        return GAN(
            genes_no=self.parser.getint("Data", "number of genes"),
            batch_size=self.parser.getint("Training", "batch size"),
            latent_dim=self.parser.getint("Model", "latent dim"),
            gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
            crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
            device=self.parser.get("EXPERIMENT", "device", fallback=None),
            library_size=self.parser.getint("Preprocessing", "library size"),
        ) 
[docs]    def get_trainer(self) -> typing.Callable:
        gan = self.get_gan()
        return lambda: gan.train(
            train_files=self.parser.get("Data", "train"),
            valid_files=self.parser.get("Data", "validation"),
            critic_iter=self.parser.getint("Training", "critic iterations"),
            max_steps=self.parser.getint("Training", "maximum steps"),
            c_lambda=self.parser.getfloat("Model", "lambda"),
            beta1=self.parser.getfloat("Optimizer", "beta1"),
            beta2=self.parser.getfloat("Optimizer", "beta2"),
            gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
            gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
            crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
            crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
            checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
            summary_freq=self.parser.getint("Logging", "summary frequency"),
            plt_freq=self.parser.getint("Logging", "plot frequency"),
            save_feq=self.parser.getint("Logging", "save frequency"),
            output_dir=self.parser.get("EXPERIMENT", "output directory"),
        )  
[docs]class ConditionalCatGANFactory(IGANFactory):
[docs]    def get_gan(self) -> ConditionalCatGAN:
        return ConditionalCatGAN(
            genes_no=self.parser.getint("Data", "number of genes"),
            batch_size=self.parser.getint("Training", "batch size"),
            latent_dim=self.parser.getint("Model", "latent dim"),
            gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
            crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
            num_classes=self.parser.getint("Data", "number of classes"),
            label_ratios=torch.Tensor(
                parse_list(self.parser["Data"]["label ratios"], float)
            ),
            device=self.parser.get("EXPERIMENT", "device", fallback=None),
            library_size=self.parser.getint("Preprocessing", "library size"),
        ) 
[docs]    def get_trainer(self) -> typing.Callable:
        gan = self.get_gan()
        return lambda: gan.train(
            train_files=self.parser.get("Data", "train"),
            valid_files=self.parser.get("Data", "validation"),
            critic_iter=self.parser.getint("Training", "critic iterations"),
            max_steps=self.parser.getint("Training", "maximum steps"),
            c_lambda=self.parser.getfloat("Model", "lambda"),
            beta1=self.parser.getfloat("Optimizer", "beta1"),
            beta2=self.parser.getfloat("Optimizer", "beta2"),
            gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
            gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
            crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
            crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
            checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
            summary_freq=self.parser.getint("Logging", "summary frequency"),
            plt_freq=self.parser.getint("Logging", "plot frequency"),
            save_feq=self.parser.getint("Logging", "save frequency"),
            output_dir=self.parser.get("EXPERIMENT", "output directory"),
        )  
[docs]class ConditionalProjGANFactory(IGANFactory):
[docs]    def get_gan(self) -> ConditionalProjGAN:
        return ConditionalProjGAN(
            genes_no=self.parser.getint("Data", "number of genes"),
            batch_size=self.parser.getint("Training", "batch size"),
            latent_dim=self.parser.getint("Model", "latent dim"),
            gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
            crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
            num_classes=self.parser.getint("Data", "number of classes"),
            label_ratios=torch.Tensor(
                parse_list(self.parser["Data"]["label ratios"], float)
            ),
            device=self.parser.get("EXPERIMENT", "device", fallback=None),
            library_size=self.parser.getint("Preprocessing", "library size"),
        ) 
[docs]    def get_trainer(self) -> typing.Callable:
        gan = self.get_gan()
        return lambda: gan.train(
            train_files=self.parser.get("Data", "train"),
            valid_files=self.parser.get("Data", "validation"),
            critic_iter=self.parser.getint("Training", "critic iterations"),
            max_steps=self.parser.getint("Training", "maximum steps"),
            c_lambda=self.parser.getfloat("Model", "lambda"),
            beta1=self.parser.getfloat("Optimizer", "beta1"),
            beta2=self.parser.getfloat("Optimizer", "beta2"),
            gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
            gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
            crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
            crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
            checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
            summary_freq=self.parser.getint("Logging", "summary frequency"),
            plt_freq=self.parser.getint("Logging", "plot frequency"),
            save_feq=self.parser.getint("Logging", "save frequency"),
            output_dir=self.parser.get("EXPERIMENT", "output directory"),
        )  
[docs]class CausalGANFactory(IGANFactory):
[docs]    def get_cc(self) -> GAN:
        return GAN(
            genes_no=self.parser.getint("Data", "number of genes"),
            batch_size=self.parser.getint("CC Training", "batch size"),
            latent_dim=self.parser.getint("CC Model", "latent dim"),
            gen_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
            crit_layers=parse_list(self.parser["CC Model"]["critic layers"], int),
            device=self.parser.get("EXPERIMENT", "device", fallback=None),
            library_size=self.parser.getint("Preprocessing", "library size"),
        ) 
[docs]    def get_gan(self) -> CausalGAN:
        with open(self.parser.get("Data", "causal graph"), "rb") as fp:
            causal_graph = pickle.load(fp)
        return CausalGAN(
            genes_no=self.parser.getint("Data", "number of genes"),
            batch_size=self.parser.getint("Training", "batch size"),
            latent_dim=self.parser.getint("Model", "latent dim"),
            noise_per_gene=self.parser.getint("Model", "noise per gene"),
            depth_per_gene=self.parser.getint("Model", "depth per gene"),
            width_per_gene=self.parser.getint("Model", "width per gene"),
            cc_latent_dim=self.parser.getint("CC Model", "latent dim"),
            cc_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
            cc_pretrained_checkpoint=self.parser.get("EXPERIMENT", "output directory")
            + f"_CC/checkpoints/step_{self.parser.getint('CC Training', 'maximum steps')}.pth",
            crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
            causal_graph=causal_graph,
            labeler_layers=parse_list(self.parser["Model"]["labeler layers"], int),
            device=self.parser.get("EXPERIMENT", "device", fallback=None),
            library_size=self.parser.getint("Preprocessing", "library size"),
        ) 
[docs]    def get_trainer(self) -> typing.Callable:
        cc = self.get_cc()
        # the following lambda will train the causal controller for maximum steps
        # specified in the CC Training section of the config file
        # after training the causal controller, the causal GAN will be instantiated
        # with the pretrained causal controller and training will start.
        return lambda: (
            cc.train(
                train_files=self.parser.get("Data", "train"),
                valid_files=self.parser.get("Data", "validation"),
                critic_iter=self.parser.getint("CC Training", "critic iterations"),
                max_steps=self.parser.getint("CC Training", "maximum steps"),
                c_lambda=self.parser.getfloat("CC Model", "lambda"),
                beta1=self.parser.getfloat("CC Optimizer", "beta1"),
                beta2=self.parser.getfloat("CC Optimizer", "beta2"),
                gen_alpha_0=self.parser.getfloat(
                    "CC Learning Rate", "generator initial"
                ),
                gen_alpha_final=self.parser.getfloat(
                    "CC Learning Rate", "generator final"
                ),
                crit_alpha_0=self.parser.getfloat("CC Learning Rate", "critic initial"),
                crit_alpha_final=self.parser.getfloat(
                    "CC Learning Rate", "critic final"
                ),
                checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
                summary_freq=self.parser.getint("CC Logging", "summary frequency"),
                plt_freq=self.parser.getint("CC Logging", "plot frequency"),
                save_feq=self.parser.getint("CC Logging", "save frequency"),
                output_dir=self.parser.get("EXPERIMENT", "output directory") + "_CC",
            ),
            self.get_gan().train(
                train_files=self.parser.get("Data", "train"),
                valid_files=self.parser.get("Data", "validation"),
                critic_iter=self.parser.getint("Training", "critic iterations"),
                max_steps=self.parser.getint("Training", "maximum steps"),
                c_lambda=self.parser.getfloat("Model", "lambda"),
                beta1=self.parser.getfloat("Optimizer", "beta1"),
                beta2=self.parser.getfloat("Optimizer", "beta2"),
                gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
                gen_alpha_final=self.parser.getfloat(
                    "Learning Rate", "generator final"
                ),
                crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
                crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
                labeler_alpha=self.parser.getfloat("Learning Rate", "labeler"),
                antilabeler_alpha=self.parser.getfloat("Learning Rate", "antilabeler"),
                labeler_training_interval=self.parser.getfloat(
                    "Training", "labeler and antilabeler training intervals"
                ),
                checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
                summary_freq=self.parser.getint("Logging", "summary frequency"),
                plt_freq=self.parser.getint("Logging", "plot frequency"),
                save_feq=self.parser.getint("Logging", "save frequency"),
                output_dir=self.parser.get("EXPERIMENT", "output directory"),
            ),
        )[0]  
[docs]def get_factory(cfg: ConfigParser) -> IGANFactory:
    """
    Return the factory for the GAN type based on 'model' key in the parser.
    Parameters
    ----------
    cfg : ConfigParser
        Parser for config file containing GAN model and training params.
    Returns
    -------
    IGANFactory
        Factory for the specified GAN.
    Raises
    ------
    ValueError
        If the model is unknown or not implemented.
    """
    # read the desired GAN
    model = cfg.get("Model", "type")
    factories = {
        "GAN": GANFactory(cfg),
        "proj conditional GAN": ConditionalProjGANFactory(cfg),
        "cat conditional GAN": ConditionalCatGANFactory(cfg),
        "causal GAN": CausalGANFactory(cfg),
    }
    if model in factories:
        return factories[model]
    raise ValueError(f"model '{model}' type is invalid")