validmind 2.7.5__py3-none-any.whl → 2.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +2 -0
- validmind/__version__.py +1 -1
- validmind/api_client.py +8 -1
- validmind/datasets/credit_risk/lending_club.py +352 -87
- validmind/html_templates/content_blocks.py +1 -1
- validmind/tests/__types__.py +17 -0
- validmind/tests/data_validation/ACFandPACFPlot.py +6 -2
- validmind/tests/data_validation/AutoMA.py +2 -2
- validmind/tests/data_validation/BivariateScatterPlots.py +4 -2
- validmind/tests/data_validation/BoxPierce.py +2 -2
- validmind/tests/data_validation/ClassImbalance.py +2 -1
- validmind/tests/data_validation/DatasetDescription.py +11 -2
- validmind/tests/data_validation/DatasetSplit.py +2 -2
- validmind/tests/data_validation/DickeyFullerGLS.py +2 -2
- validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +8 -2
- validmind/tests/data_validation/HighCardinality.py +9 -2
- validmind/tests/data_validation/HighPearsonCorrelation.py +18 -4
- validmind/tests/data_validation/IQROutliersBarPlot.py +9 -2
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -2
- validmind/tests/data_validation/MissingValuesBarPlot.py +12 -9
- validmind/tests/data_validation/MutualInformation.py +6 -8
- validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -2
- validmind/tests/data_validation/ProtectedClassesCombination.py +6 -1
- validmind/tests/data_validation/ProtectedClassesDescription.py +1 -1
- validmind/tests/data_validation/ProtectedClassesDisparity.py +4 -5
- validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +1 -4
- validmind/tests/data_validation/RollingStatsPlot.py +21 -10
- validmind/tests/data_validation/ScatterPlot.py +3 -5
- validmind/tests/data_validation/ScoreBandDefaultRates.py +2 -1
- validmind/tests/data_validation/SeasonalDecompose.py +12 -2
- validmind/tests/data_validation/Skewness.py +6 -3
- validmind/tests/data_validation/SpreadPlot.py +8 -3
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +4 -2
- validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -2
- validmind/tests/data_validation/TargetRateBarPlots.py +4 -3
- validmind/tests/data_validation/TimeSeriesFrequency.py +7 -2
- validmind/tests/data_validation/TimeSeriesMissingValues.py +14 -10
- validmind/tests/data_validation/TimeSeriesOutliers.py +1 -5
- validmind/tests/data_validation/WOEBinPlots.py +2 -2
- validmind/tests/data_validation/WOEBinTable.py +11 -9
- validmind/tests/data_validation/nlp/CommonWords.py +2 -2
- validmind/tests/data_validation/nlp/Hashtags.py +2 -2
- validmind/tests/data_validation/nlp/LanguageDetection.py +9 -6
- validmind/tests/data_validation/nlp/Mentions.py +9 -6
- validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +2 -2
- validmind/tests/data_validation/nlp/Punctuations.py +4 -2
- validmind/tests/data_validation/nlp/Sentiment.py +2 -2
- validmind/tests/data_validation/nlp/StopWords.py +5 -4
- validmind/tests/data_validation/nlp/TextDescription.py +2 -2
- validmind/tests/data_validation/nlp/Toxicity.py +2 -2
- validmind/tests/model_validation/BertScore.py +2 -2
- validmind/tests/model_validation/BleuScore.py +2 -2
- validmind/tests/model_validation/ClusterSizeDistribution.py +2 -2
- validmind/tests/model_validation/ContextualRecall.py +2 -2
- validmind/tests/model_validation/FeaturesAUC.py +2 -2
- validmind/tests/model_validation/MeteorScore.py +2 -2
- validmind/tests/model_validation/ModelPredictionResiduals.py +2 -2
- validmind/tests/model_validation/RegardScore.py +6 -2
- validmind/tests/model_validation/RegressionResidualsPlot.py +4 -3
- validmind/tests/model_validation/RougeScore.py +6 -5
- validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +11 -2
- validmind/tests/model_validation/TokenDisparity.py +2 -2
- validmind/tests/model_validation/ToxicityScore.py +10 -2
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +9 -3
- validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +16 -2
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +5 -3
- validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +2 -2
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +14 -4
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -2
- validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +16 -2
- validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +2 -2
- validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +4 -5
- validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +4 -2
- validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +8 -6
- validmind/tests/model_validation/embeddings/utils.py +11 -1
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +2 -1
- validmind/tests/model_validation/ragas/AspectCritic.py +11 -7
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +2 -1
- validmind/tests/model_validation/ragas/ContextPrecision.py +2 -1
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +2 -1
- validmind/tests/model_validation/ragas/ContextRecall.py +2 -1
- validmind/tests/model_validation/ragas/Faithfulness.py +2 -1
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +2 -1
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +2 -1
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +2 -1
- validmind/tests/model_validation/sklearn/CalibrationCurve.py +3 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +2 -5
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -2
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +2 -2
- validmind/tests/model_validation/sklearn/FeatureImportance.py +1 -14
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +6 -3
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +2 -2
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +8 -4
- validmind/tests/model_validation/sklearn/ModelParameters.py +1 -0
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -3
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +2 -2
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +20 -16
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +4 -2
- validmind/tests/model_validation/sklearn/ROCCurve.py +1 -1
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +7 -9
- validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +1 -3
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +2 -1
- validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +2 -1
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -3
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +9 -1
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +1 -1
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +11 -4
- validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +1 -3
- validmind/tests/model_validation/statsmodels/GINITable.py +7 -15
- validmind/tests/model_validation/statsmodels/Lilliefors.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +5 -2
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +5 -2
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +7 -7
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +2 -2
- validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +220 -0
- validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +155 -0
- validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +146 -0
- validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +148 -0
- validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +193 -0
- validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +178 -0
- validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -120
- validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -44
- validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +204 -0
- validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +98 -0
- validmind/tests/ongoing_monitoring/ROCCurveDrift.py +150 -0
- validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +212 -0
- validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +209 -0
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -13
- validmind/tests/prompt_validation/Bias.py +13 -9
- validmind/tests/prompt_validation/Clarity.py +13 -9
- validmind/tests/prompt_validation/Conciseness.py +13 -9
- validmind/tests/prompt_validation/Delimitation.py +13 -9
- validmind/tests/prompt_validation/NegativeInstruction.py +14 -11
- validmind/tests/prompt_validation/Robustness.py +6 -2
- validmind/tests/prompt_validation/Specificity.py +13 -9
- validmind/tests/run.py +6 -0
- validmind/utils.py +7 -8
- validmind/vm_models/dataset/dataset.py +0 -4
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/METADATA +2 -3
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/RECORD +149 -138
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/WHEEL +1 -1
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/LICENSE +0 -0
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,220 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
import plotly.graph_objects as go
|
10
|
+
from sklearn.calibration import calibration_curve
|
11
|
+
|
12
|
+
from validmind import tags, tasks
|
13
|
+
from validmind.errors import SkipTestError
|
14
|
+
from validmind.vm_models import VMDataset, VMModel
|
15
|
+
|
16
|
+
|
17
|
+
@tags(
|
18
|
+
"sklearn",
|
19
|
+
"binary_classification",
|
20
|
+
"model_performance",
|
21
|
+
"visualization",
|
22
|
+
)
|
23
|
+
@tasks("classification", "text_classification")
|
24
|
+
def CalibrationCurveDrift(
|
25
|
+
datasets: List[VMDataset],
|
26
|
+
model: VMModel,
|
27
|
+
n_bins: int = 10,
|
28
|
+
drift_pct_threshold: float = 20,
|
29
|
+
):
|
30
|
+
"""
|
31
|
+
Evaluates changes in probability calibration between reference and monitoring datasets.
|
32
|
+
|
33
|
+
### Purpose
|
34
|
+
|
35
|
+
The Calibration Curve Drift test is designed to assess changes in the model's probability calibration
|
36
|
+
over time. By comparing calibration curves between reference and monitoring datasets, this test helps
|
37
|
+
identify whether the model's probability estimates remain reliable in production. This is crucial for
|
38
|
+
understanding if the model's risk predictions maintain their intended interpretation and whether
|
39
|
+
recalibration might be necessary.
|
40
|
+
|
41
|
+
### Test Mechanism
|
42
|
+
|
43
|
+
This test proceeds by generating calibration curves for both reference and monitoring datasets. For each
|
44
|
+
dataset, it bins the predicted probabilities and calculates the actual fraction of positives within each
|
45
|
+
bin. It then compares these values between datasets to identify significant shifts in calibration.
|
46
|
+
The test quantifies drift as percentage changes in both mean predicted probabilities and actual fractions
|
47
|
+
of positives per bin, providing both visual and numerical assessments of calibration stability.
|
48
|
+
|
49
|
+
### Signs of High Risk
|
50
|
+
|
51
|
+
- Large differences between reference and monitoring calibration curves
|
52
|
+
- Systematic over-estimation or under-estimation in monitoring dataset
|
53
|
+
- Significant drift percentages exceeding the threshold in multiple bins
|
54
|
+
- Changes in calibration concentrated in specific probability ranges
|
55
|
+
- Inconsistent drift patterns across the probability spectrum
|
56
|
+
- Empty or sparse bins indicating insufficient data for reliable comparison
|
57
|
+
|
58
|
+
### Strengths
|
59
|
+
|
60
|
+
- Provides visual and quantitative assessment of calibration changes
|
61
|
+
- Identifies specific probability ranges where calibration has shifted
|
62
|
+
- Enables early detection of systematic prediction biases
|
63
|
+
- Includes detailed bin-by-bin comparison of calibration metrics
|
64
|
+
- Handles edge cases with insufficient data in certain bins
|
65
|
+
- Supports both binary and probabilistic interpretation of results
|
66
|
+
|
67
|
+
### Limitations
|
68
|
+
|
69
|
+
- Requires sufficient data in each probability bin for reliable comparison
|
70
|
+
- Sensitive to choice of number of bins and binning strategy
|
71
|
+
- May not capture complex changes in probability distributions
|
72
|
+
- Cannot directly suggest recalibration parameters
|
73
|
+
- Limited to assessing probability calibration aspects
|
74
|
+
- Results may be affected by class imbalance changes
|
75
|
+
"""
|
76
|
+
|
77
|
+
# Check for binary classification
|
78
|
+
if len(np.unique(datasets[0].y)) > 2:
|
79
|
+
raise SkipTestError(
|
80
|
+
"Calibration Curve Drift is only supported for binary classification models"
|
81
|
+
)
|
82
|
+
|
83
|
+
# Calculate calibration for reference dataset
|
84
|
+
y_prob_ref = datasets[0].y_prob(model)
|
85
|
+
y_true_ref = datasets[0].y.astype(y_prob_ref.dtype).flatten()
|
86
|
+
prob_true_ref, prob_pred_ref = calibration_curve(
|
87
|
+
y_true_ref, y_prob_ref, n_bins=n_bins, strategy="uniform"
|
88
|
+
)
|
89
|
+
|
90
|
+
# Calculate calibration for monitoring dataset
|
91
|
+
y_prob_mon = datasets[1].y_prob(model)
|
92
|
+
y_true_mon = datasets[1].y.astype(y_prob_mon.dtype).flatten()
|
93
|
+
prob_true_mon, prob_pred_mon = calibration_curve(
|
94
|
+
y_true_mon, y_prob_mon, n_bins=n_bins, strategy="uniform"
|
95
|
+
)
|
96
|
+
|
97
|
+
# Create bin labels
|
98
|
+
bin_edges = np.linspace(0, 1, n_bins + 1)
|
99
|
+
bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(n_bins)]
|
100
|
+
|
101
|
+
# Create predicted probabilities table
|
102
|
+
pred_metrics = []
|
103
|
+
for i in range(n_bins):
|
104
|
+
ref_val = "no data" if i >= len(prob_pred_ref) else round(prob_pred_ref[i], 3)
|
105
|
+
mon_val = "no data" if i >= len(prob_pred_mon) else round(prob_pred_mon[i], 3)
|
106
|
+
|
107
|
+
pred_metrics.append(
|
108
|
+
{"Bin": bin_labels[i], "Reference": ref_val, "Monitoring": mon_val}
|
109
|
+
)
|
110
|
+
|
111
|
+
pred_df = pd.DataFrame(pred_metrics)
|
112
|
+
|
113
|
+
# Calculate drift only for bins with data
|
114
|
+
mask = (pred_df["Reference"] != "no data") & (pred_df["Monitoring"] != "no data")
|
115
|
+
pred_df["Drift (%)"] = None
|
116
|
+
pred_df.loc[mask, "Drift (%)"] = (
|
117
|
+
(
|
118
|
+
pd.to_numeric(pred_df.loc[mask, "Monitoring"])
|
119
|
+
- pd.to_numeric(pred_df.loc[mask, "Reference"])
|
120
|
+
)
|
121
|
+
/ pd.to_numeric(pred_df.loc[mask, "Reference"]).abs()
|
122
|
+
* 100
|
123
|
+
).round(2)
|
124
|
+
|
125
|
+
pred_df["Pass/Fail"] = None
|
126
|
+
pred_df.loc[mask, "Pass/Fail"] = (
|
127
|
+
pred_df.loc[mask, "Drift (%)"]
|
128
|
+
.abs()
|
129
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
130
|
+
)
|
131
|
+
pred_df.loc[~mask, "Pass/Fail"] = "N/A"
|
132
|
+
|
133
|
+
# Create fraction of positives table
|
134
|
+
true_metrics = []
|
135
|
+
for i in range(n_bins):
|
136
|
+
ref_val = "no data" if i >= len(prob_true_ref) else round(prob_true_ref[i], 3)
|
137
|
+
mon_val = "no data" if i >= len(prob_true_mon) else round(prob_true_mon[i], 3)
|
138
|
+
|
139
|
+
true_metrics.append(
|
140
|
+
{"Bin": bin_labels[i], "Reference": ref_val, "Monitoring": mon_val}
|
141
|
+
)
|
142
|
+
|
143
|
+
true_df = pd.DataFrame(true_metrics)
|
144
|
+
|
145
|
+
# Calculate drift only for bins with data
|
146
|
+
mask = (true_df["Reference"] != "no data") & (true_df["Monitoring"] != "no data")
|
147
|
+
true_df["Drift (%)"] = None
|
148
|
+
true_df.loc[mask, "Drift (%)"] = (
|
149
|
+
(
|
150
|
+
pd.to_numeric(true_df.loc[mask, "Monitoring"])
|
151
|
+
- pd.to_numeric(true_df.loc[mask, "Reference"])
|
152
|
+
)
|
153
|
+
/ pd.to_numeric(true_df.loc[mask, "Reference"]).abs()
|
154
|
+
* 100
|
155
|
+
).round(2)
|
156
|
+
|
157
|
+
true_df["Pass/Fail"] = None
|
158
|
+
true_df.loc[mask, "Pass/Fail"] = (
|
159
|
+
true_df.loc[mask, "Drift (%)"]
|
160
|
+
.abs()
|
161
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
162
|
+
)
|
163
|
+
true_df.loc[~mask, "Pass/Fail"] = "N/A"
|
164
|
+
|
165
|
+
# Create figure
|
166
|
+
fig = go.Figure()
|
167
|
+
|
168
|
+
# Add perfect calibration line
|
169
|
+
fig.add_trace(
|
170
|
+
go.Scatter(
|
171
|
+
x=[0, 1],
|
172
|
+
y=[0, 1],
|
173
|
+
mode="lines",
|
174
|
+
name="Perfect Calibration",
|
175
|
+
line=dict(color="grey", dash="dash"),
|
176
|
+
)
|
177
|
+
)
|
178
|
+
|
179
|
+
# Add reference calibration curve
|
180
|
+
fig.add_trace(
|
181
|
+
go.Scatter(
|
182
|
+
x=prob_pred_ref,
|
183
|
+
y=prob_true_ref,
|
184
|
+
mode="lines+markers",
|
185
|
+
name="Reference",
|
186
|
+
line=dict(color="blue", width=2),
|
187
|
+
marker=dict(size=8),
|
188
|
+
)
|
189
|
+
)
|
190
|
+
|
191
|
+
# Add monitoring calibration curve
|
192
|
+
fig.add_trace(
|
193
|
+
go.Scatter(
|
194
|
+
x=prob_pred_mon,
|
195
|
+
y=prob_true_mon,
|
196
|
+
mode="lines+markers",
|
197
|
+
name="Monitoring",
|
198
|
+
line=dict(color="red", width=2),
|
199
|
+
marker=dict(size=8),
|
200
|
+
)
|
201
|
+
)
|
202
|
+
|
203
|
+
fig.update_layout(
|
204
|
+
title="Calibration Curves Comparison",
|
205
|
+
xaxis=dict(title="Mean Predicted Probability", range=[0, 1]),
|
206
|
+
yaxis=dict(title="Fraction of Positives", range=[0, 1]),
|
207
|
+
width=700,
|
208
|
+
height=500,
|
209
|
+
)
|
210
|
+
|
211
|
+
# Calculate overall pass/fail (only for bins with data)
|
212
|
+
pass_fail_bool = (pred_df.loc[mask, "Pass/Fail"] == "Pass").all() and (
|
213
|
+
true_df.loc[mask, "Pass/Fail"] == "Pass"
|
214
|
+
).all()
|
215
|
+
|
216
|
+
return (
|
217
|
+
fig,
|
218
|
+
{"Mean Predicted Probabilities": pred_df, "Fraction of Positives": true_df},
|
219
|
+
pass_fail_bool,
|
220
|
+
)
|
@@ -0,0 +1,155 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
from scipy import stats
|
10
|
+
from sklearn.metrics import roc_auc_score
|
11
|
+
from sklearn.preprocessing import LabelBinarizer
|
12
|
+
|
13
|
+
from validmind import tags, tasks
|
14
|
+
from validmind.vm_models import VMDataset, VMModel
|
15
|
+
|
16
|
+
|
17
|
+
def multiclass_roc_auc_score(y_test, y_pred, average="macro"):
|
18
|
+
lb = LabelBinarizer()
|
19
|
+
lb.fit(y_test)
|
20
|
+
return roc_auc_score(lb.transform(y_test), lb.transform(y_pred), average=average)
|
21
|
+
|
22
|
+
|
23
|
+
def calculate_gini(y_true, y_prob):
|
24
|
+
"""Calculate Gini coefficient (2*AUC - 1)"""
|
25
|
+
return 2 * roc_auc_score(y_true, y_prob) - 1
|
26
|
+
|
27
|
+
|
28
|
+
def calculate_ks_statistic(y_true, y_prob):
|
29
|
+
"""Calculate Kolmogorov-Smirnov statistic"""
|
30
|
+
pos_scores = y_prob[y_true == 1]
|
31
|
+
neg_scores = y_prob[y_true == 0]
|
32
|
+
return stats.ks_2samp(pos_scores, neg_scores).statistic
|
33
|
+
|
34
|
+
|
35
|
+
@tags(
|
36
|
+
"sklearn", "binary_classification", "multiclass_classification", "model_performance"
|
37
|
+
)
|
38
|
+
@tasks("classification", "text_classification")
|
39
|
+
def ClassDiscriminationDrift(
|
40
|
+
datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
|
41
|
+
):
|
42
|
+
"""
|
43
|
+
Compares classification discrimination metrics between reference and monitoring datasets.
|
44
|
+
|
45
|
+
### Purpose
|
46
|
+
|
47
|
+
The Class Discrimination Drift test is designed to evaluate changes in the model's discriminative power
|
48
|
+
over time. By comparing key discrimination metrics between reference and monitoring datasets, this test
|
49
|
+
helps identify whether the model maintains its ability to separate classes in production. This is crucial
|
50
|
+
for understanding if the model's predictive power remains stable and whether its decision boundaries
|
51
|
+
continue to effectively distinguish between different classes.
|
52
|
+
|
53
|
+
### Test Mechanism
|
54
|
+
|
55
|
+
This test proceeds by calculating three key discrimination metrics for both reference and monitoring
|
56
|
+
datasets: ROC AUC (Area Under the Curve), GINI coefficient, and KS (Kolmogorov-Smirnov) statistic.
|
57
|
+
For binary classification, it computes all three metrics. For multiclass problems, it focuses on
|
58
|
+
macro-averaged ROC AUC. The test quantifies drift as percentage changes in these metrics between
|
59
|
+
datasets, providing a comprehensive assessment of discrimination stability.
|
60
|
+
|
61
|
+
### Signs of High Risk
|
62
|
+
|
63
|
+
- Large drifts in discrimination metrics exceeding the threshold
|
64
|
+
- Significant drops in ROC AUC indicating reduced ranking ability
|
65
|
+
- Decreased GINI coefficients showing diminished separation power
|
66
|
+
- Reduced KS statistics suggesting weaker class distinction
|
67
|
+
- Inconsistent changes across different metrics
|
68
|
+
- Systematic degradation in discriminative performance
|
69
|
+
|
70
|
+
### Strengths
|
71
|
+
|
72
|
+
- Combines multiple complementary discrimination metrics
|
73
|
+
- Handles both binary and multiclass classification
|
74
|
+
- Provides clear quantitative drift assessment
|
75
|
+
- Enables early detection of model degradation
|
76
|
+
- Includes standardized drift threshold evaluation
|
77
|
+
- Supports comprehensive performance monitoring
|
78
|
+
|
79
|
+
### Limitations
|
80
|
+
|
81
|
+
- Does not identify root causes of discrimination drift
|
82
|
+
- May be sensitive to changes in class distribution
|
83
|
+
- Cannot suggest optimal decision threshold adjustments
|
84
|
+
- Limited to discrimination aspects of performance
|
85
|
+
- Requires sufficient data for reliable metric calculation
|
86
|
+
- May not capture subtle changes in decision boundaries
|
87
|
+
"""
|
88
|
+
# Get predictions and true values
|
89
|
+
y_true_ref = datasets[0].y
|
90
|
+
y_true_mon = datasets[1].y
|
91
|
+
|
92
|
+
metrics = []
|
93
|
+
|
94
|
+
# Handle binary vs multiclass
|
95
|
+
if len(np.unique(y_true_ref)) == 2:
|
96
|
+
# Binary classification
|
97
|
+
y_prob_ref = datasets[0].y_prob(model)
|
98
|
+
y_prob_mon = datasets[1].y_prob(model)
|
99
|
+
|
100
|
+
# ROC AUC
|
101
|
+
roc_auc_ref = roc_auc_score(y_true_ref, y_prob_ref)
|
102
|
+
roc_auc_mon = roc_auc_score(y_true_mon, y_prob_mon)
|
103
|
+
metrics.append(
|
104
|
+
{"Metric": "ROC_AUC", "Reference": roc_auc_ref, "Monitoring": roc_auc_mon}
|
105
|
+
)
|
106
|
+
|
107
|
+
# GINI
|
108
|
+
gini_ref = calculate_gini(y_true_ref, y_prob_ref)
|
109
|
+
gini_mon = calculate_gini(y_true_mon, y_prob_mon)
|
110
|
+
metrics.append(
|
111
|
+
{"Metric": "GINI", "Reference": gini_ref, "Monitoring": gini_mon}
|
112
|
+
)
|
113
|
+
|
114
|
+
# KS Statistic
|
115
|
+
ks_ref = calculate_ks_statistic(y_true_ref, y_prob_ref)
|
116
|
+
ks_mon = calculate_ks_statistic(y_true_mon, y_prob_mon)
|
117
|
+
metrics.append(
|
118
|
+
{"Metric": "KS_Statistic", "Reference": ks_ref, "Monitoring": ks_mon}
|
119
|
+
)
|
120
|
+
|
121
|
+
else:
|
122
|
+
# Multiclass
|
123
|
+
y_pred_ref = datasets[0].y_pred(model)
|
124
|
+
y_pred_mon = datasets[1].y_pred(model)
|
125
|
+
|
126
|
+
# Only ROC AUC for multiclass
|
127
|
+
roc_auc_ref = multiclass_roc_auc_score(y_true_ref, y_pred_ref)
|
128
|
+
roc_auc_mon = multiclass_roc_auc_score(y_true_mon, y_pred_mon)
|
129
|
+
metrics.append(
|
130
|
+
{
|
131
|
+
"Metric": "ROC_AUC_Macro",
|
132
|
+
"Reference": roc_auc_ref,
|
133
|
+
"Monitoring": roc_auc_mon,
|
134
|
+
}
|
135
|
+
)
|
136
|
+
|
137
|
+
# Create DataFrame
|
138
|
+
df = pd.DataFrame(metrics)
|
139
|
+
|
140
|
+
# Calculate drift percentage with direction
|
141
|
+
df["Drift (%)"] = (
|
142
|
+
(df["Monitoring"] - df["Reference"]) / df["Reference"].abs() * 100
|
143
|
+
).round(2)
|
144
|
+
|
145
|
+
# Add Pass/Fail column based on absolute drift
|
146
|
+
df["Pass/Fail"] = (
|
147
|
+
df["Drift (%)"]
|
148
|
+
.abs()
|
149
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
150
|
+
)
|
151
|
+
|
152
|
+
# Calculate overall pass/fail
|
153
|
+
pass_fail_bool = (df["Pass/Fail"] == "Pass").all()
|
154
|
+
|
155
|
+
return ({"Classification Discrimination Metrics": df}, pass_fail_bool)
|
@@ -0,0 +1,146 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import pandas as pd
|
8
|
+
import plotly.graph_objs as go
|
9
|
+
|
10
|
+
from validmind import tags, tasks
|
11
|
+
from validmind.errors import SkipTestError
|
12
|
+
from validmind.vm_models import VMDataset
|
13
|
+
|
14
|
+
|
15
|
+
@tags("tabular_data", "binary_classification", "multiclass_classification")
|
16
|
+
@tasks("classification")
|
17
|
+
def ClassImbalanceDrift(
|
18
|
+
datasets: List[VMDataset],
|
19
|
+
drift_pct_threshold: float = 5.0,
|
20
|
+
title: str = "Class Distribution Drift",
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
Evaluates drift in class distribution between reference and monitoring datasets.
|
24
|
+
|
25
|
+
### Purpose
|
26
|
+
|
27
|
+
The Class Imbalance Drift test is designed to detect changes in the distribution of target classes
|
28
|
+
over time. By comparing class proportions between reference and monitoring datasets, this test helps
|
29
|
+
identify whether the population structure remains stable in production. This is crucial for
|
30
|
+
understanding if the model continues to operate under similar class distribution assumptions and
|
31
|
+
whether retraining might be necessary due to significant shifts in class balance.
|
32
|
+
|
33
|
+
### Test Mechanism
|
34
|
+
|
35
|
+
This test proceeds by calculating class percentages for both reference and monitoring datasets.
|
36
|
+
It computes the proportion of each class and quantifies drift as the percentage difference in these
|
37
|
+
proportions between datasets. The test provides both visual and numerical comparisons of class
|
38
|
+
distributions, with special attention to changes that exceed the specified drift threshold.
|
39
|
+
Population stability is assessed on a class-by-class basis.
|
40
|
+
|
41
|
+
### Signs of High Risk
|
42
|
+
|
43
|
+
- Large shifts in class proportions exceeding the threshold
|
44
|
+
- Systematic changes affecting multiple classes
|
45
|
+
- Appearance of new classes or disappearance of existing ones
|
46
|
+
- Significant changes in minority class representation
|
47
|
+
- Reversal of majority-minority class relationships
|
48
|
+
- Unexpected changes in class ratios
|
49
|
+
|
50
|
+
### Strengths
|
51
|
+
|
52
|
+
- Provides clear visualization of distribution changes
|
53
|
+
- Identifies specific classes experiencing drift
|
54
|
+
- Enables early detection of population shifts
|
55
|
+
- Includes standardized drift threshold evaluation
|
56
|
+
- Supports both binary and multiclass problems
|
57
|
+
- Maintains interpretable percentage-based metrics
|
58
|
+
|
59
|
+
### Limitations
|
60
|
+
|
61
|
+
- Does not account for feature distribution changes
|
62
|
+
- Cannot identify root causes of class drift
|
63
|
+
- May be sensitive to small sample sizes
|
64
|
+
- Limited to target variable distribution only
|
65
|
+
- Requires sufficient samples per class
|
66
|
+
- May not capture subtle distribution changes
|
67
|
+
"""
|
68
|
+
# Validate inputs
|
69
|
+
if not datasets[0].target_column or not datasets[1].target_column:
|
70
|
+
raise SkipTestError("No target column provided")
|
71
|
+
|
72
|
+
# Calculate class distributions
|
73
|
+
ref_dist = (
|
74
|
+
datasets[0].df[datasets[0].target_column].value_counts(normalize=True) * 100
|
75
|
+
)
|
76
|
+
mon_dist = (
|
77
|
+
datasets[1].df[datasets[1].target_column].value_counts(normalize=True) * 100
|
78
|
+
)
|
79
|
+
|
80
|
+
# Get all unique classes
|
81
|
+
all_classes = sorted(set(ref_dist.index) | set(mon_dist.index))
|
82
|
+
|
83
|
+
if len(all_classes) > 10:
|
84
|
+
raise SkipTestError("Skipping target column with more than 10 classes")
|
85
|
+
|
86
|
+
# Create comparison table
|
87
|
+
rows = []
|
88
|
+
all_passed = True
|
89
|
+
|
90
|
+
for class_label in all_classes:
|
91
|
+
ref_percent = ref_dist.get(class_label, 0)
|
92
|
+
mon_percent = mon_dist.get(class_label, 0)
|
93
|
+
|
94
|
+
# Calculate drift (preserving sign)
|
95
|
+
drift = mon_percent - ref_percent
|
96
|
+
passed = abs(drift) < drift_pct_threshold
|
97
|
+
all_passed &= passed
|
98
|
+
|
99
|
+
rows.append(
|
100
|
+
{
|
101
|
+
datasets[0].target_column: class_label,
|
102
|
+
"Reference (%)": round(ref_percent, 4),
|
103
|
+
"Monitoring (%)": round(mon_percent, 4),
|
104
|
+
"Drift (%)": round(drift, 4),
|
105
|
+
"Pass/Fail": "Pass" if passed else "Fail",
|
106
|
+
}
|
107
|
+
)
|
108
|
+
|
109
|
+
comparison_df = pd.DataFrame(rows)
|
110
|
+
|
111
|
+
# Create named tables dictionary
|
112
|
+
tables = {"Class Distribution (%)": comparison_df}
|
113
|
+
|
114
|
+
# Create visualization
|
115
|
+
fig = go.Figure()
|
116
|
+
|
117
|
+
# Add reference distribution bar
|
118
|
+
fig.add_trace(
|
119
|
+
go.Bar(
|
120
|
+
name="Reference",
|
121
|
+
x=[str(c) for c in all_classes],
|
122
|
+
y=comparison_df["Reference (%)"],
|
123
|
+
marker_color="rgba(31, 119, 180, 0.8)", # Blue with 0.8 opacity
|
124
|
+
)
|
125
|
+
)
|
126
|
+
|
127
|
+
# Add monitoring distribution bar
|
128
|
+
fig.add_trace(
|
129
|
+
go.Bar(
|
130
|
+
name="Monitoring",
|
131
|
+
x=[str(c) for c in all_classes],
|
132
|
+
y=comparison_df["Monitoring (%)"],
|
133
|
+
marker_color="rgba(255, 127, 14, 0.8)", # Orange with 0.8 opacity
|
134
|
+
)
|
135
|
+
)
|
136
|
+
|
137
|
+
# Update layout
|
138
|
+
fig.update_layout(
|
139
|
+
title=title,
|
140
|
+
xaxis_title="Class",
|
141
|
+
yaxis_title="Percentage (%)",
|
142
|
+
barmode="group",
|
143
|
+
showlegend=True,
|
144
|
+
)
|
145
|
+
|
146
|
+
return fig, tables, all_passed
|
@@ -0,0 +1,148 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
from sklearn.metrics import classification_report
|
10
|
+
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.vm_models import VMDataset, VMModel
|
13
|
+
|
14
|
+
|
15
|
+
@tags(
|
16
|
+
"sklearn", "binary_classification", "multiclass_classification", "model_performance"
|
17
|
+
)
|
18
|
+
@tasks("classification", "text_classification")
|
19
|
+
def ClassificationAccuracyDrift(
|
20
|
+
datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
Compares classification accuracy metrics between reference and monitoring datasets.
|
24
|
+
|
25
|
+
### Purpose
|
26
|
+
|
27
|
+
The Classification Accuracy Drift test is designed to evaluate changes in the model's predictive accuracy
|
28
|
+
over time. By comparing key accuracy metrics between reference and monitoring datasets, this test helps
|
29
|
+
identify whether the model maintains its performance levels in production. This is crucial for
|
30
|
+
understanding if the model's predictions remain reliable and whether its overall effectiveness has
|
31
|
+
degraded significantly.
|
32
|
+
|
33
|
+
### Test Mechanism
|
34
|
+
|
35
|
+
This test proceeds by calculating comprehensive accuracy metrics for both reference and monitoring
|
36
|
+
datasets. It computes overall accuracy, per-label precision, recall, and F1 scores, as well as
|
37
|
+
macro-averaged metrics. The test quantifies drift as percentage changes in these metrics between
|
38
|
+
datasets, providing both granular and aggregate views of accuracy changes. Special attention is paid
|
39
|
+
to per-label performance to identify class-specific degradation.
|
40
|
+
|
41
|
+
### Signs of High Risk
|
42
|
+
|
43
|
+
- Large drifts in accuracy metrics exceeding the threshold
|
44
|
+
- Inconsistent changes across different labels
|
45
|
+
- Significant drops in macro-averaged metrics
|
46
|
+
- Systematic degradation in specific class performance
|
47
|
+
- Unexpected improvements suggesting data quality issues
|
48
|
+
- Divergent trends between precision and recall
|
49
|
+
|
50
|
+
### Strengths
|
51
|
+
|
52
|
+
- Provides comprehensive accuracy assessment
|
53
|
+
- Identifies class-specific performance changes
|
54
|
+
- Enables early detection of model degradation
|
55
|
+
- Includes both micro and macro perspectives
|
56
|
+
- Supports multi-class classification evaluation
|
57
|
+
- Maintains interpretable drift thresholds
|
58
|
+
|
59
|
+
### Limitations
|
60
|
+
|
61
|
+
- May be sensitive to class distribution changes
|
62
|
+
- Does not account for prediction confidence
|
63
|
+
- Cannot identify root causes of accuracy drift
|
64
|
+
- Limited to accuracy-based metrics only
|
65
|
+
- Requires sufficient samples per class
|
66
|
+
- May not capture subtle performance changes
|
67
|
+
"""
|
68
|
+
# Get predictions and true values
|
69
|
+
y_true_ref = datasets[0].y
|
70
|
+
y_pred_ref = datasets[0].y_pred(model)
|
71
|
+
|
72
|
+
y_true_mon = datasets[1].y
|
73
|
+
y_pred_mon = datasets[1].y_pred(model)
|
74
|
+
|
75
|
+
# Get unique labels from reference dataset
|
76
|
+
labels = np.unique(y_true_ref)
|
77
|
+
labels = sorted(labels.tolist())
|
78
|
+
|
79
|
+
# Calculate classification reports
|
80
|
+
report_ref = classification_report(
|
81
|
+
y_true=y_true_ref,
|
82
|
+
y_pred=y_pred_ref,
|
83
|
+
output_dict=True,
|
84
|
+
zero_division=0,
|
85
|
+
)
|
86
|
+
|
87
|
+
report_mon = classification_report(
|
88
|
+
y_true=y_true_mon,
|
89
|
+
y_pred=y_pred_mon,
|
90
|
+
output_dict=True,
|
91
|
+
zero_division=0,
|
92
|
+
)
|
93
|
+
|
94
|
+
# Create metrics dataframe
|
95
|
+
metrics = []
|
96
|
+
|
97
|
+
# Add accuracy
|
98
|
+
metrics.append(
|
99
|
+
{
|
100
|
+
"Metric": "Accuracy",
|
101
|
+
"Reference": report_ref["accuracy"],
|
102
|
+
"Monitoring": report_mon["accuracy"],
|
103
|
+
}
|
104
|
+
)
|
105
|
+
|
106
|
+
# Add per-label metrics
|
107
|
+
for label in labels:
|
108
|
+
label_str = str(label)
|
109
|
+
for metric in ["precision", "recall", "f1-score"]:
|
110
|
+
metric_name = f"{metric.title()}_{label_str}"
|
111
|
+
metrics.append(
|
112
|
+
{
|
113
|
+
"Metric": metric_name,
|
114
|
+
"Reference": report_ref[label_str][metric],
|
115
|
+
"Monitoring": report_mon[label_str][metric],
|
116
|
+
}
|
117
|
+
)
|
118
|
+
|
119
|
+
# Add macro averages
|
120
|
+
for metric in ["precision", "recall", "f1-score"]:
|
121
|
+
metric_name = f"{metric.title()}_Macro"
|
122
|
+
metrics.append(
|
123
|
+
{
|
124
|
+
"Metric": metric_name,
|
125
|
+
"Reference": report_ref["macro avg"][metric],
|
126
|
+
"Monitoring": report_mon["macro avg"][metric],
|
127
|
+
}
|
128
|
+
)
|
129
|
+
|
130
|
+
# Create DataFrame
|
131
|
+
df = pd.DataFrame(metrics)
|
132
|
+
|
133
|
+
# Calculate drift percentage with direction
|
134
|
+
df["Drift (%)"] = (
|
135
|
+
(df["Monitoring"] - df["Reference"]) / df["Reference"].abs() * 100
|
136
|
+
).round(2)
|
137
|
+
|
138
|
+
# Add Pass/Fail column based on absolute drift
|
139
|
+
df["Pass/Fail"] = (
|
140
|
+
df["Drift (%)"]
|
141
|
+
.abs()
|
142
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
143
|
+
)
|
144
|
+
|
145
|
+
# Calculate overall pass/fail
|
146
|
+
pass_fail_bool = (df["Pass/Fail"] == "Pass").all()
|
147
|
+
|
148
|
+
return ({"Classification Accuracy Metrics": df}, pass_fail_bool)
|