validmind 2.7.2__py3-none-any.whl → 2.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +20 -4
- validmind/ai/test_result_description/user.jinja +5 -0
- validmind/datasets/credit_risk/lending_club.py +444 -14
- validmind/tests/data_validation/MutualInformation.py +129 -0
- validmind/tests/data_validation/ScoreBandDefaultRates.py +139 -0
- validmind/tests/data_validation/TooManyZeroValues.py +6 -5
- validmind/tests/data_validation/UniqueRows.py +3 -1
- validmind/tests/decorator.py +18 -16
- validmind/tests/model_validation/sklearn/CalibrationCurve.py +116 -0
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +261 -0
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +1 -0
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +144 -56
- validmind/tests/model_validation/sklearn/ModelParameters.py +74 -0
- validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +130 -0
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +5 -6
- validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -3
- validmind/tests/run.py +43 -72
- validmind/utils.py +23 -7
- validmind/vm_models/result/result.py +18 -17
- {validmind-2.7.2.dist-info → validmind-2.7.4.dist-info}/METADATA +1 -1
- {validmind-2.7.2.dist-info → validmind-2.7.4.dist-info}/RECORD +25 -19
- {validmind-2.7.2.dist-info → validmind-2.7.4.dist-info}/LICENSE +0 -0
- {validmind-2.7.2.dist-info → validmind-2.7.4.dist-info}/WHEEL +0 -0
- {validmind-2.7.2.dist-info → validmind-2.7.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,139 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
import numpy as np
|
7
|
+
from validmind import tags, tasks
|
8
|
+
from validmind.vm_models import VMDataset, VMModel
|
9
|
+
|
10
|
+
|
11
|
+
@tags("visualization", "credit_risk", "scorecard")
|
12
|
+
@tasks("classification")
|
13
|
+
def ScoreBandDefaultRates(
|
14
|
+
dataset: VMDataset,
|
15
|
+
model: VMModel,
|
16
|
+
score_column: str = "score",
|
17
|
+
score_bands: list = None,
|
18
|
+
):
|
19
|
+
"""
|
20
|
+
Analyzes default rates and population distribution across credit score bands.
|
21
|
+
|
22
|
+
### Purpose
|
23
|
+
|
24
|
+
The Score Band Default Rates test evaluates the discriminatory power of credit scores by analyzing
|
25
|
+
default rates across different score bands. This helps validate score effectiveness, supports
|
26
|
+
policy decisions, and provides insights into portfolio risk distribution.
|
27
|
+
|
28
|
+
### Test Mechanism
|
29
|
+
|
30
|
+
The test segments the score distribution into bands and calculates key metrics for each band:
|
31
|
+
1. Population count and percentage in each band
|
32
|
+
2. Default rate within each band
|
33
|
+
3. Cumulative statistics across bands
|
34
|
+
The results show how well the scores separate good and bad accounts.
|
35
|
+
|
36
|
+
### Signs of High Risk
|
37
|
+
|
38
|
+
- Non-monotonic default rates across score bands
|
39
|
+
- Insufficient population in critical score bands
|
40
|
+
- Unexpected default rates for score ranges
|
41
|
+
- High concentration in specific score bands
|
42
|
+
- Similar default rates across adjacent bands
|
43
|
+
- Unstable default rates in key decision bands
|
44
|
+
- Extreme population skewness
|
45
|
+
- Poor risk separation between bands
|
46
|
+
|
47
|
+
### Strengths
|
48
|
+
|
49
|
+
- Clear view of score effectiveness
|
50
|
+
- Supports policy threshold decisions
|
51
|
+
- Easy to interpret and communicate
|
52
|
+
- Directly links to business decisions
|
53
|
+
- Shows risk segmentation power
|
54
|
+
- Identifies potential score issues
|
55
|
+
- Helps validate scoring model
|
56
|
+
- Supports portfolio monitoring
|
57
|
+
|
58
|
+
### Limitations
|
59
|
+
|
60
|
+
- Sensitive to band definition choices
|
61
|
+
- May mask within-band variations
|
62
|
+
- Requires sufficient data in each band
|
63
|
+
- Cannot capture non-linear patterns
|
64
|
+
- Point-in-time analysis only
|
65
|
+
- No temporal trend information
|
66
|
+
- Assumes band boundaries are appropriate
|
67
|
+
- May oversimplify risk patterns
|
68
|
+
"""
|
69
|
+
|
70
|
+
if score_column not in dataset.df.columns:
|
71
|
+
raise ValueError(
|
72
|
+
f"The required column '{score_column}' is not present in the dataset with input_id {dataset.input_id}"
|
73
|
+
)
|
74
|
+
|
75
|
+
df = dataset._df.copy()
|
76
|
+
|
77
|
+
# Default score bands if none provided
|
78
|
+
if score_bands is None:
|
79
|
+
score_bands = [410, 440, 470]
|
80
|
+
|
81
|
+
# Create band labels
|
82
|
+
band_labels = [
|
83
|
+
f"{score_bands[i]}-{score_bands[i+1]}" for i in range(len(score_bands) - 1)
|
84
|
+
]
|
85
|
+
band_labels.insert(0, f"<{score_bands[0]}")
|
86
|
+
band_labels.append(f">{score_bands[-1]}")
|
87
|
+
|
88
|
+
# Bin the scores with infinite upper bound
|
89
|
+
df["score_band"] = pd.cut(
|
90
|
+
df[score_column], bins=[-np.inf] + score_bands + [np.inf], labels=band_labels
|
91
|
+
)
|
92
|
+
|
93
|
+
# Calculate min and max scores for the total row
|
94
|
+
min_score = df[score_column].min()
|
95
|
+
max_score = df[score_column].max()
|
96
|
+
|
97
|
+
# Get predicted classes (0/1)
|
98
|
+
y_pred = dataset.y_pred(model)
|
99
|
+
|
100
|
+
# Calculate metrics by band using target_column name
|
101
|
+
results = []
|
102
|
+
for band in band_labels:
|
103
|
+
band_mask = df["score_band"] == band
|
104
|
+
population = band_mask.sum()
|
105
|
+
observed_defaults = df[band_mask][dataset.target_column].sum()
|
106
|
+
predicted_defaults = y_pred[
|
107
|
+
band_mask
|
108
|
+
].sum() # Sum of 1s gives number of predicted defaults
|
109
|
+
|
110
|
+
results.append(
|
111
|
+
{
|
112
|
+
"Score Band": band,
|
113
|
+
"Population Count": population,
|
114
|
+
"Population (%)": population / len(df) * 100,
|
115
|
+
"Predicted Default Rate (%)": (
|
116
|
+
predicted_defaults / population * 100 if population > 0 else 0
|
117
|
+
),
|
118
|
+
"Observed Default Rate (%)": (
|
119
|
+
observed_defaults / population * 100 if population > 0 else 0
|
120
|
+
),
|
121
|
+
}
|
122
|
+
)
|
123
|
+
|
124
|
+
# Add total row
|
125
|
+
total_population = len(df)
|
126
|
+
total_observed = df[dataset.target_column].sum()
|
127
|
+
total_predicted = y_pred.sum() # Total number of predicted defaults
|
128
|
+
|
129
|
+
results.append(
|
130
|
+
{
|
131
|
+
"Score Band": f"Total ({min_score:.0f}-{max_score:.0f})",
|
132
|
+
"Population Count": total_population,
|
133
|
+
"Population (%)": sum(r["Population (%)"] for r in results),
|
134
|
+
"Predicted Default Rate (%)": total_predicted / total_population * 100,
|
135
|
+
"Observed Default Rate (%)": total_observed / total_population * 100,
|
136
|
+
}
|
137
|
+
)
|
138
|
+
|
139
|
+
return pd.DataFrame(results)
|
@@ -61,24 +61,25 @@ def TooManyZeroValues(dataset: VMDataset, max_percent_threshold: float = 0.03):
|
|
61
61
|
issues.
|
62
62
|
"""
|
63
63
|
df = dataset.df
|
64
|
-
|
65
64
|
table = []
|
66
65
|
|
67
66
|
for col in dataset.feature_columns_numeric:
|
68
67
|
value_counts = df[col].value_counts()
|
68
|
+
row_count = df.shape[0]
|
69
69
|
|
70
70
|
if 0 not in value_counts.index:
|
71
71
|
continue
|
72
72
|
|
73
73
|
n_zeros = value_counts[0]
|
74
|
-
p_zeros = n_zeros /
|
74
|
+
p_zeros = (n_zeros / row_count) * 100
|
75
75
|
|
76
76
|
table.append(
|
77
77
|
{
|
78
|
-
"
|
78
|
+
"Variable": col,
|
79
|
+
"Row Count": row_count,
|
79
80
|
"Number of Zero Values": n_zeros,
|
80
|
-
"Percentage of Zero Values (%)": p_zeros
|
81
|
-
"Pass/Fail": "Pass" if p_zeros < max_percent_threshold else "Fail",
|
81
|
+
"Percentage of Zero Values (%)": p_zeros,
|
82
|
+
"Pass/Fail": ("Pass" if p_zeros < (max_percent_threshold) else "Fail"),
|
82
83
|
}
|
83
84
|
)
|
84
85
|
|
@@ -61,7 +61,9 @@ def UniqueRows(dataset: VMDataset, min_percent_threshold: float = 1):
|
|
61
61
|
"Number of Unique Values": unique_rows[col],
|
62
62
|
"Percentage of Unique Values (%)": unique_rows[col] / rows * 100,
|
63
63
|
"Pass/Fail": (
|
64
|
-
"Pass"
|
64
|
+
"Pass"
|
65
|
+
if (unique_rows[col] / rows * 100) >= min_percent_threshold
|
66
|
+
else "Fail"
|
65
67
|
),
|
66
68
|
}
|
67
69
|
for col in unique_rows.index
|
validmind/tests/decorator.py
CHANGED
@@ -24,6 +24,11 @@ def _get_save_func(func, test_id):
|
|
24
24
|
test library.
|
25
25
|
"""
|
26
26
|
|
27
|
+
# get og source before its wrapped by the test decorator
|
28
|
+
source = inspect.getsource(func)
|
29
|
+
# remove decorator line
|
30
|
+
source = source.split("\n", 1)[1]
|
31
|
+
|
27
32
|
def save(root_folder=".", imports=None):
|
28
33
|
parts = test_id.split(".")
|
29
34
|
|
@@ -41,35 +46,32 @@ def _get_save_func(func, test_id):
|
|
41
46
|
|
42
47
|
full_path = os.path.join(path, f"{test_name}.py")
|
43
48
|
|
44
|
-
|
45
|
-
|
46
|
-
source = source.split("\n", 1)[1]
|
49
|
+
_source = source.replace(f"def {func.__name__}", f"def {test_name}")
|
50
|
+
|
47
51
|
if imports:
|
48
52
|
imports = "\n".join(imports)
|
49
|
-
|
53
|
+
_source = f"{imports}\n\n\n{_source}"
|
54
|
+
|
50
55
|
# add comment to the top of the file
|
51
|
-
|
56
|
+
_source = f"""
|
52
57
|
# Saved from {func.__module__}.{func.__name__}
|
53
58
|
# Original Test ID: {test_id}
|
54
59
|
# New Test ID: {new_test_id}
|
55
60
|
|
56
|
-
{
|
61
|
+
{_source}
|
57
62
|
"""
|
58
63
|
|
59
|
-
# ensure that the function name matches the test name
|
60
|
-
source = source.replace(f"def {func.__name__}", f"def {test_name}")
|
61
|
-
|
62
64
|
# use black to format the code
|
63
65
|
try:
|
64
66
|
import black
|
65
67
|
|
66
|
-
|
68
|
+
_source = black.format_str(_source, mode=black.FileMode())
|
67
69
|
except ImportError:
|
68
70
|
# ignore if not available
|
69
71
|
pass
|
70
72
|
|
71
73
|
with open(full_path, "w") as file:
|
72
|
-
file.writelines(
|
74
|
+
file.writelines(_source)
|
73
75
|
|
74
76
|
logger.info(
|
75
77
|
f"Saved to {os.path.abspath(full_path)}!"
|
@@ -119,12 +121,12 @@ def test(func_or_id):
|
|
119
121
|
test_func = load_test(test_id, func, reload=True)
|
120
122
|
test_store.register_test(test_id, test_func)
|
121
123
|
|
122
|
-
@wraps(test_func)
|
123
|
-
def wrapper(*args, **kwargs):
|
124
|
-
return test_func(*args, **kwargs)
|
125
|
-
|
126
124
|
# special function to allow the function to be saved to a file
|
127
|
-
|
125
|
+
save_func = _get_save_func(func, test_id)
|
126
|
+
|
127
|
+
wrapper = wraps(func)(test_func)
|
128
|
+
wrapper.test_id = test_id
|
129
|
+
wrapper.save = save_func
|
128
130
|
|
129
131
|
return wrapper
|
130
132
|
|
@@ -0,0 +1,116 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from sklearn.calibration import calibration_curve
|
6
|
+
import plotly.graph_objects as go
|
7
|
+
from validmind import tags, tasks
|
8
|
+
from validmind.vm_models import VMModel, VMDataset
|
9
|
+
from validmind.vm_models.result import RawData
|
10
|
+
|
11
|
+
|
12
|
+
@tags("sklearn", "model_performance", "classification")
|
13
|
+
@tasks("classification")
|
14
|
+
def CalibrationCurve(model: VMModel, dataset: VMDataset, n_bins: int = 10):
|
15
|
+
"""
|
16
|
+
Evaluates the calibration of probability estimates by comparing predicted probabilities against observed
|
17
|
+
frequencies.
|
18
|
+
|
19
|
+
### Purpose
|
20
|
+
|
21
|
+
The Calibration Curve test assesses how well a model's predicted probabilities align with actual
|
22
|
+
observed frequencies. This is crucial for applications requiring accurate probability estimates,
|
23
|
+
such as risk assessment, decision-making systems, and cost-sensitive applications where probability
|
24
|
+
calibration directly impacts business decisions.
|
25
|
+
|
26
|
+
### Test Mechanism
|
27
|
+
|
28
|
+
The test uses sklearn's calibration_curve function to:
|
29
|
+
1. Sort predictions into bins based on predicted probabilities
|
30
|
+
2. Calculate the mean predicted probability in each bin
|
31
|
+
3. Compare against the observed frequency of positive cases
|
32
|
+
4. Plot the results against the perfect calibration line (y=x)
|
33
|
+
The resulting curve shows how well the predicted probabilities match empirical probabilities.
|
34
|
+
|
35
|
+
### Signs of High Risk
|
36
|
+
|
37
|
+
- Significant deviation from the perfect calibration line
|
38
|
+
- Systematic overconfidence (predictions too close to 0 or 1)
|
39
|
+
- Systematic underconfidence (predictions clustered around 0.5)
|
40
|
+
- Empty or sparse bins indicating poor probability coverage
|
41
|
+
- Sharp discontinuities in the calibration curve
|
42
|
+
- Different calibration patterns across different probability ranges
|
43
|
+
- Consistent over/under estimation in critical probability regions
|
44
|
+
- Large confidence intervals in certain probability ranges
|
45
|
+
|
46
|
+
### Strengths
|
47
|
+
|
48
|
+
- Visual and intuitive interpretation of probability quality
|
49
|
+
- Identifies systematic biases in probability estimates
|
50
|
+
- Supports probability threshold selection
|
51
|
+
- Helps understand model confidence patterns
|
52
|
+
- Applicable across different classification models
|
53
|
+
- Enables comparison between different models
|
54
|
+
- Guides potential need for recalibration
|
55
|
+
- Critical for risk-sensitive applications
|
56
|
+
|
57
|
+
### Limitations
|
58
|
+
|
59
|
+
- Sensitive to the number of bins chosen
|
60
|
+
- Requires sufficient samples in each bin for reliable estimates
|
61
|
+
- May mask local calibration issues within bins
|
62
|
+
- Does not account for feature-dependent calibration issues
|
63
|
+
- Limited to binary classification problems
|
64
|
+
- Cannot detect all forms of miscalibration
|
65
|
+
- Assumes bin boundaries are appropriate for the problem
|
66
|
+
- May be affected by class imbalance
|
67
|
+
"""
|
68
|
+
prob_true, prob_pred = calibration_curve(
|
69
|
+
dataset.y, dataset.y_prob(model), n_bins=n_bins
|
70
|
+
)
|
71
|
+
|
72
|
+
# Create DataFrame for raw data
|
73
|
+
raw_data = RawData(
|
74
|
+
mean_predicted_probability=prob_pred, observed_frequency=prob_true
|
75
|
+
)
|
76
|
+
|
77
|
+
# Create Plotly figure
|
78
|
+
fig = go.Figure()
|
79
|
+
|
80
|
+
# Add perfect calibration line
|
81
|
+
fig.add_trace(
|
82
|
+
go.Scatter(
|
83
|
+
x=[0, 1],
|
84
|
+
y=[0, 1],
|
85
|
+
mode="lines",
|
86
|
+
name="Perfect Calibration",
|
87
|
+
line=dict(dash="dash", color="gray"),
|
88
|
+
)
|
89
|
+
)
|
90
|
+
|
91
|
+
# Add calibration curve
|
92
|
+
fig.add_trace(
|
93
|
+
go.Scatter(
|
94
|
+
x=prob_pred,
|
95
|
+
y=prob_true,
|
96
|
+
mode="lines+markers",
|
97
|
+
name="Model Calibration",
|
98
|
+
line=dict(color="blue"),
|
99
|
+
marker=dict(size=8),
|
100
|
+
)
|
101
|
+
)
|
102
|
+
|
103
|
+
# Update layout
|
104
|
+
fig.update_layout(
|
105
|
+
title="Calibration Curve",
|
106
|
+
xaxis_title="Mean Predicted Probability",
|
107
|
+
yaxis_title="Observed Frequency",
|
108
|
+
xaxis=dict(range=[0, 1]),
|
109
|
+
yaxis=dict(range=[0, 1]),
|
110
|
+
width=800,
|
111
|
+
height=600,
|
112
|
+
showlegend=True,
|
113
|
+
template="plotly_white",
|
114
|
+
)
|
115
|
+
|
116
|
+
return raw_data, fig
|
@@ -0,0 +1,261 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import plotly.graph_objects as go
|
8
|
+
from plotly.subplots import make_subplots
|
9
|
+
from sklearn.metrics import (
|
10
|
+
roc_curve,
|
11
|
+
precision_recall_curve,
|
12
|
+
confusion_matrix,
|
13
|
+
)
|
14
|
+
from validmind import tags, tasks
|
15
|
+
from validmind.vm_models import VMDataset, VMModel
|
16
|
+
|
17
|
+
|
18
|
+
def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
|
19
|
+
"""
|
20
|
+
Find the optimal classification threshold using various methods.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
y_true: True binary labels
|
24
|
+
y_prob: Predicted probabilities
|
25
|
+
method: Method to use for finding optimal threshold
|
26
|
+
target_recall: Required if method='target_recall'
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
dict: Dictionary containing threshold and metrics
|
30
|
+
"""
|
31
|
+
# Get ROC and PR curve points
|
32
|
+
fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
|
33
|
+
precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
|
34
|
+
|
35
|
+
# Find optimal threshold based on method
|
36
|
+
if method == "naive":
|
37
|
+
optimal_threshold = 0.5
|
38
|
+
elif method == "youden":
|
39
|
+
j_scores = tpr - fpr
|
40
|
+
best_idx = np.argmax(j_scores)
|
41
|
+
optimal_threshold = thresholds_roc[best_idx]
|
42
|
+
elif method == "f1":
|
43
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
44
|
+
best_idx = np.argmax(f1_scores)
|
45
|
+
optimal_threshold = (
|
46
|
+
thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
|
47
|
+
)
|
48
|
+
elif method == "precision_recall":
|
49
|
+
diff = abs(precision - recall)
|
50
|
+
best_idx = np.argmin(diff)
|
51
|
+
optimal_threshold = (
|
52
|
+
thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
|
53
|
+
)
|
54
|
+
elif method == "target_recall":
|
55
|
+
if target_recall is None:
|
56
|
+
raise ValueError(
|
57
|
+
"target_recall must be specified when method='target_recall'"
|
58
|
+
)
|
59
|
+
idx = np.argmin(abs(recall - target_recall))
|
60
|
+
optimal_threshold = thresholds_pr[idx] if idx < len(thresholds_pr) else 1.0
|
61
|
+
else:
|
62
|
+
raise ValueError(f"Unknown method: {method}")
|
63
|
+
|
64
|
+
# Calculate predictions with optimal threshold
|
65
|
+
y_pred = (y_prob >= optimal_threshold).astype(int)
|
66
|
+
|
67
|
+
# Calculate confusion matrix
|
68
|
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
69
|
+
|
70
|
+
# Calculate metrics directly
|
71
|
+
metrics = {
|
72
|
+
"method": method,
|
73
|
+
"threshold": optimal_threshold,
|
74
|
+
"precision": tp / (tp + fp) if (tp + fp) > 0 else 0,
|
75
|
+
"recall": tp / (tp + fn) if (tp + fn) > 0 else 0,
|
76
|
+
"f1_score": 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
|
77
|
+
"accuracy": (tp + tn) / (tp + tn + fp + fn),
|
78
|
+
}
|
79
|
+
|
80
|
+
return metrics
|
81
|
+
|
82
|
+
|
83
|
+
@tags("model_validation", "threshold_optimization", "classification_metrics")
|
84
|
+
@tasks("classification")
|
85
|
+
def ClassifierThresholdOptimization(
|
86
|
+
dataset: VMDataset, model: VMModel, methods=None, target_recall=None
|
87
|
+
):
|
88
|
+
"""
|
89
|
+
Analyzes and visualizes different threshold optimization methods for binary classification models.
|
90
|
+
|
91
|
+
### Purpose
|
92
|
+
|
93
|
+
The Classifier Threshold Optimization test identifies optimal decision thresholds using various
|
94
|
+
methods to balance different performance metrics. This helps adapt the model's decision boundary
|
95
|
+
to specific business requirements, such as minimizing false positives in fraud detection or
|
96
|
+
achieving target recall in medical diagnosis.
|
97
|
+
|
98
|
+
### Test Mechanism
|
99
|
+
|
100
|
+
The test implements multiple threshold optimization methods:
|
101
|
+
1. Youden's J statistic (maximizing sensitivity + specificity - 1)
|
102
|
+
2. F1-score optimization (balancing precision and recall)
|
103
|
+
3. Precision-Recall equality point
|
104
|
+
4. Target recall achievement
|
105
|
+
5. Naive (0.5) threshold
|
106
|
+
For each method, it computes ROC and PR curves, identifies optimal points, and provides
|
107
|
+
comprehensive performance metrics at each threshold.
|
108
|
+
|
109
|
+
### Signs of High Risk
|
110
|
+
|
111
|
+
- Large discrepancies between different optimization methods
|
112
|
+
- Optimal thresholds far from the default 0.5
|
113
|
+
- Poor performance metrics across all thresholds
|
114
|
+
- Significant gap between achieved and target recall
|
115
|
+
- Unstable thresholds across different methods
|
116
|
+
- Extreme trade-offs between precision and recall
|
117
|
+
- Threshold optimization showing minimal impact
|
118
|
+
- Business metrics not improving with optimization
|
119
|
+
|
120
|
+
### Strengths
|
121
|
+
|
122
|
+
- Multiple optimization strategies for different needs
|
123
|
+
- Visual and numerical results for comparison
|
124
|
+
- Support for business-driven optimization (target recall)
|
125
|
+
- Comprehensive performance metrics at each threshold
|
126
|
+
- Integration with ROC and PR curves
|
127
|
+
- Handles class imbalance through various metrics
|
128
|
+
- Enables informed threshold selection
|
129
|
+
- Supports cost-sensitive decision making
|
130
|
+
|
131
|
+
### Limitations
|
132
|
+
|
133
|
+
- Assumes cost of false positives/negatives are known
|
134
|
+
- May need adjustment for highly imbalanced datasets
|
135
|
+
- Threshold might not be stable across different samples
|
136
|
+
- Cannot handle multi-class problems directly
|
137
|
+
- Optimization methods may conflict with business needs
|
138
|
+
- Requires sufficient validation data
|
139
|
+
- May not capture temporal changes in optimal threshold
|
140
|
+
- Single threshold may not be optimal for all subgroups
|
141
|
+
|
142
|
+
Args:
|
143
|
+
dataset: VMDataset containing features and target
|
144
|
+
model: VMModel containing predictions
|
145
|
+
methods: List of methods to compare (default: ['youden', 'f1', 'precision_recall'])
|
146
|
+
target_recall: Target recall value if using 'target_recall' method
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Dictionary containing:
|
150
|
+
- table: DataFrame comparing different threshold optimization methods
|
151
|
+
(using weighted averages for precision, recall, and f1)
|
152
|
+
- figure: Plotly figure showing ROC and PR curves with optimal thresholds
|
153
|
+
"""
|
154
|
+
# Verify binary classification
|
155
|
+
unique_values = np.unique(dataset.y)
|
156
|
+
if len(unique_values) != 2:
|
157
|
+
raise ValueError("Target variable must be binary")
|
158
|
+
|
159
|
+
if methods is None:
|
160
|
+
methods = ["naive", "youden", "f1", "precision_recall"]
|
161
|
+
if target_recall is not None:
|
162
|
+
methods.append("target_recall")
|
163
|
+
|
164
|
+
y_true = dataset.y
|
165
|
+
y_prob = dataset.y_prob(model)
|
166
|
+
|
167
|
+
# Get curve points for plotting
|
168
|
+
fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
|
169
|
+
precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
|
170
|
+
|
171
|
+
# Calculate optimal thresholds and metrics
|
172
|
+
results = []
|
173
|
+
optimal_points = {}
|
174
|
+
|
175
|
+
for method in methods:
|
176
|
+
metrics = find_optimal_threshold(y_true, y_prob, method, target_recall)
|
177
|
+
results.append(metrics)
|
178
|
+
|
179
|
+
# Store optimal points for plotting
|
180
|
+
if method == "youden":
|
181
|
+
idx = np.argmax(tpr - fpr)
|
182
|
+
optimal_points[method] = {
|
183
|
+
"x": fpr[idx],
|
184
|
+
"y": tpr[idx],
|
185
|
+
"threshold": thresholds_roc[idx],
|
186
|
+
}
|
187
|
+
elif method in ["f1", "precision_recall", "target_recall"]:
|
188
|
+
idx = np.argmin(abs(thresholds_pr - metrics["threshold"]))
|
189
|
+
optimal_points[method] = {
|
190
|
+
"x": recall[idx],
|
191
|
+
"y": precision[idx],
|
192
|
+
"threshold": metrics["threshold"],
|
193
|
+
}
|
194
|
+
|
195
|
+
# Create visualization
|
196
|
+
fig = make_subplots(
|
197
|
+
rows=1, cols=2, subplot_titles=("ROC Curve", "Precision-Recall Curve")
|
198
|
+
)
|
199
|
+
|
200
|
+
# Plot ROC curve
|
201
|
+
fig.add_trace(
|
202
|
+
go.Scatter(x=fpr, y=tpr, name="ROC Curve", line=dict(color="blue")),
|
203
|
+
row=1,
|
204
|
+
col=1,
|
205
|
+
)
|
206
|
+
|
207
|
+
# Plot PR curve
|
208
|
+
fig.add_trace(
|
209
|
+
go.Scatter(x=recall, y=precision, name="PR Curve", line=dict(color="green")),
|
210
|
+
row=1,
|
211
|
+
col=2,
|
212
|
+
)
|
213
|
+
|
214
|
+
# Add optimal points
|
215
|
+
colors = {
|
216
|
+
"youden": "red",
|
217
|
+
"f1": "orange",
|
218
|
+
"precision_recall": "purple",
|
219
|
+
"target_recall": "brown",
|
220
|
+
}
|
221
|
+
|
222
|
+
for method, points in optimal_points.items():
|
223
|
+
if method == "youden":
|
224
|
+
fig.add_trace(
|
225
|
+
go.Scatter(
|
226
|
+
x=[points["x"]],
|
227
|
+
y=[points["y"]],
|
228
|
+
name=f'{method} (t={points["threshold"]:.2f})',
|
229
|
+
mode="markers",
|
230
|
+
marker=dict(size=10, color=colors[method]),
|
231
|
+
),
|
232
|
+
row=1,
|
233
|
+
col=1,
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
fig.add_trace(
|
237
|
+
go.Scatter(
|
238
|
+
x=[points["x"]],
|
239
|
+
y=[points["y"]],
|
240
|
+
name=f'{method} (t={points["threshold"]:.2f})',
|
241
|
+
mode="markers",
|
242
|
+
marker=dict(size=10, color=colors[method]),
|
243
|
+
),
|
244
|
+
row=1,
|
245
|
+
col=2,
|
246
|
+
)
|
247
|
+
|
248
|
+
# Update layout
|
249
|
+
fig.update_layout(
|
250
|
+
height=500, title_text="Threshold Optimization Analysis", showlegend=True
|
251
|
+
)
|
252
|
+
|
253
|
+
fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
|
254
|
+
fig.update_xaxes(title_text="Recall", row=1, col=2)
|
255
|
+
fig.update_yaxes(title_text="True Positive Rate", row=1, col=1)
|
256
|
+
fig.update_yaxes(title_text="Precision", row=1, col=2)
|
257
|
+
|
258
|
+
# Create results table and sort by threshold descending
|
259
|
+
table = pd.DataFrame(results).sort_values("threshold", ascending=False)
|
260
|
+
|
261
|
+
return fig, table
|