tigramite-fast 5.2.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigramite/__init__.py +0 -0
- tigramite/causal_effects.py +1525 -0
- tigramite/causal_mediation.py +1592 -0
- tigramite/data_processing.py +1574 -0
- tigramite/graphs.py +1509 -0
- tigramite/independence_tests/LBFGS.py +1114 -0
- tigramite/independence_tests/__init__.py +0 -0
- tigramite/independence_tests/cmiknn.py +661 -0
- tigramite/independence_tests/cmiknn_mixed.py +1397 -0
- tigramite/independence_tests/cmisymb.py +286 -0
- tigramite/independence_tests/gpdc.py +664 -0
- tigramite/independence_tests/gpdc_torch.py +820 -0
- tigramite/independence_tests/gsquared.py +190 -0
- tigramite/independence_tests/independence_tests_base.py +1310 -0
- tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
- tigramite/independence_tests/pairwise_CI.py +383 -0
- tigramite/independence_tests/parcorr.py +369 -0
- tigramite/independence_tests/parcorr_mult.py +485 -0
- tigramite/independence_tests/parcorr_wls.py +451 -0
- tigramite/independence_tests/regressionCI.py +403 -0
- tigramite/independence_tests/robust_parcorr.py +403 -0
- tigramite/jpcmciplus.py +966 -0
- tigramite/lpcmci.py +3649 -0
- tigramite/models.py +2257 -0
- tigramite/pcmci.py +3935 -0
- tigramite/pcmci_base.py +1218 -0
- tigramite/plotting.py +4735 -0
- tigramite/rpcmci.py +467 -0
- tigramite/toymodels/__init__.py +0 -0
- tigramite/toymodels/context_model.py +261 -0
- tigramite/toymodels/non_additive.py +1231 -0
- tigramite/toymodels/structural_causal_processes.py +1201 -0
- tigramite/toymodels/surrogate_generator.py +319 -0
- tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
- tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
- tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
- tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
- tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1582 @@
|
|
|
1
|
+
"""Tigramite causal discovery for time series."""
|
|
2
|
+
|
|
3
|
+
# Author: Jakob Runge <jakob@jakob-runge.com>
|
|
4
|
+
#
|
|
5
|
+
# License: GNU General Public License v3.0
|
|
6
|
+
|
|
7
|
+
from __future__ import print_function
|
|
8
|
+
import numpy as np
|
|
9
|
+
|
|
10
|
+
from collections import defaultdict, OrderedDict
|
|
11
|
+
from itertools import combinations, permutations
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class OracleCI:
|
|
15
|
+
r"""Oracle of conditional independence test X _|_ Y | Z given a graph.
|
|
16
|
+
|
|
17
|
+
Class around link_coeff causal ground truth. X _|_ Y | Z is based on
|
|
18
|
+
assessing whether X and Y are d-separated given Z in the graph.
|
|
19
|
+
|
|
20
|
+
Class can be used just like a Tigramite conditional independence class
|
|
21
|
+
(e.g., ParCorr). The main use is for unit testing of PCMCI methods.
|
|
22
|
+
|
|
23
|
+
Parameters
|
|
24
|
+
----------
|
|
25
|
+
graph : array of shape [N, N, tau_max+1]
|
|
26
|
+
Causal graph.
|
|
27
|
+
links : dict
|
|
28
|
+
Dictionary of form {0:[(0, -1), ...], 1:[...], ...}.
|
|
29
|
+
Alternatively can also digest {0: [((0, -1), coeff, func)], ...}.
|
|
30
|
+
observed_vars : None or list, optional (default: None)
|
|
31
|
+
Subset of keys in links definining which variables are
|
|
32
|
+
observed. If None, then all variables are observed.
|
|
33
|
+
selection_vars : None or list, optional (default: None)
|
|
34
|
+
Subset of keys in links definining which variables are
|
|
35
|
+
selected (= always conditioned on at every time lag).
|
|
36
|
+
If None, then no variables are selected.
|
|
37
|
+
verbosity : int, optional (default: 0)
|
|
38
|
+
Level of verbosity.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
# documentation
|
|
42
|
+
@property
|
|
43
|
+
def measure(self):
|
|
44
|
+
"""
|
|
45
|
+
Concrete property to return the measure of the independence test
|
|
46
|
+
"""
|
|
47
|
+
return self._measure
|
|
48
|
+
|
|
49
|
+
def __init__(self,
|
|
50
|
+
links=None,
|
|
51
|
+
observed_vars=None,
|
|
52
|
+
selection_vars=None,
|
|
53
|
+
graph=None,
|
|
54
|
+
graph_is_mag=False,
|
|
55
|
+
tau_max=None,
|
|
56
|
+
verbosity=0):
|
|
57
|
+
|
|
58
|
+
self.tau_max = tau_max
|
|
59
|
+
self.graph_is_mag = graph_is_mag
|
|
60
|
+
|
|
61
|
+
if links is None:
|
|
62
|
+
if graph is None:
|
|
63
|
+
raise ValueError("Either links or graph must be specified!")
|
|
64
|
+
else:
|
|
65
|
+
# Get canonical DAG from graph, potentially interpreted as MAG
|
|
66
|
+
# self.tau_max = graph.shape[2]
|
|
67
|
+
(links,
|
|
68
|
+
observed_vars,
|
|
69
|
+
selection_vars) = self.get_links_from_graph(graph)
|
|
70
|
+
# # TODO make checks and tau_max?
|
|
71
|
+
# self.graph = graph
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
self.verbosity = verbosity
|
|
75
|
+
self._measure = 'oracle_ci'
|
|
76
|
+
self.confidence = None
|
|
77
|
+
self.links = links
|
|
78
|
+
self.N = len(links)
|
|
79
|
+
# self.tau_max = self._get_minmax_lag(self.links)
|
|
80
|
+
|
|
81
|
+
# Initialize already computed dsepsets of X, Y, Z
|
|
82
|
+
self.dsepsets = {}
|
|
83
|
+
|
|
84
|
+
# Initialize observed vars
|
|
85
|
+
self.observed_vars = observed_vars
|
|
86
|
+
if self.observed_vars is None:
|
|
87
|
+
self.observed_vars = range(self.N)
|
|
88
|
+
else:
|
|
89
|
+
if not set(self.observed_vars).issubset(set(range(self.N))):
|
|
90
|
+
raise ValueError("observed_vars must be subset of range(N).")
|
|
91
|
+
if self.observed_vars != sorted(self.observed_vars):
|
|
92
|
+
raise ValueError("observed_vars must ordered.")
|
|
93
|
+
if len(self.observed_vars) != len(set(self.observed_vars)):
|
|
94
|
+
raise ValueError("observed_vars must not contain duplicates.")
|
|
95
|
+
|
|
96
|
+
self.selection_vars = selection_vars
|
|
97
|
+
|
|
98
|
+
if self.selection_vars is not None:
|
|
99
|
+
if not set(self.selection_vars).issubset(set(range(self.N))):
|
|
100
|
+
raise ValueError("selection_vars must be subset of range(N).")
|
|
101
|
+
if self.selection_vars != sorted(self.selection_vars):
|
|
102
|
+
raise ValueError("selection_vars must ordered.")
|
|
103
|
+
if len(self.selection_vars) != len(set(self.selection_vars)):
|
|
104
|
+
raise ValueError("selection_vars must not contain duplicates.")
|
|
105
|
+
else:
|
|
106
|
+
self.selection_vars = []
|
|
107
|
+
|
|
108
|
+
# ToDO: maybe allow to use user-tau_max, otherwise deduced from links
|
|
109
|
+
self.graph = self.get_graph_from_links(tau_max=tau_max)
|
|
110
|
+
|
|
111
|
+
self.ci_results = {}
|
|
112
|
+
|
|
113
|
+
def set_dataframe(self, dataframe):
|
|
114
|
+
"""Dummy function."""
|
|
115
|
+
pass
|
|
116
|
+
|
|
117
|
+
def _check_XYZ(self, X, Y, Z):
|
|
118
|
+
"""Checks variables X, Y, Z.
|
|
119
|
+
|
|
120
|
+
Parameters
|
|
121
|
+
----------
|
|
122
|
+
X, Y, Z : list of tuples
|
|
123
|
+
For a dependence measure I(X;Y|Z), Y is of the form [(varY, 0)],
|
|
124
|
+
where var specifies the variable index. X typically is of the form
|
|
125
|
+
[(varX, -tau)] with tau denoting the time lag and Z can be
|
|
126
|
+
multivariate [(var1, -lag), (var2, -lag), ...] .
|
|
127
|
+
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
X, Y, Z : tuple
|
|
131
|
+
Cleaned X, Y, Z.
|
|
132
|
+
"""
|
|
133
|
+
# Get the length in time and the number of nodes
|
|
134
|
+
N = self.N
|
|
135
|
+
|
|
136
|
+
# Remove duplicates in X, Y, Z
|
|
137
|
+
X = list(OrderedDict.fromkeys(X))
|
|
138
|
+
Y = list(OrderedDict.fromkeys(Y))
|
|
139
|
+
Z = list(OrderedDict.fromkeys(Z))
|
|
140
|
+
|
|
141
|
+
# If a node in Z occurs already in X or Y, remove it from Z
|
|
142
|
+
Z = [node for node in Z if (node not in X) and (node not in Y)]
|
|
143
|
+
|
|
144
|
+
# Check that all lags are non-positive and indices are in [0,N-1]
|
|
145
|
+
XYZ = X + Y + Z
|
|
146
|
+
dim = len(XYZ)
|
|
147
|
+
# Ensure that XYZ makes sense
|
|
148
|
+
if np.array(XYZ).shape != (dim, 2):
|
|
149
|
+
raise ValueError("X, Y, Z must be lists of tuples in format"
|
|
150
|
+
" [(var, -lag),...], eg., [(2, -2), (1, 0), ...]")
|
|
151
|
+
if np.any(np.array(XYZ)[:, 1] > 0):
|
|
152
|
+
raise ValueError("nodes are %s, " % str(XYZ) +
|
|
153
|
+
"but all lags must be non-positive")
|
|
154
|
+
if (np.any(np.array(XYZ)[:, 0] >= N)
|
|
155
|
+
or np.any(np.array(XYZ)[:, 0] < 0)):
|
|
156
|
+
raise ValueError("var indices %s," % str(np.array(XYZ)[:, 0]) +
|
|
157
|
+
" but must be in [0, %d]" % (N - 1))
|
|
158
|
+
if np.all(np.array(Y)[:, 1] != 0):
|
|
159
|
+
raise ValueError("Y-nodes are %s, " % str(Y) +
|
|
160
|
+
"but one of the Y-nodes must have zero lag")
|
|
161
|
+
|
|
162
|
+
return (X, Y, Z)
|
|
163
|
+
|
|
164
|
+
def _get_lagged_parents(self, var_lag, exclude_contemp=False,
|
|
165
|
+
only_non_causal_paths=False, X=None, causal_children=None):
|
|
166
|
+
"""Helper function to yield lagged parents for var_lag from
|
|
167
|
+
self.links_coeffs.
|
|
168
|
+
|
|
169
|
+
Parameters
|
|
170
|
+
----------
|
|
171
|
+
var_lag : tuple
|
|
172
|
+
Tuple of variable and lag which is assumed <= 0.
|
|
173
|
+
exclude_contemp : bool
|
|
174
|
+
Whether contemporaneous links should be exluded.
|
|
175
|
+
|
|
176
|
+
Yields
|
|
177
|
+
------
|
|
178
|
+
Next lagged parent.
|
|
179
|
+
"""
|
|
180
|
+
|
|
181
|
+
var, lag = var_lag
|
|
182
|
+
|
|
183
|
+
for link_props in self.links[var]:
|
|
184
|
+
if len(link_props) == 3:
|
|
185
|
+
i, tau = link_props[0]
|
|
186
|
+
coeff = link_props[1]
|
|
187
|
+
else:
|
|
188
|
+
i, tau = link_props
|
|
189
|
+
coeff = 1.
|
|
190
|
+
if coeff != 0.:
|
|
191
|
+
if not (exclude_contemp and lag == 0):
|
|
192
|
+
if only_non_causal_paths:
|
|
193
|
+
if not ((i, lag + tau) in X and var_lag in causal_children):
|
|
194
|
+
yield (i, lag + tau)
|
|
195
|
+
else:
|
|
196
|
+
yield (i, lag + tau)
|
|
197
|
+
|
|
198
|
+
def _get_children(self):
|
|
199
|
+
"""Helper function to get children from links.
|
|
200
|
+
|
|
201
|
+
Note that for children the lag is positive.
|
|
202
|
+
|
|
203
|
+
Returns
|
|
204
|
+
-------
|
|
205
|
+
children : dict
|
|
206
|
+
Dictionary of form {0:[(0, 1), (3, 0), ...], 1:[], ...}.
|
|
207
|
+
"""
|
|
208
|
+
|
|
209
|
+
N = len(self.links)
|
|
210
|
+
children = dict([(j, []) for j in range(N)])
|
|
211
|
+
|
|
212
|
+
for j in range(N):
|
|
213
|
+
for link_props in self.links[j]:
|
|
214
|
+
if len(link_props) == 3:
|
|
215
|
+
i, tau = link_props[0]
|
|
216
|
+
coeff = link_props[1]
|
|
217
|
+
else:
|
|
218
|
+
i, tau = link_props
|
|
219
|
+
coeff = 1.
|
|
220
|
+
if coeff != 0.:
|
|
221
|
+
children[i].append((j, abs(tau)))
|
|
222
|
+
|
|
223
|
+
return children
|
|
224
|
+
|
|
225
|
+
def _get_lagged_children(self, var_lag, children, exclude_contemp=False,
|
|
226
|
+
only_non_causal_paths=False, X=None, causal_children=None):
|
|
227
|
+
"""Helper function to yield lagged children for var_lag from children.
|
|
228
|
+
|
|
229
|
+
Parameters
|
|
230
|
+
----------
|
|
231
|
+
var_lag : tuple
|
|
232
|
+
Tuple of variable and lag which is assumed <= 0.
|
|
233
|
+
children : dict
|
|
234
|
+
Dictionary of form {0:[(0, 1), (3, 0), ...], 1:[], ...}.
|
|
235
|
+
exclude_contemp : bool
|
|
236
|
+
Whether contemporaneous links should be exluded.
|
|
237
|
+
|
|
238
|
+
Yields
|
|
239
|
+
------
|
|
240
|
+
Next lagged child.
|
|
241
|
+
"""
|
|
242
|
+
|
|
243
|
+
var, lag = var_lag
|
|
244
|
+
# lagged_parents = []
|
|
245
|
+
|
|
246
|
+
for child in children[var]:
|
|
247
|
+
k, tau = child
|
|
248
|
+
if not (exclude_contemp and tau == 0):
|
|
249
|
+
# lagged_parents.append((i, lag + tau))
|
|
250
|
+
if only_non_causal_paths:
|
|
251
|
+
if not (var_lag in X and (k, lag + tau) in causal_children):
|
|
252
|
+
yield (k, lag + tau)
|
|
253
|
+
else:
|
|
254
|
+
yield (k, lag + tau)
|
|
255
|
+
|
|
256
|
+
def _get_non_blocked_ancestors(self, Y, conds=None, mode='non_repeating',
|
|
257
|
+
max_lag=None):
|
|
258
|
+
"""Helper function to return the non-blocked ancestors of variables Y.
|
|
259
|
+
|
|
260
|
+
Returns a dictionary of ancestors for every y in Y. y is a tuple (
|
|
261
|
+
var, lag) where lag <= 0. All ancestors with directed paths towards y
|
|
262
|
+
that are not blocked by conditions in conds are included. In mode
|
|
263
|
+
'non_repeating' an ancestor X^i_{t-\tau_i} with link X^i_{t-\tau_i}
|
|
264
|
+
--> X^j_{ t-\tau_j} is only included if X^i_{t'-\tau_i} --> X^j_{
|
|
265
|
+
t'-\tau_j} is not already part of the ancestors. The most lagged
|
|
266
|
+
ancestor for every variable X^i defines the maximum ancestral time
|
|
267
|
+
lag, which is also returned. In mode 'max_lag' ancestors are included
|
|
268
|
+
up to the maximum time lag max_lag.
|
|
269
|
+
|
|
270
|
+
It's main use is to return the maximum ancestral time lag max_lag of
|
|
271
|
+
y in Y for every variable in self.links_coeffs.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
Y : list of tuples
|
|
276
|
+
Of the form [(var, -tau)], where var specifies the variable
|
|
277
|
+
index and tau the time lag.
|
|
278
|
+
conds : list of tuples
|
|
279
|
+
Of the form [(var, -tau)], where var specifies the variable
|
|
280
|
+
index and tau the time lag.
|
|
281
|
+
mode : {'non_repeating', 'max_lag'}
|
|
282
|
+
Whether repeating links should be excluded or ancestors should be
|
|
283
|
+
followed up to max_lag.
|
|
284
|
+
max_lag : int
|
|
285
|
+
Maximum time lag to include ancestors.
|
|
286
|
+
|
|
287
|
+
Returns
|
|
288
|
+
-------
|
|
289
|
+
ancestors : dict
|
|
290
|
+
Includes ancestors for every y in Y.
|
|
291
|
+
max_lag : int
|
|
292
|
+
Maximum time lag to include ancestors.
|
|
293
|
+
"""
|
|
294
|
+
|
|
295
|
+
def _repeating(link, seen_links):
|
|
296
|
+
"""Returns True if a link or its time-shifted version is already
|
|
297
|
+
included in seen_links."""
|
|
298
|
+
i, taui = link[0]
|
|
299
|
+
j, tauj = link[1]
|
|
300
|
+
|
|
301
|
+
for seen_link in seen_links:
|
|
302
|
+
seen_i, seen_taui = seen_link[0]
|
|
303
|
+
seen_j, seen_tauj = seen_link[1]
|
|
304
|
+
|
|
305
|
+
if (i == seen_i and j == seen_j
|
|
306
|
+
and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
|
|
307
|
+
return True
|
|
308
|
+
|
|
309
|
+
return False
|
|
310
|
+
|
|
311
|
+
if conds is None:
|
|
312
|
+
conds = []
|
|
313
|
+
|
|
314
|
+
conds = [z for z in conds if z not in Y]
|
|
315
|
+
|
|
316
|
+
N = len(self.links)
|
|
317
|
+
|
|
318
|
+
# Initialize max. ancestral time lag for every N
|
|
319
|
+
if mode == 'non_repeating':
|
|
320
|
+
max_lag = 0
|
|
321
|
+
else:
|
|
322
|
+
if max_lag is None:
|
|
323
|
+
raise ValueError("max_lag must be set in mode = 'max_lag'")
|
|
324
|
+
|
|
325
|
+
if self.selection_vars is not None:
|
|
326
|
+
for selection_var in self.selection_vars:
|
|
327
|
+
# print (selection_var, conds)
|
|
328
|
+
# print([(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)])
|
|
329
|
+
conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
|
|
330
|
+
|
|
331
|
+
ancestors = dict([(y, []) for y in Y])
|
|
332
|
+
|
|
333
|
+
for y in Y:
|
|
334
|
+
j, tau = y # tau <= 0
|
|
335
|
+
if mode == 'non_repeating':
|
|
336
|
+
max_lag = max(max_lag, abs(tau))
|
|
337
|
+
seen_links = []
|
|
338
|
+
this_level = [y]
|
|
339
|
+
while len(this_level) > 0:
|
|
340
|
+
next_level = []
|
|
341
|
+
for varlag in this_level:
|
|
342
|
+
for par in self._get_lagged_parents(varlag):
|
|
343
|
+
i, tau = par
|
|
344
|
+
if par not in conds and par not in ancestors[y]:
|
|
345
|
+
if ((mode == 'non_repeating' and
|
|
346
|
+
not _repeating((par, varlag), seen_links)) or
|
|
347
|
+
(mode == 'max_lag' and
|
|
348
|
+
abs(tau) <= abs(max_lag))):
|
|
349
|
+
ancestors[y].append(par)
|
|
350
|
+
if mode == 'non_repeating':
|
|
351
|
+
max_lag = max(max_lag,
|
|
352
|
+
abs(tau))
|
|
353
|
+
next_level.append(par)
|
|
354
|
+
seen_links.append((par, varlag))
|
|
355
|
+
|
|
356
|
+
this_level = next_level
|
|
357
|
+
|
|
358
|
+
return ancestors, max_lag
|
|
359
|
+
|
|
360
|
+
def _get_maximum_possible_lag(self, XYZ):
|
|
361
|
+
"""Helper function to return the maximum time lag of any confounding path.
|
|
362
|
+
|
|
363
|
+
This is still based on a conjecture!
|
|
364
|
+
|
|
365
|
+
The conjecture states that if and only if X and Y are d-connected given Z
|
|
366
|
+
in a stationary DAG, then there exists a confounding path with a maximal
|
|
367
|
+
time lag (i.e., the node on that path with maximal lag) given as follows:
|
|
368
|
+
For any node in XYZ consider all non-repeating causal paths from the past
|
|
369
|
+
to that node, where non-repeating means that a link X^i_{t-\tau_i}
|
|
370
|
+
--> X^j_{ t-\tau_j} is only traversed if X^i_{t'-\tau_i} --> X^j_{
|
|
371
|
+
t'-\tau_j} is not already part of that path. The most lagged
|
|
372
|
+
ancestor for every variable node in XYZ defines the maximum ancestral time
|
|
373
|
+
lag, which is returned.
|
|
374
|
+
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
XYZ : list of tuples
|
|
378
|
+
Of the form [(var, -tau)], where var specifies the variable
|
|
379
|
+
index and tau the time lag.
|
|
380
|
+
|
|
381
|
+
Returns
|
|
382
|
+
-------
|
|
383
|
+
max_lag : int
|
|
384
|
+
Maximum time lag of non-repeating causal path ancestors.
|
|
385
|
+
"""
|
|
386
|
+
|
|
387
|
+
def _repeating(link, seen_path):
|
|
388
|
+
"""Returns True if a link or its time-shifted version is already
|
|
389
|
+
included in seen_links."""
|
|
390
|
+
i, taui = link[0]
|
|
391
|
+
j, tauj = link[1]
|
|
392
|
+
|
|
393
|
+
for index, seen_link in enumerate(seen_path[:-1]):
|
|
394
|
+
seen_i, seen_taui = seen_link
|
|
395
|
+
seen_j, seen_tauj = seen_path[index + 1]
|
|
396
|
+
|
|
397
|
+
if (i == seen_i and j == seen_j
|
|
398
|
+
and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
|
|
399
|
+
return True
|
|
400
|
+
|
|
401
|
+
return False
|
|
402
|
+
|
|
403
|
+
N = len(self.links)
|
|
404
|
+
|
|
405
|
+
# Initialize max. ancestral time lag for every N
|
|
406
|
+
max_lag = 0
|
|
407
|
+
|
|
408
|
+
# Not sure whether this is relevant!
|
|
409
|
+
# if self.selection_vars is not None:
|
|
410
|
+
# for selection_var in self.selection_vars:
|
|
411
|
+
# # print (selection_var, conds)
|
|
412
|
+
# # print([(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)])
|
|
413
|
+
# conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
|
|
414
|
+
|
|
415
|
+
# ancestors = dict([(y, []) for y in Y])
|
|
416
|
+
|
|
417
|
+
for y in XYZ:
|
|
418
|
+
j, tau = y # tau <= 0
|
|
419
|
+
max_lag = max(max_lag, abs(tau))
|
|
420
|
+
|
|
421
|
+
causal_path = []
|
|
422
|
+
queue = [(y, causal_path)]
|
|
423
|
+
|
|
424
|
+
while queue:
|
|
425
|
+
varlag, causal_path = queue.pop()
|
|
426
|
+
causal_path = [varlag] + causal_path
|
|
427
|
+
|
|
428
|
+
for node in self._get_lagged_parents(varlag):
|
|
429
|
+
i, tau = node
|
|
430
|
+
|
|
431
|
+
if (node not in causal_path):
|
|
432
|
+
|
|
433
|
+
if len(causal_path) == 1:
|
|
434
|
+
queue.append((node, causal_path))
|
|
435
|
+
continue
|
|
436
|
+
|
|
437
|
+
if (len(causal_path) > 1) and not _repeating((node, varlag), causal_path):
|
|
438
|
+
|
|
439
|
+
max_lag = max(max_lag, abs(tau))
|
|
440
|
+
queue.append((node, causal_path))
|
|
441
|
+
|
|
442
|
+
if self.verbosity > 0:
|
|
443
|
+
print("Max. non-repeated ancestral time lag: ", max_lag)
|
|
444
|
+
|
|
445
|
+
# ATTENTION: this may not find correct common ancestors, therefore multiply by 10
|
|
446
|
+
# until the problem is solved
|
|
447
|
+
max_lag *= 10
|
|
448
|
+
|
|
449
|
+
return max_lag
|
|
450
|
+
|
|
451
|
+
def _get_descendants(self, W, children, max_lag, ignore_time_bounds=False):
|
|
452
|
+
"""Get descendants of nodes in W up to time t.
|
|
453
|
+
|
|
454
|
+
Includes the nodes themselves.
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
descendants = set(W)
|
|
458
|
+
|
|
459
|
+
for w in W:
|
|
460
|
+
j, tau = w
|
|
461
|
+
this_level = [w]
|
|
462
|
+
while len(this_level) > 0:
|
|
463
|
+
next_level = []
|
|
464
|
+
for varlag in this_level:
|
|
465
|
+
for child in self._get_lagged_children(varlag, children):
|
|
466
|
+
i, tau = child
|
|
467
|
+
if (child not in descendants
|
|
468
|
+
and (-max_lag <= tau <= 0 or ignore_time_bounds)):
|
|
469
|
+
descendants = descendants.union(set([child]))
|
|
470
|
+
next_level.append(child)
|
|
471
|
+
|
|
472
|
+
this_level = next_level
|
|
473
|
+
|
|
474
|
+
return list(descendants)
|
|
475
|
+
|
|
476
|
+
def _has_any_path(self, X, Y, conds, max_lag=None,
|
|
477
|
+
starts_with=None, ends_with=None,
|
|
478
|
+
directed=False,
|
|
479
|
+
forbidden_nodes=None,
|
|
480
|
+
only_non_causal_paths=False,
|
|
481
|
+
check_optimality_cond=False,
|
|
482
|
+
optimality_cond_des_YM=None,
|
|
483
|
+
optimality_cond_Y=None,
|
|
484
|
+
only_collider_paths_with_vancs=False,
|
|
485
|
+
XYS=None,
|
|
486
|
+
return_path=False):
|
|
487
|
+
"""Returns True if X and Y are d-connected by any open path.
|
|
488
|
+
|
|
489
|
+
Does breadth-first search from both X and Y and meets in the middle.
|
|
490
|
+
Paths are walked according to the d-separation rules where paths can
|
|
491
|
+
only traverse motifs <-- v <-- or <-- v --> or --> v --> or
|
|
492
|
+
--> [v] <-- where [.] indicates that v is conditioned on.
|
|
493
|
+
Furthermore, paths nodes (v, t) need to fulfill max_lag <= t <= 0
|
|
494
|
+
and links cannot be traversed backwards.
|
|
495
|
+
|
|
496
|
+
Parameters
|
|
497
|
+
----------
|
|
498
|
+
X, Y : lists of tuples
|
|
499
|
+
Of the form [(var, -tau)], where var specifies the variable
|
|
500
|
+
index and tau the time lag.
|
|
501
|
+
conds : list of tuples
|
|
502
|
+
Of the form [(var, -tau)], where var specifies the variable
|
|
503
|
+
index and tau the time lag.
|
|
504
|
+
max_lag : int
|
|
505
|
+
Maximum time lag.
|
|
506
|
+
starts_with : {None, 'tail', 'arrohead'}
|
|
507
|
+
Whether to only consider paths starting with particular mark at X.
|
|
508
|
+
ends_with : {None, 'tail', 'arrohead'}
|
|
509
|
+
Whether to only consider paths ending with particular mark at Y.
|
|
510
|
+
"""
|
|
511
|
+
if max_lag is None:
|
|
512
|
+
if conds is None:
|
|
513
|
+
conds = []
|
|
514
|
+
max_lag = self._get_maximum_possible_lag(X+Y+conds)
|
|
515
|
+
|
|
516
|
+
def _walk_to_parents(v, fringe, this_path, other_path):
|
|
517
|
+
"""Helper function to update paths when walking to parents."""
|
|
518
|
+
found_connection = False
|
|
519
|
+
for w in self._get_lagged_parents(v,
|
|
520
|
+
only_non_causal_paths=only_non_causal_paths, X=X,
|
|
521
|
+
causal_children=causal_children):
|
|
522
|
+
# Cannot walk into conditioned parents and
|
|
523
|
+
# cannot walk beyond t or max_lag
|
|
524
|
+
i, t = w
|
|
525
|
+
|
|
526
|
+
if w == x and starts_with == 'arrowhead':
|
|
527
|
+
continue
|
|
528
|
+
|
|
529
|
+
if w == y and ends_with == 'arrowhead':
|
|
530
|
+
continue
|
|
531
|
+
|
|
532
|
+
if (w not in conds and w not in forbidden_nodes and
|
|
533
|
+
# (w, v) not in seen_links and
|
|
534
|
+
t <= 0 and abs(t) <= max_lag):
|
|
535
|
+
# if ((w, 'tail') not in this_path and
|
|
536
|
+
# (w, None) not in this_path):
|
|
537
|
+
if (w not in this_path or
|
|
538
|
+
('tail' not in this_path[w] and None not in this_path[w])):
|
|
539
|
+
if self.verbosity > 1:
|
|
540
|
+
print("Walk parent: %s --> %s " %(v, w))
|
|
541
|
+
fringe.append((w, 'tail'))
|
|
542
|
+
if w not in this_path:
|
|
543
|
+
this_path[w] = {'tail' : (v, 'arrowhead')}
|
|
544
|
+
else:
|
|
545
|
+
this_path[w]['tail'] = (v, 'arrowhead')
|
|
546
|
+
# seen_links.append((v, w))
|
|
547
|
+
# Determine whether X and Y are connected
|
|
548
|
+
# (w, None) indicates the start or end node X/Y
|
|
549
|
+
# if ((w, 'tail') in other_path
|
|
550
|
+
# or (w, 'arrowhead') in other_path
|
|
551
|
+
# or (w, None) in other_path):
|
|
552
|
+
if w in other_path:
|
|
553
|
+
found_connection = (w, 'tail')
|
|
554
|
+
if self.verbosity > 1:
|
|
555
|
+
print("Found connection: ", found_connection)
|
|
556
|
+
break
|
|
557
|
+
return found_connection, fringe, this_path
|
|
558
|
+
|
|
559
|
+
def _walk_to_children(v, fringe, this_path, other_path):
|
|
560
|
+
"""Helper function to update paths when walking to children."""
|
|
561
|
+
found_connection = False
|
|
562
|
+
for w in self._get_lagged_children(v, children,
|
|
563
|
+
only_non_causal_paths=only_non_causal_paths, X=X,
|
|
564
|
+
causal_children=causal_children):
|
|
565
|
+
# You can also walk into conditioned children,
|
|
566
|
+
# but cannot walk beyond t or max_lag
|
|
567
|
+
i, t = w
|
|
568
|
+
|
|
569
|
+
if w == x and starts_with == 'tail':
|
|
570
|
+
continue
|
|
571
|
+
|
|
572
|
+
if w == y and ends_with == 'tail':
|
|
573
|
+
continue
|
|
574
|
+
|
|
575
|
+
if (w not in forbidden_nodes and
|
|
576
|
+
# (w, v) not in seen_links and
|
|
577
|
+
t <= 0 and abs(t) <= max_lag):
|
|
578
|
+
# if ((w, 'arrowhead') not in this_path and
|
|
579
|
+
# (w, None) not in this_path):
|
|
580
|
+
if (w not in this_path or
|
|
581
|
+
('arrowhead' not in this_path[w] and None not in this_path[w])):
|
|
582
|
+
if self.verbosity > 1:
|
|
583
|
+
print("Walk child: %s --> %s " %(v, w))
|
|
584
|
+
fringe.append((w, 'arrowhead'))
|
|
585
|
+
# this_path[(w, 'arrowhead')] = (v, 'tail')
|
|
586
|
+
if w not in this_path:
|
|
587
|
+
this_path[w] = {'arrowhead' : (v, 'tail')}
|
|
588
|
+
else:
|
|
589
|
+
this_path[w]['arrowhead'] = (v, 'tail')
|
|
590
|
+
# seen_links.append((v, w))
|
|
591
|
+
# Determine whether X and Y are connected
|
|
592
|
+
# If the other_path contains w with a tail, then w must
|
|
593
|
+
# NOT be conditioned on. Alternatively, if the other_path
|
|
594
|
+
# contains w with an arrowhead, then w must be
|
|
595
|
+
# conditioned on.
|
|
596
|
+
# if (((w, 'tail') in other_path and w not in conds)
|
|
597
|
+
# or ((w, 'arrowhead') in other_path and w in conds)
|
|
598
|
+
# or (w, None) in other_path):
|
|
599
|
+
if w in other_path:
|
|
600
|
+
if (('tail' in other_path[w] and w not in conds) or
|
|
601
|
+
('arrowhead' in other_path[w] and w in conds) or
|
|
602
|
+
(None in other_path[w])):
|
|
603
|
+
found_connection = (w, 'arrowhead')
|
|
604
|
+
if self.verbosity > 1:
|
|
605
|
+
print("Found connection: ", found_connection)
|
|
606
|
+
break
|
|
607
|
+
return found_connection, fringe, this_path
|
|
608
|
+
|
|
609
|
+
def _walk_fringe(this_level, fringe, this_path, other_path):
|
|
610
|
+
"""Helper function to walk each fringe, i.e., the path from X and Y,
|
|
611
|
+
respectively."""
|
|
612
|
+
found_connection = False
|
|
613
|
+
|
|
614
|
+
if starts_with == 'arrowhead':
|
|
615
|
+
if len(this_level) == 1 and this_level[0] == (x, None):
|
|
616
|
+
(found_connection, fringe,
|
|
617
|
+
this_path) = _walk_to_parents(x, fringe,
|
|
618
|
+
this_path, other_path)
|
|
619
|
+
return found_connection, fringe, this_path, other_path
|
|
620
|
+
|
|
621
|
+
elif starts_with == 'tail':
|
|
622
|
+
if len(this_level) == 1 and this_level[0] == (x, None):
|
|
623
|
+
(found_connection, fringe,
|
|
624
|
+
this_path) = _walk_to_children(x, fringe,
|
|
625
|
+
this_path, other_path)
|
|
626
|
+
return found_connection, fringe, this_path, other_path
|
|
627
|
+
|
|
628
|
+
if ends_with == 'arrowhead':
|
|
629
|
+
if len(this_level) == 1 and this_level[0] == (y, None):
|
|
630
|
+
(found_connection, fringe,
|
|
631
|
+
this_path) = _walk_to_parents(y, fringe,
|
|
632
|
+
this_path, other_path)
|
|
633
|
+
return found_connection, fringe, this_path, other_path
|
|
634
|
+
|
|
635
|
+
elif ends_with == 'tail':
|
|
636
|
+
if len(this_level) == 1 and this_level[0] == (y, None):
|
|
637
|
+
(found_connection, fringe,
|
|
638
|
+
this_path) = _walk_to_children(y, fringe,
|
|
639
|
+
this_path, other_path)
|
|
640
|
+
return found_connection, fringe, this_path, other_path
|
|
641
|
+
|
|
642
|
+
for v, mark in this_level:
|
|
643
|
+
if v in conds:
|
|
644
|
+
if (mark == 'arrowhead' or mark == None) and directed is False:
|
|
645
|
+
# Motif: --> [v] <--
|
|
646
|
+
# If standing on a condition and coming from an
|
|
647
|
+
# arrowhead, you can only walk into parents
|
|
648
|
+
(found_connection, fringe,
|
|
649
|
+
this_path) = _walk_to_parents(v, fringe,
|
|
650
|
+
this_path, other_path)
|
|
651
|
+
if found_connection: break
|
|
652
|
+
else:
|
|
653
|
+
if only_collider_paths_with_vancs:
|
|
654
|
+
continue
|
|
655
|
+
|
|
656
|
+
if (mark == 'tail' or mark == None):
|
|
657
|
+
# Motif: <-- v <-- or <-- v -->
|
|
658
|
+
# If NOT standing on a condition and coming from
|
|
659
|
+
# a tail mark, you can walk into parents or
|
|
660
|
+
# children
|
|
661
|
+
(found_connection, fringe,
|
|
662
|
+
this_path) = _walk_to_parents(v, fringe,
|
|
663
|
+
this_path, other_path)
|
|
664
|
+
if found_connection: break
|
|
665
|
+
|
|
666
|
+
if not directed:
|
|
667
|
+
(found_connection, fringe,
|
|
668
|
+
this_path) = _walk_to_children(v, fringe,
|
|
669
|
+
this_path, other_path)
|
|
670
|
+
if found_connection: break
|
|
671
|
+
|
|
672
|
+
elif mark == 'arrowhead':
|
|
673
|
+
# Motif: --> v -->
|
|
674
|
+
# If NOT standing on a condition and coming from
|
|
675
|
+
# an arrowhead mark, you can only walk into
|
|
676
|
+
# children
|
|
677
|
+
(found_connection, fringe,
|
|
678
|
+
this_path) = _walk_to_children(v, fringe,
|
|
679
|
+
this_path, other_path)
|
|
680
|
+
if found_connection: break
|
|
681
|
+
|
|
682
|
+
if check_optimality_cond and v[0] in self.observed_vars:
|
|
683
|
+
# if v is not descendant of YM
|
|
684
|
+
# and v is not connected to Y given X OS\Cu
|
|
685
|
+
# print("v = ", v)
|
|
686
|
+
cond4a = v not in optimality_cond_des_YM
|
|
687
|
+
cond4b = not self._has_any_path(X=[v], Y=optimality_cond_Y,
|
|
688
|
+
conds=conds + X,
|
|
689
|
+
max_lag=None,
|
|
690
|
+
starts_with=None,
|
|
691
|
+
ends_with=None,
|
|
692
|
+
forbidden_nodes=None, #list(prelim_Oset),
|
|
693
|
+
return_path=False)
|
|
694
|
+
# print(cond4a, cond4b)
|
|
695
|
+
if cond4a and cond4b:
|
|
696
|
+
(found_connection, fringe,
|
|
697
|
+
this_path) = _walk_to_parents(v, fringe,
|
|
698
|
+
this_path, other_path)
|
|
699
|
+
# print(found_connection)
|
|
700
|
+
if found_connection: break
|
|
701
|
+
|
|
702
|
+
if self.verbosity > 1:
|
|
703
|
+
print("Updated fringe: ", fringe)
|
|
704
|
+
return found_connection, fringe, this_path, other_path
|
|
705
|
+
|
|
706
|
+
def backtrace_path():
|
|
707
|
+
"""Helper function to get path from start point, end point,
|
|
708
|
+
and connection found."""
|
|
709
|
+
|
|
710
|
+
path = [found_connection[0]]
|
|
711
|
+
node, mark = found_connection
|
|
712
|
+
|
|
713
|
+
if 'tail' in pred[node]:
|
|
714
|
+
mark = 'tail'
|
|
715
|
+
else:
|
|
716
|
+
mark = 'arrowhead'
|
|
717
|
+
# print(found_connection)
|
|
718
|
+
while path[-1] != x:
|
|
719
|
+
# print(path, node, mark, pred[node])
|
|
720
|
+
prev_node, prev_mark = pred[node][mark]
|
|
721
|
+
path.append(prev_node)
|
|
722
|
+
if prev_mark == 'arrowhead':
|
|
723
|
+
if prev_node not in conds:
|
|
724
|
+
# if pass_through_colliders:
|
|
725
|
+
# if 'tail' in pred[prev_node] and pred[prev_node]['tail'] != (node, mark):
|
|
726
|
+
# mark = 'tail'
|
|
727
|
+
# else:
|
|
728
|
+
# mark = 'arrowhead'
|
|
729
|
+
# else:
|
|
730
|
+
mark = 'tail'
|
|
731
|
+
elif prev_node in conds:
|
|
732
|
+
mark = 'arrowhead'
|
|
733
|
+
elif prev_mark == 'tail':
|
|
734
|
+
if 'tail' in pred[prev_node] and pred[prev_node]['tail'] != (node, mark):
|
|
735
|
+
mark = 'tail'
|
|
736
|
+
else:
|
|
737
|
+
mark = 'arrowhead'
|
|
738
|
+
node = prev_node
|
|
739
|
+
|
|
740
|
+
path.reverse()
|
|
741
|
+
|
|
742
|
+
node, mark = found_connection
|
|
743
|
+
if 'tail' in succ[node]:
|
|
744
|
+
mark = 'tail'
|
|
745
|
+
else:
|
|
746
|
+
mark = 'arrowhead'
|
|
747
|
+
|
|
748
|
+
while path[-1] != y:
|
|
749
|
+
next_node, next_mark = succ[node][mark]
|
|
750
|
+
path.append(next_node)
|
|
751
|
+
if next_mark == 'arrowhead':
|
|
752
|
+
if next_node not in conds:
|
|
753
|
+
# if pass_through_colliders:
|
|
754
|
+
# if 'tail' in succ[next_node] and succ[next_node]['tail'] != (node, mark):
|
|
755
|
+
# mark = 'tail'
|
|
756
|
+
# else:
|
|
757
|
+
# mark = 'arrowhead'
|
|
758
|
+
# else:
|
|
759
|
+
mark = 'tail'
|
|
760
|
+
elif next_node in conds:
|
|
761
|
+
mark = 'arrowhead'
|
|
762
|
+
elif next_mark == 'tail':
|
|
763
|
+
if 'tail' in succ[next_node] and succ[next_node]['tail'] != (node, mark):
|
|
764
|
+
mark = 'tail'
|
|
765
|
+
else:
|
|
766
|
+
mark = 'arrowhead'
|
|
767
|
+
node = next_node
|
|
768
|
+
|
|
769
|
+
return path
|
|
770
|
+
|
|
771
|
+
|
|
772
|
+
if conds is None:
|
|
773
|
+
conds = []
|
|
774
|
+
|
|
775
|
+
if forbidden_nodes is None:
|
|
776
|
+
forbidden_nodes = []
|
|
777
|
+
|
|
778
|
+
conds = [z for z in conds if z not in Y and z not in X]
|
|
779
|
+
# print(X, Y, conds)
|
|
780
|
+
|
|
781
|
+
if self.selection_vars is not None:
|
|
782
|
+
for selection_var in self.selection_vars:
|
|
783
|
+
conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
|
|
784
|
+
|
|
785
|
+
|
|
786
|
+
N = len(self.links)
|
|
787
|
+
children = self._get_children()
|
|
788
|
+
|
|
789
|
+
if only_non_causal_paths:
|
|
790
|
+
anc_Y_dict = self._get_non_blocked_ancestors(Y=Y, conds=None, mode='max_lag',
|
|
791
|
+
max_lag=max_lag)[0]
|
|
792
|
+
# print(anc_Y_dict)
|
|
793
|
+
anc_Y = []
|
|
794
|
+
for y in Y:
|
|
795
|
+
anc_Y += anc_Y_dict[y]
|
|
796
|
+
des_X = self._get_descendants(X, children=children, max_lag=max_lag)
|
|
797
|
+
mediators = set(anc_Y).intersection(set(des_X)) - set(Y) - set(X)
|
|
798
|
+
|
|
799
|
+
causal_children = list(mediators) + Y
|
|
800
|
+
else:
|
|
801
|
+
causal_children = None
|
|
802
|
+
|
|
803
|
+
if only_collider_paths_with_vancs:
|
|
804
|
+
vancs_dict = self._get_non_blocked_ancestors(Y=XYS, conds=None, mode='max_lag',
|
|
805
|
+
max_lag=max_lag)[0]
|
|
806
|
+
vancs = set()
|
|
807
|
+
for xys in XYS:
|
|
808
|
+
vancs = vancs.union(set(vancs_dict[xys]))
|
|
809
|
+
vancs = list(vancs) + XYS
|
|
810
|
+
conds = vancs
|
|
811
|
+
# else:
|
|
812
|
+
# vancs = None
|
|
813
|
+
|
|
814
|
+
# Iterate through nodes in X and Y
|
|
815
|
+
for x in X:
|
|
816
|
+
for y in Y:
|
|
817
|
+
|
|
818
|
+
# seen_links = []
|
|
819
|
+
# predecessor and successors in search
|
|
820
|
+
# (x, None) where None indicates start/end nodes, later (v,
|
|
821
|
+
# 'tail') or (w, 'arrowhead') indicate how a link ends at a node
|
|
822
|
+
pred = {x : {None: None}}
|
|
823
|
+
succ = {y : {None: None}}
|
|
824
|
+
|
|
825
|
+
# initialize fringes, start with forward from X
|
|
826
|
+
forward_fringe = [(x, None)]
|
|
827
|
+
reverse_fringe = [(y, None)]
|
|
828
|
+
|
|
829
|
+
while forward_fringe and reverse_fringe:
|
|
830
|
+
if len(forward_fringe) <= len(reverse_fringe):
|
|
831
|
+
if self.verbosity > 1:
|
|
832
|
+
print("Walk from X since len(X_fringe)=%d "
|
|
833
|
+
"<= len(Y_fringe)=%d" % (len(forward_fringe),
|
|
834
|
+
len(reverse_fringe)))
|
|
835
|
+
this_level = forward_fringe
|
|
836
|
+
forward_fringe = []
|
|
837
|
+
(found_connection, forward_fringe, pred,
|
|
838
|
+
succ) = _walk_fringe(this_level, forward_fringe, pred,
|
|
839
|
+
succ)
|
|
840
|
+
|
|
841
|
+
# print(pred)
|
|
842
|
+
if found_connection:
|
|
843
|
+
if return_path:
|
|
844
|
+
backtraced_path = backtrace_path()
|
|
845
|
+
return [(self.observed_vars.index(node[0]), node[1])
|
|
846
|
+
for node in backtraced_path
|
|
847
|
+
if node[0] in self.observed_vars]
|
|
848
|
+
else:
|
|
849
|
+
return True
|
|
850
|
+
else:
|
|
851
|
+
if self.verbosity > 1:
|
|
852
|
+
print("Walk from Y since len(X_fringe)=%d "
|
|
853
|
+
"> len(Y_fringe)=%d" % (len(forward_fringe),
|
|
854
|
+
len(reverse_fringe)))
|
|
855
|
+
this_level = reverse_fringe
|
|
856
|
+
reverse_fringe = []
|
|
857
|
+
(found_connection, reverse_fringe, succ,
|
|
858
|
+
pred) = _walk_fringe(this_level, reverse_fringe, succ,
|
|
859
|
+
pred)
|
|
860
|
+
|
|
861
|
+
if found_connection:
|
|
862
|
+
if return_path:
|
|
863
|
+
backtraced_path = backtrace_path()
|
|
864
|
+
return [(self.observed_vars.index(node[0]), node[1])
|
|
865
|
+
for node in backtraced_path
|
|
866
|
+
if node[0] in self.observed_vars]
|
|
867
|
+
else:
|
|
868
|
+
return True
|
|
869
|
+
|
|
870
|
+
if self.verbosity > 1:
|
|
871
|
+
print("X_fringe = %s \n" % str(forward_fringe) +
|
|
872
|
+
"Y_fringe = %s" % str(reverse_fringe))
|
|
873
|
+
|
|
874
|
+
return False
|
|
875
|
+
|
|
876
|
+
def _is_dsep(self, X, Y, Z, max_lag=None):
|
|
877
|
+
"""Returns whether X and Y are d-separated given Z in the graph.
|
|
878
|
+
|
|
879
|
+
X, Y, Z are of the form (var, lag) for lag <= 0. D-separation is
|
|
880
|
+
based on:
|
|
881
|
+
|
|
882
|
+
1. Assessing the maximum time lag max_lag possible for any confounding
|
|
883
|
+
path (see _get_maximum_possible_lag(...)).
|
|
884
|
+
|
|
885
|
+
2. Using the time series graph truncated at max_lag we then test
|
|
886
|
+
d-separation between X and Y conditional on Z using breadth-first
|
|
887
|
+
search of non-blocked paths according to d-separation rules.
|
|
888
|
+
|
|
889
|
+
Parameters
|
|
890
|
+
----------
|
|
891
|
+
X, Y, Z : list of tuples
|
|
892
|
+
List of variables chosen for current independence test.
|
|
893
|
+
max_lag : int, optional (default: None)
|
|
894
|
+
Used here to constrain the _is_dsep function to the graph
|
|
895
|
+
truncated at max_lag instead of identifying the max_lag from
|
|
896
|
+
ancestral search.
|
|
897
|
+
|
|
898
|
+
Returns
|
|
899
|
+
-------
|
|
900
|
+
dseparated : bool, or path
|
|
901
|
+
True if X and Y are d-separated given Z in the graph.
|
|
902
|
+
"""
|
|
903
|
+
|
|
904
|
+
N = len(self.links)
|
|
905
|
+
|
|
906
|
+
if self.verbosity > 0:
|
|
907
|
+
print("Testing X=%s d-sep Y=%s given Z=%s in TSG" %(X, Y, Z))
|
|
908
|
+
|
|
909
|
+
if Z is None:
|
|
910
|
+
Z = []
|
|
911
|
+
|
|
912
|
+
if max_lag is not None:
|
|
913
|
+
# max_lags = dict([(j, max_lag) for j in range(N)])
|
|
914
|
+
if self.verbosity > 0:
|
|
915
|
+
print("Set max. time lag to: ", max_lag)
|
|
916
|
+
else:
|
|
917
|
+
max_lag = self._get_maximum_possible_lag(X+Y+Z)
|
|
918
|
+
|
|
919
|
+
# Store overall max. lag
|
|
920
|
+
self.max_lag = max_lag
|
|
921
|
+
|
|
922
|
+
# _has_any_path is the main function that searches open paths
|
|
923
|
+
any_path = self._has_any_path(X, Y, conds=Z, max_lag=max_lag)
|
|
924
|
+
|
|
925
|
+
if any_path:
|
|
926
|
+
dseparated = False
|
|
927
|
+
else:
|
|
928
|
+
dseparated = True
|
|
929
|
+
|
|
930
|
+
return dseparated
|
|
931
|
+
|
|
932
|
+
def check_shortest_path(self, X, Y, Z,
|
|
933
|
+
max_lag=None, # compute_ancestors=False,
|
|
934
|
+
starts_with=None, ends_with=None,
|
|
935
|
+
forbidden_nodes=None,
|
|
936
|
+
directed=False,
|
|
937
|
+
only_non_causal_paths=False,
|
|
938
|
+
check_optimality_cond=False,
|
|
939
|
+
optimality_cond_des_YM=None,
|
|
940
|
+
optimality_cond_Y=None,
|
|
941
|
+
return_path=False):
|
|
942
|
+
"""Returns path between X and Y given Z in the graph.
|
|
943
|
+
|
|
944
|
+
X, Y, Z are of the form (var, lag) for lag <= 0. D-separation is
|
|
945
|
+
based on:
|
|
946
|
+
|
|
947
|
+
1. Assessing maximum time lag max_lag of last ancestor of any X, Y, Z
|
|
948
|
+
with non-blocked (by Z), non-repeating directed path towards X, Y, Z
|
|
949
|
+
in the graph. 'non_repeating' means that an ancestor X^i_{ t-\tau_i}
|
|
950
|
+
with link X^i_{t-\tau_i} --> X^j_{ t-\tau_j} is only included if
|
|
951
|
+
X^i_{t'-\tau_i} --> X^j_{ t'-\tau_j} for t'!=t is not already part of
|
|
952
|
+
the ancestors.
|
|
953
|
+
|
|
954
|
+
2. Using the time series graph truncated at max_lag we then test
|
|
955
|
+
d-separation between X and Y conditional on Z using breadth-first
|
|
956
|
+
search of non-blocked paths according to d-separation rules including
|
|
957
|
+
selection variables.
|
|
958
|
+
|
|
959
|
+
Optionally only considers paths starting/ending with specific marks)
|
|
960
|
+
and makes available the ancestors up to max_lag of X, Y, Z. This may take
|
|
961
|
+
a very long time, however.
|
|
962
|
+
|
|
963
|
+
Parameters
|
|
964
|
+
----------
|
|
965
|
+
X, Y, Z : list of tuples
|
|
966
|
+
List of variables chosen for testing paths.
|
|
967
|
+
max_lag : int, optional (default: None)
|
|
968
|
+
Used here to constrain the has_path function to the graph
|
|
969
|
+
truncated at max_lag instead of identifying the max_lag from
|
|
970
|
+
ancestral search.
|
|
971
|
+
compute_ancestors : bool
|
|
972
|
+
Whether to also make available the ancestors for X, Y, Z as
|
|
973
|
+
self.anc_all_x, self.anc_all_y, and self.anc_all_z, respectively.
|
|
974
|
+
starts_with : {None, 'tail', 'arrohead'}
|
|
975
|
+
Whether to only consider paths starting with particular mark at X.
|
|
976
|
+
ends_with : {None, 'tail', 'arrohead'}
|
|
977
|
+
Whether to only consider paths ending with particular mark at Y.
|
|
978
|
+
|
|
979
|
+
Returns
|
|
980
|
+
-------
|
|
981
|
+
path : list or False
|
|
982
|
+
Returns path or False if no path exists.
|
|
983
|
+
"""
|
|
984
|
+
|
|
985
|
+
N = len(self.links)
|
|
986
|
+
|
|
987
|
+
# Translate from observed_vars index to full variable set index
|
|
988
|
+
X = [(self.observed_vars[x[0]], x[1]) for x in X]
|
|
989
|
+
Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
|
|
990
|
+
Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
|
|
991
|
+
|
|
992
|
+
# print(X)
|
|
993
|
+
# print(Y)
|
|
994
|
+
# print(Z)
|
|
995
|
+
|
|
996
|
+
if check_optimality_cond:
|
|
997
|
+
optimality_cond_des_YM = [(self.observed_vars[x[0]], x[1])
|
|
998
|
+
for x in optimality_cond_des_YM]
|
|
999
|
+
optimality_cond_Y = [(self.observed_vars[x[0]], x[1])
|
|
1000
|
+
for x in optimality_cond_Y]
|
|
1001
|
+
|
|
1002
|
+
# Get the array to test on
|
|
1003
|
+
X, Y, Z = self._check_XYZ(X, Y, Z)
|
|
1004
|
+
|
|
1005
|
+
if self.verbosity > 0:
|
|
1006
|
+
print("Testing X=%s d-sep Y=%s given Z=%s in TSG" %(X, Y, Z))
|
|
1007
|
+
|
|
1008
|
+
if max_lag is not None:
|
|
1009
|
+
# max_lags = dict([(j, max_lag) for j in range(N)])
|
|
1010
|
+
if self.verbosity > 0:
|
|
1011
|
+
print("Set max. time lag to: ", max_lag)
|
|
1012
|
+
else:
|
|
1013
|
+
max_lag = self._get_maximum_possible_lag(X+Y+Z)
|
|
1014
|
+
|
|
1015
|
+
# Store overall max. lag
|
|
1016
|
+
self.max_lag = max_lag
|
|
1017
|
+
|
|
1018
|
+
# _has_any_path is the main function that searches open paths
|
|
1019
|
+
any_path = self._has_any_path(X, Y, conds=Z, max_lag=max_lag,
|
|
1020
|
+
starts_with=starts_with, ends_with=ends_with,
|
|
1021
|
+
return_path=return_path,
|
|
1022
|
+
directed=directed,
|
|
1023
|
+
only_non_causal_paths=only_non_causal_paths,
|
|
1024
|
+
check_optimality_cond=check_optimality_cond,
|
|
1025
|
+
optimality_cond_des_YM=optimality_cond_des_YM,
|
|
1026
|
+
optimality_cond_Y=optimality_cond_Y,
|
|
1027
|
+
forbidden_nodes=forbidden_nodes)
|
|
1028
|
+
|
|
1029
|
+
if any_path:
|
|
1030
|
+
if return_path:
|
|
1031
|
+
any_path_observed = [(self.observed_vars.index(node[0]), node[1]) for node in any_path
|
|
1032
|
+
if node[0] in self.observed_vars]
|
|
1033
|
+
else:
|
|
1034
|
+
any_path_observed = True
|
|
1035
|
+
else:
|
|
1036
|
+
any_path_observed = False
|
|
1037
|
+
|
|
1038
|
+
if self.verbosity > 0:
|
|
1039
|
+
print("_has_any_path = ", any_path)
|
|
1040
|
+
print("_has_any_path_obs = ", any_path_observed)
|
|
1041
|
+
|
|
1042
|
+
|
|
1043
|
+
# if compute_ancestors:
|
|
1044
|
+
# if self.verbosity > 0:
|
|
1045
|
+
# print("Compute ancestors.")
|
|
1046
|
+
|
|
1047
|
+
# # Get ancestors up to maximum ancestral time lag incl. repeated
|
|
1048
|
+
# # links
|
|
1049
|
+
# self.anc_all_x, _ = self._get_non_blocked_ancestors(X, conds=Z,
|
|
1050
|
+
# mode='max_lag', max_lag=max_lag)
|
|
1051
|
+
# self.anc_all_y, _ = self._get_non_blocked_ancestors(Y, conds=Z,
|
|
1052
|
+
# mode='max_lag', max_lag=max_lag)
|
|
1053
|
+
# self.anc_all_z, _ = self._get_non_blocked_ancestors(Z, conds=Z,
|
|
1054
|
+
# mode='max_lag', max_lag=max_lag)
|
|
1055
|
+
|
|
1056
|
+
return any_path_observed
|
|
1057
|
+
|
|
1058
|
+
def run_test(self, X, Y, Z=None, tau_max=0, cut_off='2xtau_max', alpha_or_thres=None,
|
|
1059
|
+
verbosity=0):
|
|
1060
|
+
"""Perform oracle conditional independence test.
|
|
1061
|
+
|
|
1062
|
+
Calls the d-separation function.
|
|
1063
|
+
|
|
1064
|
+
Parameters
|
|
1065
|
+
----------
|
|
1066
|
+
X, Y, Z : list of tuples
|
|
1067
|
+
X,Y,Z are of the form [(var, -tau)], where var specifies the
|
|
1068
|
+
variable index in the observed_vars and tau the time lag.
|
|
1069
|
+
tau_max : int, optional (default: 0)
|
|
1070
|
+
Not used here.
|
|
1071
|
+
cut_off : {'2xtau_max', 'max_lag', 'max_lag_or_tau_max'}
|
|
1072
|
+
Not used here.
|
|
1073
|
+
alpha_or_thres : float
|
|
1074
|
+
Not used here.
|
|
1075
|
+
|
|
1076
|
+
Returns
|
|
1077
|
+
-------
|
|
1078
|
+
val, pval : Tuple of floats
|
|
1079
|
+
The test statistic value and the p-value.
|
|
1080
|
+
"""
|
|
1081
|
+
|
|
1082
|
+
if Z is None:
|
|
1083
|
+
Z = []
|
|
1084
|
+
|
|
1085
|
+
# Translate from observed_vars index to full variable set index
|
|
1086
|
+
X = [(self.observed_vars[x[0]], x[1]) for x in X]
|
|
1087
|
+
Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
|
|
1088
|
+
Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
|
|
1089
|
+
|
|
1090
|
+
# Get the array to test on
|
|
1091
|
+
X, Y, Z = self._check_XYZ(X, Y, Z)
|
|
1092
|
+
|
|
1093
|
+
if not str((X, Y, Z)) in self.dsepsets:
|
|
1094
|
+
self.dsepsets[str((X, Y, Z))] = self._is_dsep(X, Y, Z)
|
|
1095
|
+
|
|
1096
|
+
if self.dsepsets[str((X, Y, Z))]:
|
|
1097
|
+
val = 0.
|
|
1098
|
+
pval = 1.
|
|
1099
|
+
dependent = False
|
|
1100
|
+
else:
|
|
1101
|
+
val = 1.
|
|
1102
|
+
pval = 0.
|
|
1103
|
+
dependent = True
|
|
1104
|
+
|
|
1105
|
+
# Saved here, but not currently used
|
|
1106
|
+
self.ci_results[(tuple(X), tuple(Y),tuple(Z))] = (val, pval, dependent)
|
|
1107
|
+
|
|
1108
|
+
if verbosity > 1:
|
|
1109
|
+
self._print_cond_ind_results(val=val, pval=pval, cached=False,
|
|
1110
|
+
conf=None)
|
|
1111
|
+
# Return the value and the pvalue
|
|
1112
|
+
if alpha_or_thres is None:
|
|
1113
|
+
return val, pval
|
|
1114
|
+
else:
|
|
1115
|
+
return val, pval, dependent
|
|
1116
|
+
|
|
1117
|
+
def get_measure(self, X, Y, Z=None, tau_max=0):
|
|
1118
|
+
"""Returns dependence measure.
|
|
1119
|
+
|
|
1120
|
+
Returns 0 if X and Y are d-separated given Z in the graph and 1 else.
|
|
1121
|
+
|
|
1122
|
+
Parameters
|
|
1123
|
+
----------
|
|
1124
|
+
X, Y [, Z] : list of tuples
|
|
1125
|
+
X,Y,Z are of the form [(var, -tau)], where var specifies the
|
|
1126
|
+
variable index in the observed_vars and tau the time lag.
|
|
1127
|
+
|
|
1128
|
+
tau_max : int, optional (default: 0)
|
|
1129
|
+
Maximum time lag. This may be used to make sure that estimates for
|
|
1130
|
+
different lags in X, Z, all have the same sample size.
|
|
1131
|
+
|
|
1132
|
+
Returns
|
|
1133
|
+
-------
|
|
1134
|
+
val : float
|
|
1135
|
+
The test statistic value.
|
|
1136
|
+
|
|
1137
|
+
"""
|
|
1138
|
+
|
|
1139
|
+
# Translate from observed_vars index to full variable set index
|
|
1140
|
+
X = [(self.observed_vars[x[0]], x[1]) for x in X]
|
|
1141
|
+
Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
|
|
1142
|
+
Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
|
|
1143
|
+
|
|
1144
|
+
# Check XYZ
|
|
1145
|
+
X, Y, Z = _check_XYZ(X, Y, Z)
|
|
1146
|
+
|
|
1147
|
+
if not str((X, Y, Z)) in self.dsepsets:
|
|
1148
|
+
self.dsepsets[str((X, Y, Z))] = self._is_dsep(X, Y, Z)
|
|
1149
|
+
|
|
1150
|
+
if self.dsepsets[str((X, Y, Z))]:
|
|
1151
|
+
return 0.
|
|
1152
|
+
else:
|
|
1153
|
+
return 1.
|
|
1154
|
+
|
|
1155
|
+
def _print_cond_ind_results(self, val, pval=None, cached=None, conf=None):
|
|
1156
|
+
"""Print results from conditional independence test.
|
|
1157
|
+
|
|
1158
|
+
Parameters
|
|
1159
|
+
----------
|
|
1160
|
+
val : float
|
|
1161
|
+
Test stastistic value.
|
|
1162
|
+
pval : float, optional (default: None)
|
|
1163
|
+
p-value
|
|
1164
|
+
conf : tuple of floats, optional (default: None)
|
|
1165
|
+
Confidence bounds.
|
|
1166
|
+
"""
|
|
1167
|
+
printstr = " val = %.3f" % (val)
|
|
1168
|
+
if pval is not None:
|
|
1169
|
+
printstr += " | pval = %.5f" % (pval)
|
|
1170
|
+
if conf is not None:
|
|
1171
|
+
printstr += " | conf bounds = (%.3f, %.3f)" % (
|
|
1172
|
+
conf[0], conf[1])
|
|
1173
|
+
if cached is not None:
|
|
1174
|
+
printstr += " %s" % ({0:"", 1:"[cached]"}[cached])
|
|
1175
|
+
|
|
1176
|
+
print(printstr)
|
|
1177
|
+
|
|
1178
|
+
def get_model_selection_criterion(self, j, parents, tau_max=0):
|
|
1179
|
+
"""
|
|
1180
|
+
Base class assumption that this is not implemented. Concrete classes
|
|
1181
|
+
should override when possible.
|
|
1182
|
+
"""
|
|
1183
|
+
raise NotImplementedError("Model selection not"+\
|
|
1184
|
+
" implemented for %s" % self.measure)
|
|
1185
|
+
|
|
1186
|
+
def _reverse_patt(self, patt):
|
|
1187
|
+
"""Inverts a link pattern"""
|
|
1188
|
+
|
|
1189
|
+
if patt == "":
|
|
1190
|
+
return ""
|
|
1191
|
+
|
|
1192
|
+
left_mark, middle_mark, right_mark = patt[0], patt[1], patt[2]
|
|
1193
|
+
if left_mark == "<":
|
|
1194
|
+
new_right_mark = ">"
|
|
1195
|
+
else:
|
|
1196
|
+
new_right_mark = left_mark
|
|
1197
|
+
if right_mark == ">":
|
|
1198
|
+
new_left_mark = "<"
|
|
1199
|
+
else:
|
|
1200
|
+
new_left_mark = right_mark
|
|
1201
|
+
|
|
1202
|
+
return new_left_mark + middle_mark + new_right_mark
|
|
1203
|
+
|
|
1204
|
+
|
|
1205
|
+
def get_links_from_graph(self, graph):
|
|
1206
|
+
"""
|
|
1207
|
+
Constructs links_coeffs dictionary, observed_vars,
|
|
1208
|
+
and selection_vars from graph array (MAG or DAG).
|
|
1209
|
+
|
|
1210
|
+
In the case of MAGs, for every <-> or --- link further
|
|
1211
|
+
latent and selection variables, respectively, are added.
|
|
1212
|
+
This corresponds to a canonical DAG (Richardson Spirtes 2002).
|
|
1213
|
+
|
|
1214
|
+
For ADMGs "---" are not supported, but also links of type "+->"
|
|
1215
|
+
exist, which corresponds to having both "-->" and "<->".
|
|
1216
|
+
|
|
1217
|
+
Can be used to evaluate d-separation in MAG/DAGs.
|
|
1218
|
+
|
|
1219
|
+
"""
|
|
1220
|
+
|
|
1221
|
+
if "U3" not in str(graph.dtype):
|
|
1222
|
+
raise ValueError("graph must be of type '<U3'!")
|
|
1223
|
+
|
|
1224
|
+
if self.graph_is_mag:
|
|
1225
|
+
edge_types = ["-->", "<--", "<->", "---"]
|
|
1226
|
+
else:
|
|
1227
|
+
edge_types = ["-->", "<--", "<->", "+->", "<-+"] #, "--+", "+--"]
|
|
1228
|
+
|
|
1229
|
+
|
|
1230
|
+
N, N, tau_maxplusone = graph.shape
|
|
1231
|
+
tau_max = tau_maxplusone - 1
|
|
1232
|
+
|
|
1233
|
+
observed_vars = list(range(N))
|
|
1234
|
+
|
|
1235
|
+
selection_vars = []
|
|
1236
|
+
|
|
1237
|
+
links = {j: [] for j in observed_vars }
|
|
1238
|
+
|
|
1239
|
+
# Add further latent variables to accommodate <-> and --- links
|
|
1240
|
+
latent_index = N
|
|
1241
|
+
for i, j, tau in zip(*np.where(graph)):
|
|
1242
|
+
|
|
1243
|
+
edge_type = graph[i, j, tau]
|
|
1244
|
+
|
|
1245
|
+
if edge_type not in edge_types:
|
|
1246
|
+
raise ValueError(
|
|
1247
|
+
"Links can only be in %s " %str(edge_types)
|
|
1248
|
+
)
|
|
1249
|
+
|
|
1250
|
+
if tau == 0:
|
|
1251
|
+
if edge_type != self._reverse_patt(graph[j, i, 0]):
|
|
1252
|
+
raise ValueError(
|
|
1253
|
+
"graph needs to have consistent lag-zero patterns (eg"
|
|
1254
|
+
" graph[i,j,0]='-->' requires graph[j,i,0]='<--')"
|
|
1255
|
+
)
|
|
1256
|
+
|
|
1257
|
+
# Consider contemporaneous links only once
|
|
1258
|
+
if j > i:
|
|
1259
|
+
continue
|
|
1260
|
+
|
|
1261
|
+
# Restrict lagged links
|
|
1262
|
+
else:
|
|
1263
|
+
if edge_type not in ["-->", "<->", "---", "+->"]: #, "--+"]:
|
|
1264
|
+
raise ValueError(
|
|
1265
|
+
"Lagged links can only be in ['-->', '<->', '---', '+->']"
|
|
1266
|
+
)
|
|
1267
|
+
|
|
1268
|
+
if edge_type == "-->":
|
|
1269
|
+
links[j].append((i, -tau))
|
|
1270
|
+
elif edge_type == "<--":
|
|
1271
|
+
links[i].append((j, -tau))
|
|
1272
|
+
elif edge_type == "<->":
|
|
1273
|
+
links[latent_index] = []
|
|
1274
|
+
links[i].append((latent_index, 0))
|
|
1275
|
+
links[j].append((latent_index, -tau))
|
|
1276
|
+
latent_index += 1
|
|
1277
|
+
elif edge_type == "---":
|
|
1278
|
+
links[latent_index] = []
|
|
1279
|
+
selection_vars.append(latent_index)
|
|
1280
|
+
links[latent_index].append((i, -tau))
|
|
1281
|
+
links[latent_index].append((j, 0))
|
|
1282
|
+
latent_index += 1
|
|
1283
|
+
elif edge_type == "+->":
|
|
1284
|
+
links[j].append((i, -tau))
|
|
1285
|
+
links[latent_index] = []
|
|
1286
|
+
links[i].append((latent_index, 0))
|
|
1287
|
+
links[j].append((latent_index, -tau))
|
|
1288
|
+
latent_index += 1
|
|
1289
|
+
elif edge_type == "<-+":
|
|
1290
|
+
links[i].append((j, -tau))
|
|
1291
|
+
links[latent_index] = []
|
|
1292
|
+
links[i].append((latent_index, 0))
|
|
1293
|
+
links[j].append((latent_index, -tau))
|
|
1294
|
+
latent_index += 1
|
|
1295
|
+
# elif edge_type == "+--":
|
|
1296
|
+
# links[i].append((j, -tau))
|
|
1297
|
+
# links[latent_index] = []
|
|
1298
|
+
# selection_vars.append(latent_index)
|
|
1299
|
+
# links[latent_index].append((i, -tau))
|
|
1300
|
+
# links[latent_index].append((j, 0))
|
|
1301
|
+
# latent_index += 1
|
|
1302
|
+
# elif edge_type == "--+":
|
|
1303
|
+
# links[j].append((i, -tau))
|
|
1304
|
+
# links[latent_index] = []
|
|
1305
|
+
# selection_vars.append(latent_index)
|
|
1306
|
+
# links[latent_index].append((i, -tau))
|
|
1307
|
+
# links[latent_index].append((j, 0))
|
|
1308
|
+
# latent_index += 1
|
|
1309
|
+
|
|
1310
|
+
return links, observed_vars, selection_vars
|
|
1311
|
+
|
|
1312
|
+
def _get_minmax_lag(self, links):
|
|
1313
|
+
"""Helper function to retrieve tau_min and tau_max from links
|
|
1314
|
+
"""
|
|
1315
|
+
|
|
1316
|
+
N = len(links)
|
|
1317
|
+
|
|
1318
|
+
# Get maximum time lag
|
|
1319
|
+
min_lag = np.inf
|
|
1320
|
+
max_lag = 0
|
|
1321
|
+
for j in range(N):
|
|
1322
|
+
for link_props in links[j]:
|
|
1323
|
+
if len(link_props) == 3:
|
|
1324
|
+
i, lag = link_props[0]
|
|
1325
|
+
coeff = link_props[1]
|
|
1326
|
+
else:
|
|
1327
|
+
i, lag = link_props
|
|
1328
|
+
coeff = 1.
|
|
1329
|
+
# func = link_props[2]
|
|
1330
|
+
if coeff != 0.:
|
|
1331
|
+
min_lag = min(min_lag, abs(lag))
|
|
1332
|
+
max_lag = max(max_lag, abs(lag))
|
|
1333
|
+
return min_lag, max_lag
|
|
1334
|
+
|
|
1335
|
+
def get_graph_from_links(self, tau_max=None):
|
|
1336
|
+
"""
|
|
1337
|
+
Constructs graph (DAG or MAG or ADMG) from links, observed_vars,
|
|
1338
|
+
and selection_vars.
|
|
1339
|
+
|
|
1340
|
+
For ADMGs uses the Latent projection operation (Pearl 2009).
|
|
1341
|
+
|
|
1342
|
+
"""
|
|
1343
|
+
|
|
1344
|
+
# TODO: use MAG from DAG construction procedure (lecture notes)
|
|
1345
|
+
# issues with tau_max?
|
|
1346
|
+
if self.graph_is_mag is False and len(self.selection_vars) > 0:
|
|
1347
|
+
raise ValueError("ADMGs do not support selection_vars.")
|
|
1348
|
+
|
|
1349
|
+
N_all = len(self.links)
|
|
1350
|
+
|
|
1351
|
+
# If tau_max is None, compute from links_coeffs
|
|
1352
|
+
_, max_lag_links = self._get_minmax_lag(self.links)
|
|
1353
|
+
if tau_max is None:
|
|
1354
|
+
tau_max = max_lag_links
|
|
1355
|
+
else:
|
|
1356
|
+
if max_lag_links > tau_max:
|
|
1357
|
+
raise ValueError("tau_max must be >= maximum lag in links_coeffs; choose tau_max=None")
|
|
1358
|
+
|
|
1359
|
+
# print("max_lag_links ", max_lag_links)
|
|
1360
|
+
|
|
1361
|
+
N = len(self.observed_vars)
|
|
1362
|
+
|
|
1363
|
+
# Init graph
|
|
1364
|
+
graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
|
|
1365
|
+
graph[:] = ""
|
|
1366
|
+
# We will enumerate the observed variables with (i,j) which refers to the index in MAG graph
|
|
1367
|
+
# while x, y iterates through the variables in the underlying DAG
|
|
1368
|
+
|
|
1369
|
+
# Loop over the observed variables
|
|
1370
|
+
for j, y in enumerate(self.observed_vars):
|
|
1371
|
+
for i, x in enumerate(self.observed_vars):
|
|
1372
|
+
for tau in range(0, tau_max + 1):
|
|
1373
|
+
if (x, -tau) != (y, 0):
|
|
1374
|
+
|
|
1375
|
+
if self.graph_is_mag:
|
|
1376
|
+
dag_anc_y, _ = self._get_non_blocked_ancestors(Y=[(y, 0)], conds=None,
|
|
1377
|
+
mode='max_lag',
|
|
1378
|
+
max_lag=tau_max)
|
|
1379
|
+
# Only consider observed ancestors
|
|
1380
|
+
mag_anc_y = [anc for anc in dag_anc_y[(y, 0)]
|
|
1381
|
+
if anc[0] in self.observed_vars]
|
|
1382
|
+
|
|
1383
|
+
dag_anc_x, _ = self._get_non_blocked_ancestors(Y=[(x, -tau)],
|
|
1384
|
+
conds=None, mode='max_lag',
|
|
1385
|
+
max_lag=tau_max)
|
|
1386
|
+
|
|
1387
|
+
# Only consider observed ancestors
|
|
1388
|
+
mag_anc_x = [anc for anc in dag_anc_x[(x, -tau)]
|
|
1389
|
+
if anc[0] in self.observed_vars]
|
|
1390
|
+
|
|
1391
|
+
# Add selection variable ancestors
|
|
1392
|
+
dag_anc_s = set()
|
|
1393
|
+
for s in self.selection_vars:
|
|
1394
|
+
dag_anc_s_here, _ = self._get_non_blocked_ancestors(Y=[(s, 0)],
|
|
1395
|
+
conds=None, mode='max_lag',
|
|
1396
|
+
max_lag=tau_max)
|
|
1397
|
+
dag_anc_s = dag_anc_s.union(set(dag_anc_s_here[(s, 0)]))
|
|
1398
|
+
|
|
1399
|
+
dag_anc_s = list(dag_anc_s)
|
|
1400
|
+
# Only consider observed ancestors
|
|
1401
|
+
mag_anc_s = [anc for anc in dag_anc_s
|
|
1402
|
+
if anc[0] in self.observed_vars]
|
|
1403
|
+
|
|
1404
|
+
Z = set([z for z in mag_anc_y + mag_anc_x + mag_anc_s if z != (y, 0) and z != (x, -tau)])
|
|
1405
|
+
Z = list(Z)
|
|
1406
|
+
|
|
1407
|
+
separated = self._is_dsep(X=[(x, -tau)], Y=[(y, 0)], Z=Z, max_lag=None)
|
|
1408
|
+
|
|
1409
|
+
# If X and Y are connected given Z, mark a link
|
|
1410
|
+
if not separated:
|
|
1411
|
+
# (i, -tau) --> j
|
|
1412
|
+
if (x, -tau) in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) not in dag_anc_x[(x, -tau)] + dag_anc_s:
|
|
1413
|
+
graph[i, j, tau] = "-->"
|
|
1414
|
+
if tau == 0:
|
|
1415
|
+
graph[j, i, 0] = "<--"
|
|
1416
|
+
|
|
1417
|
+
elif (x, -tau) not in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) not in dag_anc_x[(x, -tau)] + dag_anc_s:
|
|
1418
|
+
graph[i, j, tau] = "<->"
|
|
1419
|
+
if tau == 0:
|
|
1420
|
+
graph[j, i, 0] = "<->"
|
|
1421
|
+
|
|
1422
|
+
elif (x, -tau) in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) in dag_anc_x[(x, -tau)] + dag_anc_s:
|
|
1423
|
+
graph[i, j, tau] = "---"
|
|
1424
|
+
if tau == 0:
|
|
1425
|
+
graph[j, i, 0] = "---"
|
|
1426
|
+
else:
|
|
1427
|
+
if tau == 0 and j >= i:
|
|
1428
|
+
continue
|
|
1429
|
+
# edge_types = ["-->", "<->", "+->"]
|
|
1430
|
+
# Latent projection operation:
|
|
1431
|
+
# (i) ADMG contains i --> j iff there is a directed path x --> ... --> y on which
|
|
1432
|
+
# every non-endpoint vertex is in hidden variables (= not in observed_vars)
|
|
1433
|
+
# (ii) ADMG contains i <-> j iff there exists a path of the form x <-- ... --> y on
|
|
1434
|
+
# which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)
|
|
1435
|
+
observed_varslags = set([(v, -lag) for v in self.observed_vars
|
|
1436
|
+
for lag in range(0, tau_max + 1)]) - set([(x, -tau), (y, 0)])
|
|
1437
|
+
cond_one_xy = self._has_any_path(X=[(x, -tau)], Y=[(y, 0)],
|
|
1438
|
+
conds=[],
|
|
1439
|
+
max_lag=None,
|
|
1440
|
+
starts_with='tail',
|
|
1441
|
+
ends_with='arrowhead',
|
|
1442
|
+
directed=True,
|
|
1443
|
+
forbidden_nodes=list(observed_varslags),
|
|
1444
|
+
return_path=False)
|
|
1445
|
+
if tau == 0:
|
|
1446
|
+
cond_one_yx = self._has_any_path(X=[(y, 0)], Y=[(x, 0)],
|
|
1447
|
+
conds=[],
|
|
1448
|
+
max_lag=None,
|
|
1449
|
+
starts_with='tail',
|
|
1450
|
+
ends_with='arrowhead',
|
|
1451
|
+
directed=True,
|
|
1452
|
+
forbidden_nodes=list(observed_varslags),
|
|
1453
|
+
return_path=False)
|
|
1454
|
+
else:
|
|
1455
|
+
cond_one_yx = False
|
|
1456
|
+
cond_two = self._has_any_path(X=[(x, -tau)], Y=[(y, 0)],
|
|
1457
|
+
conds=[],
|
|
1458
|
+
max_lag=None,
|
|
1459
|
+
starts_with='arrowhead',
|
|
1460
|
+
ends_with='arrowhead',
|
|
1461
|
+
directed=False,
|
|
1462
|
+
forbidden_nodes=list(observed_varslags),
|
|
1463
|
+
return_path=False)
|
|
1464
|
+
if cond_one_xy and cond_one_yx:
|
|
1465
|
+
raise ValueError("Cyclic graph!")
|
|
1466
|
+
# print((x, -tau), y, cond_one_xy, cond_one_yx, cond_two)
|
|
1467
|
+
|
|
1468
|
+
# Only (i) holds: i --> j
|
|
1469
|
+
if cond_one_xy and not cond_two:
|
|
1470
|
+
graph[i, j, tau] = "-->"
|
|
1471
|
+
if tau == 0:
|
|
1472
|
+
graph[j, i, 0] = "<--"
|
|
1473
|
+
elif cond_one_yx and not cond_two:
|
|
1474
|
+
graph[i, j, tau] = "<--"
|
|
1475
|
+
if tau == 0:
|
|
1476
|
+
graph[j, i, 0] = "-->"
|
|
1477
|
+
|
|
1478
|
+
# Only (ii) holds: i <-> j
|
|
1479
|
+
elif not cond_one_xy and not cond_one_yx and cond_two:
|
|
1480
|
+
graph[i, j, tau] = "<->"
|
|
1481
|
+
if tau == 0:
|
|
1482
|
+
graph[j, i, 0] = "<->"
|
|
1483
|
+
|
|
1484
|
+
# Both (i) and (ii) hold: i +-> j
|
|
1485
|
+
elif cond_one_xy and cond_two:
|
|
1486
|
+
graph[i, j, tau] = "+->"
|
|
1487
|
+
if tau == 0:
|
|
1488
|
+
graph[j, i, 0] = "<-+"
|
|
1489
|
+
elif cond_one_yx and cond_two:
|
|
1490
|
+
graph[i, j, tau] = "<-+"
|
|
1491
|
+
if tau == 0:
|
|
1492
|
+
graph[j, i, 0] = "+->"
|
|
1493
|
+
# print((i, -tau), j, cond_one_xy, cond_one_yx, cond_two)
|
|
1494
|
+
|
|
1495
|
+
return graph
|
|
1496
|
+
|
|
1497
|
+
def get_confidence(self, X, Y, Z=None, tau_max=0):
|
|
1498
|
+
"""For compatibility with PCMCI.
|
|
1499
|
+
|
|
1500
|
+
Returns
|
|
1501
|
+
-------
|
|
1502
|
+
None
|
|
1503
|
+
"""
|
|
1504
|
+
return None
|
|
1505
|
+
|
|
1506
|
+
if __name__ == '__main__':
|
|
1507
|
+
|
|
1508
|
+
import tigramite.plotting as tp
|
|
1509
|
+
from matplotlib import pyplot as plt
|
|
1510
|
+
def lin_f(x): return x
|
|
1511
|
+
|
|
1512
|
+
# Define the stationary DAG
|
|
1513
|
+
links = {0 : [(0, -3), (1, 0)], 1: [(2, -2)], 2: [(1, -2)]}
|
|
1514
|
+
observed_vars = [0, 1, 2]
|
|
1515
|
+
|
|
1516
|
+
oracle = OracleCI(links=links,
|
|
1517
|
+
observed_vars=observed_vars,
|
|
1518
|
+
graph_is_mag=True,
|
|
1519
|
+
# selection_vars=selection_vars,
|
|
1520
|
+
# verbosity=2
|
|
1521
|
+
)
|
|
1522
|
+
graph = oracle.graph
|
|
1523
|
+
print(graph[:,:,0])
|
|
1524
|
+
|
|
1525
|
+
tp.plot_time_series_graph(graph=graph, var_names=None, figsize=(5, 5),
|
|
1526
|
+
save_name="tsg.pdf")
|
|
1527
|
+
|
|
1528
|
+
X = [(0, 0)]
|
|
1529
|
+
Y = [(2, 0)]
|
|
1530
|
+
Z = []
|
|
1531
|
+
# node = (3, 0)
|
|
1532
|
+
# prelim_Oset = set([(3, 0)])
|
|
1533
|
+
# S = set([])
|
|
1534
|
+
# collider_path_nodes = set([])
|
|
1535
|
+
path = oracle._has_any_path(X=X, Y=Y,
|
|
1536
|
+
conds=Z,
|
|
1537
|
+
max_lag=8,
|
|
1538
|
+
starts_with='arrowhead',
|
|
1539
|
+
ends_with='arrowhead',
|
|
1540
|
+
forbidden_nodes=None,
|
|
1541
|
+
return_path=True)
|
|
1542
|
+
print(path)
|
|
1543
|
+
|
|
1544
|
+
print("-------------------------------")
|
|
1545
|
+
print(oracle._get_maximum_possible_lag(X+Z)) #(X = X, Y = Y, Z = Z))
|
|
1546
|
+
|
|
1547
|
+
# cond_ind_test = OracleCI(graph=graph)
|
|
1548
|
+
# links, observed_vars, selection_vars = cond_ind_test.get_links_from_graph(graph)
|
|
1549
|
+
# print("{")
|
|
1550
|
+
# for j in links.keys():
|
|
1551
|
+
# parents = repr([(p, 'coeff', 'lin_f') for p in links[j]])
|
|
1552
|
+
# print(f"{j: 1d}" ":" f"{parents:s},")
|
|
1553
|
+
# print(repr(observed_vars))
|
|
1554
|
+
# cond_ind_test = OracleCI(graph=graph, verbosity=2)
|
|
1555
|
+
|
|
1556
|
+
# X = [(0, 0)]
|
|
1557
|
+
# Y = [(2, 0)]
|
|
1558
|
+
# Z = [(7, 0), (3, 0), (6, 0), (5, 0), (4, 0)] #(1, -3), (1, -2), (0, -2), (0, -1), (0, -3)]
|
|
1559
|
+
# #(j, -2) for j in range(N)] + [(j, 0) for j in range(N)]
|
|
1560
|
+
|
|
1561
|
+
# # print(oracle._get_non_blocked_ancestors(Z, Z=None, mode='max_lag',
|
|
1562
|
+
# # max_lag=2))
|
|
1563
|
+
# # cond_ind_test = OracleCI(links, observed_vars=observed_vars, verbosity=2)
|
|
1564
|
+
|
|
1565
|
+
# print(cond_ind_test.get_shortest_path(X=X, Y=Y, Z=Z,
|
|
1566
|
+
# max_lag=None, compute_ancestors=False,
|
|
1567
|
+
# backdoor=True))
|
|
1568
|
+
|
|
1569
|
+
# anc_x=None #oracle.anc_all_x[X[0]]
|
|
1570
|
+
# anc_y=None #oracle.anc_all_y[Y[0]]
|
|
1571
|
+
# anc_xy=None # []
|
|
1572
|
+
# # # for z in Z:
|
|
1573
|
+
# # # anc_xy += oracle.anc_all_z[z]
|
|
1574
|
+
|
|
1575
|
+
# fig, ax = tp.plot_tsg(links,
|
|
1576
|
+
# X=[(observed_vars[x[0]], x[1]) for x in X],
|
|
1577
|
+
# Y=[(observed_vars[y[0]], y[1]) for y in Y],
|
|
1578
|
+
# Z=[(observed_vars[z[0]], z[1]) for z in Z],
|
|
1579
|
+
# anc_x=anc_x, anc_y=anc_y,
|
|
1580
|
+
# anc_xy=anc_xy)
|
|
1581
|
+
|
|
1582
|
+
# fig.savefig("/home/rung_ja/Downloads/tsg.pdf")
|