import math
import typing
import torch
import torch.nn as nn
[docs]class MaskedLinearFunction(torch.autograd.Function):
    """
    autograd function which masks its weights by 'mask'.
    """
[docs]    @staticmethod
    def forward(
        ctx: torch.Tensor,
        input: torch.Tensor,
        weight: torch.Tensor,
        bias: torch.Tensor = None,
        mask: torch.Tensor = None,
    ) -> torch.Tensor:
        if mask is not None:
            weight = weight * mask
        output = input.mm(weight.t())
        if bias is not None:
            output += bias.unsqueeze(0).expand_as(output)
        ctx.save_for_backward(input, weight, bias, mask)
        return output 
[docs]    @staticmethod
    def backward(ctx: torch.Tensor, grad_output: torch.Tensor):
        input, weight, bias, mask = ctx.saved_tensors
        grad_input = grad_weight = grad_bias = grad_mask = None
        if ctx.needs_input_grad[0]:
            grad_input = grad_output.mm(weight)
        if ctx.needs_input_grad[1]:
            grad_weight = grad_output.t().mm(input)
            if mask is not None:
                # change grad_weight to 0 where mask == 0
                grad_weight = grad_weight * mask
        # if bias is not None and ctx.needs_input_grad[2]:
        if ctx.needs_input_grad[2]:
            grad_bias = grad_output.sum(0).squeeze(0)
        return grad_input, grad_weight, grad_bias, grad_mask  
[docs]class MaskedLinear(nn.Module):
[docs]    def __init__(
        self,
        mask: torch.Tensor,
        bias: bool = True,
        device: typing.Optional[str] = "cuda" if torch.cuda.is_available() else "cpu",
    ):    
        """
        An extension of Pytorch's linear module based on the following thread:
        https://discuss.pytorch.org/t/custom-connections-in-neural-network-layers/3027/13
        Parameters
        ----------
        mask : torch.Tensor
            Mask Tensor with shape (n_input_feature, n_output_feature).
            the elements are 0 or 1 which declare un-connected or
            connected.
            Example: the following mask declares a 4-dim from-layer and 3-dim to-layer.
            Neurons 0, 2, and 3 of the from-layer are connected to neurons 0 and 2 of
            the to-layer. Neuron 1 of the from-layer is connected to neuron 1 of the
            to-layer.
            mask = torch.tensor(
            [[1, 0, 1],
            [0, 1, 0],
            [1, 0, 1],
            [1, 0, 1],]
            )
        bias : bool, optional
            By default True
        device : typing.Optional[str], optional
            Specifies to train on 'cpu' or 'cuda'. Only 'cuda' is supported for training the
            GAN but 'cpu' can be used for inference, by default "cuda" if torch.cuda.is_available() else"cpu".
        """
        super(MaskedLinear, self).__init__()
        self.input_features = mask.shape[0]
        self.output_features = mask.shape[1]
        self.device = device
        if isinstance(mask, torch.Tensor):
            self.mask = mask.type(torch.float).t()
        else:
            self.mask = torch.tensor(mask, dtype=torch.float).t()
        self.mask = nn.Parameter(self.mask, requires_grad=False)
        self.weight = nn.Parameter(
            torch.Tensor(self.output_features, self.input_features).to(self.device)
        )
        if bias:
            self.bias = nn.Parameter(torch.Tensor(self.output_features).to(self.device))
        else:
            self.register_parameter("bias", None)
        self.reset_parameters()
        # mask weight
        self.weight.data = self.weight.data * self.mask 
[docs]    def reapply_mask(self):
        """Function to be called after weights have been initialized
        (e.g., using torch.nn.init) to reapply mask to weight."""
        # mask weight
        self.weight.data = self.weight.data * self.mask 
[docs]    def reset_parameters(self):
        stdv = 1.0 / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv) 
[docs]    def forward(self, input: torch.Tensor):
        return MaskedLinearFunction.apply(input, self.weight, self.bias, self.mask) 
 
if __name__ == "check grad":
    from torch.autograd import gradcheck
    customlinear = MaskedLinearFunction.apply
    input = (
        torch.randn(20, 20, dtype=torch.double, requires_grad=True),
        torch.randn(30, 20, dtype=torch.double, requires_grad=True),
        None,
        None,
    )
    test = gradcheck(customlinear, input, eps=1e-6, atol=1e-4)
    print(test)