Source code for dowhy.causal_prediction.algorithms.cacm

import torch
from torch.nn import functional as F

from dowhy.causal_prediction.algorithms.base_algorithm import PredictionAlgorithm
from dowhy.causal_prediction.algorithms.regularization import Regularizer


[docs]class CACM(PredictionAlgorithm): def __init__( self, model, optimizer="Adam", lr=1e-3, weight_decay=0.0, betas=(0.9, 0.999), momentum=0.9, kernel_type="gaussian", ci_test="mmd", attr_types=[], E_conditioned=True, E_eq_A=[], gamma=1e-6, lambda_causal=1.0, lambda_conf=1.0, lambda_ind=1.0, lambda_sel=1.0, ): """Class for Causally Adaptive Constraint Minimization (CACM) Algorithm. @article{Kaur2022ModelingTD, title={Modeling the Data-Generating Process is Necessary for Out-of-Distribution Generalization}, author={Jivat Neet Kaur and Emre Kıcıman and Amit Sharma}, journal={ArXiv}, year={2022}, volume={abs/2206.07837}, url={https://arxiv.org/abs/2206.07837} } :param model: Networks used for training. `model` type expected is torch.nn.Sequential(featurizer, classifier) where featurizer and classifier are of type torch.nn.Module. :param optimizer: Optimization algorithm used for training. Currently supports "Adam" and "SGD". :param lr: learning rate for CACM :param weight_decay: Value of weight decay for optimizer :param betas: Adam configuration parameters (beta1, beta2), exponential decay rate for the first moment and second-moment estimates, respectively. :param momentum: Value of momentum for SGD optimzer :param kernel_type: Kernel type for MMD penalty. Currently, supports "gaussian" (RBF). If None, distance between mean and second-order statistics (covariances) is used. :param ci_test: Conditional independence metric used for regularization penalty. Currently, MMD is supported. :param attr_types: list of attribute types (based on relationship with label Y); should be ordered according to attribute order in loaded dataset. Currently, 'causal' (Causal), 'conf' (Confounded), 'ind' (Independent) and 'sel' (Selected) are supported. For single-shift datasets, use: ['causal'], ['ind'] For multi-shift datasets, use: ['causal', 'ind'] :param E_conditioned: Binary flag indicating whether E-conditioned regularization has to be applied :param E_eq_A: list indicating indices of attributes that coincide with environment (E) definition; default is empty. :param gamma: kernel bandwidth for MMD (due to implementation, the kernel bandwdith will actually be the reciprocal of gamma i.e., gamma=1e-6 implies kernel bandwidth=1e6. See `mmd_compute` in utils.py) :param lambda_causal: MMD penalty hyperparameter for Causal shift :param lambda_conf: MMD penalty hyperparameter for Confounded shift :param lambda_ind: MMD penalty hyperparameter for Independent shift :param lambda_sel: MMD penalty hyperparameter for Selected shift :returns: an instance of PredictionAlgorithm class """ super().__init__(model, optimizer, lr, weight_decay, betas, momentum) self.CACMRegularizer = Regularizer(E_conditioned, ci_test, kernel_type, gamma) self.attr_types = attr_types self.E_eq_A = E_eq_A self.lambda_causal = lambda_causal self.lambda_conf = lambda_conf self.lambda_ind = lambda_ind self.lambda_sel = lambda_sel
[docs] def training_step(self, train_batch, batch_idx): """ Override `training_step` from PredictionAlgorithm class for CACM-specific training loop. """ self.featurizer = self.model[0] self.classifier = self.model[1] minibatches = train_batch objective = 0 correct, total = 0, 0 penalty_causal, penalty_conf, penalty_ind, penalty_sel = 0, 0, 0, 0 nmb = len(minibatches) features = [self.featurizer(xi) for xi, _, _ in minibatches] classifs = [self.classifier(fi) for fi in features] targets = [yi for _, yi, _ in minibatches] for i in range(nmb): objective += F.cross_entropy(classifs[i], targets[i]) correct += (torch.argmax(classifs[i], dim=1) == targets[i]).float().sum().item() total += classifs[i].shape[0] objective /= nmb loss = objective if self.attr_types != []: for attr_type_idx, attr_type in enumerate(self.attr_types): attribute_labels = [ ai for _, _, ai in minibatches ] # [(batch_size, num_attrs)_1, batch_size, num_attrs)_2, ..., (batch_size, num_attrs)_(num_environments)] E_eq_A_attr = attr_type_idx in self.E_eq_A # Acause regularization if attr_type == "causal": penalty_causal += self.CACMRegularizer.conditional_reg( classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr ) # Aconf regularization elif attr_type == "conf": penalty_conf += self.CACMRegularizer.unconditional_reg( classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr ) # Aind regularization elif attr_type == "ind": penalty_ind += self.CACMRegularizer.unconditional_reg( classifs, [a[:, attr_type_idx] for a in attribute_labels], nmb, E_eq_A_attr ) # Asel regularization elif attr_type == "sel": penalty_sel += self.CACMRegularizer.conditional_reg( classifs, [a[:, attr_type_idx] for a in attribute_labels], [targets], nmb, E_eq_A_attr ) if nmb > 1: penalty_causal /= nmb * (nmb - 1) / 2 penalty_conf /= nmb * (nmb - 1) / 2 penalty_ind /= nmb * (nmb - 1) / 2 penalty_sel /= nmb * (nmb - 1) / 2 # Compile loss loss += self.lambda_causal * penalty_causal loss += self.lambda_conf * penalty_conf loss += self.lambda_ind * penalty_ind loss += self.lambda_sel * penalty_sel if torch.is_tensor(penalty_causal): penalty_causal = penalty_causal.item() self.log("penalty_causal", penalty_causal, on_step=False, on_epoch=True, prog_bar=True) if torch.is_tensor(penalty_conf): penalty_conf = penalty_conf.item() self.log("penalty_conf", penalty_conf, on_step=False, on_epoch=True, prog_bar=True) if torch.is_tensor(penalty_ind): penalty_ind = penalty_ind.item() self.log("penalty_ind", penalty_ind, on_step=False, on_epoch=True, prog_bar=True) if torch.is_tensor(penalty_sel): penalty_sel = penalty_sel.item() self.log("penalty_sel", penalty_sel, on_step=False, on_epoch=True, prog_bar=True) elif self.graph is not None: pass # TODO else: raise ValueError("No attribute types or graph provided.") acc = correct / total metrics = {"train_acc": acc, "train_loss": loss} self.log_dict(metrics, on_step=False, on_epoch=True, prog_bar=True) return loss