validmind 2.6.10__py3-none-any.whl → 2.7.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +2 -0
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +20 -4
- validmind/ai/test_result_description/user.jinja +5 -0
- validmind/datasets/credit_risk/lending_club.py +444 -14
- validmind/tests/data_validation/MutualInformation.py +129 -0
- validmind/tests/data_validation/ScoreBandDefaultRates.py +139 -0
- validmind/tests/data_validation/TooManyZeroValues.py +6 -5
- validmind/tests/data_validation/UniqueRows.py +3 -1
- validmind/tests/decorator.py +18 -16
- validmind/tests/model_validation/sklearn/CalibrationCurve.py +116 -0
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +261 -0
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +1 -0
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +144 -56
- validmind/tests/model_validation/sklearn/ModelParameters.py +74 -0
- validmind/tests/model_validation/sklearn/ROCCurve.py +26 -23
- validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +130 -0
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +5 -6
- validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -3
- validmind/tests/output.py +10 -1
- validmind/tests/run.py +52 -54
- validmind/utils.py +34 -7
- validmind/vm_models/figure.py +15 -0
- validmind/vm_models/result/__init__.py +2 -2
- validmind/vm_models/result/result.py +136 -23
- {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/METADATA +1 -1
- {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/RECORD +30 -24
- {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/LICENSE +0 -0
- {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/WHEEL +0 -0
- {validmind-2.6.10.dist-info → validmind-2.7.4.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,261 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import plotly.graph_objects as go
|
8
|
+
from plotly.subplots import make_subplots
|
9
|
+
from sklearn.metrics import (
|
10
|
+
roc_curve,
|
11
|
+
precision_recall_curve,
|
12
|
+
confusion_matrix,
|
13
|
+
)
|
14
|
+
from validmind import tags, tasks
|
15
|
+
from validmind.vm_models import VMDataset, VMModel
|
16
|
+
|
17
|
+
|
18
|
+
def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
|
19
|
+
"""
|
20
|
+
Find the optimal classification threshold using various methods.
|
21
|
+
|
22
|
+
Args:
|
23
|
+
y_true: True binary labels
|
24
|
+
y_prob: Predicted probabilities
|
25
|
+
method: Method to use for finding optimal threshold
|
26
|
+
target_recall: Required if method='target_recall'
|
27
|
+
|
28
|
+
Returns:
|
29
|
+
dict: Dictionary containing threshold and metrics
|
30
|
+
"""
|
31
|
+
# Get ROC and PR curve points
|
32
|
+
fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
|
33
|
+
precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
|
34
|
+
|
35
|
+
# Find optimal threshold based on method
|
36
|
+
if method == "naive":
|
37
|
+
optimal_threshold = 0.5
|
38
|
+
elif method == "youden":
|
39
|
+
j_scores = tpr - fpr
|
40
|
+
best_idx = np.argmax(j_scores)
|
41
|
+
optimal_threshold = thresholds_roc[best_idx]
|
42
|
+
elif method == "f1":
|
43
|
+
f1_scores = 2 * (precision * recall) / (precision + recall)
|
44
|
+
best_idx = np.argmax(f1_scores)
|
45
|
+
optimal_threshold = (
|
46
|
+
thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
|
47
|
+
)
|
48
|
+
elif method == "precision_recall":
|
49
|
+
diff = abs(precision - recall)
|
50
|
+
best_idx = np.argmin(diff)
|
51
|
+
optimal_threshold = (
|
52
|
+
thresholds_pr[best_idx] if best_idx < len(thresholds_pr) else 1.0
|
53
|
+
)
|
54
|
+
elif method == "target_recall":
|
55
|
+
if target_recall is None:
|
56
|
+
raise ValueError(
|
57
|
+
"target_recall must be specified when method='target_recall'"
|
58
|
+
)
|
59
|
+
idx = np.argmin(abs(recall - target_recall))
|
60
|
+
optimal_threshold = thresholds_pr[idx] if idx < len(thresholds_pr) else 1.0
|
61
|
+
else:
|
62
|
+
raise ValueError(f"Unknown method: {method}")
|
63
|
+
|
64
|
+
# Calculate predictions with optimal threshold
|
65
|
+
y_pred = (y_prob >= optimal_threshold).astype(int)
|
66
|
+
|
67
|
+
# Calculate confusion matrix
|
68
|
+
tn, fp, fn, tp = confusion_matrix(y_true, y_pred).ravel()
|
69
|
+
|
70
|
+
# Calculate metrics directly
|
71
|
+
metrics = {
|
72
|
+
"method": method,
|
73
|
+
"threshold": optimal_threshold,
|
74
|
+
"precision": tp / (tp + fp) if (tp + fp) > 0 else 0,
|
75
|
+
"recall": tp / (tp + fn) if (tp + fn) > 0 else 0,
|
76
|
+
"f1_score": 2 * tp / (2 * tp + fp + fn) if (2 * tp + fp + fn) > 0 else 0,
|
77
|
+
"accuracy": (tp + tn) / (tp + tn + fp + fn),
|
78
|
+
}
|
79
|
+
|
80
|
+
return metrics
|
81
|
+
|
82
|
+
|
83
|
+
@tags("model_validation", "threshold_optimization", "classification_metrics")
|
84
|
+
@tasks("classification")
|
85
|
+
def ClassifierThresholdOptimization(
|
86
|
+
dataset: VMDataset, model: VMModel, methods=None, target_recall=None
|
87
|
+
):
|
88
|
+
"""
|
89
|
+
Analyzes and visualizes different threshold optimization methods for binary classification models.
|
90
|
+
|
91
|
+
### Purpose
|
92
|
+
|
93
|
+
The Classifier Threshold Optimization test identifies optimal decision thresholds using various
|
94
|
+
methods to balance different performance metrics. This helps adapt the model's decision boundary
|
95
|
+
to specific business requirements, such as minimizing false positives in fraud detection or
|
96
|
+
achieving target recall in medical diagnosis.
|
97
|
+
|
98
|
+
### Test Mechanism
|
99
|
+
|
100
|
+
The test implements multiple threshold optimization methods:
|
101
|
+
1. Youden's J statistic (maximizing sensitivity + specificity - 1)
|
102
|
+
2. F1-score optimization (balancing precision and recall)
|
103
|
+
3. Precision-Recall equality point
|
104
|
+
4. Target recall achievement
|
105
|
+
5. Naive (0.5) threshold
|
106
|
+
For each method, it computes ROC and PR curves, identifies optimal points, and provides
|
107
|
+
comprehensive performance metrics at each threshold.
|
108
|
+
|
109
|
+
### Signs of High Risk
|
110
|
+
|
111
|
+
- Large discrepancies between different optimization methods
|
112
|
+
- Optimal thresholds far from the default 0.5
|
113
|
+
- Poor performance metrics across all thresholds
|
114
|
+
- Significant gap between achieved and target recall
|
115
|
+
- Unstable thresholds across different methods
|
116
|
+
- Extreme trade-offs between precision and recall
|
117
|
+
- Threshold optimization showing minimal impact
|
118
|
+
- Business metrics not improving with optimization
|
119
|
+
|
120
|
+
### Strengths
|
121
|
+
|
122
|
+
- Multiple optimization strategies for different needs
|
123
|
+
- Visual and numerical results for comparison
|
124
|
+
- Support for business-driven optimization (target recall)
|
125
|
+
- Comprehensive performance metrics at each threshold
|
126
|
+
- Integration with ROC and PR curves
|
127
|
+
- Handles class imbalance through various metrics
|
128
|
+
- Enables informed threshold selection
|
129
|
+
- Supports cost-sensitive decision making
|
130
|
+
|
131
|
+
### Limitations
|
132
|
+
|
133
|
+
- Assumes cost of false positives/negatives are known
|
134
|
+
- May need adjustment for highly imbalanced datasets
|
135
|
+
- Threshold might not be stable across different samples
|
136
|
+
- Cannot handle multi-class problems directly
|
137
|
+
- Optimization methods may conflict with business needs
|
138
|
+
- Requires sufficient validation data
|
139
|
+
- May not capture temporal changes in optimal threshold
|
140
|
+
- Single threshold may not be optimal for all subgroups
|
141
|
+
|
142
|
+
Args:
|
143
|
+
dataset: VMDataset containing features and target
|
144
|
+
model: VMModel containing predictions
|
145
|
+
methods: List of methods to compare (default: ['youden', 'f1', 'precision_recall'])
|
146
|
+
target_recall: Target recall value if using 'target_recall' method
|
147
|
+
|
148
|
+
Returns:
|
149
|
+
Dictionary containing:
|
150
|
+
- table: DataFrame comparing different threshold optimization methods
|
151
|
+
(using weighted averages for precision, recall, and f1)
|
152
|
+
- figure: Plotly figure showing ROC and PR curves with optimal thresholds
|
153
|
+
"""
|
154
|
+
# Verify binary classification
|
155
|
+
unique_values = np.unique(dataset.y)
|
156
|
+
if len(unique_values) != 2:
|
157
|
+
raise ValueError("Target variable must be binary")
|
158
|
+
|
159
|
+
if methods is None:
|
160
|
+
methods = ["naive", "youden", "f1", "precision_recall"]
|
161
|
+
if target_recall is not None:
|
162
|
+
methods.append("target_recall")
|
163
|
+
|
164
|
+
y_true = dataset.y
|
165
|
+
y_prob = dataset.y_prob(model)
|
166
|
+
|
167
|
+
# Get curve points for plotting
|
168
|
+
fpr, tpr, thresholds_roc = roc_curve(y_true, y_prob)
|
169
|
+
precision, recall, thresholds_pr = precision_recall_curve(y_true, y_prob)
|
170
|
+
|
171
|
+
# Calculate optimal thresholds and metrics
|
172
|
+
results = []
|
173
|
+
optimal_points = {}
|
174
|
+
|
175
|
+
for method in methods:
|
176
|
+
metrics = find_optimal_threshold(y_true, y_prob, method, target_recall)
|
177
|
+
results.append(metrics)
|
178
|
+
|
179
|
+
# Store optimal points for plotting
|
180
|
+
if method == "youden":
|
181
|
+
idx = np.argmax(tpr - fpr)
|
182
|
+
optimal_points[method] = {
|
183
|
+
"x": fpr[idx],
|
184
|
+
"y": tpr[idx],
|
185
|
+
"threshold": thresholds_roc[idx],
|
186
|
+
}
|
187
|
+
elif method in ["f1", "precision_recall", "target_recall"]:
|
188
|
+
idx = np.argmin(abs(thresholds_pr - metrics["threshold"]))
|
189
|
+
optimal_points[method] = {
|
190
|
+
"x": recall[idx],
|
191
|
+
"y": precision[idx],
|
192
|
+
"threshold": metrics["threshold"],
|
193
|
+
}
|
194
|
+
|
195
|
+
# Create visualization
|
196
|
+
fig = make_subplots(
|
197
|
+
rows=1, cols=2, subplot_titles=("ROC Curve", "Precision-Recall Curve")
|
198
|
+
)
|
199
|
+
|
200
|
+
# Plot ROC curve
|
201
|
+
fig.add_trace(
|
202
|
+
go.Scatter(x=fpr, y=tpr, name="ROC Curve", line=dict(color="blue")),
|
203
|
+
row=1,
|
204
|
+
col=1,
|
205
|
+
)
|
206
|
+
|
207
|
+
# Plot PR curve
|
208
|
+
fig.add_trace(
|
209
|
+
go.Scatter(x=recall, y=precision, name="PR Curve", line=dict(color="green")),
|
210
|
+
row=1,
|
211
|
+
col=2,
|
212
|
+
)
|
213
|
+
|
214
|
+
# Add optimal points
|
215
|
+
colors = {
|
216
|
+
"youden": "red",
|
217
|
+
"f1": "orange",
|
218
|
+
"precision_recall": "purple",
|
219
|
+
"target_recall": "brown",
|
220
|
+
}
|
221
|
+
|
222
|
+
for method, points in optimal_points.items():
|
223
|
+
if method == "youden":
|
224
|
+
fig.add_trace(
|
225
|
+
go.Scatter(
|
226
|
+
x=[points["x"]],
|
227
|
+
y=[points["y"]],
|
228
|
+
name=f'{method} (t={points["threshold"]:.2f})',
|
229
|
+
mode="markers",
|
230
|
+
marker=dict(size=10, color=colors[method]),
|
231
|
+
),
|
232
|
+
row=1,
|
233
|
+
col=1,
|
234
|
+
)
|
235
|
+
else:
|
236
|
+
fig.add_trace(
|
237
|
+
go.Scatter(
|
238
|
+
x=[points["x"]],
|
239
|
+
y=[points["y"]],
|
240
|
+
name=f'{method} (t={points["threshold"]:.2f})',
|
241
|
+
mode="markers",
|
242
|
+
marker=dict(size=10, color=colors[method]),
|
243
|
+
),
|
244
|
+
row=1,
|
245
|
+
col=2,
|
246
|
+
)
|
247
|
+
|
248
|
+
# Update layout
|
249
|
+
fig.update_layout(
|
250
|
+
height=500, title_text="Threshold Optimization Analysis", showlegend=True
|
251
|
+
)
|
252
|
+
|
253
|
+
fig.update_xaxes(title_text="False Positive Rate", row=1, col=1)
|
254
|
+
fig.update_xaxes(title_text="Recall", row=1, col=2)
|
255
|
+
fig.update_yaxes(title_text="True Positive Rate", row=1, col=1)
|
256
|
+
fig.update_yaxes(title_text="Precision", row=1, col=2)
|
257
|
+
|
258
|
+
# Create results table and sort by threshold descending
|
259
|
+
table = pd.DataFrame(results).sort_values("threshold", ascending=False)
|
260
|
+
|
261
|
+
return fig, table
|
@@ -2,73 +2,161 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from typing import Union
|
6
|
-
|
5
|
+
from typing import Union, Dict, List
|
7
6
|
from sklearn.model_selection import GridSearchCV
|
7
|
+
from sklearn.metrics import make_scorer, recall_score
|
8
8
|
|
9
9
|
from validmind import tags, tasks
|
10
|
-
from validmind.errors import SkipTestError
|
11
10
|
from validmind.vm_models import VMDataset, VMModel
|
12
11
|
|
13
12
|
|
14
13
|
@tags("sklearn", "model_performance")
|
15
14
|
@tasks("classification", "clustering")
|
15
|
+
def custom_recall(y_true, y_pred_proba, threshold=0.5):
|
16
|
+
y_pred = (y_pred_proba >= threshold).astype(int)
|
17
|
+
return recall_score(y_true, y_pred)
|
18
|
+
|
19
|
+
|
20
|
+
def _get_metrics(scoring):
|
21
|
+
"""Convert scoring parameter to list of metrics."""
|
22
|
+
if scoring is None:
|
23
|
+
return ["accuracy"]
|
24
|
+
return (
|
25
|
+
scoring
|
26
|
+
if isinstance(scoring, list)
|
27
|
+
else list(scoring.keys()) if isinstance(scoring, dict) else [scoring]
|
28
|
+
)
|
29
|
+
|
30
|
+
|
31
|
+
def _get_thresholds(thresholds):
|
32
|
+
"""Convert thresholds parameter to list."""
|
33
|
+
if thresholds is None:
|
34
|
+
return [0.5]
|
35
|
+
return [thresholds] if isinstance(thresholds, (int, float)) else thresholds
|
36
|
+
|
37
|
+
|
38
|
+
def _create_scoring_dict(scoring, metrics, threshold):
|
39
|
+
"""Create scoring dictionary for GridSearchCV."""
|
40
|
+
if scoring is None:
|
41
|
+
return None
|
42
|
+
|
43
|
+
scoring_dict = {}
|
44
|
+
for metric in metrics:
|
45
|
+
if metric == "recall":
|
46
|
+
scoring_dict[metric] = make_scorer(
|
47
|
+
custom_recall, needs_proba=True, threshold=threshold
|
48
|
+
)
|
49
|
+
elif metric == "roc_auc":
|
50
|
+
scoring_dict[metric] = "roc_auc"
|
51
|
+
else:
|
52
|
+
scoring_dict[metric] = metric
|
53
|
+
return scoring_dict
|
54
|
+
|
55
|
+
|
56
|
+
@tags("sklearn", "model_performance")
|
57
|
+
@tasks("clustering", "classification")
|
16
58
|
def HyperParametersTuning(
|
17
59
|
model: VMModel,
|
18
60
|
dataset: VMDataset,
|
19
|
-
param_grid:
|
20
|
-
scoring: Union[str,
|
61
|
+
param_grid: dict,
|
62
|
+
scoring: Union[str, List, Dict] = None,
|
63
|
+
thresholds: Union[float, List[float]] = None,
|
64
|
+
fit_params: dict = None,
|
21
65
|
):
|
22
66
|
"""
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
42
|
-
|
43
|
-
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
-
|
55
|
-
|
56
|
-
-
|
57
|
-
|
58
|
-
-
|
59
|
-
|
60
|
-
-
|
61
|
-
|
67
|
+
Performs exhaustive grid search over specified parameter ranges to find optimal model configurations
|
68
|
+
across different metrics and decision thresholds.
|
69
|
+
|
70
|
+
### Purpose
|
71
|
+
|
72
|
+
The Hyperparameter Tuning test systematically explores the model's parameter space to identify optimal
|
73
|
+
configurations. It supports multiple optimization metrics and decision thresholds, providing a comprehensive
|
74
|
+
view of how different parameter combinations affect various aspects of model performance.
|
75
|
+
|
76
|
+
### Test Mechanism
|
77
|
+
|
78
|
+
The test uses scikit-learn's GridSearchCV to perform cross-validation for each parameter combination.
|
79
|
+
For each specified threshold and optimization metric, it creates a scoring dictionary with
|
80
|
+
threshold-adjusted metrics, performs grid search with cross-validation, records best parameters and
|
81
|
+
corresponding scores, and combines results into a comparative table. This process is repeated for each
|
82
|
+
optimization metric to provide a comprehensive view of model performance under different configurations.
|
83
|
+
|
84
|
+
### Signs of High Risk
|
85
|
+
|
86
|
+
- Large performance variations across different parameter combinations
|
87
|
+
- Significant discrepancies between different optimization metrics
|
88
|
+
- Best parameters at the edges of the parameter grid
|
89
|
+
- Unstable performance across different thresholds
|
90
|
+
- Overly complex model configurations (risk of overfitting)
|
91
|
+
- Very different optimal parameters for different metrics
|
92
|
+
- Cross-validation scores showing high variance
|
93
|
+
- Extreme parameter values in best configurations
|
94
|
+
|
95
|
+
### Strengths
|
96
|
+
|
97
|
+
- Comprehensive exploration of parameter space
|
98
|
+
- Supports multiple optimization metrics
|
99
|
+
- Allows threshold optimization
|
100
|
+
- Provides comparative view across different configurations
|
101
|
+
- Uses cross-validation for robust evaluation
|
102
|
+
- Helps understand trade-offs between different metrics
|
103
|
+
- Enables systematic parameter selection
|
104
|
+
- Supports both classification and clustering tasks
|
105
|
+
|
106
|
+
### Limitations
|
107
|
+
|
108
|
+
- Computationally expensive for large parameter grids
|
109
|
+
- May not find global optimum (limited to grid points)
|
110
|
+
- Cannot handle dependencies between parameters
|
111
|
+
- Memory intensive for large datasets
|
112
|
+
- Limited to scikit-learn compatible models
|
113
|
+
- Cross-validation splits may not preserve time series structure
|
114
|
+
- Grid search may miss optimal values between grid points
|
115
|
+
- Resource intensive for high-dimensional parameter spaces
|
62
116
|
"""
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
117
|
+
fit_params = fit_params or {}
|
118
|
+
|
119
|
+
# Simple case: no scoring and no thresholds
|
120
|
+
if scoring is None and thresholds is None:
|
121
|
+
estimators = GridSearchCV(model.model, param_grid=param_grid, scoring=None)
|
122
|
+
estimators.fit(dataset.x_df(), dataset.y, **fit_params)
|
123
|
+
return [
|
124
|
+
{
|
125
|
+
"Best Model": estimators.best_estimator_,
|
126
|
+
"Best Parameters": estimators.best_params_,
|
127
|
+
}
|
128
|
+
]
|
129
|
+
|
130
|
+
# Complex case: with scoring or thresholds
|
131
|
+
results = []
|
132
|
+
metrics = _get_metrics(scoring)
|
133
|
+
thresholds = _get_thresholds(thresholds)
|
134
|
+
|
135
|
+
for threshold in thresholds:
|
136
|
+
scoring_dict = _create_scoring_dict(scoring, metrics, threshold)
|
137
|
+
|
138
|
+
for optimize_for in metrics:
|
139
|
+
estimators = GridSearchCV(
|
140
|
+
model.model,
|
141
|
+
param_grid=param_grid,
|
142
|
+
scoring=scoring_dict,
|
143
|
+
refit=optimize_for if scoring is not None else True,
|
144
|
+
)
|
145
|
+
|
146
|
+
estimators.fit(dataset.x_df(), dataset.y, **fit_params)
|
147
|
+
|
148
|
+
best_index = estimators.best_index_
|
149
|
+
row_result = {
|
150
|
+
"Optimized for": optimize_for,
|
151
|
+
"Threshold": threshold,
|
152
|
+
"Best Parameters": estimators.best_params_,
|
153
|
+
}
|
154
|
+
|
155
|
+
score_key = (
|
156
|
+
"mean_test_score" if scoring is None else f"mean_test_{optimize_for}"
|
157
|
+
)
|
158
|
+
row_result[optimize_for] = estimators.cv_results_[score_key][best_index]
|
159
|
+
|
160
|
+
results.append(row_result)
|
161
|
+
|
162
|
+
return results
|
@@ -0,0 +1,74 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import pandas as pd
|
6
|
+
from validmind import tags, tasks
|
7
|
+
|
8
|
+
|
9
|
+
@tags("model_training", "metadata")
|
10
|
+
@tasks("classification", "regression")
|
11
|
+
def ModelParameters(model, model_params=None):
|
12
|
+
"""
|
13
|
+
Extracts and displays model parameters in a structured format for transparency and reproducibility.
|
14
|
+
|
15
|
+
### Purpose
|
16
|
+
|
17
|
+
The Model Parameters test is designed to provide transparency into model configuration and ensure
|
18
|
+
reproducibility of machine learning models. It accomplishes this by extracting and presenting all
|
19
|
+
relevant parameters that define the model's behavior, making it easier to audit, validate, and
|
20
|
+
reproduce model training.
|
21
|
+
|
22
|
+
### Test Mechanism
|
23
|
+
|
24
|
+
The test leverages scikit-learn's API convention of get_params() to extract model parameters. It
|
25
|
+
produces a structured DataFrame containing parameter names and their corresponding values. For models
|
26
|
+
that follow scikit-learn's API (including XGBoost, RandomForest, and other estimators), all
|
27
|
+
parameters are automatically extracted and displayed.
|
28
|
+
|
29
|
+
### Signs of High Risk
|
30
|
+
|
31
|
+
- Missing crucial parameters that should be explicitly set
|
32
|
+
- Extreme parameter values that could indicate overfitting (e.g., unlimited tree depth)
|
33
|
+
- Inconsistent parameters across different versions of the same model type
|
34
|
+
- Parameter combinations known to cause instability or poor performance
|
35
|
+
- Default values used for critical parameters that should be tuned
|
36
|
+
|
37
|
+
### Strengths
|
38
|
+
|
39
|
+
- Universal compatibility with scikit-learn API-compliant models
|
40
|
+
- Ensures transparency in model configuration
|
41
|
+
- Facilitates model reproducibility and version control
|
42
|
+
- Enables systematic parameter auditing
|
43
|
+
- Supports both classification and regression models
|
44
|
+
- Helps identify potential configuration issues
|
45
|
+
|
46
|
+
### Limitations
|
47
|
+
|
48
|
+
- Only works with models implementing scikit-learn's get_params() method
|
49
|
+
- Cannot capture dynamic parameters set during model training
|
50
|
+
- Does not validate parameter values for model-specific appropriateness
|
51
|
+
- Parameter meanings and impacts may vary across different model types
|
52
|
+
- Cannot detect indirect parameter interactions or their effects on model performance
|
53
|
+
"""
|
54
|
+
# Check if model implements get_params()
|
55
|
+
if not hasattr(model.model, "get_params"):
|
56
|
+
return pd.DataFrame()
|
57
|
+
|
58
|
+
# Get all model parameters
|
59
|
+
params = model.model.get_params()
|
60
|
+
|
61
|
+
# If model_params is None, use all parameters from get_params()
|
62
|
+
if model_params is None:
|
63
|
+
model_params = sorted(params.keys()) # Sort for consistent ordering
|
64
|
+
|
65
|
+
# Create DataFrame with parameters and their values
|
66
|
+
param_df = pd.DataFrame(
|
67
|
+
[
|
68
|
+
{"Parameter": param, "Value": str(params.get(param, "Not specified"))}
|
69
|
+
for param in model_params
|
70
|
+
if params.get(param) is not None
|
71
|
+
]
|
72
|
+
)
|
73
|
+
|
74
|
+
return param_df
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
import plotly.graph_objects as go
|
7
7
|
from sklearn.metrics import roc_auc_score, roc_curve
|
8
8
|
|
9
|
-
from validmind import tags, tasks
|
9
|
+
from validmind import RawData, tags, tasks
|
10
10
|
from validmind.errors import SkipTestError
|
11
11
|
from validmind.vm_models import VMDataset, VMModel
|
12
12
|
|
@@ -77,28 +77,31 @@ def ROCCurve(model: VMModel, dataset: VMDataset):
|
|
77
77
|
fpr, tpr, _ = roc_curve(y_true, y_prob, drop_intermediate=False)
|
78
78
|
auc = roc_auc_score(y_true, y_prob)
|
79
79
|
|
80
|
-
return
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
80
|
+
return (
|
81
|
+
RawData(fpr=fpr, tpr=tpr, auc=auc),
|
82
|
+
go.Figure(
|
83
|
+
data=[
|
84
|
+
go.Scatter(
|
85
|
+
x=fpr,
|
86
|
+
y=tpr,
|
87
|
+
mode="lines",
|
88
|
+
name=f"ROC curve (AUC = {auc:.2f})",
|
89
|
+
line=dict(color="#DE257E"),
|
90
|
+
),
|
91
|
+
go.Scatter(
|
92
|
+
x=[0, 1],
|
93
|
+
y=[0, 1],
|
94
|
+
mode="lines",
|
95
|
+
name="Random (AUC = 0.5)",
|
96
|
+
line=dict(color="grey", dash="dash"),
|
97
|
+
),
|
98
|
+
],
|
99
|
+
layout=go.Layout(
|
100
|
+
title=f"ROC Curve for {model.input_id} on {dataset.input_id}",
|
101
|
+
xaxis=dict(title="False Positive Rate"),
|
102
|
+
yaxis=dict(title="True Positive Rate"),
|
103
|
+
width=700,
|
104
|
+
height=500,
|
88
105
|
),
|
89
|
-
go.Scatter(
|
90
|
-
x=[0, 1],
|
91
|
-
y=[0, 1],
|
92
|
-
mode="lines",
|
93
|
-
name="Random (AUC = 0.5)",
|
94
|
-
line=dict(color="grey", dash="dash"),
|
95
|
-
),
|
96
|
-
],
|
97
|
-
layout=go.Layout(
|
98
|
-
title=f"ROC Curve for {model.input_id} on {dataset.input_id}",
|
99
|
-
xaxis=dict(title="False Positive Rate"),
|
100
|
-
yaxis=dict(title="True Positive Rate"),
|
101
|
-
width=700,
|
102
|
-
height=500,
|
103
106
|
),
|
104
107
|
)
|