Source code for pywhy_graphs.export.pcalg

from typing import Dict, List

import numpy as np

import pywhy_graphs
from pywhy_graphs.config import CLearnEndpoint, PCAlgCPDAGEndpoint, PCAlgPAGEndpoint
from pywhy_graphs.typing import Node

from .causallearn import graph_to_clearn


[docs] def pcalg_to_graph(arr, arr_idx: List[Node], amat_type: str): """Convert an array from R's pcalg into causal graph. Parameters ---------- arr : array-like of shape (n_nodes, n_nodes) The array representing the causal graph with enumerations following that of `PCAlgPAGEndpoint`. Rows represent the starting node and the columns represent the ending node. arr_idx : List[Node] of length (n_nodes,) The names of the nodes that are assigned to the graph in order of the rows/columns of ``arr``. amat_type : str The type of graph in pcalg. One of ``{"pag", "cpdag"}``. Returns ------- graph : causal graph The causal graph Notes ----- See pcalg documentation for ``amatType`` https://cran.r-project.org/web/packages/pcalg/pcalg.pdf. Copied here for convenience. Coding for type amat.cpdag: 0: No edge or tail 1: Arrowhead Note that the edgemark-code refers to the row index (as opposed adjacency matrices of type mag or pag). E.g.: amat[a,b] = 0 and amat[b,a] = 1 implies a --> b. amat[a,b] = 1 and amat[b,a] = 0 implies a <-- b. amat[a,b] = 0 and amat[b,a] = 0 implies a b. amat[a,b] = 1 and amat[b,a] = 1 implies a --- b. Coding for type amat.pag: 0: No edge 1: Circle 2: Arrowhead 3: Tail Note that the edgemark-code refers to the column index (as opposed adjacency matrices of type dag or cpdag). E.g.: amat[a,b] = 2 and amat[b,a] = 3 implies a --> b. amat[a,b] = 3 and amat[b,a] = 2 implies a <-- b. amat[a,b] = 2 and amat[b,a] = 2 implies a <-> b. amat[a,b] = 1 and amat[b,a] = 3 implies a --o b. amat[a,b] = 0 and amat[b,a] = 0 implies a b. """ if amat_type not in ("pag", "cpdag"): raise RuntimeError(f'Only amat_types are "pag" and "cpdag", not {amat_type}') # instantiate the type of causal graph if amat_type == "cpdag": graph = pywhy_graphs.CPDAG() # type: ignore elif amat_type == "pag": graph = pywhy_graphs.PAG() graph.add_nodes_from(arr_idx) # get all non-zero indices and add edge accordingly memo_map: Dict = dict() for idx, jdx in np.argwhere(arr != 0): arr_val = arr[idx, jdx] u, v = arr_idx[idx], arr_idx[jdx] # perform memoization to speed up loop if (idx, jdx) in memo_map: continue memo_map[(idx, jdx)] = None memo_map[(jdx, idx)] = None if amat_type == "pag": if arr_val == PCAlgPAGEndpoint.ARROW.value: # check other direction to determine if a bidirected edge if arr[jdx, idx] == PCAlgPAGEndpoint.ARROW.value: graph.add_edge(u, v, edge_type=graph.bidirected_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.CIRCLE.value: graph.add_edge(u, v, edge_type=graph.directed_edge_name) graph.add_edge(v, u, edge_type=graph.circle_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.TAIL.value: graph.add_edge(u, v, edge_type=graph.directed_edge_name) elif arr_val == PCAlgPAGEndpoint.TAIL.value: # check other direction to determine if a bidirected edge if arr[jdx, idx] == PCAlgPAGEndpoint.TAIL.value: graph.add_edge(u, v, edge_type=graph.undirected_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.CIRCLE.value: graph.add_edge(v, u, edge_type=graph.circle_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.ARROW.value: graph.add_edge(v, u, edge_type=graph.directed_edge_name) elif arr_val == PCAlgPAGEndpoint.CIRCLE.value: # check other direction to determine if a bidirected edge if arr[jdx, idx] == PCAlgPAGEndpoint.TAIL.value: graph.add_edge(u, v, edge_type=graph.circle_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.CIRCLE.value: graph.add_edge(u, v, edge_type=graph.circle_edge_name) graph.add_edge(v, u, edge_type=graph.circle_edge_name) elif arr[jdx, idx] == PCAlgPAGEndpoint.ARROW.value: graph.add_edge(u, v, edge_type=graph.circle_edge_name) graph.add_edge(v, u, edge_type=graph.directed_edge_name) elif amat_type == "cpdag": if arr_val == PCAlgCPDAGEndpoint.ARROW.value: # check other direction to determine if a bidirected edge if arr[jdx, idx] == PCAlgCPDAGEndpoint.ARROW.value: graph.add_edge(u, v, edge_type=graph.undirected_edge_name) else: graph.add_edge(u, v, edge_type=graph.directed_edge_name) elif arr_val == PCAlgCPDAGEndpoint.NULL.value: # check other direction to determine if a bidirected edge if arr[jdx, idx] == PCAlgCPDAGEndpoint.ARROW.value: graph.add_edge(v, u, edge_type=graph.directed_edge_name) return graph
[docs] def graph_to_pcalg(causal_graph): """Convert causal graph to a pcalg type adjacency array. Parameters ---------- causal_graph : instance of causal graph The causal graph that is represented in pywhy. Either CPDAG, or PAG. Returns ------- numpy_graph : array-like of shape (n_nodes, n_nodes) The numpy array that represents the graph. The values representing edges are mapped according to a pre-defined set of values. See Notes. See Also -------- pcalg_to_graph Notes ----- See :func:`pcalg_to_graph` for information on how edges are represented in pcalg. """ if not isinstance(causal_graph, (pywhy_graphs.CPDAG, pywhy_graphs.PAG)): raise RuntimeError(f"Causal graph must be one of CPDAG or PAG, not {causal_graph}.") elif isinstance(causal_graph, pywhy_graphs.CPDAG): amat_type = "cpdag" elif isinstance(causal_graph, pywhy_graphs.PAG): amat_type = "pag" # first convert to causal-learn array clearn_arr, _ = graph_to_clearn(causal_graph) # the pcalg array is in a different format clearn_arr = clearn_arr.T # now map all values to their respective pcalg values seen_idx = dict() for idx, jdx in np.argwhere(clearn_arr != 0): if (idx, jdx) in seen_idx or (jdx, idx) in seen_idx: continue seen_idx[(idx, jdx)] = None if amat_type == "cpdag": if ( clearn_arr[idx, jdx] == CLearnEndpoint.TAIL.value and clearn_arr[jdx, idx] == CLearnEndpoint.TAIL.value ): # -- clearn_arr[idx, jdx] = PCAlgCPDAGEndpoint.ARROW.value clearn_arr[jdx, idx] = PCAlgCPDAGEndpoint.ARROW.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.ARROW.value and clearn_arr[jdx, idx] == CLearnEndpoint.TAIL.value ): # -> clearn_arr[idx, jdx] = PCAlgCPDAGEndpoint.ARROW.value clearn_arr[jdx, idx] = PCAlgCPDAGEndpoint.NULL.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.TAIL.value and clearn_arr[jdx, idx] == CLearnEndpoint.ARROW.value ): # <- clearn_arr[idx, jdx] = PCAlgCPDAGEndpoint.NULL.value clearn_arr[jdx, idx] = PCAlgCPDAGEndpoint.ARROW.value if amat_type == "pag": if ( clearn_arr[idx, jdx] == CLearnEndpoint.ARROW.value and clearn_arr[jdx, idx] == CLearnEndpoint.NULL.value ): # -> clearn_arr[idx, jdx] = PCAlgPAGEndpoint.ARROW.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.NULL.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.NULL.value and clearn_arr[jdx, idx] == CLearnEndpoint.ARROW.value ): # <- clearn_arr[idx, jdx] = PCAlgPAGEndpoint.NULL.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.ARROW.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.ARROW.value and clearn_arr[jdx, idx] == CLearnEndpoint.ARROW.value ): # <-> clearn_arr[idx, jdx] = PCAlgPAGEndpoint.ARROW.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.ARROW.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.ARROW.value and clearn_arr[jdx, idx] == CLearnEndpoint.CIRCLE.value ): # o-> clearn_arr[idx, jdx] = PCAlgPAGEndpoint.ARROW.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.CIRCLE.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.CIRCLE.value and clearn_arr[jdx, idx] == CLearnEndpoint.ARROW.value ): # <-o clearn_arr[idx, jdx] = PCAlgPAGEndpoint.CIRCLE.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.ARROW.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.TAIL.value and clearn_arr[jdx, idx] == CLearnEndpoint.TAIL.value ): # -- clearn_arr[idx, jdx] = PCAlgPAGEndpoint.TAIL.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.TAIL.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.CIRCLE.value and clearn_arr[jdx, idx] == CLearnEndpoint.CIRCLE.value ): # o-o clearn_arr[idx, jdx] = PCAlgPAGEndpoint.CIRCLE.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.CIRCLE.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.CIRCLE.value and clearn_arr[jdx, idx] == CLearnEndpoint.TAIL.value ): # o-- clearn_arr[idx, jdx] = PCAlgPAGEndpoint.CIRCLE.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.TAIL.value elif ( clearn_arr[idx, jdx] == CLearnEndpoint.TAIL.value and clearn_arr[jdx, idx] == CLearnEndpoint.CIRCLE.value ): # --o clearn_arr[idx, jdx] = PCAlgPAGEndpoint.TAIL.value clearn_arr[jdx, idx] = PCAlgPAGEndpoint.CIRCLE.value return clearn_arr