Source code for dowhy.interpreters.propensity_balance_interpreter

import numpy as np
import pandas as pd

from dowhy.causal_estimator import CausalEstimate
from dowhy.causal_estimators.propensity_score_stratification_estimator import PropensityScoreStratificationEstimator
from dowhy.interpreters.visual_interpreter import VisualInterpreter


[docs]class PropensityBalanceInterpreter(VisualInterpreter): SUPPORTED_ESTIMATORS = [ PropensityScoreStratificationEstimator, ] def __init__(self, estimate, **kwargs): super().__init__(estimate, **kwargs) if not isinstance(estimate, CausalEstimate): error_msg = "The interpreter method expects a CausalEstimate object." self.logger.error(error_msg) raise ValueError(error_msg) self.estimator = self.estimate.estimator if not any( isinstance(self.estimator, est_class) for est_class in PropensityBalanceInterpreter.SUPPORTED_ESTIMATORS ): error_msg = "The interpreter method only supports propensity score stratification estimator." self.logger.error(error_msg) raise ValueError(error_msg)
[docs] def interpret(self, data: pd.DataFrame): """Balance plot that shows the change in standardized mean differences for each covariate after propensity score stratification.""" cols = ( self.estimator._observed_common_causes_names + self.estimate._treatment_name + ["strata", "propensity_score"] ) df = data[cols] df_long = ( pd.wide_to_long(df.reset_index(), stubnames=["W"], i="index", j="common_cause_id") .reset_index() .astype({"W": "float64"}) ) # First, calculating mean differences by strata mean_diff = df_long.groupby(self.estimate._treatment_name + ["common_cause_id", "strata"]).agg( mean_w=("W", np.mean) ) mean_diff = ( mean_diff.groupby(["common_cause_id", "strata"]).transform(lambda x: x.max() - x.min()).reset_index() ) mean_diff = mean_diff.query("v0==True") size_by_w_strata = ( df_long.groupby(["common_cause_id", "strata"]).agg(size=("propensity_score", np.size)).reset_index() ) size_by_strata = df_long.groupby(["common_cause_id"]).agg(size=("propensity_score", np.size)).reset_index() size_by_strata = pd.merge(size_by_w_strata, size_by_strata, on="common_cause_id") mean_diff_strata = pd.merge(mean_diff, size_by_strata, on=("common_cause_id", "strata")) stddev_by_w_strata = df_long.groupby(["common_cause_id", "strata"]).agg(stddev=("W", np.std)).reset_index() mean_diff_strata = pd.merge(mean_diff_strata, stddev_by_w_strata, on=["common_cause_id", "strata"]) mean_diff_strata["scaled_mean"] = (mean_diff_strata["mean_w"] / mean_diff_strata["stddev"]) * ( mean_diff_strata["size_x"] / mean_diff_strata["size_y"] ) mean_diff_strata = ( mean_diff_strata.groupby("common_cause_id").agg(std_mean_diff=("scaled_mean", np.sum)).reset_index() ) # Second, without strata mean_diff_overall = df_long.groupby(self.estimate._treatment_name + ["common_cause_id"]).agg( mean_w=("W", np.mean) ) mean_diff_overall = ( mean_diff_overall.groupby("common_cause_id").transform(lambda x: x.max() - x.min()).reset_index() ) mean_diff_overall = mean_diff_overall[mean_diff_overall[self.estimate._treatment_name[0]] == True] # TODO stddev_overall = df_long.groupby(["common_cause_id"]).agg(stddev=("W", np.std)).reset_index() mean_diff_overall = pd.merge(mean_diff_overall, stddev_overall, on=["common_cause_id"]) mean_diff_overall["std_mean_diff"] = mean_diff_overall["mean_w"] / mean_diff_overall["stddev"] # Third, concatenating them and plotting mean_diff_overall = mean_diff_overall[["common_cause_id", "std_mean_diff"]] mean_diff_strata["sample"] = "PropensityAdjusted" mean_diff_overall["sample"] = "Unadjusted" plot_df = pd.concat([mean_diff_overall, mean_diff_strata]) import matplotlib.pyplot as plt plt.style.use("seaborn-white") fig, ax = plt.subplots(1, 1) for label, subdf in plot_df.groupby("common_cause_id"): subdf.plot(kind="line", x="sample", y="std_mean_diff", ax=ax, label=label) plt.legend(title="Common causes") plt.ylabel("Standardized mean difference between treatment and control") plt.xlabel("") plt.xticks(rotation=45) return plot_df