tigramite-fast 5.2.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigramite/__init__.py +0 -0
- tigramite/causal_effects.py +1525 -0
- tigramite/causal_mediation.py +1592 -0
- tigramite/data_processing.py +1574 -0
- tigramite/graphs.py +1509 -0
- tigramite/independence_tests/LBFGS.py +1114 -0
- tigramite/independence_tests/__init__.py +0 -0
- tigramite/independence_tests/cmiknn.py +661 -0
- tigramite/independence_tests/cmiknn_mixed.py +1397 -0
- tigramite/independence_tests/cmisymb.py +286 -0
- tigramite/independence_tests/gpdc.py +664 -0
- tigramite/independence_tests/gpdc_torch.py +820 -0
- tigramite/independence_tests/gsquared.py +190 -0
- tigramite/independence_tests/independence_tests_base.py +1310 -0
- tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
- tigramite/independence_tests/pairwise_CI.py +383 -0
- tigramite/independence_tests/parcorr.py +369 -0
- tigramite/independence_tests/parcorr_mult.py +485 -0
- tigramite/independence_tests/parcorr_wls.py +451 -0
- tigramite/independence_tests/regressionCI.py +403 -0
- tigramite/independence_tests/robust_parcorr.py +403 -0
- tigramite/jpcmciplus.py +966 -0
- tigramite/lpcmci.py +3649 -0
- tigramite/models.py +2257 -0
- tigramite/pcmci.py +3935 -0
- tigramite/pcmci_base.py +1218 -0
- tigramite/plotting.py +4735 -0
- tigramite/rpcmci.py +467 -0
- tigramite/toymodels/__init__.py +0 -0
- tigramite/toymodels/context_model.py +261 -0
- tigramite/toymodels/non_additive.py +1231 -0
- tigramite/toymodels/structural_causal_processes.py +1201 -0
- tigramite/toymodels/surrogate_generator.py +319 -0
- tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
- tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
- tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
- tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
- tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
tigramite/graphs.py
ADDED
|
@@ -0,0 +1,1509 @@
|
|
|
1
|
+
"""Tigramite causal inference for time series."""
|
|
2
|
+
|
|
3
|
+
# Author: Jakob Runge <jakob@jakob-runge.com>
|
|
4
|
+
#
|
|
5
|
+
# License: GNU General Public License v3.0
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import math
|
|
9
|
+
import itertools
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from collections import defaultdict, deque
|
|
12
|
+
from tigramite.models import Models
|
|
13
|
+
import struct
|
|
14
|
+
|
|
15
|
+
class Graphs():
|
|
16
|
+
r"""Graph class.
|
|
17
|
+
|
|
18
|
+
Methods for dealing with causal graphs. Various graph types are
|
|
19
|
+
supported, also including hidden variables.
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
Parameters
|
|
23
|
+
----------
|
|
24
|
+
graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]
|
|
25
|
+
Different graph types are supported, see tutorial.
|
|
26
|
+
graph_type : str
|
|
27
|
+
Type of graph.
|
|
28
|
+
tau_max : int, optional (default: 0)
|
|
29
|
+
Maximum time lag of graph.
|
|
30
|
+
hidden_variables : list of tuples
|
|
31
|
+
Hidden variables in format [(i, -tau), ...]. The internal graph is
|
|
32
|
+
constructed by a latent projection.
|
|
33
|
+
verbosity : int, optional (default: 0)
|
|
34
|
+
Level of verbosity.
|
|
35
|
+
"""
|
|
36
|
+
|
|
37
|
+
def __init__(self,
|
|
38
|
+
graph,
|
|
39
|
+
graph_type,
|
|
40
|
+
tau_max=0,
|
|
41
|
+
hidden_variables=None,
|
|
42
|
+
verbosity=0):
|
|
43
|
+
|
|
44
|
+
self.verbosity = verbosity
|
|
45
|
+
self.N = graph.shape[0]
|
|
46
|
+
self.tau_max = tau_max
|
|
47
|
+
|
|
48
|
+
#
|
|
49
|
+
# Checks regarding graph type
|
|
50
|
+
#
|
|
51
|
+
supported_graphs = ['dag',
|
|
52
|
+
'admg',
|
|
53
|
+
'tsg_dag',
|
|
54
|
+
'tsg_admg',
|
|
55
|
+
'stationary_dag',
|
|
56
|
+
'stationary_admg',
|
|
57
|
+
|
|
58
|
+
'mag',
|
|
59
|
+
'tsg_mag',
|
|
60
|
+
# 'stationary_mag',
|
|
61
|
+
# 'pag',
|
|
62
|
+
# 'tsg_pag',
|
|
63
|
+
# 'stationary_pag',
|
|
64
|
+
]
|
|
65
|
+
if graph_type not in supported_graphs:
|
|
66
|
+
raise ValueError("Only graph types %s supported!" %supported_graphs)
|
|
67
|
+
|
|
68
|
+
# TODO?: check that masking aligns with hidden samples in variables
|
|
69
|
+
if hidden_variables is None:
|
|
70
|
+
hidden_variables = []
|
|
71
|
+
|
|
72
|
+
# Only needed for later extension to MAG/PAGs
|
|
73
|
+
if 'pag' in graph_type:
|
|
74
|
+
self.possible = True
|
|
75
|
+
self.definite_status = True
|
|
76
|
+
else:
|
|
77
|
+
self.possible = False
|
|
78
|
+
self.definite_status = False
|
|
79
|
+
|
|
80
|
+
# Not needed for now...
|
|
81
|
+
# self.ignore_time_bounds = False
|
|
82
|
+
|
|
83
|
+
# Construct internal graph from input graph depending on graph type
|
|
84
|
+
# and hidden variables
|
|
85
|
+
self._construct_graph(graph=graph, graph_type=graph_type,
|
|
86
|
+
hidden_variables=hidden_variables)
|
|
87
|
+
|
|
88
|
+
self._check_graph(self.graph)
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def _construct_graph(self, graph, graph_type, hidden_variables):
|
|
92
|
+
"""Construct internal graph object based on input graph and hidden variables.
|
|
93
|
+
|
|
94
|
+
Uses the latent projection operation.
|
|
95
|
+
"""
|
|
96
|
+
|
|
97
|
+
if graph_type in ['dag', 'admg']:
|
|
98
|
+
if graph.ndim != 2:
|
|
99
|
+
raise ValueError("graph_type in ['dag', 'admg'] assumes graph.shape=(N, N).")
|
|
100
|
+
|
|
101
|
+
allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
|
|
102
|
+
if np.any(np.isin(graph, allowed_edges) == False):
|
|
103
|
+
raise ValueError("Graph contains invalid graph edge. " +
|
|
104
|
+
"For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
|
|
105
|
+
|
|
106
|
+
# Convert to shape [N, N, 1, 1] with dummy dimension
|
|
107
|
+
# to process as tsg_dag or tsg_admg with potential hidden variables
|
|
108
|
+
self.graph = np.expand_dims(graph, axis=(2, 3))
|
|
109
|
+
|
|
110
|
+
# tau_max needed in _get_latent_projection_graph
|
|
111
|
+
# self.tau_max = 0
|
|
112
|
+
|
|
113
|
+
if len(hidden_variables) > 0:
|
|
114
|
+
self.graph = self._get_latent_projection_graph() # stationary=False)
|
|
115
|
+
self.graph_type = "tsg_admg"
|
|
116
|
+
else:
|
|
117
|
+
# graph = self.graph
|
|
118
|
+
self.graph_type = 'tsg_' + graph_type
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
elif graph_type in ['tsg_mag', 'mag']:
|
|
122
|
+
allowed_edges = ["-->", "<--", "<->", ""]
|
|
123
|
+
if np.any(np.isin(graph, allowed_edges) == False):
|
|
124
|
+
raise ValueError("Graph contains invalid graph edge. " +
|
|
125
|
+
"For graph_type = %s only %s are allowed." % (graph_type, str(allowed_edges)))
|
|
126
|
+
|
|
127
|
+
if len(hidden_variables) > 0:
|
|
128
|
+
raise ValueError(f"Hidden variables can not be combined with {graph_type}.")
|
|
129
|
+
|
|
130
|
+
if graph_type == 'mag':
|
|
131
|
+
if graph.ndim != 2:
|
|
132
|
+
raise ValueError("graph_type 'mag' assumes graph.shape=(N, N).")
|
|
133
|
+
self.graph = np.expand_dims(graph, axis=(2, 3))
|
|
134
|
+
self.graph_type = 'tsg_' + graph_type
|
|
135
|
+
else:
|
|
136
|
+
if graph.ndim != 4:
|
|
137
|
+
raise ValueError("tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).")
|
|
138
|
+
# Then tau_max is implicitely derived from
|
|
139
|
+
# the dimensions
|
|
140
|
+
self.graph = graph
|
|
141
|
+
# self.tau_max = graph.shape[2] - 1
|
|
142
|
+
self.graph_type = graph_type
|
|
143
|
+
|
|
144
|
+
elif graph_type in ['tsg_dag', 'tsg_admg']:
|
|
145
|
+
if graph.ndim != 4:
|
|
146
|
+
raise ValueError("tsg-graph_type assumes graph.shape=(N, N, tau_max+1, tau_max+1).")
|
|
147
|
+
|
|
148
|
+
allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
|
|
149
|
+
if np.any(np.isin(graph, allowed_edges) == False):
|
|
150
|
+
raise ValueError("Graph contains invalid graph edge. " +
|
|
151
|
+
"For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
|
|
152
|
+
|
|
153
|
+
# Then tau_max is implicitely derived from
|
|
154
|
+
# the dimensions
|
|
155
|
+
self.graph = graph
|
|
156
|
+
# self.tau_max = graph.shape[2] - 1
|
|
157
|
+
|
|
158
|
+
if len(hidden_variables) > 0:
|
|
159
|
+
self.graph = self._get_latent_projection_graph() #, stationary=False)
|
|
160
|
+
self.graph_type = "tsg_admg"
|
|
161
|
+
else:
|
|
162
|
+
self.graph_type = graph_type
|
|
163
|
+
|
|
164
|
+
elif graph_type in ['stationary_dag', 'stationary_admg']:
|
|
165
|
+
# Currently only stationary_dag without hidden variables is supported
|
|
166
|
+
if graph.ndim != 3:
|
|
167
|
+
raise ValueError("stationary graph_type assumes graph.shape=(N, N, tau_max+1).")
|
|
168
|
+
|
|
169
|
+
allowed_edges = ["-->", "<--", "<->", "<-+", "+->", ""]
|
|
170
|
+
if np.any(np.isin(graph, allowed_edges) == False):
|
|
171
|
+
raise ValueError("Graph contains invalid graph edge. " +
|
|
172
|
+
"For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
|
|
173
|
+
|
|
174
|
+
# # TODO: remove if theory for stationary ADMGs is clear
|
|
175
|
+
# if graph_type == 'stationary_dag' and len(hidden_variables) > 0:
|
|
176
|
+
# raise ValueError("Hidden variables currently not supported for "
|
|
177
|
+
# "stationary_dag.")
|
|
178
|
+
|
|
179
|
+
# For a stationary DAG without hidden variables it's sufficient to consider
|
|
180
|
+
# a tau_max that includes the parents of X, Y, M, and S. A conservative
|
|
181
|
+
# estimate thereof is simply the lag-dimension of the stationary DAG plus
|
|
182
|
+
# the maximum lag of XYS.
|
|
183
|
+
# statgraph_tau_max = graph.shape[2] - 1
|
|
184
|
+
# maxlag_XYS = 0
|
|
185
|
+
# for varlag in self.X.union(self.Y).union(self.S):
|
|
186
|
+
# maxlag_XYS = max(maxlag_XYS, abs(varlag[1]))
|
|
187
|
+
|
|
188
|
+
# self.tau_max = maxlag_XYS + statgraph_tau_max
|
|
189
|
+
|
|
190
|
+
stat_graph = deepcopy(graph)
|
|
191
|
+
|
|
192
|
+
#########################################
|
|
193
|
+
# Use this tau_max and construct ADMG by assuming paths of
|
|
194
|
+
# maximal lag 10*tau_max... TO BE REVISED!
|
|
195
|
+
self.graph = graph
|
|
196
|
+
self.graph = self._get_latent_projection_graph(stationary=True)
|
|
197
|
+
self.graph_type = "tsg_admg"
|
|
198
|
+
#########################################
|
|
199
|
+
|
|
200
|
+
# Also create stationary graph extended to tau_max
|
|
201
|
+
self.stationary_graph = np.zeros((self.N, self.N, self.tau_max + 1), dtype='<U3')
|
|
202
|
+
self.stationary_graph[:, :, :stat_graph.shape[2]] = stat_graph
|
|
203
|
+
|
|
204
|
+
# allowed_edges = ["-->", "<--"]
|
|
205
|
+
|
|
206
|
+
# # Construct tsg_graph
|
|
207
|
+
# graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
|
|
208
|
+
# graph[:] = ""
|
|
209
|
+
# for (i, j) in itertools.product(range(self.N), range(self.N)):
|
|
210
|
+
# for jt, tauj in enumerate(range(0, self.tau_max + 1)):
|
|
211
|
+
# for it, taui in enumerate(range(tauj, self.tau_max + 1)):
|
|
212
|
+
# tau = abs(taui - tauj)
|
|
213
|
+
# if tau == 0 and j == i:
|
|
214
|
+
# continue
|
|
215
|
+
# if tau > statgraph_tau_max:
|
|
216
|
+
# continue
|
|
217
|
+
|
|
218
|
+
# # if tau == 0:
|
|
219
|
+
# # if stat_graph[i, j, tau] == '-->':
|
|
220
|
+
# # graph[i, j, taui, tauj] = "-->"
|
|
221
|
+
# # graph[j, i, tauj, taui] = "<--"
|
|
222
|
+
|
|
223
|
+
# # # elif stat_graph[i, j, tau] == '<--':
|
|
224
|
+
# # # graph[i, j, taui, tauj] = "<--"
|
|
225
|
+
# # # graph[j, i, tauj, taui] = "-->"
|
|
226
|
+
# # else:
|
|
227
|
+
# if stat_graph[i, j, tau] == '-->':
|
|
228
|
+
# graph[i, j, taui, tauj] = "-->"
|
|
229
|
+
# graph[j, i, tauj, taui] = "<--"
|
|
230
|
+
# elif stat_graph[i, j, tau] == '<--':
|
|
231
|
+
# pass
|
|
232
|
+
# elif stat_graph[i, j, tau] == '':
|
|
233
|
+
# pass
|
|
234
|
+
# else:
|
|
235
|
+
# edge = stat_graph[i, j, tau]
|
|
236
|
+
# raise ValueError("Invalid graph edge %s. " %(edge) +
|
|
237
|
+
# "For graph_type = %s only %s are allowed." %(graph_type, str(allowed_edges)))
|
|
238
|
+
|
|
239
|
+
# # elif stat_graph[i, j, tau] == '<--':
|
|
240
|
+
# # graph[i, j, taui, tauj] = "<--"
|
|
241
|
+
# # graph[j, i, tauj, taui] = "-->"
|
|
242
|
+
|
|
243
|
+
# self.graph_type = 'tsg_dag'
|
|
244
|
+
# self.graph = graph
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
# return (graph, graph_type, self.tau_max, hidden_variables)
|
|
248
|
+
|
|
249
|
+
# max_lag = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S)), graph=graph)
|
|
250
|
+
|
|
251
|
+
# stat_mediators = self._get_mediators_stationary_graph(start=X, end=Y, max_lag=max_lag)
|
|
252
|
+
# self.tau_max = self._get_maximum_possible_lag(XYZ=list(X.union(Y).union(S).union(stat_mediators)), graph=graph)
|
|
253
|
+
# self.tau_max = graph_taumax
|
|
254
|
+
# for varlag in X.union(Y).union(S):
|
|
255
|
+
# self.tau_max = max(self.tau_max, abs(varlag[1]))
|
|
256
|
+
|
|
257
|
+
# if verbosity > 0:
|
|
258
|
+
# print("Setting tau_max = ", self.tau_max)
|
|
259
|
+
|
|
260
|
+
# if tau_max is None:
|
|
261
|
+
# self.tau_max = graph_taumax
|
|
262
|
+
# for varlag in X.union(Y).union(S):
|
|
263
|
+
# self.tau_max = max(self.tau_max, abs(varlag[1]))
|
|
264
|
+
|
|
265
|
+
# if verbosity > 0:
|
|
266
|
+
# print("Setting tau_max = ", self.tau_max)
|
|
267
|
+
# else:
|
|
268
|
+
# self.tau_max = graph_taumax
|
|
269
|
+
# # Repeat hidden variable pattern
|
|
270
|
+
# # if larger tau_max is given
|
|
271
|
+
# if self.tau_max > graph_taumax:
|
|
272
|
+
# for lag in range(graph_taumax + 1, self.tau_max + 1):
|
|
273
|
+
# for j in range(self.N):
|
|
274
|
+
# if (j, -(lag % (graph_taumax+1))) in self.hidden_variables:
|
|
275
|
+
# self.hidden_variables.add((j, -lag))
|
|
276
|
+
# print(self.hidden_variables)
|
|
277
|
+
|
|
278
|
+
# self.graph = self._get_latent_projection_graph(self.graph, stationary=True)
|
|
279
|
+
# self.graph_type = "tsg_admg"
|
|
280
|
+
# else:
|
|
281
|
+
|
|
282
|
+
def _check_graph(self, graph):
|
|
283
|
+
"""Checks that graph contains no invalid entries/structure.
|
|
284
|
+
|
|
285
|
+
Assumes graph.shape = (N, N, tau_max+1, tau_max+1)
|
|
286
|
+
"""
|
|
287
|
+
|
|
288
|
+
allowed_edges = ["-->", "<--"]
|
|
289
|
+
if 'admg' in self.graph_type:
|
|
290
|
+
allowed_edges += ["<->", "<-+", "+->"]
|
|
291
|
+
elif 'mag' in self.graph_type:
|
|
292
|
+
allowed_edges += ["<->"]
|
|
293
|
+
elif 'pag' in self.graph_type:
|
|
294
|
+
allowed_edges += ["<->", "o-o", "o->", "<-o"] # "o--",
|
|
295
|
+
# "--o",
|
|
296
|
+
# "x-o",
|
|
297
|
+
# "o-x",
|
|
298
|
+
# "x--",
|
|
299
|
+
# "--x",
|
|
300
|
+
# "x->",
|
|
301
|
+
# "<-x",
|
|
302
|
+
# "x-x",
|
|
303
|
+
# ]
|
|
304
|
+
|
|
305
|
+
graph_dict = defaultdict(list)
|
|
306
|
+
for i, j, taui, tauj in zip(*np.where(graph)):
|
|
307
|
+
edge = graph[i, j, taui, tauj]
|
|
308
|
+
# print((i, -taui), edge, (j, -tauj), graph[j, i, tauj, taui])
|
|
309
|
+
if edge != self._reverse_link(graph[j, i, tauj, taui]):
|
|
310
|
+
raise ValueError(
|
|
311
|
+
"graph needs to have consistent edges (eg"
|
|
312
|
+
" graph[i,j,taui,tauj]='-->' requires graph[j,i,tauj,taui]='<--')"
|
|
313
|
+
)
|
|
314
|
+
|
|
315
|
+
if edge not in allowed_edges:
|
|
316
|
+
raise ValueError("Invalid graph edge %s. " %(edge) +
|
|
317
|
+
"For graph_type = %s only %s are allowed." %(self.graph_type, str(allowed_edges)))
|
|
318
|
+
|
|
319
|
+
if edge == "-->" or edge == "+->":
|
|
320
|
+
# Map to (i,-taui, j, tauj) graph
|
|
321
|
+
indexi = i * (self.tau_max + 1) + taui
|
|
322
|
+
indexj = j * (self.tau_max + 1) + tauj
|
|
323
|
+
#dictionary containing all links
|
|
324
|
+
graph_dict[indexj].append(indexi)
|
|
325
|
+
|
|
326
|
+
# Check for cycles
|
|
327
|
+
if self._check_cyclic(graph_dict):
|
|
328
|
+
raise ValueError("graph is cyclic.")
|
|
329
|
+
|
|
330
|
+
if 'mag' in self.graph_type:
|
|
331
|
+
self._check_almost_cyclic(graph_dict)
|
|
332
|
+
# if PAG???
|
|
333
|
+
|
|
334
|
+
def _check_cyclic(self, graph_dict):
|
|
335
|
+
"""Return True if the graph_dict has a cycle.
|
|
336
|
+
|
|
337
|
+
graph_dict must be represented as a dictionary mapping vertices to
|
|
338
|
+
iterables of neighbouring vertices. For example:
|
|
339
|
+
|
|
340
|
+
>> cyclic({1: (2,), 2: (3,), 3: (1,)})
|
|
341
|
+
True
|
|
342
|
+
>> cyclic({1: (2,), 2: (3,), 3: (4,)})
|
|
343
|
+
False
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
path = set()
|
|
347
|
+
visited = set()
|
|
348
|
+
|
|
349
|
+
def visit(vertex):
|
|
350
|
+
if vertex in visited:
|
|
351
|
+
return False
|
|
352
|
+
visited.add(vertex)
|
|
353
|
+
path.add(vertex)
|
|
354
|
+
for neighbour in graph_dict.get(vertex, ()):
|
|
355
|
+
if neighbour in path or visit(neighbour):
|
|
356
|
+
return True
|
|
357
|
+
path.remove(vertex)
|
|
358
|
+
return False
|
|
359
|
+
|
|
360
|
+
return any(visit(v) for v in graph_dict)
|
|
361
|
+
|
|
362
|
+
def _check_almost_cyclic(self, parents_dict):
|
|
363
|
+
"""Check for almost-cycles in MAG.
|
|
364
|
+
|
|
365
|
+
An almost-cycle is a directed cycle with one edge replaced by a bidirected edge. To check
|
|
366
|
+
that no almost-cycles are present in the MAG, we can check for each bidirected edge that
|
|
367
|
+
no endpoint is an ancestor of the other endpoint.
|
|
368
|
+
Since there are no cycles in the graph containing only the directed edges, we first calculate
|
|
369
|
+
a (reversed) topological ordering of the nodes considering only the directed edges, and then it is sufficient
|
|
370
|
+
to check for each bidirected edge that the endpoint with the lower order is not an ancestor
|
|
371
|
+
of the other endpoint using DFS on the nodes inbetween in the topological ordering.
|
|
372
|
+
"""
|
|
373
|
+
# Count incoming edges
|
|
374
|
+
child_count = { i: 0 for i in range(self.N * (self.tau_max + 1))}
|
|
375
|
+
for indexi, parents in parents_dict.items():
|
|
376
|
+
for par in parents:
|
|
377
|
+
child_count[par] +=1
|
|
378
|
+
# Get topological ordering of nodes considering only directed edges
|
|
379
|
+
ordering = []
|
|
380
|
+
while child_count:
|
|
381
|
+
for index, count in child_count.items():
|
|
382
|
+
if count == 0:
|
|
383
|
+
for par in parents_dict[index]:
|
|
384
|
+
child_count[par] -= 1
|
|
385
|
+
del child_count[index]
|
|
386
|
+
ordering.append(index)
|
|
387
|
+
break
|
|
388
|
+
order_index = {node: index for index, node in enumerate(ordering)}
|
|
389
|
+
# Check for almost-cycles
|
|
390
|
+
for i, j, taui, tauj in zip(*np.where(self.graph)):
|
|
391
|
+
edge = self.graph[i, j, taui, tauj]
|
|
392
|
+
if edge == "<->" and (i * (self.tau_max + 1) + taui < j * (self.tau_max + 1) + tauj):
|
|
393
|
+
indexi = i * (self.tau_max + 1) + taui
|
|
394
|
+
indexj = j * (self.tau_max + 1) + tauj
|
|
395
|
+
|
|
396
|
+
if order_index[indexi] < order_index[indexj]:
|
|
397
|
+
higher = indexj
|
|
398
|
+
lower = indexi
|
|
399
|
+
else:
|
|
400
|
+
higher = indexi
|
|
401
|
+
lower = indexj
|
|
402
|
+
# DFS from higher to check if lower is reachable
|
|
403
|
+
stack = [lower]
|
|
404
|
+
visited_dfs = set()
|
|
405
|
+
while stack:
|
|
406
|
+
current = stack.pop()
|
|
407
|
+
for par in parents_dict[current]:
|
|
408
|
+
if par == higher:
|
|
409
|
+
raise ValueError("Graph contains an almost-cycle.")
|
|
410
|
+
elif par not in visited_dfs and order_index[lower] < order_index[par] < order_index[higher]:
|
|
411
|
+
stack.append(par)
|
|
412
|
+
|
|
413
|
+
def get_mediators(self, start, end):
|
|
414
|
+
"""Returns mediator variables on proper causal paths.
|
|
415
|
+
|
|
416
|
+
Parameters
|
|
417
|
+
----------
|
|
418
|
+
start : set
|
|
419
|
+
Set of start nodes.
|
|
420
|
+
end : set
|
|
421
|
+
Set of end nodes.
|
|
422
|
+
|
|
423
|
+
Returns
|
|
424
|
+
-------
|
|
425
|
+
mediators : set
|
|
426
|
+
Mediators on causal paths from start to end.
|
|
427
|
+
"""
|
|
428
|
+
|
|
429
|
+
des_X = self._get_descendants(start)
|
|
430
|
+
|
|
431
|
+
mediators = set()
|
|
432
|
+
|
|
433
|
+
# Walk along proper causal paths backwards from Y to X
|
|
434
|
+
# potential_mediators = set()
|
|
435
|
+
for y in end:
|
|
436
|
+
j, tau = y
|
|
437
|
+
this_level = [y]
|
|
438
|
+
while len(this_level) > 0:
|
|
439
|
+
next_level = []
|
|
440
|
+
for varlag in this_level:
|
|
441
|
+
for parent in self._get_parents(varlag):
|
|
442
|
+
i, tau = parent
|
|
443
|
+
# print(varlag, parent, des_X)
|
|
444
|
+
if (parent in des_X
|
|
445
|
+
and parent not in mediators
|
|
446
|
+
# and parent not in potential_mediators
|
|
447
|
+
and parent not in start
|
|
448
|
+
and parent not in end
|
|
449
|
+
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
|
|
450
|
+
mediators.add(parent)
|
|
451
|
+
next_level.append(parent)
|
|
452
|
+
|
|
453
|
+
this_level = next_level
|
|
454
|
+
|
|
455
|
+
return mediators
|
|
456
|
+
|
|
457
|
+
def _get_mediators_stationary_graph(self, start, end, max_lag):
|
|
458
|
+
"""Returns mediator variables on proper causal paths
|
|
459
|
+
from X to Y in a stationary graph."""
|
|
460
|
+
|
|
461
|
+
des_X = self._get_descendants_stationary_graph(start, max_lag)
|
|
462
|
+
|
|
463
|
+
mediators = set()
|
|
464
|
+
|
|
465
|
+
# Walk along proper causal paths backwards from Y to X
|
|
466
|
+
potential_mediators = set()
|
|
467
|
+
for y in end:
|
|
468
|
+
j, tau = y
|
|
469
|
+
this_level = [y]
|
|
470
|
+
while len(this_level) > 0:
|
|
471
|
+
next_level = []
|
|
472
|
+
for varlag in this_level:
|
|
473
|
+
for _, parent in self._get_adjacents_stationary_graph(graph=self.graph,
|
|
474
|
+
node=varlag, patterns=["<*-", "<*+"], max_lag=max_lag, exclude=None):
|
|
475
|
+
i, tau = parent
|
|
476
|
+
if (parent in des_X
|
|
477
|
+
and parent not in mediators
|
|
478
|
+
# and parent not in potential_mediators
|
|
479
|
+
and parent not in start
|
|
480
|
+
and parent not in end
|
|
481
|
+
# and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
|
|
482
|
+
):
|
|
483
|
+
mediators.add(parent)
|
|
484
|
+
next_level.append(parent)
|
|
485
|
+
|
|
486
|
+
this_level = next_level
|
|
487
|
+
|
|
488
|
+
return mediators
|
|
489
|
+
|
|
490
|
+
def _reverse_link(self, link):
|
|
491
|
+
"""Reverse a given link, taking care to replace > with < and vice versa."""
|
|
492
|
+
|
|
493
|
+
if link == "":
|
|
494
|
+
return ""
|
|
495
|
+
|
|
496
|
+
if link[2] == ">":
|
|
497
|
+
left_mark = "<"
|
|
498
|
+
else:
|
|
499
|
+
left_mark = link[2]
|
|
500
|
+
|
|
501
|
+
if link[0] == "<":
|
|
502
|
+
right_mark = ">"
|
|
503
|
+
else:
|
|
504
|
+
right_mark = link[0]
|
|
505
|
+
|
|
506
|
+
return left_mark + link[1] + right_mark
|
|
507
|
+
|
|
508
|
+
def _match_link(self, pattern, link):
|
|
509
|
+
"""Matches pattern including wildcards with link.
|
|
510
|
+
|
|
511
|
+
In an ADMG we have edge types ["-->", "<--", "<->", "+->", "<-+"].
|
|
512
|
+
Here +-> corresponds to having both "-->" and "<->".
|
|
513
|
+
|
|
514
|
+
In a MAG we have edge types ["-->", "<--", "<->", "---"].
|
|
515
|
+
"""
|
|
516
|
+
|
|
517
|
+
if pattern == '' or link == '':
|
|
518
|
+
return True if pattern == link else False
|
|
519
|
+
else:
|
|
520
|
+
left_mark, middle_mark, right_mark = pattern
|
|
521
|
+
if left_mark != '*':
|
|
522
|
+
# if link[0] != '+':
|
|
523
|
+
if link[0] != left_mark: return False
|
|
524
|
+
|
|
525
|
+
if right_mark != '*':
|
|
526
|
+
# if link[2] != '+':
|
|
527
|
+
if link[2] != right_mark: return False
|
|
528
|
+
|
|
529
|
+
if middle_mark != '*' and link[1] != middle_mark: return False
|
|
530
|
+
|
|
531
|
+
return True
|
|
532
|
+
|
|
533
|
+
def _find_adj(self, node, patterns, exclude=None, return_link=False):
|
|
534
|
+
"""Find adjacencies of node that match given patterns."""
|
|
535
|
+
|
|
536
|
+
graph = self.graph
|
|
537
|
+
|
|
538
|
+
if exclude is None:
|
|
539
|
+
exclude = []
|
|
540
|
+
# exclude = self.hidden_variables
|
|
541
|
+
# else:
|
|
542
|
+
# exclude = set(exclude).union(self.hidden_variables)
|
|
543
|
+
|
|
544
|
+
# Setup
|
|
545
|
+
i, lag_i = node
|
|
546
|
+
lag_i = abs(lag_i)
|
|
547
|
+
|
|
548
|
+
if exclude is None: exclude = []
|
|
549
|
+
if type(patterns) == str:
|
|
550
|
+
patterns = [patterns]
|
|
551
|
+
|
|
552
|
+
# Init
|
|
553
|
+
adj = []
|
|
554
|
+
# Find adjacencies going forward/contemp
|
|
555
|
+
for k, lag_ik in zip(*np.where(graph[i,:,lag_i,:])):
|
|
556
|
+
# print((k, lag_ik), graph[i,k,lag_i,lag_ik])
|
|
557
|
+
# matches = [self._match_link(patt, graph[i,k,lag_i,lag_ik]) for patt in patterns]
|
|
558
|
+
# if np.any(matches):
|
|
559
|
+
for patt in patterns:
|
|
560
|
+
if self._match_link(patt, graph[i,k,lag_i,lag_ik]):
|
|
561
|
+
match = (k, -lag_ik)
|
|
562
|
+
if match not in exclude:
|
|
563
|
+
if return_link:
|
|
564
|
+
adj.append((graph[i,k,lag_i,lag_ik], match))
|
|
565
|
+
else:
|
|
566
|
+
adj.append(match)
|
|
567
|
+
break
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
# Find adjacencies going backward/contemp
|
|
571
|
+
for k, lag_ki in zip(*np.where(graph[:,i,:,lag_i])):
|
|
572
|
+
# print((k, lag_ki), graph[k,i,lag_ki,lag_i])
|
|
573
|
+
# matches = [self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]) for patt in patterns]
|
|
574
|
+
# if np.any(matches):
|
|
575
|
+
for patt in patterns:
|
|
576
|
+
if self._match_link(self._reverse_link(patt), graph[k,i,lag_ki,lag_i]):
|
|
577
|
+
match = (k, -lag_ki)
|
|
578
|
+
if match not in exclude:
|
|
579
|
+
if return_link:
|
|
580
|
+
adj.append((self._reverse_link(graph[k,i,lag_ki,lag_i]), match))
|
|
581
|
+
else:
|
|
582
|
+
adj.append(match)
|
|
583
|
+
break
|
|
584
|
+
|
|
585
|
+
adj = list(set(adj))
|
|
586
|
+
return adj
|
|
587
|
+
|
|
588
|
+
def _is_match(self, nodei, nodej, pattern_ij):
|
|
589
|
+
"""Check whether the link between X and Y agrees with pattern."""
|
|
590
|
+
|
|
591
|
+
graph = self.graph
|
|
592
|
+
|
|
593
|
+
(i, lag_i) = nodei
|
|
594
|
+
(j, lag_j) = nodej
|
|
595
|
+
tauij = lag_j - lag_i
|
|
596
|
+
if abs(tauij) >= graph.shape[2]:
|
|
597
|
+
return False
|
|
598
|
+
return ((tauij >= 0 and self._match_link(pattern_ij, graph[i, j, tauij])) or
|
|
599
|
+
(tauij < 0 and self._match_link(self._reverse_link(pattern_ij), graph[j, i, abs(tauij)])))
|
|
600
|
+
|
|
601
|
+
def _get_children(self, varlag):
|
|
602
|
+
"""Returns set of children (varlag --> ...) for (lagged) varlag."""
|
|
603
|
+
if self.possible:
|
|
604
|
+
patterns=['-*>', 'o*o', 'o*>']
|
|
605
|
+
else:
|
|
606
|
+
patterns=['-*>', '+*>']
|
|
607
|
+
return self._find_adj(node=varlag, patterns=patterns)
|
|
608
|
+
|
|
609
|
+
def _get_parents(self, varlag):
|
|
610
|
+
"""Returns set of parents (varlag <-- ...) for (lagged) varlag."""
|
|
611
|
+
if self.possible:
|
|
612
|
+
patterns=['<*-', 'o*o', '<*o']
|
|
613
|
+
else:
|
|
614
|
+
patterns=['<*-', '<*+']
|
|
615
|
+
return self._find_adj(node=varlag, patterns=patterns)
|
|
616
|
+
|
|
617
|
+
def _get_spouses(self, varlag):
|
|
618
|
+
"""Returns set of spouses (varlag <-> ...) for (lagged) varlag."""
|
|
619
|
+
return self._find_adj(node=varlag, patterns=['<*>', '+*>', '<*+'])
|
|
620
|
+
|
|
621
|
+
def _get_neighbors(self, varlag):
|
|
622
|
+
"""Returns set of neighbors (varlag --- ...) for (lagged) varlag."""
|
|
623
|
+
return self._find_adj(node=varlag, patterns=['-*-'])
|
|
624
|
+
|
|
625
|
+
def _get_ancestors(self, W):
|
|
626
|
+
"""Get ancestors of nodes in W up to time tau_max.
|
|
627
|
+
|
|
628
|
+
Includes the nodes themselves.
|
|
629
|
+
"""
|
|
630
|
+
|
|
631
|
+
ancestors = set(W)
|
|
632
|
+
|
|
633
|
+
for w in W:
|
|
634
|
+
j, tau = w
|
|
635
|
+
this_level = [w]
|
|
636
|
+
while len(this_level) > 0:
|
|
637
|
+
next_level = []
|
|
638
|
+
for varlag in this_level:
|
|
639
|
+
|
|
640
|
+
for par in self._get_parents(varlag):
|
|
641
|
+
i, tau = par
|
|
642
|
+
if par not in ancestors and -self.tau_max <= tau <= 0:
|
|
643
|
+
ancestors.add(par)
|
|
644
|
+
next_level.append(par)
|
|
645
|
+
|
|
646
|
+
this_level = next_level
|
|
647
|
+
|
|
648
|
+
return ancestors
|
|
649
|
+
|
|
650
|
+
def _get_all_parents(self, W):
|
|
651
|
+
"""Get parents of nodes in W up to time tau_max.
|
|
652
|
+
|
|
653
|
+
Includes the nodes themselves.
|
|
654
|
+
"""
|
|
655
|
+
|
|
656
|
+
parents = set(W)
|
|
657
|
+
|
|
658
|
+
for w in W:
|
|
659
|
+
j, tau = w
|
|
660
|
+
for par in self._get_parents(w):
|
|
661
|
+
i, tau = par
|
|
662
|
+
if par not in parents and -self.tau_max <= tau <= 0:
|
|
663
|
+
parents.add(par)
|
|
664
|
+
|
|
665
|
+
return parents
|
|
666
|
+
|
|
667
|
+
def _get_all_spouses(self, W):
|
|
668
|
+
"""Get spouses of nodes in W up to time tau_max.
|
|
669
|
+
|
|
670
|
+
Includes the nodes themselves.
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
spouses = set(W)
|
|
674
|
+
|
|
675
|
+
for w in W:
|
|
676
|
+
j, tau = w
|
|
677
|
+
for spouse in self._get_spouses(w):
|
|
678
|
+
i, tau = spouse
|
|
679
|
+
if spouse not in spouses and -self.tau_max <= tau <= 0:
|
|
680
|
+
spouses.add(spouse)
|
|
681
|
+
|
|
682
|
+
return spouses
|
|
683
|
+
|
|
684
|
+
def _get_descendants_stationary_graph(self, W, max_lag):
|
|
685
|
+
"""Get descendants of nodes in W up to time t in stationary graph.
|
|
686
|
+
|
|
687
|
+
Includes the nodes themselves.
|
|
688
|
+
"""
|
|
689
|
+
|
|
690
|
+
descendants = set(W)
|
|
691
|
+
|
|
692
|
+
for w in W:
|
|
693
|
+
j, tau = w
|
|
694
|
+
this_level = [w]
|
|
695
|
+
while len(this_level) > 0:
|
|
696
|
+
next_level = []
|
|
697
|
+
for varlag in this_level:
|
|
698
|
+
for _, child in self._get_adjacents_stationary_graph(graph=self.graph,
|
|
699
|
+
node=varlag, patterns=["-*>", "-*+"], max_lag=max_lag, exclude=None):
|
|
700
|
+
i, tau = child
|
|
701
|
+
if (child not in descendants
|
|
702
|
+
# and (-self.tau_max <= tau <= 0 or self.ignore_time_bounds)
|
|
703
|
+
):
|
|
704
|
+
descendants.add(child)
|
|
705
|
+
next_level.append(child)
|
|
706
|
+
|
|
707
|
+
this_level = next_level
|
|
708
|
+
|
|
709
|
+
return descendants
|
|
710
|
+
|
|
711
|
+
def _get_descendants(self, W):
|
|
712
|
+
"""Get descendants of nodes in W up to time t.
|
|
713
|
+
|
|
714
|
+
Includes the nodes themselves.
|
|
715
|
+
"""
|
|
716
|
+
|
|
717
|
+
descendants = set(W)
|
|
718
|
+
|
|
719
|
+
for w in W:
|
|
720
|
+
j, tau = w
|
|
721
|
+
this_level = [w]
|
|
722
|
+
while len(this_level) > 0:
|
|
723
|
+
next_level = []
|
|
724
|
+
for varlag in this_level:
|
|
725
|
+
for child in self._get_children(varlag):
|
|
726
|
+
i, tau = child
|
|
727
|
+
if (child not in descendants
|
|
728
|
+
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
|
|
729
|
+
descendants.add(child)
|
|
730
|
+
next_level.append(child)
|
|
731
|
+
|
|
732
|
+
this_level = next_level
|
|
733
|
+
|
|
734
|
+
return descendants
|
|
735
|
+
|
|
736
|
+
def _get_collider_path_nodes(self, start_nodes, mediators, with_parents = True):
|
|
737
|
+
"""Returns the set of all nodes from collider path i.e. paths consisting of bidirected edges only
|
|
738
|
+
a node in start_nodes that only contains nodes in mediators and their parents up to maximum time lag.
|
|
739
|
+
|
|
740
|
+
The function recognizes only collider paths of at least length 1.
|
|
741
|
+
|
|
742
|
+
Parameters
|
|
743
|
+
----------
|
|
744
|
+
start_nodes : set or list of nodes
|
|
745
|
+
The set of nodes that a collider path may start in.
|
|
746
|
+
mediators : set or list of nodes
|
|
747
|
+
All nodes on the Path that are not start_nodes have to be contained in this set.
|
|
748
|
+
with_parents : bool
|
|
749
|
+
If set to false it will only return the collider path nodes, without its parents.
|
|
750
|
+
"""
|
|
751
|
+
|
|
752
|
+
collider_path_nodes = set()
|
|
753
|
+
# print("mediators ", mediators)
|
|
754
|
+
for w in start_nodes:
|
|
755
|
+
# print(w)
|
|
756
|
+
this_level = [w]
|
|
757
|
+
while len(this_level) > 0:
|
|
758
|
+
next_level = []
|
|
759
|
+
for varlag in this_level:
|
|
760
|
+
# print("\t", varlag, self._get_spouses(varlag))
|
|
761
|
+
for spouse in self._get_spouses(varlag):
|
|
762
|
+
# print("\t\t", spouse)
|
|
763
|
+
_, tau = spouse
|
|
764
|
+
if (spouse not in collider_path_nodes
|
|
765
|
+
and spouse in mediators
|
|
766
|
+
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
|
|
767
|
+
collider_path_nodes.add(spouse)
|
|
768
|
+
next_level.append(spouse)
|
|
769
|
+
|
|
770
|
+
this_level = next_level
|
|
771
|
+
|
|
772
|
+
# Add parents
|
|
773
|
+
if with_parents:
|
|
774
|
+
for par in self._get_all_parents(collider_path_nodes):
|
|
775
|
+
_, tau = par
|
|
776
|
+
if (par not in collider_path_nodes
|
|
777
|
+
and par in mediators
|
|
778
|
+
and (-self.tau_max <= tau <= 0)): # or self.ignore_time_bounds)):
|
|
779
|
+
collider_path_nodes = collider_path_nodes.union(set([par]))
|
|
780
|
+
|
|
781
|
+
return collider_path_nodes
|
|
782
|
+
|
|
783
|
+
def _get_adjacents_stationary_graph(self, graph, node, patterns,
|
|
784
|
+
max_lag=0, exclude=None):
|
|
785
|
+
"""Find adjacencies of node matching patterns in a stationary graph."""
|
|
786
|
+
|
|
787
|
+
# graph = self.graph
|
|
788
|
+
|
|
789
|
+
# Setup
|
|
790
|
+
i, lag_i = node
|
|
791
|
+
if exclude is None: exclude = []
|
|
792
|
+
if type(patterns) == str:
|
|
793
|
+
patterns = [patterns]
|
|
794
|
+
|
|
795
|
+
# Init
|
|
796
|
+
adj = []
|
|
797
|
+
|
|
798
|
+
# Find adjacencies going forward/contemp
|
|
799
|
+
for k, lag_ik in zip(*np.where(graph[i,:,:])):
|
|
800
|
+
matches = [self._match_link(patt, graph[i, k, lag_ik]) for patt in patterns]
|
|
801
|
+
if np.any(matches):
|
|
802
|
+
match = (k, lag_i + lag_ik)
|
|
803
|
+
if (k, lag_i + lag_ik) not in exclude and (-max_lag <= lag_i + lag_ik <= 0): # or self.ignore_time_bounds):
|
|
804
|
+
adj.append((graph[i, k, lag_ik], match))
|
|
805
|
+
|
|
806
|
+
# Find adjacencies going backward/contemp
|
|
807
|
+
for k, lag_ki in zip(*np.where(graph[:,i,:])):
|
|
808
|
+
matches = [self._match_link(self._reverse_link(patt), graph[k, i, lag_ki]) for patt in patterns]
|
|
809
|
+
if np.any(matches):
|
|
810
|
+
match = (k, lag_i - lag_ki)
|
|
811
|
+
if (k, lag_i - lag_ki) not in exclude and (-max_lag <= lag_i - lag_ki <= 0): # or self.ignore_time_bounds):
|
|
812
|
+
adj.append((self._reverse_link(graph[k, i, lag_ki]), match))
|
|
813
|
+
|
|
814
|
+
adj = list(set(adj))
|
|
815
|
+
return adj
|
|
816
|
+
|
|
817
|
+
def _get_canonical_dag_from_graph(self, graph):
|
|
818
|
+
"""Constructs canonical DAG as links_coeffs dictionary from graph.
|
|
819
|
+
|
|
820
|
+
For every <-> link further latent variables are added.
|
|
821
|
+
This corresponds to a canonical DAG (Richardson Spirtes 2002).
|
|
822
|
+
|
|
823
|
+
Can be used to evaluate d-separation.
|
|
824
|
+
"""
|
|
825
|
+
|
|
826
|
+
N, N, tau_maxplusone = graph.shape
|
|
827
|
+
tau_max = tau_maxplusone - 1
|
|
828
|
+
|
|
829
|
+
links = {j: [] for j in range(N)}
|
|
830
|
+
|
|
831
|
+
# Add further latent variables to accommodate <-> links
|
|
832
|
+
latent_index = N
|
|
833
|
+
for i, j, tau in zip(*np.where(graph)):
|
|
834
|
+
|
|
835
|
+
edge_type = graph[i, j, tau]
|
|
836
|
+
|
|
837
|
+
# Consider contemporaneous links only once
|
|
838
|
+
if tau == 0 and j > i:
|
|
839
|
+
continue
|
|
840
|
+
|
|
841
|
+
if edge_type == "-->":
|
|
842
|
+
links[j].append((i, -tau))
|
|
843
|
+
elif edge_type == "<--":
|
|
844
|
+
links[i].append((j, -tau))
|
|
845
|
+
elif edge_type == "<->":
|
|
846
|
+
links[latent_index] = []
|
|
847
|
+
links[i].append((latent_index, 0))
|
|
848
|
+
links[j].append((latent_index, -tau))
|
|
849
|
+
latent_index += 1
|
|
850
|
+
# elif edge_type == "---":
|
|
851
|
+
# links[latent_index] = []
|
|
852
|
+
# selection_vars.append(latent_index)
|
|
853
|
+
# links[latent_index].append((i, -tau))
|
|
854
|
+
# links[latent_index].append((j, 0))
|
|
855
|
+
# latent_index += 1
|
|
856
|
+
elif edge_type == "+->":
|
|
857
|
+
links[j].append((i, -tau))
|
|
858
|
+
links[latent_index] = []
|
|
859
|
+
links[i].append((latent_index, 0))
|
|
860
|
+
links[j].append((latent_index, -tau))
|
|
861
|
+
latent_index += 1
|
|
862
|
+
elif edge_type == "<-+":
|
|
863
|
+
links[i].append((j, -tau))
|
|
864
|
+
links[latent_index] = []
|
|
865
|
+
links[i].append((latent_index, 0))
|
|
866
|
+
links[j].append((latent_index, -tau))
|
|
867
|
+
latent_index += 1
|
|
868
|
+
|
|
869
|
+
return links
|
|
870
|
+
|
|
871
|
+
|
|
872
|
+
def _get_maximum_possible_lag(self, XYZ, graph):
|
|
873
|
+
"""Construct maximum relevant time lag for d-separation in stationary graph.
|
|
874
|
+
|
|
875
|
+
TO BE REVISED!
|
|
876
|
+
|
|
877
|
+
"""
|
|
878
|
+
|
|
879
|
+
def _repeating(link, seen_path):
|
|
880
|
+
"""Returns True if a link or its time-shifted version is already
|
|
881
|
+
included in seen_links."""
|
|
882
|
+
i, taui = link[0]
|
|
883
|
+
j, tauj = link[1]
|
|
884
|
+
|
|
885
|
+
for index, seen_link in enumerate(seen_path[:-1]):
|
|
886
|
+
seen_i, seen_taui = seen_link
|
|
887
|
+
seen_j, seen_tauj = seen_path[index + 1]
|
|
888
|
+
|
|
889
|
+
if (i == seen_i and j == seen_j
|
|
890
|
+
and abs(tauj-taui) == abs(seen_tauj-seen_taui)):
|
|
891
|
+
return True
|
|
892
|
+
|
|
893
|
+
return False
|
|
894
|
+
|
|
895
|
+
# TODO: does this work with PAGs?
|
|
896
|
+
# if self.possible:
|
|
897
|
+
# patterns=['<*-', '<*o', 'o*o']
|
|
898
|
+
# else:
|
|
899
|
+
# patterns=['<*-']
|
|
900
|
+
|
|
901
|
+
canonical_dag_links = self._get_canonical_dag_from_graph(graph)
|
|
902
|
+
|
|
903
|
+
max_lag = 0
|
|
904
|
+
for node in XYZ:
|
|
905
|
+
j, tau = node # tau <= 0
|
|
906
|
+
max_lag = max(max_lag, abs(tau))
|
|
907
|
+
|
|
908
|
+
causal_path = []
|
|
909
|
+
queue = [(node, causal_path)]
|
|
910
|
+
|
|
911
|
+
while queue:
|
|
912
|
+
varlag, causal_path = queue.pop()
|
|
913
|
+
causal_path = [varlag] + causal_path
|
|
914
|
+
|
|
915
|
+
var, lag = varlag
|
|
916
|
+
for partmp in canonical_dag_links[var]:
|
|
917
|
+
i, tautmp = partmp
|
|
918
|
+
# Get shifted lag since canonical_dag_links is at t=0
|
|
919
|
+
tau = tautmp + lag
|
|
920
|
+
par = (i, tau)
|
|
921
|
+
|
|
922
|
+
if (par not in causal_path):
|
|
923
|
+
|
|
924
|
+
if len(causal_path) == 1:
|
|
925
|
+
queue.append((par, causal_path))
|
|
926
|
+
continue
|
|
927
|
+
|
|
928
|
+
if (len(causal_path) > 1) and not _repeating((par, varlag), causal_path):
|
|
929
|
+
|
|
930
|
+
max_lag = max(max_lag, abs(tau))
|
|
931
|
+
queue.append((par, causal_path))
|
|
932
|
+
|
|
933
|
+
return max_lag
|
|
934
|
+
|
|
935
|
+
def _get_latent_projection_graph(self, stationary=False):
|
|
936
|
+
"""For DAGs/ADMGs uses the Latent projection operation (Pearl 2009).
|
|
937
|
+
|
|
938
|
+
Assumes a normal or stationary graph with potentially unobserved nodes.
|
|
939
|
+
Also allows particular time steps to be unobserved. By stationarity
|
|
940
|
+
that pattern of unobserved nodes is repeated into -infinity.
|
|
941
|
+
|
|
942
|
+
Latent projection operation for latents = nodes before t-tau_max or due to <->:
|
|
943
|
+
(i) auxADMG contains (i, -taui) --> (j, -tauj) iff there is a directed path
|
|
944
|
+
(i, -taui) --> ... --> (j, -tauj) on which
|
|
945
|
+
every non-endpoint vertex is in hidden variables (= not in observed_vars)
|
|
946
|
+
here iff (i, -|taui-tauj|) --> j in graph
|
|
947
|
+
(ii) auxADMG contains (i, -taui) <-> (j, -tauj) iff there exists a path of the
|
|
948
|
+
form (i, -taui) <-- ... --> (j, -tauj) on
|
|
949
|
+
which every non-endpoint vertex is non-collider AND in L (=not in observed_vars)
|
|
950
|
+
here iff (i, -|taui-tauj|) <-> j OR there is path
|
|
951
|
+
(i, -taui) <-- nodes before t-tau_max --> (j, -tauj)
|
|
952
|
+
"""
|
|
953
|
+
|
|
954
|
+
# graph = self.graph
|
|
955
|
+
|
|
956
|
+
# if self.hidden_variables is None:
|
|
957
|
+
# hidden_variables_here = []
|
|
958
|
+
# else:
|
|
959
|
+
hidden_variables_here = self.hidden_variables
|
|
960
|
+
|
|
961
|
+
aux_graph = np.zeros((self.N, self.N, self.tau_max + 1, self.tau_max + 1), dtype='<U3')
|
|
962
|
+
aux_graph[:] = ""
|
|
963
|
+
for (i, j) in itertools.product(range(self.N), range(self.N)):
|
|
964
|
+
for jt, tauj in enumerate(range(0, self.tau_max + 1)):
|
|
965
|
+
for it, taui in enumerate(range(0, self.tau_max + 1)):
|
|
966
|
+
tau = abs(taui - tauj)
|
|
967
|
+
if tau == 0 and j == i:
|
|
968
|
+
continue
|
|
969
|
+
if (i, -taui) in hidden_variables_here or (j, -tauj) in hidden_variables_here:
|
|
970
|
+
continue
|
|
971
|
+
# print("\n")
|
|
972
|
+
# print((i, -taui), (j, -tauj))
|
|
973
|
+
|
|
974
|
+
cond_i_xy = (
|
|
975
|
+
# tau <= graph_taumax
|
|
976
|
+
# and (graph[i, j, tau] == '-->' or graph[i, j, tau] == '+->')
|
|
977
|
+
# )
|
|
978
|
+
# and
|
|
979
|
+
self._check_path( #graph=graph,
|
|
980
|
+
start=[(i, -taui)],
|
|
981
|
+
end=[(j, -tauj)],
|
|
982
|
+
conditions=None,
|
|
983
|
+
starts_with=['-*>', '+*>'],
|
|
984
|
+
ends_with=['-*>', '+*>'],
|
|
985
|
+
path_type='causal',
|
|
986
|
+
hidden_by_taumax=False,
|
|
987
|
+
hidden_variables=hidden_variables_here,
|
|
988
|
+
stationary_graph=stationary,
|
|
989
|
+
))
|
|
990
|
+
cond_i_yx = (
|
|
991
|
+
# tau <= graph_taumax
|
|
992
|
+
# and (graph[i, j, tau] == '<--' or graph[i, j, tau] == '<-+')
|
|
993
|
+
# )
|
|
994
|
+
# and
|
|
995
|
+
self._check_path( #graph=graph,
|
|
996
|
+
start=[(j, -tauj)],
|
|
997
|
+
end=[(i, -taui)],
|
|
998
|
+
conditions=None,
|
|
999
|
+
starts_with=['-*>', '+*>'],
|
|
1000
|
+
ends_with=['-*>', '+*>'],
|
|
1001
|
+
path_type='causal',
|
|
1002
|
+
hidden_by_taumax=False,
|
|
1003
|
+
hidden_variables=hidden_variables_here,
|
|
1004
|
+
stationary_graph=stationary,
|
|
1005
|
+
))
|
|
1006
|
+
if stationary:
|
|
1007
|
+
hidden_by_taumax_here = True
|
|
1008
|
+
else:
|
|
1009
|
+
hidden_by_taumax_here = False
|
|
1010
|
+
|
|
1011
|
+
cond_ii = (
|
|
1012
|
+
# tau <= graph_taumax
|
|
1013
|
+
# and
|
|
1014
|
+
(
|
|
1015
|
+
# graph[i, j, tau] == '<->'
|
|
1016
|
+
# or graph[i, j, tau] == '+->' or graph[i, j, tau] == '<-+'))
|
|
1017
|
+
self._check_path( #graph=graph,
|
|
1018
|
+
start=[(i, -taui)],
|
|
1019
|
+
end=[(j, -tauj)],
|
|
1020
|
+
conditions=None,
|
|
1021
|
+
starts_with=['<**', '+**'],
|
|
1022
|
+
ends_with=['**>', '**+'],
|
|
1023
|
+
path_type='any',
|
|
1024
|
+
hidden_by_taumax=hidden_by_taumax_here,
|
|
1025
|
+
hidden_variables=hidden_variables_here,
|
|
1026
|
+
stationary_graph=stationary,
|
|
1027
|
+
)))
|
|
1028
|
+
|
|
1029
|
+
if cond_i_xy and not cond_i_yx and not cond_ii:
|
|
1030
|
+
aux_graph[i, j, taui, tauj] = "-->" #graph[i, j, tau]
|
|
1031
|
+
# if tau == 0:
|
|
1032
|
+
aux_graph[j, i, tauj, taui] = "<--" # graph[j, i, tau]
|
|
1033
|
+
elif not cond_i_xy and cond_i_yx and not cond_ii:
|
|
1034
|
+
aux_graph[i, j, taui, tauj] = "<--" #graph[i, j, tau]
|
|
1035
|
+
# if tau == 0:
|
|
1036
|
+
aux_graph[j, i, tauj, taui] = "-->" # graph[j, i, tau]
|
|
1037
|
+
elif not cond_i_xy and not cond_i_yx and cond_ii:
|
|
1038
|
+
aux_graph[i, j, taui, tauj] = '<->'
|
|
1039
|
+
# if tau == 0:
|
|
1040
|
+
aux_graph[j, i, tauj, taui] = '<->'
|
|
1041
|
+
elif cond_i_xy and not cond_i_yx and cond_ii:
|
|
1042
|
+
aux_graph[i, j, taui, tauj] = '+->'
|
|
1043
|
+
# if tau == 0:
|
|
1044
|
+
aux_graph[j, i, tauj, taui] = '<-+'
|
|
1045
|
+
elif not cond_i_xy and cond_i_yx and cond_ii:
|
|
1046
|
+
aux_graph[i, j, taui, tauj] = '<-+'
|
|
1047
|
+
# if tau == 0:
|
|
1048
|
+
aux_graph[j, i, tauj, taui] = '+->'
|
|
1049
|
+
elif cond_i_xy and cond_i_yx:
|
|
1050
|
+
raise ValueError("Cycle between %s and %s!" %(str(i, -taui), str(j, -tauj)))
|
|
1051
|
+
# print(aux_graph[i, j, taui, tauj])
|
|
1052
|
+
|
|
1053
|
+
# print((i, -taui), (j, -tauj), cond_i_xy, cond_i_yx, cond_ii, aux_graph[i, j, taui, tauj], aux_graph[j, i, tauj, taui])
|
|
1054
|
+
|
|
1055
|
+
return aux_graph
|
|
1056
|
+
|
|
1057
|
+
def _check_path(self,
|
|
1058
|
+
# graph,
|
|
1059
|
+
start, end,
|
|
1060
|
+
conditions=None,
|
|
1061
|
+
starts_with=None,
|
|
1062
|
+
ends_with=None,
|
|
1063
|
+
path_type='any',
|
|
1064
|
+
# causal_children=None,
|
|
1065
|
+
stationary_graph=False,
|
|
1066
|
+
hidden_by_taumax=False,
|
|
1067
|
+
hidden_variables=None,
|
|
1068
|
+
):
|
|
1069
|
+
"""Check whether an open/active path between start and end given conditions exists.
|
|
1070
|
+
|
|
1071
|
+
Also allows to restrict start and end patterns and to consider causal/non-causal paths
|
|
1072
|
+
|
|
1073
|
+
hidden_by_taumax and hidden_variables are relevant for the latent projection operation.
|
|
1074
|
+
"""
|
|
1075
|
+
|
|
1076
|
+
|
|
1077
|
+
if conditions is None:
|
|
1078
|
+
conditions = set()
|
|
1079
|
+
# if conditioned_variables is None:
|
|
1080
|
+
# S = []
|
|
1081
|
+
|
|
1082
|
+
start = set(start)
|
|
1083
|
+
end = set(end)
|
|
1084
|
+
conditions = set(conditions)
|
|
1085
|
+
|
|
1086
|
+
# Get maximal possible time lag of a connecting path
|
|
1087
|
+
# See Thm. XXXX - TO BE REVISED!
|
|
1088
|
+
XYZ = start.union(end).union(conditions)
|
|
1089
|
+
if stationary_graph:
|
|
1090
|
+
max_lag = 10*self.tau_max # TO BE REVISED! self._get_maximum_possible_lag(XYZ, self.graph)
|
|
1091
|
+
causal_children = list(self._get_mediators_stationary_graph(start, end, max_lag).union(end))
|
|
1092
|
+
else:
|
|
1093
|
+
max_lag = None
|
|
1094
|
+
causal_children = list(self.get_mediators(start, end).union(end))
|
|
1095
|
+
|
|
1096
|
+
# if hidden_variables is None:
|
|
1097
|
+
# hidden_variables = set()
|
|
1098
|
+
|
|
1099
|
+
if hidden_by_taumax:
|
|
1100
|
+
if hidden_variables is None:
|
|
1101
|
+
hidden_variables = set()
|
|
1102
|
+
hidden_variables = hidden_variables.union([(k, -tauk) for k in range(self.N)
|
|
1103
|
+
for tauk in range(self.tau_max+1, max_lag + 1)])
|
|
1104
|
+
|
|
1105
|
+
# print("causal_children ", causal_children)
|
|
1106
|
+
|
|
1107
|
+
if starts_with is None:
|
|
1108
|
+
starts_with = ['***']
|
|
1109
|
+
elif type(starts_with) == str:
|
|
1110
|
+
starts_with = [starts_with]
|
|
1111
|
+
|
|
1112
|
+
if ends_with is None:
|
|
1113
|
+
ends_with = ['***']
|
|
1114
|
+
elif type(ends_with) == str:
|
|
1115
|
+
ends_with = [ends_with]
|
|
1116
|
+
#
|
|
1117
|
+
# Breadth-first search to find connection
|
|
1118
|
+
#
|
|
1119
|
+
# print("\nstart, starts_with, ends_with, end ", start, starts_with, ends_with, end)
|
|
1120
|
+
# print("hidden_variables ", hidden_variables)
|
|
1121
|
+
start_from = deque()
|
|
1122
|
+
for x in start:
|
|
1123
|
+
if stationary_graph:
|
|
1124
|
+
link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=x, patterns=starts_with,
|
|
1125
|
+
max_lag=max_lag, exclude=list(start))
|
|
1126
|
+
else:
|
|
1127
|
+
link_neighbors = self._find_adj(node=x, patterns=starts_with, exclude=list(start), return_link=True)
|
|
1128
|
+
|
|
1129
|
+
for link_neighbor in link_neighbors:
|
|
1130
|
+
link, neighbor = link_neighbor
|
|
1131
|
+
|
|
1132
|
+
# if before_taumax and neighbor[1] >= -self.tau_max:
|
|
1133
|
+
# continue
|
|
1134
|
+
|
|
1135
|
+
if (hidden_variables is not None and neighbor not in end
|
|
1136
|
+
and neighbor not in hidden_variables):
|
|
1137
|
+
continue
|
|
1138
|
+
|
|
1139
|
+
if path_type == 'non_causal':
|
|
1140
|
+
if (neighbor in causal_children and self._match_link('-*>', link)
|
|
1141
|
+
and not self._match_link('+*>', link)):
|
|
1142
|
+
continue
|
|
1143
|
+
elif path_type == 'causal':
|
|
1144
|
+
if (neighbor not in causal_children): # or self._match_link('<**', link)):
|
|
1145
|
+
continue
|
|
1146
|
+
start_from.append((x, link, neighbor))
|
|
1147
|
+
|
|
1148
|
+
# print("start, end, start_from ", start, end, start_from)
|
|
1149
|
+
|
|
1150
|
+
visited = set()
|
|
1151
|
+
for (varlag_i, link_ik, varlag_k) in start_from:
|
|
1152
|
+
visited.add((link_ik, varlag_k))
|
|
1153
|
+
|
|
1154
|
+
# Traversing through motifs i *-* k *-* j
|
|
1155
|
+
while start_from:
|
|
1156
|
+
|
|
1157
|
+
# print("Continue ", start_from)
|
|
1158
|
+
varlag_i, link_ik, varlag_k = start_from.popleft()
|
|
1159
|
+
|
|
1160
|
+
# Check if we reached the end
|
|
1161
|
+
if varlag_k in end:
|
|
1162
|
+
if any(self._match_link(patt, link_ik) for patt in ends_with):
|
|
1163
|
+
# print("Connected ", varlag_i, link_ik, varlag_k)
|
|
1164
|
+
return True
|
|
1165
|
+
else:
|
|
1166
|
+
continue
|
|
1167
|
+
|
|
1168
|
+
# print("Get k = ", link_ik, varlag_k)
|
|
1169
|
+
# print("start_from ", start_from)
|
|
1170
|
+
# print("visited ", visited)
|
|
1171
|
+
|
|
1172
|
+
if stationary_graph:
|
|
1173
|
+
link_neighbors = self._get_adjacents_stationary_graph(graph=self.graph, node=varlag_k, patterns='***',
|
|
1174
|
+
max_lag=max_lag, exclude=list(start))
|
|
1175
|
+
else:
|
|
1176
|
+
link_neighbors = self._find_adj(node=varlag_k, patterns='***', exclude=list(start), return_link=True)
|
|
1177
|
+
|
|
1178
|
+
# print("link_neighbors ", link_neighbors)
|
|
1179
|
+
for link_neighbor in link_neighbors:
|
|
1180
|
+
link_kj, varlag_j = link_neighbor
|
|
1181
|
+
# print("Walk ", link_ik, varlag_k, link_kj, varlag_j)
|
|
1182
|
+
|
|
1183
|
+
# print ("visited ", (link_kj, varlag_j), visited)
|
|
1184
|
+
if (link_kj, varlag_j) in visited:
|
|
1185
|
+
# if (varlag_i, link_kj, varlag_j) in visited:
|
|
1186
|
+
# print("in visited")
|
|
1187
|
+
continue
|
|
1188
|
+
# print("Not in visited")
|
|
1189
|
+
|
|
1190
|
+
if path_type == 'causal':
|
|
1191
|
+
if not (self._match_link('-*>', link_kj) or self._match_link('+*>', link_kj)):
|
|
1192
|
+
continue
|
|
1193
|
+
|
|
1194
|
+
# If motif i *-* k *-* j is open,
|
|
1195
|
+
# then add link_kj, varlag_j to visited and start_from
|
|
1196
|
+
left_mark = link_ik[2]
|
|
1197
|
+
right_mark = link_kj[0]
|
|
1198
|
+
# print(left_mark, right_mark)
|
|
1199
|
+
|
|
1200
|
+
if self.definite_status:
|
|
1201
|
+
# Exclude paths that are not definite_status implying that any of the following
|
|
1202
|
+
# motifs occurs:
|
|
1203
|
+
# i *-> k o-* j
|
|
1204
|
+
if (left_mark == '>' and right_mark == 'o'):
|
|
1205
|
+
continue
|
|
1206
|
+
# i *-o k <-* j
|
|
1207
|
+
if (left_mark == 'o' and right_mark == '<'):
|
|
1208
|
+
continue
|
|
1209
|
+
# i *-o k o-* j and i and j are adjacent
|
|
1210
|
+
if (left_mark == 'o' and right_mark == 'o'
|
|
1211
|
+
and self._is_match(varlag_i, varlag_j, "***")):
|
|
1212
|
+
continue
|
|
1213
|
+
|
|
1214
|
+
# If k is in conditions and motif is *-o k o-*, then motif is blocked since
|
|
1215
|
+
# i and j are non-adjacent due to the check above
|
|
1216
|
+
if varlag_k in conditions and (left_mark == 'o' and right_mark == 'o'):
|
|
1217
|
+
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
|
|
1218
|
+
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
|
|
1219
|
+
|
|
1220
|
+
# If k is in conditions and left or right mark is tail '-', then motif is blocked
|
|
1221
|
+
if varlag_k in conditions and (left_mark == '-' or right_mark == '-'):
|
|
1222
|
+
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
|
|
1223
|
+
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
|
|
1224
|
+
|
|
1225
|
+
# If k is not in conditions and left and right mark are heads '><', then motif is blocked
|
|
1226
|
+
if varlag_k not in conditions and (left_mark == '>' and right_mark == '<'):
|
|
1227
|
+
# print("Motif closed ", link_ik, varlag_k, link_kj, varlag_j )
|
|
1228
|
+
continue # [('>', '<'), ('>', '+'), ('+', '<'), ('+', '+')]
|
|
1229
|
+
|
|
1230
|
+
# if (before_taumax and varlag_j not in end
|
|
1231
|
+
# and varlag_j[1] >= -self.tau_max):
|
|
1232
|
+
# # print("before_taumax ", varlag_j)
|
|
1233
|
+
# continue
|
|
1234
|
+
|
|
1235
|
+
if (hidden_variables is not None and varlag_j not in end
|
|
1236
|
+
and varlag_j not in hidden_variables):
|
|
1237
|
+
continue
|
|
1238
|
+
|
|
1239
|
+
# Motif is open
|
|
1240
|
+
# print("Motif open ", link_ik, varlag_k, link_kj, varlag_j )
|
|
1241
|
+
# start_from.add((link_kj, varlag_j))
|
|
1242
|
+
visited.add((link_kj, varlag_j))
|
|
1243
|
+
start_from.append((varlag_k, link_kj, varlag_j))
|
|
1244
|
+
# visited.add((varlag_k, link_kj, varlag_j))
|
|
1245
|
+
|
|
1246
|
+
|
|
1247
|
+
# print("Separated")
|
|
1248
|
+
return False
|
|
1249
|
+
|
|
1250
|
+
|
|
1251
|
+
def _get_causal_paths(self, source_nodes, target_nodes,
|
|
1252
|
+
mediators=None,
|
|
1253
|
+
mediated_through=None,
|
|
1254
|
+
proper_paths=True,
|
|
1255
|
+
):
|
|
1256
|
+
"""Returns causal paths via depth-first search.
|
|
1257
|
+
|
|
1258
|
+
Allows to restrict paths through mediated_through.
|
|
1259
|
+
|
|
1260
|
+
"""
|
|
1261
|
+
|
|
1262
|
+
source_nodes = set(source_nodes)
|
|
1263
|
+
target_nodes = set(target_nodes)
|
|
1264
|
+
|
|
1265
|
+
if mediators is None:
|
|
1266
|
+
mediators = set()
|
|
1267
|
+
else:
|
|
1268
|
+
mediators = set(mediators)
|
|
1269
|
+
|
|
1270
|
+
if mediated_through is None:
|
|
1271
|
+
mediated_through = []
|
|
1272
|
+
mediated_through = set(mediated_through)
|
|
1273
|
+
|
|
1274
|
+
if proper_paths:
|
|
1275
|
+
inside_set = mediators.union(target_nodes) - source_nodes
|
|
1276
|
+
else:
|
|
1277
|
+
inside_set = mediators.union(target_nodes).union(source_nodes)
|
|
1278
|
+
|
|
1279
|
+
all_causal_paths = {}
|
|
1280
|
+
for w in source_nodes:
|
|
1281
|
+
all_causal_paths[w] = {}
|
|
1282
|
+
for z in target_nodes:
|
|
1283
|
+
all_causal_paths[w][z] = []
|
|
1284
|
+
|
|
1285
|
+
for w in source_nodes:
|
|
1286
|
+
|
|
1287
|
+
causal_path = []
|
|
1288
|
+
queue = [(w, causal_path)]
|
|
1289
|
+
|
|
1290
|
+
while queue:
|
|
1291
|
+
|
|
1292
|
+
varlag, causal_path = queue.pop()
|
|
1293
|
+
causal_path = causal_path + [varlag]
|
|
1294
|
+
suitable_nodes = set(self._get_children(varlag)
|
|
1295
|
+
).intersection(inside_set)
|
|
1296
|
+
for node in suitable_nodes:
|
|
1297
|
+
i, tau = node
|
|
1298
|
+
if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
|
|
1299
|
+
and node not in causal_path):
|
|
1300
|
+
|
|
1301
|
+
queue.append((node, causal_path))
|
|
1302
|
+
|
|
1303
|
+
if node in target_nodes:
|
|
1304
|
+
if len(mediated_through) > 0 and len(set(causal_path).intersection(mediated_through)) == 0:
|
|
1305
|
+
continue
|
|
1306
|
+
else:
|
|
1307
|
+
all_causal_paths[w][node].append(causal_path + [node])
|
|
1308
|
+
|
|
1309
|
+
return all_causal_paths
|
|
1310
|
+
|
|
1311
|
+
|
|
1312
|
+
def _get_adjacency(self, varleg):
|
|
1313
|
+
"""
|
|
1314
|
+
Get all the nodes that are connect to varleg by any sort of edge. No pattern matching needed.
|
|
1315
|
+
"""
|
|
1316
|
+
var, leg = varleg
|
|
1317
|
+
adjacency_matrix = self.graph[var][:, abs(leg)]
|
|
1318
|
+
adjacency = [(i, -j) for i, j in zip(*np.where(adjacency_matrix != ''))]
|
|
1319
|
+
return adjacency
|
|
1320
|
+
|
|
1321
|
+
|
|
1322
|
+
def _edge_is_visible(self, edge):
|
|
1323
|
+
"""
|
|
1324
|
+
This function returns true if the edge is visible. In a DAG or CPDAG all directed edges are visible. In a MAG
|
|
1325
|
+
PAG an edge X-->Y is visible if there exists a node V not adjacent to Y such that there is a collider path from
|
|
1326
|
+
V into X where all (possibly zero) mediators are parents of Y.
|
|
1327
|
+
|
|
1328
|
+
Parameters
|
|
1329
|
+
----------
|
|
1330
|
+
edge : an ordered pair of nodes X and Y, each consisting of their variable index and time lag
|
|
1331
|
+
"""
|
|
1332
|
+
(x, xlag), (y, ylag) = edge
|
|
1333
|
+
# print(x, xlag, y, ylag)
|
|
1334
|
+
# only directed edges can be visible
|
|
1335
|
+
if self.graph[x][y][abs(xlag)][abs(ylag)] != '-->':
|
|
1336
|
+
return False
|
|
1337
|
+
# currently not used condition for cpdags
|
|
1338
|
+
if 'dag' in self.graph_type:
|
|
1339
|
+
return True
|
|
1340
|
+
# get all nodes on collider paths in the parents of Y and going into X
|
|
1341
|
+
y_parents = set(self._get_parents(edge[1]))
|
|
1342
|
+
collider_path_nodes = self._get_collider_path_nodes([edge[0]], y_parents, with_parents=False)
|
|
1343
|
+
if collider_path_nodes == set():
|
|
1344
|
+
#make sure that the lag of y is given as a negative number
|
|
1345
|
+
collider_path_nodes.add((x,- abs(xlag)))
|
|
1346
|
+
# chack if non-adjacent exists with edge into X or the collider path
|
|
1347
|
+
for node in collider_path_nodes:
|
|
1348
|
+
if not (set(self._get_spouses(node)).union(set(self._get_parents(node))) - set(self._get_adjacency(edge[1]))
|
|
1349
|
+
== set()):
|
|
1350
|
+
return True
|
|
1351
|
+
#chack if such a node is not a parent of Y
|
|
1352
|
+
return False
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
def is_amenable(self, source_nodes, target_nodes) -> bool:
|
|
1356
|
+
"""
|
|
1357
|
+
A function that returns whether the Graph is amenable with respect to source_nodes and target_nodes.
|
|
1358
|
+
|
|
1359
|
+
The graph is amenable with respect to source_nodes and target_nodes if every first edge on every
|
|
1360
|
+
proper possibly directed path from a source to a target node is visible.
|
|
1361
|
+
|
|
1362
|
+
Parameters
|
|
1363
|
+
----------
|
|
1364
|
+
source_nodes : list of nodes (tupels containing a varable index in [0,self.N) and a time lag in [-self.tau_max, 0]
|
|
1365
|
+
target_nodes : list of nodes (tupels containing a varable index in [0,self.N) and a time lag in [-self.tau_max, 0]
|
|
1366
|
+
"""
|
|
1367
|
+
if self.graph_type in ['dag', 'tsg_dag', 'stationary_dag']:
|
|
1368
|
+
return True
|
|
1369
|
+
all_nodes = [(i, -tau) for i in range(self.N) for tau in range(self.tau_max + 1)]
|
|
1370
|
+
paths = self._get_causal_paths(source_nodes, target_nodes, mediators=all_nodes)
|
|
1371
|
+
first_edges = set()
|
|
1372
|
+
for source in source_nodes:
|
|
1373
|
+
for target in target_nodes:
|
|
1374
|
+
for path in paths[source][target]:
|
|
1375
|
+
first_edges.add((path[0], path[1]))
|
|
1376
|
+
for edge in first_edges:
|
|
1377
|
+
if not self._edge_is_visible(edge):
|
|
1378
|
+
return False
|
|
1379
|
+
return True
|
|
1380
|
+
|
|
1381
|
+
|
|
1382
|
+
@staticmethod
|
|
1383
|
+
def get_dict_from_graph(graph, parents_only=False):
|
|
1384
|
+
"""Helper function to convert graph to dictionary of links.
|
|
1385
|
+
|
|
1386
|
+
Parameters
|
|
1387
|
+
---------
|
|
1388
|
+
graph : array of shape (N, N, tau_max+1)
|
|
1389
|
+
Matrix format of graph in string format.
|
|
1390
|
+
|
|
1391
|
+
parents_only : bool
|
|
1392
|
+
Whether to only return parents ('-->' in graph)
|
|
1393
|
+
|
|
1394
|
+
Returns
|
|
1395
|
+
-------
|
|
1396
|
+
links : dict
|
|
1397
|
+
Dictionary of form {0:{(0, -1): o-o, ...}, 1:{...}, ...}.
|
|
1398
|
+
"""
|
|
1399
|
+
N = graph.shape[0]
|
|
1400
|
+
|
|
1401
|
+
links = dict([(j, {}) for j in range(N)])
|
|
1402
|
+
|
|
1403
|
+
if parents_only:
|
|
1404
|
+
for (i, j, tau) in zip(*np.where(graph=='-->')):
|
|
1405
|
+
links[j][(i, -tau)] = graph[i,j,tau]
|
|
1406
|
+
else:
|
|
1407
|
+
for (i, j, tau) in zip(*np.where(graph!='')):
|
|
1408
|
+
links[j][(i, -tau)] = graph[i,j,tau]
|
|
1409
|
+
|
|
1410
|
+
return links
|
|
1411
|
+
|
|
1412
|
+
@staticmethod
|
|
1413
|
+
def get_graph_from_dict(links, tau_max=None):
|
|
1414
|
+
"""Helper function to convert dictionary of links to graph array format.
|
|
1415
|
+
|
|
1416
|
+
Parameters
|
|
1417
|
+
---------
|
|
1418
|
+
links : dict
|
|
1419
|
+
Dictionary of form {0:[((0, -1), coeff, func), ...], 1:[...], ...}.
|
|
1420
|
+
Also format {0:[(0, -1), ...], 1:[...], ...} is allowed.
|
|
1421
|
+
tau_max : int or None
|
|
1422
|
+
Maximum lag. If None, the maximum lag in links is used.
|
|
1423
|
+
|
|
1424
|
+
Returns
|
|
1425
|
+
-------
|
|
1426
|
+
graph : array of shape (N, N, tau_max+1)
|
|
1427
|
+
Matrix format of graph with 1 for true links and 0 else.
|
|
1428
|
+
"""
|
|
1429
|
+
|
|
1430
|
+
def _get_minmax_lag(links):
|
|
1431
|
+
"""Helper function to retrieve tau_min and tau_max from links.
|
|
1432
|
+
"""
|
|
1433
|
+
|
|
1434
|
+
N = len(links)
|
|
1435
|
+
|
|
1436
|
+
# Get maximum time lag
|
|
1437
|
+
min_lag = np.inf
|
|
1438
|
+
max_lag = 0
|
|
1439
|
+
for j in range(N):
|
|
1440
|
+
for link_props in links[j]:
|
|
1441
|
+
if len(link_props) > 2:
|
|
1442
|
+
var, lag = link_props[0]
|
|
1443
|
+
coeff = link_props[1]
|
|
1444
|
+
# func = link_props[2]
|
|
1445
|
+
if coeff != 0.:
|
|
1446
|
+
min_lag = min(min_lag, abs(lag))
|
|
1447
|
+
max_lag = max(max_lag, abs(lag))
|
|
1448
|
+
else:
|
|
1449
|
+
var, lag = link_props
|
|
1450
|
+
min_lag = min(min_lag, abs(lag))
|
|
1451
|
+
max_lag = max(max_lag, abs(lag))
|
|
1452
|
+
|
|
1453
|
+
return min_lag, max_lag
|
|
1454
|
+
|
|
1455
|
+
N = len(links)
|
|
1456
|
+
|
|
1457
|
+
# Get maximum time lag
|
|
1458
|
+
min_lag, max_lag = _get_minmax_lag(links)
|
|
1459
|
+
|
|
1460
|
+
# Set maximum lag
|
|
1461
|
+
if tau_max is None:
|
|
1462
|
+
tau_max = max_lag
|
|
1463
|
+
else:
|
|
1464
|
+
if max_lag > tau_max:
|
|
1465
|
+
raise ValueError("tau_max is smaller than maximum lag = %d "
|
|
1466
|
+
"found in links, use tau_max=None or larger "
|
|
1467
|
+
"value" % max_lag)
|
|
1468
|
+
|
|
1469
|
+
graph = np.zeros((N, N, tau_max + 1), dtype='<U3')
|
|
1470
|
+
for j in links.keys():
|
|
1471
|
+
for link_props in links[j]:
|
|
1472
|
+
if len(link_props) > 2:
|
|
1473
|
+
var, lag = link_props[0]
|
|
1474
|
+
coeff = link_props[1]
|
|
1475
|
+
if coeff != 0.:
|
|
1476
|
+
graph[var, j, abs(lag)] = "-->"
|
|
1477
|
+
if lag == 0:
|
|
1478
|
+
graph[j, var, 0] = "<--"
|
|
1479
|
+
else:
|
|
1480
|
+
var, lag = link_props
|
|
1481
|
+
graph[var, j, abs(lag)] = "-->"
|
|
1482
|
+
if lag == 0:
|
|
1483
|
+
graph[j, var, 0] = "<--"
|
|
1484
|
+
|
|
1485
|
+
return graph
|
|
1486
|
+
|
|
1487
|
+
|
|
1488
|
+
if __name__ == '__main__':
|
|
1489
|
+
|
|
1490
|
+
# Consider some toy data
|
|
1491
|
+
import tigramite
|
|
1492
|
+
import tigramite.toymodels.structural_causal_processes as toys
|
|
1493
|
+
import tigramite.data_processing as pp
|
|
1494
|
+
import tigramite.plotting as tp
|
|
1495
|
+
from matplotlib import pyplot as plt
|
|
1496
|
+
import sys
|
|
1497
|
+
|
|
1498
|
+
# # Use staticmethod to get graph
|
|
1499
|
+
graph = np.array([['', '-->'],
|
|
1500
|
+
['<--', '']], dtype='<U3')
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
# # Initialize class as `stationary_dag`
|
|
1504
|
+
causal_effects = Graphs(graph, graph_type='dag', tau_max=0,
|
|
1505
|
+
verbosity=1)
|
|
1506
|
+
|
|
1507
|
+
|
|
1508
|
+
|
|
1509
|
+
|