import os.path as op
import platform
import re
import sys
from enum import Enum, EnumMeta
from functools import partial
import numpy as np
# internal edge type to value mapping for numpy array representation
EDGE_TO_VALUE_MAPPING = {
None: 0,
"directed": 1,
"circle": 2,
"undirected": 10,
"bidirected": 20,
}
VALUE_TO_EDGE_MAPPING = {val: key for key, val in EDGE_TO_VALUE_MAPPING.items()}
class MetaEnum(EnumMeta):
"""Meta enumeration to make 'in' keyword work."""
def __contains__(cls, item):
try:
cls(item)
except ValueError:
return False
return True
# Prints out the name of the type
def __str__(self):
return self.name
class EdgeType(Enum, metaclass=MetaEnum):
"""Enumeration of different causal edge endpoints.
Categories
----------
directed : str
Signifies arrowhead ("->") edges.
circle : str
Signifies a circle ("*-o") endpoint. That is an uncertain edge,
which is either circle with directed edge (``o->``),
circle with undirected edge (``o-``), or
circle with circle edge (``o-o``).
undirected : str
Signifies an undirected ("-") edge. That is an undirected edge (``-``),
or circle with circle edge (``-o``).
Notes
-----
The possible edges between two nodes thus are:
->, <-, <->, o->, <-o, o-o
In general, among all possible causal graphs, arrowheads depict
non-descendant relationships. In DAGs, arrowheads depict direct
causal relationships (i.e. parents/children). In ADMGs, arrowheads
can come from directed edges, or bidirected edges
"""
ALL = "all"
DIRECTED = "directed"
BIDIRECTED = "bidirected"
CIRCLE = "circle"
UNDIRECTED = "undirected"
class TetradEndpoint(Enum, metaclass=MetaEnum):
"""Enumeration of tetrad endpoints."""
TAIL = "-"
ARROW = ">"
CIRCLE = "o"
class PCAlgPAGEndpoint(Enum, metaclass=MetaEnum):
"""Enumeration of pcalg PAG endpoints."""
NULL = 0
CIRCLE = 1
ARROW = 2
TAIL = 3
class PCAlgCPDAGEndpoint(Enum, metaclass=MetaEnum):
"""Enumeration of pcalg CPDAG endpoints."""
NULL = 0
ARROW = 1
# Taken from causal-learn Endpoint.py
# A typesafe enumeration of the types of endpoints that are permitted in
# Tetrad-style graphs: tail (--) null (-), arrow (->), circle (-o) and star (-*).
# 'TAIL_AND_ARROW' and 'ARROW_AND_ARROW' means there are two types of edges (<-> and -->)
# between two nodes.
# 'TAIL_AND_TAIL' means there are two types of edges with two tails ending on this endpoint
class CLearnEndpoint(Enum, metaclass=MetaEnum):
"""Enumeration of causal-learn endpoints."""
TAIL = -1
NULL = 0
ARROW = 1
CIRCLE = 2
STAR = 3
TAIL_AND_ARROW = 4
ARROW_AND_ARROW = 5
TAIL_AND_TAIL = 6 # added by pywhy.
class TigramiteEndpoint(Enum, metaclass=MetaEnum):
"""Enumeration of causal-learn endpoints."""
TAIL = "--"
NULL = ""
ARROW = "->"
CIRCLE = "-o"
STAR = "-*"
TAIL_AND_ARROW = "+->"
# ARROW_AND_ARROW
ARRAY_ENUMS = {
"clearn": CLearnEndpoint,
}
def _pl(x, non_pl="", pl="s"):
"""Determine if plural should be used."""
len_x = x if isinstance(x, (int, np.generic)) else len(x)
return non_pl if len_x == 1 else pl
def _get_numpy_libs():
bad_lib = "unknown linalg bindings"
try:
from threadpoolctl import threadpool_info
except Exception as exc:
return bad_lib + f" (threadpoolctl module not found: {exc})"
pools = threadpool_info()
rename = dict(
openblas="OpenBLAS",
mkl="MKL",
)
for pool in pools:
if pool["internal_api"] in ("openblas", "mkl"):
return (
f'{rename[pool["internal_api"]]} '
f'{pool["version"]} with '
f'{pool["num_threads"]} thread{_pl(pool["num_threads"])}'
)
return bad_lib
[docs]
def sys_info(fid=None, show_paths=False, *, dependencies="user"):
"""Print the system information for debugging.
This function is useful for printing system information
to help triage bugs.
Parameters
----------
fid : file-like | None
The file to write to. Will be passed to :func:`print()`.
Can be None to use :data:`sys.stdout`.
show_paths : bool
If True, print paths for each module.
dependencies : 'user' | 'developer'
Show dependencies relevant for users (default) or for developers
(i.e., output includes additional dependencies).
Examples
--------
Running this function with no arguments prints an output that is
useful when submitting bug reports::
>>> import pywhy_graphs
>>> pywhy_graphs.sys_info() # doctest: +SKIP
Platform: Linux-4.15.0-1067-aws-x86_64-with-glibc2.2.5
Python: 3.8.1 (default, Feb 2 2020, 08:37:37) [GCC 8.3.0]
Executable: /usr/local/bin/python
CPU: : 36 cores
Memory: 68.7 GB
numpy: 1.21.5 {OpenBLAS 0.3.17 with 8 threads}
scipy: 1.8.0
networkx: 2.8.8
sklearn: 1.2.0
matplotlib: 3.6.2 {backend=MacOSX}
pandas: 1.5.2
pygraphviz: Not found
causal-learn: no version info
joblib: 1.2.0
pywhy_graphs: 0.0.0
dodiscover: Not found
dowhy: 0.8
""" # noqa: E501
ljust = 21 if dependencies == "developer" else 18
platform_str = platform.platform()
if platform.system() == "Darwin" and sys.version_info[:2] < (3, 8):
# platform.platform() in Python < 3.8 doesn't call
# platform.mac_ver() if we're on Darwin, so we don't get a nice macOS
# version number. Therefore, let's do this manually here.
macos_ver = platform.mac_ver()[0]
macos_architecture = re.findall("Darwin-.*?-(.*)", platform_str)
if macos_architecture:
macos_architecture = macos_architecture[0]
platform_str = f"macOS-{macos_ver}-{macos_architecture}"
del macos_ver, macos_architecture
out = partial(print, end="", file=fid)
out("Platform:".ljust(ljust) + platform_str + "\n")
out("Python:".ljust(ljust) + str(sys.version).replace("\n", " ") + "\n")
out("Executable:".ljust(ljust) + sys.executable + "\n")
out("CPU:".ljust(ljust) + f"{platform.processor()}: ")
try:
import multiprocessing
except ImportError:
out("number of processors unavailable " '(requires "multiprocessing" package)\n')
else:
out(f"{multiprocessing.cpu_count()} cores\n")
out("Memory:".ljust(ljust))
try:
import psutil
except ImportError:
out('Unavailable (requires "psutil" package)')
else:
out(f"{psutil.virtual_memory().total / float(2 ** 30):0.1f} GB\n")
out("\n")
libs = _get_numpy_libs()
use_mod_names = (
"numpy",
"scipy",
"networkx",
"",
"sklearn",
"matplotlib",
"pandas",
"pygraphviz",
"causallearn", # no version
# "tigramite", # no version
"joblib",
"",
"pywhy_graphs",
"dodiscover",
"dowhy",
)
if dependencies == "developer":
use_mod_names += (
"",
"sphinx",
"sphinx_gallery",
"numpydoc",
"pydata_sphinx_theme",
"pytest",
"nbclient",
"poetry",
"poethepoet",
)
for mod_name in use_mod_names:
if mod_name == "":
out("\n")
continue
out(f"{mod_name}:".ljust(ljust))
try:
mod = __import__(mod_name)
except Exception:
out("Not found\n")
else:
if mod_name == "causallearn":
out("no version info")
else:
out(mod.__version__)
if mod_name == "numpy":
out(f" {{{libs}}}")
elif mod_name == "matplotlib":
out(f" {{backend={mod.get_backend()}}}")
if show_paths:
out(f'\n{" " * ljust}•{op.dirname(mod.__file__)}')
out("\n")