tigramite-fast 5.2.10.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- tigramite/__init__.py +0 -0
- tigramite/causal_effects.py +1525 -0
- tigramite/causal_mediation.py +1592 -0
- tigramite/data_processing.py +1574 -0
- tigramite/graphs.py +1509 -0
- tigramite/independence_tests/LBFGS.py +1114 -0
- tigramite/independence_tests/__init__.py +0 -0
- tigramite/independence_tests/cmiknn.py +661 -0
- tigramite/independence_tests/cmiknn_mixed.py +1397 -0
- tigramite/independence_tests/cmisymb.py +286 -0
- tigramite/independence_tests/gpdc.py +664 -0
- tigramite/independence_tests/gpdc_torch.py +820 -0
- tigramite/independence_tests/gsquared.py +190 -0
- tigramite/independence_tests/independence_tests_base.py +1310 -0
- tigramite/independence_tests/oracle_conditional_independence.py +1582 -0
- tigramite/independence_tests/pairwise_CI.py +383 -0
- tigramite/independence_tests/parcorr.py +369 -0
- tigramite/independence_tests/parcorr_mult.py +485 -0
- tigramite/independence_tests/parcorr_wls.py +451 -0
- tigramite/independence_tests/regressionCI.py +403 -0
- tigramite/independence_tests/robust_parcorr.py +403 -0
- tigramite/jpcmciplus.py +966 -0
- tigramite/lpcmci.py +3649 -0
- tigramite/models.py +2257 -0
- tigramite/pcmci.py +3935 -0
- tigramite/pcmci_base.py +1218 -0
- tigramite/plotting.py +4735 -0
- tigramite/rpcmci.py +467 -0
- tigramite/toymodels/__init__.py +0 -0
- tigramite/toymodels/context_model.py +261 -0
- tigramite/toymodels/non_additive.py +1231 -0
- tigramite/toymodels/structural_causal_processes.py +1201 -0
- tigramite/toymodels/surrogate_generator.py +319 -0
- tigramite_fast-5.2.10.1.dist-info/METADATA +182 -0
- tigramite_fast-5.2.10.1.dist-info/RECORD +38 -0
- tigramite_fast-5.2.10.1.dist-info/WHEEL +5 -0
- tigramite_fast-5.2.10.1.dist-info/licenses/license.txt +621 -0
- tigramite_fast-5.2.10.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1231 @@
|
|
|
1
|
+
"""Tigramite causal discovery for time series."""
|
|
2
|
+
|
|
3
|
+
# Authors: Martin Rabel, Jakob Runge <jakob@jakob-runge.com>
|
|
4
|
+
#
|
|
5
|
+
# License: GNU General Public License v3.0
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
from tigramite.data_processing import DataFrame
|
|
10
|
+
from tigramite.causal_effects import CausalEffects
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class VariableDescription:
|
|
14
|
+
r"""Variable descrption (base-class)
|
|
15
|
+
|
|
16
|
+
Used for mixed variable fitting and mediation, see Mediation-tutorial.
|
|
17
|
+
|
|
18
|
+
Parameters
|
|
19
|
+
----------
|
|
20
|
+
name : string
|
|
21
|
+
The display-name of the variable
|
|
22
|
+
observable : bool
|
|
23
|
+
Is the variable observable or hidden? (used by toy-models)
|
|
24
|
+
exogenous : bool
|
|
25
|
+
Is the variable assumed exogenous (used on noise-terms internally, keep default)
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
def __init__(self, name="(unnamed)", observable=True, exogenous=False):
|
|
29
|
+
"""Use CategoricalVariable or ContinuousVariable instead."""
|
|
30
|
+
self.noise_term = None
|
|
31
|
+
self.observable = observable
|
|
32
|
+
self.exogenous = exogenous
|
|
33
|
+
self.name = name
|
|
34
|
+
|
|
35
|
+
def Noise(self):
|
|
36
|
+
"""
|
|
37
|
+
Get the variable-description for the noise-term on this variable.
|
|
38
|
+
Used e.g. for describing SEMs in noise-models
|
|
39
|
+
|
|
40
|
+
Returns
|
|
41
|
+
-------
|
|
42
|
+
noise : VariableDescription
|
|
43
|
+
The variable-description for the noise-term on this variable.
|
|
44
|
+
"""
|
|
45
|
+
if self.noise_term is None:
|
|
46
|
+
self.noise_term = ExogenousNoiseVariable(self)
|
|
47
|
+
return self.noise_term
|
|
48
|
+
|
|
49
|
+
def Lag(self, offset):
|
|
50
|
+
"""
|
|
51
|
+
Get the variable-description for the past of this variable.
|
|
52
|
+
Used e.g. for describing SEMs in noise-models
|
|
53
|
+
|
|
54
|
+
Parameters
|
|
55
|
+
----------
|
|
56
|
+
offset : uint
|
|
57
|
+
The lag to introduce (only non-negative values, always in the past).
|
|
58
|
+
|
|
59
|
+
Returns
|
|
60
|
+
-------
|
|
61
|
+
lagged-self : VariableDescription
|
|
62
|
+
This variable, at lag "offset"
|
|
63
|
+
"""
|
|
64
|
+
if offset == 0:
|
|
65
|
+
return self
|
|
66
|
+
else:
|
|
67
|
+
return LaggedVariable(self, offset)
|
|
68
|
+
|
|
69
|
+
def Info(self, detail=0):
|
|
70
|
+
"""
|
|
71
|
+
Get an Info-string for this variable.
|
|
72
|
+
(Overridden by derived classes for discrete, continuous, lagged, ...)
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
detail : uint
|
|
77
|
+
The amount of detail to include.
|
|
78
|
+
|
|
79
|
+
Returns
|
|
80
|
+
-------
|
|
81
|
+
info : string
|
|
82
|
+
A description of this variable.
|
|
83
|
+
"""
|
|
84
|
+
return self.name
|
|
85
|
+
|
|
86
|
+
def PrintInfo(self):
|
|
87
|
+
"""Print Info-string (see Info)"""
|
|
88
|
+
print(self.Info())
|
|
89
|
+
|
|
90
|
+
def Id(self):
|
|
91
|
+
"""Get the 'vanilla' (unlagged etc) description of this variable
|
|
92
|
+
(Overridden by derived classes for discrete, continuous, lagged, ...)
|
|
93
|
+
|
|
94
|
+
Returns
|
|
95
|
+
-------
|
|
96
|
+
id : VariableDescription
|
|
97
|
+
A description of the underlying variable of this description.
|
|
98
|
+
"""
|
|
99
|
+
return self
|
|
100
|
+
|
|
101
|
+
def LagValue(self):
|
|
102
|
+
"""Get the time-lag (non-negative, positive values are in the past) of the
|
|
103
|
+
variable described by this indicator.
|
|
104
|
+
(Overridden by derived classes for discrete, continuous, lagged, ...)
|
|
105
|
+
|
|
106
|
+
Returns
|
|
107
|
+
-------
|
|
108
|
+
lag : uint
|
|
109
|
+
The time-lag (non-negative, positive values are in the past) of this variable.
|
|
110
|
+
"""
|
|
111
|
+
return 0
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class LaggedVariable(VariableDescription):
|
|
115
|
+
"""Describes a lagged variable (see VariableDescription)"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, var, lag):
|
|
118
|
+
"""Use var-description.Lag(offset) instead."""
|
|
119
|
+
self.var = var
|
|
120
|
+
self.lag = lag
|
|
121
|
+
self.exogenous = False
|
|
122
|
+
|
|
123
|
+
def Id(self):
|
|
124
|
+
return self.var
|
|
125
|
+
|
|
126
|
+
def LagValue(self):
|
|
127
|
+
return self.lag
|
|
128
|
+
|
|
129
|
+
def Info(self, detail=0):
|
|
130
|
+
return self.var.Info(detail) + " at lag " + str(self.lag)
|
|
131
|
+
|
|
132
|
+
def PrintInfo(self):
|
|
133
|
+
print(self.Info())
|
|
134
|
+
|
|
135
|
+
def DefaultValue(self):
|
|
136
|
+
return self.var.DefaultValue()
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class ExogenousNoiseVariable(VariableDescription):
|
|
140
|
+
"""Describes an exogenous noise variable (see VariableDescription)"""
|
|
141
|
+
|
|
142
|
+
def __init__(self, attached_to):
|
|
143
|
+
"""Use var-description.Noise() instead."""
|
|
144
|
+
super().__init__(name="Noise of " + attached_to.name, observable=False, exogenous=True)
|
|
145
|
+
self.associated_system_variable = attached_to
|
|
146
|
+
|
|
147
|
+
def DefaultValue(self):
|
|
148
|
+
return self.associated_system_variable.DefaultValue()
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class CategoricalVariable(VariableDescription):
|
|
152
|
+
"""Describes a categorical variable (see VariableDescription)
|
|
153
|
+
|
|
154
|
+
Parameters
|
|
155
|
+
----------
|
|
156
|
+
name : string
|
|
157
|
+
The display-name of the variable.
|
|
158
|
+
categories : uint or *iterable*
|
|
159
|
+
The number of categories in which the variable may take values, or
|
|
160
|
+
an iterable containing the possible values.
|
|
161
|
+
dtype : string
|
|
162
|
+
Name of the numpy dtype ('uint32', 'bool', ...) or 'auto'
|
|
163
|
+
"""
|
|
164
|
+
|
|
165
|
+
def __init__(self, name="(unnamed)", categories=2, dtype="auto"):
|
|
166
|
+
super().__init__(name=name)
|
|
167
|
+
self.is_categorical = True
|
|
168
|
+
if hasattr(categories, '__iter__'):
|
|
169
|
+
self.categories = len(categories)
|
|
170
|
+
self.category_values = categories
|
|
171
|
+
else:
|
|
172
|
+
self.categories = categories
|
|
173
|
+
self.category_values = None
|
|
174
|
+
if dtype == "auto":
|
|
175
|
+
self.dtype = "uint32" if self.categories > 2 else "bool"
|
|
176
|
+
else:
|
|
177
|
+
self.dtype = dtype
|
|
178
|
+
|
|
179
|
+
def Info(self, detail=0):
|
|
180
|
+
if detail > 0:
|
|
181
|
+
return self.name
|
|
182
|
+
else:
|
|
183
|
+
return self.name + f"( Categorical variable with {self.categories} categories )"
|
|
184
|
+
|
|
185
|
+
def Empty(self, N):
|
|
186
|
+
"""Build an empty np.array with N samples."""
|
|
187
|
+
return np.zeros(N, self.dtype)
|
|
188
|
+
|
|
189
|
+
def ValidValue(self, x):
|
|
190
|
+
"""Validate a value"""
|
|
191
|
+
is_integer = x % 1 == 0.0 # this works for all of: np.[u]int; built-in [u]int; np.bool; built-in boolean
|
|
192
|
+
return is_integer and 0 <= x < self.categories
|
|
193
|
+
|
|
194
|
+
def DefaultValue(self):
|
|
195
|
+
"""Return a 'default-value' to use in toy-models at the start of a time-series
|
|
196
|
+
(if otherwise undefined by the SEM)"""
|
|
197
|
+
# used for lagged values not available in time-series
|
|
198
|
+
return 0 # could be user specified & per variable ...
|
|
199
|
+
|
|
200
|
+
def CopyInfo(self, name="(unnamed)"):
|
|
201
|
+
"""Create a new variable-id with the same parameters"""
|
|
202
|
+
return CategoricalVariable(name, self.categories, self.dtype)
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
class ContinuousVariable(VariableDescription):
|
|
206
|
+
"""Describes a continuous variable (see VariableDescription)
|
|
207
|
+
|
|
208
|
+
Parameters
|
|
209
|
+
----------
|
|
210
|
+
name : string
|
|
211
|
+
The display-name of the variable.
|
|
212
|
+
dtype : string
|
|
213
|
+
Name of the numpy dtype ('float32', 'float64', ...)
|
|
214
|
+
dimension : uint
|
|
215
|
+
The number of dimensions, can be >1 for fitting.
|
|
216
|
+
"""
|
|
217
|
+
|
|
218
|
+
def __init__(self, name="(unnamed)", dtype="float32", dimension=1):
|
|
219
|
+
super().__init__(name=name)
|
|
220
|
+
self.dtype = dtype
|
|
221
|
+
self.is_categorical = False
|
|
222
|
+
self.dimension = dimension # fitting probabilities for multiple categories might be higher dim (handled by fit)
|
|
223
|
+
|
|
224
|
+
def Info(self, detail=0):
|
|
225
|
+
if detail > 0:
|
|
226
|
+
return self.name
|
|
227
|
+
else:
|
|
228
|
+
return self.name + f"( Continuous variable of dimension {self.dimension} )"
|
|
229
|
+
|
|
230
|
+
def Empty(self, N):
|
|
231
|
+
"""Build an empty np.array with N samples."""
|
|
232
|
+
if self.dimension == 1:
|
|
233
|
+
return np.zeros(N, dtype=self.dtype)
|
|
234
|
+
else:
|
|
235
|
+
return np.zeros([self.dimension, N], dtype=self.dtype)
|
|
236
|
+
|
|
237
|
+
def ValidValue(self, x):
|
|
238
|
+
return True
|
|
239
|
+
|
|
240
|
+
def DefaultValue(self):
|
|
241
|
+
"""Return a 'default-value' to use in toy-models at the start of a time-series
|
|
242
|
+
(if otherwise undefined by the SEM)"""
|
|
243
|
+
# used for lagged values not available in time-series
|
|
244
|
+
return 0.0 # could be user specified & per variable ...
|
|
245
|
+
|
|
246
|
+
def CopyInfo(self, name="(unnamed)"):
|
|
247
|
+
"""Create a new variable-id with the same parameters"""
|
|
248
|
+
return ContinuousVariable(name, self.dtype, self.dimension)
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
class Environment:
|
|
252
|
+
"""An (exogenous) environment to use with a model
|
|
253
|
+
|
|
254
|
+
Counterfactual models may create different worlds with different models from the same environment.
|
|
255
|
+
|
|
256
|
+
Parameters
|
|
257
|
+
----------
|
|
258
|
+
exogenous_noise : dictionary< VariableDescription, *callable* (np.random.generator, sample-count) >
|
|
259
|
+
For each variable, a callable generating noise-values.
|
|
260
|
+
N : uint
|
|
261
|
+
The initial number of samples to generate
|
|
262
|
+
seed : uint or None
|
|
263
|
+
Seed to use for the random-generator.
|
|
264
|
+
"""
|
|
265
|
+
|
|
266
|
+
def __init__(self, exogenous_noise, N=1000, seed=None):
|
|
267
|
+
self.rng = np.random.default_rng(seed)
|
|
268
|
+
self.can_reset = seed is None
|
|
269
|
+
self.exogenous_noise = exogenous_noise
|
|
270
|
+
self.noise = {}
|
|
271
|
+
self.N = N
|
|
272
|
+
self._ForceReset()
|
|
273
|
+
|
|
274
|
+
def ResetWithNewSeed(self, new_seed, N=None):
|
|
275
|
+
"""Reset with a new random-seed."""
|
|
276
|
+
self.rng = np.random.default_rng(new_seed)
|
|
277
|
+
self._ForceReset(N)
|
|
278
|
+
|
|
279
|
+
def Reset(self, N=None):
|
|
280
|
+
"""Reset with a random new random-seed."""
|
|
281
|
+
if not self.can_reset:
|
|
282
|
+
raise Exception("Cannot reset a fixed-seed environment, use seed=None or factory/lambda instead for"
|
|
283
|
+
"ensemble-creation.")
|
|
284
|
+
self._ForceReset(N)
|
|
285
|
+
return self
|
|
286
|
+
|
|
287
|
+
def _ForceReset(self, N=None):
|
|
288
|
+
"""[internal] Enfore a reset."""
|
|
289
|
+
if N is not None:
|
|
290
|
+
self.N = N
|
|
291
|
+
for var, _noise in self.exogenous_noise.items():
|
|
292
|
+
self.noise[var.Noise()] = _noise(self.rng, self.N)
|
|
293
|
+
|
|
294
|
+
def GetNoise(self):
|
|
295
|
+
"""Get a copy of the noise-values.
|
|
296
|
+
|
|
297
|
+
Returns
|
|
298
|
+
-------
|
|
299
|
+
noises : dictionary< Variable-Description, samples >
|
|
300
|
+
Returns a shallow copy of the samples generated.
|
|
301
|
+
Ids are 'exogenous_noise.Noise()' (see constructor).
|
|
302
|
+
"""
|
|
303
|
+
return self.noise.copy() # return a SHALLOW copy of the noise-data (copy the dict, not the data)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
class DataPointForValidation:
|
|
307
|
+
"""[INTERNAL] Used to validate causal ordering on SEM for toy-model"""
|
|
308
|
+
|
|
309
|
+
def __init__(self):
|
|
310
|
+
self.known = {}
|
|
311
|
+
self.is_timeseries = False
|
|
312
|
+
self.max_lag = 0
|
|
313
|
+
|
|
314
|
+
def Set(self, var, value):
|
|
315
|
+
self.known[var] = True
|
|
316
|
+
|
|
317
|
+
def __getitem__(self, key):
|
|
318
|
+
if key.__class__ == LaggedVariable:
|
|
319
|
+
self.is_timeseries = True
|
|
320
|
+
self.max_lag = max(self.max_lag, key.lag)
|
|
321
|
+
return key.DefaultValue()
|
|
322
|
+
|
|
323
|
+
if not key.exogenous and key not in self.known:
|
|
324
|
+
raise Exception("SEM must be causally ordered, such that contemporaneous parents are listed "
|
|
325
|
+
"before their children.")
|
|
326
|
+
return key.DefaultValue()
|
|
327
|
+
|
|
328
|
+
|
|
329
|
+
class DataPointForParents:
|
|
330
|
+
"""[INTERNAL] Used to extract causal parents (for ground-truth graph) from SEM for toy-model"""
|
|
331
|
+
|
|
332
|
+
def __init__(self):
|
|
333
|
+
self.parents = []
|
|
334
|
+
|
|
335
|
+
def __getitem__(self, key):
|
|
336
|
+
if not key.exogenous:
|
|
337
|
+
self.parents.append(key)
|
|
338
|
+
return key.DefaultValue()
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
class DataPointView:
|
|
342
|
+
"""[INTERNAL] Used to run SEMs on toy-models."""
|
|
343
|
+
|
|
344
|
+
def __init__(self, data, known=None):
|
|
345
|
+
self.data = data
|
|
346
|
+
self.index = 0
|
|
347
|
+
|
|
348
|
+
def Next(self):
|
|
349
|
+
self.index += 1
|
|
350
|
+
|
|
351
|
+
def Set(self, var, value):
|
|
352
|
+
self.data[var][self.index] = value
|
|
353
|
+
|
|
354
|
+
def __getitem__(self, key):
|
|
355
|
+
if key.__class__ == LaggedVariable:
|
|
356
|
+
if self.index < key.lag:
|
|
357
|
+
return key.DefaultValue()
|
|
358
|
+
else:
|
|
359
|
+
return self.data[key.var][self.index - key.lag]
|
|
360
|
+
else:
|
|
361
|
+
return self.data[key][self.index]
|
|
362
|
+
|
|
363
|
+
|
|
364
|
+
class Model:
|
|
365
|
+
"""Describes a Toy-model via an SEM.
|
|
366
|
+
|
|
367
|
+
Validates only the causal ordering, for more validation, see tigramite.toymodels.
|
|
368
|
+
However, this implementation can also generate non-additive models.
|
|
369
|
+
|
|
370
|
+
Parameters
|
|
371
|
+
----------
|
|
372
|
+
sem : dictionary< VariableDescription, *callable* ( sample-view : var-desc->np.array or scalar )
|
|
373
|
+
For variables in causal order, a *callable* which is passed a
|
|
374
|
+
view v of the data-samples, access data by v[var-description],
|
|
375
|
+
and return the value(s) of the described (key) variable.
|
|
376
|
+
For non-time-series, v[var-description] is an np.array,
|
|
377
|
+
for time-series, v[var-description] is a scalar.
|
|
378
|
+
"""
|
|
379
|
+
|
|
380
|
+
def __init__(self, sem):
|
|
381
|
+
self.SEM = sem
|
|
382
|
+
self.is_timeseries = False
|
|
383
|
+
self.max_lag = 0
|
|
384
|
+
self.Validate()
|
|
385
|
+
|
|
386
|
+
def GetGroundtruthLinks(self):
|
|
387
|
+
"""Get the causal links to parents (ground-truth).
|
|
388
|
+
|
|
389
|
+
Does not validate faithfullness.
|
|
390
|
+
|
|
391
|
+
Returns
|
|
392
|
+
-------
|
|
393
|
+
links, indices : dictionaries< VariableDescription, ...>
|
|
394
|
+
For links: Values are the lists of parents.
|
|
395
|
+
For indices: Values are variable indices (see GetGroundtruthLinksRaw)
|
|
396
|
+
"""
|
|
397
|
+
self.Validate()
|
|
398
|
+
links = {}
|
|
399
|
+
indices = {}
|
|
400
|
+
idx = 0
|
|
401
|
+
for var, eq in self.SEM.items():
|
|
402
|
+
data_pt = DataPointForParents()
|
|
403
|
+
self.SEM[var](data_pt)
|
|
404
|
+
links[var] = data_pt.parents
|
|
405
|
+
indices[var] = idx
|
|
406
|
+
idx += 1
|
|
407
|
+
return links, indices
|
|
408
|
+
|
|
409
|
+
def GetGroundtruthLinksRaw(self):
|
|
410
|
+
"""Get unformatted causal links to parents (ground-truth).
|
|
411
|
+
|
|
412
|
+
Does not validate faithfullness.
|
|
413
|
+
|
|
414
|
+
Returns
|
|
415
|
+
-------
|
|
416
|
+
links : tigramite-format/raw indices + lags
|
|
417
|
+
"""
|
|
418
|
+
links, indices = self.GetGroundtruthLinks()
|
|
419
|
+
links_raw = {}
|
|
420
|
+
for var in self.SEM.keys():
|
|
421
|
+
parents_raw = []
|
|
422
|
+
for p in links[var]:
|
|
423
|
+
if p.__class__ == LaggedVariable:
|
|
424
|
+
parents_raw.append((indices[p.var], -p.lag))
|
|
425
|
+
else:
|
|
426
|
+
parents_raw.append((indices[p], 0))
|
|
427
|
+
links_raw[indices[var]] = parents_raw
|
|
428
|
+
return links_raw
|
|
429
|
+
|
|
430
|
+
def GetGroundtruthGraph(self):
|
|
431
|
+
"""Get the (ground-truth) graph.
|
|
432
|
+
|
|
433
|
+
Returns
|
|
434
|
+
-------
|
|
435
|
+
graph, graph-type : see e.g. CausalEffects tutorial, string
|
|
436
|
+
"""
|
|
437
|
+
graph = CausalEffects.get_graph_from_dict(self.GetGroundtruthLinksRaw())
|
|
438
|
+
if self.is_timeseries:
|
|
439
|
+
return graph, 'stationary_dag'
|
|
440
|
+
else:
|
|
441
|
+
return graph, 'dag'
|
|
442
|
+
|
|
443
|
+
def Validate(self):
|
|
444
|
+
"""Validate causal ordering (necessary for consistent data-generation), called automatically."""
|
|
445
|
+
data_pt = DataPointForValidation()
|
|
446
|
+
for var, eq in self.SEM.items():
|
|
447
|
+
data_pt.Set(var, self.SEM[var](data_pt))
|
|
448
|
+
self.max_lag = max(self.max_lag, data_pt.max_lag)
|
|
449
|
+
self.is_timeseries = data_pt.is_timeseries
|
|
450
|
+
|
|
451
|
+
def ApplyWithExogenousNoise(self, environment, partial_data=None):
|
|
452
|
+
"""Apply to environment
|
|
453
|
+
|
|
454
|
+
Parameters
|
|
455
|
+
----------
|
|
456
|
+
environment : Environment
|
|
457
|
+
Exogenous noise-samples given as environment.
|
|
458
|
+
partial_data : None
|
|
459
|
+
[INTERNAL USE] Leave to default.
|
|
460
|
+
"""
|
|
461
|
+
if self.is_timeseries:
|
|
462
|
+
# This may be very slow (as it cannot be parallelized or be dispatched efficiently to native code via numpy)
|
|
463
|
+
return self._GenerateAsTimeseries(environment, partial_data)
|
|
464
|
+
|
|
465
|
+
data = environment.GetNoise()
|
|
466
|
+
vars = []
|
|
467
|
+
for var, eq in self.SEM.items():
|
|
468
|
+
if partial_data is not None and var in partial_data:
|
|
469
|
+
data[var] = partial_data[var]
|
|
470
|
+
else:
|
|
471
|
+
vars.append(var)
|
|
472
|
+
|
|
473
|
+
for var in vars:
|
|
474
|
+
data[var] = self.SEM[var](data)
|
|
475
|
+
return data
|
|
476
|
+
|
|
477
|
+
def _GenerateAsTimeseries(self, environment, partial_data=None):
|
|
478
|
+
"""[INTERNAL] Generate time-series data.
|
|
479
|
+
|
|
480
|
+
Generation of time-series data cannot be parallelized of efficiently dispatched to native code.
|
|
481
|
+
"""
|
|
482
|
+
data = environment.GetNoise()
|
|
483
|
+
vars = []
|
|
484
|
+
for var, eq in self.SEM.items():
|
|
485
|
+
if partial_data is not None and var in partial_data:
|
|
486
|
+
data[var] = partial_data[var]
|
|
487
|
+
else:
|
|
488
|
+
data[var] = var.Empty(environment.N)
|
|
489
|
+
vars.append(var)
|
|
490
|
+
|
|
491
|
+
data_pt = DataPointView(data, partial_data)
|
|
492
|
+
for x in range(environment.N):
|
|
493
|
+
for var in vars:
|
|
494
|
+
data_pt.Set(var, self.SEM[var](data_pt))
|
|
495
|
+
data_pt.Next()
|
|
496
|
+
return data
|
|
497
|
+
|
|
498
|
+
def Intervene(self, changes):
|
|
499
|
+
"""Get and intervened model.
|
|
500
|
+
|
|
501
|
+
Parameters
|
|
502
|
+
----------
|
|
503
|
+
changes : dictionary< VariableDescription, ...>
|
|
504
|
+
For each variable to intervene on, either a scalar (hard intervention),
|
|
505
|
+
or a *callable* replacing the equation in the SEM (see constructor).
|
|
506
|
+
|
|
507
|
+
Returns
|
|
508
|
+
-------
|
|
509
|
+
intervened model : Model
|
|
510
|
+
The intervened model.
|
|
511
|
+
"""
|
|
512
|
+
# Return a model, that describes the intervened system
|
|
513
|
+
|
|
514
|
+
new_sem = self.SEM.copy()
|
|
515
|
+
for var, eq in changes.items():
|
|
516
|
+
if callable(eq):
|
|
517
|
+
new_sem[var] = eq # intervene by function
|
|
518
|
+
else:
|
|
519
|
+
new_sem[var] = lambda data: eq # return a constant
|
|
520
|
+
return Model(new_sem)
|
|
521
|
+
|
|
522
|
+
|
|
523
|
+
class World:
|
|
524
|
+
"""A 'world' instantiation.
|
|
525
|
+
|
|
526
|
+
Generate observations from an environment (exogenous noise) and a model (SEM).
|
|
527
|
+
|
|
528
|
+
Parameters
|
|
529
|
+
----------
|
|
530
|
+
environment : Environment
|
|
531
|
+
The environment with exogenous noise samples.
|
|
532
|
+
model : Model
|
|
533
|
+
The SEM describing the system.
|
|
534
|
+
"""
|
|
535
|
+
|
|
536
|
+
def __init__(self, environment, model):
|
|
537
|
+
self.environment = environment # to check agreement in counterfactual worlds
|
|
538
|
+
if model is not None:
|
|
539
|
+
self.data = model.ApplyWithExogenousNoise(environment)
|
|
540
|
+
else:
|
|
541
|
+
self.data = {}
|
|
542
|
+
|
|
543
|
+
def Observables(self):
|
|
544
|
+
"""Get all observables
|
|
545
|
+
|
|
546
|
+
Returns
|
|
547
|
+
-------
|
|
548
|
+
observables : dictonary< VariableDescription, np.array( environment.N ) >
|
|
549
|
+
The samples for each observable variable.
|
|
550
|
+
"""
|
|
551
|
+
obs = {}
|
|
552
|
+
for var, values in self.data.items():
|
|
553
|
+
if var.observable:
|
|
554
|
+
obs[var] = values
|
|
555
|
+
return obs
|
|
556
|
+
|
|
557
|
+
|
|
558
|
+
class CounterfactualWorld(World):
|
|
559
|
+
"""A 'counterfactual' world.
|
|
560
|
+
|
|
561
|
+
Generate observations from one environment (exogenous noise) and multiple models (SEM).
|
|
562
|
+
|
|
563
|
+
Parameters
|
|
564
|
+
----------
|
|
565
|
+
environment : Environment
|
|
566
|
+
The environment with exogenous noise samples.
|
|
567
|
+
base-model : Model
|
|
568
|
+
The base-model to generate data from in the end.
|
|
569
|
+
Call TakeVariablesFromWorld overwrite some variables with values from another 'world',
|
|
570
|
+
then call Compute.
|
|
571
|
+
"""
|
|
572
|
+
|
|
573
|
+
def __init__(self, environment, model):
|
|
574
|
+
super().__init__(environment, None)
|
|
575
|
+
self.model = model
|
|
576
|
+
|
|
577
|
+
def TakeVariablesFromWorld(self, world, vars):
|
|
578
|
+
"""Take variables from a world.
|
|
579
|
+
|
|
580
|
+
Parameters
|
|
581
|
+
----------
|
|
582
|
+
world : World
|
|
583
|
+
Take samples from here.
|
|
584
|
+
vars : VariableDescription or *iterable* <VariableDescription>
|
|
585
|
+
One (or a iterable of many) variable to overwrite.
|
|
586
|
+
"""
|
|
587
|
+
if world.environment != self.environment:
|
|
588
|
+
raise Exception("Counterfactual Worlds must share exogenous noise terms.")
|
|
589
|
+
|
|
590
|
+
if hasattr(vars, '__iter__'):
|
|
591
|
+
for var in vars:
|
|
592
|
+
self.TakeVariablesFromWorld(world, var)
|
|
593
|
+
else:
|
|
594
|
+
self.data[vars] = world.data[vars]
|
|
595
|
+
|
|
596
|
+
def Compute(self):
|
|
597
|
+
"""Compute the counterfactual world observations."""
|
|
598
|
+
self.data = self.model.ApplyWithExogenousNoise(self.environment, partial_data=self.data)
|
|
599
|
+
|
|
600
|
+
|
|
601
|
+
class DataPointView_TS_Window:
|
|
602
|
+
"""[INTERNAL] Injected into SEM callbacks to compute values for counterfactual windows in parallel."""
|
|
603
|
+
|
|
604
|
+
def __init__(self, env, stationary_data, window_data, offset_in_window, max_lag):
|
|
605
|
+
self.stationary_data = stationary_data
|
|
606
|
+
self.environment = env
|
|
607
|
+
self.window_data = window_data
|
|
608
|
+
self.offset_in_window = offset_in_window
|
|
609
|
+
self.max_lag = max_lag
|
|
610
|
+
|
|
611
|
+
def __getitem__(self, key):
|
|
612
|
+
if key.exogenous:
|
|
613
|
+
return self.environment.noise[key][self.max_lag:]
|
|
614
|
+
else:
|
|
615
|
+
if self.offset_in_window < key.LagValue():
|
|
616
|
+
return self.stationary_data[key.Id()][self.max_lag - key.LagValue():-key.LagValue()]
|
|
617
|
+
else:
|
|
618
|
+
return self.window_data[self.offset_in_window - key.LagValue()][key.Id()]
|
|
619
|
+
|
|
620
|
+
|
|
621
|
+
class CounterfactualTimeseries:
|
|
622
|
+
"""[INTERNAL] Used to generate ground-truth for time-series NDE.
|
|
623
|
+
|
|
624
|
+
Computes the effect of an intervention at a SINGLE point in time.
|
|
625
|
+
Use GroundTruth_* functions instead.
|
|
626
|
+
"""
|
|
627
|
+
|
|
628
|
+
def __init__(self, environment, base_model, max_lag_in_interventions):
|
|
629
|
+
self.environment = environment
|
|
630
|
+
self.base_model = base_model
|
|
631
|
+
self.max_lag = max_lag_in_interventions
|
|
632
|
+
|
|
633
|
+
self.stationary_data = self.base_model.ApplyWithExogenousNoise(self.environment)
|
|
634
|
+
|
|
635
|
+
def ComputeIntervention(self, output_var, interventions, take):
|
|
636
|
+
output_windows = []
|
|
637
|
+
for var, value in interventions:
|
|
638
|
+
intervened_window = []
|
|
639
|
+
output_windows.append(intervened_window)
|
|
640
|
+
found_in_window = False
|
|
641
|
+
for delta_t in range(self.max_lag + 1):
|
|
642
|
+
data_pts = DataPointView_TS_Window(self.environment, self.stationary_data,
|
|
643
|
+
intervened_window, delta_t, self.max_lag)
|
|
644
|
+
current_time = {}
|
|
645
|
+
intervened_window.append(current_time)
|
|
646
|
+
for v, eq in self.base_model.SEM.items():
|
|
647
|
+
if v.Id() == var.Id() and self.max_lag - delta_t == var.LagValue():
|
|
648
|
+
found_in_window = True
|
|
649
|
+
current_time[v] = value # this is a (single) value, but numpy can broadcast it
|
|
650
|
+
else:
|
|
651
|
+
current_time[v] = eq(data_pts)
|
|
652
|
+
if not found_in_window:
|
|
653
|
+
raise Exception(f"Intervention {var.Info()}={value} was not found in time-series window.")
|
|
654
|
+
# default to window 0
|
|
655
|
+
cf_window = output_windows[0]
|
|
656
|
+
for m in take:
|
|
657
|
+
cf_window[self.max_lag - m.LagValue()][m.Id()] = output_windows[1][self.max_lag - m.LagValue()][m.Id()]
|
|
658
|
+
data_pts = DataPointView_TS_Window(self.environment, self.stationary_data,
|
|
659
|
+
cf_window, self.max_lag, self.max_lag)
|
|
660
|
+
return self.base_model.SEM[output_var](data_pts)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
def Ensemble(shared_setup, payloads, runs=1000):
|
|
664
|
+
"""Helper to run e.g. estimator vs ground-truth on an ensemble of model-realizations"""
|
|
665
|
+
|
|
666
|
+
results = np.zeros([len(payloads), runs])
|
|
667
|
+
get_next = shared_setup
|
|
668
|
+
if not callable(shared_setup):
|
|
669
|
+
environment = shared_setup
|
|
670
|
+
get_next = lambda: environment.Reset()
|
|
671
|
+
|
|
672
|
+
for r in range(runs):
|
|
673
|
+
environment = get_next()
|
|
674
|
+
p_idx = 0
|
|
675
|
+
for p in payloads:
|
|
676
|
+
results[p_idx, r] = p(environment)
|
|
677
|
+
p_idx += 1
|
|
678
|
+
return results
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
def _Fct_on_grid(fct, list_of_points, cf_delta=0.5, normalize_by_delta=False, **kwargs):
|
|
682
|
+
"""Helper to evaluate 'fct' on a grid of points"""
|
|
683
|
+
|
|
684
|
+
result = []
|
|
685
|
+
for pt in list_of_points:
|
|
686
|
+
result.append(fct(pt, pt + cf_delta, **kwargs))
|
|
687
|
+
|
|
688
|
+
if normalize_by_delta:
|
|
689
|
+
return np.array(result) / cf_delta
|
|
690
|
+
else:
|
|
691
|
+
return np.array(result)
|
|
692
|
+
|
|
693
|
+
|
|
694
|
+
def _Fct_smoothed(fct, min_x, max_x, cf_delta=0.5, steps=100, smoothing_gaussian_sigma_in_steps=5,
|
|
695
|
+
normalize_by_delta=False, boundary_effects="extend range", **kwargs):
|
|
696
|
+
"""Helper to evaluate 'fct' on a grid of points with subsequent Gauß-smoothing"""
|
|
697
|
+
|
|
698
|
+
stepsize = (max_x - min_x) / steps
|
|
699
|
+
# Extend the window to run on (numpy would extend by zeros if mode="same")
|
|
700
|
+
if boundary_effects == "extend range":
|
|
701
|
+
steps += 6 * smoothing_gaussian_sigma_in_steps - 1
|
|
702
|
+
min_x -= 3 * smoothing_gaussian_sigma_in_steps * stepsize
|
|
703
|
+
else:
|
|
704
|
+
raise Exception("Currently only smoothing mode for boundary-effects is 'extend range'")
|
|
705
|
+
|
|
706
|
+
x_values = [] # np.arange with floats is unstable wrt len of output
|
|
707
|
+
result = []
|
|
708
|
+
for i in range(steps + 1):
|
|
709
|
+
pt = stepsize * i + min_x
|
|
710
|
+
x_values.append(pt)
|
|
711
|
+
value = fct(pt, pt + cf_delta, **kwargs)
|
|
712
|
+
result.append(value)
|
|
713
|
+
|
|
714
|
+
if normalize_by_delta:
|
|
715
|
+
result = np.array(result) / cf_delta
|
|
716
|
+
else:
|
|
717
|
+
result = np.array(result)
|
|
718
|
+
|
|
719
|
+
# cut off convolution-kernel at 3 sigma
|
|
720
|
+
gx = np.arange(-3 * smoothing_gaussian_sigma_in_steps, 3 * smoothing_gaussian_sigma_in_steps)
|
|
721
|
+
gaussian = (np.exp(-(gx / smoothing_gaussian_sigma_in_steps) ** 2 / 2)
|
|
722
|
+
/ np.sqrt(2.0 * np.pi) / smoothing_gaussian_sigma_in_steps)
|
|
723
|
+
if len(np.shape(result)) == 1: # values = (samples)
|
|
724
|
+
smoothed = np.convolve(result, gaussian, mode="valid")
|
|
725
|
+
elif len(np.shape(result)) == 3: # densities = (samples, categories, cf/te)
|
|
726
|
+
smoothed = np.empty_like(
|
|
727
|
+
result[3 * smoothing_gaussian_sigma_in_steps:1 - 3 * smoothing_gaussian_sigma_in_steps, :, :])
|
|
728
|
+
for cf_te in range(2):
|
|
729
|
+
for p_category in range(np.shape(result)[1]):
|
|
730
|
+
smoothed[:, p_category, cf_te] = np.convolve(result[:, p_category, cf_te], gaussian, mode="valid")
|
|
731
|
+
else:
|
|
732
|
+
raise Exception(f"Invalid result-shape {result.shape}")
|
|
733
|
+
|
|
734
|
+
return np.array(x_values)[3 * smoothing_gaussian_sigma_in_steps:1 - 3 * smoothing_gaussian_sigma_in_steps], smoothed
|
|
735
|
+
|
|
736
|
+
|
|
737
|
+
def PlotInfo(**va_arg_dict):
|
|
738
|
+
"""Helper to add info to plot-setup"""
|
|
739
|
+
return va_arg_dict
|
|
740
|
+
|
|
741
|
+
|
|
742
|
+
def PlotAbsProbabilities(plt, target_var, data, labels):
|
|
743
|
+
"""Helper to plot Effects on Categorical variables (typically set plt=your-pylot-module-name)"""
|
|
744
|
+
|
|
745
|
+
fig = plt.figure(figsize=[12.0, 4.8], dpi=75.0, layout='constrained')
|
|
746
|
+
fig.suptitle(labels["title"])
|
|
747
|
+
shared_ax = None
|
|
748
|
+
has_printed_legend = False
|
|
749
|
+
for cY in range(target_var.categories):
|
|
750
|
+
if shared_ax is None:
|
|
751
|
+
shared_ax = plt.subplot(131 + cY) # digits are rows, cols, index+1
|
|
752
|
+
else:
|
|
753
|
+
plt.subplot(131 + cY, sharey=shared_ax) # digits are rows, cols, index+1
|
|
754
|
+
for d in data:
|
|
755
|
+
plt.plot(d["x"], d["y"][:, cY, 1], color=d["colorTE"], label=d["labelTE"])
|
|
756
|
+
plt.plot(d["x"], d["y"][:, cY, 0], color=d["colorCF"], label=d["labelCF"])
|
|
757
|
+
plt.xlabel(labels["x"])
|
|
758
|
+
plt.ylabel(labels["y"].format(cY=cY))
|
|
759
|
+
if not has_printed_legend:
|
|
760
|
+
has_printed_legend = True
|
|
761
|
+
fig.legend()
|
|
762
|
+
return fig
|
|
763
|
+
|
|
764
|
+
|
|
765
|
+
def PlotChangeInProbabilities(plt, target_var, data, labels):
|
|
766
|
+
"""Helper to plot Effects on Categorical variables (typically set plt=your-pylot-module-name)"""
|
|
767
|
+
|
|
768
|
+
fig = plt.figure(figsize=[12.0, 4.8], dpi=75.0, layout='constrained')
|
|
769
|
+
fig.suptitle(labels["title"])
|
|
770
|
+
shared_ax = None
|
|
771
|
+
has_printed_legend = False
|
|
772
|
+
for cY in range(target_var.categories):
|
|
773
|
+
if shared_ax is None:
|
|
774
|
+
shared_ax = plt.subplot(131 + cY) # digits are rows, cols, index+1
|
|
775
|
+
else:
|
|
776
|
+
plt.subplot(131 + cY, sharey=shared_ax) # digits are rows, cols, index+1
|
|
777
|
+
for d in data:
|
|
778
|
+
plt.plot(d["x"], d["y"][:, cY, 0] - d["y"][:, cY, 1], color=d["color"], label=d["label"])
|
|
779
|
+
plt.xlabel(labels["x"])
|
|
780
|
+
plt.ylabel(labels["y"].format(cY=cY))
|
|
781
|
+
if not has_printed_legend:
|
|
782
|
+
has_printed_legend = True
|
|
783
|
+
fig.legend()
|
|
784
|
+
return fig
|
|
785
|
+
|
|
786
|
+
|
|
787
|
+
def FindMaxLag(*va_args):
|
|
788
|
+
"""[INTERNAL] Finds the max lag in a collection of variables."""
|
|
789
|
+
max_lag = 0
|
|
790
|
+
for var_group in va_args:
|
|
791
|
+
if hasattr(var_group, '__iter__'):
|
|
792
|
+
for var in var_group:
|
|
793
|
+
max_lag = max(max_lag, var.LagValue())
|
|
794
|
+
else:
|
|
795
|
+
var = var_group
|
|
796
|
+
max_lag = max(max_lag, var.LagValue())
|
|
797
|
+
return max_lag
|
|
798
|
+
|
|
799
|
+
|
|
800
|
+
def GroundTruth_NDE_auto(change_from, change_to, estimator, env, model):
|
|
801
|
+
"""Ground-Truth computation from toy-model for NDE.
|
|
802
|
+
|
|
803
|
+
GroundTruth_*_auto functions extract source, target, blocked-mediators from an estimator.
|
|
804
|
+
|
|
805
|
+
Parameters
|
|
806
|
+
----------
|
|
807
|
+
change_from : single value of same type as single sample for X (float, int, or bool)
|
|
808
|
+
Reference-value to which X is set by intervention in the world seen by the mediator.
|
|
809
|
+
change_to : single value of same type as single sample for X (float, int, or bool)
|
|
810
|
+
Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
|
|
811
|
+
estimator : NaturalEffects_GraphMediation
|
|
812
|
+
Extract source, target, blocked-mediators from estimator.
|
|
813
|
+
env : Environment
|
|
814
|
+
The environment used.
|
|
815
|
+
model : Model
|
|
816
|
+
The toy-model.
|
|
817
|
+
|
|
818
|
+
Returns
|
|
819
|
+
-------
|
|
820
|
+
NDE : If Y is categorical -> np.array( # categories Y, 2 )
|
|
821
|
+
The probabilities the categories of Y (after, before) changing the interventional value of X
|
|
822
|
+
as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
|
|
823
|
+
|
|
824
|
+
NDE : If Y is continuous -> float
|
|
825
|
+
The change in the expectation-value of Y induced by changing the interventional value of X
|
|
826
|
+
as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
|
|
827
|
+
"""
|
|
828
|
+
return GroundTruth_NDE(change_from, change_to, estimator.Source, estimator.Target, estimator.BlockedMediators,
|
|
829
|
+
env, model)
|
|
830
|
+
|
|
831
|
+
|
|
832
|
+
def GroundTruth_NDE(change_from, change_to, source, target, mediators, env, model):
|
|
833
|
+
"""Ground-Truth computation from toy-model for NDE.
|
|
834
|
+
|
|
835
|
+
Note: GroundTruth_*_auto functions extract source, target, blocked-mediators from an estimator.
|
|
836
|
+
|
|
837
|
+
Parameters
|
|
838
|
+
----------
|
|
839
|
+
change_from : single value of same type as single sample for X (float, int, or bool)
|
|
840
|
+
Reference-value to which X is set by intervention in the world seen by the mediator.
|
|
841
|
+
change_to : single value of same type as single sample for X (float, int, or bool)
|
|
842
|
+
Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
|
|
843
|
+
source : VariableDescription
|
|
844
|
+
Effect source.
|
|
845
|
+
target : VariableDescription
|
|
846
|
+
Effect target.
|
|
847
|
+
mediators : *iterable* <VariableDesciption>
|
|
848
|
+
Blocked mediators.
|
|
849
|
+
env : Environment
|
|
850
|
+
The environment used.
|
|
851
|
+
model : Model
|
|
852
|
+
The toy-model.
|
|
853
|
+
|
|
854
|
+
Returns
|
|
855
|
+
-------
|
|
856
|
+
NDE : If Y is categorical -> np.array( # categories Y, 2 )
|
|
857
|
+
The probabilities the categories of Y (after, before) changing the interventional value of X
|
|
858
|
+
as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
|
|
859
|
+
|
|
860
|
+
NDE : If Y is continuous -> float
|
|
861
|
+
The change in the expectation-value of Y induced by changing the interventional value of X
|
|
862
|
+
as "seen" by Y from change_from to change_to, while keeping M as if X remained at change_from.
|
|
863
|
+
"""
|
|
864
|
+
|
|
865
|
+
if target.LagValue() != 0:
|
|
866
|
+
raise Exception("Do not use lagged targets, it is always possible to shift everything, so that the"
|
|
867
|
+
"effect is on an unlagged variable.")
|
|
868
|
+
|
|
869
|
+
if model.is_timeseries:
|
|
870
|
+
# Generate groundtruth for intervention at a SINGLE point in time
|
|
871
|
+
|
|
872
|
+
# Window-size (at each point) must be at least the
|
|
873
|
+
max_lag_in_interventions = FindMaxLag(source, target, mediators)
|
|
874
|
+
# shouldn't be necessary, but fixes some issues with offsets in stationary data:
|
|
875
|
+
max_lag_in_interventions = max(max_lag_in_interventions, model.max_lag)
|
|
876
|
+
|
|
877
|
+
ts = CounterfactualTimeseries(env, model, max_lag_in_interventions=max_lag_in_interventions)
|
|
878
|
+
|
|
879
|
+
y_cf = ts.ComputeIntervention(target, [(source, change_to), (source, change_from)], mediators)
|
|
880
|
+
y_real = ts.ComputeIntervention(target, [(source, change_from)], [])
|
|
881
|
+
|
|
882
|
+
else:
|
|
883
|
+
# Ground-Truth for non-timeseries is straight-forward:
|
|
884
|
+
|
|
885
|
+
modelA = model.Intervene(changes={source.Id(): change_from})
|
|
886
|
+
modelB = model.Intervene(changes={source.Id(): change_to})
|
|
887
|
+
worldA = World(env, modelA)
|
|
888
|
+
worldB = World(env, modelB)
|
|
889
|
+
cf_world = CounterfactualWorld(env, model)
|
|
890
|
+
cf_world.TakeVariablesFromWorld(worldA, mediators)
|
|
891
|
+
cf_world.TakeVariablesFromWorld(worldB, source)
|
|
892
|
+
cf_world.Compute()
|
|
893
|
+
|
|
894
|
+
y_cf = cf_world.Observables()[target]
|
|
895
|
+
y_real = worldA.Observables()[target]
|
|
896
|
+
|
|
897
|
+
if target.is_categorical:
|
|
898
|
+
result = []
|
|
899
|
+
for cY in range(target.categories):
|
|
900
|
+
result.append([np.count_nonzero(y_cf == cY), np.count_nonzero(y_real == cY)])
|
|
901
|
+
return np.array(result) / env.N
|
|
902
|
+
else:
|
|
903
|
+
return np.mean(y_cf - y_real)
|
|
904
|
+
|
|
905
|
+
|
|
906
|
+
def GroundTruth_NDE_fct_auto(x_min, x_max, estimator, env, model, cf_delta=0.5, normalize_by_delta=False,
|
|
907
|
+
grid_stepping=0.1):
|
|
908
|
+
"""See GroundTruth_NDE_auto and NaturalEffects_GraphMediation.NDE_smoothed."""
|
|
909
|
+
source = estimator.Source
|
|
910
|
+
target = estimator.Target
|
|
911
|
+
blocked_mediators = estimator.BlockedMediators
|
|
912
|
+
grid = np.arange(x_min, x_max, grid_stepping)
|
|
913
|
+
return grid, GroundTruth_NDE_fct(source, target, blocked_mediators, env, model, grid, cf_delta, normalize_by_delta)
|
|
914
|
+
|
|
915
|
+
|
|
916
|
+
def GroundTruth_NDE_fct(source, target, mediators, env, model, list_of_points, cf_delta=0.5, normalize_by_delta=False):
|
|
917
|
+
"""See GroundTruth_NDE and NaturalEffects_GraphMediation.NDE_smoothed."""
|
|
918
|
+
return _Fct_on_grid(GroundTruth_NDE, list_of_points, cf_delta, normalize_by_delta,
|
|
919
|
+
source=source, target=target, mediators=mediators, env=env, model=model)
|
|
920
|
+
|
|
921
|
+
|
|
922
|
+
def GroundTruth_NIE_fct(source, target, mediator, env, model, list_of_points, cf_delta=0.5, normalize_by_delta=False):
|
|
923
|
+
"""See GroundTruth_NIE and NaturalEffects_StandardMediation.NIE_smoothed."""
|
|
924
|
+
return _Fct_on_grid(GroundTruth_NIE, list_of_points, cf_delta, normalize_by_delta,
|
|
925
|
+
source=source, target=target, mediator=mediator, env=env, model=model)
|
|
926
|
+
|
|
927
|
+
|
|
928
|
+
def GroundTruth_NIE(change_from, change_to, source, target, mediator, env, model):
|
|
929
|
+
"""Ground-Truth computation from toy-model for NIE.
|
|
930
|
+
|
|
931
|
+
Standard-mediation setup only.
|
|
932
|
+
|
|
933
|
+
Parameters
|
|
934
|
+
----------
|
|
935
|
+
change_from : single value of same type as single sample for X (float, int, or bool)
|
|
936
|
+
Reference-value to which X is set by intervention in the world seen by the mediator.
|
|
937
|
+
change_to : single value of same type as single sample for X (float, int, or bool)
|
|
938
|
+
Post-intervention-value to which X is set by intervention in the world seen by the effect (directly).
|
|
939
|
+
source : VariableDescription
|
|
940
|
+
Effect source.
|
|
941
|
+
target : VariableDescription
|
|
942
|
+
Effect target.
|
|
943
|
+
mediator : VariableDesciption
|
|
944
|
+
Effect mediator.
|
|
945
|
+
env : Environment
|
|
946
|
+
The environment used.
|
|
947
|
+
model : Model
|
|
948
|
+
The toy-model.
|
|
949
|
+
|
|
950
|
+
Returns
|
|
951
|
+
-------
|
|
952
|
+
NIE : If Y is categorical -> np.array( # categories Y, 2 )
|
|
953
|
+
The probabilities the categories of Y (after, before) changing the interventional value of X
|
|
954
|
+
as "seen" by M from change_from to change_to, while keeping the value as (directly) seen by Y,
|
|
955
|
+
as if X remained at change_from.
|
|
956
|
+
|
|
957
|
+
NIE : If Y is continuous -> float
|
|
958
|
+
The change in the expectation-value of Y induced by changing the interventional value of X
|
|
959
|
+
as "seen" by M from change_from to change_to, while keeping the value as (directly) seen by Y,
|
|
960
|
+
as if X remained at change_from.
|
|
961
|
+
"""
|
|
962
|
+
|
|
963
|
+
modelA = model.Intervene(changes={source: change_from})
|
|
964
|
+
modelB = model.Intervene(changes={source: change_to})
|
|
965
|
+
worldA = World(env, modelA)
|
|
966
|
+
worldB = World(env, modelB)
|
|
967
|
+
cf_world = CounterfactualWorld(env, model)
|
|
968
|
+
# this time, take the mediator from the B-world
|
|
969
|
+
cf_world.TakeVariablesFromWorld(worldA, [source])
|
|
970
|
+
cf_world.TakeVariablesFromWorld(worldB, [mediator])
|
|
971
|
+
cf_world.Compute()
|
|
972
|
+
|
|
973
|
+
y_cf = cf_world.Observables()[target]
|
|
974
|
+
y_real = worldA.Observables()[target]
|
|
975
|
+
|
|
976
|
+
if target.is_categorical:
|
|
977
|
+
result = []
|
|
978
|
+
for cY in range(target.categories):
|
|
979
|
+
result.append([np.count_nonzero(y_cf == cY), np.count_nonzero(y_real == cY)])
|
|
980
|
+
return np.array(result) / env.N
|
|
981
|
+
else:
|
|
982
|
+
return np.mean(y_cf - y_real)
|
|
983
|
+
|
|
984
|
+
|
|
985
|
+
class DataHandler:
|
|
986
|
+
"""[INTERNAL] Implement some helper functions and generate time-series data via tigramite's data-frames."""
|
|
987
|
+
|
|
988
|
+
def __init__(self, observables, dataframe_based_preprocessing=True):
|
|
989
|
+
if isinstance(observables, DataFrame):
|
|
990
|
+
self._from_dataframe = True
|
|
991
|
+
self._data = VariablesFromDataframe(observables)
|
|
992
|
+
self._dataframe = observables
|
|
993
|
+
self.use_dataframe_for_preprocessing = dataframe_based_preprocessing
|
|
994
|
+
self.data_selection_virtual_ids = {}
|
|
995
|
+
self._info = {}
|
|
996
|
+
self.preprocessed_data = False
|
|
997
|
+
self.data_selection_frozen = False # filter all data for all fits for missing values, lock after first data-access (see GetVariableAuto and operator [] / __getitem__)
|
|
998
|
+
else:
|
|
999
|
+
self._from_dataframe = False
|
|
1000
|
+
self._data = observables
|
|
1001
|
+
self._dataframe = DataframeFromVariables(observables)
|
|
1002
|
+
self.use_dataframe_for_preprocessing = False
|
|
1003
|
+
|
|
1004
|
+
self._indices = {}
|
|
1005
|
+
self._keys = list(self._data.keys())
|
|
1006
|
+
idx = 0
|
|
1007
|
+
for var in self._data.keys():
|
|
1008
|
+
self._indices[var] = idx
|
|
1009
|
+
idx += 1
|
|
1010
|
+
|
|
1011
|
+
|
|
1012
|
+
def GetIdFor(self, idx, lag, display_name):
|
|
1013
|
+
# self._keys[idx] contains info (eg catgorical or continuous, 'original name') for variable idx, attach display name (eg Mediator)
|
|
1014
|
+
# also add 'original name' eg 'temperature' in square brackets as well as lag
|
|
1015
|
+
base_id = self._keys[idx]
|
|
1016
|
+
return base_id.CopyInfo(display_name + f"[{base_id.Info(detail=1)} at lag {lag}]")
|
|
1017
|
+
|
|
1018
|
+
def RequirePreprocessingFor(self, var, info):
|
|
1019
|
+
if self.use_dataframe_for_preprocessing:
|
|
1020
|
+
if var not in self.data_selection_virtual_ids:
|
|
1021
|
+
# if variable locked also ok after frozen
|
|
1022
|
+
if self.data_selection_frozen:
|
|
1023
|
+
raise Exception("Cannot add variables to preprocessing after variable-set has been locked-in")
|
|
1024
|
+
assert info is not None
|
|
1025
|
+
var_id = self.GetIdFor(var[0], var[1], info)
|
|
1026
|
+
self.data_selection_virtual_ids[var] = var_id
|
|
1027
|
+
self._indices[var_id] = var
|
|
1028
|
+
self._info[var_id] = info
|
|
1029
|
+
|
|
1030
|
+
def GetVariableAuto(self, var, info="Other"):
|
|
1031
|
+
if self._from_dataframe:
|
|
1032
|
+
return self.ReverseLookupSingle(var, info)
|
|
1033
|
+
else:
|
|
1034
|
+
return var
|
|
1035
|
+
|
|
1036
|
+
def GetVariablesAuto(self, vars, info="Other"):
|
|
1037
|
+
result = []
|
|
1038
|
+
for var in vars:
|
|
1039
|
+
result.append(self.GetVariableAuto(var, info))
|
|
1040
|
+
return result
|
|
1041
|
+
|
|
1042
|
+
def DataFrame(self):
|
|
1043
|
+
return self._dataframe
|
|
1044
|
+
|
|
1045
|
+
def __getitem__(self, key):
|
|
1046
|
+
if hasattr(key, '__iter__'):
|
|
1047
|
+
return [self[entry] for entry in key]
|
|
1048
|
+
|
|
1049
|
+
if self.use_dataframe_for_preprocessing:
|
|
1050
|
+
return self._indices[key]
|
|
1051
|
+
|
|
1052
|
+
if key.__class__ == LaggedVariable:
|
|
1053
|
+
return self._indices[key.var], -key.lag
|
|
1054
|
+
else:
|
|
1055
|
+
return self._indices[key], 0
|
|
1056
|
+
|
|
1057
|
+
def _PreprocessData(self, **kwargs):
|
|
1058
|
+
self.data_selection_frozen = True
|
|
1059
|
+
self.preprocessed_data = True
|
|
1060
|
+
X = []
|
|
1061
|
+
Y = []
|
|
1062
|
+
Z = []
|
|
1063
|
+
M = []
|
|
1064
|
+
for var_lag_index, var_id in self.data_selection_virtual_ids.items():
|
|
1065
|
+
info = self._info[var_id]
|
|
1066
|
+
if info == "Source":
|
|
1067
|
+
X.append(var_lag_index)
|
|
1068
|
+
elif info == "Target":
|
|
1069
|
+
Y.append(var_lag_index)
|
|
1070
|
+
elif info == "Mediator":
|
|
1071
|
+
M.append(var_lag_index)
|
|
1072
|
+
elif info == "Adjustment":
|
|
1073
|
+
Z.append(var_lag_index)
|
|
1074
|
+
else:
|
|
1075
|
+
raise Exception("Unknown Variable-Interpretation")
|
|
1076
|
+
data_preprocessed, xyz, data_type = self.DataFrame().construct_array(X=X, Y=Y, Z=Z, extraZ=M, **kwargs) # kw-args forwards eg tau-max
|
|
1077
|
+
self._data = {}
|
|
1078
|
+
i = 0
|
|
1079
|
+
for x in X:
|
|
1080
|
+
self._data[self.data_selection_virtual_ids[x]] = data_preprocessed[i]
|
|
1081
|
+
assert xyz[i] == 0
|
|
1082
|
+
i += 1
|
|
1083
|
+
for y in Y:
|
|
1084
|
+
self._data[self.data_selection_virtual_ids[y]] = data_preprocessed[i]
|
|
1085
|
+
assert xyz[i] == 1
|
|
1086
|
+
i += 1
|
|
1087
|
+
for z in Z:
|
|
1088
|
+
self._data[self.data_selection_virtual_ids[z]] = data_preprocessed[i]
|
|
1089
|
+
assert xyz[i] == 2
|
|
1090
|
+
i += 1
|
|
1091
|
+
for m in M:
|
|
1092
|
+
self._data[self.data_selection_virtual_ids[m]] = data_preprocessed[i]
|
|
1093
|
+
assert xyz[i] == 3
|
|
1094
|
+
i += 1
|
|
1095
|
+
assert i == data_preprocessed.shape[0]
|
|
1096
|
+
|
|
1097
|
+
def GetPreprocessed(self, vars, **kwargs):
|
|
1098
|
+
if not self.preprocessed_data:
|
|
1099
|
+
self._PreprocessData(**kwargs)
|
|
1100
|
+
result = {}
|
|
1101
|
+
for var in vars:
|
|
1102
|
+
assert var in self.data_selection_virtual_ids
|
|
1103
|
+
var_id = self.data_selection_virtual_ids[var]
|
|
1104
|
+
result[var_id] = self._data[var_id]
|
|
1105
|
+
return list(result.keys()), result
|
|
1106
|
+
|
|
1107
|
+
|
|
1108
|
+
def Get(self, name, vars, **kwargs):
|
|
1109
|
+
if self.use_dataframe_for_preprocessing:
|
|
1110
|
+
return self.GetPreprocessed(vars, **kwargs)
|
|
1111
|
+
ids = []
|
|
1112
|
+
i = 1
|
|
1113
|
+
for idx, lag in vars:
|
|
1114
|
+
# assemble names for this variable-group, eg "Mediator5"
|
|
1115
|
+
the_name = name if len(vars) == 1 else name + str(i)
|
|
1116
|
+
ids.append( self.GetIdFor(idx, lag, the_name) )
|
|
1117
|
+
i += 1
|
|
1118
|
+
data, xyz, data_type = self.DataFrame().construct_array(X=vars, Y=[], Z=[], **kwargs)
|
|
1119
|
+
assert data.shape[0] == len(ids)
|
|
1120
|
+
result = {}
|
|
1121
|
+
i = 0
|
|
1122
|
+
for elem in ids:
|
|
1123
|
+
result[elem] = data[i].astype(dtype=np.dtype(elem.dtype))
|
|
1124
|
+
i += 1
|
|
1125
|
+
return ids, result
|
|
1126
|
+
|
|
1127
|
+
def ReverseLookupSingle(self, index, info):
|
|
1128
|
+
if self.use_dataframe_for_preprocessing:
|
|
1129
|
+
self.RequirePreprocessingFor(index, info)
|
|
1130
|
+
return self.data_selection_virtual_ids[index]
|
|
1131
|
+
else:
|
|
1132
|
+
return self._keys[index[0]].Lag(-index[1])
|
|
1133
|
+
|
|
1134
|
+
def ReverseLookupMulti(self, index_set, info=None):
|
|
1135
|
+
return [self.ReverseLookupSingle(index, info) for index in index_set]
|
|
1136
|
+
|
|
1137
|
+
|
|
1138
|
+
def VariablesFromDataframe(dataframe):
|
|
1139
|
+
"""Extract Category-Information from tigramite::dataframe
|
|
1140
|
+
|
|
1141
|
+
Parameters
|
|
1142
|
+
----------
|
|
1143
|
+
dataframe : tigramite.data_processing.DataFrame
|
|
1144
|
+
Dataframe to extract data from.
|
|
1145
|
+
|
|
1146
|
+
|
|
1147
|
+
Returns
|
|
1148
|
+
-------
|
|
1149
|
+
variables : dictionary< VariableDescription, np.array(N) >
|
|
1150
|
+
The variable-meta-data and data.
|
|
1151
|
+
"""
|
|
1152
|
+
|
|
1153
|
+
data_types = dataframe.data_type
|
|
1154
|
+
if dataframe.data_type is not None:
|
|
1155
|
+
# Require data-types to be constant
|
|
1156
|
+
first_elements = data_types[0, :]
|
|
1157
|
+
if not np.all(first_elements == data_types):
|
|
1158
|
+
raise NotImplementedError("Natural Effect Framework currently only supports per variable"
|
|
1159
|
+
"data-types, this dataframe contains variables with changing"
|
|
1160
|
+
"(over time) data-types.")
|
|
1161
|
+
|
|
1162
|
+
result = {} # var i will be in result.items()[i]
|
|
1163
|
+
data = dataframe.values[0]
|
|
1164
|
+
node_count = np.shape(data)[1]
|
|
1165
|
+
for node in range(node_count):
|
|
1166
|
+
if data_types is None or data_types[0, node] == 0:
|
|
1167
|
+
var = ContinuousVariable(name=str(dataframe.var_names[node]))
|
|
1168
|
+
result[var] = data[:, node]
|
|
1169
|
+
else:
|
|
1170
|
+
labels, transformed = np.unique(data[:, node], return_inverse=True)
|
|
1171
|
+
var = CategoricalVariable(name=str(dataframe.var_names[node]), categories=labels)
|
|
1172
|
+
result[var] = transformed
|
|
1173
|
+
return result
|
|
1174
|
+
|
|
1175
|
+
|
|
1176
|
+
def DataframeFromVariables(data_dict):
|
|
1177
|
+
"""Convert Category-Information to tigramite::dataframe
|
|
1178
|
+
|
|
1179
|
+
Parameters
|
|
1180
|
+
----------
|
|
1181
|
+
data_dict : dictionary< VariableDescription, np.array(N) >
|
|
1182
|
+
The variable-meta-data and data.
|
|
1183
|
+
|
|
1184
|
+
Returns
|
|
1185
|
+
-------
|
|
1186
|
+
dataframe : tigramite.data_processing.DataFrame
|
|
1187
|
+
Dataframe containing the raw data.
|
|
1188
|
+
"""
|
|
1189
|
+
|
|
1190
|
+
data_len = 0
|
|
1191
|
+
keys = list(data_dict.keys())
|
|
1192
|
+
values = list(data_dict.values())
|
|
1193
|
+
if keys[0].is_categorical or keys[0].dimension == 1:
|
|
1194
|
+
data_len = np.shape(values[0])[0]
|
|
1195
|
+
else:
|
|
1196
|
+
data_len = np.shape(values[0])[1]
|
|
1197
|
+
dimensions = 0
|
|
1198
|
+
for var, data in data_dict.items():
|
|
1199
|
+
if var.is_categorical:
|
|
1200
|
+
dimensions += 1
|
|
1201
|
+
else:
|
|
1202
|
+
dimensions += var.dimension
|
|
1203
|
+
|
|
1204
|
+
data_out = np.zeros([data_len, dimensions])
|
|
1205
|
+
data_type = np.zeros([data_len, dimensions])
|
|
1206
|
+
names = []
|
|
1207
|
+
i = 0
|
|
1208
|
+
for var, data in data_dict.items():
|
|
1209
|
+
if var.is_categorical:
|
|
1210
|
+
names.append(var.name)
|
|
1211
|
+
if var.category_values is not None:
|
|
1212
|
+
data_out[:, i] = np.take(var.category_values, data)
|
|
1213
|
+
else:
|
|
1214
|
+
data_out[:, i] = data
|
|
1215
|
+
data_type[:, i] = np.ones(data_len)
|
|
1216
|
+
i += 1
|
|
1217
|
+
else:
|
|
1218
|
+
if var.dimension > 1:
|
|
1219
|
+
for j in range(var.dimension):
|
|
1220
|
+
if var.dimension == 1:
|
|
1221
|
+
names.append(var.name)
|
|
1222
|
+
else:
|
|
1223
|
+
names.append(var.name + "$_" + str(j) + "$")
|
|
1224
|
+
data_out[:, i + j] = data[j]
|
|
1225
|
+
data_type[:, i + j] = np.zeros(data_len)
|
|
1226
|
+
else:
|
|
1227
|
+
names.append(var.name)
|
|
1228
|
+
data_out[:, i] = data
|
|
1229
|
+
data_type[:, i] = np.zeros(data_len)
|
|
1230
|
+
i += var.dimension
|
|
1231
|
+
return DataFrame(data=data_out, data_type=data_type, var_names=names)
|