pyAgrum-nightly 2.3.0.9.dev202512061764412981__cp310-abi3-macosx_11_0_arm64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- pyagrum/__init__.py +165 -0
- pyagrum/_pyagrum.so +0 -0
- pyagrum/bnmixture/BNMInference.py +268 -0
- pyagrum/bnmixture/BNMLearning.py +376 -0
- pyagrum/bnmixture/BNMixture.py +464 -0
- pyagrum/bnmixture/__init__.py +60 -0
- pyagrum/bnmixture/notebook.py +1058 -0
- pyagrum/causal/_CausalFormula.py +280 -0
- pyagrum/causal/_CausalModel.py +436 -0
- pyagrum/causal/__init__.py +81 -0
- pyagrum/causal/_causalImpact.py +356 -0
- pyagrum/causal/_dSeparation.py +598 -0
- pyagrum/causal/_doAST.py +761 -0
- pyagrum/causal/_doCalculus.py +361 -0
- pyagrum/causal/_doorCriteria.py +374 -0
- pyagrum/causal/_exceptions.py +95 -0
- pyagrum/causal/_types.py +61 -0
- pyagrum/causal/causalEffectEstimation/_CausalEffectEstimation.py +1175 -0
- pyagrum/causal/causalEffectEstimation/_IVEstimators.py +718 -0
- pyagrum/causal/causalEffectEstimation/_RCTEstimators.py +132 -0
- pyagrum/causal/causalEffectEstimation/__init__.py +46 -0
- pyagrum/causal/causalEffectEstimation/_backdoorEstimators.py +774 -0
- pyagrum/causal/causalEffectEstimation/_causalBNEstimator.py +324 -0
- pyagrum/causal/causalEffectEstimation/_frontdoorEstimators.py +396 -0
- pyagrum/causal/causalEffectEstimation/_learners.py +118 -0
- pyagrum/causal/causalEffectEstimation/_utils.py +466 -0
- pyagrum/causal/notebook.py +171 -0
- pyagrum/clg/CLG.py +658 -0
- pyagrum/clg/GaussianVariable.py +111 -0
- pyagrum/clg/SEM.py +312 -0
- pyagrum/clg/__init__.py +63 -0
- pyagrum/clg/canonicalForm.py +408 -0
- pyagrum/clg/constants.py +54 -0
- pyagrum/clg/forwardSampling.py +202 -0
- pyagrum/clg/learning.py +776 -0
- pyagrum/clg/notebook.py +480 -0
- pyagrum/clg/variableElimination.py +271 -0
- pyagrum/common.py +60 -0
- pyagrum/config.py +319 -0
- pyagrum/ctbn/CIM.py +513 -0
- pyagrum/ctbn/CTBN.py +573 -0
- pyagrum/ctbn/CTBNGenerator.py +216 -0
- pyagrum/ctbn/CTBNInference.py +459 -0
- pyagrum/ctbn/CTBNLearner.py +161 -0
- pyagrum/ctbn/SamplesStats.py +671 -0
- pyagrum/ctbn/StatsIndepTest.py +355 -0
- pyagrum/ctbn/__init__.py +79 -0
- pyagrum/ctbn/constants.py +54 -0
- pyagrum/ctbn/notebook.py +264 -0
- pyagrum/defaults.ini +199 -0
- pyagrum/deprecated.py +95 -0
- pyagrum/explain/_ComputationCausal.py +75 -0
- pyagrum/explain/_ComputationConditional.py +48 -0
- pyagrum/explain/_ComputationMarginal.py +48 -0
- pyagrum/explain/_CustomShapleyCache.py +110 -0
- pyagrum/explain/_Explainer.py +176 -0
- pyagrum/explain/_Explanation.py +70 -0
- pyagrum/explain/_FIFOCache.py +54 -0
- pyagrum/explain/_ShallCausalValues.py +204 -0
- pyagrum/explain/_ShallConditionalValues.py +155 -0
- pyagrum/explain/_ShallMarginalValues.py +155 -0
- pyagrum/explain/_ShallValues.py +296 -0
- pyagrum/explain/_ShapCausalValues.py +208 -0
- pyagrum/explain/_ShapConditionalValues.py +126 -0
- pyagrum/explain/_ShapMarginalValues.py +191 -0
- pyagrum/explain/_ShapleyValues.py +298 -0
- pyagrum/explain/__init__.py +81 -0
- pyagrum/explain/_explGeneralizedMarkovBlanket.py +152 -0
- pyagrum/explain/_explIndependenceListForPairs.py +146 -0
- pyagrum/explain/_explInformationGraph.py +264 -0
- pyagrum/explain/notebook/__init__.py +54 -0
- pyagrum/explain/notebook/_bar.py +142 -0
- pyagrum/explain/notebook/_beeswarm.py +174 -0
- pyagrum/explain/notebook/_showShapValues.py +97 -0
- pyagrum/explain/notebook/_waterfall.py +220 -0
- pyagrum/explain/shapley.py +225 -0
- pyagrum/lib/__init__.py +46 -0
- pyagrum/lib/_colors.py +390 -0
- pyagrum/lib/bn2graph.py +299 -0
- pyagrum/lib/bn2roc.py +1026 -0
- pyagrum/lib/bn2scores.py +217 -0
- pyagrum/lib/bn_vs_bn.py +605 -0
- pyagrum/lib/cn2graph.py +305 -0
- pyagrum/lib/discreteTypeProcessor.py +1102 -0
- pyagrum/lib/discretizer.py +58 -0
- pyagrum/lib/dynamicBN.py +390 -0
- pyagrum/lib/explain.py +57 -0
- pyagrum/lib/export.py +84 -0
- pyagrum/lib/id2graph.py +258 -0
- pyagrum/lib/image.py +387 -0
- pyagrum/lib/ipython.py +307 -0
- pyagrum/lib/mrf2graph.py +471 -0
- pyagrum/lib/notebook.py +1821 -0
- pyagrum/lib/proba_histogram.py +552 -0
- pyagrum/lib/utils.py +138 -0
- pyagrum/pyagrum.py +31495 -0
- pyagrum/skbn/_MBCalcul.py +242 -0
- pyagrum/skbn/__init__.py +49 -0
- pyagrum/skbn/_learningMethods.py +282 -0
- pyagrum/skbn/_utils.py +297 -0
- pyagrum/skbn/bnclassifier.py +1014 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSE.md +12 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/LGPL-3.0-or-later.txt +304 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/LICENSES/MIT.txt +18 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/METADATA +145 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/RECORD +107 -0
- pyagrum_nightly-2.3.0.9.dev202512061764412981.dist-info/WHEEL +4 -0
pyagrum/lib/bn2roc.py
ADDED
|
@@ -0,0 +1,1026 @@
|
|
|
1
|
+
############################################################################
|
|
2
|
+
# This file is part of the aGrUM/pyAgrum library. #
|
|
3
|
+
# #
|
|
4
|
+
# Copyright (c) 2005-2025 by #
|
|
5
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
6
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
7
|
+
# #
|
|
8
|
+
# The aGrUM/pyAgrum library is free software; you can redistribute it #
|
|
9
|
+
# and/or modify it under the terms of either : #
|
|
10
|
+
# #
|
|
11
|
+
# - the GNU Lesser General Public License as published by #
|
|
12
|
+
# the Free Software Foundation, either version 3 of the License, #
|
|
13
|
+
# or (at your option) any later version, #
|
|
14
|
+
# - the MIT license (MIT), #
|
|
15
|
+
# - or both in dual license, as here. #
|
|
16
|
+
# #
|
|
17
|
+
# (see https://agrum.gitlab.io/articles/dual-licenses-lgplv3mit.html) #
|
|
18
|
+
# #
|
|
19
|
+
# This aGrUM/pyAgrum library is distributed in the hope that it will be #
|
|
20
|
+
# useful, but WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, #
|
|
21
|
+
# INCLUDING BUT NOT LIMITED TO THE WARRANTIES MERCHANTABILITY or FITNESS #
|
|
22
|
+
# FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE #
|
|
23
|
+
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER #
|
|
24
|
+
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, #
|
|
25
|
+
# ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR #
|
|
26
|
+
# OTHER DEALINGS IN THE SOFTWARE. #
|
|
27
|
+
# #
|
|
28
|
+
# See LICENCES for more details. #
|
|
29
|
+
# #
|
|
30
|
+
# SPDX-FileCopyrightText: Copyright 2005-2025 #
|
|
31
|
+
# - Pierre-Henri WUILLEMIN(_at_LIP6) #
|
|
32
|
+
# - Christophe GONZALES(_at_AMU) #
|
|
33
|
+
# SPDX-License-Identifier: LGPL-3.0-or-later OR MIT #
|
|
34
|
+
# #
|
|
35
|
+
# Contact : info_at_agrum_dot_org #
|
|
36
|
+
# homepage : http://agrum.gitlab.io #
|
|
37
|
+
# gitlab : https://gitlab.com/agrumery/agrum #
|
|
38
|
+
# #
|
|
39
|
+
############################################################################
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
The purpose of this module is to provide tools for building ROC and PR from Bayesian Network.
|
|
43
|
+
"""
|
|
44
|
+
|
|
45
|
+
import os
|
|
46
|
+
from typing import List, Tuple
|
|
47
|
+
|
|
48
|
+
import numpy as np
|
|
49
|
+
|
|
50
|
+
from matplotlib import pylab
|
|
51
|
+
|
|
52
|
+
import pyagrum as gum
|
|
53
|
+
from pyagrum import skbn
|
|
54
|
+
|
|
55
|
+
CSV_TMP_SUFFIX = ".x.csv"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _getFilename(datasrc):
|
|
59
|
+
"*.CSV_TMP_SUFFIXcsv is the signature of a temp csv file"
|
|
60
|
+
if datasrc.endswith(CSV_TMP_SUFFIX):
|
|
61
|
+
return "dataframe"
|
|
62
|
+
|
|
63
|
+
return datasrc
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _lines_count(filename):
|
|
67
|
+
"""
|
|
68
|
+
Parameters
|
|
69
|
+
----------
|
|
70
|
+
filename : str
|
|
71
|
+
a filename
|
|
72
|
+
|
|
73
|
+
Returns
|
|
74
|
+
-------
|
|
75
|
+
int
|
|
76
|
+
the number of lines in the file
|
|
77
|
+
"""
|
|
78
|
+
numlines = 0
|
|
79
|
+
|
|
80
|
+
with open(filename) as f:
|
|
81
|
+
for _ in f.readlines():
|
|
82
|
+
numlines += 1
|
|
83
|
+
|
|
84
|
+
return numlines
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def _checkCompatibility(bn, fields, datasrc):
|
|
88
|
+
"""
|
|
89
|
+
check if variables of the bn are in the fields
|
|
90
|
+
|
|
91
|
+
Parameters
|
|
92
|
+
----------
|
|
93
|
+
bn : pyagrum.BayesNet
|
|
94
|
+
a Bayesian network
|
|
95
|
+
fields : list
|
|
96
|
+
a list of fields
|
|
97
|
+
datasrc : str|DataFrame
|
|
98
|
+
a csv filename or a pandas.DataFrame
|
|
99
|
+
|
|
100
|
+
Returns
|
|
101
|
+
-------
|
|
102
|
+
list
|
|
103
|
+
a list of position for variables in fields, None otherwise.
|
|
104
|
+
"""
|
|
105
|
+
res = {}
|
|
106
|
+
isOK = True
|
|
107
|
+
for field in bn.names():
|
|
108
|
+
if field not in fields:
|
|
109
|
+
print(f"** field '{field}' is missing.")
|
|
110
|
+
isOK = False
|
|
111
|
+
else:
|
|
112
|
+
res[bn.idFromName(field)] = fields[field]
|
|
113
|
+
|
|
114
|
+
if not isOK:
|
|
115
|
+
res = None
|
|
116
|
+
|
|
117
|
+
return res
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def _computeAUC(points):
|
|
121
|
+
"""
|
|
122
|
+
Given a set of points drawing a ROC/PR curve, compute the AUC value
|
|
123
|
+
|
|
124
|
+
Parameters
|
|
125
|
+
----------
|
|
126
|
+
points : list
|
|
127
|
+
a list of points
|
|
128
|
+
|
|
129
|
+
Returns
|
|
130
|
+
-------
|
|
131
|
+
double
|
|
132
|
+
the AUC value
|
|
133
|
+
|
|
134
|
+
"""
|
|
135
|
+
# computes the integral from 0 to 1
|
|
136
|
+
somme = 0
|
|
137
|
+
for i in range(1, len(points)):
|
|
138
|
+
somme += (points[i][0] - points[i - 1][0]) * (points[i - 1][1] + points[i][1])
|
|
139
|
+
|
|
140
|
+
return somme / 2
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
def _computeFbeta(points, ind, beta=1):
|
|
144
|
+
return (1 + beta**2) * points[ind][1] * points[ind][0] / ((beta**2 * points[ind][1]) + points[ind][0])
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _computePoints(bn, datasrc, target, label, *, beta=1, show_progress=True, with_labels=True, significant_digits=10):
|
|
148
|
+
"""
|
|
149
|
+
Compute the ROC points.
|
|
150
|
+
|
|
151
|
+
Parameters
|
|
152
|
+
----------
|
|
153
|
+
bn : pyagrum.BayesNet
|
|
154
|
+
a Bayesian network
|
|
155
|
+
datasrc : str|DataFrame
|
|
156
|
+
a csv filename or a pandas.DataFrame
|
|
157
|
+
target : str
|
|
158
|
+
the target
|
|
159
|
+
label : str
|
|
160
|
+
the target's label or id
|
|
161
|
+
beta : float
|
|
162
|
+
the value of beta for the F-beta score
|
|
163
|
+
show_progress : bool
|
|
164
|
+
indicates if the resulting curve must be printed
|
|
165
|
+
with_labels: bool
|
|
166
|
+
whether we use label or id (especially for parameter label)
|
|
167
|
+
significant_digits:
|
|
168
|
+
number of significant digits when computing probabilities
|
|
169
|
+
|
|
170
|
+
Returns
|
|
171
|
+
-------
|
|
172
|
+
tuple (res, totalP, totalN)
|
|
173
|
+
where res is a list of (proba,isWellClassified) for each line of datasrc.
|
|
174
|
+
|
|
175
|
+
"""
|
|
176
|
+
idTarget = bn.idFromName(target)
|
|
177
|
+
label = str(label)
|
|
178
|
+
|
|
179
|
+
if not with_labels:
|
|
180
|
+
idLabel = -1
|
|
181
|
+
for i in range(bn.variable(idTarget).domainSize()):
|
|
182
|
+
if bn.variable(idTarget).label(i) == label:
|
|
183
|
+
idLabel = i
|
|
184
|
+
break
|
|
185
|
+
assert idLabel >= 0
|
|
186
|
+
else:
|
|
187
|
+
idLabel = label
|
|
188
|
+
|
|
189
|
+
Classifier = skbn.BNClassifier(beta=beta, significant_digit=significant_digits)
|
|
190
|
+
|
|
191
|
+
if show_progress:
|
|
192
|
+
# tqdm is optional:
|
|
193
|
+
# pylint: disable=import-outside-toplevel
|
|
194
|
+
filename = _getFilename(datasrc)
|
|
195
|
+
from tqdm import tqdm
|
|
196
|
+
|
|
197
|
+
pbar = tqdm(total=_lines_count(datasrc) - 1, desc=filename, bar_format="{desc}: {percentage:3.0f}%|{bar}|")
|
|
198
|
+
|
|
199
|
+
Classifier.fromTrainedModel(bn, target, idLabel)
|
|
200
|
+
# as a Binary classifier, y will be a list of True (good classification) and False (bad one)
|
|
201
|
+
X, y = Classifier.XYfromCSV(datasrc, with_labels=with_labels, target=target)
|
|
202
|
+
predictions = Classifier.predict_proba(X)
|
|
203
|
+
|
|
204
|
+
totalP = np.count_nonzero(y)
|
|
205
|
+
totalN = len(y) - totalP
|
|
206
|
+
res = []
|
|
207
|
+
for i in range(len(X)):
|
|
208
|
+
px = predictions[i][1]
|
|
209
|
+
res.append((px, y[i]))
|
|
210
|
+
|
|
211
|
+
if show_progress:
|
|
212
|
+
pbar.update()
|
|
213
|
+
|
|
214
|
+
if show_progress:
|
|
215
|
+
pbar.close()
|
|
216
|
+
|
|
217
|
+
return res, totalP, totalN
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
def _computeROC_PR(values, totalP, totalN, beta):
|
|
221
|
+
"""
|
|
222
|
+
Parameters
|
|
223
|
+
----------
|
|
224
|
+
values :
|
|
225
|
+
the curve values
|
|
226
|
+
totalP : int
|
|
227
|
+
the number of positive values
|
|
228
|
+
totalN : int
|
|
229
|
+
the number of negative values
|
|
230
|
+
beta : float
|
|
231
|
+
the value of beta for the F-beta score
|
|
232
|
+
|
|
233
|
+
Returns
|
|
234
|
+
-------
|
|
235
|
+
tuple
|
|
236
|
+
(points_ROC, ind_ROC, threshold_ROC,AUC_ROC,points_PR, ind_PR, threshold_PR, AUC_PR,thresholds)
|
|
237
|
+
"""
|
|
238
|
+
|
|
239
|
+
res = sorted(values, key=lambda t: t[0], reverse=True)
|
|
240
|
+
|
|
241
|
+
vp = 0.0 # Number of True Positives
|
|
242
|
+
fp = 0.0 # Number of False Positives
|
|
243
|
+
|
|
244
|
+
ind_ROC = 0
|
|
245
|
+
dmin_ROC = 100.0 # temporal value for knowing the best threshold
|
|
246
|
+
threshopt_ROC = 0 # best threshold (euclidean distance)
|
|
247
|
+
|
|
248
|
+
ind_PR = 0
|
|
249
|
+
fmax_PR = 0.0 # temporal value for knowing f1 max
|
|
250
|
+
threshopt_PR = 0 # threshold of f1 max
|
|
251
|
+
|
|
252
|
+
pointsROC = [(0, 0)] # first one
|
|
253
|
+
pointsPR = [(0, 1)]
|
|
254
|
+
thresholds = [1]
|
|
255
|
+
|
|
256
|
+
old_threshold = res[0][0]
|
|
257
|
+
for r_i in res:
|
|
258
|
+
# we add a point only if the threshold has changed
|
|
259
|
+
cur_threshold = r_i[0]
|
|
260
|
+
if cur_threshold < old_threshold: # the threshold allows to take computation variation into account
|
|
261
|
+
fpr = fp / totalN # false positives rate
|
|
262
|
+
tpr = vp / totalP # true positives rate and recall
|
|
263
|
+
prec = vp / (vp + fp) # precision
|
|
264
|
+
|
|
265
|
+
# euclidian distance to know the best threshold
|
|
266
|
+
d = fpr * fpr + (1 - tpr) * (1 - tpr)
|
|
267
|
+
if d < dmin_ROC:
|
|
268
|
+
dmin_ROC = d
|
|
269
|
+
ind_ROC = len(pointsROC)
|
|
270
|
+
threshopt_ROC = (cur_threshold + old_threshold) / 2
|
|
271
|
+
|
|
272
|
+
if prec + tpr > 0:
|
|
273
|
+
f = (1 + beta**2) * prec * tpr / ((beta**2 * prec) + tpr)
|
|
274
|
+
|
|
275
|
+
if f > fmax_PR:
|
|
276
|
+
fmax_PR = f
|
|
277
|
+
ind_PR = len(pointsPR)
|
|
278
|
+
threshopt_PR = (cur_threshold + old_threshold) / 2
|
|
279
|
+
|
|
280
|
+
pointsROC.append((fpr, tpr))
|
|
281
|
+
pointsPR.append((tpr, prec))
|
|
282
|
+
thresholds.append(cur_threshold)
|
|
283
|
+
|
|
284
|
+
old_threshold = cur_threshold
|
|
285
|
+
|
|
286
|
+
correct_prediction = r_i[1]
|
|
287
|
+
if correct_prediction:
|
|
288
|
+
vp += 1.0
|
|
289
|
+
else:
|
|
290
|
+
fp += 1.0
|
|
291
|
+
|
|
292
|
+
# last ones
|
|
293
|
+
thresholds.append(0)
|
|
294
|
+
pointsROC.append((1, 1))
|
|
295
|
+
pointsPR.append((1, 0))
|
|
296
|
+
|
|
297
|
+
AUC_ROC = _computeAUC(pointsROC)
|
|
298
|
+
AUC_PR = _computeAUC(pointsPR)
|
|
299
|
+
|
|
300
|
+
fbeta_ROC = _computeFbeta(pointsPR, ind_ROC, beta)
|
|
301
|
+
fbeta_PR = _computeFbeta(pointsPR, ind_PR, beta)
|
|
302
|
+
|
|
303
|
+
return (
|
|
304
|
+
pointsROC,
|
|
305
|
+
ind_ROC,
|
|
306
|
+
threshopt_ROC,
|
|
307
|
+
AUC_ROC,
|
|
308
|
+
fbeta_ROC,
|
|
309
|
+
pointsPR,
|
|
310
|
+
ind_PR,
|
|
311
|
+
threshopt_PR,
|
|
312
|
+
AUC_PR,
|
|
313
|
+
fbeta_PR,
|
|
314
|
+
thresholds,
|
|
315
|
+
)
|
|
316
|
+
|
|
317
|
+
|
|
318
|
+
def getROCpoints(bn, datasrc, target, label, with_labels=True, significant_digits=10):
|
|
319
|
+
"""
|
|
320
|
+
Compute the points of the ROC curve
|
|
321
|
+
|
|
322
|
+
Parameters
|
|
323
|
+
----------
|
|
324
|
+
bn : pyagrum.BayesNet
|
|
325
|
+
a Bayesian network
|
|
326
|
+
datasrc : str | DataFrame
|
|
327
|
+
a csv filename or a DataFrame
|
|
328
|
+
target : str
|
|
329
|
+
the target
|
|
330
|
+
label : str
|
|
331
|
+
the target's label
|
|
332
|
+
with_labels: bool
|
|
333
|
+
whether we use label or id (especially for parameter label)
|
|
334
|
+
significant_digits:
|
|
335
|
+
number of significant digits when computing probabilities
|
|
336
|
+
|
|
337
|
+
Returns
|
|
338
|
+
-------
|
|
339
|
+
List[Tuple[int,int]]
|
|
340
|
+
the list of points (FalsePositifRate,TruePositifRate)
|
|
341
|
+
"""
|
|
342
|
+
if type(datasrc) is not str:
|
|
343
|
+
if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
|
|
344
|
+
import tempfile
|
|
345
|
+
|
|
346
|
+
csvfile = tempfile.NamedTemporaryFile(delete=False)
|
|
347
|
+
tmpfilename = csvfile.name
|
|
348
|
+
csvfilename = tmpfilename + CSV_TMP_SUFFIX
|
|
349
|
+
csvfile.close()
|
|
350
|
+
if hasattr(datasrc, "to_csv"):
|
|
351
|
+
datasrc.to_csv(csvfilename, na_rep="?", index=False)
|
|
352
|
+
else:
|
|
353
|
+
datasrc.write_csv(csvfilename, na_rep="?", index=False)
|
|
354
|
+
|
|
355
|
+
l = getROCpoints(bn, csvfilename, target, label, with_labels=with_labels, significant_digits=significant_digits)
|
|
356
|
+
|
|
357
|
+
os.remove(csvfilename)
|
|
358
|
+
return l
|
|
359
|
+
else:
|
|
360
|
+
raise TypeError("first argument must be a string or a DataFrame")
|
|
361
|
+
|
|
362
|
+
(res, totalP, totalN) = _computePoints(
|
|
363
|
+
bn, datasrc, target, label, show_progress=False, with_labels=with_labels, significant_digits=significant_digits
|
|
364
|
+
)
|
|
365
|
+
(
|
|
366
|
+
pointsROC,
|
|
367
|
+
ind_ROC,
|
|
368
|
+
thresholdROC,
|
|
369
|
+
AUC_ROC,
|
|
370
|
+
fbeta_ROC,
|
|
371
|
+
pointsPR,
|
|
372
|
+
ind_PR,
|
|
373
|
+
thresholdPR,
|
|
374
|
+
AUC_PR,
|
|
375
|
+
fbeta_PR,
|
|
376
|
+
thresholds,
|
|
377
|
+
) = _computeROC_PR(res, totalP, totalN, beta=1)
|
|
378
|
+
|
|
379
|
+
return pointsROC
|
|
380
|
+
|
|
381
|
+
|
|
382
|
+
def getPRpoints(bn, datasrc, target, label, with_labels=True, significant_digits=10):
|
|
383
|
+
"""
|
|
384
|
+
Compute the points of the PR curve
|
|
385
|
+
|
|
386
|
+
Parameters
|
|
387
|
+
----------
|
|
388
|
+
bn : pyagrum.BayesNet
|
|
389
|
+
a Bayesian network
|
|
390
|
+
datasrc : str|DataFrame
|
|
391
|
+
a csv filename or a pandas.DataFrame
|
|
392
|
+
target : str
|
|
393
|
+
the target
|
|
394
|
+
label : str
|
|
395
|
+
the target's label
|
|
396
|
+
with_labels: bool
|
|
397
|
+
whether we use label or id (especially for parameter label)
|
|
398
|
+
significant_digits:
|
|
399
|
+
number of significant digits when computing probabilities
|
|
400
|
+
|
|
401
|
+
Returns
|
|
402
|
+
-------
|
|
403
|
+
List[Tuple[float,float]]
|
|
404
|
+
the list of points (precision,recall)
|
|
405
|
+
"""
|
|
406
|
+
if type(datasrc) is not str:
|
|
407
|
+
if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
|
|
408
|
+
import tempfile
|
|
409
|
+
|
|
410
|
+
csvfile = tempfile.NamedTemporaryFile(delete=False)
|
|
411
|
+
tmpfilename = csvfile.name
|
|
412
|
+
csvfilename = tmpfilename + CSV_TMP_SUFFIX
|
|
413
|
+
csvfile.close()
|
|
414
|
+
if hasattr(datasrc, "to_csv"):
|
|
415
|
+
datasrc.to_csv(csvfilename, na_rep="?", index=False)
|
|
416
|
+
else:
|
|
417
|
+
datasrc.write_csv(csvfilename, na_rep="?", index=False)
|
|
418
|
+
|
|
419
|
+
l = getPRpoints(bn, csvfilename, target, label, with_labels=with_labels, significant_digits=significant_digits)
|
|
420
|
+
|
|
421
|
+
os.remove(csvfilename)
|
|
422
|
+
return l
|
|
423
|
+
else:
|
|
424
|
+
raise TypeError("first argument must be a string or a DataFrame")
|
|
425
|
+
|
|
426
|
+
show_progress = False
|
|
427
|
+
(res, totalP, totalN) = _computePoints(
|
|
428
|
+
bn,
|
|
429
|
+
datasrc,
|
|
430
|
+
target,
|
|
431
|
+
label,
|
|
432
|
+
show_progress=show_progress,
|
|
433
|
+
with_labels=with_labels,
|
|
434
|
+
significant_digits=significant_digits,
|
|
435
|
+
)
|
|
436
|
+
(
|
|
437
|
+
pointsROC,
|
|
438
|
+
ind_ROC,
|
|
439
|
+
thresholdROC,
|
|
440
|
+
AUC_ROC,
|
|
441
|
+
fbeta_ROC,
|
|
442
|
+
pointsPR,
|
|
443
|
+
ind_PR,
|
|
444
|
+
thresholdPR,
|
|
445
|
+
AUC_PR,
|
|
446
|
+
fbeta_PR,
|
|
447
|
+
thresholds,
|
|
448
|
+
) = _computeROC_PR(res, totalP, totalN, beta=1)
|
|
449
|
+
|
|
450
|
+
return pointsPR
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
def _getPoint(threshold: float, thresholds: List[float], points: List[Tuple[float, float]]) -> Tuple[float, float]:
|
|
454
|
+
"""
|
|
455
|
+
|
|
456
|
+
Find the point corresponding to threshold in points (annotated by thresholds)
|
|
457
|
+
|
|
458
|
+
Parameters
|
|
459
|
+
----------
|
|
460
|
+
threshold : float
|
|
461
|
+
the threshold to find
|
|
462
|
+
thresholds: list[float]
|
|
463
|
+
the list of thresholds
|
|
464
|
+
points : list[tuple]
|
|
465
|
+
the list of points
|
|
466
|
+
|
|
467
|
+
Returns
|
|
468
|
+
-------
|
|
469
|
+
the point corresponding to threshold
|
|
470
|
+
"""
|
|
471
|
+
|
|
472
|
+
def _dichot(mi, ma, tab, v):
|
|
473
|
+
mid = (mi + ma) // 2
|
|
474
|
+
if mid == mi:
|
|
475
|
+
return mi
|
|
476
|
+
|
|
477
|
+
if tab[mid] == v:
|
|
478
|
+
return mid
|
|
479
|
+
elif tab[mid] > v:
|
|
480
|
+
return _dichot(mid, ma, tab, v)
|
|
481
|
+
else:
|
|
482
|
+
return _dichot(mi, mid, tab, v)
|
|
483
|
+
|
|
484
|
+
ind = _dichot(0, len(thresholds), thresholds, threshold)
|
|
485
|
+
if ind == len(points) - 1:
|
|
486
|
+
return points[ind]
|
|
487
|
+
else: # a threshold is between 2 points
|
|
488
|
+
return (points[ind][0] + points[ind + 1][0]) / 2, (points[ind][1] + points[ind + 1][1]) / 2
|
|
489
|
+
|
|
490
|
+
|
|
491
|
+
def _basicDraw(
|
|
492
|
+
ax,
|
|
493
|
+
points,
|
|
494
|
+
thresholds,
|
|
495
|
+
fbeta,
|
|
496
|
+
beta,
|
|
497
|
+
AUC,
|
|
498
|
+
main_color,
|
|
499
|
+
secondary_color,
|
|
500
|
+
last_color="black",
|
|
501
|
+
thresholds_to_show=None,
|
|
502
|
+
align_threshold="left",
|
|
503
|
+
):
|
|
504
|
+
ax.grid(color="#aaaaaa", linestyle="-", linewidth=1, alpha=0.5)
|
|
505
|
+
|
|
506
|
+
ax.plot(
|
|
507
|
+
[x[0] for x in points], [y[1] for y in points], "-", linewidth=3, color=gum.config["ROC", "draw_color"], zorder=3
|
|
508
|
+
)
|
|
509
|
+
ax.fill_between([x[0] for x in points], [y[1] for y in points], 0, color=gum.config["ROC", "fill_color"])
|
|
510
|
+
|
|
511
|
+
ax.set_ylim((-0.01, 1.01))
|
|
512
|
+
ax.set_xlim((-0.01, 1.01))
|
|
513
|
+
ax.set_xticks(pylab.arange(0, 1.1, 0.1))
|
|
514
|
+
ax.set_yticks(pylab.arange(0, 1.1, 0.1))
|
|
515
|
+
ax.grid(True)
|
|
516
|
+
|
|
517
|
+
axs = pylab.gca()
|
|
518
|
+
r = pylab.Rectangle((0, 0), 1, 1, edgecolor="#444444", facecolor="none", zorder=1)
|
|
519
|
+
axs.add_patch(r)
|
|
520
|
+
for spine in axs.spines.values():
|
|
521
|
+
spine.set_visible(False)
|
|
522
|
+
|
|
523
|
+
if len(points) < 10:
|
|
524
|
+
for i in range(1, len(points) - 1):
|
|
525
|
+
ax.plot(points[i][0], points[i][1], "o", color="#55DD55", zorder=6)
|
|
526
|
+
|
|
527
|
+
def _show_point_from_thresh(thresh, col, shape):
|
|
528
|
+
fontsize = 10 if shape == "o" else 7
|
|
529
|
+
inc_threshold = 0.01 if align_threshold == "left" else -0.01
|
|
530
|
+
point = _getPoint(thresh, thresholds, points)
|
|
531
|
+
ax.plot(point[0], point[1], shape, color=col, zorder=6)
|
|
532
|
+
ax.text(
|
|
533
|
+
point[0] + inc_threshold,
|
|
534
|
+
point[1] - 0.01,
|
|
535
|
+
f"{thresh:.4f}",
|
|
536
|
+
{"color": col, "fontsize": fontsize},
|
|
537
|
+
horizontalalignment=align_threshold,
|
|
538
|
+
verticalalignment="top",
|
|
539
|
+
rotation=0,
|
|
540
|
+
clip_on=False,
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
if thresholds_to_show is not None:
|
|
544
|
+
_show_point_from_thresh(thresholds_to_show[0], main_color, shape="o")
|
|
545
|
+
if len(thresholds_to_show) > 1:
|
|
546
|
+
_show_point_from_thresh(thresholds_to_show[1], secondary_color, shape=".")
|
|
547
|
+
if len(thresholds_to_show) > 2:
|
|
548
|
+
for i in range(2, len(thresholds_to_show)):
|
|
549
|
+
_show_point_from_thresh(thresholds_to_show[i], last_color, shape=".")
|
|
550
|
+
|
|
551
|
+
if align_threshold == "left":
|
|
552
|
+
AUC_x = 0.95
|
|
553
|
+
AUC_halign = "right"
|
|
554
|
+
else:
|
|
555
|
+
AUC_x = 0.05
|
|
556
|
+
AUC_halign = "left"
|
|
557
|
+
|
|
558
|
+
if beta == 1:
|
|
559
|
+
ax.text(
|
|
560
|
+
AUC_x,
|
|
561
|
+
0.0,
|
|
562
|
+
f"AUC={AUC:.4f}\nF1={fbeta:.4f}",
|
|
563
|
+
{"color": main_color, "fontsize": 18},
|
|
564
|
+
horizontalalignment=AUC_halign,
|
|
565
|
+
verticalalignment="bottom",
|
|
566
|
+
fontsize=18,
|
|
567
|
+
)
|
|
568
|
+
else:
|
|
569
|
+
ax.text(
|
|
570
|
+
AUC_x,
|
|
571
|
+
0.0,
|
|
572
|
+
f"AUC={AUC:.4f}\nF-{beta:g}={fbeta:.4f}",
|
|
573
|
+
{"color": main_color, "fontsize": 18},
|
|
574
|
+
horizontalalignment=AUC_halign,
|
|
575
|
+
verticalalignment="bottom",
|
|
576
|
+
fontsize=18,
|
|
577
|
+
)
|
|
578
|
+
|
|
579
|
+
|
|
580
|
+
def _drawROC(points, zeTitle, fbeta_ROC, beta, AUC_ROC, thresholds, thresholds_to_show, ax=None):
|
|
581
|
+
ax = ax or pylab.gca()
|
|
582
|
+
|
|
583
|
+
_basicDraw(
|
|
584
|
+
ax,
|
|
585
|
+
points,
|
|
586
|
+
thresholds,
|
|
587
|
+
fbeta=fbeta_ROC,
|
|
588
|
+
beta=beta,
|
|
589
|
+
AUC=AUC_ROC,
|
|
590
|
+
main_color="#DD5555",
|
|
591
|
+
secondary_color="#120af7",
|
|
592
|
+
thresholds_to_show=thresholds_to_show,
|
|
593
|
+
align_threshold="left",
|
|
594
|
+
)
|
|
595
|
+
ax.plot([0.0, 1.0], [0.0, 1.0], "-", color="#AAAAAA")
|
|
596
|
+
ax.set_xlabel("False positive rate")
|
|
597
|
+
ax.set_ylabel("True positive rate")
|
|
598
|
+
|
|
599
|
+
ax.set_title(zeTitle)
|
|
600
|
+
|
|
601
|
+
|
|
602
|
+
def _drawPR(points, zeTitle, fbeta_PR, beta, AUC_PR, thresholds, thresholds_to_show, rate, ax=None):
|
|
603
|
+
ax = ax or pylab.gca()
|
|
604
|
+
|
|
605
|
+
_basicDraw(
|
|
606
|
+
ax,
|
|
607
|
+
points,
|
|
608
|
+
thresholds,
|
|
609
|
+
fbeta=fbeta_PR,
|
|
610
|
+
beta=beta,
|
|
611
|
+
AUC=AUC_PR,
|
|
612
|
+
main_color="#120af7",
|
|
613
|
+
secondary_color="#DD5555",
|
|
614
|
+
thresholds_to_show=thresholds_to_show,
|
|
615
|
+
align_threshold="right",
|
|
616
|
+
)
|
|
617
|
+
ax.plot([0.0, 1.0], [rate, rate], "-", color="#AAAAAA")
|
|
618
|
+
ax.set_xlabel("Recall")
|
|
619
|
+
ax.set_ylabel("Precision")
|
|
620
|
+
|
|
621
|
+
ax.set_title(zeTitle)
|
|
622
|
+
|
|
623
|
+
|
|
624
|
+
def showROC_PR(
|
|
625
|
+
bn,
|
|
626
|
+
datasrc,
|
|
627
|
+
target,
|
|
628
|
+
label,
|
|
629
|
+
*,
|
|
630
|
+
beta=1,
|
|
631
|
+
show_progress=True,
|
|
632
|
+
show_fig=True,
|
|
633
|
+
save_fig=False,
|
|
634
|
+
with_labels=True,
|
|
635
|
+
show_ROC=True,
|
|
636
|
+
show_PR=True,
|
|
637
|
+
significant_digits=10,
|
|
638
|
+
bgcolor=None,
|
|
639
|
+
):
|
|
640
|
+
"""
|
|
641
|
+
Compute the ROC curve and save the result in the folder of the csv file.
|
|
642
|
+
|
|
643
|
+
Parameters
|
|
644
|
+
----------
|
|
645
|
+
bn : pyagrum.BayesNet
|
|
646
|
+
a Bayesian network
|
|
647
|
+
datasrc : str|DataFrame
|
|
648
|
+
a csv filename or a pandas.DataFrame
|
|
649
|
+
target : str
|
|
650
|
+
the target
|
|
651
|
+
label : str
|
|
652
|
+
the target label
|
|
653
|
+
beta : float
|
|
654
|
+
the value of beta for the F-beta score
|
|
655
|
+
show_progress : bool
|
|
656
|
+
indicates if the progress bar must be printed
|
|
657
|
+
save_fig:
|
|
658
|
+
save the result
|
|
659
|
+
show_fig:
|
|
660
|
+
plot the resuls
|
|
661
|
+
with_labels:
|
|
662
|
+
labels in csv
|
|
663
|
+
show_ROC: bool
|
|
664
|
+
whether we show the ROC figure
|
|
665
|
+
show_PR: bool
|
|
666
|
+
whether we show the PR figure
|
|
667
|
+
significant_digits:
|
|
668
|
+
number of significant digits when computing probabilities
|
|
669
|
+
bgcolor:
|
|
670
|
+
HTML background color for the figure (default: None if transparent)
|
|
671
|
+
|
|
672
|
+
Returns
|
|
673
|
+
-------
|
|
674
|
+
tuple
|
|
675
|
+
(pointsROC, thresholdROC, pointsPR, thresholdPR)
|
|
676
|
+
|
|
677
|
+
"""
|
|
678
|
+
if type(datasrc) is not str:
|
|
679
|
+
if hasattr(datasrc, "to_csv") or hasattr(datasrc, "write_csv"):
|
|
680
|
+
import tempfile
|
|
681
|
+
|
|
682
|
+
csvfile = tempfile.NamedTemporaryFile(delete=False)
|
|
683
|
+
tmpfilename = csvfile.name
|
|
684
|
+
csvfilename = tmpfilename + CSV_TMP_SUFFIX
|
|
685
|
+
csvfile.close()
|
|
686
|
+
if hasattr(datasrc, "to_csv"):
|
|
687
|
+
datasrc.to_csv(csvfilename, na_rep="?", index=False)
|
|
688
|
+
else:
|
|
689
|
+
datasrc.write_csv(csvfilename, na_rep="?", index=False)
|
|
690
|
+
|
|
691
|
+
showROC_PR(
|
|
692
|
+
bn,
|
|
693
|
+
csvfilename,
|
|
694
|
+
target,
|
|
695
|
+
label,
|
|
696
|
+
beta=beta,
|
|
697
|
+
show_progress=show_progress,
|
|
698
|
+
show_fig=show_fig,
|
|
699
|
+
save_fig=save_fig,
|
|
700
|
+
with_labels=with_labels,
|
|
701
|
+
show_ROC=show_ROC,
|
|
702
|
+
show_PR=show_PR,
|
|
703
|
+
significant_digits=significant_digits,
|
|
704
|
+
)
|
|
705
|
+
|
|
706
|
+
os.remove(csvfilename)
|
|
707
|
+
return
|
|
708
|
+
else:
|
|
709
|
+
raise TypeError("first argument must be a string or a DataFrame")
|
|
710
|
+
|
|
711
|
+
if bgcolor is not None:
|
|
712
|
+
oldcol = gum.config["notebook", "figure_facecolor"]
|
|
713
|
+
gum.config["notebook", "figure_facecolor"] = bgcolor
|
|
714
|
+
|
|
715
|
+
filename = _getFilename(datasrc)
|
|
716
|
+
(res, totalP, totalN) = _computePoints(
|
|
717
|
+
bn,
|
|
718
|
+
datasrc,
|
|
719
|
+
target,
|
|
720
|
+
label,
|
|
721
|
+
beta=beta,
|
|
722
|
+
show_progress=show_progress,
|
|
723
|
+
with_labels=with_labels,
|
|
724
|
+
significant_digits=significant_digits,
|
|
725
|
+
)
|
|
726
|
+
(
|
|
727
|
+
pointsROC,
|
|
728
|
+
ind_ROC,
|
|
729
|
+
thresholdROC,
|
|
730
|
+
AUC_ROC,
|
|
731
|
+
fbeta_ROC,
|
|
732
|
+
pointsPR,
|
|
733
|
+
ind_PR,
|
|
734
|
+
thresholdPR,
|
|
735
|
+
AUC_PR,
|
|
736
|
+
fbeta_PR,
|
|
737
|
+
thresholds,
|
|
738
|
+
) = _computeROC_PR(res, totalP, totalN, beta)
|
|
739
|
+
try:
|
|
740
|
+
shortname = os.path.basename(bn.property("name"))
|
|
741
|
+
except gum.NotFound:
|
|
742
|
+
shortname = "unnamed"
|
|
743
|
+
title = shortname + " vs " + filename + " - " + target + "=" + str(label)
|
|
744
|
+
|
|
745
|
+
rate = totalP / (totalP + totalN)
|
|
746
|
+
|
|
747
|
+
if show_ROC and show_PR:
|
|
748
|
+
figname = f"{filename}-ROCandPR_{shortname}-{target}-{label}.png"
|
|
749
|
+
fig = pylab.figure(figsize=(10, 4))
|
|
750
|
+
fig.suptitle(title)
|
|
751
|
+
pylab.gcf().subplots_adjust(wspace=0.1)
|
|
752
|
+
|
|
753
|
+
ax1 = fig.add_subplot(1, 2, 1)
|
|
754
|
+
_drawROC(
|
|
755
|
+
points=pointsROC,
|
|
756
|
+
zeTitle="ROC",
|
|
757
|
+
fbeta_ROC=fbeta_ROC,
|
|
758
|
+
beta=beta,
|
|
759
|
+
AUC_ROC=AUC_ROC,
|
|
760
|
+
thresholds=thresholds,
|
|
761
|
+
thresholds_to_show=[thresholdROC, thresholdPR],
|
|
762
|
+
ax=ax1,
|
|
763
|
+
)
|
|
764
|
+
|
|
765
|
+
ax2 = fig.add_subplot(1, 2, 2)
|
|
766
|
+
ax2.yaxis.tick_right()
|
|
767
|
+
ax2.yaxis.set_label_position("right")
|
|
768
|
+
_drawPR(
|
|
769
|
+
points=pointsPR,
|
|
770
|
+
zeTitle="Precision-Recall",
|
|
771
|
+
fbeta_PR=fbeta_PR,
|
|
772
|
+
beta=beta,
|
|
773
|
+
AUC_PR=AUC_PR,
|
|
774
|
+
thresholds=thresholds,
|
|
775
|
+
thresholds_to_show=[thresholdPR, thresholdROC],
|
|
776
|
+
rate=rate,
|
|
777
|
+
ax=ax2,
|
|
778
|
+
)
|
|
779
|
+
elif show_ROC:
|
|
780
|
+
figname = f"{filename}-ROC_{shortname}-{target}-{label}.png"
|
|
781
|
+
|
|
782
|
+
_drawROC(
|
|
783
|
+
points=pointsROC,
|
|
784
|
+
zeTitle=title,
|
|
785
|
+
fbeta_ROC=fbeta_ROC,
|
|
786
|
+
beta=beta,
|
|
787
|
+
AUC_ROC=AUC_ROC,
|
|
788
|
+
thresholds=thresholds,
|
|
789
|
+
thresholds_to_show=[thresholdROC],
|
|
790
|
+
)
|
|
791
|
+
elif show_PR:
|
|
792
|
+
figname = f"{filename}-PR_{shortname}-{target}-{label}.png"
|
|
793
|
+
_drawPR(
|
|
794
|
+
points=pointsPR,
|
|
795
|
+
zeTitle=title,
|
|
796
|
+
fbeta_PR=fbeta_PR,
|
|
797
|
+
beta=beta,
|
|
798
|
+
AUC_PR=AUC_PR,
|
|
799
|
+
thresholds=thresholds,
|
|
800
|
+
thresholds_to_show=[thresholdPR],
|
|
801
|
+
rate=rate,
|
|
802
|
+
)
|
|
803
|
+
|
|
804
|
+
if save_fig:
|
|
805
|
+
pylab.savefig(figname, dpi=300, transparent=(bgcolor is None))
|
|
806
|
+
|
|
807
|
+
if show_fig:
|
|
808
|
+
pylab.show()
|
|
809
|
+
|
|
810
|
+
if bgcolor is not None:
|
|
811
|
+
gum.config["notebook", "figure_facecolor"] = oldcol
|
|
812
|
+
|
|
813
|
+
return AUC_ROC, thresholdROC, AUC_PR, thresholdPR
|
|
814
|
+
|
|
815
|
+
|
|
816
|
+
def showROC(
|
|
817
|
+
bn, datasrc, target, label, show_progress=True, show_fig=True, save_fig=False, with_labels=True, significant_digits=10
|
|
818
|
+
):
|
|
819
|
+
"""
|
|
820
|
+
Compute the ROC curve and save the result in the folder of the csv file.
|
|
821
|
+
|
|
822
|
+
Parameters
|
|
823
|
+
----------
|
|
824
|
+
bn : pyagrum.BayesNet
|
|
825
|
+
a Bayesian network
|
|
826
|
+
datasrc : str|DataFrame
|
|
827
|
+
a csv filename or a pandas.DataFrame
|
|
828
|
+
target : str
|
|
829
|
+
the target
|
|
830
|
+
label : str
|
|
831
|
+
the target label
|
|
832
|
+
show_progress : bool
|
|
833
|
+
indicates if the progress bar must be printed
|
|
834
|
+
save_fig:
|
|
835
|
+
save the result
|
|
836
|
+
show_fig:
|
|
837
|
+
plot the resuls
|
|
838
|
+
with_labels:
|
|
839
|
+
labels in csv
|
|
840
|
+
significant_digits:
|
|
841
|
+
number of significant digits when computing probabilities
|
|
842
|
+
"""
|
|
843
|
+
|
|
844
|
+
return showROC_PR(
|
|
845
|
+
bn,
|
|
846
|
+
datasrc,
|
|
847
|
+
target,
|
|
848
|
+
label,
|
|
849
|
+
show_progress=show_progress,
|
|
850
|
+
show_fig=show_fig,
|
|
851
|
+
save_fig=save_fig,
|
|
852
|
+
with_labels=with_labels,
|
|
853
|
+
show_ROC=True,
|
|
854
|
+
show_PR=False,
|
|
855
|
+
significant_digits=significant_digits,
|
|
856
|
+
)
|
|
857
|
+
|
|
858
|
+
|
|
859
|
+
def showPR(
|
|
860
|
+
bn,
|
|
861
|
+
datasrc,
|
|
862
|
+
target,
|
|
863
|
+
label,
|
|
864
|
+
*,
|
|
865
|
+
beta=1,
|
|
866
|
+
show_progress=True,
|
|
867
|
+
show_fig=True,
|
|
868
|
+
save_fig=False,
|
|
869
|
+
with_labels=True,
|
|
870
|
+
significant_digits=10,
|
|
871
|
+
):
|
|
872
|
+
"""
|
|
873
|
+
Compute the ROC curve and save the result in the folder of the csv file.
|
|
874
|
+
|
|
875
|
+
Parameters
|
|
876
|
+
----------
|
|
877
|
+
bn : pyagrum.BayesNet
|
|
878
|
+
a Bayesian network
|
|
879
|
+
datasrc : str|DataFrame
|
|
880
|
+
a csv filename or a pandas.DataFrame
|
|
881
|
+
target : str
|
|
882
|
+
the target
|
|
883
|
+
label : str
|
|
884
|
+
the target label
|
|
885
|
+
show_progress : bool
|
|
886
|
+
indicates if the progress bar must be printed
|
|
887
|
+
save_fig:
|
|
888
|
+
save the result ?
|
|
889
|
+
show_fig:
|
|
890
|
+
plot the resuls ?
|
|
891
|
+
with_labels:
|
|
892
|
+
labels in csv ?
|
|
893
|
+
significant_digits:
|
|
894
|
+
number of significant digits when computing probabilities
|
|
895
|
+
"""
|
|
896
|
+
|
|
897
|
+
return showROC_PR(
|
|
898
|
+
bn,
|
|
899
|
+
datasrc,
|
|
900
|
+
target,
|
|
901
|
+
label,
|
|
902
|
+
beta=beta,
|
|
903
|
+
show_progress=show_progress,
|
|
904
|
+
show_fig=show_fig,
|
|
905
|
+
save_fig=save_fig,
|
|
906
|
+
with_labels=with_labels,
|
|
907
|
+
show_ROC=False,
|
|
908
|
+
show_PR=True,
|
|
909
|
+
significant_digits=significant_digits,
|
|
910
|
+
)
|
|
911
|
+
|
|
912
|
+
|
|
913
|
+
def animROC(bn, datasrc, target="Y", label="1"):
|
|
914
|
+
"""
|
|
915
|
+
Interactive selection of a threshold using TPR and FPR for BN and data
|
|
916
|
+
|
|
917
|
+
Parameters
|
|
918
|
+
----------
|
|
919
|
+
bn : pyagrum.BayesNet
|
|
920
|
+
a Bayesian network
|
|
921
|
+
datasrc : str|DataFrame
|
|
922
|
+
a csv filename or a pandas.DataFrame
|
|
923
|
+
target : str
|
|
924
|
+
the target
|
|
925
|
+
label : str
|
|
926
|
+
the target label
|
|
927
|
+
"""
|
|
928
|
+
import ipywidgets as widgets
|
|
929
|
+
import matplotlib.pyplot as plt
|
|
930
|
+
import matplotlib.ticker as mtick
|
|
931
|
+
|
|
932
|
+
class DisplayROC:
|
|
933
|
+
def __init__(self, points):
|
|
934
|
+
self._x = [i / len(points) for i in range(len(points))]
|
|
935
|
+
self._y1, self._y2 = zip(*points)
|
|
936
|
+
self._points = points
|
|
937
|
+
|
|
938
|
+
def display(self, threshold):
|
|
939
|
+
rate = threshold / 100.0
|
|
940
|
+
indexes = int((len(self._points) - 1) * rate)
|
|
941
|
+
|
|
942
|
+
plt.rcParams["figure.figsize"] = (4, 3)
|
|
943
|
+
|
|
944
|
+
fig, (ax1, ax2) = plt.subplots(nrows=2)
|
|
945
|
+
ax1.plot(viewer._x, viewer._y1, "g")
|
|
946
|
+
ax1.plot(viewer._x, viewer._y2, "r")
|
|
947
|
+
ax1.plot([rate, rate], [0, 1])
|
|
948
|
+
ax1.xaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
|
949
|
+
|
|
950
|
+
ax2.barh([0, 1], self._points[indexes], color=["g", "r"])
|
|
951
|
+
ax2.set_yticks(ticks=[0, 1], labels=["FPR", "TPR"])
|
|
952
|
+
ax2.annotate(f" {self._points[indexes][0]:.1%}", xy=(1, 0), xytext=(1, -0.2))
|
|
953
|
+
ax2.annotate(f" {self._points[indexes][1]:.1%}", xy=(1, 1), xytext=(1, 0.8))
|
|
954
|
+
ax2.set_xlim(0, 1)
|
|
955
|
+
|
|
956
|
+
plt.tight_layout()
|
|
957
|
+
plt.show()
|
|
958
|
+
|
|
959
|
+
viewer = DisplayROC(getROCpoints(bn, datasrc, target=target, label=label))
|
|
960
|
+
|
|
961
|
+
def interactive_view(rate: float):
|
|
962
|
+
viewer.display(rate)
|
|
963
|
+
|
|
964
|
+
# widgets.interact(interactive_view, rate=(0,100,1))
|
|
965
|
+
interactive_plot = widgets.interactive(interactive_view, rate=(0, 100, 1))
|
|
966
|
+
output = interactive_plot.children[-1]
|
|
967
|
+
output.layout.height = "250px"
|
|
968
|
+
return interactive_plot
|
|
969
|
+
|
|
970
|
+
|
|
971
|
+
def animPR(bn, datasrc, target="Y", label="1"):
|
|
972
|
+
"""
|
|
973
|
+
Interactive selection of a threshold using TPR and FPR for BN and data
|
|
974
|
+
|
|
975
|
+
Parameters
|
|
976
|
+
----------
|
|
977
|
+
bn : pyagrum.BayesNet
|
|
978
|
+
a Bayesian network
|
|
979
|
+
datasrc : str|DataFrame
|
|
980
|
+
a csv filename or a pandas.DataFrame
|
|
981
|
+
target : str
|
|
982
|
+
the target
|
|
983
|
+
label : str
|
|
984
|
+
the target label
|
|
985
|
+
"""
|
|
986
|
+
import ipywidgets as widgets
|
|
987
|
+
import matplotlib.pyplot as plt
|
|
988
|
+
import matplotlib.ticker as mtick
|
|
989
|
+
|
|
990
|
+
class DisplayPR:
|
|
991
|
+
def __init__(self, points):
|
|
992
|
+
self._x = [i / len(points) for i in range(len(points))]
|
|
993
|
+
self._y1, self._y2 = zip(*points)
|
|
994
|
+
self._points = points
|
|
995
|
+
|
|
996
|
+
def display(self, threshold):
|
|
997
|
+
rate = threshold / 100.0
|
|
998
|
+
indexes = int((len(self._points) - 1) * rate)
|
|
999
|
+
|
|
1000
|
+
plt.rcParams["figure.figsize"] = (4, 3)
|
|
1001
|
+
|
|
1002
|
+
fig, (ax1, ax2) = plt.subplots(nrows=2)
|
|
1003
|
+
ax1.plot(viewer._x, viewer._y1, "r")
|
|
1004
|
+
ax1.plot(viewer._x, viewer._y2, "g")
|
|
1005
|
+
ax1.plot([rate, rate], [0, 1])
|
|
1006
|
+
ax1.xaxis.set_major_formatter(mtick.PercentFormatter(1.0))
|
|
1007
|
+
|
|
1008
|
+
ax2.barh([1, 0], self._points[indexes], color=["r", "g"])
|
|
1009
|
+
ax2.set_yticks(ticks=[0, 1], labels=["Precision", "Recall"])
|
|
1010
|
+
ax2.annotate(f" {self._points[indexes][1]:.1%}", xy=(1, 0), xytext=(1, -0.2))
|
|
1011
|
+
ax2.annotate(f" {self._points[indexes][0]:.1%}", xy=(1, 1), xytext=(1, 0.8))
|
|
1012
|
+
ax2.set_xlim(0, 1)
|
|
1013
|
+
|
|
1014
|
+
plt.tight_layout()
|
|
1015
|
+
plt.show()
|
|
1016
|
+
|
|
1017
|
+
viewer = DisplayPR(getPRpoints(bn, datasrc, target=target, label=label))
|
|
1018
|
+
|
|
1019
|
+
def interactive_view(rate: float):
|
|
1020
|
+
viewer.display(rate)
|
|
1021
|
+
|
|
1022
|
+
# widgets.interact(interactive_view, rate=(0,100,1))
|
|
1023
|
+
interactive_plot = widgets.interactive(interactive_view, rate=(0, 100, 1))
|
|
1024
|
+
output = interactive_plot.children[-1]
|
|
1025
|
+
output.layout.height = "250px"
|
|
1026
|
+
return interactive_plot
|