{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Causal Discovery example\n", "\n", "The goal of this notebook is to show how causal discovery methods can work with DoWhy. We use discovery methods from [causal-learn](https://github.com/py-why/causal-learn) repo. As we will see, causal discovery methods require appropriate assumptions for the correctness guarantees, and thus there will be variance across results returned by different methods in practice. These methods, however, may be combined usefully with domain knowledge to construct the final causal graph." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dowhy\n", "from dowhy import CausalModel\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import graphviz\n", "import networkx as nx \n", "\n", "np.set_printoptions(precision=3, suppress=True)\n", "np.random.seed(0)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Utility function\n", "We define a utility function to draw the directed acyclic graph." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def make_graph(adjacency_matrix, labels=None):\n", " idx = np.abs(adjacency_matrix) > 0.01\n", " dirs = np.where(idx)\n", " d = graphviz.Digraph(engine='dot')\n", " names = labels if labels else [f'x{i}' for i in range(len(adjacency_matrix))]\n", " for name in names:\n", " d.node(name)\n", " for to, from_, coef in zip(dirs[0], dirs[1], adjacency_matrix[idx]):\n", " d.edge(names[from_], names[to], label=str(coef))\n", " return d\n", "\n", "def str_to_dot(string):\n", " '''\n", " Converts input string from graphviz library to valid DOT graph format.\n", " '''\n", " graph = string.strip().replace('\\n', ';').replace('\\t','')\n", " graph = graph[:9] + graph[10:-2] + graph[-1] # Removing unnecessary characters from string\n", " return graph" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiments on the Auto-MPG dataset\n", "\n", "In this section, we will use a dataset on the technical specification of cars. The dataset is downloaded from UCI Machine Learning Repository. The dataset contains 9 attributes and 398 instances. We do not know the true causal graph for the dataset and will use causal-learn to discover it. The causal graph obtained will then be used to estimate the causal effect.\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Load a modified version of the Auto MPG data: Quinlan,R.. (1993). Auto MPG. UCI Machine Learning Repository. https://doi.org/10.24432/C5859H.\n", "data_mpg = pd.read_csv(\"datasets/auto_mpg.csv\", index_col=0)\n", "\n", "print(data_mpg.shape)\n", "data_mpg.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Causal Discovery with causal-learn\n", "\n", "We use the causal-learn library to perform causal discovery on the Auto-MPG dataset. We use three methods for causal discovery here: PC, GES, and LiNGAM. These methods are widely used and do not take much time to run. Hence, these are ideal for an introduction to the topic. Causal-learn provides a comprehensive list of well-tested causal-discovery methods, and readers are welcome to explore.\n", "\n", "The documentation for the methods used are as follows:\n", "- PC [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Constraint-based%20causal%20discovery%20methods/PC.html)\n", "- GES [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Score-based%20causal%20discovery%20methods/GES.html)\n", "- LiNGAM [[link]](https://causal-learn.readthedocs.io/en/latest/search_methods_index/Causal%20discovery%20methods%20based%20on%20constrained%20functional%20causal%20models/lingam.html#ica-based-lingam)\n", "\n", "More methods could be found in the causal-learn documentation [[link]](https://causal-learn.readthedocs.io/en/latest/)." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "We first try the PC algorithm with default parameters." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.search.ConstraintBased.PC import pc\n", "\n", "labels = [f'{col}' for i, col in enumerate(data_mpg.columns)]\n", "data = data_mpg.to_numpy()\n", "\n", "cg = pc(data)\n", "\n", "# Visualization using pydot\n", "from causallearn.utils.GraphUtils import GraphUtils\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt\n", "import io\n", "\n", "pyd = GraphUtils.to_pydot(cg.G, labels=labels)\n", "tmp_png = pyd.create_png(f=\"png\")\n", "fp = io.BytesIO(tmp_png)\n", "img = mpimg.imread(fp, format='png')\n", "plt.axis('off')\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then we have a causal graph discovered by PC. Let us also try GES to see its result." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.search.ScoreBased.GES import ges\n", "\n", "# default parameters\n", "Record = ges(data)\n", "\n", "# Visualization using pydot\n", "from causallearn.utils.GraphUtils import GraphUtils\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt\n", "import io\n", "\n", "pyd = GraphUtils.to_pydot(Record['G'], labels=labels)\n", "tmp_png = pyd.create_png(f=\"png\")\n", "fp = io.BytesIO(tmp_png)\n", "img = mpimg.imread(fp, format='png')\n", "plt.axis('off')\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Well, these two results are different, which is not rare when applying causal discovery on real-world dataset, since the required assumptions on the data-generating process are hard to verify.\n", "\n", "In addition, the graphs returned by PC and GES are CPDAGs instead of DAGs, so it is possible to have undirected edges (e.g., the result returned by GES). Thus, causal effect estimataion is difficult for those methods, since there may be absence of backdoor, instrumental or frontdoor variables. In order to get a DAG, we decide to try LiNGAM on our dataset." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.search.FCMBased import lingam\n", "model_lingam = lingam.ICALiNGAM()\n", "model_lingam.fit(data)\n", "\n", "from causallearn.search.FCMBased.lingam.utils import make_dot\n", "make_dot(model_lingam.adjacency_matrix_, labels=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Now we have a DAG and are ready to estimate the causal effects based on that." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate causal effects using Linear Regression\n", "\n", "Now let us see the estimate of causal effect of *mpg* on *weight*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Obtain valid dot format\n", "graph_dot = make_graph(model_lingam.adjacency_matrix_, labels=labels)\n", "\n", "# Define Causal Model\n", "model=CausalModel(\n", " data = data_mpg,\n", " treatment='mpg',\n", " outcome='weight',\n", " graph=str_to_dot(graph_dot.source))\n", "\n", "# Identification\n", "identified_estimand = model.identify_effect(proceed_when_unidentifiable=True)\n", "print(identified_estimand)\n", "\n", "# Estimation\n", "estimate = model.estimate_effect(identified_estimand,\n", " method_name=\"backdoor.linear_regression\",\n", " control_value=0,\n", " treatment_value=1,\n", " confidence_intervals=True,\n", " test_significance=True)\n", "print(\"Causal Estimate is \" + str(estimate.value))" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Experiments on the Sachs dataset\n", "\n", "The dataset consists of the simultaneous measurements of 11 phosphorylated proteins and phospholipids derived from thousands of individual primary immune system cells, subjected to both general and specific molecular interventions (Sachs et al., 2005).\n", "\n", "The specifications of the dataset are as follows - \n", "- Number of nodes: 11\n", "- Number of arcs: 17\n", "- Number of parameters: 178\n", "- Average Markov blanket size: 3.09\n", "- Average degree: 3.09\n", "- Maximum in-degree: 3\n", "- Number of instances: 7466" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1. Load the data" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.utils.Dataset import load_dataset\n", "\n", "data_sachs, labels = load_dataset(\"sachs\")\n", "\n", "print(data_sachs.shape)\n", "print(labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Causal Discovery with causal-learn\n", "\n", "We use the three causal discovery methods mentioned above (PC, GES, and LiNGAM) to find the causal graphs." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "First, let us take a look at how PC works." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "graphs = {}\n", "graphs_nx = {}\n", "labels = [f'{col}' for i, col in enumerate(labels)]\n", "data = data_sachs\n", "\n", "from causallearn.search.ConstraintBased.PC import pc\n", "\n", "cg = pc(data)\n", "\n", "# Visualization using pydot\n", "from causallearn.utils.GraphUtils import GraphUtils\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt\n", "import io\n", "\n", "pyd = GraphUtils.to_pydot(cg.G, labels=labels)\n", "tmp_png = pyd.create_png(f=\"png\")\n", "fp = io.BytesIO(tmp_png)\n", "img = mpimg.imread(fp, format='png')\n", "plt.axis('off')\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Then, let us try GES." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.search.ScoreBased.GES import ges\n", "\n", "# default parameters\n", "Record = ges(data)\n", "\n", "# Visualization using pydot\n", "from causallearn.utils.GraphUtils import GraphUtils\n", "import matplotlib.image as mpimg\n", "import matplotlib.pyplot as plt\n", "import io\n", "\n", "pyd = GraphUtils.to_pydot(Record['G'], labels=labels)\n", "tmp_png = pyd.create_png(f=\"png\")\n", "fp = io.BytesIO(tmp_png)\n", "img = mpimg.imread(fp, format='png')\n", "plt.axis('off')\n", "plt.imshow(img)\n", "plt.show()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And also LiNGAM." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from causallearn.search.FCMBased import lingam\n", "model = lingam.ICALiNGAM()\n", "model.fit(data)\n", "\n", "from causallearn.search.FCMBased.lingam.utils import make_dot\n", "make_dot(model.adjacency_matrix_, labels=labels)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Estimate effects using Linear Regression\n", "\n", "Similarly, let us use the DAG returned by LiNGAM to estimate the causal effect of *PIP2* on *PKC*." ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Obtain valid dot format\n", "graph_dot = make_graph(model.adjacency_matrix_, labels=labels)\n", "\n", "data_df = pd.DataFrame(data=data, columns=labels)\n", "\n", "# Define Causal Model\n", "model_est=CausalModel(\n", " data = data_df,\n", " treatment='pip2',\n", " outcome='pkc',\n", " graph=str_to_dot(graph_dot.source))\n", "\n", "# Identification\n", "identified_estimand = model_est.identify_effect(proceed_when_unidentifiable=False)\n", "print(identified_estimand)\n", "\n", "# Estimation\n", "estimate = model_est.estimate_effect(identified_estimand,\n", " method_name=\"backdoor.linear_regression\",\n", " control_value=0,\n", " treatment_value=1,\n", " confidence_intervals=True,\n", " test_significance=True)\n", "print(\"Causal Estimate is \" + str(estimate.value))" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" }, "metadata": { "interpreter": { "hash": "31f2aee4e71d21fbe5cf8b01ff0e069b9275f58929596ceb00d14d90e3e16cd6" } }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": false, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }