tigramite-fast 5.2.10.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. tigramite/__init__.py +0 -0
  2. tigramite/causal_effects.py +1525 -0
  3. tigramite/causal_mediation.py +1592 -0
  4. tigramite/data_processing.py +1574 -0
  5. tigramite/graphs.py +1509 -0
  6. tigramite/independence_tests/LBFGS.py +1114 -0
  7. tigramite/independence_tests/__init__.py +0 -0
  8. tigramite/independence_tests/cmiknn.py +661 -0
  9. tigramite/independence_tests/cmiknn_mixed.py +1397 -0
  10. tigramite/independence_tests/cmisymb.py +286 -0
  11. tigramite/independence_tests/gpdc.py +664 -0
  12. tigramite/independence_tests/gpdc_torch.py +820 -0
  13. tigramite/independence_tests/gsquared.py +190 -0
  14. tigramite/independence_tests/independence_tests_base.py +1310 -0
  15. tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
  16. tigramite/independence_tests/pairwise_CI.py +383 -0
  17. tigramite/independence_tests/parcorr.py +369 -0
  18. tigramite/independence_tests/parcorr_mult.py +485 -0
  19. tigramite/independence_tests/parcorr_wls.py +451 -0
  20. tigramite/independence_tests/regressionCI.py +403 -0
  21. tigramite/independence_tests/robust_parcorr.py +403 -0
  22. tigramite/jpcmciplus.py +966 -0
  23. tigramite/lpcmci.py +3649 -0
  24. tigramite/models.py +2257 -0
  25. tigramite/pcmci.py +3935 -0
  26. tigramite/pcmci_base.py +1218 -0
  27. tigramite/plotting.py +4735 -0
  28. tigramite/rpcmci.py +467 -0
  29. tigramite/toymodels/__init__.py +0 -0
  30. tigramite/toymodels/context_model.py +261 -0
  31. tigramite/toymodels/non_additive.py +1231 -0
  32. tigramite/toymodels/structural_causal_processes.py +1201 -0
  33. tigramite/toymodels/surrogate_generator.py +319 -0
  34. tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
  35. tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
  36. tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
  37. tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
  38. tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1231 @@
1
+ """Tigramite causal discovery for time series."""
2
+
3
+ # Authors: Martin Rabel, Jakob Runge <jakob@jakob-runge.com>
4
+ #
5
+ # License: GNU General Public License v3.0
6
+
7
+
8
+ import numpy as np
9
+ from tigramite.data_processing import DataFrame
10
+ from tigramite.causal_effects import CausalEffects
11
+
12
+
13
+ class VariableDescription:
14
+ r"""Variable descrption (base-class)
15
+
16
+ Used for mixed variable fitting and mediation, see Mediation-tutorial.
17
+
18
+ Parameters
19
+ ----------
20
+ name : string
21
+ The display-name of the variable
22
+ observable : bool
23
+ Is the variable observable or hidden? (used by toy-models)
24
+ exogenous : bool
25
+ Is the variable assumed exogenous (used on noise-terms internally, keep default)
26
+ """
27
+
28
+ def __init__(self, name="(unnamed)", observable=True, exogenous=False):
29
+ """Use CategoricalVariable or ContinuousVariable instead."""
30
+ self.noise_term = None
31
+ self.observable = observable
32
+ self.exogenous = exogenous
33
+ self.name = name
34
+
35
+ def Noise(self):
36
+ """
37
+ Get the variable-description for the noise-term on this variable.
38
+ Used e.g. for describing SEMs in noise-models
39
+
40
+ Returns
41
+ -------
42
+ noise : VariableDescription
43
+ The variable-description for the noise-term on this variable.
44
+ """
45
+ if self.noise_term is None:
46
+ self.noise_term = ExogenousNoiseVariable(self)
47
+ return self.noise_term
48
+
49
+ def Lag(self, offset):
50
+ """
51
+ Get the variable-description for the past of this variable.
52
+ Used e.g. for describing SEMs in noise-models
53
+
54
+ Parameters
55
+ ----------
56
+ offset : uint
57
+ The lag to introduce (only non-negative values, always in the past).
58
+
59
+ Returns
60
+ -------
61
+ lagged-self : VariableDescription
62
+ This variable, at lag "offset"
63
+ """
64
+ if offset == 0:
65
+ return self
66
+ else:
67
+ return LaggedVariable(self, offset)
68
+
69
+ def Info(self, detail=0):
70
+ """
71
+ Get an Info-string for this variable.
72
+ (Overridden by derived classes for discrete, continuous, lagged, ...)
73
+
74
+ Parameters
75
+ ----------
76
+ detail : uint
77
+ The amount of detail to include.
78
+
79
+ Returns
80
+ -------
81
+ info : string
82
+ A description of this variable.
83
+ """
84
+ return self.name
85
+
86
+ def PrintInfo(self):
87
+ """Print Info-string (see Info)"""
88
+ print(self.Info())
89
+
90
+ def Id(self):
91
+ """Get the 'vanilla' (unlagged etc) description of this variable
92
+ (Overridden by derived classes for discrete, continuous, lagged, ...)
93
+
94
+ Returns
95
+ -------
96
+ id : VariableDescription
97
+ A description of the underlying variable of this description.
98
+ """
99
+ return self
100
+
101
+ def LagValue(self):
102
+ """Get the time-lag (non-negative, positive values are in the past) of the
103
+ variable described by this indicator.
104
+ (Overridden by derived classes for discrete, continuous, lagged, ...)
105
+
106
+ Returns
107
+ -------
108
+ lag : uint
109
+ The time-lag (non-negative, positive values are in the past) of this variable.
110
+ """
111
+ return 0
112
+
113
+
114
+ class LaggedVariable(VariableDescription):
115
+ """Describes a lagged variable (see VariableDescription)"""
116
+
117
+ def __init__(self, var, lag):
118
+ """Use var-description.Lag(offset) instead."""
119
+ self.var = var
120
+ self.lag = lag
121
+ self.exogenous = False
122
+
123
+ def Id(self):
124
+ return self.var
125
+
126
+ def LagValue(self):
127
+ return self.lag
128
+
129
+ def Info(self, detail=0):
130
+ return self.var.Info(detail) + " at lag " + str(self.lag)
131
+
132
+ def PrintInfo(self):
133
+ print(self.Info())
134
+
135
+ def DefaultValue(self):
136
+ return self.var.DefaultValue()
137
+
138
+
139
+ class ExogenousNoiseVariable(VariableDescription):
140
+ """Describes an exogenous noise variable (see VariableDescription)"""
141
+
142
+ def __init__(self, attached_to):
143
+ """Use var-description.Noise() instead."""
144
+ super().__init__(name="Noise of " + attached_to.name, observable=False, exogenous=True)
145
+ self.associated_system_variable = attached_to
146
+
147
+ def DefaultValue(self):
148
+ return self.associated_system_variable.DefaultValue()
149
+
150
+
151
+ class CategoricalVariable(VariableDescription):
152
+ """Describes a categorical variable (see VariableDescription)
153
+
154
+ Parameters
155
+ ----------
156
+ name : string
157
+ The display-name of the variable.
158
+ categories : uint or *iterable*
159
+ The number of categories in which the variable may take values, or
160
+ an iterable containing the possible values.
161
+ dtype : string
162
+ Name of the numpy dtype ('uint32', 'bool', ...) or 'auto'
163
+ """
164
+
165
+ def __init__(self, name="(unnamed)", categories=2, dtype="auto"):
166
+ super().__init__(name=name)
167
+ self.is_categorical = True
168
+ if hasattr(categories, '__iter__'):
169
+ self.categories = len(categories)
170
+ self.category_values = categories
171
+ else:
172
+ self.categories = categories
173
+ self.category_values = None
174
+ if dtype == "auto":
175
+ self.dtype = "uint32" if self.categories > 2 else "bool"
176
+ else:
177
+ self.dtype = dtype
178
+
179
+ def Info(self, detail=0):
180
+ if detail > 0:
181
+ return self.name
182
+ else:
183
+ return self.name + f"( Categorical variable with {self.categories} categories )"
184
+
185
+ def Empty(self, N):
186
+ """Build an empty np.array with N samples."""
187
+ return np.zeros(N, self.dtype)
188
+
189
+ def ValidValue(self, x):
190
+ """Validate a value"""
191
+ is_integer = x % 1 == 0.0 # this works for all of: np.[u]int; built-in [u]int; np.bool; built-in boolean
192
+ return is_integer and 0 <= x < self.categories
193
+
194
+ def DefaultValue(self):
195
+ """Return a 'default-value' to use in toy-models at the start of a time-series
196
+ (if otherwise undefined by the SEM)"""
197
+ # used for lagged values not available in time-series
198
+ return 0 # could be user specified & per variable ...
199
+
200
+ def CopyInfo(self, name="(unnamed)"):
201
+ """Create a new variable-id with the same parameters"""
202
+ return CategoricalVariable(name, self.categories, self.dtype)
203
+
204
+
205
+ class ContinuousVariable(VariableDescription):
206
+ """Describes a continuous variable (see VariableDescription)
207
+
208
+ Parameters
209
+ ----------
210
+ name : string
211
+ The display-name of the variable.
212
+ dtype : string
213
+ Name of the numpy dtype ('float32', 'float64', ...)
214
+ dimension : uint
215
+ The number of dimensions, can be >1 for fitting.
216
+ """
217
+
218
+ def __init__(self, name="(unnamed)", dtype="float32", dimension=1):
219
+ super().__init__(name=name)
220
+ self.dtype = dtype
221
+ self.is_categorical = False
222
+ self.dimension = dimension # fitting probabilities for multiple categories might be higher dim (handled by fit)
223
+
224
+ def Info(self, detail=0):
225
+ if detail > 0:
226
+ return self.name
227
+ else:
228
+ return self.name + f"( Continuous variable of dimension {self.dimension} )"
229
+
230
+ def Empty(self, N):
231
+ """Build an empty np.array with N samples."""
232
+ if self.dimension == 1:
233
+ return np.zeros(N, dtype=self.dtype)
234
+ else:
235
+ return np.zeros([self.dimension, N], dtype=self.dtype)
236
+
237
+ def ValidValue(self, x):
238
+ return True
239
+
240
+ def DefaultValue(self):
241
+ """Return a 'default-value' to use in toy-models at the start of a time-series
242
+ (if otherwise undefined by the SEM)"""
243
+ # used for lagged values not available in time-series
244
+ return 0.0 # could be user specified & per variable ...
245
+
246
+ def CopyInfo(self, name="(unnamed)"):
247
+ """Create a new variable-id with the same parameters"""
248
+ return ContinuousVariable(name, self.dtype, self.dimension)
249
+
250
+
251
+ class Environment:
252
+ """An (exogenous) environment to use with a model
253
+
254
+ Counterfactual models may create different worlds with different models from the same environment.
255
+
256
+ Parameters
257
+ ----------
258
+ exogenous_noise : dictionary< VariableDescription, *callable* (np.random.generator, sample-count) >
259
+ For each variable, a callable generating noise-values.
260
+ N : uint
261
+ The initial number of samples to generate
262
+ seed : uint or None
263
+ Seed to use for the random-generator.
264
+ """
265
+
266
+ def __init__(self, exogenous_noise, N=1000, seed=None):
267
+ self.rng = np.random.default_rng(seed)
268
+ self.can_reset = seed is None
269
+ self.exogenous_noise = exogenous_noise
270
+ self.noise = {}
271
+ self.N = N
272
+ self._ForceReset()
273
+
274
+ def ResetWithNewSeed(self, new_seed, N=None):
275
+ """Reset with a new random-seed."""
276
+ self.rng = np.random.default_rng(new_seed)
277
+ self._ForceReset(N)
278
+
279
+ def Reset(self, N=None):
280
+ """Reset with a random new random-seed."""
281
+ if not self.can_reset:
282
+ raise Exception("Cannot reset a fixed-seed environment, use seed=None or factory/lambda instead for"
283
+ "ensemble-creation.")
284
+ self._ForceReset(N)
285
+ return self
286
+
287
+ def _ForceReset(self, N=None):
288
+ """[internal] Enfore a reset."""
289
+ if N is not None:
290
+ self.N = N
291
+ for var, _noise in self.exogenous_noise.items():
292
+ self.noise[var.Noise()] = _noise(self.rng, self.N)
293
+
294
+ def GetNoise(self):
295
+ """Get a copy of the noise-values.
296
+
297
+ Returns
298
+ -------
299
+ noises : dictionary< Variable-Description, samples >
300
+ Returns a shallow copy of the samples generated.
301
+ Ids are 'exogenous_noise.Noise()' (see constructor).
302
+ """
303
+ return self.noise.copy() # return a SHALLOW copy of the noise-data (copy the dict, not the data)
304
+
305
+
306
+ class DataPointForValidation:
307
+ """[INTERNAL] Used to validate causal ordering on SEM for toy-model"""
308
+
309
+ def __init__(self):
310
+ self.known = {}
311
+ self.is_timeseries = False
312
+ self.max_lag = 0
313
+
314
+ def Set(self, var, value):
315
+ self.known[var] = True
316
+
317
+ def __getitem__(self, key):
318
+ if key.__class__ == LaggedVariable:
319
+ self.is_timeseries = True
320
+ self.max_lag = max(self.max_lag, key.lag)
321
+ return key.DefaultValue()
322
+
323
+ if not key.exogenous and key not in self.known:
324
+ raise Exception("SEM must be causally ordered, such that contemporaneous parents are listed "
325
+ "before their children.")
326
+ return key.DefaultValue()
327
+
328
+
329
+ class DataPointForParents:
330
+ """[INTERNAL] Used to extract causal parents (for ground-truth graph) from SEM for toy-model"""
331
+
332
+ def __init__(self):
333
+ self.parents = []
334
+
335
+ def __getitem__(self, key):
336
+ if not key.exogenous:
337
+ self.parents.append(key)
338
+ return key.DefaultValue()
339
+
340
+
341
+ class DataPointView:
342
+ """[INTERNAL] Used to run SEMs on toy-models."""
343
+
344
+ def __init__(self, data, known=None):
345
+ self.data = data
346
+ self.index = 0
347
+
348
+ def Next(self):
349
+ self.index += 1
350
+
351
+ def Set(self, var, value):
352
+ self.data[var][self.index] = value
353
+
354
+ def __getitem__(self, key):
355
+ if key.__class__ == LaggedVariable:
356
+ if self.index < key.lag:
357
+ return key.DefaultValue()
358
+ else:
359
+ return self.data[key.var][self.index - key.lag]
360
+ else:
361
+ return self.data[key][self.index]
362
+
363
+
364
+ class Model:
365
+ """Describes a Toy-model via an SEM.
366
+
367
+ Validates only the causal ordering, for more validation, see tigramite.toymodels.
368
+ However, this implementation can also generate non-additive models.
369
+
370
+ Parameters
371
+ ----------
372
+ sem : dictionary< VariableDescription, *callable* ( sample-view : var-desc->np.array or scalar )
373
+ For variables in causal order, a *callable* which is passed a
374
+ view v of the data-samples, access data by v[var-description],
375
+ and return the value(s) of the described (key) variable.
376
+ For non-time-series, v[var-description] is an np.array,
377
+ for time-series, v[var-description] is a scalar.
378
+ """
379
+
380
+ def __init__(self, sem):
381
+ self.SEM = sem
382
+ self.is_timeseries = False
383
+ self.max_lag = 0
384
+ self.Validate()
385
+
386
+ def GetGroundtruthLinks(self):
387
+ """Get the causal links to parents (ground-truth).
388
+
389
+ Does not validate faithfullness.
390
+
391
+ Returns
392
+ -------
393
+ links, indices : dictionaries< VariableDescription, ...>
394
+ For links: Values are the lists of parents.
395
+ For indices: Values are variable indices (see GetGroundtruthLinksRaw)
396
+ """
397
+ self.Validate()
398
+ links = {}
399
+ indices = {}
400
+ idx = 0
401
+ for var, eq in self.SEM.items():
402
+ data_pt = DataPointForParents()
403
+ self.SEM[var](data_pt)
404
+ links[var] = data_pt.parents
405
+ indices[var] = idx
406
+ idx += 1
407
+ return links, indices
408
+
409
+ def GetGroundtruthLinksRaw(self):
410
+ """Get unformatted causal links to parents (ground-truth).
411
+
412
+ Does not validate faithfullness.
413
+
414
+ Returns
415
+ -------
416
+ links : tigramite-format/raw indices + lags
417
+ """
418
+ links, indices = self.GetGroundtruthLinks()
419
+ links_raw = {}
420
+ for var in self.SEM.keys():
421
+ parents_raw = []
422
+ for p in links[var]:
423
+ if p.__class__ == LaggedVariable:
424
+ parents_raw.append((indices[p.var], -p.lag))
425
+ else:
426
+ parents_raw.append((indices[p], 0))
427
+ links_raw[indices[var]] = parents_raw
428
+ return links_raw
429
+
430
+ def GetGroundtruthGraph(self):
431
+ """Get the (ground-truth) graph.
432
+
433
+ Returns
434
+ -------
435
+ graph, graph-type : see e.g. CausalEffects tutorial, string
436
+ """
437
+ graph = CausalEffects.get_graph_from_dict(self.GetGroundtruthLinksRaw())
438
+ if self.is_timeseries:
439
+ return graph, 'stationary_dag'
440
+ else:
441
+ return graph, 'dag'
442
+
443
+ def Validate(self):
444
+ """Validate causal ordering (necessary for consistent data-generation), called automatically."""
445
+ data_pt = DataPointForValidation()
446
+ for var, eq in self.SEM.items():
447
+ data_pt.Set(var, self.SEM[var](data_pt))
448
+ self.max_lag = max(self.max_lag, data_pt.max_lag)
449
+ self.is_timeseries = data_pt.is_timeseries
450
+
451
+ def ApplyWithExogenousNoise(self, environment, partial_data=None):
452
+ """Apply to environment
453
+
454
+ Parameters
455
+ ----------
456
+ environment : Environment
457
+ Exogenous noise-samples given as environment.
458
+ partial_data : None
459
+ [INTERNAL USE] Leave to default.
460
+ """
461
+ if self.is_timeseries:
462
+ # This may be very slow (as it cannot be parallelized or be dispatched efficiently to native code via numpy)
463
+ return self._GenerateAsTimeseries(environment, partial_data)
464
+
465
+ data = environment.GetNoise()
466
+ vars = []
467
+ for var, eq in self.SEM.items():
468
+ if partial_data is not None and var in partial_data:
469
+ data[var] = partial_data[var]
470
+ else:
471
+ vars.append(var)
472
+
473
+ for var in vars:
474
+ data[var] = self.SEM[var](data)
475
+ return data
476
+
477
+ def _GenerateAsTimeseries(self, environment, partial_data=None):
478
+ """[INTERNAL] Generate time-series data.
479
+
480
+ Generation of time-series data cannot be parallelized of efficiently dispatched to native code.
481
+ """
482
+ data = environment.GetNoise()
483
+ vars = []
484
+ for var, eq in self.SEM.items():
485
+ if partial_data is not None and var in partial_data:
486
+ data[var] = partial_data[var]
487
+ else:
488
+ data[var] = var.Empty(environment.N)
489
+ vars.append(var)
490
+
491
+ data_pt = DataPointView(data, partial_data)
492
+ for x in range(environment.N):
493
+ for var in vars:
494
+ data_pt.Set(var, self.SEM[var](data_pt))
495
+ data_pt.Next()
496
+ return data
497
+
498
+ def Intervene(self, changes):
499
+ """Get and intervened model.
500
+
501
+ Parameters
502
+ ----------
503
+ changes : dictionary< VariableDescription, ...>
504
+ For each variable to intervene on, either a scalar (hard intervention),
505
+ or a *callable* replacing the equation in the SEM (see constructor).
506
+
507
+ Returns
508
+ -------
509
+ intervened model : Model
510
+ The intervened model.
511
+ """
512
+ # Return a model, that describes the intervened system
513
+
514
+ new_sem = self.SEM.copy()
515
+ for var, eq in changes.items():
516
+ if callable(eq):
517
+ new_sem[var] = eq # intervene by function
518
+ else:
519
+ new_sem[var] = lambda data: eq # return a constant
520
+ return Model(new_sem)
521
+
522
+
523
+ class World:
524
+ """A 'world' instantiation.
525
+
526
+ Generate observations from an environment (exogenous noise) and a model (SEM).
527
+
528
+ Parameters
529
+ ----------
530
+ environment : Environment
531
+ The environment with exogenous noise samples.
532
+ model : Model
533
+ The SEM describing the system.
534
+ """
535
+
536
+ def __init__(self, environment, model):
537
+ self.environment = environment # to check agreement in counterfactual worlds
538
+ if model is not None:
539
+ self.data = model.ApplyWithExogenousNoise(environment)
540
+ else:
541
+ self.data = {}
542
+
543
+ def Observables(self):
544
+ """Get all observables
545
+
546
+ Returns
547
+ -------
548
+ observables : dictonary< VariableDescription, np.array( environment.N ) >
549
+ The samples for each observable variable.
550
+ """
551
+ obs = {}
552
+ for var, values in self.data.items():
553
+ if var.observable:
554
+ obs[var] = values
555
+ return obs
556
+
557
+
558
+ class CounterfactualWorld(World):
559
+ """A 'counterfactual' world.
560
+
561
+ Generate observations from one environment (exogenous noise) and multiple models (SEM).
562
+
563
+ Parameters
564
+ ----------
565
+ environment : Environment
566
+ The environment with exogenous noise samples.
567
+ base-model : Model
568
+ The base-model to generate data from in the end.
569
+ Call TakeVariablesFromWorld overwrite some variables with values from another 'world',
570
+ then call Compute.
571
+ """
572
+
573
+ def __init__(self, environment, model):
574
+ super().__init__(environment, None)
575
+ self.model = model
576
+
577
+ def TakeVariablesFromWorld(self, world, vars):
578
+ """Take variables from a world.
579
+
580
+ Parameters
581
+ ----------
582
+ world : World
583
+ Take samples from here.
584
+ vars : VariableDescription or *iterable* <VariableDescription>
585
+ One (or a iterable of many) variable to overwrite.
586
+ """
587
+ if world.environment != self.environment:
588
+ raise Exception("Counterfactual Worlds must share exogenous noise terms.")
589
+
590
+ if hasattr(vars, '__iter__'):
591
+ for var in vars:
592
+ self.TakeVariablesFromWorld(world, var)
593
+ else:
594
+ self.data[vars] = world.data[vars]
595
+
596
+ def Compute(self):
597
+ """Compute the counterfactual world observations."""
598
+ self.data = self.model.ApplyWithExogenousNoise(self.environment, partial_data=self.data)
599
+
600
+
601
+ class DataPointView_TS_Window:
602
+ """[INTERNAL] Injected into SEM callbacks to compute values for counterfactual windows in parallel."""
603
+
604
+ def __init__(self, env, stationary_data, window_data, offset_in_window, max_lag):
605
+ self.stationary_data = stationary_data
606
+ self.environment = env
607
+ self.window_data = window_data
608
+ self.offset_in_window = offset_in_window
609
+ self.max_lag = max_lag
610
+
611
+ def __getitem__(self, key):
612
+ if key.exogenous:
613
+ return self.environment.noise[key][self.max_lag:]
614
+ else:
615
+ if self.offset_in_window < key.LagValue():
616
+ return self.stationary_data[key.Id()][self.max_lag - key.LagValue():-key.LagValue()]
617
+ else:
618
+ return self.window_data[self.offset_in_window - key.LagValue()][key.Id()]
619
+
620
+
621
+ class CounterfactualTimeseries:
622
+ """[INTERNAL] Used to generate ground-truth for time-series NDE.
623
+
624
+ Computes the effect of an intervention at a SINGLE point in time.
625
+ Use GroundTruth_* functions instead.
626
+ """
627
+
628
+ def __init__(self, environment, base_model, max_lag_in_interventions):
629
+ self.environment = environment
630
+ self.base_model = base_model
631
+ self.max_lag = max_lag_in_interventions
632
+
633
+ self.stationary_data = self.base_model.ApplyWithExogenousNoise(self.environment)
634
+
635
+ def ComputeIntervention(self, output_var, interventions, take):
636
+ output_windows = []
637
+ for var, value in interventions:
638
+ intervened_window = []
639
+ output_windows.append(intervened_window)
640
+ found_in_window = False
641
+ for delta_t in range(self.max_lag + 1):
642
+ data_pts = DataPointView_TS_Window(self.environment, self.stationary_data,
643
+ intervened_window, delta_t, self.max_lag)
644
+ current_time = {}
645
+ intervened_window.append(current_time)
646
+ for v, eq in self.base_model.SEM.items():
647
+ if v.Id() == var.Id() and self.max_lag - delta_t == var.LagValue():
648
+ found_in_window = True
649
+ current_time[v] = value # this is a (single) value, but numpy can broadcast it
650
+ else:
651
+ current_time[v] = eq(data_pts)
652
+ if not found_in_window:
653
+ raise Exception(f"Intervention {var.Info()}={value} was not found in time-series window.")
654
+ # default to window 0
655
+ cf_window = output_windows[0]
656
+ for m in take:
657
+ cf_window[self.max_lag - m.LagValue()][m.Id()] = output_windows[1][self.max_lag - m.LagValue()][m.Id()]
658
+ data_pts = DataPointView_TS_Window(self.environment, self.stationary_data,
659
+ cf_window, self.max_lag, self.max_lag)
660
+ return self.base_model.SEM[output_var](data_pts)
661
+
662
+
663
+ def Ensemble(shared_setup, payloads, runs=1000):
664
+ """Helper to run e.g. estimator vs ground-truth on an ensemble of model-realizations"""
665
+
666
+ results = np.zeros([len(payloads), runs])
667
+ get_next = shared_setup
668
+ if not callable(shared_setup):
669
+ environment = shared_setup
670
+ get_next = lambda: environment.Reset()
671
+
672
+ for r in range(runs):
673
+ environment = get_next()
674
+ p_idx = 0
675
+ for p in payloads:
676
+ results[p_idx, r] = p(environment)
677
+ p_idx += 1
678
+ return results
679
+
680
+
681
+ def _Fct_on_grid(fct, list_of_points, cf_delta=0.5, normalize_by_delta=False, **kwargs):
682
+ """Helper to evaluate 'fct' on a grid of points"""
683
+
684
+ result = []
685
+ for pt in list_of_points:
686
+ result.append(fct(pt, pt + cf_delta, **kwargs))
687
+
688
+ if normalize_by_delta:
689
+ return np.array(result) / cf_delta
690
+ else:
691
+ return np.array(result)
692
+
693
+
694
+ def _Fct_smoothed(fct, min_x, max_x, cf_delta=0.5, steps=100, smoothing_gaussian_sigma_in_steps=5,
695
+ normalize_by_delta=False, boundary_effects="extend range", **kwargs):
696
+ """Helper to evaluate 'fct' on a grid of points with subsequent Gauß-smoothing"""
697
+
698
+ stepsize = (max_x - min_x) / steps
699
+ # Extend the window to run on (numpy would extend by zeros if mode="same")
700
+ if boundary_effects == "extend range":
701
+ steps += 6 * smoothing_gaussian_sigma_in_steps - 1
702
+ min_x -= 3 * smoothing_gaussian_sigma_in_steps * stepsize
703
+ else:
704
+ raise Exception("Currently only smoothing mode for boundary-effects is 'extend range'")
705
+
706
+ x_values = [] # np.arange with floats is unstable wrt len of output
707
+ result = []
708
+ for i in range(steps + 1):
709
+ pt = stepsize * i + min_x
710
+ x_values.append(pt)
711
+ value = fct(pt, pt + cf_delta, **kwargs)
712
+ result.append(value)
713
+
714
+ if normalize_by_delta:
715
+ result = np.array(result) / cf_delta
716
+ else:
717
+ result = np.array(result)
718
+
719
+ # cut off convolution-kernel at 3 sigma
720
+ gx = np.arange(-3 * smoothing_gaussian_sigma_in_steps, 3 * smoothing_gaussian_sigma_in_steps)
721
+ gaussian = (np.exp(-(gx / smoothing_gaussian_sigma_in_steps) ** 2 / 2)
722
+ / np.sqrt(2.0 * np.pi) / smoothing_gaussian_sigma_in_steps)
723
+ if len(np.shape(result)) == 1: # values = (samples)
724
+ smoothed = np.convolve(result, gaussian, mode="valid")
725
+ elif len(np.shape(result)) == 3: # densities = (samples, categories, cf/te)
726
+ smoothed = np.empty_like(
727
+ result[3 * smoothing_gaussian_sigma_in_steps:1 - 3 * smoothing_gaussian_sigma_in_steps, :, :])
728
+ for cf_te in range(2):
729
+ for p_category in range(np.shape(result)[1]):
730
+ smoothed[:, p_category, cf_te] = np.convolve(result[:, p_category, cf_te], gaussian, mode="valid")
731
+ else:
732
+ raise Exception(f"Invalid result-shape {result.shape}")
733
+
734
+ return np.array(x_values)[3 * smoothing_gaussian_sigma_in_steps:1 - 3 * smoothing_gaussian_sigma_in_steps], smoothed
735
+
736
+
737
+ def PlotInfo(**va_arg_dict):
738
+ """Helper to add info to plot-setup"""
739
+ return va_arg_dict
740
+
741
+
742
+ def PlotAbsProbabilities(plt, target_var, data, labels):
743
+ """Helper to plot Effects on Categorical variables (typically set plt=your-pylot-module-name)"""
744
+
745
+ fig = plt.figure(figsize=[12.0, 4.8], dpi=75.0, layout='constrained')
746
+ fig.suptitle(labels["title"])
747
+ shared_ax = None
748
+ has_printed_legend = False
749
+ for cY in range(target_var.categories):
750
+ if shared_ax is None:
751
+ shared_ax = plt.subplot(131 + cY) # digits are rows, cols, index+1
752
+ else:
753
+ plt.subplot(131 + cY, sharey=shared_ax) # digits are rows, cols, index+1
754
+ for d in data:
755
+ plt.plot(d["x"], d["y"][:, cY, 1], color=d["colorTE"], label=d["labelTE"])
756
+ plt.plot(d["x"], d["y"][:, cY, 0], color=d["colorCF"], label=d["labelCF"])
757
+ plt.xlabel(labels["x"])
758
+ plt.ylabel(labels["y"].format(cY=cY))
759
+ if not has_printed_legend:
760
+ has_printed_legend = True
761
+ fig.legend()
762
+ return fig
763
+
764
+
765
+ def PlotChangeInProbabilities(plt, target_var, data, labels):
766
+ """Helper to plot Effects on Categorical variables (typically set plt=your-pylot-module-name)"""
767
+
768
+ fig = plt.figure(figsize=[12.0, 4.8], dpi=75.0, layout='constrained')
769
+ fig.suptitle(labels["title"])
770
+ shared_ax = None
771
+ has_printed_legend = False
772
+ for cY in range(target_var.categories):
773
+ if shared_ax is None:
774
+ shared_ax = plt.subplot(131 + cY) # digits are rows, cols, index+1
775
+ else:
776
+ plt.subplot(131 + cY, sharey=shared_ax) # digits are rows, cols, index+1
777
+ for d in data:
778
+ plt.plot(d["x"], d["y"][:, cY, 0] - d["y"][:, cY, 1], color=d["color"], label=d["label"])
779
+ plt.xlabel(labels["x"])
780
+ plt.ylabel(labels["y"].format(cY=cY))
781
+ if not has_printed_legend:
782
+ has_printed_legend = True
783
+ fig.legend()
784
+ return fig
785
+
786
+
787
+ def FindMaxLag(*va_args):
788
+ """[INTERNAL] Finds the max lag in a collection of variables."""
789
+ max_lag = 0
790
+ for var_group in va_args:
791
+ if hasattr(var_group, '__iter__'):
792
+ for var in var_group:
793
+ max_lag = max(max_lag, var.LagValue())
794
+ else:
795
+ var = var_group
796
+ max_lag = max(max_lag, var.LagValue())
797
+ return max_lag
798
+
799
+
800
+ def GroundTruth_NDE_auto(change_from, change_to, estimator, env, model):
801
+ """Ground-Truth computation from toy-model for NDE.
802
+
803
+ GroundTruth_*_auto functions extract source, target, blocked-mediators from an estimator.
804
+
805
+ Parameters
806
+ ----------
807
+ change_from : single value of same type as single sample for X (float, int, or bool)
808
+ Reference-value to which X is set by intervention in the world seen by the mediator.
809
+ change_to : single value of same type as single sample for X (float, int, or bool)
810
+ Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
811
+ estimator : NaturalEffects_GraphMediation
812
+ Extract source, target, blocked-mediators from estimator.
813
+ env : Environment
814
+ The environment used.
815
+ model : Model
816
+ The toy-model.
817
+
818
+ Returns
819
+ -------
820
+ NDE : If Y is categorical -> np.array( # categories Y, 2 )
821
+ The probabilities the categories of Y (after, before) changing the interventional value of X
822
+ as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
823
+
824
+ NDE : If Y is continuous -> float
825
+ The change in the expectation-value of Y induced by changing the interventional value of X
826
+ as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
827
+ """
828
+ return GroundTruth_NDE(change_from, change_to, estimator.Source, estimator.Target, estimator.BlockedMediators,
829
+ env, model)
830
+
831
+
832
+ def GroundTruth_NDE(change_from, change_to, source, target, mediators, env, model):
833
+ """Ground-Truth computation from toy-model for NDE.
834
+
835
+ Note: GroundTruth_*_auto functions extract source, target, blocked-mediators from an estimator.
836
+
837
+ Parameters
838
+ ----------
839
+ change_from : single value of same type as single sample for X (float, int, or bool)
840
+ Reference-value to which X is set by intervention in the world seen by the mediator.
841
+ change_to : single value of same type as single sample for X (float, int, or bool)
842
+ Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
843
+ source : VariableDescription
844
+ Effect source.
845
+ target : VariableDescription
846
+ Effect target.
847
+ mediators : *iterable* <VariableDesciption>
848
+ Blocked mediators.
849
+ env : Environment
850
+ The environment used.
851
+ model : Model
852
+ The toy-model.
853
+
854
+ Returns
855
+ -------
856
+ NDE : If Y is categorical -> np.array( # categories Y, 2 )
857
+ The probabilities the categories of Y (after, before) changing the interventional value of X
858
+ as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
859
+
860
+ NDE : If Y is continuous -> float
861
+ The change in the expectation-value of Y induced by changing the interventional value of X
862
+ as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
863
+ """
864
+
865
+ if target.LagValue() != 0:
866
+ raise Exception("Do not use lagged targets, it is always possible to shift everything, so that the"
867
+ "effect is on an unlagged variable.")
868
+
869
+ if model.is_timeseries:
870
+ # Generate groundtruth for intervention at a SINGLE point in time
871
+
872
+ # Window-size (at each point) must be at least the
873
+ max_lag_in_interventions = FindMaxLag(source, target, mediators)
874
+ # shouldn't be necessary, but fixes some issues with offsets in stationary data:
875
+ max_lag_in_interventions = max(max_lag_in_interventions, model.max_lag)
876
+
877
+ ts = CounterfactualTimeseries(env, model, max_lag_in_interventions=max_lag_in_interventions)
878
+
879
+ y_cf = ts.ComputeIntervention(target, [(source, change_to), (source, change_from)], mediators)
880
+ y_real = ts.ComputeIntervention(target, [(source, change_from)], [])
881
+
882
+ else:
883
+ # Ground-Truth for non-timeseries is straight-forward:
884
+
885
+ modelA = model.Intervene(changes={source.Id(): change_from})
886
+ modelB = model.Intervene(changes={source.Id(): change_to})
887
+ worldA = World(env, modelA)
888
+ worldB = World(env, modelB)
889
+ cf_world = CounterfactualWorld(env, model)
890
+ cf_world.TakeVariablesFromWorld(worldA, mediators)
891
+ cf_world.TakeVariablesFromWorld(worldB, source)
892
+ cf_world.Compute()
893
+
894
+ y_cf = cf_world.Observables()[target]
895
+ y_real = worldA.Observables()[target]
896
+
897
+ if target.is_categorical:
898
+ result = []
899
+ for cY in range(target.categories):
900
+ result.append([np.count_nonzero(y_cf == cY), np.count_nonzero(y_real == cY)])
901
+ return np.array(result) / env.N
902
+ else:
903
+ return np.mean(y_cf - y_real)
904
+
905
+
906
+ def GroundTruth_NDE_fct_auto(x_min, x_max, estimator, env, model, cf_delta=0.5, normalize_by_delta=False,
907
+ grid_stepping=0.1):
908
+ """See GroundTruth_NDE_auto and NaturalEffects_GraphMediation.NDE_smoothed."""
909
+ source = estimator.Source
910
+ target = estimator.Target
911
+ blocked_mediators = estimator.BlockedMediators
912
+ grid = np.arange(x_min, x_max, grid_stepping)
913
+ return grid, GroundTruth_NDE_fct(source, target, blocked_mediators, env, model, grid, cf_delta, normalize_by_delta)
914
+
915
+
916
+ def GroundTruth_NDE_fct(source, target, mediators, env, model, list_of_points, cf_delta=0.5, normalize_by_delta=False):
917
+ """See GroundTruth_NDE and NaturalEffects_GraphMediation.NDE_smoothed."""
918
+ return _Fct_on_grid(GroundTruth_NDE, list_of_points, cf_delta, normalize_by_delta,
919
+ source=source, target=target, mediators=mediators, env=env, model=model)
920
+
921
+
922
+ def GroundTruth_NIE_fct(source, target, mediator, env, model, list_of_points, cf_delta=0.5, normalize_by_delta=False):
923
+ """See GroundTruth_NIE and NaturalEffects_StandardMediation.NIE_smoothed."""
924
+ return _Fct_on_grid(GroundTruth_NIE, list_of_points, cf_delta, normalize_by_delta,
925
+ source=source, target=target, mediator=mediator, env=env, model=model)
926
+
927
+
928
+ def GroundTruth_NIE(change_from, change_to, source, target, mediator, env, model):
929
+ """Ground-Truth computation from toy-model for NIE.
930
+
931
+ Standard-mediation setup only.
932
+
933
+ Parameters
934
+ ----------
935
+ change_from : single value of same type as single sample for X (float, int, or bool)
936
+ Reference-value to which X is set by intervention in the world seen by the mediator.
937
+ change_to : single value of same type as single sample for X (float, int, or bool)
938
+ Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
939
+ source : VariableDescription
940
+ Effect source.
941
+ target : VariableDescription
942
+ Effect target.
943
+ mediator : VariableDesciption
944
+ Effect mediator.
945
+ env : Environment
946
+ The environment used.
947
+ model : Model
948
+ The toy-model.
949
+
950
+ Returns
951
+ -------
952
+ NIE : If Y is categorical -> np.array( # categories Y, 2 )
953
+ The probabilities the categories of Y (after, before) changing the interventional value of X
954
+ as "seen" by M from change_from to change_to, while keeping the value as (directly) seen by Y,
955
+ as if X remained at change_from.
956
+
957
+ NIE : If Y is continuous -> float
958
+ The change in the expectation-value of Y induced by changing the interventional value of X
959
+ as "seen" by M from change_from to change_to, while keeping the value as (directly) seen by Y,
960
+ as if X remained at change_from.
961
+ """
962
+
963
+ modelA = model.Intervene(changes={source: change_from})
964
+ modelB = model.Intervene(changes={source: change_to})
965
+ worldA = World(env, modelA)
966
+ worldB = World(env, modelB)
967
+ cf_world = CounterfactualWorld(env, model)
968
+ # this time, take the mediator from the B-world
969
+ cf_world.TakeVariablesFromWorld(worldA, [source])
970
+ cf_world.TakeVariablesFromWorld(worldB, [mediator])
971
+ cf_world.Compute()
972
+
973
+ y_cf = cf_world.Observables()[target]
974
+ y_real = worldA.Observables()[target]
975
+
976
+ if target.is_categorical:
977
+ result = []
978
+ for cY in range(target.categories):
979
+ result.append([np.count_nonzero(y_cf == cY), np.count_nonzero(y_real == cY)])
980
+ return np.array(result) / env.N
981
+ else:
982
+ return np.mean(y_cf - y_real)
983
+
984
+
985
+ class DataHandler:
986
+ """[INTERNAL] Implement some helper functions and generate time-series data via tigramite's data-frames."""
987
+
988
+ def __init__(self, observables, dataframe_based_preprocessing=True):
989
+ if isinstance(observables, DataFrame):
990
+ self._from_dataframe = True
991
+ self._data = VariablesFromDataframe(observables)
992
+ self._dataframe = observables
993
+ self.use_dataframe_for_preprocessing = dataframe_based_preprocessing
994
+ self.data_selection_virtual_ids = {}
995
+ self._info = {}
996
+ self.preprocessed_data = False
997
+ self.data_selection_frozen = False # filter all data for all fits for missing values, lock after first data-access (see GetVariableAuto and operator [] / __getitem__)
998
+ else:
999
+ self._from_dataframe = False
1000
+ self._data = observables
1001
+ self._dataframe = DataframeFromVariables(observables)
1002
+ self.use_dataframe_for_preprocessing = False
1003
+
1004
+ self._indices = {}
1005
+ self._keys = list(self._data.keys())
1006
+ idx = 0
1007
+ for var in self._data.keys():
1008
+ self._indices[var] = idx
1009
+ idx += 1
1010
+
1011
+
1012
+ def GetIdFor(self, idx, lag, display_name):
1013
+ # self._keys[idx] contains info (eg catgorical or continuous, 'original name') for variable idx, attach display name (eg Mediator)
1014
+ # also add 'original name' eg 'temperature' in square brackets as well as lag
1015
+ base_id = self._keys[idx]
1016
+ return base_id.CopyInfo(display_name + f"[{base_id.Info(detail=1)} at lag {lag}]")
1017
+
1018
+ def RequirePreprocessingFor(self, var, info):
1019
+ if self.use_dataframe_for_preprocessing:
1020
+ if var not in self.data_selection_virtual_ids:
1021
+ # if variable locked also ok after frozen
1022
+ if self.data_selection_frozen:
1023
+ raise Exception("Cannot add variables to preprocessing after variable-set has been locked-in")
1024
+ assert info is not None
1025
+ var_id = self.GetIdFor(var[0], var[1], info)
1026
+ self.data_selection_virtual_ids[var] = var_id
1027
+ self._indices[var_id] = var
1028
+ self._info[var_id] = info
1029
+
1030
+ def GetVariableAuto(self, var, info="Other"):
1031
+ if self._from_dataframe:
1032
+ return self.ReverseLookupSingle(var, info)
1033
+ else:
1034
+ return var
1035
+
1036
+ def GetVariablesAuto(self, vars, info="Other"):
1037
+ result = []
1038
+ for var in vars:
1039
+ result.append(self.GetVariableAuto(var, info))
1040
+ return result
1041
+
1042
+ def DataFrame(self):
1043
+ return self._dataframe
1044
+
1045
+ def __getitem__(self, key):
1046
+ if hasattr(key, '__iter__'):
1047
+ return [self[entry] for entry in key]
1048
+
1049
+ if self.use_dataframe_for_preprocessing:
1050
+ return self._indices[key]
1051
+
1052
+ if key.__class__ == LaggedVariable:
1053
+ return self._indices[key.var], -key.lag
1054
+ else:
1055
+ return self._indices[key], 0
1056
+
1057
+ def _PreprocessData(self, **kwargs):
1058
+ self.data_selection_frozen = True
1059
+ self.preprocessed_data = True
1060
+ X = []
1061
+ Y = []
1062
+ Z = []
1063
+ M = []
1064
+ for var_lag_index, var_id in self.data_selection_virtual_ids.items():
1065
+ info = self._info[var_id]
1066
+ if info == "Source":
1067
+ X.append(var_lag_index)
1068
+ elif info == "Target":
1069
+ Y.append(var_lag_index)
1070
+ elif info == "Mediator":
1071
+ M.append(var_lag_index)
1072
+ elif info == "Adjustment":
1073
+ Z.append(var_lag_index)
1074
+ else:
1075
+ raise Exception("Unknown Variable-Interpretation")
1076
+ data_preprocessed, xyz, data_type = self.DataFrame().construct_array(X=X, Y=Y, Z=Z, extraZ=M, **kwargs) # kw-args forwards eg tau-max
1077
+ self._data = {}
1078
+ i = 0
1079
+ for x in X:
1080
+ self._data[self.data_selection_virtual_ids[x]] = data_preprocessed[i]
1081
+ assert xyz[i] == 0
1082
+ i += 1
1083
+ for y in Y:
1084
+ self._data[self.data_selection_virtual_ids[y]] = data_preprocessed[i]
1085
+ assert xyz[i] == 1
1086
+ i += 1
1087
+ for z in Z:
1088
+ self._data[self.data_selection_virtual_ids[z]] = data_preprocessed[i]
1089
+ assert xyz[i] == 2
1090
+ i += 1
1091
+ for m in M:
1092
+ self._data[self.data_selection_virtual_ids[m]] = data_preprocessed[i]
1093
+ assert xyz[i] == 3
1094
+ i += 1
1095
+ assert i == data_preprocessed.shape[0]
1096
+
1097
+ def GetPreprocessed(self, vars, **kwargs):
1098
+ if not self.preprocessed_data:
1099
+ self._PreprocessData(**kwargs)
1100
+ result = {}
1101
+ for var in vars:
1102
+ assert var in self.data_selection_virtual_ids
1103
+ var_id = self.data_selection_virtual_ids[var]
1104
+ result[var_id] = self._data[var_id]
1105
+ return list(result.keys()), result
1106
+
1107
+
1108
+ def Get(self, name, vars, **kwargs):
1109
+ if self.use_dataframe_for_preprocessing:
1110
+ return self.GetPreprocessed(vars, **kwargs)
1111
+ ids = []
1112
+ i = 1
1113
+ for idx, lag in vars:
1114
+ # assemble names for this variable-group, eg "Mediator5"
1115
+ the_name = name if len(vars) == 1 else name + str(i)
1116
+ ids.append( self.GetIdFor(idx, lag, the_name) )
1117
+ i += 1
1118
+ data, xyz, data_type = self.DataFrame().construct_array(X=vars, Y=[], Z=[], **kwargs)
1119
+ assert data.shape[0] == len(ids)
1120
+ result = {}
1121
+ i = 0
1122
+ for elem in ids:
1123
+ result[elem] = data[i].astype(dtype=np.dtype(elem.dtype))
1124
+ i += 1
1125
+ return ids, result
1126
+
1127
+ def ReverseLookupSingle(self, index, info):
1128
+ if self.use_dataframe_for_preprocessing:
1129
+ self.RequirePreprocessingFor(index, info)
1130
+ return self.data_selection_virtual_ids[index]
1131
+ else:
1132
+ return self._keys[index[0]].Lag(-index[1])
1133
+
1134
+ def ReverseLookupMulti(self, index_set, info=None):
1135
+ return [self.ReverseLookupSingle(index, info) for index in index_set]
1136
+
1137
+
1138
+ def VariablesFromDataframe(dataframe):
1139
+ """Extract Category-Information from tigramite::dataframe
1140
+
1141
+ Parameters
1142
+ ----------
1143
+ dataframe : tigramite.data_processing.DataFrame
1144
+ Dataframe to extract data from.
1145
+
1146
+
1147
+ Returns
1148
+ -------
1149
+ variables : dictionary< VariableDescription, np.array(N) >
1150
+ The variable-meta-data and data.
1151
+ """
1152
+
1153
+ data_types = dataframe.data_type
1154
+ if dataframe.data_type is not None:
1155
+ # Require data-types to be constant
1156
+ first_elements = data_types[0, :]
1157
+ if not np.all(first_elements == data_types):
1158
+ raise NotImplementedError("Natural Effect Framework currently only supports per variable"
1159
+ "data-types, this dataframe contains variables with changing"
1160
+ "(over time) data-types.")
1161
+
1162
+ result = {} # var i will be in result.items()[i]
1163
+ data = dataframe.values[0]
1164
+ node_count = np.shape(data)[1]
1165
+ for node in range(node_count):
1166
+ if data_types is None or data_types[0, node] == 0:
1167
+ var = ContinuousVariable(name=str(dataframe.var_names[node]))
1168
+ result[var] = data[:, node]
1169
+ else:
1170
+ labels, transformed = np.unique(data[:, node], return_inverse=True)
1171
+ var = CategoricalVariable(name=str(dataframe.var_names[node]), categories=labels)
1172
+ result[var] = transformed
1173
+ return result
1174
+
1175
+
1176
+ def DataframeFromVariables(data_dict):
1177
+ """Convert Category-Information to tigramite::dataframe
1178
+
1179
+ Parameters
1180
+ ----------
1181
+ data_dict : dictionary< VariableDescription, np.array(N) >
1182
+ The variable-meta-data and data.
1183
+
1184
+ Returns
1185
+ -------
1186
+ dataframe : tigramite.data_processing.DataFrame
1187
+ Dataframe containing the raw data.
1188
+ """
1189
+
1190
+ data_len = 0
1191
+ keys = list(data_dict.keys())
1192
+ values = list(data_dict.values())
1193
+ if keys[0].is_categorical or keys[0].dimension == 1:
1194
+ data_len = np.shape(values[0])[0]
1195
+ else:
1196
+ data_len = np.shape(values[0])[1]
1197
+ dimensions = 0
1198
+ for var, data in data_dict.items():
1199
+ if var.is_categorical:
1200
+ dimensions += 1
1201
+ else:
1202
+ dimensions += var.dimension
1203
+
1204
+ data_out = np.zeros([data_len, dimensions])
1205
+ data_type = np.zeros([data_len, dimensions])
1206
+ names = []
1207
+ i = 0
1208
+ for var, data in data_dict.items():
1209
+ if var.is_categorical:
1210
+ names.append(var.name)
1211
+ if var.category_values is not None:
1212
+ data_out[:, i] = np.take(var.category_values, data)
1213
+ else:
1214
+ data_out[:, i] = data
1215
+ data_type[:, i] = np.ones(data_len)
1216
+ i += 1
1217
+ else:
1218
+ if var.dimension > 1:
1219
+ for j in range(var.dimension):
1220
+ if var.dimension == 1:
1221
+ names.append(var.name)
1222
+ else:
1223
+ names.append(var.name + "$_" + str(j) + "$")
1224
+ data_out[:, i + j] = data[j]
1225
+ data_type[:, i + j] = np.zeros(data_len)
1226
+ else:
1227
+ names.append(var.name)
1228
+ data_out[:, i] = data
1229
+ data_type[:, i] = np.zeros(data_len)
1230
+ i += var.dimension
1231
+ return DataFrame(data=data_out, data_type=data_type, var_names=names)