tigramite-fast 5.2.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigramite/__init__.py +0 -0
- tigramite/causal_effects.py +1525 -0
- tigramite/causal_mediation.py +1592 -0
- tigramite/data_processing.py +1574 -0
- tigramite/graphs.py +1509 -0
- tigramite/independence_tests/LBFGS.py +1114 -0
- tigramite/independence_tests/__init__.py +0 -0
- tigramite/independence_tests/cmiknn.py +661 -0
- tigramite/independence_tests/cmiknn_mixed.py +1397 -0
- tigramite/independence_tests/cmisymb.py +286 -0
- tigramite/independence_tests/gpdc.py +664 -0
- tigramite/independence_tests/gpdc_torch.py +820 -0
- tigramite/independence_tests/gsquared.py +190 -0
- tigramite/independence_tests/independence_tests_base.py +1310 -0
- tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
- tigramite/independence_tests/pairwise_CI.py +383 -0
- tigramite/independence_tests/parcorr.py +369 -0
- tigramite/independence_tests/parcorr_mult.py +485 -0
- tigramite/independence_tests/parcorr_wls.py +451 -0
- tigramite/independence_tests/regressionCI.py +403 -0
- tigramite/independence_tests/robust_parcorr.py +403 -0
- tigramite/jpcmciplus.py +966 -0
- tigramite/lpcmci.py +3649 -0
- tigramite/models.py +2257 -0
- tigramite/pcmci.py +3935 -0
- tigramite/pcmci_base.py +1218 -0
- tigramite/plotting.py +4735 -0
- tigramite/rpcmci.py +467 -0
- tigramite/toymodels/__init__.py +0 -0
- tigramite/toymodels/context_model.py +261 -0
- tigramite/toymodels/non_additive.py +1231 -0
- tigramite/toymodels/structural_causal_processes.py +1201 -0
- tigramite/toymodels/surrogate_generator.py +319 -0
- tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
- tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
- tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
- tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
- tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1525 @@
|
|
|
1
|
+
"""Tigramite causal inference for time series."""
|
|
2
|
+
|
|
3
|
+
# Author: Jakob Runge <jakob@jakob-runge.com>
|
|
4
|
+
#
|
|
5
|
+
# License: GNU General Public License v3.0
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
import math
|
|
9
|
+
import itertools
|
|
10
|
+
from copy import deepcopy
|
|
11
|
+
from collections import defaultdict
|
|
12
|
+
from tigramite.models import Models
|
|
13
|
+
from tigramite.graphs import Graphs
|
|
14
|
+
import struct
|
|
15
|
+
|
|
16
|
+
class CausalEffects(Graphs):
|
|
17
|
+
r"""Causal effect estimation.
|
|
18
|
+
|
|
19
|
+
Methods for the estimation of linear or non-parametric causal effects
|
|
20
|
+
between (potentially multivariate) X and Y (potentially conditional
|
|
21
|
+
on S) by (generalized) backdoor adjustment. Various graph types are
|
|
22
|
+
supported, also including hidden variables.
|
|
23
|
+
|
|
24
|
+
Linear and non-parametric estimators are based on sklearn. For the
|
|
25
|
+
linear case without hidden variables also an efficient estimation
|
|
26
|
+
based on Wright's path coefficients is available. This estimator
|
|
27
|
+
also allows to estimate mediation effects.
|
|
28
|
+
|
|
29
|
+
See the corresponding paper [6]_ and tigramite tutorial for an
|
|
30
|
+
in-depth introduction.
|
|
31
|
+
|
|
32
|
+
References
|
|
33
|
+
----------
|
|
34
|
+
|
|
35
|
+
.. [6] J. Runge, Necessary and sufficient graphical conditions for
|
|
36
|
+
optimal adjustment sets in causal graphical models with
|
|
37
|
+
hidden variables, Advances in Neural Information Processing
|
|
38
|
+
Systems, 2021, 34
|
|
39
|
+
https://proceedings.neurips.cc/paper/2021/hash/8485ae387a981d783f8764e508151cd9-Abstract.html
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
Parameters
|
|
43
|
+
----------
|
|
44
|
+
graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]
|
|
45
|
+
Different graph types are supported, see tutorial.
|
|
46
|
+
X : list of tuples
|
|
47
|
+
List of tuples [(i, -tau), ...] containing cause variables.
|
|
48
|
+
Y : list of tuples
|
|
49
|
+
List of tuples [(j, 0), ...] containing effect variables.
|
|
50
|
+
S : list of tuples
|
|
51
|
+
List of tuples [(i, -tau), ...] containing conditioned variables.
|
|
52
|
+
graph_type : str
|
|
53
|
+
Type of graph.
|
|
54
|
+
hidden_variables : list of tuples
|
|
55
|
+
Hidden variables in format [(i, -tau), ...]. The internal graph is
|
|
56
|
+
constructed by a latent projection.
|
|
57
|
+
check_SM_overlap : bool
|
|
58
|
+
Whether to check whether S overlaps with M.
|
|
59
|
+
verbosity : int, optional (default: 0)
|
|
60
|
+
Level of verbosity.
|
|
61
|
+
"""
|
|
62
|
+
|
|
63
|
+
def __init__(self,
|
|
64
|
+
graph,
|
|
65
|
+
graph_type,
|
|
66
|
+
X,
|
|
67
|
+
Y,
|
|
68
|
+
S=None,
|
|
69
|
+
hidden_variables=None,
|
|
70
|
+
check_SM_overlap=True,
|
|
71
|
+
verbosity=0):
|
|
72
|
+
|
|
73
|
+
self.verbosity = verbosity
|
|
74
|
+
self.N = graph.shape[0]
|
|
75
|
+
|
|
76
|
+
if S is None:
|
|
77
|
+
S = []
|
|
78
|
+
|
|
79
|
+
self.listX = list(X)
|
|
80
|
+
self.listY = list(Y)
|
|
81
|
+
self.listS = list(S)
|
|
82
|
+
|
|
83
|
+
self.X = set(X)
|
|
84
|
+
self.Y = set(Y)
|
|
85
|
+
self.S = set(S)
|
|
86
|
+
|
|
87
|
+
# TODO?: check that masking aligns with hidden samples in variables
|
|
88
|
+
if hidden_variables is None:
|
|
89
|
+
hidden_variables = []
|
|
90
|
+
|
|
91
|
+
#
|
|
92
|
+
# Checks regarding graph type
|
|
93
|
+
#
|
|
94
|
+
supported_graphs = ['dag',
|
|
95
|
+
'admg',
|
|
96
|
+
'tsg_dag',
|
|
97
|
+
'tsg_admg',
|
|
98
|
+
'stationary_dag',
|
|
99
|
+
'stationary_admg',
|
|
100
|
+
|
|
101
|
+
# 'mag',
|
|
102
|
+
# 'tsg_mag',
|
|
103
|
+
# 'stationary_mag',
|
|
104
|
+
# 'pag',
|
|
105
|
+
# 'tsg_pag',
|
|
106
|
+
# 'stationary_pag',
|
|
107
|
+
]
|
|
108
|
+
if graph_type not in supported_graphs:
|
|
109
|
+
raise ValueError("Only graph types %s supported!" %supported_graphs)
|
|
110
|
+
|
|
111
|
+
# Determine tau_max
|
|
112
|
+
if graph_type in ['dag', 'admg']:
|
|
113
|
+
self.tau_max = 0
|
|
114
|
+
|
|
115
|
+
elif graph_type in ['tsg_dag', 'tsg_admg']:
|
|
116
|
+
# tau_max is implicitely derived from
|
|
117
|
+
# the dimensions
|
|
118
|
+
self.tau_max = graph.shape[2] - 1
|
|
119
|
+
|
|
120
|
+
elif graph_type in ['stationary_dag', 'stationary_admg']:
|
|
121
|
+
# For a stationary DAG without hidden variables it's sufficient to consider
|
|
122
|
+
# a tau_max that includes the parents of X, Y, M, and S. A conservative
|
|
123
|
+
# estimate thereof is simply the lag-dimension of the stationary DAG plus
|
|
124
|
+
# the maximum lag of XYS.
|
|
125
|
+
statgraph_tau_max = graph.shape[2] - 1
|
|
126
|
+
maxlag_XYS = 0
|
|
127
|
+
for varlag in self.X.union(self.Y).union(self.S):
|
|
128
|
+
maxlag_XYS = max(maxlag_XYS, abs(varlag[1]))
|
|
129
|
+
self.tau_max = maxlag_XYS + statgraph_tau_max
|
|
130
|
+
else:
|
|
131
|
+
raise ValueError("graph_type invalid.")
|
|
132
|
+
|
|
133
|
+
self.hidden_variables = set(hidden_variables)
|
|
134
|
+
if len(self.hidden_variables.intersection(self.X.union(self.Y).union(self.S))) > 0:
|
|
135
|
+
raise ValueError("XYS overlaps with hidden_variables!")
|
|
136
|
+
|
|
137
|
+
# self.tau_max is needed in the Graphs class
|
|
138
|
+
Graphs.__init__(self,
|
|
139
|
+
graph=graph,
|
|
140
|
+
graph_type=graph_type,
|
|
141
|
+
tau_max=self.tau_max,
|
|
142
|
+
hidden_variables=self.hidden_variables,
|
|
143
|
+
verbosity=verbosity)
|
|
144
|
+
|
|
145
|
+
self._check_XYS()
|
|
146
|
+
|
|
147
|
+
self.ancX = self._get_ancestors(X)
|
|
148
|
+
self.ancY = self._get_ancestors(Y)
|
|
149
|
+
self.ancS = self._get_ancestors(S)
|
|
150
|
+
|
|
151
|
+
# If X is not in anc(Y), then no causal link exists
|
|
152
|
+
if self.ancY.intersection(set(X)) == set():
|
|
153
|
+
self.no_causal_path = True
|
|
154
|
+
if self.verbosity > 0:
|
|
155
|
+
print("No causal path from X to Y exists.")
|
|
156
|
+
else:
|
|
157
|
+
self.no_causal_path = False
|
|
158
|
+
|
|
159
|
+
# Get mediators
|
|
160
|
+
mediators = self.get_mediators(start=self.X, end=self.Y)
|
|
161
|
+
|
|
162
|
+
M = set(mediators)
|
|
163
|
+
self.M = M
|
|
164
|
+
|
|
165
|
+
self.listM = list(self.M)
|
|
166
|
+
|
|
167
|
+
for varlag in self.X.union(self.Y).union(self.S):
|
|
168
|
+
if abs(varlag[1]) > self.tau_max:
|
|
169
|
+
raise ValueError("X, Y, S must have time lags inside graph.")
|
|
170
|
+
|
|
171
|
+
if len(self.X.intersection(self.Y)) > 0:
|
|
172
|
+
raise ValueError("Overlap between X and Y.")
|
|
173
|
+
|
|
174
|
+
if len(self.S.intersection(self.Y.union(self.X))) > 0:
|
|
175
|
+
raise ValueError("Conditions S overlap with X or Y.")
|
|
176
|
+
|
|
177
|
+
# # TODO: need to prove that this is sufficient for non-identifiability!
|
|
178
|
+
# if len(self.X.intersection(self._get_descendants(self.M))) > 0:
|
|
179
|
+
# raise ValueError("Not identifiable: Overlap between X and des(M)")
|
|
180
|
+
|
|
181
|
+
if check_SM_overlap and len(self.S.intersection(self.M)) > 0:
|
|
182
|
+
raise ValueError("Conditions S overlap with mediators M.")
|
|
183
|
+
|
|
184
|
+
self.desX = self._get_descendants(self.X)
|
|
185
|
+
self.desY = self._get_descendants(self.Y)
|
|
186
|
+
self.desM = self._get_descendants(self.M)
|
|
187
|
+
self.descendants = self.desY.union(self.desM)
|
|
188
|
+
|
|
189
|
+
# Define forb as X and descendants of YM
|
|
190
|
+
self.forbidden_nodes = self.descendants.union(self.X) #.union(S)
|
|
191
|
+
|
|
192
|
+
# Define valid ancestors
|
|
193
|
+
self.vancs = self.ancX.union(self.ancY).union(self.ancS) - self.forbidden_nodes
|
|
194
|
+
|
|
195
|
+
if self.verbosity > 0:
|
|
196
|
+
if len(self.S.intersection(self.desX)) > 0:
|
|
197
|
+
print("Warning: Potentially outside assumptions: Conditions S overlap with des(X)")
|
|
198
|
+
|
|
199
|
+
# Here only check if S overlaps with des(Y), leave the option that S
|
|
200
|
+
# contains variables in des(M) to the user
|
|
201
|
+
if len(self.S.intersection(self.desY)) > 0:
|
|
202
|
+
raise ValueError("Not identifiable: Conditions S overlap with des(Y).")
|
|
203
|
+
|
|
204
|
+
if self.verbosity > 0:
|
|
205
|
+
print("\n##\n## Initializing CausalEffects class\n##"
|
|
206
|
+
"\n\nInput:")
|
|
207
|
+
print("\ngraph_type = %s" % graph_type
|
|
208
|
+
+ "\nX = %s" % self.listX
|
|
209
|
+
+ "\nY = %s" % self.listY
|
|
210
|
+
+ "\nS = %s" % self.listS
|
|
211
|
+
+ "\nM = %s" % self.listM
|
|
212
|
+
)
|
|
213
|
+
if len(self.hidden_variables) > 0:
|
|
214
|
+
print("\nhidden_variables = %s" % self.hidden_variables
|
|
215
|
+
)
|
|
216
|
+
print("\n\n")
|
|
217
|
+
if self.no_causal_path:
|
|
218
|
+
print("No causal path from X to Y exists!")
|
|
219
|
+
|
|
220
|
+
|
|
221
|
+
def _check_XYS(self):
|
|
222
|
+
"""Check whether XYS are sober.
|
|
223
|
+
"""
|
|
224
|
+
|
|
225
|
+
XYS = self.X.union(self.Y).union(self.S)
|
|
226
|
+
for xys in XYS:
|
|
227
|
+
var, lag = xys
|
|
228
|
+
if var < 0 or var >= self.N:
|
|
229
|
+
raise ValueError("XYS vars must be in [0...N]")
|
|
230
|
+
if lag < -self.tau_max or lag > 0:
|
|
231
|
+
raise ValueError("XYS lags must be in [-taumax...0]")
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def check_XYS_paths(self):
|
|
235
|
+
"""Check whether one can remove nodes from X and Y with no proper causal paths.
|
|
236
|
+
|
|
237
|
+
Returns
|
|
238
|
+
-------
|
|
239
|
+
X, Y : cleaned lists of X and Y with irrelevant nodes removed.
|
|
240
|
+
"""
|
|
241
|
+
|
|
242
|
+
# TODO: Also check S...
|
|
243
|
+
oldX = self.X.copy()
|
|
244
|
+
oldY = self.Y.copy()
|
|
245
|
+
|
|
246
|
+
# anc_Y = self._get_ancestors(self.Y)
|
|
247
|
+
# anc_S = self._get_ancestors(self.S)
|
|
248
|
+
|
|
249
|
+
# Remove first from X those nodes with no causal path to Y or S
|
|
250
|
+
X = set([x for x in self.X if x in self.ancY.union(self.ancS)])
|
|
251
|
+
|
|
252
|
+
# Remove from Y those nodes with no causal path from X
|
|
253
|
+
# des_X = self._get_descendants(X)
|
|
254
|
+
|
|
255
|
+
Y = set([y for y in self.Y if y in self.desX])
|
|
256
|
+
|
|
257
|
+
# Also require that all x in X have proper path to Y or S,
|
|
258
|
+
# that is, the first link goes out of x
|
|
259
|
+
# and into path nodes
|
|
260
|
+
mediators_S = self.get_mediators(start=self.X, end=self.S)
|
|
261
|
+
path_nodes = list(self.M.union(Y).union(mediators_S))
|
|
262
|
+
X = X.intersection(self._get_all_parents(path_nodes))
|
|
263
|
+
|
|
264
|
+
if set(oldX) != set(X) and self.verbosity > 0:
|
|
265
|
+
print("Consider pruning X = %s to X = %s " %(oldX, X) +
|
|
266
|
+
"since only these have causal path to Y")
|
|
267
|
+
|
|
268
|
+
if set(oldY) != set(Y) and self.verbosity > 0:
|
|
269
|
+
print("Consider pruning Y = %s to Y = %s " %(oldY, Y) +
|
|
270
|
+
"since only these have causal path from X")
|
|
271
|
+
|
|
272
|
+
return (list(X), list(Y))
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
def get_optimal_set(self,
|
|
276
|
+
alternative_conditions=None,
|
|
277
|
+
minimize=False,
|
|
278
|
+
return_separate_sets=False,
|
|
279
|
+
):
|
|
280
|
+
"""Returns optimal adjustment set.
|
|
281
|
+
|
|
282
|
+
See Runge NeurIPS 2021.
|
|
283
|
+
|
|
284
|
+
Parameters
|
|
285
|
+
----------
|
|
286
|
+
alternative_conditions : set of tuples
|
|
287
|
+
Used only internally in optimality theorem. If None, self.S is used.
|
|
288
|
+
minimize : {False, True, 'colliders_only'}
|
|
289
|
+
Minimize optimal set. If True, minimize such that no subset
|
|
290
|
+
can be removed without making it invalid. If 'colliders_only',
|
|
291
|
+
only colliders are minimized.
|
|
292
|
+
return_separate_sets : bool
|
|
293
|
+
Whether to return tuple of parents, colliders, collider_parents, and S.
|
|
294
|
+
|
|
295
|
+
Returns
|
|
296
|
+
-------
|
|
297
|
+
Oset_S : False or list or tuple of lists
|
|
298
|
+
Returns optimal adjustment set if a valid set exists, otherwise False.
|
|
299
|
+
"""
|
|
300
|
+
|
|
301
|
+
|
|
302
|
+
# Needed for optimality theorem where Osets for alternative S are tested
|
|
303
|
+
if alternative_conditions is None:
|
|
304
|
+
S = self.S.copy()
|
|
305
|
+
vancs = self.vancs.copy()
|
|
306
|
+
else:
|
|
307
|
+
S = alternative_conditions
|
|
308
|
+
newancS = self._get_ancestors(S)
|
|
309
|
+
vancs = self.ancX.union(self.ancY).union(newancS) - self.forbidden_nodes
|
|
310
|
+
|
|
311
|
+
# vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes
|
|
312
|
+
|
|
313
|
+
# descendants = self._get_descendants(self.Y.union(self.M))
|
|
314
|
+
|
|
315
|
+
# Sufficient condition for non-identifiability
|
|
316
|
+
if len(self.X.intersection(self.descendants)) > 0:
|
|
317
|
+
return False # raise ValueError("Not identifiable: Overlap between X and des(M)")
|
|
318
|
+
|
|
319
|
+
##
|
|
320
|
+
## Construct O-set
|
|
321
|
+
##
|
|
322
|
+
|
|
323
|
+
# Start with parents
|
|
324
|
+
parents = self._get_all_parents(self.Y.union(self.M)) # set()
|
|
325
|
+
|
|
326
|
+
# Remove forbidden nodes
|
|
327
|
+
parents = parents - self.forbidden_nodes
|
|
328
|
+
|
|
329
|
+
# Construct valid collider path nodes
|
|
330
|
+
colliders = set()
|
|
331
|
+
for w in self.Y.union(self.M):
|
|
332
|
+
j, tau = w
|
|
333
|
+
this_level = [w]
|
|
334
|
+
non_suitable_nodes = []
|
|
335
|
+
while len(this_level) > 0:
|
|
336
|
+
next_level = []
|
|
337
|
+
for varlag in this_level:
|
|
338
|
+
suitable_spouses = set(self._get_spouses(varlag)) - set(non_suitable_nodes)
|
|
339
|
+
for spouse in suitable_spouses:
|
|
340
|
+
i, tau = spouse
|
|
341
|
+
if spouse in self.X:
|
|
342
|
+
return False
|
|
343
|
+
|
|
344
|
+
if (# Node not already in set
|
|
345
|
+
spouse not in colliders #.union(parents)
|
|
346
|
+
# not forbidden
|
|
347
|
+
and spouse not in self.forbidden_nodes
|
|
348
|
+
# in time bounds
|
|
349
|
+
and (-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
|
|
350
|
+
and (spouse in vancs
|
|
351
|
+
or not self._check_path(#graph=self.graph,
|
|
352
|
+
start=self.X, end=[spouse],
|
|
353
|
+
conditions=list(parents.union(vancs)) + list(S),
|
|
354
|
+
))
|
|
355
|
+
):
|
|
356
|
+
colliders = colliders | {spouse}
|
|
357
|
+
next_level.append(spouse)
|
|
358
|
+
else:
|
|
359
|
+
if spouse not in colliders:
|
|
360
|
+
non_suitable_nodes.append(spouse)
|
|
361
|
+
|
|
362
|
+
|
|
363
|
+
this_level = set(next_level) - set(non_suitable_nodes)
|
|
364
|
+
|
|
365
|
+
# Add parents and raise Error if not identifiable
|
|
366
|
+
collider_parents = self._get_all_parents(colliders)
|
|
367
|
+
if len(self.X.intersection(collider_parents)) > 0:
|
|
368
|
+
return False
|
|
369
|
+
|
|
370
|
+
colliders_and_their_parents = colliders.union(collider_parents)
|
|
371
|
+
|
|
372
|
+
# Add valid collider path nodes and their parents
|
|
373
|
+
Oset = parents.union(colliders_and_their_parents)
|
|
374
|
+
|
|
375
|
+
|
|
376
|
+
if minimize:
|
|
377
|
+
removable = []
|
|
378
|
+
# First remove all those that have no path from X
|
|
379
|
+
sorted_Oset = Oset
|
|
380
|
+
if minimize == 'colliders_only':
|
|
381
|
+
sorted_Oset = [node for node in sorted_Oset if node not in parents]
|
|
382
|
+
|
|
383
|
+
for node in sorted_Oset:
|
|
384
|
+
if (not self._check_path(#graph=self.graph,
|
|
385
|
+
start=self.X, end=[node],
|
|
386
|
+
conditions=list(Oset - {node}) + list(S))):
|
|
387
|
+
removable.append(node)
|
|
388
|
+
|
|
389
|
+
Oset = Oset - set(removable)
|
|
390
|
+
if minimize == 'colliders_only':
|
|
391
|
+
sorted_Oset = [node for node in Oset if node not in parents]
|
|
392
|
+
|
|
393
|
+
removable = []
|
|
394
|
+
# Next remove all those with no direct connection to Y
|
|
395
|
+
for node in sorted_Oset:
|
|
396
|
+
if (not self._check_path(#graph=self.graph,
|
|
397
|
+
start=[node], end=self.Y,
|
|
398
|
+
conditions=list(Oset - {node}) + list(S) + list(self.X),
|
|
399
|
+
ends_with=['**>', '**+'])):
|
|
400
|
+
removable.append(node)
|
|
401
|
+
|
|
402
|
+
Oset = Oset - set(removable)
|
|
403
|
+
|
|
404
|
+
Oset_S = Oset.union(S)
|
|
405
|
+
|
|
406
|
+
if return_separate_sets:
|
|
407
|
+
return parents, colliders, collider_parents, S
|
|
408
|
+
else:
|
|
409
|
+
return list(Oset_S)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
def _get_collider_paths_optimality(self, source_nodes, target_nodes,
|
|
413
|
+
condition,
|
|
414
|
+
inside_set=None,
|
|
415
|
+
start_with_tail_or_head=False,
|
|
416
|
+
):
|
|
417
|
+
"""Returns relevant collider paths to check optimality.
|
|
418
|
+
|
|
419
|
+
Iterates over collider paths within O-set via depth-first search
|
|
420
|
+
|
|
421
|
+
"""
|
|
422
|
+
|
|
423
|
+
for w in source_nodes:
|
|
424
|
+
# Only used to return *all* collider paths
|
|
425
|
+
# (needed in optimality theorem)
|
|
426
|
+
|
|
427
|
+
coll_path = []
|
|
428
|
+
|
|
429
|
+
queue = [(w, coll_path)]
|
|
430
|
+
|
|
431
|
+
non_valid_subsets = []
|
|
432
|
+
|
|
433
|
+
while queue:
|
|
434
|
+
|
|
435
|
+
varlag, coll_path = queue.pop()
|
|
436
|
+
|
|
437
|
+
coll_path = coll_path + [varlag]
|
|
438
|
+
coll_path_set = set(coll_path)
|
|
439
|
+
|
|
440
|
+
suitable_nodes = set(self._get_spouses(varlag))
|
|
441
|
+
|
|
442
|
+
if start_with_tail_or_head and coll_path == [w]:
|
|
443
|
+
children = set(self._get_children(varlag))
|
|
444
|
+
suitable_nodes = suitable_nodes.union(children)
|
|
445
|
+
|
|
446
|
+
for node in suitable_nodes:
|
|
447
|
+
i, tau = node
|
|
448
|
+
if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
|
|
449
|
+
and node not in coll_path_set):
|
|
450
|
+
|
|
451
|
+
if condition == 'II' and node not in target_nodes and node not in self.vancs:
|
|
452
|
+
continue
|
|
453
|
+
|
|
454
|
+
if node in inside_set:
|
|
455
|
+
if condition == 'I':
|
|
456
|
+
non_valid = False
|
|
457
|
+
extended_set = coll_path_set | {node}
|
|
458
|
+
for pathset in non_valid_subsets[::-1]:
|
|
459
|
+
if pathset.issubset(extended_set):
|
|
460
|
+
non_valid = True
|
|
461
|
+
break
|
|
462
|
+
if non_valid is False:
|
|
463
|
+
queue.append((node, coll_path))
|
|
464
|
+
else:
|
|
465
|
+
continue
|
|
466
|
+
elif condition == 'II':
|
|
467
|
+
queue.append((node, coll_path))
|
|
468
|
+
|
|
469
|
+
if node in target_nodes:
|
|
470
|
+
# yield coll_path
|
|
471
|
+
# collider_paths[node].append(coll_path)
|
|
472
|
+
if condition == 'I':
|
|
473
|
+
# Construct OπiN
|
|
474
|
+
Sprime = self.S.union(coll_path_set)
|
|
475
|
+
OpiN = self.get_optimal_set(alternative_conditions=Sprime)
|
|
476
|
+
if OpiN is False:
|
|
477
|
+
queue = [(q_node, q_path) for (q_node, q_path) in queue if not coll_path_set.issubset(set(q_path + [q_node]))]
|
|
478
|
+
non_valid_subsets.append(coll_path_set)
|
|
479
|
+
else:
|
|
480
|
+
return False
|
|
481
|
+
|
|
482
|
+
elif condition == 'II':
|
|
483
|
+
return True
|
|
484
|
+
# yield coll_path
|
|
485
|
+
|
|
486
|
+
if condition == 'I':
|
|
487
|
+
return True
|
|
488
|
+
elif condition == 'II':
|
|
489
|
+
return False
|
|
490
|
+
# return collider_paths
|
|
491
|
+
|
|
492
|
+
|
|
493
|
+
def check_optimality(self):
|
|
494
|
+
"""Check whether optimal adjustment set exists according to Thm. 3 in Runge NeurIPS 2021.
|
|
495
|
+
|
|
496
|
+
Returns
|
|
497
|
+
-------
|
|
498
|
+
optimality : bool
|
|
499
|
+
Returns True if an optimal adjustment set exists, otherwise False.
|
|
500
|
+
"""
|
|
501
|
+
|
|
502
|
+
# Cond. 0: Exactly one valid adjustment set exists
|
|
503
|
+
cond_0 = (self._get_all_valid_adjustment_sets(check_one_set_exists=True))
|
|
504
|
+
|
|
505
|
+
#
|
|
506
|
+
# Cond. I
|
|
507
|
+
#
|
|
508
|
+
parents, colliders, collider_parents, _ = self.get_optimal_set(return_separate_sets=True)
|
|
509
|
+
Oset = parents.union(colliders).union(collider_parents)
|
|
510
|
+
n_nodes = self._get_all_spouses(self.Y.union(self.M).union(colliders)) - self.forbidden_nodes - Oset - self.S - self.Y - self.M - colliders
|
|
511
|
+
|
|
512
|
+
if (len(n_nodes) == 0):
|
|
513
|
+
# # (1) There are no spouses N ∈ sp(YMC) \ (forbOS)
|
|
514
|
+
cond_I = True
|
|
515
|
+
else:
|
|
516
|
+
|
|
517
|
+
# (2) For all N ∈ N and all its collider paths i it holds that
|
|
518
|
+
# OπiN does not block all non-causal paths from X to Y
|
|
519
|
+
# cond_I = True
|
|
520
|
+
cond_I = self._get_collider_paths_optimality(
|
|
521
|
+
source_nodes=list(n_nodes), target_nodes=list(self.Y.union(self.M)),
|
|
522
|
+
condition='I',
|
|
523
|
+
inside_set=Oset.union(self.S), start_with_tail_or_head=False,
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
#
|
|
527
|
+
# Cond. II
|
|
528
|
+
#
|
|
529
|
+
e_nodes = Oset.difference(parents)
|
|
530
|
+
cond_II = True
|
|
531
|
+
for E in e_nodes:
|
|
532
|
+
Oset_minusE = Oset - {E}
|
|
533
|
+
if self._check_path(#graph=self.graph,
|
|
534
|
+
start=list(self.X), end=[E],
|
|
535
|
+
conditions=list(self.S) + list(Oset_minusE)):
|
|
536
|
+
|
|
537
|
+
cond_II = self._get_collider_paths_optimality(
|
|
538
|
+
target_nodes=self.Y.union(self.M),
|
|
539
|
+
source_nodes=[E],
|
|
540
|
+
condition='II',
|
|
541
|
+
inside_set=list(Oset.union(self.S)),
|
|
542
|
+
start_with_tail_or_head = True)
|
|
543
|
+
|
|
544
|
+
if cond_II is False:
|
|
545
|
+
if self.verbosity > 1:
|
|
546
|
+
print("Non-optimal due to E = ", E)
|
|
547
|
+
break
|
|
548
|
+
|
|
549
|
+
optimality = (cond_0 or (cond_I and cond_II))
|
|
550
|
+
if self.verbosity > 0:
|
|
551
|
+
print("Optimality = %s with cond_0 = %s, cond_I = %s, cond_II = %s"
|
|
552
|
+
% (optimality, cond_0, cond_I, cond_II))
|
|
553
|
+
return optimality
|
|
554
|
+
|
|
555
|
+
def _check_validity(self, Z):
|
|
556
|
+
"""Checks whether Z is a valid adjustment set."""
|
|
557
|
+
|
|
558
|
+
# causal_children = list(self.M.union(self.Y))
|
|
559
|
+
backdoor_path = self._check_path(#graph=self.graph,
|
|
560
|
+
start=list(self.X), end=list(self.Y),
|
|
561
|
+
conditions=list(Z),
|
|
562
|
+
# causal_children=causal_children,
|
|
563
|
+
path_type = 'non_causal')
|
|
564
|
+
|
|
565
|
+
if backdoor_path:
|
|
566
|
+
return False
|
|
567
|
+
else:
|
|
568
|
+
return True
|
|
569
|
+
|
|
570
|
+
def _get_adjust_set(self,
|
|
571
|
+
minimize=False,
|
|
572
|
+
):
|
|
573
|
+
"""Returns Adjust-set.
|
|
574
|
+
|
|
575
|
+
See van der Zander, B.; Liśkiewicz, M. & Textor, J.
|
|
576
|
+
Separators and adjustment sets in causal graphs: Complete
|
|
577
|
+
criteria and an algorithmic framework
|
|
578
|
+
Artificial Intelligence, Elsevier, 2019, 270, 1-40
|
|
579
|
+
|
|
580
|
+
"""
|
|
581
|
+
|
|
582
|
+
vancs = self.vancs.copy()
|
|
583
|
+
|
|
584
|
+
if minimize:
|
|
585
|
+
# Get removable nodes by computing minimal valid set from Z
|
|
586
|
+
if minimize == 'keep_parentsYM':
|
|
587
|
+
minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
|
|
588
|
+
|
|
589
|
+
else:
|
|
590
|
+
minimize_nodes = vancs
|
|
591
|
+
|
|
592
|
+
# Zprime2 = Zprime
|
|
593
|
+
# First remove all nodes that have no unique path to X given Oset
|
|
594
|
+
for node in minimize_nodes:
|
|
595
|
+
# path = self.oracle.check_shortest_path(X=X, Y=[node],
|
|
596
|
+
# Z=list(vancs - {node}),
|
|
597
|
+
# max_lag=None,
|
|
598
|
+
# starts_with=None, #'arrowhead',
|
|
599
|
+
# forbidden_nodes=None, #list(Zprime - {node}),
|
|
600
|
+
# return_path=False)
|
|
601
|
+
path = self._check_path(#graph=self.graph,
|
|
602
|
+
start=self.X, end=[node],
|
|
603
|
+
conditions=list(vancs - {node}),
|
|
604
|
+
)
|
|
605
|
+
|
|
606
|
+
if path is False:
|
|
607
|
+
vancs = vancs - {node}
|
|
608
|
+
|
|
609
|
+
if minimize == 'keep_parentsYM':
|
|
610
|
+
minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
|
|
611
|
+
else:
|
|
612
|
+
minimize_nodes = vancs
|
|
613
|
+
|
|
614
|
+
# print(Zprime2)
|
|
615
|
+
# Next remove all nodes that have no unique path to Y given Oset_min
|
|
616
|
+
# Z = Zprime2
|
|
617
|
+
for node in minimize_nodes:
|
|
618
|
+
|
|
619
|
+
path = self._check_path(#graph=self.graph,
|
|
620
|
+
start=[node], end=self.Y,
|
|
621
|
+
conditions=list(vancs - {node}) + list(self.X),
|
|
622
|
+
)
|
|
623
|
+
|
|
624
|
+
if path is False:
|
|
625
|
+
vancs = vancs - {node}
|
|
626
|
+
|
|
627
|
+
if self._check_validity(list(vancs)) is False:
|
|
628
|
+
return False
|
|
629
|
+
else:
|
|
630
|
+
return list(vancs)
|
|
631
|
+
|
|
632
|
+
|
|
633
|
+
def _get_all_valid_adjustment_sets(self,
|
|
634
|
+
check_one_set_exists=False, yield_index=None):
|
|
635
|
+
"""Constructs all valid adjustment sets or just checks whether one exists.
|
|
636
|
+
|
|
637
|
+
See van der Zander, B.; Liśkiewicz, M. & Textor, J.
|
|
638
|
+
Separators and adjustment sets in causal graphs: Complete
|
|
639
|
+
criteria and an algorithmic framework
|
|
640
|
+
Artificial Intelligence, Elsevier, 2019, 270, 1-40
|
|
641
|
+
|
|
642
|
+
"""
|
|
643
|
+
|
|
644
|
+
cond_set = set(self.S)
|
|
645
|
+
all_vars = [(i, -tau) for i in range(self.N)
|
|
646
|
+
for tau in range(0, self.tau_max + 1)]
|
|
647
|
+
|
|
648
|
+
all_vars_set = set(all_vars) - self.forbidden_nodes
|
|
649
|
+
|
|
650
|
+
|
|
651
|
+
def find_sep(I, R):
|
|
652
|
+
Rprime = R - self.X - self.Y
|
|
653
|
+
# TODO: anteriors and NOT ancestors where
|
|
654
|
+
# anteriors include --- links in causal paths
|
|
655
|
+
# print(I)
|
|
656
|
+
XYI = list(self.X.union(self.Y).union(I))
|
|
657
|
+
# print(XYI)
|
|
658
|
+
ancs = self._get_ancestors(list(XYI))
|
|
659
|
+
Z = ancs.intersection(Rprime)
|
|
660
|
+
if self._check_validity(Z) is False:
|
|
661
|
+
return False
|
|
662
|
+
else:
|
|
663
|
+
return Z
|
|
664
|
+
|
|
665
|
+
|
|
666
|
+
def list_sep(I, R):
|
|
667
|
+
# print(find_sep(X, Y, I, R))
|
|
668
|
+
if find_sep(I, R) is not False:
|
|
669
|
+
# print(I,R)
|
|
670
|
+
if I == R:
|
|
671
|
+
# print('--->', I)
|
|
672
|
+
yield I
|
|
673
|
+
else:
|
|
674
|
+
# Pick arbitrary node from R-I
|
|
675
|
+
RminusI = list(R - I)
|
|
676
|
+
# print(R, I, RminusI)
|
|
677
|
+
v = RminusI[0]
|
|
678
|
+
# print("here ", X, Y, I.union(set([v])), R)
|
|
679
|
+
yield from list_sep(I | {v}, R)
|
|
680
|
+
yield from list_sep(I, R - {v})
|
|
681
|
+
|
|
682
|
+
# print("all ", X, Y, cond_set, all_vars_set)
|
|
683
|
+
all_sets = []
|
|
684
|
+
I = cond_set
|
|
685
|
+
R = all_vars_set
|
|
686
|
+
for index, valid_set in enumerate(list_sep(I, R)):
|
|
687
|
+
# print(valid_set)
|
|
688
|
+
all_sets.append(list(valid_set))
|
|
689
|
+
if check_one_set_exists and index > 0:
|
|
690
|
+
break
|
|
691
|
+
|
|
692
|
+
if yield_index is not None and index == yield_index:
|
|
693
|
+
return valid_set
|
|
694
|
+
|
|
695
|
+
if yield_index is not None:
|
|
696
|
+
return None
|
|
697
|
+
|
|
698
|
+
if check_one_set_exists:
|
|
699
|
+
if len(all_sets) == 1:
|
|
700
|
+
return True
|
|
701
|
+
else:
|
|
702
|
+
return False
|
|
703
|
+
|
|
704
|
+
return all_sets
|
|
705
|
+
|
|
706
|
+
|
|
707
|
+
def fit_total_effect(self,
|
|
708
|
+
dataframe,
|
|
709
|
+
estimator,
|
|
710
|
+
adjustment_set='optimal',
|
|
711
|
+
conditional_estimator=None,
|
|
712
|
+
data_transform=None,
|
|
713
|
+
mask_type=None,
|
|
714
|
+
ignore_identifiability=False,
|
|
715
|
+
):
|
|
716
|
+
"""Returns a fitted model for the total causal effect of X on Y
|
|
717
|
+
conditional on S.
|
|
718
|
+
|
|
719
|
+
Parameters
|
|
720
|
+
----------
|
|
721
|
+
dataframe : data object
|
|
722
|
+
Tigramite dataframe object. It must have the attributes dataframe.values
|
|
723
|
+
yielding a numpy array of shape (observations T, variables N) and
|
|
724
|
+
optionally a mask of the same shape and a missing values flag.
|
|
725
|
+
estimator : sklearn model object
|
|
726
|
+
For example, sklearn.linear_model.LinearRegression() for a linear
|
|
727
|
+
regression model.
|
|
728
|
+
adjustment_set : str or list of tuples
|
|
729
|
+
If 'optimal' the Oset is used, if 'minimized_optimal' the minimized Oset,
|
|
730
|
+
and if 'colliders_minimized_optimal', the colliders-minimized Oset.
|
|
731
|
+
If a list of tuples is passed, this set is used.
|
|
732
|
+
conditional_estimator : sklearn model object, optional (default: None)
|
|
733
|
+
Used to fit conditional causal effects in nested regression.
|
|
734
|
+
If None, the same model as for estimator is used.
|
|
735
|
+
data_transform : sklearn preprocessing object, optional (default: None)
|
|
736
|
+
Used to transform data prior to fitting. For example,
|
|
737
|
+
sklearn.preprocessing.StandardScaler for simple standardization. The
|
|
738
|
+
fitted parameters are stored.
|
|
739
|
+
mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
|
|
740
|
+
Masking mode: Indicators for which variables in the dependence
|
|
741
|
+
measure I(X; Y | Z) the samples should be masked. If None, the mask
|
|
742
|
+
is not used. Explained in tutorial on masking and missing values.
|
|
743
|
+
ignore_identifiability : bool
|
|
744
|
+
Only applies to adjustment sets supplied by user. Ignores if that
|
|
745
|
+
set leads to a non-identifiable effect.
|
|
746
|
+
"""
|
|
747
|
+
|
|
748
|
+
if self.no_causal_path:
|
|
749
|
+
if self.verbosity > 0:
|
|
750
|
+
print("No causal path from X to Y exists.")
|
|
751
|
+
return self
|
|
752
|
+
|
|
753
|
+
self.dataframe = dataframe
|
|
754
|
+
self.conditional_estimator = conditional_estimator
|
|
755
|
+
|
|
756
|
+
# if self.dataframe.has_vector_data:
|
|
757
|
+
# raise ValueError("vector_vars in DataFrame cannot be used together with CausalEffects!"
|
|
758
|
+
# " You can estimate vector-valued effects by using multivariate X, Y, S."
|
|
759
|
+
# " Note, however, that this requires assuming a graph at the level "
|
|
760
|
+
# "of the components of X, Y, S, ...")
|
|
761
|
+
|
|
762
|
+
if self.N != self.dataframe.N:
|
|
763
|
+
raise ValueError("Dataset dimensions inconsistent with number of variables in graph.")
|
|
764
|
+
|
|
765
|
+
if adjustment_set == 'optimal':
|
|
766
|
+
# Check optimality and use either optimal or colliders_only set
|
|
767
|
+
adjustment_set = self.get_optimal_set()
|
|
768
|
+
elif adjustment_set == 'colliders_minimized_optimal':
|
|
769
|
+
adjustment_set = self.get_optimal_set(minimize='colliders_only')
|
|
770
|
+
elif adjustment_set == 'minimized_optimal':
|
|
771
|
+
adjustment_set = self.get_optimal_set(minimize=True)
|
|
772
|
+
else:
|
|
773
|
+
if ignore_identifiability is False and self._check_validity(adjustment_set) is False:
|
|
774
|
+
raise ValueError("Chosen adjustment_set is not valid.")
|
|
775
|
+
|
|
776
|
+
if adjustment_set is False:
|
|
777
|
+
raise ValueError("Causal effect not identifiable via adjustment.")
|
|
778
|
+
|
|
779
|
+
self.adjustment_set = adjustment_set
|
|
780
|
+
|
|
781
|
+
# Fit model of Y on X and Z (and conditions)
|
|
782
|
+
# Build the model
|
|
783
|
+
self.model = Models(
|
|
784
|
+
dataframe=dataframe,
|
|
785
|
+
model=estimator,
|
|
786
|
+
conditional_model=conditional_estimator,
|
|
787
|
+
data_transform=data_transform,
|
|
788
|
+
mask_type=mask_type,
|
|
789
|
+
verbosity=self.verbosity)
|
|
790
|
+
|
|
791
|
+
self.model.get_general_fitted_model(
|
|
792
|
+
Y=self.listY, X=self.listX, Z=list(self.adjustment_set),
|
|
793
|
+
conditions=self.listS,
|
|
794
|
+
tau_max=self.tau_max,
|
|
795
|
+
cut_off='tau_max',
|
|
796
|
+
return_data=False)
|
|
797
|
+
|
|
798
|
+
return self
|
|
799
|
+
|
|
800
|
+
def predict_total_effect(self,
|
|
801
|
+
intervention_data,
|
|
802
|
+
conditions_data=None,
|
|
803
|
+
pred_params=None,
|
|
804
|
+
return_further_pred_results=False,
|
|
805
|
+
aggregation_func=np.mean,
|
|
806
|
+
transform_interventions_and_prediction=False,
|
|
807
|
+
intervention_type='hard',
|
|
808
|
+
):
|
|
809
|
+
"""Predict effect of intervention with fitted model.
|
|
810
|
+
|
|
811
|
+
Uses the model.predict() function of the sklearn model.
|
|
812
|
+
|
|
813
|
+
Parameters
|
|
814
|
+
----------
|
|
815
|
+
intervention_data : numpy array
|
|
816
|
+
Numpy array of shape (n_interventions, len(X)) that contains the do(X) values.
|
|
817
|
+
conditions_data : data object, optional
|
|
818
|
+
Numpy array of shape (n_interventions, len(S)) that contains the S=s values.
|
|
819
|
+
pred_params : dict, optional
|
|
820
|
+
Optional parameters passed on to sklearn prediction function.
|
|
821
|
+
return_further_pred_results : bool, optional (default: False)
|
|
822
|
+
In case the predictor class returns more than just the expected value,
|
|
823
|
+
the entire results can be returned.
|
|
824
|
+
aggregation_func : callable
|
|
825
|
+
Callable applied to output of 'predict'. Default is 'np.mean'.
|
|
826
|
+
transform_interventions_and_prediction : bool (default: False)
|
|
827
|
+
Whether to perform the inverse data_transform on prediction results.
|
|
828
|
+
intervention_type : {'hard', 'soft'}
|
|
829
|
+
Specify whether intervention is 'hard' (set value) or 'soft'
|
|
830
|
+
(add value to observed data).
|
|
831
|
+
|
|
832
|
+
Returns
|
|
833
|
+
-------
|
|
834
|
+
Results from prediction: an array of shape (n_interventions, len(Y)).
|
|
835
|
+
If estimate_confidence = True, then a tuple is returned.
|
|
836
|
+
"""
|
|
837
|
+
|
|
838
|
+
def get_vectorized_length(W):
|
|
839
|
+
return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
|
|
840
|
+
|
|
841
|
+
# lenX = len(self.listX)
|
|
842
|
+
# lenS = len(self.listS)
|
|
843
|
+
|
|
844
|
+
lenX = get_vectorized_length(self.listX)
|
|
845
|
+
lenS = get_vectorized_length(self.listS)
|
|
846
|
+
|
|
847
|
+
if intervention_data.shape[1] != lenX:
|
|
848
|
+
raise ValueError("intervention_data.shape[1] must be len(X).")
|
|
849
|
+
|
|
850
|
+
if intervention_type not in {'hard', 'soft'}:
|
|
851
|
+
raise ValueError("intervention_type must be 'hard' or 'soft'.")
|
|
852
|
+
|
|
853
|
+
if conditions_data is not None and lenS > 0:
|
|
854
|
+
if conditions_data.shape[1] != lenS:
|
|
855
|
+
raise ValueError("conditions_data.shape[1] must be len(S).")
|
|
856
|
+
if conditions_data.shape[0] != intervention_data.shape[0]:
|
|
857
|
+
raise ValueError("conditions_data.shape[0] must match intervention_data.shape[0].")
|
|
858
|
+
elif conditions_data is not None and lenS == 0:
|
|
859
|
+
raise ValueError("conditions_data specified, but S=None or empty.")
|
|
860
|
+
elif conditions_data is None and lenS > 0:
|
|
861
|
+
raise ValueError("S specified, but conditions_data is None.")
|
|
862
|
+
|
|
863
|
+
|
|
864
|
+
if self.no_causal_path:
|
|
865
|
+
if self.verbosity > 0:
|
|
866
|
+
print("No causal path from X to Y exists.")
|
|
867
|
+
return np.zeros((len(intervention_data), len(self.listY)))
|
|
868
|
+
|
|
869
|
+
effect = self.model.get_general_prediction(
|
|
870
|
+
intervention_data=intervention_data,
|
|
871
|
+
conditions_data=conditions_data,
|
|
872
|
+
pred_params=pred_params,
|
|
873
|
+
return_further_pred_results=return_further_pred_results,
|
|
874
|
+
transform_interventions_and_prediction=transform_interventions_and_prediction,
|
|
875
|
+
aggregation_func=aggregation_func,
|
|
876
|
+
intervention_type=intervention_type,)
|
|
877
|
+
|
|
878
|
+
return effect
|
|
879
|
+
|
|
880
|
+
def fit_wright_effect(self,
|
|
881
|
+
dataframe,
|
|
882
|
+
mediation=None,
|
|
883
|
+
method='parents',
|
|
884
|
+
links_coeffs=None,
|
|
885
|
+
data_transform=None,
|
|
886
|
+
mask_type=None,
|
|
887
|
+
):
|
|
888
|
+
"""Returns a fitted model for the total or mediated causal effect of X on Y
|
|
889
|
+
potentially through mediator variables.
|
|
890
|
+
|
|
891
|
+
Parameters
|
|
892
|
+
----------
|
|
893
|
+
dataframe : data object
|
|
894
|
+
Tigramite dataframe object. It must have the attributes dataframe.values
|
|
895
|
+
yielding a numpy array of shape (observations T, variables N) and
|
|
896
|
+
optionally a mask of the same shape and a missing values flag.
|
|
897
|
+
mediation : None, 'direct', or list of tuples
|
|
898
|
+
If None, total effect is estimated, if 'direct' then only the direct effect is estimated,
|
|
899
|
+
else only those causal paths are considerd that pass at least through one of these mediator nodes.
|
|
900
|
+
method : {'parents', 'links_coeffs', 'optimal'}
|
|
901
|
+
Method to use for estimating Wright's path coefficients. If 'optimal',
|
|
902
|
+
the Oset is used, if 'links_coeffs', the coefficients in links_coeffs are used,
|
|
903
|
+
if 'parents', the parents are used (only valid for DAGs).
|
|
904
|
+
links_coeffs : dict
|
|
905
|
+
Only used if method = 'links_coeffs'.
|
|
906
|
+
Dictionary of format: {0:[((i, -tau), coeff),...], 1:[...],
|
|
907
|
+
...} for all variables where i must be in [0..N-1] and tau >= 0 with
|
|
908
|
+
number of variables N. coeff must be a float.
|
|
909
|
+
data_transform : None
|
|
910
|
+
Not implemented for Wright estimator. Complicated for missing samples.
|
|
911
|
+
mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
|
|
912
|
+
Masking mode: Indicators for which variables in the dependence
|
|
913
|
+
measure I(X; Y | Z) the samples should be masked. If None, the mask
|
|
914
|
+
is not used. Explained in tutorial on masking and missing values.
|
|
915
|
+
"""
|
|
916
|
+
|
|
917
|
+
if self.no_causal_path:
|
|
918
|
+
if self.verbosity > 0:
|
|
919
|
+
print("No causal path from X to Y exists.")
|
|
920
|
+
return self
|
|
921
|
+
|
|
922
|
+
if data_transform is not None:
|
|
923
|
+
raise ValueError("data_transform not implemented for Wright estimator."
|
|
924
|
+
" You can preprocess data yourself beforehand.")
|
|
925
|
+
|
|
926
|
+
import sklearn.linear_model
|
|
927
|
+
|
|
928
|
+
self.dataframe = dataframe
|
|
929
|
+
if self.dataframe.has_vector_data:
|
|
930
|
+
raise ValueError("vector_vars in DataFrame cannot be used together with Wright method!"
|
|
931
|
+
" You can either 1) estimate vector-valued effects by using multivariate (X, Y, S)"
|
|
932
|
+
" together with assuming a graph at the level of the components of (X, Y, S), "
|
|
933
|
+
" or 2) use vector_vars together with fit_total_effect and an estimator"
|
|
934
|
+
" that supports multiple outputs.")
|
|
935
|
+
|
|
936
|
+
estimator = sklearn.linear_model.LinearRegression()
|
|
937
|
+
|
|
938
|
+
# Fit model of Y on X and Z (and conditions)
|
|
939
|
+
# Build the model
|
|
940
|
+
self.model = Models(
|
|
941
|
+
dataframe=dataframe,
|
|
942
|
+
model=estimator,
|
|
943
|
+
data_transform=None, #data_transform,
|
|
944
|
+
mask_type=mask_type,
|
|
945
|
+
verbosity=self.verbosity)
|
|
946
|
+
|
|
947
|
+
mediators = self.M # self.get_mediators(start=self.X, end=self.Y)
|
|
948
|
+
|
|
949
|
+
if mediation == 'direct':
|
|
950
|
+
causal_paths = {}
|
|
951
|
+
for w in self.X:
|
|
952
|
+
causal_paths[w] = {}
|
|
953
|
+
for z in self.Y:
|
|
954
|
+
if w in self._get_parents(z):
|
|
955
|
+
causal_paths[w][z] = [[w, z]]
|
|
956
|
+
else:
|
|
957
|
+
causal_paths[w][z] = []
|
|
958
|
+
else:
|
|
959
|
+
causal_paths = self._get_causal_paths(source_nodes=self.X,
|
|
960
|
+
target_nodes=self.Y, mediators=mediators,
|
|
961
|
+
mediated_through=mediation, proper_paths=True)
|
|
962
|
+
|
|
963
|
+
if method == 'links_coeffs':
|
|
964
|
+
coeffs = {}
|
|
965
|
+
max_lag = 0
|
|
966
|
+
for medy in [med for med in mediators] + [y for y in self.listY]:
|
|
967
|
+
coeffs[medy] = {}
|
|
968
|
+
j, tauj = medy
|
|
969
|
+
for ipar, par_coeff in enumerate(links_coeffs[medy[0]]):
|
|
970
|
+
par, coeff, _ = par_coeff
|
|
971
|
+
i, taui = par
|
|
972
|
+
taui_shifted = taui + tauj
|
|
973
|
+
max_lag = max(abs(par[1]), max_lag)
|
|
974
|
+
coeffs[medy][(i, taui_shifted)] = coeff #self.fit_results[j][(j, 0)]['model'].coef_[ipar]
|
|
975
|
+
|
|
976
|
+
self.model.tau_max = max_lag
|
|
977
|
+
# print(coeffs)
|
|
978
|
+
|
|
979
|
+
elif method == 'optimal':
|
|
980
|
+
# all_parents = {}
|
|
981
|
+
coeffs = {}
|
|
982
|
+
for medy in [med for med in mediators] + [y for y in self.listY]:
|
|
983
|
+
coeffs[medy] = {}
|
|
984
|
+
mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X).union(self.Y)) - {medy}
|
|
985
|
+
all_parents = self._get_all_parents([medy]) - {medy}
|
|
986
|
+
for par in mediator_parents:
|
|
987
|
+
Sprime = set(all_parents) - {par, medy}
|
|
988
|
+
causal_effects = CausalEffects(graph=self.graph,
|
|
989
|
+
X=[par], Y=[medy], S=Sprime,
|
|
990
|
+
graph_type=self.graph_type,
|
|
991
|
+
check_SM_overlap=False,
|
|
992
|
+
)
|
|
993
|
+
oset = causal_effects.get_optimal_set()
|
|
994
|
+
# print(medy, par, list(set(all_parents)), oset)
|
|
995
|
+
if oset is False:
|
|
996
|
+
raise ValueError("Not identifiable via Wright's method.")
|
|
997
|
+
fit_res = self.model.get_general_fitted_model(
|
|
998
|
+
Y=[medy], X=[par], Z=oset,
|
|
999
|
+
tau_max=self.tau_max,
|
|
1000
|
+
cut_off='tau_max',
|
|
1001
|
+
return_data=False)
|
|
1002
|
+
coeffs[medy][par] = fit_res['model'].coef_[0]
|
|
1003
|
+
|
|
1004
|
+
elif method == 'parents':
|
|
1005
|
+
coeffs = {}
|
|
1006
|
+
for medy in [med for med in mediators] + [y for y in self.listY]:
|
|
1007
|
+
coeffs[medy] = {}
|
|
1008
|
+
# mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X)) - {medy}
|
|
1009
|
+
all_parents = self._get_all_parents([medy]) - {medy}
|
|
1010
|
+
if 'dag' not in self.graph_type:
|
|
1011
|
+
spouses = self._get_all_spouses([medy]) - {medy}
|
|
1012
|
+
if len(spouses) != 0:
|
|
1013
|
+
raise ValueError("method == 'parents' only possible for "
|
|
1014
|
+
"causal paths without adjacent bi-directed links!")
|
|
1015
|
+
|
|
1016
|
+
# print(j, all_parents[j])
|
|
1017
|
+
# if len(all_parents[j]) > 0:
|
|
1018
|
+
# print(medy, list(all_parents))
|
|
1019
|
+
fit_res = self.model.get_general_fitted_model(
|
|
1020
|
+
Y=[medy], X=list(all_parents), Z=[],
|
|
1021
|
+
conditions=None,
|
|
1022
|
+
tau_max=self.tau_max,
|
|
1023
|
+
cut_off='tau_max',
|
|
1024
|
+
return_data=False)
|
|
1025
|
+
|
|
1026
|
+
for ipar, par in enumerate(list(all_parents)):
|
|
1027
|
+
# print(par, fit_res['model'].coef_)
|
|
1028
|
+
coeffs[medy][par] = fit_res['model'].coef_[0][ipar]
|
|
1029
|
+
|
|
1030
|
+
else:
|
|
1031
|
+
raise ValueError("method must be 'optimal', 'links_coeffs', or 'parents'.")
|
|
1032
|
+
|
|
1033
|
+
# Effect is sum over products over all path coefficients
|
|
1034
|
+
# from x in X to y in Y
|
|
1035
|
+
effect = {}
|
|
1036
|
+
for (x, y) in itertools.product(self.listX, self.listY):
|
|
1037
|
+
effect[(x, y)] = 0.
|
|
1038
|
+
for causal_path in causal_paths[x][y]:
|
|
1039
|
+
effect_here = 1.
|
|
1040
|
+
# print(x, y, causal_path)
|
|
1041
|
+
for index, node in enumerate(causal_path[:-1]):
|
|
1042
|
+
i, taui = node
|
|
1043
|
+
j, tauj = causal_path[index + 1]
|
|
1044
|
+
# tau_ij = abs(tauj - taui)
|
|
1045
|
+
# print((j, tauj), (i, taui))
|
|
1046
|
+
effect_here *= coeffs[(j, tauj)][(i, taui)]
|
|
1047
|
+
|
|
1048
|
+
effect[(x, y)] += effect_here
|
|
1049
|
+
|
|
1050
|
+
# Make fitted coefficients available as attribute
|
|
1051
|
+
self.coeffs = coeffs
|
|
1052
|
+
|
|
1053
|
+
# Modify and overwrite variables in self.model
|
|
1054
|
+
self.model.Y = self.listY
|
|
1055
|
+
self.model.X = self.listX
|
|
1056
|
+
self.model.Z = []
|
|
1057
|
+
self.model.conditions = []
|
|
1058
|
+
self.model.cut_off = 'tau_max' # 'max_lag_or_tau_max'
|
|
1059
|
+
|
|
1060
|
+
class dummy_fit_class():
|
|
1061
|
+
def __init__(self, y_here, listX_here, effect_here):
|
|
1062
|
+
dim = len(listX_here)
|
|
1063
|
+
self.coeff_array = np.array([effect_here[(x, y_here)] for x in listX_here]).reshape(dim, 1)
|
|
1064
|
+
def predict(self, X):
|
|
1065
|
+
return np.dot(X, self.coeff_array).squeeze()
|
|
1066
|
+
|
|
1067
|
+
fit_results = {}
|
|
1068
|
+
for y in self.listY:
|
|
1069
|
+
fit_results[y] = {}
|
|
1070
|
+
fit_results[y]['model'] = dummy_fit_class(y, self.listX, effect)
|
|
1071
|
+
fit_results[y]['data_transform'] = deepcopy(data_transform)
|
|
1072
|
+
|
|
1073
|
+
# self.effect = effect
|
|
1074
|
+
self.model.fit_results = fit_results
|
|
1075
|
+
return self
|
|
1076
|
+
|
|
1077
|
+
def predict_wright_effect(self,
|
|
1078
|
+
intervention_data,
|
|
1079
|
+
pred_params=None,
|
|
1080
|
+
):
|
|
1081
|
+
"""Predict linear effect of intervention with fitted Wright-model.
|
|
1082
|
+
|
|
1083
|
+
Parameters
|
|
1084
|
+
----------
|
|
1085
|
+
intervention_data : numpy array
|
|
1086
|
+
Numpy array of shape (n_interventions, len(X)) that contains the do(X) values.
|
|
1087
|
+
pred_params : dict, optional
|
|
1088
|
+
Optional parameters passed on to sklearn prediction function.
|
|
1089
|
+
|
|
1090
|
+
Returns
|
|
1091
|
+
-------
|
|
1092
|
+
Results from prediction: an array of shape (n_interventions, len(Y)).
|
|
1093
|
+
"""
|
|
1094
|
+
|
|
1095
|
+
lenX = len(self.listX)
|
|
1096
|
+
lenY = len(self.listY)
|
|
1097
|
+
|
|
1098
|
+
if intervention_data.shape[1] != lenX:
|
|
1099
|
+
raise ValueError("intervention_data.shape[1] must be len(X).")
|
|
1100
|
+
|
|
1101
|
+
if self.no_causal_path:
|
|
1102
|
+
if self.verbosity > 0:
|
|
1103
|
+
print("No causal path from X to Y exists.")
|
|
1104
|
+
return np.zeros((len(intervention_data), len(self.Y)))
|
|
1105
|
+
|
|
1106
|
+
n_interventions, _ = intervention_data.shape
|
|
1107
|
+
|
|
1108
|
+
|
|
1109
|
+
predicted_array = np.zeros((n_interventions, lenY))
|
|
1110
|
+
pred_dict = {}
|
|
1111
|
+
for iy, y in enumerate(self.listY):
|
|
1112
|
+
# Print message
|
|
1113
|
+
if self.verbosity > 1:
|
|
1114
|
+
print("\n## Predicting target %s" % str(y))
|
|
1115
|
+
if pred_params is not None:
|
|
1116
|
+
for key in list(pred_params):
|
|
1117
|
+
print("%s = %s" % (key, pred_params[key]))
|
|
1118
|
+
# Default value for pred_params
|
|
1119
|
+
if pred_params is None:
|
|
1120
|
+
pred_params = {}
|
|
1121
|
+
# Check this is a valid target
|
|
1122
|
+
if y not in self.model.fit_results:
|
|
1123
|
+
raise ValueError("y = %s not yet fitted" % str(y))
|
|
1124
|
+
|
|
1125
|
+
# data_transform is too complicated for Wright estimator
|
|
1126
|
+
# Transform the data if needed
|
|
1127
|
+
# fitted_data_transform = self.model.fit_results[y]['fitted_data_transform']
|
|
1128
|
+
# if fitted_data_transform is not None:
|
|
1129
|
+
# intervention_data = fitted_data_transform['X'].transform(X=intervention_data)
|
|
1130
|
+
|
|
1131
|
+
# Now iterate through interventions (and potentially S)
|
|
1132
|
+
for index, dox_vals in enumerate(intervention_data):
|
|
1133
|
+
# Construct XZS-array
|
|
1134
|
+
intervention_array = dox_vals.reshape(1, lenX)
|
|
1135
|
+
predictor_array = intervention_array
|
|
1136
|
+
|
|
1137
|
+
predicted_vals = self.model.fit_results[y]['model'].predict(
|
|
1138
|
+
X=predictor_array, **pred_params)
|
|
1139
|
+
predicted_array[index, iy] = predicted_vals.mean()
|
|
1140
|
+
|
|
1141
|
+
# data_transform is too complicated for Wright estimator
|
|
1142
|
+
# if fitted_data_transform is not None:
|
|
1143
|
+
# rescaled = fitted_data_transform['Y'].inverse_transform(X=predicted_array[index, iy].reshape(-1, 1))
|
|
1144
|
+
# predicted_array[index, iy] = rescaled.squeeze()
|
|
1145
|
+
|
|
1146
|
+
return predicted_array
|
|
1147
|
+
|
|
1148
|
+
|
|
1149
|
+
def fit_bootstrap_of(self, method, method_args,
|
|
1150
|
+
boot_samples=100,
|
|
1151
|
+
boot_blocklength=1,
|
|
1152
|
+
seed=None):
|
|
1153
|
+
"""Runs chosen method on bootstrap samples drawn from DataFrame.
|
|
1154
|
+
|
|
1155
|
+
Bootstraps for tau=0 are drawn from [max_lag, ..., T] and all lagged
|
|
1156
|
+
variables constructed in DataFrame.construct_array are consistently
|
|
1157
|
+
shifted with respect to this bootsrap sample to ensure that lagged
|
|
1158
|
+
relations in the bootstrap sample are preserved.
|
|
1159
|
+
|
|
1160
|
+
This function fits the models, predict_bootstrap_of can then be used
|
|
1161
|
+
to get confidence intervals for the effect of interventions.
|
|
1162
|
+
|
|
1163
|
+
Parameters
|
|
1164
|
+
----------
|
|
1165
|
+
method : str
|
|
1166
|
+
Chosen method among valid functions in this class.
|
|
1167
|
+
method_args : dict
|
|
1168
|
+
Arguments passed to method.
|
|
1169
|
+
boot_samples : int
|
|
1170
|
+
Number of bootstrap samples to draw.
|
|
1171
|
+
boot_blocklength : int, optional (default: 1)
|
|
1172
|
+
Block length for block-bootstrap.
|
|
1173
|
+
seed : int, optional(default = None)
|
|
1174
|
+
Seed for RandomState (default_rng)
|
|
1175
|
+
"""
|
|
1176
|
+
|
|
1177
|
+
# if dataframe.analysis_mode != 'single':
|
|
1178
|
+
# raise ValueError("CausalEffects class currently only supports single "
|
|
1179
|
+
# "datasets.")
|
|
1180
|
+
|
|
1181
|
+
valid_methods = ['fit_total_effect',
|
|
1182
|
+
'fit_wright_effect',
|
|
1183
|
+
]
|
|
1184
|
+
|
|
1185
|
+
if method not in valid_methods:
|
|
1186
|
+
raise ValueError("method must be one of %s" % str(valid_methods))
|
|
1187
|
+
|
|
1188
|
+
# First call the method on the original dataframe
|
|
1189
|
+
# to make available adjustment set etc
|
|
1190
|
+
getattr(self, method)(**method_args)
|
|
1191
|
+
|
|
1192
|
+
self.original_model = deepcopy(self.model)
|
|
1193
|
+
|
|
1194
|
+
if self.verbosity > 0:
|
|
1195
|
+
print("\n##\n## Running Bootstrap of %s " % method +
|
|
1196
|
+
"\n##\n" +
|
|
1197
|
+
"\nboot_samples = %s \n" % boot_samples +
|
|
1198
|
+
"\nboot_blocklength = %s \n" % boot_blocklength
|
|
1199
|
+
)
|
|
1200
|
+
|
|
1201
|
+
method_args_bootstrap = deepcopy(method_args)
|
|
1202
|
+
self.bootstrap_results = {}
|
|
1203
|
+
|
|
1204
|
+
for b in range(boot_samples):
|
|
1205
|
+
# # Replace dataframe in method args by bootstrapped dataframe
|
|
1206
|
+
# method_args_bootstrap['dataframe'].bootstrap = boot_draw
|
|
1207
|
+
if seed is None:
|
|
1208
|
+
random_state = np.random.default_rng(None)
|
|
1209
|
+
else:
|
|
1210
|
+
random_state = np.random.default_rng(seed*boot_samples + b)
|
|
1211
|
+
|
|
1212
|
+
method_args_bootstrap['dataframe'].bootstrap = {'boot_blocklength':boot_blocklength,
|
|
1213
|
+
'random_state':random_state}
|
|
1214
|
+
|
|
1215
|
+
# Call method and save fitted model
|
|
1216
|
+
getattr(self, method)(**method_args_bootstrap)
|
|
1217
|
+
self.bootstrap_results[b] = deepcopy(self.model)
|
|
1218
|
+
|
|
1219
|
+
# Reset model
|
|
1220
|
+
self.model = self.original_model
|
|
1221
|
+
|
|
1222
|
+
return self
|
|
1223
|
+
|
|
1224
|
+
|
|
1225
|
+
def predict_bootstrap_of(self, method, method_args,
|
|
1226
|
+
conf_lev=0.9,
|
|
1227
|
+
return_individual_bootstrap_results=False):
|
|
1228
|
+
"""Predicts with fitted bootstraps.
|
|
1229
|
+
|
|
1230
|
+
To be used after fitting with fit_bootstrap_of. Only uses the
|
|
1231
|
+
expected values of the predict function, not potential other output.
|
|
1232
|
+
|
|
1233
|
+
Parameters
|
|
1234
|
+
----------
|
|
1235
|
+
method : str
|
|
1236
|
+
Chosen method among valid functions in this class.
|
|
1237
|
+
method_args : dict
|
|
1238
|
+
Arguments passed to method.
|
|
1239
|
+
conf_lev : float, optional (default: 0.9)
|
|
1240
|
+
Two-sided confidence interval.
|
|
1241
|
+
return_individual_bootstrap_results : bool
|
|
1242
|
+
Returns the individual bootstrap predictions.
|
|
1243
|
+
|
|
1244
|
+
Returns
|
|
1245
|
+
-------
|
|
1246
|
+
confidence_intervals : numpy array
|
|
1247
|
+
"""
|
|
1248
|
+
|
|
1249
|
+
valid_methods = ['predict_total_effect',
|
|
1250
|
+
'predict_wright_effect',
|
|
1251
|
+
]
|
|
1252
|
+
|
|
1253
|
+
if method not in valid_methods:
|
|
1254
|
+
raise ValueError("method must be one of %s" % str(valid_methods))
|
|
1255
|
+
|
|
1256
|
+
# def get_vectorized_length(W):
|
|
1257
|
+
# return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
|
|
1258
|
+
|
|
1259
|
+
lenX = len(self.listX)
|
|
1260
|
+
lenS = len(self.listS)
|
|
1261
|
+
lenY = len(self.listY)
|
|
1262
|
+
|
|
1263
|
+
n_interventions, _ = method_args['intervention_data'].shape
|
|
1264
|
+
|
|
1265
|
+
boot_samples = len(self.bootstrap_results)
|
|
1266
|
+
# bootstrap_predicted_array = np.zeros((boot_samples, n_interventions, lenY))
|
|
1267
|
+
|
|
1268
|
+
for b in range(boot_samples): #self.bootstrap_results.keys():
|
|
1269
|
+
self.model = self.bootstrap_results[b]
|
|
1270
|
+
boot_effect = getattr(self, method)(**method_args)
|
|
1271
|
+
|
|
1272
|
+
if isinstance(boot_effect, tuple):
|
|
1273
|
+
boot_effect = boot_effect[0]
|
|
1274
|
+
|
|
1275
|
+
if b == 0:
|
|
1276
|
+
bootstrap_predicted_array = np.zeros((boot_samples, ) + boot_effect.shape,
|
|
1277
|
+
dtype=boot_effect.dtype)
|
|
1278
|
+
bootstrap_predicted_array[b] = boot_effect
|
|
1279
|
+
|
|
1280
|
+
# Reset model
|
|
1281
|
+
self.model = self.original_model
|
|
1282
|
+
|
|
1283
|
+
# Confidence intervals for val_matrix; interval is two-sided
|
|
1284
|
+
c_int = (1. - (1. - conf_lev)/2.)
|
|
1285
|
+
confidence_interval = np.percentile(
|
|
1286
|
+
bootstrap_predicted_array, axis=0,
|
|
1287
|
+
q = [100*(1. - c_int), 100*c_int]) #[:,:,0]
|
|
1288
|
+
|
|
1289
|
+
if return_individual_bootstrap_results:
|
|
1290
|
+
return bootstrap_predicted_array, confidence_interval
|
|
1291
|
+
|
|
1292
|
+
return confidence_interval
|
|
1293
|
+
|
|
1294
|
+
|
|
1295
|
+
if __name__ == '__main__':
|
|
1296
|
+
|
|
1297
|
+
# Consider some toy data
|
|
1298
|
+
import tigramite
|
|
1299
|
+
import tigramite.toymodels.structural_causal_processes as toys
|
|
1300
|
+
import tigramite.data_processing as pp
|
|
1301
|
+
import tigramite.plotting as tp
|
|
1302
|
+
from matplotlib import pyplot as plt
|
|
1303
|
+
import sys
|
|
1304
|
+
|
|
1305
|
+
import sklearn
|
|
1306
|
+
from sklearn.linear_model import LinearRegression, LogisticRegression
|
|
1307
|
+
from sklearn.preprocessing import StandardScaler
|
|
1308
|
+
from sklearn.neural_network import MLPRegressor
|
|
1309
|
+
|
|
1310
|
+
|
|
1311
|
+
# def lin_f(x): return x
|
|
1312
|
+
# coeff = .5
|
|
1313
|
+
|
|
1314
|
+
# links_coeffs = {0: [((0, -1), 0.5, lin_f)],
|
|
1315
|
+
# 1: [((1, -1), 0.5, lin_f), ((0, -1), 0.5, lin_f)],
|
|
1316
|
+
# 2: [((2, -1), 0.5, lin_f), ((1, 0), 0.5, lin_f)]
|
|
1317
|
+
# }
|
|
1318
|
+
# T = 1000
|
|
1319
|
+
# data, nonstat = toys.structural_causal_process(
|
|
1320
|
+
# links_coeffs, T=T, noises=None, seed=7)
|
|
1321
|
+
# dataframe = pp.DataFrame(data)
|
|
1322
|
+
|
|
1323
|
+
# graph = CausalEffects.get_graph_from_dict(links_coeffs)
|
|
1324
|
+
|
|
1325
|
+
# original_graph = np.array([[['', ''],
|
|
1326
|
+
# ['-->', ''],
|
|
1327
|
+
# ['-->', ''],
|
|
1328
|
+
# ['', '']],
|
|
1329
|
+
|
|
1330
|
+
# [['<--', ''],
|
|
1331
|
+
# ['', '-->'],
|
|
1332
|
+
# ['-->', ''],
|
|
1333
|
+
# ['-->', '']],
|
|
1334
|
+
|
|
1335
|
+
# [['<--', ''],
|
|
1336
|
+
# ['<--', ''],
|
|
1337
|
+
# ['', '-->'],
|
|
1338
|
+
# ['-->', '']],
|
|
1339
|
+
|
|
1340
|
+
# [['', ''],
|
|
1341
|
+
# ['<--', ''],
|
|
1342
|
+
# ['<--', ''],
|
|
1343
|
+
# ['', '-->']]], dtype='<U3')
|
|
1344
|
+
# graph = np.copy(original_graph)
|
|
1345
|
+
|
|
1346
|
+
# # Add T <-> Reco and T
|
|
1347
|
+
# graph[2,3,0] = '+->' ; graph[3,2,0] = '<-+'
|
|
1348
|
+
# graph[1,3,1] = '<->' #; graph[2,1,0] = '<--'
|
|
1349
|
+
|
|
1350
|
+
# added = np.zeros((4, 4, 1), dtype='<U3')
|
|
1351
|
+
# added[:] = ""
|
|
1352
|
+
# graph = np.append(graph, added , axis=2)
|
|
1353
|
+
|
|
1354
|
+
|
|
1355
|
+
# X = [(1, 0)]
|
|
1356
|
+
# Y = [(3, 0)]
|
|
1357
|
+
|
|
1358
|
+
# # # Initialize class as `stationary_dag`
|
|
1359
|
+
# causal_effects = CausalEffects(graph, graph_type='stationary_admg',
|
|
1360
|
+
# X=X, Y=Y, S=None,
|
|
1361
|
+
# hidden_variables=None,
|
|
1362
|
+
# verbosity=0)
|
|
1363
|
+
|
|
1364
|
+
# print(causal_effects.get_optimal_set())
|
|
1365
|
+
|
|
1366
|
+
# tp.plot_time_series_graph(
|
|
1367
|
+
# graph = graph,
|
|
1368
|
+
# save_name='Example_graph_in.pdf',
|
|
1369
|
+
# # special_nodes=special_nodes,
|
|
1370
|
+
# # var_names=var_names,
|
|
1371
|
+
# figsize=(6, 4),
|
|
1372
|
+
# )
|
|
1373
|
+
|
|
1374
|
+
# tp.plot_time_series_graph(
|
|
1375
|
+
# graph = causal_effects.graph,
|
|
1376
|
+
# save_name='Example_graph_out.pdf',
|
|
1377
|
+
# # special_nodes=special_nodes,
|
|
1378
|
+
# # var_names=var_names,
|
|
1379
|
+
# figsize=(6, 4),
|
|
1380
|
+
# )
|
|
1381
|
+
|
|
1382
|
+
# causal_effects.fit_wright_effect(dataframe=dataframe,
|
|
1383
|
+
# # links_coeffs = links_coeffs,
|
|
1384
|
+
# # mediation = [(1, 0), (1, -1), (1, -2)]
|
|
1385
|
+
# )
|
|
1386
|
+
|
|
1387
|
+
# intervention_data = 1.*np.ones((1, 1))
|
|
1388
|
+
# y1 = causal_effects.predict_wright_effect(
|
|
1389
|
+
# intervention_data=intervention_data,
|
|
1390
|
+
# )
|
|
1391
|
+
|
|
1392
|
+
# intervention_data = 0.*np.ones((1, 1))
|
|
1393
|
+
# y2 = causal_effects.predict_wright_effect(
|
|
1394
|
+
# intervention_data=intervention_data,
|
|
1395
|
+
# )
|
|
1396
|
+
|
|
1397
|
+
# beta = (y1 - y2)
|
|
1398
|
+
# print("Causal effect is %.5f" %(beta))
|
|
1399
|
+
|
|
1400
|
+
# tp.plot_time_series_graph(
|
|
1401
|
+
# graph = causal_effects.graph,
|
|
1402
|
+
# save_name='Example_graph.pdf',
|
|
1403
|
+
# # special_nodes=special_nodes,
|
|
1404
|
+
# var_names=var_names,
|
|
1405
|
+
# figsize=(8, 4),
|
|
1406
|
+
# )
|
|
1407
|
+
|
|
1408
|
+
T = 10000
|
|
1409
|
+
def lin_f(x): return x
|
|
1410
|
+
|
|
1411
|
+
auto_coeff = 0.
|
|
1412
|
+
coeff = 2.
|
|
1413
|
+
|
|
1414
|
+
links = {
|
|
1415
|
+
0: [((0, -1), auto_coeff, lin_f)],
|
|
1416
|
+
1: [((1, -1), auto_coeff, lin_f)],
|
|
1417
|
+
2: [((2, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
|
|
1418
|
+
3: [((3, -1), auto_coeff, lin_f)],
|
|
1419
|
+
}
|
|
1420
|
+
data, nonstat = toys.structural_causal_process(links, T=T,
|
|
1421
|
+
noises=None, seed=7)
|
|
1422
|
+
|
|
1423
|
+
|
|
1424
|
+
# # Create some missing values
|
|
1425
|
+
# data[-10:,:] = 999.
|
|
1426
|
+
# var_names = range(2)
|
|
1427
|
+
|
|
1428
|
+
dataframe = pp.DataFrame(data,
|
|
1429
|
+
vector_vars={0:[(0,0), (1,0)],
|
|
1430
|
+
1:[(2,0), (3,0)]}
|
|
1431
|
+
)
|
|
1432
|
+
|
|
1433
|
+
# # Construct expert knowledge graph from links here
|
|
1434
|
+
aux_links = {0: [(0, -1)],
|
|
1435
|
+
1: [(1, -1), (0, 0)],
|
|
1436
|
+
}
|
|
1437
|
+
# # Use staticmethod to get graph
|
|
1438
|
+
graph = CausalEffects.get_graph_from_dict(aux_links, tau_max=2)
|
|
1439
|
+
# graph = np.array([['', '-->'],
|
|
1440
|
+
# ['<--', '']], dtype='<U3')
|
|
1441
|
+
|
|
1442
|
+
# # We are interested in lagged total effect of X on Y
|
|
1443
|
+
X = [(0, 0), (0, -1)]
|
|
1444
|
+
Y = [(1, 0), (1, -1)]
|
|
1445
|
+
|
|
1446
|
+
# # Initialize class as `stationary_dag`
|
|
1447
|
+
causal_effects = CausalEffects(graph, graph_type='stationary_dag',
|
|
1448
|
+
X=X, Y=Y, S=None,
|
|
1449
|
+
hidden_variables=None,
|
|
1450
|
+
verbosity=1)
|
|
1451
|
+
|
|
1452
|
+
# print(data)
|
|
1453
|
+
# # Optimal adjustment set (is used by default)
|
|
1454
|
+
# # print(causal_effects.get_optimal_set())
|
|
1455
|
+
|
|
1456
|
+
# # # Fit causal effect model from observational data
|
|
1457
|
+
causal_effects.fit_total_effect(
|
|
1458
|
+
dataframe=dataframe,
|
|
1459
|
+
# mask_type='y',
|
|
1460
|
+
estimator=LinearRegression(),
|
|
1461
|
+
)
|
|
1462
|
+
|
|
1463
|
+
# # Fit causal effect model from observational data
|
|
1464
|
+
# causal_effects.fit_bootstrap_of(
|
|
1465
|
+
# method='fit_total_effect',
|
|
1466
|
+
# method_args={'dataframe':dataframe,
|
|
1467
|
+
# # mask_type='y',
|
|
1468
|
+
# 'estimator':LinearRegression()
|
|
1469
|
+
# },
|
|
1470
|
+
# boot_samples=3,
|
|
1471
|
+
# boot_blocklength=1,
|
|
1472
|
+
# seed=5
|
|
1473
|
+
# )
|
|
1474
|
+
|
|
1475
|
+
|
|
1476
|
+
# Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
|
|
1477
|
+
lenX = 4 # len(dataframe.vector_vars[X[0][0]])
|
|
1478
|
+
dox_vals = np.linspace(0., 1., 3)
|
|
1479
|
+
intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), lenX)
|
|
1480
|
+
|
|
1481
|
+
intervention_data = np.array([[1., 0., 0., 0.]])
|
|
1482
|
+
|
|
1483
|
+
print(intervention_data)
|
|
1484
|
+
|
|
1485
|
+
pred_Y = causal_effects.predict_total_effect(
|
|
1486
|
+
intervention_data=intervention_data)
|
|
1487
|
+
print(pred_Y, pred_Y.shape)
|
|
1488
|
+
|
|
1489
|
+
|
|
1490
|
+
|
|
1491
|
+
|
|
1492
|
+
|
|
1493
|
+
# # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
|
|
1494
|
+
# # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
|
|
1495
|
+
# intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), len(X))
|
|
1496
|
+
# conf = causal_effects.predict_bootstrap_of(
|
|
1497
|
+
# method='predict_total_effect',
|
|
1498
|
+
# method_args={'intervention_data':intervention_data})
|
|
1499
|
+
# print(conf, conf.shape)
|
|
1500
|
+
|
|
1501
|
+
|
|
1502
|
+
|
|
1503
|
+
# # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
|
|
1504
|
+
# # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
|
|
1505
|
+
# # intervention_data = dox_vals.reshape(len(dox_vals), len(X))
|
|
1506
|
+
# # pred_Y = causal_effects.predict_total_effect(
|
|
1507
|
+
# # intervention_data=intervention_data)
|
|
1508
|
+
# # print(pred_Y)
|
|
1509
|
+
|
|
1510
|
+
|
|
1511
|
+
|
|
1512
|
+
# # Fit causal effect model from observational data
|
|
1513
|
+
# causal_effects.fit_wright_effect(
|
|
1514
|
+
# dataframe=dataframe,
|
|
1515
|
+
# # mask_type='y',
|
|
1516
|
+
# # estimator=LinearRegression(),
|
|
1517
|
+
# # data_transform=StandardScaler(),
|
|
1518
|
+
# )
|
|
1519
|
+
|
|
1520
|
+
# # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
|
|
1521
|
+
# dox_vals = np.linspace(0., 1., 5)
|
|
1522
|
+
# intervention_data = dox_vals.reshape(len(dox_vals), len(X))
|
|
1523
|
+
# pred_Y = causal_effects.predict_wright_effect(
|
|
1524
|
+
# intervention_data=intervention_data)
|
|
1525
|
+
# print(pred_Y)
|