pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,671 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
from typing import Dict, Tuple, List
|
|
42
|
+
|
|
43
|
+
import csv
|
|
44
|
+
import math
|
|
45
|
+
import matplotlib.pyplot as plt
|
|
46
|
+
import pyagrum
|
|
47
|
+
|
|
48
|
+
from pyagrum.ctbn import CIM
|
|
49
|
+
from pyagrum.ctbn import CTBN
|
|
50
|
+
|
|
51
|
+
"""
|
|
52
|
+
Contains Trajectory Class and functions to compute stats from samples and tools for trajectory plotting.
|
|
53
|
+
Also contains Stats Class to store stats that are used for independence tests.
|
|
54
|
+
"""
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
def readTrajectoryCSV(filename: str) -> Dict[int, List[Tuple[float, str, str]]]:
|
|
58
|
+
"""
|
|
59
|
+
Reads trajectories from a csv file.
|
|
60
|
+
Storing format : {IdSample, time, var, state}
|
|
61
|
+
|
|
62
|
+
Parameters
|
|
63
|
+
----------
|
|
64
|
+
filename : str
|
|
65
|
+
Path to the file.
|
|
66
|
+
|
|
67
|
+
Returns
|
|
68
|
+
-------
|
|
69
|
+
Dict[int, List[Tuple[float, str, str]]]
|
|
70
|
+
The trajectories, a trajectory for every index.
|
|
71
|
+
"""
|
|
72
|
+
data = dict()
|
|
73
|
+
with open(filename, newline="") as csvfile:
|
|
74
|
+
reader = csv.DictReader(csvfile)
|
|
75
|
+
for row in reader:
|
|
76
|
+
IdSample, time, var, state = (row["IdSample"], row["time"], row["var"], row["state"])
|
|
77
|
+
IdSample = int(IdSample)
|
|
78
|
+
if IdSample not in data.keys():
|
|
79
|
+
data[IdSample] = list()
|
|
80
|
+
data[IdSample].append((float(time), var, state))
|
|
81
|
+
return data
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def plotTrajectory(
|
|
85
|
+
v: pyagrum.DiscreteVariable, traj: List[Tuple[float, str, str]], timeHorizon: float = None, plotname: str = None
|
|
86
|
+
):
|
|
87
|
+
"""
|
|
88
|
+
Plot a variable's trajectory using matplotlib.pyplot.
|
|
89
|
+
|
|
90
|
+
Parameters
|
|
91
|
+
----------
|
|
92
|
+
v : pyagrum.DiscreteVariable
|
|
93
|
+
Variable to follow.
|
|
94
|
+
traj : List[Tuple[float, str, str]]
|
|
95
|
+
Trajectory to plot.
|
|
96
|
+
timeHorizon : float
|
|
97
|
+
Maximum time length to show.
|
|
98
|
+
plotname : str
|
|
99
|
+
Name of the plot.
|
|
100
|
+
"""
|
|
101
|
+
if plotname is None:
|
|
102
|
+
name = f"trajectory of {v.name()}"
|
|
103
|
+
else:
|
|
104
|
+
name = plotname
|
|
105
|
+
|
|
106
|
+
if timeHorizon is None:
|
|
107
|
+
Tlim = traj[-1][0]
|
|
108
|
+
else:
|
|
109
|
+
Tlim = timeHorizon
|
|
110
|
+
|
|
111
|
+
XAxis = []
|
|
112
|
+
YAxis = []
|
|
113
|
+
|
|
114
|
+
prevtimepoint = 0
|
|
115
|
+
for timepoint, var, state in traj:
|
|
116
|
+
if timepoint == 0:
|
|
117
|
+
continue
|
|
118
|
+
if timepoint > Tlim:
|
|
119
|
+
break
|
|
120
|
+
if var == v.name():
|
|
121
|
+
XAxis.append(prevtimepoint)
|
|
122
|
+
XAxis.append(timepoint)
|
|
123
|
+
YAxis.append(state)
|
|
124
|
+
YAxis.append(state)
|
|
125
|
+
prevtimepoint = timepoint
|
|
126
|
+
|
|
127
|
+
plt.plot(XAxis, YAxis)
|
|
128
|
+
plt.yticks(range(v.domainSize()), v.labels())
|
|
129
|
+
plt.title(name)
|
|
130
|
+
plt.xlabel("time")
|
|
131
|
+
plt.ylabel("state")
|
|
132
|
+
plt.show()
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
def plotFollowVar(
|
|
136
|
+
v: pyagrum.DiscreteVariable,
|
|
137
|
+
trajectories: Dict[int, List[Tuple[float, str, str]]],
|
|
138
|
+
timeHorizon: float = None,
|
|
139
|
+
N: int = None,
|
|
140
|
+
plotname: str = None,
|
|
141
|
+
):
|
|
142
|
+
"""
|
|
143
|
+
Plot the evolution (the proportions of the states the variable transition into) of a variable over time.
|
|
144
|
+
|
|
145
|
+
Parameters
|
|
146
|
+
----------
|
|
147
|
+
v : pyagrum.DiscreteVariable
|
|
148
|
+
The variable to follow.
|
|
149
|
+
trajectories : Dict[int, List[Tuple[float, str, str]]]
|
|
150
|
+
Contains trajectories. The result is a mean over all the trajectories.
|
|
151
|
+
timeHorizon : float
|
|
152
|
+
Maximum time length to plot. If None, then the entire time length of the trajectories is used.
|
|
153
|
+
N : int
|
|
154
|
+
Number of division of interval [0, timeHorizon].
|
|
155
|
+
plotname : str
|
|
156
|
+
Name of the plot. If None, the picked name is "Proportions for each state of {v.name()}".
|
|
157
|
+
|
|
158
|
+
"""
|
|
159
|
+
if plotname is None:
|
|
160
|
+
name = f"Proportions for each state of {v.name()}"
|
|
161
|
+
else:
|
|
162
|
+
name = plotname
|
|
163
|
+
|
|
164
|
+
if timeHorizon is None:
|
|
165
|
+
# timeHorizon is always found at the end of a trajectory.
|
|
166
|
+
T = trajectories[0][-1][0]
|
|
167
|
+
else:
|
|
168
|
+
T = timeHorizon
|
|
169
|
+
|
|
170
|
+
if N is None:
|
|
171
|
+
h = 1
|
|
172
|
+
else:
|
|
173
|
+
h = timeHorizon / N
|
|
174
|
+
|
|
175
|
+
n = len(trajectories.keys())
|
|
176
|
+
domain = sorted(v.labels())
|
|
177
|
+
division = [h * i for i in range(math.ceil(T // h) + 1)]
|
|
178
|
+
count = {label: {h * i: 0 for i in range(math.ceil(T // h) + 1)} for label in domain}
|
|
179
|
+
total = {h * i: 0 for i in range(math.ceil(T // h) + 1)}
|
|
180
|
+
|
|
181
|
+
# Counting transitions
|
|
182
|
+
for i in range(n):
|
|
183
|
+
traj = trajectories[i]
|
|
184
|
+
for time, var, state in traj:
|
|
185
|
+
if var == v.name() and time > 0:
|
|
186
|
+
step = math.ceil(time / h) * h
|
|
187
|
+
# if a transition appears at time t, then we count it for all the next step(>t) until the end
|
|
188
|
+
for i in range(len(division)):
|
|
189
|
+
if division[i] >= step:
|
|
190
|
+
count[state][division[i]] += 1
|
|
191
|
+
total[division[i]] += 1
|
|
192
|
+
|
|
193
|
+
# Compute average of all sums
|
|
194
|
+
for label in domain:
|
|
195
|
+
for step in division:
|
|
196
|
+
if total[step] != 0:
|
|
197
|
+
count[label][step] /= total[step]
|
|
198
|
+
|
|
199
|
+
YAxisList = [[count[lab][step] for step in division] for lab in domain]
|
|
200
|
+
|
|
201
|
+
_, ax = plt.subplots()
|
|
202
|
+
plt.xlim(left=0, right=T)
|
|
203
|
+
plt.ylim(top=1, bottom=0)
|
|
204
|
+
|
|
205
|
+
plt.xlabel("time")
|
|
206
|
+
plt.ylabel("state proportion")
|
|
207
|
+
ax.stackplot(division, YAxisList)
|
|
208
|
+
plt.title(name)
|
|
209
|
+
plt.legend(domain)
|
|
210
|
+
plt.show()
|
|
211
|
+
|
|
212
|
+
|
|
213
|
+
def CTBNFromData(data: Dict[int, List[Tuple[float, str, str]]]) -> CTBN:
|
|
214
|
+
"""
|
|
215
|
+
Constructs a CTBN and add the corresponding variables found in the trajectories.
|
|
216
|
+
|
|
217
|
+
Warning
|
|
218
|
+
-------
|
|
219
|
+
If data is too short, some variables or state labels might be missed.
|
|
220
|
+
|
|
221
|
+
Parameters
|
|
222
|
+
----------
|
|
223
|
+
data : Dict[int, List[Tuple[float, str, str]]]
|
|
224
|
+
The trajectories used to look for variables.
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
CTBN
|
|
229
|
+
The resulting CTBN.
|
|
230
|
+
"""
|
|
231
|
+
names: set[str] = set()
|
|
232
|
+
labels: dict[str, set[str]] = dict()
|
|
233
|
+
|
|
234
|
+
for i in range(len(data.keys())):
|
|
235
|
+
for t, var, state in data[i]:
|
|
236
|
+
names.add(var)
|
|
237
|
+
if var not in labels.keys():
|
|
238
|
+
labels[var] = set()
|
|
239
|
+
labels[var].add(state)
|
|
240
|
+
|
|
241
|
+
ctbn = CTBN()
|
|
242
|
+
|
|
243
|
+
for name in names:
|
|
244
|
+
var = pyagrum.LabelizedVariable(name, name, sorted(list(labels[name])))
|
|
245
|
+
ctbn.add(var)
|
|
246
|
+
|
|
247
|
+
return ctbn
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def computeCIMFromStats(X: str, M: pyagrum.Tensor, T: pyagrum.Tensor) -> "pyagrum.Tensor":
|
|
251
|
+
"""
|
|
252
|
+
Computes a CIM (Conditional Intensity Matrix) using stats from a trajectory. Variables in the tensor
|
|
253
|
+
are not copied but directly used in the result to avoid memory issues.
|
|
254
|
+
|
|
255
|
+
Parameters
|
|
256
|
+
----------
|
|
257
|
+
X : str
|
|
258
|
+
Name of the variable to compute CIM for.
|
|
259
|
+
M : pyagrum.Tensor
|
|
260
|
+
Tensor containing the number of transitions for each pair of ``X``'s states.
|
|
261
|
+
T : pyagrum.Tensor
|
|
262
|
+
Tensor containing the time spent to transition from every state of ``X``.
|
|
263
|
+
|
|
264
|
+
Returns
|
|
265
|
+
-------
|
|
266
|
+
pyagrum.Tensor
|
|
267
|
+
The resulting tensor, ``X``'s CIM.
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
res: pyagrum.Tensor = pyagrum.Tensor(M)
|
|
271
|
+
res.fillWith(0)
|
|
272
|
+
|
|
273
|
+
i = pyagrum.Instantiation(res)
|
|
274
|
+
iTime = pyagrum.Instantiation(T)
|
|
275
|
+
|
|
276
|
+
i.setFirst()
|
|
277
|
+
iTime.setFirst()
|
|
278
|
+
|
|
279
|
+
posI = i.pos(res.variable(CIM.varI(X)))
|
|
280
|
+
posJ = i.pos(res.variable(CIM.varJ(X)))
|
|
281
|
+
|
|
282
|
+
while not i.end():
|
|
283
|
+
iTime.setVals(i)
|
|
284
|
+
|
|
285
|
+
iSum = pyagrum.Instantiation(i)
|
|
286
|
+
sumCIM = 0
|
|
287
|
+
iSum.setFirstVar(i.variable(posJ))
|
|
288
|
+
while not iSum.end():
|
|
289
|
+
sumCIM += M.get(iSum)
|
|
290
|
+
iSum.incVar(i.variable(posJ))
|
|
291
|
+
|
|
292
|
+
if i.val(posI) == i.val(posJ):
|
|
293
|
+
if T.get(iTime) != 0:
|
|
294
|
+
res.set(i, -round(sumCIM / T.get(iTime), 3))
|
|
295
|
+
else:
|
|
296
|
+
if T.get(iTime) != 0:
|
|
297
|
+
res.set(i, round((M.get(i) / T.get(iTime)), 3))
|
|
298
|
+
i.inc()
|
|
299
|
+
|
|
300
|
+
return res
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class Trajectory:
|
|
304
|
+
"""
|
|
305
|
+
Tools to extract useful informations from a trajectory. It is used for parameters/graph learning.
|
|
306
|
+
It can be created from a trajectory (a dict of trajectories) or from a file that contains one.
|
|
307
|
+
|
|
308
|
+
Parameters
|
|
309
|
+
----------
|
|
310
|
+
source : str|Dict[int, List[Tuple[float, str, str]]]
|
|
311
|
+
The path to a csv file containing the samples or the dict of trajectories itself.
|
|
312
|
+
ctbn : CTBN
|
|
313
|
+
To link the variables's name in the trajectory to their pyAgrum variable. If not given, a new CTBN is created
|
|
314
|
+
with the variables and labels found in the trajectory. (warning : if the trajectory is short, all of the variables
|
|
315
|
+
may not be found correctly).
|
|
316
|
+
|
|
317
|
+
Attributes
|
|
318
|
+
----------
|
|
319
|
+
data : Dict[int, List[Tuple[float, str, str]]]
|
|
320
|
+
The samples.
|
|
321
|
+
ctbn : CTBN
|
|
322
|
+
The CTBN used to link the names in the trajectory to pyAgrum variables.
|
|
323
|
+
timeHorizon : float
|
|
324
|
+
The time length of the trajectory.
|
|
325
|
+
"""
|
|
326
|
+
|
|
327
|
+
def __init__(self, source, ctbn: CTBN = None):
|
|
328
|
+
if isinstance(source, str):
|
|
329
|
+
self.data = readTrajectoryCSV(source)
|
|
330
|
+
else:
|
|
331
|
+
self.data = source
|
|
332
|
+
|
|
333
|
+
if ctbn is None:
|
|
334
|
+
self.ctbn = CTBNFromData(self.data)
|
|
335
|
+
else:
|
|
336
|
+
self.ctbn = ctbn
|
|
337
|
+
# to assert
|
|
338
|
+
self.timeHorizon = self.data[0][-1][0]
|
|
339
|
+
|
|
340
|
+
def setStatValues(self, X: str, inst_u: Dict[str, str], Txu: pyagrum.Tensor, Mxu: pyagrum.Tensor):
|
|
341
|
+
"""
|
|
342
|
+
Fills the tensors given.
|
|
343
|
+
|
|
344
|
+
Parameters
|
|
345
|
+
----------
|
|
346
|
+
X : str
|
|
347
|
+
Name of the variable.
|
|
348
|
+
inst_u : Dict[str, str]
|
|
349
|
+
Instance of conditioning variables.
|
|
350
|
+
Txu : pyagrum.Tensor
|
|
351
|
+
Tensor to fill. Contains the time spent in each state.
|
|
352
|
+
Mxu : pyagrum.Tensor
|
|
353
|
+
Tensor to fill. Contains the number of transitions from any pair of states.
|
|
354
|
+
"""
|
|
355
|
+
|
|
356
|
+
def checkAllValues(curr, goal):
|
|
357
|
+
for e in curr.keys():
|
|
358
|
+
if curr[e] != goal[e]:
|
|
359
|
+
return False
|
|
360
|
+
return True
|
|
361
|
+
|
|
362
|
+
def findNextValue(X, traj, start):
|
|
363
|
+
for i in range(start + 1, len(traj)):
|
|
364
|
+
if traj[i][1] == X:
|
|
365
|
+
return traj[i][2]
|
|
366
|
+
|
|
367
|
+
X_i = CIM.varI(X)
|
|
368
|
+
X_j = CIM.varJ(X)
|
|
369
|
+
data = self.data
|
|
370
|
+
for i in range(len(data)):
|
|
371
|
+
traj = data[i]
|
|
372
|
+
prev_duration = 0
|
|
373
|
+
duration = 0
|
|
374
|
+
|
|
375
|
+
# store current values at duration
|
|
376
|
+
u_values = dict()
|
|
377
|
+
X_value = None
|
|
378
|
+
|
|
379
|
+
# init (find initial values)
|
|
380
|
+
for l in range(len(traj)):
|
|
381
|
+
time, var, state = traj[l]
|
|
382
|
+
if time != 0:
|
|
383
|
+
break
|
|
384
|
+
if var == X:
|
|
385
|
+
X_value = state
|
|
386
|
+
elif var in inst_u.keys():
|
|
387
|
+
u_values[var] = state
|
|
388
|
+
|
|
389
|
+
# loop over the transitions
|
|
390
|
+
for l in range(len(traj)):
|
|
391
|
+
time, var, state = traj[l]
|
|
392
|
+
if time == 0:
|
|
393
|
+
continue
|
|
394
|
+
|
|
395
|
+
prev_duration = duration
|
|
396
|
+
duration = time
|
|
397
|
+
|
|
398
|
+
########## begin check
|
|
399
|
+
if checkAllValues(u_values, inst_u):
|
|
400
|
+
# update T
|
|
401
|
+
inst1 = u_values.copy()
|
|
402
|
+
inst1[X_i] = X_value
|
|
403
|
+
Txu[inst1] += duration - prev_duration
|
|
404
|
+
|
|
405
|
+
# update M
|
|
406
|
+
if time < self.timeHorizon and var == X:
|
|
407
|
+
nextv = findNextValue(X, traj, l)
|
|
408
|
+
if nextv != X_value:
|
|
409
|
+
inst2 = u_values.copy()
|
|
410
|
+
inst2[X_i] = X_value
|
|
411
|
+
inst2[X_j] = nextv
|
|
412
|
+
Mxu[inst2] += 1
|
|
413
|
+
########## end check
|
|
414
|
+
|
|
415
|
+
if time < self.timeHorizon and var == X:
|
|
416
|
+
X_value = findNextValue(X, traj, l)
|
|
417
|
+
elif time < self.timeHorizon and var in inst_u.keys():
|
|
418
|
+
u_values[var] = findNextValue(var, traj, l)
|
|
419
|
+
|
|
420
|
+
def computeStats(self, X: str, U: List[str]) -> Tuple[pyagrum.Tensor, pyagrum.Tensor]:
|
|
421
|
+
"""
|
|
422
|
+
Computes time spent and number of transitions values of ``X`` and returns them as ``pyagrum.Tensor``.
|
|
423
|
+
|
|
424
|
+
Parameters
|
|
425
|
+
----------
|
|
426
|
+
X : str
|
|
427
|
+
Name of the variable.
|
|
428
|
+
U : List[str]
|
|
429
|
+
List of conditioning variable's name.
|
|
430
|
+
|
|
431
|
+
Returns
|
|
432
|
+
-------
|
|
433
|
+
Tuple[pyagrum.Tensor, pyagrum.Tensor]
|
|
434
|
+
The resulting tensors.
|
|
435
|
+
"""
|
|
436
|
+
par = [self.ctbn.variable(nv) for nv in U]
|
|
437
|
+
n = len(self.data)
|
|
438
|
+
|
|
439
|
+
Txu = pyagrum.Tensor()
|
|
440
|
+
Mxu = pyagrum.Tensor()
|
|
441
|
+
current = pyagrum.Instantiation()
|
|
442
|
+
|
|
443
|
+
X_from = self.ctbn.CIM(X).findVar(CIM.varI(X))
|
|
444
|
+
X_to = self.ctbn.CIM(X).findVar(CIM.varJ(X))
|
|
445
|
+
|
|
446
|
+
Txu.add(X_from)
|
|
447
|
+
Mxu.add(X_from)
|
|
448
|
+
Mxu.add(X_to)
|
|
449
|
+
for v in par:
|
|
450
|
+
current.add(v)
|
|
451
|
+
Txu.add(v)
|
|
452
|
+
Mxu.add(v)
|
|
453
|
+
|
|
454
|
+
Txu.fillWith(0)
|
|
455
|
+
Mxu.fillWith(0)
|
|
456
|
+
|
|
457
|
+
# looping over all possible instances
|
|
458
|
+
current.setFirst()
|
|
459
|
+
while not current.end():
|
|
460
|
+
inst_par = {vname: current.todict(withLabels=True)[vname] for vname in U}
|
|
461
|
+
self.setStatValues(X, inst_par, Txu, Mxu)
|
|
462
|
+
current.inc()
|
|
463
|
+
|
|
464
|
+
Mxu = Mxu.putFirst(X_to.name())
|
|
465
|
+
return (Txu / n, Mxu / n)
|
|
466
|
+
|
|
467
|
+
def computeAllCIMs(self):
|
|
468
|
+
"""
|
|
469
|
+
Computes the CIMs of the variables in ``self.ctbn``. Conditioning is given by the graph of ``self.ctbn``.
|
|
470
|
+
"""
|
|
471
|
+
for var in self.ctbn.variables():
|
|
472
|
+
T, M = self.computeStats(var.name(), self.ctbn.parentNames(var.name()))
|
|
473
|
+
self.ctbn.CIM(var.name())._pot = pyagrum.Tensor(
|
|
474
|
+
computeCIMFromStats(var.name(), M.putFirst(CIM.varJ(var.name())), T)
|
|
475
|
+
)
|
|
476
|
+
|
|
477
|
+
def setStatsForTests(
|
|
478
|
+
self, X: str, Y: str, inst_u: Dict[str, str], Txu: pyagrum.Tensor, Txyu: pyagrum.Tensor, Mxyu: pyagrum.Tensor
|
|
479
|
+
):
|
|
480
|
+
"""
|
|
481
|
+
Fills the tensors given. They are used for independence testing.
|
|
482
|
+
|
|
483
|
+
Parameters
|
|
484
|
+
----------
|
|
485
|
+
X : str
|
|
486
|
+
Name of the variable.
|
|
487
|
+
Y : str
|
|
488
|
+
Name of a conditioning variable.
|
|
489
|
+
inst_u : Dict[str, str]
|
|
490
|
+
Instance of conditioning variables.
|
|
491
|
+
Txu : pyagrum.Tensor
|
|
492
|
+
Tensor to fill. Contains the time spent in each state. Conditioned by variables in ``inst_u``.
|
|
493
|
+
Txyu : pyagrum.Tensor
|
|
494
|
+
Tensor to fill. Contains the time spent in each state. Conditioned by ``Y`` and variables in ``inst_u``.
|
|
495
|
+
Mxyu : pyagrum.Tensor
|
|
496
|
+
Tensor to fill. Contains the number of transitions from any pair of states. Conditioned by ``Y`` and variables in ``inst_u``.
|
|
497
|
+
"""
|
|
498
|
+
|
|
499
|
+
def checkAllValues(curr, goal):
|
|
500
|
+
for e in curr.keys():
|
|
501
|
+
if curr[e] != goal[e]:
|
|
502
|
+
return False
|
|
503
|
+
return True
|
|
504
|
+
|
|
505
|
+
def findNextValue(X, traj, start):
|
|
506
|
+
for i in range(start + 1, len(traj)):
|
|
507
|
+
if traj[i][1] == X:
|
|
508
|
+
return traj[i][2]
|
|
509
|
+
|
|
510
|
+
X_i = CIM.varI(X)
|
|
511
|
+
X_j = CIM.varJ(X)
|
|
512
|
+
data = self.data
|
|
513
|
+
for i in range(len(data)):
|
|
514
|
+
traj = data[i]
|
|
515
|
+
prev_duration = 0
|
|
516
|
+
duration = 0
|
|
517
|
+
|
|
518
|
+
# store current values at duration
|
|
519
|
+
u_values = dict()
|
|
520
|
+
X_value = None
|
|
521
|
+
Y_value = None
|
|
522
|
+
|
|
523
|
+
# init (find initial values)
|
|
524
|
+
for l in range(len(traj)):
|
|
525
|
+
time, var, state = traj[l]
|
|
526
|
+
if time != 0:
|
|
527
|
+
break
|
|
528
|
+
if var == X:
|
|
529
|
+
X_value = state
|
|
530
|
+
elif var == Y:
|
|
531
|
+
Y_value = state
|
|
532
|
+
elif var in inst_u.keys():
|
|
533
|
+
u_values[var] = state
|
|
534
|
+
|
|
535
|
+
# loop over the transitions
|
|
536
|
+
for l in range(len(traj)):
|
|
537
|
+
time, var, state = traj[l]
|
|
538
|
+
if time == 0:
|
|
539
|
+
continue
|
|
540
|
+
|
|
541
|
+
prev_duration = duration
|
|
542
|
+
duration = time
|
|
543
|
+
|
|
544
|
+
########## begin check
|
|
545
|
+
if checkAllValues(u_values, inst_u):
|
|
546
|
+
# update T
|
|
547
|
+
inst1 = u_values.copy()
|
|
548
|
+
inst1[X_i] = X_value
|
|
549
|
+
Txu[inst1] += duration - prev_duration
|
|
550
|
+
|
|
551
|
+
inst2 = inst1.copy()
|
|
552
|
+
inst2[Y] = Y_value
|
|
553
|
+
Txyu[inst2] += duration - prev_duration
|
|
554
|
+
|
|
555
|
+
# update M
|
|
556
|
+
if time < self.timeHorizon and var == X:
|
|
557
|
+
nextv = findNextValue(X, traj, l)
|
|
558
|
+
if nextv != X_value:
|
|
559
|
+
inst3 = u_values.copy()
|
|
560
|
+
inst3[X_i] = X_value
|
|
561
|
+
inst3[X_j] = nextv
|
|
562
|
+
inst3[Y] = Y_value
|
|
563
|
+
Mxyu[inst3] += 1
|
|
564
|
+
########## end check
|
|
565
|
+
|
|
566
|
+
if time < self.timeHorizon and var == X:
|
|
567
|
+
X_value = findNextValue(X, traj, l)
|
|
568
|
+
elif time < self.timeHorizon and var == Y:
|
|
569
|
+
Y_value = findNextValue(Y, traj, l)
|
|
570
|
+
elif time < self.timeHorizon and var in inst_u.keys():
|
|
571
|
+
u_values[var] = findNextValue(var, traj, l)
|
|
572
|
+
|
|
573
|
+
def computeStatsForTests(self, X: str, Y: str, U: List[str]) -> Tuple[pyagrum.Tensor, pyagrum.Tensor, pyagrum.Tensor]:
|
|
574
|
+
"""
|
|
575
|
+
Computes time spent and number of transitions values of ``X`` when conditioned by ``Y`` and ``U`` and
|
|
576
|
+
returns them as ``pyagrum.Tensor``. Used for independence testing.
|
|
577
|
+
|
|
578
|
+
Parameters
|
|
579
|
+
----------
|
|
580
|
+
X : str
|
|
581
|
+
Name of the variable.
|
|
582
|
+
Y : str
|
|
583
|
+
Name of a conditioning variable not in ``U``.
|
|
584
|
+
U : List[str]
|
|
585
|
+
List of conditioning variable's name.
|
|
586
|
+
|
|
587
|
+
Returns
|
|
588
|
+
-------
|
|
589
|
+
Tuple[pyagrum.Tensor, pyagrum.Tensor, pyagrum.Tensor]
|
|
590
|
+
The resulting tensors.
|
|
591
|
+
"""
|
|
592
|
+
par = [self.ctbn.variable(nv) for nv in U]
|
|
593
|
+
n = len(self.data)
|
|
594
|
+
|
|
595
|
+
Txu = pyagrum.Tensor()
|
|
596
|
+
Txyu = pyagrum.Tensor()
|
|
597
|
+
Mxyu = pyagrum.Tensor()
|
|
598
|
+
current = pyagrum.Instantiation()
|
|
599
|
+
|
|
600
|
+
X_from = self.ctbn.CIM(X).findVar(CIM.varI(X))
|
|
601
|
+
X_to = self.ctbn.CIM(X).findVar(CIM.varJ(X))
|
|
602
|
+
varY = self.ctbn.variable(Y)
|
|
603
|
+
|
|
604
|
+
Txu.add(X_from)
|
|
605
|
+
Txyu.add(X_from)
|
|
606
|
+
Txyu.add(varY)
|
|
607
|
+
Mxyu.add(varY)
|
|
608
|
+
Mxyu.add(X_from)
|
|
609
|
+
Mxyu.add(X_to)
|
|
610
|
+
|
|
611
|
+
for v in par:
|
|
612
|
+
current.add(v)
|
|
613
|
+
Txu.add(v)
|
|
614
|
+
Txyu.add(v)
|
|
615
|
+
Mxyu.add(v)
|
|
616
|
+
|
|
617
|
+
Txu.fillWith(0)
|
|
618
|
+
Txyu.fillWith(0)
|
|
619
|
+
Mxyu.fillWith(0)
|
|
620
|
+
|
|
621
|
+
# looping over all possible instances
|
|
622
|
+
current.setFirst()
|
|
623
|
+
while not current.end():
|
|
624
|
+
inst_par = {vname: current.todict(withLabels=True)[vname] for vname in U}
|
|
625
|
+
self.setStatsForTests(X, Y, inst_par, Txu, Txyu, Mxyu)
|
|
626
|
+
current.inc()
|
|
627
|
+
|
|
628
|
+
Txu.putFirst(X_from.name())
|
|
629
|
+
Txyu.putFirst(X_from.name())
|
|
630
|
+
Mxyu = Mxyu.putFirst(X_to.name())
|
|
631
|
+
return (Txu / n, Txyu / n, Mxyu / n)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
class Stats:
|
|
635
|
+
"""
|
|
636
|
+
Stores all tensors used for learning.
|
|
637
|
+
|
|
638
|
+
Parameters
|
|
639
|
+
----------
|
|
640
|
+
trajectory : Trajectory
|
|
641
|
+
Samples used to find stats.
|
|
642
|
+
X : str
|
|
643
|
+
Name of the variable to study.
|
|
644
|
+
Y : str
|
|
645
|
+
Name of the variable used for conditioning variable ``X``.
|
|
646
|
+
par : List[str]
|
|
647
|
+
List of conditioning variables of ``X``.
|
|
648
|
+
|
|
649
|
+
Attributes
|
|
650
|
+
----------
|
|
651
|
+
Mxy : pyagrum.Tensor
|
|
652
|
+
Tensor containing the number of transitions the variable ``X`` does from any
|
|
653
|
+
of its states for any instance of its parents and variable``Y``.
|
|
654
|
+
Mx : pyagrum.Tensor
|
|
655
|
+
Tensor containing the number of transitions the variable ``X`` does from any
|
|
656
|
+
of its states for any instance of its parents.
|
|
657
|
+
Tx : pyagrum.Tensor
|
|
658
|
+
Tensor containing the time spent by ``X`` to transition from a state to another for any instance of its parents.
|
|
659
|
+
Txy : pyagrum.Tensor
|
|
660
|
+
Tensor containing the time spent by ``X`` to transition from a state to another for any instance of its parents and of ``Y``.
|
|
661
|
+
Qx : pyagrum.Tensor
|
|
662
|
+
Conditional Intensity Matrix(CIM) of ``X``.
|
|
663
|
+
QxY : pyagrum.Tensor
|
|
664
|
+
Conditional Intensity Matrix(CIM) of ``X`` that includes the conditioning variable ``Y``.
|
|
665
|
+
"""
|
|
666
|
+
|
|
667
|
+
def __init__(self, trajectory: Trajectory, X: str, Y: str, par: List[str]):
|
|
668
|
+
self.Tx, self.Txy, self.Mxy = trajectory.computeStatsForTests(X, Y, par)
|
|
669
|
+
self.Mx = pyagrum.Tensor(self.Mxy).sumOut([Y])
|
|
670
|
+
self.Qx = computeCIMFromStats(X, self.Mx, self.Tx)
|
|
671
|
+
self.Qxy = computeCIMFromStats(X, self.Mxy, self.Txy)
|