pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +172 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
@@ -0,0 +1,1821 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ """
42
+ tools for BN in jupyter notebook
43
+ """
44
+
45
+ import time
46
+ import warnings
47
+
48
+ # fix DeprecationWarning of base64.encodestring()
49
+ try:
50
+ from base64 import encodebytes
51
+ except ImportError: # 3+
52
+ from base64 import encodestring as encodebytes
53
+
54
+ import io
55
+ import base64
56
+
57
+ import matplotlib as mpl
58
+ import matplotlib.pyplot as plt
59
+ from matplotlib_inline.backend_inline import set_matplotlib_formats
60
+
61
+ import numpy as np
62
+ import pydot as dot
63
+
64
+ try:
65
+ dot.call_graphviz("dot", ["--help"], ".")
66
+ except FileNotFoundError:
67
+ print("""Graphviz is not installed.
68
+ Please install this program in order to visualize graphical models in pyagrum.
69
+ See https://graphviz.org/download/""")
70
+
71
+ import IPython.core.display
72
+ import IPython.core.pylabtools
73
+ import IPython.display
74
+
75
+ import pyagrum as gum
76
+ from pyagrum.lib.bn2graph import BN2dot
77
+ from pyagrum.lib.cn2graph import CN2dot
78
+ from pyagrum.lib.id2graph import ID2dot
79
+ from pyagrum.lib.mrf2graph import MRF2UGdot
80
+ from pyagrum.lib.mrf2graph import MRF2FactorGraphdot
81
+ from pyagrum.lib.bn_vs_bn import graphDiff
82
+ from pyagrum.lib.proba_histogram import proba2histo, probaMinMaxH
83
+ from pyagrum.lib.image import prepareShowInference, prepareLinksForSVG
84
+
85
+ import pyagrum.lib._colors as gumcols
86
+
87
+
88
+ class FlowLayout(object):
89
+ """ "
90
+ A class / object to display plots in a horizontal / flow layout below a cell
91
+
92
+ based on : https://stackoverflow.com/questions/21754976/ipython-notebook-arrange-plots-horizontally
93
+ """
94
+
95
+ def __init__(self):
96
+ self.clear()
97
+
98
+ def clear(self):
99
+ """
100
+ clear the flow
101
+ """
102
+ # string buffer for the HTML: initially some CSS; images to be appended
103
+ self.sHtml = f"""
104
+ <style>
105
+ .floating-box {{
106
+ display: inline-block;
107
+ margin: 7px;
108
+ padding : 3px;
109
+ border: {gum.config.asInt["notebook", "flow_border_width"]}px solid {gum.config["notebook", "flow_border_color"]};
110
+ valign:middle;
111
+ background-color: {gum.config["notebook", "flow_background_color"]};
112
+ }}
113
+ </style>
114
+ """
115
+ return self
116
+
117
+ def _getCaption(self, caption):
118
+ if caption == "":
119
+ return ""
120
+ return f"<br><center><small><em>{caption}</em></small></center>"
121
+
122
+ def add_html(self, html, caption=None, title=None):
123
+ """
124
+ add an html element in the row (title is an obsolete parameter)
125
+ """
126
+ if caption is None:
127
+ if title is None:
128
+ cap = ""
129
+ else:
130
+ print("`title` is obsolete since `0.22.8`. Please use `caption`.")
131
+ cap = title
132
+ else:
133
+ cap = caption
134
+
135
+ self.sHtml += f'<div class="floating-box">{html}{self._getCaption(cap)}</div>'
136
+ return self
137
+
138
+ def add_separator(self, size=3):
139
+ """
140
+ add a (poor) separation between elements in a row
141
+ """
142
+ self.add_html("&nbsp;" * size)
143
+ return self
144
+
145
+ def add_plot(self, oAxes, caption=None, title=None):
146
+ """
147
+ Add a PNG representation of a Matplotlib Axes object
148
+ (title is an obsolete parameter)
149
+ """
150
+ if caption is None:
151
+ if title is None:
152
+ cap = ""
153
+ else:
154
+ print("`title` is obsolete since `0.22.8`. Please use `caption`.")
155
+ cap = title
156
+ else:
157
+ cap = caption
158
+
159
+ Bio = io.BytesIO() # bytes buffer for the plot
160
+ fig = oAxes.get_figure()
161
+ fig.canvas.print_png(Bio) # make a png of the plot in the buffer
162
+
163
+ # encode the bytes as string using base 64
164
+ sB64Img = base64.b64encode(Bio.getvalue()).decode()
165
+ self.sHtml += (
166
+ f'<div class="floating-box"><img src="data:image/png;base64,{sB64Img}\n">{self._getCaption(cap)}</div>'
167
+ )
168
+ plt.close()
169
+ return self
170
+
171
+ def new_line(self):
172
+ """
173
+ add a breakline (a new row)
174
+ """
175
+ self.sHtml += "<br/>"
176
+ return self
177
+
178
+ def html(self):
179
+ """
180
+ Returns its content as HTML object
181
+ """
182
+ return IPython.display.HTML(self.sHtml)
183
+
184
+ def display(self):
185
+ """
186
+ Display the accumulated HTML
187
+ """
188
+ IPython.display.display(self.html())
189
+ self.clear()
190
+
191
+ def add(self, obj, caption=None, title=None):
192
+ """
193
+ add an element in the row by trying to treat it as plot or html if possible.
194
+ (title is an obsolete parameter)
195
+ """
196
+ if caption is None:
197
+ if title is None:
198
+ cap = ""
199
+ else:
200
+ print("`title` is obsolete since `0.22.8`. Please use `caption`.")
201
+ cap = title
202
+ else:
203
+ cap = caption
204
+
205
+ if hasattr(obj, "get_figure"):
206
+ self.add_plot(obj, cap)
207
+ elif hasattr(obj, "_repr_html_"):
208
+ self.add_html(obj._repr_html_(), cap)
209
+ else:
210
+ self.add_html(obj, cap)
211
+
212
+ return self
213
+
214
+ def row(self, *args, captions=None):
215
+ """
216
+ Create a row with flow with the same syntax as `pyagrum.lib.notebook.sideBySide`.
217
+ """
218
+ self.clear()
219
+ for i, arg in enumerate(args):
220
+ if captions is None:
221
+ self.add(arg)
222
+ else:
223
+ self.add(arg, captions[i])
224
+
225
+ self.display()
226
+
227
+
228
+ flow = FlowLayout()
229
+
230
+
231
+ def configuration():
232
+ """
233
+ Display the collection of dependance and versions
234
+ """
235
+ from collections import OrderedDict
236
+ import sys
237
+ import os
238
+
239
+ packages = OrderedDict()
240
+ packages["OS"] = "%s [%s]" % (os.name, sys.platform)
241
+ packages["Python"] = sys.version
242
+ packages["IPython"] = IPython.__version__
243
+ packages["Matplotlib"] = mpl.__version__
244
+ packages["Numpy"] = np.__version__
245
+ packages["pyDot"] = dot.__version__
246
+ packages["pyAgrum"] = gum.__version__
247
+
248
+ res = "<table><tr><th>Library</th><th>Version</th></tr>"
249
+
250
+ for name in packages:
251
+ res += "<tr><td>%s</td><td>%s</td></tr>" % (name, packages[name])
252
+
253
+ res += "</table><div align='right'><small>%s</small></div>" % time.strftime("%a %b %d %H:%M:%S %Y %Z")
254
+
255
+ IPython.display.display(IPython.display.HTML(res))
256
+
257
+
258
+ def _reprGraph(gr, size, asString, graph_format=None):
259
+ """
260
+ repr a pydot graph in a notebook
261
+
262
+ Parameters
263
+ ----------
264
+ gr : dot.Dot
265
+ the dot representation of the graph
266
+ size: int | str
267
+ the size argument for the representation
268
+ asString : bool
269
+ display the graph or return a string containing the corresponding HTML fragment
270
+ graph_format: str
271
+ "svg" or "png" ?
272
+
273
+ Returns
274
+ -------
275
+ str | None
276
+ return the HTML representation as a str or display the graph
277
+ """
278
+ gumcols.prepareDot(gr, size=size)
279
+
280
+ if graph_format is None:
281
+ graph_format = gum.config["notebook", "graph_format"]
282
+
283
+ if graph_format == "svg":
284
+ gsvg = IPython.display.SVG(prepareLinksForSVG(gr.create_svg(encoding="utf-8").decode("utf-8")))
285
+ if asString:
286
+ return gsvg.data
287
+ else:
288
+ IPython.display.display(gsvg)
289
+ else:
290
+ i = IPython.core.display.Image(format="png", data=gr.create_png())
291
+ if asString:
292
+ return f'<img style="margin:0" src="data:image/png;base64,{encodebytes(i.data).decode()}"/>'
293
+ else:
294
+ IPython.core.display.display_png(i)
295
+
296
+
297
+ def showGraph(gr: dot.Dot, size=None):
298
+ """
299
+ show a pydot graph in a notebook
300
+
301
+ Parameters
302
+ ----------
303
+ gr: pydot.Dot
304
+ the graph
305
+ size: float|str
306
+ the size (for graphviz) of the rendered graph
307
+ """
308
+ if size is None:
309
+ size = gum.config["notebook", "default_graph_size"]
310
+
311
+ return _reprGraph(gr, size, asString=False)
312
+
313
+
314
+ def getGraph(gr: dot.Dot, size=None) -> str:
315
+ """
316
+ get an HTML representation of a pydot graph
317
+
318
+ Parameters
319
+ ----------
320
+ gr: pydot.Dot
321
+ the graph
322
+ size: float|str
323
+ the size (for graphviz) of the rendered graph
324
+
325
+ Returns
326
+ -------
327
+ str
328
+ the HTML representation of the graph (as a string)
329
+ """
330
+ if size is None:
331
+ size = gum.config["notebook", "default_graph_size"]
332
+
333
+ return _reprGraph(gr, size, asString=True)
334
+
335
+
336
+ def _from_dotstring(dotstring):
337
+ g = dot.graph_from_dot_data(dotstring)[0]
338
+ return g
339
+
340
+
341
+ def showDot(dotstring: str, size=None):
342
+ """
343
+ show a dot string as a graph
344
+
345
+ Parameters
346
+ ----------
347
+ dotstring:str
348
+ the dot string
349
+ size: float|str
350
+ size (for graphviz) of the rendered graph
351
+ """
352
+ if size is None:
353
+ size = gum.config["notebook", "default_graph_size"]
354
+ showGraph(_from_dotstring(dotstring), size)
355
+
356
+
357
+ def getDot(dotstring: str, size=None) -> str:
358
+ """
359
+ get an HTML representation of a dot string
360
+
361
+ Parameters
362
+ ----------
363
+ dotstring:str
364
+ the dot string
365
+ size: float|str
366
+ size (for graphviz) of the rendered graph
367
+
368
+ Returns
369
+ -------
370
+ the HTML representation of the dot string
371
+ """
372
+ if size is None:
373
+ size = gum.config["notebook", "default_graph_size"]
374
+
375
+ return getGraph(_from_dotstring(dotstring), size)
376
+
377
+
378
+ def getBNDiff(bn1, bn2, size=None, noStyle=False):
379
+ """
380
+ get a HTML string representation of a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
381
+
382
+ if `noStyle` is False use 4 styles (fixed in pyagrum.config) :
383
+ - the arc is common for both
384
+ - the arc is common but inverted in `bn2`
385
+ - the arc is added in `bn2`
386
+ - the arc is removed in `bn2`
387
+
388
+ Parameters
389
+ ----------
390
+ bn1: pyagrum.BayesNet
391
+ the reference
392
+ bn2: pyagrum.BayesNet
393
+ the compared one
394
+ size: float|str
395
+ size (for graphviz) of the rendered graph
396
+ noStyle: bool
397
+ with style or not.
398
+
399
+ Returns
400
+ -------
401
+ str
402
+ the HTML representation of the comparison
403
+ """
404
+ if size is None:
405
+ size = gum.config["notebook", "default_graph_size"]
406
+
407
+ return getGraph(graphDiff(bn1, bn2, noStyle), size)
408
+
409
+
410
+ def showBNDiff(bn1, bn2, size=None, noStyle=False):
411
+ """
412
+ show a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
413
+
414
+ if `noStyle` is False use 4 styles (fixed in pyagrum.config) :
415
+ - the arc is common for both
416
+ - the arc is common but inverted in `bn2`
417
+ - the arc is added in `bn2`
418
+ - the arc is removed in `bn2`
419
+
420
+ Parameters
421
+ ----------
422
+ bn1: pyagrum.BayesNet
423
+ the reference
424
+ bn2: pyagrum.BayesNet
425
+ the compared one
426
+ size: float|str
427
+ size (for graphviz) of the rendered graph
428
+ noStyle: bool
429
+ with style or not.
430
+ """
431
+ if size is None:
432
+ size = gum.config["notebook", "default_graph_size"]
433
+
434
+ showGraph(graphDiff(bn1, bn2, noStyle), size)
435
+
436
+
437
+ def getJunctionTreeMap(
438
+ bn,
439
+ size: str = None,
440
+ scaleClique: float = None,
441
+ scaleSep: float = None,
442
+ lenEdge: float = None,
443
+ colorClique: str = None,
444
+ colorSep: str = None,
445
+ ):
446
+ """
447
+ Return a representation of the map of the junction tree of a Bayesian network
448
+
449
+ Parameters
450
+ ----------
451
+ bn: pyagrum.BayesNet
452
+ the model
453
+ scaleClique: float
454
+ the scale for the size of the clique nodes (depending on the number of nodes in the clique)
455
+ scaleSep: float
456
+ the scale for the size of the separator nodes (depending on the number of nodes in the clique)
457
+ lenEdge: float
458
+ the desired length of edges
459
+ colorClique: str
460
+ color for the clique nodes
461
+ colorSep: str
462
+ color for the separator nodes
463
+ """
464
+ jtg = gum.JunctionTreeGenerator()
465
+ jt = jtg.junctionTree(bn)
466
+
467
+ if size is None:
468
+ size = gum.config["notebook", "junctiontree_map_size"]
469
+ return getGraph(jt.map(scaleClique, scaleSep, lenEdge, colorClique, colorSep), size)
470
+
471
+
472
+ def showJunctionTreeMap(
473
+ bn,
474
+ size: str = None,
475
+ scaleClique: float = None,
476
+ scaleSep: float = None,
477
+ lenEdge: float = None,
478
+ colorClique: str = None,
479
+ colorSep: str = None,
480
+ ):
481
+ """
482
+ Show the map of the junction tree of a Bayesian network
483
+
484
+ Parameters
485
+ ----------
486
+ bn: pyagrum.BayesNet
487
+ the model
488
+ scaleClique: float
489
+ the scale for the size of the clique nodes (depending on the number of nodes in the clique)
490
+ scaleSep: float
491
+ the scale for the size of the separator nodes (depending on the number of nodes in the clique)
492
+ lenEdge: float
493
+ the desired length of edges
494
+ colorClique: str
495
+ color for the clique nodes
496
+ colorSep: str
497
+ color for the separator nodes
498
+ """
499
+ jtg = gum.JunctionTreeGenerator()
500
+ jt = jtg.junctionTree(bn)
501
+
502
+ if size is None:
503
+ size = gum.config["notebook", "junctiontree_map_size"]
504
+ showGraph(jt.map(scaleClique, scaleSep, lenEdge, colorClique, colorSep), size)
505
+
506
+
507
+ def showJunctionTree(bn, withNames=True, size=None):
508
+ """
509
+ Show a junction tree of a Bayesian network
510
+
511
+ Parameters
512
+ ----------
513
+ bn: pyagrum.BayesNet
514
+ the model
515
+ withNames: bool
516
+ names or id in the graph (names can created very large nodes)
517
+ size: float|str
518
+ size (for graphviz) of the rendered graph
519
+ """
520
+ if size is None:
521
+ size = gum.config["notebook", "default_graph_size"]
522
+
523
+ jtg = gum.JunctionTreeGenerator()
524
+ jt = jtg.junctionTree(bn)
525
+
526
+ jt._engine = jtg
527
+ jtg._model = bn
528
+
529
+ if withNames:
530
+ showDot(jt.toDotWithNames(bn), size)
531
+ else:
532
+ showDot(jt.toDot(), size)
533
+
534
+
535
+ def getJunctionTree(bn, withNames=True, size=None):
536
+ """
537
+ get a HTML string for a junction tree (more specifically a join tree)
538
+
539
+ Parameters
540
+ ----------
541
+ bn: "pyagrum.BayesNet"
542
+ the Bayesian network
543
+ withNames: Boolean
544
+ display the variable names or the node id in the clique
545
+ size: str
546
+ size (for graphviz) of the rendered graph
547
+ Returns
548
+ -------
549
+ str
550
+ the HTML representation of the graph
551
+ """
552
+ if size is None:
553
+ size = gum.config["notebook", "junctiontree_graph_size"]
554
+
555
+ jtg = gum.JunctionTreeGenerator()
556
+ jt = jtg.junctionTree(bn)
557
+
558
+ jt._engine = jtg
559
+ jtg._model = bn
560
+
561
+ if withNames:
562
+ return getDot(jt.toDotWithNames(bn), size)
563
+ else:
564
+ return getDot(jt.toDot(), size)
565
+
566
+
567
+ def showProba(p, scale=None):
568
+ """
569
+ Show a mono-dim Tensor (a marginal)
570
+
571
+ Parameters
572
+ ----------
573
+ p: pyagrum.Tensor
574
+ the marginal to show
575
+ scale: float
576
+ the zoom factor
577
+ """
578
+ _ = proba2histo(p, scale)
579
+ set_matplotlib_formats(gum.config["notebook", "graph_format"])
580
+ plt.show()
581
+
582
+
583
+ def _getMatplotFig(fig):
584
+ bio = io.BytesIO() # bytes buffer for the plot
585
+ # .canvas.print_png(bio) # make a png of the plot in the buffer
586
+ fig.savefig(bio, format="png", bbox_inches="tight")
587
+
588
+ # encode the bytes as string using base 64
589
+ sB64Img = base64.b64encode(bio.getvalue()).decode()
590
+ res = f'<img src="data:image/png;base64,{sB64Img}\n">'
591
+ plt.close()
592
+ return res
593
+
594
+
595
+ def getProba(p, scale=None) -> str:
596
+ """
597
+ get a mono-dim Tensor as html (png/svg) image
598
+
599
+ Parameters
600
+ ----------
601
+ p: pyagrum.Tensor
602
+ the marginal to show
603
+ scale: float
604
+ the zoom factor
605
+
606
+ Returns
607
+ -------
608
+ str
609
+ the HTML representation of the marginal
610
+ """
611
+ set_matplotlib_formats(gum.config["notebook", "graph_format"])
612
+ # return _getMatplotFig(proba2histo(p, scale))
613
+ fig = proba2histo(p, scale)
614
+ plt.close()
615
+ return _getMatplotFig(fig)
616
+
617
+
618
+ def showProbaMinMax(pmin, pmax, scale=1.0):
619
+ """
620
+ Show a bi-Tensor (min,max)
621
+
622
+ Parameters
623
+ ----------
624
+ pmin: pyagrum.Tensor
625
+ the min pmarginal to show
626
+ pmax: pyagrum.Tensor
627
+ the max pmarginal to show
628
+ scale: float
629
+ the zoom factor
630
+ """
631
+ _ = probaMinMaxH(pmin, pmax, scale)
632
+ set_matplotlib_formats(gum.config["notebook", "graph_format"])
633
+ plt.show()
634
+
635
+
636
+ def getProbaMinMax(pmin, pmax, scale=1.0) -> str:
637
+ """
638
+ get a bi-Tensor (min,max) as html (png/svg) img
639
+
640
+ Parameters
641
+ ----------
642
+ pmin: pyagrum.Tensor
643
+ the min pmarginal to show
644
+ pmax: pyagrum.Tensor
645
+ the max pmarginal to show
646
+ scale: float
647
+ the zoom factor
648
+
649
+ Returns
650
+ -------
651
+ str
652
+ the HTML representation of the marginal min,max
653
+ """
654
+ set_matplotlib_formats(gum.config["notebook", "graph_format"])
655
+ return _getMatplotFig(probaMinMaxH(pmin, pmax, scale))
656
+
657
+
658
+ def getPosterior(bn, evs, target):
659
+ """
660
+ shortcut for proba2histo(gum.getPosterior(bn,evs,target))
661
+
662
+ Parameters
663
+ ----------
664
+ bn: "pyagrum.BayesNet"
665
+ the BayesNet
666
+ evs: Dict[str|int:int|str|List[float]]
667
+ map of evidence
668
+ target: str|int
669
+ name of target variable
670
+
671
+ Returns
672
+ ------
673
+ the matplotlib graph
674
+ """
675
+ fig = proba2histo(gum.getPosterior(bn, evs=evs, target=target))
676
+ plt.close()
677
+ return _getMatplotFig(fig)
678
+
679
+
680
+ def showPosterior(bn, evs, target):
681
+ """
682
+ shortcut for showProba(gum.getPosterior(bn,evs,target))
683
+
684
+ Parameters
685
+ ----------
686
+ bn: "pyagrum.BayesNet"
687
+ the BayesNet
688
+ evs: Dict[str|int:int|str|List[float]]
689
+ map of evidence
690
+ target: str|int
691
+ name of target variable
692
+ """
693
+ showProba(gum.getPosterior(bn, evs=evs, target=target))
694
+
695
+
696
+ def animApproximationScheme(apsc, scale=np.log10):
697
+ """
698
+ show an animated version of an approximation algorithm
699
+
700
+ Parameters
701
+ ----------
702
+ apsc
703
+ the approximation algorithm
704
+ scale
705
+ a function to apply to the figure
706
+ """
707
+ f = plt.gcf()
708
+
709
+ h = gum.PythonApproximationListener(apsc._asIApproximationSchemeConfiguration())
710
+ apsc.setVerbosity(True)
711
+ apsc.listener = h
712
+
713
+ def stopper(x):
714
+ IPython.display.clear_output(True)
715
+ plt.title(f"{x} \n Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}")
716
+
717
+ def progresser(x, y, z):
718
+ if len(apsc.history()) < 10:
719
+ plt.xlim(1, 10)
720
+ else:
721
+ plt.xlim(1, len(apsc.history()))
722
+ plt.plot(scale(apsc.history()), "g")
723
+ IPython.display.clear_output(True)
724
+ IPython.display.display(f)
725
+
726
+ h.setWhenStop(stopper)
727
+ h.setWhenProgress(progresser)
728
+
729
+
730
+ def showApproximationScheme(apsc, scale=np.log10):
731
+ """
732
+ show the state of an approximation algorithm
733
+
734
+ Parameters
735
+ ----------
736
+ apsc
737
+ the approximation algorithm
738
+ scale
739
+ a function to apply to the figure
740
+ """
741
+ if apsc.verbosity():
742
+ if len(apsc.history()) < 10:
743
+ plt.xlim(1, 10)
744
+ else:
745
+ plt.xlim(1, len(apsc.history()))
746
+ plt.title(f"Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}")
747
+ plt.plot(scale(apsc.history()), "g")
748
+
749
+
750
+ def showMRF(
751
+ mrf,
752
+ view=None,
753
+ size=None,
754
+ nodeColor=None,
755
+ factorColor=None,
756
+ edgeWidth=None,
757
+ edgeColor=None,
758
+ cmapNode=None,
759
+ cmapEdge=None,
760
+ ):
761
+ """
762
+ show a Markov random field
763
+
764
+ Parameters
765
+ ----------
766
+ mrf : "pyagrum.MarkovRandomField"
767
+ the Markov random field
768
+ view : str
769
+ 'graph' | 'factorgraph’ | None (default)
770
+ size : str
771
+ size (for graphviz) of the rendered graph
772
+ nodeColor: Dict[int,float]
773
+ a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
774
+ factorColor: Dict[int,float]
775
+ a function returning a value (between 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
776
+ edgeWidth: Dict[Tuple[int,int],float]
777
+ a edgeMap of values to be shown as width of edges (used when view='graph')
778
+ edgeColor: Dict[int,float]
779
+ a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
780
+ cmapNode: Dict[Tuple[int,int],float]
781
+ color map to show the colors (if cmapEdge is None, this color map is used also for edges)
782
+ cmapEdge: "matplotlib.ColorMap"
783
+ color map to show the edge color if distinction is needed
784
+
785
+ Returns
786
+ ------
787
+ the graph
788
+ """
789
+ if view is None:
790
+ view = gum.config["notebook", "default_markovrandomfield_view"]
791
+
792
+ if size is None:
793
+ size = gum.config["notebook", "default_graph_size"]
794
+
795
+ if cmapEdge is None:
796
+ cmapEdge = cmapNode
797
+
798
+ if view == "graph":
799
+ dottxt = MRF2UGdot(
800
+ mrf, size, nodeColor=nodeColor, edgeWidth=edgeWidth, edgeColor=edgeColor, cmapNode=cmapNode, cmapEdge=cmapEdge
801
+ )
802
+ else:
803
+ dottxt = MRF2FactorGraphdot(mrf, size, nodeColor=nodeColor, factorColor=factorColor, cmapNode=cmapNode)
804
+
805
+ return showGraph(dottxt, size)
806
+
807
+
808
+ def showInfluenceDiagram(diag, size=None):
809
+ """
810
+ show an influence diagram as a graph
811
+
812
+ Parameters
813
+ ----------
814
+ diag : "pyagrum.InfluenceDiagram"
815
+ the influence diagram
816
+ size : str
817
+ size (for graphviz) of the rendered graph
818
+
819
+ Returns
820
+ -------
821
+ the representation of the influence diagram
822
+ """
823
+ if size is None:
824
+ size = gum.config["influenceDiagram", "default_id_size"]
825
+
826
+ return showGraph(ID2dot(diag), size)
827
+
828
+
829
+ def getInfluenceDiagram(diag, size=None):
830
+ """
831
+ get a HTML string for an influence diagram as a graph
832
+
833
+ Parameters
834
+ ----------
835
+ diag : "pyagrum.InfluenceDiagram"
836
+ the influence diagram
837
+ size : str
838
+ size (for graphviz) of the rendered graph
839
+
840
+ Returns
841
+ -------
842
+ str
843
+ the HTML representation of the influence diagram
844
+ """
845
+ if size is None:
846
+ size = gum.config["influenceDiagram", "default_id_size"]
847
+
848
+ return getGraph(ID2dot(diag), size)
849
+
850
+
851
+ def showBN(bn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
852
+ """
853
+ show a Bayesian network
854
+
855
+ Parameters
856
+ ----------
857
+ bn : pyagrum.BayesNet
858
+ the Bayesian network
859
+ size: str
860
+ size (for graphviz) of the rendered graph
861
+ nodeColor: dict[Tuple(int,int),float]
862
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
863
+ arcWidth: dict[Tuple(int,int),float]
864
+ an arcMap of values to be shown as bold arcs
865
+ arcLabel: dict[Tuple(int,int),str]
866
+ an arcMap of labels to be shown next to arcs
867
+ arcColor: dict[Tuple(int,int),float]
868
+ an arcMap of values (between 0 and 1) to be shown as color of arcs
869
+ cmapNode: ColorMap
870
+ color map to show the vals of Nodes ( (if cmapEdge is None, this color map is used also for edges)
871
+ cmapArc: ColorMap
872
+ color map to show the vals of Arcs
873
+ showMsg: dict
874
+ a nodeMap of values to be shown as tooltip
875
+ """
876
+ if size is None:
877
+ size = gum.config["notebook", "default_graph_size"]
878
+
879
+ if cmapArc is None:
880
+ cmapArc = cmapNode
881
+
882
+ return showGraph(
883
+ BN2dot(
884
+ bn,
885
+ size=size,
886
+ nodeColor=nodeColor,
887
+ arcWidth=arcWidth,
888
+ arcLabel=arcLabel,
889
+ arcColor=arcColor,
890
+ cmapNode=cmapNode,
891
+ cmapArc=cmapArc,
892
+ ),
893
+ size,
894
+ )
895
+
896
+
897
+ def showCN(cn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
898
+ """
899
+ show a credal network
900
+
901
+ Parameters
902
+ ----------
903
+ cn : pyagrum.CredalNet
904
+ the Credal network
905
+ size: str
906
+ size (for graphviz) of the rendered graph
907
+ nodeColor: dict[int,float]
908
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
909
+ arcWidth: dict[Tuple(int,int),float]
910
+ an arcMap of values to be shown as bold arcs
911
+ arcLabel: dict[Tuple(int,int),float]
912
+ an arcMap of labels to be shown next to arcs
913
+ arcColor: dict[Tuple(int,int),float]
914
+ an arcMap of values (between 0 and 1) to be shown as color of arcs
915
+ cmapNode: matplotlib.color.colormap
916
+ color map to show the vals of Nodes
917
+ cmapArc: matplotlib.color.colormap
918
+ color map to show the vals of Arcs
919
+ showMsg : dict[int,str]
920
+ a nodeMap of values to be shown as tooltip
921
+
922
+ Returns
923
+ -------
924
+ the graph
925
+ """
926
+ if size is None:
927
+ size = gum.config["notebook", "default_graph_size"]
928
+
929
+ if cmapArc is None:
930
+ cmapArc = cmapNode
931
+
932
+ return showGraph(
933
+ CN2dot(
934
+ cn,
935
+ size=size,
936
+ nodeColor=nodeColor,
937
+ arcWidth=arcWidth,
938
+ arcLabel=arcLabel,
939
+ arcColor=arcColor,
940
+ cmapNode=cmapNode,
941
+ cmapArc=cmapArc,
942
+ ),
943
+ size,
944
+ )
945
+
946
+
947
+ def getMRF(
948
+ mrf,
949
+ view=None,
950
+ size=None,
951
+ nodeColor=None,
952
+ factorColor=None,
953
+ edgeWidth=None,
954
+ edgeColor=None,
955
+ cmapNode=None,
956
+ cmapEdge=None,
957
+ ):
958
+ """
959
+ get an HTML string for a Markov random field
960
+
961
+ Parameters
962
+ ----------
963
+ mrf : "pyagrum.MarkovRandomField"
964
+ the Markov random field
965
+ view: str
966
+ 'graph' | 'factorgraph’ | None (default)
967
+ size: str
968
+ size (for graphviz) of the rendered graph
969
+ nodeColor: Dict[str,float]
970
+ a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
971
+ factorColor: Dict[str,float]
972
+ a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
973
+ edgeWidth: Dict[Tuple[str,str],float]
974
+ a edgeMap of values to be shown as width of edges (used when view='graph')
975
+ edgeColor: Dict[Tuple[str,str],float]
976
+ a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
977
+ cmapNode: matplotlib.ColorMap
978
+ color map to show the colors (if cmapEdge is None, cmapNode is used for edges)
979
+ cmapEdge: matplotlib.ColorMap
980
+ color map to show the edge color if distinction is needed
981
+
982
+ Returns
983
+ -------
984
+ the graph
985
+ """
986
+ if size is None:
987
+ size = gum.config["notebook", "default_graph_size"]
988
+
989
+ if cmapEdge is None:
990
+ cmapEdge = cmapNode
991
+
992
+ if view is None:
993
+ view = gum.config["notebook", "default_markovrandomfield_view"]
994
+
995
+ if view == "graph":
996
+ dottxt = MRF2UGdot(mrf, size, nodeColor, edgeWidth, edgeColor, cmapNode, cmapEdge)
997
+ else:
998
+ dottxt = MRF2FactorGraphdot(mrf, size, nodeColor, factorColor, cmapNode=cmapNode)
999
+
1000
+ return getGraph(dottxt, size)
1001
+
1002
+
1003
+ def getBN(bn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
1004
+ """
1005
+ get a HTML string for a Bayesian network
1006
+
1007
+ Parameters
1008
+ ----------
1009
+ bn : pyagrum.BayesNet
1010
+ the Bayesian network
1011
+ size: str
1012
+ size (for graphviz) of the rendered graph
1013
+ nodeColor: dict[Tuple(int,int),float]
1014
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
1015
+ arcWidth: dict[Tuple(int,int),float]
1016
+ an arcMap of values to be shown as bold arcs
1017
+ arcLabel: dict[Tuple(int,int),str]
1018
+ an arcMap of labels to be shown next to arcs
1019
+ arcColor: dict[Tuple(int,int),float]
1020
+ an arcMap of values (between 0 and 1) to be shown as color of arcs
1021
+ cmapNode: ColorMap
1022
+ color map to show the vals of Nodes
1023
+ cmapArc: ColorMap
1024
+ color map to show the vals of Arcs
1025
+ showMsg: dict
1026
+ a nodeMap of values to be shown as tooltip
1027
+
1028
+ Returns
1029
+ -------
1030
+ pydot.Dot
1031
+ the desired representation of the Bayesian network
1032
+ """
1033
+ if size is None:
1034
+ size = gum.config["notebook", "default_graph_size"]
1035
+
1036
+ if cmapArc is None:
1037
+ cmapArc = cmapNode
1038
+
1039
+ return getGraph(
1040
+ BN2dot(
1041
+ bn,
1042
+ size=size,
1043
+ nodeColor=nodeColor,
1044
+ arcWidth=arcWidth,
1045
+ arcLabel=arcLabel,
1046
+ arcColor=arcColor,
1047
+ cmapNode=cmapNode,
1048
+ cmapArc=cmapArc,
1049
+ ),
1050
+ size,
1051
+ )
1052
+
1053
+
1054
+ def getCN(cn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
1055
+ """
1056
+ get a HTML string for a credal network
1057
+
1058
+ Parameters
1059
+ ----------
1060
+ cn : pyagrum.CredalNet
1061
+ the Credal network
1062
+ size: str
1063
+ size (for graphviz) of the rendered graph
1064
+ nodeColor: dict[int,float]
1065
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
1066
+ arcWidth: dict[Tuple(int,int),float]
1067
+ an arcMap of values to be shown as bold arcs
1068
+ arcLabel: dict[Tuple(int,int),float]
1069
+ an arcMap of labels to be shown next to arcs
1070
+ arcColor: dict[Tuple(int,int),float]
1071
+ an arcMap of values (between 0 and 1) to be shown as color of arcs
1072
+ cmapNode: matplotlib.color.colormap
1073
+ color map to show the vals of Nodes
1074
+ cmapArc: matplotlib.color.colormap
1075
+ color map to show the vals of Arcs
1076
+ showMsg : dict[int,str]
1077
+ a nodeMap of values to be shown as tooltip
1078
+
1079
+ Returns
1080
+ -------
1081
+ pydot.Dot
1082
+ the desired representation of the Credal Network
1083
+ """
1084
+ if size is None:
1085
+ size = gum.config["notebook", "default_graph_size"]
1086
+
1087
+ if cmapArc is None:
1088
+ cmapArc = cmapNode
1089
+
1090
+ return getGraph(
1091
+ CN2dot(
1092
+ cn,
1093
+ size=size,
1094
+ nodeColor=nodeColor,
1095
+ arcWidth=arcWidth,
1096
+ arcLabel=arcLabel,
1097
+ arcColor=arcColor,
1098
+ cmapNode=cmapNode,
1099
+ cmapArc=cmapArc,
1100
+ ),
1101
+ size,
1102
+ )
1103
+
1104
+
1105
+ def showInference(model, **kwargs):
1106
+ """
1107
+ show pydot graph for an inference in a notebook
1108
+
1109
+ Parameters
1110
+ ----------
1111
+ model: pyagrum.GraphicalModel
1112
+ the model in which to infer (pyagrum.BayesNet, pyagrum.MarkovRandomField or pyagrum.InfluenceDiagram)
1113
+ engine: gum.Inference
1114
+ inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovRandomField and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
1115
+ evs: Dict[int|str,int|str|List[float]]
1116
+ map of evidence
1117
+ targets: Set[str]
1118
+ set of targets
1119
+ size: string
1120
+ size (for graphviz) of the rendered graph
1121
+ nodeColor: Dict[str,float]
1122
+ a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
1123
+ factorColor: Dict[int,float]
1124
+ a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
1125
+ arcWidth: : Dict[(str,str),float]
1126
+ an arcMap of values to be shown as width of arcs
1127
+ arcColor: : Dict[(str,str),float]
1128
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
1129
+ cmapNode: matplotlib.ColorMap
1130
+ map to show the color of nodes and arcs
1131
+ cmapArc: matplotlib.ColorMap
1132
+ color map to show the vals of Arcs.
1133
+ graph: pyagrum.Graph
1134
+ only shows nodes that have their id in the graph (and not in the whole BN)
1135
+ view: str
1136
+ graph | factorgraph | None (default) for Markov random field
1137
+
1138
+ Returns
1139
+ -------
1140
+ the desired representation of the inference
1141
+ """
1142
+ if "size" in kwargs:
1143
+ size = kwargs["size"]
1144
+ else:
1145
+ size = gum.config["notebook", "default_graph_inference_size"]
1146
+
1147
+ showGraph(prepareShowInference(model, **kwargs), size)
1148
+
1149
+
1150
+ def getInference(model, **kwargs):
1151
+ """
1152
+ get a HTML string for an inference in a notebook
1153
+
1154
+ Parameters
1155
+ ----------
1156
+ model: pyagrum.GraphicalModel
1157
+ the model in which to infer (pyagrum.BayesNet, pyagrum.MarkovRandomField or pyagrum.InfluenceDiagram)
1158
+ engine: gum.Inference
1159
+ inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovRandomField and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
1160
+ evs: Dict[int|str,int|str|List[float]]
1161
+ map of evidence
1162
+ targets: Set[str]
1163
+ set of targets
1164
+ size: string
1165
+ size (for graphviz) of the rendered graph
1166
+ nodeColor: Dict[str,float]
1167
+ a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
1168
+ factorColor: Dict[int,float]
1169
+ a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
1170
+ arcWidth: : Dict[(str,str),float]
1171
+ an arcMap of values to be shown as width of arcs
1172
+ arcColor: : Dict[(str,str),float]
1173
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
1174
+ cmapNode: matplotlib.ColorMap
1175
+ map to show the color of nodes and arcs
1176
+ cmapArc: matplotlib.ColorMap
1177
+ color map to show the vals of Arcs.
1178
+ graph: pyagrum.Graph
1179
+ only shows nodes that have their id in the graph (and not in the whole BN)
1180
+ view: str
1181
+ graph | factorgraph | None (default) for Markov random field
1182
+
1183
+ Returns
1184
+ -------
1185
+ the desired representation of the inference
1186
+ """
1187
+ if "size" in kwargs:
1188
+ size = kwargs["size"]
1189
+ else:
1190
+ size = gum.config["notebook", "default_graph_inference_size"]
1191
+
1192
+ grinf = prepareShowInference(model, **kwargs)
1193
+ return getGraph(grinf, size)
1194
+
1195
+
1196
+ def _reprTensor(pot, digits=None, withColors=None, varnames=None, asString=False):
1197
+ """
1198
+ return a representation of a gum.Tensor as a HTML table.
1199
+ The first dimension is special (horizontal) due to the representation of conditional probability table
1200
+
1201
+ Parameters
1202
+ ---------
1203
+ pot: gum.Tensor
1204
+ the tensor to get
1205
+ digits: int
1206
+ number of digits to show
1207
+ withColors: Boolean
1208
+ background color for proba cells or not
1209
+ varnames: Dict[str,str]
1210
+ a mapping that gives the aliases for variables name in the table
1211
+ asString: Boolean
1212
+ display the table or a HTML string
1213
+
1214
+ Returns
1215
+ -------
1216
+ str
1217
+ the HTML table that represents the tensor
1218
+ """
1219
+ from fractions import Fraction
1220
+
1221
+ r0, g0, b0 = gumcols.hex2rgb(gum.config["notebook", "tensor_color_0"])
1222
+ r1, g1, b1 = gumcols.hex2rgb(gum.config["notebook", "tensor_color_1"])
1223
+
1224
+ if digits is None:
1225
+ digits = gum.config.asInt["notebook", "tensor_visible_digits"]
1226
+
1227
+ if withColors is None:
1228
+ withColors = gum.config.asBool["notebook", "tensor_with_colors"]
1229
+
1230
+ with_fraction = gum.config["notebook", "tensor_with_fraction"] == "True"
1231
+ if with_fraction:
1232
+ fraction_limit = int(gum.config["notebook", "tensor_fraction_limit"])
1233
+ fraction_round_error = float(gum.config["notebook", "tensor_fraction_round_error"])
1234
+ fraction_with_latex = gum.config["notebook", "tensor_fraction_with_latex"] == "True"
1235
+
1236
+ def _rgb(r, g, b):
1237
+ return "#%02x%02x%02x" % (r, g, b)
1238
+
1239
+ def _mkCell(val):
1240
+ s = "<td style='"
1241
+ if withColors and (0 <= val <= 1):
1242
+ r = int(r0 + val * (r1 - r0))
1243
+ g = int(g0 + val * (g1 - g0))
1244
+ b = int(b0 + val * (b1 - b0))
1245
+
1246
+ tx = gumcols.rgb2brightness(r, g, b)
1247
+
1248
+ s += "color:" + tx + ";background-color:" + _rgb(r, g, b) + ";"
1249
+
1250
+ str_val = ""
1251
+ if with_fraction:
1252
+ frac_val = Fraction(val).limit_denominator(fraction_limit)
1253
+ val_app = frac_val.numerator / frac_val.denominator
1254
+ if abs(val_app - val) < fraction_round_error:
1255
+ str_val = "text-align:center;'>"
1256
+ if fraction_with_latex:
1257
+ str_val += "$$"
1258
+ if frac_val.denominator > 1:
1259
+ str_val += f"\\frac{{{frac_val.numerator}}}{{{frac_val.denominator}}}"
1260
+ else:
1261
+ str_val += f"{frac_val.numerator}"
1262
+ str_val += "$$"
1263
+ else:
1264
+ str_val += f"{frac_val}"
1265
+ str_val += "</td>"
1266
+ if str_val == "":
1267
+ str_val = f"text-align:right;padding: 3px;'>{val:.{digits}f}</td>"
1268
+
1269
+ return s + str_val
1270
+
1271
+ html = list()
1272
+ html.append('<table style="border:1px solid black;border-collapse: collapse;">')
1273
+ if pot.empty():
1274
+ html.append("<tr><th>&nbsp;</th></tr>")
1275
+ html.append("<tr>" + _mkCell(pot.get(gum.Instantiation())) + "</tr>")
1276
+ else:
1277
+ if varnames is not None and len(varnames) != pot.nbrDim():
1278
+ raise ValueError(f"varnames contains {len(varnames)} value(s) instead of the needed {pot.nbrDim()} value(s).")
1279
+
1280
+ nparents = pot.nbrDim() - 1
1281
+ var = pot.variable(0)
1282
+ varname = var.name() if varnames is None else varnames[0]
1283
+
1284
+ # first line
1285
+ if nparents > 0:
1286
+ html.append(f"""<tr><th colspan='{nparents}'></th>
1287
+ <th colspan='{var.domainSize()}' style='border:1px solid black;color:black;background-color:#808080;'><center>{varname}</center>
1288
+ </th></tr>""")
1289
+ else:
1290
+ html.append(f"""<tr style='border:1px solid black;color:black;background-color:#808080'>
1291
+ <th colspan='{var.domainSize()}'><center>{varname}</center></th></tr>""")
1292
+
1293
+ # second line
1294
+ s = "<tr>"
1295
+ if nparents > 0:
1296
+ # parents order
1297
+ if gum.config["notebook", "tensor_parent_values"] == "revmerge":
1298
+ pmin, pmax, pinc = nparents - 1, 0 - 1, -1
1299
+ else:
1300
+ pmin, pmax, pinc = 0, nparents, 1
1301
+
1302
+ if varnames is None:
1303
+ varnames = list(reversed(pot.names))
1304
+ for par in range(pmin, pmax, pinc):
1305
+ parent = varnames[par]
1306
+ s += f"<th style='border:1px solid black;color:black;background-color:#808080'><center>{parent}</center></th>"
1307
+
1308
+ for label in var.labels():
1309
+ s += f"""<th style='border:1px solid black;border-bottom-style: double;color:black;background-color:#BBBBBB'>
1310
+ <center>{label}</center></th>"""
1311
+ s += "</tr>"
1312
+
1313
+ html.append(s)
1314
+
1315
+ inst = gum.Instantiation(pot)
1316
+ off = 1
1317
+ offset = dict()
1318
+ for i in range(1, nparents + 1):
1319
+ offset[i] = off
1320
+ off *= inst.variable(i).domainSize()
1321
+
1322
+ inst.setFirst()
1323
+ while not inst.end():
1324
+ s = "<tr>"
1325
+ # parents order
1326
+ if gum.config["notebook", "tensor_parent_values"] == "revmerge":
1327
+ pmin, pmax, pinc = 1, nparents + 1, 1
1328
+ else:
1329
+ pmin, pmax, pinc = nparents, 0, -1
1330
+ for par in range(pmin, pmax, pinc):
1331
+ label = inst.variable(par).label(inst.val(par))
1332
+ if par == 1 or gum.config["notebook", "tensor_parent_values"] == "nomerge":
1333
+ s += f"<th style='border:1px solid black;color:black;background-color:#BBBBBB'><center>{label}</center></th>"
1334
+ else:
1335
+ if sum([inst.val(i) for i in range(1, par)]) == 0:
1336
+ s += f"""<th style='border:1px solid black;color:black;background-color:#BBBBBB;' rowspan = '{offset[par]}'>
1337
+ <center>{label}</center></th>"""
1338
+ for _ in range(pot.variable(0).domainSize()):
1339
+ s += _mkCell(pot.get(inst))
1340
+ inst.inc()
1341
+ s += "</tr>"
1342
+ html.append(s)
1343
+
1344
+ html.append("</table>")
1345
+
1346
+ if asString:
1347
+ return "\n".join(html)
1348
+ else:
1349
+ return IPython.display.HTML("".join(html))
1350
+
1351
+
1352
+ def __isKindOfProba(pot):
1353
+ """
1354
+ check if pot is a joint proba or a CPT
1355
+
1356
+ Parameters
1357
+ ----------
1358
+ pot: gum.Tensor
1359
+ the tensor
1360
+
1361
+ Returns
1362
+ -------
1363
+ True or False
1364
+ """
1365
+ epsilon = 1e-5
1366
+ if pot.min() < -epsilon:
1367
+ return False
1368
+ if pot.max() > 1 + epsilon:
1369
+ return False
1370
+
1371
+ # is it a joint proba ?
1372
+ if abs(pot.sum() - 1) < epsilon:
1373
+ return True
1374
+
1375
+ # marginal and then not proba (because of the test just above)
1376
+ if pot.nbrDim() < 2:
1377
+ return False
1378
+
1379
+ # is is a CPT ?
1380
+ q = pot.sumOut([pot.variable(0).name()])
1381
+ if abs(q.max() - 1) > epsilon:
1382
+ return False
1383
+ if abs(q.min() - 1) > epsilon:
1384
+ return False
1385
+
1386
+ return True
1387
+
1388
+
1389
+ def showPotential(pot, digits=None, withColors=None, varnames=None):
1390
+ warnings.warn("showPotential is deprecated since pyAgrum 2.0.0. Use showTensor instead", DeprecationWarning)
1391
+ showTensor(pot, digits, withColors, varnames)
1392
+
1393
+
1394
+ def showTensor(pot, digits=None, withColors=None, varnames=None):
1395
+ """
1396
+ show a gum.Tensor as a HTML table.
1397
+ The first dimension is special (horizontal) due to the representation of conditional probability table
1398
+
1399
+ Parameters
1400
+ ----------
1401
+ pot : gum.Tensor
1402
+ the tensor to show
1403
+ digits : int
1404
+ number of digits to show
1405
+ withColors : bool
1406
+ background color for proba cells or not
1407
+ varnames : List[str]
1408
+ the aliases for variables name in the table
1409
+ """
1410
+ if withColors is None:
1411
+ withColors = gum.config.asBool["notebook", "tensor_with_colors"]
1412
+
1413
+ if withColors:
1414
+ withColors = __isKindOfProba(pot)
1415
+
1416
+ s = _reprTensor(pot, digits, withColors, varnames, asString=False)
1417
+ IPython.display.display(s)
1418
+
1419
+
1420
+ def getPotential(pot, digits=None, withColors=None, varnames=None):
1421
+ warnings.warn("getPotential is deprecated since pyAgrum 2.0.0. Use getTensor instead", DeprecationWarning)
1422
+ return getTensor(pot, digits, withColors, varnames)
1423
+
1424
+
1425
+ def getTensor(pot, digits=None, withColors=None, varnames=None):
1426
+ """
1427
+ return a HTML string of a gum.Tensor as a HTML table.
1428
+ The first dimension is special (horizontal) due to the representation of conditional probability table
1429
+
1430
+ Parameters
1431
+ ----------
1432
+ pot : gum.Tensor
1433
+ the tensor to show
1434
+ digits : int
1435
+ number of digits to show
1436
+ withColors : bool
1437
+ background for proba cells or not
1438
+ varnames : List[str]
1439
+ the aliases for variables name in the table
1440
+
1441
+ Returns
1442
+ -------
1443
+ str
1444
+ the html representation of the Tensor (as a string)
1445
+ """
1446
+ if withColors is None:
1447
+ withColors = gum.config.asBool["notebook", "tensor_with_colors"]
1448
+
1449
+ if withColors:
1450
+ withColors = __isKindOfProba(pot)
1451
+
1452
+ return _reprTensor(pot, digits, withColors, varnames, asString=True)
1453
+
1454
+
1455
+ def showCPTs(bn):
1456
+ flow.clear()
1457
+ for i in bn.names():
1458
+ flow.add_html(getTensor(bn.cpt(i)))
1459
+ flow.display()
1460
+
1461
+
1462
+ def getSideBySide(*args, **kwargs):
1463
+ """
1464
+ create an HTML table for args as string (using string, _repr_html_() or str())
1465
+
1466
+ Parameters
1467
+ ----------
1468
+ args: str
1469
+ HMTL fragments as string arg, arg._repr_html_() or str(arg)
1470
+ captions: List[str], optional
1471
+ list of captions
1472
+ valign: str
1473
+ vertical position in the row (top|middle|bottom, middle by default)
1474
+ ncols: int
1475
+ number of columns (infinite by default)
1476
+
1477
+ Returns
1478
+ -------
1479
+ str
1480
+ a string representing the table
1481
+ """
1482
+ vals = {"captions", "valign", "ncols"}
1483
+ if not set(kwargs.keys()).issubset(vals):
1484
+ raise TypeError(f"sideBySide() got unexpected keyword argument(s) : '{set(kwargs.keys()).difference(vals)}'")
1485
+
1486
+ if "captions" in kwargs:
1487
+ captions = kwargs["captions"]
1488
+ else:
1489
+ captions = None
1490
+
1491
+ if "valign" in kwargs and kwargs["valign"] in ["top", "middle", "bottom"]:
1492
+ v_align = f"vertical-align:{kwargs['valign']};"
1493
+ else:
1494
+ v_align = "vertical-align:middle;"
1495
+
1496
+ ncols = None
1497
+ if "ncols" in kwargs:
1498
+ ncols = int(kwargs["ncols"])
1499
+ if ncols < 1:
1500
+ ncols = 1
1501
+
1502
+ def reprHTML(s):
1503
+ if isinstance(s, str):
1504
+ return s
1505
+ elif hasattr(s, "_repr_html_"):
1506
+ return s._repr_html_()
1507
+ else:
1508
+ return str(s)
1509
+
1510
+ s = '<table style="border-style: hidden; border-collapse: collapse;" width="100%"><tr>'
1511
+ for i in range(len(args)):
1512
+ s += f'<td style="border-top:hidden;border-bottom:hidden;{v_align}"><div align="center" style="{v_align}">'
1513
+ s += reprHTML(args[i])
1514
+ if captions is not None:
1515
+ s += f"<br><small><i>{captions[i]}</i></small>"
1516
+ s += "</div></td>"
1517
+ if ncols is not None and (i + 1) % ncols == 0:
1518
+ s += "</tr><tr>"
1519
+ s += "</tr></table>"
1520
+ return s
1521
+
1522
+
1523
+ def sideBySide(*args, **kwargs):
1524
+ """
1525
+ display side by side args as HMTL fragment (using string, _repr_html_() or str())
1526
+
1527
+ Parameters
1528
+ ----------
1529
+ args: str
1530
+ HMTL fragments as string arg, arg._repr_html_() or str(arg)
1531
+ captions: List[str], optional
1532
+ list of captions
1533
+ valign: str
1534
+ vertical position in the row (top|middle|bottom, middle by default)
1535
+ ncols: int
1536
+ number of columns (infinite by default)
1537
+ """
1538
+ IPython.display.display(IPython.display.HTML(getSideBySide(*args, **kwargs)))
1539
+
1540
+
1541
+ def getInferenceEngine(ie, inferenceCaption):
1542
+ """
1543
+ display an inference as a BN+ lists of hard/soft evidence and list of targets
1544
+
1545
+ Parameters
1546
+ ---------
1547
+ ie : "pyagrum.InferenceEngine"
1548
+ Inference engine
1549
+ caption: str
1550
+ inferenceCaption: caption for the inference
1551
+
1552
+ Returns
1553
+ -------
1554
+ str
1555
+ the HTML representation
1556
+ """
1557
+ t = '<div align="left"><ul>'
1558
+ if ie.nbrHardEvidence() > 0:
1559
+ t += "<li><b>hard evidence</b><br/>"
1560
+ t += ", ".join([ie.BN().variable(n).name() for n in ie.hardEvidenceNodes()])
1561
+ t += "</li>"
1562
+ if ie.nbrSoftEvidence() > 0:
1563
+ t += "<li><b>soft evidence</b><br/>"
1564
+ t += ", ".join([ie.BN().variable(n).name() for n in ie.softEvidenceNodes()])
1565
+ t += "</li>"
1566
+ if ie.nbrTargets() > 0:
1567
+ t += "<li><b>target(s)</b><br/>"
1568
+ if ie.nbrTargets() == ie.BN().size():
1569
+ t += " all"
1570
+ else:
1571
+ t += ", ".join([ie.BN().variable(n).name() for n in ie.targets()])
1572
+ t += "</li>"
1573
+
1574
+ if hasattr(ie, "nbrJointTargets") and ie.nbrJointTargets() > 0:
1575
+ t += "<li><b>Joint target(s)</b><br/>"
1576
+ t += ", ".join(["[" + (", ".join([ie.BN().variable(n).name() for n in ns])) + "]" for ns in ie.jointTargets()])
1577
+ t += "</li>"
1578
+ t += "</ul></div>"
1579
+ return getSideBySide(getBN(ie.BN()), t, captions=[inferenceCaption, "Evidence and targets"])
1580
+
1581
+
1582
+ def getJT(jt, size=None):
1583
+ """
1584
+ returns the representation of a junction tree as a HTML string
1585
+
1586
+ Parameters
1587
+ ----------
1588
+ jt: pyagrum.JunctionTree
1589
+ the junction tree
1590
+ size: str
1591
+ the size (for graphviz) of the graph
1592
+
1593
+ Returns
1594
+ -------
1595
+ str
1596
+ the representation of a junction tree as a HTML string
1597
+
1598
+ """
1599
+ if gum.config.asBool["notebook", "junctiontree_with_names"]:
1600
+
1601
+ def cliqlabels(c):
1602
+ labels = ",".join(sorted([model.variable(n).name() for n in jt.clique(c)]))
1603
+ return f"({c}):{labels}"
1604
+
1605
+ def cliqnames(c):
1606
+ return "-".join(sorted([model.variable(n).name() for n in jt.clique(c)]))
1607
+
1608
+ def seplabels(c1, c2):
1609
+ return ",".join(sorted([model.variable(n).name() for n in jt.separator(c1, c2)]))
1610
+
1611
+ def sepnames(c1, c2):
1612
+ return cliqnames(c1) + "+" + cliqnames(c2)
1613
+ else:
1614
+
1615
+ def cliqlabels(c):
1616
+ ids = ",".join([str(n) for n in sorted(jt.clique(c))])
1617
+ return f"({c}):{ids}"
1618
+
1619
+ def cliqnames(c):
1620
+ return "-".join([str(n) for n in sorted(jt.clique(c))])
1621
+
1622
+ def seplabels(c1, c2):
1623
+ return ",".join([str(n) for n in sorted(jt.separator(c1, c2))])
1624
+
1625
+ def sepnames(c1, c2):
1626
+ return cliqnames(c1) + "^" + cliqnames(c2)
1627
+
1628
+ model = jt._engine._model
1629
+ name = model.propertyWithDefault("name", str(type(model)).split(".")[-1][:-2])
1630
+ graph = dot.Dot(graph_type="graph", graph_name=name, bgcolor="transparent")
1631
+ for c in jt.nodes():
1632
+ graph.add_node(
1633
+ dot.Node(
1634
+ '"' + cliqnames(c) + '"',
1635
+ label='"' + cliqlabels(c) + '"',
1636
+ style="filled",
1637
+ fillcolor=gum.config["notebook", "junctiontree_clique_bgcolor"],
1638
+ fontcolor=gum.config["notebook", "junctiontree_clique_fgcolor"],
1639
+ fontsize=gum.config["notebook", "junctiontree_clique_fontsize"],
1640
+ )
1641
+ )
1642
+ for c1, c2 in jt.edges():
1643
+ graph.add_node(
1644
+ dot.Node(
1645
+ '"' + sepnames(c1, c2) + '"',
1646
+ label='"' + seplabels(c1, c2) + '"',
1647
+ style="filled",
1648
+ shape="box",
1649
+ width="0",
1650
+ height="0",
1651
+ margin="0.02",
1652
+ fillcolor=gum.config["notebook", "junctiontree_separator_bgcolor"],
1653
+ fontcolor=gum.config["notebook", "junctiontree_separator_fgcolor"],
1654
+ fontsize=gum.config["notebook", "junctiontree_separator_fontsize"],
1655
+ )
1656
+ )
1657
+ for c1, c2 in jt.edges():
1658
+ graph.add_edge(dot.Edge('"' + cliqnames(c1) + '"', '"' + sepnames(c1, c2) + '"'))
1659
+ graph.add_edge(dot.Edge('"' + sepnames(c1, c2) + '"', '"' + cliqnames(c2) + '"'))
1660
+
1661
+ graph.set_size(gum.config["notebook", "junctiontree_graph_size"])
1662
+
1663
+ return graph.to_string()
1664
+
1665
+
1666
+ def getCliqueGraph(cg, size=None):
1667
+ """get a representation for clique graph. Special case for junction tree
1668
+ (clique graph with an attribute `_engine`)
1669
+
1670
+ Parameters
1671
+ cg (gum.CliqueGraph): the clique graph (maybe junction tree for a _model) to
1672
+ represent
1673
+
1674
+ Returns
1675
+ -------
1676
+ pydot.Dot
1677
+ the dot representation of the graph
1678
+ """
1679
+ if hasattr(cg, "_engine"):
1680
+ return getDot(getJT(cg), size)
1681
+ else:
1682
+ return getDot(cg.toDot())
1683
+
1684
+
1685
+ def show(model, **kwargs):
1686
+ """
1687
+ propose a (visual) representation of a graphical model or a graph or a Tensor in a notebook
1688
+
1689
+ Parameters
1690
+ ----------
1691
+ model
1692
+ the model to show (pyagrum.BayesNet, pyagrum.MarkovRandomField, pyagrum.InfluenceDiagram or pyagrum.Tensor) or a dot string, or a `pydot.Dot` or even just an object with a method `toDot()`.
1693
+
1694
+ size: str
1695
+ size (for graphviz) to represent the graphical model (no effect for Tensor)
1696
+ nodeColor: Dict[str,float]
1697
+ a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
1698
+ factorColor: Dict[int,float]
1699
+ a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
1700
+ arcWidth: : Dict[(str,str),float]
1701
+ an arcMap of values to be shown as width of arcs
1702
+ arcColor: : Dict[(str,str),float]
1703
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
1704
+ cmapNode: matplotlib.ColorMap
1705
+ map to show the color of nodes and arcs
1706
+ cmapArc: matplotlib.ColorMap
1707
+ color map to show the vals of Arcs.
1708
+ graph: pyagrum.Graph
1709
+ only shows nodes that have their id in the graph (and not in the whole BN)
1710
+ view: str
1711
+ graph | factorgraph | None (default) for Markov random field
1712
+ """
1713
+ if isinstance(model, gum.BayesNet):
1714
+ showBN(model, **kwargs)
1715
+ elif isinstance(model, gum.MarkovRandomField):
1716
+ showMRF(model, **kwargs)
1717
+ elif isinstance(model, gum.InfluenceDiagram):
1718
+ showInfluenceDiagram(model, **kwargs)
1719
+ elif isinstance(model, gum.CredalNet):
1720
+ showCN(model, **kwargs)
1721
+ elif isinstance(model, gum.Tensor):
1722
+ showTensor(model)
1723
+ elif hasattr(model, "toDot"):
1724
+ showDot(model.toDot(), **kwargs)
1725
+ elif isinstance(model, dot.Dot):
1726
+ showGraph(model, **kwargs)
1727
+ else:
1728
+ raise gum.InvalidArgument(
1729
+ "Argument model should be a PGM (BayesNet, MarkovRandomField, Influence Diagram or Tensor or ..."
1730
+ )
1731
+
1732
+
1733
+ def inspectBN(bn):
1734
+ """
1735
+ inspect a BN (graph and CPTs) in a notebook (using flow)
1736
+ Parameters
1737
+ ----------
1738
+ bn
1739
+ """
1740
+ flow.row(bn, *[bn.cpt(c) for c in sorted(bn.names())])
1741
+
1742
+
1743
+ def _update_config_notebooks():
1744
+ # hook to control some parameters for notebook when config changes
1745
+ mpl.rcParams["figure.facecolor"] = gum.config["notebook", "figure_facecolor"]
1746
+ set_matplotlib_formats(gum.config["notebook", "graph_format"])
1747
+
1748
+
1749
+ # check if an instance of ipython exists
1750
+ try:
1751
+ get_ipython
1752
+ except NameError:
1753
+ import warnings
1754
+
1755
+ warnings.warn("""
1756
+ ** pyagrum.lib.notebook has to be imported from an IPython's instance (mainly notebook).
1757
+ """)
1758
+ else:
1759
+
1760
+ def map(
1761
+ self,
1762
+ scaleClique: float = None,
1763
+ scaleSep: float = None,
1764
+ lenEdge: float = None,
1765
+ colorClique: str = None,
1766
+ colorSep: str = None,
1767
+ ) -> dot.Dot:
1768
+ """
1769
+ show the map of the junction tree.
1770
+
1771
+ Parameters
1772
+ ----------
1773
+ scaleClique: float
1774
+ the scale for the size of the clique nodes (depending on the number of nodes in the clique)
1775
+ scaleSep: float
1776
+ the scale for the size of the separator nodes (depending on the number of nodes in the clique)
1777
+ lenEdge: float
1778
+ the desired length of edges
1779
+ colorClique: str
1780
+ color for the clique nodes
1781
+ colorSep: str
1782
+ color for the separator nodes
1783
+ """
1784
+ if scaleClique is None:
1785
+ scaleClique = float(gum.config["notebook", "junctiontree_map_cliquescale"])
1786
+ if scaleSep is None:
1787
+ scaleSep = float(gum.config["notebook", "junctiontree_map_sepscale"])
1788
+ if lenEdge is None:
1789
+ lenEdge = float(gum.config["notebook", "junctiontree_map_edgelen"])
1790
+ if colorClique is None:
1791
+ colorClique = gum.config["notebook", "junctiontree_clique_bgcolor"]
1792
+ if colorSep is None:
1793
+ colorSep = gum.config["notebook", "junctiontree_separator_bgcolor"]
1794
+ return _from_dotstring(self.__map_str__(scaleClique, scaleSep, lenEdge, colorClique, colorSep))
1795
+
1796
+ setattr(gum.CliqueGraph, "map", map)
1797
+
1798
+ gum.config.add_hook(_update_config_notebooks)
1799
+ gum.config.run_hooks()
1800
+
1801
+ # adding _repr_html_ to some pyAgrum classes !
1802
+ gum.BayesNet._repr_html_ = lambda self: getBN(self)
1803
+ gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
1804
+ gum.MarkovRandomField._repr_html_ = lambda self: getMRF(self)
1805
+ gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
1806
+ gum.InfluenceDiagram._repr_html_ = lambda self: getInfluenceDiagram(self)
1807
+ gum.CredalNet._repr_html_ = lambda self: getCN(self)
1808
+
1809
+ gum.CliqueGraph._repr_html_ = lambda self: getCliqueGraph(self)
1810
+
1811
+ gum.Tensor._repr_html_ = lambda self: getTensor(self)
1812
+ gum.LazyPropagation._repr_html_ = lambda self: getInferenceEngine(self, "Lazy Propagation on this BN")
1813
+
1814
+ gum.UndiGraph._repr_html_ = lambda self: getDot(self.toDot())
1815
+ gum.DiGraph._repr_html_ = lambda self: getDot(self.toDot())
1816
+ gum.MixedGraph._repr_html_ = lambda self: getDot(self.toDot())
1817
+ gum.DAG._repr_html_ = lambda self: getDot(self.toDot())
1818
+ gum.EssentialGraph._repr_html_ = lambda self: getDot(self.toDot())
1819
+ gum.MarkovBlanket._repr_html_ = lambda self: getDot(self.toDot())
1820
+
1821
+ dot.Dot._repr_html_ = lambda self: getGraph(self)