Source code for layers.cbn

import torch
from torch import nn


[docs] class ConditionalBatchNorm(nn.Module):
[docs] def __init__(self, num_features: int, num_classes: int) -> None: """ 1D Conditional batch normalization (CBN) layer (Dumoulin et al., 2016; De Vries et al.,2017). Parameters ---------- num_features : int Number of input features. num_classes : int Number of classes (i.e., distinct labels). """ super().__init__() self.num_features = num_features # regular batch norm without learnable parameters self._batch_norm = nn.BatchNorm1d(num_features, affine=False) self._embed = nn.Embedding(num_classes, num_features * 2) # Initialise scale at N(1, 0.02) self._embed.weight.data[:, :num_features].normal_(1, 0.02) # Initialise shift at 0 self._embed.weight.data[:, num_features:].zero_()
[docs] def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: """ Perform CBN given a batch of labels. Parameters ---------- x : torch.Tensor Tensor on which to perform CBN. y : torch.Tensor A batch of labels. Returns ------- torch.Tensor Conditionally batch normalized input. """ out = self._batch_norm(x) # separate weight and bias from the embedding scale, shift = self._embed(y).chunk(2, 1) # shift and scale activations based on labels y provided return scale * out + shift