Source code for dowhy.causal_refuters.overrule.utils

"""Utilities for learning boolean rules.

This module implements the boolean ruleset estimator from OverRule [1]. Code is adapted (with some simplifications)
from https://github.com/clinicalml/overlap-code, under the MIT License.

[1] Oberst, M., Johansson, F., Wei, D., Gao, T., Brat, G., Sontag, D., & Varshney, K. (2020). Characterization of
Overlap in Observational Studies. In S. Chiappa & R. Calandra (Eds.), Proceedings of the Twenty Third International
Conference on Artificial Intelligence and Statistics (Vol. 108, pp. 788–798). PMLR. https://arxiv.org/abs/1907.04138
"""

from typing import Dict, List, Optional, Union

import numpy as np
import pandas as pd


[docs]def sampleUnif(x, n: int = 10000, seed: Optional[int] = None): """ Generate samples from a uniform distribution over the max / min of each column of the sample X. These are used for estimation of support, as the number of samples included under the rules gives a measure of volume. This function is specialized to continuous variables, while `sample_reference` handles the general case, calling this function where necessary. :param x: 2D array of samples, where each column corresponds to a feature. :type x: Pandas Dataframe or Numpy Array :param n: int, defaults to 10000 :type n: int, optional :param seed: Random seed for uniform sampling, defaults to None :type seed: int, optional """ rng = np.random.default_rng(seed) xMin, xMax = np.nanmin(x, axis=0), np.nanmax(x, axis=0) refSamples = rng.uniform(low=xMin.tolist(), high=xMax.tolist(), size=(n, xMin.shape[0])) assert refSamples.shape[1] == x.shape[1] return refSamples
[docs]def sample_reference( x, n: Optional[int] = None, cat_cols: List[str] = [], seed: Optional[int] = None, ref_range: Optional[Dict] = None ): """ Generate samples from a uniform distribution over the columns of X. :param x: 2D array of samples, where each column corresponds to a feature. :type x: Pandas Dataframe or Numpy Array :param n: Number of samples to draw, defaults to the same number as the samples provided. :type n: Optional[int], optional :param cat_cols: Set of categorical columns, defaults to None :type cat_cols: List[str], optional :param seed: Random seed for uniform sampling, defaults to None :type seed: int, optional :param ref_range: Manual override of the range for reference samples, given as a dictionary of the form `ref_range = {c: {"is_binary": True/False, "min": min_value, "max": max_value}}` :type ref_range: Optional[Dict], optional """ if n is None: n = x.shape[0] rng = np.random.default_rng(seed) data = x if isinstance(x, pd.DataFrame) else pd.DataFrame(x) if ref_range is not None: assert isinstance(ref_range, dict) else: ref_range = {} ref_cols = {} counter = seed # Iterate over columns for c in data: if c in ref_range.keys(): # logging.info("Using provided reference range for {}".format(c)) if ref_range[c]["is_binary"]: ref_cols[c] = rng.choice([0, 1], n) else: ref_cols[c] = rng.uniform(low=ref_range[c]["min"], high=ref_range[c]["max"], size=(n, 1)).ravel() else: # number of unique values valUniq = data[c].nunique() # Constant column if valUniq < 2: ref_cols[c] = np.array([data[c].values[0]] * n) # Binary column elif valUniq == 2 or (c in cat_cols) or (data[c].dtype == "object"): cs = data[c].unique() ref_cols[c] = rng.choice(cs, n) # Ordinal column (seed = counter so not correlated) elif np.issubdtype(data[c].dtype, np.dtype(int).type) | np.issubdtype(data[c].dtype, np.dtype(float).type): ref_cols[c] = sampleUnif(data[[c]].values, n, seed=counter).ravel() if counter is not None: counter += 1 return pd.DataFrame(ref_cols)
[docs]def fatom(f: str, o: str, v: Optional[Union[str, float]], fmt: str = "%.3f") -> str: """ Format an "atom", i.e., a single literal in a Boolean Rule. :param f: Feature name :type f: str :param o: Operator, one of ["<=", ">", ">=", "<", "==", "not", ""] :type o: str :param v: Value of comparison for ["<=", ">", ">=", "<", "=="] :type v: Optional[Union[str, float]] :param fmt: Formatting string for floats, defaults to "%.3f" :type fmt: str :return: Formatted atom :rtype: str """ if o in ["<=", ">", ">=", "<", "=="]: if isinstance(v, str): return ("[%s %s %s]") % (f, o, v) else: return ("[%s %s " + fmt + "]") % (f, o, v) elif o == "not": return "not %s" % f else: return f
[docs]def rule_str(C: List, fmt: str = "%.3f") -> str: """ Convert a rule into a string. :param C: List of rules, where each element is a list (a single rule) containing a set of atoms. :type C: List :param fmt: Formatting string for floats, defaults to "%.3f" :type fmt: str :return: Formatted rule :rtype: str """ s = " " + "\n∨ ".join(["(%s)" % (" ∧ ".join([fatom(a[0], a[1], a[2], fmt=fmt) for a in c])) for c in C]) return s