pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. pyagrum/__init__.py +165 -0
  2. pyagrum/_pyagrum.so +0 -0
  3. pyagrum/bnmixture/BNMInference.py +268 -0
  4. pyagrum/bnmixture/BNMLearning.py +376 -0
  5. pyagrum/bnmixture/BNMixture.py +464 -0
  6. pyagrum/bnmixture/__init__.py +60 -0
  7. pyagrum/bnmixture/notebook.py +1058 -0
  8. pyagrum/causal/_CausalFormula.py +280 -0
  9. pyagrum/causal/_CausalModel.py +436 -0
  10. pyagrum/causal/__init__.py +81 -0
  11. pyagrum/causal/_causalImpact.py +356 -0
  12. pyagrum/causal/_dSeparation.py +598 -0
  13. pyagrum/causal/_doAST.py +761 -0
  14. pyagrum/causal/_doCalculus.py +361 -0
  15. pyagrum/causal/_doorCriteria.py +374 -0
  16. pyagrum/causal/_exceptions.py +95 -0
  17. pyagrum/causal/_types.py +61 -0
  18. pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
  19. pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
  20. pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
  21. pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
  22. pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
  23. pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
  24. pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
  25. pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
  26. pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
  27. pyagrum/causal/notebook.py +171 -0
  28. pyagrum/clg/CLG.py +658 -0
  29. pyagrum/clg/GaussianVariable.py +111 -0
  30. pyagrum/clg/SEM.py +312 -0
  31. pyagrum/clg/__init__.py +63 -0
  32. pyagrum/clg/canonicalForm.py +408 -0
  33. pyagrum/clg/constants.py +54 -0
  34. pyagrum/clg/forwardSampling.py +202 -0
  35. pyagrum/clg/learning.py +776 -0
  36. pyagrum/clg/notebook.py +480 -0
  37. pyagrum/clg/variableElimination.py +271 -0
  38. pyagrum/common.py +60 -0
  39. pyagrum/config.py +319 -0
  40. pyagrum/ctbn/CIM.py +513 -0
  41. pyagrum/ctbn/CTBN.py +573 -0
  42. pyagrum/ctbn/CTBNGenerator.py +216 -0
  43. pyagrum/ctbn/CTBNInference.py +459 -0
  44. pyagrum/ctbn/CTBNLearner.py +161 -0
  45. pyagrum/ctbn/SamplesStats.py +671 -0
  46. pyagrum/ctbn/StatsIndepTest.py +355 -0
  47. pyagrum/ctbn/__init__.py +79 -0
  48. pyagrum/ctbn/constants.py +54 -0
  49. pyagrum/ctbn/notebook.py +264 -0
  50. pyagrum/defaults.ini +199 -0
  51. pyagrum/deprecated.py +95 -0
  52. pyagrum/explain/_ComputationCausal.py +75 -0
  53. pyagrum/explain/_ComputationConditional.py +48 -0
  54. pyagrum/explain/_ComputationMarginal.py +48 -0
  55. pyagrum/explain/_CustomShapleyCache.py +110 -0
  56. pyagrum/explain/_Explainer.py +176 -0
  57. pyagrum/explain/_Explanation.py +70 -0
  58. pyagrum/explain/_FIFOCache.py +54 -0
  59. pyagrum/explain/_ShallCausalValues.py +204 -0
  60. pyagrum/explain/_ShallConditionalValues.py +155 -0
  61. pyagrum/explain/_ShallMarginalValues.py +155 -0
  62. pyagrum/explain/_ShallValues.py +296 -0
  63. pyagrum/explain/_ShapCausalValues.py +208 -0
  64. pyagrum/explain/_ShapConditionalValues.py +126 -0
  65. pyagrum/explain/_ShapMarginalValues.py +191 -0
  66. pyagrum/explain/_ShapleyValues.py +298 -0
  67. pyagrum/explain/__init__.py +81 -0
  68. pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
  69. pyagrum/explain/_explIndependenceListForPairs.py +146 -0
  70. pyagrum/explain/_explInformationGraph.py +264 -0
  71. pyagrum/explain/notebook/__init__.py +54 -0
  72. pyagrum/explain/notebook/_bar.py +142 -0
  73. pyagrum/explain/notebook/_beeswarm.py +174 -0
  74. pyagrum/explain/notebook/_showShapValues.py +97 -0
  75. pyagrum/explain/notebook/_waterfall.py +220 -0
  76. pyagrum/explain/shapley.py +225 -0
  77. pyagrum/lib/__init__.py +46 -0
  78. pyagrum/lib/_colors.py +390 -0
  79. pyagrum/lib/bn2graph.py +299 -0
  80. pyagrum/lib/bn2roc.py +1026 -0
  81. pyagrum/lib/bn2scores.py +217 -0
  82. pyagrum/lib/bn_vs_bn.py +605 -0
  83. pyagrum/lib/cn2graph.py +305 -0
  84. pyagrum/lib/discreteTypeProcessor.py +1102 -0
  85. pyagrum/lib/discretizer.py +58 -0
  86. pyagrum/lib/dynamicBN.py +390 -0
  87. pyagrum/lib/explain.py +57 -0
  88. pyagrum/lib/export.py +84 -0
  89. pyagrum/lib/id2graph.py +258 -0
  90. pyagrum/lib/image.py +387 -0
  91. pyagrum/lib/ipython.py +307 -0
  92. pyagrum/lib/mrf2graph.py +471 -0
  93. pyagrum/lib/notebook.py +1821 -0
  94. pyagrum/lib/proba_histogram.py +552 -0
  95. pyagrum/lib/utils.py +138 -0
  96. pyagrum/pyagrum.py +31495 -0
  97. pyagrum/skbn/_MBCalcul.py +242 -0
  98. pyagrum/skbn/__init__.py +49 -0
  99. pyagrum/skbn/_learningMethods.py +282 -0
  100. pyagrum/skbn/_utils.py +297 -0
  101. pyagrum/skbn/bnclassifier.py +1014 -0
  102. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
  103. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
  104. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
  105. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
  106. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
  107. pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/lib/bn2roc.py ADDED
@@ -0,0 +1,1026 @@
1
+ ############################################################################
2
+ # This file is part of the aGrUM/pyAgrum library. #
3
+ # #
4
+ # Copyright (c) 2005-2025 by #
5
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
6
+ # - Christophe GONZALES(_at_AMU) #
7
+ # #
8
+ # The aGrUM/pyAgrum library is free software; you can redistribute it #
9
+ # and/or modify it under the terms of either : #
10
+ # #
11
+ # - the GNU Lesser General Public License as published by #
12
+ # the Free Software Foundation, either version 3 of the License, #
13
+ # or (at your option) any later version, #
14
+ # - the MIT license (MIT), #
15
+ # - or both in dual license, as here. #
16
+ # #
17
+ # (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
18
+ # #
19
+ # This aGrUM/pyAgrum library is distributed in the hope that it will be #
20
+ # useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
21
+ # INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
22
+ # FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
23
+ # AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
24
+ # LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
25
+ # ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
26
+ # OTHER DEALINGS IN THE SOFTWARE. #
27
+ # #
28
+ # See LICENCES for more details. #
29
+ # #
30
+ # SPDX-FileCopyrightText: Copyright 2005-2025 #
31
+ # - Pierre-Henri WUILLEMIN(_at_LIP6) #
32
+ # - Christophe GONZALES(_at_AMU) #
33
+ # SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
34
+ # #
35
+ # Contact : info_at_agrum_dot_org #
36
+ # homepage : http://agrum.gitlab.io #
37
+ # gitlab : https://gitlab.com/agrumery/agrum #
38
+ # #
39
+ ############################################################################
40
+
41
+ """
42
+ The purpose of this module is to provide tools for building ROC and PR from Bayesian Network.
43
+ """
44
+
45
+ import os
46
+ from typing import List, Tuple
47
+
48
+ import numpy as np
49
+
50
+ from matplotlib import pylab
51
+
52
+ import pyagrum as gum
53
+ from pyagrum import skbn
54
+
55
+ CSV_TMP_SUFFIX = ".x.csv"
56
+
57
+
58
+ def _getFilename(datasrc):
59
+ "*.CSV_TMP_SUFFIXcsv is the signature of a temp csv file"
60
+ if datasrc.endswith(CSV_TMP_SUFFIX):
61
+ return "dataframe"
62
+
63
+ return datasrc
64
+
65
+
66
+ def _lines_count(filename):
67
+ """
68
+ Parameters
69
+ ----------
70
+ filename : str
71
+ a filename
72
+
73
+ Returns
74
+ -------
75
+ int
76
+ the number of lines in the file
77
+ """
78
+ numlines = 0
79
+
80
+ with open(filename) as f:
81
+ for _ in f.readlines():
82
+ numlines += 1
83
+
84
+ return numlines
85
+
86
+
87
+ def _checkCompatibility(bn, fields, datasrc):
88
+ """
89
+ check if variables of the bn are in the fields
90
+
91
+ Parameters
92
+ ----------
93
+ bn : pyagrum.BayesNet
94
+ a Bayesian network
95
+ fields : list
96
+ a list of fields
97
+ datasrc : str|DataFrame
98
+ a csv filename or a pandas.DataFrame
99
+
100
+ Returns
101
+ -------
102
+ list
103
+ a list of position for variables in fields, None otherwise.
104
+ """
105
+ res = {}
106
+ isOK = True
107
+ for field in bn.names():
108
+ if field not in fields:
109
+ print(f"** field '{field}' is missing.")
110
+ isOK = False
111
+ else:
112
+ res[bn.idFromName(field)] = fields[field]
113
+
114
+ if not isOK:
115
+ res = None
116
+
117
+ return res
118
+
119
+
120
+ def _computeAUC(points):
121
+ """
122
+ Given a set of points drawing a ROC/PR curve, compute the AUC value
123
+
124
+ Parameters
125
+ ----------
126
+ points : list
127
+ a list of points
128
+
129
+ Returns
130
+ -------
131
+ double
132
+ the AUC value
133
+
134
+ """
135
+ # computes the integral from 0 to 1
136
+ somme = 0
137
+ for i in range(1, len(points)):
138
+ somme += (points[i][0] - points[i - 1][0]) * (points[i - 1][1] + points[i][1])
139
+
140
+ return somme / 2
141
+
142
+
143
+ def _computeFbeta(points, ind, beta=1):
144
+ return (1 + beta**2) * points[ind][1] * points[ind][0] / ((beta**2 * points[ind][1]) + points[ind][0])
145
+
146
+
147
+ def _computePoints(bn, datasrc, target, label, *, beta=1, show_progress=True, with_labels=True, significant_digits=10):
148
+ """
149
+ Compute the ROC points.
150
+
151
+ Parameters
152
+ ----------
153
+ bn : pyagrum.BayesNet
154
+ a Bayesian network
155
+ datasrc : str|DataFrame
156
+ a csv filename or a pandas.DataFrame
157
+ target : str
158
+ the target
159
+ label : str
160
+ the target's label or id
161
+ beta : float
162
+ the value of beta for the F-beta score
163
+ show_progress : bool
164
+ indicates if the resulting curve must be printed
165
+ with_labels: bool
166
+ whether we use label or id (especially for parameter label)
167
+ significant_digits:
168
+ number of significant digits when computing probabilities
169
+
170
+ Returns
171
+ -------
172
+ tuple (res, totalP, totalN)
173
+ where res is a list of (proba,isWellClassified) for each line of datasrc.
174
+
175
+ """
176
+ idTarget = bn.idFromName(target)
177
+ label = str(label)
178
+
179
+ if not with_labels:
180
+ idLabel = -1
181
+ for i in range(bn.variable(idTarget).domainSize()):
182
+ if bn.variable(idTarget).label(i) == label:
183
+ idLabel = i
184
+ break
185
+ assert idLabel >= 0
186
+ else:
187
+ idLabel = label
188
+
189
+ Classifier = skbn.BNClassifier(beta=beta, significant_digit=significant_digits)
190
+
191
+ if show_progress:
192
+ # tqdm is optional:
193
+ # pylint: disable=import-outside-toplevel
194
+ filename = _getFilename(datasrc)
195
+ from tqdm import tqdm
196
+
197
+ pbar = tqdm(total=_lines_count(datasrc) - 1, desc=filename, bar_format="{desc}: {percentage:3.0f}%|{bar}|")
198
+
199
+ Classifier.fromTrainedModel(bn, target, idLabel)
200
+ # as a Binary classifier, y will be a list of True (good classification) and False (bad one)
201
+ X, y = Classifier.XYfromCSV(datasrc, with_labels=with_labels, target=target)
202
+ predictions = Classifier.predict_proba(X)
203
+
204
+ totalP = np.count_nonzero(y)
205
+ totalN = len(y) - totalP
206
+ res = []
207
+ for i in range(len(X)):
208
+ px = predictions[i][1]
209
+ res.append((px, y[i]))
210
+
211
+ if show_progress:
212
+ pbar.update()
213
+
214
+ if show_progress:
215
+ pbar.close()
216
+
217
+ return res, totalP, totalN
218
+
219
+
220
+ def _computeROC_PR(values, totalP, totalN, beta):
221
+ """
222
+ Parameters
223
+ ----------
224
+ values :
225
+ the curve values
226
+ totalP : int
227
+ the number of positive values
228
+ totalN : int
229
+ the number of negative values
230
+ beta : float
231
+ the value of beta for the F-beta score
232
+
233
+ Returns
234
+ -------
235
+ tuple
236
+ (points_ROC, ind_ROC, threshold_ROC,AUC_ROC,points_PR, ind_PR, threshold_PR, AUC_PR,thresholds)
237
+ """
238
+
239
+ res = sorted(values, key=lambda t: t[0], reverse=True)
240
+
241
+ vp = 0.0 # Number of True Positives
242
+ fp = 0.0 # Number of False Positives
243
+
244
+ ind_ROC = 0
245
+ dmin_ROC = 100.0 # temporal value for knowing the best threshold
246
+ threshopt_ROC = 0 # best threshold (euclidean distance)
247
+
248
+ ind_PR = 0
249
+ fmax_PR = 0.0 # temporal value for knowing f1 max
250
+ threshopt_PR = 0 # threshold of f1 max
251
+
252
+ pointsROC = [(0, 0)] # first one
253
+ pointsPR = [(0, 1)]
254
+ thresholds = [1]
255
+
256
+ old_threshold = res[0][0]
257
+ for r_i in res:
258
+ # we add a point only if the threshold has changed
259
+ cur_threshold = r_i[0]
260
+ if cur_threshold < old_threshold: # the threshold allows to take computation variation into account
261
+ fpr = fp / totalN # false positives rate
262
+ tpr = vp / totalP # true positives rate and recall
263
+ prec = vp / (vp + fp) # precision
264
+
265
+ # euclidian distance to know the best threshold
266
+ d = fpr * fpr + (1 - tpr) * (1 - tpr)
267
+ if d < dmin_ROC:
268
+ dmin_ROC = d
269
+ ind_ROC = len(pointsROC)
270
+ threshopt_ROC = (cur_threshold + old_threshold) / 2
271
+
272
+ if prec + tpr > 0:
273
+ f = (1 + beta**2) * prec * tpr / ((beta**2 * prec) + tpr)
274
+
275
+ if f > fmax_PR:
276
+ fmax_PR = f
277
+ ind_PR = len(pointsPR)
278
+ threshopt_PR = (cur_threshold + old_threshold) / 2
279
+
280
+ pointsROC.append((fpr, tpr))
281
+ pointsPR.append((tpr, prec))
282
+ thresholds.append(cur_threshold)
283
+
284
+ old_threshold = cur_threshold
285
+
286
+ correct_prediction = r_i[1]
287
+ if correct_prediction:
288
+ vp += 1.0
289
+ else:
290
+ fp += 1.0
291
+
292
+ # last ones
293
+ thresholds.append(0)
294
+ pointsROC.append((1, 1))
295
+ pointsPR.append((1, 0))
296
+
297
+ AUC_ROC = _computeAUC(pointsROC)
298
+ AUC_PR = _computeAUC(pointsPR)
299
+
300
+ fbeta_ROC = _computeFbeta(pointsPR, ind_ROC, beta)
301
+ fbeta_PR = _computeFbeta(pointsPR, ind_PR, beta)
302
+
303
+ return (
304
+ pointsROC,
305
+ ind_ROC,
306
+ threshopt_ROC,
307
+ AUC_ROC,
308
+ fbeta_ROC,
309
+ pointsPR,
310
+ ind_PR,
311
+ threshopt_PR,
312
+ AUC_PR,
313
+ fbeta_PR,
314
+ thresholds,
315
+ )
316
+
317
+
318
+ def getROCpoints(bn, datasrc, target, label, with_labels=True, significant_digits=10):
319
+ """
320
+ Compute the points of the ROC curve
321
+
322
+ Parameters
323
+ ----------
324
+ bn : pyagrum.BayesNet
325
+ a Bayesian network
326
+ datasrc : str | DataFrame
327
+ a csv filename or a DataFrame
328
+ target : str
329
+ the target
330
+ label : str
331
+ the target's label
332
+ with_labels: bool
333
+ whether we use label or id (especially for parameter label)
334
+ significant_digits:
335
+ number of significant digits when computing probabilities
336
+
337
+ Returns
338
+ -------
339
+ List[Tuple[int,int]]
340
+ the list of points (FalsePositifRate,TruePositifRate)
341
+ """
342
+ if type(datasrc) is not str:
343
+ if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
344
+ import tempfile
345
+
346
+ csvfile = tempfile.NamedTemporaryFile(delete=False)
347
+ tmpfilename = csvfile.name
348
+ csvfilename = tmpfilename + CSV_TMP_SUFFIX
349
+ csvfile.close()
350
+ if hasattr(datasrc, "to_csv"):
351
+ datasrc.to_csv(csvfilename, na_rep="?", index=False)
352
+ else:
353
+ datasrc.write_csv(csvfilename, na_rep="?", index=False)
354
+
355
+ l = getROCpoints(bn, csvfilename, target, label, with_labels=with_labels, significant_digits=significant_digits)
356
+
357
+ os.remove(csvfilename)
358
+ return l
359
+ else:
360
+ raise TypeError("first argument must be a string or a DataFrame")
361
+
362
+ (res, totalP, totalN) = _computePoints(
363
+ bn, datasrc, target, label, show_progress=False, with_labels=with_labels, significant_digits=significant_digits
364
+ )
365
+ (
366
+ pointsROC,
367
+ ind_ROC,
368
+ thresholdROC,
369
+ AUC_ROC,
370
+ fbeta_ROC,
371
+ pointsPR,
372
+ ind_PR,
373
+ thresholdPR,
374
+ AUC_PR,
375
+ fbeta_PR,
376
+ thresholds,
377
+ ) = _computeROC_PR(res, totalP, totalN, beta=1)
378
+
379
+ return pointsROC
380
+
381
+
382
+ def getPRpoints(bn, datasrc, target, label, with_labels=True, significant_digits=10):
383
+ """
384
+ Compute the points of the PR curve
385
+
386
+ Parameters
387
+ ----------
388
+ bn : pyagrum.BayesNet
389
+ a Bayesian network
390
+ datasrc : str|DataFrame
391
+ a csv filename or a pandas.DataFrame
392
+ target : str
393
+ the target
394
+ label : str
395
+ the target's label
396
+ with_labels: bool
397
+ whether we use label or id (especially for parameter label)
398
+ significant_digits:
399
+ number of significant digits when computing probabilities
400
+
401
+ Returns
402
+ -------
403
+ List[Tuple[float,float]]
404
+ the list of points (precision,recall)
405
+ """
406
+ if type(datasrc) is not str:
407
+ if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
408
+ import tempfile
409
+
410
+ csvfile = tempfile.NamedTemporaryFile(delete=False)
411
+ tmpfilename = csvfile.name
412
+ csvfilename = tmpfilename + CSV_TMP_SUFFIX
413
+ csvfile.close()
414
+ if hasattr(datasrc, "to_csv"):
415
+ datasrc.to_csv(csvfilename, na_rep="?", index=False)
416
+ else:
417
+ datasrc.write_csv(csvfilename, na_rep="?", index=False)
418
+
419
+ l = getPRpoints(bn, csvfilename, target, label, with_labels=with_labels, significant_digits=significant_digits)
420
+
421
+ os.remove(csvfilename)
422
+ return l
423
+ else:
424
+ raise TypeError("first argument must be a string or a DataFrame")
425
+
426
+ show_progress = False
427
+ (res, totalP, totalN) = _computePoints(
428
+ bn,
429
+ datasrc,
430
+ target,
431
+ label,
432
+ show_progress=show_progress,
433
+ with_labels=with_labels,
434
+ significant_digits=significant_digits,
435
+ )
436
+ (
437
+ pointsROC,
438
+ ind_ROC,
439
+ thresholdROC,
440
+ AUC_ROC,
441
+ fbeta_ROC,
442
+ pointsPR,
443
+ ind_PR,
444
+ thresholdPR,
445
+ AUC_PR,
446
+ fbeta_PR,
447
+ thresholds,
448
+ ) = _computeROC_PR(res, totalP, totalN, beta=1)
449
+
450
+ return pointsPR
451
+
452
+
453
+ def _getPoint(threshold: float, thresholds: List[float], points: List[Tuple[float, float]]) -> Tuple[float, float]:
454
+ """
455
+
456
+ Find the point corresponding to threshold in points (annotated by thresholds)
457
+
458
+ Parameters
459
+ ----------
460
+ threshold : float
461
+ the threshold to find
462
+ thresholds: list[float]
463
+ the list of thresholds
464
+ points : list[tuple]
465
+ the list of points
466
+
467
+ Returns
468
+ -------
469
+ the point corresponding to threshold
470
+ """
471
+
472
+ def _dichot(mi, ma, tab, v):
473
+ mid = (mi + ma) // 2
474
+ if mid == mi:
475
+ return mi
476
+
477
+ if tab[mid] == v:
478
+ return mid
479
+ elif tab[mid] > v:
480
+ return _dichot(mid, ma, tab, v)
481
+ else:
482
+ return _dichot(mi, mid, tab, v)
483
+
484
+ ind = _dichot(0, len(thresholds), thresholds, threshold)
485
+ if ind == len(points) - 1:
486
+ return points[ind]
487
+ else: # a threshold is between 2 points
488
+ return (points[ind][0] + points[ind + 1][0]) / 2, (points[ind][1] + points[ind + 1][1]) / 2
489
+
490
+
491
+ def _basicDraw(
492
+ ax,
493
+ points,
494
+ thresholds,
495
+ fbeta,
496
+ beta,
497
+ AUC,
498
+ main_color,
499
+ secondary_color,
500
+ last_color="black",
501
+ thresholds_to_show=None,
502
+ align_threshold="left",
503
+ ):
504
+ ax.grid(color="#aaaaaa", linestyle="-", linewidth=1, alpha=0.5)
505
+
506
+ ax.plot(
507
+ [x[0] for x in points], [y[1] for y in points], "-", linewidth=3, color=gum.config["ROC", "draw_color"], zorder=3
508
+ )
509
+ ax.fill_between([x[0] for x in points], [y[1] for y in points], 0, color=gum.config["ROC", "fill_color"])
510
+
511
+ ax.set_ylim((-0.01, 1.01))
512
+ ax.set_xlim((-0.01, 1.01))
513
+ ax.set_xticks(pylab.arange(0, 1.1, 0.1))
514
+ ax.set_yticks(pylab.arange(0, 1.1, 0.1))
515
+ ax.grid(True)
516
+
517
+ axs = pylab.gca()
518
+ r = pylab.Rectangle((0, 0), 1, 1, edgecolor="#444444", facecolor="none", zorder=1)
519
+ axs.add_patch(r)
520
+ for spine in axs.spines.values():
521
+ spine.set_visible(False)
522
+
523
+ if len(points) < 10:
524
+ for i in range(1, len(points) - 1):
525
+ ax.plot(points[i][0], points[i][1], "o", color="#55DD55", zorder=6)
526
+
527
+ def _show_point_from_thresh(thresh, col, shape):
528
+ fontsize = 10 if shape == "o" else 7
529
+ inc_threshold = 0.01 if align_threshold == "left" else -0.01
530
+ point = _getPoint(thresh, thresholds, points)
531
+ ax.plot(point[0], point[1], shape, color=col, zorder=6)
532
+ ax.text(
533
+ point[0] + inc_threshold,
534
+ point[1] - 0.01,
535
+ f"{thresh:.4f}",
536
+ {"color": col, "fontsize": fontsize},
537
+ horizontalalignment=align_threshold,
538
+ verticalalignment="top",
539
+ rotation=0,
540
+ clip_on=False,
541
+ )
542
+
543
+ if thresholds_to_show is not None:
544
+ _show_point_from_thresh(thresholds_to_show[0], main_color, shape="o")
545
+ if len(thresholds_to_show) > 1:
546
+ _show_point_from_thresh(thresholds_to_show[1], secondary_color, shape=".")
547
+ if len(thresholds_to_show) > 2:
548
+ for i in range(2, len(thresholds_to_show)):
549
+ _show_point_from_thresh(thresholds_to_show[i], last_color, shape=".")
550
+
551
+ if align_threshold == "left":
552
+ AUC_x = 0.95
553
+ AUC_halign = "right"
554
+ else:
555
+ AUC_x = 0.05
556
+ AUC_halign = "left"
557
+
558
+ if beta == 1:
559
+ ax.text(
560
+ AUC_x,
561
+ 0.0,
562
+ f"AUC={AUC:.4f}\nF1={fbeta:.4f}",
563
+ {"color": main_color, "fontsize": 18},
564
+ horizontalalignment=AUC_halign,
565
+ verticalalignment="bottom",
566
+ fontsize=18,
567
+ )
568
+ else:
569
+ ax.text(
570
+ AUC_x,
571
+ 0.0,
572
+ f"AUC={AUC:.4f}\nF-{beta:g}={fbeta:.4f}",
573
+ {"color": main_color, "fontsize": 18},
574
+ horizontalalignment=AUC_halign,
575
+ verticalalignment="bottom",
576
+ fontsize=18,
577
+ )
578
+
579
+
580
+ def _drawROC(points, zeTitle, fbeta_ROC, beta, AUC_ROC, thresholds, thresholds_to_show, ax=None):
581
+ ax = ax or pylab.gca()
582
+
583
+ _basicDraw(
584
+ ax,
585
+ points,
586
+ thresholds,
587
+ fbeta=fbeta_ROC,
588
+ beta=beta,
589
+ AUC=AUC_ROC,
590
+ main_color="#DD5555",
591
+ secondary_color="#120af7",
592
+ thresholds_to_show=thresholds_to_show,
593
+ align_threshold="left",
594
+ )
595
+ ax.plot([0.0, 1.0], [0.0, 1.0], "-", color="#AAAAAA")
596
+ ax.set_xlabel("False positive rate")
597
+ ax.set_ylabel("True positive rate")
598
+
599
+ ax.set_title(zeTitle)
600
+
601
+
602
+ def _drawPR(points, zeTitle, fbeta_PR, beta, AUC_PR, thresholds, thresholds_to_show, rate, ax=None):
603
+ ax = ax or pylab.gca()
604
+
605
+ _basicDraw(
606
+ ax,
607
+ points,
608
+ thresholds,
609
+ fbeta=fbeta_PR,
610
+ beta=beta,
611
+ AUC=AUC_PR,
612
+ main_color="#120af7",
613
+ secondary_color="#DD5555",
614
+ thresholds_to_show=thresholds_to_show,
615
+ align_threshold="right",
616
+ )
617
+ ax.plot([0.0, 1.0], [rate, rate], "-", color="#AAAAAA")
618
+ ax.set_xlabel("Recall")
619
+ ax.set_ylabel("Precision")
620
+
621
+ ax.set_title(zeTitle)
622
+
623
+
624
+ def showROC_PR(
625
+ bn,
626
+ datasrc,
627
+ target,
628
+ label,
629
+ *,
630
+ beta=1,
631
+ show_progress=True,
632
+ show_fig=True,
633
+ save_fig=False,
634
+ with_labels=True,
635
+ show_ROC=True,
636
+ show_PR=True,
637
+ significant_digits=10,
638
+ bgcolor=None,
639
+ ):
640
+ """
641
+ Compute the ROC curve and save the result in the folder of the csv file.
642
+
643
+ Parameters
644
+ ----------
645
+ bn : pyagrum.BayesNet
646
+ a Bayesian network
647
+ datasrc : str|DataFrame
648
+ a csv filename or a pandas.DataFrame
649
+ target : str
650
+ the target
651
+ label : str
652
+ the target label
653
+ beta : float
654
+ the value of beta for the F-beta score
655
+ show_progress : bool
656
+ indicates if the progress bar must be printed
657
+ save_fig:
658
+ save the result
659
+ show_fig:
660
+ plot the resuls
661
+ with_labels:
662
+ labels in csv
663
+ show_ROC: bool
664
+ whether we show the ROC figure
665
+ show_PR: bool
666
+ whether we show the PR figure
667
+ significant_digits:
668
+ number of significant digits when computing probabilities
669
+ bgcolor:
670
+ HTML background color for the figure (default: None if transparent)
671
+
672
+ Returns
673
+ -------
674
+ tuple
675
+ (pointsROC, thresholdROC, pointsPR, thresholdPR)
676
+
677
+ """
678
+ if type(datasrc) is not str:
679
+ if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
680
+ import tempfile
681
+
682
+ csvfile = tempfile.NamedTemporaryFile(delete=False)
683
+ tmpfilename = csvfile.name
684
+ csvfilename = tmpfilename + CSV_TMP_SUFFIX
685
+ csvfile.close()
686
+ if hasattr(datasrc, "to_csv"):
687
+ datasrc.to_csv(csvfilename, na_rep="?", index=False)
688
+ else:
689
+ datasrc.write_csv(csvfilename, na_rep="?", index=False)
690
+
691
+ showROC_PR(
692
+ bn,
693
+ csvfilename,
694
+ target,
695
+ label,
696
+ beta=beta,
697
+ show_progress=show_progress,
698
+ show_fig=show_fig,
699
+ save_fig=save_fig,
700
+ with_labels=with_labels,
701
+ show_ROC=show_ROC,
702
+ show_PR=show_PR,
703
+ significant_digits=significant_digits,
704
+ )
705
+
706
+ os.remove(csvfilename)
707
+ return
708
+ else:
709
+ raise TypeError("first argument must be a string or a DataFrame")
710
+
711
+ if bgcolor is not None:
712
+ oldcol = gum.config["notebook", "figure_facecolor"]
713
+ gum.config["notebook", "figure_facecolor"] = bgcolor
714
+
715
+ filename = _getFilename(datasrc)
716
+ (res, totalP, totalN) = _computePoints(
717
+ bn,
718
+ datasrc,
719
+ target,
720
+ label,
721
+ beta=beta,
722
+ show_progress=show_progress,
723
+ with_labels=with_labels,
724
+ significant_digits=significant_digits,
725
+ )
726
+ (
727
+ pointsROC,
728
+ ind_ROC,
729
+ thresholdROC,
730
+ AUC_ROC,
731
+ fbeta_ROC,
732
+ pointsPR,
733
+ ind_PR,
734
+ thresholdPR,
735
+ AUC_PR,
736
+ fbeta_PR,
737
+ thresholds,
738
+ ) = _computeROC_PR(res, totalP, totalN, beta)
739
+ try:
740
+ shortname = os.path.basename(bn.property("name"))
741
+ except gum.NotFound:
742
+ shortname = "unnamed"
743
+ title = shortname + " vs " + filename + " - " + target + "=" + str(label)
744
+
745
+ rate = totalP / (totalP + totalN)
746
+
747
+ if show_ROC and show_PR:
748
+ figname = f"{filename}-ROCandPR_{shortname}-{target}-{label}.png"
749
+ fig = pylab.figure(figsize=(10, 4))
750
+ fig.suptitle(title)
751
+ pylab.gcf().subplots_adjust(wspace=0.1)
752
+
753
+ ax1 = fig.add_subplot(1, 2, 1)
754
+ _drawROC(
755
+ points=pointsROC,
756
+ zeTitle="ROC",
757
+ fbeta_ROC=fbeta_ROC,
758
+ beta=beta,
759
+ AUC_ROC=AUC_ROC,
760
+ thresholds=thresholds,
761
+ thresholds_to_show=[thresholdROC, thresholdPR],
762
+ ax=ax1,
763
+ )
764
+
765
+ ax2 = fig.add_subplot(1, 2, 2)
766
+ ax2.yaxis.tick_right()
767
+ ax2.yaxis.set_label_position("right")
768
+ _drawPR(
769
+ points=pointsPR,
770
+ zeTitle="Precision-Recall",
771
+ fbeta_PR=fbeta_PR,
772
+ beta=beta,
773
+ AUC_PR=AUC_PR,
774
+ thresholds=thresholds,
775
+ thresholds_to_show=[thresholdPR, thresholdROC],
776
+ rate=rate,
777
+ ax=ax2,
778
+ )
779
+ elif show_ROC:
780
+ figname = f"{filename}-ROC_{shortname}-{target}-{label}.png"
781
+
782
+ _drawROC(
783
+ points=pointsROC,
784
+ zeTitle=title,
785
+ fbeta_ROC=fbeta_ROC,
786
+ beta=beta,
787
+ AUC_ROC=AUC_ROC,
788
+ thresholds=thresholds,
789
+ thresholds_to_show=[thresholdROC],
790
+ )
791
+ elif show_PR:
792
+ figname = f"{filename}-PR_{shortname}-{target}-{label}.png"
793
+ _drawPR(
794
+ points=pointsPR,
795
+ zeTitle=title,
796
+ fbeta_PR=fbeta_PR,
797
+ beta=beta,
798
+ AUC_PR=AUC_PR,
799
+ thresholds=thresholds,
800
+ thresholds_to_show=[thresholdPR],
801
+ rate=rate,
802
+ )
803
+
804
+ if save_fig:
805
+ pylab.savefig(figname, dpi=300, transparent=(bgcolor is None))
806
+
807
+ if show_fig:
808
+ pylab.show()
809
+
810
+ if bgcolor is not None:
811
+ gum.config["notebook", "figure_facecolor"] = oldcol
812
+
813
+ return AUC_ROC, thresholdROC, AUC_PR, thresholdPR
814
+
815
+
816
+ def showROC(
817
+ bn, datasrc, target, label, show_progress=True, show_fig=True, save_fig=False, with_labels=True, significant_digits=10
818
+ ):
819
+ """
820
+ Compute the ROC curve and save the result in the folder of the csv file.
821
+
822
+ Parameters
823
+ ----------
824
+ bn : pyagrum.BayesNet
825
+ a Bayesian network
826
+ datasrc : str|DataFrame
827
+ a csv filename or a pandas.DataFrame
828
+ target : str
829
+ the target
830
+ label : str
831
+ the target label
832
+ show_progress : bool
833
+ indicates if the progress bar must be printed
834
+ save_fig:
835
+ save the result
836
+ show_fig:
837
+ plot the resuls
838
+ with_labels:
839
+ labels in csv
840
+ significant_digits:
841
+ number of significant digits when computing probabilities
842
+ """
843
+
844
+ return showROC_PR(
845
+ bn,
846
+ datasrc,
847
+ target,
848
+ label,
849
+ show_progress=show_progress,
850
+ show_fig=show_fig,
851
+ save_fig=save_fig,
852
+ with_labels=with_labels,
853
+ show_ROC=True,
854
+ show_PR=False,
855
+ significant_digits=significant_digits,
856
+ )
857
+
858
+
859
+ def showPR(
860
+ bn,
861
+ datasrc,
862
+ target,
863
+ label,
864
+ *,
865
+ beta=1,
866
+ show_progress=True,
867
+ show_fig=True,
868
+ save_fig=False,
869
+ with_labels=True,
870
+ significant_digits=10,
871
+ ):
872
+ """
873
+ Compute the ROC curve and save the result in the folder of the csv file.
874
+
875
+ Parameters
876
+ ----------
877
+ bn : pyagrum.BayesNet
878
+ a Bayesian network
879
+ datasrc : str|DataFrame
880
+ a csv filename or a pandas.DataFrame
881
+ target : str
882
+ the target
883
+ label : str
884
+ the target label
885
+ show_progress : bool
886
+ indicates if the progress bar must be printed
887
+ save_fig:
888
+ save the result ?
889
+ show_fig:
890
+ plot the resuls ?
891
+ with_labels:
892
+ labels in csv ?
893
+ significant_digits:
894
+ number of significant digits when computing probabilities
895
+ """
896
+
897
+ return showROC_PR(
898
+ bn,
899
+ datasrc,
900
+ target,
901
+ label,
902
+ beta=beta,
903
+ show_progress=show_progress,
904
+ show_fig=show_fig,
905
+ save_fig=save_fig,
906
+ with_labels=with_labels,
907
+ show_ROC=False,
908
+ show_PR=True,
909
+ significant_digits=significant_digits,
910
+ )
911
+
912
+
913
+ def animROC(bn, datasrc, target="Y", label="1"):
914
+ """
915
+ Interactive selection of a threshold using TPR and FPR for BN and data
916
+
917
+ Parameters
918
+ ----------
919
+ bn : pyagrum.BayesNet
920
+ a Bayesian network
921
+ datasrc : str|DataFrame
922
+ a csv filename or a pandas.DataFrame
923
+ target : str
924
+ the target
925
+ label : str
926
+ the target label
927
+ """
928
+ import ipywidgets as widgets
929
+ import matplotlib.pyplot as plt
930
+ import matplotlib.ticker as mtick
931
+
932
+ class DisplayROC:
933
+ def __init__(self, points):
934
+ self._x = [i / len(points) for i in range(len(points))]
935
+ self._y1, self._y2 = zip(*points)
936
+ self._points = points
937
+
938
+ def display(self, threshold):
939
+ rate = threshold / 100.0
940
+ indexes = int((len(self._points) - 1) * rate)
941
+
942
+ plt.rcParams["figure.figsize"] = (4, 3)
943
+
944
+ fig, (ax1, ax2) = plt.subplots(nrows=2)
945
+ ax1.plot(viewer._x, viewer._y1, "g")
946
+ ax1.plot(viewer._x, viewer._y2, "r")
947
+ ax1.plot([rate, rate], [0, 1])
948
+ ax1.xaxis.set_major_formatter(mtick.PercentFormatter(1.0))
949
+
950
+ ax2.barh([0, 1], self._points[indexes], color=["g", "r"])
951
+ ax2.set_yticks(ticks=[0, 1], labels=["FPR", "TPR"])
952
+ ax2.annotate(f" {self._points[indexes][0]:.1%}", xy=(1, 0), xytext=(1, -0.2))
953
+ ax2.annotate(f" {self._points[indexes][1]:.1%}", xy=(1, 1), xytext=(1, 0.8))
954
+ ax2.set_xlim(0, 1)
955
+
956
+ plt.tight_layout()
957
+ plt.show()
958
+
959
+ viewer = DisplayROC(getROCpoints(bn, datasrc, target=target, label=label))
960
+
961
+ def interactive_view(rate: float):
962
+ viewer.display(rate)
963
+
964
+ # widgets.interact(interactive_view, rate=(0,100,1))
965
+ interactive_plot = widgets.interactive(interactive_view, rate=(0, 100, 1))
966
+ output = interactive_plot.children[-1]
967
+ output.layout.height = "250px"
968
+ return interactive_plot
969
+
970
+
971
+ def animPR(bn, datasrc, target="Y", label="1"):
972
+ """
973
+ Interactive selection of a threshold using TPR and FPR for BN and data
974
+
975
+ Parameters
976
+ ----------
977
+ bn : pyagrum.BayesNet
978
+ a Bayesian network
979
+ datasrc : str|DataFrame
980
+ a csv filename or a pandas.DataFrame
981
+ target : str
982
+ the target
983
+ label : str
984
+ the target label
985
+ """
986
+ import ipywidgets as widgets
987
+ import matplotlib.pyplot as plt
988
+ import matplotlib.ticker as mtick
989
+
990
+ class DisplayPR:
991
+ def __init__(self, points):
992
+ self._x = [i / len(points) for i in range(len(points))]
993
+ self._y1, self._y2 = zip(*points)
994
+ self._points = points
995
+
996
+ def display(self, threshold):
997
+ rate = threshold / 100.0
998
+ indexes = int((len(self._points) - 1) * rate)
999
+
1000
+ plt.rcParams["figure.figsize"] = (4, 3)
1001
+
1002
+ fig, (ax1, ax2) = plt.subplots(nrows=2)
1003
+ ax1.plot(viewer._x, viewer._y1, "r")
1004
+ ax1.plot(viewer._x, viewer._y2, "g")
1005
+ ax1.plot([rate, rate], [0, 1])
1006
+ ax1.xaxis.set_major_formatter(mtick.PercentFormatter(1.0))
1007
+
1008
+ ax2.barh([1, 0], self._points[indexes], color=["r", "g"])
1009
+ ax2.set_yticks(ticks=[0, 1], labels=["Precision", "Recall"])
1010
+ ax2.annotate(f" {self._points[indexes][1]:.1%}", xy=(1, 0), xytext=(1, -0.2))
1011
+ ax2.annotate(f" {self._points[indexes][0]:.1%}", xy=(1, 1), xytext=(1, 0.8))
1012
+ ax2.set_xlim(0, 1)
1013
+
1014
+ plt.tight_layout()
1015
+ plt.show()
1016
+
1017
+ viewer = DisplayPR(getPRpoints(bn, datasrc, target=target, label=label))
1018
+
1019
+ def interactive_view(rate: float):
1020
+ viewer.display(rate)
1021
+
1022
+ # widgets.interact(interactive_view, rate=(0,100,1))
1023
+ interactive_plot = widgets.interactive(interactive_view, rate=(0, 100, 1))
1024
+ output = interactive_plot.children[-1]
1025
+ output.layout.height = "250px"
1026
+ return interactive_plot