Source code for dowhy.causal_prediction.models.networks

# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved

""" 
    The MNIST_MLP architecture is borrowed from OoD-Bench:
        @inproceedings{ye2022ood,
         title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization},
         author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun},
         booktitle={CVPR},
         year={2022}
        }
"""

import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models


[docs]class Identity(nn.Module): """An identity layer""" def __init__(self): super(Identity, self).__init__()
[docs] def forward(self, x): return x
[docs]class MLP(nn.Module): """Just an MLP""" def __init__(self, n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout): super(MLP, self).__init__() self.input = nn.Linear(n_inputs, mlp_width) self.dropout = nn.Dropout(mlp_dropout) self.hiddens = nn.ModuleList([nn.Linear(mlp_width, mlp_width) for _ in range(mlp_depth - 2)]) self.output = nn.Linear(mlp_width, n_outputs) self.n_outputs = n_outputs
[docs] def forward(self, x): x = self.input(x) x = self.dropout(x) x = F.relu(x) for hidden in self.hiddens: x = hidden(x) x = self.dropout(x) x = F.relu(x) x = self.output(x) return x
[docs]class MNIST_MLP(nn.Module): def __init__(self, input_shape): super(MNIST_MLP, self).__init__() self.hdim = hdim = 390 self.encoder = nn.Sequential( nn.Linear(input_shape[0] * input_shape[1] * input_shape[2], hdim), nn.ReLU(True), nn.Linear(hdim, hdim), nn.ReLU(True), ) self.n_outputs = hdim for m in self.encoder: if isinstance(m, nn.Linear): gain = nn.init.calculate_gain("relu") nn.init.xavier_uniform_(m.weight, gain=gain) nn.init.zeros_(m.bias)
[docs] def forward(self, x): x = x.view(x.size(0), -1) return self.encoder(x)
[docs]class ResNet(torch.nn.Module): """ResNet with the softmax chopped off and the batchnorm frozen""" def __init__(self, input_shape, resnet18=True, resnet_dropout=0.0): super(ResNet, self).__init__() if resnet18: self.network = torchvision.models.resnet18(pretrained=True) self.n_outputs = 512 else: self.network = torchvision.models.resnet50(pretrained=True) self.n_outputs = 2048 # adapt number of channels nc = input_shape[0] if nc != 3: tmp = self.network.conv1.weight.data.clone() self.network.conv1 = nn.Conv2d(nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False) for i in range(nc): self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :] # save memory del self.network.fc self.network.fc = Identity() self.freeze_bn() self.dropout = nn.Dropout(resnet_dropout)
[docs] def forward(self, x): """Encode x into a feature vector of size n_outputs.""" return self.dropout(self.network(x))
[docs] def train(self, mode=True): """ Override the default train() to freeze the BN parameters """ super().train(mode) self.freeze_bn()
[docs] def freeze_bn(self): for m in self.network.modules(): if isinstance(m, nn.BatchNorm2d): m.eval()
[docs]class MNIST_CNN(nn.Module): """ Hand-tuned architecture for MNIST. Weirdness I've noticed so far with this architecture: - adding a linear layer after the mean-pool in features hurts RotatedMNIST-100 generalization severely. """ n_outputs = 128 def __init__(self, input_shape): super(MNIST_CNN, self).__init__() self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1) self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1) self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1) self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1) self.bn0 = nn.GroupNorm(8, 64) self.bn1 = nn.GroupNorm(8, 128) self.bn2 = nn.GroupNorm(8, 128) self.bn3 = nn.GroupNorm(8, 128) self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
[docs] def forward(self, x): x = self.conv1(x) x = F.relu(x) x = self.bn0(x) x = self.conv2(x) x = F.relu(x) x = self.bn1(x) x = self.conv3(x) x = F.relu(x) x = self.bn2(x) x = self.conv4(x) x = F.relu(x) x = self.bn3(x) x = self.avgpool(x) x = x.view(len(x), -1) return x
[docs]class ContextNet(nn.Module): def __init__(self, input_shape): super(ContextNet, self).__init__() # Keep same dimensions padding = (5 - 1) // 2 self.context_net = nn.Sequential( nn.Conv2d(input_shape[0], 64, 5, padding=padding), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 64, 5, padding=padding), nn.BatchNorm2d(64), nn.ReLU(), nn.Conv2d(64, 1, 5, padding=padding), )
[docs] def forward(self, x): return self.context_net(x)
[docs]def Classifier(in_features, out_features, is_nonlinear=False): if is_nonlinear: return torch.nn.Sequential( torch.nn.Linear(in_features, in_features // 2), torch.nn.ReLU(), torch.nn.Linear(in_features // 2, in_features // 4), torch.nn.ReLU(), torch.nn.Linear(in_features // 4, out_features), ) else: return torch.nn.Linear(in_features, out_features)