Source code for dowhy.graph_learners

from importlib import import_module

from dowhy.graph_learner import GraphLearner


[docs]def get_discovery_class_object(method_name, *args, **kwargs): """ Import class from graph_learners. """ # from https://www.bnmetrics.com/blog/factory-pattern-in-python3-simple-version try: module_name = method_name class_name = module_name.upper() discovery_module = import_module("." + module_name, package="dowhy.graph_learners") discovery_class = getattr(discovery_module, class_name) assert issubclass(discovery_class, GraphLearner) except (AttributeError, AssertionError, ImportError): raise ImportError("{} is not an existing causal discovery method.".format(method_name)) return discovery_class
[docs]def get_library_class_object(module_method_name, *args, **kwargs): """ Import library for causal inference. """ # from https://www.bnmetrics.com/blog/factory-pattern-in-python3-simple-version try: (module_name, _, class_name) = module_method_name.rpartition(".") discovery_module = import_module(module_name) discovery_class = getattr(discovery_module, class_name) except (AttributeError, AssertionError, ImportError): raise ImportError( "Error loading {}.{}. Double-check the method name and ensure that all library dependencies are installed.".format( module_name, class_name ) ) return discovery_class