pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +172 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
@@ -0,0 +1,1058 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ import IPython
42
+ import hashlib
43
+ import time
44
+ import csv
45
+ import math
46
+
47
+ import pyagrum as gum
48
+ import pyagrum.lib.notebook as gnb
49
+ import pyagrum.lib._colors as gumcols
50
+ import pydot as dot
51
+ import matplotlib.pyplot as plt
52
+ import matplotlib as mpl
53
+ import pyagrum.bnmixture as BNM
54
+ import numpy as np
55
+
56
+ from tempfile import mkdtemp
57
+ from matplotlib.colors import LinearSegmentedColormap
58
+ from matplotlib.backends.backend_agg import FigureCanvasAgg as fc
59
+ from base64 import encodebytes
60
+ from math import ceil
61
+
62
+ _colors = ["lightgreen", "lightseagreen", "green"]
63
+ _cmap1 = LinearSegmentedColormap.from_list("mycmap", _colors)
64
+
65
+
66
+ def _compareBN(dotref: dot.Dot, bncmp: dot.Dot) -> dot.Dot:
67
+ """
68
+ Allow to create a dot representation of a BN with the nodes at the same position as in ``dotref``.
69
+
70
+ Notes
71
+ -----
72
+ Considering that all the BNs of a BNMixture have the same variables, it is not necessary to have a reference BN in parameters.
73
+
74
+ Parameters
75
+ ----------
76
+ dotref : pydot.Dot
77
+ pydot graph to get positions from.
78
+ bncmp : pyagrum.BayesNet
79
+ BN to get pydot representation for.
80
+
81
+ Returns
82
+ -------
83
+ pydot.Dot
84
+ pydot representation of Bayesian Network ``bncmp`` with nodes at same position as ``dotref`` nodes.
85
+ """
86
+ g = dotref
87
+
88
+ # loading positions
89
+ positions = {
90
+ l[1]: f"{l[2]},{l[3]}!"
91
+ for l in csv.reader(g.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
92
+ if len(l) > 3 and l[0] == "node"
93
+ }
94
+
95
+ res = dot.Dot(graph_type="digraph", bgcolor="transparent", layout="fdp", splines=True)
96
+
97
+ # adding nodes
98
+ for i1 in bncmp.nodes():
99
+ res.add_node(
100
+ dot.Node(
101
+ f'"{bncmp.variable(i1).name()}"',
102
+ style="filled",
103
+ fillcolor=gum.config["bnmixture", "default_node_bgcolor"],
104
+ fontcolor=gum.config["bnmixture", "default_node_fgcolor"],
105
+ pos=positions[bncmp.variable(i1).name()],
106
+ )
107
+ )
108
+
109
+ # adding arcs
110
+ for i1, i2 in bncmp.arcs():
111
+ n1 = bncmp.variable(i1).name()
112
+ n2 = bncmp.variable(i2).name()
113
+ res.add_edge(
114
+ dot.Edge(
115
+ f'"{n1}"',
116
+ f'"{n2}"',
117
+ style=gum.config["bnmixture", "default_arc_style"],
118
+ color=gum.config["bnmixture", "default_arc_color"],
119
+ constraint="false",
120
+ )
121
+ )
122
+
123
+ return res
124
+
125
+
126
+ def _compareBNinf(bnref: gum.BayesNet, refdot: dot.Dot, cmpDot: dot.Dot, scale=1.0):
127
+ """
128
+ Allow to modify a pydot graph of inference to have its nodes at the same positions as the nodes in the reference graph.
129
+
130
+ Parameters
131
+ ----------
132
+ bnref : pyagrum.BayesNet
133
+ BN used as a reference to get name of variables.
134
+ refdot : pydot.Dot
135
+ dot used as a reference to get the position of the nodes.
136
+ cmpDot : pydot.Dot
137
+ dot of the graph to modify to have the nodes at same position as the nodes in ``refdot``.
138
+ """
139
+ # nop2 layout allows to give positions to nodes
140
+ cmpDot.set("layout", "nop2")
141
+ cmpDot.set("splines", "true")
142
+
143
+ # loading positions of reference
144
+ # positions are scaled
145
+ x_scale = 80 * scale * 1.5 # 50 * scale * 2
146
+ y_scale = 60 * scale * 1.3 # 50 * scale * 1.5
147
+ positions = {
148
+ l[1]: f"{str(float(l[2]) * x_scale)},{str(float(l[3]) * y_scale)}!"
149
+ for l in csv.reader(refdot.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
150
+ if len(l) > 3 and l[0] == "node"
151
+ }
152
+
153
+ # modifying positions
154
+ for node in cmpDot.get_nodes():
155
+ name = node.get_name()
156
+ # converts a dot name to a normal name
157
+ # necessary because dot name format is ""<name>"".
158
+ namecut = name[1:-1]
159
+ if namecut not in bnref.names():
160
+ continue
161
+ node.set("pos", positions[namecut])
162
+
163
+
164
+ def getMixtureGraph(bnm: BNM.IMixture, size=None, ref=False):
165
+ """
166
+ HTML representation of a Mixture.
167
+
168
+ Parameters
169
+ ----------
170
+ bnm : pyagrum.bnmixture.IMixture
171
+ Mixture to get graph from.
172
+ size : str | int
173
+ Size of the graph.
174
+ ref : bool
175
+ if True, the representation will contain the reference BN's graph.
176
+ """
177
+ gnb.flow.clear()
178
+ if size is None:
179
+ size = gum.config["bnmixture", "default_graph_size"]
180
+
181
+ dotref = gnb.BN2dot(bnm._refBN, size=size)
182
+ if ref:
183
+ gnb.flow.add(gnb.getGraph(dotref), caption=f"reference BN : {bnm._refName}")
184
+
185
+ for bn in bnm.BNs():
186
+ name = bn.property("name")
187
+ gnb.flow.add(gnb.getGraph(_compareBN(dotref, bn), size=size), caption=f"{name}, w={bnm.weight(name)}")
188
+
189
+ return gnb.flow.html()
190
+
191
+
192
+ def showMixtureGraph(bnm: BNM.IMixture, size=None, ref=False):
193
+ """
194
+ Display a HTML representation of a Mixture.
195
+
196
+ Parameters
197
+ ----------
198
+ bnm : pyagrum.bnmixture.IMixture
199
+ Mixture to get graph from.
200
+ size : str | int
201
+ Size of the graph.
202
+ ref : bool
203
+ if True, the representation will contain the reference BN's graph.
204
+ """
205
+ html = getMixtureGraph(bnm, size=size, ref=ref)
206
+ IPython.display.display(html)
207
+
208
+
209
+ def BNMixtureInference2dot(
210
+ bnm: BNM.BNMixture,
211
+ engine=None,
212
+ size=None,
213
+ evs=None,
214
+ targets=None,
215
+ nodeColor=None,
216
+ arcWidth=None,
217
+ arcColor=None,
218
+ cmapNode=None,
219
+ cmapArc=None,
220
+ dag=None,
221
+ ):
222
+ """
223
+
224
+ Creates a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
225
+
226
+ Parameters
227
+ ----------
228
+ bnm : pyagrum.bnmixture.BNMixture
229
+ The Bayesian Net Mixture used.
230
+ size : str
231
+ size of the rendered graph
232
+ engine : pyagrum.Inference
233
+ inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
234
+ evs : dict
235
+ map of evidence
236
+ targets : set
237
+ set of targets. If targets={} then each node is a target
238
+ nodeColor : dict
239
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
240
+ arcWidth : dict
241
+ a arcMap of values to be shown as bold arcs
242
+ arcColor : dict
243
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
244
+ cmapNode : ColorMap
245
+ color map to show the vals of Nodes
246
+ cmapArc : ColorMap
247
+ color map to show the vals of Arcs
248
+ dag : pyagrum.DAG
249
+ only shows nodes that have their id in the dag.
250
+ """
251
+ if evs is None:
252
+ evs = {}
253
+ if targets is None:
254
+ targets = {}
255
+ if cmapNode is None:
256
+ cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
257
+
258
+ if cmapArc is None:
259
+ cmapArc = plt.get_cmap(gum.config["notebook", "default_arc_cmap"])
260
+
261
+ # default
262
+ maxarcs = 100
263
+ minarcs = 0
264
+
265
+ if arcWidth is not None:
266
+ minarcs = min(arcWidth.values())
267
+ maxarcs = max(arcWidth.values())
268
+
269
+ startTime = time.time()
270
+ if engine is None:
271
+ ie = BNM.BNMixtureInference(bnm)
272
+ else:
273
+ ie = BNM.BNMixtureInference(bnm, engine=engine)
274
+
275
+ ie.setEvidence(evs)
276
+ ie.makeInference()
277
+ stopTime = time.time()
278
+
279
+ temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
280
+
281
+ dotstr = 'digraph structs {\n fontcolor="' + gumcols.getBlackInTheme() + '";bgcolor="transparent";'
282
+
283
+ if gum.config.asBool["notebook", "show_inference_time"]:
284
+ dotstr += f' label="Inference in {1000 * (stopTime - startTime):6.2f}ms";\n'
285
+
286
+ fontname, fontsize = gumcols.fontFromMatplotlib()
287
+ dotstr += f' node [fillcolor="{gum.config["notebook", "default_node_bgcolor"]}", style=filled,color="{gum.config["notebook", "default_node_fgcolor"]}",fontname="{fontname}",fontsize="{fontsize}"];\n'
288
+ dotstr += f' edge [color="{gumcols.getBlackInTheme()}"];\n'
289
+
290
+ showdag = bnm._refBN.dag() if dag is None else dag
291
+
292
+ for nid in showdag.nodes():
293
+ name = bnm._refBN.variable(nid).name()
294
+
295
+ # defaults
296
+ bgcol = gum.config["notebook", "default_node_bgcolor"]
297
+ fgcol = gum.config["notebook", "default_node_fgcolor"]
298
+ if len(targets) == 0 or name in targets or nid in targets:
299
+ bgcol = gum.config["notebook", "figure_facecolor"]
300
+
301
+ if nodeColor is not None and (name in nodeColor or nid in nodeColor):
302
+ bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
303
+ fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
304
+
305
+ # 'hard' colour for evidence (?)
306
+ if name in evs or nid in evs:
307
+ bgcol = gum.config["notebook", "evidence_bgcolor"]
308
+ fgcol = gum.config["notebook", "evidence_fgcolor"]
309
+
310
+ colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
311
+ if len(targets) == 0 or name in targets or nid in targets:
312
+ filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
313
+ # proba_histogram.saveFigProba(ie.posterior(name), filename, bgcolor=bgcol)
314
+ saveFigProba(ie, name, filename, bgcolor=bgcol, scale=float(gum.config["bnmixture", "default_histo_scale"]))
315
+ dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
316
+ else:
317
+ dotstr += f' "{name}" [{colorattribute}]'
318
+
319
+ for a in showdag.arcs():
320
+ (n, j) = a
321
+ pw = 1
322
+ av = f"{n}&nbsp;&rarr;&nbsp;{j}"
323
+ col = gumcols.getBlackInTheme()
324
+
325
+ if arcWidth is not None and a in arcWidth:
326
+ if maxarcs != minarcs:
327
+ pw = 0.1 + 5 * (arcWidth[a] - minarcs) / (maxarcs - minarcs)
328
+ av = f"{n}&nbsp;&rarr;&nbsp;{j} : {arcWidth[a]}"
329
+
330
+ if arcColor is not None and a in arcColor:
331
+ col = gumcols.proba2color(arcColor[a], cmapArc)
332
+
333
+ dotstr += f' "{bnm._refBN.variable(n).name()}"->"{bnm._refBN.variable(j).name()}" [penwidth="{pw}",tooltip="{av}",color="{col}"];'
334
+
335
+ dotstr += "}"
336
+
337
+ g = dot.graph_from_dot_data(dotstr)[0]
338
+
339
+ # workaround for some badly parsed graph (pyparsing>=3.03)
340
+ g.del_node('"\\n"')
341
+
342
+ if size is None:
343
+ size = gum.config["notebook", "default_graph_inference_size"]
344
+ g.set_size(size)
345
+ g.temp_dir = temp_dir
346
+
347
+ return g
348
+
349
+
350
+ def BootstrapInference2dot(
351
+ bnm: BNM.BootstrapMixture,
352
+ size=None,
353
+ engine=None,
354
+ evs=None,
355
+ targets=None,
356
+ nodeColor=None,
357
+ arcWidth=None,
358
+ arcColor=None,
359
+ cmapNode=None,
360
+ cmapArc=None,
361
+ dag=None,
362
+ quantiles=False,
363
+ show_mu_sigma=False,
364
+ ):
365
+ """
366
+ create a pydot representation of an inference in a BootstrapMixture (reference BN's posterior is used, while other BNs are used to compute stats).
367
+
368
+ Parameters
369
+ ----------
370
+ bnm : pyagrum.bnmixture.BootstrapMixture
371
+ the Mixture.
372
+ size: str
373
+ size of the rendered graph
374
+ engine: pyagrum.Inference
375
+ inference algorithm used. If None, LazyPropagation will be used. This is the class, not the initialized object.
376
+ evs: dict
377
+ map of evidence
378
+ targets: set
379
+ set of targets. If targets={} then each node is a target
380
+ nodeColor: dict
381
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
382
+ arcWidth: dict
383
+ a arcMap of values to be shown as bold arcs
384
+ arcColor: dict
385
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
386
+ cmapNode: ColorMap
387
+ color map to show the vals of Nodes
388
+ cmapArc: ColorMap
389
+ color map to show the vals of Arcs
390
+ dag : pyagrum.DAG
391
+ only shows nodes that have their id in the dag (and not in the whole BN)
392
+ quantiles : bool
393
+ if True, shows quantiles on tensors. Quantiles default values can be set using pyagrum.config.
394
+
395
+ Returns
396
+ -------
397
+ the desired representation of the inference
398
+ """
399
+ if evs is None:
400
+ evs = {}
401
+ if targets is None:
402
+ targets = {}
403
+ if cmapNode is None:
404
+ cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
405
+
406
+ if cmapArc is None:
407
+ cmapArc = plt.get_cmap(gum.config["notebook", "default_arc_cmap"])
408
+
409
+ # defaukt
410
+ maxarcs = 100
411
+ minarcs = 0
412
+
413
+ if arcWidth is not None:
414
+ minarcs = min(arcWidth.values())
415
+ maxarcs = max(arcWidth.values())
416
+
417
+ startTime = time.time()
418
+ if engine is None:
419
+ ie = BNM.BootstrapMixtureInference(bnm)
420
+ else:
421
+ ie = BNM.BootstrapMixtureInference(bnm, engine=engine)
422
+ ie.setEvidence(evs)
423
+ ie.makeInference()
424
+ stopTime = time.time()
425
+
426
+ temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
427
+
428
+ dotstr = 'digraph structs {\n fontcolor="' + gumcols.getBlackInTheme() + '";bgcolor="transparent";'
429
+
430
+ lab = ""
431
+ if gum.config.asBool["notebook", "show_inference_time"]:
432
+ lab += f"Inference in {1000 * (stopTime - startTime):6.2f}ms"
433
+
434
+ if quantiles:
435
+ q1 = float(gum.config["bnmixture", "left_quantile"]) * 100
436
+ q2 = float(gum.config["bnmixture", "right_quantile"]) * 100
437
+ title = f"\nquantiles=[{q1:.1f}%, {q2:.1f}%]"
438
+ lab += f"\n{title}"
439
+ dotstr += f' label="{lab}";\n'
440
+
441
+ fontname, fontsize = gumcols.fontFromMatplotlib()
442
+ dotstr += f' node [fillcolor="{gum.config["notebook", "default_node_bgcolor"]}", style=filled,color="{gum.config["notebook", "default_node_fgcolor"]}",fontname="{fontname}",fontsize="{fontsize}"];\n'
443
+ dotstr += f' edge [color="{gumcols.getBlackInTheme()}"];\n'
444
+
445
+ showdag = bnm._refBN.dag() if dag is None else dag
446
+
447
+ for nid in showdag.nodes():
448
+ name = bnm.variable(nid).name()
449
+
450
+ # defaults
451
+ bgcol = gum.config["notebook", "default_node_bgcolor"]
452
+ fgcol = gum.config["notebook", "default_node_fgcolor"]
453
+ if len(targets) == 0 or name in targets or nid in targets:
454
+ bgcol = gum.config["notebook", "figure_facecolor"]
455
+
456
+ if nodeColor is not None and (name in nodeColor or nid in nodeColor):
457
+ bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
458
+ fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
459
+
460
+ # 'hard' colour for evidence (?)
461
+ if name in evs or nid in evs:
462
+ bgcol = gum.config["notebook", "evidence_bgcolor"]
463
+ fgcol = gum.config["notebook", "evidence_fgcolor"]
464
+
465
+ colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
466
+ if len(targets) == 0 or name in targets or nid in targets:
467
+ filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
468
+ saveFigProba(
469
+ ie,
470
+ name,
471
+ filename,
472
+ bgcolor=bgcol,
473
+ quantiles=quantiles,
474
+ scale=float(gum.config["bnmixture", "default_boot_histo_scale"]),
475
+ show_mu_sigma=show_mu_sigma,
476
+ )
477
+ dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
478
+ else:
479
+ dotstr += f' "{name}" [{colorattribute}]'
480
+
481
+ for a in showdag.arcs():
482
+ (n, j) = a
483
+ pw = 1
484
+ av = f"{n}&nbsp;&rarr;&nbsp;{j}"
485
+ col = gumcols.getBlackInTheme()
486
+
487
+ if arcWidth is not None and a in arcWidth:
488
+ if maxarcs != minarcs:
489
+ pw = 0.1 + 5 * (arcWidth[a] - minarcs) / (maxarcs - minarcs)
490
+ av = f"{n}&nbsp;&rarr;&nbsp;{j} : {arcWidth[a]}"
491
+
492
+ if arcColor is not None and a in arcColor:
493
+ col = gumcols.proba2color(arcColor[a], cmapArc)
494
+
495
+ dotstr += f' "{bnm.variable(n).name()}"->"{bnm.variable(j).name()}" [penwidth="{pw}",tooltip="{av}",color="{col}"];'
496
+
497
+ dotstr += "}"
498
+
499
+ g = dot.graph_from_dot_data(dotstr)[0]
500
+
501
+ # workaround for some badly parsed graph (pyparsing>=3.03)
502
+ g.del_node('"\\n"')
503
+
504
+ if size is None:
505
+ size = gum.config["notebook", "default_graph_inference_size"]
506
+ g.set_size(size)
507
+ g.temp_dir = temp_dir
508
+
509
+ return g
510
+
511
+
512
+ def showBNMixtureInference(
513
+ bnm: BNM.BNMixture,
514
+ engine=None,
515
+ size=None,
516
+ evs=None,
517
+ targets=None,
518
+ nodeColor=None,
519
+ arcWidth=None,
520
+ arcColor=None,
521
+ cmapNode=None,
522
+ cmapArc=None,
523
+ dag=None,
524
+ ):
525
+ """
526
+ Displays a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
527
+
528
+ Parameters
529
+ ----------
530
+ bnm : pyagrum.bnmixture.BNMixture
531
+ The Bayesian Net Mixture used.
532
+ size : str
533
+ size of the rendered graph
534
+ engine : pyagrum.Inference
535
+ inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
536
+ evs : dict
537
+ map of evidence
538
+ targets : set
539
+ set of targets. If targets={} then each node is a target
540
+ nodeColor : dict
541
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
542
+ arcWidth : dict
543
+ a arcMap of values to be shown as bold arcs
544
+ arcColor : dict
545
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
546
+ cmapNode : ColorMap
547
+ color map to show the vals of Nodes
548
+ cmapArc : ColorMap
549
+ color map to show the vals of Arcs
550
+ dag : pyagrum.DAG
551
+ only shows nodes that have their id in the dag.
552
+ """
553
+ html = BNMixtureInference2dot(
554
+ bnm,
555
+ engine=engine,
556
+ size=size,
557
+ evs=evs,
558
+ targets=targets,
559
+ nodeColor=nodeColor,
560
+ arcWidth=arcWidth,
561
+ arcColor=arcColor,
562
+ cmapNode=cmapNode,
563
+ cmapArc=cmapArc,
564
+ )
565
+ refdot = gnb.BN2dot(bnm._refBN)
566
+ _compareBNinf(bnm._refBN, refdot, html, scale=float(gum.config["bnmixture", "default_histo_scale"]))
567
+ IPython.display.display(html)
568
+
569
+
570
+ def showBootstrapMixtureInference(
571
+ bnm: BNM.BootstrapMixture,
572
+ engine=None,
573
+ size=None,
574
+ evs=None,
575
+ targets=None,
576
+ nodeColor=None,
577
+ arcWidth=None,
578
+ arcColor=None,
579
+ cmapNode=None,
580
+ cmapArc=None,
581
+ dag=None,
582
+ quantiles=False,
583
+ show_mu_sigma=False,
584
+ ):
585
+ """
586
+ Displays a HTML representation of the inference graph of a BNM.BNMixture (average of all posteriors in the BNMixture).
587
+
588
+ Parameters
589
+ ----------
590
+ bnm : pyagrum.bnmixture.BootstrapMixture
591
+ The Bayesian Net Mixture used.
592
+ size : str
593
+ size of the rendered graph
594
+ engine : pyagrum.Inference
595
+ inference algorithm used. If None, LazyPropagation will be used. Note : this is an unitialized class object.
596
+ evs : dict
597
+ map of evidence
598
+ targets : set
599
+ set of targets. If targets={} then each node is a target
600
+ nodeColor : dict
601
+ a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
602
+ arcWidth : dict
603
+ a arcMap of values to be shown as bold arcs
604
+ arcColor : dict
605
+ a arcMap of values (between 0 and 1) to be shown as color of arcs
606
+ cmapNode : ColorMap
607
+ color map to show the vals of Nodes
608
+ cmapArc : ColorMap
609
+ color map to show the vals of Arcs
610
+ dag : pyagrum.DAG
611
+ only shows nodes that have their id in the dag.
612
+ quantiles : bool
613
+ if True, shows quantiles on tensors. Quantiles default values can be set using pyagrum.config.
614
+ """
615
+ html = BootstrapInference2dot(
616
+ bnm,
617
+ engine=engine,
618
+ size=size,
619
+ evs=evs,
620
+ targets=targets,
621
+ nodeColor=nodeColor,
622
+ arcWidth=arcWidth,
623
+ arcColor=arcColor,
624
+ cmapNode=cmapNode,
625
+ cmapArc=cmapArc,
626
+ quantiles=quantiles,
627
+ show_mu_sigma=show_mu_sigma,
628
+ )
629
+ refdot = gnb.BN2dot(bnm._refBN)
630
+ _compareBNinf(bnm._refBN, refdot, html, scale=float(gum.config["bnmixture", "default_boot_histo_scale"]))
631
+ IPython.display.display(html)
632
+
633
+
634
+ def _normalizedArcsWeight(bnm: BNM.IMixture):
635
+ """
636
+ Counts arcs in the BNs of the mixture. The value of an arc is the weight of the BN containing it.
637
+ Result is normalized.
638
+ """
639
+ countArcs = {nod1: {nod2: 0.0 for nod2 in bnm._refBN.names() if nod2 != nod1} for nod1 in bnm._refBN.names()}
640
+ sum_weight = sum(bnm._weights.values())
641
+
642
+ mi = 1
643
+ ma = 0
644
+ for bn_name in bnm.names():
645
+ bn = bnm.BN(bn_name)
646
+ w = bnm.weight(bn_name)
647
+ for a, b in bn.arcs():
648
+ tail = bn.variable(a).name()
649
+ head = bn.variable(b).name()
650
+ countArcs[tail][head] += w
651
+
652
+ for n1 in bnm._refBN.names():
653
+ for n2 in bnm._refBN.names():
654
+ if n1 == n2:
655
+ continue
656
+ countArcs[n1][n2] = countArcs[n1][n2] / sum_weight
657
+ if 0 != countArcs[n1][n2] < mi:
658
+ mi = countArcs[n1][n2]
659
+ if countArcs[n1][n2] > ma:
660
+ ma = countArcs[n1][n2]
661
+
662
+ return countArcs, mi, ma
663
+
664
+
665
+ def _compareArcs2dot(bnm: BNM.IMixture, size=None, refStruct=False):
666
+ """
667
+ Pydot representation of a graph that shows confidence value for every arc in the mixture.
668
+ """
669
+ countArcs, mi, ma = _normalizedArcsWeight(bnm)
670
+ g = gnb.BN2dot(bnm._refBN, size=size)
671
+ positions = {
672
+ l[1]: f"{str(float(l[2]) * 2)},{str(float(l[3]) * 2)}!"
673
+ for l in csv.reader(g.create(format="plain").decode("utf8").split("\n"), delimiter=" ", quotechar='"')
674
+ if len(l) > 3 and l[0] == "node"
675
+ }
676
+
677
+ if refStruct:
678
+ res = dot.Dot(
679
+ graph_type="digraph", bgcolor="transparent", layout=gum.config["bnmixture", "default_layout"], splines=True
680
+ )
681
+ else:
682
+ res = dot.Dot(
683
+ graph_type="digraph",
684
+ bgcolor="transparent",
685
+ layout=gum.config["bnmixture", "default_layout"],
686
+ splines=True,
687
+ overlap_scaling=gum.config["bnmixture", "default_overlap"],
688
+ sep=3,
689
+ )
690
+
691
+ for vname in bnm._refBN.names():
692
+ if refStruct:
693
+ pos = positions[vname]
694
+ else:
695
+ pos = ""
696
+ res.add_node(
697
+ dot.Node(
698
+ f'"{vname}"',
699
+ style="filled",
700
+ fillcolor=gum.config["bnmixture", "default_node_bgcolor"],
701
+ fontcolor=gum.config["bnmixture", "default_node_fgcolor"],
702
+ pos=pos,
703
+ )
704
+ )
705
+
706
+ for n1 in bnm._refBN.names():
707
+ for n2 in bnm._refBN.names():
708
+ if n1 == n2 or countArcs[n1][n2] == 0:
709
+ continue
710
+ if bnm._refBN.existsArc(n1, n2):
711
+ style = gum.config["bnmixture", "correct_arc_style"]
712
+ else:
713
+ style = gum.config["bnmixture", "incorrect_arc_style"]
714
+
715
+ # print(f"({n1}, {n2}) {countArcs[n1][n2]}")
716
+ col = gumcols.proba2color(min(countArcs[n1][n2], 0.99), _cmap1)
717
+ res.add_edge(
718
+ dot.Edge(
719
+ f'"{n1}"',
720
+ f'"{n2}"',
721
+ style=style,
722
+ color=col,
723
+ penwidth=ceil(countArcs[n1][n2] * 6),
724
+ arrowhead=gum.config["bnmixture", "default_arrow_type"],
725
+ arrowsize=gum.config["bnmixture", "default_head_size"] * ceil(countArcs[n1][n2] * 6),
726
+ constraint="false",
727
+ )
728
+ )
729
+
730
+ return res, mi, ma
731
+
732
+
733
+ def getComparison(bnm, size=None, refStruct=False):
734
+ """
735
+ dot representation of a graph that shows confidence value for every arc in the mixture.
736
+
737
+ Parameters
738
+ ----------
739
+ bnm : the mixture
740
+ size : the size of the dot figure
741
+ refStruct : do we use a reference structure
742
+
743
+ Returns
744
+ -------
745
+ the dot representation
746
+ """
747
+ gr, _, _ = _compareArcs2dot(bnm, size=size, refStruct=refStruct)
748
+ return gr
749
+
750
+
751
+ def showComparison(bnm, size=None, refStruct=False):
752
+ """
753
+ draw the graph that shows confidence representation for every arc in the mixture.
754
+
755
+ Parameters
756
+ ----------
757
+ bnm : the mixture
758
+ size : the size of the dot figure
759
+ refStruct : do we use a reference structure
760
+ """
761
+ gnb.show(getComparison(bnm, size, refStruct=refStruct))
762
+
763
+
764
+ def getArcsComparison(bnm, size=None, refStruct=False):
765
+ """
766
+ html representation of a graph that shows confidence value for every arc in the mixture.
767
+ Shows the confidence value of minimum and maximum on a confidence axis.
768
+ """
769
+ gr, mi, ma = _compareArcs2dot(bnm, size=size, refStruct=refStruct)
770
+ gsvg = IPython.display.SVG(gr.create_svg(encoding="utf-8"))
771
+ width = (
772
+ int(gsvg.data.split("width=")[1].split('"')[1].split("pt")[0]) / mpl.pyplot.rcParams["figure.dpi"]
773
+ ) # pixel in inches
774
+ if width < 5:
775
+ width = 5
776
+
777
+ fig = mpl.figure.Figure(figsize=(width, 1))
778
+ fig.patch.set_alpha(0)
779
+ canvas = fc(fig)
780
+ ax1 = fig.add_axes([0.05, 0.80, 0.9, 0.15])
781
+ norm = mpl.colors.Normalize()
782
+ cb1 = mpl.colorbar.ColorbarBase(ax1, cmap=_cmap1, norm=norm, orientation="horizontal")
783
+ cb1.set_label("Confidence")
784
+ cb1.ax.text(0.1, -3, f"min {mi:.4f}", ha="left", va="top", color=gumcols.proba2bgcolor(max(mi, 0.01), _cmap1))
785
+ cb1.ax.text(0.9, -3, f"max {ma:.4f}", ha="right", va="top", color=gumcols.proba2bgcolor(min(ma, 0.99), _cmap1))
786
+ cb1.ax.text(mi, 1, "|", ha="center", va="top", color="red")
787
+ cb1.ax.text(ma, 1, "|", ha="center", va="top", color="red")
788
+ png = IPython.core.pylabtools.print_figure(canvas.figure, "png") # from IPython.core.pylabtools
789
+ png_legend = f"<img style='vertical-align:middle' src='data:image/png;base64,{encodebytes(png).decode('ascii')}'>"
790
+
791
+ sss = f"<div align='center'>{gsvg.data}<br/>{png_legend}</div>"
792
+ return sss
793
+
794
+
795
+ def showArcsComparison(bnm, size=None, refStruct=False):
796
+ """
797
+ Displays representation of a graph that shows confidence value for every arc in the mixture.
798
+ Shows the confidence value of minimum and maximum on a confidence axis.
799
+ """
800
+ sss = getArcsComparison(bnm, size=size, refStruct=refStruct)
801
+ return IPython.display.display(IPython.display.HTML(sss))
802
+
803
+
804
+ def arcsCompLegend():
805
+ res = dot.Dot(graph_type="digraph", bgcolor="transparent", rankdir="LR")
806
+ for i in "abcd":
807
+ res.add_node(dot.Node(i, style="invis"))
808
+ res.add_edge(
809
+ dot.Edge(
810
+ "a",
811
+ "b",
812
+ label="Present in reference",
813
+ style=gum.config["bnmixture", "correct_arc_style"],
814
+ color=gum.config["bnmixture", "correct_arc_color"],
815
+ )
816
+ )
817
+ res.add_edge(
818
+ dot.Edge(
819
+ "c",
820
+ "d",
821
+ label="Absent from reference",
822
+ style=gum.config["bnmixture", "incorrect_arc_style"],
823
+ color=gum.config["bnmixture", "correct_arc_color"],
824
+ )
825
+ )
826
+
827
+ return res
828
+
829
+
830
+ ####################### tool box for quantiles #######################
831
+ def tensor2ref(ref, tens) -> "pyagrum.Tensor":
832
+ """
833
+ Returns a copy of ``tens`` but with a reference to ``ref`` 's variables instead. Allow to sum tensors that have the same variables
834
+ but with different instantiations of them.
835
+
836
+ Parameters
837
+ ----------
838
+ ref : pyagrum.Tensor
839
+ Tensor containing variables of reference.
840
+ tens : pyagrum.Tensor
841
+ Tensor to convert.
842
+
843
+ Returns
844
+ -------
845
+ pyagrum.Tensor
846
+ The converted tensor with values of ``tens`` and variable references of ``ref``.
847
+ """
848
+ res = pyagrum.Tensor()
849
+ for v in tens.names:
850
+ res.add(ref.variable(v))
851
+ return res.fillWith(tens)
852
+
853
+
854
+ def _stats(tens):
855
+ """
856
+ Returns mean and variance.
857
+ """
858
+ mu = 0.0
859
+ mu2 = 0.0
860
+ v = tens.variable(0)
861
+ for i, p in enumerate(tens.tolist()):
862
+ x = v.numerical(i)
863
+ mu += p * x
864
+ mu2 += p * x * x
865
+ return mu, math.sqrt(mu2 - mu * mu)
866
+
867
+
868
+ def _getTitleHisto(p, show_mu_sigma=False):
869
+ """
870
+ Return title of a histogram.
871
+ """
872
+ var = p.variable(0)
873
+ title = var.name()
874
+
875
+ if show_mu_sigma:
876
+ (mu, std) = _stats(p)
877
+ if std > 0.0:
878
+ title += f"\n$\\mu={mu:.2f}$; $\\sigma={std:.2f}$"
879
+
880
+ return title
881
+
882
+
883
+ def _getProbaH(ie, var_name, scale=1.0, txtcolor="black", quantiles=False, show_mu_sigma=False):
884
+ """
885
+ Compute the representation of a horizontal histogram of a variable posterior
886
+
887
+ Parameters
888
+ ----------
889
+ ie : IMixtureInference
890
+ Inference of a mixture.
891
+ var_name : str
892
+ Name of the variable to get histogram for.
893
+ scale : float
894
+ scale for the size of histograms
895
+ txtcolor : str
896
+ color for text
897
+ quantiles : bool
898
+ shows quantiles.
899
+ show_mu_sigma : bool
900
+ shows mu and sigma.
901
+
902
+ Returns
903
+ -------
904
+ matplotlib.Figure
905
+ A matplotlib histogram of the posterior.
906
+ """
907
+
908
+ pots = ie._posteriors(var_name)
909
+ for tens in pots.values():
910
+ tens.normalize()
911
+ avrg = ie.posterior(var_name)
912
+ var = avrg.variable(0)
913
+ ra = np.arange(var.domainSize())
914
+
915
+ ra_reverse = np.arange(var.domainSize() - 1, -1, -1) # reverse order
916
+
917
+ vx = [var.label(int(i)) for i in ra_reverse]
918
+ fig = plt.figure()
919
+ fig.set_figheight(scale * var.domainSize() / 4.0)
920
+ fig.set_figwidth(scale * 2)
921
+
922
+ ax = fig.add_subplot(111)
923
+ ax.set_facecolor("white")
924
+
925
+ if gum.config.asBool["notebook", "histogram_use_percent"]:
926
+ perc = 100
927
+ suffix = "%"
928
+ else:
929
+ perc = 1
930
+ suffix = ""
931
+
932
+ if quantiles:
933
+ pmin, pmax = ie.quantiles(var_name)
934
+ vmin = pmin.tolist()
935
+ vmin.reverse()
936
+ vmax = pmax.tolist()
937
+ vmax.reverse()
938
+ vmean = avrg.tolist()
939
+ vmean.reverse()
940
+ error = [(abs(mean - mi), abs(ma - mean)) for mi, ma, mean in zip(vmin, vmax, vmean)]
941
+ barmean = ax.barh(
942
+ ra,
943
+ vmean,
944
+ align="center",
945
+ height=float(gum.config["bnmixture", "default_bar_height"]),
946
+ color=gum.config["notebook", "histogram_color"],
947
+ xerr=error,
948
+ capsize=float(gum.config["bnmixture", "default_bar_capsize"]) * scale,
949
+ )
950
+
951
+ for b in barmean:
952
+ txt = f"{b.get_width() * perc:.{gum.config.asInt['notebook', 'histogram_horizontal_visible_digits']}f}{suffix}"
953
+ # ax.text(0.5, b.get_y(), txt, ha='center', va='bottom')
954
+ if b.get_width() >= 0.2 * (2 / scale):
955
+ ax.text(b.get_width(), b.get_y(), txt, ha="right", va="bottom", fontsize=10, color="white")
956
+ else:
957
+ ax.text(b.get_width(), b.get_y(), txt, ha="left", va="bottom", fontsize=10)
958
+
959
+ else:
960
+ vmean = avrg.tolist()
961
+ vmean.reverse()
962
+ barmean = ax.barh(
963
+ ra,
964
+ vmean,
965
+ align="center",
966
+ height=float(gum.config["bnmixture", "default_bar_height"]),
967
+ color=gum.config["notebook", "histogram_color"],
968
+ )
969
+
970
+ for b in barmean:
971
+ txt = f"{b.get_width() * perc:.{gum.config.asInt['notebook', 'histogram_horizontal_visible_digits']}f}{suffix}"
972
+ if b.get_width() >= 0.2 * (2 / scale):
973
+ ax.text(b.get_width(), b.get_y(), txt, ha="right", va="bottom", fontsize=10, color="white")
974
+ else:
975
+ ax.text(b.get_width(), b.get_y(), txt, ha="left", va="bottom", fontsize=10)
976
+
977
+ ax.set_xlim(0, 1)
978
+ ax.set_yticks(np.arange(var.domainSize()))
979
+ ax.set_yticklabels(vx, color=txtcolor)
980
+ ax.set_xticklabels([])
981
+ # ax.set_xlabel('Probability')
982
+ # Even if utility, now we do show the mean/sigma of the distribution.
983
+ ax.set_title(_getTitleHisto(avrg, show_mu_sigma=show_mu_sigma), color=txtcolor)
984
+ ax.get_xaxis().grid(True)
985
+ ax.margins(0)
986
+
987
+ return fig
988
+
989
+
990
+ def proba2histo(ie, var_name, scale=1.0, txtcolor="Black", quantiles=False, show_mu_sigma=False):
991
+ """
992
+ Compute the representation of a horizontal histogram of a variable posterior
993
+
994
+ Parameters
995
+ ----------
996
+ ie : IMixtureInference
997
+ Inference of a mixture.
998
+ var_name : str
999
+ Name of the variable to get histogram for.
1000
+ scale : float
1001
+ scale for the size of histograms
1002
+ txtcolor : str
1003
+ color for text
1004
+ quantiles : bool
1005
+ shows quantiles.
1006
+ show_mu_sigma : bool
1007
+ shows mu and sigma.
1008
+
1009
+ Returns
1010
+ -------
1011
+ matplotlib.Figure
1012
+ A matplotlib histogram of the posterior.
1013
+ """
1014
+ return _getProbaH(ie, var_name, scale=scale, txtcolor=txtcolor, quantiles=quantiles, show_mu_sigma=show_mu_sigma)
1015
+
1016
+
1017
+ def saveFigProba(
1018
+ ie, var_name, filename, bgcolor=None, txtcolor="Black", quantiles=False, scale=1.0, show_mu_sigma=False
1019
+ ):
1020
+ """
1021
+ Saves a figure which is the representation of a histogram for a posterior.
1022
+
1023
+ Parameters
1024
+ ----------
1025
+ ie : IMixtureInference
1026
+ Inference of a mixture.
1027
+ var_name : str
1028
+ Name of the variable to get histogram for.
1029
+ filename: str
1030
+ the name of the saved file
1031
+ bgcolor: str
1032
+ color for background (transparent if None)
1033
+ txtcolor : str
1034
+ color for text
1035
+ scale : float
1036
+ scale for the size of histograms
1037
+ quantiles : bool
1038
+ shows quantiles.
1039
+ show_mu_sigma : bool
1040
+ shows mu and sigma.
1041
+ """
1042
+ fig = proba2histo(ie, var_name, txtcolor=txtcolor, quantiles=quantiles, scale=scale, show_mu_sigma=show_mu_sigma)
1043
+
1044
+ if bgcolor is None:
1045
+ fc = gum.config["notebook", "figure_facecolor"]
1046
+ else:
1047
+ fc = bgcolor
1048
+
1049
+ fig.savefig(
1050
+ filename,
1051
+ bbox_inches="tight",
1052
+ transparent=False,
1053
+ facecolor=fc,
1054
+ pad_inches=0.05,
1055
+ dpi=fig.dpi,
1056
+ format=gum.config["notebook", "graph_format"],
1057
+ )
1058
+ plt.close(fig)