pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +171 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
@@ -0,0 +1,761 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ """
42
+ This file defines the needed class for the representation of an abstract syntax tree for causal formula
43
+ """
44
+
45
+ from collections import defaultdict
46
+ from typing import Union, Dict, Optional, Iterable, List
47
+
48
+ import pyagrum
49
+ from pyagrum.causal._types import NameSet
50
+
51
+ # pylint: disable=unused-import
52
+ import pyagrum.causal # for annotations
53
+
54
+
55
+ class ASTtree:
56
+ """
57
+ Represents a generic node for the CausalFormula. The type of the node will be registered in a string.
58
+
59
+ Parameters
60
+ ----------
61
+ typ: str
62
+ the type of the node (will be specified in concrete children classes.
63
+ verbose: bool
64
+ if True, add some messages
65
+ """
66
+
67
+ def __init__(self, typ: str, verbose=False):
68
+ """
69
+ Represents a generic node for the CausalFormula. The type of the node will be registered in a string.
70
+
71
+ Parameters
72
+ ----------
73
+ typ: str
74
+ the type of the node (will be specified in concrete children classes.
75
+ verbose: bool
76
+ if True, add some messages
77
+ """
78
+ self._type = typ
79
+ self.__continueNextLine = "| "
80
+ self._verbose = verbose
81
+
82
+ @property
83
+ def _continueNextLine(self):
84
+ return self.__continueNextLine
85
+
86
+ @property
87
+ def type(self) -> str:
88
+ """
89
+ Returns
90
+ -------
91
+ str
92
+ the type of the node
93
+ """
94
+ return self._type
95
+
96
+ def __str__(self, prefix: str = "") -> str:
97
+ """
98
+ stringify a CausalFormula tree
99
+
100
+ Parameters
101
+ ----------
102
+ prefix: str
103
+ a prefix for each line of the string representation
104
+
105
+ Returns
106
+ -------
107
+ str
108
+ the string version of the tree
109
+ """
110
+ raise NotImplementedError
111
+
112
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
113
+ """
114
+ Create a protected LaTeX representation of a ASTtree
115
+
116
+ Parameters
117
+ ----------
118
+ nameOccur: Dict[str,int]
119
+ the number of occurrence for each variable
120
+
121
+ Returns
122
+ -------
123
+ str
124
+ a protected version of LaTeX representation of the tree
125
+ """
126
+ raise NotImplementedError
127
+
128
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
129
+ """
130
+ Internal virtual function to create a LaTeX representation of the ASTtree
131
+
132
+ Parameters
133
+ ----------
134
+ nameOccur: Dict[str,int]
135
+ the number of occurrence for each variable
136
+
137
+ Returns
138
+ -------
139
+ str
140
+ LaTeX representation of the tree
141
+ """
142
+ raise NotImplementedError
143
+
144
+ def toLatex(self, nameOccur: Optional[Dict[str, int]] = None) -> str:
145
+ """
146
+ Create a LaTeX representation of a ASTtree
147
+
148
+ Parameters
149
+ ----------
150
+ nameOccur: Dict[str,int] default=None
151
+ the number of occurrence for each variable
152
+
153
+ Returns
154
+ -------
155
+ str
156
+ LaTeX representation of the tree
157
+ """
158
+ if nameOccur is None:
159
+ nameOccur = defaultdict(int)
160
+ return self.fastToLatex(nameOccur)
161
+
162
+ @staticmethod
163
+ def _latexCorrect(srcName: Union[str, Iterable[str]], nameOccur: Dict[str, int]) -> Union[str, Iterable[str]]:
164
+ """
165
+ Change the latex presentation of variable w.r.t the number of occurrence of this variable : for instance,
166
+ add primes when necessary
167
+
168
+ Parameters
169
+ ----------
170
+ srcName: str
171
+ the name or an iterable containing a collection of names
172
+ nameOccur: Dict[str,int]
173
+ the dict that gives the number of occurrence for each variable (default value 0 if the variable
174
+ is not a key in this dict)
175
+
176
+ Returns
177
+ -------
178
+ str | Iterable[str]
179
+ the corrected name or the list of corrected names
180
+ """
181
+
182
+ def __transform(v: str) -> str:
183
+ nbr = max(0, nameOccur[v] - 1)
184
+ return v + ("'" * nbr)
185
+
186
+ if isinstance(srcName, str):
187
+ return __transform(srcName)
188
+
189
+ return sorted([__transform(v) for v in srcName])
190
+
191
+ def copy(self) -> "ASTtree":
192
+ """
193
+ Copy an CausalFormula tree
194
+
195
+ Returns
196
+ -------
197
+ ASTtree
198
+ the new causal tree
199
+ """
200
+ raise NotImplementedError
201
+
202
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
203
+ """
204
+ Evaluation of a AST tree from inside a BN
205
+
206
+ Parameters
207
+ ----------
208
+ contextual_bn: pyagrum.BayesNet
209
+ the observational Bayesian network in which will be done the computations
210
+
211
+ Returns
212
+ -------
213
+ pyagrum.Tensor
214
+ the resulting Tensor
215
+ """
216
+ raise NotImplementedError
217
+
218
+
219
+ class ASTBinaryOp(ASTtree):
220
+ """
221
+ Represents a generic binary node for the CausalFormula. The op1 and op2 are the two operands of the class.
222
+
223
+ Parameters
224
+ ----------
225
+ typ: str
226
+ the type of the node (will be specified in concrete children classes
227
+ op1: ASTtree
228
+ left operand
229
+ op2: ASTtree
230
+ right operand
231
+ """
232
+
233
+ def __init__(self, typ: str, op1: ASTtree, op2: ASTtree):
234
+ """
235
+ Represents a generic binary node for the CausalFormula. The op1 and op2 are the two operands of the class.
236
+
237
+ Parameters
238
+ ----------
239
+ typ: str
240
+ the type of the node (will be specified in concrete children classes
241
+ op1: ASTtree
242
+ left operand
243
+ op2: ASTtree
244
+ right operand
245
+ """
246
+ super().__init__(typ)
247
+ self._op1: ASTtree = op1
248
+ self._op2: ASTtree = op2
249
+
250
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
251
+ raise NotImplementedError
252
+
253
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
254
+ raise NotImplementedError
255
+
256
+ def copy(self) -> "ASTtree":
257
+ raise NotImplementedError
258
+
259
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
260
+ raise NotImplementedError
261
+
262
+ @property
263
+ def op1(self) -> ASTtree:
264
+ """
265
+ Returns
266
+ -------
267
+ ASTtree
268
+ the left operand
269
+ """
270
+ return self._op1
271
+
272
+ @property
273
+ def op2(self) -> ASTtree:
274
+ """
275
+ Returns
276
+ -------
277
+ ASTtree
278
+ the right operand
279
+ """
280
+ return self._op2
281
+
282
+ def __str__(self, prefix: str = "") -> str:
283
+ return f"""{prefix}{self.type}\n{self.op1.__str__(prefix + self._continueNextLine)}
284
+ {self.op2.__str__(prefix + self._continueNextLine)}"""
285
+
286
+
287
+ class ASTplus(ASTBinaryOp):
288
+ """
289
+ Represents the sum of 2 :class:`causal.ASTtree`
290
+
291
+ Parameters
292
+ ----------
293
+ op1: ASTtree
294
+ left operand
295
+ op2: ASTtree
296
+ right operand
297
+ """
298
+
299
+ def __init__(self, op1: ASTtree, op2: ASTtree):
300
+ """
301
+ Represents the sum of 2 :class:`causal.ASTtree`
302
+
303
+ Parameters
304
+ ----------
305
+ op1: ASTtree
306
+ left operand
307
+ op2: ASTtree
308
+ right operand
309
+ """
310
+ super().__init__("+", op1, op2)
311
+
312
+ def copy(self) -> "ASTtree":
313
+ return ASTplus(self.op1.copy(), self.op2.copy())
314
+
315
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
316
+ return f"\\left({self.fastToLatex(nameOccur)}\\right)"
317
+
318
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
319
+ return self.op1.fastToLatex(nameOccur) + "+" + self.op2.fastToLatex(nameOccur)
320
+
321
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
322
+ if self._verbose:
323
+ print("EVAL operation + ", flush=True)
324
+ res = self.op1.eval(contextual_bn) + self.op2.eval(contextual_bn)
325
+
326
+ if self._verbose:
327
+ print(f"END OF EVAL operation : {res}", flush=True)
328
+
329
+ return res
330
+
331
+
332
+ class ASTminus(ASTBinaryOp):
333
+ """
334
+ Represents the substraction of 2 :class:`causal.ASTtree`
335
+
336
+ Parameters
337
+ ----------
338
+ op1: ASTtree
339
+ left operand
340
+ op2: ASTtree
341
+ right operand
342
+ """
343
+
344
+ def __init__(self, op1: ASTtree, op2: ASTtree):
345
+ """
346
+ Represents the substraction of 2 :class:`causal.ASTtree`
347
+
348
+ Parameters
349
+ ----------
350
+ op1: ASTtree
351
+ left operand
352
+ op2: ASTtree
353
+ right operand
354
+ """
355
+ super().__init__("-", op1, op2)
356
+
357
+ def copy(self) -> "ASTtree":
358
+ return ASTminus(self.op1.copy(), self.op2.copy())
359
+
360
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
361
+ return "\\left(" + self.fastToLatex(nameOccur) + "\\right)"
362
+
363
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
364
+ return self.op1.fastToLatex(nameOccur) + "-" + self.op2.fastToLatex(nameOccur)
365
+
366
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
367
+ if self._verbose:
368
+ print("EVAL operation", flush=True)
369
+ res = self.op1.eval(contextual_bn) - self.op2.eval(contextual_bn)
370
+
371
+ if self._verbose:
372
+ print(f"END OF EVAL operation : {res}", flush=True)
373
+
374
+ return res
375
+
376
+
377
+ class ASTmult(ASTBinaryOp):
378
+ """
379
+ Represents the multiplication of 2 :class:`causal.ASTtree`
380
+
381
+ Parameters
382
+ ----------
383
+ op1: ASTtree
384
+ left operand
385
+ op2: ASTtree
386
+ right operand
387
+ """
388
+
389
+ def __init__(self, op1: ASTtree, op2: ASTtree):
390
+ """
391
+ Represents the multiplication of 2 :class:`causal.ASTtree`
392
+
393
+ Parameters
394
+ ----------
395
+ op1: ASTtree
396
+ left operand
397
+ op2: ASTtree
398
+ right operand
399
+ """
400
+ super().__init__("*", op1, op2)
401
+
402
+ def copy(self) -> "ASTtree":
403
+ return ASTmult(self.op1.copy(), self.op2.copy())
404
+
405
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
406
+ return self.fastToLatex(nameOccur)
407
+
408
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
409
+ return self.op1.protectToLatex(nameOccur) + " \\cdot " + self.op2.protectToLatex(nameOccur)
410
+
411
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
412
+ if self._verbose:
413
+ print("EVAL operation * in context", flush=True)
414
+ res = self.op1.eval(contextual_bn) * self.op2.eval(contextual_bn)
415
+
416
+ if self._verbose:
417
+ print(f"END OF EVAL operation * : {res}", flush=True)
418
+
419
+ return res
420
+
421
+
422
+ class ASTdiv(ASTBinaryOp):
423
+ """
424
+ Represents the division of 2 :class:`causal.ASTtree`
425
+
426
+ Parameters
427
+ ----------
428
+ op1: ASTtree
429
+ left operand
430
+ op2: ASTtree
431
+ right operand
432
+ """
433
+
434
+ def __init__(self, op1: ASTtree, op2: ASTtree):
435
+ """
436
+ Represents the division of 2 :class:`causal.ASTtree`
437
+
438
+ Parameters
439
+ ----------
440
+ op1: ASTtree
441
+ left operand
442
+ op2: ASTtree
443
+ right operand
444
+ """
445
+ super().__init__("/", op1, op2)
446
+
447
+ def copy(self) -> "ASTtree":
448
+ return ASTdiv(self.op1.copy(), self.copy(self.op2.copy()))
449
+
450
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
451
+ return self.fastToLatex(nameOccur)
452
+
453
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
454
+ return " \\frac {" + self.op1.fastToLatex(nameOccur) + "}{" + self.op2.fastToLatex(nameOccur) + "}"
455
+
456
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
457
+ if self._verbose:
458
+ print("EVAL operation / in context", flush=True)
459
+ res = self.op1.eval(contextual_bn) / self.op2.eval(contextual_bn)
460
+
461
+ if self._verbose:
462
+ print(f"END OF EVAL operation / : {res}", flush=True)
463
+
464
+ return res
465
+
466
+
467
+ class ASTposteriorProba(ASTtree):
468
+ """
469
+ Represent a conditional probability :math:`P_{bn}(vars|knw)` that can be computed by an inference in a BN.
470
+
471
+ Parameters
472
+ ----------
473
+ bn: pyagrum.BayesNet
474
+ the :class:`pyAgrum:pyagrum.BayesNet`
475
+ varset: Set[str]
476
+ a set of variable names (in the BN) conditioned in the posterior
477
+ knw: Set[str]
478
+ a set of variable names (in the BN) conditioning in the posterior
479
+ """
480
+
481
+ def __init__(self, bn: "pyagrum.BayesNet", varset: NameSet, knw: NameSet):
482
+ """
483
+ Represent a conditional probability :math:`P_{bn}(vars|knw)` that can be computed by an inference in a BN.
484
+
485
+ Parameters
486
+ ----------
487
+ bn: pyagrum.BayesNet
488
+ the :class:`pyAgrum:pyagrum.BayesNet`
489
+ varset: Set[str]
490
+ a set of variable names (in the BN) conditioned in the posterior
491
+ knw: Set[str]
492
+ a set of variable names (in the BN) conditioning in the posterior
493
+ """
494
+ super().__init__("_posterior_")
495
+ if not isinstance(varset, set):
496
+ raise ValueError("'varset' must be a set")
497
+ if not isinstance(knw, set):
498
+ raise ValueError("'knw' must be a set")
499
+
500
+ self._vars = varset
501
+ self._bn = bn
502
+ minKnames = {bn.variable(i).name() for i in bn.minimalCondSet(varset, knw)}
503
+ self._knw = minKnames
504
+
505
+ @property
506
+ def vars(self) -> NameSet:
507
+ """
508
+ Returns
509
+ -------
510
+ Set[str]
511
+ (Conditioned) vars in :math:`P_{bn}(vars|knw)`
512
+ """
513
+ return self._vars
514
+
515
+ @property
516
+ def knw(self) -> NameSet:
517
+ """
518
+ Returns
519
+ -------
520
+ Set[str]
521
+ (Conditioning) knw in :math:`P_{bn}(vars|knw)`
522
+ """
523
+ return self._knw
524
+
525
+ @property
526
+ def bn(self) -> "pyagrum.BayesNet":
527
+ """
528
+ Returns
529
+ -------
530
+ pyagrum.BayesNet
531
+ the observationnal BayesNet in :math:`P_{bn}(vars|knw)`
532
+ """
533
+ return self._bn
534
+
535
+ def __str__(self, prefix: str = "") -> str:
536
+ s = "P("
537
+ s += ",".join(sorted(self.vars))
538
+ if self.knw is not None:
539
+ s += "|"
540
+ s += ",".join(sorted(self.knw))
541
+ s += ")"
542
+ return f"{prefix}{s}"
543
+
544
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
545
+ return self.fastToLatex(nameOccur)
546
+
547
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
548
+ s = "P\\left(" + ",".join(self._latexCorrect(self.vars, nameOccur))
549
+ if self.knw is not None and len(self.knw) > 0:
550
+ s += "\\mid "
551
+ s += ",".join(self._latexCorrect(self.knw, nameOccur))
552
+
553
+ s += "\\right)"
554
+
555
+ return s
556
+
557
+ def copy(self) -> "ASTtree":
558
+ return ASTposteriorProba(self.bn, self.vars, self.knw)
559
+
560
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
561
+ if self._verbose:
562
+ print(f"EVAL ${self.fastToLatex(defaultdict(int))} in context", flush=True)
563
+ ie = pyagrum.LazyPropagation(contextual_bn)
564
+ p = None
565
+
566
+ # simple case : we just need a CPT from the BN
567
+ if len(self.vars) == 1:
568
+ for x in self.vars:
569
+ break # we keep the first one and only one
570
+ ix = contextual_bn.idFromName(x)
571
+ if {contextual_bn.variable(i).name() for i in contextual_bn.parents(ix)} == self.knw:
572
+ p = contextual_bn.cpt(ix)
573
+
574
+ if p is None:
575
+ if len(self.knw) == 0:
576
+ ie.addJointTarget(self.vars)
577
+ ie.makeInference()
578
+ p = ie.jointPosterior(self.vars)
579
+ else:
580
+ ie.addJointTarget(self.vars | self.knw)
581
+ ie.makeInference()
582
+ p = ie.jointPosterior(self.vars | self.knw) / ie.jointPosterior(self.knw)
583
+
584
+ #
585
+ # res = p.extract({k: v for k, v in context.todict().items() if k in self.vars + self.knw})
586
+
587
+ if self._verbose:
588
+ print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {p}", flush=True)
589
+
590
+ return p
591
+
592
+
593
+ class ASTjointProba(ASTtree):
594
+ """
595
+ Represent a joint probability in the base observational part of the :class:`causal.CausalModel`
596
+
597
+ Parameters
598
+ ----------
599
+ varNames: Set[str]
600
+ a set of variable names
601
+ """
602
+
603
+ def __init__(self, varNames: NameSet):
604
+ """
605
+ Represent a joint probability in the base observational part of the :class:`causal.CausalModel`
606
+
607
+ Parameters
608
+ ----------
609
+ varNames: Set[str]
610
+ a set of variable names
611
+ """
612
+ super().__init__("_joint_")
613
+ self._varNames = varNames
614
+
615
+ @property
616
+ def varNames(self) -> NameSet:
617
+ """
618
+ Returns
619
+ -------
620
+ Set[str]
621
+ the set of names of var
622
+ """
623
+ return self._varNames
624
+
625
+ def __str__(self, prefix: str = "") -> str:
626
+ s = "P("
627
+ s += ",".join(sorted(self._varNames))
628
+ s += ")"
629
+ return f"{prefix}joint {s}"
630
+
631
+ def copy(self) -> "ASTtree":
632
+ return ASTjointProba(self.varNames)
633
+
634
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
635
+ return self.fastToLatex(nameOccur)
636
+
637
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
638
+ return "P\\left(" + ",".join(self._latexCorrect(self.varNames, nameOccur)) + "\\right)"
639
+
640
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
641
+ if self._verbose:
642
+ print(f"EVAL ${self.fastToLatex(defaultdict(int))}$ in context", flush=True)
643
+ ie = pyagrum.LazyPropagation(contextual_bn)
644
+ if len(self.varNames) > 1:
645
+ svars = set(self.varNames)
646
+ ie.addJointTarget(svars)
647
+ ie.makeInference()
648
+ res = ie.jointPosterior(svars)
649
+ else:
650
+ for name in self.varNames:
651
+ break # take the first and only one name in varNames
652
+ ie.makeInference()
653
+ res = ie.posterior(name)
654
+
655
+ if self._verbose:
656
+ print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {res}", flush=True)
657
+
658
+ return res
659
+
660
+
661
+ class ASTsum(ASTtree):
662
+ """
663
+ Represents a sum over a variable of a :class:`causal.ASTtree`.
664
+
665
+ Parameters
666
+ ----------
667
+ var: str
668
+ name of the variable on which to sum
669
+ term: ASTtree
670
+ the tree to be evaluated
671
+ """
672
+
673
+ def __init__(self, var: str, term: ASTtree):
674
+ """
675
+ Represents a sum over a variable of a :class:`causal.ASTtree`.
676
+
677
+ Parameters
678
+ ----------
679
+ var: str
680
+ name of the variable on which to sum
681
+ term: ASTtree
682
+ the tree to be evaluated
683
+ """
684
+ super().__init__("_sum_")
685
+
686
+ va = var if isinstance(var, list) else [var]
687
+ self.var = va[0]
688
+
689
+ if len(va) > 1:
690
+ self._term = ASTsum(va[1:], term)
691
+ else:
692
+ self._term = term
693
+
694
+ @property
695
+ def term(self) -> ASTtree:
696
+ """
697
+ Returns
698
+ -------
699
+ ASTtree
700
+ the term to sum
701
+ """
702
+ return self._term
703
+
704
+ def __str__(self, prefix: str = "") -> str:
705
+ l = []
706
+ a = self
707
+ while a.type == "_sum_":
708
+ l.append(a.var)
709
+ a = a.term
710
+ return f"{prefix}sum on {','.join(sorted(l))} for\n{a.__str__(prefix + self._continueNextLine)}"
711
+
712
+ def copy(self) -> "ASTtree":
713
+ return ASTsum(self.var, self.term.copy())
714
+
715
+ def protectToLatex(self, nameOccur: Dict[str, int]) -> str:
716
+ return "\\left(" + self.fastToLatex(nameOccur) + "\\right)"
717
+
718
+ def fastToLatex(self, nameOccur: Dict[str, int]) -> str:
719
+ la = []
720
+ a = self
721
+ while a.type == "_sum_":
722
+ la.append(a.var)
723
+ nameOccur[a.var] += 1
724
+ a = a.term
725
+
726
+ res = "\\sum_{" + (",".join(self._latexCorrect(la, nameOccur))) + "}{" + a.fastToLatex(nameOccur) + "}"
727
+ for v in la:
728
+ nameOccur[v] -= 1
729
+
730
+ return res
731
+
732
+ def eval(self, contextual_bn: "pyagrum.BayesNet") -> "pyagrum.Tensor":
733
+ if self._verbose:
734
+ print(f"EVAL ${self.fastToLatex(defaultdict(int))}$", flush=True)
735
+
736
+ res = self.term.eval(contextual_bn).sumOut([self.var])
737
+
738
+ if self._verbose:
739
+ print(f"END OF EVAL ${self.fastToLatex(defaultdict(int))}$ : {res}", flush=True)
740
+
741
+ return res
742
+
743
+
744
+ def productOfTrees(lterms: List[ASTtree]) -> ASTtree:
745
+ """
746
+ create an ASTtree for a sequence of multiplications of ASTtree
747
+
748
+ Parameters
749
+ ----------
750
+ lterms: List[ASTtree]
751
+ the trees (as ASTtree) to multiply
752
+
753
+ Returns
754
+ -------
755
+ ASTtree
756
+ the ASTtree representing the tree of multiplications
757
+
758
+ """
759
+ if len(lterms) == 1:
760
+ return lterms[0]
761
+ return ASTmult(lterms[0], productOfTrees(lterms[1:]))