pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/causal/_doAST.py
ADDED
|
@@ -0,0 +1,761 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
This file defines the needed class for the representation of an abstract syntax tree for causal formula
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
from collections import defaultdict
|
|
46
|
+
from typing import Union, Dict, Optional, Iterable, List
|
|
47
|
+
|
|
48
|
+
import pyagrum
|
|
49
|
+
from pyagrum.causal._types import NameSet
|
|
50
|
+
|
|
51
|
+
# pylint: disable=unused-import
|
|
52
|
+
import pyagrum.causal # for annotations
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class ASTtree:
|
|
56
|
+
"""
|
|
57
|
+
Represents a generic node for the CausalFormula. The type of the node will be registered in a string.
|
|
58
|
+
|
|
59
|
+
Parameters
|
|
60
|
+
----------
|
|
61
|
+
typ: str
|
|
62
|
+
the type of the node (will be specified in concrete children classes.
|
|
63
|
+
verbose: bool
|
|
64
|
+
if True, add some messages
|
|
65
|
+
"""
|
|
66
|
+
|
|
67
|
+
def __init__(self, typ: str, verbose=False):
|
|
68
|
+
"""
|
|
69
|
+
Represents a generic node for the CausalFormula. The type of the node will be registered in a string.
|
|
70
|
+
|
|
71
|
+
Parameters
|
|
72
|
+
----------
|
|
73
|
+
typ: str
|
|
74
|
+
the type of the node (will be specified in concrete children classes.
|
|
75
|
+
verbose: bool
|
|
76
|
+
if True, add some messages
|
|
77
|
+
"""
|
|
78
|
+
self._type = typ
|
|
79
|
+
self.__continueNextLine = "| "
|
|
80
|
+
self._verbose = verbose
|
|
81
|
+
|
|
82
|
+
@property
|
|
83
|
+
def _continueNextLine(self):
|
|
84
|
+
return self.__continueNextLine
|
|
85
|
+
|
|
86
|
+
@property
|
|
87
|
+
def type(self) -> str:
|
|
88
|
+
"""
|
|
89
|
+
Returns
|
|
90
|
+
-------
|
|
91
|
+
str
|
|
92
|
+
the type of the node
|
|
93
|
+
"""
|
|
94
|
+
return self._type
|
|
95
|
+
|
|
96
|
+
def __str__(self, prefix: str = "") -> str:
|
|
97
|
+
"""
|
|
98
|
+
stringify a CausalFormula tree
|
|
99
|
+
|
|
100
|
+
Parameters
|
|
101
|
+
----------
|
|
102
|
+
prefix: str
|
|
103
|
+
a prefix for each line of the string representation
|
|
104
|
+
|
|
105
|
+
Returns
|
|
106
|
+
-------
|
|
107
|
+
str
|
|
108
|
+
the string version of the tree
|
|
109
|
+
"""
|
|
110
|
+
raise NotImplementedError
|
|
111
|
+
|
|
112
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
113
|
+
"""
|
|
114
|
+
Create a protected LaTeX representation of a ASTtree
|
|
115
|
+
|
|
116
|
+
Parameters
|
|
117
|
+
----------
|
|
118
|
+
nameOccur: Dict[str,int]
|
|
119
|
+
the number of occurrence for each variable
|
|
120
|
+
|
|
121
|
+
Returns
|
|
122
|
+
-------
|
|
123
|
+
str
|
|
124
|
+
a protected version of LaTeX representation of the tree
|
|
125
|
+
"""
|
|
126
|
+
raise NotImplementedError
|
|
127
|
+
|
|
128
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
129
|
+
"""
|
|
130
|
+
Internal virtual function to create a LaTeX representation of the ASTtree
|
|
131
|
+
|
|
132
|
+
Parameters
|
|
133
|
+
----------
|
|
134
|
+
nameOccur: Dict[str,int]
|
|
135
|
+
the number of occurrence for each variable
|
|
136
|
+
|
|
137
|
+
Returns
|
|
138
|
+
-------
|
|
139
|
+
str
|
|
140
|
+
LaTeX representation of the tree
|
|
141
|
+
"""
|
|
142
|
+
raise NotImplementedError
|
|
143
|
+
|
|
144
|
+
def toLatex(self, nameOccur: Optional[Dict[str, int]] = None) -> str:
|
|
145
|
+
"""
|
|
146
|
+
Create a LaTeX representation of a ASTtree
|
|
147
|
+
|
|
148
|
+
Parameters
|
|
149
|
+
----------
|
|
150
|
+
nameOccur: Dict[str,int] default=None
|
|
151
|
+
the number of occurrence for each variable
|
|
152
|
+
|
|
153
|
+
Returns
|
|
154
|
+
-------
|
|
155
|
+
str
|
|
156
|
+
LaTeX representation of the tree
|
|
157
|
+
"""
|
|
158
|
+
if nameOccur is None:
|
|
159
|
+
nameOccur = defaultdict(int)
|
|
160
|
+
return self.fastToLatex(nameOccur)
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def _latexCorrect(srcName: Union[str, Iterable[str]], nameOccur: Dict[str, int]) -> Union[str, Iterable[str]]:
|
|
164
|
+
"""
|
|
165
|
+
Change the latex presentation of variable w.r.t the number of occurrence of this variable : for instance,
|
|
166
|
+
add primes when necessary
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
srcName: str
|
|
171
|
+
the name or an iterable containing a collection of names
|
|
172
|
+
nameOccur: Dict[str,int]
|
|
173
|
+
the dict that gives the number of occurrence for each variable (default value 0 if the variable
|
|
174
|
+
is not a key in this dict)
|
|
175
|
+
|
|
176
|
+
Returns
|
|
177
|
+
-------
|
|
178
|
+
str | Iterable[str]
|
|
179
|
+
the corrected name or the list of corrected names
|
|
180
|
+
"""
|
|
181
|
+
|
|
182
|
+
def __transform(v: str) -> str:
|
|
183
|
+
nbr = max(0, nameOccur[v] - 1)
|
|
184
|
+
return v + ("'" * nbr)
|
|
185
|
+
|
|
186
|
+
if isinstance(srcName, str):
|
|
187
|
+
return __transform(srcName)
|
|
188
|
+
|
|
189
|
+
return sorted([__transform(v) for v in srcName])
|
|
190
|
+
|
|
191
|
+
def copy(self) -> "ASTtree":
|
|
192
|
+
"""
|
|
193
|
+
Copy an CausalFormula tree
|
|
194
|
+
|
|
195
|
+
Returns
|
|
196
|
+
-------
|
|
197
|
+
ASTtree
|
|
198
|
+
the new causal tree
|
|
199
|
+
"""
|
|
200
|
+
raise NotImplementedError
|
|
201
|
+
|
|
202
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
203
|
+
"""
|
|
204
|
+
Evaluation of a AST tree from inside a BN
|
|
205
|
+
|
|
206
|
+
Parameters
|
|
207
|
+
----------
|
|
208
|
+
contextual_bn: pyagrum.BayesNet
|
|
209
|
+
the observational Bayesian network in which will be done the computations
|
|
210
|
+
|
|
211
|
+
Returns
|
|
212
|
+
-------
|
|
213
|
+
pyagrum.Tensor
|
|
214
|
+
the resulting Tensor
|
|
215
|
+
"""
|
|
216
|
+
raise NotImplementedError
|
|
217
|
+
|
|
218
|
+
|
|
219
|
+
class ASTBinaryOp(ASTtree):
|
|
220
|
+
"""
|
|
221
|
+
Represents a generic binary node for the CausalFormula. The op1 and op2 are the two operands of the class.
|
|
222
|
+
|
|
223
|
+
Parameters
|
|
224
|
+
----------
|
|
225
|
+
typ: str
|
|
226
|
+
the type of the node (will be specified in concrete children classes
|
|
227
|
+
op1: ASTtree
|
|
228
|
+
left operand
|
|
229
|
+
op2: ASTtree
|
|
230
|
+
right operand
|
|
231
|
+
"""
|
|
232
|
+
|
|
233
|
+
def __init__(self, typ: str, op1: ASTtree, op2: ASTtree):
|
|
234
|
+
"""
|
|
235
|
+
Represents a generic binary node for the CausalFormula. The op1 and op2 are the two operands of the class.
|
|
236
|
+
|
|
237
|
+
Parameters
|
|
238
|
+
----------
|
|
239
|
+
typ: str
|
|
240
|
+
the type of the node (will be specified in concrete children classes
|
|
241
|
+
op1: ASTtree
|
|
242
|
+
left operand
|
|
243
|
+
op2: ASTtree
|
|
244
|
+
right operand
|
|
245
|
+
"""
|
|
246
|
+
super().__init__(typ)
|
|
247
|
+
self._op1: ASTtree = op1
|
|
248
|
+
self._op2: ASTtree = op2
|
|
249
|
+
|
|
250
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
251
|
+
raise NotImplementedError
|
|
252
|
+
|
|
253
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
254
|
+
raise NotImplementedError
|
|
255
|
+
|
|
256
|
+
def copy(self) -> "ASTtree":
|
|
257
|
+
raise NotImplementedError
|
|
258
|
+
|
|
259
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
260
|
+
raise NotImplementedError
|
|
261
|
+
|
|
262
|
+
@property
|
|
263
|
+
def op1(self) -> ASTtree:
|
|
264
|
+
"""
|
|
265
|
+
Returns
|
|
266
|
+
-------
|
|
267
|
+
ASTtree
|
|
268
|
+
the left operand
|
|
269
|
+
"""
|
|
270
|
+
return self._op1
|
|
271
|
+
|
|
272
|
+
@property
|
|
273
|
+
def op2(self) -> ASTtree:
|
|
274
|
+
"""
|
|
275
|
+
Returns
|
|
276
|
+
-------
|
|
277
|
+
ASTtree
|
|
278
|
+
the right operand
|
|
279
|
+
"""
|
|
280
|
+
return self._op2
|
|
281
|
+
|
|
282
|
+
def __str__(self, prefix: str = "") -> str:
|
|
283
|
+
return f"""{prefix}{self.type}\n{self.op1.__str__(prefix + self._continueNextLine)}
|
|
284
|
+
{self.op2.__str__(prefix + self._continueNextLine)}"""
|
|
285
|
+
|
|
286
|
+
|
|
287
|
+
class ASTplus(ASTBinaryOp):
|
|
288
|
+
"""
|
|
289
|
+
Represents the sum of 2 :class:`causal.ASTtree`
|
|
290
|
+
|
|
291
|
+
Parameters
|
|
292
|
+
----------
|
|
293
|
+
op1: ASTtree
|
|
294
|
+
left operand
|
|
295
|
+
op2: ASTtree
|
|
296
|
+
right operand
|
|
297
|
+
"""
|
|
298
|
+
|
|
299
|
+
def __init__(self, op1: ASTtree, op2: ASTtree):
|
|
300
|
+
"""
|
|
301
|
+
Represents the sum of 2 :class:`causal.ASTtree`
|
|
302
|
+
|
|
303
|
+
Parameters
|
|
304
|
+
----------
|
|
305
|
+
op1: ASTtree
|
|
306
|
+
left operand
|
|
307
|
+
op2: ASTtree
|
|
308
|
+
right operand
|
|
309
|
+
"""
|
|
310
|
+
super().__init__("+", op1, op2)
|
|
311
|
+
|
|
312
|
+
def copy(self) -> "ASTtree":
|
|
313
|
+
return ASTplus(self.op1.copy(), self.op2.copy())
|
|
314
|
+
|
|
315
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
316
|
+
return f"\\left({self.fastToLatex(nameOccur)}\\right)"
|
|
317
|
+
|
|
318
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
319
|
+
return self.op1.fastToLatex(nameOccur) + "+" + self.op2.fastToLatex(nameOccur)
|
|
320
|
+
|
|
321
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
322
|
+
if self._verbose:
|
|
323
|
+
print("EVAL operation + ", flush=True)
|
|
324
|
+
res = self.op1.eval(contextual_bn) + self.op2.eval(contextual_bn)
|
|
325
|
+
|
|
326
|
+
if self._verbose:
|
|
327
|
+
print(f"END OF EVAL operation : {res}", flush=True)
|
|
328
|
+
|
|
329
|
+
return res
|
|
330
|
+
|
|
331
|
+
|
|
332
|
+
class ASTminus(ASTBinaryOp):
|
|
333
|
+
"""
|
|
334
|
+
Represents the substraction of 2 :class:`causal.ASTtree`
|
|
335
|
+
|
|
336
|
+
Parameters
|
|
337
|
+
----------
|
|
338
|
+
op1: ASTtree
|
|
339
|
+
left operand
|
|
340
|
+
op2: ASTtree
|
|
341
|
+
right operand
|
|
342
|
+
"""
|
|
343
|
+
|
|
344
|
+
def __init__(self, op1: ASTtree, op2: ASTtree):
|
|
345
|
+
"""
|
|
346
|
+
Represents the substraction of 2 :class:`causal.ASTtree`
|
|
347
|
+
|
|
348
|
+
Parameters
|
|
349
|
+
----------
|
|
350
|
+
op1: ASTtree
|
|
351
|
+
left operand
|
|
352
|
+
op2: ASTtree
|
|
353
|
+
right operand
|
|
354
|
+
"""
|
|
355
|
+
super().__init__("-", op1, op2)
|
|
356
|
+
|
|
357
|
+
def copy(self) -> "ASTtree":
|
|
358
|
+
return ASTminus(self.op1.copy(), self.op2.copy())
|
|
359
|
+
|
|
360
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
361
|
+
return "\\left(" + self.fastToLatex(nameOccur) + "\\right)"
|
|
362
|
+
|
|
363
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
364
|
+
return self.op1.fastToLatex(nameOccur) + "-" + self.op2.fastToLatex(nameOccur)
|
|
365
|
+
|
|
366
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
367
|
+
if self._verbose:
|
|
368
|
+
print("EVAL operation", flush=True)
|
|
369
|
+
res = self.op1.eval(contextual_bn) - self.op2.eval(contextual_bn)
|
|
370
|
+
|
|
371
|
+
if self._verbose:
|
|
372
|
+
print(f"END OF EVAL operation : {res}", flush=True)
|
|
373
|
+
|
|
374
|
+
return res
|
|
375
|
+
|
|
376
|
+
|
|
377
|
+
class ASTmult(ASTBinaryOp):
|
|
378
|
+
"""
|
|
379
|
+
Represents the multiplication of 2 :class:`causal.ASTtree`
|
|
380
|
+
|
|
381
|
+
Parameters
|
|
382
|
+
----------
|
|
383
|
+
op1: ASTtree
|
|
384
|
+
left operand
|
|
385
|
+
op2: ASTtree
|
|
386
|
+
right operand
|
|
387
|
+
"""
|
|
388
|
+
|
|
389
|
+
def __init__(self, op1: ASTtree, op2: ASTtree):
|
|
390
|
+
"""
|
|
391
|
+
Represents the multiplication of 2 :class:`causal.ASTtree`
|
|
392
|
+
|
|
393
|
+
Parameters
|
|
394
|
+
----------
|
|
395
|
+
op1: ASTtree
|
|
396
|
+
left operand
|
|
397
|
+
op2: ASTtree
|
|
398
|
+
right operand
|
|
399
|
+
"""
|
|
400
|
+
super().__init__("*", op1, op2)
|
|
401
|
+
|
|
402
|
+
def copy(self) -> "ASTtree":
|
|
403
|
+
return ASTmult(self.op1.copy(), self.op2.copy())
|
|
404
|
+
|
|
405
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
406
|
+
return self.fastToLatex(nameOccur)
|
|
407
|
+
|
|
408
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
409
|
+
return self.op1.protectToLatex(nameOccur) + " \\cdot " + self.op2.protectToLatex(nameOccur)
|
|
410
|
+
|
|
411
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
412
|
+
if self._verbose:
|
|
413
|
+
print("EVAL operation * in context", flush=True)
|
|
414
|
+
res = self.op1.eval(contextual_bn) * self.op2.eval(contextual_bn)
|
|
415
|
+
|
|
416
|
+
if self._verbose:
|
|
417
|
+
print(f"END OF EVAL operation * : {res}", flush=True)
|
|
418
|
+
|
|
419
|
+
return res
|
|
420
|
+
|
|
421
|
+
|
|
422
|
+
class ASTdiv(ASTBinaryOp):
|
|
423
|
+
"""
|
|
424
|
+
Represents the division of 2 :class:`causal.ASTtree`
|
|
425
|
+
|
|
426
|
+
Parameters
|
|
427
|
+
----------
|
|
428
|
+
op1: ASTtree
|
|
429
|
+
left operand
|
|
430
|
+
op2: ASTtree
|
|
431
|
+
right operand
|
|
432
|
+
"""
|
|
433
|
+
|
|
434
|
+
def __init__(self, op1: ASTtree, op2: ASTtree):
|
|
435
|
+
"""
|
|
436
|
+
Represents the division of 2 :class:`causal.ASTtree`
|
|
437
|
+
|
|
438
|
+
Parameters
|
|
439
|
+
----------
|
|
440
|
+
op1: ASTtree
|
|
441
|
+
left operand
|
|
442
|
+
op2: ASTtree
|
|
443
|
+
right operand
|
|
444
|
+
"""
|
|
445
|
+
super().__init__("/", op1, op2)
|
|
446
|
+
|
|
447
|
+
def copy(self) -> "ASTtree":
|
|
448
|
+
return ASTdiv(self.op1.copy(), self.copy(self.op2.copy()))
|
|
449
|
+
|
|
450
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
451
|
+
return self.fastToLatex(nameOccur)
|
|
452
|
+
|
|
453
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
454
|
+
return " \\frac {" + self.op1.fastToLatex(nameOccur) + "}{" + self.op2.fastToLatex(nameOccur) + "}"
|
|
455
|
+
|
|
456
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
457
|
+
if self._verbose:
|
|
458
|
+
print("EVAL operation / in context", flush=True)
|
|
459
|
+
res = self.op1.eval(contextual_bn) / self.op2.eval(contextual_bn)
|
|
460
|
+
|
|
461
|
+
if self._verbose:
|
|
462
|
+
print(f"END OF EVAL operation / : {res}", flush=True)
|
|
463
|
+
|
|
464
|
+
return res
|
|
465
|
+
|
|
466
|
+
|
|
467
|
+
class ASTposteriorProba(ASTtree):
|
|
468
|
+
"""
|
|
469
|
+
Represent a conditional probability :math:`P_{bn}(vars|knw)` that can be computed by an inference in a BN.
|
|
470
|
+
|
|
471
|
+
Parameters
|
|
472
|
+
----------
|
|
473
|
+
bn: pyagrum.BayesNet
|
|
474
|
+
the :class:`pyAgrum:pyagrum.BayesNet`
|
|
475
|
+
varset: Set[str]
|
|
476
|
+
a set of variable names (in the BN) conditioned in the posterior
|
|
477
|
+
knw: Set[str]
|
|
478
|
+
a set of variable names (in the BN) conditioning in the posterior
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
def __init__(self, bn: "pyagrum.BayesNet", varset: NameSet, knw: NameSet):
|
|
482
|
+
"""
|
|
483
|
+
Represent a conditional probability :math:`P_{bn}(vars|knw)` that can be computed by an inference in a BN.
|
|
484
|
+
|
|
485
|
+
Parameters
|
|
486
|
+
----------
|
|
487
|
+
bn: pyagrum.BayesNet
|
|
488
|
+
the :class:`pyAgrum:pyagrum.BayesNet`
|
|
489
|
+
varset: Set[str]
|
|
490
|
+
a set of variable names (in the BN) conditioned in the posterior
|
|
491
|
+
knw: Set[str]
|
|
492
|
+
a set of variable names (in the BN) conditioning in the posterior
|
|
493
|
+
"""
|
|
494
|
+
super().__init__("_posterior_")
|
|
495
|
+
if not isinstance(varset, set):
|
|
496
|
+
raise ValueError("'varset' must be a set")
|
|
497
|
+
if not isinstance(knw, set):
|
|
498
|
+
raise ValueError("'knw' must be a set")
|
|
499
|
+
|
|
500
|
+
self._vars = varset
|
|
501
|
+
self._bn = bn
|
|
502
|
+
minKnames = {bn.variable(i).name() for i in bn.minimalCondSet(varset, knw)}
|
|
503
|
+
self._knw = minKnames
|
|
504
|
+
|
|
505
|
+
@property
|
|
506
|
+
def vars(self) -> NameSet:
|
|
507
|
+
"""
|
|
508
|
+
Returns
|
|
509
|
+
-------
|
|
510
|
+
Set[str]
|
|
511
|
+
(Conditioned) vars in :math:`P_{bn}(vars|knw)`
|
|
512
|
+
"""
|
|
513
|
+
return self._vars
|
|
514
|
+
|
|
515
|
+
@property
|
|
516
|
+
def knw(self) -> NameSet:
|
|
517
|
+
"""
|
|
518
|
+
Returns
|
|
519
|
+
-------
|
|
520
|
+
Set[str]
|
|
521
|
+
(Conditioning) knw in :math:`P_{bn}(vars|knw)`
|
|
522
|
+
"""
|
|
523
|
+
return self._knw
|
|
524
|
+
|
|
525
|
+
@property
|
|
526
|
+
def bn(self) -> "pyagrum.BayesNet":
|
|
527
|
+
"""
|
|
528
|
+
Returns
|
|
529
|
+
-------
|
|
530
|
+
pyagrum.BayesNet
|
|
531
|
+
the observationnal BayesNet in :math:`P_{bn}(vars|knw)`
|
|
532
|
+
"""
|
|
533
|
+
return self._bn
|
|
534
|
+
|
|
535
|
+
def __str__(self, prefix: str = "") -> str:
|
|
536
|
+
s = "P("
|
|
537
|
+
s += ",".join(sorted(self.vars))
|
|
538
|
+
if self.knw is not None:
|
|
539
|
+
s += "|"
|
|
540
|
+
s += ",".join(sorted(self.knw))
|
|
541
|
+
s += ")"
|
|
542
|
+
return f"{prefix}{s}"
|
|
543
|
+
|
|
544
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
545
|
+
return self.fastToLatex(nameOccur)
|
|
546
|
+
|
|
547
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
548
|
+
s = "P\\left(" + ",".join(self._latexCorrect(self.vars, nameOccur))
|
|
549
|
+
if self.knw is not None and len(self.knw) > 0:
|
|
550
|
+
s += "\\mid "
|
|
551
|
+
s += ",".join(self._latexCorrect(self.knw, nameOccur))
|
|
552
|
+
|
|
553
|
+
s += "\\right)"
|
|
554
|
+
|
|
555
|
+
return s
|
|
556
|
+
|
|
557
|
+
def copy(self) -> "ASTtree":
|
|
558
|
+
return ASTposteriorProba(self.bn, self.vars, self.knw)
|
|
559
|
+
|
|
560
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
561
|
+
if self._verbose:
|
|
562
|
+
print(f"EVAL ${self.fastToLatex(defaultdict(int))} in context", flush=True)
|
|
563
|
+
ie = pyagrum.LazyPropagation(contextual_bn)
|
|
564
|
+
p = None
|
|
565
|
+
|
|
566
|
+
# simple case : we just need a CPT from the BN
|
|
567
|
+
if len(self.vars) == 1:
|
|
568
|
+
for x in self.vars:
|
|
569
|
+
break # we keep the first one and only one
|
|
570
|
+
ix = contextual_bn.idFromName(x)
|
|
571
|
+
if {contextual_bn.variable(i).name() for i in contextual_bn.parents(ix)} == self.knw:
|
|
572
|
+
p = contextual_bn.cpt(ix)
|
|
573
|
+
|
|
574
|
+
if p is None:
|
|
575
|
+
if len(self.knw) == 0:
|
|
576
|
+
ie.addJointTarget(self.vars)
|
|
577
|
+
ie.makeInference()
|
|
578
|
+
p = ie.jointPosterior(self.vars)
|
|
579
|
+
else:
|
|
580
|
+
ie.addJointTarget(self.vars | self.knw)
|
|
581
|
+
ie.makeInference()
|
|
582
|
+
p = ie.jointPosterior(self.vars | self.knw) / ie.jointPosterior(self.knw)
|
|
583
|
+
|
|
584
|
+
#
|
|
585
|
+
# res = p.extract({k: v for k, v in context.todict().items() if k in self.vars + self.knw})
|
|
586
|
+
|
|
587
|
+
if self._verbose:
|
|
588
|
+
print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {p}", flush=True)
|
|
589
|
+
|
|
590
|
+
return p
|
|
591
|
+
|
|
592
|
+
|
|
593
|
+
class ASTjointProba(ASTtree):
|
|
594
|
+
"""
|
|
595
|
+
Represent a joint probability in the base observational part of the :class:`causal.CausalModel`
|
|
596
|
+
|
|
597
|
+
Parameters
|
|
598
|
+
----------
|
|
599
|
+
varNames: Set[str]
|
|
600
|
+
a set of variable names
|
|
601
|
+
"""
|
|
602
|
+
|
|
603
|
+
def __init__(self, varNames: NameSet):
|
|
604
|
+
"""
|
|
605
|
+
Represent a joint probability in the base observational part of the :class:`causal.CausalModel`
|
|
606
|
+
|
|
607
|
+
Parameters
|
|
608
|
+
----------
|
|
609
|
+
varNames: Set[str]
|
|
610
|
+
a set of variable names
|
|
611
|
+
"""
|
|
612
|
+
super().__init__("_joint_")
|
|
613
|
+
self._varNames = varNames
|
|
614
|
+
|
|
615
|
+
@property
|
|
616
|
+
def varNames(self) -> NameSet:
|
|
617
|
+
"""
|
|
618
|
+
Returns
|
|
619
|
+
-------
|
|
620
|
+
Set[str]
|
|
621
|
+
the set of names of var
|
|
622
|
+
"""
|
|
623
|
+
return self._varNames
|
|
624
|
+
|
|
625
|
+
def __str__(self, prefix: str = "") -> str:
|
|
626
|
+
s = "P("
|
|
627
|
+
s += ",".join(sorted(self._varNames))
|
|
628
|
+
s += ")"
|
|
629
|
+
return f"{prefix}joint {s}"
|
|
630
|
+
|
|
631
|
+
def copy(self) -> "ASTtree":
|
|
632
|
+
return ASTjointProba(self.varNames)
|
|
633
|
+
|
|
634
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
635
|
+
return self.fastToLatex(nameOccur)
|
|
636
|
+
|
|
637
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
638
|
+
return "P\\left(" + ",".join(self._latexCorrect(self.varNames, nameOccur)) + "\\right)"
|
|
639
|
+
|
|
640
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
641
|
+
if self._verbose:
|
|
642
|
+
print(f"EVAL ${self.fastToLatex(defaultdict(int))}$ in context", flush=True)
|
|
643
|
+
ie = pyagrum.LazyPropagation(contextual_bn)
|
|
644
|
+
if len(self.varNames) > 1:
|
|
645
|
+
svars = set(self.varNames)
|
|
646
|
+
ie.addJointTarget(svars)
|
|
647
|
+
ie.makeInference()
|
|
648
|
+
res = ie.jointPosterior(svars)
|
|
649
|
+
else:
|
|
650
|
+
for name in self.varNames:
|
|
651
|
+
break # take the first and only one name in varNames
|
|
652
|
+
ie.makeInference()
|
|
653
|
+
res = ie.posterior(name)
|
|
654
|
+
|
|
655
|
+
if self._verbose:
|
|
656
|
+
print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {res}", flush=True)
|
|
657
|
+
|
|
658
|
+
return res
|
|
659
|
+
|
|
660
|
+
|
|
661
|
+
class ASTsum(ASTtree):
|
|
662
|
+
"""
|
|
663
|
+
Represents a sum over a variable of a :class:`causal.ASTtree`.
|
|
664
|
+
|
|
665
|
+
Parameters
|
|
666
|
+
----------
|
|
667
|
+
var: str
|
|
668
|
+
name of the variable on which to sum
|
|
669
|
+
term: ASTtree
|
|
670
|
+
the tree to be evaluated
|
|
671
|
+
"""
|
|
672
|
+
|
|
673
|
+
def __init__(self, var: str, term: ASTtree):
|
|
674
|
+
"""
|
|
675
|
+
Represents a sum over a variable of a :class:`causal.ASTtree`.
|
|
676
|
+
|
|
677
|
+
Parameters
|
|
678
|
+
----------
|
|
679
|
+
var: str
|
|
680
|
+
name of the variable on which to sum
|
|
681
|
+
term: ASTtree
|
|
682
|
+
the tree to be evaluated
|
|
683
|
+
"""
|
|
684
|
+
super().__init__("_sum_")
|
|
685
|
+
|
|
686
|
+
va = var if isinstance(var, list) else [var]
|
|
687
|
+
self.var = va[0]
|
|
688
|
+
|
|
689
|
+
if len(va) > 1:
|
|
690
|
+
self._term = ASTsum(va[1:], term)
|
|
691
|
+
else:
|
|
692
|
+
self._term = term
|
|
693
|
+
|
|
694
|
+
@property
|
|
695
|
+
def term(self) -> ASTtree:
|
|
696
|
+
"""
|
|
697
|
+
Returns
|
|
698
|
+
-------
|
|
699
|
+
ASTtree
|
|
700
|
+
the term to sum
|
|
701
|
+
"""
|
|
702
|
+
return self._term
|
|
703
|
+
|
|
704
|
+
def __str__(self, prefix: str = "") -> str:
|
|
705
|
+
l = []
|
|
706
|
+
a = self
|
|
707
|
+
while a.type == "_sum_":
|
|
708
|
+
l.append(a.var)
|
|
709
|
+
a = a.term
|
|
710
|
+
return f"{prefix}sum on {','.join(sorted(l))} for\n{a.__str__(prefix + self._continueNextLine)}"
|
|
711
|
+
|
|
712
|
+
def copy(self) -> "ASTtree":
|
|
713
|
+
return ASTsum(self.var, self.term.copy())
|
|
714
|
+
|
|
715
|
+
def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
716
|
+
return "\\left(" + self.fastToLatex(nameOccur) + "\\right)"
|
|
717
|
+
|
|
718
|
+
def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
|
|
719
|
+
la = []
|
|
720
|
+
a = self
|
|
721
|
+
while a.type == "_sum_":
|
|
722
|
+
la.append(a.var)
|
|
723
|
+
nameOccur[a.var] += 1
|
|
724
|
+
a = a.term
|
|
725
|
+
|
|
726
|
+
res = "\\sum_{" + (",".join(self._latexCorrect(la, nameOccur))) + "}{" + a.fastToLatex(nameOccur) + "}"
|
|
727
|
+
for v in la:
|
|
728
|
+
nameOccur[v] -= 1
|
|
729
|
+
|
|
730
|
+
return res
|
|
731
|
+
|
|
732
|
+
def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
|
|
733
|
+
if self._verbose:
|
|
734
|
+
print(f"EVAL ${self.fastToLatex(defaultdict(int))}$", flush=True)
|
|
735
|
+
|
|
736
|
+
res = self.term.eval(contextual_bn).sumOut([self.var])
|
|
737
|
+
|
|
738
|
+
if self._verbose:
|
|
739
|
+
print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {res}", flush=True)
|
|
740
|
+
|
|
741
|
+
return res
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
def productOfTrees(lterms: List[ASTtree]) -> ASTtree:
|
|
745
|
+
"""
|
|
746
|
+
create an ASTtree for a sequence of multiplications of ASTtree
|
|
747
|
+
|
|
748
|
+
Parameters
|
|
749
|
+
----------
|
|
750
|
+
lterms: List[ASTtree]
|
|
751
|
+
the trees (as ASTtree) to multiply
|
|
752
|
+
|
|
753
|
+
Returns
|
|
754
|
+
-------
|
|
755
|
+
ASTtree
|
|
756
|
+
the ASTtree representing the tree of multiplications
|
|
757
|
+
|
|
758
|
+
"""
|
|
759
|
+
if len(lterms) == 1:
|
|
760
|
+
return lterms[0]
|
|
761
|
+
return ASTmult(lterms[0], productOfTrees(lterms[1:]))
|