dowhy.causal_prediction.dataloaders package
Submodules
dowhy.causal_prediction.dataloaders.fast_data_loader module
dowhy.causal_prediction.dataloaders.get_data_loader module
- dowhy.causal_prediction.dataloaders.get_data_loader.get_eval_loader(dataset, envs, batch_size, class_balanced=False)[source]
Return evaluation dataloaders (test/validation).
- Parameters:
dataset – dataset class containing list of environments
envs – list containing indices of validation/test domains in the dataset
batch_size – Value of batch size to be used for dataloaders
class_balanced – Binary flag indicating whether balanced sampling is to be done between classes
- Returns:
list of dataloaders
- dowhy.causal_prediction.dataloaders.get_data_loader.get_loaders(dataset, train_envs, batch_size, val_envs=None, test_envs=None, class_balanced=False, holdout_fraction=0.2, trial_seed=0)[source]
Return training, validation, and test dataloaders.
- Parameters:
dataset – dataset class containing list of environments
train_envs – list containing indices of training domains in the dataset
batch_size – Value of batch size to be used for dataloaders
val_envs – list containing indices of validation domains in the dataset. If None, fraction of training data (holdout_fraction) is used to create validation set.
test_envs – list containing indices of test domains in the dataset
class_balanced – Binary flag indicating whether balanced sampling is to be done between classes
holdout_fraction – fraction of training data used for creating validation domains. This is used when val_envs is None.
trial_seed – seed used for generating validation split from training data. This is used when val_envs is None.
- Returns:
dictionary of list of dataloaders in the format {‘train_loaders’: [train_dataloader_1, train_dataloader_2, ….],
’val_loaders’: [val_dataloader_1, val_dataloader_2, ….], ‘test_loaders’: [test_dataloader_1, test_dataloader_2, ….]
}
- dowhy.causal_prediction.dataloaders.get_data_loader.get_train_eval_loader(dataset, envs, batch_size, class_balanced, holdout_fraction, trial_seed)[source]
Return training and validation dataloaders.
- Parameters:
dataset – dataset class containing list of environments
envs – list containing indices of training domains in the dataset
batch_size – Value of batch size to be used for dataloaders
class_balanced – Binary flag indicating whether balanced sampling is to be done between classes
holdout_fraction – fraction of training data used for creating validation domains
trial_seed – seed used for generating validation split from training data
- Returns:
two lists of dataloaders for training (train_loaders) and validation (val_loaders) respectively
- dowhy.causal_prediction.dataloaders.get_data_loader.get_train_loader(dataset, envs, batch_size, class_balanced=False)[source]
Return training dataloaders.
- Parameters:
dataset – dataset class containing list of environments
envs – list containing indices of training domains in the dataset
batch_size – Value of batch size to be used for dataloaders
class_balanced – Binary flag indicating whether balanced sampling is to be done between classes
- Returns:
list of dataloaders
dowhy.causal_prediction.dataloaders.misc module
misc helper functions