tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1582 @@
1
+ """Tigramite causal discovery for time series."""
2
+
3
+ # Author: Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+ from __future__ import print_function
8
+ import numpy as np
9
+
10
+ from collections import defaultdict, OrderedDict
11
+ from itertools import combinations, permutations
12
+
13
+
14
+ class OracleCI:
15
+ r"""Oracle of conditional independence test X _|_ Y | Z given a graph.
16
+
17
+ Class around link_coeff causal ground truth. X _|_ Y | Z is based on
18
+ assessing whether X and Y are d-separated given Z in the graph.
19
+
20
+ Class can be used just like a Tigramite conditional independence class
21
+ (e.g., ParCorr). The main use is for unit testing of PCMCI methods.
22
+
23
+ Parameters
24
+ ----------
25
+ graph : array of shape [N, N, tau_max+1]
26
+ Causal graph.
27
+ links : dict
28
+ Dictionary of form {0:[(0, -1), ...], 1:[...], ...}.
29
+ Alternatively can also digest {0: [((0, -1), coeff, func)], ...}.
30
+ observed_vars : None or list, optional (default: None)
31
+ Subset of keys in links definining which variables are
32
+ observed. If None, then all variables are observed.
33
+ selection_vars : None or list, optional (default: None)
34
+ Subset of keys in links definining which variables are
35
+ selected (= always conditioned on at every time lag).
36
+ If None, then no variables are selected.
37
+ verbosity : int, optional (default: 0)
38
+ Level of verbosity.
39
+ """
40
+
41
+ # documentation
42
+ @property
43
+ def measure(self):
44
+ """
45
+ Concrete property to return the measure of the independence test
46
+ """
47
+ return self._measure
48
+
49
+ def __init__(self,
50
+ links=None,
51
+ observed_vars=None,
52
+ selection_vars=None,
53
+ graph=None,
54
+ graph_is_mag=False,
55
+ tau_max=None,
56
+ verbosity=0):
57
+
58
+ self.tau_max = tau_max
59
+ self.graph_is_mag = graph_is_mag
60
+
61
+ if links is None:
62
+ if graph is None:
63
+ raise ValueError("Either links or graph must be specified!")
64
+ else:
65
+ # Get canonical DAG from graph, potentially interpreted as MAG
66
+ # self.tau_max = graph.shape[2]
67
+ (links,
68
+ observed_vars,
69
+ selection_vars) = self.get_links_from_graph(graph)
70
+ # # TODO make checks and tau_max?
71
+ # self.graph = graph
72
+
73
+
74
+ self.verbosity = verbosity
75
+ self._measure = 'oracle_ci'
76
+ self.confidence = None
77
+ self.links = links
78
+ self.N = len(links)
79
+ # self.tau_max = self._get_minmax_lag(self.links)
80
+
81
+ # Initialize already computed dsepsets of X, Y, Z
82
+ self.dsepsets = {}
83
+
84
+ # Initialize observed vars
85
+ self.observed_vars = observed_vars
86
+ if self.observed_vars is None:
87
+ self.observed_vars = range(self.N)
88
+ else:
89
+ if not set(self.observed_vars).issubset(set(range(self.N))):
90
+ raise ValueError("observed_vars must be subset of range(N).")
91
+ if self.observed_vars != sorted(self.observed_vars):
92
+ raise ValueError("observed_vars must ordered.")
93
+ if len(self.observed_vars) != len(set(self.observed_vars)):
94
+ raise ValueError("observed_vars must not contain duplicates.")
95
+
96
+ self.selection_vars = selection_vars
97
+
98
+ if self.selection_vars is not None:
99
+ if not set(self.selection_vars).issubset(set(range(self.N))):
100
+ raise ValueError("selection_vars must be subset of range(N).")
101
+ if self.selection_vars != sorted(self.selection_vars):
102
+ raise ValueError("selection_vars must ordered.")
103
+ if len(self.selection_vars) != len(set(self.selection_vars)):
104
+ raise ValueError("selection_vars must not contain duplicates.")
105
+ else:
106
+ self.selection_vars = []
107
+
108
+ # ToDO: maybe allow to use user-tau_max, otherwise deduced from links
109
+ self.graph = self.get_graph_from_links(tau_max=tau_max)
110
+
111
+ self.ci_results = {}
112
+
113
+ def set_dataframe(self, dataframe):
114
+ """Dummy function."""
115
+ pass
116
+
117
+ def _check_XYZ(self, X, Y, Z):
118
+ """Checks variables X, Y, Z.
119
+
120
+ Parameters
121
+ ----------
122
+ X, Y, Z : list of tuples
123
+ For a dependence measure I(X;Y|Z), Y is of the form [(varY, 0)],
124
+ where var specifies the variable index. X typically is of the form
125
+ [(varX, -tau)] with tau denoting the time lag and Z can be
126
+ multivariate [(var1, -lag), (var2, -lag), ...] .
127
+
128
+ Returns
129
+ -------
130
+ X, Y, Z : tuple
131
+ Cleaned X, Y, Z.
132
+ """
133
+ # Get the length in time and the number of nodes
134
+ N = self.N
135
+
136
+ # Remove duplicates in X, Y, Z
137
+ X = list(OrderedDict.fromkeys(X))
138
+ Y = list(OrderedDict.fromkeys(Y))
139
+ Z = list(OrderedDict.fromkeys(Z))
140
+
141
+ # If a node in Z occurs already in X or Y, remove it from Z
142
+ Z = [node for node in Z if (node not in X) and (node not in Y)]
143
+
144
+ # Check that all lags are non-positive and indices are in [0,N-1]
145
+ XYZ = X + Y + Z
146
+ dim = len(XYZ)
147
+ # Ensure that XYZ makes sense
148
+ if np.array(XYZ).shape != (dim, 2):
149
+ raise ValueError("X, Y, Z must be lists of tuples in format"
150
+ " [(var, -lag),...], eg., [(2, -2), (1, 0), ...]")
151
+ if np.any(np.array(XYZ)[:, 1] > 0):
152
+ raise ValueError("nodes are %s, " % str(XYZ) +
153
+ "but all lags must be non-positive")
154
+ if (np.any(np.array(XYZ)[:, 0] >= N)
155
+ or np.any(np.array(XYZ)[:, 0] < 0)):
156
+ raise ValueError("var indices %s," % str(np.array(XYZ)[:, 0]) +
157
+ " but must be in [0, %d]" % (N - 1))
158
+ if np.all(np.array(Y)[:, 1] != 0):
159
+ raise ValueError("Y-nodes are %s, " % str(Y) +
160
+ "but one of the Y-nodes must have zero lag")
161
+
162
+ return (X, Y, Z)
163
+
164
+ def _get_lagged_parents(self, var_lag, exclude_contemp=False,
165
+ only_non_causal_paths=False, X=None, causal_children=None):
166
+ """Helper function to yield lagged parents for var_lag from
167
+ self.links_coeffs.
168
+
169
+ Parameters
170
+ ----------
171
+ var_lag : tuple
172
+ Tuple of variable and lag which is assumed <= 0.
173
+ exclude_contemp : bool
174
+ Whether contemporaneous links should be exluded.
175
+
176
+ Yields
177
+ ------
178
+ Next lagged parent.
179
+ """
180
+
181
+ var, lag = var_lag
182
+
183
+ for link_props in self.links[var]:
184
+ if len(link_props) == 3:
185
+ i, tau = link_props[0]
186
+ coeff = link_props[1]
187
+ else:
188
+ i, tau = link_props
189
+ coeff = 1.
190
+ if coeff != 0.:
191
+ if not (exclude_contemp and lag == 0):
192
+ if only_non_causal_paths:
193
+ if not ((i, lag + tau) in X and var_lag in causal_children):
194
+ yield (i, lag + tau)
195
+ else:
196
+ yield (i, lag + tau)
197
+
198
+ def _get_children(self):
199
+ """Helper function to get children from links.
200
+
201
+ Note that for children the lag is positive.
202
+
203
+ Returns
204
+ -------
205
+ children : dict
206
+ Dictionary of form {0:[(0, 1), (3, 0), ...], 1:[], ...}.
207
+ """
208
+
209
+ N = len(self.links)
210
+ children = dict([(j, []) for j in range(N)])
211
+
212
+ for j in range(N):
213
+ for link_props in self.links[j]:
214
+ if len(link_props) == 3:
215
+ i, tau = link_props[0]
216
+ coeff = link_props[1]
217
+ else:
218
+ i, tau = link_props
219
+ coeff = 1.
220
+ if coeff != 0.:
221
+ children[i].append((j, abs(tau)))
222
+
223
+ return children
224
+
225
+ def _get_lagged_children(self, var_lag, children, exclude_contemp=False,
226
+ only_non_causal_paths=False, X=None, causal_children=None):
227
+ """Helper function to yield lagged children for var_lag from children.
228
+
229
+ Parameters
230
+ ----------
231
+ var_lag : tuple
232
+ Tuple of variable and lag which is assumed <= 0.
233
+ children : dict
234
+ Dictionary of form {0:[(0, 1), (3, 0), ...], 1:[], ...}.
235
+ exclude_contemp : bool
236
+ Whether contemporaneous links should be exluded.
237
+
238
+ Yields
239
+ ------
240
+ Next lagged child.
241
+ """
242
+
243
+ var, lag = var_lag
244
+ # lagged_parents = []
245
+
246
+ for child in children[var]:
247
+ k, tau = child
248
+ if not (exclude_contemp and tau == 0):
249
+ # lagged_parents.append((i, lag + tau))
250
+ if only_non_causal_paths:
251
+ if not (var_lag in X and (k, lag + tau) in causal_children):
252
+ yield (k, lag + tau)
253
+ else:
254
+ yield (k, lag + tau)
255
+
256
+ def _get_non_blocked_ancestors(self, Y, conds=None, mode='non_repeating',
257
+ max_lag=None):
258
+ """Helper function to return the non-blocked ancestors of variables Y.
259
+
260
+ Returns a dictionary of ancestors for every y in Y. y is a tuple (
261
+ var, lag) where lag <= 0. All ancestors with directed paths towards y
262
+ that are not blocked by conditions in conds are included. In mode
263
+ 'non_repeating' an ancestor X^i_{t-\tau_i} with link X^i_{t-\tau_i}
264
+ --> X^j_{ t-\tau_j} is only included if X^i_{t'-\tau_i} --> X^j_{
265
+ t'-\tau_j} is not already part of the ancestors. The most lagged
266
+ ancestor for every variable X^i defines the maximum ancestral time
267
+ lag, which is also returned. In mode 'max_lag' ancestors are included
268
+ up to the maximum time lag max_lag.
269
+
270
+ It's main use is to return the maximum ancestral time lag max_lag of
271
+ y in Y for every variable in self.links_coeffs.
272
+
273
+ Parameters
274
+ ----------
275
+ Y : list of tuples
276
+ Of the form [(var, -tau)], where var specifies the variable
277
+ index and tau the time lag.
278
+ conds : list of tuples
279
+ Of the form [(var, -tau)], where var specifies the variable
280
+ index and tau the time lag.
281
+ mode : {'non_repeating', 'max_lag'}
282
+ Whether repeating links should be excluded or ancestors should be
283
+ followed up to max_lag.
284
+ max_lag : int
285
+ Maximum time lag to include ancestors.
286
+
287
+ Returns
288
+ -------
289
+ ancestors : dict
290
+ Includes ancestors for every y in Y.
291
+ max_lag : int
292
+ Maximum time lag to include ancestors.
293
+ """
294
+
295
+ def _repeating(link, seen_links):
296
+ """Returns True if a link or its time-shifted version is already
297
+ included in seen_links."""
298
+ i, taui = link[0]
299
+ j, tauj = link[1]
300
+
301
+ for seen_link in seen_links:
302
+ seen_i, seen_taui = seen_link[0]
303
+ seen_j, seen_tauj = seen_link[1]
304
+
305
+ if (i == seen_i and j == seen_j
306
+ and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
307
+ return True
308
+
309
+ return False
310
+
311
+ if conds is None:
312
+ conds = []
313
+
314
+ conds = [z for z in conds if z not in Y]
315
+
316
+ N = len(self.links)
317
+
318
+ # Initialize max. ancestral time lag for every N
319
+ if mode == 'non_repeating':
320
+ max_lag = 0
321
+ else:
322
+ if max_lag is None:
323
+ raise ValueError("max_lag must be set in mode = 'max_lag'")
324
+
325
+ if self.selection_vars is not None:
326
+ for selection_var in self.selection_vars:
327
+ # print (selection_var, conds)
328
+ # print([(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)])
329
+ conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
330
+
331
+ ancestors = dict([(y, []) for y in Y])
332
+
333
+ for y in Y:
334
+ j, tau = y # tau <= 0
335
+ if mode == 'non_repeating':
336
+ max_lag = max(max_lag, abs(tau))
337
+ seen_links = []
338
+ this_level = [y]
339
+ while len(this_level) > 0:
340
+ next_level = []
341
+ for varlag in this_level:
342
+ for par in self._get_lagged_parents(varlag):
343
+ i, tau = par
344
+ if par not in conds and par not in ancestors[y]:
345
+ if ((mode == 'non_repeating' and
346
+ not _repeating((par, varlag), seen_links)) or
347
+ (mode == 'max_lag' and
348
+ abs(tau) <= abs(max_lag))):
349
+ ancestors[y].append(par)
350
+ if mode == 'non_repeating':
351
+ max_lag = max(max_lag,
352
+ abs(tau))
353
+ next_level.append(par)
354
+ seen_links.append((par, varlag))
355
+
356
+ this_level = next_level
357
+
358
+ return ancestors, max_lag
359
+
360
+ def _get_maximum_possible_lag(self, XYZ):
361
+ """Helper function to return the maximum time lag of any confounding path.
362
+
363
+ This is still based on a conjecture!
364
+
365
+ The conjecture states that if and only if X and Y are d-connected given Z
366
+ in a stationary DAG, then there exists a confounding path with a maximal
367
+ time lag (i.e., the node on that path with maximal lag) given as follows:
368
+ For any node in XYZ consider all non-repeating causal paths from the past
369
+ to that node, where non-repeating means that a link X^i_{t-\tau_i}
370
+ --> X^j_{ t-\tau_j} is only traversed if X^i_{t'-\tau_i} --> X^j_{
371
+ t'-\tau_j} is not already part of that path. The most lagged
372
+ ancestor for every variable node in XYZ defines the maximum ancestral time
373
+ lag, which is returned.
374
+
375
+ Parameters
376
+ ----------
377
+ XYZ : list of tuples
378
+ Of the form [(var, -tau)], where var specifies the variable
379
+ index and tau the time lag.
380
+
381
+ Returns
382
+ -------
383
+ max_lag : int
384
+ Maximum time lag of non-repeating causal path ancestors.
385
+ """
386
+
387
+ def _repeating(link, seen_path):
388
+ """Returns True if a link or its time-shifted version is already
389
+ included in seen_links."""
390
+ i, taui = link[0]
391
+ j, tauj = link[1]
392
+
393
+ for index, seen_link in enumerate(seen_path[:-1]):
394
+ seen_i, seen_taui = seen_link
395
+ seen_j, seen_tauj = seen_path[index + 1]
396
+
397
+ if (i == seen_i and j == seen_j
398
+ and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
399
+ return True
400
+
401
+ return False
402
+
403
+ N = len(self.links)
404
+
405
+ # Initialize max. ancestral time lag for every N
406
+ max_lag = 0
407
+
408
+ # Not sure whether this is relevant!
409
+ # if self.selection_vars is not None:
410
+ # for selection_var in self.selection_vars:
411
+ # # print (selection_var, conds)
412
+ # # print([(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)])
413
+ # conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
414
+
415
+ # ancestors = dict([(y, []) for y in Y])
416
+
417
+ for y in XYZ:
418
+ j, tau = y # tau <= 0
419
+ max_lag = max(max_lag, abs(tau))
420
+
421
+ causal_path = []
422
+ queue = [(y, causal_path)]
423
+
424
+ while queue:
425
+ varlag, causal_path = queue.pop()
426
+ causal_path = [varlag] + causal_path
427
+
428
+ for node in self._get_lagged_parents(varlag):
429
+ i, tau = node
430
+
431
+ if (node not in causal_path):
432
+
433
+ if len(causal_path) == 1:
434
+ queue.append((node, causal_path))
435
+ continue
436
+
437
+ if (len(causal_path) > 1) and not _repeating((node, varlag), causal_path):
438
+
439
+ max_lag = max(max_lag, abs(tau))
440
+ queue.append((node, causal_path))
441
+
442
+ if self.verbosity > 0:
443
+ print("Max. non-repeated ancestral time lag: ", max_lag)
444
+
445
+ # ATTENTION: this may not find correct common ancestors, therefore multiply by 10
446
+ # until the problem is solved
447
+ max_lag *= 10
448
+
449
+ return max_lag
450
+
451
+ def _get_descendants(self, W, children, max_lag, ignore_time_bounds=False):
452
+ """Get descendants of nodes in W up to time t.
453
+
454
+ Includes the nodes themselves.
455
+ """
456
+
457
+ descendants = set(W)
458
+
459
+ for w in W:
460
+ j, tau = w
461
+ this_level = [w]
462
+ while len(this_level) > 0:
463
+ next_level = []
464
+ for varlag in this_level:
465
+ for child in self._get_lagged_children(varlag, children):
466
+ i, tau = child
467
+ if (child not in descendants
468
+ and (-max_lag <= tau <= 0 or ignore_time_bounds)):
469
+ descendants = descendants.union(set([child]))
470
+ next_level.append(child)
471
+
472
+ this_level = next_level
473
+
474
+ return list(descendants)
475
+
476
+ def _has_any_path(self, X, Y, conds, max_lag=None,
477
+ starts_with=None, ends_with=None,
478
+ directed=False,
479
+ forbidden_nodes=None,
480
+ only_non_causal_paths=False,
481
+ check_optimality_cond=False,
482
+ optimality_cond_des_YM=None,
483
+ optimality_cond_Y=None,
484
+ only_collider_paths_with_vancs=False,
485
+ XYS=None,
486
+ return_path=False):
487
+ """Returns True if X and Y are d-connected by any open path.
488
+
489
+ Does breadth-first search from both X and Y and meets in the middle.
490
+ Paths are walked according to the d-separation rules where paths can
491
+ only traverse motifs <-- v <-- or <-- v --> or --> v --> or
492
+ --> [v] <-- where [.] indicates that v is conditioned on.
493
+ Furthermore, paths nodes (v, t) need to fulfill max_lag <= t <= 0
494
+ and links cannot be traversed backwards.
495
+
496
+ Parameters
497
+ ----------
498
+ X, Y : lists of tuples
499
+ Of the form [(var, -tau)], where var specifies the variable
500
+ index and tau the time lag.
501
+ conds : list of tuples
502
+ Of the form [(var, -tau)], where var specifies the variable
503
+ index and tau the time lag.
504
+ max_lag : int
505
+ Maximum time lag.
506
+ starts_with : {None, 'tail', 'arrohead'}
507
+ Whether to only consider paths starting with particular mark at X.
508
+ ends_with : {None, 'tail', 'arrohead'}
509
+ Whether to only consider paths ending with particular mark at Y.
510
+ """
511
+ if max_lag is None:
512
+ if conds is None:
513
+ conds = []
514
+ max_lag = self._get_maximum_possible_lag(X+Y+conds)
515
+
516
+ def _walk_to_parents(v, fringe, this_path, other_path):
517
+ """Helper function to update paths when walking to parents."""
518
+ found_connection = False
519
+ for w in self._get_lagged_parents(v,
520
+ only_non_causal_paths=only_non_causal_paths, X=X,
521
+ causal_children=causal_children):
522
+ # Cannot walk into conditioned parents and
523
+ # cannot walk beyond t or max_lag
524
+ i, t = w
525
+
526
+ if w == x and starts_with == 'arrowhead':
527
+ continue
528
+
529
+ if w == y and ends_with == 'arrowhead':
530
+ continue
531
+
532
+ if (w not in conds and w not in forbidden_nodes and
533
+ # (w, v) not in seen_links and
534
+ t <= 0 and abs(t) <= max_lag):
535
+ # if ((w, 'tail') not in this_path and
536
+ # (w, None) not in this_path):
537
+ if (w not in this_path or
538
+ ('tail' not in this_path[w] and None not in this_path[w])):
539
+ if self.verbosity > 1:
540
+ print("Walk parent: %s --> %s " %(v, w))
541
+ fringe.append((w, 'tail'))
542
+ if w not in this_path:
543
+ this_path[w] = {'tail' : (v, 'arrowhead')}
544
+ else:
545
+ this_path[w]['tail'] = (v, 'arrowhead')
546
+ # seen_links.append((v, w))
547
+ # Determine whether X and Y are connected
548
+ # (w, None) indicates the start or end node X/Y
549
+ # if ((w, 'tail') in other_path
550
+ # or (w, 'arrowhead') in other_path
551
+ # or (w, None) in other_path):
552
+ if w in other_path:
553
+ found_connection = (w, 'tail')
554
+ if self.verbosity > 1:
555
+ print("Found connection: ", found_connection)
556
+ break
557
+ return found_connection, fringe, this_path
558
+
559
+ def _walk_to_children(v, fringe, this_path, other_path):
560
+ """Helper function to update paths when walking to children."""
561
+ found_connection = False
562
+ for w in self._get_lagged_children(v, children,
563
+ only_non_causal_paths=only_non_causal_paths, X=X,
564
+ causal_children=causal_children):
565
+ # You can also walk into conditioned children,
566
+ # but cannot walk beyond t or max_lag
567
+ i, t = w
568
+
569
+ if w == x and starts_with == 'tail':
570
+ continue
571
+
572
+ if w == y and ends_with == 'tail':
573
+ continue
574
+
575
+ if (w not in forbidden_nodes and
576
+ # (w, v) not in seen_links and
577
+ t <= 0 and abs(t) <= max_lag):
578
+ # if ((w, 'arrowhead') not in this_path and
579
+ # (w, None) not in this_path):
580
+ if (w not in this_path or
581
+ ('arrowhead' not in this_path[w] and None not in this_path[w])):
582
+ if self.verbosity > 1:
583
+ print("Walk child: %s --> %s " %(v, w))
584
+ fringe.append((w, 'arrowhead'))
585
+ # this_path[(w, 'arrowhead')] = (v, 'tail')
586
+ if w not in this_path:
587
+ this_path[w] = {'arrowhead' : (v, 'tail')}
588
+ else:
589
+ this_path[w]['arrowhead'] = (v, 'tail')
590
+ # seen_links.append((v, w))
591
+ # Determine whether X and Y are connected
592
+ # If the other_path contains w with a tail, then w must
593
+ # NOT be conditioned on. Alternatively, if the other_path
594
+ # contains w with an arrowhead, then w must be
595
+ # conditioned on.
596
+ # if (((w, 'tail') in other_path and w not in conds)
597
+ # or ((w, 'arrowhead') in other_path and w in conds)
598
+ # or (w, None) in other_path):
599
+ if w in other_path:
600
+ if (('tail' in other_path[w] and w not in conds) or
601
+ ('arrowhead' in other_path[w] and w in conds) or
602
+ (None in other_path[w])):
603
+ found_connection = (w, 'arrowhead')
604
+ if self.verbosity > 1:
605
+ print("Found connection: ", found_connection)
606
+ break
607
+ return found_connection, fringe, this_path
608
+
609
+ def _walk_fringe(this_level, fringe, this_path, other_path):
610
+ """Helper function to walk each fringe, i.e., the path from X and Y,
611
+ respectively."""
612
+ found_connection = False
613
+
614
+ if starts_with == 'arrowhead':
615
+ if len(this_level) == 1 and this_level[0] == (x, None):
616
+ (found_connection, fringe,
617
+ this_path) = _walk_to_parents(x, fringe,
618
+ this_path, other_path)
619
+ return found_connection, fringe, this_path, other_path
620
+
621
+ elif starts_with == 'tail':
622
+ if len(this_level) == 1 and this_level[0] == (x, None):
623
+ (found_connection, fringe,
624
+ this_path) = _walk_to_children(x, fringe,
625
+ this_path, other_path)
626
+ return found_connection, fringe, this_path, other_path
627
+
628
+ if ends_with == 'arrowhead':
629
+ if len(this_level) == 1 and this_level[0] == (y, None):
630
+ (found_connection, fringe,
631
+ this_path) = _walk_to_parents(y, fringe,
632
+ this_path, other_path)
633
+ return found_connection, fringe, this_path, other_path
634
+
635
+ elif ends_with == 'tail':
636
+ if len(this_level) == 1 and this_level[0] == (y, None):
637
+ (found_connection, fringe,
638
+ this_path) = _walk_to_children(y, fringe,
639
+ this_path, other_path)
640
+ return found_connection, fringe, this_path, other_path
641
+
642
+ for v, mark in this_level:
643
+ if v in conds:
644
+ if (mark == 'arrowhead' or mark == None) and directed is False:
645
+ # Motif: --> [v] <--
646
+ # If standing on a condition and coming from an
647
+ # arrowhead, you can only walk into parents
648
+ (found_connection, fringe,
649
+ this_path) = _walk_to_parents(v, fringe,
650
+ this_path, other_path)
651
+ if found_connection: break
652
+ else:
653
+ if only_collider_paths_with_vancs:
654
+ continue
655
+
656
+ if (mark == 'tail' or mark == None):
657
+ # Motif: <-- v <-- or <-- v -->
658
+ # If NOT standing on a condition and coming from
659
+ # a tail mark, you can walk into parents or
660
+ # children
661
+ (found_connection, fringe,
662
+ this_path) = _walk_to_parents(v, fringe,
663
+ this_path, other_path)
664
+ if found_connection: break
665
+
666
+ if not directed:
667
+ (found_connection, fringe,
668
+ this_path) = _walk_to_children(v, fringe,
669
+ this_path, other_path)
670
+ if found_connection: break
671
+
672
+ elif mark == 'arrowhead':
673
+ # Motif: --> v -->
674
+ # If NOT standing on a condition and coming from
675
+ # an arrowhead mark, you can only walk into
676
+ # children
677
+ (found_connection, fringe,
678
+ this_path) = _walk_to_children(v, fringe,
679
+ this_path, other_path)
680
+ if found_connection: break
681
+
682
+ if check_optimality_cond and v[0] in self.observed_vars:
683
+ # if v is not descendant of YM
684
+ # and v is not connected to Y given X OS\Cu
685
+ # print("v = ", v)
686
+ cond4a = v not in optimality_cond_des_YM
687
+ cond4b = not self._has_any_path(X=[v], Y=optimality_cond_Y,
688
+ conds=conds + X,
689
+ max_lag=None,
690
+ starts_with=None,
691
+ ends_with=None,
692
+ forbidden_nodes=None, #list(prelim_Oset),
693
+ return_path=False)
694
+ # print(cond4a, cond4b)
695
+ if cond4a and cond4b:
696
+ (found_connection, fringe,
697
+ this_path) = _walk_to_parents(v, fringe,
698
+ this_path, other_path)
699
+ # print(found_connection)
700
+ if found_connection: break
701
+
702
+ if self.verbosity > 1:
703
+ print("Updated fringe: ", fringe)
704
+ return found_connection, fringe, this_path, other_path
705
+
706
+ def backtrace_path():
707
+ """Helper function to get path from start point, end point,
708
+ and connection found."""
709
+
710
+ path = [found_connection[0]]
711
+ node, mark = found_connection
712
+
713
+ if 'tail' in pred[node]:
714
+ mark = 'tail'
715
+ else:
716
+ mark = 'arrowhead'
717
+ # print(found_connection)
718
+ while path[-1] != x:
719
+ # print(path, node, mark, pred[node])
720
+ prev_node, prev_mark = pred[node][mark]
721
+ path.append(prev_node)
722
+ if prev_mark == 'arrowhead':
723
+ if prev_node not in conds:
724
+ # if pass_through_colliders:
725
+ # if 'tail' in pred[prev_node] and pred[prev_node]['tail'] != (node, mark):
726
+ # mark = 'tail'
727
+ # else:
728
+ # mark = 'arrowhead'
729
+ # else:
730
+ mark = 'tail'
731
+ elif prev_node in conds:
732
+ mark = 'arrowhead'
733
+ elif prev_mark == 'tail':
734
+ if 'tail' in pred[prev_node] and pred[prev_node]['tail'] != (node, mark):
735
+ mark = 'tail'
736
+ else:
737
+ mark = 'arrowhead'
738
+ node = prev_node
739
+
740
+ path.reverse()
741
+
742
+ node, mark = found_connection
743
+ if 'tail' in succ[node]:
744
+ mark = 'tail'
745
+ else:
746
+ mark = 'arrowhead'
747
+
748
+ while path[-1] != y:
749
+ next_node, next_mark = succ[node][mark]
750
+ path.append(next_node)
751
+ if next_mark == 'arrowhead':
752
+ if next_node not in conds:
753
+ # if pass_through_colliders:
754
+ # if 'tail' in succ[next_node] and succ[next_node]['tail'] != (node, mark):
755
+ # mark = 'tail'
756
+ # else:
757
+ # mark = 'arrowhead'
758
+ # else:
759
+ mark = 'tail'
760
+ elif next_node in conds:
761
+ mark = 'arrowhead'
762
+ elif next_mark == 'tail':
763
+ if 'tail' in succ[next_node] and succ[next_node]['tail'] != (node, mark):
764
+ mark = 'tail'
765
+ else:
766
+ mark = 'arrowhead'
767
+ node = next_node
768
+
769
+ return path
770
+
771
+
772
+ if conds is None:
773
+ conds = []
774
+
775
+ if forbidden_nodes is None:
776
+ forbidden_nodes = []
777
+
778
+ conds = [z for z in conds if z not in Y and z not in X]
779
+ # print(X, Y, conds)
780
+
781
+ if self.selection_vars is not None:
782
+ for selection_var in self.selection_vars:
783
+ conds += [(selection_var, -tau_sel) for tau_sel in range(0, max_lag + 1)]
784
+
785
+
786
+ N = len(self.links)
787
+ children = self._get_children()
788
+
789
+ if only_non_causal_paths:
790
+ anc_Y_dict = self._get_non_blocked_ancestors(Y=Y, conds=None, mode='max_lag',
791
+ max_lag=max_lag)[0]
792
+ # print(anc_Y_dict)
793
+ anc_Y = []
794
+ for y in Y:
795
+ anc_Y += anc_Y_dict[y]
796
+ des_X = self._get_descendants(X, children=children, max_lag=max_lag)
797
+ mediators = set(anc_Y).intersection(set(des_X)) - set(Y) - set(X)
798
+
799
+ causal_children = list(mediators) + Y
800
+ else:
801
+ causal_children = None
802
+
803
+ if only_collider_paths_with_vancs:
804
+ vancs_dict = self._get_non_blocked_ancestors(Y=XYS, conds=None, mode='max_lag',
805
+ max_lag=max_lag)[0]
806
+ vancs = set()
807
+ for xys in XYS:
808
+ vancs = vancs.union(set(vancs_dict[xys]))
809
+ vancs = list(vancs) + XYS
810
+ conds = vancs
811
+ # else:
812
+ # vancs = None
813
+
814
+ # Iterate through nodes in X and Y
815
+ for x in X:
816
+ for y in Y:
817
+
818
+ # seen_links = []
819
+ # predecessor and successors in search
820
+ # (x, None) where None indicates start/end nodes, later (v,
821
+ # 'tail') or (w, 'arrowhead') indicate how a link ends at a node
822
+ pred = {x : {None: None}}
823
+ succ = {y : {None: None}}
824
+
825
+ # initialize fringes, start with forward from X
826
+ forward_fringe = [(x, None)]
827
+ reverse_fringe = [(y, None)]
828
+
829
+ while forward_fringe and reverse_fringe:
830
+ if len(forward_fringe) <= len(reverse_fringe):
831
+ if self.verbosity > 1:
832
+ print("Walk from X since len(X_fringe)=%d "
833
+ "<= len(Y_fringe)=%d" % (len(forward_fringe),
834
+ len(reverse_fringe)))
835
+ this_level = forward_fringe
836
+ forward_fringe = []
837
+ (found_connection, forward_fringe, pred,
838
+ succ) = _walk_fringe(this_level, forward_fringe, pred,
839
+ succ)
840
+
841
+ # print(pred)
842
+ if found_connection:
843
+ if return_path:
844
+ backtraced_path = backtrace_path()
845
+ return [(self.observed_vars.index(node[0]), node[1])
846
+ for node in backtraced_path
847
+ if node[0] in self.observed_vars]
848
+ else:
849
+ return True
850
+ else:
851
+ if self.verbosity > 1:
852
+ print("Walk from Y since len(X_fringe)=%d "
853
+ "> len(Y_fringe)=%d" % (len(forward_fringe),
854
+ len(reverse_fringe)))
855
+ this_level = reverse_fringe
856
+ reverse_fringe = []
857
+ (found_connection, reverse_fringe, succ,
858
+ pred) = _walk_fringe(this_level, reverse_fringe, succ,
859
+ pred)
860
+
861
+ if found_connection:
862
+ if return_path:
863
+ backtraced_path = backtrace_path()
864
+ return [(self.observed_vars.index(node[0]), node[1])
865
+ for node in backtraced_path
866
+ if node[0] in self.observed_vars]
867
+ else:
868
+ return True
869
+
870
+ if self.verbosity > 1:
871
+ print("X_fringe = %s \n" % str(forward_fringe) +
872
+ "Y_fringe = %s" % str(reverse_fringe))
873
+
874
+ return False
875
+
876
+ def _is_dsep(self, X, Y, Z, max_lag=None):
877
+ """Returns whether X and Y are d-separated given Z in the graph.
878
+
879
+ X, Y, Z are of the form (var, lag) for lag <= 0. D-separation is
880
+ based on:
881
+
882
+ 1. Assessing the maximum time lag max_lag possible for any confounding
883
+ path (see _get_maximum_possible_lag(...)).
884
+
885
+ 2. Using the time series graph truncated at max_lag we then test
886
+ d-separation between X and Y conditional on Z using breadth-first
887
+ search of non-blocked paths according to d-separation rules.
888
+
889
+ Parameters
890
+ ----------
891
+ X, Y, Z : list of tuples
892
+ List of variables chosen for current independence test.
893
+ max_lag : int, optional (default: None)
894
+ Used here to constrain the _is_dsep function to the graph
895
+ truncated at max_lag instead of identifying the max_lag from
896
+ ancestral search.
897
+
898
+ Returns
899
+ -------
900
+ dseparated : bool, or path
901
+ True if X and Y are d-separated given Z in the graph.
902
+ """
903
+
904
+ N = len(self.links)
905
+
906
+ if self.verbosity > 0:
907
+ print("Testing X=%s d-sep Y=%s given Z=%s in TSG" %(X, Y, Z))
908
+
909
+ if Z is None:
910
+ Z = []
911
+
912
+ if max_lag is not None:
913
+ # max_lags = dict([(j, max_lag) for j in range(N)])
914
+ if self.verbosity > 0:
915
+ print("Set max. time lag to: ", max_lag)
916
+ else:
917
+ max_lag = self._get_maximum_possible_lag(X+Y+Z)
918
+
919
+ # Store overall max. lag
920
+ self.max_lag = max_lag
921
+
922
+ # _has_any_path is the main function that searches open paths
923
+ any_path = self._has_any_path(X, Y, conds=Z, max_lag=max_lag)
924
+
925
+ if any_path:
926
+ dseparated = False
927
+ else:
928
+ dseparated = True
929
+
930
+ return dseparated
931
+
932
+ def check_shortest_path(self, X, Y, Z,
933
+ max_lag=None, # compute_ancestors=False,
934
+ starts_with=None, ends_with=None,
935
+ forbidden_nodes=None,
936
+ directed=False,
937
+ only_non_causal_paths=False,
938
+ check_optimality_cond=False,
939
+ optimality_cond_des_YM=None,
940
+ optimality_cond_Y=None,
941
+ return_path=False):
942
+ """Returns path between X and Y given Z in the graph.
943
+
944
+ X, Y, Z are of the form (var, lag) for lag <= 0. D-separation is
945
+ based on:
946
+
947
+ 1. Assessing maximum time lag max_lag of last ancestor of any X, Y, Z
948
+ with non-blocked (by Z), non-repeating directed path towards X, Y, Z
949
+ in the graph. 'non_repeating' means that an ancestor X^i_{ t-\tau_i}
950
+ with link X^i_{t-\tau_i} --> X^j_{ t-\tau_j} is only included if
951
+ X^i_{t'-\tau_i} --> X^j_{ t'-\tau_j} for t'!=t is not already part of
952
+ the ancestors.
953
+
954
+ 2. Using the time series graph truncated at max_lag we then test
955
+ d-separation between X and Y conditional on Z using breadth-first
956
+ search of non-blocked paths according to d-separation rules including
957
+ selection variables.
958
+
959
+ Optionally only considers paths starting/ending with specific marks)
960
+ and makes available the ancestors up to max_lag of X, Y, Z. This may take
961
+ a very long time, however.
962
+
963
+ Parameters
964
+ ----------
965
+ X, Y, Z : list of tuples
966
+ List of variables chosen for testing paths.
967
+ max_lag : int, optional (default: None)
968
+ Used here to constrain the has_path function to the graph
969
+ truncated at max_lag instead of identifying the max_lag from
970
+ ancestral search.
971
+ compute_ancestors : bool
972
+ Whether to also make available the ancestors for X, Y, Z as
973
+ self.anc_all_x, self.anc_all_y, and self.anc_all_z, respectively.
974
+ starts_with : {None, 'tail', 'arrohead'}
975
+ Whether to only consider paths starting with particular mark at X.
976
+ ends_with : {None, 'tail', 'arrohead'}
977
+ Whether to only consider paths ending with particular mark at Y.
978
+
979
+ Returns
980
+ -------
981
+ path : list or False
982
+ Returns path or False if no path exists.
983
+ """
984
+
985
+ N = len(self.links)
986
+
987
+ # Translate from observed_vars index to full variable set index
988
+ X = [(self.observed_vars[x[0]], x[1]) for x in X]
989
+ Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
990
+ Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
991
+
992
+ # print(X)
993
+ # print(Y)
994
+ # print(Z)
995
+
996
+ if check_optimality_cond:
997
+ optimality_cond_des_YM = [(self.observed_vars[x[0]], x[1])
998
+ for x in optimality_cond_des_YM]
999
+ optimality_cond_Y = [(self.observed_vars[x[0]], x[1])
1000
+ for x in optimality_cond_Y]
1001
+
1002
+ # Get the array to test on
1003
+ X, Y, Z = self._check_XYZ(X, Y, Z)
1004
+
1005
+ if self.verbosity > 0:
1006
+ print("Testing X=%s d-sep Y=%s given Z=%s in TSG" %(X, Y, Z))
1007
+
1008
+ if max_lag is not None:
1009
+ # max_lags = dict([(j, max_lag) for j in range(N)])
1010
+ if self.verbosity > 0:
1011
+ print("Set max. time lag to: ", max_lag)
1012
+ else:
1013
+ max_lag = self._get_maximum_possible_lag(X+Y+Z)
1014
+
1015
+ # Store overall max. lag
1016
+ self.max_lag = max_lag
1017
+
1018
+ # _has_any_path is the main function that searches open paths
1019
+ any_path = self._has_any_path(X, Y, conds=Z, max_lag=max_lag,
1020
+ starts_with=starts_with, ends_with=ends_with,
1021
+ return_path=return_path,
1022
+ directed=directed,
1023
+ only_non_causal_paths=only_non_causal_paths,
1024
+ check_optimality_cond=check_optimality_cond,
1025
+ optimality_cond_des_YM=optimality_cond_des_YM,
1026
+ optimality_cond_Y=optimality_cond_Y,
1027
+ forbidden_nodes=forbidden_nodes)
1028
+
1029
+ if any_path:
1030
+ if return_path:
1031
+ any_path_observed = [(self.observed_vars.index(node[0]), node[1]) for node in any_path
1032
+ if node[0] in self.observed_vars]
1033
+ else:
1034
+ any_path_observed = True
1035
+ else:
1036
+ any_path_observed = False
1037
+
1038
+ if self.verbosity > 0:
1039
+ print("_has_any_path = ", any_path)
1040
+ print("_has_any_path_obs = ", any_path_observed)
1041
+
1042
+
1043
+ # if compute_ancestors:
1044
+ # if self.verbosity > 0:
1045
+ # print("Compute ancestors.")
1046
+
1047
+ # # Get ancestors up to maximum ancestral time lag incl. repeated
1048
+ # # links
1049
+ # self.anc_all_x, _ = self._get_non_blocked_ancestors(X, conds=Z,
1050
+ # mode='max_lag', max_lag=max_lag)
1051
+ # self.anc_all_y, _ = self._get_non_blocked_ancestors(Y, conds=Z,
1052
+ # mode='max_lag', max_lag=max_lag)
1053
+ # self.anc_all_z, _ = self._get_non_blocked_ancestors(Z, conds=Z,
1054
+ # mode='max_lag', max_lag=max_lag)
1055
+
1056
+ return any_path_observed
1057
+
1058
+ def run_test(self, X, Y, Z=None, tau_max=0, cut_off='2xtau_max', alpha_or_thres=None,
1059
+ verbosity=0):
1060
+ """Perform oracle conditional independence test.
1061
+
1062
+ Calls the d-separation function.
1063
+
1064
+ Parameters
1065
+ ----------
1066
+ X, Y, Z : list of tuples
1067
+ X,Y,Z are of the form [(var, -tau)], where var specifies the
1068
+ variable index in the observed_vars and tau the time lag.
1069
+ tau_max : int, optional (default: 0)
1070
+ Not used here.
1071
+ cut_off : {'2xtau_max', 'max_lag', 'max_lag_or_tau_max'}
1072
+ Not used here.
1073
+ alpha_or_thres : float
1074
+ Not used here.
1075
+
1076
+ Returns
1077
+ -------
1078
+ val, pval : Tuple of floats
1079
+ The test statistic value and the p-value.
1080
+ """
1081
+
1082
+ if Z is None:
1083
+ Z = []
1084
+
1085
+ # Translate from observed_vars index to full variable set index
1086
+ X = [(self.observed_vars[x[0]], x[1]) for x in X]
1087
+ Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
1088
+ Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
1089
+
1090
+ # Get the array to test on
1091
+ X, Y, Z = self._check_XYZ(X, Y, Z)
1092
+
1093
+ if not str((X, Y, Z)) in self.dsepsets:
1094
+ self.dsepsets[str((X, Y, Z))] = self._is_dsep(X, Y, Z)
1095
+
1096
+ if self.dsepsets[str((X, Y, Z))]:
1097
+ val = 0.
1098
+ pval = 1.
1099
+ dependent = False
1100
+ else:
1101
+ val = 1.
1102
+ pval = 0.
1103
+ dependent = True
1104
+
1105
+ # Saved here, but not currently used
1106
+ self.ci_results[(tuple(X), tuple(Y),tuple(Z))] = (val, pval, dependent)
1107
+
1108
+ if verbosity > 1:
1109
+ self._print_cond_ind_results(val=val, pval=pval, cached=False,
1110
+ conf=None)
1111
+ # Return the value and the pvalue
1112
+ if alpha_or_thres is None:
1113
+ return val, pval
1114
+ else:
1115
+ return val, pval, dependent
1116
+
1117
+ def get_measure(self, X, Y, Z=None, tau_max=0):
1118
+ """Returns dependence measure.
1119
+
1120
+ Returns 0 if X and Y are d-separated given Z in the graph and 1 else.
1121
+
1122
+ Parameters
1123
+ ----------
1124
+ X, Y [, Z] : list of tuples
1125
+ X,Y,Z are of the form [(var, -tau)], where var specifies the
1126
+ variable index in the observed_vars and tau the time lag.
1127
+
1128
+ tau_max : int, optional (default: 0)
1129
+ Maximum time lag. This may be used to make sure that estimates for
1130
+ different lags in X, Z, all have the same sample size.
1131
+
1132
+ Returns
1133
+ -------
1134
+ val : float
1135
+ The test statistic value.
1136
+
1137
+ """
1138
+
1139
+ # Translate from observed_vars index to full variable set index
1140
+ X = [(self.observed_vars[x[0]], x[1]) for x in X]
1141
+ Y = [(self.observed_vars[y[0]], y[1]) for y in Y]
1142
+ Z = [(self.observed_vars[z[0]], z[1]) for z in Z]
1143
+
1144
+ # Check XYZ
1145
+ X, Y, Z = _check_XYZ(X, Y, Z)
1146
+
1147
+ if not str((X, Y, Z)) in self.dsepsets:
1148
+ self.dsepsets[str((X, Y, Z))] = self._is_dsep(X, Y, Z)
1149
+
1150
+ if self.dsepsets[str((X, Y, Z))]:
1151
+ return 0.
1152
+ else:
1153
+ return 1.
1154
+
1155
+ def _print_cond_ind_results(self, val, pval=None, cached=None, conf=None):
1156
+ """Print results from conditional independence test.
1157
+
1158
+ Parameters
1159
+ ----------
1160
+ val : float
1161
+ Test stastistic value.
1162
+ pval : float, optional (default: None)
1163
+ p-value
1164
+ conf : tuple of floats, optional (default: None)
1165
+ Confidence bounds.
1166
+ """
1167
+ printstr = " val = %.3f" % (val)
1168
+ if pval is not None:
1169
+ printstr += " | pval = %.5f" % (pval)
1170
+ if conf is not None:
1171
+ printstr += " | conf bounds = (%.3f, %.3f)" % (
1172
+ conf[0], conf[1])
1173
+ if cached is not None:
1174
+ printstr += " %s" % ({0:"", 1:"[cached]"}[cached])
1175
+
1176
+ print(printstr)
1177
+
1178
+ def get_model_selection_criterion(self, j, parents, tau_max=0):
1179
+ """
1180
+ Base class assumption that this is not implemented. Concrete classes
1181
+ should override when possible.
1182
+ """
1183
+ raise NotImplementedError("Model selection not"+\
1184
+ " implemented for %s" % self.measure)
1185
+
1186
+ def _reverse_patt(self, patt):
1187
+ """Inverts a link pattern"""
1188
+
1189
+ if patt == "":
1190
+ return ""
1191
+
1192
+ left_mark, middle_mark, right_mark = patt[0], patt[1], patt[2]
1193
+ if left_mark == "<":
1194
+ new_right_mark = ">"
1195
+ else:
1196
+ new_right_mark = left_mark
1197
+ if right_mark == ">":
1198
+ new_left_mark = "<"
1199
+ else:
1200
+ new_left_mark = right_mark
1201
+
1202
+ return new_left_mark + middle_mark + new_right_mark
1203
+
1204
+
1205
+ def get_links_from_graph(self, graph):
1206
+ """
1207
+ Constructs links_coeffs dictionary, observed_vars,
1208
+ and selection_vars from graph array (MAG or DAG).
1209
+
1210
+ In the case of MAGs, for every <-> or --- link further
1211
+ latent and selection variables, respectively, are added.
1212
+ This corresponds to a canonical DAG (Richardson Spirtes 2002).
1213
+
1214
+ For ADMGs "---" are not supported, but also links of type "+->"
1215
+ exist, which corresponds to having both "-->" and "<->".
1216
+
1217
+ Can be used to evaluate d-separation in MAG/DAGs.
1218
+
1219
+ """
1220
+
1221
+ if "U3" not in str(graph.dtype):
1222
+ raise ValueError("graph must be of type '<U3'!")
1223
+
1224
+ if self.graph_is_mag:
1225
+ edge_types = ["-->", "<--", "<->", "---"]
1226
+ else:
1227
+ edge_types = ["-->", "<--", "<->", "+->", "<-+"] #, "--+", "+--"]
1228
+
1229
+
1230
+ N, N, tau_maxplusone = graph.shape
1231
+ tau_max = tau_maxplusone - 1
1232
+
1233
+ observed_vars = list(range(N))
1234
+
1235
+ selection_vars = []
1236
+
1237
+ links = {j: [] for j in observed_vars }
1238
+
1239
+ # Add further latent variables to accommodate <-> and --- links
1240
+ latent_index = N
1241
+ for i, j, tau in zip(*np.where(graph)):
1242
+
1243
+ edge_type = graph[i, j, tau]
1244
+
1245
+ if edge_type not in edge_types:
1246
+ raise ValueError(
1247
+ "Links can only be in %s " %str(edge_types)
1248
+ )
1249
+
1250
+ if tau == 0:
1251
+ if edge_type != self._reverse_patt(graph[j, i, 0]):
1252
+ raise ValueError(
1253
+ "graph needs to have consistent lag-zero patterns (eg"
1254
+ " graph[i,j,0]='-->' requires graph[j,i,0]='<--')"
1255
+ )
1256
+
1257
+ # Consider contemporaneous links only once
1258
+ if j > i:
1259
+ continue
1260
+
1261
+ # Restrict lagged links
1262
+ else:
1263
+ if edge_type not in ["-->", "<->", "---", "+->"]: #, "--+"]:
1264
+ raise ValueError(
1265
+ "Lagged links can only be in ['-->', '<->', '---', '+->']"
1266
+ )
1267
+
1268
+ if edge_type == "-->":
1269
+ links[j].append((i, -tau))
1270
+ elif edge_type == "<--":
1271
+ links[i].append((j, -tau))
1272
+ elif edge_type == "<->":
1273
+ links[latent_index] = []
1274
+ links[i].append((latent_index, 0))
1275
+ links[j].append((latent_index, -tau))
1276
+ latent_index += 1
1277
+ elif edge_type == "---":
1278
+ links[latent_index] = []
1279
+ selection_vars.append(latent_index)
1280
+ links[latent_index].append((i, -tau))
1281
+ links[latent_index].append((j, 0))
1282
+ latent_index += 1
1283
+ elif edge_type == "+->":
1284
+ links[j].append((i, -tau))
1285
+ links[latent_index] = []
1286
+ links[i].append((latent_index, 0))
1287
+ links[j].append((latent_index, -tau))
1288
+ latent_index += 1
1289
+ elif edge_type == "<-+":
1290
+ links[i].append((j, -tau))
1291
+ links[latent_index] = []
1292
+ links[i].append((latent_index, 0))
1293
+ links[j].append((latent_index, -tau))
1294
+ latent_index += 1
1295
+ # elif edge_type == "+--":
1296
+ # links[i].append((j, -tau))
1297
+ # links[latent_index] = []
1298
+ # selection_vars.append(latent_index)
1299
+ # links[latent_index].append((i, -tau))
1300
+ # links[latent_index].append((j, 0))
1301
+ # latent_index += 1
1302
+ # elif edge_type == "--+":
1303
+ # links[j].append((i, -tau))
1304
+ # links[latent_index] = []
1305
+ # selection_vars.append(latent_index)
1306
+ # links[latent_index].append((i, -tau))
1307
+ # links[latent_index].append((j, 0))
1308
+ # latent_index += 1
1309
+
1310
+ return links, observed_vars, selection_vars
1311
+
1312
+ def _get_minmax_lag(self, links):
1313
+ """Helper function to retrieve tau_min and tau_max from links
1314
+ """
1315
+
1316
+ N = len(links)
1317
+
1318
+ # Get maximum time lag
1319
+ min_lag = np.inf
1320
+ max_lag = 0
1321
+ for j in range(N):
1322
+ for link_props in links[j]:
1323
+ if len(link_props) == 3:
1324
+ i, lag = link_props[0]
1325
+ coeff = link_props[1]
1326
+ else:
1327
+ i, lag = link_props
1328
+ coeff = 1.
1329
+ # func = link_props[2]
1330
+ if coeff != 0.:
1331
+ min_lag = min(min_lag, abs(lag))
1332
+ max_lag = max(max_lag, abs(lag))
1333
+ return min_lag, max_lag
1334
+
1335
+ def get_graph_from_links(self, tau_max=None):
1336
+ """
1337
+ Constructs graph (DAG or MAG or ADMG) from links, observed_vars,
1338
+ and selection_vars.
1339
+
1340
+ For ADMGs uses the Latent projection operation (Pearl 2009).
1341
+
1342
+ """
1343
+
1344
+ # TODO: use MAG from DAG construction procedure (lecture notes)
1345
+ # issues with tau_max?
1346
+ if self.graph_is_mag is False and len(self.selection_vars) > 0:
1347
+ raise ValueError("ADMGs do not support selection_vars.")
1348
+
1349
+ N_all = len(self.links)
1350
+
1351
+ # If tau_max is None, compute from links_coeffs
1352
+ _, max_lag_links = self._get_minmax_lag(self.links)
1353
+ if tau_max is None:
1354
+ tau_max = max_lag_links
1355
+ else:
1356
+ if max_lag_links > tau_max:
1357
+ raise ValueError("tau_max must be >= maximum lag in links_coeffs; choose tau_max=None")
1358
+
1359
+ # print("max_lag_links ", max_lag_links)
1360
+
1361
+ N = len(self.observed_vars)
1362
+
1363
+ # Init graph
1364
+ graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
1365
+ graph[:] = ""
1366
+ # We will enumerate the observed variables with (i,j) which refers to the index in MAG graph
1367
+ # while x, y iterates through the variables in the underlying DAG
1368
+
1369
+ # Loop over the observed variables
1370
+ for j, y in enumerate(self.observed_vars):
1371
+ for i, x in enumerate(self.observed_vars):
1372
+ for tau in range(0, tau_max + 1):
1373
+ if (x, -tau) != (y, 0):
1374
+
1375
+ if self.graph_is_mag:
1376
+ dag_anc_y, _ = self._get_non_blocked_ancestors(Y=[(y, 0)], conds=None,
1377
+ mode='max_lag',
1378
+ max_lag=tau_max)
1379
+ # Only consider observed ancestors
1380
+ mag_anc_y = [anc for anc in dag_anc_y[(y, 0)]
1381
+ if anc[0] in self.observed_vars]
1382
+
1383
+ dag_anc_x, _ = self._get_non_blocked_ancestors(Y=[(x, -tau)],
1384
+ conds=None, mode='max_lag',
1385
+ max_lag=tau_max)
1386
+
1387
+ # Only consider observed ancestors
1388
+ mag_anc_x = [anc for anc in dag_anc_x[(x, -tau)]
1389
+ if anc[0] in self.observed_vars]
1390
+
1391
+ # Add selection variable ancestors
1392
+ dag_anc_s = set()
1393
+ for s in self.selection_vars:
1394
+ dag_anc_s_here, _ = self._get_non_blocked_ancestors(Y=[(s, 0)],
1395
+ conds=None, mode='max_lag',
1396
+ max_lag=tau_max)
1397
+ dag_anc_s = dag_anc_s.union(set(dag_anc_s_here[(s, 0)]))
1398
+
1399
+ dag_anc_s = list(dag_anc_s)
1400
+ # Only consider observed ancestors
1401
+ mag_anc_s = [anc for anc in dag_anc_s
1402
+ if anc[0] in self.observed_vars]
1403
+
1404
+ Z = set([z for z in mag_anc_y + mag_anc_x + mag_anc_s if z != (y, 0) and z != (x, -tau)])
1405
+ Z = list(Z)
1406
+
1407
+ separated = self._is_dsep(X=[(x, -tau)], Y=[(y, 0)], Z=Z, max_lag=None)
1408
+
1409
+ # If X and Y are connected given Z, mark a link
1410
+ if not separated:
1411
+ # (i, -tau) --> j
1412
+ if (x, -tau) in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) not in dag_anc_x[(x, -tau)] + dag_anc_s:
1413
+ graph[i, j, tau] = "-->"
1414
+ if tau == 0:
1415
+ graph[j, i, 0] = "<--"
1416
+
1417
+ elif (x, -tau) not in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) not in dag_anc_x[(x, -tau)] + dag_anc_s:
1418
+ graph[i, j, tau] = "<->"
1419
+ if tau == 0:
1420
+ graph[j, i, 0] = "<->"
1421
+
1422
+ elif (x, -tau) in dag_anc_y[(y, 0)] + dag_anc_s and (y, 0) in dag_anc_x[(x, -tau)] + dag_anc_s:
1423
+ graph[i, j, tau] = "---"
1424
+ if tau == 0:
1425
+ graph[j, i, 0] = "---"
1426
+ else:
1427
+ if tau == 0 and j >= i:
1428
+ continue
1429
+ # edge_types = ["-->", "<->", "+->"]
1430
+ # Latent projection operation:
1431
+ # (i) ADMG contains i --> j iff there is a directed path x --> ... --> y on which
1432
+ # every non-endpoint vertex is in hidden variables (= not in observed_vars)
1433
+ # (ii) ADMG contains i <-> j iff there exists a path of the form x <-- ... --> y on
1434
+ # which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)
1435
+ observed_varslags = set([(v, -lag) for v in self.observed_vars
1436
+ for lag in range(0, tau_max + 1)]) - set([(x, -tau), (y, 0)])
1437
+ cond_one_xy = self._has_any_path(X=[(x, -tau)], Y=[(y, 0)],
1438
+ conds=[],
1439
+ max_lag=None,
1440
+ starts_with='tail',
1441
+ ends_with='arrowhead',
1442
+ directed=True,
1443
+ forbidden_nodes=list(observed_varslags),
1444
+ return_path=False)
1445
+ if tau == 0:
1446
+ cond_one_yx = self._has_any_path(X=[(y, 0)], Y=[(x, 0)],
1447
+ conds=[],
1448
+ max_lag=None,
1449
+ starts_with='tail',
1450
+ ends_with='arrowhead',
1451
+ directed=True,
1452
+ forbidden_nodes=list(observed_varslags),
1453
+ return_path=False)
1454
+ else:
1455
+ cond_one_yx = False
1456
+ cond_two = self._has_any_path(X=[(x, -tau)], Y=[(y, 0)],
1457
+ conds=[],
1458
+ max_lag=None,
1459
+ starts_with='arrowhead',
1460
+ ends_with='arrowhead',
1461
+ directed=False,
1462
+ forbidden_nodes=list(observed_varslags),
1463
+ return_path=False)
1464
+ if cond_one_xy and cond_one_yx:
1465
+ raise ValueError("Cyclic graph!")
1466
+ # print((x, -tau), y, cond_one_xy, cond_one_yx, cond_two)
1467
+
1468
+ # Only (i) holds: i --> j
1469
+ if cond_one_xy and not cond_two:
1470
+ graph[i, j, tau] = "-->"
1471
+ if tau == 0:
1472
+ graph[j, i, 0] = "<--"
1473
+ elif cond_one_yx and not cond_two:
1474
+ graph[i, j, tau] = "<--"
1475
+ if tau == 0:
1476
+ graph[j, i, 0] = "-->"
1477
+
1478
+ # Only (ii) holds: i <-> j
1479
+ elif not cond_one_xy and not cond_one_yx and cond_two:
1480
+ graph[i, j, tau] = "<->"
1481
+ if tau == 0:
1482
+ graph[j, i, 0] = "<->"
1483
+
1484
+ # Both (i) and (ii) hold: i +-> j
1485
+ elif cond_one_xy and cond_two:
1486
+ graph[i, j, tau] = "+->"
1487
+ if tau == 0:
1488
+ graph[j, i, 0] = "<-+"
1489
+ elif cond_one_yx and cond_two:
1490
+ graph[i, j, tau] = "<-+"
1491
+ if tau == 0:
1492
+ graph[j, i, 0] = "+->"
1493
+ # print((i, -tau), j, cond_one_xy, cond_one_yx, cond_two)
1494
+
1495
+ return graph
1496
+
1497
+ def get_confidence(self, X, Y, Z=None, tau_max=0):
1498
+ """For compatibility with PCMCI.
1499
+
1500
+ Returns
1501
+ -------
1502
+ None
1503
+ """
1504
+ return None
1505
+
1506
+ if __name__ == '__main__':
1507
+
1508
+ import tigramite.plotting as tp
1509
+ from matplotlib import pyplot as plt
1510
+ def lin_f(x): return x
1511
+
1512
+ # Define the stationary DAG
1513
+ links = {0 : [(0, -3), (1, 0)], 1: [(2, -2)], 2: [(1, -2)]}
1514
+ observed_vars = [0, 1, 2]
1515
+
1516
+ oracle = OracleCI(links=links,
1517
+ observed_vars=observed_vars,
1518
+ graph_is_mag=True,
1519
+ # selection_vars=selection_vars,
1520
+ # verbosity=2
1521
+ )
1522
+ graph = oracle.graph
1523
+ print(graph[:,:,0])
1524
+
1525
+ tp.plot_time_series_graph(graph=graph, var_names=None, figsize=(5, 5),
1526
+ save_name="tsg.pdf")
1527
+
1528
+ X = [(0, 0)]
1529
+ Y = [(2, 0)]
1530
+ Z = []
1531
+ # node = (3, 0)
1532
+ # prelim_Oset = set([(3, 0)])
1533
+ # S = set([])
1534
+ # collider_path_nodes = set([])
1535
+ path = oracle._has_any_path(X=X, Y=Y,
1536
+ conds=Z,
1537
+ max_lag=8,
1538
+ starts_with='arrowhead',
1539
+ ends_with='arrowhead',
1540
+ forbidden_nodes=None,
1541
+ return_path=True)
1542
+ print(path)
1543
+
1544
+ print("-------------------------------")
1545
+ print(oracle._get_maximum_possible_lag(X+Z)) #(X = X, Y = Y, Z = Z))
1546
+
1547
+ # cond_ind_test = OracleCI(graph=graph)
1548
+ # links, observed_vars, selection_vars = cond_ind_test.get_links_from_graph(graph)
1549
+ # print("{")
1550
+ # for j in links.keys():
1551
+ # parents = repr([(p, 'coeff', 'lin_f') for p in links[j]])
1552
+ # print(f"{j: 1d}" ":" f"{parents:s},")
1553
+ # print(repr(observed_vars))
1554
+ # cond_ind_test = OracleCI(graph=graph, verbosity=2)
1555
+
1556
+ # X = [(0, 0)]
1557
+ # Y = [(2, 0)]
1558
+ # Z = [(7, 0), (3, 0), (6, 0), (5, 0), (4, 0)] #(1, -3), (1, -2), (0, -2), (0, -1), (0, -3)]
1559
+ # #(j, -2) for j in range(N)] + [(j, 0) for j in range(N)]
1560
+
1561
+ # # print(oracle._get_non_blocked_ancestors(Z, Z=None, mode='max_lag',
1562
+ # # max_lag=2))
1563
+ # # cond_ind_test = OracleCI(links, observed_vars=observed_vars, verbosity=2)
1564
+
1565
+ # print(cond_ind_test.get_shortest_path(X=X, Y=Y, Z=Z,
1566
+ # max_lag=None, compute_ancestors=False,
1567
+ # backdoor=True))
1568
+
1569
+ # anc_x=None #oracle.anc_all_x[X[0]]
1570
+ # anc_y=None #oracle.anc_all_y[Y[0]]
1571
+ # anc_xy=None # []
1572
+ # # # for z in Z:
1573
+ # # # anc_xy += oracle.anc_all_z[z]
1574
+
1575
+ # fig, ax = tp.plot_tsg(links,
1576
+ # X=[(observed_vars[x[0]], x[1]) for x in X],
1577
+ # Y=[(observed_vars[y[0]], y[1]) for y in Y],
1578
+ # Z=[(observed_vars[z[0]], z[1]) for z in Z],
1579
+ # anc_x=anc_x, anc_y=anc_y,
1580
+ # anc_xy=anc_xy)
1581
+
1582
+ # fig.savefig("/home/rung_ja/Downloads/tsg.pdf")