{ "cells": [ { "cell_type": "markdown", "id": "b490c539", "metadata": {}, "source": [ "# Falsification of User-Given Directed Acyclic Graphs\n", "This notebook demonstrates a tool on falsifying a user-given DAG using observational data. The main function is `falsify_graph()`, which takes as input a DAG and data and returns an evaluation result.\n", "For more details about this method, please read the related paper:\n", "\n", "> Eulig, E., Mastakouri, A. A., Blöbaum, P., Hardt, M., & Janzing, D. (2023). Toward Falsifying Causal Graphs Using a Permutation-Based Test. https://arxiv.org/abs/2305.09565" ] }, { "cell_type": "code", "execution_count": null, "id": "13a5f7f6", "metadata": {}, "outputs": [], "source": [ "# Import the necessary libraries and functions for this demo\n", "import numpy as np\n", "import pandas as pd\n", "import networkx as nx\n", "from sklearn.ensemble import GradientBoostingRegressor\n", "from dowhy.gcm.falsify import FalsifyConst, falsify_graph, plot_local_insights, run_validations, apply_suggestions\n", "from dowhy.gcm.independence_test.generalised_cov_measure import generalised_cov_based\n", "from dowhy.gcm.util import plot\n", "from dowhy.gcm.util.general import set_random_seed\n", "from dowhy.gcm.ml import SklearnRegressionModel\n", "from dowhy.gcm.util.general import set_random_seed\n", "set_random_seed(0)\n", "\n", "# Set random seed\n", "set_random_seed(1332)" ] }, { "cell_type": "markdown", "id": "79cc1ace", "metadata": {}, "source": [ "### Synthetic Data\n", "We will first demonstrate the tool on synthetic data. To this end we generated a random DAG with 5 nodes `falsify_g_true.gml` and some data from a random SCM with nonlinear conditionals (`falsify_data_nonlinear.csv`)." ] }, { "cell_type": "code", "execution_count": null, "id": "08d67fd1", "metadata": {}, "outputs": [], "source": [ "# Load example graph and data\n", "g_true = nx.read_gml(f\"falsify_g_true.gml\")\n", "data = pd.read_csv(f\"falsify_data_nonlinear.csv\")\n", "\n", "# Plot true DAG\n", "print(\"True DAG\")\n", "plot(g_true)" ] }, { "cell_type": "markdown", "id": "8f1509b8", "metadata": {}, "source": [ "Let's first evaluate the true DAG on that data (the following cell will take approximately 20s to run)" ] }, { "cell_type": "code", "execution_count": null, "id": "953753d0", "metadata": {}, "outputs": [], "source": [ "result = falsify_graph(g_true, data, plot_histogram=True)\n", "# Summarize the result\n", "print(result)" ] }, { "cell_type": "markdown", "id": "e0c1d8f3", "metadata": {}, "source": [ "As expected, we do not reject the true DAG. Let's understand what `falsify_graph()` does exactly: When we provide a given DAG to `falsify_graph()`, we test violations of local Markov conditions (LMC) by running conditional independence tests (CIs). I.e. for each node in the graph we test if\n", "$$ X_i \\perp \\!\\!\\! \\perp_P X_j \\in \\text{NonDesc}_{X_i} \\ | \\ \\text{PA}_{X_i} $$\n", "We then randomly permute the nodes of the given DAG and test violations of LMC again. We can do this for either a fixed amount of permutations or for all $n!,n:$ Number of nodes in the given DAG. We can then use the probability of a random node permutation (the null) having as few or fewer violations as the given DAG (the test statistic) as a measure to validate the given DAG (the p-value reported in the upper-right corner in the plot above).\n", "\n", "Similarly, we can run an oracle test for each permutation w.r.t. the given DAG, i.e. if the given DAG were the true DAG, how many violations of LMCs do we expect for some permutation. Note, that asking about the number of permutations violating zero LMCs is identical to asking about how many DAGs lie in the same Markov equivalence class (MEC) as the given DAG. In our method we use the number of permuted DAGs lying in the same MEC as the given DAG (with 0 #Violations of tPA) as a measure of how informative the given DAG is. Only if few permutations lie in the same MEC, the independences entailed by the given DAG are 'characteristic' in the sense that the given DAG is falsifiable by testing implied CIs.\n", "\n", "In the plot above we see histograms of the LMC violations of permuted DAGs (blue) and d-separation (oracle, orange) violations of the permuted DAGs. The dashed orange and blue line indicate the number of violations of LMC (blue) / d-separation (orange) of the given DAG. As expected for the true DAG, both histograms are broadly overlapping (except for statistical errors in the CI tests).\n", "\n", "If we are not interested in the plot and just wish to know whether a given is falsified using our test we can use the `falsified` attribute of the returned `EvaluationResult` object of `falsify_graph()` instead." ] }, { "cell_type": "code", "execution_count": null, "id": "529de3f5", "metadata": {}, "outputs": [], "source": [ "print(f\"Graph is falsifiable: {result.falsifiable}, Graph is falsified: {result.falsified}\")" ] }, { "cell_type": "markdown", "id": "43929df3", "metadata": {}, "source": [ "Now, let's modify the true DAG to simulate a DAG where a domain expert knows some of the edges but removes a true one and introduces a wrong one." ] }, { "cell_type": "code", "execution_count": null, "id": "4623dbc3", "metadata": {}, "outputs": [], "source": [ "# Simulate a domain expert with knowledge over some of the edges in the system\n", "g_given = g_true.copy()\n", "g_given.add_edges_from(([('X4', 'X1')])) # Add wrong edge from X4 -> X1\n", "g_given.remove_edge('X2', 'X0') # Remove true edge from X2 -> X0\n", "plot(g_given)" ] }, { "cell_type": "code", "execution_count": null, "id": "fcd847fa", "metadata": {}, "outputs": [], "source": [ "# Run evaluation and plot the result using `plot=True`\n", "result = falsify_graph(g_given, data, plot_histogram=True)\n", "# Summarize the result\n", "print(result)" ] }, { "cell_type": "markdown", "id": "f5b515ed", "metadata": {}, "source": [ "Here, we observe two things. First, the given DAG violates 2 more LMCs than the true DAG. Second, there are many permuted DAGs that violate as many or less LMCs as the given DAG. This is reflected in the p-value LMC, which is much higher than before. Based on the default significance level of 0.05 we would therefore reject the given DAG.\n", "\n", "We can gain additional insight by highlighting the nodes for which the violations of LMCs occurred in the given DAG." ] }, { "cell_type": "code", "execution_count": null, "id": "3fe68e25", "metadata": {}, "outputs": [], "source": [ "# Plot nodes for which violations of LMCs occured\n", "print('Violations of LMCs')\n", "plot_local_insights(g_given, result, method=FalsifyConst.VALIDATE_LMC)" ] }, { "cell_type": "markdown", "id": "ebe3227f", "metadata": {}, "source": [ "### Real Data (Protein Network dataset by Sachs et al., 2005)\n", "Let's try our evaluation method on some real data, the protein network data from Sachs et al., 2005 [1]. This dataset contains quantitative measurements of the expression levels of $n=11$ phosphorylated proteins and phospholipids in the human primary T cell signaling network. The $N=7,466$ measurements, corresponding to individual cells, were acquired via intracellular multicolor flow cytometry. The consensus network contains 19 directed edges with no cycles. \n", "[1] https://www.science.org/doi/abs/10.1126/science.1105809" ] }, { "cell_type": "code", "execution_count": null, "id": "03e8ffa3", "metadata": {}, "outputs": [], "source": [ "# Load the data and consensus DAG\n", "data_url = \"https://raw.githubusercontent.com/FenTechSolutions/CausalDiscoveryToolbox/master/cdt/data/resources/cyto_full_data.csv\"\n", "data_sachs = pd.read_csv(data_url)\n", "g_sachs = nx.read_gml('falsify_sachs.gml')" ] }, { "cell_type": "code", "execution_count": null, "id": "3d30a01b", "metadata": {}, "outputs": [], "source": [ "plot(g_sachs)" ] }, { "cell_type": "markdown", "id": "4dab580c", "metadata": {}, "source": [ "Because of the large number of samples, evaluation using the kernel test above takes too long for this demo. Therefore, we’ll use a test based on the generalized covariance measure (GCM) instead. We'll use sklearns gradient boosted decision trees as regressors." ] }, { "cell_type": "code", "execution_count": null, "id": "072544c1", "metadata": {}, "outputs": [], "source": [ "# Define independence test based on the generalised covariance measure with gradient boosted decision trees as models\n", "def create_gradient_boost_regressor(**kwargs) -> SklearnRegressionModel:\n", " return SklearnRegressionModel(GradientBoostingRegressor(**kwargs))\n", "def gcm(X, Y, Z=None):\n", " return generalised_cov_based(X, Y, Z=Z, prediction_model_X=create_gradient_boost_regressor,\n", " prediction_model_Y=create_gradient_boost_regressor)" ] }, { "cell_type": "markdown", "id": "4363cab3", "metadata": {}, "source": [ "It is infeasible (and unnecessary) to run our baseline on all 11! node-permutations of the graph. We therefore set `n_permutations=100` to evaluate using 100 random permutations. To use the GCM test defined above we'll use the parameters `independence_test=gcm` (unconditional independence testing) and `conditional_independence_test=gcm` (conditional independence testing).\n", "\n", "The following cell will take about 3min to run." ] }, { "cell_type": "code", "execution_count": null, "id": "bbfe658d", "metadata": {}, "outputs": [], "source": [ "# Run evaluation for consensus graph and data.\n", "result_sachs = falsify_graph(g_sachs, data_sachs, n_permutations=100,\n", " independence_test=gcm, \n", " conditional_independence_test=gcm, \n", " plot_histogram=True)\n", "print(result_sachs)" ] }, { "cell_type": "markdown", "id": "14b9c1e4", "metadata": {}, "source": [ "We observe that the consensus DAG is both informative (0/100 permutations lie in the same MEC) and significantly better than random in terms of CI it entails. Note, that the number of LMC violations of the given DAG are much more than the expected type I error rate of the CI tests for the default significance level `significance_ci=0.05` used here. The naive approach of rejecting a DAG with more than 5% violations of LMC would thus falsly reject this DAG." ] }, { "cell_type": "markdown", "id": "9f647056", "metadata": {}, "source": [ "### Edge Suggestions\n", "Beyond falsification of a given DAG shown above we can also run additional tests using `suggestions=True` and report those back to the user. To demonstrate this we will use the synthetic DAG and data from before." ] }, { "cell_type": "code", "execution_count": null, "id": "58650ea8", "metadata": {}, "outputs": [], "source": [ "result = falsify_graph(g_given, data, plot_histogram=True, suggestions=True)\n", "print(result)" ] }, { "cell_type": "markdown", "id": "d9a198d5", "metadata": {}, "source": [ "Compared to the output above we now see the additional row `Suggestions` in the print representation of the evaluation summary. We used a test of causal minimality to report suggestions to the user and would correctly suggest the removal of the edge $X4 \\to X1$, which was wrongly added by the domain expert. We can also plot those suggestions using `plot_local_insights`:" ] }, { "cell_type": "code", "execution_count": null, "id": "ab8ef360", "metadata": {}, "outputs": [], "source": [ "# Plot suggestions\n", "plot_local_insights(g_given, result, method=FalsifyConst.VALIDATE_CM)" ] }, { "cell_type": "markdown", "id": "44e451e9", "metadata": {}, "source": [ "We can apply those suggestions using `apply_suggestions`. If there is an edge we do not want to remove we can use the additional parameter `edges_to_keep` to specify which edges we would not want to be removed." ] }, { "cell_type": "code", "execution_count": null, "id": "073a54c4", "metadata": {}, "outputs": [], "source": [ "# Apply all suggestions (we could exclude suggestions via `edges_to_keep=[('X3', 'X4')])`)\n", "g_given_pruned = apply_suggestions(g_given, result)\n", "# Plot pruned DAG\n", "plot(g_given_pruned)" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.18" } }, "nbformat": 4, "nbformat_minor": 5 }