dowhy.causal_prediction.datasets package#

Submodules#

dowhy.causal_prediction.datasets.base_dataset module#

MultipleDomainDataset class in this file is borrowed from DomainBed: facebookresearch/DomainBed
@inproceedings{gulrajani2021in,

title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},

}

class dowhy.causal_prediction.datasets.base_dataset.MultipleDomainDataset[source]#

Bases: object

CHECKPOINT_FREQ = 100#
ENVIRONMENTS = None#
INPUT_SHAPE = None#
N_STEPS = 5001#
N_WORKERS = 8#

dowhy.causal_prediction.datasets.mnist module#

class dowhy.causal_prediction.datasets.mnist.MNISTCausalAttribute(root, download=True)[source]#

Bases: MultipleDomainDataset

Class for MNISTCausalAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%', '+80%', '-90%', '-90%']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

Returns:

TensorDataset containing transformed images, labels, and attributes (color)

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#
class dowhy.causal_prediction.datasets.mnist.MNISTCausalIndAttribute(root, download=True)[source]#

Bases: MultipleDomainDataset

Class for MNISTIndAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['+90%, 15', '+80%, 16', '-90%, 90', '-90%, 90']#
INPUT_SHAPE = (2, 14, 14)#
N_STEPS = 5001#
color_dataset(images, labels, environment)[source]#

Transform MNIST dataset to introduce correlation between attribute (color) and label. There is a direct-causal relationship between label Y and color.

Parameters:
  • images – rotated MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

Returns:

transformed images, labels, and attributes (color)

color_rot_dataset(images, labels, environment, env_id, angle)[source]#

Transform MNIST dataset by (i) applying rotation to images, then (ii) introducing correlation between attribute (color) and label. Attribute (rotation angle) is independent of label Y; there is a direct-causal relationship between label Y and color.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • environment – Value of correlation between color and label

  • angle – Value of rotation angle used for transforming the image

Returns:

TensorDataset containing transformed images, labels, and attributes (color, angle)

rotate_dataset(images, angle)[source]#

Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.

Parameters:
  • images – original MNIST images

  • angle – Value of rotation angle used for transforming the image

Returns:

transformed images

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#
class dowhy.causal_prediction.datasets.mnist.MNISTIndAttribute(root, download=True)[source]#

Bases: MultipleDomainDataset

Class for MNISTIndAttribute dataset.

Parameters:
  • root – The directory where data can be found (or should be downloaded to, if it does not exist).

  • download – Binary flag indicating whether data should be downloaded

Returns:

an instance of MultipleDomainDataset class

CHECKPOINT_FREQ = 500#
ENVIRONMENTS = ['15', '60', '90', '90']#
INPUT_SHAPE = (1, 14, 14)#
N_STEPS = 5001#
rotate_dataset(images, labels, env_id, angle)[source]#

Transform MNIST dataset by applying rotation to images. Attribute (rotation angle) is independent of label Y.

Parameters:
  • images – original MNIST images

  • labels – original MNIST labels

  • angle – Value of rotation angle used for transforming the image

Returns:

TensorDataset containing transformed images, labels, and attributes (angle)

torch_bernoulli_(p, size)[source]#
torch_xor_(a, b)[source]#

Module contents#