pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/lib/notebook.py
ADDED
|
@@ -0,0 +1,1821 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
tools for BN in jupyter notebook
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
import time
|
|
46
|
+
import warnings
|
|
47
|
+
|
|
48
|
+
# fix DeprecationWarning of base64.encodestring()
|
|
49
|
+
try:
|
|
50
|
+
from base64 import encodebytes
|
|
51
|
+
except ImportError: # 3+
|
|
52
|
+
from base64 import encodestring as encodebytes
|
|
53
|
+
|
|
54
|
+
import io
|
|
55
|
+
import base64
|
|
56
|
+
|
|
57
|
+
import matplotlib as mpl
|
|
58
|
+
import matplotlib.pyplot as plt
|
|
59
|
+
from matplotlib_inline.backend_inline import set_matplotlib_formats
|
|
60
|
+
|
|
61
|
+
import numpy as np
|
|
62
|
+
import pydot as dot
|
|
63
|
+
|
|
64
|
+
try:
|
|
65
|
+
dot.call_graphviz("dot", ["--help"], ".")
|
|
66
|
+
except FileNotFoundError:
|
|
67
|
+
print("""Graphviz is not installed.
|
|
68
|
+
Please install this program in order to visualize graphical models in pyagrum.
|
|
69
|
+
See https://graphviz.org/download/""")
|
|
70
|
+
|
|
71
|
+
import IPython.core.display
|
|
72
|
+
import IPython.core.pylabtools
|
|
73
|
+
import IPython.display
|
|
74
|
+
|
|
75
|
+
import pyagrum as gum
|
|
76
|
+
from pyagrum.lib.bn2graph import BN2dot
|
|
77
|
+
from pyagrum.lib.cn2graph import CN2dot
|
|
78
|
+
from pyagrum.lib.id2graph import ID2dot
|
|
79
|
+
from pyagrum.lib.mrf2graph import MRF2UGdot
|
|
80
|
+
from pyagrum.lib.mrf2graph import MRF2FactorGraphdot
|
|
81
|
+
from pyagrum.lib.bn_vs_bn import graphDiff
|
|
82
|
+
from pyagrum.lib.proba_histogram import proba2histo, probaMinMaxH
|
|
83
|
+
from pyagrum.lib.image import prepareShowInference, prepareLinksForSVG
|
|
84
|
+
|
|
85
|
+
import pyagrum.lib._colors as gumcols
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class FlowLayout(object):
|
|
89
|
+
""" "
|
|
90
|
+
A class / object to display plots in a horizontal / flow layout below a cell
|
|
91
|
+
|
|
92
|
+
based on : https://stackoverflow.com/questions/21754976/ipython-notebook-arrange-plots-horizontally
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self):
|
|
96
|
+
self.clear()
|
|
97
|
+
|
|
98
|
+
def clear(self):
|
|
99
|
+
"""
|
|
100
|
+
clear the flow
|
|
101
|
+
"""
|
|
102
|
+
# string buffer for the HTML: initially some CSS; images to be appended
|
|
103
|
+
self.sHtml = f"""
|
|
104
|
+
<style>
|
|
105
|
+
.floating-box {{
|
|
106
|
+
display: inline-block;
|
|
107
|
+
margin: 7px;
|
|
108
|
+
padding : 3px;
|
|
109
|
+
border: {gum.config.asInt["notebook", "flow_border_width"]}px solid {gum.config["notebook", "flow_border_color"]};
|
|
110
|
+
valign:middle;
|
|
111
|
+
background-color: {gum.config["notebook", "flow_background_color"]};
|
|
112
|
+
}}
|
|
113
|
+
</style>
|
|
114
|
+
"""
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def _getCaption(self, caption):
|
|
118
|
+
if caption == "":
|
|
119
|
+
return ""
|
|
120
|
+
return f"<br><center><small><em>{caption}</em></small></center>"
|
|
121
|
+
|
|
122
|
+
def add_html(self, html, caption=None, title=None):
|
|
123
|
+
"""
|
|
124
|
+
add an html element in the row (title is an obsolete parameter)
|
|
125
|
+
"""
|
|
126
|
+
if caption is None:
|
|
127
|
+
if title is None:
|
|
128
|
+
cap = ""
|
|
129
|
+
else:
|
|
130
|
+
print("`title` is obsolete since `0.22.8`. Please use `caption`.")
|
|
131
|
+
cap = title
|
|
132
|
+
else:
|
|
133
|
+
cap = caption
|
|
134
|
+
|
|
135
|
+
self.sHtml += f'<div class="floating-box">{html}{self._getCaption(cap)}</div>'
|
|
136
|
+
return self
|
|
137
|
+
|
|
138
|
+
def add_separator(self, size=3):
|
|
139
|
+
"""
|
|
140
|
+
add a (poor) separation between elements in a row
|
|
141
|
+
"""
|
|
142
|
+
self.add_html(" " * size)
|
|
143
|
+
return self
|
|
144
|
+
|
|
145
|
+
def add_plot(self, oAxes, caption=None, title=None):
|
|
146
|
+
"""
|
|
147
|
+
Add a PNG representation of a Matplotlib Axes object
|
|
148
|
+
(title is an obsolete parameter)
|
|
149
|
+
"""
|
|
150
|
+
if caption is None:
|
|
151
|
+
if title is None:
|
|
152
|
+
cap = ""
|
|
153
|
+
else:
|
|
154
|
+
print("`title` is obsolete since `0.22.8`. Please use `caption`.")
|
|
155
|
+
cap = title
|
|
156
|
+
else:
|
|
157
|
+
cap = caption
|
|
158
|
+
|
|
159
|
+
Bio = io.BytesIO() # bytes buffer for the plot
|
|
160
|
+
fig = oAxes.get_figure()
|
|
161
|
+
fig.canvas.print_png(Bio) # make a png of the plot in the buffer
|
|
162
|
+
|
|
163
|
+
# encode the bytes as string using base 64
|
|
164
|
+
sB64Img = base64.b64encode(Bio.getvalue()).decode()
|
|
165
|
+
self.sHtml += (
|
|
166
|
+
f'<div class="floating-box"><img src="data:image/png;base64,{sB64Img}\n">{self._getCaption(cap)}</div>'
|
|
167
|
+
)
|
|
168
|
+
plt.close()
|
|
169
|
+
return self
|
|
170
|
+
|
|
171
|
+
def new_line(self):
|
|
172
|
+
"""
|
|
173
|
+
add a breakline (a new row)
|
|
174
|
+
"""
|
|
175
|
+
self.sHtml += "<br/>"
|
|
176
|
+
return self
|
|
177
|
+
|
|
178
|
+
def html(self):
|
|
179
|
+
"""
|
|
180
|
+
Returns its content as HTML object
|
|
181
|
+
"""
|
|
182
|
+
return IPython.display.HTML(self.sHtml)
|
|
183
|
+
|
|
184
|
+
def display(self):
|
|
185
|
+
"""
|
|
186
|
+
Display the accumulated HTML
|
|
187
|
+
"""
|
|
188
|
+
IPython.display.display(self.html())
|
|
189
|
+
self.clear()
|
|
190
|
+
|
|
191
|
+
def add(self, obj, caption=None, title=None):
|
|
192
|
+
"""
|
|
193
|
+
add an element in the row by trying to treat it as plot or html if possible.
|
|
194
|
+
(title is an obsolete parameter)
|
|
195
|
+
"""
|
|
196
|
+
if caption is None:
|
|
197
|
+
if title is None:
|
|
198
|
+
cap = ""
|
|
199
|
+
else:
|
|
200
|
+
print("`title` is obsolete since `0.22.8`. Please use `caption`.")
|
|
201
|
+
cap = title
|
|
202
|
+
else:
|
|
203
|
+
cap = caption
|
|
204
|
+
|
|
205
|
+
if hasattr(obj, "get_figure"):
|
|
206
|
+
self.add_plot(obj, cap)
|
|
207
|
+
elif hasattr(obj, "_repr_html_"):
|
|
208
|
+
self.add_html(obj._repr_html_(), cap)
|
|
209
|
+
else:
|
|
210
|
+
self.add_html(obj, cap)
|
|
211
|
+
|
|
212
|
+
return self
|
|
213
|
+
|
|
214
|
+
def row(self, *args, captions=None):
|
|
215
|
+
"""
|
|
216
|
+
Create a row with flow with the same syntax as `pyagrum.lib.notebook.sideBySide`.
|
|
217
|
+
"""
|
|
218
|
+
self.clear()
|
|
219
|
+
for i, arg in enumerate(args):
|
|
220
|
+
if captions is None:
|
|
221
|
+
self.add(arg)
|
|
222
|
+
else:
|
|
223
|
+
self.add(arg, captions[i])
|
|
224
|
+
|
|
225
|
+
self.display()
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
flow = FlowLayout()
|
|
229
|
+
|
|
230
|
+
|
|
231
|
+
def configuration():
|
|
232
|
+
"""
|
|
233
|
+
Display the collection of dependance and versions
|
|
234
|
+
"""
|
|
235
|
+
from collections import OrderedDict
|
|
236
|
+
import sys
|
|
237
|
+
import os
|
|
238
|
+
|
|
239
|
+
packages = OrderedDict()
|
|
240
|
+
packages["OS"] = "%s [%s]" % (os.name, sys.platform)
|
|
241
|
+
packages["Python"] = sys.version
|
|
242
|
+
packages["IPython"] = IPython.__version__
|
|
243
|
+
packages["Matplotlib"] = mpl.__version__
|
|
244
|
+
packages["Numpy"] = np.__version__
|
|
245
|
+
packages["pyDot"] = dot.__version__
|
|
246
|
+
packages["pyAgrum"] = gum.__version__
|
|
247
|
+
|
|
248
|
+
res = "<table><tr><th>Library</th><th>Version</th></tr>"
|
|
249
|
+
|
|
250
|
+
for name in packages:
|
|
251
|
+
res += "<tr><td>%s</td><td>%s</td></tr>" % (name, packages[name])
|
|
252
|
+
|
|
253
|
+
res += "</table><div align='right'><small>%s</small></div>" % time.strftime("%a %b %d %H:%M:%S %Y %Z")
|
|
254
|
+
|
|
255
|
+
IPython.display.display(IPython.display.HTML(res))
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def _reprGraph(gr, size, asString, graph_format=None):
|
|
259
|
+
"""
|
|
260
|
+
repr a pydot graph in a notebook
|
|
261
|
+
|
|
262
|
+
Parameters
|
|
263
|
+
----------
|
|
264
|
+
gr : dot.Dot
|
|
265
|
+
the dot representation of the graph
|
|
266
|
+
size: int | str
|
|
267
|
+
the size argument for the representation
|
|
268
|
+
asString : bool
|
|
269
|
+
display the graph or return a string containing the corresponding HTML fragment
|
|
270
|
+
graph_format: str
|
|
271
|
+
"svg" or "png" ?
|
|
272
|
+
|
|
273
|
+
Returns
|
|
274
|
+
-------
|
|
275
|
+
str | None
|
|
276
|
+
return the HTML representation as a str or display the graph
|
|
277
|
+
"""
|
|
278
|
+
gumcols.prepareDot(gr, size=size)
|
|
279
|
+
|
|
280
|
+
if graph_format is None:
|
|
281
|
+
graph_format = gum.config["notebook", "graph_format"]
|
|
282
|
+
|
|
283
|
+
if graph_format == "svg":
|
|
284
|
+
gsvg = IPython.display.SVG(prepareLinksForSVG(gr.create_svg(encoding="utf-8").decode("utf-8")))
|
|
285
|
+
if asString:
|
|
286
|
+
return gsvg.data
|
|
287
|
+
else:
|
|
288
|
+
IPython.display.display(gsvg)
|
|
289
|
+
else:
|
|
290
|
+
i = IPython.core.display.Image(format="png", data=gr.create_png())
|
|
291
|
+
if asString:
|
|
292
|
+
return f'<img style="margin:0" src="data:image/png;base64,{encodebytes(i.data).decode()}"/>'
|
|
293
|
+
else:
|
|
294
|
+
IPython.core.display.display_png(i)
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def showGraph(gr: dot.Dot, size=None):
|
|
298
|
+
"""
|
|
299
|
+
show a pydot graph in a notebook
|
|
300
|
+
|
|
301
|
+
Parameters
|
|
302
|
+
----------
|
|
303
|
+
gr: pydot.Dot
|
|
304
|
+
the graph
|
|
305
|
+
size: float|str
|
|
306
|
+
the size (for graphviz) of the rendered graph
|
|
307
|
+
"""
|
|
308
|
+
if size is None:
|
|
309
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
310
|
+
|
|
311
|
+
return _reprGraph(gr, size, asString=False)
|
|
312
|
+
|
|
313
|
+
|
|
314
|
+
def getGraph(gr: dot.Dot, size=None) -> str:
|
|
315
|
+
"""
|
|
316
|
+
get an HTML representation of a pydot graph
|
|
317
|
+
|
|
318
|
+
Parameters
|
|
319
|
+
----------
|
|
320
|
+
gr: pydot.Dot
|
|
321
|
+
the graph
|
|
322
|
+
size: float|str
|
|
323
|
+
the size (for graphviz) of the rendered graph
|
|
324
|
+
|
|
325
|
+
Returns
|
|
326
|
+
-------
|
|
327
|
+
str
|
|
328
|
+
the HTML representation of the graph (as a string)
|
|
329
|
+
"""
|
|
330
|
+
if size is None:
|
|
331
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
332
|
+
|
|
333
|
+
return _reprGraph(gr, size, asString=True)
|
|
334
|
+
|
|
335
|
+
|
|
336
|
+
def _from_dotstring(dotstring):
|
|
337
|
+
g = dot.graph_from_dot_data(dotstring)[0]
|
|
338
|
+
return g
|
|
339
|
+
|
|
340
|
+
|
|
341
|
+
def showDot(dotstring: str, size=None):
|
|
342
|
+
"""
|
|
343
|
+
show a dot string as a graph
|
|
344
|
+
|
|
345
|
+
Parameters
|
|
346
|
+
----------
|
|
347
|
+
dotstring:str
|
|
348
|
+
the dot string
|
|
349
|
+
size: float|str
|
|
350
|
+
size (for graphviz) of the rendered graph
|
|
351
|
+
"""
|
|
352
|
+
if size is None:
|
|
353
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
354
|
+
showGraph(_from_dotstring(dotstring), size)
|
|
355
|
+
|
|
356
|
+
|
|
357
|
+
def getDot(dotstring: str, size=None) -> str:
|
|
358
|
+
"""
|
|
359
|
+
get an HTML representation of a dot string
|
|
360
|
+
|
|
361
|
+
Parameters
|
|
362
|
+
----------
|
|
363
|
+
dotstring:str
|
|
364
|
+
the dot string
|
|
365
|
+
size: float|str
|
|
366
|
+
size (for graphviz) of the rendered graph
|
|
367
|
+
|
|
368
|
+
Returns
|
|
369
|
+
-------
|
|
370
|
+
the HTML representation of the dot string
|
|
371
|
+
"""
|
|
372
|
+
if size is None:
|
|
373
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
374
|
+
|
|
375
|
+
return getGraph(_from_dotstring(dotstring), size)
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def getBNDiff(bn1, bn2, size=None, noStyle=False):
|
|
379
|
+
"""
|
|
380
|
+
get a HTML string representation of a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
|
|
381
|
+
|
|
382
|
+
if `noStyle` is False use 4 styles (fixed in pyagrum.config) :
|
|
383
|
+
- the arc is common for both
|
|
384
|
+
- the arc is common but inverted in `bn2`
|
|
385
|
+
- the arc is added in `bn2`
|
|
386
|
+
- the arc is removed in `bn2`
|
|
387
|
+
|
|
388
|
+
Parameters
|
|
389
|
+
----------
|
|
390
|
+
bn1: pyagrum.BayesNet
|
|
391
|
+
the reference
|
|
392
|
+
bn2: pyagrum.BayesNet
|
|
393
|
+
the compared one
|
|
394
|
+
size: float|str
|
|
395
|
+
size (for graphviz) of the rendered graph
|
|
396
|
+
noStyle: bool
|
|
397
|
+
with style or not.
|
|
398
|
+
|
|
399
|
+
Returns
|
|
400
|
+
-------
|
|
401
|
+
str
|
|
402
|
+
the HTML representation of the comparison
|
|
403
|
+
"""
|
|
404
|
+
if size is None:
|
|
405
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
406
|
+
|
|
407
|
+
return getGraph(graphDiff(bn1, bn2, noStyle), size)
|
|
408
|
+
|
|
409
|
+
|
|
410
|
+
def showBNDiff(bn1, bn2, size=None, noStyle=False):
|
|
411
|
+
"""
|
|
412
|
+
show a graphical diff between the arcs of _bn1 (reference) with those of _bn2.
|
|
413
|
+
|
|
414
|
+
if `noStyle` is False use 4 styles (fixed in pyagrum.config) :
|
|
415
|
+
- the arc is common for both
|
|
416
|
+
- the arc is common but inverted in `bn2`
|
|
417
|
+
- the arc is added in `bn2`
|
|
418
|
+
- the arc is removed in `bn2`
|
|
419
|
+
|
|
420
|
+
Parameters
|
|
421
|
+
----------
|
|
422
|
+
bn1: pyagrum.BayesNet
|
|
423
|
+
the reference
|
|
424
|
+
bn2: pyagrum.BayesNet
|
|
425
|
+
the compared one
|
|
426
|
+
size: float|str
|
|
427
|
+
size (for graphviz) of the rendered graph
|
|
428
|
+
noStyle: bool
|
|
429
|
+
with style or not.
|
|
430
|
+
"""
|
|
431
|
+
if size is None:
|
|
432
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
433
|
+
|
|
434
|
+
showGraph(graphDiff(bn1, bn2, noStyle), size)
|
|
435
|
+
|
|
436
|
+
|
|
437
|
+
def getJunctionTreeMap(
|
|
438
|
+
bn,
|
|
439
|
+
size: str = None,
|
|
440
|
+
scaleClique: float = None,
|
|
441
|
+
scaleSep: float = None,
|
|
442
|
+
lenEdge: float = None,
|
|
443
|
+
colorClique: str = None,
|
|
444
|
+
colorSep: str = None,
|
|
445
|
+
):
|
|
446
|
+
"""
|
|
447
|
+
Return a representation of the map of the junction tree of a Bayesian network
|
|
448
|
+
|
|
449
|
+
Parameters
|
|
450
|
+
----------
|
|
451
|
+
bn: pyagrum.BayesNet
|
|
452
|
+
the model
|
|
453
|
+
scaleClique: float
|
|
454
|
+
the scale for the size of the clique nodes (depending on the number of nodes in the clique)
|
|
455
|
+
scaleSep: float
|
|
456
|
+
the scale for the size of the separator nodes (depending on the number of nodes in the clique)
|
|
457
|
+
lenEdge: float
|
|
458
|
+
the desired length of edges
|
|
459
|
+
colorClique: str
|
|
460
|
+
color for the clique nodes
|
|
461
|
+
colorSep: str
|
|
462
|
+
color for the separator nodes
|
|
463
|
+
"""
|
|
464
|
+
jtg = gum.JunctionTreeGenerator()
|
|
465
|
+
jt = jtg.junctionTree(bn)
|
|
466
|
+
|
|
467
|
+
if size is None:
|
|
468
|
+
size = gum.config["notebook", "junctiontree_map_size"]
|
|
469
|
+
return getGraph(jt.map(scaleClique, scaleSep, lenEdge, colorClique, colorSep), size)
|
|
470
|
+
|
|
471
|
+
|
|
472
|
+
def showJunctionTreeMap(
|
|
473
|
+
bn,
|
|
474
|
+
size: str = None,
|
|
475
|
+
scaleClique: float = None,
|
|
476
|
+
scaleSep: float = None,
|
|
477
|
+
lenEdge: float = None,
|
|
478
|
+
colorClique: str = None,
|
|
479
|
+
colorSep: str = None,
|
|
480
|
+
):
|
|
481
|
+
"""
|
|
482
|
+
Show the map of the junction tree of a Bayesian network
|
|
483
|
+
|
|
484
|
+
Parameters
|
|
485
|
+
----------
|
|
486
|
+
bn: pyagrum.BayesNet
|
|
487
|
+
the model
|
|
488
|
+
scaleClique: float
|
|
489
|
+
the scale for the size of the clique nodes (depending on the number of nodes in the clique)
|
|
490
|
+
scaleSep: float
|
|
491
|
+
the scale for the size of the separator nodes (depending on the number of nodes in the clique)
|
|
492
|
+
lenEdge: float
|
|
493
|
+
the desired length of edges
|
|
494
|
+
colorClique: str
|
|
495
|
+
color for the clique nodes
|
|
496
|
+
colorSep: str
|
|
497
|
+
color for the separator nodes
|
|
498
|
+
"""
|
|
499
|
+
jtg = gum.JunctionTreeGenerator()
|
|
500
|
+
jt = jtg.junctionTree(bn)
|
|
501
|
+
|
|
502
|
+
if size is None:
|
|
503
|
+
size = gum.config["notebook", "junctiontree_map_size"]
|
|
504
|
+
showGraph(jt.map(scaleClique, scaleSep, lenEdge, colorClique, colorSep), size)
|
|
505
|
+
|
|
506
|
+
|
|
507
|
+
def showJunctionTree(bn, withNames=True, size=None):
|
|
508
|
+
"""
|
|
509
|
+
Show a junction tree of a Bayesian network
|
|
510
|
+
|
|
511
|
+
Parameters
|
|
512
|
+
----------
|
|
513
|
+
bn: pyagrum.BayesNet
|
|
514
|
+
the model
|
|
515
|
+
withNames: bool
|
|
516
|
+
names or id in the graph (names can created very large nodes)
|
|
517
|
+
size: float|str
|
|
518
|
+
size (for graphviz) of the rendered graph
|
|
519
|
+
"""
|
|
520
|
+
if size is None:
|
|
521
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
522
|
+
|
|
523
|
+
jtg = gum.JunctionTreeGenerator()
|
|
524
|
+
jt = jtg.junctionTree(bn)
|
|
525
|
+
|
|
526
|
+
jt._engine = jtg
|
|
527
|
+
jtg._model = bn
|
|
528
|
+
|
|
529
|
+
if withNames:
|
|
530
|
+
showDot(jt.toDotWithNames(bn), size)
|
|
531
|
+
else:
|
|
532
|
+
showDot(jt.toDot(), size)
|
|
533
|
+
|
|
534
|
+
|
|
535
|
+
def getJunctionTree(bn, withNames=True, size=None):
|
|
536
|
+
"""
|
|
537
|
+
get a HTML string for a junction tree (more specifically a join tree)
|
|
538
|
+
|
|
539
|
+
Parameters
|
|
540
|
+
----------
|
|
541
|
+
bn: "pyagrum.BayesNet"
|
|
542
|
+
the Bayesian network
|
|
543
|
+
withNames: Boolean
|
|
544
|
+
display the variable names or the node id in the clique
|
|
545
|
+
size: str
|
|
546
|
+
size (for graphviz) of the rendered graph
|
|
547
|
+
Returns
|
|
548
|
+
-------
|
|
549
|
+
str
|
|
550
|
+
the HTML representation of the graph
|
|
551
|
+
"""
|
|
552
|
+
if size is None:
|
|
553
|
+
size = gum.config["notebook", "junctiontree_graph_size"]
|
|
554
|
+
|
|
555
|
+
jtg = gum.JunctionTreeGenerator()
|
|
556
|
+
jt = jtg.junctionTree(bn)
|
|
557
|
+
|
|
558
|
+
jt._engine = jtg
|
|
559
|
+
jtg._model = bn
|
|
560
|
+
|
|
561
|
+
if withNames:
|
|
562
|
+
return getDot(jt.toDotWithNames(bn), size)
|
|
563
|
+
else:
|
|
564
|
+
return getDot(jt.toDot(), size)
|
|
565
|
+
|
|
566
|
+
|
|
567
|
+
def showProba(p, scale=None):
|
|
568
|
+
"""
|
|
569
|
+
Show a mono-dim Tensor (a marginal)
|
|
570
|
+
|
|
571
|
+
Parameters
|
|
572
|
+
----------
|
|
573
|
+
p: pyagrum.Tensor
|
|
574
|
+
the marginal to show
|
|
575
|
+
scale: float
|
|
576
|
+
the zoom factor
|
|
577
|
+
"""
|
|
578
|
+
_ = proba2histo(p, scale)
|
|
579
|
+
set_matplotlib_formats(gum.config["notebook", "graph_format"])
|
|
580
|
+
plt.show()
|
|
581
|
+
|
|
582
|
+
|
|
583
|
+
def _getMatplotFig(fig):
|
|
584
|
+
bio = io.BytesIO() # bytes buffer for the plot
|
|
585
|
+
# .canvas.print_png(bio) # make a png of the plot in the buffer
|
|
586
|
+
fig.savefig(bio, format="png", bbox_inches="tight")
|
|
587
|
+
|
|
588
|
+
# encode the bytes as string using base 64
|
|
589
|
+
sB64Img = base64.b64encode(bio.getvalue()).decode()
|
|
590
|
+
res = f'<img src="data:image/png;base64,{sB64Img}\n">'
|
|
591
|
+
plt.close()
|
|
592
|
+
return res
|
|
593
|
+
|
|
594
|
+
|
|
595
|
+
def getProba(p, scale=None) -> str:
|
|
596
|
+
"""
|
|
597
|
+
get a mono-dim Tensor as html (png/svg) image
|
|
598
|
+
|
|
599
|
+
Parameters
|
|
600
|
+
----------
|
|
601
|
+
p: pyagrum.Tensor
|
|
602
|
+
the marginal to show
|
|
603
|
+
scale: float
|
|
604
|
+
the zoom factor
|
|
605
|
+
|
|
606
|
+
Returns
|
|
607
|
+
-------
|
|
608
|
+
str
|
|
609
|
+
the HTML representation of the marginal
|
|
610
|
+
"""
|
|
611
|
+
set_matplotlib_formats(gum.config["notebook", "graph_format"])
|
|
612
|
+
# return _getMatplotFig(proba2histo(p, scale))
|
|
613
|
+
fig = proba2histo(p, scale)
|
|
614
|
+
plt.close()
|
|
615
|
+
return _getMatplotFig(fig)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
def showProbaMinMax(pmin, pmax, scale=1.0):
|
|
619
|
+
"""
|
|
620
|
+
Show a bi-Tensor (min,max)
|
|
621
|
+
|
|
622
|
+
Parameters
|
|
623
|
+
----------
|
|
624
|
+
pmin: pyagrum.Tensor
|
|
625
|
+
the min pmarginal to show
|
|
626
|
+
pmax: pyagrum.Tensor
|
|
627
|
+
the max pmarginal to show
|
|
628
|
+
scale: float
|
|
629
|
+
the zoom factor
|
|
630
|
+
"""
|
|
631
|
+
_ = probaMinMaxH(pmin, pmax, scale)
|
|
632
|
+
set_matplotlib_formats(gum.config["notebook", "graph_format"])
|
|
633
|
+
plt.show()
|
|
634
|
+
|
|
635
|
+
|
|
636
|
+
def getProbaMinMax(pmin, pmax, scale=1.0) -> str:
|
|
637
|
+
"""
|
|
638
|
+
get a bi-Tensor (min,max) as html (png/svg) img
|
|
639
|
+
|
|
640
|
+
Parameters
|
|
641
|
+
----------
|
|
642
|
+
pmin: pyagrum.Tensor
|
|
643
|
+
the min pmarginal to show
|
|
644
|
+
pmax: pyagrum.Tensor
|
|
645
|
+
the max pmarginal to show
|
|
646
|
+
scale: float
|
|
647
|
+
the zoom factor
|
|
648
|
+
|
|
649
|
+
Returns
|
|
650
|
+
-------
|
|
651
|
+
str
|
|
652
|
+
the HTML representation of the marginal min,max
|
|
653
|
+
"""
|
|
654
|
+
set_matplotlib_formats(gum.config["notebook", "graph_format"])
|
|
655
|
+
return _getMatplotFig(probaMinMaxH(pmin, pmax, scale))
|
|
656
|
+
|
|
657
|
+
|
|
658
|
+
def getPosterior(bn, evs, target):
|
|
659
|
+
"""
|
|
660
|
+
shortcut for proba2histo(gum.getPosterior(bn,evs,target))
|
|
661
|
+
|
|
662
|
+
Parameters
|
|
663
|
+
----------
|
|
664
|
+
bn: "pyagrum.BayesNet"
|
|
665
|
+
the BayesNet
|
|
666
|
+
evs: Dict[str|int:int|str|List[float]]
|
|
667
|
+
map of evidence
|
|
668
|
+
target: str|int
|
|
669
|
+
name of target variable
|
|
670
|
+
|
|
671
|
+
Returns
|
|
672
|
+
------
|
|
673
|
+
the matplotlib graph
|
|
674
|
+
"""
|
|
675
|
+
fig = proba2histo(gum.getPosterior(bn, evs=evs, target=target))
|
|
676
|
+
plt.close()
|
|
677
|
+
return _getMatplotFig(fig)
|
|
678
|
+
|
|
679
|
+
|
|
680
|
+
def showPosterior(bn, evs, target):
|
|
681
|
+
"""
|
|
682
|
+
shortcut for showProba(gum.getPosterior(bn,evs,target))
|
|
683
|
+
|
|
684
|
+
Parameters
|
|
685
|
+
----------
|
|
686
|
+
bn: "pyagrum.BayesNet"
|
|
687
|
+
the BayesNet
|
|
688
|
+
evs: Dict[str|int:int|str|List[float]]
|
|
689
|
+
map of evidence
|
|
690
|
+
target: str|int
|
|
691
|
+
name of target variable
|
|
692
|
+
"""
|
|
693
|
+
showProba(gum.getPosterior(bn, evs=evs, target=target))
|
|
694
|
+
|
|
695
|
+
|
|
696
|
+
def animApproximationScheme(apsc, scale=np.log10):
|
|
697
|
+
"""
|
|
698
|
+
show an animated version of an approximation algorithm
|
|
699
|
+
|
|
700
|
+
Parameters
|
|
701
|
+
----------
|
|
702
|
+
apsc
|
|
703
|
+
the approximation algorithm
|
|
704
|
+
scale
|
|
705
|
+
a function to apply to the figure
|
|
706
|
+
"""
|
|
707
|
+
f = plt.gcf()
|
|
708
|
+
|
|
709
|
+
h = gum.PythonApproximationListener(apsc._asIApproximationSchemeConfiguration())
|
|
710
|
+
apsc.setVerbosity(True)
|
|
711
|
+
apsc.listener = h
|
|
712
|
+
|
|
713
|
+
def stopper(x):
|
|
714
|
+
IPython.display.clear_output(True)
|
|
715
|
+
plt.title(f"{x} \n Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}")
|
|
716
|
+
|
|
717
|
+
def progresser(x, y, z):
|
|
718
|
+
if len(apsc.history()) < 10:
|
|
719
|
+
plt.xlim(1, 10)
|
|
720
|
+
else:
|
|
721
|
+
plt.xlim(1, len(apsc.history()))
|
|
722
|
+
plt.plot(scale(apsc.history()), "g")
|
|
723
|
+
IPython.display.clear_output(True)
|
|
724
|
+
IPython.display.display(f)
|
|
725
|
+
|
|
726
|
+
h.setWhenStop(stopper)
|
|
727
|
+
h.setWhenProgress(progresser)
|
|
728
|
+
|
|
729
|
+
|
|
730
|
+
def showApproximationScheme(apsc, scale=np.log10):
|
|
731
|
+
"""
|
|
732
|
+
show the state of an approximation algorithm
|
|
733
|
+
|
|
734
|
+
Parameters
|
|
735
|
+
----------
|
|
736
|
+
apsc
|
|
737
|
+
the approximation algorithm
|
|
738
|
+
scale
|
|
739
|
+
a function to apply to the figure
|
|
740
|
+
"""
|
|
741
|
+
if apsc.verbosity():
|
|
742
|
+
if len(apsc.history()) < 10:
|
|
743
|
+
plt.xlim(1, 10)
|
|
744
|
+
else:
|
|
745
|
+
plt.xlim(1, len(apsc.history()))
|
|
746
|
+
plt.title(f"Time : {apsc.currentTime()}s | Iterations : {apsc.nbrIterations()} | Epsilon : {apsc.epsilon()}")
|
|
747
|
+
plt.plot(scale(apsc.history()), "g")
|
|
748
|
+
|
|
749
|
+
|
|
750
|
+
def showMRF(
|
|
751
|
+
mrf,
|
|
752
|
+
view=None,
|
|
753
|
+
size=None,
|
|
754
|
+
nodeColor=None,
|
|
755
|
+
factorColor=None,
|
|
756
|
+
edgeWidth=None,
|
|
757
|
+
edgeColor=None,
|
|
758
|
+
cmapNode=None,
|
|
759
|
+
cmapEdge=None,
|
|
760
|
+
):
|
|
761
|
+
"""
|
|
762
|
+
show a Markov random field
|
|
763
|
+
|
|
764
|
+
Parameters
|
|
765
|
+
----------
|
|
766
|
+
mrf : "pyagrum.MarkovRandomField"
|
|
767
|
+
the Markov random field
|
|
768
|
+
view : str
|
|
769
|
+
'graph' | 'factorgraph’ | None (default)
|
|
770
|
+
size : str
|
|
771
|
+
size (for graphviz) of the rendered graph
|
|
772
|
+
nodeColor: Dict[int,float]
|
|
773
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
774
|
+
factorColor: Dict[int,float]
|
|
775
|
+
a function returning a value (between 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
|
|
776
|
+
edgeWidth: Dict[Tuple[int,int],float]
|
|
777
|
+
a edgeMap of values to be shown as width of edges (used when view='graph')
|
|
778
|
+
edgeColor: Dict[int,float]
|
|
779
|
+
a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
|
|
780
|
+
cmapNode: Dict[Tuple[int,int],float]
|
|
781
|
+
color map to show the colors (if cmapEdge is None, this color map is used also for edges)
|
|
782
|
+
cmapEdge: "matplotlib.ColorMap"
|
|
783
|
+
color map to show the edge color if distinction is needed
|
|
784
|
+
|
|
785
|
+
Returns
|
|
786
|
+
------
|
|
787
|
+
the graph
|
|
788
|
+
"""
|
|
789
|
+
if view is None:
|
|
790
|
+
view = gum.config["notebook", "default_markovrandomfield_view"]
|
|
791
|
+
|
|
792
|
+
if size is None:
|
|
793
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
794
|
+
|
|
795
|
+
if cmapEdge is None:
|
|
796
|
+
cmapEdge = cmapNode
|
|
797
|
+
|
|
798
|
+
if view == "graph":
|
|
799
|
+
dottxt = MRF2UGdot(
|
|
800
|
+
mrf, size, nodeColor=nodeColor, edgeWidth=edgeWidth, edgeColor=edgeColor, cmapNode=cmapNode, cmapEdge=cmapEdge
|
|
801
|
+
)
|
|
802
|
+
else:
|
|
803
|
+
dottxt = MRF2FactorGraphdot(mrf, size, nodeColor=nodeColor, factorColor=factorColor, cmapNode=cmapNode)
|
|
804
|
+
|
|
805
|
+
return showGraph(dottxt, size)
|
|
806
|
+
|
|
807
|
+
|
|
808
|
+
def showInfluenceDiagram(diag, size=None):
|
|
809
|
+
"""
|
|
810
|
+
show an influence diagram as a graph
|
|
811
|
+
|
|
812
|
+
Parameters
|
|
813
|
+
----------
|
|
814
|
+
diag : "pyagrum.InfluenceDiagram"
|
|
815
|
+
the influence diagram
|
|
816
|
+
size : str
|
|
817
|
+
size (for graphviz) of the rendered graph
|
|
818
|
+
|
|
819
|
+
Returns
|
|
820
|
+
-------
|
|
821
|
+
the representation of the influence diagram
|
|
822
|
+
"""
|
|
823
|
+
if size is None:
|
|
824
|
+
size = gum.config["influenceDiagram", "default_id_size"]
|
|
825
|
+
|
|
826
|
+
return showGraph(ID2dot(diag), size)
|
|
827
|
+
|
|
828
|
+
|
|
829
|
+
def getInfluenceDiagram(diag, size=None):
|
|
830
|
+
"""
|
|
831
|
+
get a HTML string for an influence diagram as a graph
|
|
832
|
+
|
|
833
|
+
Parameters
|
|
834
|
+
----------
|
|
835
|
+
diag : "pyagrum.InfluenceDiagram"
|
|
836
|
+
the influence diagram
|
|
837
|
+
size : str
|
|
838
|
+
size (for graphviz) of the rendered graph
|
|
839
|
+
|
|
840
|
+
Returns
|
|
841
|
+
-------
|
|
842
|
+
str
|
|
843
|
+
the HTML representation of the influence diagram
|
|
844
|
+
"""
|
|
845
|
+
if size is None:
|
|
846
|
+
size = gum.config["influenceDiagram", "default_id_size"]
|
|
847
|
+
|
|
848
|
+
return getGraph(ID2dot(diag), size)
|
|
849
|
+
|
|
850
|
+
|
|
851
|
+
def showBN(bn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
|
|
852
|
+
"""
|
|
853
|
+
show a Bayesian network
|
|
854
|
+
|
|
855
|
+
Parameters
|
|
856
|
+
----------
|
|
857
|
+
bn : pyagrum.BayesNet
|
|
858
|
+
the Bayesian network
|
|
859
|
+
size: str
|
|
860
|
+
size (for graphviz) of the rendered graph
|
|
861
|
+
nodeColor: dict[Tuple(int,int),float]
|
|
862
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
863
|
+
arcWidth: dict[Tuple(int,int),float]
|
|
864
|
+
an arcMap of values to be shown as bold arcs
|
|
865
|
+
arcLabel: dict[Tuple(int,int),str]
|
|
866
|
+
an arcMap of labels to be shown next to arcs
|
|
867
|
+
arcColor: dict[Tuple(int,int),float]
|
|
868
|
+
an arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
869
|
+
cmapNode: ColorMap
|
|
870
|
+
color map to show the vals of Nodes ( (if cmapEdge is None, this color map is used also for edges)
|
|
871
|
+
cmapArc: ColorMap
|
|
872
|
+
color map to show the vals of Arcs
|
|
873
|
+
showMsg: dict
|
|
874
|
+
a nodeMap of values to be shown as tooltip
|
|
875
|
+
"""
|
|
876
|
+
if size is None:
|
|
877
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
878
|
+
|
|
879
|
+
if cmapArc is None:
|
|
880
|
+
cmapArc = cmapNode
|
|
881
|
+
|
|
882
|
+
return showGraph(
|
|
883
|
+
BN2dot(
|
|
884
|
+
bn,
|
|
885
|
+
size=size,
|
|
886
|
+
nodeColor=nodeColor,
|
|
887
|
+
arcWidth=arcWidth,
|
|
888
|
+
arcLabel=arcLabel,
|
|
889
|
+
arcColor=arcColor,
|
|
890
|
+
cmapNode=cmapNode,
|
|
891
|
+
cmapArc=cmapArc,
|
|
892
|
+
),
|
|
893
|
+
size,
|
|
894
|
+
)
|
|
895
|
+
|
|
896
|
+
|
|
897
|
+
def showCN(cn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
|
|
898
|
+
"""
|
|
899
|
+
show a credal network
|
|
900
|
+
|
|
901
|
+
Parameters
|
|
902
|
+
----------
|
|
903
|
+
cn : pyagrum.CredalNet
|
|
904
|
+
the Credal network
|
|
905
|
+
size: str
|
|
906
|
+
size (for graphviz) of the rendered graph
|
|
907
|
+
nodeColor: dict[int,float]
|
|
908
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
909
|
+
arcWidth: dict[Tuple(int,int),float]
|
|
910
|
+
an arcMap of values to be shown as bold arcs
|
|
911
|
+
arcLabel: dict[Tuple(int,int),float]
|
|
912
|
+
an arcMap of labels to be shown next to arcs
|
|
913
|
+
arcColor: dict[Tuple(int,int),float]
|
|
914
|
+
an arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
915
|
+
cmapNode: matplotlib.color.colormap
|
|
916
|
+
color map to show the vals of Nodes
|
|
917
|
+
cmapArc: matplotlib.color.colormap
|
|
918
|
+
color map to show the vals of Arcs
|
|
919
|
+
showMsg : dict[int,str]
|
|
920
|
+
a nodeMap of values to be shown as tooltip
|
|
921
|
+
|
|
922
|
+
Returns
|
|
923
|
+
-------
|
|
924
|
+
the graph
|
|
925
|
+
"""
|
|
926
|
+
if size is None:
|
|
927
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
928
|
+
|
|
929
|
+
if cmapArc is None:
|
|
930
|
+
cmapArc = cmapNode
|
|
931
|
+
|
|
932
|
+
return showGraph(
|
|
933
|
+
CN2dot(
|
|
934
|
+
cn,
|
|
935
|
+
size=size,
|
|
936
|
+
nodeColor=nodeColor,
|
|
937
|
+
arcWidth=arcWidth,
|
|
938
|
+
arcLabel=arcLabel,
|
|
939
|
+
arcColor=arcColor,
|
|
940
|
+
cmapNode=cmapNode,
|
|
941
|
+
cmapArc=cmapArc,
|
|
942
|
+
),
|
|
943
|
+
size,
|
|
944
|
+
)
|
|
945
|
+
|
|
946
|
+
|
|
947
|
+
def getMRF(
|
|
948
|
+
mrf,
|
|
949
|
+
view=None,
|
|
950
|
+
size=None,
|
|
951
|
+
nodeColor=None,
|
|
952
|
+
factorColor=None,
|
|
953
|
+
edgeWidth=None,
|
|
954
|
+
edgeColor=None,
|
|
955
|
+
cmapNode=None,
|
|
956
|
+
cmapEdge=None,
|
|
957
|
+
):
|
|
958
|
+
"""
|
|
959
|
+
get an HTML string for a Markov random field
|
|
960
|
+
|
|
961
|
+
Parameters
|
|
962
|
+
----------
|
|
963
|
+
mrf : "pyagrum.MarkovRandomField"
|
|
964
|
+
the Markov random field
|
|
965
|
+
view: str
|
|
966
|
+
'graph' | 'factorgraph’ | None (default)
|
|
967
|
+
size: str
|
|
968
|
+
size (for graphviz) of the rendered graph
|
|
969
|
+
nodeColor: Dict[str,float]
|
|
970
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
971
|
+
factorColor: Dict[str,float]
|
|
972
|
+
a function returning a value (beeween 0 and 1) to be shown as a color of factor. (used when view='factorgraph')
|
|
973
|
+
edgeWidth: Dict[Tuple[str,str],float]
|
|
974
|
+
a edgeMap of values to be shown as width of edges (used when view='graph')
|
|
975
|
+
edgeColor: Dict[Tuple[str,str],float]
|
|
976
|
+
a edgeMap of values (between 0 and 1) to be shown as color of edges (used when view='graph')
|
|
977
|
+
cmapNode: matplotlib.ColorMap
|
|
978
|
+
color map to show the colors (if cmapEdge is None, cmapNode is used for edges)
|
|
979
|
+
cmapEdge: matplotlib.ColorMap
|
|
980
|
+
color map to show the edge color if distinction is needed
|
|
981
|
+
|
|
982
|
+
Returns
|
|
983
|
+
-------
|
|
984
|
+
the graph
|
|
985
|
+
"""
|
|
986
|
+
if size is None:
|
|
987
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
988
|
+
|
|
989
|
+
if cmapEdge is None:
|
|
990
|
+
cmapEdge = cmapNode
|
|
991
|
+
|
|
992
|
+
if view is None:
|
|
993
|
+
view = gum.config["notebook", "default_markovrandomfield_view"]
|
|
994
|
+
|
|
995
|
+
if view == "graph":
|
|
996
|
+
dottxt = MRF2UGdot(mrf, size, nodeColor, edgeWidth, edgeColor, cmapNode, cmapEdge)
|
|
997
|
+
else:
|
|
998
|
+
dottxt = MRF2FactorGraphdot(mrf, size, nodeColor, factorColor, cmapNode=cmapNode)
|
|
999
|
+
|
|
1000
|
+
return getGraph(dottxt, size)
|
|
1001
|
+
|
|
1002
|
+
|
|
1003
|
+
def getBN(bn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
|
|
1004
|
+
"""
|
|
1005
|
+
get a HTML string for a Bayesian network
|
|
1006
|
+
|
|
1007
|
+
Parameters
|
|
1008
|
+
----------
|
|
1009
|
+
bn : pyagrum.BayesNet
|
|
1010
|
+
the Bayesian network
|
|
1011
|
+
size: str
|
|
1012
|
+
size (for graphviz) of the rendered graph
|
|
1013
|
+
nodeColor: dict[Tuple(int,int),float]
|
|
1014
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
1015
|
+
arcWidth: dict[Tuple(int,int),float]
|
|
1016
|
+
an arcMap of values to be shown as bold arcs
|
|
1017
|
+
arcLabel: dict[Tuple(int,int),str]
|
|
1018
|
+
an arcMap of labels to be shown next to arcs
|
|
1019
|
+
arcColor: dict[Tuple(int,int),float]
|
|
1020
|
+
an arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
1021
|
+
cmapNode: ColorMap
|
|
1022
|
+
color map to show the vals of Nodes
|
|
1023
|
+
cmapArc: ColorMap
|
|
1024
|
+
color map to show the vals of Arcs
|
|
1025
|
+
showMsg: dict
|
|
1026
|
+
a nodeMap of values to be shown as tooltip
|
|
1027
|
+
|
|
1028
|
+
Returns
|
|
1029
|
+
-------
|
|
1030
|
+
pydot.Dot
|
|
1031
|
+
the desired representation of the Bayesian network
|
|
1032
|
+
"""
|
|
1033
|
+
if size is None:
|
|
1034
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
1035
|
+
|
|
1036
|
+
if cmapArc is None:
|
|
1037
|
+
cmapArc = cmapNode
|
|
1038
|
+
|
|
1039
|
+
return getGraph(
|
|
1040
|
+
BN2dot(
|
|
1041
|
+
bn,
|
|
1042
|
+
size=size,
|
|
1043
|
+
nodeColor=nodeColor,
|
|
1044
|
+
arcWidth=arcWidth,
|
|
1045
|
+
arcLabel=arcLabel,
|
|
1046
|
+
arcColor=arcColor,
|
|
1047
|
+
cmapNode=cmapNode,
|
|
1048
|
+
cmapArc=cmapArc,
|
|
1049
|
+
),
|
|
1050
|
+
size,
|
|
1051
|
+
)
|
|
1052
|
+
|
|
1053
|
+
|
|
1054
|
+
def getCN(cn, size=None, nodeColor=None, arcWidth=None, arcLabel=None, arcColor=None, cmapNode=None, cmapArc=None):
|
|
1055
|
+
"""
|
|
1056
|
+
get a HTML string for a credal network
|
|
1057
|
+
|
|
1058
|
+
Parameters
|
|
1059
|
+
----------
|
|
1060
|
+
cn : pyagrum.CredalNet
|
|
1061
|
+
the Credal network
|
|
1062
|
+
size: str
|
|
1063
|
+
size (for graphviz) of the rendered graph
|
|
1064
|
+
nodeColor: dict[int,float]
|
|
1065
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
1066
|
+
arcWidth: dict[Tuple(int,int),float]
|
|
1067
|
+
an arcMap of values to be shown as bold arcs
|
|
1068
|
+
arcLabel: dict[Tuple(int,int),float]
|
|
1069
|
+
an arcMap of labels to be shown next to arcs
|
|
1070
|
+
arcColor: dict[Tuple(int,int),float]
|
|
1071
|
+
an arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
1072
|
+
cmapNode: matplotlib.color.colormap
|
|
1073
|
+
color map to show the vals of Nodes
|
|
1074
|
+
cmapArc: matplotlib.color.colormap
|
|
1075
|
+
color map to show the vals of Arcs
|
|
1076
|
+
showMsg : dict[int,str]
|
|
1077
|
+
a nodeMap of values to be shown as tooltip
|
|
1078
|
+
|
|
1079
|
+
Returns
|
|
1080
|
+
-------
|
|
1081
|
+
pydot.Dot
|
|
1082
|
+
the desired representation of the Credal Network
|
|
1083
|
+
"""
|
|
1084
|
+
if size is None:
|
|
1085
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
1086
|
+
|
|
1087
|
+
if cmapArc is None:
|
|
1088
|
+
cmapArc = cmapNode
|
|
1089
|
+
|
|
1090
|
+
return getGraph(
|
|
1091
|
+
CN2dot(
|
|
1092
|
+
cn,
|
|
1093
|
+
size=size,
|
|
1094
|
+
nodeColor=nodeColor,
|
|
1095
|
+
arcWidth=arcWidth,
|
|
1096
|
+
arcLabel=arcLabel,
|
|
1097
|
+
arcColor=arcColor,
|
|
1098
|
+
cmapNode=cmapNode,
|
|
1099
|
+
cmapArc=cmapArc,
|
|
1100
|
+
),
|
|
1101
|
+
size,
|
|
1102
|
+
)
|
|
1103
|
+
|
|
1104
|
+
|
|
1105
|
+
def showInference(model, **kwargs):
|
|
1106
|
+
"""
|
|
1107
|
+
show pydot graph for an inference in a notebook
|
|
1108
|
+
|
|
1109
|
+
Parameters
|
|
1110
|
+
----------
|
|
1111
|
+
model: pyagrum.GraphicalModel
|
|
1112
|
+
the model in which to infer (pyagrum.BayesNet, pyagrum.MarkovRandomField or pyagrum.InfluenceDiagram)
|
|
1113
|
+
engine: gum.Inference
|
|
1114
|
+
inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovRandomField and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
|
|
1115
|
+
evs: Dict[int|str,int|str|List[float]]
|
|
1116
|
+
map of evidence
|
|
1117
|
+
targets: Set[str]
|
|
1118
|
+
set of targets
|
|
1119
|
+
size: string
|
|
1120
|
+
size (for graphviz) of the rendered graph
|
|
1121
|
+
nodeColor: Dict[str,float]
|
|
1122
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
1123
|
+
factorColor: Dict[int,float]
|
|
1124
|
+
a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
|
|
1125
|
+
arcWidth: : Dict[(str,str),float]
|
|
1126
|
+
an arcMap of values to be shown as width of arcs
|
|
1127
|
+
arcColor: : Dict[(str,str),float]
|
|
1128
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
1129
|
+
cmapNode: matplotlib.ColorMap
|
|
1130
|
+
map to show the color of nodes and arcs
|
|
1131
|
+
cmapArc: matplotlib.ColorMap
|
|
1132
|
+
color map to show the vals of Arcs.
|
|
1133
|
+
graph: pyagrum.Graph
|
|
1134
|
+
only shows nodes that have their id in the graph (and not in the whole BN)
|
|
1135
|
+
view: str
|
|
1136
|
+
graph | factorgraph | None (default) for Markov random field
|
|
1137
|
+
|
|
1138
|
+
Returns
|
|
1139
|
+
-------
|
|
1140
|
+
the desired representation of the inference
|
|
1141
|
+
"""
|
|
1142
|
+
if "size" in kwargs:
|
|
1143
|
+
size = kwargs["size"]
|
|
1144
|
+
else:
|
|
1145
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
1146
|
+
|
|
1147
|
+
showGraph(prepareShowInference(model, **kwargs), size)
|
|
1148
|
+
|
|
1149
|
+
|
|
1150
|
+
def getInference(model, **kwargs):
|
|
1151
|
+
"""
|
|
1152
|
+
get a HTML string for an inference in a notebook
|
|
1153
|
+
|
|
1154
|
+
Parameters
|
|
1155
|
+
----------
|
|
1156
|
+
model: pyagrum.GraphicalModel
|
|
1157
|
+
the model in which to infer (pyagrum.BayesNet, pyagrum.MarkovRandomField or pyagrum.InfluenceDiagram)
|
|
1158
|
+
engine: gum.Inference
|
|
1159
|
+
inference algorithm used. If None, gum.LazyPropagation will be used for BayesNet, gum.ShaferShenoy for gum.MarkovRandomField and gum.ShaferShenoyLIMIDInference for gum.InfluenceDiagram.
|
|
1160
|
+
evs: Dict[int|str,int|str|List[float]]
|
|
1161
|
+
map of evidence
|
|
1162
|
+
targets: Set[str]
|
|
1163
|
+
set of targets
|
|
1164
|
+
size: string
|
|
1165
|
+
size (for graphviz) of the rendered graph
|
|
1166
|
+
nodeColor: Dict[str,float]
|
|
1167
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
1168
|
+
factorColor: Dict[int,float]
|
|
1169
|
+
a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
|
|
1170
|
+
arcWidth: : Dict[(str,str),float]
|
|
1171
|
+
an arcMap of values to be shown as width of arcs
|
|
1172
|
+
arcColor: : Dict[(str,str),float]
|
|
1173
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
1174
|
+
cmapNode: matplotlib.ColorMap
|
|
1175
|
+
map to show the color of nodes and arcs
|
|
1176
|
+
cmapArc: matplotlib.ColorMap
|
|
1177
|
+
color map to show the vals of Arcs.
|
|
1178
|
+
graph: pyagrum.Graph
|
|
1179
|
+
only shows nodes that have their id in the graph (and not in the whole BN)
|
|
1180
|
+
view: str
|
|
1181
|
+
graph | factorgraph | None (default) for Markov random field
|
|
1182
|
+
|
|
1183
|
+
Returns
|
|
1184
|
+
-------
|
|
1185
|
+
the desired representation of the inference
|
|
1186
|
+
"""
|
|
1187
|
+
if "size" in kwargs:
|
|
1188
|
+
size = kwargs["size"]
|
|
1189
|
+
else:
|
|
1190
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
1191
|
+
|
|
1192
|
+
grinf = prepareShowInference(model, **kwargs)
|
|
1193
|
+
return getGraph(grinf, size)
|
|
1194
|
+
|
|
1195
|
+
|
|
1196
|
+
def _reprTensor(pot, digits=None, withColors=None, varnames=None, asString=False):
|
|
1197
|
+
"""
|
|
1198
|
+
return a representation of a gum.Tensor as a HTML table.
|
|
1199
|
+
The first dimension is special (horizontal) due to the representation of conditional probability table
|
|
1200
|
+
|
|
1201
|
+
Parameters
|
|
1202
|
+
---------
|
|
1203
|
+
pot: gum.Tensor
|
|
1204
|
+
the tensor to get
|
|
1205
|
+
digits: int
|
|
1206
|
+
number of digits to show
|
|
1207
|
+
withColors: Boolean
|
|
1208
|
+
background color for proba cells or not
|
|
1209
|
+
varnames: Dict[str,str]
|
|
1210
|
+
a mapping that gives the aliases for variables name in the table
|
|
1211
|
+
asString: Boolean
|
|
1212
|
+
display the table or a HTML string
|
|
1213
|
+
|
|
1214
|
+
Returns
|
|
1215
|
+
-------
|
|
1216
|
+
str
|
|
1217
|
+
the HTML table that represents the tensor
|
|
1218
|
+
"""
|
|
1219
|
+
from fractions import Fraction
|
|
1220
|
+
|
|
1221
|
+
r0, g0, b0 = gumcols.hex2rgb(gum.config["notebook", "tensor_color_0"])
|
|
1222
|
+
r1, g1, b1 = gumcols.hex2rgb(gum.config["notebook", "tensor_color_1"])
|
|
1223
|
+
|
|
1224
|
+
if digits is None:
|
|
1225
|
+
digits = gum.config.asInt["notebook", "tensor_visible_digits"]
|
|
1226
|
+
|
|
1227
|
+
if withColors is None:
|
|
1228
|
+
withColors = gum.config.asBool["notebook", "tensor_with_colors"]
|
|
1229
|
+
|
|
1230
|
+
with_fraction = gum.config["notebook", "tensor_with_fraction"] == "True"
|
|
1231
|
+
if with_fraction:
|
|
1232
|
+
fraction_limit = int(gum.config["notebook", "tensor_fraction_limit"])
|
|
1233
|
+
fraction_round_error = float(gum.config["notebook", "tensor_fraction_round_error"])
|
|
1234
|
+
fraction_with_latex = gum.config["notebook", "tensor_fraction_with_latex"] == "True"
|
|
1235
|
+
|
|
1236
|
+
def _rgb(r, g, b):
|
|
1237
|
+
return "#%02x%02x%02x" % (r, g, b)
|
|
1238
|
+
|
|
1239
|
+
def _mkCell(val):
|
|
1240
|
+
s = "<td style='"
|
|
1241
|
+
if withColors and (0 <= val <= 1):
|
|
1242
|
+
r = int(r0 + val * (r1 - r0))
|
|
1243
|
+
g = int(g0 + val * (g1 - g0))
|
|
1244
|
+
b = int(b0 + val * (b1 - b0))
|
|
1245
|
+
|
|
1246
|
+
tx = gumcols.rgb2brightness(r, g, b)
|
|
1247
|
+
|
|
1248
|
+
s += "color:" + tx + ";background-color:" + _rgb(r, g, b) + ";"
|
|
1249
|
+
|
|
1250
|
+
str_val = ""
|
|
1251
|
+
if with_fraction:
|
|
1252
|
+
frac_val = Fraction(val).limit_denominator(fraction_limit)
|
|
1253
|
+
val_app = frac_val.numerator / frac_val.denominator
|
|
1254
|
+
if abs(val_app - val) < fraction_round_error:
|
|
1255
|
+
str_val = "text-align:center;'>"
|
|
1256
|
+
if fraction_with_latex:
|
|
1257
|
+
str_val += "$$"
|
|
1258
|
+
if frac_val.denominator > 1:
|
|
1259
|
+
str_val += f"\\frac{{{frac_val.numerator}}}{{{frac_val.denominator}}}"
|
|
1260
|
+
else:
|
|
1261
|
+
str_val += f"{frac_val.numerator}"
|
|
1262
|
+
str_val += "$$"
|
|
1263
|
+
else:
|
|
1264
|
+
str_val += f"{frac_val}"
|
|
1265
|
+
str_val += "</td>"
|
|
1266
|
+
if str_val == "":
|
|
1267
|
+
str_val = f"text-align:right;padding: 3px;'>{val:.{digits}f}</td>"
|
|
1268
|
+
|
|
1269
|
+
return s + str_val
|
|
1270
|
+
|
|
1271
|
+
html = list()
|
|
1272
|
+
html.append('<table style="border:1px solid black;border-collapse: collapse;">')
|
|
1273
|
+
if pot.empty():
|
|
1274
|
+
html.append("<tr><th> </th></tr>")
|
|
1275
|
+
html.append("<tr>" + _mkCell(pot.get(gum.Instantiation())) + "</tr>")
|
|
1276
|
+
else:
|
|
1277
|
+
if varnames is not None and len(varnames) != pot.nbrDim():
|
|
1278
|
+
raise ValueError(f"varnames contains {len(varnames)} value(s) instead of the needed {pot.nbrDim()} value(s).")
|
|
1279
|
+
|
|
1280
|
+
nparents = pot.nbrDim() - 1
|
|
1281
|
+
var = pot.variable(0)
|
|
1282
|
+
varname = var.name() if varnames is None else varnames[0]
|
|
1283
|
+
|
|
1284
|
+
# first line
|
|
1285
|
+
if nparents > 0:
|
|
1286
|
+
html.append(f"""<tr><th colspan='{nparents}'></th>
|
|
1287
|
+
<th colspan='{var.domainSize()}' style='border:1px solid black;color:black;background-color:#808080;'><center>{varname}</center>
|
|
1288
|
+
</th></tr>""")
|
|
1289
|
+
else:
|
|
1290
|
+
html.append(f"""<tr style='border:1px solid black;color:black;background-color:#808080'>
|
|
1291
|
+
<th colspan='{var.domainSize()}'><center>{varname}</center></th></tr>""")
|
|
1292
|
+
|
|
1293
|
+
# second line
|
|
1294
|
+
s = "<tr>"
|
|
1295
|
+
if nparents > 0:
|
|
1296
|
+
# parents order
|
|
1297
|
+
if gum.config["notebook", "tensor_parent_values"] == "revmerge":
|
|
1298
|
+
pmin, pmax, pinc = nparents - 1, 0 - 1, -1
|
|
1299
|
+
else:
|
|
1300
|
+
pmin, pmax, pinc = 0, nparents, 1
|
|
1301
|
+
|
|
1302
|
+
if varnames is None:
|
|
1303
|
+
varnames = list(reversed(pot.names))
|
|
1304
|
+
for par in range(pmin, pmax, pinc):
|
|
1305
|
+
parent = varnames[par]
|
|
1306
|
+
s += f"<th style='border:1px solid black;color:black;background-color:#808080'><center>{parent}</center></th>"
|
|
1307
|
+
|
|
1308
|
+
for label in var.labels():
|
|
1309
|
+
s += f"""<th style='border:1px solid black;border-bottom-style: double;color:black;background-color:#BBBBBB'>
|
|
1310
|
+
<center>{label}</center></th>"""
|
|
1311
|
+
s += "</tr>"
|
|
1312
|
+
|
|
1313
|
+
html.append(s)
|
|
1314
|
+
|
|
1315
|
+
inst = gum.Instantiation(pot)
|
|
1316
|
+
off = 1
|
|
1317
|
+
offset = dict()
|
|
1318
|
+
for i in range(1, nparents + 1):
|
|
1319
|
+
offset[i] = off
|
|
1320
|
+
off *= inst.variable(i).domainSize()
|
|
1321
|
+
|
|
1322
|
+
inst.setFirst()
|
|
1323
|
+
while not inst.end():
|
|
1324
|
+
s = "<tr>"
|
|
1325
|
+
# parents order
|
|
1326
|
+
if gum.config["notebook", "tensor_parent_values"] == "revmerge":
|
|
1327
|
+
pmin, pmax, pinc = 1, nparents + 1, 1
|
|
1328
|
+
else:
|
|
1329
|
+
pmin, pmax, pinc = nparents, 0, -1
|
|
1330
|
+
for par in range(pmin, pmax, pinc):
|
|
1331
|
+
label = inst.variable(par).label(inst.val(par))
|
|
1332
|
+
if par == 1 or gum.config["notebook", "tensor_parent_values"] == "nomerge":
|
|
1333
|
+
s += f"<th style='border:1px solid black;color:black;background-color:#BBBBBB'><center>{label}</center></th>"
|
|
1334
|
+
else:
|
|
1335
|
+
if sum([inst.val(i) for i in range(1, par)]) == 0:
|
|
1336
|
+
s += f"""<th style='border:1px solid black;color:black;background-color:#BBBBBB;' rowspan = '{offset[par]}'>
|
|
1337
|
+
<center>{label}</center></th>"""
|
|
1338
|
+
for _ in range(pot.variable(0).domainSize()):
|
|
1339
|
+
s += _mkCell(pot.get(inst))
|
|
1340
|
+
inst.inc()
|
|
1341
|
+
s += "</tr>"
|
|
1342
|
+
html.append(s)
|
|
1343
|
+
|
|
1344
|
+
html.append("</table>")
|
|
1345
|
+
|
|
1346
|
+
if asString:
|
|
1347
|
+
return "\n".join(html)
|
|
1348
|
+
else:
|
|
1349
|
+
return IPython.display.HTML("".join(html))
|
|
1350
|
+
|
|
1351
|
+
|
|
1352
|
+
def __isKindOfProba(pot):
|
|
1353
|
+
"""
|
|
1354
|
+
check if pot is a joint proba or a CPT
|
|
1355
|
+
|
|
1356
|
+
Parameters
|
|
1357
|
+
----------
|
|
1358
|
+
pot: gum.Tensor
|
|
1359
|
+
the tensor
|
|
1360
|
+
|
|
1361
|
+
Returns
|
|
1362
|
+
-------
|
|
1363
|
+
True or False
|
|
1364
|
+
"""
|
|
1365
|
+
epsilon = 1e-5
|
|
1366
|
+
if pot.min() < -epsilon:
|
|
1367
|
+
return False
|
|
1368
|
+
if pot.max() > 1 + epsilon:
|
|
1369
|
+
return False
|
|
1370
|
+
|
|
1371
|
+
# is it a joint proba ?
|
|
1372
|
+
if abs(pot.sum() - 1) < epsilon:
|
|
1373
|
+
return True
|
|
1374
|
+
|
|
1375
|
+
# marginal and then not proba (because of the test just above)
|
|
1376
|
+
if pot.nbrDim() < 2:
|
|
1377
|
+
return False
|
|
1378
|
+
|
|
1379
|
+
# is is a CPT ?
|
|
1380
|
+
q = pot.sumOut([pot.variable(0).name()])
|
|
1381
|
+
if abs(q.max() - 1) > epsilon:
|
|
1382
|
+
return False
|
|
1383
|
+
if abs(q.min() - 1) > epsilon:
|
|
1384
|
+
return False
|
|
1385
|
+
|
|
1386
|
+
return True
|
|
1387
|
+
|
|
1388
|
+
|
|
1389
|
+
def showPotential(pot, digits=None, withColors=None, varnames=None):
|
|
1390
|
+
warnings.warn("showPotential is deprecated since pyAgrum 2.0.0. Use showTensor instead", DeprecationWarning)
|
|
1391
|
+
showTensor(pot, digits, withColors, varnames)
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
def showTensor(pot, digits=None, withColors=None, varnames=None):
|
|
1395
|
+
"""
|
|
1396
|
+
show a gum.Tensor as a HTML table.
|
|
1397
|
+
The first dimension is special (horizontal) due to the representation of conditional probability table
|
|
1398
|
+
|
|
1399
|
+
Parameters
|
|
1400
|
+
----------
|
|
1401
|
+
pot : gum.Tensor
|
|
1402
|
+
the tensor to show
|
|
1403
|
+
digits : int
|
|
1404
|
+
number of digits to show
|
|
1405
|
+
withColors : bool
|
|
1406
|
+
background color for proba cells or not
|
|
1407
|
+
varnames : List[str]
|
|
1408
|
+
the aliases for variables name in the table
|
|
1409
|
+
"""
|
|
1410
|
+
if withColors is None:
|
|
1411
|
+
withColors = gum.config.asBool["notebook", "tensor_with_colors"]
|
|
1412
|
+
|
|
1413
|
+
if withColors:
|
|
1414
|
+
withColors = __isKindOfProba(pot)
|
|
1415
|
+
|
|
1416
|
+
s = _reprTensor(pot, digits, withColors, varnames, asString=False)
|
|
1417
|
+
IPython.display.display(s)
|
|
1418
|
+
|
|
1419
|
+
|
|
1420
|
+
def getPotential(pot, digits=None, withColors=None, varnames=None):
|
|
1421
|
+
warnings.warn("getPotential is deprecated since pyAgrum 2.0.0. Use getTensor instead", DeprecationWarning)
|
|
1422
|
+
return getTensor(pot, digits, withColors, varnames)
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
def getTensor(pot, digits=None, withColors=None, varnames=None):
|
|
1426
|
+
"""
|
|
1427
|
+
return a HTML string of a gum.Tensor as a HTML table.
|
|
1428
|
+
The first dimension is special (horizontal) due to the representation of conditional probability table
|
|
1429
|
+
|
|
1430
|
+
Parameters
|
|
1431
|
+
----------
|
|
1432
|
+
pot : gum.Tensor
|
|
1433
|
+
the tensor to show
|
|
1434
|
+
digits : int
|
|
1435
|
+
number of digits to show
|
|
1436
|
+
withColors : bool
|
|
1437
|
+
background for proba cells or not
|
|
1438
|
+
varnames : List[str]
|
|
1439
|
+
the aliases for variables name in the table
|
|
1440
|
+
|
|
1441
|
+
Returns
|
|
1442
|
+
-------
|
|
1443
|
+
str
|
|
1444
|
+
the html representation of the Tensor (as a string)
|
|
1445
|
+
"""
|
|
1446
|
+
if withColors is None:
|
|
1447
|
+
withColors = gum.config.asBool["notebook", "tensor_with_colors"]
|
|
1448
|
+
|
|
1449
|
+
if withColors:
|
|
1450
|
+
withColors = __isKindOfProba(pot)
|
|
1451
|
+
|
|
1452
|
+
return _reprTensor(pot, digits, withColors, varnames, asString=True)
|
|
1453
|
+
|
|
1454
|
+
|
|
1455
|
+
def showCPTs(bn):
|
|
1456
|
+
flow.clear()
|
|
1457
|
+
for i in bn.names():
|
|
1458
|
+
flow.add_html(getTensor(bn.cpt(i)))
|
|
1459
|
+
flow.display()
|
|
1460
|
+
|
|
1461
|
+
|
|
1462
|
+
def getSideBySide(*args, **kwargs):
|
|
1463
|
+
"""
|
|
1464
|
+
create an HTML table for args as string (using string, _repr_html_() or str())
|
|
1465
|
+
|
|
1466
|
+
Parameters
|
|
1467
|
+
----------
|
|
1468
|
+
args: str
|
|
1469
|
+
HMTL fragments as string arg, arg._repr_html_() or str(arg)
|
|
1470
|
+
captions: List[str], optional
|
|
1471
|
+
list of captions
|
|
1472
|
+
valign: str
|
|
1473
|
+
vertical position in the row (top|middle|bottom, middle by default)
|
|
1474
|
+
ncols: int
|
|
1475
|
+
number of columns (infinite by default)
|
|
1476
|
+
|
|
1477
|
+
Returns
|
|
1478
|
+
-------
|
|
1479
|
+
str
|
|
1480
|
+
a string representing the table
|
|
1481
|
+
"""
|
|
1482
|
+
vals = {"captions", "valign", "ncols"}
|
|
1483
|
+
if not set(kwargs.keys()).issubset(vals):
|
|
1484
|
+
raise TypeError(f"sideBySide() got unexpected keyword argument(s) : '{set(kwargs.keys()).difference(vals)}'")
|
|
1485
|
+
|
|
1486
|
+
if "captions" in kwargs:
|
|
1487
|
+
captions = kwargs["captions"]
|
|
1488
|
+
else:
|
|
1489
|
+
captions = None
|
|
1490
|
+
|
|
1491
|
+
if "valign" in kwargs and kwargs["valign"] in ["top", "middle", "bottom"]:
|
|
1492
|
+
v_align = f"vertical-align:{kwargs['valign']};"
|
|
1493
|
+
else:
|
|
1494
|
+
v_align = "vertical-align:middle;"
|
|
1495
|
+
|
|
1496
|
+
ncols = None
|
|
1497
|
+
if "ncols" in kwargs:
|
|
1498
|
+
ncols = int(kwargs["ncols"])
|
|
1499
|
+
if ncols < 1:
|
|
1500
|
+
ncols = 1
|
|
1501
|
+
|
|
1502
|
+
def reprHTML(s):
|
|
1503
|
+
if isinstance(s, str):
|
|
1504
|
+
return s
|
|
1505
|
+
elif hasattr(s, "_repr_html_"):
|
|
1506
|
+
return s._repr_html_()
|
|
1507
|
+
else:
|
|
1508
|
+
return str(s)
|
|
1509
|
+
|
|
1510
|
+
s = '<table style="border-style: hidden; border-collapse: collapse;" width="100%"><tr>'
|
|
1511
|
+
for i in range(len(args)):
|
|
1512
|
+
s += f'<td style="border-top:hidden;border-bottom:hidden;{v_align}"><div align="center" style="{v_align}">'
|
|
1513
|
+
s += reprHTML(args[i])
|
|
1514
|
+
if captions is not None:
|
|
1515
|
+
s += f"<br><small><i>{captions[i]}</i></small>"
|
|
1516
|
+
s += "</div></td>"
|
|
1517
|
+
if ncols is not None and (i + 1) % ncols == 0:
|
|
1518
|
+
s += "</tr><tr>"
|
|
1519
|
+
s += "</tr></table>"
|
|
1520
|
+
return s
|
|
1521
|
+
|
|
1522
|
+
|
|
1523
|
+
def sideBySide(*args, **kwargs):
|
|
1524
|
+
"""
|
|
1525
|
+
display side by side args as HMTL fragment (using string, _repr_html_() or str())
|
|
1526
|
+
|
|
1527
|
+
Parameters
|
|
1528
|
+
----------
|
|
1529
|
+
args: str
|
|
1530
|
+
HMTL fragments as string arg, arg._repr_html_() or str(arg)
|
|
1531
|
+
captions: List[str], optional
|
|
1532
|
+
list of captions
|
|
1533
|
+
valign: str
|
|
1534
|
+
vertical position in the row (top|middle|bottom, middle by default)
|
|
1535
|
+
ncols: int
|
|
1536
|
+
number of columns (infinite by default)
|
|
1537
|
+
"""
|
|
1538
|
+
IPython.display.display(IPython.display.HTML(getSideBySide(*args, **kwargs)))
|
|
1539
|
+
|
|
1540
|
+
|
|
1541
|
+
def getInferenceEngine(ie, inferenceCaption):
|
|
1542
|
+
"""
|
|
1543
|
+
display an inference as a BN+ lists of hard/soft evidence and list of targets
|
|
1544
|
+
|
|
1545
|
+
Parameters
|
|
1546
|
+
---------
|
|
1547
|
+
ie : "pyagrum.InferenceEngine"
|
|
1548
|
+
Inference engine
|
|
1549
|
+
caption: str
|
|
1550
|
+
inferenceCaption: caption for the inference
|
|
1551
|
+
|
|
1552
|
+
Returns
|
|
1553
|
+
-------
|
|
1554
|
+
str
|
|
1555
|
+
the HTML representation
|
|
1556
|
+
"""
|
|
1557
|
+
t = '<div align="left"><ul>'
|
|
1558
|
+
if ie.nbrHardEvidence() > 0:
|
|
1559
|
+
t += "<li><b>hard evidence</b><br/>"
|
|
1560
|
+
t += ", ".join([ie.BN().variable(n).name() for n in ie.hardEvidenceNodes()])
|
|
1561
|
+
t += "</li>"
|
|
1562
|
+
if ie.nbrSoftEvidence() > 0:
|
|
1563
|
+
t += "<li><b>soft evidence</b><br/>"
|
|
1564
|
+
t += ", ".join([ie.BN().variable(n).name() for n in ie.softEvidenceNodes()])
|
|
1565
|
+
t += "</li>"
|
|
1566
|
+
if ie.nbrTargets() > 0:
|
|
1567
|
+
t += "<li><b>target(s)</b><br/>"
|
|
1568
|
+
if ie.nbrTargets() == ie.BN().size():
|
|
1569
|
+
t += " all"
|
|
1570
|
+
else:
|
|
1571
|
+
t += ", ".join([ie.BN().variable(n).name() for n in ie.targets()])
|
|
1572
|
+
t += "</li>"
|
|
1573
|
+
|
|
1574
|
+
if hasattr(ie, "nbrJointTargets") and ie.nbrJointTargets() > 0:
|
|
1575
|
+
t += "<li><b>Joint target(s)</b><br/>"
|
|
1576
|
+
t += ", ".join(["[" + (", ".join([ie.BN().variable(n).name() for n in ns])) + "]" for ns in ie.jointTargets()])
|
|
1577
|
+
t += "</li>"
|
|
1578
|
+
t += "</ul></div>"
|
|
1579
|
+
return getSideBySide(getBN(ie.BN()), t, captions=[inferenceCaption, "Evidence and targets"])
|
|
1580
|
+
|
|
1581
|
+
|
|
1582
|
+
def getJT(jt, size=None):
|
|
1583
|
+
"""
|
|
1584
|
+
returns the representation of a junction tree as a HTML string
|
|
1585
|
+
|
|
1586
|
+
Parameters
|
|
1587
|
+
----------
|
|
1588
|
+
jt: pyagrum.JunctionTree
|
|
1589
|
+
the junction tree
|
|
1590
|
+
size: str
|
|
1591
|
+
the size (for graphviz) of the graph
|
|
1592
|
+
|
|
1593
|
+
Returns
|
|
1594
|
+
-------
|
|
1595
|
+
str
|
|
1596
|
+
the representation of a junction tree as a HTML string
|
|
1597
|
+
|
|
1598
|
+
"""
|
|
1599
|
+
if gum.config.asBool["notebook", "junctiontree_with_names"]:
|
|
1600
|
+
|
|
1601
|
+
def cliqlabels(c):
|
|
1602
|
+
labels = ",".join(sorted([model.variable(n).name() for n in jt.clique(c)]))
|
|
1603
|
+
return f"({c}):{labels}"
|
|
1604
|
+
|
|
1605
|
+
def cliqnames(c):
|
|
1606
|
+
return "-".join(sorted([model.variable(n).name() for n in jt.clique(c)]))
|
|
1607
|
+
|
|
1608
|
+
def seplabels(c1, c2):
|
|
1609
|
+
return ",".join(sorted([model.variable(n).name() for n in jt.separator(c1, c2)]))
|
|
1610
|
+
|
|
1611
|
+
def sepnames(c1, c2):
|
|
1612
|
+
return cliqnames(c1) + "+" + cliqnames(c2)
|
|
1613
|
+
else:
|
|
1614
|
+
|
|
1615
|
+
def cliqlabels(c):
|
|
1616
|
+
ids = ",".join([str(n) for n in sorted(jt.clique(c))])
|
|
1617
|
+
return f"({c}):{ids}"
|
|
1618
|
+
|
|
1619
|
+
def cliqnames(c):
|
|
1620
|
+
return "-".join([str(n) for n in sorted(jt.clique(c))])
|
|
1621
|
+
|
|
1622
|
+
def seplabels(c1, c2):
|
|
1623
|
+
return ",".join([str(n) for n in sorted(jt.separator(c1, c2))])
|
|
1624
|
+
|
|
1625
|
+
def sepnames(c1, c2):
|
|
1626
|
+
return cliqnames(c1) + "^" + cliqnames(c2)
|
|
1627
|
+
|
|
1628
|
+
model = jt._engine._model
|
|
1629
|
+
name = model.propertyWithDefault("name", str(type(model)).split(".")[-1][:-2])
|
|
1630
|
+
graph = dot.Dot(graph_type="graph", graph_name=name, bgcolor="transparent")
|
|
1631
|
+
for c in jt.nodes():
|
|
1632
|
+
graph.add_node(
|
|
1633
|
+
dot.Node(
|
|
1634
|
+
'"' + cliqnames(c) + '"',
|
|
1635
|
+
label='"' + cliqlabels(c) + '"',
|
|
1636
|
+
style="filled",
|
|
1637
|
+
fillcolor=gum.config["notebook", "junctiontree_clique_bgcolor"],
|
|
1638
|
+
fontcolor=gum.config["notebook", "junctiontree_clique_fgcolor"],
|
|
1639
|
+
fontsize=gum.config["notebook", "junctiontree_clique_fontsize"],
|
|
1640
|
+
)
|
|
1641
|
+
)
|
|
1642
|
+
for c1, c2 in jt.edges():
|
|
1643
|
+
graph.add_node(
|
|
1644
|
+
dot.Node(
|
|
1645
|
+
'"' + sepnames(c1, c2) + '"',
|
|
1646
|
+
label='"' + seplabels(c1, c2) + '"',
|
|
1647
|
+
style="filled",
|
|
1648
|
+
shape="box",
|
|
1649
|
+
width="0",
|
|
1650
|
+
height="0",
|
|
1651
|
+
margin="0.02",
|
|
1652
|
+
fillcolor=gum.config["notebook", "junctiontree_separator_bgcolor"],
|
|
1653
|
+
fontcolor=gum.config["notebook", "junctiontree_separator_fgcolor"],
|
|
1654
|
+
fontsize=gum.config["notebook", "junctiontree_separator_fontsize"],
|
|
1655
|
+
)
|
|
1656
|
+
)
|
|
1657
|
+
for c1, c2 in jt.edges():
|
|
1658
|
+
graph.add_edge(dot.Edge('"' + cliqnames(c1) + '"', '"' + sepnames(c1, c2) + '"'))
|
|
1659
|
+
graph.add_edge(dot.Edge('"' + sepnames(c1, c2) + '"', '"' + cliqnames(c2) + '"'))
|
|
1660
|
+
|
|
1661
|
+
graph.set_size(gum.config["notebook", "junctiontree_graph_size"])
|
|
1662
|
+
|
|
1663
|
+
return graph.to_string()
|
|
1664
|
+
|
|
1665
|
+
|
|
1666
|
+
def getCliqueGraph(cg, size=None):
|
|
1667
|
+
"""get a representation for clique graph. Special case for junction tree
|
|
1668
|
+
(clique graph with an attribute `_engine`)
|
|
1669
|
+
|
|
1670
|
+
Parameters
|
|
1671
|
+
cg (gum.CliqueGraph): the clique graph (maybe junction tree for a _model) to
|
|
1672
|
+
represent
|
|
1673
|
+
|
|
1674
|
+
Returns
|
|
1675
|
+
-------
|
|
1676
|
+
pydot.Dot
|
|
1677
|
+
the dot representation of the graph
|
|
1678
|
+
"""
|
|
1679
|
+
if hasattr(cg, "_engine"):
|
|
1680
|
+
return getDot(getJT(cg), size)
|
|
1681
|
+
else:
|
|
1682
|
+
return getDot(cg.toDot())
|
|
1683
|
+
|
|
1684
|
+
|
|
1685
|
+
def show(model, **kwargs):
|
|
1686
|
+
"""
|
|
1687
|
+
propose a (visual) representation of a graphical model or a graph or a Tensor in a notebook
|
|
1688
|
+
|
|
1689
|
+
Parameters
|
|
1690
|
+
----------
|
|
1691
|
+
model
|
|
1692
|
+
the model to show (pyagrum.BayesNet, pyagrum.MarkovRandomField, pyagrum.InfluenceDiagram or pyagrum.Tensor) or a dot string, or a `pydot.Dot` or even just an object with a method `toDot()`.
|
|
1693
|
+
|
|
1694
|
+
size: str
|
|
1695
|
+
size (for graphviz) to represent the graphical model (no effect for Tensor)
|
|
1696
|
+
nodeColor: Dict[str,float]
|
|
1697
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
1698
|
+
factorColor: Dict[int,float]
|
|
1699
|
+
a nodeMap of values (between 0 and 1) to be shown as color of factors (in MarkovRandomField representation)
|
|
1700
|
+
arcWidth: : Dict[(str,str),float]
|
|
1701
|
+
an arcMap of values to be shown as width of arcs
|
|
1702
|
+
arcColor: : Dict[(str,str),float]
|
|
1703
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
1704
|
+
cmapNode: matplotlib.ColorMap
|
|
1705
|
+
map to show the color of nodes and arcs
|
|
1706
|
+
cmapArc: matplotlib.ColorMap
|
|
1707
|
+
color map to show the vals of Arcs.
|
|
1708
|
+
graph: pyagrum.Graph
|
|
1709
|
+
only shows nodes that have their id in the graph (and not in the whole BN)
|
|
1710
|
+
view: str
|
|
1711
|
+
graph | factorgraph | None (default) for Markov random field
|
|
1712
|
+
"""
|
|
1713
|
+
if isinstance(model, gum.BayesNet):
|
|
1714
|
+
showBN(model, **kwargs)
|
|
1715
|
+
elif isinstance(model, gum.MarkovRandomField):
|
|
1716
|
+
showMRF(model, **kwargs)
|
|
1717
|
+
elif isinstance(model, gum.InfluenceDiagram):
|
|
1718
|
+
showInfluenceDiagram(model, **kwargs)
|
|
1719
|
+
elif isinstance(model, gum.CredalNet):
|
|
1720
|
+
showCN(model, **kwargs)
|
|
1721
|
+
elif isinstance(model, gum.Tensor):
|
|
1722
|
+
showTensor(model)
|
|
1723
|
+
elif hasattr(model, "toDot"):
|
|
1724
|
+
showDot(model.toDot(), **kwargs)
|
|
1725
|
+
elif isinstance(model, dot.Dot):
|
|
1726
|
+
showGraph(model, **kwargs)
|
|
1727
|
+
else:
|
|
1728
|
+
raise gum.InvalidArgument(
|
|
1729
|
+
"Argument model should be a PGM (BayesNet, MarkovRandomField, Influence Diagram or Tensor or ..."
|
|
1730
|
+
)
|
|
1731
|
+
|
|
1732
|
+
|
|
1733
|
+
def inspectBN(bn):
|
|
1734
|
+
"""
|
|
1735
|
+
inspect a BN (graph and CPTs) in a notebook (using flow)
|
|
1736
|
+
Parameters
|
|
1737
|
+
----------
|
|
1738
|
+
bn
|
|
1739
|
+
"""
|
|
1740
|
+
flow.row(bn, *[bn.cpt(c) for c in sorted(bn.names())])
|
|
1741
|
+
|
|
1742
|
+
|
|
1743
|
+
def _update_config_notebooks():
|
|
1744
|
+
# hook to control some parameters for notebook when config changes
|
|
1745
|
+
mpl.rcParams["figure.facecolor"] = gum.config["notebook", "figure_facecolor"]
|
|
1746
|
+
set_matplotlib_formats(gum.config["notebook", "graph_format"])
|
|
1747
|
+
|
|
1748
|
+
|
|
1749
|
+
# check if an instance of ipython exists
|
|
1750
|
+
try:
|
|
1751
|
+
get_ipython
|
|
1752
|
+
except NameError:
|
|
1753
|
+
import warnings
|
|
1754
|
+
|
|
1755
|
+
warnings.warn("""
|
|
1756
|
+
** pyagrum.lib.notebook has to be imported from an IPython's instance (mainly notebook).
|
|
1757
|
+
""")
|
|
1758
|
+
else:
|
|
1759
|
+
|
|
1760
|
+
def map(
|
|
1761
|
+
self,
|
|
1762
|
+
scaleClique: float = None,
|
|
1763
|
+
scaleSep: float = None,
|
|
1764
|
+
lenEdge: float = None,
|
|
1765
|
+
colorClique: str = None,
|
|
1766
|
+
colorSep: str = None,
|
|
1767
|
+
) -> dot.Dot:
|
|
1768
|
+
"""
|
|
1769
|
+
show the map of the junction tree.
|
|
1770
|
+
|
|
1771
|
+
Parameters
|
|
1772
|
+
----------
|
|
1773
|
+
scaleClique: float
|
|
1774
|
+
the scale for the size of the clique nodes (depending on the number of nodes in the clique)
|
|
1775
|
+
scaleSep: float
|
|
1776
|
+
the scale for the size of the separator nodes (depending on the number of nodes in the clique)
|
|
1777
|
+
lenEdge: float
|
|
1778
|
+
the desired length of edges
|
|
1779
|
+
colorClique: str
|
|
1780
|
+
color for the clique nodes
|
|
1781
|
+
colorSep: str
|
|
1782
|
+
color for the separator nodes
|
|
1783
|
+
"""
|
|
1784
|
+
if scaleClique is None:
|
|
1785
|
+
scaleClique = float(gum.config["notebook", "junctiontree_map_cliquescale"])
|
|
1786
|
+
if scaleSep is None:
|
|
1787
|
+
scaleSep = float(gum.config["notebook", "junctiontree_map_sepscale"])
|
|
1788
|
+
if lenEdge is None:
|
|
1789
|
+
lenEdge = float(gum.config["notebook", "junctiontree_map_edgelen"])
|
|
1790
|
+
if colorClique is None:
|
|
1791
|
+
colorClique = gum.config["notebook", "junctiontree_clique_bgcolor"]
|
|
1792
|
+
if colorSep is None:
|
|
1793
|
+
colorSep = gum.config["notebook", "junctiontree_separator_bgcolor"]
|
|
1794
|
+
return _from_dotstring(self.__map_str__(scaleClique, scaleSep, lenEdge, colorClique, colorSep))
|
|
1795
|
+
|
|
1796
|
+
setattr(gum.CliqueGraph, "map", map)
|
|
1797
|
+
|
|
1798
|
+
gum.config.add_hook(_update_config_notebooks)
|
|
1799
|
+
gum.config.run_hooks()
|
|
1800
|
+
|
|
1801
|
+
# adding _repr_html_ to some pyAgrum classes !
|
|
1802
|
+
gum.BayesNet._repr_html_ = lambda self: getBN(self)
|
|
1803
|
+
gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
|
|
1804
|
+
gum.MarkovRandomField._repr_html_ = lambda self: getMRF(self)
|
|
1805
|
+
gum.BayesNetFragment._repr_html_ = lambda self: getBN(self)
|
|
1806
|
+
gum.InfluenceDiagram._repr_html_ = lambda self: getInfluenceDiagram(self)
|
|
1807
|
+
gum.CredalNet._repr_html_ = lambda self: getCN(self)
|
|
1808
|
+
|
|
1809
|
+
gum.CliqueGraph._repr_html_ = lambda self: getCliqueGraph(self)
|
|
1810
|
+
|
|
1811
|
+
gum.Tensor._repr_html_ = lambda self: getTensor(self)
|
|
1812
|
+
gum.LazyPropagation._repr_html_ = lambda self: getInferenceEngine(self, "Lazy Propagation on this BN")
|
|
1813
|
+
|
|
1814
|
+
gum.UndiGraph._repr_html_ = lambda self: getDot(self.toDot())
|
|
1815
|
+
gum.DiGraph._repr_html_ = lambda self: getDot(self.toDot())
|
|
1816
|
+
gum.MixedGraph._repr_html_ = lambda self: getDot(self.toDot())
|
|
1817
|
+
gum.DAG._repr_html_ = lambda self: getDot(self.toDot())
|
|
1818
|
+
gum.EssentialGraph._repr_html_ = lambda self: getDot(self.toDot())
|
|
1819
|
+
gum.MarkovBlanket._repr_html_ = lambda self: getDot(self.toDot())
|
|
1820
|
+
|
|
1821
|
+
dot.Dot._repr_html_ = lambda self: getGraph(self)
|