validmind 2.7.5__py3-none-any.whl → 2.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. validmind/__init__.py +2 -0
  2. validmind/__version__.py +1 -1
  3. validmind/api_client.py +8 -1
  4. validmind/datasets/credit_risk/lending_club.py +352 -87
  5. validmind/html_templates/content_blocks.py +1 -1
  6. validmind/tests/__types__.py +17 -0
  7. validmind/tests/data_validation/ACFandPACFPlot.py +6 -2
  8. validmind/tests/data_validation/AutoMA.py +2 -2
  9. validmind/tests/data_validation/BivariateScatterPlots.py +4 -2
  10. validmind/tests/data_validation/BoxPierce.py +2 -2
  11. validmind/tests/data_validation/ClassImbalance.py +2 -1
  12. validmind/tests/data_validation/DatasetDescription.py +11 -2
  13. validmind/tests/data_validation/DatasetSplit.py +2 -2
  14. validmind/tests/data_validation/DickeyFullerGLS.py +2 -2
  15. validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +8 -2
  16. validmind/tests/data_validation/HighCardinality.py +9 -2
  17. validmind/tests/data_validation/HighPearsonCorrelation.py +18 -4
  18. validmind/tests/data_validation/IQROutliersBarPlot.py +9 -2
  19. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -2
  20. validmind/tests/data_validation/MissingValuesBarPlot.py +12 -9
  21. validmind/tests/data_validation/MutualInformation.py +6 -8
  22. validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -2
  23. validmind/tests/data_validation/ProtectedClassesCombination.py +6 -1
  24. validmind/tests/data_validation/ProtectedClassesDescription.py +1 -1
  25. validmind/tests/data_validation/ProtectedClassesDisparity.py +4 -5
  26. validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +1 -4
  27. validmind/tests/data_validation/RollingStatsPlot.py +21 -10
  28. validmind/tests/data_validation/ScatterPlot.py +3 -5
  29. validmind/tests/data_validation/ScoreBandDefaultRates.py +2 -1
  30. validmind/tests/data_validation/SeasonalDecompose.py +12 -2
  31. validmind/tests/data_validation/Skewness.py +6 -3
  32. validmind/tests/data_validation/SpreadPlot.py +8 -3
  33. validmind/tests/data_validation/TabularCategoricalBarPlots.py +4 -2
  34. validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -2
  35. validmind/tests/data_validation/TargetRateBarPlots.py +4 -3
  36. validmind/tests/data_validation/TimeSeriesFrequency.py +7 -2
  37. validmind/tests/data_validation/TimeSeriesMissingValues.py +14 -10
  38. validmind/tests/data_validation/TimeSeriesOutliers.py +1 -5
  39. validmind/tests/data_validation/WOEBinPlots.py +2 -2
  40. validmind/tests/data_validation/WOEBinTable.py +11 -9
  41. validmind/tests/data_validation/nlp/CommonWords.py +2 -2
  42. validmind/tests/data_validation/nlp/Hashtags.py +2 -2
  43. validmind/tests/data_validation/nlp/LanguageDetection.py +9 -6
  44. validmind/tests/data_validation/nlp/Mentions.py +9 -6
  45. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +2 -2
  46. validmind/tests/data_validation/nlp/Punctuations.py +4 -2
  47. validmind/tests/data_validation/nlp/Sentiment.py +2 -2
  48. validmind/tests/data_validation/nlp/StopWords.py +5 -4
  49. validmind/tests/data_validation/nlp/TextDescription.py +2 -2
  50. validmind/tests/data_validation/nlp/Toxicity.py +2 -2
  51. validmind/tests/model_validation/BertScore.py +2 -2
  52. validmind/tests/model_validation/BleuScore.py +2 -2
  53. validmind/tests/model_validation/ClusterSizeDistribution.py +2 -2
  54. validmind/tests/model_validation/ContextualRecall.py +2 -2
  55. validmind/tests/model_validation/FeaturesAUC.py +2 -2
  56. validmind/tests/model_validation/MeteorScore.py +2 -2
  57. validmind/tests/model_validation/ModelPredictionResiduals.py +2 -2
  58. validmind/tests/model_validation/RegardScore.py +6 -2
  59. validmind/tests/model_validation/RegressionResidualsPlot.py +4 -3
  60. validmind/tests/model_validation/RougeScore.py +6 -5
  61. validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +11 -2
  62. validmind/tests/model_validation/TokenDisparity.py +2 -2
  63. validmind/tests/model_validation/ToxicityScore.py +10 -2
  64. validmind/tests/model_validation/embeddings/ClusterDistribution.py +9 -3
  65. validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +16 -2
  66. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +5 -3
  67. validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +2 -2
  68. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +14 -4
  69. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -2
  70. validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +16 -2
  71. validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +2 -2
  72. validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +4 -5
  73. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +4 -2
  74. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +4 -2
  75. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +4 -2
  76. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +4 -2
  77. validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +8 -6
  78. validmind/tests/model_validation/embeddings/utils.py +11 -1
  79. validmind/tests/model_validation/ragas/AnswerCorrectness.py +2 -1
  80. validmind/tests/model_validation/ragas/AspectCritic.py +11 -7
  81. validmind/tests/model_validation/ragas/ContextEntityRecall.py +2 -1
  82. validmind/tests/model_validation/ragas/ContextPrecision.py +2 -1
  83. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +2 -1
  84. validmind/tests/model_validation/ragas/ContextRecall.py +2 -1
  85. validmind/tests/model_validation/ragas/Faithfulness.py +2 -1
  86. validmind/tests/model_validation/ragas/NoiseSensitivity.py +2 -1
  87. validmind/tests/model_validation/ragas/ResponseRelevancy.py +2 -1
  88. validmind/tests/model_validation/ragas/SemanticSimilarity.py +2 -1
  89. validmind/tests/model_validation/sklearn/CalibrationCurve.py +3 -2
  90. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +2 -5
  91. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -2
  92. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +2 -2
  93. validmind/tests/model_validation/sklearn/FeatureImportance.py +1 -14
  94. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +6 -3
  95. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +2 -2
  96. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +8 -4
  97. validmind/tests/model_validation/sklearn/ModelParameters.py +1 -0
  98. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -3
  99. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +2 -2
  100. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +20 -16
  101. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +4 -2
  102. validmind/tests/model_validation/sklearn/ROCCurve.py +1 -1
  103. validmind/tests/model_validation/sklearn/RegressionR2Square.py +7 -9
  104. validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +1 -3
  105. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +2 -1
  106. validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +2 -1
  107. validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -3
  108. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +9 -1
  109. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +1 -1
  110. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +11 -4
  111. validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +1 -3
  112. validmind/tests/model_validation/statsmodels/GINITable.py +7 -15
  113. validmind/tests/model_validation/statsmodels/Lilliefors.py +2 -2
  114. validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
  115. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +2 -2
  116. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +5 -2
  117. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +5 -2
  118. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +7 -7
  119. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +2 -2
  120. validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +220 -0
  121. validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +155 -0
  122. validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +146 -0
  123. validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +148 -0
  124. validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +193 -0
  125. validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +178 -0
  126. validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -120
  127. validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
  128. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -44
  129. validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +204 -0
  130. validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +98 -0
  131. validmind/tests/ongoing_monitoring/ROCCurveDrift.py +150 -0
  132. validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +212 -0
  133. validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +209 -0
  134. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -13
  135. validmind/tests/prompt_validation/Bias.py +13 -9
  136. validmind/tests/prompt_validation/Clarity.py +13 -9
  137. validmind/tests/prompt_validation/Conciseness.py +13 -9
  138. validmind/tests/prompt_validation/Delimitation.py +13 -9
  139. validmind/tests/prompt_validation/NegativeInstruction.py +14 -11
  140. validmind/tests/prompt_validation/Robustness.py +6 -2
  141. validmind/tests/prompt_validation/Specificity.py +13 -9
  142. validmind/tests/run.py +6 -0
  143. validmind/utils.py +7 -8
  144. validmind/vm_models/dataset/dataset.py +0 -4
  145. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/METADATA +2 -3
  146. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/RECORD +149 -138
  147. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/WHEEL +1 -1
  148. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/LICENSE +0 -0
  149. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/entry_points.txt +0 -0
@@ -2,16 +2,15 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
-
6
- import matplotlib.pyplot as plt
7
- import numpy as np
5
+ import pandas as pd
6
+ import plotly.graph_objects as go
8
7
 
9
8
  from validmind import tags, tasks
10
9
 
11
10
 
12
11
  @tags("visualization")
13
12
  @tasks("monitoring")
14
- def PredictionCorrelation(datasets, model):
13
+ def PredictionCorrelation(datasets, model, drift_pct_threshold=20):
15
14
  """
16
15
  Assesses correlation changes between model predictions from reference and monitoring datasets to detect potential
17
16
  target drift.
@@ -47,55 +46,98 @@ def PredictionCorrelation(datasets, model):
47
46
  - Focuses solely on linear relationships, potentially missing non-linear interactions.
48
47
  """
49
48
 
50
- prediction_prob_column = f"{model.input_id}_probabilities"
51
- prediction_column = f"{model.input_id}_prediction"
49
+ # Get feature columns and predictions
50
+ feature_columns = datasets[0].feature_columns
51
+ y_prob_ref = pd.Series(datasets[0].y_prob(model), index=datasets[0].df.index)
52
+ y_prob_mon = pd.Series(datasets[1].y_prob(model), index=datasets[1].df.index)
52
53
 
53
- df_corr = datasets[0]._df.corr()
54
- df_corr = df_corr[[prediction_prob_column]]
54
+ # Create dataframes with features and predictions
55
+ df_ref = datasets[0].df[feature_columns].copy()
56
+ df_ref["predictions"] = y_prob_ref
55
57
 
56
- df_corr2 = datasets[1]._df.corr()
57
- df_corr2 = df_corr2[[prediction_prob_column]]
58
+ df_mon = datasets[1].df[feature_columns].copy()
59
+ df_mon["predictions"] = y_prob_mon
58
60
 
59
- corr_final = df_corr.merge(df_corr2, left_index=True, right_index=True)
60
- corr_final.columns = ["Reference Predictions", "Monitoring Predictions"]
61
- corr_final = corr_final.drop(index=[prediction_column, prediction_prob_column])
61
+ # Calculate correlations
62
+ corr_ref = df_ref.corr()["predictions"]
63
+ corr_mon = df_mon.corr()["predictions"]
62
64
 
63
- n = len(corr_final)
64
- r = np.arange(n)
65
- width = 0.25
65
+ # Combine correlations (excluding the predictions row)
66
+ corr_final = pd.DataFrame(
67
+ {
68
+ "Reference Predictions": corr_ref[feature_columns],
69
+ "Monitoring Predictions": corr_mon[feature_columns],
70
+ }
71
+ )
66
72
 
67
- fig = plt.figure()
73
+ # Calculate drift percentage with direction
74
+ corr_final["Drift (%)"] = (
75
+ (corr_final["Monitoring Predictions"] - corr_final["Reference Predictions"])
76
+ / corr_final["Reference Predictions"].abs()
77
+ * 100
78
+ ).round(2)
79
+
80
+ # Add Pass/Fail column based on absolute drift
81
+ corr_final["Pass/Fail"] = (
82
+ corr_final["Drift (%)"]
83
+ .abs()
84
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
85
+ )
68
86
 
69
- plt.bar(
70
- r,
71
- corr_final["Reference Predictions"],
72
- color="b",
73
- width=width,
74
- edgecolor="black",
75
- label="Reference Prediction Correlation",
87
+ # Create plotly figure
88
+ fig = go.Figure()
89
+
90
+ # Add reference predictions bar
91
+ fig.add_trace(
92
+ go.Bar(
93
+ name="Reference Prediction Correlation",
94
+ x=corr_final.index,
95
+ y=corr_final["Reference Predictions"],
96
+ marker_color="blue",
97
+ marker_line_color="black",
98
+ marker_line_width=1,
99
+ opacity=0.75,
100
+ )
76
101
  )
77
- plt.bar(
78
- r + width,
79
- corr_final["Monitoring Predictions"],
80
- color="g",
81
- width=width,
82
- edgecolor="black",
83
- label="Monitoring Prediction Correlation",
102
+
103
+ # Add monitoring predictions bar
104
+ fig.add_trace(
105
+ go.Bar(
106
+ name="Monitoring Prediction Correlation",
107
+ x=corr_final.index,
108
+ y=corr_final["Monitoring Predictions"],
109
+ marker_color="green",
110
+ marker_line_color="black",
111
+ marker_line_width=1,
112
+ opacity=0.75,
113
+ )
84
114
  )
85
115
 
86
- plt.xlabel("Features")
87
- plt.ylabel("Correlation")
88
- plt.title("Correlation between Predictions and Features")
116
+ # Update layout
117
+ fig.update_layout(
118
+ title="Correlation between Predictions and Features",
119
+ xaxis_title="Features",
120
+ yaxis_title="Correlation",
121
+ barmode="group",
122
+ template="plotly_white",
123
+ showlegend=True,
124
+ xaxis_tickangle=-45,
125
+ yaxis=dict(
126
+ range=[-1, 1], # Correlation range is always -1 to 1
127
+ zeroline=True,
128
+ zerolinewidth=1,
129
+ zerolinecolor="grey",
130
+ gridcolor="lightgrey",
131
+ ),
132
+ hoverlabel=dict(bgcolor="white", font_size=12, font_family="Arial"),
133
+ )
89
134
 
90
- features = corr_final.index.to_list()
91
- plt.xticks(r + width / 2, features, rotation=45)
92
- plt.legend()
93
- plt.tight_layout()
135
+ # Ensure Features is the first column
136
+ corr_final["Feature"] = corr_final.index
137
+ cols = ["Feature"] + [col for col in corr_final.columns if col != "Feature"]
138
+ corr_final = corr_final[cols]
94
139
 
95
- plt.close()
140
+ # Calculate overall pass/fail
141
+ pass_fail_bool = (corr_final["Pass/Fail"] == "Pass").all()
96
142
 
97
- corr_final["Features"] = corr_final.index
98
- corr_final = corr_final[
99
- ["Features", "Reference Predictions", "Monitoring Predictions"]
100
- ]
101
- return ({"Correlation Pair Table": corr_final}, fig)
143
+ return ({"Correlation Pair Table": corr_final}, fig, pass_fail_bool)
@@ -0,0 +1,204 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ import plotly.graph_objects as go
10
+ from plotly.subplots import make_subplots
11
+ from scipy import stats
12
+
13
+ from validmind import tags, tasks
14
+ from validmind.vm_models import VMDataset, VMModel
15
+
16
+
17
+ @tags("visualization", "credit_risk")
18
+ @tasks("classification")
19
+ def PredictionProbabilitiesHistogramDrift(
20
+ datasets: List[VMDataset],
21
+ model: VMModel,
22
+ title="Prediction Probabilities Histogram Drift",
23
+ drift_pct_threshold: float = 20.0,
24
+ ):
25
+ """
26
+ Compares prediction probability distributions between reference and monitoring datasets.
27
+
28
+ ### Purpose
29
+
30
+ The Prediction Probabilities Histogram Drift test is designed to evaluate changes in the model's
31
+ probability predictions over time. By comparing probability distributions between reference and
32
+ monitoring datasets using histograms, this test helps identify whether the model's probability
33
+ assignments have shifted in production. This is crucial for understanding if the model's risk
34
+ assessment behavior remains consistent and whether its probability estimates maintain their
35
+ original distribution patterns.
36
+
37
+ ### Test Mechanism
38
+
39
+ This test proceeds by generating histograms of prediction probabilities for both reference and
40
+ monitoring datasets. For each class, it analyzes the distribution shape, central tendency, and
41
+ spread of probabilities. The test computes distribution moments (mean, variance, skewness,
42
+ kurtosis) and quantifies their drift between datasets. Visual comparison of overlaid histograms
43
+ provides immediate insight into distribution changes.
44
+
45
+ ### Signs of High Risk
46
+
47
+ - Significant shifts in probability distribution shapes
48
+ - Large drifts in distribution moments exceeding threshold
49
+ - Appearance of new modes or peaks in monitoring data
50
+ - Changes in the spread or concentration of probabilities
51
+ - Systematic shifts in probability assignments
52
+ - Unexpected changes in distribution characteristics
53
+
54
+ ### Strengths
55
+
56
+ - Provides intuitive visualization of probability changes
57
+ - Identifies specific changes in distribution shape
58
+ - Enables quantitative assessment of distribution drift
59
+ - Supports analysis across multiple classes
60
+ - Includes comprehensive moment analysis
61
+ - Maintains interpretable probability scale
62
+
63
+ ### Limitations
64
+
65
+ - May be sensitive to binning choices
66
+ - Requires sufficient samples for reliable histograms
67
+ - Cannot suggest probability recalibration
68
+ - Complex interpretation for multiple classes
69
+ - May not capture subtle distribution changes
70
+ - Limited to univariate probability analysis
71
+ """
72
+ # Get predictions and true values
73
+ y_prob_ref = datasets[0].y_prob(model)
74
+ df_ref = datasets[0].df.copy()
75
+ df_ref["probabilities"] = y_prob_ref
76
+
77
+ y_prob_mon = datasets[1].y_prob(model)
78
+ df_mon = datasets[1].df.copy()
79
+ df_mon["probabilities"] = y_prob_mon
80
+
81
+ # Get unique classes
82
+ classes = sorted(df_ref[datasets[0].target_column].unique())
83
+
84
+ # Create subplots with more horizontal space for legends
85
+ fig = make_subplots(
86
+ rows=len(classes),
87
+ cols=1,
88
+ subplot_titles=[f"Class {cls}" for cls in classes],
89
+ horizontal_spacing=0.15,
90
+ )
91
+
92
+ # Define colors
93
+ ref_color = "rgba(31, 119, 180, 0.8)" # Blue with 0.8 opacity
94
+ mon_color = "rgba(255, 127, 14, 0.8)" # Orange with 0.8 opacity
95
+
96
+ # Dictionary to store tables for each class
97
+ tables = {}
98
+ all_passed = True # Track overall pass/fail
99
+
100
+ # Add histograms and create tables for each class
101
+ for i, class_value in enumerate(classes, start=1):
102
+ # Get probabilities for current class
103
+ ref_probs = df_ref[df_ref[datasets[0].target_column] == class_value][
104
+ "probabilities"
105
+ ]
106
+ mon_probs = df_mon[df_mon[datasets[1].target_column] == class_value][
107
+ "probabilities"
108
+ ]
109
+
110
+ # Calculate distribution moments
111
+ ref_stats = {
112
+ "Mean": np.mean(ref_probs),
113
+ "Variance": np.var(ref_probs),
114
+ "Skewness": stats.skew(ref_probs),
115
+ "Kurtosis": stats.kurtosis(ref_probs),
116
+ }
117
+
118
+ mon_stats = {
119
+ "Mean": np.mean(mon_probs),
120
+ "Variance": np.var(mon_probs),
121
+ "Skewness": stats.skew(mon_probs),
122
+ "Kurtosis": stats.kurtosis(mon_probs),
123
+ }
124
+
125
+ # Create table for this class
126
+ table_data = []
127
+ class_passed = True # Track pass/fail for this class
128
+
129
+ for stat_name in ["Mean", "Variance", "Skewness", "Kurtosis"]:
130
+ ref_val = ref_stats[stat_name]
131
+ mon_val = mon_stats[stat_name]
132
+ drift = (
133
+ ((mon_val - ref_val) / abs(ref_val)) * 100 if ref_val != 0 else np.inf
134
+ )
135
+ passed = abs(drift) < drift_pct_threshold
136
+ class_passed &= passed # Update class pass/fail
137
+
138
+ table_data.append(
139
+ {
140
+ "Statistic": stat_name,
141
+ "Reference": round(ref_val, 4),
142
+ "Monitoring": round(mon_val, 4),
143
+ "Drift (%)": round(drift, 2),
144
+ "Pass/Fail": "Pass" if passed else "Fail",
145
+ }
146
+ )
147
+
148
+ tables[f"Class {class_value}"] = pd.DataFrame(table_data)
149
+ all_passed &= class_passed # Update overall pass/fail
150
+
151
+ # Reference dataset histogram
152
+ fig.add_trace(
153
+ go.Histogram(
154
+ x=ref_probs,
155
+ name=f"Reference - Class {class_value}",
156
+ marker_color=ref_color,
157
+ showlegend=True,
158
+ legendrank=i * 2 - 1,
159
+ ),
160
+ row=i,
161
+ col=1,
162
+ )
163
+
164
+ # Monitoring dataset histogram
165
+ fig.add_trace(
166
+ go.Histogram(
167
+ x=mon_probs,
168
+ name=f"Monitoring - Class {class_value}",
169
+ marker_color=mon_color,
170
+ showlegend=True,
171
+ legendrank=i * 2,
172
+ ),
173
+ row=i,
174
+ col=1,
175
+ )
176
+
177
+ # Update layout
178
+ fig.update_layout(
179
+ title_text=title,
180
+ barmode="overlay",
181
+ height=300 * len(classes),
182
+ width=1000,
183
+ showlegend=True,
184
+ )
185
+
186
+ # Update axes labels and add separate legends for each subplot
187
+ for i in range(len(classes)):
188
+ fig.update_xaxes(title_text="Probability", row=i + 1, col=1)
189
+ fig.update_yaxes(title_text="Frequency", row=i + 1, col=1)
190
+
191
+ # Add separate legend for each subplot
192
+ fig.update_layout(
193
+ **{
194
+ f'legend{i+1 if i > 0 else ""}': dict(
195
+ yanchor="middle",
196
+ y=1 - (i / len(classes)) - (0.5 / len(classes)),
197
+ xanchor="left",
198
+ x=1.05,
199
+ tracegroupgap=5,
200
+ )
201
+ }
202
+ )
203
+
204
+ return fig, tables, all_passed
@@ -0,0 +1,98 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ import plotly.graph_objects as go
6
+ from plotly.subplots import make_subplots
7
+
8
+ from validmind import tags, tasks
9
+
10
+
11
+ @tags("visualization")
12
+ @tasks("monitoring")
13
+ def PredictionQuantilesAcrossFeatures(datasets, model):
14
+ """
15
+ Assesses differences in model prediction distributions across individual features between reference
16
+ and monitoring datasets through quantile analysis.
17
+
18
+ ### Purpose
19
+
20
+ This test aims to visualize how prediction distributions vary across feature values by showing
21
+ quantile information between reference and monitoring datasets. It helps identify significant
22
+ shifts in prediction patterns and potential areas of model instability.
23
+
24
+ ### Test Mechanism
25
+
26
+ The test generates box plots for each feature, comparing prediction probability distributions
27
+ between the reference and monitoring datasets. Each plot consists of two subplots showing the
28
+ quantile distribution of predictions: one for reference data and one for monitoring data.
29
+
30
+ ### Signs of High Risk
31
+
32
+ - Significant differences in prediction distributions between reference and monitoring data
33
+ - Unexpected shifts in prediction quantiles across feature values
34
+ - Large changes in prediction variability between datasets
35
+
36
+ ### Strengths
37
+
38
+ - Provides clear visualization of prediction distribution changes
39
+ - Shows outliers and variability in predictions across features
40
+ - Enables quick identification of problematic feature ranges
41
+
42
+ ### Limitations
43
+
44
+ - May not capture complex relationships between features and predictions
45
+ - Quantile analysis may smooth over important individual predictions
46
+ - Requires careful interpretation of distribution changes
47
+ """
48
+
49
+ feature_columns = datasets[0].feature_columns
50
+ y_prob_reference = datasets[0].y_prob(model)
51
+ y_prob_monitoring = datasets[1].y_prob(model)
52
+
53
+ figures_to_save = []
54
+ for column in feature_columns:
55
+ # Create subplot
56
+ fig = make_subplots(1, 2, subplot_titles=("Reference", "Monitoring"))
57
+
58
+ # Add reference box plot
59
+ fig.add_trace(
60
+ go.Box(
61
+ x=datasets[0].df[column],
62
+ y=y_prob_reference,
63
+ name="Reference",
64
+ boxpoints="outliers",
65
+ marker_color="blue",
66
+ ),
67
+ row=1,
68
+ col=1,
69
+ )
70
+
71
+ # Add monitoring box plot
72
+ fig.add_trace(
73
+ go.Box(
74
+ x=datasets[1].df[column],
75
+ y=y_prob_monitoring,
76
+ name="Monitoring",
77
+ boxpoints="outliers",
78
+ marker_color="red",
79
+ ),
80
+ row=1,
81
+ col=2,
82
+ )
83
+
84
+ # Update layout
85
+ fig.update_layout(
86
+ title=f"Prediction Distributions vs {column}",
87
+ showlegend=False,
88
+ width=800,
89
+ height=400,
90
+ )
91
+
92
+ # Update axes
93
+ fig.update_xaxes(title=column)
94
+ fig.update_yaxes(title="Prediction Value")
95
+
96
+ figures_to_save.append(fig)
97
+
98
+ return tuple(figures_to_save)
@@ -0,0 +1,150 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import plotly.graph_objects as go
9
+ from sklearn.metrics import roc_auc_score, roc_curve
10
+
11
+ from validmind import tags, tasks
12
+ from validmind.errors import SkipTestError
13
+ from validmind.vm_models import VMDataset, VMModel
14
+
15
+
16
+ @tags(
17
+ "sklearn",
18
+ "binary_classification",
19
+ "model_performance",
20
+ "visualization",
21
+ )
22
+ @tasks("classification", "text_classification")
23
+ def ROCCurveDrift(datasets: List[VMDataset], model: VMModel):
24
+ """
25
+ Compares ROC curves between reference and monitoring datasets.
26
+
27
+ ### Purpose
28
+
29
+ The ROC Curve Drift test is designed to evaluate changes in the model's discriminative ability
30
+ over time. By comparing Receiver Operating Characteristic (ROC) curves between reference and
31
+ monitoring datasets, this test helps identify whether the model maintains its ability to
32
+ distinguish between classes across different decision thresholds. This is crucial for
33
+ understanding if the model's trade-off between sensitivity and specificity remains stable
34
+ in production.
35
+
36
+ ### Test Mechanism
37
+
38
+ This test proceeds by generating ROC curves for both reference and monitoring datasets. For each
39
+ dataset, it plots the True Positive Rate against the False Positive Rate across all possible
40
+ classification thresholds. The test also computes AUC scores and visualizes the difference
41
+ between ROC curves, providing both graphical and numerical assessments of discrimination
42
+ stability. Special attention is paid to regions where curves diverge significantly.
43
+
44
+ ### Signs of High Risk
45
+
46
+ - Large differences between reference and monitoring ROC curves
47
+ - Significant drop in AUC score for monitoring dataset
48
+ - Systematic differences in specific FPR regions
49
+ - Changes in optimal operating points
50
+ - Inconsistent performance across different thresholds
51
+ - Unexpected crossovers between curves
52
+
53
+ ### Strengths
54
+
55
+ - Provides comprehensive view of discriminative ability
56
+ - Identifies specific threshold ranges with drift
57
+ - Enables visualization of performance differences
58
+ - Includes AUC comparison for overall assessment
59
+ - Supports threshold-independent evaluation
60
+ - Maintains interpretable performance metrics
61
+
62
+ ### Limitations
63
+
64
+ - Limited to binary classification problems
65
+ - May be sensitive to class distribution changes
66
+ - Cannot suggest optimal threshold adjustments
67
+ - Requires visual inspection for detailed analysis
68
+ - Complex interpretation of curve differences
69
+ - May not capture subtle performance changes
70
+ """
71
+ # Check for binary classification
72
+ if len(np.unique(datasets[0].y)) > 2:
73
+ raise SkipTestError(
74
+ "ROC Curve Drift is only supported for binary classification models"
75
+ )
76
+
77
+ # Calculate ROC curves for reference dataset
78
+ y_prob_ref = datasets[0].y_prob(model)
79
+ y_true_ref = datasets[0].y.astype(y_prob_ref.dtype).flatten()
80
+ fpr_ref, tpr_ref, _ = roc_curve(y_true_ref, y_prob_ref, drop_intermediate=False)
81
+ auc_ref = roc_auc_score(y_true_ref, y_prob_ref)
82
+
83
+ # Calculate ROC curves for monitoring dataset
84
+ y_prob_mon = datasets[1].y_prob(model)
85
+ y_true_mon = datasets[1].y.astype(y_prob_mon.dtype).flatten()
86
+ fpr_mon, tpr_mon, _ = roc_curve(y_true_mon, y_prob_mon, drop_intermediate=False)
87
+ auc_mon = roc_auc_score(y_true_mon, y_prob_mon)
88
+
89
+ # Create superimposed ROC curves plot
90
+ fig1 = go.Figure()
91
+
92
+ fig1.add_trace(
93
+ go.Scatter(
94
+ x=fpr_ref,
95
+ y=tpr_ref,
96
+ mode="lines",
97
+ name=f"Reference (AUC = {auc_ref:.3f})",
98
+ line=dict(color="blue", width=2),
99
+ )
100
+ )
101
+
102
+ fig1.add_trace(
103
+ go.Scatter(
104
+ x=fpr_mon,
105
+ y=tpr_mon,
106
+ mode="lines",
107
+ name=f"Monitoring (AUC = {auc_mon:.3f})",
108
+ line=dict(color="red", width=2),
109
+ )
110
+ )
111
+
112
+ fig1.update_layout(
113
+ title="ROC Curves Comparison",
114
+ xaxis=dict(title="False Positive Rate"),
115
+ yaxis=dict(title="True Positive Rate"),
116
+ width=700,
117
+ height=500,
118
+ )
119
+
120
+ # Interpolate monitoring TPR to match reference FPR points
121
+ tpr_mon_interp = np.interp(fpr_ref, fpr_mon, tpr_mon)
122
+
123
+ # Calculate TPR difference
124
+ tpr_diff = tpr_mon_interp - tpr_ref
125
+
126
+ # Create difference plot
127
+ fig2 = go.Figure()
128
+
129
+ fig2.add_trace(
130
+ go.Scatter(
131
+ x=fpr_ref,
132
+ y=tpr_diff,
133
+ mode="lines",
134
+ name="TPR Difference",
135
+ line=dict(color="purple", width=2),
136
+ )
137
+ )
138
+
139
+ # Add horizontal line at y=0
140
+ fig2.add_hline(y=0, line=dict(color="grey", dash="dash"), name="No Difference")
141
+
142
+ fig2.update_layout(
143
+ title="ROC Curve Difference (Monitoring - Reference)",
144
+ xaxis=dict(title="False Positive Rate"),
145
+ yaxis=dict(title="TPR Difference"),
146
+ width=700,
147
+ height=500,
148
+ )
149
+
150
+ return fig1, fig2