Source code for dowhy.gcm.causal_models

"""This module defines the fundamental classes for graphical causal models (GCMs)."""

from typing import Any, Callable, Optional, Union

import networkx as nx

from dowhy.gcm.causal_mechanisms import (
    ConditionalStochasticModel,
    FunctionalCausalModel,
    InvertibleFunctionalCausalModel,
    StochasticModel,
)
from dowhy.graph import (
    DirectedGraph,
    HasNodes,
    get_ordered_predecessors,
    is_root_node,
    validate_acyclic,
    validate_node_in_graph,
)

# This constant is used as key when storing/accessing models as causal mechanisms in graph node attributes
CAUSAL_MECHANISM = "causal_mechanism"
# This constant is used as key when storing the parents of a node during fitting. It's used for validation purposes
# afterwards.
PARENTS_DURING_FIT = "parents_during_fit"


[docs]class ProbabilisticCausalModel: """Represents a probabilistic graphical causal model, i.e. it combines a graphical representation of causal causal relationships and corresponding causal mechanism for each node describing the data generation process. The causal mechanisms can be any general stochastic models.""" def __init__( self, graph: Optional[DirectedGraph] = None, graph_copier: Callable[[DirectedGraph], DirectedGraph] = nx.DiGraph ): """ :param graph: Optional graph object to be used as causal graph. :param graph_copier: Optional function that can copy a causal graph. Defaults to a networkx.DiGraph constructor. """ # Todo: Remove after https://github.com/py-why/dowhy/pull/943. from dowhy.causal_graph import CausalGraph from dowhy.causal_model import CausalModel if graph is None: graph = nx.DiGraph() elif isinstance(graph, CausalModel): graph = graph_copier(graph._graph._graph) elif isinstance(graph, CausalGraph): graph = graph_copier(graph._graph) self.graph = graph self.graph_copier = graph_copier
[docs] def set_causal_mechanism(self, node: Any, mechanism: Union[StochasticModel, ConditionalStochasticModel]) -> None: """Assigns the generative causal model of node in the causal graph. :param node: Target node whose causal model is to be assigned. :param mechanism: Causal mechanism to be assigned. A root node must be a :class:`~dowhy.gcm.graph.StochasticModel`, whereas a non-root node must be a :class:`~dowhy.gcm.graph.ConditionalStochasticModel`. """ if node not in self.graph.nodes: raise ValueError("Node %s can not be found in the given graph!" % node) self.graph.nodes[node][CAUSAL_MECHANISM] = mechanism
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, ConditionalStochasticModel]: """Returns the generative causal model of node in the causal graph. :param node: Target node whose causal model is to be assigned. :returns: The causal mechanism for this node. A root node is of type :class:`~dowhy.gcm.graph.StochasticModel`, whereas a non-root node is of type :class:`~dowhy.gcm.graph.ConditionalStochasticModel`. """ return self.graph.nodes[node][CAUSAL_MECHANISM]
[docs] def clone(self): """Clones the causal model, but keeps causal mechanisms untrained.""" graph_copy = self.graph_copier(self.graph) clone_causal_models(self.graph, graph_copy) return self.__class__(graph_copy)
[docs]class StructuralCausalModel(ProbabilisticCausalModel): """Represents a structural causal model (SCM), as required e.g. by :func:`~dowhy.gcm.whatif.counterfactual_samples`. As compared to a :class:`~dowhy.gcm.cms.ProbabilisticCausalModel`, an SCM describes the data generation process in non-root nodes by functional causal models. """
[docs] def set_causal_mechanism(self, node: Any, mechanism: Union[StochasticModel, FunctionalCausalModel]) -> None: super().set_causal_mechanism(node, mechanism)
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, FunctionalCausalModel]: return super().causal_mechanism(node)
[docs]class InvertibleStructuralCausalModel(StructuralCausalModel): """Represents an invertible structural graphical causal model, as required e.g. by :func:`~dowhy.gcm.whatif.counterfactual_samples`. This is a subclass of :class:`~dowhy.gcm.cms.StructuralCausalModel` and has further restrictions on the class of causal mechanisms. Here, the mechanisms of non-root nodes need to be invertible with respect to the noise, such as :class:`~dowhy.gcm.causal_mechanisms.PostNonlinearModel`. """
[docs] def set_causal_mechanism( self, target_node: Any, mechanism: Union[StochasticModel, InvertibleFunctionalCausalModel] ) -> None: super().set_causal_mechanism(target_node, mechanism)
[docs] def causal_mechanism(self, node: Any) -> Union[StochasticModel, InvertibleFunctionalCausalModel]: return super().causal_mechanism(node)
[docs]def validate_causal_model_assignment(causal_graph: DirectedGraph, target_node: Any) -> None: validate_node_has_causal_model(causal_graph, target_node) causal_model = causal_graph.nodes[target_node][CAUSAL_MECHANISM] if is_root_node(causal_graph, target_node): if not isinstance(causal_model, StochasticModel): raise RuntimeError( "Node %s is a root node and, thus, requires a StochasticModel, " "but a %s was found!" % (target_node, causal_model) ) elif not isinstance(causal_model, ConditionalStochasticModel): raise RuntimeError( "Node %s has parents and, thus, requires a ConditionalStochasticModel, " "but a %s was found!" % (target_node, causal_model) )
[docs]def validate_node_has_causal_model(causal_graph: HasNodes, node: Any) -> None: validate_node_in_graph(causal_graph, node) if CAUSAL_MECHANISM not in causal_graph.nodes[node]: raise ValueError("Node %s has no assigned causal mechanism!" % node)
[docs]def validate_causal_dag(causal_graph: DirectedGraph) -> None: validate_acyclic(causal_graph) validate_causal_graph(causal_graph)
[docs]def validate_causal_graph(causal_graph: DirectedGraph) -> None: for node in causal_graph.nodes: validate_node(causal_graph, node)
[docs]def validate_node(causal_graph: DirectedGraph, node: Any) -> None: validate_causal_model_assignment(causal_graph, node) validate_local_structure(causal_graph, node)
[docs]def validate_local_structure(causal_graph: DirectedGraph, node: Any) -> None: if PARENTS_DURING_FIT not in causal_graph.nodes[node] or causal_graph.nodes[node][ PARENTS_DURING_FIT ] != get_ordered_predecessors(causal_graph, node): raise RuntimeError( "The causal mechanism of node %s is not fitted to the graphical structure! Fit all " "causal models in the graph first. If the mechanism is already fitted based on the causal " "parents, consider to update the persisted parents for that node manually." % node )
[docs]def clone_causal_models(source: HasNodes, destination: HasNodes): for node in destination.nodes: if CAUSAL_MECHANISM in source.nodes[node]: destination.nodes[node][CAUSAL_MECHANISM] = source.nodes[node][CAUSAL_MECHANISM].clone()