GLDF 0.9.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- GLDF/__init__.py +2 -0
- GLDF/bridges/__init__.py +0 -0
- GLDF/bridges/causal_learn.py +185 -0
- GLDF/bridges/tigramite.py +143 -0
- GLDF/bridges/tigramite_plotting_modified.py +4764 -0
- GLDF/cit.py +274 -0
- GLDF/data_management.py +588 -0
- GLDF/data_processing.py +754 -0
- GLDF/frontend.py +537 -0
- GLDF/hccd.py +403 -0
- GLDF/hyperparams.py +205 -0
- GLDF/independence_atoms.py +78 -0
- GLDF/state_space_construction.py +288 -0
- GLDF/tutorials/01_preconfigured_quickstart.ipynb +302 -0
- GLDF/tutorials/02_detailed_configuration.ipynb +394 -0
- GLDF/tutorials/03_custom_patterns.ipynb +447 -0
- gldf-0.9.0.dist-info/METADATA +101 -0
- gldf-0.9.0.dist-info/RECORD +20 -0
- gldf-0.9.0.dist-info/WHEEL +4 -0
- gldf-0.9.0.dist-info/licenses/LICENSE +621 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
from .hccd import IRepresentState, IRepresentStateSpace, IConstructStateSpace, IPresentResult, IProvideIndependenceAtoms, IResolveRegimeStructure, graph_t
|
|
2
|
+
from .data_management import CI_Identifier, BlockView
|
|
3
|
+
|
|
4
|
+
from typing import Callable
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
from dataclasses import dataclass
|
|
8
|
+
from itertools import chain, combinations, product
|
|
9
|
+
# cf https://docs.python.org/3/library/itertools.html#itertools-recipes:
|
|
10
|
+
def powerset(iterable):
|
|
11
|
+
s = list(iterable)
|
|
12
|
+
return chain.from_iterable(combinations(s, r) for r in range(len(s)+1))
|
|
13
|
+
def powerset_as_list_of_sets(iterable):
|
|
14
|
+
return [set(as_tuple) for as_tuple in powerset(iterable)]
|
|
15
|
+
def all_binary_combinations(bit_count):
|
|
16
|
+
return product([0, 1], repeat=bit_count)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
@dataclass
|
|
20
|
+
class ModelIndicator:
|
|
21
|
+
"""
|
|
22
|
+
Represents a model-indicator in the current baseline implementation.
|
|
23
|
+
"""
|
|
24
|
+
undirected_link: tuple #: The link in the model on which a change was detected.
|
|
25
|
+
representor: CI_Identifier #: A representing test (this test is independent iff the model-indicator is zero.
|
|
26
|
+
|
|
27
|
+
def __hash__(self):
|
|
28
|
+
return hash(self.undirected_link)
|
|
29
|
+
def __eq__(self, other: 'ModelIndicator | tuple'):
|
|
30
|
+
if isinstance(other, ModelIndicator):
|
|
31
|
+
return other.undirected_link == self.undirected_link
|
|
32
|
+
else:
|
|
33
|
+
return other == self.undirected_link
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class State(IRepresentState):
|
|
37
|
+
"""
|
|
38
|
+
Represents a state in the current baseline implementation.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def __init__(self, state_space: 'StateSpace', model_indicator_active: dict[ModelIndicator, bool]):
|
|
42
|
+
self._state_space = state_space
|
|
43
|
+
self.implied = set()
|
|
44
|
+
self.model_indicator_active = model_indicator_active
|
|
45
|
+
|
|
46
|
+
def state_space(self) -> IConstructStateSpace:
|
|
47
|
+
return self._state_space
|
|
48
|
+
|
|
49
|
+
def add_implication(self, ci: CI_Identifier):
|
|
50
|
+
self.implied.add(ci)
|
|
51
|
+
|
|
52
|
+
def overwrites_ci(self, ci: CI_Identifier):
|
|
53
|
+
return self._state_space.controls_ci(ci)
|
|
54
|
+
|
|
55
|
+
def get_ci_pseudo_value(self, ci: CI_Identifier):
|
|
56
|
+
return ci not in self.implied
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def _implies_all_conditions(self, set_of_conditions: set[tuple]) -> bool:
|
|
60
|
+
return all( [self.model_indicator_active[mi] for mi in set_of_conditions] )
|
|
61
|
+
|
|
62
|
+
def _implies_all_conditions_in_at_least_one_listed_set(self, list_of_sets_of_conditions: list[set[tuple]]) -> bool:
|
|
63
|
+
return any( [self._implies_all_conditions(set_of_condition) for set_of_condition in list_of_sets_of_conditions] )
|
|
64
|
+
|
|
65
|
+
class StateSpace(IRepresentStateSpace):
|
|
66
|
+
"""
|
|
67
|
+
Represents the state-space in the current baseline implementation.
|
|
68
|
+
"""
|
|
69
|
+
def __init__(self, model_indicators: 'list[ModelIndicator]'=[], marked_ci: set[CI_Identifier]=set()):
|
|
70
|
+
self.model_indicators = model_indicators
|
|
71
|
+
self.marked_ci = marked_ci
|
|
72
|
+
self._states = self._build_states()
|
|
73
|
+
|
|
74
|
+
def _fold_model_indicator_activity_from_list_into_dict(self, model_indicator_activity: tuple) -> 'dict[ModelIndicator, bool]':
|
|
75
|
+
return {mi: value for mi, value in zip(self.model_indicators, model_indicator_activity)}
|
|
76
|
+
|
|
77
|
+
def _build_states(self) -> list[State]:
|
|
78
|
+
return [State(state_space=self, model_indicator_active=self._fold_model_indicator_activity_from_list_into_dict(model_indicator_activity))
|
|
79
|
+
for model_indicator_activity
|
|
80
|
+
in all_binary_combinations(len(self.model_indicators))]
|
|
81
|
+
|
|
82
|
+
def is_trivial(self) -> bool:
|
|
83
|
+
return len(self.model_indicators) == 0
|
|
84
|
+
|
|
85
|
+
def controls_ci(self, ci: CI_Identifier):
|
|
86
|
+
return ci in self.marked_ci
|
|
87
|
+
|
|
88
|
+
def states(self) -> list[State]:
|
|
89
|
+
return self._states
|
|
90
|
+
|
|
91
|
+
def finalize(self, graphs: dict[IRepresentState,graph_t]) -> IPresentResult:
|
|
92
|
+
return Unionize_Translate_Transfer_NoUnionCylces.obtain_translation_result(state_space=self, graphs=graphs)
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def states_which_imply(self, list_of_conditions: list[set[tuple]]) -> list[State]:
|
|
96
|
+
return [state for state in self.states()
|
|
97
|
+
if state._implies_all_conditions_in_at_least_one_listed_set(list_of_conditions)]
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class ModelIndicators_NoUnionCycles:
|
|
104
|
+
"""
|
|
105
|
+
In the current baseline implementation, in a first phase, model indicators are
|
|
106
|
+
found as a maximum by semi-ordering based on indicator-implications.
|
|
107
|
+
|
|
108
|
+
*Implements phase I of Algo. 3 from* [RR25]_\\ *.*
|
|
109
|
+
"""
|
|
110
|
+
|
|
111
|
+
def __init__(self, testing_backend: IProvideIndependenceAtoms, marked_tests: set[CI_Identifier]):
|
|
112
|
+
self.testing_backend = testing_backend
|
|
113
|
+
self.marked_tests = marked_tests
|
|
114
|
+
self.model_indicators = list(self._initialize_model_indicators())
|
|
115
|
+
self.state_space = StateSpace(model_indicators=self.model_indicators, marked_ci=self.marked_tests)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def _initialize_model_indicators(self) -> list[ModelIndicator]:
|
|
119
|
+
links = set(ci.undirected_link() for ci in self.marked_tests
|
|
120
|
+
if not self.testing_backend.found_globally_independent_for_some_Z(ci.undirected_link()))
|
|
121
|
+
return [ModelIndicator(undirected_link=link, representor=self._model_indicator_representator(link=link))
|
|
122
|
+
for link in links]
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
def _model_indicator_representator(self, link) -> CI_Identifier:
|
|
126
|
+
smallest_element = None
|
|
127
|
+
for relevant_test in [marked_test for marked_test in self.marked_tests if marked_test.undirected_link() == link]:
|
|
128
|
+
if smallest_element is None:
|
|
129
|
+
smallest_element = relevant_test
|
|
130
|
+
elif self.testing_backend.regime_implication([relevant_test], smallest_element) \
|
|
131
|
+
and not self.testing_backend.regime_implication([smallest_element], relevant_test):
|
|
132
|
+
# switch conservatively (ordering is such that small conditioning sets are first)
|
|
133
|
+
smallest_element = relevant_test
|
|
134
|
+
assert smallest_element is not None, "This should always find a representor for each model indicator!"
|
|
135
|
+
return smallest_element
|
|
136
|
+
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
|
|
140
|
+
|
|
141
|
+
class NoUnionCycles(IConstructStateSpace):
|
|
142
|
+
"""
|
|
143
|
+
Current baseline implementation.
|
|
144
|
+
|
|
145
|
+
*Implements Algo. 3 from* [RR25]_\\ *.*
|
|
146
|
+
"""
|
|
147
|
+
|
|
148
|
+
def __init__(self):
|
|
149
|
+
self.testing_backend = None
|
|
150
|
+
self.model_indicators = None
|
|
151
|
+
self.state_space = None
|
|
152
|
+
|
|
153
|
+
def construct_statespace(self, testing_backend: IProvideIndependenceAtoms, marked_tests: set[CI_Identifier], previous_graphs) -> IRepresentStateSpace:
|
|
154
|
+
self.testing_backend = testing_backend
|
|
155
|
+
|
|
156
|
+
model_indicator_construction = ModelIndicators_NoUnionCycles(testing_backend, marked_tests)
|
|
157
|
+
self.model_indicators = model_indicator_construction.model_indicators
|
|
158
|
+
self.state_space = model_indicator_construction.state_space
|
|
159
|
+
|
|
160
|
+
for ci in marked_tests:
|
|
161
|
+
self.translate_ci(ci)
|
|
162
|
+
|
|
163
|
+
return self.state_space
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def store_translated_ci(self, ci: CI_Identifier, implied_by: list[set[tuple]]) -> None:
|
|
167
|
+
for state in self.state_space.states_which_imply(implied_by):
|
|
168
|
+
# state => implied_by => ci
|
|
169
|
+
state.add_implication(ci)
|
|
170
|
+
|
|
171
|
+
def translate_ci(self, ci: CI_Identifier) -> None:
|
|
172
|
+
"""
|
|
173
|
+
Translate detected indicators.
|
|
174
|
+
|
|
175
|
+
*Implements phase II of Algo. 3 in* [RR25]_\\ *.*
|
|
176
|
+
|
|
177
|
+
:param ci: marked CI test
|
|
178
|
+
:type ci: CI_Identifier
|
|
179
|
+
"""""
|
|
180
|
+
candidates = powerset_as_list_of_sets(self.model_indicators)
|
|
181
|
+
empty_set = candidates.pop(0) # pop empty set at index 0
|
|
182
|
+
assert len(empty_set) == 0
|
|
183
|
+
|
|
184
|
+
necessary = []
|
|
185
|
+
|
|
186
|
+
if ci.undirected_link() in self.model_indicators:
|
|
187
|
+
# the X and Y are always dependent if the direct link is there
|
|
188
|
+
candidates = list([c for c in candidates if ci.undirected_link() in c])
|
|
189
|
+
|
|
190
|
+
while len(candidates) > 0: # python lists are not lists, so iterator-lifetime is wonky
|
|
191
|
+
if len(candidates) == 1 and len(necessary) == 0:
|
|
192
|
+
return self.store_translated_ci(ci, implied_by=candidates)
|
|
193
|
+
c: set[ModelIndicator] = candidates.pop(0)
|
|
194
|
+
representors_of_c = list([mi.representor for mi in c])
|
|
195
|
+
if self.testing_backend.regime_implication(representors_of_c, ci):
|
|
196
|
+
candidates = [c_ for c_ in candidates if not c <= c_] # "<=" is subset operator
|
|
197
|
+
necessary.append(c)
|
|
198
|
+
|
|
199
|
+
return self.store_translated_ci(ci, implied_by=necessary)
|
|
200
|
+
|
|
201
|
+
|
|
202
|
+
|
|
203
|
+
from .data_processing import ITestCI
|
|
204
|
+
from .data_management import IManageData
|
|
205
|
+
|
|
206
|
+
class StructuredResultWithTranslation(IPresentResult):
|
|
207
|
+
"""
|
|
208
|
+
Represents the (translated and transfered) result/labeled union graph
|
|
209
|
+
in the current baseline implementation.
|
|
210
|
+
"""
|
|
211
|
+
|
|
212
|
+
def __init__(self, union_graph: graph_t, model_indicators: list[ModelIndicator], graphs: dict[IRepresentState, graph_t]):
|
|
213
|
+
self._union_graph = union_graph
|
|
214
|
+
self._model_indicators = model_indicators
|
|
215
|
+
self._state_graphs = list(graphs.values())
|
|
216
|
+
|
|
217
|
+
def union_graph(self) -> graph_t:
|
|
218
|
+
return self._union_graph
|
|
219
|
+
|
|
220
|
+
def state_graphs(self) -> list[graph_t]:
|
|
221
|
+
raise self._state_graphs
|
|
222
|
+
|
|
223
|
+
def model_indicators(self) -> list:
|
|
224
|
+
return self._model_indicators
|
|
225
|
+
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
class Unionize_Translate_Transfer_NoUnionCylces:
|
|
229
|
+
"""
|
|
230
|
+
Namescope for helpers for constructiong a labeled union-graph.
|
|
231
|
+
"""
|
|
232
|
+
@staticmethod
|
|
233
|
+
def unionize_edgemark(a, b):
|
|
234
|
+
if a == b:
|
|
235
|
+
return a
|
|
236
|
+
elif a == 'x' or b == 'x':
|
|
237
|
+
return 'x'
|
|
238
|
+
elif a == 'o':
|
|
239
|
+
return b
|
|
240
|
+
elif b == 'o':
|
|
241
|
+
return a
|
|
242
|
+
else:
|
|
243
|
+
return 'x'
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def unionize_edge(cls, a, b):
|
|
247
|
+
if a == '':
|
|
248
|
+
return b
|
|
249
|
+
if b == '':
|
|
250
|
+
return a
|
|
251
|
+
lhs = cls.unionize_edgemark(a[0], b[0])
|
|
252
|
+
rhs = cls.unionize_edgemark(a[2], b[2])
|
|
253
|
+
return lhs + "-" + rhs
|
|
254
|
+
|
|
255
|
+
@classmethod
|
|
256
|
+
def unionize_and_transfer(cls, graphs):
|
|
257
|
+
result = None
|
|
258
|
+
for g in graphs:
|
|
259
|
+
if result is None:
|
|
260
|
+
result = g
|
|
261
|
+
else:
|
|
262
|
+
result = np.array([cls.unionize_edge(edge_a, edge_b) for edge_a, edge_b in zip(result.flatten(), g.flatten())]).reshape(result.shape)
|
|
263
|
+
return result
|
|
264
|
+
|
|
265
|
+
@classmethod
|
|
266
|
+
def obtain_translation_result(cls, state_space: StateSpace, graphs: dict[IRepresentState,graph_t]):
|
|
267
|
+
# make results ready for serialization and easy plotting etc
|
|
268
|
+
return StructuredResultWithTranslation(
|
|
269
|
+
union_graph = cls.unionize_and_transfer(graphs.values()),
|
|
270
|
+
model_indicators = state_space.model_indicators,
|
|
271
|
+
graphs = graphs
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
class ResolveByRepresentor(IResolveRegimeStructure):
|
|
276
|
+
"""
|
|
277
|
+
Primitive approximate resolution of model-indicators by representors.
|
|
278
|
+
"""
|
|
279
|
+
|
|
280
|
+
def __init__(self, indicator_resolution_score: Callable[[BlockView], np.ndarray], data_mgr: IManageData, block_size: int):
|
|
281
|
+
self.indicator_resolution_score = indicator_resolution_score
|
|
282
|
+
self.data_mgr = data_mgr
|
|
283
|
+
self.block_size = block_size
|
|
284
|
+
|
|
285
|
+
def resolve_model_indicator(self, model_indicator: ModelIndicator) -> np.ndarray:
|
|
286
|
+
patterned_data = self.data_mgr.get_patterned_data(model_indicator.representor).view_blocks(self.block_size)
|
|
287
|
+
result = self.indicator_resolution_score( patterned_data )
|
|
288
|
+
return self.data_mgr.reproject_blocks(result, block_configuration=patterned_data)
|