pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/ctbn/CTBN.py
ADDED
|
@@ -0,0 +1,573 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
from typing import Dict, Tuple, List, Set
|
|
42
|
+
|
|
43
|
+
import pyagrum
|
|
44
|
+
|
|
45
|
+
from pyagrum.ctbn import CIM
|
|
46
|
+
from pyagrum.ctbn.constants import NodeId, NameOrId
|
|
47
|
+
|
|
48
|
+
import pyagrum.ctbn
|
|
49
|
+
|
|
50
|
+
"""
|
|
51
|
+
This file contains the CTBN class.
|
|
52
|
+
CTBN : Continous-Time Bayesian Network
|
|
53
|
+
"""
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
class CTBN:
|
|
57
|
+
"""
|
|
58
|
+
This class is used to represent a CTBN.
|
|
59
|
+
A CTBN is : a set of random variables, a CIM for each one of them and a pyagrum.DiGraph to represent dependency relations.
|
|
60
|
+
|
|
61
|
+
Attributes
|
|
62
|
+
----------
|
|
63
|
+
_graph : pyagrum.DiGraph
|
|
64
|
+
Graph representing dependency relations between variables. Also used to link a variable with an id.
|
|
65
|
+
_cim : Dict[NodeId, CIM]
|
|
66
|
+
Dict containing a CIM for each nodeId(the integer given to a variable).
|
|
67
|
+
_id2var : Dict[NodeId, pyagrum.DiscreteVariable]
|
|
68
|
+
Dict containing the variable associated to a node id.
|
|
69
|
+
_name2id : Dict[str, NodeId]
|
|
70
|
+
Dict containing the nodeId associated to a variable's name.
|
|
71
|
+
"""
|
|
72
|
+
|
|
73
|
+
_graph: pyagrum.DiGraph
|
|
74
|
+
_cim: Dict[NodeId, CIM]
|
|
75
|
+
_id2var: Dict[NodeId, pyagrum.DiscreteVariable]
|
|
76
|
+
_name2id: Dict[str, NodeId]
|
|
77
|
+
|
|
78
|
+
def __init__(self):
|
|
79
|
+
self._graph = pyagrum.DiGraph()
|
|
80
|
+
self._cim = {}
|
|
81
|
+
self._id2var = {}
|
|
82
|
+
self._name2id = {}
|
|
83
|
+
|
|
84
|
+
def add(self, var: pyagrum.DiscreteVariable) -> NodeId:
|
|
85
|
+
"""
|
|
86
|
+
Add a new variable to the Ctbn.
|
|
87
|
+
|
|
88
|
+
Parameters
|
|
89
|
+
----------
|
|
90
|
+
var : pyagrum.DiscreteVariable
|
|
91
|
+
The variable to add to the CTBN.
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
NodeId
|
|
96
|
+
The id given to the variable.
|
|
97
|
+
|
|
98
|
+
Raises
|
|
99
|
+
------
|
|
100
|
+
NameError
|
|
101
|
+
If a variable with the same name already exists.
|
|
102
|
+
ValueError
|
|
103
|
+
If the variable is None.
|
|
104
|
+
"""
|
|
105
|
+
if var is None:
|
|
106
|
+
raise ValueError("The var cannot be None.")
|
|
107
|
+
if var.name() == "" or not CIM.isParent(var):
|
|
108
|
+
raise NameError(f"The name '{var.name()}' is not correct.")
|
|
109
|
+
if var.name() in self._name2id:
|
|
110
|
+
raise NameError(f"A variable with the same name ({var.name()}) already exists in this CTBN.")
|
|
111
|
+
|
|
112
|
+
# link variable to its name and NodeId
|
|
113
|
+
n = NodeId(self._graph.addNode())
|
|
114
|
+
self._id2var[n] = var
|
|
115
|
+
self._name2id[var.name()] = n
|
|
116
|
+
|
|
117
|
+
# add leaving and starting states in the CIM
|
|
118
|
+
v_i = var.clone()
|
|
119
|
+
v_i.setName(CIM.varI(var.name()))
|
|
120
|
+
v_j = var.clone()
|
|
121
|
+
v_j.setName(CIM.varJ(var.name()))
|
|
122
|
+
self._cim[n] = CIM().add(v_j).add(v_i)
|
|
123
|
+
|
|
124
|
+
return n
|
|
125
|
+
|
|
126
|
+
def _nameOrId(self, val: NameOrId) -> NodeId:
|
|
127
|
+
"""
|
|
128
|
+
Returns
|
|
129
|
+
-------
|
|
130
|
+
NodeId
|
|
131
|
+
The id of a variable.
|
|
132
|
+
|
|
133
|
+
Raises
|
|
134
|
+
------
|
|
135
|
+
pyagrum.NotFound
|
|
136
|
+
If the variable's name isn't in the CTBN.
|
|
137
|
+
"""
|
|
138
|
+
if isinstance(val, int):
|
|
139
|
+
if val not in self._id2var.keys():
|
|
140
|
+
raise pyagrum.NotFound("the variable isn't in the ctbn")
|
|
141
|
+
return val
|
|
142
|
+
else:
|
|
143
|
+
if val not in self.names():
|
|
144
|
+
raise pyagrum.NotFound("the variable isn't in the ctbn")
|
|
145
|
+
return self._name2id[val]
|
|
146
|
+
|
|
147
|
+
def addArc(self, val1: NameOrId, val2: NameOrId) -> Tuple[NodeId, NodeId]:
|
|
148
|
+
"""
|
|
149
|
+
Adds an arc ``val1`` -> ``val2``.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
val1 : NameOrId
|
|
154
|
+
The name or id of the first variable.
|
|
155
|
+
val2 : NameOrId
|
|
156
|
+
The name or id of the second variable.
|
|
157
|
+
|
|
158
|
+
Returns
|
|
159
|
+
-------
|
|
160
|
+
Tuple[NodeId, NodeId]
|
|
161
|
+
The created arc (``val1``, ``val2``).
|
|
162
|
+
|
|
163
|
+
Raises
|
|
164
|
+
------
|
|
165
|
+
pyagrum.NotFound
|
|
166
|
+
If one the variables is not in the CTBN.
|
|
167
|
+
"""
|
|
168
|
+
n1 = self._nameOrId(val1)
|
|
169
|
+
n2 = self._nameOrId(val2)
|
|
170
|
+
self._graph.addArc(n1, n2)
|
|
171
|
+
|
|
172
|
+
# adding n1 as a parent of n2 in the CIM
|
|
173
|
+
self._cim[n2].add(self._id2var[n1])
|
|
174
|
+
|
|
175
|
+
return (n1, n2)
|
|
176
|
+
|
|
177
|
+
def eraseArc(self, val1: NameOrId, val2: NameOrId):
|
|
178
|
+
"""
|
|
179
|
+
Erases an arc from the graph.
|
|
180
|
+
|
|
181
|
+
Parameters
|
|
182
|
+
----------
|
|
183
|
+
val1 : NameOrId
|
|
184
|
+
The name or id of the first variable.
|
|
185
|
+
val2 : NameOrId
|
|
186
|
+
The name or id of the second variable.
|
|
187
|
+
|
|
188
|
+
Raises
|
|
189
|
+
------
|
|
190
|
+
pyagrum.NotFound
|
|
191
|
+
If a variable isn't in the CIM.
|
|
192
|
+
pyagrum.InvalidArgument
|
|
193
|
+
If a variable isn't a parent in the CIM.
|
|
194
|
+
"""
|
|
195
|
+
n1 = self._nameOrId(val1)
|
|
196
|
+
n2 = self._nameOrId(val2)
|
|
197
|
+
self._graph.eraseArc(n1, n2)
|
|
198
|
+
self._cim[n2].remove(self._id2var[n1])
|
|
199
|
+
|
|
200
|
+
def name(self, node: NodeId) -> str:
|
|
201
|
+
"""
|
|
202
|
+
Parameters
|
|
203
|
+
----------
|
|
204
|
+
node : NodeId
|
|
205
|
+
The id of the variable.
|
|
206
|
+
|
|
207
|
+
Returns
|
|
208
|
+
-------
|
|
209
|
+
str
|
|
210
|
+
The variable's name linked to the NodeId.
|
|
211
|
+
|
|
212
|
+
Raises
|
|
213
|
+
------
|
|
214
|
+
pyagrum.NotFound
|
|
215
|
+
If the variable is not found in the CTBN.
|
|
216
|
+
"""
|
|
217
|
+
if node not in self._id2var.keys():
|
|
218
|
+
raise pyagrum.NotFound("The node isn't in the ctbn")
|
|
219
|
+
return self._id2var[node].name()
|
|
220
|
+
|
|
221
|
+
def node(self, name: str) -> NodeId:
|
|
222
|
+
"""
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
name : str
|
|
226
|
+
The name of the variable.
|
|
227
|
+
|
|
228
|
+
Returns
|
|
229
|
+
-------
|
|
230
|
+
NodeId
|
|
231
|
+
The id of the variable.
|
|
232
|
+
|
|
233
|
+
Raises
|
|
234
|
+
------
|
|
235
|
+
pyagrum.NotFound
|
|
236
|
+
If the variable is not found in the CTBN.
|
|
237
|
+
"""
|
|
238
|
+
if name not in self.names():
|
|
239
|
+
raise pyagrum.NotFound("the variable isn't in the ctbn")
|
|
240
|
+
return self._name2id[name]
|
|
241
|
+
|
|
242
|
+
def labels(self, val: NameOrId) -> tuple:
|
|
243
|
+
"""
|
|
244
|
+
Parameters
|
|
245
|
+
----------
|
|
246
|
+
val : NameOrId
|
|
247
|
+
The name or id of the variable.
|
|
248
|
+
|
|
249
|
+
Returns
|
|
250
|
+
-------
|
|
251
|
+
tuple
|
|
252
|
+
A tuple containing the labels of the variable.
|
|
253
|
+
|
|
254
|
+
Raises
|
|
255
|
+
------
|
|
256
|
+
pyagrum.NotFound
|
|
257
|
+
If the variable is not found in the CTBN.
|
|
258
|
+
"""
|
|
259
|
+
return self._id2var[self._nameOrId(val)].labels()
|
|
260
|
+
|
|
261
|
+
def variable(self, val: NameOrId) -> "pyagrum.DiscreteVariable":
|
|
262
|
+
"""
|
|
263
|
+
Parameters
|
|
264
|
+
----------
|
|
265
|
+
val : NameOrId
|
|
266
|
+
The name or id of the variable.
|
|
267
|
+
|
|
268
|
+
Returns
|
|
269
|
+
-------
|
|
270
|
+
pyagrum.DiscreteVariable
|
|
271
|
+
The corresponding variable.
|
|
272
|
+
|
|
273
|
+
Raises
|
|
274
|
+
------
|
|
275
|
+
pyagrum.NotFound
|
|
276
|
+
If the variable is not found in the CTBN.
|
|
277
|
+
"""
|
|
278
|
+
return self._id2var[self._nameOrId(val)]
|
|
279
|
+
|
|
280
|
+
def variables(self) -> List[pyagrum.DiscreteVariable]:
|
|
281
|
+
"""
|
|
282
|
+
Returns
|
|
283
|
+
-------
|
|
284
|
+
List[pyagrum.DiscreteVariable]
|
|
285
|
+
The list of variables in the CTBN.
|
|
286
|
+
"""
|
|
287
|
+
return [self.variable(i) for i in self.nodes()]
|
|
288
|
+
|
|
289
|
+
def nodes(self) -> List[NodeId]:
|
|
290
|
+
"""
|
|
291
|
+
Returns
|
|
292
|
+
-------
|
|
293
|
+
List[NodeId]
|
|
294
|
+
The list of variables id in the CTBN.
|
|
295
|
+
"""
|
|
296
|
+
return list(self._id2var.keys())
|
|
297
|
+
|
|
298
|
+
def names(self) -> List[str]:
|
|
299
|
+
"""
|
|
300
|
+
Returns
|
|
301
|
+
-------
|
|
302
|
+
List[str]
|
|
303
|
+
The list of variables name in the CTBN.
|
|
304
|
+
"""
|
|
305
|
+
return list(self._name2id.keys())
|
|
306
|
+
|
|
307
|
+
def arcs(self) -> Set[Tuple[NodeId, NodeId]]:
|
|
308
|
+
"""
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
Set[Tuple[NodeId, NodeId]]
|
|
312
|
+
The set of arcs as a set of couple of NodeIds in the CTBN.
|
|
313
|
+
"""
|
|
314
|
+
return self._graph.arcs()
|
|
315
|
+
|
|
316
|
+
def parents(self, val: NameOrId) -> Set[NodeId]:
|
|
317
|
+
"""
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
val : NameOrId
|
|
321
|
+
The variable's name or id.
|
|
322
|
+
|
|
323
|
+
Returns
|
|
324
|
+
-------
|
|
325
|
+
Set[NodeId]
|
|
326
|
+
A set containing the id of the variable's parents in the CTBN.
|
|
327
|
+
|
|
328
|
+
Raises
|
|
329
|
+
------
|
|
330
|
+
pyagrum.NotFound
|
|
331
|
+
If the variable isn't found in the CTBN.
|
|
332
|
+
"""
|
|
333
|
+
return self._graph.parents(self._nameOrId(val))
|
|
334
|
+
|
|
335
|
+
def parentNames(self, val: NameOrId) -> List[str]:
|
|
336
|
+
"""
|
|
337
|
+
Parameters
|
|
338
|
+
----------
|
|
339
|
+
val : NameOrId
|
|
340
|
+
The variable's name or id.
|
|
341
|
+
|
|
342
|
+
Returns
|
|
343
|
+
-------
|
|
344
|
+
List[str]
|
|
345
|
+
A list containing the names of the variable's parents.
|
|
346
|
+
|
|
347
|
+
Raises
|
|
348
|
+
------
|
|
349
|
+
pyagrum.NotFound
|
|
350
|
+
If the variable isn't in the CTBN.
|
|
351
|
+
"""
|
|
352
|
+
return [self.name(n) for n in self.parents(val)]
|
|
353
|
+
|
|
354
|
+
def children(self, val: NameOrId) -> Set[NodeId]:
|
|
355
|
+
"""
|
|
356
|
+
Parameters
|
|
357
|
+
----------
|
|
358
|
+
val : NameOrId
|
|
359
|
+
The variable's name or id.
|
|
360
|
+
|
|
361
|
+
Returns
|
|
362
|
+
-------
|
|
363
|
+
Set[NodeId]
|
|
364
|
+
A set containing the ids of the variable's children.
|
|
365
|
+
|
|
366
|
+
Raises
|
|
367
|
+
------
|
|
368
|
+
pyagrum.NotFound
|
|
369
|
+
If the variable isn't in the CTBN.
|
|
370
|
+
"""
|
|
371
|
+
return self._graph.children(self._nameOrId(val))
|
|
372
|
+
|
|
373
|
+
def childrenNames(self, val: NameOrId) -> List[str]:
|
|
374
|
+
"""
|
|
375
|
+
Parameters
|
|
376
|
+
----------
|
|
377
|
+
val : NameOrId
|
|
378
|
+
The variable's name or id.
|
|
379
|
+
|
|
380
|
+
Returns
|
|
381
|
+
-------
|
|
382
|
+
List[str]
|
|
383
|
+
A list containing the names of a variable's children.
|
|
384
|
+
|
|
385
|
+
Raises
|
|
386
|
+
------
|
|
387
|
+
pyagrum.NotFound
|
|
388
|
+
If the variable isn't in the CTBN.
|
|
389
|
+
"""
|
|
390
|
+
return [self.name(n) for n in self.children(val)]
|
|
391
|
+
|
|
392
|
+
def CIM(self, val: NameOrId) -> CIM:
|
|
393
|
+
"""
|
|
394
|
+
Parameters
|
|
395
|
+
----------
|
|
396
|
+
val : NameOrId
|
|
397
|
+
The variable's name or id.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
CIM
|
|
402
|
+
The variable's CIM.
|
|
403
|
+
|
|
404
|
+
Raises
|
|
405
|
+
------
|
|
406
|
+
pyagrum.NotFound
|
|
407
|
+
If the variable isn't in the CTBN.
|
|
408
|
+
"""
|
|
409
|
+
return self._cim[self._nameOrId(val)]
|
|
410
|
+
|
|
411
|
+
def completeInstantiation(self):
|
|
412
|
+
"""
|
|
413
|
+
Returns
|
|
414
|
+
-------
|
|
415
|
+
pyagrum.Instantiation
|
|
416
|
+
An instantiation of the variables in the CTBN.
|
|
417
|
+
"""
|
|
418
|
+
res = pyagrum.Instantiation()
|
|
419
|
+
for nod in self.nodes():
|
|
420
|
+
res.add(self.variable(nod))
|
|
421
|
+
return res
|
|
422
|
+
|
|
423
|
+
def fullInstantiation(self):
|
|
424
|
+
"""
|
|
425
|
+
Returns
|
|
426
|
+
-------
|
|
427
|
+
pyagrum.Instatiation
|
|
428
|
+
An instantiation of the variables in the CTBN including the corresponding
|
|
429
|
+
starting and ending (i.e from/to variables) variables.
|
|
430
|
+
"""
|
|
431
|
+
res = pyagrum.Instantiation()
|
|
432
|
+
for nod in self.nodes():
|
|
433
|
+
res.add(self.variable(nod))
|
|
434
|
+
res.add(self.CIM(nod).variable(0)) # v_i
|
|
435
|
+
res.add(self.CIM(nod).variable(1)) # v_j
|
|
436
|
+
return res
|
|
437
|
+
|
|
438
|
+
def toDot(self):
|
|
439
|
+
"""
|
|
440
|
+
Create a display of the graph representating the CTBN.
|
|
441
|
+
|
|
442
|
+
Returns
|
|
443
|
+
-------
|
|
444
|
+
str
|
|
445
|
+
A display of the graph.
|
|
446
|
+
"""
|
|
447
|
+
|
|
448
|
+
chaine = """
|
|
449
|
+
digraph "ctbn" {
|
|
450
|
+
graph [bgcolor=transparent,label=""];
|
|
451
|
+
node [style=filled fillcolor="#ffffaa"];
|
|
452
|
+
|
|
453
|
+
"""
|
|
454
|
+
|
|
455
|
+
# Add the name of the variables
|
|
456
|
+
for nomVar in self.names():
|
|
457
|
+
chaine += ' "' + str(nomVar) + '"' + ";\n"
|
|
458
|
+
chaine += " "
|
|
459
|
+
|
|
460
|
+
# Adding arcs
|
|
461
|
+
for arc in list(self._graph.arcs()):
|
|
462
|
+
num1 = arc[0]
|
|
463
|
+
num2 = arc[1]
|
|
464
|
+
nom1 = self.name(num1)
|
|
465
|
+
nom2 = self.name(num2)
|
|
466
|
+
chaine += ' "' + str(nom1) + '"->"' + str(nom2) + '"' + ";\n"
|
|
467
|
+
|
|
468
|
+
chaine = chaine[:-1]
|
|
469
|
+
chaine += "}"
|
|
470
|
+
return chaine
|
|
471
|
+
|
|
472
|
+
def equals(self, ctbn: "CTBN") -> bool:
|
|
473
|
+
"""
|
|
474
|
+
Tests the topologic equality with another CTBN.
|
|
475
|
+
|
|
476
|
+
Parameters
|
|
477
|
+
----------
|
|
478
|
+
ctbn : CTBN
|
|
479
|
+
CTBN to test equality with.
|
|
480
|
+
|
|
481
|
+
Returns
|
|
482
|
+
-------
|
|
483
|
+
bool
|
|
484
|
+
True if they are equal, False if not.
|
|
485
|
+
"""
|
|
486
|
+
|
|
487
|
+
names1 = self.names()
|
|
488
|
+
names2 = ctbn.names()
|
|
489
|
+
|
|
490
|
+
arcs1 = self.arcs()
|
|
491
|
+
arcs2 = ctbn.arcs()
|
|
492
|
+
|
|
493
|
+
# Checks the number of nodes
|
|
494
|
+
if len(names1) != len(names2):
|
|
495
|
+
print("difference de taille")
|
|
496
|
+
return False
|
|
497
|
+
|
|
498
|
+
# Checks the number of arcs
|
|
499
|
+
if len(arcs1) != len(arcs2):
|
|
500
|
+
print("difference de taille")
|
|
501
|
+
return False
|
|
502
|
+
|
|
503
|
+
# Checks if all nodes from current CTBN are in the other one
|
|
504
|
+
for name in names1:
|
|
505
|
+
if name not in names2:
|
|
506
|
+
print("nom non present dans le ctbn en parametres")
|
|
507
|
+
return False
|
|
508
|
+
|
|
509
|
+
# Checks if all arcs from current CTBN are in the other one
|
|
510
|
+
for arc in arcs1:
|
|
511
|
+
if not ctbn._graph.existsArc(
|
|
512
|
+
ctbn._name2id[self._id2var[arc[0]].name()], ctbn._name2id[self._id2var[arc[1]].name()]
|
|
513
|
+
):
|
|
514
|
+
print(self._id2var[arc[0]].name(), self._id2var[arc[1]].name())
|
|
515
|
+
print("arc non present dans le ctbn en parametres")
|
|
516
|
+
return False
|
|
517
|
+
|
|
518
|
+
return True
|
|
519
|
+
|
|
520
|
+
def _compareCIM(self, ctbn: "CTBN") -> float:
|
|
521
|
+
"""
|
|
522
|
+
Compute the relative equality of the CIMs with another CTBN's CIMs.
|
|
523
|
+
|
|
524
|
+
Parameters
|
|
525
|
+
----------
|
|
526
|
+
ctbn : CTBN
|
|
527
|
+
CTBN to compare with.
|
|
528
|
+
|
|
529
|
+
Returns
|
|
530
|
+
-------
|
|
531
|
+
float
|
|
532
|
+
The difference between the Cims.
|
|
533
|
+
"""
|
|
534
|
+
selfCIMList = [self.CIM(i) for i in self.nodes()]
|
|
535
|
+
PCIM = selfCIMList.pop(0)
|
|
536
|
+
for cim in selfCIMList:
|
|
537
|
+
PCIM = PCIM.amalgamate(cim)
|
|
538
|
+
|
|
539
|
+
ctbnCIMList = [ctbn.CIM(i) for i in ctbn.nodes()]
|
|
540
|
+
QCIM = ctbnCIMList.pop(0)
|
|
541
|
+
for cim in ctbnCIMList:
|
|
542
|
+
QCIM = QCIM.amalgamate(cim)
|
|
543
|
+
|
|
544
|
+
P = PCIM._pot
|
|
545
|
+
Q = QCIM._pot
|
|
546
|
+
|
|
547
|
+
P1 = pyagrum.Tensor(P)
|
|
548
|
+
|
|
549
|
+
P1.fillWith(Q)
|
|
550
|
+
|
|
551
|
+
diff = (P - P1).sq().max()
|
|
552
|
+
|
|
553
|
+
return diff
|
|
554
|
+
|
|
555
|
+
def __getstate__(self):
|
|
556
|
+
state = {
|
|
557
|
+
"nodes": [self.variable(i).toFast() for i in self.nodes()],
|
|
558
|
+
# 0 and 1 is the node itself (#i and #j)
|
|
559
|
+
"parents": {self.variable(i).name(): list(self.CIM(i).varNames)[2:] for i in self.nodes()},
|
|
560
|
+
"cim": {self.variable(i).name(): self.CIM(i)._pot[:].flatten().tolist() for i in self.nodes()},
|
|
561
|
+
}
|
|
562
|
+
return state
|
|
563
|
+
|
|
564
|
+
def __setstate__(self, state):
|
|
565
|
+
self.__init__()
|
|
566
|
+
for node in state["nodes"]:
|
|
567
|
+
self.add(pyagrum.fastVariable(node))
|
|
568
|
+
for node, parents in state["parents"].items():
|
|
569
|
+
for parent in parents:
|
|
570
|
+
self.addArc(parent, node)
|
|
571
|
+
for node, cim in state["cim"].items():
|
|
572
|
+
self.CIM(node)._pot.fillWith(cim)
|
|
573
|
+
return self
|