GLDF 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- GLDF/__init__.py +2 -0
- GLDF/bridges/__init__.py +0 -0
- GLDF/bridges/causal_learn.py +185 -0
- GLDF/bridges/tigramite.py +143 -0
- GLDF/bridges/tigramite_plotting_modified.py +4764 -0
- GLDF/cit.py +274 -0
- GLDF/data_management.py +588 -0
- GLDF/data_processing.py +754 -0
- GLDF/frontend.py +537 -0
- GLDF/hccd.py +403 -0
- GLDF/hyperparams.py +205 -0
- GLDF/independence_atoms.py +78 -0
- GLDF/state_space_construction.py +288 -0
- GLDF/tutorials/01_preconfigured_quickstart.ipynb +302 -0
- GLDF/tutorials/02_detailed_configuration.ipynb +394 -0
- GLDF/tutorials/03_custom_patterns.ipynb +447 -0
- gldf-0.9.0.dist-info/METADATA +101 -0
- gldf-0.9.0.dist-info/RECORD +20 -0
- gldf-0.9.0.dist-info/WHEEL +4 -0
- gldf-0.9.0.dist-info/licenses/LICENSE +621 -0
GLDF/data_management.py
ADDED
|
@@ -0,0 +1,588 @@
|
|
|
1
|
+
from dataclasses import dataclass
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import TypeVar, Generic
|
|
4
|
+
|
|
5
|
+
# For python >= 3.12, there is a new simplified syntax for generics.
|
|
6
|
+
# Unfortunately this new syntax is not backwards-compatible
|
|
7
|
+
# (it does not parse to a valid expression for older versions ...).
|
|
8
|
+
var_index = TypeVar('var_index') #, default=int) 3.13 # , SupportsRichComparisonT, hashable)
|
|
9
|
+
|
|
10
|
+
class CI_Identifier(Generic[var_index]):
|
|
11
|
+
"""
|
|
12
|
+
A multi-index defining a conditional independence-statement
|
|
13
|
+
|
|
14
|
+
.. note::
|
|
15
|
+
|
|
16
|
+
The index type :py:obj:`var_index` must be comparable (totally ordered) and hashable, eg int or tuple of int,
|
|
17
|
+
type-annotations are given for the type-var var_index; for time-series, see also :py:class:`CI_Identifier_TimeSeries`.
|
|
18
|
+
"""
|
|
19
|
+
type var_index = var_index
|
|
20
|
+
|
|
21
|
+
def __init__(self, idx_x:var_index, idx_y:var_index, idx_list_z:list[var_index]):
|
|
22
|
+
"""Construct representation of CIT for single-variables X, Y and a list of indices for the variables in
|
|
23
|
+
the set of conditios Z. The representation is undirected and for a (logical) set Z:
|
|
24
|
+
Reordering X and Y or changing the order in the list of inices for Z will result in the same
|
|
25
|
+
representation, see also :py:meth:`__hash__` and :py:meth:`__eq__`.
|
|
26
|
+
|
|
27
|
+
:param idx_x: Index of variable X
|
|
28
|
+
:type idx_x: var_index
|
|
29
|
+
:param idx_y: Index of variable Y
|
|
30
|
+
:type idx_y: var_index
|
|
31
|
+
:param idx_list_z: List of indices of variables in set of conditions Z.
|
|
32
|
+
:type idx_list_z: list[var_index]
|
|
33
|
+
"""
|
|
34
|
+
self.idx_x = min(idx_x, idx_y)
|
|
35
|
+
self.idx_y = max(idx_x, idx_y)
|
|
36
|
+
self.idx_list_z = list([z_idx for z_idx in sorted(idx_list_z)])
|
|
37
|
+
|
|
38
|
+
def undirected_link(self)->tuple[var_index,var_index]:
|
|
39
|
+
"""Get the associated (undirected) link as a tuple.
|
|
40
|
+
|
|
41
|
+
:return: The associated undirected link.
|
|
42
|
+
:rtype: tuple[var_index,var_index]
|
|
43
|
+
"""
|
|
44
|
+
return self.idx_x, self.idx_y
|
|
45
|
+
|
|
46
|
+
def _as_tuple(self)->tuple[var_index, var_index, tuple[var_index, ...]]:
|
|
47
|
+
"""Transcode to a tuple-representation. Used for hashing and comparison-operations.
|
|
48
|
+
|
|
49
|
+
:return: _description_
|
|
50
|
+
:rtype: tuple
|
|
51
|
+
"""
|
|
52
|
+
return self.idx_x, self.idx_y, tuple(self.idx_list_z)
|
|
53
|
+
|
|
54
|
+
def __hash__(self) -> int:
|
|
55
|
+
"""Hash for unordered containers (dict, set).
|
|
56
|
+
|
|
57
|
+
:return: A hash value.
|
|
58
|
+
:rtype: int
|
|
59
|
+
"""
|
|
60
|
+
return hash( self._as_tuple() )
|
|
61
|
+
|
|
62
|
+
def __eq__(self, other: 'CI_Identifier') -> bool:
|
|
63
|
+
"""Equality compare two CI-identifiers as undirected (X and Y are exchangable) test
|
|
64
|
+
with a set (order does not matter) of condtions.
|
|
65
|
+
|
|
66
|
+
:param other: Other CI-Identifier to compare to.
|
|
67
|
+
:type other: CI_Identifier
|
|
68
|
+
:return: Equality
|
|
69
|
+
:rtype: bool
|
|
70
|
+
"""
|
|
71
|
+
return self._as_tuple() == other._as_tuple()
|
|
72
|
+
|
|
73
|
+
def z_dim(self) -> int:
|
|
74
|
+
"""Get dimension of (number of variables in) conditioning set Z.
|
|
75
|
+
|
|
76
|
+
:return: dim(Z)
|
|
77
|
+
:rtype: int
|
|
78
|
+
"""
|
|
79
|
+
return len(self.idx_list_z)
|
|
80
|
+
|
|
81
|
+
def conditioning_set(self) -> set[var_index]:
|
|
82
|
+
"""Get the conditioning set Z (as set).
|
|
83
|
+
|
|
84
|
+
:return: Z
|
|
85
|
+
:rtype: set[var_index]
|
|
86
|
+
"""
|
|
87
|
+
return set(self.idx_list_z)
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class CI_Identifier_TimeSeries(CI_Identifier[tuple[int,int]]):
|
|
91
|
+
"""
|
|
92
|
+
A multi-index defining a conditional independence-statement for timeseries,
|
|
93
|
+
using tigramite's indexing-convention: Individual nodes are indexed by a
|
|
94
|
+
pair (index, -lag).
|
|
95
|
+
"""
|
|
96
|
+
def max_abs_timelag(self)->int:
|
|
97
|
+
"""Get the maximum (abolute) time-lag of any variable involved in the test.
|
|
98
|
+
This means, the maximum over -lag for the lags stored in X, Y or any member of Z.
|
|
99
|
+
|
|
100
|
+
:return: Maximum absolute time-lag.
|
|
101
|
+
:rtype: int
|
|
102
|
+
"""
|
|
103
|
+
max_timelag_xy = max(-self.idx_x[1], -self.idx_y[1])
|
|
104
|
+
if self.z_dim() > 0:
|
|
105
|
+
max_timelag_z = max([-idx_z_lag for _, idx_z_lag in self.idx_list_z])
|
|
106
|
+
return max(max_timelag_xy, max_timelag_z)
|
|
107
|
+
else:
|
|
108
|
+
return max_timelag_xy
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
@dataclass
|
|
116
|
+
class BlockView:
|
|
117
|
+
"""
|
|
118
|
+
View data as pattern-aligned blocks.
|
|
119
|
+
"""
|
|
120
|
+
|
|
121
|
+
pattern_provider: 'CIT_DataPatterned' #: Pattern-provider used to generate this view. Primarily for internal use in convenience-functions like :py:meth:`match_blocksize`.
|
|
122
|
+
cache_id : object|None #: unique identifier associated to the data by the data-manager, to be used for caching results of tests. None to disable caching (eg for bootstrap)
|
|
123
|
+
x_blocks: np.ndarray #: shape=(n,B) with n the block-count, B the block-size
|
|
124
|
+
y_blocks: np.ndarray #: shape=(n,B) with n the block-count, B the block-size
|
|
125
|
+
z_blocks: np.ndarray #: shape=(n,B,k) with n the block-count, B the block-size, k=dim(Z)
|
|
126
|
+
|
|
127
|
+
def copy_and_center(self) -> 'BlockView':
|
|
128
|
+
"""Copy and subtract mean. (Used internally by cit to structure
|
|
129
|
+
residuals correctly, no cach-id required.)
|
|
130
|
+
|
|
131
|
+
:return: Centered copy
|
|
132
|
+
:rtype: BlockView
|
|
133
|
+
"""
|
|
134
|
+
return BlockView(
|
|
135
|
+
self.pattern_provider,
|
|
136
|
+
None,
|
|
137
|
+
self.x_blocks-np.mean(self.x_blocks, axis=1).reshape(-1,1),
|
|
138
|
+
self.y_blocks-np.mean(self.y_blocks, axis=1).reshape(-1,1),
|
|
139
|
+
self.z_blocks-np.mean(self.z_blocks, axis=1).reshape(self.block_count(),1,self.z_dim()) if self.z_blocks is not None else None
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
def block_size(self) -> int:
|
|
143
|
+
"""Get block-size (total number of samples per block).
|
|
144
|
+
|
|
145
|
+
:return: block-size
|
|
146
|
+
:rtype: int
|
|
147
|
+
"""
|
|
148
|
+
return self.x_blocks.shape[1]
|
|
149
|
+
|
|
150
|
+
def block_count(self) -> int:
|
|
151
|
+
"""Get block-count.
|
|
152
|
+
|
|
153
|
+
:return: block-count
|
|
154
|
+
:rtype: int
|
|
155
|
+
"""
|
|
156
|
+
return self.x_blocks.shape[0]
|
|
157
|
+
|
|
158
|
+
def sample_count_used(self) -> int:
|
|
159
|
+
"""Get number of used (contained in a block) data-points.
|
|
160
|
+
|
|
161
|
+
:return: used sample-size
|
|
162
|
+
:rtype: int
|
|
163
|
+
"""
|
|
164
|
+
return self.block_size() * self.block_count()
|
|
165
|
+
|
|
166
|
+
def z_dim(self) -> int:
|
|
167
|
+
"""Get z-dimension (number of variables in the conditioning set Z).
|
|
168
|
+
|
|
169
|
+
:return: dim(Z)
|
|
170
|
+
:rtype: int
|
|
171
|
+
"""
|
|
172
|
+
if self.z_blocks is None:
|
|
173
|
+
return 0
|
|
174
|
+
else:
|
|
175
|
+
return self.z_blocks.shape[2]
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
def trivialize(self) -> 'BlockView':
|
|
179
|
+
"""View data as trivial (a single block of block-size=sampe-size) blocks.
|
|
180
|
+
|
|
181
|
+
:return: trivial block-view
|
|
182
|
+
:rtype: BlockView
|
|
183
|
+
"""
|
|
184
|
+
return self.pattern_provider.view_blocks_trivial()
|
|
185
|
+
|
|
186
|
+
def match_blocksize(self, other: 'BlockView') -> 'BlockView':
|
|
187
|
+
"""View data as blocks with size matching another block-view.
|
|
188
|
+
|
|
189
|
+
:param other: block-view whose block-size should be matched
|
|
190
|
+
:type other: BlockView
|
|
191
|
+
:return: block-view of matching size
|
|
192
|
+
:rtype: BlockView
|
|
193
|
+
"""
|
|
194
|
+
return self.pattern_provider.view_blocks_match(other)
|
|
195
|
+
|
|
196
|
+
def apply_blockformat(self, X: np.ndarray, Y: np.ndarray, Z: np.ndarray|None=None) -> 'BlockView':
|
|
197
|
+
"""Given data for X, Y, Z, match to current block-size settings. (Used internally by cit to structure
|
|
198
|
+
residuals correctly, no cach-id required.)
|
|
199
|
+
|
|
200
|
+
:param X: X
|
|
201
|
+
:type X: np.ndarray
|
|
202
|
+
:param Y: Y
|
|
203
|
+
:type Y: np.ndarray
|
|
204
|
+
:param Z: Z, defaults to None
|
|
205
|
+
:type Z: np.ndarray | None, optional
|
|
206
|
+
:return: block-view of given data matching current size
|
|
207
|
+
:rtype: BlockView
|
|
208
|
+
"""
|
|
209
|
+
return self.pattern_provider.clone_from_data(X, Y, Z).view_blocks_match(self)
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
def bootstrap_unaligned_blocks(self, rng: np.random.Generator, bootstrap_block_count: int) -> 'BlockView':
|
|
213
|
+
"""Bootstrap random (unaligned) blocks.
|
|
214
|
+
|
|
215
|
+
:param rng: A random number generator.
|
|
216
|
+
:type rng: np.random.Generator (or similar, requires .integers behaving as numpy.random.Generator)
|
|
217
|
+
:param bootstrap_block_count: Number of blocks to bootstrap
|
|
218
|
+
:type bootstrap_block_count: int
|
|
219
|
+
:return: The bootstrapped blocks (actually a copy, not a view, despite being typed as block-"view").
|
|
220
|
+
:rtype: BlockView
|
|
221
|
+
"""
|
|
222
|
+
return self.pattern_provider.bootstrap_unaligned_blocks(rng, bootstrap_block_count, self.block_size())
|
|
223
|
+
|
|
224
|
+
@dataclass
|
|
225
|
+
class CIT_Data:
|
|
226
|
+
"""
|
|
227
|
+
Data for CIT.
|
|
228
|
+
|
|
229
|
+
.. seealso::
|
|
230
|
+
Used through :py:class:`CIT_DataPatterned`.
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
x_data: np.ndarray #: Shape specified by data-manager/pattern-provider. See :py:class:`CIT_DataPatterned`.
|
|
234
|
+
y_data: np.ndarray #: Same shape as x_data.
|
|
235
|
+
z_data: np.ndarray #: Shape=(shape_xy,k), where shape_xy is the shape of x_data/y_data and k=dim(Z) is the size of the conditioning set.
|
|
236
|
+
cache_id : tuple|None #: Unique identifier associated to the data by the data-manager, to be used for caching results. None to disable caching.
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
class CIT_DataPatterned(CIT_Data):
|
|
241
|
+
"""
|
|
242
|
+
Patterned data for mCIT.
|
|
243
|
+
|
|
244
|
+
.. seealso::
|
|
245
|
+
Pattern-related aspects are to be overwritten by custom pattern providers,
|
|
246
|
+
for example :py:class:`CIT_DataPatterned_PersistentInTime` or
|
|
247
|
+
:py:class:`CIT_DataPatterned_PesistentInSpace`.
|
|
248
|
+
"""
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def view_blocks(self, block_size:int) -> BlockView:
|
|
253
|
+
"""View as blocks of given size. The layout of blocks encodes the (prior) knowledge about patterns.
|
|
254
|
+
|
|
255
|
+
:param block_size: requested block-size (the block-size of the result may not exactly match this number,
|
|
256
|
+
if the underlying pattern provider cannot construct arbitrary block-sizes).
|
|
257
|
+
:type block_size: int
|
|
258
|
+
:return: view as pattern-aligned blocks
|
|
259
|
+
:rtype: BlockView
|
|
260
|
+
"""
|
|
261
|
+
raise NotImplementedError()
|
|
262
|
+
|
|
263
|
+
@staticmethod
|
|
264
|
+
def get_actual_block_format(requested_size: int) -> int|tuple[int,...]:
|
|
265
|
+
"""Get the actual (possibly multi-dimensional) format of blocks produced. Used for plotting.
|
|
266
|
+
|
|
267
|
+
:param requested_size: The size of blocks requested.
|
|
268
|
+
:type requested_size: int
|
|
269
|
+
:return: Format of blocks produced.
|
|
270
|
+
:rtype: int|tuple[int,...]
|
|
271
|
+
"""
|
|
272
|
+
raise NotImplementedError()
|
|
273
|
+
|
|
274
|
+
@staticmethod
|
|
275
|
+
def reproject_blocks(value_per_block: np.ndarray, block_configuration: BlockView, data_configuration: tuple[int,...]) -> np.ndarray:
|
|
276
|
+
"""Reproject a function :math:`f` on blocks to the original index-set layout (for example time, space etc). Used for plotting.
|
|
277
|
+
|
|
278
|
+
:param value_per_block: values of :math:`f` for each block
|
|
279
|
+
:type value_per_block: np.ndarray
|
|
280
|
+
:param block_configuration: the block-configuration (eg block-size) used
|
|
281
|
+
:type block_configuration: BlockView
|
|
282
|
+
:param data_configuration: the data-shape (per-variable) in the original data
|
|
283
|
+
:type data_configuration: tuple[int,...]
|
|
284
|
+
:return: plottable layout of :math:`f` as function of the original index-space
|
|
285
|
+
:rtype: np.ndarray
|
|
286
|
+
"""
|
|
287
|
+
raise NotImplementedError()
|
|
288
|
+
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
def sample_count(self) -> int:
|
|
292
|
+
"""Get sample size.
|
|
293
|
+
|
|
294
|
+
:return: sample-size N
|
|
295
|
+
:rtype: int
|
|
296
|
+
"""
|
|
297
|
+
return self.x_data.size
|
|
298
|
+
|
|
299
|
+
def z_dim(self) -> int:
|
|
300
|
+
"""Get dimension (number of variables) of conditioning set Z.
|
|
301
|
+
|
|
302
|
+
:return: dim(Z)
|
|
303
|
+
:rtype: int
|
|
304
|
+
"""
|
|
305
|
+
if self.z_data is None:
|
|
306
|
+
return 0
|
|
307
|
+
else:
|
|
308
|
+
return self.z_data.shape[-1]
|
|
309
|
+
|
|
310
|
+
def copy_and_center(self) -> 'CIT_DataPatterned':
|
|
311
|
+
"""Copy and subtract mean.
|
|
312
|
+
|
|
313
|
+
:return: centered copy
|
|
314
|
+
:rtype: CIT_DataPatterned
|
|
315
|
+
"""
|
|
316
|
+
return self.clone_from_data(
|
|
317
|
+
self.x_data-np.mean(self.x_data),
|
|
318
|
+
self.y_data-np.mean(self.y_data),
|
|
319
|
+
self.z_data-np.mean(self.z_data, axis=tuple(range(self.x_data.ndim))).reshape(*((1,)*self.x_data.ndim),self.z_dim()) if self.z_blocks is not None else None,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
def view_blocks_trivial(self) -> BlockView:
|
|
323
|
+
"""View as trivial blocks (a single block of size N).
|
|
324
|
+
|
|
325
|
+
:return: view by trivial blocks
|
|
326
|
+
:rtype: BlockView
|
|
327
|
+
"""
|
|
328
|
+
return BlockView(
|
|
329
|
+
pattern_provider=self,
|
|
330
|
+
cache_id=None if self.cache_id is None else (*self.cache_id, -1),
|
|
331
|
+
x_blocks=self.x_data.reshape((1, -1)),
|
|
332
|
+
y_blocks=self.y_data.reshape((1, -1)),
|
|
333
|
+
z_blocks=self.z_data.reshape((1, -1, self.z_dim())) if self.z_dim() > 0 else None
|
|
334
|
+
)
|
|
335
|
+
|
|
336
|
+
def view_blocks_match(self, other: BlockView) -> BlockView:
|
|
337
|
+
"""View as blocks matching configuration (block-sizes etc) of another block-view.
|
|
338
|
+
|
|
339
|
+
:param other: The other block-view, whose configuration should be copied.
|
|
340
|
+
:type other: BlockView
|
|
341
|
+
:return: A block-view of the data with same configuration as :param other:.
|
|
342
|
+
:rtype: BlockView
|
|
343
|
+
"""
|
|
344
|
+
return self.view_blocks(other.block_size())
|
|
345
|
+
|
|
346
|
+
def bootstrap_unaligned_blocks(self, rng: np.random.Generator, bootstrap_block_count: int, block_size: int) -> BlockView:
|
|
347
|
+
"""Bootstrap random (unaligned) blocks.
|
|
348
|
+
|
|
349
|
+
:param rng: A random number generator.
|
|
350
|
+
:type rng: np.random.Generator (or similar, requires .integers of numpy.random.Generator)
|
|
351
|
+
:param bootstrap_block_count: Number of blocks to bootstrap
|
|
352
|
+
:type bootstrap_block_count: int
|
|
353
|
+
:param block_size: Size per block
|
|
354
|
+
:type block_size: int
|
|
355
|
+
:return: A the bootstrapped blocks (actually a copy, not a view)
|
|
356
|
+
:rtype: BlockView
|
|
357
|
+
"""
|
|
358
|
+
indices = rng.integers(0, self.sample_count(), (bootstrap_block_count, block_size))
|
|
359
|
+
return BlockView(self, None, self.x_data.reshape(-1)[indices], self.y_data.reshape(-1)[indices], self.z_data.reshape(-1,self.z_dim())[indices,:])
|
|
360
|
+
|
|
361
|
+
@classmethod
|
|
362
|
+
def clone_from_data(cls, X:np.ndarray,Y:np.ndarray,Z:np.ndarray) -> 'CIT_DataPatterned':
|
|
363
|
+
"""Attach the currently used pattern-provider to given data.
|
|
364
|
+
|
|
365
|
+
:param X: X-data
|
|
366
|
+
:type X: np.ndarray
|
|
367
|
+
:param Y: Y-data
|
|
368
|
+
:type Y: np.ndarray
|
|
369
|
+
:param Z: Z-data
|
|
370
|
+
:type Z: np.ndarray
|
|
371
|
+
:return: Patterned data.
|
|
372
|
+
:rtype: decltype(self), a type derived from :py:class:`CIT_DataPatterned`
|
|
373
|
+
"""
|
|
374
|
+
return cls(X, Y, Z, cache_id=None)
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
|
|
379
|
+
class CIT_DataPatterned_PersistentInTime(CIT_DataPatterned):
|
|
380
|
+
"""
|
|
381
|
+
Patterned data for mCIT. The implemented pattern captures persistent regimes in a single (eg time) direction.
|
|
382
|
+
|
|
383
|
+
| x_data has shape=(N), where N is sample-size
|
|
384
|
+
| y_data has shape=(N), where N is sample-size
|
|
385
|
+
| z_data has shape=(N,k), where N is sample-size and k=dim(Z)
|
|
386
|
+
|
|
387
|
+
.. seealso::
|
|
388
|
+
See :ref:`overview on custom patterns <label-patterns>`.
|
|
389
|
+
Methods are specified and documented on :py:class:`CIT_DataPatterned`.
|
|
390
|
+
"""
|
|
391
|
+
|
|
392
|
+
|
|
393
|
+
def view_blocks(self, block_size:int) -> BlockView:
|
|
394
|
+
block_count = int(self.sample_count()/block_size)
|
|
395
|
+
aligned_N = block_size * block_count
|
|
396
|
+
return BlockView(
|
|
397
|
+
pattern_provider=self,
|
|
398
|
+
cache_id=None if self.cache_id is None else (*self.cache_id, block_size),
|
|
399
|
+
x_blocks=self.x_data[:aligned_N].reshape((block_count, block_size)),
|
|
400
|
+
y_blocks=self.y_data[:aligned_N].reshape((block_count, block_size)),
|
|
401
|
+
z_blocks=self.z_data[:aligned_N, :].reshape((block_count, block_size, -1)) if self.z_dim() > 0 else None
|
|
402
|
+
)
|
|
403
|
+
|
|
404
|
+
|
|
405
|
+
@staticmethod
|
|
406
|
+
def get_actual_block_format(requested_size: int) -> int:
|
|
407
|
+
return requested_size
|
|
408
|
+
|
|
409
|
+
@staticmethod
|
|
410
|
+
def reproject_blocks(value_per_block: np.ndarray, block_configuration: BlockView, data_configuration: tuple[int,...]) -> np.ndarray:
|
|
411
|
+
return value_per_block.repeat(block_configuration.block_size())
|
|
412
|
+
|
|
413
|
+
|
|
414
|
+
class CIT_DataPatterned_PesistentInSpace(CIT_DataPatterned):
|
|
415
|
+
"""
|
|
416
|
+
Patterned data for mCIT. The implemented pattern captures persistent regimes in two (eg spatial) direction.
|
|
417
|
+
|
|
418
|
+
| x_data has shape=(w,h), where w, h are the width and height of the sample-grid.
|
|
419
|
+
| y_data has shape=(w,h), where w, h are the width and height of the sample-grid.
|
|
420
|
+
| z_data has shape=(w,h,k), where w, h are the width and height of the sample-grid, and k=dim(Z)
|
|
421
|
+
|
|
422
|
+
.. seealso::
|
|
423
|
+
See :ref:`overview on custom patterns <label-patterns>`.
|
|
424
|
+
Methods are specified and documented on :py:class:`CIT_DataPatterned`.
|
|
425
|
+
"""
|
|
426
|
+
|
|
427
|
+
|
|
428
|
+
def _get_full_size(self) -> tuple[int,int]:
|
|
429
|
+
return self.x_data.shape
|
|
430
|
+
@staticmethod
|
|
431
|
+
def get_actual_block_format(requested_size: int) -> tuple[int,int]:
|
|
432
|
+
sqr_side = int(np.ceil(np.sqrt(requested_size)-0.001))
|
|
433
|
+
return (sqr_side, sqr_side)
|
|
434
|
+
|
|
435
|
+
def view_blocks(self, block_size:int) -> BlockView:
|
|
436
|
+
actual_block_size = self.get_actual_block_format(block_size)
|
|
437
|
+
actual_block_size_total = np.prod(actual_block_size)
|
|
438
|
+
block_counts = list(map(int, np.divide( self._get_full_size(), actual_block_size )))
|
|
439
|
+
block_count_total = np.prod(block_counts)
|
|
440
|
+
aligned_N = np.multiply(block_counts, actual_block_size)
|
|
441
|
+
def extract_blocks(data):
|
|
442
|
+
block_individual_axes = data[:aligned_N[0], :aligned_N[1]].reshape((block_counts[0], actual_block_size[0], block_counts[1], actual_block_size[1]))
|
|
443
|
+
return block_individual_axes.transpose(0,2,1,3).reshape(block_count_total,actual_block_size_total)
|
|
444
|
+
def extract_blocks_z(data_z):
|
|
445
|
+
block_individual_axes = data_z[:aligned_N[0], :aligned_N[1],:].reshape((block_counts[0], actual_block_size[0], block_counts[1], actual_block_size[1], self.z_dim()))
|
|
446
|
+
return block_individual_axes.transpose(0,2,1,3,4).reshape(block_count_total,actual_block_size_total, self.z_dim())
|
|
447
|
+
|
|
448
|
+
return BlockView(
|
|
449
|
+
pattern_provider=self,
|
|
450
|
+
cache_id=None if self.cache_id is None else (*self.cache_id, actual_block_size),
|
|
451
|
+
x_blocks=extract_blocks(self.x_data),
|
|
452
|
+
y_blocks=extract_blocks(self.y_data),
|
|
453
|
+
z_blocks=extract_blocks_z(self.z_data) if self.z_dim() > 0 else None
|
|
454
|
+
)
|
|
455
|
+
|
|
456
|
+
@staticmethod
|
|
457
|
+
def reproject_blocks(value_per_block: np.ndarray, block_configuration: BlockView, data_configuration: tuple[int,...]) -> np.ndarray:
|
|
458
|
+
actual_block_size = block_configuration.pattern_provider.get_actual_block_format(block_configuration.block_size())
|
|
459
|
+
directional_block_count = np.floor(np.asarray(data_configuration)/np.asarray(actual_block_size)).astype(int)
|
|
460
|
+
reprojected_values = value_per_block.reshape(directional_block_count)
|
|
461
|
+
for idx, directional_size in enumerate(actual_block_size):
|
|
462
|
+
reprojected_values = reprojected_values.repeat(directional_size, axis=idx)
|
|
463
|
+
return reprojected_values
|
|
464
|
+
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
|
|
471
|
+
class IManageData:
|
|
472
|
+
"""Specification of data-manager interface. Implement this to provide a custom data-manager.
|
|
473
|
+
"""
|
|
474
|
+
|
|
475
|
+
def get_patterned_data(self, ci: CI_Identifier) -> CIT_DataPatterned:
|
|
476
|
+
"""Get CIT-data with attached pattern-information.
|
|
477
|
+
|
|
478
|
+
.. seealso::
|
|
479
|
+
Details on patterns are provided at :ref:`label-patterns`.
|
|
480
|
+
Details on cache-IDs are given at :ref:`label-cache-ids`.
|
|
481
|
+
|
|
482
|
+
:param ci: The CI identified by its variable indices.
|
|
483
|
+
:type ci: CI_Identifier
|
|
484
|
+
:return: The CIT-data with attached pattern-provider.
|
|
485
|
+
:rtype: CIT_DataPatterned
|
|
486
|
+
"""
|
|
487
|
+
raise NotImplementedError()
|
|
488
|
+
|
|
489
|
+
def number_of_variables(self) -> int:
|
|
490
|
+
"""Get the number of variables (as used e.g. by PCMCI) in the current data-set.
|
|
491
|
+
|
|
492
|
+
:return: Number of (contemporaneous) variables.
|
|
493
|
+
:rtype: int
|
|
494
|
+
"""
|
|
495
|
+
raise NotImplementedError()
|
|
496
|
+
|
|
497
|
+
def total_sample_size(self) -> int:
|
|
498
|
+
"""Get the total sample-size.
|
|
499
|
+
|
|
500
|
+
:return: sample-size
|
|
501
|
+
:rtype: int
|
|
502
|
+
"""
|
|
503
|
+
raise NotImplementedError()
|
|
504
|
+
|
|
505
|
+
def reproject_blocks(self, value_per_block: np.ndarray, block_configuration: BlockView) -> np.ndarray:
|
|
506
|
+
"""Project function-values given on blocks back to original data-layout for plotting.
|
|
507
|
+
|
|
508
|
+
:param value_per_block: function-values taken on blocks
|
|
509
|
+
:type value_per_block: np.ndarray
|
|
510
|
+
:param block_configuration: the block-layout (e.g. block-size)
|
|
511
|
+
:type block_configuration: BlockView
|
|
512
|
+
:return: the function-values taken in the original index-space.
|
|
513
|
+
:rtype: np.ndarray
|
|
514
|
+
"""
|
|
515
|
+
return self.pattern.reproject_blocks(value_per_block, block_configuration, self._data.shape[:-1])
|
|
516
|
+
|
|
517
|
+
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
|
|
521
|
+
class DataManager_NumpyArray_IID(IManageData):
|
|
522
|
+
"""Data-manager designed for use with IID data.
|
|
523
|
+
"""
|
|
524
|
+
|
|
525
|
+
def __init__(self, data_indexed_by_sampleidx_variableidx:np.ndarray, copy_data:bool=True, pattern=CIT_DataPatterned_PersistentInTime, reproject_pattern_for_plotting=None):
|
|
526
|
+
self._data = data_indexed_by_sampleidx_variableidx.copy() if copy_data else data_indexed_by_sampleidx_variableidx
|
|
527
|
+
# protect against accidential modification:
|
|
528
|
+
self._data.flags['WRITEABLE'] = False
|
|
529
|
+
self.pattern = pattern
|
|
530
|
+
self.reproject_pattern_for_plotting = reproject_pattern_for_plotting
|
|
531
|
+
|
|
532
|
+
def get_patterned_data(self, ci: CI_Identifier[int]) -> CIT_DataPatterned:
|
|
533
|
+
# Multi-indexing (Z) in numpy will copy (ie data_z has its own malloc and memcpy),
|
|
534
|
+
# thus we may not want to store data_z with the query in cache?
|
|
535
|
+
# Note: x[...,k] accesses index k in the last axis ie [:,k] or [:,:,k] etc
|
|
536
|
+
data_z = self._data[...,ci.idx_list_z] if len(ci.idx_list_z)>0 else None
|
|
537
|
+
data_x = self._data[...,ci.idx_x]
|
|
538
|
+
data_y = self._data[...,ci.idx_y]
|
|
539
|
+
return self.pattern(x_data=data_x, y_data=data_y, z_data=data_z, cache_id=(self, ci))
|
|
540
|
+
|
|
541
|
+
def number_of_variables(self) -> int:
|
|
542
|
+
return self._data.shape[-1]
|
|
543
|
+
|
|
544
|
+
def total_sample_size(self) -> int:
|
|
545
|
+
return np.prod( self._data.shape[:-2] )
|
|
546
|
+
|
|
547
|
+
|
|
548
|
+
|
|
549
|
+
|
|
550
|
+
class DataManager_NumpyArray_Timeseries(IManageData):
|
|
551
|
+
"""Data-manager designed for use with time-series data.
|
|
552
|
+
"""
|
|
553
|
+
|
|
554
|
+
def __init__(self, data_indexed_by_sampleidx_variableidx:np.ndarray, copy_data:bool=True, pattern=CIT_DataPatterned_PersistentInTime, reproject_pattern_for_plotting=None):
|
|
555
|
+
self._data = data_indexed_by_sampleidx_variableidx.copy() if copy_data else data_indexed_by_sampleidx_variableidx
|
|
556
|
+
# protect against accidential modification:
|
|
557
|
+
self._data.flags['WRITEABLE'] = False
|
|
558
|
+
self.pattern = pattern
|
|
559
|
+
self.reproject_pattern_for_plotting = reproject_pattern_for_plotting
|
|
560
|
+
|
|
561
|
+
def get_patterned_data(self, ci: CI_Identifier_TimeSeries) -> CIT_DataPatterned:
|
|
562
|
+
T, var_count_total = self._data.shape
|
|
563
|
+
max_timelag = ci.max_abs_timelag()
|
|
564
|
+
window = np.lib.stride_tricks.sliding_window_view(self._data, [max_timelag+1,var_count_total]) \
|
|
565
|
+
.reshape(T-max_timelag, max_timelag+1, var_count_total) # (window count, window length, variables)
|
|
566
|
+
|
|
567
|
+
# -1 is last in window, so e.g. lag=0 is last in window, lag=-1 is second to last etc
|
|
568
|
+
x_var, x_lag = ci.idx_x
|
|
569
|
+
data_x = window[:, -1+x_lag, x_var]
|
|
570
|
+
y_var, y_lag = ci.idx_y
|
|
571
|
+
data_y = window[:, -1+y_lag, y_var]
|
|
572
|
+
data_z = None
|
|
573
|
+
if ci.z_dim() > 0:
|
|
574
|
+
z_vars = np.array([z_var for z_var, _ in ci.idx_list_z])
|
|
575
|
+
z_lags = np.array([z_lag for _, z_lag in ci.idx_list_z])
|
|
576
|
+
# Multi-indexing (Z) in numpy will copy (ie data_z has its own malloc and memcpy),
|
|
577
|
+
# thus we may not want to store data_z with the query in cache?
|
|
578
|
+
data_z = window[:, -1+z_lags, z_vars]
|
|
579
|
+
return self.pattern(x_data=data_x, y_data=data_y, z_data=data_z, cache_id=(self, ci))
|
|
580
|
+
|
|
581
|
+
def number_of_variables(self) -> int:
|
|
582
|
+
return self._data.shape[1]
|
|
583
|
+
|
|
584
|
+
def total_sample_size(self) -> int:
|
|
585
|
+
return self._data.shape[0]
|
|
586
|
+
|
|
587
|
+
|
|
588
|
+
|