import collections
from abc import abstractmethod
from typing import Iterable, List, Optional, Set, Tuple
from networkx.classes.reportviews import NodeView
from pywhy_graphs.typing import Node
from .admg import ADMG
from .pag import PAG
class AugmentedNodeMixin:
graph: dict
nodes: NodeView
domains: Set[int] = set()
@abstractmethod
def add_edge(self, u_of_edge, v_of_edge, edge_type="all", **attr):
pass
@abstractmethod
def add_node(self, u, **attrs):
pass
@property
@abstractmethod
def directed_edge_name(self) -> str:
pass
def _verify_augmentednode_dict(self):
# verify validity of F nodes
if "F-nodes" not in self.graph:
self.graph["F-nodes"] = collections.defaultdict(lambda: collections.defaultdict(set))
elif not isinstance(self.graph["F-nodes"], dict):
raise RuntimeError(
"There is a graph property named F-nodes already that is not of type dict."
)
if "S-nodes" not in self.graph:
self.graph["S-nodes"] = dict()
elif not isinstance(self.graph["S-nodes"], dict):
raise RuntimeError(
"There is a graph property named S-nodes already that is not of type dict."
)
def add_f_node(self, intervention_set: Set[Node], require_unique=True, domain=None):
"""Add an F-node to the graph.
Parameters
----------
intervention_set : Set[Node]
A set of regular nodes that already exist in the causal graph.
require_unique : bool, optional
Whether or not to require that the intervention set is unique. If False,
then the intervention set is added to the graph, even if it is already
an F-node. The default is True.
domain : Optional[Set[int]], optional
The domain of the F-node. If None, then the domain is just set to 1.
"""
if isinstance(intervention_set, str) or not isinstance(intervention_set, Iterable):
raise RuntimeError("The intervention set nodes must be an iterable set of node(s).")
if domain is None:
domain = set([1])
# check that there are no duplicates and perform set conversion
orig_len = len(intervention_set)
intervention_set = frozenset(intervention_set) # type: ignore
if len(intervention_set) != orig_len:
raise RuntimeError("The intervention set must be a set of unique nodes.")
# check that the F-node intervention set has variables within the graph
if require_unique and intervention_set in self.intervention_sets:
raise RuntimeError(
f"You cannot add an F-node for {intervention_set} because "
f"there is already an F-node."
)
for node in intervention_set:
if node not in self.nodes:
raise RuntimeError(
f"All intervention sets must be nodes already in the graph. {node} is not."
)
# add a new F-node into the graph
f_node_name = ("F", len(self.f_nodes))
self.add_node(f_node_name)
# add edge between the F-node and its intervention set
for intervened_node in intervention_set:
self.add_edge(f_node_name, intervened_node, self.directed_edge_name)
# adding nodes to F-node container occurs last, because of the error checks
# that occur in adding edges
self.graph["F-nodes"][f_node_name]["targets"] = intervention_set
self.graph["F-nodes"][f_node_name]["domain"] = domain
def add_f_nodes_from(self, intervention_sets: List[Set[Node]]):
"""Add a bunch of F-nodes at once."""
for intervention_set in intervention_sets:
self.add_f_node(intervention_set)
def set_f_node(self, f_node, targets: Optional[Set] = None):
if f_node not in self.nodes:
raise RuntimeError(f"{f_node} is not a node in the existing graph.")
if targets is not None and not all(target in self.nodes for target in targets):
raise RuntimeError(f"Not all targets {targets} are in the existing graph.")
self.graph["F-nodes"][f_node]["targets"] = targets
@property
def augmented_nodes(self):
"""Return set of augmented nodes."""
return self.f_nodes + self.s_nodes
@property
def f_nodes(self) -> List[Node]:
"""Return set of F-nodes."""
return list(self.graph["F-nodes"].keys())
@property
def non_augmented_nodes(self):
"""Return set of non augmented-nodes."""
return set(self.nodes).difference(self.f_nodes).difference(self.s_nodes)
@property
def intervention_sets(self):
"""Return set of intervention-sets."""
targets = set()
for f_node in self.f_nodes:
targets.add(self.graph["F-nodes"][f_node]["targets"])
return targets
@property
def intervened_nodes(self):
"""Return set of intervened nodes."""
nodes = set()
for iset in self.intervention_sets:
nodes = nodes.union(iset)
return nodes
@property
def domain_ids(self) -> List[int]:
"""Return set of domain ids."""
domain_ids = set()
for src, target in self.graph["S-nodes"].values():
domain_ids.add(src)
domain_ids.add(target)
return list(domain_ids)
@property
def s_nodes(self) -> List[Node]:
"""Return set of S-nodes."""
return list(self.graph["S-nodes"].keys())
def add_s_node(self, domain_ids: Tuple, node_changes: Optional[Set[Node]] = None):
if isinstance(node_changes, str) or not isinstance(node_changes, Iterable):
raise RuntimeError("The intervention set nodes must be an iterable set of node(s).")
# check that there are no duplicates and perform set conversion
orig_len = len(node_changes)
node_changes = frozenset(node_changes) # type: ignore
if len(node_changes) != orig_len:
raise RuntimeError("The set must be a set of unique nodes.")
# check that the F-node intervention set has variables within the graph
if domain_ids in self.domain_ids:
raise RuntimeError(
f"You cannot add an augmneted-node for {node_changes} because "
f"there is already an augmented-node."
)
# add domains
self.domains.update(domain_ids)
# add a new S-node into the graph
s_node_name = ("S", len(self.s_nodes))
self.add_node(s_node_name, domain_ids=domain_ids)
# add edge between the F-node and its intervention set
for perturbed_node in node_changes:
self.add_edge(s_node_name, perturbed_node, self.directed_edge_name)
# adding nodes to F-node container occurs last, because of the error checks
# that occur in adding edges
self.graph["S-nodes"][s_node_name] = domain_ids
[docs]
class AugmentedGraph(ADMG, AugmentedNodeMixin):
"""An augmented causal diagram.
An augmented graph is one where interventions are represented by F-nodes.
See :footcite:`pearl_aspects_1993`, where they were first introduced. They
allow one to model hard and soft interventions as an explicit "F-node" added
to the existing causal graph. For more information, see <TBD user guide>.
Parameters
----------
incoming_directed_edges : input directed edges (optional, default: None)
Data to initialize directed edges. All arguments that are accepted
by `networkx.DiGraph` are accepted.
incoming_bidirected_edges : input bidirected edges (optional, default: None)
Data to initialize bidirected edges. All arguments that are accepted
by `networkx.Graph` are accepted.
incoming_undirected_edges : input undirected edges (optional, default: None)
Data to initialize undirected edges. All arguments that are accepted
by `networkx.Graph` are accepted.
directed_edge_name : str
The name for the directed edges. By default 'directed'.
bidirected_edge_name : str
The name for the bidirected edges. By default 'bidirected'.
undirected_edge_name : str
The name for the directed edges. By default 'undirected'.
attr : keyword arguments, optional (default= no attributes)
Attributes to add to graph as key=value pairs.
See Also
--------
networkx.DiGraph
networkx.Graph
ADMG
pywhy_graphs.networkx.MixedEdgeGraph
Notes
-----
**Edge Type Subgraphs**
Different edge types in an I-PAG are represented exactly as they are in a
:class:`pywhy_graphs.PAG`.
**F-nodes**
F-nodes are represented in pywhy-graphs as a tuple as ``('F', <index>)``, where ``index``
is just a random index number. Each F-node is mapped to the intervention-set that they
are applied on. For example in the graph :math:`('F', 0) \\rightarrow X \\rightarrow Y`,
``('F', 0)`` is the F-node added that models an intervention on ``X``. Each intervention-set
is a set of regular nodes in the causal graph.
References
----------
.. footbibliography::
"""
def __init__(
self,
incoming_directed_edges=None,
incoming_bidirected_edges=None,
incoming_undirected_edges=None,
directed_edge_name: str = "directed",
bidirected_edge_name: str = "bidirected",
undirected_edge_name: str = "undirected",
**attr,
):
super().__init__(
incoming_directed_edges,
incoming_bidirected_edges,
incoming_undirected_edges,
directed_edge_name,
bidirected_edge_name,
undirected_edge_name,
**attr,
)
# verify validity of F nodes
self._verify_augmentednode_dict()
def remove_node(self, n):
if n in self.f_nodes:
del self.graph["F-nodes"][n]
return super().remove_node(n)
[docs]
class AugmentedPAG(PAG, AugmentedNodeMixin):
"""An augmented PAG.
An augmented PAG is a PAG that has been augmented with either F-nodes or
S-nodes, or both. It is a Markov equivalence class of causal diagrams.
Parameters
----------
incoming_directed_edges : input directed edges (optional, default: None)
Data to initialize directed edges. All arguments that are accepted
by `networkx.DiGraph` are accepted.
incoming_undirected_edges : input undirected edges (optional, default: None)
Data to initialize undirected edges. All arguments that are accepted
by `networkx.Graph` are accepted.
incoming_bidirected_edges : input bidirected edges (optional, default: None)
Data to initialize bidirected edges. All arguments that are accepted
by `networkx.Graph` are accepted.
incoming_circle_edges : input circular endpoint edges (optional, default: None)
Data to initialize edges with circle endpoints. All arguments that are accepted
by `networkx.DiGraph` are accepted.
directed_edge_name : str
The name for the directed edges. By default 'directed'.
undirected_edge_name : str
The name for the undirected edges. By default 'undirected'.
bidirected_edge_name : str
The name for the bidirected edges. By default 'bidirected'.
circle_edge_name : str
The name for the circle edges. By default 'circle'.
f_nodes : List[Node], optional
List of corresponding nodes that are F nodes, by default None.
Notes
-----
F-nodes are just nodes that are added to a causal graph, and
represent an "augmentation" of the original causal graph to handle
interventions. Each F-node is mapped to a 2-tuple representing the
index pair of intervention-targets.
If the intervention targets are unknown, then the 2-tuple contains
integer indices representing the index of an interventional distribution.
This is called :math:`\\sigma` in :footcite:`Jaber2020causal`.
**Edge Type Subgraphs**
Different edge types in an AugmentedPAG are represented exactly as they are in a
:class:`pywhy_graphs.PAG`.
**F-nodes**
Interventions are represented by special nodes, known as F-nodes. See
:footcite:`Jaber2020causal`, or :footcite:`Kocaoglu2019characterization` for details.
F-nodes are represented in pywhy-graphs as a tuple as ``('F', <index>)``, where ``index``
is just a random index number. Each F-node is mapped to the intervention-set that they
are applied on. For example in the graph :math:`('F', 0) \\rightarrow X \\rightarrow Y`,
``('F', 0)`` is the F-node added that models an intervention on ``X``. Each intervention-set
is a set of regular nodes in the causal graph.
**S-nodes**
Different domains and environments are represented by special nodes, known as S-nodes. See
:footcite:`bareinboim_causal_2016` for details.
S-nodes are represented in pywhy-graphs as a tuple as ``('S', <index>)``, where ``index``
is just a random index number. Each F-node is mapped to the intervention-set that they
are applied on. For example in the graph :math:`('F', 0) \\rightarrow X \\rightarrow Y`,
``('F', 0)`` is the F-node added that models an intervention on ``X``. Each intervention-set
is a set of regular nodes in the causal graph.
References
----------
.. footbibliography::
"""
def __init__(
self,
incoming_directed_edges=None,
incoming_undirected_edges=None,
incoming_bidirected_edges=None,
incoming_circle_edges=None,
directed_edge_name: str = "directed",
undirected_edge_name: str = "undirected",
bidirected_edge_name: str = "bidirected",
circle_edge_name: str = "circle",
**attr,
):
super().__init__(
incoming_directed_edges,
incoming_undirected_edges,
incoming_bidirected_edges,
incoming_circle_edges,
directed_edge_name,
undirected_edge_name,
bidirected_edge_name,
circle_edge_name,
**attr,
)
self._verify_augmentednode_dict()
def remove_node(self, n):
if n in self.f_nodes:
del self.graph["F-nodes"][n]
if n in self.s_nodes:
del self.graph["S-nodes"][n]
return super().remove_node(n)