Source code for pywhy_graphs.algorithms.multidomain

from itertools import combinations
from typing import Optional, Set
from warnings import warn

from pywhy_graphs.classes import AugmentedGraph
from pywhy_graphs.typing import Node

[docs]def add_all_snode_combinations(G, n_domains: int): """Add all possible S-nodes to the graph given number of domains. Parameters ---------- G : AugmentedGraph The augmented graph. n_domains : int The number of domains. Returns ------- G : AugmentedGraph The augmented graph with all possible S-nodes added. Note that none of the added S-nodes have any edges. """ G = G.copy() s_node_domains = dict() sdx = 0 # add all the S-nodes representing differences across pairs of domains # to every single node with S-nodes for domains in combinations(range(1, n_domains + 1), 2): source_domain, target_domain = sorted(domains) # now modify the function of the edge, S-nodes are pointing to s_node = ("S", sdx) G.add_node(s_node, domain_ids=(source_domain, target_domain)) s_node_domains[(source_domain, target_domain)] = s_node # increment the S-node counter sdx += 1 return G, s_node_domains
[docs]def compute_invariant_domains_per_node( G: AugmentedGraph, node: Node, all_poss_snodes: Optional[Set] = None, n_domains: Optional[int] = None, inconsistency="raise", ): """Compute the invariant domains for a specific node. This proceeds by constructing all possible S-nodes given the number of domains (i.e. ``n_domains choose 2`` S-nodes), and then uses the S-nodes in G to infer the invariant domains for the node. Parameters ---------- G : AugmentedGraph The augmented graph. node : Node The node in G to compute the invariant domains for. all_poss_snodes : Optional[Set], optional All possible S-nodes, by default None. If None, will infer based on the number of domains. n_domains : int, optional The number of domains, by default None. If None, will infer based on the ``domain_ids`` attribute of G. inconsistency : str, optional How to handle inconsistencies, by default 'raise'. Can be one of: - 'raise': raise an exception - 'ignore': ignore the inconsistency. - 'warn': raise a warning An inconsistency is when the current included S-nodes are not the same after computing the invariant domains. If 'ignore', or 'warn', the inconsistent S-node will be removed in `G`. Returns ------- G : AugmentedGraph The augmented graph """ # infer the number of domains based on the number of domain IDs in the augmented # graph so far if n_domains is None: n_domains = len(G.domain_ids) # original S-nodes orig_s_nodes = set(G.s_nodes) # add now all relevant S-nodes considering the domains if all_poss_snodes is None: G, s_node_domains = add_all_snode_combinations(G, n_domains) all_poss_snodes = set(G.s_nodes) remove_s_node = [] for s_node in all_poss_snodes: if s_node not in orig_s_nodes: remove_s_node.append(s_node) # find all connected pairs tuples = [] for s_node in remove_s_node: source_domain, target_domain = G.nodes(data=True)[s_node]["domain_ids"] tuples.append((source_domain, target_domain)) G.remove_node(s_node) # now compute all invariant domains connected_pairs = find_connected_pairs(tuples, n_domains) invariant_domains = set() for domain_pair in connected_pairs: # remove all the S-nodes that are not in the connected component s_node = s_node_domains[domain_pair] G.remove_edge(s_node, node) # check if any S-nodes are not in the original if s_node not in orig_s_nodes: if inconsistency == "raise": raise RuntimeError(f"Inconsistency in S-nodes for node {node}!") elif inconsistency == "warn": warn(f"Inconsistency in S-nodes for node {node}!") # for each removed S-node, there are invariances in the SCM for this node invariant_domains.add(domain_pair[0]) invariant_domains.add(domain_pair[1]) # now set the functional relationships based on the invariances G.nodes()[node]["invariant_domains"] = invariant_domains return G
class DisjointSet: """Helper data structure to enable disjoint set.""" def __init__(self): self.parent = {} def find(self, x): if x not in self.parent: self.parent[x] = x return x if self.parent[x] != x: self.parent[x] = self.find(self.parent[x]) return self.parent[x] def union(self, x, y): root_x = self.find(x) root_y = self.find(y) if root_x != root_y: self.parent[root_x] = root_y def get_sets(self): sets = {} for item in self.parent: root = self.find(item) if root not in sets: sets[root] = set() sets[root].add(item) return sets.values()
[docs]def find_connected_pairs(tuples, max_number): """Find connected pairs of domain tuples. Parameters ---------- tuples : List of tuples List of tuples of domain ids (i, j). max_number : int The maximum number that can be in a domain id. Returns ------- connected_pairs : set of tuples Set of domain ids that are connected. """ # XXX: this could be made more efficient as it checks for any unordered pair combination # in our setting, we always know (j > i) in (i, j). disjoint_set = DisjointSet() for i, j in tuples: if j > max_number: continue disjoint_set.union(i, j) connected_pairs = set() for set_items in disjoint_set.get_sets(): for i in set_items: for j in set_items: if i != j: connected_pairs.add(tuple(sorted((i, j)))) return connected_pairs