Source code for pywhy_graphs.classes.timeseries.pag

from copy import deepcopy
from typing import Dict, FrozenSet, Iterator, Mapping

import networkx as nx

from pywhy_graphs.classes.base import AncestralMixin, ConservativeMixin
from pywhy_graphs.typing import Node

from .digraph import StationaryTimeSeriesDiGraph
from .graph import StationaryTimeSeriesGraph
from .mixededge import StationaryTimeSeriesMixedEdgeGraph


[docs] class StationaryTimeSeriesPAG( StationaryTimeSeriesMixedEdgeGraph, AncestralMixin, ConservativeMixin ): def __init__( self, incoming_directed_edges=None, incoming_circle_edges=None, incoming_bidirected_edges=None, incoming_undirected_edges=None, circle_edge_name: str = "circle", directed_edge_name: str = "directed", bidirected_edge_name: str = "bidirected", undirected_edge_name: str = "undirected", stationary: bool = True, **attr, ): self.stationary = stationary super().__init__(**attr) self.add_edge_type( StationaryTimeSeriesDiGraph(incoming_directed_edges, stationary=stationary, **attr), directed_edge_name, ) self.add_edge_type( StationaryTimeSeriesDiGraph( incoming_circle_edges, stationary=stationary, check_time_direction=False, **attr ), circle_edge_name, ) self.add_edge_type( StationaryTimeSeriesGraph(incoming_undirected_edges, stationary=stationary, **attr), undirected_edge_name, ) self.add_edge_type( StationaryTimeSeriesGraph(incoming_bidirected_edges, stationary=stationary, **attr), bidirected_edge_name, ) self._directed_name = directed_edge_name self._undirected_name = undirected_edge_name self._circle_name = circle_edge_name self._bidirected_name = bidirected_edge_name from pywhy_graphs import is_valid_mec_graph # check that construction of PAG was valid is_valid_mec_graph(self) # extended patterns store unfaithful triples # these can be used for conservative structure learning algorithm self._unfaithful_triples: Dict[FrozenSet[Node], None] = dict() @property def undirected_edge_name(self) -> str: """Name of the undirected edge internal graph.""" return self._undirected_name @property def directed_edge_name(self) -> str: """Name of the directed edge internal graph.""" return self._directed_name @property def bidirected_edge_name(self) -> str: """Name of the bidirected edge internal graph.""" return self._bidirected_name @property def circle_edge_name(self) -> str: """Name of the bidirected edge internal graph.""" return self._circle_name @property def undirected_edges(self) -> Mapping: """``EdgeView`` of the undirected edges.""" return self.get_graphs(self._undirected_name).edges @property def bidirected_edges(self) -> Mapping: """``EdgeView`` of the bidirected edges.""" return self.get_graphs(self._bidirected_name).edges @property def directed_edges(self) -> Mapping: """``EdgeView`` of the directed edges.""" return self.get_graphs(self._directed_name).edges @property def circle_edges(self) -> Mapping: """``EdgeView`` of the directed edges.""" return self.get_graphs(self.circle_edge_name).edges
[docs] def sub_directed_graph(self) -> nx.DiGraph: """Sub-graph of just the directed edges.""" return self._get_internal_graph(self._directed_name)
[docs] def sub_undirected_graph(self) -> nx.Graph: """Sub-graph of just the undirected edges.""" return self._get_internal_graph(self._undirected_name)
[docs] def sub_bidirected_graph(self) -> nx.Graph: """Sub-graph of just the bidirected edges.""" return self._get_internal_graph(self._bidirected_name)
[docs] def sub_circle_graph(self) -> nx.Graph: """Sub-graph of just the circle edges.""" return self._get_internal_graph(self.circle_edge_name)
[docs] def orient_uncertain_edge(self, u: Node, v: Node) -> None: """Orient undirected edge into an arrowhead. If there is an undirected edge u - v, then the arrowhead will orient u -> v. If the correct order is v <- u, then simply pass the arguments in different order. Parameters ---------- u : node The parent node v : node The node that 'u' points to in the graph. """ if not self.has_edge(u, v, self.circle_edge_name): raise RuntimeError(f"There is no circle edge between {u} and {v}.") u, v = sorted([u, v], key=lambda x: x[1]) # type: ignore self.remove_edge(u, v, self.circle_edge_name) self.add_edge(u, v, self._directed_name) # type: ignore
[docs] def possible_children(self, n: Node) -> Iterator[Node]: """Return an iterator over children of node n. Children of node 'n' are nodes with a directed edge from 'n' to that node. For example, 'n' -> 'x', 'n' -> 'y'. Nodes only connected via a bidirected edge are not considered children: 'n' <-> 'y'. Parameters ---------- n : node A node in the causal DAG. Returns ------- children : Iterator An iterator of the children of node 'n'. """ for nbr in self.neighbors(n): if ( not self.has_edge(nbr, n, self.directed_edge_name) and not self.has_edge(nbr, n, self.bidirected_edge_name) and not self.has_edge(nbr, n, self.undirected_edge_name) ): yield nbr
[docs] def possible_parents(self, n: Node) -> Iterator[Node]: """Return an iterator over parents of node n. Parents of node 'n' are nodes with a directed edge from 'n' to that node. For example, 'n' <- 'x', 'n' <- 'y'. Nodes only connected via a bidirected edge are not considered parents: 'n' <-> 'y'. Parameters ---------- n : node A node in the causal DAG. Returns ------- parents : Iterator An iterator of the parents of node 'n'. """ for nbr in self.neighbors(n): print( nbr, self.has_edge(n, nbr, self.directed_edge_name), self.has_edge(nbr, n, self.bidirected_edge_name), self.has_edge(nbr, n, self.undirected_edge_name), ) if ( not self.has_edge(n, nbr, self.directed_edge_name) and not self.has_edge(nbr, n, self.bidirected_edge_name) and not self.has_edge(nbr, n, self.undirected_edge_name) ): yield nbr
def to_ts_undirected(self): graph_class = StationaryTimeSeriesGraph # deepcopy when not a view G = graph_class() G.graph.update(deepcopy(self.graph)) G.add_nodes_from((n, 0) for n in self.variables) G.add_edges_from( (u, v, deepcopy(d)) for _, edge_adj in self.adj.items() for u, nbrs in edge_adj.items() for v, d in nbrs.items() ) return G