# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""
The MNIST_MLP architecture is borrowed from OoD-Bench:
@inproceedings{ye2022ood,
title={OoD-Bench: Quantifying and Understanding Two Dimensions of Out-of-Distribution Generalization},
author={Ye, Nanyang and Li, Kaican and Bai, Haoyue and Yu, Runpeng and Hong, Lanqing and Zhou, Fengwei and Li, Zhenguo and Zhu, Jun},
booktitle={CVPR},
year={2022}
}
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models
[docs]class Identity(nn.Module):
"""An identity layer"""
def __init__(self):
super(Identity, self).__init__()
[docs] def forward(self, x):
return x
[docs]class MLP(nn.Module):
"""Just an MLP"""
def __init__(self, n_inputs, n_outputs, mlp_width, mlp_depth, mlp_dropout):
super(MLP, self).__init__()
self.input = nn.Linear(n_inputs, mlp_width)
self.dropout = nn.Dropout(mlp_dropout)
self.hiddens = nn.ModuleList([nn.Linear(mlp_width, mlp_width) for _ in range(mlp_depth - 2)])
self.output = nn.Linear(mlp_width, n_outputs)
self.n_outputs = n_outputs
[docs] def forward(self, x):
x = self.input(x)
x = self.dropout(x)
x = F.relu(x)
for hidden in self.hiddens:
x = hidden(x)
x = self.dropout(x)
x = F.relu(x)
x = self.output(x)
return x
[docs]class MNIST_MLP(nn.Module):
def __init__(self, input_shape):
super(MNIST_MLP, self).__init__()
self.hdim = hdim = 390
self.encoder = nn.Sequential(
nn.Linear(input_shape[0] * input_shape[1] * input_shape[2], hdim),
nn.ReLU(True),
nn.Linear(hdim, hdim),
nn.ReLU(True),
)
self.n_outputs = hdim
for m in self.encoder:
if isinstance(m, nn.Linear):
gain = nn.init.calculate_gain("relu")
nn.init.xavier_uniform_(m.weight, gain=gain)
nn.init.zeros_(m.bias)
[docs] def forward(self, x):
x = x.view(x.size(0), -1)
return self.encoder(x)
[docs]class ResNet(torch.nn.Module):
"""ResNet with the softmax chopped off and the batchnorm frozen"""
def __init__(self, input_shape, resnet18=True, resnet_dropout=0.0):
super(ResNet, self).__init__()
if resnet18:
self.network = torchvision.models.resnet18(pretrained=True)
self.n_outputs = 512
else:
self.network = torchvision.models.resnet50(pretrained=True)
self.n_outputs = 2048
# adapt number of channels
nc = input_shape[0]
if nc != 3:
tmp = self.network.conv1.weight.data.clone()
self.network.conv1 = nn.Conv2d(nc, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
for i in range(nc):
self.network.conv1.weight.data[:, i, :, :] = tmp[:, i % 3, :, :]
# save memory
del self.network.fc
self.network.fc = Identity()
self.freeze_bn()
self.dropout = nn.Dropout(resnet_dropout)
[docs] def forward(self, x):
"""Encode x into a feature vector of size n_outputs."""
return self.dropout(self.network(x))
[docs] def train(self, mode=True):
"""
Override the default train() to freeze the BN parameters
"""
super().train(mode)
self.freeze_bn()
[docs] def freeze_bn(self):
for m in self.network.modules():
if isinstance(m, nn.BatchNorm2d):
m.eval()
[docs]class MNIST_CNN(nn.Module):
"""
Hand-tuned architecture for MNIST.
Weirdness I've noticed so far with this architecture:
- adding a linear layer after the mean-pool in features hurts
RotatedMNIST-100 generalization severely.
"""
n_outputs = 128
def __init__(self, input_shape):
super(MNIST_CNN, self).__init__()
self.conv1 = nn.Conv2d(input_shape[0], 64, 3, 1, padding=1)
self.conv2 = nn.Conv2d(64, 128, 3, stride=2, padding=1)
self.conv3 = nn.Conv2d(128, 128, 3, 1, padding=1)
self.conv4 = nn.Conv2d(128, 128, 3, 1, padding=1)
self.bn0 = nn.GroupNorm(8, 64)
self.bn1 = nn.GroupNorm(8, 128)
self.bn2 = nn.GroupNorm(8, 128)
self.bn3 = nn.GroupNorm(8, 128)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
[docs] def forward(self, x):
x = self.conv1(x)
x = F.relu(x)
x = self.bn0(x)
x = self.conv2(x)
x = F.relu(x)
x = self.bn1(x)
x = self.conv3(x)
x = F.relu(x)
x = self.bn2(x)
x = self.conv4(x)
x = F.relu(x)
x = self.bn3(x)
x = self.avgpool(x)
x = x.view(len(x), -1)
return x
[docs]class ContextNet(nn.Module):
def __init__(self, input_shape):
super(ContextNet, self).__init__()
# Keep same dimensions
padding = (5 - 1) // 2
self.context_net = nn.Sequential(
nn.Conv2d(input_shape[0], 64, 5, padding=padding),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 64, 5, padding=padding),
nn.BatchNorm2d(64),
nn.ReLU(),
nn.Conv2d(64, 1, 5, padding=padding),
)
[docs] def forward(self, x):
return self.context_net(x)
[docs]def Classifier(in_features, out_features, is_nonlinear=False):
if is_nonlinear:
return torch.nn.Sequential(
torch.nn.Linear(in_features, in_features // 2),
torch.nn.ReLU(),
torch.nn.Linear(in_features // 2, in_features // 4),
torch.nn.ReLU(),
torch.nn.Linear(in_features // 4, out_features),
)
else:
return torch.nn.Linear(in_features, out_features)