Source code for dowhy.causal_prediction.dataloaders.misc
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved
"""misc helper functions
"""
import hashlib
from collections import Counter
import numpy as np
import torch
[docs]def make_weights_for_balanced_classes(dataset):
counts = Counter()
classes = []
for _, y in dataset:
y = int(y)
counts[y] += 1
classes.append(y)
n_classes = len(counts)
weight_per_class = {}
for y in counts:
weight_per_class[y] = 1 / (counts[y] * n_classes)
weights = torch.zeros(len(dataset))
for i, y in enumerate(classes):
weights[i] = weight_per_class[int(y)]
return weights
class _SplitDataset(torch.utils.data.Dataset):
"""Used by split_dataset"""
def __init__(self, underlying_dataset, keys):
super(_SplitDataset, self).__init__()
self.underlying_dataset = underlying_dataset
self.keys = keys
def __getitem__(self, key):
return self.underlying_dataset[self.keys[key]]
def __len__(self):
return len(self.keys)
[docs]def split_dataset(dataset, n, seed=0):
"""
Return a pair of datasets corresponding to a random split of the given
dataset, with n datapoints in the first dataset and the rest in the last,
using the given random seed
"""
assert n <= len(dataset)
keys = list(range(len(dataset)))
np.random.RandomState(seed).shuffle(keys)
keys_1 = keys[:n]
keys_2 = keys[n:]
return _SplitDataset(dataset, keys_1), _SplitDataset(dataset, keys_2)
[docs]def seed_hash(*args):
"""
Derive an integer hash from all args, for use as a random seed.
"""
args_str = str(args)
return int(hashlib.md5(args_str.encode("utf-8")).hexdigest(), 16) % (2**31)