pyAgrum-nightly 2.3.1.9.dev202512261765915415__cp310-abi3-macosx_10_15_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +172 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.1.9.dev202512261765915415.dist-info/WHEEL +4 -0
@@ -0,0 +1,605 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ """
42
+ The purpose of this module is to provide tools for comaring different BNs.
43
+ """
44
+
45
+ import os
46
+ import math
47
+ from itertools import product, combinations
48
+
49
+ import pydot as dot
50
+
51
+ import pyagrum as gum
52
+ import pyagrum.lib.bn2graph as ggr
53
+ import pyagrum.lib.utils as gutils
54
+
55
+
56
+ STRUCTURAL_HAMMING = "structural hamming"
57
+ PURE_HAMMING = "hamming"
58
+
59
+
60
+ class GraphicalBNComparator:
61
+ """
62
+ BNGraphicalComparator allows to compare in multiple way 2 BNs...
63
+ The smallest assumption is that the names of the variables are the same in
64
+ the 2 BNs. But some comparisons will have also to check the type and
65
+ domainSize of the variables.
66
+
67
+ The bns have not exactly the same role : _bn1 is rather the referent model
68
+ for the comparison whereas _bn2 is the compared one to the referent model.
69
+
70
+ Parameters
71
+ ----------
72
+ bn1 : str or pyagrum.BayesNet
73
+ a BN or a filename for reference
74
+ bn2 : str or pyagrum.BayesNet
75
+ another BN or antoher filename for comparison
76
+ """
77
+
78
+ def __init__(self, bn1: str | gum.BayesNet, bn2: str | gum.BayesNet, delta=1e-6):
79
+ self.DELTA_ERROR = delta
80
+ if isinstance(bn1, str):
81
+ self._bn1 = gum.loadBN(bn1)
82
+ self._bn1.setProperty("name", '"' + os.path.basename(self._bn1.property("name") + '"'))
83
+ else:
84
+ self._bn1 = bn1
85
+
86
+ if isinstance(bn2, str):
87
+ self._bn2 = gum.loadBN(bn2)
88
+ self._bn2.setProperty("name", '"' + os.path.basename(self._bn2.property("name") + '"'))
89
+ else:
90
+ self._bn2 = bn2
91
+
92
+ s1: set[str] = set(self._bn1.names())
93
+ s2: set[str] = set(self._bn2.names())
94
+
95
+ if s1 != s2:
96
+ raise ValueError(
97
+ "The 2 BNs are not comparable! There are names not present in the 2 BNs : " + str(s1.symmetric_difference(s2))
98
+ )
99
+
100
+ def _compareBNVariables(self):
101
+ """
102
+ Checks if the two BNs have the same set of variables
103
+
104
+ Returns
105
+ -------
106
+ str
107
+ 'OK' if the BNs have composed of the same variables, indicates problematic variables otherwise
108
+
109
+ """
110
+ # it is assumed (checked by the constructor) that _bn1 and _bn2 share the same set of variable names
111
+ for i in self._bn1.nodes():
112
+ v1 = self._bn1.variable(i)
113
+ v2 = self._bn2.variableFromName(v1.name())
114
+ if v2.domainSize() != v1.domainSize():
115
+ return v1.name() + " has not the same domain size in the two bns"
116
+
117
+ return "OK"
118
+
119
+ @staticmethod
120
+ def _parents_name(bn, n):
121
+ return {bn.variable(p).name() for p in bn.parents(n)}
122
+
123
+ def _compareBNParents(self):
124
+ """
125
+ Returns
126
+ -------
127
+ str
128
+ 'OK' if _bn2 have (at least) the same variable as b1 and their parents are the same.
129
+
130
+ """
131
+ for id1 in self._bn1.nodes():
132
+ id2 = self._bn2.idFromName(self._bn1.variable(id1).name())
133
+
134
+ p1 = self._parents_name(self._bn1, id1)
135
+ p2 = self._parents_name(self._bn2, id2)
136
+ if p1 != p2:
137
+ return (
138
+ self._bn1.variable(id1).name()
139
+ + " has different parents in the two bns whose names are in "
140
+ + str(p1.symmetric_difference(p2))
141
+ )
142
+
143
+ return "OK"
144
+
145
+ def _compareTensors(self, pot1, pot2):
146
+ """
147
+ Compare 2 tensors one in each Bayesian network
148
+
149
+ Parameters
150
+ ----------
151
+ pot1 : pyagrum.Tensor
152
+ one of b1's cpts
153
+ pot2 : pyagrum.Tensor
154
+ one of _bn2's cpts
155
+
156
+ Returns
157
+ -------
158
+ str
159
+ 'OK' if CPTs are the same
160
+
161
+ Raises
162
+ ------
163
+ gum.KeyError
164
+ If cpts are not from the same variable
165
+ """
166
+ I1 = gum.Instantiation(pot1)
167
+ I2 = gum.Instantiation(pot2)
168
+ I1.setFirst()
169
+ while not I1.end():
170
+ I2.fromdict(I1.todict()) # copy value on the base of names
171
+ if abs(pot1.get(I1) - pot2.get(I2)) > self.DELTA_ERROR:
172
+ return "Different CPTs for " + pot1.variable(0).name()
173
+ I1 += 1
174
+ return "OK"
175
+
176
+ def _compareBNCPT(self):
177
+ """
178
+ Returns
179
+ -------
180
+ str
181
+ 'OK' if _bn2 have (at least) the same variable as b1 and their cpts are the same
182
+ """
183
+ for i in self._bn1.nodes():
184
+ res = self._compareTensors(self._bn1.cpt(i), self._bn2.cpt(self._bn1.variable(i).name()))
185
+ if res != "OK":
186
+ return res
187
+ return "OK"
188
+
189
+ def equivalentBNs(self):
190
+ """
191
+ Check if the 2 BNs are equivalent :
192
+
193
+ * same variables
194
+ * same graphical structure
195
+ * same parameters
196
+
197
+ Returns
198
+ -------
199
+ str
200
+ "OK" if bn are the same, a description of the error otherwise
201
+
202
+ """
203
+
204
+ ret = self._compareBNVariables()
205
+ if ret != "OK":
206
+ return ret
207
+
208
+ ret = self._compareBNParents()
209
+ if ret != "OK":
210
+ return ret
211
+
212
+ ret = self._compareBNCPT()
213
+ return ret
214
+
215
+ def dotDiff(self):
216
+ """Return a pydot graph that compares the arcs of _bn1 (reference) with those of self._bn2.
217
+ full black line: the arc is common for both
218
+ full red line: the arc is common but inverted in _bn2
219
+ dotted black line: the arc is added in _bn2
220
+ dotted red line: the arc is removed in _bn2
221
+
222
+ Warning
223
+ -------
224
+ if pydot is not installed, this function just returns None
225
+
226
+ Returns
227
+ -------
228
+ pydot.Dot
229
+ the result dot graph or None if pydot can not be imported
230
+ """
231
+ return graphDiff(self._bn1, self._bn2)
232
+
233
+ def skeletonScores(self):
234
+ """
235
+ Compute Precision, Recall, F-score for skeletons of self._bn2 compared to self._bn1
236
+
237
+ precision and recall are computed considering BN1 as the reference
238
+
239
+ Fscor is 2*(recall* precision)/(recall+precision) and is the weighted average of Precision and Recall.
240
+
241
+ dist2opt=square root of (1-precision)^2+(1-recall)^2 and represents the euclidian distance to the ideal point (precision=1, recall=1)
242
+
243
+ Returns
244
+ -------
245
+ dict[str,double]
246
+ A dictionnary containing 'precision', 'recall', 'fscore', 'dist2opt' and so on.
247
+ """
248
+ # t: True, f: False, p: Positive, n: Negative
249
+ count = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
250
+
251
+ # We look at all combination
252
+ listVariables = self._bn1.names()
253
+
254
+ # Loop on pairs of variables
255
+ for head, tail in combinations(listVariables, 2):
256
+ idHead_1 = self._bn1.idFromName(head)
257
+ idTail_1 = self._bn1.idFromName(tail)
258
+
259
+ idHead_2 = self._bn2.idFromName(head)
260
+ idTail_2 = self._bn2.idFromName(tail)
261
+
262
+ if self._bn1.dag().existsArc(idHead_1, idTail_1) or self._bn1.dag().existsArc(
263
+ idTail_1, idHead_1
264
+ ): # Check edge node1-node2
265
+ if self._bn2.dag().existsArc(idHead_2, idTail_2) or self._bn2.dag().existsArc(idTail_2, idHead_2): # if edge:
266
+ count["tp"] += 1
267
+ else: # If no edge:
268
+ count["fn"] += 1
269
+ else: # Check if no edge
270
+ if self._bn2.dag().existsArc(idHead_2, idTail_2) or self._bn2.dag().existsArc(idTail_2, idHead_2): # If edge
271
+ count["fp"] += 1
272
+ else: # If no arc
273
+ count["tn"] += 1
274
+
275
+ # Compute the scores
276
+ if count["tp"] + count["fn"] != 0:
277
+ recall = (1.0 * count["tp"]) / (count["tp"] + count["fn"])
278
+ else:
279
+ recall = 0.0
280
+
281
+ if count["tp"] + count["fp"] != 0:
282
+ precision = (1.0 * count["tp"]) / (count["tp"] + count["fp"])
283
+ else:
284
+ precision = 0.0
285
+
286
+ if precision + recall != 0.0:
287
+ Fscore = (2 * recall * precision) / (recall + precision)
288
+ else:
289
+ Fscore = 0.0
290
+
291
+ return {
292
+ "count": count,
293
+ "recall": recall,
294
+ "precision": precision,
295
+ "fscore": Fscore,
296
+ "dist2opt": math.sqrt((1 - precision) ** 2 + (1 - recall) ** 2),
297
+ }
298
+
299
+ def scores(self):
300
+ """
301
+ Compute Precision, Recall, F-score for self._bn2 compared to self._bn1
302
+
303
+ precision and recall are computed considering BN1 as the reference
304
+
305
+ Fscore is 2*(recall* precision)/(recall+precision) and is the weighted average of Precision and Recall.
306
+
307
+ dist2opt=square root of (1-precision)^2+(1-recall)^2 and represents the euclidian distance to the ideal point (precision=1, recall=1)
308
+
309
+ Returns
310
+ -------
311
+ dict[str,double]
312
+ A dictionnary containing 'precision', 'recall', 'fscore', 'dist2opt' and so on.
313
+ """
314
+ # t: True, f: False, p: Positive, n: Negative
315
+ count = {"tp": 0, "tn": 0, "fp": 0, "fn": 0}
316
+
317
+ # We look at all combination
318
+ listVariables = self._bn1.names()
319
+
320
+ # Loop on oriented pairs of variables
321
+ for head, tail in product(listVariables, listVariables):
322
+ if head != tail:
323
+ idHead_1 = self._bn1.idFromName(head)
324
+ idTail_1 = self._bn1.idFromName(tail)
325
+
326
+ idHead_2 = self._bn2.idFromName(head)
327
+ idTail_2 = self._bn2.idFromName(tail)
328
+
329
+ if self._bn1.dag().existsArc(idHead_1, idTail_1): # Check arcs head->tail
330
+ if self._bn2.dag().existsArc(idHead_2, idTail_2): # if arc:
331
+ count["tp"] += 1
332
+ else: # If no arc:
333
+ count["fn"] += 1
334
+ else: # Check if no arc
335
+ if self._bn2.dag().existsArc(idHead_2, idTail_2): # If arc
336
+ count["fp"] += 1
337
+ else: # If no arc
338
+ count["tn"] += 1
339
+
340
+ # Compute the scores
341
+ if count["tp"] + count["fn"] != 0:
342
+ recall = (1.0 * count["tp"]) / (count["tp"] + count["fn"])
343
+ else:
344
+ recall = 0.0
345
+
346
+ if count["tp"] + count["fp"] != 0:
347
+ precision = (1.0 * count["tp"]) / (count["tp"] + count["fp"])
348
+ else:
349
+ precision = 0.0
350
+
351
+ if precision + recall != 0.0:
352
+ Fscore = (2 * recall * precision) / (recall + precision)
353
+ else:
354
+ Fscore = 0.0
355
+
356
+ return {
357
+ "count": count,
358
+ "recall": recall,
359
+ "precision": precision,
360
+ "fscore": Fscore,
361
+ "dist2opt": math.sqrt((1 - precision) ** 2 + (1 - recall) ** 2),
362
+ }
363
+
364
+ def hamming(self):
365
+ """
366
+ Compute hamming and structural hamming distance
367
+
368
+ Hamming distance is the difference of edges comparing the 2 skeletons, and Structural Hamming difference is the
369
+ difference comparing the cpdags, including the arcs' orientation.
370
+
371
+ Returns
372
+ -------
373
+ dict[double,double]
374
+ A dictionary containing PURE_HAMMING,STRUCTURAL_HAMMING
375
+ """
376
+ # convert graphs to cpdags
377
+ cpdag1 = gum.EssentialGraph(self._bn1).pdag()
378
+ cpdag2 = gum.EssentialGraph(self._bn2).pdag()
379
+
380
+ # We look at all combinations
381
+ listVariables = self._bn1.names()
382
+ hamming_dico = {PURE_HAMMING: 0, STRUCTURAL_HAMMING: 0}
383
+
384
+ for head, tail in combinations(listVariables, 2):
385
+ idHead_1 = self._bn1.idFromName(head)
386
+ idTail_1 = self._bn1.idFromName(tail)
387
+
388
+ idHead_2 = self._bn2.idFromName(head)
389
+ idTail_2 = self._bn2.idFromName(tail)
390
+
391
+ if cpdag1.existsArc(idHead_1, idTail_1): # Check arcs head->tail
392
+ if cpdag2.existsArc(idTail_2, idHead_2) or cpdag2.existsEdge(idTail_2, idHead_2):
393
+ hamming_dico[STRUCTURAL_HAMMING] += 1
394
+ elif (
395
+ not cpdag2.existsArc(idTail_2, idHead_2)
396
+ and not cpdag2.existsArc(idHead_2, idTail_2)
397
+ and not cpdag2.existsEdge(idTail_2, idHead_2)
398
+ ):
399
+ hamming_dico[STRUCTURAL_HAMMING] += 1
400
+ hamming_dico[PURE_HAMMING] += 1
401
+
402
+ elif cpdag1.existsArc(idTail_1, idHead_1): # Check arcs tail->head
403
+ if cpdag2.existsArc(idHead_2, idTail_2) or cpdag2.existsEdge(idTail_2, idHead_2):
404
+ hamming_dico[STRUCTURAL_HAMMING] += 1
405
+ elif (
406
+ not cpdag2.existsArc(idTail_2, idHead_2)
407
+ and not cpdag2.existsArc(idHead_2, idTail_2)
408
+ and not cpdag2.existsEdge(idTail_2, idHead_2)
409
+ ):
410
+ hamming_dico[STRUCTURAL_HAMMING] += 1
411
+ hamming_dico[PURE_HAMMING] += 1
412
+
413
+ elif cpdag1.existsEdge(idTail_1, idHead_1): # Check edge
414
+ if cpdag2.existsArc(idHead_2, idTail_2) or cpdag2.existsArc(idTail_2, idHead_2):
415
+ hamming_dico[STRUCTURAL_HAMMING] += 1
416
+ elif (
417
+ not cpdag2.existsArc(idTail_2, idHead_2)
418
+ and not cpdag2.existsArc(idHead_2, idTail_2)
419
+ and not cpdag2.existsEdge(idTail_2, idHead_2)
420
+ ):
421
+ hamming_dico[STRUCTURAL_HAMMING] += 1
422
+ hamming_dico[PURE_HAMMING] += 1
423
+ # check no edge or arc on the ref graph, and yes on the other graph
424
+
425
+ elif (
426
+ cpdag2.existsArc(idHead_2, idTail_2)
427
+ or cpdag2.existsEdge(idHead_2, idTail_2)
428
+ or cpdag2.existsArc(idTail_2, idHead_2)
429
+ ):
430
+ hamming_dico[STRUCTURAL_HAMMING] += 1
431
+ hamming_dico[PURE_HAMMING] += 1
432
+
433
+ return hamming_dico
434
+
435
+
436
+ def graphDiff(bnref, bncmp, noStyle=False):
437
+ """Return a pydot graph that compares the arcs of bnref to bncmp.
438
+ graphDiff allows bncmp to have less nodes than bnref. (this is not the case in GraphicalBNComparator.dotDiff())
439
+
440
+ if noStyle is False use 4 styles (fixed in pyagrum.config) :
441
+ - the arc is common for both
442
+ - the arc is common but inverted in _bn2
443
+ - the arc is added in _bn2
444
+ - the arc is removed in _bn2
445
+
446
+ See graphDiffLegend() to add a legend to the graph.
447
+ Warning
448
+ -------
449
+ if pydot is not installed, this function just returns None
450
+
451
+ Returns
452
+ -------
453
+ pydot.Dot
454
+ the result dot graph or None if pydot can not be imported
455
+ """
456
+ g = ggr.BN2dot(bnref)
457
+ positions = gutils.dot_layout(g)
458
+
459
+ res = dot.Dot(graph_type="digraph", bgcolor="transparent", layout="fdp", splines=True)
460
+ for i1 in bnref.nodes():
461
+ if bnref.variable(i1).name() in bncmp.names():
462
+ res.add_node(
463
+ dot.Node(
464
+ f'"{bnref.variable(i1).name()}"',
465
+ style="filled",
466
+ fillcolor=gum.config["notebook", "graphdiff_correct_color"],
467
+ color=gutils.getBlackInTheme(),
468
+ )
469
+ )
470
+ else:
471
+ if not noStyle:
472
+ res.add_node(
473
+ dot.Node(
474
+ f'"{bnref.variable(i1).name()}"',
475
+ style="dashed",
476
+ fillcolor=gum.config["notebook", "graphdiff_correct_color"],
477
+ color=gutils.getBlackInTheme(),
478
+ )
479
+ )
480
+ if noStyle:
481
+ for i1, i2 in bncmp.arcs():
482
+ n1 = bncmp.variable(i1).name()
483
+ n2 = bncmp.variable(i2).name()
484
+ res.add_edge(
485
+ dot.Edge(
486
+ f'"{n1}"',
487
+ f'"{n2}"',
488
+ style=gum.config["notebook", "graphdiff_correct_style"],
489
+ color=gum.config["notebook", "graphdiff_correct_color"],
490
+ )
491
+ )
492
+
493
+ else:
494
+ for i1, i2 in bnref.arcs():
495
+ n1 = bnref.variable(i1).name()
496
+ n2 = bnref.variable(i2).name()
497
+
498
+ # a node is missing
499
+ if not (n1 in bncmp.names() and n2 in bncmp.names()):
500
+ res.add_edge(
501
+ dot.Edge(
502
+ f'"{n1}"',
503
+ f'"{n2}"',
504
+ style=gum.config["notebook", "graphdiff_missing_style"],
505
+ color=gum.config["notebook", "graphdiff_missing_color"],
506
+ )
507
+ )
508
+ continue
509
+
510
+ if bncmp.existsArc(n1, n2): # arc is OK in BN2
511
+ res.add_edge(
512
+ dot.Edge(
513
+ f'"{n1}"',
514
+ f'"{n2}"',
515
+ style=gum.config["notebook", "graphdiff_correct_style"],
516
+ color=gum.config["notebook", "graphdiff_correct_color"],
517
+ )
518
+ )
519
+ elif bncmp.existsArc(n2, n1): # arc is reversed in BN2
520
+ res.add_edge(dot.Edge(f'"{n1}"', f'"{n2}"', style="invis"))
521
+ res.add_edge(
522
+ dot.Edge(
523
+ f'"{n2}"',
524
+ f'"{n1}"',
525
+ style=gum.config["notebook", "graphdiff_reversed_style"],
526
+ color=gum.config["notebook", "graphdiff_reversed_color"],
527
+ )
528
+ )
529
+ else: # arc is missing in BN2
530
+ res.add_edge(
531
+ dot.Edge(
532
+ f'"{n1}"',
533
+ f'"{n2}"',
534
+ style=gum.config["notebook", "graphdiff_missing_style"],
535
+ color=gum.config["notebook", "graphdiff_missing_color"],
536
+ )
537
+ )
538
+
539
+ for i1, i2 in bncmp.arcs():
540
+ n1 = bncmp.variable(i1).name()
541
+ n2 = bncmp.variable(i2).name()
542
+ if not bnref.existsArc(n1, n2) and not bnref.existsArc(n2, n1): # arc only in BN2
543
+ res.add_edge(
544
+ dot.Edge(
545
+ f'"{n1}"',
546
+ f'"{n2}"',
547
+ style=gum.config["notebook", "graphdiff_overflow_style"],
548
+ color=gum.config["notebook", "graphdiff_overflow_color"],
549
+ )
550
+ )
551
+
552
+ gutils.apply_dot_layout(res, positions)
553
+
554
+ return res
555
+
556
+
557
+ def graphDiffLegend():
558
+ try:
559
+ # pydot is optional
560
+ # pylint: disable=import-outside-toplevel
561
+ import pydot as dot
562
+ except ImportError:
563
+ return None
564
+
565
+ res = dot.Dot(graph_type="digraph", bgcolor="transparent", rankdir="LR")
566
+ for i in "abcdefgh":
567
+ res.add_node(dot.Node(i, style="invis"))
568
+ res.add_edge(
569
+ dot.Edge(
570
+ "a",
571
+ "b",
572
+ label="overflow",
573
+ style=gum.config["notebook", "graphdiff_overflow_style"],
574
+ color=gum.config["notebook", "graphdiff_overflow_color"],
575
+ )
576
+ )
577
+ res.add_edge(
578
+ dot.Edge(
579
+ "c",
580
+ "d",
581
+ label="Missing",
582
+ style=gum.config["notebook", "graphdiff_missing_style"],
583
+ color=gum.config["notebook", "graphdiff_missing_color"],
584
+ )
585
+ )
586
+ res.add_edge(
587
+ dot.Edge(
588
+ "e",
589
+ "f",
590
+ label="reversed",
591
+ style=gum.config["notebook", "graphdiff_reversed_style"],
592
+ color=gum.config["notebook", "graphdiff_reversed_color"],
593
+ )
594
+ )
595
+ res.add_edge(
596
+ dot.Edge(
597
+ "g",
598
+ "h",
599
+ label="Correct",
600
+ style=gum.config["notebook", "graphdiff_correct_style"],
601
+ color=gum.config["notebook", "graphdiff_correct_color"],
602
+ )
603
+ )
604
+
605
+ return res