Source code for pywhy_graphs.config

import os.path as op
import platform
import re
import sys
from enum import Enum, EnumMeta
from functools import partial

import numpy as np

# internal edge type to value mapping for numpy array representation
EDGE_TO_VALUE_MAPPING = {
    None: 0,
    "directed": 1,
    "circle": 2,
    "undirected": 10,
    "bidirected": 20,
}
VALUE_TO_EDGE_MAPPING = {val: key for key, val in EDGE_TO_VALUE_MAPPING.items()}


class MetaEnum(EnumMeta):
    """Meta enumeration to make 'in' keyword work."""

    def __contains__(cls, item):
        try:
            cls(item)
        except ValueError:
            return False
        return True

    # Prints out the name of the type
    def __str__(self):
        return self.name


class EdgeType(Enum, metaclass=MetaEnum):
    """Enumeration of different causal edge endpoints.

    Categories
    ----------
    directed : str
        Signifies arrowhead ("->") edges.
    circle : str
        Signifies a circle ("*-o") endpoint. That is an uncertain edge,
        which is either circle with directed edge (``o->``),
        circle with undirected edge (``o-``), or
        circle with circle edge (``o-o``).
    undirected : str
        Signifies an undirected ("-") edge. That is an undirected edge (``-``),
        or circle with circle edge (``-o``).

    Notes
    -----
    The possible edges between two nodes thus are:

    ->, <-, <->, o->, <-o, o-o

    In general, among all possible causal graphs, arrowheads depict
    non-descendant relationships. In DAGs, arrowheads depict direct
    causal relationships (i.e. parents/children). In ADMGs, arrowheads
    can come from directed edges, or bidirected edges
    """

    ALL = "all"
    DIRECTED = "directed"
    BIDIRECTED = "bidirected"
    CIRCLE = "circle"
    UNDIRECTED = "undirected"


class TetradEndpoint(Enum, metaclass=MetaEnum):
    """Enumeration of tetrad endpoints."""

    TAIL = "-"
    ARROW = ">"
    CIRCLE = "o"


class PCAlgPAGEndpoint(Enum, metaclass=MetaEnum):
    """Enumeration of pcalg PAG endpoints."""

    NULL = 0
    CIRCLE = 1
    ARROW = 2
    TAIL = 3


class PCAlgCPDAGEndpoint(Enum, metaclass=MetaEnum):
    """Enumeration of pcalg CPDAG endpoints."""

    NULL = 0
    ARROW = 1


# Taken from causal-learn Endpoint.py
# A typesafe enumeration of the types of endpoints that are permitted in
# Tetrad-style graphs: tail (--) null (-), arrow (->), circle (-o) and star (-*).
# 'TAIL_AND_ARROW' and 'ARROW_AND_ARROW' means there are two types of edges (<-> and -->)
# between two nodes.
# 'TAIL_AND_TAIL' means there are two types of edges with two tails ending on this endpoint
class CLearnEndpoint(Enum, metaclass=MetaEnum):
    """Enumeration of causal-learn endpoints."""

    TAIL = -1
    NULL = 0
    ARROW = 1
    CIRCLE = 2
    STAR = 3
    TAIL_AND_ARROW = 4
    ARROW_AND_ARROW = 5
    TAIL_AND_TAIL = 6  # added by pywhy.


class TigramiteEndpoint(Enum, metaclass=MetaEnum):
    """Enumeration of causal-learn endpoints."""

    TAIL = "--"
    NULL = ""
    ARROW = "->"
    CIRCLE = "-o"
    STAR = "-*"
    TAIL_AND_ARROW = "+->"
    # ARROW_AND_ARROW


ARRAY_ENUMS = {
    "clearn": CLearnEndpoint,
}


def _pl(x, non_pl="", pl="s"):
    """Determine if plural should be used."""
    len_x = x if isinstance(x, (int, np.generic)) else len(x)
    return non_pl if len_x == 1 else pl


def _get_numpy_libs():
    bad_lib = "unknown linalg bindings"
    try:
        from threadpoolctl import threadpool_info
    except Exception as exc:
        return bad_lib + f" (threadpoolctl module not found: {exc})"
    pools = threadpool_info()
    rename = dict(
        openblas="OpenBLAS",
        mkl="MKL",
    )
    for pool in pools:
        if pool["internal_api"] in ("openblas", "mkl"):
            return (
                f'{rename[pool["internal_api"]]} '
                f'{pool["version"]} with '
                f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}'
            )
    return bad_lib


[docs] def sys_info(fid=None, show_paths=False, *, dependencies="user"): """Print the system information for debugging. This function is useful for printing system information to help triage bugs. Parameters ---------- fid : file-like | None The file to write to. Will be passed to :func:`print()`. Can be None to use :data:`sys.stdout`. show_paths : bool If True, print paths for each module. dependencies : 'user' | 'developer' Show dependencies relevant for users (default) or for developers (i.e., output includes additional dependencies). Examples -------- Running this function with no arguments prints an output that is useful when submitting bug reports:: >>> import pywhy_graphs >>> pywhy_graphs.sys_info() # doctest: +SKIP Platform: Linux-4.15.0-1067-aws-x86_64-with-glibc2.2.5 Python: 3.8.1 (default, Feb 2 2020, 08:37:37) [GCC 8.3.0] Executable: /usr/local/bin/python CPU: : 36 cores Memory: 68.7 GB numpy: 1.21.5 {OpenBLAS 0.3.17 with 8 threads} scipy: 1.8.0 networkx: 2.8.8 sklearn: 1.2.0 matplotlib: 3.6.2 {backend=MacOSX} pandas: 1.5.2 pygraphviz: Not found causal-learn: no version info joblib: 1.2.0 pywhy_graphs: 0.0.0 dodiscover: Not found dowhy: 0.8 """ # noqa: E501 ljust = 21 if dependencies == "developer" else 18 platform_str = platform.platform() if platform.system() == "Darwin" and sys.version_info[:2] < (3, 8): # platform.platform() in Python < 3.8 doesn't call # platform.mac_ver() if we're on Darwin, so we don't get a nice macOS # version number. Therefore, let's do this manually here. macos_ver = platform.mac_ver()[0] macos_architecture = re.findall("Darwin-.*?-(.*)", platform_str) if macos_architecture: macos_architecture = macos_architecture[0] platform_str = f"macOS-{macos_ver}-{macos_architecture}" del macos_ver, macos_architecture out = partial(print, end="", file=fid) out("Platform:".ljust(ljust) + platform_str + "\n") out("Python:".ljust(ljust) + str(sys.version).replace("\n", " ") + "\n") out("Executable:".ljust(ljust) + sys.executable + "\n") out("CPU:".ljust(ljust) + f"{platform.processor()}: ") try: import multiprocessing except ImportError: out("number of processors unavailable " '(requires "multiprocessing" package)\n') else: out(f"{multiprocessing.cpu_count()} cores\n") out("Memory:".ljust(ljust)) try: import psutil except ImportError: out('Unavailable (requires "psutil" package)') else: out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n") out("\n") libs = _get_numpy_libs() use_mod_names = ( "numpy", "scipy", "networkx", "", "sklearn", "matplotlib", "pandas", "pygraphviz", "causallearn", # no version # "tigramite", # no version "joblib", "", "pywhy_graphs", "dodiscover", "dowhy", ) if dependencies == "developer": use_mod_names += ( "", "sphinx", "sphinx_gallery", "numpydoc", "pydata_sphinx_theme", "pytest", "nbclient", "poetry", "poethepoet", ) for mod_name in use_mod_names: if mod_name == "": out("\n") continue out(f"{mod_name}:".ljust(ljust)) try: mod = __import__(mod_name) except Exception: out("Not found\n") else: if mod_name == "causallearn": out("no version info") else: out(mod.__version__) if mod_name == "numpy": out(f" {{{libs}}}") elif mod_name == "matplotlib": out(f" {{backend={mod.get_backend()}}}") if show_paths: out(f'\n{" " * ljust}{op.dirname(mod.__file__)}') out("\n")