Demo for DoWhy Causal Prediction on MNIST#

The goal of this notebook is to demonstrate an example of causal prediction using Causally Adaptive Constraint Minimization (CACM) (https://arxiv.org/abs/2206.07837) [1].

Multi-attribute distribution shift datasets#

Domain generalization literature has largely focused on datasets with a single kind of distribution shift over one attribute. Using MNIST as an example, domains are created either by adding new values of a spurious attribute like rotation (e.g., Rotated-MNIST dataset [2]) or domains exhibit different values of correlation between the class label and a spurious attribute like color (e.g., Colored-MNIST [3]). However, real-world data often has multiple distribution shifts over different attributes. For example, satellite imagery data demonstrates distribution shifts over time as well as the region captured.

Multi-attribute MNIST#

We create a multi-attribute shift variant of MNIST, where both the color and rotation angle of digits can shift across data distributions. Hence, we create three variants of MNIST – MNISTCausalAttribute (single-attribute shift), MNISTIndAttribute (single-attribute shift), MNISTCausalIndAttribute (multi-attribute shift). To describe, Causal, Ind, and CausalInd datasets better, consider the causal graph for the data generating process below:

main_fig_mnist.drawio.png

Distribution shifts are characterized based on the relationship between spurious attributes A and the classification label Y. 1. Causal: Attribute has a direct-Causal relationship with the class label i.e., Y causing attribute (e.g., Color here) 2. Ind: Attribute is Independent of the class label (e.g., Rotation here) 3. CausalInd: Different attributes having Causal and Independent relationships with Y co-exist in the data

Domains in multi-attribute MNIST#

We describe the domains for our multi-attribute shift dataset MNISTCausalIndAttribute. Each domain Ei has a specific Rotation angle ri and a specific correlation corri between Color C and label Y . Our setup consists of 3 domains: E1, E2 are training domains, E3 is the test domain. We define corri = P(Y = 1|C = 1) = P(Y = 0|C = 0) in Ei. In our setup, r1 = 15◦, r2 = 60◦, r3 = 90◦ and corr1 = 0.9, corr2 = 0.8, corr3 = 0.1. All environments have 25% label noise, as in [3]

Other dataset-related details can be found in dowhy.causal_prediction.datasets.

MNIST_visualize.drawio%20%283%29.png

[1]:
import torch
import pytorch_lightning as pl

Initialize dataset#

[2]:
from dowhy.causal_prediction.datasets.mnist import MNISTCausalAttribute

# dataset class initialization requires mandatory param `data_dir`
# `download` is passed to torchvision.datasets.MNIST and downloads data if not present
data_dir = 'data'
dataset = MNISTCausalAttribute(data_dir, download=True)
Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-images-idx3-ubyte.gz to data/MNIST/raw/train-images-idx3-ubyte.gz
Extracting data/MNIST/raw/train-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/train-labels-idx1-ubyte.gz to data/MNIST/raw/train-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/train-labels-idx1-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-images-idx3-ubyte.gz to data/MNIST/raw/t10k-images-idx3-ubyte.gz
Extracting data/MNIST/raw/t10k-images-idx3-ubyte.gz to data/MNIST/raw

Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz
Failed to download (trying next):
HTTP Error 403: Forbidden

Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz
Downloading https://ossci-datasets.s3.amazonaws.com/mnist/t10k-labels-idx1-ubyte.gz to data/MNIST/raw/t10k-labels-idx1-ubyte.gz
Extracting data/MNIST/raw/t10k-labels-idx1-ubyte.gz to data/MNIST/raw

Initialize data loaders#

get_loaders returns data loaders for training, validation, and test. loaders returned is a dictionary of train_loaders, val_loaders, test_loaders. There are two scenarios supported currently to initialize validation domains:

Method 1: When a domain(s) from the dataset is explicitly specified as the validation domain Method 2: When no specific validation domain is present, a subset of the training domain(s) is used to create the validation set

Run either cell below Method 1 or Method 2 as required.

[3]:
from dowhy.causal_prediction.dataloaders.get_data_loader import get_loaders

Method 1: Provide validation domain explicitly#

Provide index of validation domains as val_envs. test_envs is an optional parameter.

[4]:
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64,
            val_envs=[2], test_envs=[3])
/github/home/.cache/pypoetry/virtualenvs/dowhy-oN2hW5jr-py3.8/lib/python3.8/site-packages/torch/utils/data/dataloader.py:554: UserWarning: This DataLoader will create 8 worker processes in total. Our suggested max number of worker in current system is 4, which is smaller than what this DataLoader is going to create. Please be aware that excessive worker creation might get DataLoader running slow or even freeze, lower the worker number to avoid potential slowness/freeze if necessary.
  warnings.warn(_create_warning_msg(

Method 2: Validation set using subset of training data#

val_envs, test_envs are optional parameters. If val_envs is not provided, a subset of training data is used for creating the validation set. The fraction of training data used is determined by holdout_fraction.

[5]:
loaders = get_loaders(dataset, train_envs=[0, 1], batch_size=64,
            holdout_fraction=0.2, test_envs=[3])

The code below handles more than one validation or test domains, if present. Run the cell below irrespective of Method 1 or 2 used above.

[6]:
# handle multiple validation and test domains if present
from pytorch_lightning.trainer.supporters import CombinedLoader

if len(loaders['val_loaders']) > 1:
    val_loaders = loaders['val_loaders']
    loaders['val_loaders'] = CombinedLoader(val_loaders)

if len(loaders['test_loaders']) > 1:
    test_loaders = loaders['test_loaders']
    loaders['test_loaders'] = CombinedLoader(test_loaders)

Initialize model and algorithm#

[7]:
from dowhy.causal_prediction.models.networks import MNIST_MLP, Classifier

model below is expected to be of type torch.nn.Sequential with two torch.nn.Module elements (feature extractor and classifier). We provide sample networks (MLP, ResNet) in dowhy.causal_prediction.models.networks but the user can flexibly use any model.

[8]:
featurizer = MNIST_MLP(dataset.input_shape)
classifier = Classifier(
    featurizer.n_outputs,
    dataset.num_classes)

model = torch.nn.Sequential(featurizer, classifier)

Initialize algorithm class: ERM#

We have implemented Empirical Risk Minimization (ERM) in dowhy.causal_prediction.algorithms as a baseline.

[9]:
from dowhy.causal_prediction.algorithms.erm import ERM
[10]:
algorithm = ERM(model, lr=1e-3)

Fit predictor and start training#

Note: The optimal accuracy for MNISTCausalAttribute (and other MNIST variants introduced) is 75% as we introduce 25% noise following previous work.

[11]:
trainer = pl.Trainer(devices=1, max_epochs=5)

# val_loaders is optional param
trainer.fit(algorithm, loaders['train_loaders'], loaders['val_loaders'])
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/github/home/.cache/pypoetry/virtualenvs/dowhy-oN2hW5jr-py3.8/lib/python3.8/site-packages/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py:67: UserWarning: Starting from v1.9.0, `tensorboardX` has been removed as a dependency of the `pytorch_lightning` package, due to potential conflicts with other packages in the ML ecosystem. For this reason, `logger=True` will use `CSVLogger` as the default logger, unless the `tensorboard` or `tensorboardX` packages are found. Please `pip install lightning[extra]` or one of them to enable TensorBoard support by default
  warning_cache.warn(
Missing logger folder: /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 306 K
-------------------------------------
306 K     Trainable params
0         Non-trainable params
306 K     Total params
1.226     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.

Evaluate on test domain#

Perform an evaluation epoch over the test set using trainer.test. ckpt_path determines the model to be used for evaluation – ‘best’, ‘last’, or path to a specific checkpoint. If ckpt_path is not passed, best model checkpoint from the previous trainer.fit is loaded (https://pytorch-lightning.readthedocs.io/en/stable/_modules/pytorch_lightning/trainer/trainer.html#Trainer.test).

We report accuracy (test_acc) and cross-entropy loss (test_loss) on the test domains/test set.

[12]:
if 'test_loaders' in loaders:
    trainer.test(dataloaders=loaders['test_loaders'], ckpt_path='best')
Restoring states from the checkpoint path at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_0/checkpoints/epoch=4-step=1560.ckpt
Loaded model weights from checkpoint at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_0/checkpoints/epoch=4-step=1560.ckpt
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc            0.19349999725818634
        test_loss           1.6614809036254883
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Prediction with CACM#

We now train and evaluate the above dataset with CACM. We specify the type of shifts present using list attr_types provided as input to CACM. Further instructions regarding using CACM with multi-attribute shifts is provided in the next section.

[13]:
from dowhy.causal_prediction.algorithms.cacm import CACM
[14]:
# `attr_types` list contains type of attributes present (supports 'causal', 'conf', ind', and  'sel' currently)
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal'], lambda_causal=100.)
[15]:
trainer = pl.Trainer(devices=1, max_epochs=5)

trainer.fit(algorithm, loaders['train_loaders'], loaders['val_loaders'])
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs

  | Name  | Type       | Params
-------------------------------------
0 | model | Sequential | 306 K
-------------------------------------
306 K     Trainable params
0         Non-trainable params
306 K     Total params
1.226     Total estimated model params size (MB)
`Trainer.fit` stopped: `max_epochs=5` reached.
[16]:
if 'test_loaders' in loaders:
    trainer.test(dataloaders=loaders['test_loaders'], ckpt_path='best')
Restoring states from the checkpoint path at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_1/checkpoints/epoch=4-step=1560.ckpt
Loaded model weights from checkpoint at /__w/dowhy/dowhy/docs/source/example_notebooks/prediction/lightning_logs/version_1/checkpoints/epoch=4-step=1560.ckpt
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
       Test metric             DataLoader 0
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────
        test_acc             0.614799976348877
        test_loss           0.6902953386306763
────────────────────────────────────────────────────────────────────────────────────────────────────────────────────────

Extending to different datasets and algorithms#

MNIST Independent and Causal+Independent datasets#

We show how to perform the above evaluation for MNISTIndAttribute andMNISTCausalIndAttribute datasets. Additional attr_types should be provided to CACM algorithm for handling multiple shifts. We currently support Causal, Confounded, Independent, and Selected distribution shifts in the data.

MNISTIndAttribute: Single-attribute Independent shift#

[17]:
from dowhy.causal_prediction.datasets.mnist import MNISTIndAttribute

data_dir = 'data'
dataset = MNISTIndAttribute(data_dir)
[18]:
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['ind'], lambda_ind=10., E_eq_A=[0])

MNISTCausalIndAttribute: Multi-attribute Causal+Independent shift#

[19]:
from dowhy.causal_prediction.datasets.mnist import MNISTCausalIndAttribute

data_dir = 'data'
dataset = MNISTCausalIndAttribute(data_dir)
[20]:
# `attr_types` should be ordered consistent with the attribute order in dataset class
algorithm = CACM(model, lr=1e-3, gamma=1e-2, attr_types=['causal', 'ind'], lambda_causal=100., lambda_ind=10., E_eq_A=[1])

Additional datasets and algorithms#

We provide our demo on MNIST using ERM and CACM algorithms. It is possible to extend the evaluation to new datasets and algorithms for evaluation.

New datasets can be added to dowhy.causal_prediction.datasets and imported here, as we did for MNIST. We provide description of the MNIST dataset (and variants) in dowhy.causal_prediction.datasets.mnist that will be helpful in creating new dataset classes. We currently support Causal, Confounded, Independent, and Selected distribution shifts in the data.

We have implemented ERM in dowhy.causal_prediction.algorithms as a baseline. Additional algorithms can be added by overriding the training_step function in base class PredictionAlgorithm.

References#

[1] Kaur, J.N., Kıcıman, E., & Sharma, A. (2022). Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization. ArXiv, abs/2206.07837.

[2] Ghifary, M., Kleijn, W., Zhang, M., & Balduzzi, D. (2015). Domain Generalization for Object Recognition with Multi-task Autoencoders. 2015 IEEE International Conference on Computer Vision (ICCV), 2551-2559.

[3] Arjovsky, M., Bottou, L., Gulrajani, I., & Lopez-Paz, D. (2019). Invariant Risk Minimization. ArXiv, abs/1907.02893.