validmind 2.7.5__py3-none-any.whl → 2.7.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +2 -0
- validmind/__version__.py +1 -1
- validmind/api_client.py +8 -1
- validmind/datasets/credit_risk/lending_club.py +352 -87
- validmind/html_templates/content_blocks.py +1 -1
- validmind/tests/__types__.py +17 -0
- validmind/tests/data_validation/ACFandPACFPlot.py +6 -2
- validmind/tests/data_validation/AutoMA.py +2 -2
- validmind/tests/data_validation/BivariateScatterPlots.py +4 -2
- validmind/tests/data_validation/BoxPierce.py +2 -2
- validmind/tests/data_validation/ClassImbalance.py +2 -1
- validmind/tests/data_validation/DatasetDescription.py +11 -2
- validmind/tests/data_validation/DatasetSplit.py +2 -2
- validmind/tests/data_validation/DickeyFullerGLS.py +2 -2
- validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +8 -2
- validmind/tests/data_validation/HighCardinality.py +9 -2
- validmind/tests/data_validation/HighPearsonCorrelation.py +18 -4
- validmind/tests/data_validation/IQROutliersBarPlot.py +9 -2
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -2
- validmind/tests/data_validation/MissingValuesBarPlot.py +12 -9
- validmind/tests/data_validation/MutualInformation.py +6 -8
- validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -2
- validmind/tests/data_validation/ProtectedClassesCombination.py +6 -1
- validmind/tests/data_validation/ProtectedClassesDescription.py +1 -1
- validmind/tests/data_validation/ProtectedClassesDisparity.py +4 -5
- validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +1 -4
- validmind/tests/data_validation/RollingStatsPlot.py +21 -10
- validmind/tests/data_validation/ScatterPlot.py +3 -5
- validmind/tests/data_validation/ScoreBandDefaultRates.py +2 -1
- validmind/tests/data_validation/SeasonalDecompose.py +12 -2
- validmind/tests/data_validation/Skewness.py +6 -3
- validmind/tests/data_validation/SpreadPlot.py +8 -3
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +4 -2
- validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -2
- validmind/tests/data_validation/TargetRateBarPlots.py +4 -3
- validmind/tests/data_validation/TimeSeriesFrequency.py +7 -2
- validmind/tests/data_validation/TimeSeriesMissingValues.py +14 -10
- validmind/tests/data_validation/TimeSeriesOutliers.py +1 -5
- validmind/tests/data_validation/WOEBinPlots.py +2 -2
- validmind/tests/data_validation/WOEBinTable.py +11 -9
- validmind/tests/data_validation/nlp/CommonWords.py +2 -2
- validmind/tests/data_validation/nlp/Hashtags.py +2 -2
- validmind/tests/data_validation/nlp/LanguageDetection.py +9 -6
- validmind/tests/data_validation/nlp/Mentions.py +9 -6
- validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +2 -2
- validmind/tests/data_validation/nlp/Punctuations.py +4 -2
- validmind/tests/data_validation/nlp/Sentiment.py +2 -2
- validmind/tests/data_validation/nlp/StopWords.py +5 -4
- validmind/tests/data_validation/nlp/TextDescription.py +2 -2
- validmind/tests/data_validation/nlp/Toxicity.py +2 -2
- validmind/tests/model_validation/BertScore.py +2 -2
- validmind/tests/model_validation/BleuScore.py +2 -2
- validmind/tests/model_validation/ClusterSizeDistribution.py +2 -2
- validmind/tests/model_validation/ContextualRecall.py +2 -2
- validmind/tests/model_validation/FeaturesAUC.py +2 -2
- validmind/tests/model_validation/MeteorScore.py +2 -2
- validmind/tests/model_validation/ModelPredictionResiduals.py +2 -2
- validmind/tests/model_validation/RegardScore.py +6 -2
- validmind/tests/model_validation/RegressionResidualsPlot.py +4 -3
- validmind/tests/model_validation/RougeScore.py +6 -5
- validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +11 -2
- validmind/tests/model_validation/TokenDisparity.py +2 -2
- validmind/tests/model_validation/ToxicityScore.py +10 -2
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +9 -3
- validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +16 -2
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +5 -3
- validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +2 -2
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +14 -4
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -2
- validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +16 -2
- validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +2 -2
- validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +4 -5
- validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +4 -2
- validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +4 -2
- validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +8 -6
- validmind/tests/model_validation/embeddings/utils.py +11 -1
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +2 -1
- validmind/tests/model_validation/ragas/AspectCritic.py +11 -7
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +2 -1
- validmind/tests/model_validation/ragas/ContextPrecision.py +2 -1
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +2 -1
- validmind/tests/model_validation/ragas/ContextRecall.py +2 -1
- validmind/tests/model_validation/ragas/Faithfulness.py +2 -1
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +2 -1
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +2 -1
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +2 -1
- validmind/tests/model_validation/sklearn/CalibrationCurve.py +3 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +2 -5
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -2
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +2 -2
- validmind/tests/model_validation/sklearn/FeatureImportance.py +1 -14
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +6 -3
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +2 -2
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +8 -4
- validmind/tests/model_validation/sklearn/ModelParameters.py +1 -0
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -3
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +2 -2
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +20 -16
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +4 -2
- validmind/tests/model_validation/sklearn/ROCCurve.py +1 -1
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +7 -9
- validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +1 -3
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +2 -1
- validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +2 -1
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -3
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +9 -1
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +1 -1
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +11 -4
- validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +1 -3
- validmind/tests/model_validation/statsmodels/GINITable.py +7 -15
- validmind/tests/model_validation/statsmodels/Lilliefors.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +2 -2
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +5 -2
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +5 -2
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +7 -7
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +2 -2
- validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +220 -0
- validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +155 -0
- validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +146 -0
- validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +148 -0
- validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +193 -0
- validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +178 -0
- validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -120
- validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -44
- validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +204 -0
- validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +98 -0
- validmind/tests/ongoing_monitoring/ROCCurveDrift.py +150 -0
- validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +212 -0
- validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +209 -0
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -13
- validmind/tests/prompt_validation/Bias.py +13 -9
- validmind/tests/prompt_validation/Clarity.py +13 -9
- validmind/tests/prompt_validation/Conciseness.py +13 -9
- validmind/tests/prompt_validation/Delimitation.py +13 -9
- validmind/tests/prompt_validation/NegativeInstruction.py +14 -11
- validmind/tests/prompt_validation/Robustness.py +6 -2
- validmind/tests/prompt_validation/Specificity.py +13 -9
- validmind/tests/run.py +6 -0
- validmind/utils.py +7 -8
- validmind/vm_models/dataset/dataset.py +0 -4
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/METADATA +2 -3
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/RECORD +149 -138
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/WHEEL +1 -1
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/LICENSE +0 -0
- {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,193 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import pandas as pd
|
9
|
+
from sklearn.metrics import confusion_matrix
|
10
|
+
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.vm_models import VMDataset, VMModel
|
13
|
+
|
14
|
+
|
15
|
+
@tags(
|
16
|
+
"sklearn", "binary_classification", "multiclass_classification", "model_performance"
|
17
|
+
)
|
18
|
+
@tasks("classification", "text_classification")
|
19
|
+
def ConfusionMatrixDrift(
|
20
|
+
datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
|
21
|
+
):
|
22
|
+
"""
|
23
|
+
Compares confusion matrix metrics between reference and monitoring datasets.
|
24
|
+
|
25
|
+
### Purpose
|
26
|
+
|
27
|
+
The Confusion Matrix Drift test is designed to evaluate changes in the model's error patterns
|
28
|
+
over time. By comparing confusion matrix elements between reference and monitoring datasets, this
|
29
|
+
test helps identify whether the model maintains consistent prediction behavior in production. This
|
30
|
+
is crucial for understanding if the model's error patterns have shifted and whether specific types
|
31
|
+
of misclassifications have become more prevalent.
|
32
|
+
|
33
|
+
### Test Mechanism
|
34
|
+
|
35
|
+
This test proceeds by generating confusion matrices for both reference and monitoring datasets.
|
36
|
+
For binary classification, it tracks True Positives, True Negatives, False Positives, and False
|
37
|
+
Negatives as percentages of total predictions. For multiclass problems, it analyzes per-class
|
38
|
+
metrics including true positives and error rates. The test quantifies drift as percentage changes
|
39
|
+
in these metrics between datasets, providing detailed insight into shifting prediction patterns.
|
40
|
+
|
41
|
+
### Signs of High Risk
|
42
|
+
|
43
|
+
- Large drifts in confusion matrix elements exceeding threshold
|
44
|
+
- Systematic changes in false positive or false negative rates
|
45
|
+
- Inconsistent changes across different classes
|
46
|
+
- Significant shifts in error patterns for specific classes
|
47
|
+
- Unexpected improvements in certain metrics
|
48
|
+
- Divergent trends between different types of errors
|
49
|
+
|
50
|
+
### Strengths
|
51
|
+
|
52
|
+
- Provides detailed analysis of prediction behavior
|
53
|
+
- Identifies specific types of prediction changes
|
54
|
+
- Enables early detection of systematic errors
|
55
|
+
- Includes comprehensive error pattern analysis
|
56
|
+
- Supports both binary and multiclass problems
|
57
|
+
- Maintains interpretable percentage-based metrics
|
58
|
+
|
59
|
+
### Limitations
|
60
|
+
|
61
|
+
- May be sensitive to class distribution changes
|
62
|
+
- Cannot identify root causes of prediction drift
|
63
|
+
- Requires sufficient samples for reliable comparison
|
64
|
+
- Limited to hard predictions (not probabilities)
|
65
|
+
- May not capture subtle changes in decision boundaries
|
66
|
+
- Complex interpretation for multiclass problems
|
67
|
+
"""
|
68
|
+
# Get predictions and true values for reference dataset
|
69
|
+
y_pred_ref = datasets[0].y_pred(model)
|
70
|
+
y_true_ref = datasets[0].y.astype(y_pred_ref.dtype)
|
71
|
+
|
72
|
+
# Get predictions and true values for monitoring dataset
|
73
|
+
y_pred_mon = datasets[1].y_pred(model)
|
74
|
+
y_true_mon = datasets[1].y.astype(y_pred_mon.dtype)
|
75
|
+
|
76
|
+
# Get unique labels from reference dataset
|
77
|
+
labels = np.unique(y_true_ref)
|
78
|
+
labels = sorted(labels.tolist())
|
79
|
+
|
80
|
+
# Calculate confusion matrices
|
81
|
+
cm_ref = confusion_matrix(y_true_ref, y_pred_ref, labels=labels)
|
82
|
+
cm_mon = confusion_matrix(y_true_mon, y_pred_mon, labels=labels)
|
83
|
+
|
84
|
+
# Get total counts
|
85
|
+
total_ref = len(y_true_ref)
|
86
|
+
total_mon = len(y_true_mon)
|
87
|
+
|
88
|
+
# Create sample counts table
|
89
|
+
counts_data = {
|
90
|
+
"Dataset": ["Reference", "Monitoring"],
|
91
|
+
"Total": [total_ref, total_mon],
|
92
|
+
}
|
93
|
+
|
94
|
+
# Add per-class counts
|
95
|
+
for label in labels:
|
96
|
+
label_str = f"Class_{label}"
|
97
|
+
counts_data[label_str] = [
|
98
|
+
np.sum(y_true_ref == label),
|
99
|
+
np.sum(y_true_mon == label),
|
100
|
+
]
|
101
|
+
|
102
|
+
counts_df = pd.DataFrame(counts_data)
|
103
|
+
|
104
|
+
# Create confusion matrix metrics
|
105
|
+
metrics = []
|
106
|
+
|
107
|
+
if len(labels) == 2:
|
108
|
+
# Binary classification
|
109
|
+
tn_ref, fp_ref, fn_ref, tp_ref = cm_ref.ravel()
|
110
|
+
tn_mon, fp_mon, fn_mon, tp_mon = cm_mon.ravel()
|
111
|
+
|
112
|
+
confusion_elements = [
|
113
|
+
("True Negatives (%)", tn_ref / total_ref * 100, tn_mon / total_mon * 100),
|
114
|
+
("False Positives (%)", fp_ref / total_ref * 100, fp_mon / total_mon * 100),
|
115
|
+
("False Negatives (%)", fn_ref / total_ref * 100, fn_mon / total_mon * 100),
|
116
|
+
("True Positives (%)", tp_ref / total_ref * 100, tp_mon / total_mon * 100),
|
117
|
+
]
|
118
|
+
|
119
|
+
for name, ref_val, mon_val in confusion_elements:
|
120
|
+
metrics.append(
|
121
|
+
{
|
122
|
+
"Metric": name,
|
123
|
+
"Reference": round(ref_val, 2),
|
124
|
+
"Monitoring": round(mon_val, 2),
|
125
|
+
}
|
126
|
+
)
|
127
|
+
|
128
|
+
else:
|
129
|
+
# Multiclass - calculate per-class metrics
|
130
|
+
for i, label in enumerate(labels):
|
131
|
+
# True Positives for this class
|
132
|
+
tp_ref = cm_ref[i, i]
|
133
|
+
tp_mon = cm_mon[i, i]
|
134
|
+
|
135
|
+
# False Positives (sum of column minus TP)
|
136
|
+
fp_ref = cm_ref[:, i].sum() - tp_ref
|
137
|
+
fp_mon = cm_mon[:, i].sum() - tp_mon
|
138
|
+
|
139
|
+
# False Negatives (sum of row minus TP)
|
140
|
+
fn_ref = cm_ref[i, :].sum() - tp_ref
|
141
|
+
fn_mon = cm_mon[i, :].sum() - tp_mon
|
142
|
+
|
143
|
+
class_metrics = [
|
144
|
+
(
|
145
|
+
f"True Positives_{label} (%)",
|
146
|
+
tp_ref / total_ref * 100,
|
147
|
+
tp_mon / total_mon * 100,
|
148
|
+
),
|
149
|
+
(
|
150
|
+
f"False Positives_{label} (%)",
|
151
|
+
fp_ref / total_ref * 100,
|
152
|
+
fp_mon / total_mon * 100,
|
153
|
+
),
|
154
|
+
(
|
155
|
+
f"False Negatives_{label} (%)",
|
156
|
+
fn_ref / total_ref * 100,
|
157
|
+
fn_mon / total_mon * 100,
|
158
|
+
),
|
159
|
+
]
|
160
|
+
|
161
|
+
for name, ref_val, mon_val in class_metrics:
|
162
|
+
metrics.append(
|
163
|
+
{
|
164
|
+
"Metric": name,
|
165
|
+
"Reference": round(ref_val, 2),
|
166
|
+
"Monitoring": round(mon_val, 2),
|
167
|
+
}
|
168
|
+
)
|
169
|
+
|
170
|
+
# Create metrics DataFrame
|
171
|
+
metrics_df = pd.DataFrame(metrics)
|
172
|
+
|
173
|
+
# Calculate drift percentage with direction
|
174
|
+
metrics_df["Drift (%)"] = (
|
175
|
+
(metrics_df["Monitoring"] - metrics_df["Reference"])
|
176
|
+
/ metrics_df["Reference"].abs()
|
177
|
+
* 100
|
178
|
+
).round(2)
|
179
|
+
|
180
|
+
# Add Pass/Fail column based on absolute drift
|
181
|
+
metrics_df["Pass/Fail"] = (
|
182
|
+
metrics_df["Drift (%)"]
|
183
|
+
.abs()
|
184
|
+
.apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
|
185
|
+
)
|
186
|
+
|
187
|
+
# Calculate overall pass/fail
|
188
|
+
pass_fail_bool = (metrics_df["Pass/Fail"] == "Pass").all()
|
189
|
+
|
190
|
+
return (
|
191
|
+
{"Confusion Matrix Metrics": metrics_df, "Sample Counts": counts_df},
|
192
|
+
pass_fail_bool,
|
193
|
+
)
|
@@ -0,0 +1,178 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from typing import List
|
6
|
+
|
7
|
+
import numpy as np
|
8
|
+
import plotly.graph_objects as go
|
9
|
+
from plotly.subplots import make_subplots
|
10
|
+
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.vm_models import VMDataset, VMModel
|
13
|
+
|
14
|
+
|
15
|
+
@tags("visualization", "credit_risk")
|
16
|
+
@tasks("classification")
|
17
|
+
def CumulativePredictionProbabilitiesDrift(
|
18
|
+
datasets: List[VMDataset],
|
19
|
+
model: VMModel,
|
20
|
+
):
|
21
|
+
"""
|
22
|
+
Compares cumulative prediction probability distributions between reference and monitoring datasets.
|
23
|
+
|
24
|
+
### Purpose
|
25
|
+
|
26
|
+
The Cumulative Prediction Probabilities Drift test is designed to evaluate changes in the model's
|
27
|
+
probability predictions over time. By comparing cumulative distribution functions of predicted
|
28
|
+
probabilities between reference and monitoring datasets, this test helps identify whether the
|
29
|
+
model's probability assignments remain stable in production. This is crucial for understanding if
|
30
|
+
the model's risk assessment behavior has shifted and whether its probability calibration remains
|
31
|
+
consistent.
|
32
|
+
|
33
|
+
### Test Mechanism
|
34
|
+
|
35
|
+
This test proceeds by generating cumulative distribution functions (CDFs) of predicted probabilities
|
36
|
+
for both reference and monitoring datasets. For each class, it plots the cumulative proportion of
|
37
|
+
predictions against probability values, enabling direct comparison of probability distributions.
|
38
|
+
The test visualizes both the CDFs and their differences, providing insight into how probability
|
39
|
+
assignments have shifted across the entire probability range.
|
40
|
+
|
41
|
+
### Signs of High Risk
|
42
|
+
|
43
|
+
- Large gaps between reference and monitoring CDFs
|
44
|
+
- Systematic shifts in probability assignments
|
45
|
+
- Concentration of differences in specific probability ranges
|
46
|
+
- Changes in the shape of probability distributions
|
47
|
+
- Unexpected patterns in cumulative differences
|
48
|
+
- Significant shifts in probability thresholds
|
49
|
+
|
50
|
+
### Strengths
|
51
|
+
|
52
|
+
- Provides comprehensive view of probability changes
|
53
|
+
- Identifies specific probability ranges with drift
|
54
|
+
- Enables visualization of distribution differences
|
55
|
+
- Supports analysis across multiple classes
|
56
|
+
- Maintains interpretable probability scale
|
57
|
+
- Captures subtle changes in probability assignments
|
58
|
+
|
59
|
+
### Limitations
|
60
|
+
|
61
|
+
- Does not provide single drift metric
|
62
|
+
- May be complex to interpret for multiple classes
|
63
|
+
- Cannot suggest probability recalibration
|
64
|
+
- Requires visual inspection for assessment
|
65
|
+
- Sensitive to sample size differences
|
66
|
+
- May not capture class-specific calibration issues
|
67
|
+
"""
|
68
|
+
# Get predictions and true values
|
69
|
+
y_prob_ref = datasets[0].y_prob(model)
|
70
|
+
df_ref = datasets[0].df.copy()
|
71
|
+
df_ref["probabilities"] = y_prob_ref
|
72
|
+
|
73
|
+
y_prob_mon = datasets[1].y_prob(model)
|
74
|
+
df_mon = datasets[1].df.copy()
|
75
|
+
df_mon["probabilities"] = y_prob_mon
|
76
|
+
|
77
|
+
# Get unique classes
|
78
|
+
classes = sorted(df_ref[datasets[0].target_column].unique())
|
79
|
+
|
80
|
+
# Define colors
|
81
|
+
ref_color = "rgba(31, 119, 180, 0.8)" # Blue with 0.8 opacity
|
82
|
+
mon_color = "rgba(255, 127, 14, 0.8)" # Orange with 0.8 opacity
|
83
|
+
diff_color = "rgba(148, 103, 189, 0.8)" # Purple with 0.8 opacity
|
84
|
+
|
85
|
+
figures = []
|
86
|
+
for class_value in classes:
|
87
|
+
# Create figure with secondary y-axis
|
88
|
+
fig = make_subplots(
|
89
|
+
rows=2,
|
90
|
+
cols=1,
|
91
|
+
subplot_titles=[
|
92
|
+
f"Cumulative Distributions - Class {class_value}",
|
93
|
+
"Difference (Monitoring - Reference)",
|
94
|
+
],
|
95
|
+
vertical_spacing=0.15,
|
96
|
+
shared_xaxes=True,
|
97
|
+
)
|
98
|
+
|
99
|
+
# Get probabilities for current class
|
100
|
+
ref_probs = df_ref[df_ref[datasets[0].target_column] == class_value][
|
101
|
+
"probabilities"
|
102
|
+
]
|
103
|
+
mon_probs = df_mon[df_mon[datasets[1].target_column] == class_value][
|
104
|
+
"probabilities"
|
105
|
+
]
|
106
|
+
|
107
|
+
# Calculate cumulative distributions
|
108
|
+
ref_sorted = np.sort(ref_probs)
|
109
|
+
ref_cumsum = np.arange(len(ref_sorted)) / float(len(ref_sorted))
|
110
|
+
|
111
|
+
mon_sorted = np.sort(mon_probs)
|
112
|
+
mon_cumsum = np.arange(len(mon_sorted)) / float(len(mon_sorted))
|
113
|
+
|
114
|
+
# Reference dataset cumulative curve
|
115
|
+
fig.add_trace(
|
116
|
+
go.Scatter(
|
117
|
+
x=ref_sorted,
|
118
|
+
y=ref_cumsum,
|
119
|
+
mode="lines",
|
120
|
+
name="Reference",
|
121
|
+
line=dict(color=ref_color, width=2),
|
122
|
+
),
|
123
|
+
row=1,
|
124
|
+
col=1,
|
125
|
+
)
|
126
|
+
|
127
|
+
# Monitoring dataset cumulative curve
|
128
|
+
fig.add_trace(
|
129
|
+
go.Scatter(
|
130
|
+
x=mon_sorted,
|
131
|
+
y=mon_cumsum,
|
132
|
+
mode="lines",
|
133
|
+
name="Monitoring",
|
134
|
+
line=dict(color=mon_color, width=2),
|
135
|
+
),
|
136
|
+
row=1,
|
137
|
+
col=1,
|
138
|
+
)
|
139
|
+
|
140
|
+
# Calculate and plot difference
|
141
|
+
# Interpolate monitoring values to match reference x-points
|
142
|
+
mon_interp = np.interp(ref_sorted, mon_sorted, mon_cumsum)
|
143
|
+
difference = mon_interp - ref_cumsum
|
144
|
+
|
145
|
+
fig.add_trace(
|
146
|
+
go.Scatter(
|
147
|
+
x=ref_sorted,
|
148
|
+
y=difference,
|
149
|
+
mode="lines",
|
150
|
+
name="Difference",
|
151
|
+
line=dict(color=diff_color, width=2),
|
152
|
+
),
|
153
|
+
row=2,
|
154
|
+
col=1,
|
155
|
+
)
|
156
|
+
|
157
|
+
# Add horizontal line at y=0 for difference plot
|
158
|
+
fig.add_hline(y=0, line=dict(color="grey", dash="dash"), row=2, col=1)
|
159
|
+
|
160
|
+
# Update layout
|
161
|
+
fig.update_layout(
|
162
|
+
height=600,
|
163
|
+
width=800,
|
164
|
+
showlegend=True,
|
165
|
+
legend=dict(yanchor="middle", y=0.9, xanchor="left", x=1.05),
|
166
|
+
)
|
167
|
+
|
168
|
+
# Update axes
|
169
|
+
fig.update_xaxes(title_text="Probability", range=[0, 1], row=2, col=1)
|
170
|
+
fig.update_xaxes(range=[0, 1], row=1, col=1)
|
171
|
+
fig.update_yaxes(
|
172
|
+
title_text="Cumulative Distribution", range=[0, 1], row=1, col=1
|
173
|
+
)
|
174
|
+
fig.update_yaxes(title_text="Difference", row=2, col=1)
|
175
|
+
|
176
|
+
figures.append(fig)
|
177
|
+
|
178
|
+
return tuple(figures)
|
@@ -2,18 +2,100 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
|
6
|
-
import matplotlib.pyplot as plt
|
7
5
|
import numpy as np
|
8
6
|
import pandas as pd
|
7
|
+
import plotly.graph_objects as go
|
9
8
|
|
10
9
|
from validmind import tags, tasks
|
11
10
|
|
12
11
|
|
12
|
+
def calculate_psi_score(actual, expected):
|
13
|
+
"""Calculate PSI score for a single bucket."""
|
14
|
+
return (actual - expected) * np.log((actual + 1e-6) / (expected + 1e-6))
|
15
|
+
|
16
|
+
|
17
|
+
def calculate_feature_distributions(
|
18
|
+
reference_data, monitoring_data, feature_columns, bins
|
19
|
+
):
|
20
|
+
"""Calculate population distributions for each feature."""
|
21
|
+
# Calculate quantiles from reference data
|
22
|
+
quantiles = reference_data[feature_columns].quantile(
|
23
|
+
bins, method="single", interpolation="nearest"
|
24
|
+
)
|
25
|
+
|
26
|
+
distributions = {}
|
27
|
+
for dataset_name, data in [
|
28
|
+
("reference", reference_data),
|
29
|
+
("monitoring", monitoring_data),
|
30
|
+
]:
|
31
|
+
for feature in feature_columns:
|
32
|
+
for bin_idx, threshold in enumerate(quantiles[feature]):
|
33
|
+
if bin_idx == 0:
|
34
|
+
mask = data[feature] < threshold
|
35
|
+
else:
|
36
|
+
prev_threshold = quantiles[feature][bins[bin_idx - 1]]
|
37
|
+
mask = (data[feature] >= prev_threshold) & (
|
38
|
+
data[feature] < threshold
|
39
|
+
)
|
40
|
+
|
41
|
+
count = mask.sum()
|
42
|
+
proportion = count / len(data)
|
43
|
+
distributions[(dataset_name, feature, bins[bin_idx])] = proportion
|
44
|
+
|
45
|
+
return distributions
|
46
|
+
|
47
|
+
|
48
|
+
def create_distribution_plot(feature_name, reference_dist, monitoring_dist, bins):
|
49
|
+
"""Create population distribution plot for a feature."""
|
50
|
+
fig = go.Figure()
|
51
|
+
|
52
|
+
# Add reference distribution
|
53
|
+
fig.add_trace(
|
54
|
+
go.Bar(
|
55
|
+
x=list(range(len(bins))),
|
56
|
+
y=reference_dist,
|
57
|
+
name="Reference",
|
58
|
+
marker_color="blue",
|
59
|
+
marker_line_color="black",
|
60
|
+
marker_line_width=1,
|
61
|
+
opacity=0.75,
|
62
|
+
)
|
63
|
+
)
|
64
|
+
|
65
|
+
# Add monitoring distribution
|
66
|
+
fig.add_trace(
|
67
|
+
go.Bar(
|
68
|
+
x=list(range(len(bins))),
|
69
|
+
y=monitoring_dist,
|
70
|
+
name="Monitoring",
|
71
|
+
marker_color="green",
|
72
|
+
marker_line_color="black",
|
73
|
+
marker_line_width=1,
|
74
|
+
opacity=0.75,
|
75
|
+
)
|
76
|
+
)
|
77
|
+
|
78
|
+
fig.update_layout(
|
79
|
+
title=f"Population Distribution: {feature_name}",
|
80
|
+
xaxis_title="Bin",
|
81
|
+
yaxis_title="Population %",
|
82
|
+
barmode="group",
|
83
|
+
template="plotly_white",
|
84
|
+
showlegend=True,
|
85
|
+
width=800,
|
86
|
+
height=400,
|
87
|
+
)
|
88
|
+
|
89
|
+
return fig
|
90
|
+
|
91
|
+
|
13
92
|
@tags("visualization")
|
14
93
|
@tasks("monitoring")
|
15
94
|
def FeatureDrift(
|
16
|
-
datasets,
|
95
|
+
datasets,
|
96
|
+
bins=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
|
97
|
+
feature_columns=None,
|
98
|
+
psi_threshold=0.2,
|
17
99
|
):
|
18
100
|
"""
|
19
101
|
Evaluates changes in feature distribution over time to identify potential model drift.
|
@@ -57,130 +139,48 @@ def FeatureDrift(
|
|
57
139
|
- PSI score interpretation can be overly simplistic for complex datasets.
|
58
140
|
"""
|
59
141
|
|
60
|
-
#
|
61
|
-
|
62
|
-
feature_columns = feature_columns or default_feature_columns
|
142
|
+
# Get feature columns
|
143
|
+
feature_columns = feature_columns or datasets[0].feature_columns
|
63
144
|
|
64
|
-
|
65
|
-
|
145
|
+
# Get data
|
146
|
+
reference_data = datasets[0].df
|
147
|
+
monitoring_data = datasets[1].df
|
66
148
|
|
67
|
-
|
68
|
-
|
149
|
+
# Calculate distributions
|
150
|
+
distributions = calculate_feature_distributions(
|
151
|
+
reference_data, monitoring_data, feature_columns, bins
|
69
152
|
)
|
70
|
-
PSI_QUANTILES = quantiles_train.to_dict()
|
71
|
-
|
72
|
-
PSI_BUCKET_FRAC, col, n = get_psi_buckets(
|
73
|
-
x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES
|
74
|
-
)
|
75
|
-
|
76
|
-
def nest(d: dict) -> dict:
|
77
|
-
result = {}
|
78
|
-
for key, value in d.items():
|
79
|
-
target = result
|
80
|
-
for k in key[:-1]: # traverse all keys but the last
|
81
|
-
target = target.setdefault(k, {})
|
82
|
-
target[key[-1]] = value
|
83
|
-
return result
|
84
|
-
|
85
|
-
PSI_BUCKET_FRAC = nest(PSI_BUCKET_FRAC)
|
86
153
|
|
87
|
-
|
88
|
-
|
154
|
+
# Calculate PSI scores
|
155
|
+
psi_scores = {}
|
156
|
+
for feature in feature_columns:
|
89
157
|
psi = 0
|
90
|
-
for
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
158
|
+
for bin_val in bins:
|
159
|
+
reference_prop = distributions[("reference", feature, bin_val)]
|
160
|
+
monitoring_prop = distributions[("monitoring", feature, bin_val)]
|
161
|
+
psi += calculate_psi_score(monitoring_prop, reference_prop)
|
162
|
+
psi_scores[feature] = psi
|
163
|
+
|
164
|
+
# Create PSI score dataframe
|
165
|
+
psi_df = pd.DataFrame(list(psi_scores.items()), columns=["Feature", "PSI Score"])
|
166
|
+
|
167
|
+
# Add Pass/Fail column
|
168
|
+
psi_df["Pass/Fail"] = psi_df["PSI Score"].apply(
|
169
|
+
lambda x: "Pass" if x < psi_threshold else "Fail"
|
170
|
+
)
|
100
171
|
|
172
|
+
# Sort by PSI Score
|
101
173
|
psi_df.sort_values(by=["PSI Score"], inplace=True, ascending=False)
|
102
174
|
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
final_psi = pd.DataFrame(psi_table)
|
111
|
-
|
112
|
-
return (final_psi, *save_fig)
|
113
|
-
|
114
|
-
|
115
|
-
def get_psi_buckets(x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES):
|
116
|
-
DATA = {"test": x_test_df, "train": x_train_df}
|
117
|
-
PSI_BUCKET_FRAC = {}
|
118
|
-
for table in DATA.keys():
|
119
|
-
total_count = DATA[table].shape[0]
|
120
|
-
for col in feature_columns:
|
121
|
-
count_sum = 0
|
122
|
-
for n in bins:
|
123
|
-
if n == 0:
|
124
|
-
bucket_count = (DATA[table][col] < PSI_QUANTILES[col][n]).sum()
|
125
|
-
elif n < 9:
|
126
|
-
bucket_count = (
|
127
|
-
total_count
|
128
|
-
- count_sum
|
129
|
-
- ((DATA[table][col] >= PSI_QUANTILES[col][n]).sum())
|
130
|
-
)
|
131
|
-
elif n == 9:
|
132
|
-
bucket_count = total_count - count_sum
|
133
|
-
count_sum += bucket_count
|
134
|
-
PSI_BUCKET_FRAC[table, col, n] = bucket_count / total_count
|
135
|
-
return PSI_BUCKET_FRAC, col, n
|
136
|
-
|
137
|
-
|
138
|
-
def plot_hist(PSI_BUCKET_FRAC, bins):
|
139
|
-
bin_table_psi = pd.DataFrame(PSI_BUCKET_FRAC)
|
140
|
-
save_fig = []
|
141
|
-
for i in range(len(bin_table_psi)):
|
175
|
+
# Create distribution plots
|
176
|
+
figures = []
|
177
|
+
for feature in feature_columns:
|
178
|
+
reference_dist = [distributions[("reference", feature, b)] for b in bins]
|
179
|
+
monitoring_dist = [distributions[("monitoring", feature, b)] for b in bins]
|
180
|
+
fig = create_distribution_plot(feature, reference_dist, monitoring_dist, bins)
|
181
|
+
figures.append(fig)
|
142
182
|
|
143
|
-
|
144
|
-
|
145
|
-
columns=["Bin", "Population % Reference"],
|
146
|
-
)
|
147
|
-
y = pd.DataFrame(
|
148
|
-
bin_table_psi.iloc[i]["train"].items(),
|
149
|
-
columns=["Bin", "Population % Monitoring"],
|
150
|
-
)
|
151
|
-
xy = x.merge(y, on="Bin")
|
152
|
-
xy.index = xy["Bin"]
|
153
|
-
xy = xy.drop(columns="Bin", axis=1)
|
154
|
-
feature_name = bin_table_psi.index[i]
|
155
|
-
|
156
|
-
n = len(bins)
|
157
|
-
r = np.arange(n)
|
158
|
-
width = 0.25
|
159
|
-
|
160
|
-
fig = plt.figure()
|
161
|
-
|
162
|
-
plt.bar(
|
163
|
-
r,
|
164
|
-
xy["Population % Reference"],
|
165
|
-
color="b",
|
166
|
-
width=width,
|
167
|
-
edgecolor="black",
|
168
|
-
label="Reference {0}".format(feature_name),
|
169
|
-
)
|
170
|
-
plt.bar(
|
171
|
-
r + width,
|
172
|
-
xy["Population % Monitoring"],
|
173
|
-
color="g",
|
174
|
-
width=width,
|
175
|
-
edgecolor="black",
|
176
|
-
label="Monitoring {0}".format(feature_name),
|
177
|
-
)
|
183
|
+
# Calculate overall pass/fail
|
184
|
+
pass_fail_bool = (psi_df["Pass/Fail"] == "Pass").all()
|
178
185
|
|
179
|
-
|
180
|
-
plt.ylabel("Population %")
|
181
|
-
plt.title("Histogram of Population Differences {0}".format(feature_name))
|
182
|
-
plt.legend()
|
183
|
-
plt.tight_layout()
|
184
|
-
plt.close()
|
185
|
-
save_fig.append(fig)
|
186
|
-
return save_fig
|
186
|
+
return ({"PSI Scores": psi_df}, *figures, pass_fail_bool)
|
@@ -53,30 +53,25 @@ def PredictionAcrossEachFeature(datasets, model):
|
|
53
53
|
observed during the training of the model.
|
54
54
|
"""
|
55
55
|
|
56
|
-
|
57
|
-
|
56
|
+
y_prob_reference = datasets[0].y_prob(model)
|
57
|
+
y_prob_monitoring = datasets[1].y_prob(model)
|
58
58
|
|
59
59
|
figures_to_save = []
|
60
|
-
for column in
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
ax2.set_title("Monitoring")
|
78
|
-
ax2.set_xlabel(column)
|
79
|
-
figures_to_save.append(fig)
|
80
|
-
plt.close()
|
60
|
+
for column in datasets[0].feature_columns:
|
61
|
+
fig, axs = plt.subplots(1, 2, figsize=(20, 10), sharey="row")
|
62
|
+
|
63
|
+
ax1, ax2 = axs
|
64
|
+
|
65
|
+
ax1.scatter(datasets[0].df[column], y_prob_reference)
|
66
|
+
ax2.scatter(datasets[1].df[column], y_prob_monitoring)
|
67
|
+
|
68
|
+
ax1.set_title("Reference")
|
69
|
+
ax1.set_xlabel(column)
|
70
|
+
ax1.set_ylabel("Prediction Value")
|
71
|
+
|
72
|
+
ax2.set_title("Monitoring")
|
73
|
+
ax2.set_xlabel(column)
|
74
|
+
figures_to_save.append(fig)
|
75
|
+
plt.close()
|
81
76
|
|
82
77
|
return tuple(figures_to_save)
|