validmind 2.7.5__py3-none-any.whl → 2.7.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (24) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/datasets/credit_risk/lending_club.py +354 -88
  3. validmind/tests/data_validation/HighPearsonCorrelation.py +12 -2
  4. validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +218 -0
  5. validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +153 -0
  6. validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +144 -0
  7. validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +146 -0
  8. validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +191 -0
  9. validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +176 -0
  10. validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -121
  11. validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
  12. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -45
  13. validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +202 -0
  14. validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +97 -0
  15. validmind/tests/ongoing_monitoring/ROCCurveDrift.py +149 -0
  16. validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +210 -0
  17. validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +207 -0
  18. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -14
  19. validmind/vm_models/dataset/dataset.py +0 -4
  20. {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/METADATA +2 -2
  21. {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/RECORD +24 -13
  22. {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/LICENSE +0 -0
  23. {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/WHEEL +0 -0
  24. {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,218 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ import plotly.graph_objects as go
8
+ from sklearn.calibration import calibration_curve
9
+ from typing import List
10
+ from validmind import tags, tasks
11
+ from validmind.errors import SkipTestError
12
+ from validmind.vm_models import VMDataset, VMModel
13
+
14
+
15
+ @tags(
16
+ "sklearn",
17
+ "binary_classification",
18
+ "model_performance",
19
+ "visualization",
20
+ )
21
+ @tasks("classification", "text_classification")
22
+ def CalibrationCurveDrift(
23
+ datasets: List[VMDataset],
24
+ model: VMModel,
25
+ n_bins: int = 10,
26
+ drift_pct_threshold: float = 20,
27
+ ):
28
+ """
29
+ Evaluates changes in probability calibration between reference and monitoring datasets.
30
+
31
+ ### Purpose
32
+
33
+ The Calibration Curve Drift test is designed to assess changes in the model's probability calibration
34
+ over time. By comparing calibration curves between reference and monitoring datasets, this test helps
35
+ identify whether the model's probability estimates remain reliable in production. This is crucial for
36
+ understanding if the model's risk predictions maintain their intended interpretation and whether
37
+ recalibration might be necessary.
38
+
39
+ ### Test Mechanism
40
+
41
+ This test proceeds by generating calibration curves for both reference and monitoring datasets. For each
42
+ dataset, it bins the predicted probabilities and calculates the actual fraction of positives within each
43
+ bin. It then compares these values between datasets to identify significant shifts in calibration.
44
+ The test quantifies drift as percentage changes in both mean predicted probabilities and actual fractions
45
+ of positives per bin, providing both visual and numerical assessments of calibration stability.
46
+
47
+ ### Signs of High Risk
48
+
49
+ - Large differences between reference and monitoring calibration curves
50
+ - Systematic over-estimation or under-estimation in monitoring dataset
51
+ - Significant drift percentages exceeding the threshold in multiple bins
52
+ - Changes in calibration concentrated in specific probability ranges
53
+ - Inconsistent drift patterns across the probability spectrum
54
+ - Empty or sparse bins indicating insufficient data for reliable comparison
55
+
56
+ ### Strengths
57
+
58
+ - Provides visual and quantitative assessment of calibration changes
59
+ - Identifies specific probability ranges where calibration has shifted
60
+ - Enables early detection of systematic prediction biases
61
+ - Includes detailed bin-by-bin comparison of calibration metrics
62
+ - Handles edge cases with insufficient data in certain bins
63
+ - Supports both binary and probabilistic interpretation of results
64
+
65
+ ### Limitations
66
+
67
+ - Requires sufficient data in each probability bin for reliable comparison
68
+ - Sensitive to choice of number of bins and binning strategy
69
+ - May not capture complex changes in probability distributions
70
+ - Cannot directly suggest recalibration parameters
71
+ - Limited to assessing probability calibration aspects
72
+ - Results may be affected by class imbalance changes
73
+ """
74
+
75
+ # Check for binary classification
76
+ if len(np.unique(datasets[0].y)) > 2:
77
+ raise SkipTestError(
78
+ "Calibration Curve Drift is only supported for binary classification models"
79
+ )
80
+
81
+ # Calculate calibration for reference dataset
82
+ y_prob_ref = datasets[0].y_prob(model)
83
+ y_true_ref = datasets[0].y.astype(y_prob_ref.dtype).flatten()
84
+ prob_true_ref, prob_pred_ref = calibration_curve(
85
+ y_true_ref, y_prob_ref, n_bins=n_bins, strategy="uniform"
86
+ )
87
+
88
+ # Calculate calibration for monitoring dataset
89
+ y_prob_mon = datasets[1].y_prob(model)
90
+ y_true_mon = datasets[1].y.astype(y_prob_mon.dtype).flatten()
91
+ prob_true_mon, prob_pred_mon = calibration_curve(
92
+ y_true_mon, y_prob_mon, n_bins=n_bins, strategy="uniform"
93
+ )
94
+
95
+ # Create bin labels
96
+ bin_edges = np.linspace(0, 1, n_bins + 1)
97
+ bin_labels = [f"{bin_edges[i]:.1f}-{bin_edges[i+1]:.1f}" for i in range(n_bins)]
98
+
99
+ # Create predicted probabilities table
100
+ pred_metrics = []
101
+ for i in range(n_bins):
102
+ ref_val = "no data" if i >= len(prob_pred_ref) else round(prob_pred_ref[i], 3)
103
+ mon_val = "no data" if i >= len(prob_pred_mon) else round(prob_pred_mon[i], 3)
104
+
105
+ pred_metrics.append(
106
+ {"Bin": bin_labels[i], "Reference": ref_val, "Monitoring": mon_val}
107
+ )
108
+
109
+ pred_df = pd.DataFrame(pred_metrics)
110
+
111
+ # Calculate drift only for bins with data
112
+ mask = (pred_df["Reference"] != "no data") & (pred_df["Monitoring"] != "no data")
113
+ pred_df["Drift (%)"] = None
114
+ pred_df.loc[mask, "Drift (%)"] = (
115
+ (
116
+ pd.to_numeric(pred_df.loc[mask, "Monitoring"])
117
+ - pd.to_numeric(pred_df.loc[mask, "Reference"])
118
+ )
119
+ / pd.to_numeric(pred_df.loc[mask, "Reference"]).abs()
120
+ * 100
121
+ ).round(2)
122
+
123
+ pred_df["Pass/Fail"] = None
124
+ pred_df.loc[mask, "Pass/Fail"] = (
125
+ pred_df.loc[mask, "Drift (%)"]
126
+ .abs()
127
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
128
+ )
129
+ pred_df.loc[~mask, "Pass/Fail"] = "N/A"
130
+
131
+ # Create fraction of positives table
132
+ true_metrics = []
133
+ for i in range(n_bins):
134
+ ref_val = "no data" if i >= len(prob_true_ref) else round(prob_true_ref[i], 3)
135
+ mon_val = "no data" if i >= len(prob_true_mon) else round(prob_true_mon[i], 3)
136
+
137
+ true_metrics.append(
138
+ {"Bin": bin_labels[i], "Reference": ref_val, "Monitoring": mon_val}
139
+ )
140
+
141
+ true_df = pd.DataFrame(true_metrics)
142
+
143
+ # Calculate drift only for bins with data
144
+ mask = (true_df["Reference"] != "no data") & (true_df["Monitoring"] != "no data")
145
+ true_df["Drift (%)"] = None
146
+ true_df.loc[mask, "Drift (%)"] = (
147
+ (
148
+ pd.to_numeric(true_df.loc[mask, "Monitoring"])
149
+ - pd.to_numeric(true_df.loc[mask, "Reference"])
150
+ )
151
+ / pd.to_numeric(true_df.loc[mask, "Reference"]).abs()
152
+ * 100
153
+ ).round(2)
154
+
155
+ true_df["Pass/Fail"] = None
156
+ true_df.loc[mask, "Pass/Fail"] = (
157
+ true_df.loc[mask, "Drift (%)"]
158
+ .abs()
159
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
160
+ )
161
+ true_df.loc[~mask, "Pass/Fail"] = "N/A"
162
+
163
+ # Create figure
164
+ fig = go.Figure()
165
+
166
+ # Add perfect calibration line
167
+ fig.add_trace(
168
+ go.Scatter(
169
+ x=[0, 1],
170
+ y=[0, 1],
171
+ mode="lines",
172
+ name="Perfect Calibration",
173
+ line=dict(color="grey", dash="dash"),
174
+ )
175
+ )
176
+
177
+ # Add reference calibration curve
178
+ fig.add_trace(
179
+ go.Scatter(
180
+ x=prob_pred_ref,
181
+ y=prob_true_ref,
182
+ mode="lines+markers",
183
+ name="Reference",
184
+ line=dict(color="blue", width=2),
185
+ marker=dict(size=8),
186
+ )
187
+ )
188
+
189
+ # Add monitoring calibration curve
190
+ fig.add_trace(
191
+ go.Scatter(
192
+ x=prob_pred_mon,
193
+ y=prob_true_mon,
194
+ mode="lines+markers",
195
+ name="Monitoring",
196
+ line=dict(color="red", width=2),
197
+ marker=dict(size=8),
198
+ )
199
+ )
200
+
201
+ fig.update_layout(
202
+ title="Calibration Curves Comparison",
203
+ xaxis=dict(title="Mean Predicted Probability", range=[0, 1]),
204
+ yaxis=dict(title="Fraction of Positives", range=[0, 1]),
205
+ width=700,
206
+ height=500,
207
+ )
208
+
209
+ # Calculate overall pass/fail (only for bins with data)
210
+ pass_fail_bool = (pred_df.loc[mask, "Pass/Fail"] == "Pass").all() and (
211
+ true_df.loc[mask, "Pass/Fail"] == "Pass"
212
+ ).all()
213
+
214
+ return (
215
+ fig,
216
+ {"Mean Predicted Probabilities": pred_df, "Fraction of Positives": true_df},
217
+ pass_fail_bool,
218
+ )
@@ -0,0 +1,153 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.metrics import roc_auc_score
8
+ from sklearn.preprocessing import LabelBinarizer
9
+ from scipy import stats
10
+ from typing import List
11
+ from validmind import tags, tasks
12
+ from validmind.vm_models import VMDataset, VMModel
13
+
14
+
15
+ def multiclass_roc_auc_score(y_test, y_pred, average="macro"):
16
+ lb = LabelBinarizer()
17
+ lb.fit(y_test)
18
+ return roc_auc_score(lb.transform(y_test), lb.transform(y_pred), average=average)
19
+
20
+
21
+ def calculate_gini(y_true, y_prob):
22
+ """Calculate Gini coefficient (2*AUC - 1)"""
23
+ return 2 * roc_auc_score(y_true, y_prob) - 1
24
+
25
+
26
+ def calculate_ks_statistic(y_true, y_prob):
27
+ """Calculate Kolmogorov-Smirnov statistic"""
28
+ pos_scores = y_prob[y_true == 1]
29
+ neg_scores = y_prob[y_true == 0]
30
+ return stats.ks_2samp(pos_scores, neg_scores).statistic
31
+
32
+
33
+ @tags(
34
+ "sklearn", "binary_classification", "multiclass_classification", "model_performance"
35
+ )
36
+ @tasks("classification", "text_classification")
37
+ def ClassDiscriminationDrift(
38
+ datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
39
+ ):
40
+ """
41
+ Compares classification discrimination metrics between reference and monitoring datasets.
42
+
43
+ ### Purpose
44
+
45
+ The Class Discrimination Drift test is designed to evaluate changes in the model's discriminative power
46
+ over time. By comparing key discrimination metrics between reference and monitoring datasets, this test
47
+ helps identify whether the model maintains its ability to separate classes in production. This is crucial
48
+ for understanding if the model's predictive power remains stable and whether its decision boundaries
49
+ continue to effectively distinguish between different classes.
50
+
51
+ ### Test Mechanism
52
+
53
+ This test proceeds by calculating three key discrimination metrics for both reference and monitoring
54
+ datasets: ROC AUC (Area Under the Curve), GINI coefficient, and KS (Kolmogorov-Smirnov) statistic.
55
+ For binary classification, it computes all three metrics. For multiclass problems, it focuses on
56
+ macro-averaged ROC AUC. The test quantifies drift as percentage changes in these metrics between
57
+ datasets, providing a comprehensive assessment of discrimination stability.
58
+
59
+ ### Signs of High Risk
60
+
61
+ - Large drifts in discrimination metrics exceeding the threshold
62
+ - Significant drops in ROC AUC indicating reduced ranking ability
63
+ - Decreased GINI coefficients showing diminished separation power
64
+ - Reduced KS statistics suggesting weaker class distinction
65
+ - Inconsistent changes across different metrics
66
+ - Systematic degradation in discriminative performance
67
+
68
+ ### Strengths
69
+
70
+ - Combines multiple complementary discrimination metrics
71
+ - Handles both binary and multiclass classification
72
+ - Provides clear quantitative drift assessment
73
+ - Enables early detection of model degradation
74
+ - Includes standardized drift threshold evaluation
75
+ - Supports comprehensive performance monitoring
76
+
77
+ ### Limitations
78
+
79
+ - Does not identify root causes of discrimination drift
80
+ - May be sensitive to changes in class distribution
81
+ - Cannot suggest optimal decision threshold adjustments
82
+ - Limited to discrimination aspects of performance
83
+ - Requires sufficient data for reliable metric calculation
84
+ - May not capture subtle changes in decision boundaries
85
+ """
86
+ # Get predictions and true values
87
+ y_true_ref = datasets[0].y
88
+ y_true_mon = datasets[1].y
89
+
90
+ metrics = []
91
+
92
+ # Handle binary vs multiclass
93
+ if len(np.unique(y_true_ref)) == 2:
94
+ # Binary classification
95
+ y_prob_ref = datasets[0].y_prob(model)
96
+ y_prob_mon = datasets[1].y_prob(model)
97
+
98
+ # ROC AUC
99
+ roc_auc_ref = roc_auc_score(y_true_ref, y_prob_ref)
100
+ roc_auc_mon = roc_auc_score(y_true_mon, y_prob_mon)
101
+ metrics.append(
102
+ {"Metric": "ROC_AUC", "Reference": roc_auc_ref, "Monitoring": roc_auc_mon}
103
+ )
104
+
105
+ # GINI
106
+ gini_ref = calculate_gini(y_true_ref, y_prob_ref)
107
+ gini_mon = calculate_gini(y_true_mon, y_prob_mon)
108
+ metrics.append(
109
+ {"Metric": "GINI", "Reference": gini_ref, "Monitoring": gini_mon}
110
+ )
111
+
112
+ # KS Statistic
113
+ ks_ref = calculate_ks_statistic(y_true_ref, y_prob_ref)
114
+ ks_mon = calculate_ks_statistic(y_true_mon, y_prob_mon)
115
+ metrics.append(
116
+ {"Metric": "KS_Statistic", "Reference": ks_ref, "Monitoring": ks_mon}
117
+ )
118
+
119
+ else:
120
+ # Multiclass
121
+ y_pred_ref = datasets[0].y_pred(model)
122
+ y_pred_mon = datasets[1].y_pred(model)
123
+
124
+ # Only ROC AUC for multiclass
125
+ roc_auc_ref = multiclass_roc_auc_score(y_true_ref, y_pred_ref)
126
+ roc_auc_mon = multiclass_roc_auc_score(y_true_mon, y_pred_mon)
127
+ metrics.append(
128
+ {
129
+ "Metric": "ROC_AUC_Macro",
130
+ "Reference": roc_auc_ref,
131
+ "Monitoring": roc_auc_mon,
132
+ }
133
+ )
134
+
135
+ # Create DataFrame
136
+ df = pd.DataFrame(metrics)
137
+
138
+ # Calculate drift percentage with direction
139
+ df["Drift (%)"] = (
140
+ (df["Monitoring"] - df["Reference"]) / df["Reference"].abs() * 100
141
+ ).round(2)
142
+
143
+ # Add Pass/Fail column based on absolute drift
144
+ df["Pass/Fail"] = (
145
+ df["Drift (%)"]
146
+ .abs()
147
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
148
+ )
149
+
150
+ # Calculate overall pass/fail
151
+ pass_fail_bool = (df["Pass/Fail"] == "Pass").all()
152
+
153
+ return ({"Classification Discrimination Metrics": df}, pass_fail_bool)
@@ -0,0 +1,144 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import pandas as pd
6
+ import plotly.graph_objs as go
7
+ from typing import List
8
+ from validmind import tags, tasks
9
+ from validmind.vm_models import VMDataset
10
+ from validmind.errors import SkipTestError
11
+
12
+
13
+ @tags("tabular_data", "binary_classification", "multiclass_classification")
14
+ @tasks("classification")
15
+ def ClassImbalanceDrift(
16
+ datasets: List[VMDataset],
17
+ drift_pct_threshold: float = 5.0,
18
+ title: str = "Class Distribution Drift",
19
+ ):
20
+ """
21
+ Evaluates drift in class distribution between reference and monitoring datasets.
22
+
23
+ ### Purpose
24
+
25
+ The Class Imbalance Drift test is designed to detect changes in the distribution of target classes
26
+ over time. By comparing class proportions between reference and monitoring datasets, this test helps
27
+ identify whether the population structure remains stable in production. This is crucial for
28
+ understanding if the model continues to operate under similar class distribution assumptions and
29
+ whether retraining might be necessary due to significant shifts in class balance.
30
+
31
+ ### Test Mechanism
32
+
33
+ This test proceeds by calculating class percentages for both reference and monitoring datasets.
34
+ It computes the proportion of each class and quantifies drift as the percentage difference in these
35
+ proportions between datasets. The test provides both visual and numerical comparisons of class
36
+ distributions, with special attention to changes that exceed the specified drift threshold.
37
+ Population stability is assessed on a class-by-class basis.
38
+
39
+ ### Signs of High Risk
40
+
41
+ - Large shifts in class proportions exceeding the threshold
42
+ - Systematic changes affecting multiple classes
43
+ - Appearance of new classes or disappearance of existing ones
44
+ - Significant changes in minority class representation
45
+ - Reversal of majority-minority class relationships
46
+ - Unexpected changes in class ratios
47
+
48
+ ### Strengths
49
+
50
+ - Provides clear visualization of distribution changes
51
+ - Identifies specific classes experiencing drift
52
+ - Enables early detection of population shifts
53
+ - Includes standardized drift threshold evaluation
54
+ - Supports both binary and multiclass problems
55
+ - Maintains interpretable percentage-based metrics
56
+
57
+ ### Limitations
58
+
59
+ - Does not account for feature distribution changes
60
+ - Cannot identify root causes of class drift
61
+ - May be sensitive to small sample sizes
62
+ - Limited to target variable distribution only
63
+ - Requires sufficient samples per class
64
+ - May not capture subtle distribution changes
65
+ """
66
+ # Validate inputs
67
+ if not datasets[0].target_column or not datasets[1].target_column:
68
+ raise SkipTestError("No target column provided")
69
+
70
+ # Calculate class distributions
71
+ ref_dist = (
72
+ datasets[0].df[datasets[0].target_column].value_counts(normalize=True) * 100
73
+ )
74
+ mon_dist = (
75
+ datasets[1].df[datasets[1].target_column].value_counts(normalize=True) * 100
76
+ )
77
+
78
+ # Get all unique classes
79
+ all_classes = sorted(set(ref_dist.index) | set(mon_dist.index))
80
+
81
+ if len(all_classes) > 10:
82
+ raise SkipTestError("Skipping target column with more than 10 classes")
83
+
84
+ # Create comparison table
85
+ rows = []
86
+ all_passed = True
87
+
88
+ for class_label in all_classes:
89
+ ref_percent = ref_dist.get(class_label, 0)
90
+ mon_percent = mon_dist.get(class_label, 0)
91
+
92
+ # Calculate drift (preserving sign)
93
+ drift = mon_percent - ref_percent
94
+ passed = abs(drift) < drift_pct_threshold
95
+ all_passed &= passed
96
+
97
+ rows.append(
98
+ {
99
+ datasets[0].target_column: class_label,
100
+ "Reference (%)": round(ref_percent, 4),
101
+ "Monitoring (%)": round(mon_percent, 4),
102
+ "Drift (%)": round(drift, 4),
103
+ "Pass/Fail": "Pass" if passed else "Fail",
104
+ }
105
+ )
106
+
107
+ comparison_df = pd.DataFrame(rows)
108
+
109
+ # Create named tables dictionary
110
+ tables = {"Class Distribution (%)": comparison_df}
111
+
112
+ # Create visualization
113
+ fig = go.Figure()
114
+
115
+ # Add reference distribution bar
116
+ fig.add_trace(
117
+ go.Bar(
118
+ name="Reference",
119
+ x=[str(c) for c in all_classes],
120
+ y=comparison_df["Reference (%)"],
121
+ marker_color="rgba(31, 119, 180, 0.8)", # Blue with 0.8 opacity
122
+ )
123
+ )
124
+
125
+ # Add monitoring distribution bar
126
+ fig.add_trace(
127
+ go.Bar(
128
+ name="Monitoring",
129
+ x=[str(c) for c in all_classes],
130
+ y=comparison_df["Monitoring (%)"],
131
+ marker_color="rgba(255, 127, 14, 0.8)", # Orange with 0.8 opacity
132
+ )
133
+ )
134
+
135
+ # Update layout
136
+ fig.update_layout(
137
+ title=title,
138
+ xaxis_title="Class",
139
+ yaxis_title="Percentage (%)",
140
+ barmode="group",
141
+ showlegend=True,
142
+ )
143
+
144
+ return fig, tables, all_passed
@@ -0,0 +1,146 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import numpy as np
6
+ import pandas as pd
7
+ from sklearn.metrics import classification_report
8
+ from typing import List
9
+ from validmind import tags, tasks
10
+ from validmind.vm_models import VMDataset, VMModel
11
+
12
+
13
+ @tags(
14
+ "sklearn", "binary_classification", "multiclass_classification", "model_performance"
15
+ )
16
+ @tasks("classification", "text_classification")
17
+ def ClassificationAccuracyDrift(
18
+ datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
19
+ ):
20
+ """
21
+ Compares classification accuracy metrics between reference and monitoring datasets.
22
+
23
+ ### Purpose
24
+
25
+ The Classification Accuracy Drift test is designed to evaluate changes in the model's predictive accuracy
26
+ over time. By comparing key accuracy metrics between reference and monitoring datasets, this test helps
27
+ identify whether the model maintains its performance levels in production. This is crucial for
28
+ understanding if the model's predictions remain reliable and whether its overall effectiveness has
29
+ degraded significantly.
30
+
31
+ ### Test Mechanism
32
+
33
+ This test proceeds by calculating comprehensive accuracy metrics for both reference and monitoring
34
+ datasets. It computes overall accuracy, per-label precision, recall, and F1 scores, as well as
35
+ macro-averaged metrics. The test quantifies drift as percentage changes in these metrics between
36
+ datasets, providing both granular and aggregate views of accuracy changes. Special attention is paid
37
+ to per-label performance to identify class-specific degradation.
38
+
39
+ ### Signs of High Risk
40
+
41
+ - Large drifts in accuracy metrics exceeding the threshold
42
+ - Inconsistent changes across different labels
43
+ - Significant drops in macro-averaged metrics
44
+ - Systematic degradation in specific class performance
45
+ - Unexpected improvements suggesting data quality issues
46
+ - Divergent trends between precision and recall
47
+
48
+ ### Strengths
49
+
50
+ - Provides comprehensive accuracy assessment
51
+ - Identifies class-specific performance changes
52
+ - Enables early detection of model degradation
53
+ - Includes both micro and macro perspectives
54
+ - Supports multi-class classification evaluation
55
+ - Maintains interpretable drift thresholds
56
+
57
+ ### Limitations
58
+
59
+ - May be sensitive to class distribution changes
60
+ - Does not account for prediction confidence
61
+ - Cannot identify root causes of accuracy drift
62
+ - Limited to accuracy-based metrics only
63
+ - Requires sufficient samples per class
64
+ - May not capture subtle performance changes
65
+ """
66
+ # Get predictions and true values
67
+ y_true_ref = datasets[0].y
68
+ y_pred_ref = datasets[0].y_pred(model)
69
+
70
+ y_true_mon = datasets[1].y
71
+ y_pred_mon = datasets[1].y_pred(model)
72
+
73
+ # Get unique labels from reference dataset
74
+ labels = np.unique(y_true_ref)
75
+ labels = sorted(labels.tolist())
76
+
77
+ # Calculate classification reports
78
+ report_ref = classification_report(
79
+ y_true=y_true_ref,
80
+ y_pred=y_pred_ref,
81
+ output_dict=True,
82
+ zero_division=0,
83
+ )
84
+
85
+ report_mon = classification_report(
86
+ y_true=y_true_mon,
87
+ y_pred=y_pred_mon,
88
+ output_dict=True,
89
+ zero_division=0,
90
+ )
91
+
92
+ # Create metrics dataframe
93
+ metrics = []
94
+
95
+ # Add accuracy
96
+ metrics.append(
97
+ {
98
+ "Metric": "Accuracy",
99
+ "Reference": report_ref["accuracy"],
100
+ "Monitoring": report_mon["accuracy"],
101
+ }
102
+ )
103
+
104
+ # Add per-label metrics
105
+ for label in labels:
106
+ label_str = str(label)
107
+ for metric in ["precision", "recall", "f1-score"]:
108
+ metric_name = f"{metric.title()}_{label_str}"
109
+ metrics.append(
110
+ {
111
+ "Metric": metric_name,
112
+ "Reference": report_ref[label_str][metric],
113
+ "Monitoring": report_mon[label_str][metric],
114
+ }
115
+ )
116
+
117
+ # Add macro averages
118
+ for metric in ["precision", "recall", "f1-score"]:
119
+ metric_name = f"{metric.title()}_Macro"
120
+ metrics.append(
121
+ {
122
+ "Metric": metric_name,
123
+ "Reference": report_ref["macro avg"][metric],
124
+ "Monitoring": report_mon["macro avg"][metric],
125
+ }
126
+ )
127
+
128
+ # Create DataFrame
129
+ df = pd.DataFrame(metrics)
130
+
131
+ # Calculate drift percentage with direction
132
+ df["Drift (%)"] = (
133
+ (df["Monitoring"] - df["Reference"]) / df["Reference"].abs() * 100
134
+ ).round(2)
135
+
136
+ # Add Pass/Fail column based on absolute drift
137
+ df["Pass/Fail"] = (
138
+ df["Drift (%)"]
139
+ .abs()
140
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
141
+ )
142
+
143
+ # Calculate overall pass/fail
144
+ pass_fail_bool = (df["Pass/Fail"] == "Pass").all()
145
+
146
+ return ({"Classification Accuracy Metrics": df}, pass_fail_bool)