Source code for pywhy_graphs.algorithms.pag

import logging
from collections import deque
from itertools import chain, combinations, permutations
from typing import List, Optional, Set, Tuple

import networkx as nx
import numpy as np

from pywhy_graphs import ADMG, CPDAG, PAG, StationaryTimeSeriesPAG
from pywhy_graphs.algorithms.generic import (
    has_adc,
    inducing_path,
    single_source_shortest_mixed_path,
    valid_mag,
)
from pywhy_graphs.typing import Node, TsNode

logger = logging.getLogger()

__all__ = [
    "possible_ancestors",
    "possible_descendants",
    "discriminating_path",
    "pds",
    "pds_path",
    "uncovered_pd_path",
    "pds_t",
    "pds_t_path",
    "is_definite_noncollider",
    "pag_to_mag",
    "check_pag_definition",
    "valid_pag",
    "mag_to_pag",
]


def _possibly_directed(G: PAG, i: Node, j: Node, reverse: bool = False):
    """Check that edge is possibly directed.

    A possibly directed edge is one of the form:
    - ``i -> j``
    - ``i o-> j``
    - ``i o-o j``

    Parameters
    ----------
    G : PAG
        The graph.
    i : Node
        The first node.
    j : Node
        The second node.
    reverse : bool
        Whether to check the reverse direction for valid path. If true,
        will check for ``i *-> j``. If false (default) will check for
        ``i <-* j``.

    Returns
    -------
    valid : bool
        Whether to path from ``... i *-* j ...`` is a valid path.
    """
    if i not in G.neighbors(j):
        return False

    if reverse:
        # i *-> j is invalid
        direct_check = G.has_edge(i, j, G.directed_edge_name)
    else:
        # i <-* j is invalid
        direct_check = G.has_edge(j, i, G.directed_edge_name)

    # the direct check checks for i *-> j or i <-* j
    # i <-> j is also checked
    # everything else is valid; i.e. i -- j, or i o-o j
    if direct_check or G.has_edge(i, j, G.bidirected_edge_name):
        return False
    return True


[docs] def possible_ancestors(G: PAG, source: Node) -> Set[Node]: """Possible ancestors of a source node. Parameters ---------- G : PAG The graph. source : Node The source node to start at. Returns ------- possible_ancestors : Set[Node] The set of nodes that are possible ancestors. """ valid_path = lambda *args: _possibly_directed(*args, reverse=True) # type: ignore # perform BFS starting at source using neighbors paths = single_source_shortest_mixed_path(G, source, valid_path=valid_path) return set(paths.keys())
[docs] def possible_descendants(G: PAG, source: Node) -> Set[Node]: """Possible descendants of a source node. Parameters ---------- G : PAG The graph. source : Node The source node to start at. Returns ------- possible_descendants : Set[Node] The set of nodes that are possible descendants. """ valid_path = lambda *args: _possibly_directed(*args, reverse=False) # type: ignore # perform BFS starting at source using neighbors paths = single_source_shortest_mixed_path(G, source, valid_path=valid_path) return set(paths.keys())
def is_definite_collider(G: PAG, node1: Node, node2: Node, node3: Node) -> bool: """Check if <node1, node2, node3> path forms a definite collider. I.e. node1 *-> node2 <-* node3. Parameters ---------- node1 : node A node on the path to check. node2 : node A node on the path to check. node3 : node A node on the path to check. Returns ------- is_collider : bool Whether or not the path is a definite collider. """ # check arrow from node1 into node2 condition_one = G.has_edge(node1, node2, G.directed_edge_name) or G.has_edge( node1, node2, G.bidirected_edge_name ) # check arrow from node2 into node1 condition_two = G.has_edge(node3, node2, G.directed_edge_name) or G.has_edge( node3, node2, G.bidirected_edge_name ) return condition_one and condition_two
[docs] def is_definite_noncollider(G: PAG, node1: Node, node2: Node, node3: Node) -> bool: """Check if <node1, node2, node3> path forms a definite non-collider. Definite noncolliders have the form: - node1 *-* node2 -> node3, or - node1 <- node2 *-* node3, or - node1 *-o node2 o-* node3 with node1 and node3 non-adjacent Parameters ---------- node1 : node A node on the path to check. node2 : node A node on the path to check. node3 : node A node on the path to check. Returns ------- is_noncollider : bool Whether or not the path is a definite non-collider. If it is not a definite non-collider, then it may be a definite collider, or uncertain. """ if G.has_edge(node1, node2, G.directed_edge_name) or G.has_edge( node1, node2, G.bidirected_edge_name ): # node1 *-> node2 *-* node3 # or node1 *-* node2 <-* node3 if G.has_edge(node3, node2, G.directed_edge_name) or G.has_edge( node3, node2, G.bidirected_edge_name ): return False elif G.has_edge(node1, node2, G.circle_edge_name) and G.has_edge( node3, node2, G.circle_edge_name ): # node1 *-o node2 o-* node3 if G.has_edge(node1, node3, "any") or G.has_edge(node3, node1, "any"): return False return True
[docs] def discriminating_path( graph: PAG, u: Node, a: Node, c: Node, max_path_length: Optional[int] = None ) -> Tuple[bool, List[Node], Set[Node]]: """Find the discriminating path for <..., a, u, c>. A discriminating path, p = <v, ..., a, u, c>, is one where: - p has at least 3 edges - u is non-endpoint and u is adjacent to c - v is not adjacent to c - every vertex between v and u is a collider on p and parent of c Parameters ---------- graph : PAG PAG to orient. u : node A node in the graph. a : node A node in the graph. c : node A node in the graph. max_path_length : optional, int The maximum distance to check in the graph. By default None, which sets it to 1000. Returns ------- explored_nodes : set A set of explored nodes. disc_path : list The discriminating path starting from node c. found_discriminating_path : bool Whether or not a discriminating path was found. """ if max_path_length is None: max_path_length = 1000 explored_nodes: Set[Node] = set() found_discriminating_path = False disc_path: List[Node] = [] # parents of c form the discriminating path cparents = graph.parents(c) # keep track of the distance searched distance = 0 # keep track of the previous nodes, i.e. to build a path # from node (key) to its child along the path (value) descendant_nodes = dict() descendant_nodes[u] = c descendant_nodes[a] = u # keep track of paths of certain nodes that were already explored # start off with the valid triple <a, u, c> # - u is adjacent to c # - u has an arrow pointing to a # - TBD a is a definite collider # - TBD endpoint is not adjacent to c explored_nodes.add(c) explored_nodes.add(u) explored_nodes.add(a) # a must be a parent of c if not graph.has_edge(a, c, graph.directed_edge_name): return found_discriminating_path, disc_path, explored_nodes # a and u must be connected by a bidirected edge, or with an edge towards a # for a to be a definite collider if not graph.has_edge(a, u, graph.bidirected_edge_name) and not graph.has_edge( u, a, graph.directed_edge_name ): return found_discriminating_path, disc_path, explored_nodes # now add 'a' to the queue and begin exploring # adjacent nodes that are connected with bidirected edges path = deque([a]) while len(path) != 0: this_node = path.popleft() # check distance criterion to prevent checking very long paths distance += 1 if distance > 0 and distance > max_path_length: logger.warning( f"Did not finish checking discriminating path in {graph} because the path " f"length exceeded {max_path_length}." ) return found_discriminating_path, disc_path, explored_nodes # now we check all neighbors to this_node that are pointing to it # either with a direct edge, or a bidirected edge node_iterator = chain(graph.possible_parents(this_node), graph.parents(this_node)) node_iterator = chain(node_iterator, graph.sub_bidirected_graph().neighbors(this_node)) # 'next_node' is either a parent, possible parent, or in a bidrected # edge with 'this_node'. # 'this_node' is a definite collider since there was # confirmed an arrow pointing towards 'this_node' # and 'next_node' is connected to it via a bidirected arrow. for next_node in node_iterator: # if we have already explored this neighbor, then it is # already along the potentially discriminating path if next_node in explored_nodes: continue # keep track of explored_nodes explored_nodes.add(next_node) # Check if 'next_node' is now the end of the discriminating path. # Note we now have 3 edges in the path by construction. if c not in graph.neighbors(next_node) and next_node != c: logger.info(f"Reached the end of the discriminating path with {next_node}.") explored_nodes.add(next_node) descendant_nodes[next_node] = this_node found_discriminating_path = True break # If we didn't reach the end of the discriminating path, # then we can add 'next_node' to the path. This only occurs # if 'next_node' is a valid new entry, which requires it # to be a part of the parents of 'c'. if next_node in cparents and graph.has_edge( this_node, next_node, graph.bidirected_edge_name ): # check that the next_node is a possible collider with at least # this_node -> next_node # since it is a parent, we can now add it to the path queue path.append(next_node) descendant_nodes[next_node] = this_node explored_nodes.add(next_node) # return the actual discriminating path if found_discriminating_path: disc_path = deque([]) # type: ignore disc_path.append(next_node) while disc_path[-1] != c: disc_path.append(descendant_nodes[disc_path[-1]]) return found_discriminating_path, disc_path, explored_nodes
[docs] def uncovered_pd_path( graph: PAG, u: Node, c: Node, max_path_length: Optional[int] = None, first_node: Optional[Node] = None, second_node: Optional[Node] = None, force_circle: bool = False, forbid_node: Optional[Node] = None, ) -> Tuple[List[Node], bool]: """Compute uncovered potentially directed (pd) paths from u to c. In a pd path, the edge between V(i) and V(i+1) is not an arrowhead into V(i) or a tail from V(i+1). An intuitive explanation given in :footcite:`Zhang2008` notes that a pd path could be oriented into a directed path by changing circles into tails or arrowheads. In addition, the path is uncovered, meaning every node beside the endpoints are unshielded, meaning V(i-1) and V(i+1) are not adjacent. A special case of a uncovered pd path is an uncovered circle path, which appears as u o-o ... o-o c. Parameters ---------- graph : PAG PAG to orient. u : node A node in the graph to start the uncovered path. c : node A node in the graph. max_path_length : optional, int The maximum distance to check in the graph. By default None, which sets it to 1000. first_node : node, optional The node previous to 'u'. If it is before 'u', then we will check that 'u' is unshielded. If it is not passed, then 'u' is considered the first node in the path and hence does not need to be unshielded. Both 'first_node' and 'second_node' cannot be passed. second_node : node, optional The node after 'u' that the path must traverse. Both 'first_node' and 'second_node' cannot be passed. force_circle: bool Whether to search for only circle paths (u o-o ... o-o c) or all potentially directed paths. By default False, which searches for all potentially directed paths. forbid_node: node, optional A node after 'u' which is forbidden to immediately traverse when searching for a path. Notes ----- The definition of an uncovered pd path is taken from :footcite:`Zhang2008`. Typically uncovered potentially directed paths are defined by two nodes. However, in one use case within the FCI algorithm, it is defined relative to an adjacent third node that comes before 'u'. In certain cases (e.g. R5 of FCI) an uncovered pd path must be found between two variables, but these variables are already adjacent and connected by a trivial uncovered pd path. To prevent the function from returning this trivial path, the 'forbid_node' argument can be used. References ---------- .. footbibliography:: """ if first_node is not None and second_node is not None: raise RuntimeError( "Both first and second node cannot be set. Only set one of them. " "Read the docstring for more information." ) if ( any(node not in graph for node in (u, c)) or (first_node is not None and first_node not in graph) or (second_node is not None and second_node not in graph) ): raise RuntimeError("Some nodes are not in graph... Double check function arguments.") if max_path_length is None: max_path_length = 1000 explored_nodes: Set[Node] = set() found_uncovered_pd_path = False uncov_pd_path: List[Node] = [] # keep track of the distance searched distance = 0 start_node = u # keep track of the previous nodes, i.e. to build a path # from node (key) to its child along the path (value) descendant_nodes = dict() if first_node is not None: descendant_nodes[u] = first_node if second_node is not None: descendant_nodes[second_node] = u # keep track of paths of certain nodes that were already explored # start off with the valid triple <a, u, c> # - u is adjacent to c # - u has an arrow pointing to a # - TBD a is a definite collider # - TBD endpoint is not adjacent to c explored_nodes.add(u) if first_node is not None: explored_nodes.add(first_node) if second_node is not None: explored_nodes.add(second_node) # we now want to start on the second_node start_node = second_node # now add 'a' to the queue and begin exploring # adjacent nodes that are connected with bidirected edges path = deque([start_node]) while len(path) != 0: this_node = path.popleft() prev_node = descendant_nodes.get(this_node) # check distance criterion to prevent checking very long paths distance += 1 if distance > 0 and distance > max_path_length: logger.warning( f"Did not finish checking discriminating path in {graph} because the path " f"length exceeded {max_path_length}." ) return uncov_pd_path, found_uncovered_pd_path # get all adjacent nodes to 'this_node' for next_node in graph.neighbors(this_node): # check that this is the starting node and whether or not we are on a forbidden path if this_node == start_node and forbid_node is not None and next_node == forbid_node: continue # if we have already explored this neighbor, then ignore if next_node in explored_nodes: continue # now check that the next_node is uncovered by comparing # with the previous node, because the triple is shielded if prev_node is not None and next_node in graph.neighbors(prev_node): continue # now check that the triple is potentially directed, else # we skip this node condition = graph.has_edge(this_node, next_node, graph.circle_edge_name) if not force_circle: # If we do not restrict to circle paths then directed edges are also OK condition = condition or graph.has_edge( this_node, next_node, graph.directed_edge_name ) if not condition: continue # now this next node is potentially directed, does not # form a shielded triple, so we add it to the path explored_nodes.add(next_node) descendant_nodes[next_node] = this_node # if we have reached our end node, then we have found an # uncovered possibly-directed path if next_node == c: logger.info(f"Reached the end of the uncovered pd path with {next_node}.") found_uncovered_pd_path = True break path.append(next_node) # return the actual uncovered pd path if first_node is None: first_node = u if found_uncovered_pd_path: uncov_pd_path_: deque = deque([]) uncov_pd_path_.appendleft(c) while uncov_pd_path_[0] != first_node: uncov_pd_path_.appendleft(descendant_nodes[uncov_pd_path_[0]]) uncov_pd_path = list(uncov_pd_path_) return uncov_pd_path, found_uncovered_pd_path
[docs] def pds( graph: PAG, node_x: Node, node_y: Optional[Node] = None, max_path_length: Optional[int] = None ) -> Set[Node]: """Find all PDS sets between node_x and node_y. Parameters ---------- graph : PAG The graph. node_x : node The node 'x'. node_y : node The node 'y'. max_path_length : optional, int The maximum length of a path to search on. By default None, which sets it to 1000. Returns ------- dsep : set The possibly d-separating set between node_x and node_y. Notes ----- Possibly d-separating (PDS) sets are nodes V, along an adjacency paths from 'node_x' to some 'V', which has the following characteristics for every subpath triple <X, Y, Z> on the path: - Y is a collider, or - Y is a triangle (i.e. X, Y and Z form a complete subgraph) If the path meets these characteristics, then 'V' is in the PDS set. If Y is a triangle, then it will be uncertain with circular edges due to the fact that it is a shielded triple, not allowing us to infer that it is a collider. These are defined in :footcite:`Colombo2012` and :footcite:`Spirtes1993`. References ---------- .. footbibliography:: """ if max_path_length is None: max_path_length = 1000 distance = 0 edge = None # possibly d-sep set dsep: Set[Node] = set() # a queue to q: deque = deque() seen_edges = set() node_list: Optional[List[Node]] = [] # keep track of previous nodes along the path for every node # along a path previous = {node_x: None} # get the adjacency graph to perform path searches over adj_graph = graph.to_undirected() if node_y is not None: # edge case: check that there exists paths between node_x # and node_y if not nx.has_path(adj_graph, node_x, node_y): return dsep # get a list of all neighbors of node_x that is not y # and add these as candidates to explore a path # and also add them to the possibly d-separating set for node_v in graph.neighbors(node_x): # ngbhr cannot be endpoint if node_v == node_y: continue if node_y is not None: # used for RFCI # check that node_b is connected to the endpoint if # the endpoint is passed if not nx.has_path(adj_graph, node_v, node_y): continue # form edge as a tuple edge = (node_x, node_v) # this path from node_x - node_v is a candidate path # that will have a possibly d-separating set q.append(edge) # keep track of the edes seen_edges.add(edge) # all immediately adjacent nodes are part of the pdsep set dsep.add(node_v) while len(q) != 0: this_edge = q.popleft() prev_node, this_node = this_edge # if we get the previous edge, then increment the distance # and if this_edge == edge: edge = None distance += 1 if distance > 0 and distance > max_path_length: break if node_y is not None: # check that node_b is connected to the endpoint if # the endpoint is passed if not nx.has_path(adj_graph, this_node, node_y): continue # now add this_node to the pdsep set, since we have # reached this node dsep.add(this_node) # now we want to check the subpath that is created # using the previous node, the current node and the next node for next_node in graph.neighbors(this_node): # check if 'node_c' in (prev_node, X, Y) if next_node in (prev_node, node_x, node_y): continue # get the previous nodes, and add the previous node # for this next node node_list = previous.get(next_node) if node_list is None: node_list = [] node_list.append(this_node) # check that we have a definite collider # check the edge: prev_node - this_node # check the edge: this_node - next_node is_def_collider = is_definite_collider(graph, prev_node, this_node, next_node) # check that there is a triangle, meaning # prev_node is adjacent to next_node is_triangle = next_node in graph.neighbors(prev_node) # if we have a collider, or triangle, then this edge # is a candidate on a pdsep path if is_def_collider or is_triangle: next_edge = (prev_node, next_node) if next_edge in seen_edges: continue seen_edges.add(next_edge) q.append(next_edge) if edge is None: edge = next_edge return dsep
[docs] def pds_path( graph: PAG, node_x: Node, node_y: Node, max_path_length: Optional[int] = None ) -> Set[Node]: """Compute the possibly-d-separating set path. Returns the PDS_path set defined in definition 3.4 of :footcite:`Colombo2012`. Parameters ---------- graph : PAG The graph. node_x : node The starting node. node_y : node The ending node max_path_length : int, optional The maximum length of a path to search on for PDS set, by default None, which sets it to 1000. Returns ------- pds_path : set The set of nodes in the possibly d-separating path set. Notes ----- This is a smaller subset compared to possibly-d-separating sets. It takes the PDS set and intersects it with the biconnected components of the adjacency graph that contains the edge (node_x, node_y). The current implementation calls `pds` and then restricts the nodes that it returns. """ # get the adjacency graph to perform path searches over adj_graph = graph.to_undirected() # compute all biconnected componnets biconn_comp = nx.biconnected_component_edges(adj_graph) # compute the PDS set pds_set = pds(graph, node_x=node_x, node_y=node_y, max_path_length=max_path_length) # now we intersect the connected component that has the edge found_component: Set = set() for comp in biconn_comp: if (node_x, node_y) in comp or (node_y, node_x) in comp: # add all unique nodes in the biconnected component for x, y in comp: found_component.add(x) found_component.add(y) break # now intersect the pds set with the biconnected component with the edge between # 'x' and 'y' pds_path = pds_set.intersection(found_component) return pds_path
[docs] def pds_t( graph: StationaryTimeSeriesPAG, node_x: TsNode, node_y: TsNode, max_path_length: Optional[int] = None, ) -> Set: """Compute the possibly-d-separating set over time. Returns the 'pdst' set defined in :footcite:`Malinsky18a_svarfci`. Parameters ---------- graph : StationaryTimeSeriesPAG The graph. node_x : node The starting node. node_y : node The ending node max_path_length : int, optional The maximum length of a path to search on for PDS set, by default None, which sets it to 1000. Returns ------- pds_t_set : set The set of nodes in the possibly d-separating path set. Notes ----- This is a smaller subset compared to possibly-d-separating sets. This consists of nodes, 'x', in the PDS set of (node_x, node_y), with the time-lag of 'x' being less than the max time-lag among node_x and and node_y. The current implementation calls `pds` and then restricts the nodes that it returns. """ _check_ts_node(node_x) _check_ts_node(node_y) _, x_lag = node_x _, y_lag = node_y max_lag = max(np.abs(x_lag), np.abs(y_lag)) # compute the PDS set pds_set = pds( graph, node_x=node_x, node_y=node_y, max_path_length=max_path_length ) # type: ignore pds_t_set = set() # only keep nodes with max-lag less than or equal to max(x_lag, y_lag) for node in pds_set: if np.abs(node[1]) <= max_lag: # type: ignore pds_t_set.add(node) return pds_t_set
[docs] def pds_t_path( graph: StationaryTimeSeriesPAG, node_x: TsNode, node_y: TsNode, max_path_length: Optional[int] = None, ) -> Set: """Compute the possibly-d-separating path set over time. Returns the 'pdst_path' set defined in :footcite:`Malinsky18a_svarfci` with the additional restriction that any nodes must be on a path between the two endpoints. Parameters ---------- graph : StationaryTimeSeriesPAG The graph. node_x : node The starting node. node_y : node The ending node max_path_length : int, optional The maximum length of a path to search on for PDS set, by default None, which sets it to 1000. Returns ------- pds_t_set : set The set of nodes in the possibly d-separating path set. Notes ----- This is a smaller subset compared to possibly-d-separating sets. This consists of nodes, 'x', in the PDS set of (node_x, node_y), with the time-lag of 'x' being less than the max time-lag among node_x and and node_y. The current implementation calls `pds` and then restricts the nodes that it returns. """ _check_ts_node(node_x) _check_ts_node(node_y) _, x_lag = node_x _, y_lag = node_y max_lag = max(np.abs(x_lag), np.abs(y_lag)) # compute the PDS set pds_set = pds_path( graph, node_x=node_x, node_y=node_y, max_path_length=max_path_length ) # type: ignore pds_t_set = set() # only keep nodes with max-lag less than or equal to max(x_lag, y_lag) for node in pds_set: if np.abs(node[1]) <= max_lag: # type: ignore pds_t_set.add(node) return pds_t_set
def definite_m_separated( G, x, y, z, bidirected_edge_name="bidirected", directed_edge_name="directed", circle_edge_name="circle", ): """Check definite m-separation among 'x' and 'y' given 'z' in partial ancestral graph G. A partial ancestral graph (PAG) is defined with directed edges (``->``), bidirected edges (``<->``), and circle-endpoint edges (``o-*``, where the ``*`` for example can mean an arrowhead from a directed edge). This algorithm implements the definite m-separation check, which checks for the absence of possibly m-connecting paths between 'x' and 'y' given 'z'. This algorithm first obtains the ancestral subgraph of x | y | z which only requires knowledge of the directed edges. Then, all outgoing directed edges from nodes in z are deleted. After that, an undirected graph composed from the directed and bidirected edges amongst the remaining nodes is created. Then, x is independent of y given z if x is disconnected from y in this new graph. Parameters ---------- G : mixed-edge-graph Mixed edge causal graph. x : set First set of nodes in ``G``. y : set Second set of nodes in ``G``. z : set Set of conditioning nodes in ``G``. Can be empty set. Returns ------- b : bool A boolean that is true if ``x`` is definite m-separated from ``y`` given ``z`` in ``G``. References ---------- .. footbibliography:: See Also -------- d_separated m_separated PAG Notes ----- There is no known optimal algorithm for checking definite m-separation to our knowledge, so the algorithm proceeds by enumerating paths between 'x' and 'y'. This first checks the subgraph comprised of only circle edges. If there is a path """ if not isinstance(G, PAG): raise ValueError("Definite m-separated is only defined for a PAG.") # this proceeds by first removing unnecessary nodes def _check_ts_node(node): if not isinstance(node, tuple) or len(node) != 2: raise ValueError( f"All nodes in time series DAG must be a 2-tuple of the form (<node>, <lag>). " f"You passed in {node}." ) if node[1] > 0: raise ValueError(f"All lag points should be 0, or less. You passed in {node}.") def _apply_meek_rules(graph: CPDAG) -> None: """Orient edges in a skeleton graph to estimate the causal DAG, or CPDAG. These are known as the Meek rules :footcite:`Meek1995`. They are deterministic in the sense that they are logical characterizations of what edges must be present given the rest of the local graph structure. Parameters ---------- graph : CPDAG A graph containing directed and undirected edges. """ # For all the combination of nodes i and j, apply the following # rules. completed = False while not completed: # type: ignore change_flag = False for i in graph.nodes: for j in graph.neighbors(i): if i == j: continue # Rule 1: Orient i-j into i->j whenever there is an arrow k->i # such that k and j are nonadjacent. r1_add = _meek_rule1(graph, i, j) # Rule 2: Orient i-j into i->j whenever there is a chain # i->k->j. r2_add = _meek_rule2(graph, i, j) # Rule 3: Orient i-j into i->j whenever there are two chains # i-k->j and i-l->j such that k and l are nonadjacent. r3_add = _meek_rule3(graph, i, j) # Rule 4: Orient i-j into i->j whenever there are two chains # i-k->l and k->l->j such that k and j are nonadjacent. # r4_add = _meek_rule4(graph, i, j) if any([r1_add, r2_add, r3_add, r4_add]) and not change_flag: change_flag = True if not change_flag: completed = True break def _meek_rule1(graph: CPDAG, i: str, j: str) -> bool: """Apply rule 1 of Meek's rules. Looks for i - j such that k -> i, such that (k,i,j) is an unshielded triple. Then can orient i - j as i -> j. """ added_arrows = False # Check if i-j. if graph.has_edge(i, j, graph.undirected_edge_name): for k in graph.predecessors(i): # Skip if k and j are adjacent because then it is a # shielded triple if j in graph.neighbors(k): continue # check if the triple is in the graph's excluded triples if frozenset((k, i, j)) in graph.excluded_triples: continue # Make i-j into i->j graph.orient_uncertain_edge(i, j) added_arrows = True break return added_arrows def _meek_rule2(graph: CPDAG, i: str, j: str) -> bool: """Apply rule 2 of Meek's rules. Check for i - j, and then looks for i -> k -> j triple, to orient i - j as i -> j. """ added_arrows = False # Check if i-j. if graph.has_edge(i, j, graph.undirected_edge_name): # Find nodes k where k is i->k child_i = set() for k in graph.successors(i): if not graph.has_edge(k, i, graph.directed_edge_name): child_i.add(k) # Find nodes j where j is k->j. parent_j = set() for k in graph.predecessors(j): if not graph.has_edge(j, k, graph.directed_edge_name): parent_j.add(k) # Check if there is any node k where i->k->j. candidate_k = child_i.intersection(parent_j) # if the graph has excluded triples, we would check at this point if graph.excluded_triples: # check if the triple is in the graph's excluded triples # if so, remove them from the candidates for k in candidate_k: if frozenset((i, k, j)) in graph.excluded_triples: candidate_k.remove(k) # if there are candidate 'k' nodes, then orient the edge accordingly if len(candidate_k) > 0: # Make i-j into i->j graph.orient_uncertain_edge(i, j) added_arrows = True return added_arrows def _meek_rule3(graph: CPDAG, i: str, j: str) -> bool: """Apply rule 3 of Meek's rules. Check for i - j, and then looks for k -> j <- l collider, and i - k and i - l, then orient i -> j. """ added_arrows = False # Check if i-j first if graph.has_edge(i, j, graph.undirected_edge_name): # For all the pairs of nodes adjacent to i, # look for (k, l), such that j -> l and k -> l for k, l_node in combinations(graph.neighbors(i), 2): # Skip if k and l are adjacent. if l_node in graph.neighbors(k): continue # Skip if not k->j. if graph.has_edge(j, k, graph.directed_edge_name) or ( not graph.has_edge(k, j, graph.directed_edge_name) ): continue # Skip if not l->j. if graph.has_edge(j, l_node, graph.directed_edge_name) or ( not graph.has_edge(l_node, j, graph.directed_edge_name) ): continue # check if the triple is inside graph's excluded triples if frozenset((l_node, i, k)) in graph.excluded_triples: continue # if i - k and i - l, then at this point, we have a valid path # to orient if graph.has_edge(k, i, graph.undirected_edge_name) and graph.has_edge( l_node, i, graph.undirected_edge_name ): graph.orient_uncertain_edge(i, j) added_arrows = True break return added_arrows def _meek_rule4(graph: CPDAG, i: str, j: str) -> bool: """Apply rule 4 of Meek's rules. Check for i - j, and then looks for i - k -> l -> j, to orient i - j as i -> j. """ added_arrows = False # Check if i-j. if graph.has_edge(i, j, graph.undirected_edge_name): # Find nodes k where k is i-k adj_i = set() for k in graph.neighbors(i): if not graph.has_edge(k, i, graph.directed_edge_name): adj_i.add(k) # Find nodes l where j is l->j. parent_j = set() for k in graph.predecessors(j): if not graph.has_edge(j, k, graph.directed_edge_name): parent_j.add(k) # generate all permutations of sets containing neighbors of i and parents of j permut = permutations(adj_i, len(parent_j)) unq = set() # type: ignore for comb in permut: zipped = zip(comb, parent_j) unq.update(zipped) # check if these pairs have a directed edge between them and that k-j does not exist dedges = set(graph.directed_edges) undedges = set(graph.undirected_edges) candidate_k = set() for pair in unq: if pair in dedges: if (pair[0], j) not in undedges: candidate_k.add(pair) # if there are candidate 'k->l' pairs, then orient the edge accordingly if len(candidate_k) > 0: # Make i-j into i->j # logger.info(f"R2: Removing edge {i}-{j} to form {i}->{j}.") graph.orient_uncertain_edge(i, j) added_arrows = True return added_arrows
[docs] def pag_to_mag(graph): """Sample a MAG from a PAG using Zhang's algorithm. Using the algorithm defined in Theorem 2 of :footcite:`Zhang2008`, which turns all o-> edges to -> and -o edges to ->, then it converts the graph into a DAG with no unshielded colliders using the meek rules. Parameters ---------- G : Graph The PAG. Returns ------- mag : Graph The MAG constructed from the PAG. """ copy_graph = graph.copy() cedges = set(copy_graph.circle_edges) dedges = set(copy_graph.directed_edges) temp_cpdag = CPDAG() to_remove = [] to_reorient = [] to_add = [] for u, v in cedges: if (v, u) in dedges: # remove the circle end from a 'o-->' edge to make a '-->' edge to_remove.append((u, v)) elif (v, u) not in cedges: # reorient a '--o' edge to '-->' to_reorient.append((u, v)) elif (v, u) in cedges and ( v, u, ) not in to_add: # add all 'o--o' edges to the cpdag to_add.append((u, v)) for u, v in to_remove: copy_graph.remove_edge(u, v, copy_graph.circle_edge_name) for u, v in to_reorient: copy_graph.orient_uncertain_edge(u, v) for u, v in to_add: temp_cpdag.add_edge(v, u, temp_cpdag.undirected_edge_name) flag = True # convert the graph into a DAG with no unshielded colliders while flag: undedges = temp_cpdag.undirected_edges if len(undedges) != 0: for u, v in undedges: temp_cpdag.remove_edge(u, v, temp_cpdag.undirected_edge_name) temp_cpdag.add_edge(u, v, temp_cpdag.directed_edge_name) _apply_meek_rules(temp_cpdag) break else: flag = False mag = ADMG() # provisional MAG # construct the final MAG for u, v in copy_graph.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) for u, v in temp_cpdag.directed_edges: mag.add_edge(u, v, mag.directed_edge_name) return mag
[docs] def check_pag_definition(G: PAG, L: Optional[set] = None, S: Optional[set] = None): """Checks if the provided graph is a valid Partial Ancestral Graph (PAG). A valid PAG as defined in :footcite:`Zhang2008` is a mixed edge graph that has no directed or almost directed cycles and no inducing paths between any two non-adjacent pair of nodes. It is graph representing all conditional independence (CI) statements that are present in a DAG, forming an equivalence class of DAGs that encode the same CI statements. The steps involved in this check are as follows: - Check for any directed cycles in the PAG. - Check for any almost directed cycles in the PAG. - For every pair of non-adjacent nodes, check for inducing paths. Parameters ---------- G : Graph The graph. Returns ------- is_valid : bool A boolean indicating whether the provided graph is a valid PAG or not. """ if L is None: L = set() if S is None: S = set() directed_sub_graph = G.sub_directed_graph() all_nodes = set(G.nodes) # check if there are more than one edges b/w two nodes for node in all_nodes: nb = set(G.neighbors(node)) for elem in nb: edge_data = G.get_edge_data(node, elem) if (edge_data["bidirected"] is not None) and (edge_data["directed"] is not None): return False # check if there are any directed cyclces try: nx.find_cycle(directed_sub_graph) # raises a NetworkXNoCycle error return False except nx.NetworkXNoCycle: pass # check if there are any almost directed cycles if has_adc(G): # if there is an ADC, it's not a valid MAG return False # check if there are any inducing paths between non-adjacent nodes in the # non-circle edge sub-graph dedges = list(G.edges()["directed"]) # undedges = list(G.edges()["undirected"]) biedges = list(G.edges()["bidirected"]) temp_pag = PAG() temp_pag.add_edges_from(dedges, temp_pag.directed_edge_name) # can't remember why I only handle directed and bidirected edges # temp_pag.add_edges_from(undedges, temp_pag.undirected_edge_name) temp_pag.add_edges_from(biedges, temp_pag.bidirected_edge_name) all_nodes = set(temp_pag.nodes) for source in all_nodes: nb = set(temp_pag.neighbors(source)) cur_set = all_nodes - nb cur_set.remove(source) for dest in cur_set: out = inducing_path(temp_pag, source, dest, L, S) if out[0] is True: return False return True
[docs] def mag_to_pag(G: PAG): """Converts the provided mag into a pag using the FCI algorithm. The FCI algorithms, as defined in :footcite:`Zhang2008` is a provably complete for learning all the tractable features of an MAG, thus producing a PAG. Parameters ---------- G : MAG The MAG. Returns ------- pag : PAG The PAG constructed from the MAG. """ try: from dodiscover import FCI, make_context from dodiscover.ci import Oracle from dodiscover.constraint.utils import dummy_sample except ImportError as e: raise ImportError("The 'dodiscover' package is required to convert a MAG to a PAG.") data = dummy_sample(G) oracle = Oracle(G) # ci_estimator = GSquareCITest(data_type="discrete") context = make_context().variables(data=data).build() fci = FCI(ci_estimator=oracle) fci.learn_graph(data, context) return fci.graph_
def _check_pag_edges_are_equal(G1: PAG, G2: PAG): """Check if the two provided PAGs are equivalent or not. This function compares the edges in both the graphs to determine equivalency. Parameters ---------- G1 : PAG The first PAG. G2 : PAG The second PAG. Returns ------- is_equivalent : bool A boolean indicating whether the two PAGs are equivalent or not. """ g1_edges = G1.edges() g2_edges = G2.edges() if set(g1_edges["directed"]) != set(g2_edges["directed"]): return False elif set(g1_edges["undirected"]) != set(g2_edges["undirected"]): return False elif set(g1_edges["bidirected"]) != set(g2_edges["bidirected"]): return False elif set(g1_edges["circle"]) != set(g2_edges["circle"]): return False else: return True
[docs] def valid_pag(G: PAG): """Check if the provided PAG is valid or not. The function applies Theorem 2 from :footcite:`Zhang2008`, which constitutes a sufficient check for whether the PAG is valid or not. The function determines the validity by first converting the PAG into an MAG, then checking the validity of the said MAG. After the validity of the MAG has been established, the MAG is converted back into a PAG. Then the function checks to see if the original and the reconverted PAG are equivalent or not. Parameters ---------- G : PAG The PAG. Returns ------- is_valid : bool Boolean indicating whether the provided PAG is valid or not. References ---------- .. footbibliography:: """ interim_bool = False # check if the graph is a vald PAG if not check_pag_definition(G): return False converted_mag = pag_to_mag(G) if valid_mag(converted_mag): interim_bool = True # convert the mag back to a pag rec_pag = mag_to_pag(converted_mag) # check if the converted pag is equivalent to the original if _check_pag_edges_are_equal(rec_pag, G): return interim_bool else: return False