import pickle
import typing
from abc import ABC, abstractmethod
from configparser import ConfigParser
import torch
from gans.causal_gan import CausalGAN
from gans.conditional_gan_cat import ConditionalCatGAN
from gans.conditional_gan_proj import ConditionalProjGAN
from gans.gan import GAN
[docs]
def parse_list(str_list: str, type_: type) -> list:
return list(map(type_, str.split(str_list)))
[docs]
class IGANFactory(ABC):
"""
Factory that represents a GAN.
This factory does not keep of created references.
"""
[docs]
def __init__(self, parser: ConfigParser) -> None:
"""
Initialize the factory.
Parameters
----------
parser : ConfigParser
Parser for config file containing GAN model and training params.
"""
self.parser = parser
[docs]
@abstractmethod
def get_gan(self) -> GAN:
"""
Returns a GAN instance
Returns
-------
GAN
GAN instance.
"""
pass
[docs]
@abstractmethod
def get_trainer(self) -> typing.Callable:
"""
Returns the GAN train function.
Returns
-------
typing.Callable
GAN train() function.
"""
pass
[docs]
class GANFactory(IGANFactory):
[docs]
def get_gan(self) -> GAN:
return GAN(
genes_no=self.parser.getint("Data", "number of genes"),
batch_size=self.parser.getint("Training", "batch size"),
latent_dim=self.parser.getint("Model", "latent dim"),
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
device=self.parser.get("EXPERIMENT", "device", fallback=None),
library_size=self.parser.getint("Preprocessing", "library size"),
)
[docs]
def get_trainer(self) -> typing.Callable:
gan = self.get_gan()
return lambda: gan.train(
train_files=self.parser.get("Data", "train"),
valid_files=self.parser.get("Data", "validation"),
critic_iter=self.parser.getint("Training", "critic iterations"),
max_steps=self.parser.getint("Training", "maximum steps"),
c_lambda=self.parser.getfloat("Model", "lambda"),
beta1=self.parser.getfloat("Optimizer", "beta1"),
beta2=self.parser.getfloat("Optimizer", "beta2"),
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
summary_freq=self.parser.getint("Logging", "summary frequency"),
plt_freq=self.parser.getint("Logging", "plot frequency"),
save_feq=self.parser.getint("Logging", "save frequency"),
output_dir=self.parser.get("EXPERIMENT", "output directory"),
)
[docs]
class ConditionalCatGANFactory(IGANFactory):
[docs]
def get_gan(self) -> ConditionalCatGAN:
return ConditionalCatGAN(
genes_no=self.parser.getint("Data", "number of genes"),
batch_size=self.parser.getint("Training", "batch size"),
latent_dim=self.parser.getint("Model", "latent dim"),
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
num_classes=self.parser.getint("Data", "number of classes"),
label_ratios=torch.Tensor(
parse_list(self.parser["Data"]["label ratios"], float)
),
device=self.parser.get("EXPERIMENT", "device", fallback=None),
library_size=self.parser.getint("Preprocessing", "library size"),
)
[docs]
def get_trainer(self) -> typing.Callable:
gan = self.get_gan()
return lambda: gan.train(
train_files=self.parser.get("Data", "train"),
valid_files=self.parser.get("Data", "validation"),
critic_iter=self.parser.getint("Training", "critic iterations"),
max_steps=self.parser.getint("Training", "maximum steps"),
c_lambda=self.parser.getfloat("Model", "lambda"),
beta1=self.parser.getfloat("Optimizer", "beta1"),
beta2=self.parser.getfloat("Optimizer", "beta2"),
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
summary_freq=self.parser.getint("Logging", "summary frequency"),
plt_freq=self.parser.getint("Logging", "plot frequency"),
save_feq=self.parser.getint("Logging", "save frequency"),
output_dir=self.parser.get("EXPERIMENT", "output directory"),
)
[docs]
class ConditionalProjGANFactory(IGANFactory):
[docs]
def get_gan(self) -> ConditionalProjGAN:
return ConditionalProjGAN(
genes_no=self.parser.getint("Data", "number of genes"),
batch_size=self.parser.getint("Training", "batch size"),
latent_dim=self.parser.getint("Model", "latent dim"),
gen_layers=parse_list(self.parser["Model"]["generator layers"], int),
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
num_classes=self.parser.getint("Data", "number of classes"),
label_ratios=torch.Tensor(
parse_list(self.parser["Data"]["label ratios"], float)
),
device=self.parser.get("EXPERIMENT", "device", fallback=None),
library_size=self.parser.getint("Preprocessing", "library size"),
)
[docs]
def get_trainer(self) -> typing.Callable:
gan = self.get_gan()
return lambda: gan.train(
train_files=self.parser.get("Data", "train"),
valid_files=self.parser.get("Data", "validation"),
critic_iter=self.parser.getint("Training", "critic iterations"),
max_steps=self.parser.getint("Training", "maximum steps"),
c_lambda=self.parser.getfloat("Model", "lambda"),
beta1=self.parser.getfloat("Optimizer", "beta1"),
beta2=self.parser.getfloat("Optimizer", "beta2"),
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
gen_alpha_final=self.parser.getfloat("Learning Rate", "generator final"),
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
summary_freq=self.parser.getint("Logging", "summary frequency"),
plt_freq=self.parser.getint("Logging", "plot frequency"),
save_feq=self.parser.getint("Logging", "save frequency"),
output_dir=self.parser.get("EXPERIMENT", "output directory"),
)
[docs]
class CausalGANFactory(IGANFactory):
[docs]
def get_cc(self) -> GAN:
return GAN(
genes_no=self.parser.getint("Data", "number of genes"),
batch_size=self.parser.getint("CC Training", "batch size"),
latent_dim=self.parser.getint("CC Model", "latent dim"),
gen_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
crit_layers=parse_list(self.parser["CC Model"]["critic layers"], int),
device=self.parser.get("EXPERIMENT", "device", fallback=None),
library_size=self.parser.getint("Preprocessing", "library size"),
)
[docs]
def get_gan(self) -> CausalGAN:
with open(self.parser.get("Data", "causal graph"), "rb") as fp:
causal_graph = pickle.load(fp)
return CausalGAN(
genes_no=self.parser.getint("Data", "number of genes"),
batch_size=self.parser.getint("Training", "batch size"),
latent_dim=self.parser.getint("Model", "latent dim"),
noise_per_gene=self.parser.getint("Model", "noise per gene"),
depth_per_gene=self.parser.getint("Model", "depth per gene"),
width_per_gene=self.parser.getint("Model", "width per gene"),
cc_latent_dim=self.parser.getint("CC Model", "latent dim"),
cc_layers=parse_list(self.parser["CC Model"]["generator layers"], int),
cc_pretrained_checkpoint=self.parser.get("EXPERIMENT", "output directory")
+ f"_CC/checkpoints/step_{self.parser.getint('CC Training', 'maximum steps')}.pth",
crit_layers=parse_list(self.parser["Model"]["critic layers"], int),
causal_graph=causal_graph,
labeler_layers=parse_list(self.parser["Model"]["labeler layers"], int),
device=self.parser.get("EXPERIMENT", "device", fallback=None),
library_size=self.parser.getint("Preprocessing", "library size"),
)
[docs]
def get_trainer(self) -> typing.Callable:
cc = self.get_cc()
# the following lambda will train the causal controller for maximum steps
# specified in the CC Training section of the config file
# after training the causal controller, the causal GAN will be instantiated
# with the pretrained causal controller and training will start.
return lambda: (
cc.train(
train_files=self.parser.get("Data", "train"),
valid_files=self.parser.get("Data", "validation"),
critic_iter=self.parser.getint("CC Training", "critic iterations"),
max_steps=self.parser.getint("CC Training", "maximum steps"),
c_lambda=self.parser.getfloat("CC Model", "lambda"),
beta1=self.parser.getfloat("CC Optimizer", "beta1"),
beta2=self.parser.getfloat("CC Optimizer", "beta2"),
gen_alpha_0=self.parser.getfloat(
"CC Learning Rate", "generator initial"
),
gen_alpha_final=self.parser.getfloat(
"CC Learning Rate", "generator final"
),
crit_alpha_0=self.parser.getfloat("CC Learning Rate", "critic initial"),
crit_alpha_final=self.parser.getfloat(
"CC Learning Rate", "critic final"
),
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
summary_freq=self.parser.getint("CC Logging", "summary frequency"),
plt_freq=self.parser.getint("CC Logging", "plot frequency"),
save_feq=self.parser.getint("CC Logging", "save frequency"),
output_dir=self.parser.get("EXPERIMENT", "output directory") + "_CC",
),
self.get_gan().train(
train_files=self.parser.get("Data", "train"),
valid_files=self.parser.get("Data", "validation"),
critic_iter=self.parser.getint("Training", "critic iterations"),
max_steps=self.parser.getint("Training", "maximum steps"),
c_lambda=self.parser.getfloat("Model", "lambda"),
beta1=self.parser.getfloat("Optimizer", "beta1"),
beta2=self.parser.getfloat("Optimizer", "beta2"),
gen_alpha_0=self.parser.getfloat("Learning Rate", "generator initial"),
gen_alpha_final=self.parser.getfloat(
"Learning Rate", "generator final"
),
crit_alpha_0=self.parser.getfloat("Learning Rate", "critic initial"),
crit_alpha_final=self.parser.getfloat("Learning Rate", "critic final"),
labeler_alpha=self.parser.getfloat("Learning Rate", "labeler"),
antilabeler_alpha=self.parser.getfloat("Learning Rate", "antilabeler"),
labeler_training_interval=self.parser.getfloat(
"Training", "labeler and antilabeler training intervals"
),
checkpoint=self.parser.get("EXPERIMENT", "checkpoint", fallback=None),
summary_freq=self.parser.getint("Logging", "summary frequency"),
plt_freq=self.parser.getint("Logging", "plot frequency"),
save_feq=self.parser.getint("Logging", "save frequency"),
output_dir=self.parser.get("EXPERIMENT", "output directory"),
),
)[0]
[docs]
def get_factory(cfg: ConfigParser) -> IGANFactory:
"""
Return the factory for the GAN type based on 'model' key in the parser.
Parameters
----------
cfg : ConfigParser
Parser for config file containing GAN model and training params.
Returns
-------
IGANFactory
Factory for the specified GAN.
Raises
------
ValueError
If the model is unknown or not implemented.
"""
# read the desired GAN
model = cfg.get("Model", "type")
factories = {
"GAN": GANFactory(cfg),
"proj conditional GAN": ConditionalProjGANFactory(cfg),
"cat conditional GAN": ConditionalCatGANFactory(cfg),
"causal GAN": CausalGANFactory(cfg),
}
if model in factories:
return factories[model]
raise ValueError(f"model '{model}' type is invalid")