pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/lib/mrf2graph.py
ADDED
|
@@ -0,0 +1,471 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
The purpose of this module is to provide tools for mapping Markov random field (and inference) in dot language in order to
|
|
43
|
+
be displayed/saved as image.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
import time
|
|
47
|
+
import hashlib
|
|
48
|
+
|
|
49
|
+
import matplotlib.pyplot as plt
|
|
50
|
+
import pydot as dot
|
|
51
|
+
|
|
52
|
+
import pyagrum as gum
|
|
53
|
+
import pyagrum.lib._colors as gumcols
|
|
54
|
+
|
|
55
|
+
from .proba_histogram import saveFigProba
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def MRF2UGdot(
|
|
59
|
+
mrf,
|
|
60
|
+
size="4",
|
|
61
|
+
nodeColor=None,
|
|
62
|
+
edgeWidth=None,
|
|
63
|
+
edgeLabel=None,
|
|
64
|
+
edgeColor=None,
|
|
65
|
+
cmapNode=None,
|
|
66
|
+
cmapEdge=None,
|
|
67
|
+
showMsg=None,
|
|
68
|
+
):
|
|
69
|
+
"""
|
|
70
|
+
Create a pydot representation of the Markov random field as an undirected graph
|
|
71
|
+
|
|
72
|
+
Parameters
|
|
73
|
+
----------
|
|
74
|
+
mrf: pyagrum.MarkovRandomField
|
|
75
|
+
The Markov random field
|
|
76
|
+
size : int |str
|
|
77
|
+
Size of the rendered graph
|
|
78
|
+
nodeColor : Dict[int,float]
|
|
79
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
80
|
+
edgeWidth : Dict[Tuple(int,int),float]
|
|
81
|
+
a edgeMap of values to be shown as width of edges
|
|
82
|
+
edgeLabel: Dict[Tuple(int,int),str]
|
|
83
|
+
an edgeMap of labels to be shown next to edges
|
|
84
|
+
edgeColor: Dict[Tuple(int,int),float]
|
|
85
|
+
a edgeMap of values to be shown as color of edges
|
|
86
|
+
cmapNode : matplotlib.color.colormap
|
|
87
|
+
color map to show the vals of Nodes
|
|
88
|
+
cmapEdge : matplotlib.color.colormap
|
|
89
|
+
color map to show the vals of Edges.
|
|
90
|
+
showMsg : Dict[int,str]
|
|
91
|
+
a nodeMap of values to be shown as tooltip
|
|
92
|
+
|
|
93
|
+
Returns
|
|
94
|
+
-------
|
|
95
|
+
pydot.Dot
|
|
96
|
+
the desired representation of the MRF as a dot graph
|
|
97
|
+
"""
|
|
98
|
+
if cmapNode is None:
|
|
99
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
100
|
+
|
|
101
|
+
if cmapEdge is None:
|
|
102
|
+
cmapEdge = plt.get_cmap(gum.config["notebook", "default_edge_cmap"])
|
|
103
|
+
|
|
104
|
+
# default
|
|
105
|
+
maxedges = 100
|
|
106
|
+
minedges = 0
|
|
107
|
+
|
|
108
|
+
if edgeWidth is not None:
|
|
109
|
+
minedges = min(edgeWidth.values())
|
|
110
|
+
maxedges = max(edgeWidth.values())
|
|
111
|
+
|
|
112
|
+
graph = dot.Dot(graph_type="graph", bgcolor="transparent")
|
|
113
|
+
|
|
114
|
+
for n in mrf.names():
|
|
115
|
+
if nodeColor is None or n not in nodeColor:
|
|
116
|
+
bgcol = gum.config["notebook", "default_node_bgcolor"]
|
|
117
|
+
fgcol = gum.config["notebook", "default_node_fgcolor"]
|
|
118
|
+
res = ""
|
|
119
|
+
else:
|
|
120
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[n], cmapNode)
|
|
121
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[n], cmapNode)
|
|
122
|
+
res = f" : {nodeColor[n] if showMsg is None else showMsg[n]:2.5f}"
|
|
123
|
+
|
|
124
|
+
node = dot.Node(
|
|
125
|
+
'"' + n + '"', style="filled", fillcolor=bgcol, fontcolor=fgcol, tooltip=f'"({mrf.idFromName(n)}) {n}{res}"'
|
|
126
|
+
)
|
|
127
|
+
graph.add_node(node)
|
|
128
|
+
|
|
129
|
+
for a in mrf.edges():
|
|
130
|
+
(n, j) = a
|
|
131
|
+
pw = 1
|
|
132
|
+
av = f"{n} — {j}"
|
|
133
|
+
col = gumcols.getBlackInTheme()
|
|
134
|
+
lb = ""
|
|
135
|
+
|
|
136
|
+
if edgeWidth is not None and a in edgeWidth:
|
|
137
|
+
if maxedges != minedges:
|
|
138
|
+
pw = 0.1 + 5 * (edgeWidth[a] - minedges) / (maxedges - minedges)
|
|
139
|
+
av = f"{n} — {j} : {edgeWidth[a]}"
|
|
140
|
+
|
|
141
|
+
if edgeColor is not None and a in edgeColor:
|
|
142
|
+
col = gumcols.proba2color(edgeColor[a], cmapEdge)
|
|
143
|
+
|
|
144
|
+
if edgeLabel is not None and a in edgeLabel:
|
|
145
|
+
lb = edgeLabel[a]
|
|
146
|
+
|
|
147
|
+
edge = dot.Edge(
|
|
148
|
+
'"' + mrf.variable(a[0]).name() + '"',
|
|
149
|
+
'"' + mrf.variable(a[1]).name() + '"',
|
|
150
|
+
label=lb,
|
|
151
|
+
fontsize="10",
|
|
152
|
+
penwidth=pw,
|
|
153
|
+
color=col,
|
|
154
|
+
tooltip=av,
|
|
155
|
+
)
|
|
156
|
+
graph.add_edge(edge)
|
|
157
|
+
|
|
158
|
+
if size is None:
|
|
159
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
160
|
+
graph.set_size(size)
|
|
161
|
+
return graph
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def MRF2FactorGraphdot(mrf, size=None, nodeColor=None, factorColor=None, cmapNode=None, showMsg=None):
|
|
165
|
+
"""
|
|
166
|
+
Create a pydot representation of the Markov random field as a factor graph
|
|
167
|
+
|
|
168
|
+
Parameters
|
|
169
|
+
----------
|
|
170
|
+
mrf: pyagrum.MarkovRandomField
|
|
171
|
+
the model
|
|
172
|
+
size: float|str
|
|
173
|
+
the size of the rendered graph
|
|
174
|
+
nodeColor: Dict[str,float]
|
|
175
|
+
a nodeMap of values (between 0 and 1) to be shown as color of nodes (with special colors for 0 and 1)
|
|
176
|
+
factorColor:
|
|
177
|
+
a function returning a value (between 0 and 1) to be shown as a color of factor.
|
|
178
|
+
cmapNode: matplotlib.colors.Colormap
|
|
179
|
+
colormap for nodes
|
|
180
|
+
showMsg: Dict[str,str]
|
|
181
|
+
a nodeMap of values to be shown as tooltip
|
|
182
|
+
|
|
183
|
+
Returns
|
|
184
|
+
-------
|
|
185
|
+
pydot.Dot
|
|
186
|
+
the desired representation of the MRF as a dot graph
|
|
187
|
+
"""
|
|
188
|
+
if cmapNode is None:
|
|
189
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
190
|
+
|
|
191
|
+
graph = dot.Dot(graph_type="graph", bgcolor="transparent", layout=gum.config["factorgraph", "graph_layout"])
|
|
192
|
+
|
|
193
|
+
for n in mrf.names():
|
|
194
|
+
if nodeColor is None or n not in nodeColor:
|
|
195
|
+
bgcol = gum.config["factorgraph", "default_node_bgcolor"]
|
|
196
|
+
fgcol = gum.config["factorgraph", "default_node_fgcolor"]
|
|
197
|
+
res = ""
|
|
198
|
+
else:
|
|
199
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[n], cmapNode)
|
|
200
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[n], cmapNode)
|
|
201
|
+
res = f" : {nodeColor[n] if showMsg is None else showMsg[n]:2.5f}"
|
|
202
|
+
|
|
203
|
+
node = dot.Node(
|
|
204
|
+
'"' + n + '"',
|
|
205
|
+
style="filled",
|
|
206
|
+
fillcolor=bgcol,
|
|
207
|
+
fontcolor=fgcol,
|
|
208
|
+
shape="rectangle",
|
|
209
|
+
margin=0.04,
|
|
210
|
+
width=0,
|
|
211
|
+
height=0,
|
|
212
|
+
tooltip=f'"({mrf.idFromName(n)}) {n}{res}"',
|
|
213
|
+
)
|
|
214
|
+
graph.add_node(node)
|
|
215
|
+
|
|
216
|
+
def factorname(factor):
|
|
217
|
+
return '"f' + "#".join(map(str, sorted(list(factor)))) + '"'
|
|
218
|
+
|
|
219
|
+
for f in mrf.factors():
|
|
220
|
+
if factorColor is None:
|
|
221
|
+
bgcol = gum.config["factorgraph", "default_factor_bgcolor"]
|
|
222
|
+
else:
|
|
223
|
+
bgcol = gumcols.proba2bgcolor(factorColor(f), cmapNode)
|
|
224
|
+
node = dot.Node(factorname(f), style="filled", fillcolor=bgcol, shape="point", width=0.1, height=0.1)
|
|
225
|
+
graph.add_node(node)
|
|
226
|
+
|
|
227
|
+
for f in mrf.factors():
|
|
228
|
+
for n in f:
|
|
229
|
+
edge = dot.Edge(
|
|
230
|
+
factorname(f),
|
|
231
|
+
'"' + mrf.variable(n).name() + '"',
|
|
232
|
+
color=gumcols.getBlackInTheme(),
|
|
233
|
+
len=gum.config["factorgraph", "edge_length"],
|
|
234
|
+
)
|
|
235
|
+
graph.add_edge(edge)
|
|
236
|
+
|
|
237
|
+
if size is None:
|
|
238
|
+
size = gum.config["notebook", "default_graph_size"]
|
|
239
|
+
graph.set_size(size)
|
|
240
|
+
return graph
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
def MRFinference2UGdot(
|
|
244
|
+
mrf,
|
|
245
|
+
size=None,
|
|
246
|
+
engine=None,
|
|
247
|
+
evs=None,
|
|
248
|
+
targets=None,
|
|
249
|
+
nodeColor=None,
|
|
250
|
+
factorColor=None,
|
|
251
|
+
arcWidth=None,
|
|
252
|
+
arcColor=None,
|
|
253
|
+
cmapNode=None,
|
|
254
|
+
cmapArc=None,
|
|
255
|
+
view=None,
|
|
256
|
+
):
|
|
257
|
+
"""
|
|
258
|
+
create a pydot representation of an inference in a MRF as an UG
|
|
259
|
+
|
|
260
|
+
:param pyagrum.MarkovRandomField mrf:
|
|
261
|
+
:param string size: size of the rendered graph
|
|
262
|
+
:param pyAgrum Inference engine: inference algorithm used. If None, ShaferShenoyMRFInference will be used
|
|
263
|
+
:param dictionnary evs: map of evidence
|
|
264
|
+
:param set targets: set of targets. If targets={} then each node is a target
|
|
265
|
+
:param nodeColor: a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
266
|
+
:param arcWidth: a arcMap of values to be shown as bold arcs
|
|
267
|
+
:param arcColor: a arcMap of values (between 0 and 1) to be shown as color of arcs
|
|
268
|
+
:param cmapNode: color map to show the vals of Nodes
|
|
269
|
+
:param cmapArc: color map to show the vals of Arcs
|
|
270
|
+
|
|
271
|
+
:return: the desired representation of the inference
|
|
272
|
+
"""
|
|
273
|
+
if evs is None:
|
|
274
|
+
evs = {}
|
|
275
|
+
if targets is None:
|
|
276
|
+
targets = {}
|
|
277
|
+
|
|
278
|
+
if cmapNode is None:
|
|
279
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
280
|
+
|
|
281
|
+
if cmapArc is None:
|
|
282
|
+
cmapArc = plt.get_cmap(gum.config["notebook", "default_arc_cmap"])
|
|
283
|
+
|
|
284
|
+
# default
|
|
285
|
+
minarcs = 0
|
|
286
|
+
maxarcs = 100
|
|
287
|
+
|
|
288
|
+
if arcWidth is not None:
|
|
289
|
+
minarcs = min(arcWidth.values())
|
|
290
|
+
maxarcs = max(arcWidth.values())
|
|
291
|
+
|
|
292
|
+
startTime = time.time()
|
|
293
|
+
if engine is None:
|
|
294
|
+
ie = gum.ShaferShenoyMRFInference(mrf)
|
|
295
|
+
else:
|
|
296
|
+
ie = engine
|
|
297
|
+
ie.setEvidence(evs)
|
|
298
|
+
ie.makeInference()
|
|
299
|
+
stopTime = time.time()
|
|
300
|
+
|
|
301
|
+
from tempfile import mkdtemp
|
|
302
|
+
|
|
303
|
+
temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
|
|
304
|
+
|
|
305
|
+
dotstr = 'graph structs {\n fontcolor="' + gumcols.getBlackInTheme() + '";bgcolor="transparent";'
|
|
306
|
+
|
|
307
|
+
if gum.config.asBool["notebook", "show_inference_time"]:
|
|
308
|
+
dotstr += f' label="Inference in {1000 * (stopTime - startTime):6.2f}ms";\n'
|
|
309
|
+
|
|
310
|
+
fontname, fontsize = gumcols.fontFromMatplotlib()
|
|
311
|
+
dotstr += f' node [fillcolor="{gum.config["notebook", "default_node_bgcolor"]}", style=filled,color="{gum.config["notebook", "default_node_fgcolor"]}",fontname="{fontname}",fontsize="{fontsize}"];\n'
|
|
312
|
+
dotstr += f' edge [color="{gumcols.getBlackInTheme()}"];\n'
|
|
313
|
+
|
|
314
|
+
for nid in mrf.nodes():
|
|
315
|
+
name = mrf.variable(nid).name()
|
|
316
|
+
|
|
317
|
+
# defaults
|
|
318
|
+
bgcol = gum.config["notebook", "default_node_bgcolor"]
|
|
319
|
+
fgcol = gum.config["notebook", "default_node_fgcolor"]
|
|
320
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
321
|
+
bgcol = gum.config["notebook", "figure_facecolor"]
|
|
322
|
+
|
|
323
|
+
if nodeColor is not None and (name in nodeColor or nid in nodeColor):
|
|
324
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
|
|
325
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
|
|
326
|
+
|
|
327
|
+
# 'hard' colour for evidence (?)
|
|
328
|
+
if nid in ie.hardEvidenceNodes() | ie.softEvidenceNodes():
|
|
329
|
+
bgcol = gum.config["notebook", "evidence_bgcolor"]
|
|
330
|
+
fgcol = gum.config["notebook", "evidence_fgcolor"]
|
|
331
|
+
|
|
332
|
+
colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
|
|
333
|
+
|
|
334
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
335
|
+
filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
|
|
336
|
+
saveFigProba(ie.posterior(name), filename, bgcolor=bgcol)
|
|
337
|
+
dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
|
|
338
|
+
else:
|
|
339
|
+
dotstr += f' "{name}" [{colorattribute}]'
|
|
340
|
+
|
|
341
|
+
for a in mrf.edges():
|
|
342
|
+
(n, j) = a
|
|
343
|
+
pw = 1
|
|
344
|
+
av = f"{n} — {j}"
|
|
345
|
+
col = gumcols.getBlackInTheme()
|
|
346
|
+
|
|
347
|
+
if arcWidth is not None and (n, j) in arcWidth:
|
|
348
|
+
if maxarcs != minarcs:
|
|
349
|
+
pw = 0.1 + 5 * (arcWidth[a] - minarcs) / (maxarcs - minarcs)
|
|
350
|
+
av = f"{n} — {j} : {arcWidth[a]}"
|
|
351
|
+
|
|
352
|
+
if arcColor is not None and a in arcColor:
|
|
353
|
+
col = gumcols.proba2color(arcColor[a], cmapArc)
|
|
354
|
+
|
|
355
|
+
dotstr += f' "{mrf.variable(n).name()}"--"{mrf.variable(j).name()}" [penwidth="{pw}",tooltip="{av}",color="{col}"];'
|
|
356
|
+
dotstr += "}"
|
|
357
|
+
|
|
358
|
+
g = dot.graph_from_dot_data(dotstr)[0]
|
|
359
|
+
|
|
360
|
+
if size is None:
|
|
361
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
362
|
+
g.set_size(size)
|
|
363
|
+
g.temp_dir = temp_dir
|
|
364
|
+
|
|
365
|
+
return g
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def MRFinference2FactorGraphdot(
|
|
369
|
+
mrf, size=None, engine=None, evs=None, targets=None, nodeColor=None, factorColor=None, cmapNode=None
|
|
370
|
+
):
|
|
371
|
+
"""
|
|
372
|
+
create a pydot representation of an inference in a MRF as an factor graph
|
|
373
|
+
|
|
374
|
+
:param pyagrum.MarkovRandomField mrf:
|
|
375
|
+
:param string size: size of the rendered graph
|
|
376
|
+
:param pyAgrum Inference engine: inference algorithm used. If None, ShaferShenoyMRFInference will be used
|
|
377
|
+
:param dictionnary evs: map of evidence
|
|
378
|
+
:param set targets: set of targets. If targets={} then each node is a target
|
|
379
|
+
:param nodeColor: a nodeMap of values to be shown as color nodes (with special color for 0 and 1)
|
|
380
|
+
:param cmapNode: color map to show the vals of Nodes
|
|
381
|
+
|
|
382
|
+
:return: the desired representation of the inference
|
|
383
|
+
"""
|
|
384
|
+
if evs is None:
|
|
385
|
+
evs = {}
|
|
386
|
+
if targets is None:
|
|
387
|
+
targets = {}
|
|
388
|
+
if cmapNode is None:
|
|
389
|
+
cmapNode = plt.get_cmap(gum.config["notebook", "default_node_cmap"])
|
|
390
|
+
|
|
391
|
+
startTime = time.time()
|
|
392
|
+
if engine is None:
|
|
393
|
+
ie = gum.ShaferShenoyMRFInference(mrf)
|
|
394
|
+
else:
|
|
395
|
+
ie = engine
|
|
396
|
+
ie.setEvidence(evs)
|
|
397
|
+
ie.makeInference()
|
|
398
|
+
stopTime = time.time()
|
|
399
|
+
|
|
400
|
+
from tempfile import mkdtemp
|
|
401
|
+
|
|
402
|
+
temp_dir = mkdtemp("", "tmp", None) # with TemporaryDirectory() as temp_dir:
|
|
403
|
+
dotstr = f'''graph{{
|
|
404
|
+
layout="{gum.config["factorgraph", "graph_layout"]}";
|
|
405
|
+
fontcolor="{gumcols.getBlackInTheme()}";
|
|
406
|
+
bgcolor="transparent";
|
|
407
|
+
'''
|
|
408
|
+
|
|
409
|
+
if gum.config.asBool["notebook", "show_inference_time"]:
|
|
410
|
+
dotstr += f' label="Inference in {1000 * (stopTime - startTime):6.2f}ms";\n'
|
|
411
|
+
|
|
412
|
+
dotstr += (
|
|
413
|
+
' node [fillcolor="'
|
|
414
|
+
+ gum.config["notebook", "default_node_bgcolor"]
|
|
415
|
+
+ '", style=filled,color="'
|
|
416
|
+
+ gum.config["notebook", "default_node_fgcolor"]
|
|
417
|
+
+ '"];'
|
|
418
|
+
+ "\n"
|
|
419
|
+
)
|
|
420
|
+
dotstr += ' edge [color="' + gumcols.getBlackInTheme() + '"];' + "\n"
|
|
421
|
+
|
|
422
|
+
for nid in mrf.nodes():
|
|
423
|
+
name = mrf.variable(nid).name()
|
|
424
|
+
|
|
425
|
+
# defaults
|
|
426
|
+
bgcol = gum.config["notebook", "default_node_bgcolor"]
|
|
427
|
+
fgcol = gum.config["notebook", "default_node_fgcolor"]
|
|
428
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
429
|
+
bgcol = gum.config["notebook", "figure_facecolor"]
|
|
430
|
+
|
|
431
|
+
if nodeColor is not None and (name in nodeColor or nid in nodeColor):
|
|
432
|
+
bgcol = gumcols.proba2bgcolor(nodeColor[name], cmapNode)
|
|
433
|
+
fgcol = gumcols.proba2fgcolor(nodeColor[name], cmapNode)
|
|
434
|
+
|
|
435
|
+
# 'hard' colour for evidence (?)
|
|
436
|
+
if nid in ie.hardEvidenceNodes() | ie.softEvidenceNodes():
|
|
437
|
+
bgcol = gum.config["notebook", "evidence_bgcolor"]
|
|
438
|
+
fgcol = gum.config["notebook", "evidence_fgcolor"]
|
|
439
|
+
|
|
440
|
+
colorattribute = f'fillcolor="{bgcol}", fontcolor="{fgcol}", color="#000000"'
|
|
441
|
+
if len(targets) == 0 or name in targets or nid in targets:
|
|
442
|
+
filename = temp_dir + hashlib.md5(name.encode()).hexdigest() + "." + gum.config["notebook", "graph_format"]
|
|
443
|
+
saveFigProba(ie.posterior(name), filename, bgcolor=bgcol)
|
|
444
|
+
dotstr += f' "{name}" [shape=rectangle,image="{filename}",label="", {colorattribute}];\n'
|
|
445
|
+
else:
|
|
446
|
+
dotstr += f' "{name}" [shape=rectangle,margin=0.04,width=0,height=0,{colorattribute}];\n'
|
|
447
|
+
|
|
448
|
+
def factorname(f):
|
|
449
|
+
return '"f' + "#".join(map(str, sorted(list(f)))) + '"'
|
|
450
|
+
|
|
451
|
+
for f in mrf.factors():
|
|
452
|
+
if factorColor is None:
|
|
453
|
+
bgcol = gum.config["factorgraph", "default_factor_bgcolor"]
|
|
454
|
+
else:
|
|
455
|
+
bgcol = gumcols.proba2bgcolor(factorColor(f), cmapNode)
|
|
456
|
+
dotstr += f" {factorname(f)} [style=filled,fillcolor={bgcol},shape=point,width=0.1,height=0.1];\n"
|
|
457
|
+
|
|
458
|
+
for f in mrf.factors():
|
|
459
|
+
col = gumcols.getBlackInTheme()
|
|
460
|
+
for n in f:
|
|
461
|
+
dotstr += f' {factorname(f)}->"{mrf.variable(n).name()}" [tooltip="{f}:{n}",color="{col}",fillcolor="{bgcol}",len="{gum.config["factorgraph", "edge_length_inference"]}"];\n'
|
|
462
|
+
dotstr += "}"
|
|
463
|
+
|
|
464
|
+
g = dot.graph_from_dot_data(dotstr)[0]
|
|
465
|
+
|
|
466
|
+
if size is None:
|
|
467
|
+
size = gum.config["notebook", "default_graph_inference_size"]
|
|
468
|
+
g.set_size(size)
|
|
469
|
+
g.temp_dir = temp_dir
|
|
470
|
+
|
|
471
|
+
return g
|