pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +172 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
|
@@ -0,0 +1,1058 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
import IPython
|
|
42
|
+
import hashlib
|
|
43
|
+
import time
|
|
44
|
+
import csv
|
|
45
|
+
import math
|
|
46
|
+
|
|
47
|
+
import pyagrum as gum
|
|
48
|
+
import pyagrum.lib.notebook as gnb
|
|
49
|
+
import pyagrum.lib._colors as gumcols
|
|
50
|
+
import pydot as dot
|
|
51
|
+
import matplotlib.pyplot as plt
|
|
52
|
+
import matplotlib as mpl
|
|
53
|
+
import pyagrum.bnmixture as BNM
|
|
54
|
+
import numpy as np
|
|
55
|
+
|
|
56
|
+
from tempfile import mkdtemp
|
|
57
|
+
from matplotlib.colors import LinearSegmentedColormap
|
|
58
|
+
from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
|
|
59
|
+
from base64 import encodebytes
|
|
60
|
+
from math import ceil
|
|
61
|
+
|
|
62
|
+
_colors = ["lightgreen", "lightseagreen", "green"]
|
|
63
|
+
_cmap1 = LinearSegmentedColormap.from_list("mycmap", _colors)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _compareBN(dotref: dot.Dot, bncmp: dot.Dot) -> dot.Dot:
|
|
67
|
+
"""
|
|
68
|
+
Allow to create a dot representation of a BN with the nodes at the same position as in ``dotref``.
|
|
69
|
+
|
|
70
|
+
Notes
|
|
71
|
+
-----
|
|
72
|
+
Considering that all the BNs of a BNMixture have the same variables, it is not necessary to have a reference BN in parameters.
|
|
73
|
+
|
|
74
|
+
Parameters
|
|
75
|
+
----------
|
|
76
|
+
dotref : pydot.Dot
|
|
77
|
+
pydot graph to get positions from.
|
|
78
|
+
bncmp : pyagrum.BayesNet
|
|
79
|
+
BN to get pydot representation for.
|
|
80
|
+
|
|
81
|
+
Returns
|
|
82
|
+
-------
|
|
83
|
+
pydot.Dot
|
|
84
|
+
pydot representation of Bayesian Network ``bncmp`` with nodes at same position as ``dotref`` nodes.
|
|
85
|
+
"""
|
|
86
|
+
g = dotref
|
|
87
|
+
|
|
88
|
+
# loading positions
|
|
89
|
+
positions = {
|
|
90
|
+
l[1]: f"{l[2]},{l[3]}!"
|
|
91
|
+
for l in csv.reader(g.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
|
|
92
|
+
if len(l) > 3 and l[0] == "node"
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
res = dot.Dot(graph_type="digraph", bgcolor="transparent", layout="fdp", splines=True)
|
|
96
|
+
|
|
97
|
+
# adding nodes
|
|
98
|
+
for i1 in bncmp.nodes():
|
|
99
|
+
res.add_node(
|
|
100
|
+
dot.Node(
|
|
101
|
+
f'"{bncmp.variable(i1).name()}"',
|
|
102
|
+
style="filled",
|
|
103
|
+
fillcolor=gum.config["bnmixture", "default_node_bgcolor"],
|
|
104
|
+
fontcolor=gum.config["bnmixture", "default_node_fgcolor"],
|
|
105
|
+
pos=positions[bncmp.variable(i1).name()],
|
|
106
|
+
)
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
# adding arcs
|
|
110
|
+
for i1, i2 in bncmp.arcs():
|
|
111
|
+
n1 = bncmp.variable(i1).name()
|
|
112
|
+
n2 = bncmp.variable(i2).name()
|
|
113
|
+
res.add_edge(
|
|
114
|
+
dot.Edge(
|
|
115
|
+
f'"{n1}"',
|
|
116
|
+
f'"{n2}"',
|
|
117
|
+
style=gum.config["bnmixture", "default_arc_style"],
|
|
118
|
+
color=gum.config["bnmixture", "default_arc_color"],
|
|
119
|
+
constraint="false",
|
|
120
|
+
)
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
return res
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
def _compareBNinf(bnref: gum.BayesNet, refdot: dot.Dot, cmpDot: dot.Dot, scale=1.0):
|
|
127
|
+
"""
|
|
128
|
+
Allow to modify a pydot graph of inference to have its nodes at the same positions as the nodes in the reference graph.
|
|
129
|
+
|
|
130
|
+
Parameters
|
|
131
|
+
----------
|
|
132
|
+
bnref : pyagrum.BayesNet
|
|
133
|
+
BN used as a reference to get name of variables.
|
|
134
|
+
refdot : pydot.Dot
|
|
135
|
+
dot used as a reference to get the position of the nodes.
|
|
136
|
+
cmpDot : pydot.Dot
|
|
137
|
+
dot of the graph to modify to have the nodes at same position as the nodes in ``refdot``.
|
|
138
|
+
"""
|
|
139
|
+
# nop2 layout allows to give positions to nodes
|
|
140
|
+
cmpDot.set("layout", "nop2")
|
|
141
|
+
cmpDot.set("splines", "true")
|
|
142
|
+
|
|
143
|
+
# loading positions of reference
|
|
144
|
+
# positions are scaled
|
|
145
|
+
x_scale = 80 * scale * 1.5 # 50 * scale * 2
|
|
146
|
+
y_scale = 60 * scale * 1.3 # 50 * scale * 1.5
|
|
147
|
+
positions = {
|
|
148
|
+
l[1]: f"{str(float(l[2]) * x_scale)},{str(float(l[3]) * y_scale)}!"
|
|
149
|
+
for l in csv.reader(refdot.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
|
|
150
|
+
if len(l) > 3 and l[0] == "node"
|
|
151
|
+
}
|
|
152
|
+
|
|
153
|
+
# modifying positions
|
|
154
|
+
for node in cmpDot.get_nodes():
|
|
155
|
+
name = node.get_name()
|
|
156
|
+
# converts a dot name to a normal name
|
|
157
|
+
# necessary because dot name format is ""<name>"".
|
|
158
|
+
namecut = name[1:-1]
|
|
159
|
+
if namecut not in bnref.names():
|
|
160
|
+
continue
|
|
161
|
+
node.set("pos", positions[namecut])
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def getMixtureGraph(bnm: BNM.IMixture, size=None, ref=False):
|
|
165
|
+
"""
|
|
166
|
+
HTML representation of a Mixture.
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
bnm : pyagrum.bnmixture.IMixture
|
|
171
|
+
Mixture to get graph from.
|
|
172
|
+
size : str | int
|
|
173
|
+
Size of the graph.
|
|
174
|
+
ref : bool
|
|
175
|
+
if True, the representation will contain the reference BN's graph.
|
|
176
|
+
"""
|
|
177
|
+
gnb.flow.clear()
|
|
178
|
+
if size is None:
|
|
179
|
+
size = gum.config["bnmixture", "default_graph_size"]
|
|
180
|
+
|
|
181
|
+
dotref = gnb.BN2dot(bnm._refBN, size=size)
|
|
182
|
+
if ref:
|
|
183
|
+
gnb.flow.add(gnb.getGraph(dotref), caption=f"reference BN : {bnm._refName}")
|
|
184
|
+
|
|
185
|
+
for bn in bnm.BNs():
|
|
186
|
+
name = bn.property("name")
|
|
187
|
+
gnb.flow.add(gnb.getGraph(_compareBN(dotref, bn), size=size), caption=f"{name}, w={bnm.weight(name)}")
|
|
188
|
+
|
|
189
|
+
return gnb.flow.html()
|
|
190
|
+
|
|
191
|
+
|
|
192
|
+
def showMixtureGraph(bnm: BNM.IMixture, size=None, ref=False):
|
|
193
|
+
"""
|
|
194
|
+
Display a HTML representation of a Mixture.
|
|
195
|
+
|
|
196
|
+
Parameters
|
|
197
|
+
----------
|
|
198
|
+
bnm : pyagrum.bnmixture.IMixture
|
|
199
|
+
Mixture to get graph from.
|
|
200
|
+
size : str | int
|
|
201
|
+
Size of the graph.
|
|
202
|
+
ref : bool
|
|
203
|
+
if True, the representation will contain the reference BN's graph.
|
|
204
|
+
"""
|
|
205
|
+
html = getMixtureGraph(bnm, size=size, ref=ref)
|
|
206
|
+
IPython.display.display(html)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
def BNMixtureInference2dot(
|
|
210
|
+
bnm: BNM.BNMixture,
|
|
211
|
+
engine=None,
|
|
212
|
+
size=None,
|
|
213
|
+
evs=None,
|
|
214
|
+
targets=None,
|
|
215
|
+
nodeColor=None,
|
|
216
|
+
arcWidth=None,
|
|
217
|
+
arcColor=None,
|
|
218
|
+
cmapNode=None,
|
|
219
|
+
cmapArc=None,
|
|
220
|
+
dag=None,
|
|
221
|
+
):
|
|
222
|
+
"""
|
|
223
|
+
|
|
224
|
+
Creates a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
|
|
225
|
+
|
|
226
|
+
Parameters
|
|
227
|
+
----------
|
|
228
|
+
bnm : pyagrum.bnmixture.BNMixture
|
|
229
|
+
The Bayesian Net Mixture used.
|
|
230
|
+
size : str
|
|
231
|
+
size of the rendered graph
|
|
232
|
+
engine : pyagrum.Inference
|
|
233
|
+
inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
|
|
234
|
+
evs : dict
|
|
235
|
+
map of evidence
|
|
236
|
+
targets : set
|
|
237
|
+
set of targets. If targets={} then each node is a target
|
|
238
|
+
nodeColor : dict
|
|
239
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
240
|
+
arcWidth : dict
|
|
241
|
+
a arcMap of values to be shown as bold arcs
|
|
242
|
+
arcColor : dict
|
|
243
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
244
|
+
cmapNode : ColorMap
|
|
245
|
+
color map to show the vals of Nodes
|
|
246
|
+
cmapArc : ColorMap
|
|
247
|
+
color map to show the vals of Arcs
|
|
248
|
+
dag : pyagrum.DAG
|
|
249
|
+
only shows nodes that have their id in the dag.
|
|
250
|
+
"""
|
|
251
|
+
if evs is None:
|
|
252
|
+
evs = {}
|
|
253
|
+
if targets is None:
|
|
254
|
+
targets = {}
|
|
255
|
+
if cmapNode is None:
|
|
256
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
257
|
+
|
|
258
|
+
if cmapArc is None:
|
|
259
|
+
cmapArc = plt.get_cmap(gum.config["notebook", "default_arc_cmap"])
|
|
260
|
+
|
|
261
|
+
# default
|
|
262
|
+
maxarcs = 100
|
|
263
|
+
minarcs = 0
|
|
264
|
+
|
|
265
|
+
if arcWidth is not None:
|
|
266
|
+
minarcs = min(arcWidth.values())
|
|
267
|
+
maxarcs = max(arcWidth.values())
|
|
268
|
+
|
|
269
|
+
startTime = time.time()
|
|
270
|
+
if engine is None:
|
|
271
|
+
ie = BNM.BNMixtureInference(bnm)
|
|
272
|
+
else:
|
|
273
|
+
ie = BNM.BNMixtureInference(bnm, engine=engine)
|
|
274
|
+
|
|
275
|
+
ie.setEvidence(evs)
|
|
276
|
+
ie.makeInference()
|
|
277
|
+
stopTime = time.time()
|
|
278
|
+
|
|
279
|
+
temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
|
|
280
|
+
|
|
281
|
+
dotstr = 'digraph structs {\n fontcolor="' + gumcols.getBlackInTheme() + '";bgcolor="transparent";'
|
|
282
|
+
|
|
283
|
+
if gum.config.asBool["notebook", "show_inference_time"]:
|
|
284
|
+
dotstr += f' label="Inference in {1000 * (stopTime - startTime):6.2f}ms";\n'
|
|
285
|
+
|
|
286
|
+
fontname, fontsize = gumcols.fontFromMatplotlib()
|
|
287
|
+
dotstr += f' node [fillcolor="{gum.config["notebook", "default_node_bgcolor"]}", style=filled,color="{gum.config["notebook", "default_node_fgcolor"]}",fontname="{fontname}",fontsize="{fontsize}"];\n'
|
|
288
|
+
dotstr += f' edge [color="{gumcols.getBlackInTheme()}"];\n'
|
|
289
|
+
|
|
290
|
+
showdag = bnm._refBN.dag() if dag is None else dag
|
|
291
|
+
|
|
292
|
+
for nid in showdag.nodes():
|
|
293
|
+
name = bnm._refBN.variable(nid).name()
|
|
294
|
+
|
|
295
|
+
# defaults
|
|
296
|
+
bgcol = gum.config["notebook", "default_node_bgcolor"]
|
|
297
|
+
fgcol = gum.config["notebook", "default_node_fgcolor"]
|
|
298
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
299
|
+
bgcol = gum.config["notebook", "figure_facecolor"]
|
|
300
|
+
|
|
301
|
+
if nodeColor is not None and (name in nodeColor or nid in nodeColor):
|
|
302
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
|
|
303
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
|
|
304
|
+
|
|
305
|
+
# 'hard' colour for evidence (?)
|
|
306
|
+
if name in evs or nid in evs:
|
|
307
|
+
bgcol = gum.config["notebook", "evidence_bgcolor"]
|
|
308
|
+
fgcol = gum.config["notebook", "evidence_fgcolor"]
|
|
309
|
+
|
|
310
|
+
colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
|
|
311
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
312
|
+
filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
|
|
313
|
+
# proba_histogram.saveFigProba(ie.posterior(name), filename, bgcolor=bgcol)
|
|
314
|
+
saveFigProba(ie, name, filename, bgcolor=bgcol, scale=float(gum.config["bnmixture", "default_histo_scale"]))
|
|
315
|
+
dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
|
|
316
|
+
else:
|
|
317
|
+
dotstr += f' "{name}" [{colorattribute}]'
|
|
318
|
+
|
|
319
|
+
for a in showdag.arcs():
|
|
320
|
+
(n, j) = a
|
|
321
|
+
pw = 1
|
|
322
|
+
av = f"{n} → {j}"
|
|
323
|
+
col = gumcols.getBlackInTheme()
|
|
324
|
+
|
|
325
|
+
if arcWidth is not None and a in arcWidth:
|
|
326
|
+
if maxarcs != minarcs:
|
|
327
|
+
pw = 0.1 + 5 * (arcWidth[a] - minarcs) / (maxarcs - minarcs)
|
|
328
|
+
av = f"{n} → {j} : {arcWidth[a]}"
|
|
329
|
+
|
|
330
|
+
if arcColor is not None and a in arcColor:
|
|
331
|
+
col = gumcols.proba2color(arcColor[a], cmapArc)
|
|
332
|
+
|
|
333
|
+
dotstr += f' "{bnm._refBN.variable(n).name()}"->"{bnm._refBN.variable(j).name()}" [penwidth="{pw}",tooltip="{av}",color="{col}"];'
|
|
334
|
+
|
|
335
|
+
dotstr += "}"
|
|
336
|
+
|
|
337
|
+
g = dot.graph_from_dot_data(dotstr)[0]
|
|
338
|
+
|
|
339
|
+
# workaround for some badly parsed graph (pyparsing>=3.03)
|
|
340
|
+
g.del_node('"\\n"')
|
|
341
|
+
|
|
342
|
+
if size is None:
|
|
343
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
344
|
+
g.set_size(size)
|
|
345
|
+
g.temp_dir = temp_dir
|
|
346
|
+
|
|
347
|
+
return g
|
|
348
|
+
|
|
349
|
+
|
|
350
|
+
def BootstrapInference2dot(
|
|
351
|
+
bnm: BNM.BootstrapMixture,
|
|
352
|
+
size=None,
|
|
353
|
+
engine=None,
|
|
354
|
+
evs=None,
|
|
355
|
+
targets=None,
|
|
356
|
+
nodeColor=None,
|
|
357
|
+
arcWidth=None,
|
|
358
|
+
arcColor=None,
|
|
359
|
+
cmapNode=None,
|
|
360
|
+
cmapArc=None,
|
|
361
|
+
dag=None,
|
|
362
|
+
quantiles=False,
|
|
363
|
+
show_mu_sigma=False,
|
|
364
|
+
):
|
|
365
|
+
"""
|
|
366
|
+
create a pydot representation of an inference in a BootstrapMixture (reference BN's posterior is used, while other BNs are used to compute stats).
|
|
367
|
+
|
|
368
|
+
Parameters
|
|
369
|
+
----------
|
|
370
|
+
bnm : pyagrum.bnmixture.BootstrapMixture
|
|
371
|
+
the Mixture.
|
|
372
|
+
size: str
|
|
373
|
+
size of the rendered graph
|
|
374
|
+
engine: pyagrum.Inference
|
|
375
|
+
inference algorithm used. If None, LazyPropagation will be used. This is the class, not the initialized object.
|
|
376
|
+
evs: dict
|
|
377
|
+
map of evidence
|
|
378
|
+
targets: set
|
|
379
|
+
set of targets. If targets={} then each node is a target
|
|
380
|
+
nodeColor: dict
|
|
381
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
382
|
+
arcWidth: dict
|
|
383
|
+
a arcMap of values to be shown as bold arcs
|
|
384
|
+
arcColor: dict
|
|
385
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
386
|
+
cmapNode: ColorMap
|
|
387
|
+
color map to show the vals of Nodes
|
|
388
|
+
cmapArc: ColorMap
|
|
389
|
+
color map to show the vals of Arcs
|
|
390
|
+
dag : pyagrum.DAG
|
|
391
|
+
only shows nodes that have their id in the dag (and not in the whole BN)
|
|
392
|
+
quantiles : bool
|
|
393
|
+
if True, shows quantiles on tensors. Quantiles default values can be set using pyagrum.config.
|
|
394
|
+
|
|
395
|
+
Returns
|
|
396
|
+
-------
|
|
397
|
+
the desired representation of the inference
|
|
398
|
+
"""
|
|
399
|
+
if evs is None:
|
|
400
|
+
evs = {}
|
|
401
|
+
if targets is None:
|
|
402
|
+
targets = {}
|
|
403
|
+
if cmapNode is None:
|
|
404
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
405
|
+
|
|
406
|
+
if cmapArc is None:
|
|
407
|
+
cmapArc = plt.get_cmap(gum.config["notebook", "default_arc_cmap"])
|
|
408
|
+
|
|
409
|
+
# defaukt
|
|
410
|
+
maxarcs = 100
|
|
411
|
+
minarcs = 0
|
|
412
|
+
|
|
413
|
+
if arcWidth is not None:
|
|
414
|
+
minarcs = min(arcWidth.values())
|
|
415
|
+
maxarcs = max(arcWidth.values())
|
|
416
|
+
|
|
417
|
+
startTime = time.time()
|
|
418
|
+
if engine is None:
|
|
419
|
+
ie = BNM.BootstrapMixtureInference(bnm)
|
|
420
|
+
else:
|
|
421
|
+
ie = BNM.BootstrapMixtureInference(bnm, engine=engine)
|
|
422
|
+
ie.setEvidence(evs)
|
|
423
|
+
ie.makeInference()
|
|
424
|
+
stopTime = time.time()
|
|
425
|
+
|
|
426
|
+
temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
|
|
427
|
+
|
|
428
|
+
dotstr = 'digraph structs {\n fontcolor="' + gumcols.getBlackInTheme() + '";bgcolor="transparent";'
|
|
429
|
+
|
|
430
|
+
lab = ""
|
|
431
|
+
if gum.config.asBool["notebook", "show_inference_time"]:
|
|
432
|
+
lab += f"Inference in {1000 * (stopTime - startTime):6.2f}ms"
|
|
433
|
+
|
|
434
|
+
if quantiles:
|
|
435
|
+
q1 = float(gum.config["bnmixture", "left_quantile"]) * 100
|
|
436
|
+
q2 = float(gum.config["bnmixture", "right_quantile"]) * 100
|
|
437
|
+
title = f"\nquantiles=[{q1:.1f}%, {q2:.1f}%]"
|
|
438
|
+
lab += f"\n{title}"
|
|
439
|
+
dotstr += f' label="{lab}";\n'
|
|
440
|
+
|
|
441
|
+
fontname, fontsize = gumcols.fontFromMatplotlib()
|
|
442
|
+
dotstr += f' node [fillcolor="{gum.config["notebook", "default_node_bgcolor"]}", style=filled,color="{gum.config["notebook", "default_node_fgcolor"]}",fontname="{fontname}",fontsize="{fontsize}"];\n'
|
|
443
|
+
dotstr += f' edge [color="{gumcols.getBlackInTheme()}"];\n'
|
|
444
|
+
|
|
445
|
+
showdag = bnm._refBN.dag() if dag is None else dag
|
|
446
|
+
|
|
447
|
+
for nid in showdag.nodes():
|
|
448
|
+
name = bnm.variable(nid).name()
|
|
449
|
+
|
|
450
|
+
# defaults
|
|
451
|
+
bgcol = gum.config["notebook", "default_node_bgcolor"]
|
|
452
|
+
fgcol = gum.config["notebook", "default_node_fgcolor"]
|
|
453
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
454
|
+
bgcol = gum.config["notebook", "figure_facecolor"]
|
|
455
|
+
|
|
456
|
+
if nodeColor is not None and (name in nodeColor or nid in nodeColor):
|
|
457
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
|
|
458
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
|
|
459
|
+
|
|
460
|
+
# 'hard' colour for evidence (?)
|
|
461
|
+
if name in evs or nid in evs:
|
|
462
|
+
bgcol = gum.config["notebook", "evidence_bgcolor"]
|
|
463
|
+
fgcol = gum.config["notebook", "evidence_fgcolor"]
|
|
464
|
+
|
|
465
|
+
colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
|
|
466
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
467
|
+
filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
|
|
468
|
+
saveFigProba(
|
|
469
|
+
ie,
|
|
470
|
+
name,
|
|
471
|
+
filename,
|
|
472
|
+
bgcolor=bgcol,
|
|
473
|
+
quantiles=quantiles,
|
|
474
|
+
scale=float(gum.config["bnmixture", "default_boot_histo_scale"]),
|
|
475
|
+
show_mu_sigma=show_mu_sigma,
|
|
476
|
+
)
|
|
477
|
+
dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
|
|
478
|
+
else:
|
|
479
|
+
dotstr += f' "{name}" [{colorattribute}]'
|
|
480
|
+
|
|
481
|
+
for a in showdag.arcs():
|
|
482
|
+
(n, j) = a
|
|
483
|
+
pw = 1
|
|
484
|
+
av = f"{n} → {j}"
|
|
485
|
+
col = gumcols.getBlackInTheme()
|
|
486
|
+
|
|
487
|
+
if arcWidth is not None and a in arcWidth:
|
|
488
|
+
if maxarcs != minarcs:
|
|
489
|
+
pw = 0.1 + 5 * (arcWidth[a] - minarcs) / (maxarcs - minarcs)
|
|
490
|
+
av = f"{n} → {j} : {arcWidth[a]}"
|
|
491
|
+
|
|
492
|
+
if arcColor is not None and a in arcColor:
|
|
493
|
+
col = gumcols.proba2color(arcColor[a], cmapArc)
|
|
494
|
+
|
|
495
|
+
dotstr += f' "{bnm.variable(n).name()}"->"{bnm.variable(j).name()}" [penwidth="{pw}",tooltip="{av}",color="{col}"];'
|
|
496
|
+
|
|
497
|
+
dotstr += "}"
|
|
498
|
+
|
|
499
|
+
g = dot.graph_from_dot_data(dotstr)[0]
|
|
500
|
+
|
|
501
|
+
# workaround for some badly parsed graph (pyparsing>=3.03)
|
|
502
|
+
g.del_node('"\\n"')
|
|
503
|
+
|
|
504
|
+
if size is None:
|
|
505
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
506
|
+
g.set_size(size)
|
|
507
|
+
g.temp_dir = temp_dir
|
|
508
|
+
|
|
509
|
+
return g
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
def showBNMixtureInference(
|
|
513
|
+
bnm: BNM.BNMixture,
|
|
514
|
+
engine=None,
|
|
515
|
+
size=None,
|
|
516
|
+
evs=None,
|
|
517
|
+
targets=None,
|
|
518
|
+
nodeColor=None,
|
|
519
|
+
arcWidth=None,
|
|
520
|
+
arcColor=None,
|
|
521
|
+
cmapNode=None,
|
|
522
|
+
cmapArc=None,
|
|
523
|
+
dag=None,
|
|
524
|
+
):
|
|
525
|
+
"""
|
|
526
|
+
Displays a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
|
|
527
|
+
|
|
528
|
+
Parameters
|
|
529
|
+
----------
|
|
530
|
+
bnm : pyagrum.bnmixture.BNMixture
|
|
531
|
+
The Bayesian Net Mixture used.
|
|
532
|
+
size : str
|
|
533
|
+
size of the rendered graph
|
|
534
|
+
engine : pyagrum.Inference
|
|
535
|
+
inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
|
|
536
|
+
evs : dict
|
|
537
|
+
map of evidence
|
|
538
|
+
targets : set
|
|
539
|
+
set of targets. If targets={} then each node is a target
|
|
540
|
+
nodeColor : dict
|
|
541
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
542
|
+
arcWidth : dict
|
|
543
|
+
a arcMap of values to be shown as bold arcs
|
|
544
|
+
arcColor : dict
|
|
545
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
546
|
+
cmapNode : ColorMap
|
|
547
|
+
color map to show the vals of Nodes
|
|
548
|
+
cmapArc : ColorMap
|
|
549
|
+
color map to show the vals of Arcs
|
|
550
|
+
dag : pyagrum.DAG
|
|
551
|
+
only shows nodes that have their id in the dag.
|
|
552
|
+
"""
|
|
553
|
+
html = BNMixtureInference2dot(
|
|
554
|
+
bnm,
|
|
555
|
+
engine=engine,
|
|
556
|
+
size=size,
|
|
557
|
+
evs=evs,
|
|
558
|
+
targets=targets,
|
|
559
|
+
nodeColor=nodeColor,
|
|
560
|
+
arcWidth=arcWidth,
|
|
561
|
+
arcColor=arcColor,
|
|
562
|
+
cmapNode=cmapNode,
|
|
563
|
+
cmapArc=cmapArc,
|
|
564
|
+
)
|
|
565
|
+
refdot = gnb.BN2dot(bnm._refBN)
|
|
566
|
+
_compareBNinf(bnm._refBN, refdot, html, scale=float(gum.config["bnmixture", "default_histo_scale"]))
|
|
567
|
+
IPython.display.display(html)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
def showBootstrapMixtureInference(
|
|
571
|
+
bnm: BNM.BootstrapMixture,
|
|
572
|
+
engine=None,
|
|
573
|
+
size=None,
|
|
574
|
+
evs=None,
|
|
575
|
+
targets=None,
|
|
576
|
+
nodeColor=None,
|
|
577
|
+
arcWidth=None,
|
|
578
|
+
arcColor=None,
|
|
579
|
+
cmapNode=None,
|
|
580
|
+
cmapArc=None,
|
|
581
|
+
dag=None,
|
|
582
|
+
quantiles=False,
|
|
583
|
+
show_mu_sigma=False,
|
|
584
|
+
):
|
|
585
|
+
"""
|
|
586
|
+
Displays a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
|
|
587
|
+
|
|
588
|
+
Parameters
|
|
589
|
+
----------
|
|
590
|
+
bnm : pyagrum.bnmixture.BootstrapMixture
|
|
591
|
+
The Bayesian Net Mixture used.
|
|
592
|
+
size : str
|
|
593
|
+
size of the rendered graph
|
|
594
|
+
engine : pyagrum.Inference
|
|
595
|
+
inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
|
|
596
|
+
evs : dict
|
|
597
|
+
map of evidence
|
|
598
|
+
targets : set
|
|
599
|
+
set of targets. If targets={} then each node is a target
|
|
600
|
+
nodeColor : dict
|
|
601
|
+
a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
602
|
+
arcWidth : dict
|
|
603
|
+
a arcMap of values to be shown as bold arcs
|
|
604
|
+
arcColor : dict
|
|
605
|
+
a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
606
|
+
cmapNode : ColorMap
|
|
607
|
+
color map to show the vals of Nodes
|
|
608
|
+
cmapArc : ColorMap
|
|
609
|
+
color map to show the vals of Arcs
|
|
610
|
+
dag : pyagrum.DAG
|
|
611
|
+
only shows nodes that have their id in the dag.
|
|
612
|
+
quantiles : bool
|
|
613
|
+
if True, shows quantiles on tensors. Quantiles default values can be set using pyagrum.config.
|
|
614
|
+
"""
|
|
615
|
+
html = BootstrapInference2dot(
|
|
616
|
+
bnm,
|
|
617
|
+
engine=engine,
|
|
618
|
+
size=size,
|
|
619
|
+
evs=evs,
|
|
620
|
+
targets=targets,
|
|
621
|
+
nodeColor=nodeColor,
|
|
622
|
+
arcWidth=arcWidth,
|
|
623
|
+
arcColor=arcColor,
|
|
624
|
+
cmapNode=cmapNode,
|
|
625
|
+
cmapArc=cmapArc,
|
|
626
|
+
quantiles=quantiles,
|
|
627
|
+
show_mu_sigma=show_mu_sigma,
|
|
628
|
+
)
|
|
629
|
+
refdot = gnb.BN2dot(bnm._refBN)
|
|
630
|
+
_compareBNinf(bnm._refBN, refdot, html, scale=float(gum.config["bnmixture", "default_boot_histo_scale"]))
|
|
631
|
+
IPython.display.display(html)
|
|
632
|
+
|
|
633
|
+
|
|
634
|
+
def _normalizedArcsWeight(bnm: BNM.IMixture):
|
|
635
|
+
"""
|
|
636
|
+
Counts arcs in the BNs of the mixture. The value of an arc is the weight of the BN containing it.
|
|
637
|
+
Result is normalized.
|
|
638
|
+
"""
|
|
639
|
+
countArcs = {nod1: {nod2: 0.0 for nod2 in bnm._refBN.names() if nod2 != nod1} for nod1 in bnm._refBN.names()}
|
|
640
|
+
sum_weight = sum(bnm._weights.values())
|
|
641
|
+
|
|
642
|
+
mi = 1
|
|
643
|
+
ma = 0
|
|
644
|
+
for bn_name in bnm.names():
|
|
645
|
+
bn = bnm.BN(bn_name)
|
|
646
|
+
w = bnm.weight(bn_name)
|
|
647
|
+
for a, b in bn.arcs():
|
|
648
|
+
tail = bn.variable(a).name()
|
|
649
|
+
head = bn.variable(b).name()
|
|
650
|
+
countArcs[tail][head] += w
|
|
651
|
+
|
|
652
|
+
for n1 in bnm._refBN.names():
|
|
653
|
+
for n2 in bnm._refBN.names():
|
|
654
|
+
if n1 == n2:
|
|
655
|
+
continue
|
|
656
|
+
countArcs[n1][n2] = countArcs[n1][n2] / sum_weight
|
|
657
|
+
if 0 != countArcs[n1][n2] < mi:
|
|
658
|
+
mi = countArcs[n1][n2]
|
|
659
|
+
if countArcs[n1][n2] > ma:
|
|
660
|
+
ma = countArcs[n1][n2]
|
|
661
|
+
|
|
662
|
+
return countArcs, mi, ma
|
|
663
|
+
|
|
664
|
+
|
|
665
|
+
def _compareArcs2dot(bnm: BNM.IMixture, size=None, refStruct=False):
|
|
666
|
+
"""
|
|
667
|
+
Pydot representation of a graph that shows confidence value for every arc in the mixture.
|
|
668
|
+
"""
|
|
669
|
+
countArcs, mi, ma = _normalizedArcsWeight(bnm)
|
|
670
|
+
g = gnb.BN2dot(bnm._refBN, size=size)
|
|
671
|
+
positions = {
|
|
672
|
+
l[1]: f"{str(float(l[2]) * 2)},{str(float(l[3]) * 2)}!"
|
|
673
|
+
for l in csv.reader(g.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
|
|
674
|
+
if len(l) > 3 and l[0] == "node"
|
|
675
|
+
}
|
|
676
|
+
|
|
677
|
+
if refStruct:
|
|
678
|
+
res = dot.Dot(
|
|
679
|
+
graph_type="digraph", bgcolor="transparent", layout=gum.config["bnmixture", "default_layout"], splines=True
|
|
680
|
+
)
|
|
681
|
+
else:
|
|
682
|
+
res = dot.Dot(
|
|
683
|
+
graph_type="digraph",
|
|
684
|
+
bgcolor="transparent",
|
|
685
|
+
layout=gum.config["bnmixture", "default_layout"],
|
|
686
|
+
splines=True,
|
|
687
|
+
overlap_scaling=gum.config["bnmixture", "default_overlap"],
|
|
688
|
+
sep=3,
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
for vname in bnm._refBN.names():
|
|
692
|
+
if refStruct:
|
|
693
|
+
pos = positions[vname]
|
|
694
|
+
else:
|
|
695
|
+
pos = ""
|
|
696
|
+
res.add_node(
|
|
697
|
+
dot.Node(
|
|
698
|
+
f'"{vname}"',
|
|
699
|
+
style="filled",
|
|
700
|
+
fillcolor=gum.config["bnmixture", "default_node_bgcolor"],
|
|
701
|
+
fontcolor=gum.config["bnmixture", "default_node_fgcolor"],
|
|
702
|
+
pos=pos,
|
|
703
|
+
)
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
for n1 in bnm._refBN.names():
|
|
707
|
+
for n2 in bnm._refBN.names():
|
|
708
|
+
if n1 == n2 or countArcs[n1][n2] == 0:
|
|
709
|
+
continue
|
|
710
|
+
if bnm._refBN.existsArc(n1, n2):
|
|
711
|
+
style = gum.config["bnmixture", "correct_arc_style"]
|
|
712
|
+
else:
|
|
713
|
+
style = gum.config["bnmixture", "incorrect_arc_style"]
|
|
714
|
+
|
|
715
|
+
# print(f"({n1}, {n2}) {countArcs[n1][n2]}")
|
|
716
|
+
col = gumcols.proba2color(min(countArcs[n1][n2], 0.99), _cmap1)
|
|
717
|
+
res.add_edge(
|
|
718
|
+
dot.Edge(
|
|
719
|
+
f'"{n1}"',
|
|
720
|
+
f'"{n2}"',
|
|
721
|
+
style=style,
|
|
722
|
+
color=col,
|
|
723
|
+
penwidth=ceil(countArcs[n1][n2] * 6),
|
|
724
|
+
arrowhead=gum.config["bnmixture", "default_arrow_type"],
|
|
725
|
+
arrowsize=gum.config["bnmixture", "default_head_size"] * ceil(countArcs[n1][n2] * 6),
|
|
726
|
+
constraint="false",
|
|
727
|
+
)
|
|
728
|
+
)
|
|
729
|
+
|
|
730
|
+
return res, mi, ma
|
|
731
|
+
|
|
732
|
+
|
|
733
|
+
def getComparison(bnm, size=None, refStruct=False):
|
|
734
|
+
"""
|
|
735
|
+
dot representation of a graph that shows confidence value for every arc in the mixture.
|
|
736
|
+
|
|
737
|
+
Parameters
|
|
738
|
+
----------
|
|
739
|
+
bnm : the mixture
|
|
740
|
+
size : the size of the dot figure
|
|
741
|
+
refStruct : do we use a reference structure
|
|
742
|
+
|
|
743
|
+
Returns
|
|
744
|
+
-------
|
|
745
|
+
the dot representation
|
|
746
|
+
"""
|
|
747
|
+
gr, _, _ = _compareArcs2dot(bnm, size=size, refStruct=refStruct)
|
|
748
|
+
return gr
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def showComparison(bnm, size=None, refStruct=False):
|
|
752
|
+
"""
|
|
753
|
+
draw the graph that shows confidence representation for every arc in the mixture.
|
|
754
|
+
|
|
755
|
+
Parameters
|
|
756
|
+
----------
|
|
757
|
+
bnm : the mixture
|
|
758
|
+
size : the size of the dot figure
|
|
759
|
+
refStruct : do we use a reference structure
|
|
760
|
+
"""
|
|
761
|
+
gnb.show(getComparison(bnm, size, refStruct=refStruct))
|
|
762
|
+
|
|
763
|
+
|
|
764
|
+
def getArcsComparison(bnm, size=None, refStruct=False):
|
|
765
|
+
"""
|
|
766
|
+
html representation of a graph that shows confidence value for every arc in the mixture.
|
|
767
|
+
Shows the confidence value of minimum and maximum on a confidence axis.
|
|
768
|
+
"""
|
|
769
|
+
gr, mi, ma = _compareArcs2dot(bnm, size=size, refStruct=refStruct)
|
|
770
|
+
gsvg = IPython.display.SVG(gr.create_svg(encoding="utf-8"))
|
|
771
|
+
width = (
|
|
772
|
+
int(gsvg.data.split("width=")[1].split('"')[1].split("pt")[0]) / mpl.pyplot.rcParams["figure.dpi"]
|
|
773
|
+
) # pixel in inches
|
|
774
|
+
if width < 5:
|
|
775
|
+
width = 5
|
|
776
|
+
|
|
777
|
+
fig = mpl.figure.Figure(figsize=(width, 1))
|
|
778
|
+
fig.patch.set_alpha(0)
|
|
779
|
+
canvas = fc(fig)
|
|
780
|
+
ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
|
|
781
|
+
norm = mpl.colors.Normalize()
|
|
782
|
+
cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=_cmap1, norm=norm, orientation="horizontal")
|
|
783
|
+
cb1.set_label("Confidence")
|
|
784
|
+
cb1.ax.text(0.1, -3, f"min {mi:.4f}", ha="left", va="top", color=gumcols.proba2bgcolor(max(mi, 0.01), _cmap1))
|
|
785
|
+
cb1.ax.text(0.9, -3, f"max {ma:.4f}", ha="right", va="top", color=gumcols.proba2bgcolor(min(ma, 0.99), _cmap1))
|
|
786
|
+
cb1.ax.text(mi, 1, "|", ha="center", va="top", color="red")
|
|
787
|
+
cb1.ax.text(ma, 1, "|", ha="center", va="top", color="red")
|
|
788
|
+
png = IPython.core.pylabtools.print_figure(canvas.figure, "png") # from IPython.core.pylabtools
|
|
789
|
+
png_legend = f"<img style='vertical-align:middle' src='data:image/png;base64,{encodebytes(png).decode('ascii')}'>"
|
|
790
|
+
|
|
791
|
+
sss = f"<div align='center'>{gsvg.data}<br/>{png_legend}</div>"
|
|
792
|
+
return sss
|
|
793
|
+
|
|
794
|
+
|
|
795
|
+
def showArcsComparison(bnm, size=None, refStruct=False):
|
|
796
|
+
"""
|
|
797
|
+
Displays representation of a graph that shows confidence value for every arc in the mixture.
|
|
798
|
+
Shows the confidence value of minimum and maximum on a confidence axis.
|
|
799
|
+
"""
|
|
800
|
+
sss = getArcsComparison(bnm, size=size, refStruct=refStruct)
|
|
801
|
+
return IPython.display.display(IPython.display.HTML(sss))
|
|
802
|
+
|
|
803
|
+
|
|
804
|
+
def arcsCompLegend():
|
|
805
|
+
res = dot.Dot(graph_type="digraph", bgcolor="transparent", rankdir="LR")
|
|
806
|
+
for i in "abcd":
|
|
807
|
+
res.add_node(dot.Node(i, style="invis"))
|
|
808
|
+
res.add_edge(
|
|
809
|
+
dot.Edge(
|
|
810
|
+
"a",
|
|
811
|
+
"b",
|
|
812
|
+
label="Present in reference",
|
|
813
|
+
style=gum.config["bnmixture", "correct_arc_style"],
|
|
814
|
+
color=gum.config["bnmixture", "correct_arc_color"],
|
|
815
|
+
)
|
|
816
|
+
)
|
|
817
|
+
res.add_edge(
|
|
818
|
+
dot.Edge(
|
|
819
|
+
"c",
|
|
820
|
+
"d",
|
|
821
|
+
label="Absent from reference",
|
|
822
|
+
style=gum.config["bnmixture", "incorrect_arc_style"],
|
|
823
|
+
color=gum.config["bnmixture", "correct_arc_color"],
|
|
824
|
+
)
|
|
825
|
+
)
|
|
826
|
+
|
|
827
|
+
return res
|
|
828
|
+
|
|
829
|
+
|
|
830
|
+
####################### tool box for quantiles #######################
|
|
831
|
+
def tensor2ref(ref, tens) -> "pyagrum.Tensor":
|
|
832
|
+
"""
|
|
833
|
+
Returns a copy of ``tens`` but with a reference to ``ref`` 's variables instead. Allow to sum tensors that have the same variables
|
|
834
|
+
but with different instantiations of them.
|
|
835
|
+
|
|
836
|
+
Parameters
|
|
837
|
+
----------
|
|
838
|
+
ref : pyagrum.Tensor
|
|
839
|
+
Tensor containing variables of reference.
|
|
840
|
+
tens : pyagrum.Tensor
|
|
841
|
+
Tensor to convert.
|
|
842
|
+
|
|
843
|
+
Returns
|
|
844
|
+
-------
|
|
845
|
+
pyagrum.Tensor
|
|
846
|
+
The converted tensor with values of ``tens`` and variable references of ``ref``.
|
|
847
|
+
"""
|
|
848
|
+
res = pyagrum.Tensor()
|
|
849
|
+
for v in tens.names:
|
|
850
|
+
res.add(ref.variable(v))
|
|
851
|
+
return res.fillWith(tens)
|
|
852
|
+
|
|
853
|
+
|
|
854
|
+
def _stats(tens):
|
|
855
|
+
"""
|
|
856
|
+
Returns mean and variance.
|
|
857
|
+
"""
|
|
858
|
+
mu = 0.0
|
|
859
|
+
mu2 = 0.0
|
|
860
|
+
v = tens.variable(0)
|
|
861
|
+
for i, p in enumerate(tens.tolist()):
|
|
862
|
+
x = v.numerical(i)
|
|
863
|
+
mu += p * x
|
|
864
|
+
mu2 += p * x * x
|
|
865
|
+
return mu, math.sqrt(mu2 - mu * mu)
|
|
866
|
+
|
|
867
|
+
|
|
868
|
+
def _getTitleHisto(p, show_mu_sigma=False):
|
|
869
|
+
"""
|
|
870
|
+
Return title of a histogram.
|
|
871
|
+
"""
|
|
872
|
+
var = p.variable(0)
|
|
873
|
+
title = var.name()
|
|
874
|
+
|
|
875
|
+
if show_mu_sigma:
|
|
876
|
+
(mu, std) = _stats(p)
|
|
877
|
+
if std > 0.0:
|
|
878
|
+
title += f"\n$\\mu={mu:.2f}$; $\\sigma={std:.2f}$"
|
|
879
|
+
|
|
880
|
+
return title
|
|
881
|
+
|
|
882
|
+
|
|
883
|
+
def _getProbaH(ie, var_name, scale=1.0, txtcolor="black", quantiles=False, show_mu_sigma=False):
|
|
884
|
+
"""
|
|
885
|
+
Compute the representation of a horizontal histogram of a variable posterior
|
|
886
|
+
|
|
887
|
+
Parameters
|
|
888
|
+
----------
|
|
889
|
+
ie : IMixtureInference
|
|
890
|
+
Inference of a mixture.
|
|
891
|
+
var_name : str
|
|
892
|
+
Name of the variable to get histogram for.
|
|
893
|
+
scale : float
|
|
894
|
+
scale for the size of histograms
|
|
895
|
+
txtcolor : str
|
|
896
|
+
color for text
|
|
897
|
+
quantiles : bool
|
|
898
|
+
shows quantiles.
|
|
899
|
+
show_mu_sigma : bool
|
|
900
|
+
shows mu and sigma.
|
|
901
|
+
|
|
902
|
+
Returns
|
|
903
|
+
-------
|
|
904
|
+
matplotlib.Figure
|
|
905
|
+
A matplotlib histogram of the posterior.
|
|
906
|
+
"""
|
|
907
|
+
|
|
908
|
+
pots = ie._posteriors(var_name)
|
|
909
|
+
for tens in pots.values():
|
|
910
|
+
tens.normalize()
|
|
911
|
+
avrg = ie.posterior(var_name)
|
|
912
|
+
var = avrg.variable(0)
|
|
913
|
+
ra = np.arange(var.domainSize())
|
|
914
|
+
|
|
915
|
+
ra_reverse = np.arange(var.domainSize() - 1, -1, -1) # reverse order
|
|
916
|
+
|
|
917
|
+
vx = [var.label(int(i)) for i in ra_reverse]
|
|
918
|
+
fig = plt.figure()
|
|
919
|
+
fig.set_figheight(scale * var.domainSize() / 4.0)
|
|
920
|
+
fig.set_figwidth(scale * 2)
|
|
921
|
+
|
|
922
|
+
ax = fig.add_subplot(111)
|
|
923
|
+
ax.set_facecolor("white")
|
|
924
|
+
|
|
925
|
+
if gum.config.asBool["notebook", "histogram_use_percent"]:
|
|
926
|
+
perc = 100
|
|
927
|
+
suffix = "%"
|
|
928
|
+
else:
|
|
929
|
+
perc = 1
|
|
930
|
+
suffix = ""
|
|
931
|
+
|
|
932
|
+
if quantiles:
|
|
933
|
+
pmin, pmax = ie.quantiles(var_name)
|
|
934
|
+
vmin = pmin.tolist()
|
|
935
|
+
vmin.reverse()
|
|
936
|
+
vmax = pmax.tolist()
|
|
937
|
+
vmax.reverse()
|
|
938
|
+
vmean = avrg.tolist()
|
|
939
|
+
vmean.reverse()
|
|
940
|
+
error = [(abs(mean - mi), abs(ma - mean)) for mi, ma, mean in zip(vmin, vmax, vmean)]
|
|
941
|
+
barmean = ax.barh(
|
|
942
|
+
ra,
|
|
943
|
+
vmean,
|
|
944
|
+
align="center",
|
|
945
|
+
height=float(gum.config["bnmixture", "default_bar_height"]),
|
|
946
|
+
color=gum.config["notebook", "histogram_color"],
|
|
947
|
+
xerr=error,
|
|
948
|
+
capsize=float(gum.config["bnmixture", "default_bar_capsize"]) * scale,
|
|
949
|
+
)
|
|
950
|
+
|
|
951
|
+
for b in barmean:
|
|
952
|
+
txt = f"{b.get_width() * perc:.{gum.config.asInt['notebook', 'histogram_horizontal_visible_digits']}f}{suffix}"
|
|
953
|
+
# ax.text(0.5, b.get_y(), txt, ha='center', va='bottom')
|
|
954
|
+
if b.get_width() >= 0.2 * (2 / scale):
|
|
955
|
+
ax.text(b.get_width(), b.get_y(), txt, ha="right", va="bottom", fontsize=10, color="white")
|
|
956
|
+
else:
|
|
957
|
+
ax.text(b.get_width(), b.get_y(), txt, ha="left", va="bottom", fontsize=10)
|
|
958
|
+
|
|
959
|
+
else:
|
|
960
|
+
vmean = avrg.tolist()
|
|
961
|
+
vmean.reverse()
|
|
962
|
+
barmean = ax.barh(
|
|
963
|
+
ra,
|
|
964
|
+
vmean,
|
|
965
|
+
align="center",
|
|
966
|
+
height=float(gum.config["bnmixture", "default_bar_height"]),
|
|
967
|
+
color=gum.config["notebook", "histogram_color"],
|
|
968
|
+
)
|
|
969
|
+
|
|
970
|
+
for b in barmean:
|
|
971
|
+
txt = f"{b.get_width() * perc:.{gum.config.asInt['notebook', 'histogram_horizontal_visible_digits']}f}{suffix}"
|
|
972
|
+
if b.get_width() >= 0.2 * (2 / scale):
|
|
973
|
+
ax.text(b.get_width(), b.get_y(), txt, ha="right", va="bottom", fontsize=10, color="white")
|
|
974
|
+
else:
|
|
975
|
+
ax.text(b.get_width(), b.get_y(), txt, ha="left", va="bottom", fontsize=10)
|
|
976
|
+
|
|
977
|
+
ax.set_xlim(0, 1)
|
|
978
|
+
ax.set_yticks(np.arange(var.domainSize()))
|
|
979
|
+
ax.set_yticklabels(vx, color=txtcolor)
|
|
980
|
+
ax.set_xticklabels([])
|
|
981
|
+
# ax.set_xlabel('Probability')
|
|
982
|
+
# Even if utility, now we do show the mean/sigma of the distribution.
|
|
983
|
+
ax.set_title(_getTitleHisto(avrg, show_mu_sigma=show_mu_sigma), color=txtcolor)
|
|
984
|
+
ax.get_xaxis().grid(True)
|
|
985
|
+
ax.margins(0)
|
|
986
|
+
|
|
987
|
+
return fig
|
|
988
|
+
|
|
989
|
+
|
|
990
|
+
def proba2histo(ie, var_name, scale=1.0, txtcolor="Black", quantiles=False, show_mu_sigma=False):
|
|
991
|
+
"""
|
|
992
|
+
Compute the representation of a horizontal histogram of a variable posterior
|
|
993
|
+
|
|
994
|
+
Parameters
|
|
995
|
+
----------
|
|
996
|
+
ie : IMixtureInference
|
|
997
|
+
Inference of a mixture.
|
|
998
|
+
var_name : str
|
|
999
|
+
Name of the variable to get histogram for.
|
|
1000
|
+
scale : float
|
|
1001
|
+
scale for the size of histograms
|
|
1002
|
+
txtcolor : str
|
|
1003
|
+
color for text
|
|
1004
|
+
quantiles : bool
|
|
1005
|
+
shows quantiles.
|
|
1006
|
+
show_mu_sigma : bool
|
|
1007
|
+
shows mu and sigma.
|
|
1008
|
+
|
|
1009
|
+
Returns
|
|
1010
|
+
-------
|
|
1011
|
+
matplotlib.Figure
|
|
1012
|
+
A matplotlib histogram of the posterior.
|
|
1013
|
+
"""
|
|
1014
|
+
return _getProbaH(ie, var_name, scale=scale, txtcolor=txtcolor, quantiles=quantiles, show_mu_sigma=show_mu_sigma)
|
|
1015
|
+
|
|
1016
|
+
|
|
1017
|
+
def saveFigProba(
|
|
1018
|
+
ie, var_name, filename, bgcolor=None, txtcolor="Black", quantiles=False, scale=1.0, show_mu_sigma=False
|
|
1019
|
+
):
|
|
1020
|
+
"""
|
|
1021
|
+
Saves a figure which is the representation of a histogram for a posterior.
|
|
1022
|
+
|
|
1023
|
+
Parameters
|
|
1024
|
+
----------
|
|
1025
|
+
ie : IMixtureInference
|
|
1026
|
+
Inference of a mixture.
|
|
1027
|
+
var_name : str
|
|
1028
|
+
Name of the variable to get histogram for.
|
|
1029
|
+
filename: str
|
|
1030
|
+
the name of the saved file
|
|
1031
|
+
bgcolor: str
|
|
1032
|
+
color for background (transparent if None)
|
|
1033
|
+
txtcolor : str
|
|
1034
|
+
color for text
|
|
1035
|
+
scale : float
|
|
1036
|
+
scale for the size of histograms
|
|
1037
|
+
quantiles : bool
|
|
1038
|
+
shows quantiles.
|
|
1039
|
+
show_mu_sigma : bool
|
|
1040
|
+
shows mu and sigma.
|
|
1041
|
+
"""
|
|
1042
|
+
fig = proba2histo(ie, var_name, txtcolor=txtcolor, quantiles=quantiles, scale=scale, show_mu_sigma=show_mu_sigma)
|
|
1043
|
+
|
|
1044
|
+
if bgcolor is None:
|
|
1045
|
+
fc = gum.config["notebook", "figure_facecolor"]
|
|
1046
|
+
else:
|
|
1047
|
+
fc = bgcolor
|
|
1048
|
+
|
|
1049
|
+
fig.savefig(
|
|
1050
|
+
filename,
|
|
1051
|
+
bbox_inches="tight",
|
|
1052
|
+
transparent=False,
|
|
1053
|
+
facecolor=fc,
|
|
1054
|
+
pad_inches=0.05,
|
|
1055
|
+
dpi=fig.dpi,
|
|
1056
|
+
format=gum.config["notebook", "graph_format"],
|
|
1057
|
+
)
|
|
1058
|
+
plt.close(fig)
|