dowhy.causal_prediction.dataloaders package#

Submodules#

dowhy.causal_prediction.dataloaders.fast_data_loader module#

class dowhy.causal_prediction.dataloaders.fast_data_loader.FastDataLoader(dataset, batch_size, num_workers)[source]#

Bases: object

DataLoader wrapper with slightly improved speed by not respawning worker processes at every epoch.

class dowhy.causal_prediction.dataloaders.fast_data_loader.InfiniteDataLoader(dataset, weights, batch_size, num_workers)[source]#

Bases: object

dowhy.causal_prediction.dataloaders.get_data_loader module#

dowhy.causal_prediction.dataloaders.get_data_loader.get_eval_loader(dataset, envs, batch_size, class_balanced=False)[source]#

Return evaluation dataloaders (test/validation).

Parameters:
  • dataset – dataset class containing list of environments

  • envs – list containing indices of validation/test domains in the dataset

  • batch_size – Value of batch size to be used for dataloaders

  • class_balanced – Binary flag indicating whether balanced sampling is to be done between classes

Returns:

list of dataloaders

dowhy.causal_prediction.dataloaders.get_data_loader.get_loaders(dataset, train_envs, batch_size, val_envs=None, test_envs=None, class_balanced=False, holdout_fraction=0.2, trial_seed=0)[source]#

Return training, validation, and test dataloaders.

Parameters:
  • dataset – dataset class containing list of environments

  • train_envs – list containing indices of training domains in the dataset

  • batch_size – Value of batch size to be used for dataloaders

  • val_envs – list containing indices of validation domains in the dataset. If None, fraction of training data (holdout_fraction) is used to create validation set.

  • test_envs – list containing indices of test domains in the dataset

  • class_balanced – Binary flag indicating whether balanced sampling is to be done between classes

  • holdout_fraction – fraction of training data used for creating validation domains. This is used when val_envs is None.

  • trial_seed – seed used for generating validation split from training data. This is used when val_envs is None.

Returns:

dictionary of list of dataloaders in the format {‘train_loaders’: [train_dataloader_1, train_dataloader_2, ….],

’val_loaders’: [val_dataloader_1, val_dataloader_2, ….], ‘test_loaders’: [test_dataloader_1, test_dataloader_2, ….]

}

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_eval_loader(dataset, envs, batch_size, class_balanced, holdout_fraction, trial_seed)[source]#

Return training and validation dataloaders.

Parameters:
  • dataset – dataset class containing list of environments

  • envs – list containing indices of training domains in the dataset

  • batch_size – Value of batch size to be used for dataloaders

  • class_balanced – Binary flag indicating whether balanced sampling is to be done between classes

  • holdout_fraction – fraction of training data used for creating validation domains

  • trial_seed – seed used for generating validation split from training data

Returns:

two lists of dataloaders for training (train_loaders) and validation (val_loaders) respectively

dowhy.causal_prediction.dataloaders.get_data_loader.get_train_loader(dataset, envs, batch_size, class_balanced=False)[source]#

Return training dataloaders.

Parameters:
  • dataset – dataset class containing list of environments

  • envs – list containing indices of training domains in the dataset

  • batch_size – Value of batch size to be used for dataloaders

  • class_balanced – Binary flag indicating whether balanced sampling is to be done between classes

Returns:

list of dataloaders

dowhy.causal_prediction.dataloaders.misc module#

misc helper functions

dowhy.causal_prediction.dataloaders.misc.make_weights_for_balanced_classes(dataset)[source]#
dowhy.causal_prediction.dataloaders.misc.seed_hash(*args)[source]#

Derive an integer hash from all args, for use as a random seed.

dowhy.causal_prediction.dataloaders.misc.split_dataset(dataset, n, seed=0)[source]#

Return a pair of datasets corresponding to a random split of the given dataset, with n datapoints in the first dataset and the rest in the last, using the given random seed

Module contents#