validmind 2.7.5__py3-none-any.whl → 2.7.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (149) hide show
  1. validmind/__init__.py +2 -0
  2. validmind/__version__.py +1 -1
  3. validmind/api_client.py +8 -1
  4. validmind/datasets/credit_risk/lending_club.py +352 -87
  5. validmind/html_templates/content_blocks.py +1 -1
  6. validmind/tests/__types__.py +17 -0
  7. validmind/tests/data_validation/ACFandPACFPlot.py +6 -2
  8. validmind/tests/data_validation/AutoMA.py +2 -2
  9. validmind/tests/data_validation/BivariateScatterPlots.py +4 -2
  10. validmind/tests/data_validation/BoxPierce.py +2 -2
  11. validmind/tests/data_validation/ClassImbalance.py +2 -1
  12. validmind/tests/data_validation/DatasetDescription.py +11 -2
  13. validmind/tests/data_validation/DatasetSplit.py +2 -2
  14. validmind/tests/data_validation/DickeyFullerGLS.py +2 -2
  15. validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +8 -2
  16. validmind/tests/data_validation/HighCardinality.py +9 -2
  17. validmind/tests/data_validation/HighPearsonCorrelation.py +18 -4
  18. validmind/tests/data_validation/IQROutliersBarPlot.py +9 -2
  19. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -2
  20. validmind/tests/data_validation/MissingValuesBarPlot.py +12 -9
  21. validmind/tests/data_validation/MutualInformation.py +6 -8
  22. validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -2
  23. validmind/tests/data_validation/ProtectedClassesCombination.py +6 -1
  24. validmind/tests/data_validation/ProtectedClassesDescription.py +1 -1
  25. validmind/tests/data_validation/ProtectedClassesDisparity.py +4 -5
  26. validmind/tests/data_validation/ProtectedClassesThresholdOptimizer.py +1 -4
  27. validmind/tests/data_validation/RollingStatsPlot.py +21 -10
  28. validmind/tests/data_validation/ScatterPlot.py +3 -5
  29. validmind/tests/data_validation/ScoreBandDefaultRates.py +2 -1
  30. validmind/tests/data_validation/SeasonalDecompose.py +12 -2
  31. validmind/tests/data_validation/Skewness.py +6 -3
  32. validmind/tests/data_validation/SpreadPlot.py +8 -3
  33. validmind/tests/data_validation/TabularCategoricalBarPlots.py +4 -2
  34. validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -2
  35. validmind/tests/data_validation/TargetRateBarPlots.py +4 -3
  36. validmind/tests/data_validation/TimeSeriesFrequency.py +7 -2
  37. validmind/tests/data_validation/TimeSeriesMissingValues.py +14 -10
  38. validmind/tests/data_validation/TimeSeriesOutliers.py +1 -5
  39. validmind/tests/data_validation/WOEBinPlots.py +2 -2
  40. validmind/tests/data_validation/WOEBinTable.py +11 -9
  41. validmind/tests/data_validation/nlp/CommonWords.py +2 -2
  42. validmind/tests/data_validation/nlp/Hashtags.py +2 -2
  43. validmind/tests/data_validation/nlp/LanguageDetection.py +9 -6
  44. validmind/tests/data_validation/nlp/Mentions.py +9 -6
  45. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +2 -2
  46. validmind/tests/data_validation/nlp/Punctuations.py +4 -2
  47. validmind/tests/data_validation/nlp/Sentiment.py +2 -2
  48. validmind/tests/data_validation/nlp/StopWords.py +5 -4
  49. validmind/tests/data_validation/nlp/TextDescription.py +2 -2
  50. validmind/tests/data_validation/nlp/Toxicity.py +2 -2
  51. validmind/tests/model_validation/BertScore.py +2 -2
  52. validmind/tests/model_validation/BleuScore.py +2 -2
  53. validmind/tests/model_validation/ClusterSizeDistribution.py +2 -2
  54. validmind/tests/model_validation/ContextualRecall.py +2 -2
  55. validmind/tests/model_validation/FeaturesAUC.py +2 -2
  56. validmind/tests/model_validation/MeteorScore.py +2 -2
  57. validmind/tests/model_validation/ModelPredictionResiduals.py +2 -2
  58. validmind/tests/model_validation/RegardScore.py +6 -2
  59. validmind/tests/model_validation/RegressionResidualsPlot.py +4 -3
  60. validmind/tests/model_validation/RougeScore.py +6 -5
  61. validmind/tests/model_validation/TimeSeriesPredictionWithCI.py +11 -2
  62. validmind/tests/model_validation/TokenDisparity.py +2 -2
  63. validmind/tests/model_validation/ToxicityScore.py +10 -2
  64. validmind/tests/model_validation/embeddings/ClusterDistribution.py +9 -3
  65. validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +16 -2
  66. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +5 -3
  67. validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +2 -2
  68. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +14 -4
  69. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -2
  70. validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +16 -2
  71. validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +2 -2
  72. validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +4 -5
  73. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +4 -2
  74. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +4 -2
  75. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +4 -2
  76. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +4 -2
  77. validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +8 -6
  78. validmind/tests/model_validation/embeddings/utils.py +11 -1
  79. validmind/tests/model_validation/ragas/AnswerCorrectness.py +2 -1
  80. validmind/tests/model_validation/ragas/AspectCritic.py +11 -7
  81. validmind/tests/model_validation/ragas/ContextEntityRecall.py +2 -1
  82. validmind/tests/model_validation/ragas/ContextPrecision.py +2 -1
  83. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +2 -1
  84. validmind/tests/model_validation/ragas/ContextRecall.py +2 -1
  85. validmind/tests/model_validation/ragas/Faithfulness.py +2 -1
  86. validmind/tests/model_validation/ragas/NoiseSensitivity.py +2 -1
  87. validmind/tests/model_validation/ragas/ResponseRelevancy.py +2 -1
  88. validmind/tests/model_validation/ragas/SemanticSimilarity.py +2 -1
  89. validmind/tests/model_validation/sklearn/CalibrationCurve.py +3 -2
  90. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +2 -5
  91. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -2
  92. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +2 -2
  93. validmind/tests/model_validation/sklearn/FeatureImportance.py +1 -14
  94. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +6 -3
  95. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +2 -2
  96. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +8 -4
  97. validmind/tests/model_validation/sklearn/ModelParameters.py +1 -0
  98. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -3
  99. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +2 -2
  100. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +20 -16
  101. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +4 -2
  102. validmind/tests/model_validation/sklearn/ROCCurve.py +1 -1
  103. validmind/tests/model_validation/sklearn/RegressionR2Square.py +7 -9
  104. validmind/tests/model_validation/sklearn/RegressionR2SquareComparison.py +1 -3
  105. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +2 -1
  106. validmind/tests/model_validation/sklearn/ScoreProbabilityAlignment.py +2 -1
  107. validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -3
  108. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +9 -1
  109. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +1 -1
  110. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +11 -4
  111. validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +1 -3
  112. validmind/tests/model_validation/statsmodels/GINITable.py +7 -15
  113. validmind/tests/model_validation/statsmodels/Lilliefors.py +2 -2
  114. validmind/tests/model_validation/statsmodels/RegressionCoeffs.py +1 -1
  115. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +2 -2
  116. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +5 -2
  117. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +5 -2
  118. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +7 -7
  119. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +2 -2
  120. validmind/tests/ongoing_monitoring/CalibrationCurveDrift.py +220 -0
  121. validmind/tests/ongoing_monitoring/ClassDiscriminationDrift.py +155 -0
  122. validmind/tests/ongoing_monitoring/ClassImbalanceDrift.py +146 -0
  123. validmind/tests/ongoing_monitoring/ClassificationAccuracyDrift.py +148 -0
  124. validmind/tests/ongoing_monitoring/ConfusionMatrixDrift.py +193 -0
  125. validmind/tests/ongoing_monitoring/CumulativePredictionProbabilitiesDrift.py +178 -0
  126. validmind/tests/ongoing_monitoring/FeatureDrift.py +120 -120
  127. validmind/tests/ongoing_monitoring/PredictionAcrossEachFeature.py +18 -23
  128. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +86 -44
  129. validmind/tests/ongoing_monitoring/PredictionProbabilitiesHistogramDrift.py +204 -0
  130. validmind/tests/ongoing_monitoring/PredictionQuantilesAcrossFeatures.py +98 -0
  131. validmind/tests/ongoing_monitoring/ROCCurveDrift.py +150 -0
  132. validmind/tests/ongoing_monitoring/ScoreBandsDrift.py +212 -0
  133. validmind/tests/ongoing_monitoring/ScorecardHistogramDrift.py +209 -0
  134. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +91 -13
  135. validmind/tests/prompt_validation/Bias.py +13 -9
  136. validmind/tests/prompt_validation/Clarity.py +13 -9
  137. validmind/tests/prompt_validation/Conciseness.py +13 -9
  138. validmind/tests/prompt_validation/Delimitation.py +13 -9
  139. validmind/tests/prompt_validation/NegativeInstruction.py +14 -11
  140. validmind/tests/prompt_validation/Robustness.py +6 -2
  141. validmind/tests/prompt_validation/Specificity.py +13 -9
  142. validmind/tests/run.py +6 -0
  143. validmind/utils.py +7 -8
  144. validmind/vm_models/dataset/dataset.py +0 -4
  145. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/METADATA +2 -3
  146. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/RECORD +149 -138
  147. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/WHEEL +1 -1
  148. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/LICENSE +0 -0
  149. {validmind-2.7.5.dist-info → validmind-2.7.7.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,193 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import pandas as pd
9
+ from sklearn.metrics import confusion_matrix
10
+
11
+ from validmind import tags, tasks
12
+ from validmind.vm_models import VMDataset, VMModel
13
+
14
+
15
+ @tags(
16
+ "sklearn", "binary_classification", "multiclass_classification", "model_performance"
17
+ )
18
+ @tasks("classification", "text_classification")
19
+ def ConfusionMatrixDrift(
20
+ datasets: List[VMDataset], model: VMModel, drift_pct_threshold=20
21
+ ):
22
+ """
23
+ Compares confusion matrix metrics between reference and monitoring datasets.
24
+
25
+ ### Purpose
26
+
27
+ The Confusion Matrix Drift test is designed to evaluate changes in the model's error patterns
28
+ over time. By comparing confusion matrix elements between reference and monitoring datasets, this
29
+ test helps identify whether the model maintains consistent prediction behavior in production. This
30
+ is crucial for understanding if the model's error patterns have shifted and whether specific types
31
+ of misclassifications have become more prevalent.
32
+
33
+ ### Test Mechanism
34
+
35
+ This test proceeds by generating confusion matrices for both reference and monitoring datasets.
36
+ For binary classification, it tracks True Positives, True Negatives, False Positives, and False
37
+ Negatives as percentages of total predictions. For multiclass problems, it analyzes per-class
38
+ metrics including true positives and error rates. The test quantifies drift as percentage changes
39
+ in these metrics between datasets, providing detailed insight into shifting prediction patterns.
40
+
41
+ ### Signs of High Risk
42
+
43
+ - Large drifts in confusion matrix elements exceeding threshold
44
+ - Systematic changes in false positive or false negative rates
45
+ - Inconsistent changes across different classes
46
+ - Significant shifts in error patterns for specific classes
47
+ - Unexpected improvements in certain metrics
48
+ - Divergent trends between different types of errors
49
+
50
+ ### Strengths
51
+
52
+ - Provides detailed analysis of prediction behavior
53
+ - Identifies specific types of prediction changes
54
+ - Enables early detection of systematic errors
55
+ - Includes comprehensive error pattern analysis
56
+ - Supports both binary and multiclass problems
57
+ - Maintains interpretable percentage-based metrics
58
+
59
+ ### Limitations
60
+
61
+ - May be sensitive to class distribution changes
62
+ - Cannot identify root causes of prediction drift
63
+ - Requires sufficient samples for reliable comparison
64
+ - Limited to hard predictions (not probabilities)
65
+ - May not capture subtle changes in decision boundaries
66
+ - Complex interpretation for multiclass problems
67
+ """
68
+ # Get predictions and true values for reference dataset
69
+ y_pred_ref = datasets[0].y_pred(model)
70
+ y_true_ref = datasets[0].y.astype(y_pred_ref.dtype)
71
+
72
+ # Get predictions and true values for monitoring dataset
73
+ y_pred_mon = datasets[1].y_pred(model)
74
+ y_true_mon = datasets[1].y.astype(y_pred_mon.dtype)
75
+
76
+ # Get unique labels from reference dataset
77
+ labels = np.unique(y_true_ref)
78
+ labels = sorted(labels.tolist())
79
+
80
+ # Calculate confusion matrices
81
+ cm_ref = confusion_matrix(y_true_ref, y_pred_ref, labels=labels)
82
+ cm_mon = confusion_matrix(y_true_mon, y_pred_mon, labels=labels)
83
+
84
+ # Get total counts
85
+ total_ref = len(y_true_ref)
86
+ total_mon = len(y_true_mon)
87
+
88
+ # Create sample counts table
89
+ counts_data = {
90
+ "Dataset": ["Reference", "Monitoring"],
91
+ "Total": [total_ref, total_mon],
92
+ }
93
+
94
+ # Add per-class counts
95
+ for label in labels:
96
+ label_str = f"Class_{label}"
97
+ counts_data[label_str] = [
98
+ np.sum(y_true_ref == label),
99
+ np.sum(y_true_mon == label),
100
+ ]
101
+
102
+ counts_df = pd.DataFrame(counts_data)
103
+
104
+ # Create confusion matrix metrics
105
+ metrics = []
106
+
107
+ if len(labels) == 2:
108
+ # Binary classification
109
+ tn_ref, fp_ref, fn_ref, tp_ref = cm_ref.ravel()
110
+ tn_mon, fp_mon, fn_mon, tp_mon = cm_mon.ravel()
111
+
112
+ confusion_elements = [
113
+ ("True Negatives (%)", tn_ref / total_ref * 100, tn_mon / total_mon * 100),
114
+ ("False Positives (%)", fp_ref / total_ref * 100, fp_mon / total_mon * 100),
115
+ ("False Negatives (%)", fn_ref / total_ref * 100, fn_mon / total_mon * 100),
116
+ ("True Positives (%)", tp_ref / total_ref * 100, tp_mon / total_mon * 100),
117
+ ]
118
+
119
+ for name, ref_val, mon_val in confusion_elements:
120
+ metrics.append(
121
+ {
122
+ "Metric": name,
123
+ "Reference": round(ref_val, 2),
124
+ "Monitoring": round(mon_val, 2),
125
+ }
126
+ )
127
+
128
+ else:
129
+ # Multiclass - calculate per-class metrics
130
+ for i, label in enumerate(labels):
131
+ # True Positives for this class
132
+ tp_ref = cm_ref[i, i]
133
+ tp_mon = cm_mon[i, i]
134
+
135
+ # False Positives (sum of column minus TP)
136
+ fp_ref = cm_ref[:, i].sum() - tp_ref
137
+ fp_mon = cm_mon[:, i].sum() - tp_mon
138
+
139
+ # False Negatives (sum of row minus TP)
140
+ fn_ref = cm_ref[i, :].sum() - tp_ref
141
+ fn_mon = cm_mon[i, :].sum() - tp_mon
142
+
143
+ class_metrics = [
144
+ (
145
+ f"True Positives_{label} (%)",
146
+ tp_ref / total_ref * 100,
147
+ tp_mon / total_mon * 100,
148
+ ),
149
+ (
150
+ f"False Positives_{label} (%)",
151
+ fp_ref / total_ref * 100,
152
+ fp_mon / total_mon * 100,
153
+ ),
154
+ (
155
+ f"False Negatives_{label} (%)",
156
+ fn_ref / total_ref * 100,
157
+ fn_mon / total_mon * 100,
158
+ ),
159
+ ]
160
+
161
+ for name, ref_val, mon_val in class_metrics:
162
+ metrics.append(
163
+ {
164
+ "Metric": name,
165
+ "Reference": round(ref_val, 2),
166
+ "Monitoring": round(mon_val, 2),
167
+ }
168
+ )
169
+
170
+ # Create metrics DataFrame
171
+ metrics_df = pd.DataFrame(metrics)
172
+
173
+ # Calculate drift percentage with direction
174
+ metrics_df["Drift (%)"] = (
175
+ (metrics_df["Monitoring"] - metrics_df["Reference"])
176
+ / metrics_df["Reference"].abs()
177
+ * 100
178
+ ).round(2)
179
+
180
+ # Add Pass/Fail column based on absolute drift
181
+ metrics_df["Pass/Fail"] = (
182
+ metrics_df["Drift (%)"]
183
+ .abs()
184
+ .apply(lambda x: "Pass" if x < drift_pct_threshold else "Fail")
185
+ )
186
+
187
+ # Calculate overall pass/fail
188
+ pass_fail_bool = (metrics_df["Pass/Fail"] == "Pass").all()
189
+
190
+ return (
191
+ {"Confusion Matrix Metrics": metrics_df, "Sample Counts": counts_df},
192
+ pass_fail_bool,
193
+ )
@@ -0,0 +1,178 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from typing import List
6
+
7
+ import numpy as np
8
+ import plotly.graph_objects as go
9
+ from plotly.subplots import make_subplots
10
+
11
+ from validmind import tags, tasks
12
+ from validmind.vm_models import VMDataset, VMModel
13
+
14
+
15
+ @tags("visualization", "credit_risk")
16
+ @tasks("classification")
17
+ def CumulativePredictionProbabilitiesDrift(
18
+ datasets: List[VMDataset],
19
+ model: VMModel,
20
+ ):
21
+ """
22
+ Compares cumulative prediction probability distributions between reference and monitoring datasets.
23
+
24
+ ### Purpose
25
+
26
+ The Cumulative Prediction Probabilities Drift test is designed to evaluate changes in the model's
27
+ probability predictions over time. By comparing cumulative distribution functions of predicted
28
+ probabilities between reference and monitoring datasets, this test helps identify whether the
29
+ model's probability assignments remain stable in production. This is crucial for understanding if
30
+ the model's risk assessment behavior has shifted and whether its probability calibration remains
31
+ consistent.
32
+
33
+ ### Test Mechanism
34
+
35
+ This test proceeds by generating cumulative distribution functions (CDFs) of predicted probabilities
36
+ for both reference and monitoring datasets. For each class, it plots the cumulative proportion of
37
+ predictions against probability values, enabling direct comparison of probability distributions.
38
+ The test visualizes both the CDFs and their differences, providing insight into how probability
39
+ assignments have shifted across the entire probability range.
40
+
41
+ ### Signs of High Risk
42
+
43
+ - Large gaps between reference and monitoring CDFs
44
+ - Systematic shifts in probability assignments
45
+ - Concentration of differences in specific probability ranges
46
+ - Changes in the shape of probability distributions
47
+ - Unexpected patterns in cumulative differences
48
+ - Significant shifts in probability thresholds
49
+
50
+ ### Strengths
51
+
52
+ - Provides comprehensive view of probability changes
53
+ - Identifies specific probability ranges with drift
54
+ - Enables visualization of distribution differences
55
+ - Supports analysis across multiple classes
56
+ - Maintains interpretable probability scale
57
+ - Captures subtle changes in probability assignments
58
+
59
+ ### Limitations
60
+
61
+ - Does not provide single drift metric
62
+ - May be complex to interpret for multiple classes
63
+ - Cannot suggest probability recalibration
64
+ - Requires visual inspection for assessment
65
+ - Sensitive to sample size differences
66
+ - May not capture class-specific calibration issues
67
+ """
68
+ # Get predictions and true values
69
+ y_prob_ref = datasets[0].y_prob(model)
70
+ df_ref = datasets[0].df.copy()
71
+ df_ref["probabilities"] = y_prob_ref
72
+
73
+ y_prob_mon = datasets[1].y_prob(model)
74
+ df_mon = datasets[1].df.copy()
75
+ df_mon["probabilities"] = y_prob_mon
76
+
77
+ # Get unique classes
78
+ classes = sorted(df_ref[datasets[0].target_column].unique())
79
+
80
+ # Define colors
81
+ ref_color = "rgba(31, 119, 180, 0.8)" # Blue with 0.8 opacity
82
+ mon_color = "rgba(255, 127, 14, 0.8)" # Orange with 0.8 opacity
83
+ diff_color = "rgba(148, 103, 189, 0.8)" # Purple with 0.8 opacity
84
+
85
+ figures = []
86
+ for class_value in classes:
87
+ # Create figure with secondary y-axis
88
+ fig = make_subplots(
89
+ rows=2,
90
+ cols=1,
91
+ subplot_titles=[
92
+ f"Cumulative Distributions - Class {class_value}",
93
+ "Difference (Monitoring - Reference)",
94
+ ],
95
+ vertical_spacing=0.15,
96
+ shared_xaxes=True,
97
+ )
98
+
99
+ # Get probabilities for current class
100
+ ref_probs = df_ref[df_ref[datasets[0].target_column] == class_value][
101
+ "probabilities"
102
+ ]
103
+ mon_probs = df_mon[df_mon[datasets[1].target_column] == class_value][
104
+ "probabilities"
105
+ ]
106
+
107
+ # Calculate cumulative distributions
108
+ ref_sorted = np.sort(ref_probs)
109
+ ref_cumsum = np.arange(len(ref_sorted)) / float(len(ref_sorted))
110
+
111
+ mon_sorted = np.sort(mon_probs)
112
+ mon_cumsum = np.arange(len(mon_sorted)) / float(len(mon_sorted))
113
+
114
+ # Reference dataset cumulative curve
115
+ fig.add_trace(
116
+ go.Scatter(
117
+ x=ref_sorted,
118
+ y=ref_cumsum,
119
+ mode="lines",
120
+ name="Reference",
121
+ line=dict(color=ref_color, width=2),
122
+ ),
123
+ row=1,
124
+ col=1,
125
+ )
126
+
127
+ # Monitoring dataset cumulative curve
128
+ fig.add_trace(
129
+ go.Scatter(
130
+ x=mon_sorted,
131
+ y=mon_cumsum,
132
+ mode="lines",
133
+ name="Monitoring",
134
+ line=dict(color=mon_color, width=2),
135
+ ),
136
+ row=1,
137
+ col=1,
138
+ )
139
+
140
+ # Calculate and plot difference
141
+ # Interpolate monitoring values to match reference x-points
142
+ mon_interp = np.interp(ref_sorted, mon_sorted, mon_cumsum)
143
+ difference = mon_interp - ref_cumsum
144
+
145
+ fig.add_trace(
146
+ go.Scatter(
147
+ x=ref_sorted,
148
+ y=difference,
149
+ mode="lines",
150
+ name="Difference",
151
+ line=dict(color=diff_color, width=2),
152
+ ),
153
+ row=2,
154
+ col=1,
155
+ )
156
+
157
+ # Add horizontal line at y=0 for difference plot
158
+ fig.add_hline(y=0, line=dict(color="grey", dash="dash"), row=2, col=1)
159
+
160
+ # Update layout
161
+ fig.update_layout(
162
+ height=600,
163
+ width=800,
164
+ showlegend=True,
165
+ legend=dict(yanchor="middle", y=0.9, xanchor="left", x=1.05),
166
+ )
167
+
168
+ # Update axes
169
+ fig.update_xaxes(title_text="Probability", range=[0, 1], row=2, col=1)
170
+ fig.update_xaxes(range=[0, 1], row=1, col=1)
171
+ fig.update_yaxes(
172
+ title_text="Cumulative Distribution", range=[0, 1], row=1, col=1
173
+ )
174
+ fig.update_yaxes(title_text="Difference", row=2, col=1)
175
+
176
+ figures.append(fig)
177
+
178
+ return tuple(figures)
@@ -2,18 +2,100 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
-
6
- import matplotlib.pyplot as plt
7
5
  import numpy as np
8
6
  import pandas as pd
7
+ import plotly.graph_objects as go
9
8
 
10
9
  from validmind import tags, tasks
11
10
 
12
11
 
12
+ def calculate_psi_score(actual, expected):
13
+ """Calculate PSI score for a single bucket."""
14
+ return (actual - expected) * np.log((actual + 1e-6) / (expected + 1e-6))
15
+
16
+
17
+ def calculate_feature_distributions(
18
+ reference_data, monitoring_data, feature_columns, bins
19
+ ):
20
+ """Calculate population distributions for each feature."""
21
+ # Calculate quantiles from reference data
22
+ quantiles = reference_data[feature_columns].quantile(
23
+ bins, method="single", interpolation="nearest"
24
+ )
25
+
26
+ distributions = {}
27
+ for dataset_name, data in [
28
+ ("reference", reference_data),
29
+ ("monitoring", monitoring_data),
30
+ ]:
31
+ for feature in feature_columns:
32
+ for bin_idx, threshold in enumerate(quantiles[feature]):
33
+ if bin_idx == 0:
34
+ mask = data[feature] < threshold
35
+ else:
36
+ prev_threshold = quantiles[feature][bins[bin_idx - 1]]
37
+ mask = (data[feature] >= prev_threshold) & (
38
+ data[feature] < threshold
39
+ )
40
+
41
+ count = mask.sum()
42
+ proportion = count / len(data)
43
+ distributions[(dataset_name, feature, bins[bin_idx])] = proportion
44
+
45
+ return distributions
46
+
47
+
48
+ def create_distribution_plot(feature_name, reference_dist, monitoring_dist, bins):
49
+ """Create population distribution plot for a feature."""
50
+ fig = go.Figure()
51
+
52
+ # Add reference distribution
53
+ fig.add_trace(
54
+ go.Bar(
55
+ x=list(range(len(bins))),
56
+ y=reference_dist,
57
+ name="Reference",
58
+ marker_color="blue",
59
+ marker_line_color="black",
60
+ marker_line_width=1,
61
+ opacity=0.75,
62
+ )
63
+ )
64
+
65
+ # Add monitoring distribution
66
+ fig.add_trace(
67
+ go.Bar(
68
+ x=list(range(len(bins))),
69
+ y=monitoring_dist,
70
+ name="Monitoring",
71
+ marker_color="green",
72
+ marker_line_color="black",
73
+ marker_line_width=1,
74
+ opacity=0.75,
75
+ )
76
+ )
77
+
78
+ fig.update_layout(
79
+ title=f"Population Distribution: {feature_name}",
80
+ xaxis_title="Bin",
81
+ yaxis_title="Population %",
82
+ barmode="group",
83
+ template="plotly_white",
84
+ showlegend=True,
85
+ width=800,
86
+ height=400,
87
+ )
88
+
89
+ return fig
90
+
91
+
13
92
  @tags("visualization")
14
93
  @tasks("monitoring")
15
94
  def FeatureDrift(
16
- datasets, bins=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9], feature_columns=None
95
+ datasets,
96
+ bins=[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9],
97
+ feature_columns=None,
98
+ psi_threshold=0.2,
17
99
  ):
18
100
  """
19
101
  Evaluates changes in feature distribution over time to identify potential model drift.
@@ -57,130 +139,48 @@ def FeatureDrift(
57
139
  - PSI score interpretation can be overly simplistic for complex datasets.
58
140
  """
59
141
 
60
- # Feature columns for both datasets should be the same if not given
61
- default_feature_columns = datasets[0].feature_columns
62
- feature_columns = feature_columns or default_feature_columns
142
+ # Get feature columns
143
+ feature_columns = feature_columns or datasets[0].feature_columns
63
144
 
64
- x_train_df = datasets[0].x_df()
65
- x_test_df = datasets[1].x_df()
145
+ # Get data
146
+ reference_data = datasets[0].df
147
+ monitoring_data = datasets[1].df
66
148
 
67
- quantiles_train = x_train_df[feature_columns].quantile(
68
- bins, method="single", interpolation="nearest"
149
+ # Calculate distributions
150
+ distributions = calculate_feature_distributions(
151
+ reference_data, monitoring_data, feature_columns, bins
69
152
  )
70
- PSI_QUANTILES = quantiles_train.to_dict()
71
-
72
- PSI_BUCKET_FRAC, col, n = get_psi_buckets(
73
- x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES
74
- )
75
-
76
- def nest(d: dict) -> dict:
77
- result = {}
78
- for key, value in d.items():
79
- target = result
80
- for k in key[:-1]: # traverse all keys but the last
81
- target = target.setdefault(k, {})
82
- target[key[-1]] = value
83
- return result
84
-
85
- PSI_BUCKET_FRAC = nest(PSI_BUCKET_FRAC)
86
153
 
87
- PSI_SCORES = {}
88
- for col in feature_columns:
154
+ # Calculate PSI scores
155
+ psi_scores = {}
156
+ for feature in feature_columns:
89
157
  psi = 0
90
- for n in bins:
91
- actual = PSI_BUCKET_FRAC["test"][col][n]
92
- expected = PSI_BUCKET_FRAC["train"][col][n]
93
- psi_of_bucket = (actual - expected) * np.log(
94
- (actual + 1e-6) / (expected + 1e-6)
95
- )
96
- psi += psi_of_bucket
97
- PSI_SCORES[col] = psi
98
-
99
- psi_df = pd.DataFrame(list(PSI_SCORES.items()), columns=["Features", "PSI Score"])
158
+ for bin_val in bins:
159
+ reference_prop = distributions[("reference", feature, bin_val)]
160
+ monitoring_prop = distributions[("monitoring", feature, bin_val)]
161
+ psi += calculate_psi_score(monitoring_prop, reference_prop)
162
+ psi_scores[feature] = psi
163
+
164
+ # Create PSI score dataframe
165
+ psi_df = pd.DataFrame(list(psi_scores.items()), columns=["Feature", "PSI Score"])
166
+
167
+ # Add Pass/Fail column
168
+ psi_df["Pass/Fail"] = psi_df["PSI Score"].apply(
169
+ lambda x: "Pass" if x < psi_threshold else "Fail"
170
+ )
100
171
 
172
+ # Sort by PSI Score
101
173
  psi_df.sort_values(by=["PSI Score"], inplace=True, ascending=False)
102
174
 
103
- psi_table = [
104
- {"Features": values["Features"], "PSI Score": values["PSI Score"]}
105
- for i, values in enumerate(psi_df.to_dict(orient="records"))
106
- ]
107
-
108
- save_fig = plot_hist(PSI_BUCKET_FRAC, bins)
109
-
110
- final_psi = pd.DataFrame(psi_table)
111
-
112
- return (final_psi, *save_fig)
113
-
114
-
115
- def get_psi_buckets(x_test_df, x_train_df, feature_columns, bins, PSI_QUANTILES):
116
- DATA = {"test": x_test_df, "train": x_train_df}
117
- PSI_BUCKET_FRAC = {}
118
- for table in DATA.keys():
119
- total_count = DATA[table].shape[0]
120
- for col in feature_columns:
121
- count_sum = 0
122
- for n in bins:
123
- if n == 0:
124
- bucket_count = (DATA[table][col] < PSI_QUANTILES[col][n]).sum()
125
- elif n < 9:
126
- bucket_count = (
127
- total_count
128
- - count_sum
129
- - ((DATA[table][col] >= PSI_QUANTILES[col][n]).sum())
130
- )
131
- elif n == 9:
132
- bucket_count = total_count - count_sum
133
- count_sum += bucket_count
134
- PSI_BUCKET_FRAC[table, col, n] = bucket_count / total_count
135
- return PSI_BUCKET_FRAC, col, n
136
-
137
-
138
- def plot_hist(PSI_BUCKET_FRAC, bins):
139
- bin_table_psi = pd.DataFrame(PSI_BUCKET_FRAC)
140
- save_fig = []
141
- for i in range(len(bin_table_psi)):
175
+ # Create distribution plots
176
+ figures = []
177
+ for feature in feature_columns:
178
+ reference_dist = [distributions[("reference", feature, b)] for b in bins]
179
+ monitoring_dist = [distributions[("monitoring", feature, b)] for b in bins]
180
+ fig = create_distribution_plot(feature, reference_dist, monitoring_dist, bins)
181
+ figures.append(fig)
142
182
 
143
- x = pd.DataFrame(
144
- bin_table_psi.iloc[i]["test"].items(),
145
- columns=["Bin", "Population % Reference"],
146
- )
147
- y = pd.DataFrame(
148
- bin_table_psi.iloc[i]["train"].items(),
149
- columns=["Bin", "Population % Monitoring"],
150
- )
151
- xy = x.merge(y, on="Bin")
152
- xy.index = xy["Bin"]
153
- xy = xy.drop(columns="Bin", axis=1)
154
- feature_name = bin_table_psi.index[i]
155
-
156
- n = len(bins)
157
- r = np.arange(n)
158
- width = 0.25
159
-
160
- fig = plt.figure()
161
-
162
- plt.bar(
163
- r,
164
- xy["Population % Reference"],
165
- color="b",
166
- width=width,
167
- edgecolor="black",
168
- label="Reference {0}".format(feature_name),
169
- )
170
- plt.bar(
171
- r + width,
172
- xy["Population % Monitoring"],
173
- color="g",
174
- width=width,
175
- edgecolor="black",
176
- label="Monitoring {0}".format(feature_name),
177
- )
183
+ # Calculate overall pass/fail
184
+ pass_fail_bool = (psi_df["Pass/Fail"] == "Pass").all()
178
185
 
179
- plt.xlabel("Bin")
180
- plt.ylabel("Population %")
181
- plt.title("Histogram of Population Differences {0}".format(feature_name))
182
- plt.legend()
183
- plt.tight_layout()
184
- plt.close()
185
- save_fig.append(fig)
186
- return save_fig
186
+ return ({"PSI Scores": psi_df}, *figures, pass_fail_bool)
@@ -53,30 +53,25 @@ def PredictionAcrossEachFeature(datasets, model):
53
53
  observed during the training of the model.
54
54
  """
55
55
 
56
- df_reference = datasets[0]._df
57
- df_monitoring = datasets[1]._df
56
+ y_prob_reference = datasets[0].y_prob(model)
57
+ y_prob_monitoring = datasets[1].y_prob(model)
58
58
 
59
59
  figures_to_save = []
60
- for column in df_reference:
61
- prediction_prob_column = f"{model.input_id}_probabilities"
62
- prediction_column = f"{model.input_id}_prediction"
63
- if column == prediction_prob_column or column == prediction_column:
64
- pass
65
- else:
66
- fig, axs = plt.subplots(1, 2, figsize=(20, 10), sharey="row")
67
-
68
- ax1, ax2 = axs
69
-
70
- ax1.scatter(df_reference[column], df_reference[prediction_prob_column])
71
- ax2.scatter(df_monitoring[column], df_monitoring[prediction_prob_column])
72
-
73
- ax1.set_title("Reference")
74
- ax1.set_xlabel(column)
75
- ax1.set_ylabel("Prediction Value")
76
-
77
- ax2.set_title("Monitoring")
78
- ax2.set_xlabel(column)
79
- figures_to_save.append(fig)
80
- plt.close()
60
+ for column in datasets[0].feature_columns:
61
+ fig, axs = plt.subplots(1, 2, figsize=(20, 10), sharey="row")
62
+
63
+ ax1, ax2 = axs
64
+
65
+ ax1.scatter(datasets[0].df[column], y_prob_reference)
66
+ ax2.scatter(datasets[1].df[column], y_prob_monitoring)
67
+
68
+ ax1.set_title("Reference")
69
+ ax1.set_xlabel(column)
70
+ ax1.set_ylabel("Prediction Value")
71
+
72
+ ax2.set_title("Monitoring")
73
+ ax2.set_xlabel(column)
74
+ figures_to_save.append(fig)
75
+ plt.close()
81
76
 
82
77
  return tuple(figures_to_save)