dowhy.causal_prediction.algorithms package
Submodules
dowhy.causal_prediction.algorithms.base_algorithm module
- class dowhy.causal_prediction.algorithms.base_algorithm.PredictionAlgorithm(model, optimizer, lr, weight_decay, betas, momentum)[source]
Bases:
LightningModule
This class implements the default methods for a Pytorch lightning module pl.LightningModule. Its methods are called when the fit() method is called. To know more about these methods, refer to https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html.
- Parameters:
model – Neural network modules used for training
optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.
lr – Value of learning rate
weight_decay – Value of weight decay for optimizer
betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.
momentum – Value of momentum for SGD optimzer
- configure_optimizers()[source]
Initialize the optimizer using params passed when initializing PredictionAlgorithm class.
- test_step(batch, batch_idx, dataloader_idx=0)[source]
Activate the test loop for the pl.LightningModule.
Test loop is called only when test() is used.
dowhy.causal_prediction.algorithms.cacm module
- class dowhy.causal_prediction.algorithms.cacm.CACM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9, kernel_type='gaussian', ci_test='mmd', attr_types=[], E_conditioned=True, E_eq_A=[], gamma=1e-06, lambda_causal=1.0, lambda_conf=1.0, lambda_ind=1.0, lambda_sel=1.0)[source]
Bases:
PredictionAlgorithm
- Class for Causally Adaptive Constraint Minimization (CACM) Algorithm.
- @article{Kaur2022ModelingTD,
title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization}, author={Jivat Neet Kaur and Emre Kıcıman and Amit Sharma}, journal={ArXiv}, year={2022}, volume={abs/2206.07837}, url={https://arxiv.org/abs/2206.07837}
}
- Parameters:
model – Networks used for training. model type expected is torch.nn.Sequential(featurizer, classifier) where featurizer and classifier are of type torch.nn.Module.
optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.
lr – learning rate for CACM
weight_decay – Value of weight decay for optimizer
betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.
momentum – Value of momentum for SGD optimzer
kernel_type – Kernel type for MMD penalty. Currently, supports “gaussian” (RBF). If None, distance between mean and second-order statistics (covariances) is used.
ci_test – Conditional independence metric used for regularization penalty. Currently, MMD is supported.
attr_types – list of attribute types (based on relationship with label Y); should be ordered according to attribute order in loaded dataset. Currently, ‘causal’ (Causal), ‘conf’ (Confounded), ‘ind’ (Independent) and ‘sel’ (Selected) are supported. For single-shift datasets, use: [‘causal’], [‘ind’] For multi-shift datasets, use: [‘causal’, ‘ind’]
E_conditioned – Binary flag indicating whether E-conditioned regularization has to be applied
E_eq_A – list indicating indices of attributes that coincide with environment (E) definition; default is empty.
gamma – kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See mmd_compute in utils.py)
lambda_causal – MMD penalty hyperparameter for Causal shift
lambda_conf – MMD penalty hyperparameter for Confounded shift
lambda_ind – MMD penalty hyperparameter for Independent shift
lambda_sel – MMD penalty hyperparameter for Selected shift
- Returns:
an instance of PredictionAlgorithm class
dowhy.causal_prediction.algorithms.erm module
- class dowhy.causal_prediction.algorithms.erm.ERM(model, optimizer='Adam', lr=0.001, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9)[source]
Bases:
PredictionAlgorithm
This class implements the default methods for a Pytorch lightning module pl.LightningModule. Its methods are called when the fit() method is called. To know more about these methods, refer to https://pytorch-lightning.readthedocs.io/en/stable/common/lightning_module.html.
- Parameters:
model – Neural network modules used for training
optimizer – Optimization algorithm used for training. Currently supports “Adam” and “SGD”.
lr – Value of learning rate
weight_decay – Value of weight decay for optimizer
betas – Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively.
momentum – Value of momentum for SGD optimzer
dowhy.causal_prediction.algorithms.regularization module
- class dowhy.causal_prediction.algorithms.regularization.Regularizer(E_conditioned, ci_test, kernel_type, gamma)[source]
Bases:
object
Implements methods for applying unconditional and conditional regularization.
- Parameters:
E_conditioned – Binary flag indicating whether E-conditioned regularization has to be applied
ci_test – Conditional independence metric used for regularization penalty. Currently, MMD is supported.
kernel_type – Kernel type for MMD penalty. Currently, supports “gaussian” (RBF). If None, distance between mean and second-order statistics (covariances) is used.
gamma – kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See mmd_compute in utils.py)
- conditional_reg(classifs, attribute_labels, conditioning_subset, num_envs, E_eq_A=False)[source]
Implement conditional regularization φ(x) ⊥⊥ A_i | A_s
- Parameters:
classifs – feature representations output from classifier layer (gφ(x))
attribute_labels – attribute labels loaded with the dataset for attribute A_i
conditioning_subset – list of subset of observed variables A_s (attributes + targets) such that (X_c, A_i) are d-separated conditioned on this subset
num_envs – number of environments/domains
E_eq_A – Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition
Find group indices for conditional regularization based on conditioning subset by taking all possible combinations e.g., conditioning_subset = [A1, Y], where A1 is in {0, 1} and Y is in {0, 1, 2}, we assign groups in the following way:
A1 = 0, Y = 0 -> group 0 A1 = 1, Y = 0 -> group 1 A1 = 0, Y = 1 -> group 2 A1 = 1, Y = 1 -> group 3 A1 = 0, Y = 2 -> group 4 A1 = 1, Y = 2 -> group 5
- Code snippet for computing group indices adapted from WILDS: https://github.com/p-lambda/wilds
- @inproceedings{wilds2021,
title = {{WILDS}: A Benchmark of in-the-Wild Distribution Shifts}, author = {Pang Wei Koh and Shiori Sagawa and Henrik Marklund and Sang Michael Xie and Marvin Zhang and Akshay Balsubramani and Weihua Hu and Michihiro Yasunaga and Richard Lanas Phillips and Irena Gao and Tony Lee and Etienne David and Ian Stavness and Wei Guo and Berton A. Earnshaw and Imran S. Haque and Sara Beery and Jure Leskovec and Anshul Kundaje and Emma Pierson and Sergey Levine and Chelsea Finn and Percy Liang}, booktitle = {International Conference on Machine Learning (ICML)}, year = {2021}
}`
- unconditional_reg(classifs, attribute_labels, num_envs, E_eq_A=False)[source]
Implement unconditional regularization φ(x) ⊥⊥ A_i
- Parameters:
classifs – feature representations output from classifier layer (gφ(x))
attribute_labels – attribute labels loaded with the dataset for attribute A_i
num_envs – number of environments/domains
E_eq_A – Binary flag indicating whether attribute (A_i) coinicides with environment (E) definition
dowhy.causal_prediction.algorithms.utils module
- The functions in this file are borrowed from DomainBed: https://github.com/facebookresearch/DomainBed
- @inproceedings{gulrajani2021in,
title={In Search of Lost Domain Generalization}, author={Ishaan Gulrajani and David Lopez-Paz}, booktitle={International Conference on Learning Representations}, year={2021},
}