tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1525 @@
1
+ """Tigramite causal inference for time series."""
2
+
3
+ # Author: Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+ import numpy as np
8
+ import math
9
+ import itertools
10
+ from copy import deepcopy
11
+ from collections import defaultdict
12
+ from tigramite.models import Models
13
+ from tigramite.graphs import Graphs
14
+ import struct
15
+
16
+ class CausalEffects(Graphs):
17
+ r"""Causal effect estimation.
18
+
19
+ Methods for the estimation of linear or non-parametric causal effects
20
+ between (potentially multivariate) X and Y (potentially conditional
21
+ on S) by (generalized) backdoor adjustment. Various graph types are
22
+ supported, also including hidden variables.
23
+
24
+ Linear and non-parametric estimators are based on sklearn. For the
25
+ linear case without hidden variables also an efficient estimation
26
+ based on Wright's path coefficients is available. This estimator
27
+ also allows to estimate mediation effects.
28
+
29
+ See the corresponding paper [6]_ and tigramite tutorial for an
30
+ in-depth introduction.
31
+
32
+ References
33
+ ----------
34
+
35
+ .. [6] J. Runge, Necessary and sufficient graphical conditions for
36
+ optimal adjustment sets in causal graphical models with
37
+ hidden variables, Advances in Neural Information Processing
38
+ Systems, 2021, 34
39
+ https://proceedings.neurips.cc/paper/2021/hash/8485ae387a981d783f8764e508151cd9-Abstract.html
40
+
41
+
42
+ Parameters
43
+ ----------
44
+ graph : array of either shape [N, N], [N, N, tau_max+1], or [N, N, tau_max+1, tau_max+1]
45
+ Different graph types are supported, see tutorial.
46
+ X : list of tuples
47
+ List of tuples [(i, -tau), ...] containing cause variables.
48
+ Y : list of tuples
49
+ List of tuples [(j, 0), ...] containing effect variables.
50
+ S : list of tuples
51
+ List of tuples [(i, -tau), ...] containing conditioned variables.
52
+ graph_type : str
53
+ Type of graph.
54
+ hidden_variables : list of tuples
55
+ Hidden variables in format [(i, -tau), ...]. The internal graph is
56
+ constructed by a latent projection.
57
+ check_SM_overlap : bool
58
+ Whether to check whether S overlaps with M.
59
+ verbosity : int, optional (default: 0)
60
+ Level of verbosity.
61
+ """
62
+
63
+ def __init__(self,
64
+ graph,
65
+ graph_type,
66
+ X,
67
+ Y,
68
+ S=None,
69
+ hidden_variables=None,
70
+ check_SM_overlap=True,
71
+ verbosity=0):
72
+
73
+ self.verbosity = verbosity
74
+ self.N = graph.shape[0]
75
+
76
+ if S is None:
77
+ S = []
78
+
79
+ self.listX = list(X)
80
+ self.listY = list(Y)
81
+ self.listS = list(S)
82
+
83
+ self.X = set(X)
84
+ self.Y = set(Y)
85
+ self.S = set(S)
86
+
87
+ # TODO?: check that masking aligns with hidden samples in variables
88
+ if hidden_variables is None:
89
+ hidden_variables = []
90
+
91
+ #
92
+ # Checks regarding graph type
93
+ #
94
+ supported_graphs = ['dag',
95
+ 'admg',
96
+ 'tsg_dag',
97
+ 'tsg_admg',
98
+ 'stationary_dag',
99
+ 'stationary_admg',
100
+
101
+ # 'mag',
102
+ # 'tsg_mag',
103
+ # 'stationary_mag',
104
+ # 'pag',
105
+ # 'tsg_pag',
106
+ # 'stationary_pag',
107
+ ]
108
+ if graph_type not in supported_graphs:
109
+ raise ValueError("Only graph types %s supported!" %supported_graphs)
110
+
111
+ # Determine tau_max
112
+ if graph_type in ['dag', 'admg']:
113
+ self.tau_max = 0
114
+
115
+ elif graph_type in ['tsg_dag', 'tsg_admg']:
116
+ # tau_max is implicitely derived from
117
+ # the dimensions
118
+ self.tau_max = graph.shape[2] - 1
119
+
120
+ elif graph_type in ['stationary_dag', 'stationary_admg']:
121
+ # For a stationary DAG without hidden variables it's sufficient to consider
122
+ # a tau_max that includes the parents of X, Y, M, and S. A conservative
123
+ # estimate thereof is simply the lag-dimension of the stationary DAG plus
124
+ # the maximum lag of XYS.
125
+ statgraph_tau_max = graph.shape[2] - 1
126
+ maxlag_XYS = 0
127
+ for varlag in self.X.union(self.Y).union(self.S):
128
+ maxlag_XYS = max(maxlag_XYS, abs(varlag[1]))
129
+ self.tau_max = maxlag_XYS + statgraph_tau_max
130
+ else:
131
+ raise ValueError("graph_type invalid.")
132
+
133
+ self.hidden_variables = set(hidden_variables)
134
+ if len(self.hidden_variables.intersection(self.X.union(self.Y).union(self.S))) > 0:
135
+ raise ValueError("XYS overlaps with hidden_variables!")
136
+
137
+ # self.tau_max is needed in the Graphs class
138
+ Graphs.__init__(self,
139
+ graph=graph,
140
+ graph_type=graph_type,
141
+ tau_max=self.tau_max,
142
+ hidden_variables=self.hidden_variables,
143
+ verbosity=verbosity)
144
+
145
+ self._check_XYS()
146
+
147
+ self.ancX = self._get_ancestors(X)
148
+ self.ancY = self._get_ancestors(Y)
149
+ self.ancS = self._get_ancestors(S)
150
+
151
+ # If X is not in anc(Y), then no causal link exists
152
+ if self.ancY.intersection(set(X)) == set():
153
+ self.no_causal_path = True
154
+ if self.verbosity > 0:
155
+ print("No causal path from X to Y exists.")
156
+ else:
157
+ self.no_causal_path = False
158
+
159
+ # Get mediators
160
+ mediators = self.get_mediators(start=self.X, end=self.Y)
161
+
162
+ M = set(mediators)
163
+ self.M = M
164
+
165
+ self.listM = list(self.M)
166
+
167
+ for varlag in self.X.union(self.Y).union(self.S):
168
+ if abs(varlag[1]) > self.tau_max:
169
+ raise ValueError("X, Y, S must have time lags inside graph.")
170
+
171
+ if len(self.X.intersection(self.Y)) > 0:
172
+ raise ValueError("Overlap between X and Y.")
173
+
174
+ if len(self.S.intersection(self.Y.union(self.X))) > 0:
175
+ raise ValueError("Conditions S overlap with X or Y.")
176
+
177
+ # # TODO: need to prove that this is sufficient for non-identifiability!
178
+ # if len(self.X.intersection(self._get_descendants(self.M))) > 0:
179
+ # raise ValueError("Not identifiable: Overlap between X and des(M)")
180
+
181
+ if check_SM_overlap and len(self.S.intersection(self.M)) > 0:
182
+ raise ValueError("Conditions S overlap with mediators M.")
183
+
184
+ self.desX = self._get_descendants(self.X)
185
+ self.desY = self._get_descendants(self.Y)
186
+ self.desM = self._get_descendants(self.M)
187
+ self.descendants = self.desY.union(self.desM)
188
+
189
+ # Define forb as X and descendants of YM
190
+ self.forbidden_nodes = self.descendants.union(self.X) #.union(S)
191
+
192
+ # Define valid ancestors
193
+ self.vancs = self.ancX.union(self.ancY).union(self.ancS) - self.forbidden_nodes
194
+
195
+ if self.verbosity > 0:
196
+ if len(self.S.intersection(self.desX)) > 0:
197
+ print("Warning: Potentially outside assumptions: Conditions S overlap with des(X)")
198
+
199
+ # Here only check if S overlaps with des(Y), leave the option that S
200
+ # contains variables in des(M) to the user
201
+ if len(self.S.intersection(self.desY)) > 0:
202
+ raise ValueError("Not identifiable: Conditions S overlap with des(Y).")
203
+
204
+ if self.verbosity > 0:
205
+ print("\n##\n## Initializing CausalEffects class\n##"
206
+ "\n\nInput:")
207
+ print("\ngraph_type = %s" % graph_type
208
+ + "\nX = %s" % self.listX
209
+ + "\nY = %s" % self.listY
210
+ + "\nS = %s" % self.listS
211
+ + "\nM = %s" % self.listM
212
+ )
213
+ if len(self.hidden_variables) > 0:
214
+ print("\nhidden_variables = %s" % self.hidden_variables
215
+ )
216
+ print("\n\n")
217
+ if self.no_causal_path:
218
+ print("No causal path from X to Y exists!")
219
+
220
+
221
+ def _check_XYS(self):
222
+ """Check whether XYS are sober.
223
+ """
224
+
225
+ XYS = self.X.union(self.Y).union(self.S)
226
+ for xys in XYS:
227
+ var, lag = xys
228
+ if var < 0 or var >= self.N:
229
+ raise ValueError("XYS vars must be in [0...N]")
230
+ if lag < -self.tau_max or lag > 0:
231
+ raise ValueError("XYS lags must be in [-taumax...0]")
232
+
233
+
234
+ def check_XYS_paths(self):
235
+ """Check whether one can remove nodes from X and Y with no proper causal paths.
236
+
237
+ Returns
238
+ -------
239
+ X, Y : cleaned lists of X and Y with irrelevant nodes removed.
240
+ """
241
+
242
+ # TODO: Also check S...
243
+ oldX = self.X.copy()
244
+ oldY = self.Y.copy()
245
+
246
+ # anc_Y = self._get_ancestors(self.Y)
247
+ # anc_S = self._get_ancestors(self.S)
248
+
249
+ # Remove first from X those nodes with no causal path to Y or S
250
+ X = set([x for x in self.X if x in self.ancY.union(self.ancS)])
251
+
252
+ # Remove from Y those nodes with no causal path from X
253
+ # des_X = self._get_descendants(X)
254
+
255
+ Y = set([y for y in self.Y if y in self.desX])
256
+
257
+ # Also require that all x in X have proper path to Y or S,
258
+ # that is, the first link goes out of x
259
+ # and into path nodes
260
+ mediators_S = self.get_mediators(start=self.X, end=self.S)
261
+ path_nodes = list(self.M.union(Y).union(mediators_S))
262
+ X = X.intersection(self._get_all_parents(path_nodes))
263
+
264
+ if set(oldX) != set(X) and self.verbosity > 0:
265
+ print("Consider pruning X = %s to X = %s " %(oldX, X) +
266
+ "since only these have causal path to Y")
267
+
268
+ if set(oldY) != set(Y) and self.verbosity > 0:
269
+ print("Consider pruning Y = %s to Y = %s " %(oldY, Y) +
270
+ "since only these have causal path from X")
271
+
272
+ return (list(X), list(Y))
273
+
274
+
275
+ def get_optimal_set(self,
276
+ alternative_conditions=None,
277
+ minimize=False,
278
+ return_separate_sets=False,
279
+ ):
280
+ """Returns optimal adjustment set.
281
+
282
+ See Runge NeurIPS 2021.
283
+
284
+ Parameters
285
+ ----------
286
+ alternative_conditions : set of tuples
287
+ Used only internally in optimality theorem. If None, self.S is used.
288
+ minimize : {False, True, 'colliders_only'}
289
+ Minimize optimal set. If True, minimize such that no subset
290
+ can be removed without making it invalid. If 'colliders_only',
291
+ only colliders are minimized.
292
+ return_separate_sets : bool
293
+ Whether to return tuple of parents, colliders, collider_parents, and S.
294
+
295
+ Returns
296
+ -------
297
+ Oset_S : False or list or tuple of lists
298
+ Returns optimal adjustment set if a valid set exists, otherwise False.
299
+ """
300
+
301
+
302
+ # Needed for optimality theorem where Osets for alternative S are tested
303
+ if alternative_conditions is None:
304
+ S = self.S.copy()
305
+ vancs = self.vancs.copy()
306
+ else:
307
+ S = alternative_conditions
308
+ newancS = self._get_ancestors(S)
309
+ vancs = self.ancX.union(self.ancY).union(newancS) - self.forbidden_nodes
310
+
311
+ # vancs = self._get_ancestors(list(self.X.union(self.Y).union(S))) - self.forbidden_nodes
312
+
313
+ # descendants = self._get_descendants(self.Y.union(self.M))
314
+
315
+ # Sufficient condition for non-identifiability
316
+ if len(self.X.intersection(self.descendants)) > 0:
317
+ return False # raise ValueError("Not identifiable: Overlap between X and des(M)")
318
+
319
+ ##
320
+ ## Construct O-set
321
+ ##
322
+
323
+ # Start with parents
324
+ parents = self._get_all_parents(self.Y.union(self.M)) # set()
325
+
326
+ # Remove forbidden nodes
327
+ parents = parents - self.forbidden_nodes
328
+
329
+ # Construct valid collider path nodes
330
+ colliders = set()
331
+ for w in self.Y.union(self.M):
332
+ j, tau = w
333
+ this_level = [w]
334
+ non_suitable_nodes = []
335
+ while len(this_level) > 0:
336
+ next_level = []
337
+ for varlag in this_level:
338
+ suitable_spouses = set(self._get_spouses(varlag)) - set(non_suitable_nodes)
339
+ for spouse in suitable_spouses:
340
+ i, tau = spouse
341
+ if spouse in self.X:
342
+ return False
343
+
344
+ if (# Node not already in set
345
+ spouse not in colliders #.union(parents)
346
+ # not forbidden
347
+ and spouse not in self.forbidden_nodes
348
+ # in time bounds
349
+ and (-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
350
+ and (spouse in vancs
351
+ or not self._check_path(#graph=self.graph,
352
+ start=self.X, end=[spouse],
353
+ conditions=list(parents.union(vancs)) + list(S),
354
+ ))
355
+ ):
356
+ colliders = colliders | {spouse}
357
+ next_level.append(spouse)
358
+ else:
359
+ if spouse not in colliders:
360
+ non_suitable_nodes.append(spouse)
361
+
362
+
363
+ this_level = set(next_level) - set(non_suitable_nodes)
364
+
365
+ # Add parents and raise Error if not identifiable
366
+ collider_parents = self._get_all_parents(colliders)
367
+ if len(self.X.intersection(collider_parents)) > 0:
368
+ return False
369
+
370
+ colliders_and_their_parents = colliders.union(collider_parents)
371
+
372
+ # Add valid collider path nodes and their parents
373
+ Oset = parents.union(colliders_and_their_parents)
374
+
375
+
376
+ if minimize:
377
+ removable = []
378
+ # First remove all those that have no path from X
379
+ sorted_Oset = Oset
380
+ if minimize == 'colliders_only':
381
+ sorted_Oset = [node for node in sorted_Oset if node not in parents]
382
+
383
+ for node in sorted_Oset:
384
+ if (not self._check_path(#graph=self.graph,
385
+ start=self.X, end=[node],
386
+ conditions=list(Oset - {node}) + list(S))):
387
+ removable.append(node)
388
+
389
+ Oset = Oset - set(removable)
390
+ if minimize == 'colliders_only':
391
+ sorted_Oset = [node for node in Oset if node not in parents]
392
+
393
+ removable = []
394
+ # Next remove all those with no direct connection to Y
395
+ for node in sorted_Oset:
396
+ if (not self._check_path(#graph=self.graph,
397
+ start=[node], end=self.Y,
398
+ conditions=list(Oset - {node}) + list(S) + list(self.X),
399
+ ends_with=['**>', '**+'])):
400
+ removable.append(node)
401
+
402
+ Oset = Oset - set(removable)
403
+
404
+ Oset_S = Oset.union(S)
405
+
406
+ if return_separate_sets:
407
+ return parents, colliders, collider_parents, S
408
+ else:
409
+ return list(Oset_S)
410
+
411
+
412
+ def _get_collider_paths_optimality(self, source_nodes, target_nodes,
413
+ condition,
414
+ inside_set=None,
415
+ start_with_tail_or_head=False,
416
+ ):
417
+ """Returns relevant collider paths to check optimality.
418
+
419
+ Iterates over collider paths within O-set via depth-first search
420
+
421
+ """
422
+
423
+ for w in source_nodes:
424
+ # Only used to return *all* collider paths
425
+ # (needed in optimality theorem)
426
+
427
+ coll_path = []
428
+
429
+ queue = [(w, coll_path)]
430
+
431
+ non_valid_subsets = []
432
+
433
+ while queue:
434
+
435
+ varlag, coll_path = queue.pop()
436
+
437
+ coll_path = coll_path + [varlag]
438
+ coll_path_set = set(coll_path)
439
+
440
+ suitable_nodes = set(self._get_spouses(varlag))
441
+
442
+ if start_with_tail_or_head and coll_path == [w]:
443
+ children = set(self._get_children(varlag))
444
+ suitable_nodes = suitable_nodes.union(children)
445
+
446
+ for node in suitable_nodes:
447
+ i, tau = node
448
+ if ((-self.tau_max <= tau <= 0) # or self.ignore_time_bounds)
449
+ and node not in coll_path_set):
450
+
451
+ if condition == 'II' and node not in target_nodes and node not in self.vancs:
452
+ continue
453
+
454
+ if node in inside_set:
455
+ if condition == 'I':
456
+ non_valid = False
457
+ extended_set = coll_path_set | {node}
458
+ for pathset in non_valid_subsets[::-1]:
459
+ if pathset.issubset(extended_set):
460
+ non_valid = True
461
+ break
462
+ if non_valid is False:
463
+ queue.append((node, coll_path))
464
+ else:
465
+ continue
466
+ elif condition == 'II':
467
+ queue.append((node, coll_path))
468
+
469
+ if node in target_nodes:
470
+ # yield coll_path
471
+ # collider_paths[node].append(coll_path)
472
+ if condition == 'I':
473
+ # Construct OπiN
474
+ Sprime = self.S.union(coll_path_set)
475
+ OpiN = self.get_optimal_set(alternative_conditions=Sprime)
476
+ if OpiN is False:
477
+ queue = [(q_node, q_path) for (q_node, q_path) in queue if not coll_path_set.issubset(set(q_path + [q_node]))]
478
+ non_valid_subsets.append(coll_path_set)
479
+ else:
480
+ return False
481
+
482
+ elif condition == 'II':
483
+ return True
484
+ # yield coll_path
485
+
486
+ if condition == 'I':
487
+ return True
488
+ elif condition == 'II':
489
+ return False
490
+ # return collider_paths
491
+
492
+
493
+ def check_optimality(self):
494
+ """Check whether optimal adjustment set exists according to Thm. 3 in Runge NeurIPS 2021.
495
+
496
+ Returns
497
+ -------
498
+ optimality : bool
499
+ Returns True if an optimal adjustment set exists, otherwise False.
500
+ """
501
+
502
+ # Cond. 0: Exactly one valid adjustment set exists
503
+ cond_0 = (self._get_all_valid_adjustment_sets(check_one_set_exists=True))
504
+
505
+ #
506
+ # Cond. I
507
+ #
508
+ parents, colliders, collider_parents, _ = self.get_optimal_set(return_separate_sets=True)
509
+ Oset = parents.union(colliders).union(collider_parents)
510
+ n_nodes = self._get_all_spouses(self.Y.union(self.M).union(colliders)) - self.forbidden_nodes - Oset - self.S - self.Y - self.M - colliders
511
+
512
+ if (len(n_nodes) == 0):
513
+ # # (1) There are no spouses N ∈ sp(YMC) \ (forbOS)
514
+ cond_I = True
515
+ else:
516
+
517
+ # (2) For all N ∈ N and all its collider paths i it holds that
518
+ # OπiN does not block all non-causal paths from X to Y
519
+ # cond_I = True
520
+ cond_I = self._get_collider_paths_optimality(
521
+ source_nodes=list(n_nodes), target_nodes=list(self.Y.union(self.M)),
522
+ condition='I',
523
+ inside_set=Oset.union(self.S), start_with_tail_or_head=False,
524
+ )
525
+
526
+ #
527
+ # Cond. II
528
+ #
529
+ e_nodes = Oset.difference(parents)
530
+ cond_II = True
531
+ for E in e_nodes:
532
+ Oset_minusE = Oset - {E}
533
+ if self._check_path(#graph=self.graph,
534
+ start=list(self.X), end=[E],
535
+ conditions=list(self.S) + list(Oset_minusE)):
536
+
537
+ cond_II = self._get_collider_paths_optimality(
538
+ target_nodes=self.Y.union(self.M),
539
+ source_nodes=[E],
540
+ condition='II',
541
+ inside_set=list(Oset.union(self.S)),
542
+ start_with_tail_or_head = True)
543
+
544
+ if cond_II is False:
545
+ if self.verbosity > 1:
546
+ print("Non-optimal due to E = ", E)
547
+ break
548
+
549
+ optimality = (cond_0 or (cond_I and cond_II))
550
+ if self.verbosity > 0:
551
+ print("Optimality = %s with cond_0 = %s, cond_I = %s, cond_II = %s"
552
+ % (optimality, cond_0, cond_I, cond_II))
553
+ return optimality
554
+
555
+ def _check_validity(self, Z):
556
+ """Checks whether Z is a valid adjustment set."""
557
+
558
+ # causal_children = list(self.M.union(self.Y))
559
+ backdoor_path = self._check_path(#graph=self.graph,
560
+ start=list(self.X), end=list(self.Y),
561
+ conditions=list(Z),
562
+ # causal_children=causal_children,
563
+ path_type = 'non_causal')
564
+
565
+ if backdoor_path:
566
+ return False
567
+ else:
568
+ return True
569
+
570
+ def _get_adjust_set(self,
571
+ minimize=False,
572
+ ):
573
+ """Returns Adjust-set.
574
+
575
+ See van der Zander, B.; Liśkiewicz, M. & Textor, J.
576
+ Separators and adjustment sets in causal graphs: Complete
577
+ criteria and an algorithmic framework
578
+ Artificial Intelligence, Elsevier, 2019, 270, 1-40
579
+
580
+ """
581
+
582
+ vancs = self.vancs.copy()
583
+
584
+ if minimize:
585
+ # Get removable nodes by computing minimal valid set from Z
586
+ if minimize == 'keep_parentsYM':
587
+ minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
588
+
589
+ else:
590
+ minimize_nodes = vancs
591
+
592
+ # Zprime2 = Zprime
593
+ # First remove all nodes that have no unique path to X given Oset
594
+ for node in minimize_nodes:
595
+ # path = self.oracle.check_shortest_path(X=X, Y=[node],
596
+ # Z=list(vancs - {node}),
597
+ # max_lag=None,
598
+ # starts_with=None, #'arrowhead',
599
+ # forbidden_nodes=None, #list(Zprime - {node}),
600
+ # return_path=False)
601
+ path = self._check_path(#graph=self.graph,
602
+ start=self.X, end=[node],
603
+ conditions=list(vancs - {node}),
604
+ )
605
+
606
+ if path is False:
607
+ vancs = vancs - {node}
608
+
609
+ if minimize == 'keep_parentsYM':
610
+ minimize_nodes = vancs - self._get_all_parents(list(self.Y.union(self.M)))
611
+ else:
612
+ minimize_nodes = vancs
613
+
614
+ # print(Zprime2)
615
+ # Next remove all nodes that have no unique path to Y given Oset_min
616
+ # Z = Zprime2
617
+ for node in minimize_nodes:
618
+
619
+ path = self._check_path(#graph=self.graph,
620
+ start=[node], end=self.Y,
621
+ conditions=list(vancs - {node}) + list(self.X),
622
+ )
623
+
624
+ if path is False:
625
+ vancs = vancs - {node}
626
+
627
+ if self._check_validity(list(vancs)) is False:
628
+ return False
629
+ else:
630
+ return list(vancs)
631
+
632
+
633
+ def _get_all_valid_adjustment_sets(self,
634
+ check_one_set_exists=False, yield_index=None):
635
+ """Constructs all valid adjustment sets or just checks whether one exists.
636
+
637
+ See van der Zander, B.; Liśkiewicz, M. & Textor, J.
638
+ Separators and adjustment sets in causal graphs: Complete
639
+ criteria and an algorithmic framework
640
+ Artificial Intelligence, Elsevier, 2019, 270, 1-40
641
+
642
+ """
643
+
644
+ cond_set = set(self.S)
645
+ all_vars = [(i, -tau) for i in range(self.N)
646
+ for tau in range(0, self.tau_max + 1)]
647
+
648
+ all_vars_set = set(all_vars) - self.forbidden_nodes
649
+
650
+
651
+ def find_sep(I, R):
652
+ Rprime = R - self.X - self.Y
653
+ # TODO: anteriors and NOT ancestors where
654
+ # anteriors include --- links in causal paths
655
+ # print(I)
656
+ XYI = list(self.X.union(self.Y).union(I))
657
+ # print(XYI)
658
+ ancs = self._get_ancestors(list(XYI))
659
+ Z = ancs.intersection(Rprime)
660
+ if self._check_validity(Z) is False:
661
+ return False
662
+ else:
663
+ return Z
664
+
665
+
666
+ def list_sep(I, R):
667
+ # print(find_sep(X, Y, I, R))
668
+ if find_sep(I, R) is not False:
669
+ # print(I,R)
670
+ if I == R:
671
+ # print('--->', I)
672
+ yield I
673
+ else:
674
+ # Pick arbitrary node from R-I
675
+ RminusI = list(R - I)
676
+ # print(R, I, RminusI)
677
+ v = RminusI[0]
678
+ # print("here ", X, Y, I.union(set([v])), R)
679
+ yield from list_sep(I | {v}, R)
680
+ yield from list_sep(I, R - {v})
681
+
682
+ # print("all ", X, Y, cond_set, all_vars_set)
683
+ all_sets = []
684
+ I = cond_set
685
+ R = all_vars_set
686
+ for index, valid_set in enumerate(list_sep(I, R)):
687
+ # print(valid_set)
688
+ all_sets.append(list(valid_set))
689
+ if check_one_set_exists and index > 0:
690
+ break
691
+
692
+ if yield_index is not None and index == yield_index:
693
+ return valid_set
694
+
695
+ if yield_index is not None:
696
+ return None
697
+
698
+ if check_one_set_exists:
699
+ if len(all_sets) == 1:
700
+ return True
701
+ else:
702
+ return False
703
+
704
+ return all_sets
705
+
706
+
707
+ def fit_total_effect(self,
708
+ dataframe,
709
+ estimator,
710
+ adjustment_set='optimal',
711
+ conditional_estimator=None,
712
+ data_transform=None,
713
+ mask_type=None,
714
+ ignore_identifiability=False,
715
+ ):
716
+ """Returns a fitted model for the total causal effect of X on Y
717
+ conditional on S.
718
+
719
+ Parameters
720
+ ----------
721
+ dataframe : data object
722
+ Tigramite dataframe object. It must have the attributes dataframe.values
723
+ yielding a numpy array of shape (observations T, variables N) and
724
+ optionally a mask of the same shape and a missing values flag.
725
+ estimator : sklearn model object
726
+ For example, sklearn.linear_model.LinearRegression() for a linear
727
+ regression model.
728
+ adjustment_set : str or list of tuples
729
+ If 'optimal' the Oset is used, if 'minimized_optimal' the minimized Oset,
730
+ and if 'colliders_minimized_optimal', the colliders-minimized Oset.
731
+ If a list of tuples is passed, this set is used.
732
+ conditional_estimator : sklearn model object, optional (default: None)
733
+ Used to fit conditional causal effects in nested regression.
734
+ If None, the same model as for estimator is used.
735
+ data_transform : sklearn preprocessing object, optional (default: None)
736
+ Used to transform data prior to fitting. For example,
737
+ sklearn.preprocessing.StandardScaler for simple standardization. The
738
+ fitted parameters are stored.
739
+ mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
740
+ Masking mode: Indicators for which variables in the dependence
741
+ measure I(X; Y | Z) the samples should be masked. If None, the mask
742
+ is not used. Explained in tutorial on masking and missing values.
743
+ ignore_identifiability : bool
744
+ Only applies to adjustment sets supplied by user. Ignores if that
745
+ set leads to a non-identifiable effect.
746
+ """
747
+
748
+ if self.no_causal_path:
749
+ if self.verbosity > 0:
750
+ print("No causal path from X to Y exists.")
751
+ return self
752
+
753
+ self.dataframe = dataframe
754
+ self.conditional_estimator = conditional_estimator
755
+
756
+ # if self.dataframe.has_vector_data:
757
+ # raise ValueError("vector_vars in DataFrame cannot be used together with CausalEffects!"
758
+ # " You can estimate vector-valued effects by using multivariate X, Y, S."
759
+ # " Note, however, that this requires assuming a graph at the level "
760
+ # "of the components of X, Y, S, ...")
761
+
762
+ if self.N != self.dataframe.N:
763
+ raise ValueError("Dataset dimensions inconsistent with number of variables in graph.")
764
+
765
+ if adjustment_set == 'optimal':
766
+ # Check optimality and use either optimal or colliders_only set
767
+ adjustment_set = self.get_optimal_set()
768
+ elif adjustment_set == 'colliders_minimized_optimal':
769
+ adjustment_set = self.get_optimal_set(minimize='colliders_only')
770
+ elif adjustment_set == 'minimized_optimal':
771
+ adjustment_set = self.get_optimal_set(minimize=True)
772
+ else:
773
+ if ignore_identifiability is False and self._check_validity(adjustment_set) is False:
774
+ raise ValueError("Chosen adjustment_set is not valid.")
775
+
776
+ if adjustment_set is False:
777
+ raise ValueError("Causal effect not identifiable via adjustment.")
778
+
779
+ self.adjustment_set = adjustment_set
780
+
781
+ # Fit model of Y on X and Z (and conditions)
782
+ # Build the model
783
+ self.model = Models(
784
+ dataframe=dataframe,
785
+ model=estimator,
786
+ conditional_model=conditional_estimator,
787
+ data_transform=data_transform,
788
+ mask_type=mask_type,
789
+ verbosity=self.verbosity)
790
+
791
+ self.model.get_general_fitted_model(
792
+ Y=self.listY, X=self.listX, Z=list(self.adjustment_set),
793
+ conditions=self.listS,
794
+ tau_max=self.tau_max,
795
+ cut_off='tau_max',
796
+ return_data=False)
797
+
798
+ return self
799
+
800
+ def predict_total_effect(self,
801
+ intervention_data,
802
+ conditions_data=None,
803
+ pred_params=None,
804
+ return_further_pred_results=False,
805
+ aggregation_func=np.mean,
806
+ transform_interventions_and_prediction=False,
807
+ intervention_type='hard',
808
+ ):
809
+ """Predict effect of intervention with fitted model.
810
+
811
+ Uses the model.predict() function of the sklearn model.
812
+
813
+ Parameters
814
+ ----------
815
+ intervention_data : numpy array
816
+ Numpy array of shape (n_interventions, len(X)) that contains the do(X) values.
817
+ conditions_data : data object, optional
818
+ Numpy array of shape (n_interventions, len(S)) that contains the S=s values.
819
+ pred_params : dict, optional
820
+ Optional parameters passed on to sklearn prediction function.
821
+ return_further_pred_results : bool, optional (default: False)
822
+ In case the predictor class returns more than just the expected value,
823
+ the entire results can be returned.
824
+ aggregation_func : callable
825
+ Callable applied to output of 'predict'. Default is 'np.mean'.
826
+ transform_interventions_and_prediction : bool (default: False)
827
+ Whether to perform the inverse data_transform on prediction results.
828
+ intervention_type : {'hard', 'soft'}
829
+ Specify whether intervention is 'hard' (set value) or 'soft'
830
+ (add value to observed data).
831
+
832
+ Returns
833
+ -------
834
+ Results from prediction: an array of shape (n_interventions, len(Y)).
835
+ If estimate_confidence = True, then a tuple is returned.
836
+ """
837
+
838
+ def get_vectorized_length(W):
839
+ return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
840
+
841
+ # lenX = len(self.listX)
842
+ # lenS = len(self.listS)
843
+
844
+ lenX = get_vectorized_length(self.listX)
845
+ lenS = get_vectorized_length(self.listS)
846
+
847
+ if intervention_data.shape[1] != lenX:
848
+ raise ValueError("intervention_data.shape[1] must be len(X).")
849
+
850
+ if intervention_type not in {'hard', 'soft'}:
851
+ raise ValueError("intervention_type must be 'hard' or 'soft'.")
852
+
853
+ if conditions_data is not None and lenS > 0:
854
+ if conditions_data.shape[1] != lenS:
855
+ raise ValueError("conditions_data.shape[1] must be len(S).")
856
+ if conditions_data.shape[0] != intervention_data.shape[0]:
857
+ raise ValueError("conditions_data.shape[0] must match intervention_data.shape[0].")
858
+ elif conditions_data is not None and lenS == 0:
859
+ raise ValueError("conditions_data specified, but S=None or empty.")
860
+ elif conditions_data is None and lenS > 0:
861
+ raise ValueError("S specified, but conditions_data is None.")
862
+
863
+
864
+ if self.no_causal_path:
865
+ if self.verbosity > 0:
866
+ print("No causal path from X to Y exists.")
867
+ return np.zeros((len(intervention_data), len(self.listY)))
868
+
869
+ effect = self.model.get_general_prediction(
870
+ intervention_data=intervention_data,
871
+ conditions_data=conditions_data,
872
+ pred_params=pred_params,
873
+ return_further_pred_results=return_further_pred_results,
874
+ transform_interventions_and_prediction=transform_interventions_and_prediction,
875
+ aggregation_func=aggregation_func,
876
+ intervention_type=intervention_type,)
877
+
878
+ return effect
879
+
880
+ def fit_wright_effect(self,
881
+ dataframe,
882
+ mediation=None,
883
+ method='parents',
884
+ links_coeffs=None,
885
+ data_transform=None,
886
+ mask_type=None,
887
+ ):
888
+ """Returns a fitted model for the total or mediated causal effect of X on Y
889
+ potentially through mediator variables.
890
+
891
+ Parameters
892
+ ----------
893
+ dataframe : data object
894
+ Tigramite dataframe object. It must have the attributes dataframe.values
895
+ yielding a numpy array of shape (observations T, variables N) and
896
+ optionally a mask of the same shape and a missing values flag.
897
+ mediation : None, 'direct', or list of tuples
898
+ If None, total effect is estimated, if 'direct' then only the direct effect is estimated,
899
+ else only those causal paths are considerd that pass at least through one of these mediator nodes.
900
+ method : {'parents', 'links_coeffs', 'optimal'}
901
+ Method to use for estimating Wright's path coefficients. If 'optimal',
902
+ the Oset is used, if 'links_coeffs', the coefficients in links_coeffs are used,
903
+ if 'parents', the parents are used (only valid for DAGs).
904
+ links_coeffs : dict
905
+ Only used if method = 'links_coeffs'.
906
+ Dictionary of format: {0:[((i, -tau), coeff),...], 1:[...],
907
+ ...} for all variables where i must be in [0..N-1] and tau >= 0 with
908
+ number of variables N. coeff must be a float.
909
+ data_transform : None
910
+ Not implemented for Wright estimator. Complicated for missing samples.
911
+ mask_type : {None, 'y','x','z','xy','xz','yz','xyz'}
912
+ Masking mode: Indicators for which variables in the dependence
913
+ measure I(X; Y | Z) the samples should be masked. If None, the mask
914
+ is not used. Explained in tutorial on masking and missing values.
915
+ """
916
+
917
+ if self.no_causal_path:
918
+ if self.verbosity > 0:
919
+ print("No causal path from X to Y exists.")
920
+ return self
921
+
922
+ if data_transform is not None:
923
+ raise ValueError("data_transform not implemented for Wright estimator."
924
+ " You can preprocess data yourself beforehand.")
925
+
926
+ import sklearn.linear_model
927
+
928
+ self.dataframe = dataframe
929
+ if self.dataframe.has_vector_data:
930
+ raise ValueError("vector_vars in DataFrame cannot be used together with Wright method!"
931
+ " You can either 1) estimate vector-valued effects by using multivariate (X, Y, S)"
932
+ " together with assuming a graph at the level of the components of (X, Y, S), "
933
+ " or 2) use vector_vars together with fit_total_effect and an estimator"
934
+ " that supports multiple outputs.")
935
+
936
+ estimator = sklearn.linear_model.LinearRegression()
937
+
938
+ # Fit model of Y on X and Z (and conditions)
939
+ # Build the model
940
+ self.model = Models(
941
+ dataframe=dataframe,
942
+ model=estimator,
943
+ data_transform=None, #data_transform,
944
+ mask_type=mask_type,
945
+ verbosity=self.verbosity)
946
+
947
+ mediators = self.M # self.get_mediators(start=self.X, end=self.Y)
948
+
949
+ if mediation == 'direct':
950
+ causal_paths = {}
951
+ for w in self.X:
952
+ causal_paths[w] = {}
953
+ for z in self.Y:
954
+ if w in self._get_parents(z):
955
+ causal_paths[w][z] = [[w, z]]
956
+ else:
957
+ causal_paths[w][z] = []
958
+ else:
959
+ causal_paths = self._get_causal_paths(source_nodes=self.X,
960
+ target_nodes=self.Y, mediators=mediators,
961
+ mediated_through=mediation, proper_paths=True)
962
+
963
+ if method == 'links_coeffs':
964
+ coeffs = {}
965
+ max_lag = 0
966
+ for medy in [med for med in mediators] + [y for y in self.listY]:
967
+ coeffs[medy] = {}
968
+ j, tauj = medy
969
+ for ipar, par_coeff in enumerate(links_coeffs[medy[0]]):
970
+ par, coeff, _ = par_coeff
971
+ i, taui = par
972
+ taui_shifted = taui + tauj
973
+ max_lag = max(abs(par[1]), max_lag)
974
+ coeffs[medy][(i, taui_shifted)] = coeff #self.fit_results[j][(j, 0)]['model'].coef_[ipar]
975
+
976
+ self.model.tau_max = max_lag
977
+ # print(coeffs)
978
+
979
+ elif method == 'optimal':
980
+ # all_parents = {}
981
+ coeffs = {}
982
+ for medy in [med for med in mediators] + [y for y in self.listY]:
983
+ coeffs[medy] = {}
984
+ mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X).union(self.Y)) - {medy}
985
+ all_parents = self._get_all_parents([medy]) - {medy}
986
+ for par in mediator_parents:
987
+ Sprime = set(all_parents) - {par, medy}
988
+ causal_effects = CausalEffects(graph=self.graph,
989
+ X=[par], Y=[medy], S=Sprime,
990
+ graph_type=self.graph_type,
991
+ check_SM_overlap=False,
992
+ )
993
+ oset = causal_effects.get_optimal_set()
994
+ # print(medy, par, list(set(all_parents)), oset)
995
+ if oset is False:
996
+ raise ValueError("Not identifiable via Wright's method.")
997
+ fit_res = self.model.get_general_fitted_model(
998
+ Y=[medy], X=[par], Z=oset,
999
+ tau_max=self.tau_max,
1000
+ cut_off='tau_max',
1001
+ return_data=False)
1002
+ coeffs[medy][par] = fit_res['model'].coef_[0]
1003
+
1004
+ elif method == 'parents':
1005
+ coeffs = {}
1006
+ for medy in [med for med in mediators] + [y for y in self.listY]:
1007
+ coeffs[medy] = {}
1008
+ # mediator_parents = self._get_all_parents([medy]).intersection(mediators.union(self.X)) - {medy}
1009
+ all_parents = self._get_all_parents([medy]) - {medy}
1010
+ if 'dag' not in self.graph_type:
1011
+ spouses = self._get_all_spouses([medy]) - {medy}
1012
+ if len(spouses) != 0:
1013
+ raise ValueError("method == 'parents' only possible for "
1014
+ "causal paths without adjacent bi-directed links!")
1015
+
1016
+ # print(j, all_parents[j])
1017
+ # if len(all_parents[j]) > 0:
1018
+ # print(medy, list(all_parents))
1019
+ fit_res = self.model.get_general_fitted_model(
1020
+ Y=[medy], X=list(all_parents), Z=[],
1021
+ conditions=None,
1022
+ tau_max=self.tau_max,
1023
+ cut_off='tau_max',
1024
+ return_data=False)
1025
+
1026
+ for ipar, par in enumerate(list(all_parents)):
1027
+ # print(par, fit_res['model'].coef_)
1028
+ coeffs[medy][par] = fit_res['model'].coef_[0][ipar]
1029
+
1030
+ else:
1031
+ raise ValueError("method must be 'optimal', 'links_coeffs', or 'parents'.")
1032
+
1033
+ # Effect is sum over products over all path coefficients
1034
+ # from x in X to y in Y
1035
+ effect = {}
1036
+ for (x, y) in itertools.product(self.listX, self.listY):
1037
+ effect[(x, y)] = 0.
1038
+ for causal_path in causal_paths[x][y]:
1039
+ effect_here = 1.
1040
+ # print(x, y, causal_path)
1041
+ for index, node in enumerate(causal_path[:-1]):
1042
+ i, taui = node
1043
+ j, tauj = causal_path[index + 1]
1044
+ # tau_ij = abs(tauj - taui)
1045
+ # print((j, tauj), (i, taui))
1046
+ effect_here *= coeffs[(j, tauj)][(i, taui)]
1047
+
1048
+ effect[(x, y)] += effect_here
1049
+
1050
+ # Make fitted coefficients available as attribute
1051
+ self.coeffs = coeffs
1052
+
1053
+ # Modify and overwrite variables in self.model
1054
+ self.model.Y = self.listY
1055
+ self.model.X = self.listX
1056
+ self.model.Z = []
1057
+ self.model.conditions = []
1058
+ self.model.cut_off = 'tau_max' # 'max_lag_or_tau_max'
1059
+
1060
+ class dummy_fit_class():
1061
+ def __init__(self, y_here, listX_here, effect_here):
1062
+ dim = len(listX_here)
1063
+ self.coeff_array = np.array([effect_here[(x, y_here)] for x in listX_here]).reshape(dim, 1)
1064
+ def predict(self, X):
1065
+ return np.dot(X, self.coeff_array).squeeze()
1066
+
1067
+ fit_results = {}
1068
+ for y in self.listY:
1069
+ fit_results[y] = {}
1070
+ fit_results[y]['model'] = dummy_fit_class(y, self.listX, effect)
1071
+ fit_results[y]['data_transform'] = deepcopy(data_transform)
1072
+
1073
+ # self.effect = effect
1074
+ self.model.fit_results = fit_results
1075
+ return self
1076
+
1077
+ def predict_wright_effect(self,
1078
+ intervention_data,
1079
+ pred_params=None,
1080
+ ):
1081
+ """Predict linear effect of intervention with fitted Wright-model.
1082
+
1083
+ Parameters
1084
+ ----------
1085
+ intervention_data : numpy array
1086
+ Numpy array of shape (n_interventions, len(X)) that contains the do(X) values.
1087
+ pred_params : dict, optional
1088
+ Optional parameters passed on to sklearn prediction function.
1089
+
1090
+ Returns
1091
+ -------
1092
+ Results from prediction: an array of shape (n_interventions, len(Y)).
1093
+ """
1094
+
1095
+ lenX = len(self.listX)
1096
+ lenY = len(self.listY)
1097
+
1098
+ if intervention_data.shape[1] != lenX:
1099
+ raise ValueError("intervention_data.shape[1] must be len(X).")
1100
+
1101
+ if self.no_causal_path:
1102
+ if self.verbosity > 0:
1103
+ print("No causal path from X to Y exists.")
1104
+ return np.zeros((len(intervention_data), len(self.Y)))
1105
+
1106
+ n_interventions, _ = intervention_data.shape
1107
+
1108
+
1109
+ predicted_array = np.zeros((n_interventions, lenY))
1110
+ pred_dict = {}
1111
+ for iy, y in enumerate(self.listY):
1112
+ # Print message
1113
+ if self.verbosity > 1:
1114
+ print("\n## Predicting target %s" % str(y))
1115
+ if pred_params is not None:
1116
+ for key in list(pred_params):
1117
+ print("%s = %s" % (key, pred_params[key]))
1118
+ # Default value for pred_params
1119
+ if pred_params is None:
1120
+ pred_params = {}
1121
+ # Check this is a valid target
1122
+ if y not in self.model.fit_results:
1123
+ raise ValueError("y = %s not yet fitted" % str(y))
1124
+
1125
+ # data_transform is too complicated for Wright estimator
1126
+ # Transform the data if needed
1127
+ # fitted_data_transform = self.model.fit_results[y]['fitted_data_transform']
1128
+ # if fitted_data_transform is not None:
1129
+ # intervention_data = fitted_data_transform['X'].transform(X=intervention_data)
1130
+
1131
+ # Now iterate through interventions (and potentially S)
1132
+ for index, dox_vals in enumerate(intervention_data):
1133
+ # Construct XZS-array
1134
+ intervention_array = dox_vals.reshape(1, lenX)
1135
+ predictor_array = intervention_array
1136
+
1137
+ predicted_vals = self.model.fit_results[y]['model'].predict(
1138
+ X=predictor_array, **pred_params)
1139
+ predicted_array[index, iy] = predicted_vals.mean()
1140
+
1141
+ # data_transform is too complicated for Wright estimator
1142
+ # if fitted_data_transform is not None:
1143
+ # rescaled = fitted_data_transform['Y'].inverse_transform(X=predicted_array[index, iy].reshape(-1, 1))
1144
+ # predicted_array[index, iy] = rescaled.squeeze()
1145
+
1146
+ return predicted_array
1147
+
1148
+
1149
+ def fit_bootstrap_of(self, method, method_args,
1150
+ boot_samples=100,
1151
+ boot_blocklength=1,
1152
+ seed=None):
1153
+ """Runs chosen method on bootstrap samples drawn from DataFrame.
1154
+
1155
+ Bootstraps for tau=0 are drawn from [max_lag, ..., T] and all lagged
1156
+ variables constructed in DataFrame.construct_array are consistently
1157
+ shifted with respect to this bootsrap sample to ensure that lagged
1158
+ relations in the bootstrap sample are preserved.
1159
+
1160
+ This function fits the models, predict_bootstrap_of can then be used
1161
+ to get confidence intervals for the effect of interventions.
1162
+
1163
+ Parameters
1164
+ ----------
1165
+ method : str
1166
+ Chosen method among valid functions in this class.
1167
+ method_args : dict
1168
+ Arguments passed to method.
1169
+ boot_samples : int
1170
+ Number of bootstrap samples to draw.
1171
+ boot_blocklength : int, optional (default: 1)
1172
+ Block length for block-bootstrap.
1173
+ seed : int, optional(default = None)
1174
+ Seed for RandomState (default_rng)
1175
+ """
1176
+
1177
+ # if dataframe.analysis_mode != 'single':
1178
+ # raise ValueError("CausalEffects class currently only supports single "
1179
+ # "datasets.")
1180
+
1181
+ valid_methods = ['fit_total_effect',
1182
+ 'fit_wright_effect',
1183
+ ]
1184
+
1185
+ if method not in valid_methods:
1186
+ raise ValueError("method must be one of %s" % str(valid_methods))
1187
+
1188
+ # First call the method on the original dataframe
1189
+ # to make available adjustment set etc
1190
+ getattr(self, method)(**method_args)
1191
+
1192
+ self.original_model = deepcopy(self.model)
1193
+
1194
+ if self.verbosity > 0:
1195
+ print("\n##\n## Running Bootstrap of %s " % method +
1196
+ "\n##\n" +
1197
+ "\nboot_samples = %s \n" % boot_samples +
1198
+ "\nboot_blocklength = %s \n" % boot_blocklength
1199
+ )
1200
+
1201
+ method_args_bootstrap = deepcopy(method_args)
1202
+ self.bootstrap_results = {}
1203
+
1204
+ for b in range(boot_samples):
1205
+ # # Replace dataframe in method args by bootstrapped dataframe
1206
+ # method_args_bootstrap['dataframe'].bootstrap = boot_draw
1207
+ if seed is None:
1208
+ random_state = np.random.default_rng(None)
1209
+ else:
1210
+ random_state = np.random.default_rng(seed*boot_samples + b)
1211
+
1212
+ method_args_bootstrap['dataframe'].bootstrap = {'boot_blocklength':boot_blocklength,
1213
+ 'random_state':random_state}
1214
+
1215
+ # Call method and save fitted model
1216
+ getattr(self, method)(**method_args_bootstrap)
1217
+ self.bootstrap_results[b] = deepcopy(self.model)
1218
+
1219
+ # Reset model
1220
+ self.model = self.original_model
1221
+
1222
+ return self
1223
+
1224
+
1225
+ def predict_bootstrap_of(self, method, method_args,
1226
+ conf_lev=0.9,
1227
+ return_individual_bootstrap_results=False):
1228
+ """Predicts with fitted bootstraps.
1229
+
1230
+ To be used after fitting with fit_bootstrap_of. Only uses the
1231
+ expected values of the predict function, not potential other output.
1232
+
1233
+ Parameters
1234
+ ----------
1235
+ method : str
1236
+ Chosen method among valid functions in this class.
1237
+ method_args : dict
1238
+ Arguments passed to method.
1239
+ conf_lev : float, optional (default: 0.9)
1240
+ Two-sided confidence interval.
1241
+ return_individual_bootstrap_results : bool
1242
+ Returns the individual bootstrap predictions.
1243
+
1244
+ Returns
1245
+ -------
1246
+ confidence_intervals : numpy array
1247
+ """
1248
+
1249
+ valid_methods = ['predict_total_effect',
1250
+ 'predict_wright_effect',
1251
+ ]
1252
+
1253
+ if method not in valid_methods:
1254
+ raise ValueError("method must be one of %s" % str(valid_methods))
1255
+
1256
+ # def get_vectorized_length(W):
1257
+ # return sum([len(self.dataframe.vector_vars[w[0]]) for w in W])
1258
+
1259
+ lenX = len(self.listX)
1260
+ lenS = len(self.listS)
1261
+ lenY = len(self.listY)
1262
+
1263
+ n_interventions, _ = method_args['intervention_data'].shape
1264
+
1265
+ boot_samples = len(self.bootstrap_results)
1266
+ # bootstrap_predicted_array = np.zeros((boot_samples, n_interventions, lenY))
1267
+
1268
+ for b in range(boot_samples): #self.bootstrap_results.keys():
1269
+ self.model = self.bootstrap_results[b]
1270
+ boot_effect = getattr(self, method)(**method_args)
1271
+
1272
+ if isinstance(boot_effect, tuple):
1273
+ boot_effect = boot_effect[0]
1274
+
1275
+ if b == 0:
1276
+ bootstrap_predicted_array = np.zeros((boot_samples, ) + boot_effect.shape,
1277
+ dtype=boot_effect.dtype)
1278
+ bootstrap_predicted_array[b] = boot_effect
1279
+
1280
+ # Reset model
1281
+ self.model = self.original_model
1282
+
1283
+ # Confidence intervals for val_matrix; interval is two-sided
1284
+ c_int = (1. - (1. - conf_lev)/2.)
1285
+ confidence_interval = np.percentile(
1286
+ bootstrap_predicted_array, axis=0,
1287
+ q = [100*(1. - c_int), 100*c_int]) #[:,:,0]
1288
+
1289
+ if return_individual_bootstrap_results:
1290
+ return bootstrap_predicted_array, confidence_interval
1291
+
1292
+ return confidence_interval
1293
+
1294
+
1295
+ if __name__ == '__main__':
1296
+
1297
+ # Consider some toy data
1298
+ import tigramite
1299
+ import tigramite.toymodels.structural_causal_processes as toys
1300
+ import tigramite.data_processing as pp
1301
+ import tigramite.plotting as tp
1302
+ from matplotlib import pyplot as plt
1303
+ import sys
1304
+
1305
+ import sklearn
1306
+ from sklearn.linear_model import LinearRegression, LogisticRegression
1307
+ from sklearn.preprocessing import StandardScaler
1308
+ from sklearn.neural_network import MLPRegressor
1309
+
1310
+
1311
+ # def lin_f(x): return x
1312
+ # coeff = .5
1313
+
1314
+ # links_coeffs = {0: [((0, -1), 0.5, lin_f)],
1315
+ # 1: [((1, -1), 0.5, lin_f), ((0, -1), 0.5, lin_f)],
1316
+ # 2: [((2, -1), 0.5, lin_f), ((1, 0), 0.5, lin_f)]
1317
+ # }
1318
+ # T = 1000
1319
+ # data, nonstat = toys.structural_causal_process(
1320
+ # links_coeffs, T=T, noises=None, seed=7)
1321
+ # dataframe = pp.DataFrame(data)
1322
+
1323
+ # graph = CausalEffects.get_graph_from_dict(links_coeffs)
1324
+
1325
+ # original_graph = np.array([[['', ''],
1326
+ # ['-->', ''],
1327
+ # ['-->', ''],
1328
+ # ['', '']],
1329
+
1330
+ # [['<--', ''],
1331
+ # ['', '-->'],
1332
+ # ['-->', ''],
1333
+ # ['-->', '']],
1334
+
1335
+ # [['<--', ''],
1336
+ # ['<--', ''],
1337
+ # ['', '-->'],
1338
+ # ['-->', '']],
1339
+
1340
+ # [['', ''],
1341
+ # ['<--', ''],
1342
+ # ['<--', ''],
1343
+ # ['', '-->']]], dtype='<U3')
1344
+ # graph = np.copy(original_graph)
1345
+
1346
+ # # Add T <-> Reco and T
1347
+ # graph[2,3,0] = '+->' ; graph[3,2,0] = '<-+'
1348
+ # graph[1,3,1] = '<->' #; graph[2,1,0] = '<--'
1349
+
1350
+ # added = np.zeros((4, 4, 1), dtype='<U3')
1351
+ # added[:] = ""
1352
+ # graph = np.append(graph, added , axis=2)
1353
+
1354
+
1355
+ # X = [(1, 0)]
1356
+ # Y = [(3, 0)]
1357
+
1358
+ # # # Initialize class as `stationary_dag`
1359
+ # causal_effects = CausalEffects(graph, graph_type='stationary_admg',
1360
+ # X=X, Y=Y, S=None,
1361
+ # hidden_variables=None,
1362
+ # verbosity=0)
1363
+
1364
+ # print(causal_effects.get_optimal_set())
1365
+
1366
+ # tp.plot_time_series_graph(
1367
+ # graph = graph,
1368
+ # save_name='Example_graph_in.pdf',
1369
+ # # special_nodes=special_nodes,
1370
+ # # var_names=var_names,
1371
+ # figsize=(6, 4),
1372
+ # )
1373
+
1374
+ # tp.plot_time_series_graph(
1375
+ # graph = causal_effects.graph,
1376
+ # save_name='Example_graph_out.pdf',
1377
+ # # special_nodes=special_nodes,
1378
+ # # var_names=var_names,
1379
+ # figsize=(6, 4),
1380
+ # )
1381
+
1382
+ # causal_effects.fit_wright_effect(dataframe=dataframe,
1383
+ # # links_coeffs = links_coeffs,
1384
+ # # mediation = [(1, 0), (1, -1), (1, -2)]
1385
+ # )
1386
+
1387
+ # intervention_data = 1.*np.ones((1, 1))
1388
+ # y1 = causal_effects.predict_wright_effect(
1389
+ # intervention_data=intervention_data,
1390
+ # )
1391
+
1392
+ # intervention_data = 0.*np.ones((1, 1))
1393
+ # y2 = causal_effects.predict_wright_effect(
1394
+ # intervention_data=intervention_data,
1395
+ # )
1396
+
1397
+ # beta = (y1 - y2)
1398
+ # print("Causal effect is %.5f" %(beta))
1399
+
1400
+ # tp.plot_time_series_graph(
1401
+ # graph = causal_effects.graph,
1402
+ # save_name='Example_graph.pdf',
1403
+ # # special_nodes=special_nodes,
1404
+ # var_names=var_names,
1405
+ # figsize=(8, 4),
1406
+ # )
1407
+
1408
+ T = 10000
1409
+ def lin_f(x): return x
1410
+
1411
+ auto_coeff = 0.
1412
+ coeff = 2.
1413
+
1414
+ links = {
1415
+ 0: [((0, -1), auto_coeff, lin_f)],
1416
+ 1: [((1, -1), auto_coeff, lin_f)],
1417
+ 2: [((2, -1), auto_coeff, lin_f), ((0, 0), coeff, lin_f)],
1418
+ 3: [((3, -1), auto_coeff, lin_f)],
1419
+ }
1420
+ data, nonstat = toys.structural_causal_process(links, T=T,
1421
+ noises=None, seed=7)
1422
+
1423
+
1424
+ # # Create some missing values
1425
+ # data[-10:,:] = 999.
1426
+ # var_names = range(2)
1427
+
1428
+ dataframe = pp.DataFrame(data,
1429
+ vector_vars={0:[(0,0), (1,0)],
1430
+ 1:[(2,0), (3,0)]}
1431
+ )
1432
+
1433
+ # # Construct expert knowledge graph from links here
1434
+ aux_links = {0: [(0, -1)],
1435
+ 1: [(1, -1), (0, 0)],
1436
+ }
1437
+ # # Use staticmethod to get graph
1438
+ graph = CausalEffects.get_graph_from_dict(aux_links, tau_max=2)
1439
+ # graph = np.array([['', '-->'],
1440
+ # ['<--', '']], dtype='<U3')
1441
+
1442
+ # # We are interested in lagged total effect of X on Y
1443
+ X = [(0, 0), (0, -1)]
1444
+ Y = [(1, 0), (1, -1)]
1445
+
1446
+ # # Initialize class as `stationary_dag`
1447
+ causal_effects = CausalEffects(graph, graph_type='stationary_dag',
1448
+ X=X, Y=Y, S=None,
1449
+ hidden_variables=None,
1450
+ verbosity=1)
1451
+
1452
+ # print(data)
1453
+ # # Optimal adjustment set (is used by default)
1454
+ # # print(causal_effects.get_optimal_set())
1455
+
1456
+ # # # Fit causal effect model from observational data
1457
+ causal_effects.fit_total_effect(
1458
+ dataframe=dataframe,
1459
+ # mask_type='y',
1460
+ estimator=LinearRegression(),
1461
+ )
1462
+
1463
+ # # Fit causal effect model from observational data
1464
+ # causal_effects.fit_bootstrap_of(
1465
+ # method='fit_total_effect',
1466
+ # method_args={'dataframe':dataframe,
1467
+ # # mask_type='y',
1468
+ # 'estimator':LinearRegression()
1469
+ # },
1470
+ # boot_samples=3,
1471
+ # boot_blocklength=1,
1472
+ # seed=5
1473
+ # )
1474
+
1475
+
1476
+ # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
1477
+ lenX = 4 # len(dataframe.vector_vars[X[0][0]])
1478
+ dox_vals = np.linspace(0., 1., 3)
1479
+ intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), lenX)
1480
+
1481
+ intervention_data = np.array([[1., 0., 0., 0.]])
1482
+
1483
+ print(intervention_data)
1484
+
1485
+ pred_Y = causal_effects.predict_total_effect(
1486
+ intervention_data=intervention_data)
1487
+ print(pred_Y, pred_Y.shape)
1488
+
1489
+
1490
+
1491
+
1492
+
1493
+ # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
1494
+ # # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
1495
+ # intervention_data = np.tile(dox_vals.reshape(len(dox_vals), 1), len(X))
1496
+ # conf = causal_effects.predict_bootstrap_of(
1497
+ # method='predict_total_effect',
1498
+ # method_args={'intervention_data':intervention_data})
1499
+ # print(conf, conf.shape)
1500
+
1501
+
1502
+
1503
+ # # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
1504
+ # # dox_vals = np.array([1.]) #np.linspace(0., 1., 1)
1505
+ # # intervention_data = dox_vals.reshape(len(dox_vals), len(X))
1506
+ # # pred_Y = causal_effects.predict_total_effect(
1507
+ # # intervention_data=intervention_data)
1508
+ # # print(pred_Y)
1509
+
1510
+
1511
+
1512
+ # # Fit causal effect model from observational data
1513
+ # causal_effects.fit_wright_effect(
1514
+ # dataframe=dataframe,
1515
+ # # mask_type='y',
1516
+ # # estimator=LinearRegression(),
1517
+ # # data_transform=StandardScaler(),
1518
+ # )
1519
+
1520
+ # # # Predict effect of interventions do(X=0.), ..., do(X=1.) in one go
1521
+ # dox_vals = np.linspace(0., 1., 5)
1522
+ # intervention_data = dox_vals.reshape(len(dox_vals), len(X))
1523
+ # pred_Y = causal_effects.predict_wright_effect(
1524
+ # intervention_data=intervention_data)
1525
+ # print(pred_Y)