pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/lib/bn_vs_bn.py
ADDED
|
@@ -0,0 +1,605 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
The purpose of this module is to provide tools for comaring different BNs.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
import os
|
|
46
|
+
import math
|
|
47
|
+
from itertools import product, combinations
|
|
48
|
+
|
|
49
|
+
import pydot as dot
|
|
50
|
+
|
|
51
|
+
import pyagrum as gum
|
|
52
|
+
import pyagrum.lib.bn2graph as ggr
|
|
53
|
+
import pyagrum.lib.utils as gutils
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
STRUCTURAL_HAMMING = "structural hamming"
|
|
57
|
+
PURE_HAMMING = "hamming"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class GraphicalBNComparator:
|
|
61
|
+
"""
|
|
62
|
+
BNGraphicalComparator allows to compare in multiple way 2 BNs...
|
|
63
|
+
The smallest assumption is that the names of the variables are the same in
|
|
64
|
+
the 2 BNs. But some comparisons will have also to check the type and
|
|
65
|
+
domainSize of the variables.
|
|
66
|
+
|
|
67
|
+
The bns have not exactly the same role : _bn1 is rather the referent model
|
|
68
|
+
for the comparison whereas _bn2 is the compared one to the referent model.
|
|
69
|
+
|
|
70
|
+
Parameters
|
|
71
|
+
----------
|
|
72
|
+
bn1 : str or pyagrum.BayesNet
|
|
73
|
+
a BN or a filename for reference
|
|
74
|
+
bn2 : str or pyagrum.BayesNet
|
|
75
|
+
another BN or antoher filename for comparison
|
|
76
|
+
"""
|
|
77
|
+
|
|
78
|
+
def __init__(self, bn1: str | gum.BayesNet, bn2: str | gum.BayesNet, delta=1e-6):
|
|
79
|
+
self.DELTA_ERROR = delta
|
|
80
|
+
if isinstance(bn1, str):
|
|
81
|
+
self._bn1 = gum.loadBN(bn1)
|
|
82
|
+
self._bn1.setProperty("name", '"' + os.path.basename(self._bn1.property("name") + '"'))
|
|
83
|
+
else:
|
|
84
|
+
self._bn1 = bn1
|
|
85
|
+
|
|
86
|
+
if isinstance(bn2, str):
|
|
87
|
+
self._bn2 = gum.loadBN(bn2)
|
|
88
|
+
self._bn2.setProperty("name", '"' + os.path.basename(self._bn2.property("name") + '"'))
|
|
89
|
+
else:
|
|
90
|
+
self._bn2 = bn2
|
|
91
|
+
|
|
92
|
+
s1: set[str] = set(self._bn1.names())
|
|
93
|
+
s2: set[str] = set(self._bn2.names())
|
|
94
|
+
|
|
95
|
+
if s1 != s2:
|
|
96
|
+
raise ValueError(
|
|
97
|
+
"The 2 BNs are not comparable! There are names not present in the 2 BNs : " + str(s1.symmetric_difference(s2))
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
def _compareBNVariables(self):
|
|
101
|
+
"""
|
|
102
|
+
Checks if the two BNs have the same set of variables
|
|
103
|
+
|
|
104
|
+
Returns
|
|
105
|
+
-------
|
|
106
|
+
str
|
|
107
|
+
'OK' if the BNs have composed of the same variables, indicates problematic variables otherwise
|
|
108
|
+
|
|
109
|
+
"""
|
|
110
|
+
# it is assumed (checked by the constructor) that _bn1 and _bn2 share the same set of variable names
|
|
111
|
+
for i in self._bn1.nodes():
|
|
112
|
+
v1 = self._bn1.variable(i)
|
|
113
|
+
v2 = self._bn2.variableFromName(v1.name())
|
|
114
|
+
if v2.domainSize() != v1.domainSize():
|
|
115
|
+
return v1.name() + " has not the same domain size in the two bns"
|
|
116
|
+
|
|
117
|
+
return "OK"
|
|
118
|
+
|
|
119
|
+
@staticmethod
|
|
120
|
+
def _parents_name(bn, n):
|
|
121
|
+
return {bn.variable(p).name() for p in bn.parents(n)}
|
|
122
|
+
|
|
123
|
+
def _compareBNParents(self):
|
|
124
|
+
"""
|
|
125
|
+
Returns
|
|
126
|
+
-------
|
|
127
|
+
str
|
|
128
|
+
'OK' if _bn2 have (at least) the same variable as b1 and their parents are the same.
|
|
129
|
+
|
|
130
|
+
"""
|
|
131
|
+
for id1 in self._bn1.nodes():
|
|
132
|
+
id2 = self._bn2.idFromName(self._bn1.variable(id1).name())
|
|
133
|
+
|
|
134
|
+
p1 = self._parents_name(self._bn1, id1)
|
|
135
|
+
p2 = self._parents_name(self._bn2, id2)
|
|
136
|
+
if p1 != p2:
|
|
137
|
+
return (
|
|
138
|
+
self._bn1.variable(id1).name()
|
|
139
|
+
+ " has different parents in the two bns whose names are in "
|
|
140
|
+
+ str(p1.symmetric_difference(p2))
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
return "OK"
|
|
144
|
+
|
|
145
|
+
def _compareTensors(self, pot1, pot2):
|
|
146
|
+
"""
|
|
147
|
+
Compare 2 tensors one in each Bayesian network
|
|
148
|
+
|
|
149
|
+
Parameters
|
|
150
|
+
----------
|
|
151
|
+
pot1 : pyagrum.Tensor
|
|
152
|
+
one of b1's cpts
|
|
153
|
+
pot2 : pyagrum.Tensor
|
|
154
|
+
one of _bn2's cpts
|
|
155
|
+
|
|
156
|
+
Returns
|
|
157
|
+
-------
|
|
158
|
+
str
|
|
159
|
+
'OK' if CPTs are the same
|
|
160
|
+
|
|
161
|
+
Raises
|
|
162
|
+
------
|
|
163
|
+
gum.KeyError
|
|
164
|
+
If cpts are not from the same variable
|
|
165
|
+
"""
|
|
166
|
+
I1 = gum.Instantiation(pot1)
|
|
167
|
+
I2 = gum.Instantiation(pot2)
|
|
168
|
+
I1.setFirst()
|
|
169
|
+
while not I1.end():
|
|
170
|
+
I2.fromdict(I1.todict()) # copy value on the base of names
|
|
171
|
+
if abs(pot1.get(I1) - pot2.get(I2)) > self.DELTA_ERROR:
|
|
172
|
+
return "Different CPTs for " + pot1.variable(0).name()
|
|
173
|
+
I1 += 1
|
|
174
|
+
return "OK"
|
|
175
|
+
|
|
176
|
+
def _compareBNCPT(self):
|
|
177
|
+
"""
|
|
178
|
+
Returns
|
|
179
|
+
-------
|
|
180
|
+
str
|
|
181
|
+
'OK' if _bn2 have (at least) the same variable as b1 and their cpts are the same
|
|
182
|
+
"""
|
|
183
|
+
for i in self._bn1.nodes():
|
|
184
|
+
res = self._compareTensors(self._bn1.cpt(i), self._bn2.cpt(self._bn1.variable(i).name()))
|
|
185
|
+
if res != "OK":
|
|
186
|
+
return res
|
|
187
|
+
return "OK"
|
|
188
|
+
|
|
189
|
+
def equivalentBNs(self):
|
|
190
|
+
"""
|
|
191
|
+
Check if the 2 BNs are equivalent :
|
|
192
|
+
|
|
193
|
+
* same variables
|
|
194
|
+
* same graphical structure
|
|
195
|
+
* same parameters
|
|
196
|
+
|
|
197
|
+
Returns
|
|
198
|
+
-------
|
|
199
|
+
str
|
|
200
|
+
"OK" if bn are the same, a description of the error otherwise
|
|
201
|
+
|
|
202
|
+
"""
|
|
203
|
+
|
|
204
|
+
ret = self._compareBNVariables()
|
|
205
|
+
if ret != "OK":
|
|
206
|
+
return ret
|
|
207
|
+
|
|
208
|
+
ret = self._compareBNParents()
|
|
209
|
+
if ret != "OK":
|
|
210
|
+
return ret
|
|
211
|
+
|
|
212
|
+
ret = self._compareBNCPT()
|
|
213
|
+
return ret
|
|
214
|
+
|
|
215
|
+
def dotDiff(self):
|
|
216
|
+
"""Return a pydot graph that compares the arcs of _bn1 (reference) with those of self._bn2.
|
|
217
|
+
full black line: the arc is common for both
|
|
218
|
+
full red line: the arc is common but inverted in _bn2
|
|
219
|
+
dotted black line: the arc is added in _bn2
|
|
220
|
+
dotted red line: the arc is removed in _bn2
|
|
221
|
+
|
|
222
|
+
Warning
|
|
223
|
+
-------
|
|
224
|
+
if pydot is not installed, this function just returns None
|
|
225
|
+
|
|
226
|
+
Returns
|
|
227
|
+
-------
|
|
228
|
+
pydot.Dot
|
|
229
|
+
the result dot graph or None if pydot can not be imported
|
|
230
|
+
"""
|
|
231
|
+
return graphDiff(self._bn1, self._bn2)
|
|
232
|
+
|
|
233
|
+
def skeletonScores(self):
|
|
234
|
+
"""
|
|
235
|
+
Compute Precision, Recall, F-score for skeletons of self._bn2 compared to self._bn1
|
|
236
|
+
|
|
237
|
+
precision and recall are computed considering BN1 as the reference
|
|
238
|
+
|
|
239
|
+
Fscor is 2*(recall* precision)/(recall+precision) and is the weighted average of Precision and Recall.
|
|
240
|
+
|
|
241
|
+
dist2opt=square root of (1-precision)^2+(1-recall)^2 and represents the euclidian distance to the ideal point (precision=1, recall=1)
|
|
242
|
+
|
|
243
|
+
Returns
|
|
244
|
+
-------
|
|
245
|
+
dict[str,double]
|
|
246
|
+
A dictionnary containing 'precision', 'recall', 'fscore', 'dist2opt' and so on.
|
|
247
|
+
"""
|
|
248
|
+
# t: True, f: False, p: Positive, n: Negative
|
|
249
|
+
count = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
|
|
250
|
+
|
|
251
|
+
# We look at all combination
|
|
252
|
+
listVariables = self._bn1.names()
|
|
253
|
+
|
|
254
|
+
# Loop on pairs of variables
|
|
255
|
+
for head, tail in combinations(listVariables, 2):
|
|
256
|
+
idHead_1 = self._bn1.idFromName(head)
|
|
257
|
+
idTail_1 = self._bn1.idFromName(tail)
|
|
258
|
+
|
|
259
|
+
idHead_2 = self._bn2.idFromName(head)
|
|
260
|
+
idTail_2 = self._bn2.idFromName(tail)
|
|
261
|
+
|
|
262
|
+
if self._bn1.dag().existsArc(idHead_1, idTail_1) or self._bn1.dag().existsArc(
|
|
263
|
+
idTail_1, idHead_1
|
|
264
|
+
): # Check edge node1-node2
|
|
265
|
+
if self._bn2.dag().existsArc(idHead_2, idTail_2) or self._bn2.dag().existsArc(idTail_2, idHead_2): # if edge:
|
|
266
|
+
count["tp"] += 1
|
|
267
|
+
else: # If no edge:
|
|
268
|
+
count["fn"] += 1
|
|
269
|
+
else: # Check if no edge
|
|
270
|
+
if self._bn2.dag().existsArc(idHead_2, idTail_2) or self._bn2.dag().existsArc(idTail_2, idHead_2): # If edge
|
|
271
|
+
count["fp"] += 1
|
|
272
|
+
else: # If no arc
|
|
273
|
+
count["tn"] += 1
|
|
274
|
+
|
|
275
|
+
# Compute the scores
|
|
276
|
+
if count["tp"] + count["fn"] != 0:
|
|
277
|
+
recall = (1.0 * count["tp"]) / (count["tp"] + count["fn"])
|
|
278
|
+
else:
|
|
279
|
+
recall = 0.0
|
|
280
|
+
|
|
281
|
+
if count["tp"] + count["fp"] != 0:
|
|
282
|
+
precision = (1.0 * count["tp"]) / (count["tp"] + count["fp"])
|
|
283
|
+
else:
|
|
284
|
+
precision = 0.0
|
|
285
|
+
|
|
286
|
+
if precision + recall != 0.0:
|
|
287
|
+
Fscore = (2 * recall * precision) / (recall + precision)
|
|
288
|
+
else:
|
|
289
|
+
Fscore = 0.0
|
|
290
|
+
|
|
291
|
+
return {
|
|
292
|
+
"count": count,
|
|
293
|
+
"recall": recall,
|
|
294
|
+
"precision": precision,
|
|
295
|
+
"fscore": Fscore,
|
|
296
|
+
"dist2opt": math.sqrt((1 - precision) ** 2 + (1 - recall) ** 2),
|
|
297
|
+
}
|
|
298
|
+
|
|
299
|
+
def scores(self):
|
|
300
|
+
"""
|
|
301
|
+
Compute Precision, Recall, F-score for self._bn2 compared to self._bn1
|
|
302
|
+
|
|
303
|
+
precision and recall are computed considering BN1 as the reference
|
|
304
|
+
|
|
305
|
+
Fscore is 2*(recall* precision)/(recall+precision) and is the weighted average of Precision and Recall.
|
|
306
|
+
|
|
307
|
+
dist2opt=square root of (1-precision)^2+(1-recall)^2 and represents the euclidian distance to the ideal point (precision=1, recall=1)
|
|
308
|
+
|
|
309
|
+
Returns
|
|
310
|
+
-------
|
|
311
|
+
dict[str,double]
|
|
312
|
+
A dictionnary containing 'precision', 'recall', 'fscore', 'dist2opt' and so on.
|
|
313
|
+
"""
|
|
314
|
+
# t: True, f: False, p: Positive, n: Negative
|
|
315
|
+
count = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
|
|
316
|
+
|
|
317
|
+
# We look at all combination
|
|
318
|
+
listVariables = self._bn1.names()
|
|
319
|
+
|
|
320
|
+
# Loop on oriented pairs of variables
|
|
321
|
+
for head, tail in product(listVariables, listVariables):
|
|
322
|
+
if head != tail:
|
|
323
|
+
idHead_1 = self._bn1.idFromName(head)
|
|
324
|
+
idTail_1 = self._bn1.idFromName(tail)
|
|
325
|
+
|
|
326
|
+
idHead_2 = self._bn2.idFromName(head)
|
|
327
|
+
idTail_2 = self._bn2.idFromName(tail)
|
|
328
|
+
|
|
329
|
+
if self._bn1.dag().existsArc(idHead_1, idTail_1): # Check arcs head->tail
|
|
330
|
+
if self._bn2.dag().existsArc(idHead_2, idTail_2): # if arc:
|
|
331
|
+
count["tp"] += 1
|
|
332
|
+
else: # If no arc:
|
|
333
|
+
count["fn"] += 1
|
|
334
|
+
else: # Check if no arc
|
|
335
|
+
if self._bn2.dag().existsArc(idHead_2, idTail_2): # If arc
|
|
336
|
+
count["fp"] += 1
|
|
337
|
+
else: # If no arc
|
|
338
|
+
count["tn"] += 1
|
|
339
|
+
|
|
340
|
+
# Compute the scores
|
|
341
|
+
if count["tp"] + count["fn"] != 0:
|
|
342
|
+
recall = (1.0 * count["tp"]) / (count["tp"] + count["fn"])
|
|
343
|
+
else:
|
|
344
|
+
recall = 0.0
|
|
345
|
+
|
|
346
|
+
if count["tp"] + count["fp"] != 0:
|
|
347
|
+
precision = (1.0 * count["tp"]) / (count["tp"] + count["fp"])
|
|
348
|
+
else:
|
|
349
|
+
precision = 0.0
|
|
350
|
+
|
|
351
|
+
if precision + recall != 0.0:
|
|
352
|
+
Fscore = (2 * recall * precision) / (recall + precision)
|
|
353
|
+
else:
|
|
354
|
+
Fscore = 0.0
|
|
355
|
+
|
|
356
|
+
return {
|
|
357
|
+
"count": count,
|
|
358
|
+
"recall": recall,
|
|
359
|
+
"precision": precision,
|
|
360
|
+
"fscore": Fscore,
|
|
361
|
+
"dist2opt": math.sqrt((1 - precision) ** 2 + (1 - recall) ** 2),
|
|
362
|
+
}
|
|
363
|
+
|
|
364
|
+
def hamming(self):
|
|
365
|
+
"""
|
|
366
|
+
Compute hamming and structural hamming distance
|
|
367
|
+
|
|
368
|
+
Hamming distance is the difference of edges comparing the 2 skeletons, and Structural Hamming difference is the
|
|
369
|
+
difference comparing the cpdags, including the arcs' orientation.
|
|
370
|
+
|
|
371
|
+
Returns
|
|
372
|
+
-------
|
|
373
|
+
dict[double,double]
|
|
374
|
+
A dictionary containing PURE_HAMMING,STRUCTURAL_HAMMING
|
|
375
|
+
"""
|
|
376
|
+
# convert graphs to cpdags
|
|
377
|
+
cpdag1 = gum.EssentialGraph(self._bn1).pdag()
|
|
378
|
+
cpdag2 = gum.EssentialGraph(self._bn2).pdag()
|
|
379
|
+
|
|
380
|
+
# We look at all combinations
|
|
381
|
+
listVariables = self._bn1.names()
|
|
382
|
+
hamming_dico = {PURE_HAMMING: 0, STRUCTURAL_HAMMING: 0}
|
|
383
|
+
|
|
384
|
+
for head, tail in combinations(listVariables, 2):
|
|
385
|
+
idHead_1 = self._bn1.idFromName(head)
|
|
386
|
+
idTail_1 = self._bn1.idFromName(tail)
|
|
387
|
+
|
|
388
|
+
idHead_2 = self._bn2.idFromName(head)
|
|
389
|
+
idTail_2 = self._bn2.idFromName(tail)
|
|
390
|
+
|
|
391
|
+
if cpdag1.existsArc(idHead_1, idTail_1): # Check arcs head->tail
|
|
392
|
+
if cpdag2.existsArc(idTail_2, idHead_2) or cpdag2.existsEdge(idTail_2, idHead_2):
|
|
393
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
394
|
+
elif (
|
|
395
|
+
not cpdag2.existsArc(idTail_2, idHead_2)
|
|
396
|
+
and not cpdag2.existsArc(idHead_2, idTail_2)
|
|
397
|
+
and not cpdag2.existsEdge(idTail_2, idHead_2)
|
|
398
|
+
):
|
|
399
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
400
|
+
hamming_dico[PURE_HAMMING] += 1
|
|
401
|
+
|
|
402
|
+
elif cpdag1.existsArc(idTail_1, idHead_1): # Check arcs tail->head
|
|
403
|
+
if cpdag2.existsArc(idHead_2, idTail_2) or cpdag2.existsEdge(idTail_2, idHead_2):
|
|
404
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
405
|
+
elif (
|
|
406
|
+
not cpdag2.existsArc(idTail_2, idHead_2)
|
|
407
|
+
and not cpdag2.existsArc(idHead_2, idTail_2)
|
|
408
|
+
and not cpdag2.existsEdge(idTail_2, idHead_2)
|
|
409
|
+
):
|
|
410
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
411
|
+
hamming_dico[PURE_HAMMING] += 1
|
|
412
|
+
|
|
413
|
+
elif cpdag1.existsEdge(idTail_1, idHead_1): # Check edge
|
|
414
|
+
if cpdag2.existsArc(idHead_2, idTail_2) or cpdag2.existsArc(idTail_2, idHead_2):
|
|
415
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
416
|
+
elif (
|
|
417
|
+
not cpdag2.existsArc(idTail_2, idHead_2)
|
|
418
|
+
and not cpdag2.existsArc(idHead_2, idTail_2)
|
|
419
|
+
and not cpdag2.existsEdge(idTail_2, idHead_2)
|
|
420
|
+
):
|
|
421
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
422
|
+
hamming_dico[PURE_HAMMING] += 1
|
|
423
|
+
# check no edge or arc on the ref graph, and yes on the other graph
|
|
424
|
+
|
|
425
|
+
elif (
|
|
426
|
+
cpdag2.existsArc(idHead_2, idTail_2)
|
|
427
|
+
or cpdag2.existsEdge(idHead_2, idTail_2)
|
|
428
|
+
or cpdag2.existsArc(idTail_2, idHead_2)
|
|
429
|
+
):
|
|
430
|
+
hamming_dico[STRUCTURAL_HAMMING] += 1
|
|
431
|
+
hamming_dico[PURE_HAMMING] += 1
|
|
432
|
+
|
|
433
|
+
return hamming_dico
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
def graphDiff(bnref, bncmp, noStyle=False):
|
|
437
|
+
"""Return a pydot graph that compares the arcs of bnref to bncmp.
|
|
438
|
+
graphDiff allows bncmp to have less nodes than bnref. (this is not the case in GraphicalBNComparator.dotDiff())
|
|
439
|
+
|
|
440
|
+
if noStyle is False use 4 styles (fixed in pyagrum.config) :
|
|
441
|
+
- the arc is common for both
|
|
442
|
+
- the arc is common but inverted in _bn2
|
|
443
|
+
- the arc is added in _bn2
|
|
444
|
+
- the arc is removed in _bn2
|
|
445
|
+
|
|
446
|
+
See graphDiffLegend() to add a legend to the graph.
|
|
447
|
+
Warning
|
|
448
|
+
-------
|
|
449
|
+
if pydot is not installed, this function just returns None
|
|
450
|
+
|
|
451
|
+
Returns
|
|
452
|
+
-------
|
|
453
|
+
pydot.Dot
|
|
454
|
+
the result dot graph or None if pydot can not be imported
|
|
455
|
+
"""
|
|
456
|
+
g = ggr.BN2dot(bnref)
|
|
457
|
+
positions = gutils.dot_layout(g)
|
|
458
|
+
|
|
459
|
+
res = dot.Dot(graph_type="digraph", bgcolor="transparent", layout="fdp", splines=True)
|
|
460
|
+
for i1 in bnref.nodes():
|
|
461
|
+
if bnref.variable(i1).name() in bncmp.names():
|
|
462
|
+
res.add_node(
|
|
463
|
+
dot.Node(
|
|
464
|
+
f'"{bnref.variable(i1).name()}"',
|
|
465
|
+
style="filled",
|
|
466
|
+
fillcolor=gum.config["notebook", "graphdiff_correct_color"],
|
|
467
|
+
color=gutils.getBlackInTheme(),
|
|
468
|
+
)
|
|
469
|
+
)
|
|
470
|
+
else:
|
|
471
|
+
if not noStyle:
|
|
472
|
+
res.add_node(
|
|
473
|
+
dot.Node(
|
|
474
|
+
f'"{bnref.variable(i1).name()}"',
|
|
475
|
+
style="dashed",
|
|
476
|
+
fillcolor=gum.config["notebook", "graphdiff_correct_color"],
|
|
477
|
+
color=gutils.getBlackInTheme(),
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
if noStyle:
|
|
481
|
+
for i1, i2 in bncmp.arcs():
|
|
482
|
+
n1 = bncmp.variable(i1).name()
|
|
483
|
+
n2 = bncmp.variable(i2).name()
|
|
484
|
+
res.add_edge(
|
|
485
|
+
dot.Edge(
|
|
486
|
+
f'"{n1}"',
|
|
487
|
+
f'"{n2}"',
|
|
488
|
+
style=gum.config["notebook", "graphdiff_correct_style"],
|
|
489
|
+
color=gum.config["notebook", "graphdiff_correct_color"],
|
|
490
|
+
)
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
else:
|
|
494
|
+
for i1, i2 in bnref.arcs():
|
|
495
|
+
n1 = bnref.variable(i1).name()
|
|
496
|
+
n2 = bnref.variable(i2).name()
|
|
497
|
+
|
|
498
|
+
# a node is missing
|
|
499
|
+
if not (n1 in bncmp.names() and n2 in bncmp.names()):
|
|
500
|
+
res.add_edge(
|
|
501
|
+
dot.Edge(
|
|
502
|
+
f'"{n1}"',
|
|
503
|
+
f'"{n2}"',
|
|
504
|
+
style=gum.config["notebook", "graphdiff_missing_style"],
|
|
505
|
+
color=gum.config["notebook", "graphdiff_missing_color"],
|
|
506
|
+
)
|
|
507
|
+
)
|
|
508
|
+
continue
|
|
509
|
+
|
|
510
|
+
if bncmp.existsArc(n1, n2): # arc is OK in BN2
|
|
511
|
+
res.add_edge(
|
|
512
|
+
dot.Edge(
|
|
513
|
+
f'"{n1}"',
|
|
514
|
+
f'"{n2}"',
|
|
515
|
+
style=gum.config["notebook", "graphdiff_correct_style"],
|
|
516
|
+
color=gum.config["notebook", "graphdiff_correct_color"],
|
|
517
|
+
)
|
|
518
|
+
)
|
|
519
|
+
elif bncmp.existsArc(n2, n1): # arc is reversed in BN2
|
|
520
|
+
res.add_edge(dot.Edge(f'"{n1}"', f'"{n2}"', style="invis"))
|
|
521
|
+
res.add_edge(
|
|
522
|
+
dot.Edge(
|
|
523
|
+
f'"{n2}"',
|
|
524
|
+
f'"{n1}"',
|
|
525
|
+
style=gum.config["notebook", "graphdiff_reversed_style"],
|
|
526
|
+
color=gum.config["notebook", "graphdiff_reversed_color"],
|
|
527
|
+
)
|
|
528
|
+
)
|
|
529
|
+
else: # arc is missing in BN2
|
|
530
|
+
res.add_edge(
|
|
531
|
+
dot.Edge(
|
|
532
|
+
f'"{n1}"',
|
|
533
|
+
f'"{n2}"',
|
|
534
|
+
style=gum.config["notebook", "graphdiff_missing_style"],
|
|
535
|
+
color=gum.config["notebook", "graphdiff_missing_color"],
|
|
536
|
+
)
|
|
537
|
+
)
|
|
538
|
+
|
|
539
|
+
for i1, i2 in bncmp.arcs():
|
|
540
|
+
n1 = bncmp.variable(i1).name()
|
|
541
|
+
n2 = bncmp.variable(i2).name()
|
|
542
|
+
if not bnref.existsArc(n1, n2) and not bnref.existsArc(n2, n1): # arc only in BN2
|
|
543
|
+
res.add_edge(
|
|
544
|
+
dot.Edge(
|
|
545
|
+
f'"{n1}"',
|
|
546
|
+
f'"{n2}"',
|
|
547
|
+
style=gum.config["notebook", "graphdiff_overflow_style"],
|
|
548
|
+
color=gum.config["notebook", "graphdiff_overflow_color"],
|
|
549
|
+
)
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
gutils.apply_dot_layout(res, positions)
|
|
553
|
+
|
|
554
|
+
return res
|
|
555
|
+
|
|
556
|
+
|
|
557
|
+
def graphDiffLegend():
|
|
558
|
+
try:
|
|
559
|
+
# pydot is optional
|
|
560
|
+
# pylint: disable=import-outside-toplevel
|
|
561
|
+
import pydot as dot
|
|
562
|
+
except ImportError:
|
|
563
|
+
return None
|
|
564
|
+
|
|
565
|
+
res = dot.Dot(graph_type="digraph", bgcolor="transparent", rankdir="LR")
|
|
566
|
+
for i in "abcdefgh":
|
|
567
|
+
res.add_node(dot.Node(i, style="invis"))
|
|
568
|
+
res.add_edge(
|
|
569
|
+
dot.Edge(
|
|
570
|
+
"a",
|
|
571
|
+
"b",
|
|
572
|
+
label="overflow",
|
|
573
|
+
style=gum.config["notebook", "graphdiff_overflow_style"],
|
|
574
|
+
color=gum.config["notebook", "graphdiff_overflow_color"],
|
|
575
|
+
)
|
|
576
|
+
)
|
|
577
|
+
res.add_edge(
|
|
578
|
+
dot.Edge(
|
|
579
|
+
"c",
|
|
580
|
+
"d",
|
|
581
|
+
label="Missing",
|
|
582
|
+
style=gum.config["notebook", "graphdiff_missing_style"],
|
|
583
|
+
color=gum.config["notebook", "graphdiff_missing_color"],
|
|
584
|
+
)
|
|
585
|
+
)
|
|
586
|
+
res.add_edge(
|
|
587
|
+
dot.Edge(
|
|
588
|
+
"e",
|
|
589
|
+
"f",
|
|
590
|
+
label="reversed",
|
|
591
|
+
style=gum.config["notebook", "graphdiff_reversed_style"],
|
|
592
|
+
color=gum.config["notebook", "graphdiff_reversed_color"],
|
|
593
|
+
)
|
|
594
|
+
)
|
|
595
|
+
res.add_edge(
|
|
596
|
+
dot.Edge(
|
|
597
|
+
"g",
|
|
598
|
+
"h",
|
|
599
|
+
label="Correct",
|
|
600
|
+
style=gum.config["notebook", "graphdiff_correct_style"],
|
|
601
|
+
color=gum.config["notebook", "graphdiff_correct_color"],
|
|
602
|
+
)
|
|
603
|
+
)
|
|
604
|
+
|
|
605
|
+
return res
|