validmind 2.7.5__py3-none-any.whl → 2.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__version__.py +1 -1
- validmind/datasets/credit_risk/lending_club.py +354 -88
- validmind/tests/data_validation/HighPearsonCorrelation.py +12 -2
- validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +218 -0
- validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +153 -0
- validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +144 -0
- validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +146 -0
- validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +191 -0
- validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +176 -0
- validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -121
- validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -45
- validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +202 -0
- validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +97 -0
- validmind/tests/ongoing_monitoring/ROCCurveDrift.py +149 -0
- validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +210 -0
- validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +207 -0
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -14
- validmind/vm_models/dataset/dataset.py +0 -4
- {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/METADATA +2 -2
- {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/RECORD +24 -13
- {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/LICENSE +0 -0
- {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/WHEEL +0 -0
- {validmind-2.7.5.dist-info → validmind-2.7.6.dist-info}/entry_points.txt +0 -0
@@ -2,16 +2,14 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
|
6
|
-
import
|
7
|
-
import numpy as np
|
8
|
-
|
5
|
+
import pandas as pd
|
6
|
+
import plotly.graph_objects as go
|
9
7
|
from validmind import tags, tasks
|
10
8
|
|
11
9
|
|
12
10
|
@tags("visualization")
|
13
11
|
@tasks("monitoring")
|
14
|
-
def PredictionCorrelation(datasets, model):
|
12
|
+
def PredictionCorrelation(datasets, model, drift_pct_threshold=20):
|
15
13
|
"""
|
16
14
|
Assesses correlation changes between model predictions from reference and monitoring datasets to detect potential
|
17
15
|
target drift.
|
@@ -47,55 +45,98 @@ def PredictionCorrelation(datasets, model):
|
|
47
45
|
- Focuses solely on linear relationships, potentially missing non-linear interactions.
|
48
46
|
"""
|
49
47
|
|
50
|
-
|
51
|
-
|
48
|
+
# Get feature columns and predictions
|
49
|
+
feature_columns = datasets[0].feature_columns
|
50
|
+
y_prob_ref = pd.Series(datasets[0].y_prob(model), index=datasets[0].df.index)
|
51
|
+
y_prob_mon = pd.Series(datasets[1].y_prob(model), index=datasets[1].df.index)
|
52
52
|
|
53
|
-
|
54
|
-
|
53
|
+
# Create dataframes with features and predictions
|
54
|
+
df_ref = datasets[0].df[feature_columns].copy()
|
55
|
+
df_ref["predictions"] = y_prob_ref
|
55
56
|
|
56
|
-
|
57
|
-
|
57
|
+
df_mon = datasets[1].df[feature_columns].copy()
|
58
|
+
df_mon["predictions"] = y_prob_mon
|
58
59
|
|
59
|
-
|
60
|
-
|
61
|
-
|
60
|
+
# Calculate correlations
|
61
|
+
corr_ref = df_ref.corr()["predictions"]
|
62
|
+
corr_mon = df_mon.corr()["predictions"]
|
62
63
|
|
63
|
-
|
64
|
-
|
65
|
-
|
64
|
+
# Combine correlations (excluding the predictions row)
|
65
|
+
corr_final = pd.DataFrame(
|
66
|
+
{
|
67
|
+
"Reference Predictions": corr_ref[feature_columns],
|
68
|
+
"Monitoring Predictions": corr_mon[feature_columns],
|
69
|
+
}
|
70
|
+
)
|
66
71
|
|
67
|
-
|
72
|
+
# Calculate drift percentage with direction
|
73
|
+
corr_final["Drift (%)"] = (
|
74
|
+
(corr_final["Monitoring Predictions"] - corr_final["Reference Predictions"])
|
75
|
+
/ corr_final["Reference Predictions"].abs()
|
76
|
+
* 100
|
77
|
+
).round(2)
|
78
|
+
|
79
|
+
# Add Pass/Fail column based on absolute drift
|
80
|
+
corr_final["Pass/Fail"] = (
|
81
|
+
corr_final["Drift (%)"]
|
82
|
+
.abs()
|
83
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
84
|
+
)
|
68
85
|
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
86
|
+
# Create plotly figure
|
87
|
+
fig = go.Figure()
|
88
|
+
|
89
|
+
# Add reference predictions bar
|
90
|
+
fig.add_trace(
|
91
|
+
go.Bar(
|
92
|
+
name="Reference Prediction Correlation",
|
93
|
+
x=corr_final.index,
|
94
|
+
y=corr_final["Reference Predictions"],
|
95
|
+
marker_color="blue",
|
96
|
+
marker_line_color="black",
|
97
|
+
marker_line_width=1,
|
98
|
+
opacity=0.75,
|
99
|
+
)
|
76
100
|
)
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
101
|
+
|
102
|
+
# Add monitoring predictions bar
|
103
|
+
fig.add_trace(
|
104
|
+
go.Bar(
|
105
|
+
name="Monitoring Prediction Correlation",
|
106
|
+
x=corr_final.index,
|
107
|
+
y=corr_final["Monitoring Predictions"],
|
108
|
+
marker_color="green",
|
109
|
+
marker_line_color="black",
|
110
|
+
marker_line_width=1,
|
111
|
+
opacity=0.75,
|
112
|
+
)
|
84
113
|
)
|
85
114
|
|
86
|
-
|
87
|
-
|
88
|
-
|
115
|
+
# Update layout
|
116
|
+
fig.update_layout(
|
117
|
+
title="Correlation between Predictions and Features",
|
118
|
+
xaxis_title="Features",
|
119
|
+
yaxis_title="Correlation",
|
120
|
+
barmode="group",
|
121
|
+
template="plotly_white",
|
122
|
+
showlegend=True,
|
123
|
+
xaxis_tickangle=-45,
|
124
|
+
yaxis=dict(
|
125
|
+
range=[-1, 1], # Correlation range is always -1 to 1
|
126
|
+
zeroline=True,
|
127
|
+
zerolinewidth=1,
|
128
|
+
zerolinecolor="grey",
|
129
|
+
gridcolor="lightgrey",
|
130
|
+
),
|
131
|
+
hoverlabel=dict(bgcolor="white", font_size=12, font_family="Arial"),
|
132
|
+
)
|
89
133
|
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
134
|
+
# Ensure Features is the first column
|
135
|
+
corr_final["Feature"] = corr_final.index
|
136
|
+
cols = ["Feature"] + [col for col in corr_final.columns if col != "Feature"]
|
137
|
+
corr_final = corr_final[cols]
|
94
138
|
|
95
|
-
|
139
|
+
# Calculate overall pass/fail
|
140
|
+
pass_fail_bool = (corr_final["Pass/Fail"] == "Pass").all()
|
96
141
|
|
97
|
-
|
98
|
-
corr_final = corr_final[
|
99
|
-
["Features", "Reference Predictions", "Monitoring Predictions"]
|
100
|
-
]
|
101
|
-
return ({"Correlation Pair Table": corr_final}, fig)
|
142
|
+
return ({"Correlation Pair Table": corr_final}, fig, pass_fail_bool)
|
@@ -0,0 +1,202 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import pandas as pd
|
7
|
+
import plotly.graph_objects as go
|
8
|
+
from plotly.subplots import make_subplots
|
9
|
+
from scipy import stats
|
10
|
+
from typing import List
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.vm_models import VMDataset, VMModel
|
13
|
+
|
14
|
+
|
15
|
+
@tags("visualization", "credit_risk")
|
16
|
+
@tasks("classification")
|
17
|
+
def PredictionProbabilitiesHistogramDrift(
|
18
|
+
datasets: List[VMDataset],
|
19
|
+
model: VMModel,
|
20
|
+
title="Prediction Probabilities Histogram Drift",
|
21
|
+
drift_pct_threshold: float = 20.0,
|
22
|
+
):
|
23
|
+
"""
|
24
|
+
Compares prediction probability distributions between reference and monitoring datasets.
|
25
|
+
|
26
|
+
### Purpose
|
27
|
+
|
28
|
+
The Prediction Probabilities Histogram Drift test is designed to evaluate changes in the model's
|
29
|
+
probability predictions over time. By comparing probability distributions between reference and
|
30
|
+
monitoring datasets using histograms, this test helps identify whether the model's probability
|
31
|
+
assignments have shifted in production. This is crucial for understanding if the model's risk
|
32
|
+
assessment behavior remains consistent and whether its probability estimates maintain their
|
33
|
+
original distribution patterns.
|
34
|
+
|
35
|
+
### Test Mechanism
|
36
|
+
|
37
|
+
This test proceeds by generating histograms of prediction probabilities for both reference and
|
38
|
+
monitoring datasets. For each class, it analyzes the distribution shape, central tendency, and
|
39
|
+
spread of probabilities. The test computes distribution moments (mean, variance, skewness,
|
40
|
+
kurtosis) and quantifies their drift between datasets. Visual comparison of overlaid histograms
|
41
|
+
provides immediate insight into distribution changes.
|
42
|
+
|
43
|
+
### Signs of High Risk
|
44
|
+
|
45
|
+
- Significant shifts in probability distribution shapes
|
46
|
+
- Large drifts in distribution moments exceeding threshold
|
47
|
+
- Appearance of new modes or peaks in monitoring data
|
48
|
+
- Changes in the spread or concentration of probabilities
|
49
|
+
- Systematic shifts in probability assignments
|
50
|
+
- Unexpected changes in distribution characteristics
|
51
|
+
|
52
|
+
### Strengths
|
53
|
+
|
54
|
+
- Provides intuitive visualization of probability changes
|
55
|
+
- Identifies specific changes in distribution shape
|
56
|
+
- Enables quantitative assessment of distribution drift
|
57
|
+
- Supports analysis across multiple classes
|
58
|
+
- Includes comprehensive moment analysis
|
59
|
+
- Maintains interpretable probability scale
|
60
|
+
|
61
|
+
### Limitations
|
62
|
+
|
63
|
+
- May be sensitive to binning choices
|
64
|
+
- Requires sufficient samples for reliable histograms
|
65
|
+
- Cannot suggest probability recalibration
|
66
|
+
- Complex interpretation for multiple classes
|
67
|
+
- May not capture subtle distribution changes
|
68
|
+
- Limited to univariate probability analysis
|
69
|
+
"""
|
70
|
+
# Get predictions and true values
|
71
|
+
y_prob_ref = datasets[0].y_prob(model)
|
72
|
+
df_ref = datasets[0].df.copy()
|
73
|
+
df_ref["probabilities"] = y_prob_ref
|
74
|
+
|
75
|
+
y_prob_mon = datasets[1].y_prob(model)
|
76
|
+
df_mon = datasets[1].df.copy()
|
77
|
+
df_mon["probabilities"] = y_prob_mon
|
78
|
+
|
79
|
+
# Get unique classes
|
80
|
+
classes = sorted(df_ref[datasets[0].target_column].unique())
|
81
|
+
|
82
|
+
# Create subplots with more horizontal space for legends
|
83
|
+
fig = make_subplots(
|
84
|
+
rows=len(classes),
|
85
|
+
cols=1,
|
86
|
+
subplot_titles=[f"Class {cls}" for cls in classes],
|
87
|
+
horizontal_spacing=0.15,
|
88
|
+
)
|
89
|
+
|
90
|
+
# Define colors
|
91
|
+
ref_color = "rgba(31, 119, 180, 0.8)" # Blue with 0.8 opacity
|
92
|
+
mon_color = "rgba(255, 127, 14, 0.8)" # Orange with 0.8 opacity
|
93
|
+
|
94
|
+
# Dictionary to store tables for each class
|
95
|
+
tables = {}
|
96
|
+
all_passed = True # Track overall pass/fail
|
97
|
+
|
98
|
+
# Add histograms and create tables for each class
|
99
|
+
for i, class_value in enumerate(classes, start=1):
|
100
|
+
# Get probabilities for current class
|
101
|
+
ref_probs = df_ref[df_ref[datasets[0].target_column] == class_value][
|
102
|
+
"probabilities"
|
103
|
+
]
|
104
|
+
mon_probs = df_mon[df_mon[datasets[1].target_column] == class_value][
|
105
|
+
"probabilities"
|
106
|
+
]
|
107
|
+
|
108
|
+
# Calculate distribution moments
|
109
|
+
ref_stats = {
|
110
|
+
"Mean": np.mean(ref_probs),
|
111
|
+
"Variance": np.var(ref_probs),
|
112
|
+
"Skewness": stats.skew(ref_probs),
|
113
|
+
"Kurtosis": stats.kurtosis(ref_probs),
|
114
|
+
}
|
115
|
+
|
116
|
+
mon_stats = {
|
117
|
+
"Mean": np.mean(mon_probs),
|
118
|
+
"Variance": np.var(mon_probs),
|
119
|
+
"Skewness": stats.skew(mon_probs),
|
120
|
+
"Kurtosis": stats.kurtosis(mon_probs),
|
121
|
+
}
|
122
|
+
|
123
|
+
# Create table for this class
|
124
|
+
table_data = []
|
125
|
+
class_passed = True # Track pass/fail for this class
|
126
|
+
|
127
|
+
for stat_name in ["Mean", "Variance", "Skewness", "Kurtosis"]:
|
128
|
+
ref_val = ref_stats[stat_name]
|
129
|
+
mon_val = mon_stats[stat_name]
|
130
|
+
drift = (
|
131
|
+
((mon_val - ref_val) / abs(ref_val)) * 100 if ref_val != 0 else np.inf
|
132
|
+
)
|
133
|
+
passed = abs(drift) < drift_pct_threshold
|
134
|
+
class_passed &= passed # Update class pass/fail
|
135
|
+
|
136
|
+
table_data.append(
|
137
|
+
{
|
138
|
+
"Statistic": stat_name,
|
139
|
+
"Reference": round(ref_val, 4),
|
140
|
+
"Monitoring": round(mon_val, 4),
|
141
|
+
"Drift (%)": round(drift, 2),
|
142
|
+
"Pass/Fail": "Pass" if passed else "Fail",
|
143
|
+
}
|
144
|
+
)
|
145
|
+
|
146
|
+
tables[f"Class {class_value}"] = pd.DataFrame(table_data)
|
147
|
+
all_passed &= class_passed # Update overall pass/fail
|
148
|
+
|
149
|
+
# Reference dataset histogram
|
150
|
+
fig.add_trace(
|
151
|
+
go.Histogram(
|
152
|
+
x=ref_probs,
|
153
|
+
name=f"Reference - Class {class_value}",
|
154
|
+
marker_color=ref_color,
|
155
|
+
showlegend=True,
|
156
|
+
legendrank=i * 2 - 1,
|
157
|
+
),
|
158
|
+
row=i,
|
159
|
+
col=1,
|
160
|
+
)
|
161
|
+
|
162
|
+
# Monitoring dataset histogram
|
163
|
+
fig.add_trace(
|
164
|
+
go.Histogram(
|
165
|
+
x=mon_probs,
|
166
|
+
name=f"Monitoring - Class {class_value}",
|
167
|
+
marker_color=mon_color,
|
168
|
+
showlegend=True,
|
169
|
+
legendrank=i * 2,
|
170
|
+
),
|
171
|
+
row=i,
|
172
|
+
col=1,
|
173
|
+
)
|
174
|
+
|
175
|
+
# Update layout
|
176
|
+
fig.update_layout(
|
177
|
+
title_text=title,
|
178
|
+
barmode="overlay",
|
179
|
+
height=300 * len(classes),
|
180
|
+
width=1000,
|
181
|
+
showlegend=True,
|
182
|
+
)
|
183
|
+
|
184
|
+
# Update axes labels and add separate legends for each subplot
|
185
|
+
for i in range(len(classes)):
|
186
|
+
fig.update_xaxes(title_text="Probability", row=i + 1, col=1)
|
187
|
+
fig.update_yaxes(title_text="Frequency", row=i + 1, col=1)
|
188
|
+
|
189
|
+
# Add separate legend for each subplot
|
190
|
+
fig.update_layout(
|
191
|
+
**{
|
192
|
+
f'legend{i+1 if i > 0 else ""}': dict(
|
193
|
+
yanchor="middle",
|
194
|
+
y=1 - (i / len(classes)) - (0.5 / len(classes)),
|
195
|
+
xanchor="left",
|
196
|
+
x=1.05,
|
197
|
+
tracegroupgap=5,
|
198
|
+
)
|
199
|
+
}
|
200
|
+
)
|
201
|
+
|
202
|
+
return fig, tables, all_passed
|
@@ -0,0 +1,97 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import plotly.graph_objects as go
|
6
|
+
from plotly.subplots import make_subplots
|
7
|
+
from validmind import tags, tasks
|
8
|
+
|
9
|
+
|
10
|
+
@tags("visualization")
|
11
|
+
@tasks("monitoring")
|
12
|
+
def PredictionQuantilesAcrossFeatures(datasets, model):
|
13
|
+
"""
|
14
|
+
Assesses differences in model prediction distributions across individual features between reference
|
15
|
+
and monitoring datasets through quantile analysis.
|
16
|
+
|
17
|
+
### Purpose
|
18
|
+
|
19
|
+
This test aims to visualize how prediction distributions vary across feature values by showing
|
20
|
+
quantile information between reference and monitoring datasets. It helps identify significant
|
21
|
+
shifts in prediction patterns and potential areas of model instability.
|
22
|
+
|
23
|
+
### Test Mechanism
|
24
|
+
|
25
|
+
The test generates box plots for each feature, comparing prediction probability distributions
|
26
|
+
between the reference and monitoring datasets. Each plot consists of two subplots showing the
|
27
|
+
quantile distribution of predictions: one for reference data and one for monitoring data.
|
28
|
+
|
29
|
+
### Signs of High Risk
|
30
|
+
|
31
|
+
- Significant differences in prediction distributions between reference and monitoring data
|
32
|
+
- Unexpected shifts in prediction quantiles across feature values
|
33
|
+
- Large changes in prediction variability between datasets
|
34
|
+
|
35
|
+
### Strengths
|
36
|
+
|
37
|
+
- Provides clear visualization of prediction distribution changes
|
38
|
+
- Shows outliers and variability in predictions across features
|
39
|
+
- Enables quick identification of problematic feature ranges
|
40
|
+
|
41
|
+
### Limitations
|
42
|
+
|
43
|
+
- May not capture complex relationships between features and predictions
|
44
|
+
- Quantile analysis may smooth over important individual predictions
|
45
|
+
- Requires careful interpretation of distribution changes
|
46
|
+
"""
|
47
|
+
|
48
|
+
feature_columns = datasets[0].feature_columns
|
49
|
+
y_prob_reference = datasets[0].y_prob(model)
|
50
|
+
y_prob_monitoring = datasets[1].y_prob(model)
|
51
|
+
|
52
|
+
figures_to_save = []
|
53
|
+
for column in feature_columns:
|
54
|
+
# Create subplot
|
55
|
+
fig = make_subplots(1, 2, subplot_titles=("Reference", "Monitoring"))
|
56
|
+
|
57
|
+
# Add reference box plot
|
58
|
+
fig.add_trace(
|
59
|
+
go.Box(
|
60
|
+
x=datasets[0].df[column],
|
61
|
+
y=y_prob_reference,
|
62
|
+
name="Reference",
|
63
|
+
boxpoints="outliers",
|
64
|
+
marker_color="blue",
|
65
|
+
),
|
66
|
+
row=1,
|
67
|
+
col=1,
|
68
|
+
)
|
69
|
+
|
70
|
+
# Add monitoring box plot
|
71
|
+
fig.add_trace(
|
72
|
+
go.Box(
|
73
|
+
x=datasets[1].df[column],
|
74
|
+
y=y_prob_monitoring,
|
75
|
+
name="Monitoring",
|
76
|
+
boxpoints="outliers",
|
77
|
+
marker_color="red",
|
78
|
+
),
|
79
|
+
row=1,
|
80
|
+
col=2,
|
81
|
+
)
|
82
|
+
|
83
|
+
# Update layout
|
84
|
+
fig.update_layout(
|
85
|
+
title=f"Prediction Distributions vs {column}",
|
86
|
+
showlegend=False,
|
87
|
+
width=800,
|
88
|
+
height=400,
|
89
|
+
)
|
90
|
+
|
91
|
+
# Update axes
|
92
|
+
fig.update_xaxes(title=column)
|
93
|
+
fig.update_yaxes(title="Prediction Value")
|
94
|
+
|
95
|
+
figures_to_save.append(fig)
|
96
|
+
|
97
|
+
return tuple(figures_to_save)
|
@@ -0,0 +1,149 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import plotly.graph_objects as go
|
7
|
+
from sklearn.metrics import roc_auc_score, roc_curve
|
8
|
+
from validmind import tags, tasks
|
9
|
+
from validmind.errors import SkipTestError
|
10
|
+
from validmind.vm_models import VMDataset, VMModel
|
11
|
+
|
12
|
+
from typing import List
|
13
|
+
|
14
|
+
|
15
|
+
@tags(
|
16
|
+
"sklearn",
|
17
|
+
"binary_classification",
|
18
|
+
"model_performance",
|
19
|
+
"visualization",
|
20
|
+
)
|
21
|
+
@tasks("classification", "text_classification")
|
22
|
+
def ROCCurveDrift(datasets: List[VMDataset], model: VMModel):
|
23
|
+
"""
|
24
|
+
Compares ROC curves between reference and monitoring datasets.
|
25
|
+
|
26
|
+
### Purpose
|
27
|
+
|
28
|
+
The ROC Curve Drift test is designed to evaluate changes in the model's discriminative ability
|
29
|
+
over time. By comparing Receiver Operating Characteristic (ROC) curves between reference and
|
30
|
+
monitoring datasets, this test helps identify whether the model maintains its ability to
|
31
|
+
distinguish between classes across different decision thresholds. This is crucial for
|
32
|
+
understanding if the model's trade-off between sensitivity and specificity remains stable
|
33
|
+
in production.
|
34
|
+
|
35
|
+
### Test Mechanism
|
36
|
+
|
37
|
+
This test proceeds by generating ROC curves for both reference and monitoring datasets. For each
|
38
|
+
dataset, it plots the True Positive Rate against the False Positive Rate across all possible
|
39
|
+
classification thresholds. The test also computes AUC scores and visualizes the difference
|
40
|
+
between ROC curves, providing both graphical and numerical assessments of discrimination
|
41
|
+
stability. Special attention is paid to regions where curves diverge significantly.
|
42
|
+
|
43
|
+
### Signs of High Risk
|
44
|
+
|
45
|
+
- Large differences between reference and monitoring ROC curves
|
46
|
+
- Significant drop in AUC score for monitoring dataset
|
47
|
+
- Systematic differences in specific FPR regions
|
48
|
+
- Changes in optimal operating points
|
49
|
+
- Inconsistent performance across different thresholds
|
50
|
+
- Unexpected crossovers between curves
|
51
|
+
|
52
|
+
### Strengths
|
53
|
+
|
54
|
+
- Provides comprehensive view of discriminative ability
|
55
|
+
- Identifies specific threshold ranges with drift
|
56
|
+
- Enables visualization of performance differences
|
57
|
+
- Includes AUC comparison for overall assessment
|
58
|
+
- Supports threshold-independent evaluation
|
59
|
+
- Maintains interpretable performance metrics
|
60
|
+
|
61
|
+
### Limitations
|
62
|
+
|
63
|
+
- Limited to binary classification problems
|
64
|
+
- May be sensitive to class distribution changes
|
65
|
+
- Cannot suggest optimal threshold adjustments
|
66
|
+
- Requires visual inspection for detailed analysis
|
67
|
+
- Complex interpretation of curve differences
|
68
|
+
- May not capture subtle performance changes
|
69
|
+
"""
|
70
|
+
# Check for binary classification
|
71
|
+
if len(np.unique(datasets[0].y)) > 2:
|
72
|
+
raise SkipTestError(
|
73
|
+
"ROC Curve Drift is only supported for binary classification models"
|
74
|
+
)
|
75
|
+
|
76
|
+
# Calculate ROC curves for reference dataset
|
77
|
+
y_prob_ref = datasets[0].y_prob(model)
|
78
|
+
y_true_ref = datasets[0].y.astype(y_prob_ref.dtype).flatten()
|
79
|
+
fpr_ref, tpr_ref, _ = roc_curve(y_true_ref, y_prob_ref, drop_intermediate=False)
|
80
|
+
auc_ref = roc_auc_score(y_true_ref, y_prob_ref)
|
81
|
+
|
82
|
+
# Calculate ROC curves for monitoring dataset
|
83
|
+
y_prob_mon = datasets[1].y_prob(model)
|
84
|
+
y_true_mon = datasets[1].y.astype(y_prob_mon.dtype).flatten()
|
85
|
+
fpr_mon, tpr_mon, _ = roc_curve(y_true_mon, y_prob_mon, drop_intermediate=False)
|
86
|
+
auc_mon = roc_auc_score(y_true_mon, y_prob_mon)
|
87
|
+
|
88
|
+
# Create superimposed ROC curves plot
|
89
|
+
fig1 = go.Figure()
|
90
|
+
|
91
|
+
fig1.add_trace(
|
92
|
+
go.Scatter(
|
93
|
+
x=fpr_ref,
|
94
|
+
y=tpr_ref,
|
95
|
+
mode="lines",
|
96
|
+
name=f"Reference (AUC = {auc_ref:.3f})",
|
97
|
+
line=dict(color="blue", width=2),
|
98
|
+
)
|
99
|
+
)
|
100
|
+
|
101
|
+
fig1.add_trace(
|
102
|
+
go.Scatter(
|
103
|
+
x=fpr_mon,
|
104
|
+
y=tpr_mon,
|
105
|
+
mode="lines",
|
106
|
+
name=f"Monitoring (AUC = {auc_mon:.3f})",
|
107
|
+
line=dict(color="red", width=2),
|
108
|
+
)
|
109
|
+
)
|
110
|
+
|
111
|
+
fig1.update_layout(
|
112
|
+
title="ROC Curves Comparison",
|
113
|
+
xaxis=dict(title="False Positive Rate"),
|
114
|
+
yaxis=dict(title="True Positive Rate"),
|
115
|
+
width=700,
|
116
|
+
height=500,
|
117
|
+
)
|
118
|
+
|
119
|
+
# Interpolate monitoring TPR to match reference FPR points
|
120
|
+
tpr_mon_interp = np.interp(fpr_ref, fpr_mon, tpr_mon)
|
121
|
+
|
122
|
+
# Calculate TPR difference
|
123
|
+
tpr_diff = tpr_mon_interp - tpr_ref
|
124
|
+
|
125
|
+
# Create difference plot
|
126
|
+
fig2 = go.Figure()
|
127
|
+
|
128
|
+
fig2.add_trace(
|
129
|
+
go.Scatter(
|
130
|
+
x=fpr_ref,
|
131
|
+
y=tpr_diff,
|
132
|
+
mode="lines",
|
133
|
+
name="TPR Difference",
|
134
|
+
line=dict(color="purple", width=2),
|
135
|
+
)
|
136
|
+
)
|
137
|
+
|
138
|
+
# Add horizontal line at y=0
|
139
|
+
fig2.add_hline(y=0, line=dict(color="grey", dash="dash"), name="No Difference")
|
140
|
+
|
141
|
+
fig2.update_layout(
|
142
|
+
title="ROC Curve Difference (Monitoring - Reference)",
|
143
|
+
xaxis=dict(title="False Positive Rate"),
|
144
|
+
yaxis=dict(title="TPR Difference"),
|
145
|
+
width=700,
|
146
|
+
height=500,
|
147
|
+
)
|
148
|
+
|
149
|
+
return fig1, fig2
|