from copy import deepcopy
from typing import Dict, FrozenSet, Iterator, Mapping
import networkx as nx
from pywhy_graphs.classes.base import AncestralMixin, ConservativeMixin
from pywhy_graphs.typing import Node
from .digraph import StationaryTimeSeriesDiGraph
from .graph import StationaryTimeSeriesGraph
from .mixededge import StationaryTimeSeriesMixedEdgeGraph
[docs]
class StationaryTimeSeriesPAG(
StationaryTimeSeriesMixedEdgeGraph, AncestralMixin, ConservativeMixin
):
def __init__(
self,
incoming_directed_edges=None,
incoming_circle_edges=None,
incoming_bidirected_edges=None,
incoming_undirected_edges=None,
circle_edge_name: str = "circle",
directed_edge_name: str = "directed",
bidirected_edge_name: str = "bidirected",
undirected_edge_name: str = "undirected",
stationary: bool = True,
**attr,
):
self.stationary = stationary
super().__init__(**attr)
self.add_edge_type(
StationaryTimeSeriesDiGraph(incoming_directed_edges, stationary=stationary, **attr),
directed_edge_name,
)
self.add_edge_type(
StationaryTimeSeriesDiGraph(
incoming_circle_edges, stationary=stationary, check_time_direction=False, **attr
),
circle_edge_name,
)
self.add_edge_type(
StationaryTimeSeriesGraph(incoming_undirected_edges, stationary=stationary, **attr),
undirected_edge_name,
)
self.add_edge_type(
StationaryTimeSeriesGraph(incoming_bidirected_edges, stationary=stationary, **attr),
bidirected_edge_name,
)
self._directed_name = directed_edge_name
self._undirected_name = undirected_edge_name
self._circle_name = circle_edge_name
self._bidirected_name = bidirected_edge_name
from pywhy_graphs import is_valid_mec_graph
# check that construction of PAG was valid
is_valid_mec_graph(self)
# extended patterns store unfaithful triples
# these can be used for conservative structure learning algorithm
self._unfaithful_triples: Dict[FrozenSet[Node], None] = dict()
@property
def undirected_edge_name(self) -> str:
"""Name of the undirected edge internal graph."""
return self._undirected_name
@property
def directed_edge_name(self) -> str:
"""Name of the directed edge internal graph."""
return self._directed_name
@property
def bidirected_edge_name(self) -> str:
"""Name of the bidirected edge internal graph."""
return self._bidirected_name
@property
def circle_edge_name(self) -> str:
"""Name of the bidirected edge internal graph."""
return self._circle_name
@property
def undirected_edges(self) -> Mapping:
"""``EdgeView`` of the undirected edges."""
return self.get_graphs(self._undirected_name).edges
@property
def bidirected_edges(self) -> Mapping:
"""``EdgeView`` of the bidirected edges."""
return self.get_graphs(self._bidirected_name).edges
@property
def directed_edges(self) -> Mapping:
"""``EdgeView`` of the directed edges."""
return self.get_graphs(self._directed_name).edges
@property
def circle_edges(self) -> Mapping:
"""``EdgeView`` of the directed edges."""
return self.get_graphs(self.circle_edge_name).edges
[docs]
def sub_directed_graph(self) -> nx.DiGraph:
"""Sub-graph of just the directed edges."""
return self._get_internal_graph(self._directed_name)
[docs]
def sub_undirected_graph(self) -> nx.Graph:
"""Sub-graph of just the undirected edges."""
return self._get_internal_graph(self._undirected_name)
[docs]
def sub_bidirected_graph(self) -> nx.Graph:
"""Sub-graph of just the bidirected edges."""
return self._get_internal_graph(self._bidirected_name)
[docs]
def sub_circle_graph(self) -> nx.Graph:
"""Sub-graph of just the circle edges."""
return self._get_internal_graph(self.circle_edge_name)
[docs]
def orient_uncertain_edge(self, u: Node, v: Node) -> None:
"""Orient undirected edge into an arrowhead.
If there is an undirected edge u - v, then the arrowhead
will orient u -> v. If the correct order is v <- u, then
simply pass the arguments in different order.
Parameters
----------
u : node
The parent node
v : node
The node that 'u' points to in the graph.
"""
if not self.has_edge(u, v, self.circle_edge_name):
raise RuntimeError(f"There is no circle edge between {u} and {v}.")
u, v = sorted([u, v], key=lambda x: x[1]) # type: ignore
self.remove_edge(u, v, self.circle_edge_name)
self.add_edge(u, v, self._directed_name) # type: ignore
[docs]
def possible_children(self, n: Node) -> Iterator[Node]:
"""Return an iterator over children of node n.
Children of node 'n' are nodes with a directed
edge from 'n' to that node. For example,
'n' -> 'x', 'n' -> 'y'. Nodes only connected
via a bidirected edge are not considered children:
'n' <-> 'y'.
Parameters
----------
n : node
A node in the causal DAG.
Returns
-------
children : Iterator
An iterator of the children of node 'n'.
"""
for nbr in self.neighbors(n):
if (
not self.has_edge(nbr, n, self.directed_edge_name)
and not self.has_edge(nbr, n, self.bidirected_edge_name)
and not self.has_edge(nbr, n, self.undirected_edge_name)
):
yield nbr
[docs]
def possible_parents(self, n: Node) -> Iterator[Node]:
"""Return an iterator over parents of node n.
Parents of node 'n' are nodes with a directed
edge from 'n' to that node. For example,
'n' <- 'x', 'n' <- 'y'. Nodes only connected
via a bidirected edge are not considered parents:
'n' <-> 'y'.
Parameters
----------
n : node
A node in the causal DAG.
Returns
-------
parents : Iterator
An iterator of the parents of node 'n'.
"""
for nbr in self.neighbors(n):
print(
nbr,
self.has_edge(n, nbr, self.directed_edge_name),
self.has_edge(nbr, n, self.bidirected_edge_name),
self.has_edge(nbr, n, self.undirected_edge_name),
)
if (
not self.has_edge(n, nbr, self.directed_edge_name)
and not self.has_edge(nbr, n, self.bidirected_edge_name)
and not self.has_edge(nbr, n, self.undirected_edge_name)
):
yield nbr
def to_ts_undirected(self):
graph_class = StationaryTimeSeriesGraph
# deepcopy when not a view
G = graph_class()
G.graph.update(deepcopy(self.graph))
G.add_nodes_from((n, 0) for n in self.variables)
G.add_edges_from(
(u, v, deepcopy(d))
for _, edge_adj in self.adj.items()
for u, nbrs in edge_adj.items()
for v, d in nbrs.items()
)
return G