validmind 2.5.25__py3-none-any.whl → 2.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. validmind/__init__.py +8 -17
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +66 -85
  4. validmind/ai/test_result_description/context.py +2 -2
  5. validmind/ai/utils.py +26 -1
  6. validmind/api_client.py +43 -79
  7. validmind/client.py +5 -7
  8. validmind/client_config.py +1 -1
  9. validmind/datasets/__init__.py +1 -1
  10. validmind/datasets/classification/customer_churn.py +7 -5
  11. validmind/datasets/nlp/__init__.py +2 -2
  12. validmind/errors.py +6 -10
  13. validmind/html_templates/content_blocks.py +18 -16
  14. validmind/logging.py +21 -16
  15. validmind/tests/__init__.py +28 -5
  16. validmind/tests/__types__.py +186 -170
  17. validmind/tests/_store.py +7 -21
  18. validmind/tests/comparison.py +362 -0
  19. validmind/tests/data_validation/ACFandPACFPlot.py +44 -73
  20. validmind/tests/data_validation/ADF.py +49 -83
  21. validmind/tests/data_validation/AutoAR.py +59 -96
  22. validmind/tests/data_validation/AutoMA.py +59 -96
  23. validmind/tests/data_validation/AutoStationarity.py +66 -114
  24. validmind/tests/data_validation/ClassImbalance.py +48 -117
  25. validmind/tests/data_validation/DatasetDescription.py +180 -209
  26. validmind/tests/data_validation/DatasetSplit.py +50 -75
  27. validmind/tests/data_validation/DescriptiveStatistics.py +59 -85
  28. validmind/tests/data_validation/{DFGLSArch.py → DickeyFullerGLS.py} +44 -76
  29. validmind/tests/data_validation/Duplicates.py +21 -90
  30. validmind/tests/data_validation/EngleGrangerCoint.py +53 -75
  31. validmind/tests/data_validation/HighCardinality.py +32 -80
  32. validmind/tests/data_validation/HighPearsonCorrelation.py +29 -97
  33. validmind/tests/data_validation/IQROutliersBarPlot.py +63 -94
  34. validmind/tests/data_validation/IQROutliersTable.py +40 -80
  35. validmind/tests/data_validation/IsolationForestOutliers.py +41 -63
  36. validmind/tests/data_validation/KPSS.py +33 -81
  37. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +47 -95
  38. validmind/tests/data_validation/MissingValues.py +17 -58
  39. validmind/tests/data_validation/MissingValuesBarPlot.py +61 -87
  40. validmind/tests/data_validation/PhillipsPerronArch.py +56 -79
  41. validmind/tests/data_validation/RollingStatsPlot.py +50 -81
  42. validmind/tests/data_validation/SeasonalDecompose.py +102 -184
  43. validmind/tests/data_validation/Skewness.py +27 -64
  44. validmind/tests/data_validation/SpreadPlot.py +34 -57
  45. validmind/tests/data_validation/TabularCategoricalBarPlots.py +46 -65
  46. validmind/tests/data_validation/TabularDateTimeHistograms.py +23 -45
  47. validmind/tests/data_validation/TabularNumericalHistograms.py +27 -46
  48. validmind/tests/data_validation/TargetRateBarPlots.py +54 -93
  49. validmind/tests/data_validation/TimeSeriesFrequency.py +48 -133
  50. validmind/tests/data_validation/TimeSeriesHistogram.py +24 -3
  51. validmind/tests/data_validation/TimeSeriesLinePlot.py +29 -47
  52. validmind/tests/data_validation/TimeSeriesMissingValues.py +59 -135
  53. validmind/tests/data_validation/TimeSeriesOutliers.py +54 -171
  54. validmind/tests/data_validation/TooManyZeroValues.py +21 -70
  55. validmind/tests/data_validation/UniqueRows.py +23 -62
  56. validmind/tests/data_validation/WOEBinPlots.py +83 -109
  57. validmind/tests/data_validation/WOEBinTable.py +28 -69
  58. validmind/tests/data_validation/ZivotAndrewsArch.py +33 -75
  59. validmind/tests/data_validation/nlp/CommonWords.py +49 -57
  60. validmind/tests/data_validation/nlp/Hashtags.py +27 -49
  61. validmind/tests/data_validation/nlp/LanguageDetection.py +7 -13
  62. validmind/tests/data_validation/nlp/Mentions.py +32 -63
  63. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +89 -14
  64. validmind/tests/data_validation/nlp/Punctuations.py +63 -47
  65. validmind/tests/data_validation/nlp/Sentiment.py +4 -0
  66. validmind/tests/data_validation/nlp/StopWords.py +62 -91
  67. validmind/tests/data_validation/nlp/TextDescription.py +116 -159
  68. validmind/tests/data_validation/nlp/Toxicity.py +12 -4
  69. validmind/tests/decorator.py +33 -242
  70. validmind/tests/load.py +212 -153
  71. validmind/tests/model_validation/BertScore.py +13 -7
  72. validmind/tests/model_validation/BleuScore.py +4 -0
  73. validmind/tests/model_validation/ClusterSizeDistribution.py +24 -47
  74. validmind/tests/model_validation/ContextualRecall.py +3 -0
  75. validmind/tests/model_validation/FeaturesAUC.py +43 -74
  76. validmind/tests/model_validation/MeteorScore.py +3 -0
  77. validmind/tests/model_validation/RegardScore.py +5 -1
  78. validmind/tests/model_validation/RegressionResidualsPlot.py +54 -75
  79. validmind/tests/model_validation/embeddings/ClusterDistribution.py +10 -33
  80. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +11 -29
  81. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +19 -31
  82. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +40 -49
  83. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +29 -15
  84. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +25 -11
  85. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +28 -13
  86. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +67 -38
  87. validmind/tests/model_validation/embeddings/utils.py +53 -0
  88. validmind/tests/model_validation/ragas/AnswerCorrectness.py +37 -32
  89. validmind/tests/model_validation/ragas/{AspectCritique.py → AspectCritic.py} +33 -27
  90. validmind/tests/model_validation/ragas/ContextEntityRecall.py +44 -41
  91. validmind/tests/model_validation/ragas/ContextPrecision.py +40 -35
  92. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +133 -0
  93. validmind/tests/model_validation/ragas/ContextRecall.py +40 -35
  94. validmind/tests/model_validation/ragas/Faithfulness.py +42 -30
  95. validmind/tests/model_validation/ragas/NoiseSensitivity.py +59 -35
  96. validmind/tests/model_validation/ragas/{AnswerRelevance.py → ResponseRelevancy.py} +52 -41
  97. validmind/tests/model_validation/ragas/{AnswerSimilarity.py → SemanticSimilarity.py} +39 -34
  98. validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +13 -16
  99. validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +13 -16
  100. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +51 -89
  101. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +31 -61
  102. validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +118 -83
  103. validmind/tests/model_validation/sklearn/CompletenessScore.py +13 -16
  104. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +62 -94
  105. validmind/tests/model_validation/sklearn/FeatureImportance.py +7 -8
  106. validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +12 -15
  107. validmind/tests/model_validation/sklearn/HomogeneityScore.py +12 -15
  108. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +23 -53
  109. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +60 -74
  110. validmind/tests/model_validation/sklearn/MinimumAccuracy.py +16 -84
  111. validmind/tests/model_validation/sklearn/MinimumF1Score.py +22 -72
  112. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +29 -78
  113. validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +52 -82
  114. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +51 -145
  115. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +60 -78
  116. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +130 -172
  117. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +26 -55
  118. validmind/tests/model_validation/sklearn/ROCCurve.py +43 -77
  119. validmind/tests/model_validation/sklearn/RegressionPerformance.py +41 -94
  120. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +47 -136
  121. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +164 -208
  122. validmind/tests/model_validation/sklearn/SilhouettePlot.py +54 -99
  123. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +50 -124
  124. validmind/tests/model_validation/sklearn/VMeasure.py +12 -15
  125. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +225 -281
  126. validmind/tests/model_validation/statsmodels/AutoARIMA.py +40 -45
  127. validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +22 -47
  128. validmind/tests/model_validation/statsmodels/Lilliefors.py +17 -28
  129. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +37 -81
  130. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +37 -105
  131. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +62 -166
  132. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +57 -119
  133. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +20 -57
  134. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +47 -80
  135. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +2 -0
  136. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +4 -2
  137. validmind/tests/output.py +120 -0
  138. validmind/tests/prompt_validation/Bias.py +55 -98
  139. validmind/tests/prompt_validation/Clarity.py +56 -99
  140. validmind/tests/prompt_validation/Conciseness.py +63 -101
  141. validmind/tests/prompt_validation/Delimitation.py +48 -89
  142. validmind/tests/prompt_validation/NegativeInstruction.py +62 -96
  143. validmind/tests/prompt_validation/Robustness.py +80 -121
  144. validmind/tests/prompt_validation/Specificity.py +61 -95
  145. validmind/tests/prompt_validation/ai_powered_test.py +2 -2
  146. validmind/tests/run.py +314 -496
  147. validmind/tests/test_providers.py +109 -79
  148. validmind/tests/utils.py +91 -0
  149. validmind/unit_metrics/__init__.py +16 -155
  150. validmind/unit_metrics/classification/F1.py +1 -0
  151. validmind/unit_metrics/classification/Precision.py +1 -0
  152. validmind/unit_metrics/classification/ROC_AUC.py +1 -0
  153. validmind/unit_metrics/classification/Recall.py +1 -0
  154. validmind/unit_metrics/regression/AdjustedRSquaredScore.py +1 -0
  155. validmind/unit_metrics/regression/GiniCoefficient.py +1 -0
  156. validmind/unit_metrics/regression/HuberLoss.py +1 -0
  157. validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +1 -0
  158. validmind/unit_metrics/regression/MeanAbsoluteError.py +1 -0
  159. validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +1 -0
  160. validmind/unit_metrics/regression/MeanBiasDeviation.py +1 -0
  161. validmind/unit_metrics/regression/MeanSquaredError.py +1 -0
  162. validmind/unit_metrics/regression/QuantileLoss.py +1 -0
  163. validmind/unit_metrics/regression/RSquaredScore.py +2 -1
  164. validmind/unit_metrics/regression/RootMeanSquaredError.py +1 -0
  165. validmind/utils.py +66 -17
  166. validmind/vm_models/__init__.py +2 -17
  167. validmind/vm_models/dataset/dataset.py +31 -4
  168. validmind/vm_models/figure.py +7 -37
  169. validmind/vm_models/model.py +3 -0
  170. validmind/vm_models/result/__init__.py +7 -0
  171. validmind/vm_models/result/result.jinja +21 -0
  172. validmind/vm_models/result/result.py +337 -0
  173. validmind/vm_models/result/utils.py +160 -0
  174. validmind/vm_models/test_suite/runner.py +16 -54
  175. validmind/vm_models/test_suite/summary.py +3 -3
  176. validmind/vm_models/test_suite/test.py +43 -77
  177. validmind/vm_models/test_suite/test_suite.py +8 -40
  178. validmind-2.6.7.dist-info/METADATA +137 -0
  179. {validmind-2.5.25.dist-info → validmind-2.6.7.dist-info}/RECORD +182 -189
  180. validmind/tests/data_validation/AutoSeasonality.py +0 -190
  181. validmind/tests/metadata.py +0 -59
  182. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +0 -176
  183. validmind/tests/model_validation/ragas/ContextUtilization.py +0 -161
  184. validmind/tests/model_validation/sklearn/ClusterPerformance.py +0 -80
  185. validmind/unit_metrics/composite.py +0 -238
  186. validmind/vm_models/test/metric.py +0 -98
  187. validmind/vm_models/test/metric_result.py +0 -61
  188. validmind/vm_models/test/output_template.py +0 -55
  189. validmind/vm_models/test/result_summary.py +0 -76
  190. validmind/vm_models/test/result_wrapper.py +0 -488
  191. validmind/vm_models/test/test.py +0 -103
  192. validmind/vm_models/test/threshold_test.py +0 -106
  193. validmind/vm_models/test/threshold_test_result.py +0 -75
  194. validmind/vm_models/test_context.py +0 -259
  195. validmind-2.5.25.dist-info/METADATA +0 -118
  196. {validmind-2.5.25.dist-info → validmind-2.6.7.dist-info}/LICENSE +0 -0
  197. {validmind-2.5.25.dist-info → validmind-2.6.7.dist-info}/WHEEL +0 -0
  198. {validmind-2.5.25.dist-info → validmind-2.6.7.dist-info}/entry_points.txt +0 -0
@@ -2,20 +2,33 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
5
+ from typing import Union
6
6
 
7
7
  import plotly.graph_objects as go
8
8
  from sklearn.inspection import permutation_importance
9
9
 
10
+ from validmind import tags, tasks
10
11
  from validmind.errors import SkipTestError
11
12
  from validmind.logging import get_logger
12
- from validmind.vm_models import Figure, Metric
13
+ from validmind.vm_models import VMDataset, VMModel
13
14
 
14
15
  logger = get_logger(__name__)
15
16
 
16
17
 
17
- @dataclass
18
- class PermutationFeatureImportance(Metric):
18
+ @tags(
19
+ "sklearn",
20
+ "binary_classification",
21
+ "multiclass_classification",
22
+ "feature_importance",
23
+ "visualization",
24
+ )
25
+ @tasks("classification", "text_classification")
26
+ def PermutationFeatureImportance(
27
+ model: VMModel,
28
+ dataset: VMDataset,
29
+ fontsize: Union[int, None] = None,
30
+ figure_height: Union[int, None] = None,
31
+ ):
19
32
  """
20
33
  Assesses the significance of each feature in a model by evaluating the impact on model performance when feature
21
34
  values are randomly rearranged.
@@ -55,78 +68,47 @@ class PermutationFeatureImportance(Metric):
55
68
  allocate importance to one and not the other.
56
69
  - Cannot interact with certain libraries like statsmodels, pytorch, catboost, etc., thus limiting its applicability.
57
70
  """
58
-
59
- name = "pfi"
60
- required_inputs = ["model", "dataset"]
61
- default_params = {
62
- "fontsize": None,
63
- "figure_height": 1000,
64
- }
65
- tasks = ["classification", "text_classification"]
66
- tags = [
67
- "sklearn",
68
- "binary_classification",
69
- "multiclass_classification",
70
- "feature_importance",
71
- "visualization",
72
- ]
73
-
74
- def run(self):
75
- x = self.inputs.dataset.x_df()
76
- y = self.inputs.dataset.y_df()
77
-
78
- if self.inputs.model.library in [
79
- "statsmodels",
80
- "pytorch",
81
- "catboost",
82
- "transformers",
83
- "R",
84
- ]:
85
- raise SkipTestError(f"Skipping PFI for {self.inputs.model.library} models")
86
-
87
- pfi_values = permutation_importance(
88
- self.inputs.model.model,
89
- x,
90
- y,
91
- random_state=0,
92
- n_jobs=-2,
93
- )
94
-
95
- pfi = {}
96
- for i, column in enumerate(x.columns):
97
- pfi[column] = [pfi_values["importances_mean"][i]], [
98
- pfi_values["importances_std"][i]
99
- ]
100
-
101
- sorted_idx = pfi_values.importances_mean.argsort()
102
-
103
- fig = go.Figure()
104
- fig.add_trace(
105
- go.Bar(
106
- y=x.columns[sorted_idx],
107
- x=pfi_values.importances[sorted_idx].mean(axis=1).T,
108
- orientation="h",
109
- )
110
- )
111
- fig.update_layout(
112
- title_text="Permutation Importances",
113
- yaxis=dict(
114
- tickmode="linear", # set tick mode to linear
115
- dtick=1, # set interval between ticks
116
- tickfont=dict(
117
- size=self.params["fontsize"]
118
- ), # set the tick label font size
119
- ),
120
- height=self.params["figure_height"], # use figure_height parameter here
121
- )
122
-
123
- return self.cache_results(
124
- metric_value=pfi,
125
- figures=[
126
- Figure(
127
- for_object=self,
128
- key=f"pfi_{self.inputs.dataset.input_id}_{self.inputs.model.input_id}",
129
- figure=fig,
130
- ),
131
- ],
71
+ if model.library in [
72
+ "statsmodels",
73
+ "pytorch",
74
+ "catboost",
75
+ "transformers",
76
+ "R",
77
+ ]:
78
+ raise SkipTestError(f"Skipping PFI for {model.library} models")
79
+
80
+ pfi_values = permutation_importance(
81
+ estimator=model.model,
82
+ X=dataset.x_df(),
83
+ y=dataset.y_df(),
84
+ random_state=0,
85
+ n_jobs=-2,
86
+ )
87
+
88
+ pfi = {}
89
+ for i, column in enumerate(dataset.feature_columns):
90
+ pfi[column] = [pfi_values["importances_mean"][i]], [
91
+ pfi_values["importances_std"][i]
92
+ ]
93
+
94
+ sorted_idx = pfi_values.importances_mean.argsort()
95
+
96
+ fig = go.Figure()
97
+ fig.add_trace(
98
+ go.Bar(
99
+ y=[dataset.feature_columns[i] for i in sorted_idx],
100
+ x=pfi_values.importances[sorted_idx].mean(axis=1).T,
101
+ orientation="h",
132
102
  )
103
+ )
104
+ fig.update_layout(
105
+ title_text="Permutation Importances",
106
+ yaxis=dict(
107
+ tickmode="linear",
108
+ dtick=1,
109
+ tickfont=dict(size=fontsize),
110
+ ),
111
+ height=figure_height,
112
+ )
113
+
114
+ return fig
@@ -2,26 +2,87 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
5
+ from typing import List
6
6
 
7
7
  import numpy as np
8
8
  import pandas as pd
9
9
  import plotly.graph_objects as go
10
10
 
11
+ from validmind import tags, tasks
12
+ from validmind.errors import SkipTestError
11
13
  from validmind.logging import get_logger
12
- from validmind.vm_models import (
13
- Figure,
14
- Metric,
15
- ResultSummary,
16
- ResultTable,
17
- ResultTableMetadata,
18
- )
14
+ from validmind.vm_models import VMDataset, VMModel
19
15
 
20
16
  logger = get_logger(__name__)
21
17
 
22
18
 
23
- @dataclass
24
- class PopulationStabilityIndex(Metric):
19
+ def calculate_psi(score_initial, score_new, num_bins=10, mode="fixed"):
20
+ """
21
+ Taken from:
22
+ https://towardsdatascience.com/checking-model-stability-and-population-shift-with-psi-and-csi-6d12af008783
23
+ """
24
+ eps = 1e-4
25
+
26
+ # Sort the data
27
+ score_initial.sort()
28
+ score_new.sort()
29
+
30
+ # Prepare the bins
31
+ min_val = min(min(score_initial), min(score_new))
32
+ max_val = max(max(score_initial), max(score_new))
33
+ if mode == "fixed":
34
+ bins = [
35
+ min_val + (max_val - min_val) * (i) / num_bins for i in range(num_bins + 1)
36
+ ]
37
+ elif mode == "quantile":
38
+ bins = pd.qcut(score_initial, q=num_bins, retbins=True)[
39
+ 1
40
+ ] # Create the quantiles based on the initial population
41
+ else:
42
+ raise ValueError(
43
+ f"Mode '{mode}' not recognized. Allowed options are 'fixed' and 'quantile'"
44
+ )
45
+ bins[0] = min_val - eps # Correct the lower boundary
46
+ bins[-1] = max_val + eps # Correct the higher boundary
47
+
48
+ # Bucketize the initial population and count the sample inside each bucket
49
+ bins_initial = pd.cut(score_initial, bins=bins, labels=range(1, num_bins + 1))
50
+ df_initial = pd.DataFrame({"initial": score_initial, "bin": bins_initial})
51
+ grp_initial = df_initial.groupby("bin").count()
52
+ grp_initial["percent_initial"] = grp_initial["initial"] / sum(
53
+ grp_initial["initial"]
54
+ )
55
+
56
+ # Bucketize the new population and count the sample inside each bucket
57
+ bins_new = pd.cut(score_new, bins=bins, labels=range(1, num_bins + 1))
58
+ df_new = pd.DataFrame({"new": score_new, "bin": bins_new})
59
+ grp_new = df_new.groupby("bin").count()
60
+ grp_new["percent_new"] = grp_new["new"] / sum(grp_new["new"])
61
+
62
+ # Compare the bins to calculate PSI
63
+ psi_df = grp_initial.join(grp_new, on="bin", how="inner")
64
+
65
+ # Add a small value for when the percent is zero
66
+ psi_df["percent_initial"] = psi_df["percent_initial"].apply(
67
+ lambda x: eps if x == 0 else x
68
+ )
69
+ psi_df["percent_new"] = psi_df["percent_new"].apply(lambda x: eps if x == 0 else x)
70
+
71
+ # Calculate the psi
72
+ psi_df["psi"] = (psi_df["percent_initial"] - psi_df["percent_new"]) * np.log(
73
+ psi_df["percent_initial"] / psi_df["percent_new"]
74
+ )
75
+
76
+ return psi_df.to_dict(orient="records")
77
+
78
+
79
+ @tags(
80
+ "sklearn", "binary_classification", "multiclass_classification", "model_performance"
81
+ )
82
+ @tasks("classification", "text_classification")
83
+ def PopulationStabilityIndex(
84
+ datasets: List[VMDataset], model: VMModel, num_bins: int = 10, mode: str = "fixed"
85
+ ):
25
86
  """
26
87
  Assesses the Population Stability Index (PSI) to quantify the stability of an ML model's predictions across
27
88
  different datasets.
@@ -72,150 +133,39 @@ class PopulationStabilityIndex(Metric):
72
133
  relationships between features and the target variable (concept drift), or both. However, distinguishing between
73
134
  these causes is non-trivial.
74
135
  """
75
-
76
- name = "psi"
77
- required_inputs = ["model", "datasets"]
78
- tasks = ["classification", "text_classification"]
79
- tags = [
80
- "sklearn",
81
- "binary_classification",
82
- "multiclass_classification",
83
- "model_performance",
84
- ]
85
- default_params = {
86
- "num_bins": 10,
87
- "mode": "fixed",
88
- }
89
-
90
- def summary(self, metric_value):
91
- # Add a table with the PSI values for each feature
92
- # The data looks like this: [{"initial": 2652, "percent_initial": 0.5525, "new": 830, "percent_new": 0.5188, "psi": 0.0021},...
93
- psi_table = [
94
- {
95
- "Bin": (
96
- i if i < (len(metric_value) - 1) else "Total"
97
- ), # The last bin is the "Total" bin
98
- "Count Initial": values["initial"],
99
- "Percent Initial (%)": values["percent_initial"] * 100,
100
- "Count New": values["new"],
101
- "Percent New (%)": values["percent_new"] * 100,
102
- "PSI": values["psi"],
103
- }
104
- for i, values in enumerate(metric_value)
105
- ]
106
-
107
- return ResultSummary(
108
- results=[
109
- ResultTable(
110
- data=psi_table,
111
- metadata=ResultTableMetadata(
112
- title="Population Stability Index for Training and Test Datasets"
113
- ),
114
- ),
115
- ]
116
- )
117
-
118
- def _get_psi(
119
- self, score_initial, score_new, num_bins=10, mode="fixed", as_dict=False
120
- ):
121
- """
122
- Taken from:
123
- https://towardsdatascience.com/checking-model-stability-and-population-shift-with-psi-and-csi-6d12af008783
124
- """
125
- eps = 1e-4
126
-
127
- # Sort the data
128
- score_initial.sort()
129
- score_new.sort()
130
-
131
- # Prepare the bins
132
- min_val = min(min(score_initial), min(score_new))
133
- max_val = max(max(score_initial), max(score_new))
134
- if mode == "fixed":
135
- bins = [
136
- min_val + (max_val - min_val) * (i) / num_bins
137
- for i in range(num_bins + 1)
138
- ]
139
- elif mode == "quantile":
140
- bins = pd.qcut(score_initial, q=num_bins, retbins=True)[
141
- 1
142
- ] # Create the quantiles based on the initial population
143
- else:
144
- raise ValueError(
145
- f"Mode '{mode}' not recognized. Allowed options are 'fixed' and 'quantile'"
146
- )
147
- bins[0] = min_val - eps # Correct the lower boundary
148
- bins[-1] = max_val + eps # Correct the higher boundary
149
-
150
- # Bucketize the initial population and count the sample inside each bucket
151
- bins_initial = pd.cut(score_initial, bins=bins, labels=range(1, num_bins + 1))
152
- df_initial = pd.DataFrame({"initial": score_initial, "bin": bins_initial})
153
- grp_initial = df_initial.groupby("bin").count()
154
- grp_initial["percent_initial"] = grp_initial["initial"] / sum(
155
- grp_initial["initial"]
156
- )
157
-
158
- # Bucketize the new population and count the sample inside each bucket
159
- bins_new = pd.cut(score_new, bins=bins, labels=range(1, num_bins + 1))
160
- df_new = pd.DataFrame({"new": score_new, "bin": bins_new})
161
- grp_new = df_new.groupby("bin").count()
162
- grp_new["percent_new"] = grp_new["new"] / sum(grp_new["new"])
163
-
164
- # Compare the bins to calculate PSI
165
- psi_df = grp_initial.join(grp_new, on="bin", how="inner")
166
-
167
- # Add a small value for when the percent is zero
168
- psi_df["percent_initial"] = psi_df["percent_initial"].apply(
169
- lambda x: eps if x == 0 else x
170
- )
171
- psi_df["percent_new"] = psi_df["percent_new"].apply(
172
- lambda x: eps if x == 0 else x
173
- )
174
-
175
- # Calculate the psi
176
- psi_df["psi"] = (psi_df["percent_initial"] - psi_df["percent_new"]) * np.log(
177
- psi_df["percent_initial"] / psi_df["percent_new"]
178
- )
179
-
180
- return psi_df.to_dict(orient="records")
181
-
182
- def run(self):
183
- if self.inputs.model.library in ["statsmodels", "pytorch", "catboost"]:
184
- logger.info(f"Skiping PSI for {self.inputs.model.library} models")
185
- return
186
-
187
- num_bins = self.params["num_bins"]
188
- mode = self.params["mode"]
189
-
190
- psi_results = self._get_psi(
191
- self.inputs.model.predict_proba(self.inputs.datasets[0].x).copy(),
192
- self.inputs.model.predict_proba(self.inputs.datasets[1].x).copy(),
193
- num_bins=num_bins,
194
- mode=mode,
195
- )
196
-
197
- trace1 = go.Bar(
198
- x=list(range(len(psi_results))),
199
- y=[d["percent_initial"] for d in psi_results],
200
- name="Initial",
201
- marker=dict(color="#DE257E"),
202
- )
203
- trace2 = go.Bar(
204
- x=list(range(len(psi_results))),
205
- y=[d["percent_new"] for d in psi_results],
206
- name="New",
207
- marker=dict(color="#E8B1F8"),
208
- )
209
-
210
- trace3 = go.Scatter(
211
- x=list(range(len(psi_results))),
212
- y=[d["psi"] for d in psi_results],
213
- name="PSI",
214
- yaxis="y2",
215
- line=dict(color="#257EDE"),
216
- )
217
-
218
- layout = go.Layout(
136
+ if model.library in ["statsmodels", "pytorch", "catboost"]:
137
+ raise SkipTestError(f"Skiping PSI for {model.library} models")
138
+
139
+ psi_results = calculate_psi(
140
+ datasets[0].y_prob(model).copy(),
141
+ datasets[1].y_prob(model).copy(),
142
+ num_bins=num_bins,
143
+ mode=mode,
144
+ )
145
+
146
+ fig = go.Figure(
147
+ data=[
148
+ go.Bar(
149
+ x=list(range(len(psi_results))),
150
+ y=[d["percent_initial"] for d in psi_results],
151
+ name="Initial",
152
+ marker=dict(color="#DE257E"),
153
+ ),
154
+ go.Bar(
155
+ x=list(range(len(psi_results))),
156
+ y=[d["percent_new"] for d in psi_results],
157
+ name="New",
158
+ marker=dict(color="#E8B1F8"),
159
+ ),
160
+ go.Scatter(
161
+ x=list(range(len(psi_results))),
162
+ y=[d["psi"] for d in psi_results],
163
+ name="PSI",
164
+ yaxis="y2",
165
+ line=dict(color="#257EDE"),
166
+ ),
167
+ ],
168
+ layout=go.Layout(
219
169
  title="Population Stability Index (PSI) Plot",
220
170
  xaxis=dict(title="Bin"),
221
171
  yaxis=dict(title="Population Ratio"),
@@ -229,23 +179,31 @@ class PopulationStabilityIndex(Metric):
229
179
  ], # Adjust as needed
230
180
  ),
231
181
  barmode="group",
232
- )
233
-
234
- fig = go.Figure(data=[trace1, trace2, trace3], layout=layout)
235
- figure = Figure(
236
- for_object=self,
237
- key=self.key,
238
- figure=fig,
239
- )
240
-
241
- # Calculate the sum of each numeric column
242
- total_psi = {
243
- key: sum(d.get(key, 0) for d in psi_results)
244
- for key in psi_results[0].keys()
245
- if isinstance(psi_results[0][key], (int, float))
246
- }
182
+ ),
183
+ )
184
+
185
+ # sum up the PSI values to get the total values
186
+ total_psi = {
187
+ key: sum(d.get(key, 0) for d in psi_results)
188
+ for key in psi_results[0].keys()
189
+ if isinstance(psi_results[0][key], (int, float))
190
+ }
191
+ psi_results.append(total_psi)
247
192
 
248
- # Add the total PSI dictionary to the list
249
- psi_results.append(total_psi)
193
+ table_title = f"Population Stability Index for {datasets[0].input_id} and {datasets[1].input_id} Datasets"
250
194
 
251
- return self.cache_results(metric_value=psi_results, figures=[figure])
195
+ return {
196
+ table_title: [
197
+ {
198
+ "Bin": (
199
+ i if i < (len(psi_results) - 1) else "Total"
200
+ ), # The last bin is the "Total" bin
201
+ "Count Initial": values["initial"],
202
+ "Percent Initial (%)": values["percent_initial"] * 100,
203
+ "Count New": values["new"],
204
+ "Percent New (%)": values["percent_new"] * 100,
205
+ "PSI": values["psi"],
206
+ }
207
+ for i, values in enumerate(psi_results)
208
+ ],
209
+ }, fig
@@ -2,19 +2,19 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
7
5
  import numpy as np
8
6
  import plotly.graph_objects as go
9
7
  from sklearn.metrics import precision_recall_curve
10
8
 
9
+ from validmind import tags, tasks
11
10
  from validmind.errors import SkipTestError
12
11
  from validmind.models import FoundationModel
13
- from validmind.vm_models import Figure, Metric
12
+ from validmind.vm_models import VMDataset, VMModel
14
13
 
15
14
 
16
- @dataclass
17
- class PrecisionRecallCurve(Metric):
15
+ @tags("sklearn", "binary_classification", "model_performance", "visualization")
16
+ @tasks("classification", "text_classification")
17
+ def PrecisionRecallCurve(model: VMModel, dataset: VMDataset):
18
18
  """
19
19
  Evaluates the precision-recall trade-off for binary classification models and visualizes the Precision-Recall curve.
20
20
 
@@ -55,59 +55,30 @@ class PrecisionRecallCurve(Metric):
55
55
  - It may not fully represent the overall accuracy of the model if the cost of false positives and false negatives
56
56
  are extremely different, or if the dataset is heavily imbalanced.
57
57
  """
58
+ if isinstance(model, FoundationModel):
59
+ raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
58
60
 
59
- name = "pr_curve"
60
- required_inputs = ["model", "dataset"]
61
- tasks = ["classification", "text_classification"]
62
- tags = [
63
- "sklearn",
64
- "binary_classification",
65
- "multiclass_classification",
66
- "model_performance",
67
- "visualization",
68
- ]
69
-
70
- def run(self):
71
- if isinstance(self.inputs.model, FoundationModel):
72
- raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
73
-
74
- y_true = self.inputs.dataset.y
75
- y_pred = self.inputs.dataset.y_prob(self.inputs.model)
76
-
77
- # PR curve is only supported for binary classification
78
- if len(np.unique(y_true)) > 2:
79
- raise SkipTestError(
80
- "Precision Recall Curve is only supported for binary classification models"
81
- )
61
+ y_true = dataset.y
62
+ if len(np.unique(y_true)) > 2:
63
+ raise SkipTestError(
64
+ "Precision Recall Curve is only supported for binary classification models"
65
+ )
82
66
 
83
- precision, recall, pr_thresholds = precision_recall_curve(y_true, y_pred)
67
+ precision, recall, _ = precision_recall_curve(y_true, dataset.y_prob(model))
84
68
 
85
- trace = go.Scatter(
86
- x=recall,
87
- y=precision,
88
- mode="lines",
89
- name="Precision-Recall Curve",
90
- line=dict(color="#DE257E"),
91
- )
92
- layout = go.Layout(
69
+ return go.Figure(
70
+ data=[
71
+ go.Scatter(
72
+ x=recall,
73
+ y=precision,
74
+ mode="lines",
75
+ name="Precision-Recall Curve",
76
+ line=dict(color="#DE257E"),
77
+ )
78
+ ],
79
+ layout=go.Layout(
93
80
  title="Precision-Recall Curve",
94
81
  xaxis=dict(title="Recall"),
95
82
  yaxis=dict(title="Precision"),
96
- )
97
-
98
- fig = go.Figure(data=[trace], layout=layout)
99
-
100
- return self.cache_results(
101
- metric_value={
102
- "precision": precision,
103
- "recall": recall,
104
- "thresholds": pr_thresholds,
105
- },
106
- figures=[
107
- Figure(
108
- for_object=self,
109
- key="pr_curve",
110
- figure=fig,
111
- )
112
- ],
113
- )
83
+ ),
84
+ )