import numpy as np
from numpy.typing import NDArray
from scipy.stats import ttest_ind
from dodiscover.toporder.score import SCORE
from dodiscover.toporder.utils import full_adj_to_order
[docs]class DAS(SCORE):
"""The DAS (Discovery At Scale) algorithm for causal discovery.
DAS :footcite:`Montagna2023a` infer the topological ordering using
SCORE :footcite:`rolland2022`.
Then it finds edges in the graph by inspection of the non diagonal entries of the
Hessian of the log likelihood.
A final, computationally cheap, pruning step is performed with
CAM pruning :footcite:`Buhlmann2013`.
The method assumes Additive Noise Model and Gaussianity of the noise terms.
DAS is a highly scalable method, allowing to run causal discovery on thousands of nodes.
It reduces the computational complexity of the pruning method of an order of magnitude with
respect to SCORE.
Parameters
----------
eta_G: float, optional
Regularization parameter for Stein gradient estimator, default is 0.001.
eta_H : float, optional
Regularization parameter for Stein Hessian estimator, default is 0.001.
alpha : float, optional
Alpha cutoff value for variable selection with hypothesis testing over regression
coefficients, default is 0.05.
prune : bool, optional
If True (default), apply CAM-pruning after finding the topological order.
If False, DAS is equivalent to SCORE.
das_cutoff : float, optional
Alpha value for hypothesis testing in preliminary DAS pruning.
If None (default), it is set equal to `alpha`.
n_splines : int, optional
Number of splines to use for the feature function, default is 10.
Automatically decreased in case of insufficient samples
splines_degree: int, optional
Order of spline to use for the feature function, default is 3.
min_parents : int, optional
Minimum number of edges retained by DAS preliminary pruning step, default is 5.
min_parents <= 5 doesn't significantly affects execution time, while increasing the
accuracy.
max_parents : int, optional
Maximum number of parents allowed for a single node, default is 20.
Given that CAM pruning is inefficient for > ~20 nodes, larger values are not advised.
The value of max_parents should be decrease under the assumption of sparse graphs.
References
----------
.. footbibliography::
Notes
-----
Prior knowledge about the included and excluded directed edges in the output DAG
is supported. It is not possible to provide explicit constraints on the relative
positions of nodes in the topological ordering. However, explicitly including a
directed edge in the DAG defines an implicit constraint on the relative position
of the nodes in the topological ordering (i.e. if directed edge `(i,j)` is
encoded in the graph, node `i` will precede node `j` in the output order).
"""
def __init__(
self,
eta_G: float = 0.001,
eta_H: float = 0.001,
alpha: float = 0.05,
prune: bool = True,
das_cutoff: float = None,
n_splines: int = 10,
splines_degree: int = 3,
min_parents: int = 5,
max_parents: int = 20,
):
super().__init__(
eta_G, eta_H, alpha, prune, n_splines, splines_degree, estimate_variance=True, pns=False
)
self.min_parents = min_parents
self.max_parents = max_parents
self.das_cutoff = alpha if das_cutoff is None else das_cutoff
def _prune(self, X: NDArray, A_dense: NDArray) -> NDArray:
"""
DAS preliminary pruning of A_dense matrix representation of a fully connected graph.
Parameters
----------
X : np.ndarray
n x d matrix of the data
A_dense : np.ndarray
fully connected matrix corresponding to a topological ordering
Returns
-------
np.ndarray
Sparse adjacency matrix representing the pruned DAG.
"""
_, d = X.shape
order = full_adj_to_order(A_dense)
max_parents = self.max_parents + 1 # +1 to account for A[l, l] = 1
remaining_nodes = list(range(d))
A_das = np.zeros((d, d))
hess = self.hessian(X, eta_G=self.eta_G, eta_H=self.eta_H)
for i in range(d - 1):
leaf = order[::-1][i]
hess_l = hess[:, leaf, :][:, remaining_nodes]
hess_m = np.abs(np.median(hess_l * self.var[leaf], axis=0))
max_parents = min(max_parents, len(remaining_nodes))
# Find index of the reference for the hypothesis testing
topk_indices = np.argsort(hess_m)[::-1][:max_parents]
topk_values = hess_m[topk_indices] # largest
argmin = topk_indices[np.argmin(topk_values)]
# Edges selection step
parents = []
hess_l = np.abs(hess_l)
l_index = remaining_nodes.index(
leaf
) # leaf index in the remaining nodes (from 0 to len(remaining_nodes)-1)
for j in range(max_parents):
node = topk_indices[j]
if node != l_index: # enforce diagonal elements = 0
if j < self.min_parents: # do not filter minimum number of parents
parents.append(remaining_nodes[node])
else: # filter potential parents with hp testing
# Use hess_l[:, argmin] as sample from a zero mean population
# (implicit assumption: argmin corresponds to zero mean hessian entry)
_, pval = ttest_ind(
hess_l[:, node],
hess_l[:, argmin],
alternative="greater",
equal_var=False,
)
if pval < self.das_cutoff:
parents.append(remaining_nodes[node])
A_das[parents, leaf] = 1
remaining_nodes.pop(l_index)
return super()._prune(X, A_das)