tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
tigramite/graphs.py ADDED
@@ -0,0 +1,1509 @@
1
+ """Tigramite causal inference for time series."""
2
+
3
+ # Author: Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+ import numpy as np
8
+ import math
9
+ import itertools
10
+ from copy import deepcopy
11
+ from collections import defaultdict, deque
12
+ from tigramite.models import Models
13
+ import struct
14
+
15
+ class Graphs():
16
+ r"""Graph class.
17
+
18
+ Methods for dealing with causal graphs. Various graph types are
19
+ supported, also including hidden variables.
20
+
21
+
22
+ Parameters
23
+ ----------
24
+ graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]
25
+ Different graph types are supported, see tutorial.
26
+ graph_type : str
27
+ Type of graph.
28
+ tau_max : int, optional (default: 0)
29
+ Maximum time lag of graph.
30
+ hidden_variables : list of tuples
31
+ Hidden variables in format [(i, -tau), ...]. The internal graph is
32
+ constructed by a latent projection.
33
+ verbosity : int, optional (default: 0)
34
+ Level of verbosity.
35
+ """
36
+
37
+ def __init__(self,
38
+ graph,
39
+ graph_type,
40
+ tau_max=0,
41
+ hidden_variables=None,
42
+ verbosity=0):
43
+
44
+ self.verbosity = verbosity
45
+ self.N = graph.shape[0]
46
+ self.tau_max = tau_max
47
+
48
+ #
49
+ # Checks regarding graph type
50
+ #
51
+ supported_graphs = ['dag',
52
+ 'admg',
53
+ 'tsg_dag',
54
+ 'tsg_admg',
55
+ 'stationary_dag',
56
+ 'stationary_admg',
57
+
58
+ 'mag',
59
+ 'tsg_mag',
60
+ # 'stationary_mag',
61
+ # 'pag',
62
+ # 'tsg_pag',
63
+ # 'stationary_pag',
64
+ ]
65
+ if graph_type not in supported_graphs:
66
+ raise ValueError("Only graph types %s supported!" %supported_graphs)
67
+
68
+ # TODO?: check that masking aligns with hidden samples in variables
69
+ if hidden_variables is None:
70
+ hidden_variables = []
71
+
72
+ # Only needed for later extension to MAG/PAGs
73
+ if 'pag' in graph_type:
74
+ self.possible = True
75
+ self.definite_status = True
76
+ else:
77
+ self.possible = False
78
+ self.definite_status = False
79
+
80
+ # Not needed for now...
81
+ # self.ignore_time_bounds = False
82
+
83
+ # Construct internal graph from input graph depending on graph type
84
+ # and hidden variables
85
+ self._construct_graph(graph=graph, graph_type=graph_type,
86
+ hidden_variables=hidden_variables)
87
+
88
+ self._check_graph(self.graph)
89
+
90
+
91
+ def _construct_graph(self, graph, graph_type, hidden_variables):
92
+ """Construct internal graph object based on input graph and hidden variables.
93
+
94
+ Uses the latent projection operation.
95
+ """
96
+
97
+ if graph_type in ['dag', 'admg']:
98
+ if graph.ndim != 2:
99
+ raise ValueError("graph_type in ['dag', 'admg'] assumes graph.shape=(N, N).")
100
+
101
+ allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
102
+ if np.any(np.isin(graph, allowed_edges) == False):
103
+ raise ValueError("Graph contains invalid graph edge. " +
104
+ "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
105
+
106
+ # Convert to shape [N, N, 1, 1] with dummy dimension
107
+ # to process as tsg_dag or tsg_admg with potential hidden variables
108
+ self.graph = np.expand_dims(graph, axis=(2, 3))
109
+
110
+ # tau_max needed in _get_latent_projection_graph
111
+ # self.tau_max = 0
112
+
113
+ if len(hidden_variables) > 0:
114
+ self.graph = self._get_latent_projection_graph() # stationary=False)
115
+ self.graph_type = "tsg_admg"
116
+ else:
117
+ # graph = self.graph
118
+ self.graph_type = 'tsg_' + graph_type
119
+
120
+
121
+ elif graph_type in ['tsg_mag', 'mag']:
122
+ allowed_edges = ["-->", "<--", "<->", ""]
123
+ if np.any(np.isin(graph, allowed_edges) == False):
124
+ raise ValueError("Graph contains invalid graph edge. " +
125
+ "For graph_type = %s only %s are allowed." % (graph_type, str(allowed_edges)))
126
+
127
+ if len(hidden_variables) > 0:
128
+ raise ValueError(f"Hidden variables can not be combined with {graph_type}.")
129
+
130
+ if graph_type == 'mag':
131
+ if graph.ndim != 2:
132
+ raise ValueError("graph_type 'mag' assumes graph.shape=(N, N).")
133
+ self.graph = np.expand_dims(graph, axis=(2, 3))
134
+ self.graph_type = 'tsg_' + graph_type
135
+ else:
136
+ if graph.ndim != 4:
137
+ raise ValueError("tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).")
138
+ # Then tau_max is implicitely derived from
139
+ # the dimensions
140
+ self.graph = graph
141
+ # self.tau_max = graph.shape[2] - 1
142
+ self.graph_type = graph_type
143
+
144
+ elif graph_type in ['tsg_dag', 'tsg_admg']:
145
+ if graph.ndim != 4:
146
+ raise ValueError("tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).")
147
+
148
+ allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
149
+ if np.any(np.isin(graph, allowed_edges) == False):
150
+ raise ValueError("Graph contains invalid graph edge. " +
151
+ "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
152
+
153
+ # Then tau_max is implicitely derived from
154
+ # the dimensions
155
+ self.graph = graph
156
+ # self.tau_max = graph.shape[2] - 1
157
+
158
+ if len(hidden_variables) > 0:
159
+ self.graph = self._get_latent_projection_graph() #, stationary=False)
160
+ self.graph_type = "tsg_admg"
161
+ else:
162
+ self.graph_type = graph_type
163
+
164
+ elif graph_type in ['stationary_dag', 'stationary_admg']:
165
+ # Currently only stationary_dag without hidden variables is supported
166
+ if graph.ndim != 3:
167
+ raise ValueError("stationary graph_type assumes graph.shape=(N, N, tau_max+1).")
168
+
169
+ allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
170
+ if np.any(np.isin(graph, allowed_edges) == False):
171
+ raise ValueError("Graph contains invalid graph edge. " +
172
+ "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
173
+
174
+ # # TODO: remove if theory for stationary ADMGs is clear
175
+ # if graph_type == 'stationary_dag' and len(hidden_variables) > 0:
176
+ # raise ValueError("Hidden variables currently not supported for "
177
+ # "stationary_dag.")
178
+
179
+ # For a stationary DAG without hidden variables it's sufficient to consider
180
+ # a tau_max that includes the parents of X, Y, M, and S. A conservative
181
+ # estimate thereof is simply the lag-dimension of the stationary DAG plus
182
+ # the maximum lag of XYS.
183
+ # statgraph_tau_max = graph.shape[2] - 1
184
+ # maxlag_XYS = 0
185
+ # for varlag in self.X.union(self.Y).union(self.S):
186
+ # maxlag_XYS = max(maxlag_XYS, abs(varlag[1]))
187
+
188
+ # self.tau_max = maxlag_XYS + statgraph_tau_max
189
+
190
+ stat_graph = deepcopy(graph)
191
+
192
+ #########################################
193
+ # Use this tau_max and construct ADMG by assuming paths of
194
+ # maximal lag 10*tau_max... TO BE REVISED!
195
+ self.graph = graph
196
+ self.graph = self._get_latent_projection_graph(stationary=True)
197
+ self.graph_type = "tsg_admg"
198
+ #########################################
199
+
200
+ # Also create stationary graph extended to tau_max
201
+ self.stationary_graph = np.zeros((self.N, self.N, self.tau_max + 1), dtype='<U3')
202
+ self.stationary_graph[:, :, :stat_graph.shape[2]] = stat_graph
203
+
204
+ # allowed_edges = ["-->", "<--"]
205
+
206
+ # # Construct tsg_graph
207
+ # graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
208
+ # graph[:] = ""
209
+ # for (i, j) in itertools.product(range(self.N), range(self.N)):
210
+ # for jt, tauj in enumerate(range(0, self.tau_max + 1)):
211
+ # for it, taui in enumerate(range(tauj, self.tau_max + 1)):
212
+ # tau = abs(taui - tauj)
213
+ # if tau == 0 and j == i:
214
+ # continue
215
+ # if tau > statgraph_tau_max:
216
+ # continue
217
+
218
+ # # if tau == 0:
219
+ # # if stat_graph[i, j, tau] == '-->':
220
+ # # graph[i, j, taui, tauj] = "-->"
221
+ # # graph[j, i, tauj, taui] = "<--"
222
+
223
+ # # # elif stat_graph[i, j, tau] == '<--':
224
+ # # # graph[i, j, taui, tauj] = "<--"
225
+ # # # graph[j, i, tauj, taui] = "-->"
226
+ # # else:
227
+ # if stat_graph[i, j, tau] == '-->':
228
+ # graph[i, j, taui, tauj] = "-->"
229
+ # graph[j, i, tauj, taui] = "<--"
230
+ # elif stat_graph[i, j, tau] == '<--':
231
+ # pass
232
+ # elif stat_graph[i, j, tau] == '':
233
+ # pass
234
+ # else:
235
+ # edge = stat_graph[i, j, tau]
236
+ # raise ValueError("Invalid graph edge %s. " %(edge) +
237
+ # "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
238
+
239
+ # # elif stat_graph[i, j, tau] == '<--':
240
+ # # graph[i, j, taui, tauj] = "<--"
241
+ # # graph[j, i, tauj, taui] = "-->"
242
+
243
+ # self.graph_type = 'tsg_dag'
244
+ # self.graph = graph
245
+
246
+
247
+ # return (graph, graph_type, self.tau_max, hidden_variables)
248
+
249
+ # max_lag = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S)), graph=graph)
250
+
251
+ # stat_mediators = self._get_mediators_stationary_graph(start=X, end=Y, max_lag=max_lag)
252
+ # self.tau_max = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S).union(stat_mediators)), graph=graph)
253
+ # self.tau_max = graph_taumax
254
+ # for varlag in X.union(Y).union(S):
255
+ # self.tau_max = max(self.tau_max, abs(varlag[1]))
256
+
257
+ # if verbosity > 0:
258
+ # print("Setting tau_max = ", self.tau_max)
259
+
260
+ # if tau_max is None:
261
+ # self.tau_max = graph_taumax
262
+ # for varlag in X.union(Y).union(S):
263
+ # self.tau_max = max(self.tau_max, abs(varlag[1]))
264
+
265
+ # if verbosity > 0:
266
+ # print("Setting tau_max = ", self.tau_max)
267
+ # else:
268
+ # self.tau_max = graph_taumax
269
+ # # Repeat hidden variable pattern
270
+ # # if larger tau_max is given
271
+ # if self.tau_max > graph_taumax:
272
+ # for lag in range(graph_taumax + 1, self.tau_max + 1):
273
+ # for j in range(self.N):
274
+ # if (j, -(lag % (graph_taumax+1))) in self.hidden_variables:
275
+ # self.hidden_variables.add((j, -lag))
276
+ # print(self.hidden_variables)
277
+
278
+ # self.graph = self._get_latent_projection_graph(self.graph, stationary=True)
279
+ # self.graph_type = "tsg_admg"
280
+ # else:
281
+
282
+ def _check_graph(self, graph):
283
+ """Checks that graph contains no invalid entries/structure.
284
+
285
+ Assumes graph.shape = (N, N, tau_max+1, tau_max+1)
286
+ """
287
+
288
+ allowed_edges = ["-->", "<--"]
289
+ if 'admg' in self.graph_type:
290
+ allowed_edges += ["<->", "<-+", "+->"]
291
+ elif 'mag' in self.graph_type:
292
+ allowed_edges += ["<->"]
293
+ elif 'pag' in self.graph_type:
294
+ allowed_edges += ["<->", "o-o", "o->", "<-o"] # "o--",
295
+ # "--o",
296
+ # "x-o",
297
+ # "o-x",
298
+ # "x--",
299
+ # "--x",
300
+ # "x->",
301
+ # "<-x",
302
+ # "x-x",
303
+ # ]
304
+
305
+ graph_dict = defaultdict(list)
306
+ for i, j, taui, tauj in zip(*np.where(graph)):
307
+ edge = graph[i, j, taui, tauj]
308
+ # print((i, -taui), edge, (j, -tauj), graph[j, i, tauj, taui])
309
+ if edge != self._reverse_link(graph[j, i, tauj, taui]):
310
+ raise ValueError(
311
+ "graph needs to have consistent edges (eg"
312
+ " graph[i,j,taui,tauj]='-->' requires graph[j,i,tauj,taui]='<--')"
313
+ )
314
+
315
+ if edge not in allowed_edges:
316
+ raise ValueError("Invalid graph edge %s. " %(edge) +
317
+ "For graph_type = %s only %s are allowed." %(self.graph_type, str(allowed_edges)))
318
+
319
+ if edge == "-->" or edge == "+->":
320
+ # Map to (i,-taui, j, tauj) graph
321
+ indexi = i * (self.tau_max + 1) + taui
322
+ indexj = j * (self.tau_max + 1) + tauj
323
+ #dictionary containing all links
324
+ graph_dict[indexj].append(indexi)
325
+
326
+ # Check for cycles
327
+ if self._check_cyclic(graph_dict):
328
+ raise ValueError("graph is cyclic.")
329
+
330
+ if 'mag' in self.graph_type:
331
+ self._check_almost_cyclic(graph_dict)
332
+ # if PAG???
333
+
334
+ def _check_cyclic(self, graph_dict):
335
+ """Return True if the graph_dict has a cycle.
336
+
337
+ graph_dict must be represented as a dictionary mapping vertices to
338
+ iterables of neighbouring vertices. For example:
339
+
340
+ >> cyclic({1: (2,), 2: (3,), 3: (1,)})
341
+ True
342
+ >> cyclic({1: (2,), 2: (3,), 3: (4,)})
343
+ False
344
+ """
345
+
346
+ path = set()
347
+ visited = set()
348
+
349
+ def visit(vertex):
350
+ if vertex in visited:
351
+ return False
352
+ visited.add(vertex)
353
+ path.add(vertex)
354
+ for neighbour in graph_dict.get(vertex, ()):
355
+ if neighbour in path or visit(neighbour):
356
+ return True
357
+ path.remove(vertex)
358
+ return False
359
+
360
+ return any(visit(v) for v in graph_dict)
361
+
362
+ def _check_almost_cyclic(self, parents_dict):
363
+ """Check for almost-cycles in MAG.
364
+
365
+ An almost-cycle is a directed cycle with one edge replaced by a bidirected edge. To check
366
+ that no almost-cycles are present in the MAG, we can check for each bidirected edge that
367
+ no endpoint is an ancestor of the other endpoint.
368
+ Since there are no cycles in the graph containing only the directed edges, we first calculate
369
+ a (reversed) topological ordering of the nodes considering only the directed edges, and then it is sufficient
370
+ to check for each bidirected edge that the endpoint with the lower order is not an ancestor
371
+ of the other endpoint using DFS on the nodes inbetween in the topological ordering.
372
+ """
373
+ # Count incoming edges
374
+ child_count = { i: 0 for i in range(self.N * (self.tau_max + 1))}
375
+ for indexi, parents in parents_dict.items():
376
+ for par in parents:
377
+ child_count[par] +=1
378
+ # Get topological ordering of nodes considering only directed edges
379
+ ordering = []
380
+ while child_count:
381
+ for index, count in child_count.items():
382
+ if count == 0:
383
+ for par in parents_dict[index]:
384
+ child_count[par] -= 1
385
+ del child_count[index]
386
+ ordering.append(index)
387
+ break
388
+ order_index = {node: index for index, node in enumerate(ordering)}
389
+ # Check for almost-cycles
390
+ for i, j, taui, tauj in zip(*np.where(self.graph)):
391
+ edge = self.graph[i, j, taui, tauj]
392
+ if edge == "<->" and (i * (self.tau_max + 1) + taui < j * (self.tau_max + 1) + tauj):
393
+ indexi = i * (self.tau_max + 1) + taui
394
+ indexj = j * (self.tau_max + 1) + tauj
395
+
396
+ if order_index[indexi] < order_index[indexj]:
397
+ higher = indexj
398
+ lower = indexi
399
+ else:
400
+ higher = indexi
401
+ lower = indexj
402
+ # DFS from higher to check if lower is reachable
403
+ stack = [lower]
404
+ visited_dfs = set()
405
+ while stack:
406
+ current = stack.pop()
407
+ for par in parents_dict[current]:
408
+ if par == higher:
409
+ raise ValueError("Graph contains an almost-cycle.")
410
+ elif par not in visited_dfs and order_index[lower] < order_index[par] < order_index[higher]:
411
+ stack.append(par)
412
+
413
+ def get_mediators(self, start, end):
414
+ """Returns mediator variables on proper causal paths.
415
+
416
+ Parameters
417
+ ----------
418
+ start : set
419
+ Set of start nodes.
420
+ end : set
421
+ Set of end nodes.
422
+
423
+ Returns
424
+ -------
425
+ mediators : set
426
+ Mediators on causal paths from start to end.
427
+ """
428
+
429
+ des_X = self._get_descendants(start)
430
+
431
+ mediators = set()
432
+
433
+ # Walk along proper causal paths backwards from Y to X
434
+ # potential_mediators = set()
435
+ for y in end:
436
+ j, tau = y
437
+ this_level = [y]
438
+ while len(this_level) > 0:
439
+ next_level = []
440
+ for varlag in this_level:
441
+ for parent in self._get_parents(varlag):
442
+ i, tau = parent
443
+ # print(varlag, parent, des_X)
444
+ if (parent in des_X
445
+ and parent not in mediators
446
+ # and parent not in potential_mediators
447
+ and parent not in start
448
+ and parent not in end
449
+ and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
450
+ mediators.add(parent)
451
+ next_level.append(parent)
452
+
453
+ this_level = next_level
454
+
455
+ return mediators
456
+
457
+ def _get_mediators_stationary_graph(self, start, end, max_lag):
458
+ """Returns mediator variables on proper causal paths
459
+ from X to Y in a stationary graph."""
460
+
461
+ des_X = self._get_descendants_stationary_graph(start, max_lag)
462
+
463
+ mediators = set()
464
+
465
+ # Walk along proper causal paths backwards from Y to X
466
+ potential_mediators = set()
467
+ for y in end:
468
+ j, tau = y
469
+ this_level = [y]
470
+ while len(this_level) > 0:
471
+ next_level = []
472
+ for varlag in this_level:
473
+ for _, parent in self._get_adjacents_stationary_graph(graph=self.graph,
474
+ node=varlag, patterns=["<*-", "<*+"], max_lag=max_lag, exclude=None):
475
+ i, tau = parent
476
+ if (parent in des_X
477
+ and parent not in mediators
478
+ # and parent not in potential_mediators
479
+ and parent not in start
480
+ and parent not in end
481
+ # and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
482
+ ):
483
+ mediators.add(parent)
484
+ next_level.append(parent)
485
+
486
+ this_level = next_level
487
+
488
+ return mediators
489
+
490
+ def _reverse_link(self, link):
491
+ """Reverse a given link, taking care to replace > with < and vice versa."""
492
+
493
+ if link == "":
494
+ return ""
495
+
496
+ if link[2] == ">":
497
+ left_mark = "<"
498
+ else:
499
+ left_mark = link[2]
500
+
501
+ if link[0] == "<":
502
+ right_mark = ">"
503
+ else:
504
+ right_mark = link[0]
505
+
506
+ return left_mark + link[1] + right_mark
507
+
508
+ def _match_link(self, pattern, link):
509
+ """Matches pattern including wildcards with link.
510
+
511
+ In an ADMG we have edge types ["-->", "<--", "<->", "+->", "<-+"].
512
+ Here +-> corresponds to having both "-->" and "<->".
513
+
514
+ In a MAG we have edge types ["-->", "<--", "<->", "---"].
515
+ """
516
+
517
+ if pattern == '' or link == '':
518
+ return True if pattern == link else False
519
+ else:
520
+ left_mark, middle_mark, right_mark = pattern
521
+ if left_mark != '*':
522
+ # if link[0] != '+':
523
+ if link[0] != left_mark: return False
524
+
525
+ if right_mark != '*':
526
+ # if link[2] != '+':
527
+ if link[2] != right_mark: return False
528
+
529
+ if middle_mark != '*' and link[1] != middle_mark: return False
530
+
531
+ return True
532
+
533
+ def _find_adj(self, node, patterns, exclude=None, return_link=False):
534
+ """Find adjacencies of node that match given patterns."""
535
+
536
+ graph = self.graph
537
+
538
+ if exclude is None:
539
+ exclude = []
540
+ # exclude = self.hidden_variables
541
+ # else:
542
+ # exclude = set(exclude).union(self.hidden_variables)
543
+
544
+ # Setup
545
+ i, lag_i = node
546
+ lag_i = abs(lag_i)
547
+
548
+ if exclude is None: exclude = []
549
+ if type(patterns) == str:
550
+ patterns = [patterns]
551
+
552
+ # Init
553
+ adj = []
554
+ # Find adjacencies going forward/contemp
555
+ for k, lag_ik in zip(*np.where(graph[i,:,lag_i,:])):
556
+ # print((k, lag_ik), graph[i,k,lag_i,lag_ik])
557
+ # matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]
558
+ # if np.any(matches):
559
+ for patt in patterns:
560
+ if self._match_link(patt, graph[i,k,lag_i,lag_ik]):
561
+ match = (k, -lag_ik)
562
+ if match not in exclude:
563
+ if return_link:
564
+ adj.append((graph[i,k,lag_i,lag_ik], match))
565
+ else:
566
+ adj.append(match)
567
+ break
568
+
569
+
570
+ # Find adjacencies going backward/contemp
571
+ for k, lag_ki in zip(*np.where(graph[:,i,:,lag_i])):
572
+ # print((k, lag_ki), graph[k,i,lag_ki,lag_i])
573
+ # matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]
574
+ # if np.any(matches):
575
+ for patt in patterns:
576
+ if self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]):
577
+ match = (k, -lag_ki)
578
+ if match not in exclude:
579
+ if return_link:
580
+ adj.append((self._reverse_link(graph[k,i,lag_ki,lag_i]), match))
581
+ else:
582
+ adj.append(match)
583
+ break
584
+
585
+ adj = list(set(adj))
586
+ return adj
587
+
588
+ def _is_match(self, nodei, nodej, pattern_ij):
589
+ """Check whether the link between X and Y agrees with pattern."""
590
+
591
+ graph = self.graph
592
+
593
+ (i, lag_i) = nodei
594
+ (j, lag_j) = nodej
595
+ tauij = lag_j - lag_i
596
+ if abs(tauij) >= graph.shape[2]:
597
+ return False
598
+ return ((tauij >= 0 and self._match_link(pattern_ij, graph[i, j, tauij])) or
599
+ (tauij < 0 and self._match_link(self._reverse_link(pattern_ij), graph[j, i, abs(tauij)])))
600
+
601
+ def _get_children(self, varlag):
602
+ """Returns set of children (varlag --> ...) for (lagged) varlag."""
603
+ if self.possible:
604
+ patterns=['-*>', 'o*o', 'o*>']
605
+ else:
606
+ patterns=['-*>', '+*>']
607
+ return self._find_adj(node=varlag, patterns=patterns)
608
+
609
+ def _get_parents(self, varlag):
610
+ """Returns set of parents (varlag <-- ...) for (lagged) varlag."""
611
+ if self.possible:
612
+ patterns=['<*-', 'o*o', '<*o']
613
+ else:
614
+ patterns=['<*-', '<*+']
615
+ return self._find_adj(node=varlag, patterns=patterns)
616
+
617
+ def _get_spouses(self, varlag):
618
+ """Returns set of spouses (varlag <-> ...) for (lagged) varlag."""
619
+ return self._find_adj(node=varlag, patterns=['<*>', '+*>', '<*+'])
620
+
621
+ def _get_neighbors(self, varlag):
622
+ """Returns set of neighbors (varlag --- ...) for (lagged) varlag."""
623
+ return self._find_adj(node=varlag, patterns=['-*-'])
624
+
625
+ def _get_ancestors(self, W):
626
+ """Get ancestors of nodes in W up to time tau_max.
627
+
628
+ Includes the nodes themselves.
629
+ """
630
+
631
+ ancestors = set(W)
632
+
633
+ for w in W:
634
+ j, tau = w
635
+ this_level = [w]
636
+ while len(this_level) > 0:
637
+ next_level = []
638
+ for varlag in this_level:
639
+
640
+ for par in self._get_parents(varlag):
641
+ i, tau = par
642
+ if par not in ancestors and -self.tau_max <= tau <= 0:
643
+ ancestors.add(par)
644
+ next_level.append(par)
645
+
646
+ this_level = next_level
647
+
648
+ return ancestors
649
+
650
+ def _get_all_parents(self, W):
651
+ """Get parents of nodes in W up to time tau_max.
652
+
653
+ Includes the nodes themselves.
654
+ """
655
+
656
+ parents = set(W)
657
+
658
+ for w in W:
659
+ j, tau = w
660
+ for par in self._get_parents(w):
661
+ i, tau = par
662
+ if par not in parents and -self.tau_max <= tau <= 0:
663
+ parents.add(par)
664
+
665
+ return parents
666
+
667
+ def _get_all_spouses(self, W):
668
+ """Get spouses of nodes in W up to time tau_max.
669
+
670
+ Includes the nodes themselves.
671
+ """
672
+
673
+ spouses = set(W)
674
+
675
+ for w in W:
676
+ j, tau = w
677
+ for spouse in self._get_spouses(w):
678
+ i, tau = spouse
679
+ if spouse not in spouses and -self.tau_max <= tau <= 0:
680
+ spouses.add(spouse)
681
+
682
+ return spouses
683
+
684
+ def _get_descendants_stationary_graph(self, W, max_lag):
685
+ """Get descendants of nodes in W up to time t in stationary graph.
686
+
687
+ Includes the nodes themselves.
688
+ """
689
+
690
+ descendants = set(W)
691
+
692
+ for w in W:
693
+ j, tau = w
694
+ this_level = [w]
695
+ while len(this_level) > 0:
696
+ next_level = []
697
+ for varlag in this_level:
698
+ for _, child in self._get_adjacents_stationary_graph(graph=self.graph,
699
+ node=varlag, patterns=["-*>", "-*+"], max_lag=max_lag, exclude=None):
700
+ i, tau = child
701
+ if (child not in descendants
702
+ # and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
703
+ ):
704
+ descendants.add(child)
705
+ next_level.append(child)
706
+
707
+ this_level = next_level
708
+
709
+ return descendants
710
+
711
+ def _get_descendants(self, W):
712
+ """Get descendants of nodes in W up to time t.
713
+
714
+ Includes the nodes themselves.
715
+ """
716
+
717
+ descendants = set(W)
718
+
719
+ for w in W:
720
+ j, tau = w
721
+ this_level = [w]
722
+ while len(this_level) > 0:
723
+ next_level = []
724
+ for varlag in this_level:
725
+ for child in self._get_children(varlag):
726
+ i, tau = child
727
+ if (child not in descendants
728
+ and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
729
+ descendants.add(child)
730
+ next_level.append(child)
731
+
732
+ this_level = next_level
733
+
734
+ return descendants
735
+
736
+ def _get_collider_path_nodes(self, start_nodes, mediators, with_parents = True):
737
+ """Returns the set of all nodes from collider path i.e. paths consisting of bidirected edges only
738
+ a node in start_nodes that only contains nodes in mediators and their parents up to maximum time lag.
739
+
740
+ The function recognizes only collider paths of at least length 1.
741
+
742
+ Parameters
743
+ ----------
744
+ start_nodes : set or list of nodes
745
+ The set of nodes that a collider path may start in.
746
+ mediators : set or list of nodes
747
+ All nodes on the Path that are not start_nodes have to be contained in this set.
748
+ with_parents : bool
749
+ If set to false it will only return the collider path nodes, without its parents.
750
+ """
751
+
752
+ collider_path_nodes = set()
753
+ # print("mediators ", mediators)
754
+ for w in start_nodes:
755
+ # print(w)
756
+ this_level = [w]
757
+ while len(this_level) > 0:
758
+ next_level = []
759
+ for varlag in this_level:
760
+ # print("\t", varlag, self._get_spouses(varlag))
761
+ for spouse in self._get_spouses(varlag):
762
+ # print("\t\t", spouse)
763
+ _, tau = spouse
764
+ if (spouse not in collider_path_nodes
765
+ and spouse in mediators
766
+ and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
767
+ collider_path_nodes.add(spouse)
768
+ next_level.append(spouse)
769
+
770
+ this_level = next_level
771
+
772
+ # Add parents
773
+ if with_parents:
774
+ for par in self._get_all_parents(collider_path_nodes):
775
+ _, tau = par
776
+ if (par not in collider_path_nodes
777
+ and par in mediators
778
+ and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
779
+ collider_path_nodes = collider_path_nodes.union(set([par]))
780
+
781
+ return collider_path_nodes
782
+
783
+ def _get_adjacents_stationary_graph(self, graph, node, patterns,
784
+ max_lag=0, exclude=None):
785
+ """Find adjacencies of node matching patterns in a stationary graph."""
786
+
787
+ # graph = self.graph
788
+
789
+ # Setup
790
+ i, lag_i = node
791
+ if exclude is None: exclude = []
792
+ if type(patterns) == str:
793
+ patterns = [patterns]
794
+
795
+ # Init
796
+ adj = []
797
+
798
+ # Find adjacencies going forward/contemp
799
+ for k, lag_ik in zip(*np.where(graph[i,:,:])):
800
+ matches = [self._match_link(patt, graph[i, k, lag_ik]) for patt in patterns]
801
+ if np.any(matches):
802
+ match = (k, lag_i + lag_ik)
803
+ if (k, lag_i + lag_ik) not in exclude and (-max_lag <= lag_i + lag_ik <= 0): # or self.ignore_time_bounds):
804
+ adj.append((graph[i, k, lag_ik], match))
805
+
806
+ # Find adjacencies going backward/contemp
807
+ for k, lag_ki in zip(*np.where(graph[:,i,:])):
808
+ matches = [self._match_link(self._reverse_link(patt), graph[k, i, lag_ki]) for patt in patterns]
809
+ if np.any(matches):
810
+ match = (k, lag_i - lag_ki)
811
+ if (k, lag_i - lag_ki) not in exclude and (-max_lag <= lag_i - lag_ki <= 0): # or self.ignore_time_bounds):
812
+ adj.append((self._reverse_link(graph[k, i, lag_ki]), match))
813
+
814
+ adj = list(set(adj))
815
+ return adj
816
+
817
+ def _get_canonical_dag_from_graph(self, graph):
818
+ """Constructs canonical DAG as links_coeffs dictionary from graph.
819
+
820
+ For every <-> link further latent variables are added.
821
+ This corresponds to a canonical DAG (Richardson Spirtes 2002).
822
+
823
+ Can be used to evaluate d-separation.
824
+ """
825
+
826
+ N, N, tau_maxplusone = graph.shape
827
+ tau_max = tau_maxplusone - 1
828
+
829
+ links = {j: [] for j in range(N)}
830
+
831
+ # Add further latent variables to accommodate <-> links
832
+ latent_index = N
833
+ for i, j, tau in zip(*np.where(graph)):
834
+
835
+ edge_type = graph[i, j, tau]
836
+
837
+ # Consider contemporaneous links only once
838
+ if tau == 0 and j > i:
839
+ continue
840
+
841
+ if edge_type == "-->":
842
+ links[j].append((i, -tau))
843
+ elif edge_type == "<--":
844
+ links[i].append((j, -tau))
845
+ elif edge_type == "<->":
846
+ links[latent_index] = []
847
+ links[i].append((latent_index, 0))
848
+ links[j].append((latent_index, -tau))
849
+ latent_index += 1
850
+ # elif edge_type == "---":
851
+ # links[latent_index] = []
852
+ # selection_vars.append(latent_index)
853
+ # links[latent_index].append((i, -tau))
854
+ # links[latent_index].append((j, 0))
855
+ # latent_index += 1
856
+ elif edge_type == "+->":
857
+ links[j].append((i, -tau))
858
+ links[latent_index] = []
859
+ links[i].append((latent_index, 0))
860
+ links[j].append((latent_index, -tau))
861
+ latent_index += 1
862
+ elif edge_type == "<-+":
863
+ links[i].append((j, -tau))
864
+ links[latent_index] = []
865
+ links[i].append((latent_index, 0))
866
+ links[j].append((latent_index, -tau))
867
+ latent_index += 1
868
+
869
+ return links
870
+
871
+
872
+ def _get_maximum_possible_lag(self, XYZ, graph):
873
+ """Construct maximum relevant time lag for d-separation in stationary graph.
874
+
875
+ TO BE REVISED!
876
+
877
+ """
878
+
879
+ def _repeating(link, seen_path):
880
+ """Returns True if a link or its time-shifted version is already
881
+ included in seen_links."""
882
+ i, taui = link[0]
883
+ j, tauj = link[1]
884
+
885
+ for index, seen_link in enumerate(seen_path[:-1]):
886
+ seen_i, seen_taui = seen_link
887
+ seen_j, seen_tauj = seen_path[index + 1]
888
+
889
+ if (i == seen_i and j == seen_j
890
+ and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
891
+ return True
892
+
893
+ return False
894
+
895
+ # TODO: does this work with PAGs?
896
+ # if self.possible:
897
+ # patterns=['<*-', '<*o', 'o*o']
898
+ # else:
899
+ # patterns=['<*-']
900
+
901
+ canonical_dag_links = self._get_canonical_dag_from_graph(graph)
902
+
903
+ max_lag = 0
904
+ for node in XYZ:
905
+ j, tau = node # tau <= 0
906
+ max_lag = max(max_lag, abs(tau))
907
+
908
+ causal_path = []
909
+ queue = [(node, causal_path)]
910
+
911
+ while queue:
912
+ varlag, causal_path = queue.pop()
913
+ causal_path = [varlag] + causal_path
914
+
915
+ var, lag = varlag
916
+ for partmp in canonical_dag_links[var]:
917
+ i, tautmp = partmp
918
+ # Get shifted lag since canonical_dag_links is at t=0
919
+ tau = tautmp + lag
920
+ par = (i, tau)
921
+
922
+ if (par not in causal_path):
923
+
924
+ if len(causal_path) == 1:
925
+ queue.append((par, causal_path))
926
+ continue
927
+
928
+ if (len(causal_path) > 1) and not _repeating((par, varlag), causal_path):
929
+
930
+ max_lag = max(max_lag, abs(tau))
931
+ queue.append((par, causal_path))
932
+
933
+ return max_lag
934
+
935
+ def _get_latent_projection_graph(self, stationary=False):
936
+ """For DAGs/ADMGs uses the Latent projection operation (Pearl 2009).
937
+
938
+ Assumes a normal or stationary graph with potentially unobserved nodes.
939
+ Also allows particular time steps to be unobserved. By stationarity
940
+ that pattern of unobserved nodes is repeated into -infinity.
941
+
942
+ Latent projection operation for latents = nodes before t-tau_max or due to <->:
943
+ (i) auxADMG contains (i, -taui) --> (j, -tauj) iff there is a directed path
944
+ (i, -taui) --> ... --> (j, -tauj) on which
945
+ every non-endpoint vertex is in hidden variables (= not in observed_vars)
946
+ here iff (i, -|taui-tauj|) --> j in graph
947
+ (ii) auxADMG contains (i, -taui) <-> (j, -tauj) iff there exists a path of the
948
+ form (i, -taui) <-- ... --> (j, -tauj) on
949
+ which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)
950
+ here iff (i, -|taui-tauj|) <-> j OR there is path
951
+ (i, -taui) <-- nodes before t-tau_max --> (j, -tauj)
952
+ """
953
+
954
+ # graph = self.graph
955
+
956
+ # if self.hidden_variables is None:
957
+ # hidden_variables_here = []
958
+ # else:
959
+ hidden_variables_here = self.hidden_variables
960
+
961
+ aux_graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
962
+ aux_graph[:] = ""
963
+ for (i, j) in itertools.product(range(self.N), range(self.N)):
964
+ for jt, tauj in enumerate(range(0, self.tau_max + 1)):
965
+ for it, taui in enumerate(range(0, self.tau_max + 1)):
966
+ tau = abs(taui - tauj)
967
+ if tau == 0 and j == i:
968
+ continue
969
+ if (i, -taui) in hidden_variables_here or (j, -tauj) in hidden_variables_here:
970
+ continue
971
+ # print("\n")
972
+ # print((i, -taui), (j, -tauj))
973
+
974
+ cond_i_xy = (
975
+ # tau <= graph_taumax
976
+ # and (graph[i, j, tau] == '-->' or graph[i, j, tau] == '+->')
977
+ # )
978
+ # and
979
+ self._check_path( #graph=graph,
980
+ start=[(i, -taui)],
981
+ end=[(j, -tauj)],
982
+ conditions=None,
983
+ starts_with=['-*>', '+*>'],
984
+ ends_with=['-*>', '+*>'],
985
+ path_type='causal',
986
+ hidden_by_taumax=False,
987
+ hidden_variables=hidden_variables_here,
988
+ stationary_graph=stationary,
989
+ ))
990
+ cond_i_yx = (
991
+ # tau <= graph_taumax
992
+ # and (graph[i, j, tau] == '<--' or graph[i, j, tau] == '<-+')
993
+ # )
994
+ # and
995
+ self._check_path( #graph=graph,
996
+ start=[(j, -tauj)],
997
+ end=[(i, -taui)],
998
+ conditions=None,
999
+ starts_with=['-*>', '+*>'],
1000
+ ends_with=['-*>', '+*>'],
1001
+ path_type='causal',
1002
+ hidden_by_taumax=False,
1003
+ hidden_variables=hidden_variables_here,
1004
+ stationary_graph=stationary,
1005
+ ))
1006
+ if stationary:
1007
+ hidden_by_taumax_here = True
1008
+ else:
1009
+ hidden_by_taumax_here = False
1010
+
1011
+ cond_ii = (
1012
+ # tau <= graph_taumax
1013
+ # and
1014
+ (
1015
+ # graph[i, j, tau] == '<->'
1016
+ # or graph[i, j, tau] == '+->' or graph[i, j, tau] == '<-+'))
1017
+ self._check_path( #graph=graph,
1018
+ start=[(i, -taui)],
1019
+ end=[(j, -tauj)],
1020
+ conditions=None,
1021
+ starts_with=['<**', '+**'],
1022
+ ends_with=['**>', '**+'],
1023
+ path_type='any',
1024
+ hidden_by_taumax=hidden_by_taumax_here,
1025
+ hidden_variables=hidden_variables_here,
1026
+ stationary_graph=stationary,
1027
+ )))
1028
+
1029
+ if cond_i_xy and not cond_i_yx and not cond_ii:
1030
+ aux_graph[i, j, taui, tauj] = "-->" #graph[i, j, tau]
1031
+ # if tau == 0:
1032
+ aux_graph[j, i, tauj, taui] = "<--" # graph[j, i, tau]
1033
+ elif not cond_i_xy and cond_i_yx and not cond_ii:
1034
+ aux_graph[i, j, taui, tauj] = "<--" #graph[i, j, tau]
1035
+ # if tau == 0:
1036
+ aux_graph[j, i, tauj, taui] = "-->" # graph[j, i, tau]
1037
+ elif not cond_i_xy and not cond_i_yx and cond_ii:
1038
+ aux_graph[i, j, taui, tauj] = '<->'
1039
+ # if tau == 0:
1040
+ aux_graph[j, i, tauj, taui] = '<->'
1041
+ elif cond_i_xy and not cond_i_yx and cond_ii:
1042
+ aux_graph[i, j, taui, tauj] = '+->'
1043
+ # if tau == 0:
1044
+ aux_graph[j, i, tauj, taui] = '<-+'
1045
+ elif not cond_i_xy and cond_i_yx and cond_ii:
1046
+ aux_graph[i, j, taui, tauj] = '<-+'
1047
+ # if tau == 0:
1048
+ aux_graph[j, i, tauj, taui] = '+->'
1049
+ elif cond_i_xy and cond_i_yx:
1050
+ raise ValueError("Cycle between %s and %s!" %(str(i, -taui), str(j, -tauj)))
1051
+ # print(aux_graph[i, j, taui, tauj])
1052
+
1053
+ # print((i, -taui), (j, -tauj), cond_i_xy, cond_i_yx, cond_ii, aux_graph[i, j, taui, tauj], aux_graph[j, i, tauj, taui])
1054
+
1055
+ return aux_graph
1056
+
1057
+ def _check_path(self,
1058
+ # graph,
1059
+ start, end,
1060
+ conditions=None,
1061
+ starts_with=None,
1062
+ ends_with=None,
1063
+ path_type='any',
1064
+ # causal_children=None,
1065
+ stationary_graph=False,
1066
+ hidden_by_taumax=False,
1067
+ hidden_variables=None,
1068
+ ):
1069
+ """Check whether an open/active path between start and end given conditions exists.
1070
+
1071
+ Also allows to restrict start and end patterns and to consider causal/non-causal paths
1072
+
1073
+ hidden_by_taumax and hidden_variables are relevant for the latent projection operation.
1074
+ """
1075
+
1076
+
1077
+ if conditions is None:
1078
+ conditions = set()
1079
+ # if conditioned_variables is None:
1080
+ # S = []
1081
+
1082
+ start = set(start)
1083
+ end = set(end)
1084
+ conditions = set(conditions)
1085
+
1086
+ # Get maximal possible time lag of a connecting path
1087
+ # See Thm. XXXX - TO BE REVISED!
1088
+ XYZ = start.union(end).union(conditions)
1089
+ if stationary_graph:
1090
+ max_lag = 10*self.tau_max # TO BE REVISED! self._get_maximum_possible_lag(XYZ, self.graph)
1091
+ causal_children = list(self._get_mediators_stationary_graph(start, end, max_lag).union(end))
1092
+ else:
1093
+ max_lag = None
1094
+ causal_children = list(self.get_mediators(start, end).union(end))
1095
+
1096
+ # if hidden_variables is None:
1097
+ # hidden_variables = set()
1098
+
1099
+ if hidden_by_taumax:
1100
+ if hidden_variables is None:
1101
+ hidden_variables = set()
1102
+ hidden_variables = hidden_variables.union([(k, -tauk) for k in range(self.N)
1103
+ for tauk in range(self.tau_max+1, max_lag + 1)])
1104
+
1105
+ # print("causal_children ", causal_children)
1106
+
1107
+ if starts_with is None:
1108
+ starts_with = ['***']
1109
+ elif type(starts_with) == str:
1110
+ starts_with = [starts_with]
1111
+
1112
+ if ends_with is None:
1113
+ ends_with = ['***']
1114
+ elif type(ends_with) == str:
1115
+ ends_with = [ends_with]
1116
+ #
1117
+ # Breadth-first search to find connection
1118
+ #
1119
+ # print("\nstart, starts_with, ends_with, end ", start, starts_with, ends_with, end)
1120
+ # print("hidden_variables ", hidden_variables)
1121
+ start_from = deque()
1122
+ for x in start:
1123
+ if stationary_graph:
1124
+ link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=x, patterns=starts_with,
1125
+ max_lag=max_lag, exclude=list(start))
1126
+ else:
1127
+ link_neighbors = self._find_adj(node=x, patterns=starts_with, exclude=list(start), return_link=True)
1128
+
1129
+ for link_neighbor in link_neighbors:
1130
+ link, neighbor = link_neighbor
1131
+
1132
+ # if before_taumax and neighbor[1] >= -self.tau_max:
1133
+ # continue
1134
+
1135
+ if (hidden_variables is not None and neighbor not in end
1136
+ and neighbor not in hidden_variables):
1137
+ continue
1138
+
1139
+ if path_type == 'non_causal':
1140
+ if (neighbor in causal_children and self._match_link('-*>', link)
1141
+ and not self._match_link('+*>', link)):
1142
+ continue
1143
+ elif path_type == 'causal':
1144
+ if (neighbor not in causal_children): # or self._match_link('<**', link)):
1145
+ continue
1146
+ start_from.append((x, link, neighbor))
1147
+
1148
+ # print("start, end, start_from ", start, end, start_from)
1149
+
1150
+ visited = set()
1151
+ for (varlag_i, link_ik, varlag_k) in start_from:
1152
+ visited.add((link_ik, varlag_k))
1153
+
1154
+ # Traversing through motifs i *-* k *-* j
1155
+ while start_from:
1156
+
1157
+ # print("Continue ", start_from)
1158
+ varlag_i, link_ik, varlag_k = start_from.popleft()
1159
+
1160
+ # Check if we reached the end
1161
+ if varlag_k in end:
1162
+ if any(self._match_link(patt, link_ik) for patt in ends_with):
1163
+ # print("Connected ", varlag_i, link_ik, varlag_k)
1164
+ return True
1165
+ else:
1166
+ continue
1167
+
1168
+ # print("Get k = ", link_ik, varlag_k)
1169
+ # print("start_from ", start_from)
1170
+ # print("visited ", visited)
1171
+
1172
+ if stationary_graph:
1173
+ link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=varlag_k, patterns='***',
1174
+ max_lag=max_lag, exclude=list(start))
1175
+ else:
1176
+ link_neighbors = self._find_adj(node=varlag_k, patterns='***', exclude=list(start), return_link=True)
1177
+
1178
+ # print("link_neighbors ", link_neighbors)
1179
+ for link_neighbor in link_neighbors:
1180
+ link_kj, varlag_j = link_neighbor
1181
+ # print("Walk ", link_ik, varlag_k, link_kj, varlag_j)
1182
+
1183
+ # print ("visited ", (link_kj, varlag_j), visited)
1184
+ if (link_kj, varlag_j) in visited:
1185
+ # if (varlag_i, link_kj, varlag_j) in visited:
1186
+ # print("in visited")
1187
+ continue
1188
+ # print("Not in visited")
1189
+
1190
+ if path_type == 'causal':
1191
+ if not (self._match_link('-*>', link_kj) or self._match_link('+*>', link_kj)):
1192
+ continue
1193
+
1194
+ # If motif i *-* k *-* j is open,
1195
+ # then add link_kj, varlag_j to visited and start_from
1196
+ left_mark = link_ik[2]
1197
+ right_mark = link_kj[0]
1198
+ # print(left_mark, right_mark)
1199
+
1200
+ if self.definite_status:
1201
+ # Exclude paths that are not definite_status implying that any of the following
1202
+ # motifs occurs:
1203
+ # i *-> k o-* j
1204
+ if (left_mark == '>' and right_mark == 'o'):
1205
+ continue
1206
+ # i *-o k <-* j
1207
+ if (left_mark == 'o' and right_mark == '<'):
1208
+ continue
1209
+ # i *-o k o-* j and i and j are adjacent
1210
+ if (left_mark == 'o' and right_mark == 'o'
1211
+ and self._is_match(varlag_i, varlag_j, "***")):
1212
+ continue
1213
+
1214
+ # If k is in conditions and motif is *-o k o-*, then motif is blocked since
1215
+ # i and j are non-adjacent due to the check above
1216
+ if varlag_k in conditions and (left_mark == 'o' and right_mark == 'o'):
1217
+ # print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
1218
+ continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
1219
+
1220
+ # If k is in conditions and left or right mark is tail '-', then motif is blocked
1221
+ if varlag_k in conditions and (left_mark == '-' or right_mark == '-'):
1222
+ # print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
1223
+ continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
1224
+
1225
+ # If k is not in conditions and left and right mark are heads '><', then motif is blocked
1226
+ if varlag_k not in conditions and (left_mark == '>' and right_mark == '<'):
1227
+ # print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
1228
+ continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
1229
+
1230
+ # if (before_taumax and varlag_j not in end
1231
+ # and varlag_j[1] >= -self.tau_max):
1232
+ # # print("before_taumax ", varlag_j)
1233
+ # continue
1234
+
1235
+ if (hidden_variables is not None and varlag_j not in end
1236
+ and varlag_j not in hidden_variables):
1237
+ continue
1238
+
1239
+ # Motif is open
1240
+ # print("Motif open ", link_ik, varlag_k, link_kj, varlag_j )
1241
+ # start_from.add((link_kj, varlag_j))
1242
+ visited.add((link_kj, varlag_j))
1243
+ start_from.append((varlag_k, link_kj, varlag_j))
1244
+ # visited.add((varlag_k, link_kj, varlag_j))
1245
+
1246
+
1247
+ # print("Separated")
1248
+ return False
1249
+
1250
+
1251
+ def _get_causal_paths(self, source_nodes, target_nodes,
1252
+ mediators=None,
1253
+ mediated_through=None,
1254
+ proper_paths=True,
1255
+ ):
1256
+ """Returns causal paths via depth-first search.
1257
+
1258
+ Allows to restrict paths through mediated_through.
1259
+
1260
+ """
1261
+
1262
+ source_nodes = set(source_nodes)
1263
+ target_nodes = set(target_nodes)
1264
+
1265
+ if mediators is None:
1266
+ mediators = set()
1267
+ else:
1268
+ mediators = set(mediators)
1269
+
1270
+ if mediated_through is None:
1271
+ mediated_through = []
1272
+ mediated_through = set(mediated_through)
1273
+
1274
+ if proper_paths:
1275
+ inside_set = mediators.union(target_nodes) - source_nodes
1276
+ else:
1277
+ inside_set = mediators.union(target_nodes).union(source_nodes)
1278
+
1279
+ all_causal_paths = {}
1280
+ for w in source_nodes:
1281
+ all_causal_paths[w] = {}
1282
+ for z in target_nodes:
1283
+ all_causal_paths[w][z] = []
1284
+
1285
+ for w in source_nodes:
1286
+
1287
+ causal_path = []
1288
+ queue = [(w, causal_path)]
1289
+
1290
+ while queue:
1291
+
1292
+ varlag, causal_path = queue.pop()
1293
+ causal_path = causal_path + [varlag]
1294
+ suitable_nodes = set(self._get_children(varlag)
1295
+ ).intersection(inside_set)
1296
+ for node in suitable_nodes:
1297
+ i, tau = node
1298
+ if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
1299
+ and node not in causal_path):
1300
+
1301
+ queue.append((node, causal_path))
1302
+
1303
+ if node in target_nodes:
1304
+ if len(mediated_through) > 0 and len(set(causal_path).intersection(mediated_through)) == 0:
1305
+ continue
1306
+ else:
1307
+ all_causal_paths[w][node].append(causal_path + [node])
1308
+
1309
+ return all_causal_paths
1310
+
1311
+
1312
+ def _get_adjacency(self, varleg):
1313
+ """
1314
+ Get all the nodes that are connect to varleg by any sort of edge. No pattern matching needed.
1315
+ """
1316
+ var, leg = varleg
1317
+ adjacency_matrix = self.graph[var][:, abs(leg)]
1318
+ adjacency = [(i, -j) for i, j in zip(*np.where(adjacency_matrix != ''))]
1319
+ return adjacency
1320
+
1321
+
1322
+ def _edge_is_visible(self, edge):
1323
+ """
1324
+ This function returns true if the edge is visible. In a DAG or CPDAG all directed edges are visible. In a MAG
1325
+ PAG an edge X-->Y is visible if there exists a node V not adjacent to Y such that there is a collider path from
1326
+ V into X where all (possibly zero) mediators are parents of Y.
1327
+
1328
+ Parameters
1329
+ ----------
1330
+ edge : an ordered pair of nodes X and Y, each consisting of their variable index and time lag
1331
+ """
1332
+ (x, xlag), (y, ylag) = edge
1333
+ # print(x, xlag, y, ylag)
1334
+ # only directed edges can be visible
1335
+ if self.graph[x][y][abs(xlag)][abs(ylag)] != '-->':
1336
+ return False
1337
+ # currently not used condition for cpdags
1338
+ if 'dag' in self.graph_type:
1339
+ return True
1340
+ # get all nodes on collider paths in the parents of Y and going into X
1341
+ y_parents = set(self._get_parents(edge[1]))
1342
+ collider_path_nodes = self._get_collider_path_nodes([edge[0]], y_parents, with_parents=False)
1343
+ if collider_path_nodes == set():
1344
+ #make sure that the lag of y is given as a negative number
1345
+ collider_path_nodes.add((x,- abs(xlag)))
1346
+ # chack if non-adjacent exists with edge into X or the collider path
1347
+ for node in collider_path_nodes:
1348
+ if not (set(self._get_spouses(node)).union(set(self._get_parents(node))) - set(self._get_adjacency(edge[1]))
1349
+ == set()):
1350
+ return True
1351
+ #chack if such a node is not a parent of Y
1352
+ return False
1353
+
1354
+
1355
+ def is_amenable(self, source_nodes, target_nodes) -> bool:
1356
+ """
1357
+ A function that returns whether the Graph is amenable with respect to source_nodes and target_nodes.
1358
+
1359
+ The graph is amenable with respect to source_nodes and target_nodes if every first edge on every
1360
+ proper possibly directed path from a source to a target node is visible.
1361
+
1362
+ Parameters
1363
+ ----------
1364
+ source_nodes : list of nodes (tupels containing a varable index in [0,self.N) and a time lag in [-self.tau_max, 0]
1365
+ target_nodes : list of nodes (tupels containing a varable index in [0,self.N) and a time lag in [-self.tau_max, 0]
1366
+ """
1367
+ if self.graph_type in ['dag', 'tsg_dag', 'stationary_dag']:
1368
+ return True
1369
+ all_nodes = [(i, -tau) for i in range(self.N) for tau in range(self.tau_max + 1)]
1370
+ paths = self._get_causal_paths(source_nodes, target_nodes, mediators=all_nodes)
1371
+ first_edges = set()
1372
+ for source in source_nodes:
1373
+ for target in target_nodes:
1374
+ for path in paths[source][target]:
1375
+ first_edges.add((path[0], path[1]))
1376
+ for edge in first_edges:
1377
+ if not self._edge_is_visible(edge):
1378
+ return False
1379
+ return True
1380
+
1381
+
1382
+ @staticmethod
1383
+ def get_dict_from_graph(graph, parents_only=False):
1384
+ """Helper function to convert graph to dictionary of links.
1385
+
1386
+ Parameters
1387
+ ---------
1388
+ graph : array of shape (N, N, tau_max+1)
1389
+ Matrix format of graph in string format.
1390
+
1391
+ parents_only : bool
1392
+ Whether to only return parents ('-->' in graph)
1393
+
1394
+ Returns
1395
+ -------
1396
+ links : dict
1397
+ Dictionary of form {0:{(0, -1): o-o, ...}, 1:{...}, ...}.
1398
+ """
1399
+ N = graph.shape[0]
1400
+
1401
+ links = dict([(j, {}) for j in range(N)])
1402
+
1403
+ if parents_only:
1404
+ for (i, j, tau) in zip(*np.where(graph=='-->')):
1405
+ links[j][(i, -tau)] = graph[i,j,tau]
1406
+ else:
1407
+ for (i, j, tau) in zip(*np.where(graph!='')):
1408
+ links[j][(i, -tau)] = graph[i,j,tau]
1409
+
1410
+ return links
1411
+
1412
+ @staticmethod
1413
+ def get_graph_from_dict(links, tau_max=None):
1414
+ """Helper function to convert dictionary of links to graph array format.
1415
+
1416
+ Parameters
1417
+ ---------
1418
+ links : dict
1419
+ Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.
1420
+ Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.
1421
+ tau_max : int or None
1422
+ Maximum lag. If None, the maximum lag in links is used.
1423
+
1424
+ Returns
1425
+ -------
1426
+ graph : array of shape (N, N, tau_max+1)
1427
+ Matrix format of graph with 1 for true links and 0 else.
1428
+ """
1429
+
1430
+ def _get_minmax_lag(links):
1431
+ """Helper function to retrieve tau_min and tau_max from links.
1432
+ """
1433
+
1434
+ N = len(links)
1435
+
1436
+ # Get maximum time lag
1437
+ min_lag = np.inf
1438
+ max_lag = 0
1439
+ for j in range(N):
1440
+ for link_props in links[j]:
1441
+ if len(link_props) > 2:
1442
+ var, lag = link_props[0]
1443
+ coeff = link_props[1]
1444
+ # func = link_props[2]
1445
+ if coeff != 0.:
1446
+ min_lag = min(min_lag, abs(lag))
1447
+ max_lag = max(max_lag, abs(lag))
1448
+ else:
1449
+ var, lag = link_props
1450
+ min_lag = min(min_lag, abs(lag))
1451
+ max_lag = max(max_lag, abs(lag))
1452
+
1453
+ return min_lag, max_lag
1454
+
1455
+ N = len(links)
1456
+
1457
+ # Get maximum time lag
1458
+ min_lag, max_lag = _get_minmax_lag(links)
1459
+
1460
+ # Set maximum lag
1461
+ if tau_max is None:
1462
+ tau_max = max_lag
1463
+ else:
1464
+ if max_lag > tau_max:
1465
+ raise ValueError("tau_max is smaller than maximum lag = %d "
1466
+ "found in links, use tau_max=None or larger "
1467
+ "value" % max_lag)
1468
+
1469
+ graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
1470
+ for j in links.keys():
1471
+ for link_props in links[j]:
1472
+ if len(link_props) > 2:
1473
+ var, lag = link_props[0]
1474
+ coeff = link_props[1]
1475
+ if coeff != 0.:
1476
+ graph[var, j, abs(lag)] = "-->"
1477
+ if lag == 0:
1478
+ graph[j, var, 0] = "<--"
1479
+ else:
1480
+ var, lag = link_props
1481
+ graph[var, j, abs(lag)] = "-->"
1482
+ if lag == 0:
1483
+ graph[j, var, 0] = "<--"
1484
+
1485
+ return graph
1486
+
1487
+
1488
+ if __name__ == '__main__':
1489
+
1490
+ # Consider some toy data
1491
+ import tigramite
1492
+ import tigramite.toymodels.structural_causal_processes as toys
1493
+ import tigramite.data_processing as pp
1494
+ import tigramite.plotting as tp
1495
+ from matplotlib import pyplot as plt
1496
+ import sys
1497
+
1498
+ # # Use staticmethod to get graph
1499
+ graph = np.array([['', '-->'],
1500
+ ['<--', '']], dtype='<U3')
1501
+
1502
+
1503
+ # # Initialize class as `stationary_dag`
1504
+ causal_effects = Graphs(graph, graph_type='dag', tau_max=0,
1505
+ verbosity=1)
1506
+
1507
+
1508
+
1509
+