dowhy.causal_prediction.datasets package#
Submodules#
dowhy.causal_prediction.datasets.base_dataset module#
- MultipleDomainDataset class in this file is borrowed from DomainBed: facebookresearch/DomainBed
- @inproceedings{gulrajani2021in,
title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},
}
dowhy.causal_prediction.datasets.mnist module#
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#
Bases:
MultipleDomainDataset
Class for MNISTCausalAttribute dataset.
- Parameters:
root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- color_dataset(images, labels, environment)[source]#
Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.
- Parameters:
images – original MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
- Returns:
TensorDataset containing transformed images, labels, and attributes (color)
- class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#
Bases:
MultipleDomainDataset
Class for MNISTIndAttribute dataset.
- Parameters:
root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
- INPUT_SHAPE = (2, 14, 14)#
- N_STEPS = 5001#
- color_dataset(images, labels, environment)[source]#
Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.
- Parameters:
images – rotated MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
- Returns:
transformed images, labels, and attributes (color)
- color_rot_dataset(images, labels, environment, env_id, angle)[source]#
Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label. Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color.
- Parameters:
images – original MNIST images
labels – original MNIST labels
environment – Value of correlation between color and label
angle – Value of rotation angle used for transforming the image
- Returns:
TensorDataset containing transformed images, labels, and attributes (color, angle)
- class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#
Bases:
MultipleDomainDataset
Class for MNISTIndAttribute dataset.
- Parameters:
root – The directory where data can be found (or should be downloaded to, if it does not exist).
download – Binary flag indicating whether data should be downloaded
- Returns:
an instance of MultipleDomainDataset class
- CHECKPOINT_FREQ = 500#
- ENVIRONMENTS = ['15', '60', '90', '90']#
- INPUT_SHAPE = (1, 14, 14)#
- N_STEPS = 5001#
- rotate_dataset(images, labels, env_id, angle)[source]#
Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.
- Parameters:
images – original MNIST images
labels – original MNIST labels
angle – Value of rotation angle used for transforming the image
- Returns:
TensorDataset containing transformed images, labels, and attributes (angle)