validmind 2.5.25__py3-none-any.whl → 2.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +8 -17
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +66 -85
- validmind/ai/test_result_description/context.py +2 -2
- validmind/ai/utils.py +26 -1
- validmind/api_client.py +43 -79
- validmind/client.py +5 -7
- validmind/client_config.py +1 -1
- validmind/datasets/__init__.py +1 -1
- validmind/datasets/classification/customer_churn.py +7 -5
- validmind/datasets/nlp/__init__.py +2 -2
- validmind/errors.py +6 -10
- validmind/html_templates/content_blocks.py +18 -16
- validmind/logging.py +21 -16
- validmind/tests/__init__.py +28 -5
- validmind/tests/__types__.py +186 -170
- validmind/tests/_store.py +7 -21
- validmind/tests/comparison.py +362 -0
- validmind/tests/data_validation/ACFandPACFPlot.py +44 -73
- validmind/tests/data_validation/ADF.py +49 -83
- validmind/tests/data_validation/AutoAR.py +59 -96
- validmind/tests/data_validation/AutoMA.py +59 -96
- validmind/tests/data_validation/AutoStationarity.py +66 -114
- validmind/tests/data_validation/ClassImbalance.py +48 -117
- validmind/tests/data_validation/DatasetDescription.py +180 -209
- validmind/tests/data_validation/DatasetSplit.py +50 -75
- validmind/tests/data_validation/DescriptiveStatistics.py +59 -85
- validmind/tests/data_validation/{DFGLSArch.py → DickeyFullerGLS.py} +44 -76
- validmind/tests/data_validation/Duplicates.py +21 -90
- validmind/tests/data_validation/EngleGrangerCoint.py +53 -75
- validmind/tests/data_validation/HighCardinality.py +32 -80
- validmind/tests/data_validation/HighPearsonCorrelation.py +29 -97
- validmind/tests/data_validation/IQROutliersBarPlot.py +63 -94
- validmind/tests/data_validation/IQROutliersTable.py +40 -80
- validmind/tests/data_validation/IsolationForestOutliers.py +41 -63
- validmind/tests/data_validation/KPSS.py +33 -81
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +47 -95
- validmind/tests/data_validation/MissingValues.py +17 -58
- validmind/tests/data_validation/MissingValuesBarPlot.py +61 -87
- validmind/tests/data_validation/PhillipsPerronArch.py +56 -79
- validmind/tests/data_validation/RollingStatsPlot.py +50 -81
- validmind/tests/data_validation/SeasonalDecompose.py +102 -184
- validmind/tests/data_validation/Skewness.py +27 -64
- validmind/tests/data_validation/SpreadPlot.py +34 -57
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +46 -65
- validmind/tests/data_validation/TabularDateTimeHistograms.py +23 -45
- validmind/tests/data_validation/TabularNumericalHistograms.py +27 -46
- validmind/tests/data_validation/TargetRateBarPlots.py +54 -93
- validmind/tests/data_validation/TimeSeriesFrequency.py +48 -133
- validmind/tests/data_validation/TimeSeriesHistogram.py +24 -3
- validmind/tests/data_validation/TimeSeriesLinePlot.py +29 -47
- validmind/tests/data_validation/TimeSeriesMissingValues.py +59 -135
- validmind/tests/data_validation/TimeSeriesOutliers.py +54 -171
- validmind/tests/data_validation/TooManyZeroValues.py +21 -70
- validmind/tests/data_validation/UniqueRows.py +23 -62
- validmind/tests/data_validation/WOEBinPlots.py +83 -109
- validmind/tests/data_validation/WOEBinTable.py +28 -69
- validmind/tests/data_validation/ZivotAndrewsArch.py +33 -75
- validmind/tests/data_validation/nlp/CommonWords.py +49 -57
- validmind/tests/data_validation/nlp/Hashtags.py +27 -49
- validmind/tests/data_validation/nlp/LanguageDetection.py +7 -13
- validmind/tests/data_validation/nlp/Mentions.py +32 -63
- validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +89 -14
- validmind/tests/data_validation/nlp/Punctuations.py +63 -47
- validmind/tests/data_validation/nlp/Sentiment.py +4 -0
- validmind/tests/data_validation/nlp/StopWords.py +62 -91
- validmind/tests/data_validation/nlp/TextDescription.py +116 -159
- validmind/tests/data_validation/nlp/Toxicity.py +12 -4
- validmind/tests/decorator.py +33 -242
- validmind/tests/load.py +212 -153
- validmind/tests/model_validation/BertScore.py +13 -7
- validmind/tests/model_validation/BleuScore.py +4 -0
- validmind/tests/model_validation/ClusterSizeDistribution.py +24 -47
- validmind/tests/model_validation/ContextualRecall.py +3 -0
- validmind/tests/model_validation/FeaturesAUC.py +43 -74
- validmind/tests/model_validation/MeteorScore.py +3 -0
- validmind/tests/model_validation/RegardScore.py +5 -1
- validmind/tests/model_validation/RegressionResidualsPlot.py +54 -75
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +10 -33
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +11 -29
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +19 -31
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +40 -49
- validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +29 -15
- validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +25 -11
- validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +28 -13
- validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +67 -38
- validmind/tests/model_validation/embeddings/utils.py +53 -0
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +37 -32
- validmind/tests/model_validation/ragas/{AspectCritique.py → AspectCritic.py} +33 -27
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +44 -41
- validmind/tests/model_validation/ragas/ContextPrecision.py +40 -35
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +133 -0
- validmind/tests/model_validation/ragas/ContextRecall.py +40 -35
- validmind/tests/model_validation/ragas/Faithfulness.py +42 -30
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +59 -35
- validmind/tests/model_validation/ragas/{AnswerRelevance.py → ResponseRelevancy.py} +52 -41
- validmind/tests/model_validation/ragas/{AnswerSimilarity.py → SemanticSimilarity.py} +39 -34
- validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +13 -16
- validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +13 -16
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +51 -89
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +31 -61
- validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +118 -83
- validmind/tests/model_validation/sklearn/CompletenessScore.py +13 -16
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +62 -94
- validmind/tests/model_validation/sklearn/FeatureImportance.py +7 -8
- validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +12 -15
- validmind/tests/model_validation/sklearn/HomogeneityScore.py +12 -15
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +23 -53
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +60 -74
- validmind/tests/model_validation/sklearn/MinimumAccuracy.py +16 -84
- validmind/tests/model_validation/sklearn/MinimumF1Score.py +22 -72
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +29 -78
- validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +52 -82
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +51 -145
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +60 -78
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +130 -172
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +26 -55
- validmind/tests/model_validation/sklearn/ROCCurve.py +43 -77
- validmind/tests/model_validation/sklearn/RegressionPerformance.py +41 -94
- validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +47 -136
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +164 -208
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +54 -99
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +50 -124
- validmind/tests/model_validation/sklearn/VMeasure.py +12 -15
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +225 -281
- validmind/tests/model_validation/statsmodels/AutoARIMA.py +40 -45
- validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +22 -47
- validmind/tests/model_validation/statsmodels/Lilliefors.py +17 -28
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +37 -81
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +37 -105
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +62 -166
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +57 -119
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +20 -57
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +47 -80
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +2 -0
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +4 -2
- validmind/tests/output.py +120 -0
- validmind/tests/prompt_validation/Bias.py +55 -98
- validmind/tests/prompt_validation/Clarity.py +56 -99
- validmind/tests/prompt_validation/Conciseness.py +63 -101
- validmind/tests/prompt_validation/Delimitation.py +48 -89
- validmind/tests/prompt_validation/NegativeInstruction.py +62 -96
- validmind/tests/prompt_validation/Robustness.py +80 -121
- validmind/tests/prompt_validation/Specificity.py +61 -95
- validmind/tests/prompt_validation/ai_powered_test.py +2 -2
- validmind/tests/run.py +314 -496
- validmind/tests/test_providers.py +109 -79
- validmind/tests/utils.py +91 -0
- validmind/unit_metrics/__init__.py +16 -155
- validmind/unit_metrics/classification/F1.py +1 -0
- validmind/unit_metrics/classification/Precision.py +1 -0
- validmind/unit_metrics/classification/ROC_AUC.py +1 -0
- validmind/unit_metrics/classification/Recall.py +1 -0
- validmind/unit_metrics/regression/AdjustedRSquaredScore.py +1 -0
- validmind/unit_metrics/regression/GiniCoefficient.py +1 -0
- validmind/unit_metrics/regression/HuberLoss.py +1 -0
- validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +1 -0
- validmind/unit_metrics/regression/MeanAbsoluteError.py +1 -0
- validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +1 -0
- validmind/unit_metrics/regression/MeanBiasDeviation.py +1 -0
- validmind/unit_metrics/regression/MeanSquaredError.py +1 -0
- validmind/unit_metrics/regression/QuantileLoss.py +1 -0
- validmind/unit_metrics/regression/RSquaredScore.py +2 -1
- validmind/unit_metrics/regression/RootMeanSquaredError.py +1 -0
- validmind/utils.py +66 -17
- validmind/vm_models/__init__.py +2 -17
- validmind/vm_models/dataset/dataset.py +31 -4
- validmind/vm_models/figure.py +7 -37
- validmind/vm_models/model.py +3 -0
- validmind/vm_models/result/__init__.py +7 -0
- validmind/vm_models/result/result.jinja +21 -0
- validmind/vm_models/result/result.py +337 -0
- validmind/vm_models/result/utils.py +160 -0
- validmind/vm_models/test_suite/runner.py +16 -54
- validmind/vm_models/test_suite/summary.py +3 -3
- validmind/vm_models/test_suite/test.py +43 -77
- validmind/vm_models/test_suite/test_suite.py +8 -40
- validmind-2.6.8.dist-info/METADATA +137 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/RECORD +182 -189
- validmind/tests/data_validation/AutoSeasonality.py +0 -190
- validmind/tests/metadata.py +0 -59
- validmind/tests/model_validation/embeddings/StabilityAnalysis.py +0 -176
- validmind/tests/model_validation/ragas/ContextUtilization.py +0 -161
- validmind/tests/model_validation/sklearn/ClusterPerformance.py +0 -80
- validmind/unit_metrics/composite.py +0 -238
- validmind/vm_models/test/metric.py +0 -98
- validmind/vm_models/test/metric_result.py +0 -61
- validmind/vm_models/test/output_template.py +0 -55
- validmind/vm_models/test/result_summary.py +0 -76
- validmind/vm_models/test/result_wrapper.py +0 -488
- validmind/vm_models/test/test.py +0 -103
- validmind/vm_models/test/threshold_test.py +0 -106
- validmind/vm_models/test/threshold_test_result.py +0 -75
- validmind/vm_models/test_context.py +0 -259
- validmind-2.5.25.dist-info/METADATA +0 -118
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/LICENSE +0 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/WHEEL +0 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/entry_points.txt +0 -0
@@ -2,20 +2,33 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from
|
5
|
+
from typing import Union
|
6
6
|
|
7
7
|
import plotly.graph_objects as go
|
8
8
|
from sklearn.inspection import permutation_importance
|
9
9
|
|
10
|
+
from validmind import tags, tasks
|
10
11
|
from validmind.errors import SkipTestError
|
11
12
|
from validmind.logging import get_logger
|
12
|
-
from validmind.vm_models import
|
13
|
+
from validmind.vm_models import VMDataset, VMModel
|
13
14
|
|
14
15
|
logger = get_logger(__name__)
|
15
16
|
|
16
17
|
|
17
|
-
@
|
18
|
-
|
18
|
+
@tags(
|
19
|
+
"sklearn",
|
20
|
+
"binary_classification",
|
21
|
+
"multiclass_classification",
|
22
|
+
"feature_importance",
|
23
|
+
"visualization",
|
24
|
+
)
|
25
|
+
@tasks("classification", "text_classification")
|
26
|
+
def PermutationFeatureImportance(
|
27
|
+
model: VMModel,
|
28
|
+
dataset: VMDataset,
|
29
|
+
fontsize: Union[int, None] = None,
|
30
|
+
figure_height: Union[int, None] = None,
|
31
|
+
):
|
19
32
|
"""
|
20
33
|
Assesses the significance of each feature in a model by evaluating the impact on model performance when feature
|
21
34
|
values are randomly rearranged.
|
@@ -55,78 +68,47 @@ class PermutationFeatureImportance(Metric):
|
|
55
68
|
allocate importance to one and not the other.
|
56
69
|
- Cannot interact with certain libraries like statsmodels, pytorch, catboost, etc., thus limiting its applicability.
|
57
70
|
"""
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
"
|
63
|
-
"
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
x,
|
90
|
-
y,
|
91
|
-
random_state=0,
|
92
|
-
n_jobs=-2,
|
93
|
-
)
|
94
|
-
|
95
|
-
pfi = {}
|
96
|
-
for i, column in enumerate(x.columns):
|
97
|
-
pfi[column] = [pfi_values["importances_mean"][i]], [
|
98
|
-
pfi_values["importances_std"][i]
|
99
|
-
]
|
100
|
-
|
101
|
-
sorted_idx = pfi_values.importances_mean.argsort()
|
102
|
-
|
103
|
-
fig = go.Figure()
|
104
|
-
fig.add_trace(
|
105
|
-
go.Bar(
|
106
|
-
y=x.columns[sorted_idx],
|
107
|
-
x=pfi_values.importances[sorted_idx].mean(axis=1).T,
|
108
|
-
orientation="h",
|
109
|
-
)
|
110
|
-
)
|
111
|
-
fig.update_layout(
|
112
|
-
title_text="Permutation Importances",
|
113
|
-
yaxis=dict(
|
114
|
-
tickmode="linear", # set tick mode to linear
|
115
|
-
dtick=1, # set interval between ticks
|
116
|
-
tickfont=dict(
|
117
|
-
size=self.params["fontsize"]
|
118
|
-
), # set the tick label font size
|
119
|
-
),
|
120
|
-
height=self.params["figure_height"], # use figure_height parameter here
|
121
|
-
)
|
122
|
-
|
123
|
-
return self.cache_results(
|
124
|
-
metric_value=pfi,
|
125
|
-
figures=[
|
126
|
-
Figure(
|
127
|
-
for_object=self,
|
128
|
-
key=f"pfi_{self.inputs.dataset.input_id}_{self.inputs.model.input_id}",
|
129
|
-
figure=fig,
|
130
|
-
),
|
131
|
-
],
|
71
|
+
if model.library in [
|
72
|
+
"statsmodels",
|
73
|
+
"pytorch",
|
74
|
+
"catboost",
|
75
|
+
"transformers",
|
76
|
+
"R",
|
77
|
+
]:
|
78
|
+
raise SkipTestError(f"Skipping PFI for {model.library} models")
|
79
|
+
|
80
|
+
pfi_values = permutation_importance(
|
81
|
+
estimator=model.model,
|
82
|
+
X=dataset.x_df(),
|
83
|
+
y=dataset.y_df(),
|
84
|
+
random_state=0,
|
85
|
+
n_jobs=-2,
|
86
|
+
)
|
87
|
+
|
88
|
+
pfi = {}
|
89
|
+
for i, column in enumerate(dataset.feature_columns):
|
90
|
+
pfi[column] = [pfi_values["importances_mean"][i]], [
|
91
|
+
pfi_values["importances_std"][i]
|
92
|
+
]
|
93
|
+
|
94
|
+
sorted_idx = pfi_values.importances_mean.argsort()
|
95
|
+
|
96
|
+
fig = go.Figure()
|
97
|
+
fig.add_trace(
|
98
|
+
go.Bar(
|
99
|
+
y=[dataset.feature_columns[i] for i in sorted_idx],
|
100
|
+
x=pfi_values.importances[sorted_idx].mean(axis=1).T,
|
101
|
+
orientation="h",
|
132
102
|
)
|
103
|
+
)
|
104
|
+
fig.update_layout(
|
105
|
+
title_text="Permutation Importances",
|
106
|
+
yaxis=dict(
|
107
|
+
tickmode="linear",
|
108
|
+
dtick=1,
|
109
|
+
tickfont=dict(size=fontsize),
|
110
|
+
),
|
111
|
+
height=figure_height,
|
112
|
+
)
|
113
|
+
|
114
|
+
return fig
|
@@ -2,26 +2,87 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from
|
5
|
+
from typing import List
|
6
6
|
|
7
7
|
import numpy as np
|
8
8
|
import pandas as pd
|
9
9
|
import plotly.graph_objects as go
|
10
10
|
|
11
|
+
from validmind import tags, tasks
|
12
|
+
from validmind.errors import SkipTestError
|
11
13
|
from validmind.logging import get_logger
|
12
|
-
from validmind.vm_models import
|
13
|
-
Figure,
|
14
|
-
Metric,
|
15
|
-
ResultSummary,
|
16
|
-
ResultTable,
|
17
|
-
ResultTableMetadata,
|
18
|
-
)
|
14
|
+
from validmind.vm_models import VMDataset, VMModel
|
19
15
|
|
20
16
|
logger = get_logger(__name__)
|
21
17
|
|
22
18
|
|
23
|
-
|
24
|
-
|
19
|
+
def calculate_psi(score_initial, score_new, num_bins=10, mode="fixed"):
|
20
|
+
"""
|
21
|
+
Taken from:
|
22
|
+
https://towardsdatascience.com/checking-model-stability-and-population-shift-with-psi-and-csi-6d12af008783
|
23
|
+
"""
|
24
|
+
eps = 1e-4
|
25
|
+
|
26
|
+
# Sort the data
|
27
|
+
score_initial.sort()
|
28
|
+
score_new.sort()
|
29
|
+
|
30
|
+
# Prepare the bins
|
31
|
+
min_val = min(min(score_initial), min(score_new))
|
32
|
+
max_val = max(max(score_initial), max(score_new))
|
33
|
+
if mode == "fixed":
|
34
|
+
bins = [
|
35
|
+
min_val + (max_val - min_val) * (i) / num_bins for i in range(num_bins + 1)
|
36
|
+
]
|
37
|
+
elif mode == "quantile":
|
38
|
+
bins = pd.qcut(score_initial, q=num_bins, retbins=True)[
|
39
|
+
1
|
40
|
+
] # Create the quantiles based on the initial population
|
41
|
+
else:
|
42
|
+
raise ValueError(
|
43
|
+
f"Mode '{mode}' not recognized. Allowed options are 'fixed' and 'quantile'"
|
44
|
+
)
|
45
|
+
bins[0] = min_val - eps # Correct the lower boundary
|
46
|
+
bins[-1] = max_val + eps # Correct the higher boundary
|
47
|
+
|
48
|
+
# Bucketize the initial population and count the sample inside each bucket
|
49
|
+
bins_initial = pd.cut(score_initial, bins=bins, labels=range(1, num_bins + 1))
|
50
|
+
df_initial = pd.DataFrame({"initial": score_initial, "bin": bins_initial})
|
51
|
+
grp_initial = df_initial.groupby("bin").count()
|
52
|
+
grp_initial["percent_initial"] = grp_initial["initial"] / sum(
|
53
|
+
grp_initial["initial"]
|
54
|
+
)
|
55
|
+
|
56
|
+
# Bucketize the new population and count the sample inside each bucket
|
57
|
+
bins_new = pd.cut(score_new, bins=bins, labels=range(1, num_bins + 1))
|
58
|
+
df_new = pd.DataFrame({"new": score_new, "bin": bins_new})
|
59
|
+
grp_new = df_new.groupby("bin").count()
|
60
|
+
grp_new["percent_new"] = grp_new["new"] / sum(grp_new["new"])
|
61
|
+
|
62
|
+
# Compare the bins to calculate PSI
|
63
|
+
psi_df = grp_initial.join(grp_new, on="bin", how="inner")
|
64
|
+
|
65
|
+
# Add a small value for when the percent is zero
|
66
|
+
psi_df["percent_initial"] = psi_df["percent_initial"].apply(
|
67
|
+
lambda x: eps if x == 0 else x
|
68
|
+
)
|
69
|
+
psi_df["percent_new"] = psi_df["percent_new"].apply(lambda x: eps if x == 0 else x)
|
70
|
+
|
71
|
+
# Calculate the psi
|
72
|
+
psi_df["psi"] = (psi_df["percent_initial"] - psi_df["percent_new"]) * np.log(
|
73
|
+
psi_df["percent_initial"] / psi_df["percent_new"]
|
74
|
+
)
|
75
|
+
|
76
|
+
return psi_df.to_dict(orient="records")
|
77
|
+
|
78
|
+
|
79
|
+
@tags(
|
80
|
+
"sklearn", "binary_classification", "multiclass_classification", "model_performance"
|
81
|
+
)
|
82
|
+
@tasks("classification", "text_classification")
|
83
|
+
def PopulationStabilityIndex(
|
84
|
+
datasets: List[VMDataset], model: VMModel, num_bins: int = 10, mode: str = "fixed"
|
85
|
+
):
|
25
86
|
"""
|
26
87
|
Assesses the Population Stability Index (PSI) to quantify the stability of an ML model's predictions across
|
27
88
|
different datasets.
|
@@ -72,150 +133,39 @@ class PopulationStabilityIndex(Metric):
|
|
72
133
|
relationships between features and the target variable (concept drift), or both. However, distinguishing between
|
73
134
|
these causes is non-trivial.
|
74
135
|
"""
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
"
|
96
|
-
|
97
|
-
),
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
"PSI"
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
results=[
|
109
|
-
ResultTable(
|
110
|
-
data=psi_table,
|
111
|
-
metadata=ResultTableMetadata(
|
112
|
-
title="Population Stability Index for Training and Test Datasets"
|
113
|
-
),
|
114
|
-
),
|
115
|
-
]
|
116
|
-
)
|
117
|
-
|
118
|
-
def _get_psi(
|
119
|
-
self, score_initial, score_new, num_bins=10, mode="fixed", as_dict=False
|
120
|
-
):
|
121
|
-
"""
|
122
|
-
Taken from:
|
123
|
-
https://towardsdatascience.com/checking-model-stability-and-population-shift-with-psi-and-csi-6d12af008783
|
124
|
-
"""
|
125
|
-
eps = 1e-4
|
126
|
-
|
127
|
-
# Sort the data
|
128
|
-
score_initial.sort()
|
129
|
-
score_new.sort()
|
130
|
-
|
131
|
-
# Prepare the bins
|
132
|
-
min_val = min(min(score_initial), min(score_new))
|
133
|
-
max_val = max(max(score_initial), max(score_new))
|
134
|
-
if mode == "fixed":
|
135
|
-
bins = [
|
136
|
-
min_val + (max_val - min_val) * (i) / num_bins
|
137
|
-
for i in range(num_bins + 1)
|
138
|
-
]
|
139
|
-
elif mode == "quantile":
|
140
|
-
bins = pd.qcut(score_initial, q=num_bins, retbins=True)[
|
141
|
-
1
|
142
|
-
] # Create the quantiles based on the initial population
|
143
|
-
else:
|
144
|
-
raise ValueError(
|
145
|
-
f"Mode '{mode}' not recognized. Allowed options are 'fixed' and 'quantile'"
|
146
|
-
)
|
147
|
-
bins[0] = min_val - eps # Correct the lower boundary
|
148
|
-
bins[-1] = max_val + eps # Correct the higher boundary
|
149
|
-
|
150
|
-
# Bucketize the initial population and count the sample inside each bucket
|
151
|
-
bins_initial = pd.cut(score_initial, bins=bins, labels=range(1, num_bins + 1))
|
152
|
-
df_initial = pd.DataFrame({"initial": score_initial, "bin": bins_initial})
|
153
|
-
grp_initial = df_initial.groupby("bin").count()
|
154
|
-
grp_initial["percent_initial"] = grp_initial["initial"] / sum(
|
155
|
-
grp_initial["initial"]
|
156
|
-
)
|
157
|
-
|
158
|
-
# Bucketize the new population and count the sample inside each bucket
|
159
|
-
bins_new = pd.cut(score_new, bins=bins, labels=range(1, num_bins + 1))
|
160
|
-
df_new = pd.DataFrame({"new": score_new, "bin": bins_new})
|
161
|
-
grp_new = df_new.groupby("bin").count()
|
162
|
-
grp_new["percent_new"] = grp_new["new"] / sum(grp_new["new"])
|
163
|
-
|
164
|
-
# Compare the bins to calculate PSI
|
165
|
-
psi_df = grp_initial.join(grp_new, on="bin", how="inner")
|
166
|
-
|
167
|
-
# Add a small value for when the percent is zero
|
168
|
-
psi_df["percent_initial"] = psi_df["percent_initial"].apply(
|
169
|
-
lambda x: eps if x == 0 else x
|
170
|
-
)
|
171
|
-
psi_df["percent_new"] = psi_df["percent_new"].apply(
|
172
|
-
lambda x: eps if x == 0 else x
|
173
|
-
)
|
174
|
-
|
175
|
-
# Calculate the psi
|
176
|
-
psi_df["psi"] = (psi_df["percent_initial"] - psi_df["percent_new"]) * np.log(
|
177
|
-
psi_df["percent_initial"] / psi_df["percent_new"]
|
178
|
-
)
|
179
|
-
|
180
|
-
return psi_df.to_dict(orient="records")
|
181
|
-
|
182
|
-
def run(self):
|
183
|
-
if self.inputs.model.library in ["statsmodels", "pytorch", "catboost"]:
|
184
|
-
logger.info(f"Skiping PSI for {self.inputs.model.library} models")
|
185
|
-
return
|
186
|
-
|
187
|
-
num_bins = self.params["num_bins"]
|
188
|
-
mode = self.params["mode"]
|
189
|
-
|
190
|
-
psi_results = self._get_psi(
|
191
|
-
self.inputs.model.predict_proba(self.inputs.datasets[0].x).copy(),
|
192
|
-
self.inputs.model.predict_proba(self.inputs.datasets[1].x).copy(),
|
193
|
-
num_bins=num_bins,
|
194
|
-
mode=mode,
|
195
|
-
)
|
196
|
-
|
197
|
-
trace1 = go.Bar(
|
198
|
-
x=list(range(len(psi_results))),
|
199
|
-
y=[d["percent_initial"] for d in psi_results],
|
200
|
-
name="Initial",
|
201
|
-
marker=dict(color="#DE257E"),
|
202
|
-
)
|
203
|
-
trace2 = go.Bar(
|
204
|
-
x=list(range(len(psi_results))),
|
205
|
-
y=[d["percent_new"] for d in psi_results],
|
206
|
-
name="New",
|
207
|
-
marker=dict(color="#E8B1F8"),
|
208
|
-
)
|
209
|
-
|
210
|
-
trace3 = go.Scatter(
|
211
|
-
x=list(range(len(psi_results))),
|
212
|
-
y=[d["psi"] for d in psi_results],
|
213
|
-
name="PSI",
|
214
|
-
yaxis="y2",
|
215
|
-
line=dict(color="#257EDE"),
|
216
|
-
)
|
217
|
-
|
218
|
-
layout = go.Layout(
|
136
|
+
if model.library in ["statsmodels", "pytorch", "catboost"]:
|
137
|
+
raise SkipTestError(f"Skiping PSI for {model.library} models")
|
138
|
+
|
139
|
+
psi_results = calculate_psi(
|
140
|
+
datasets[0].y_prob(model).copy(),
|
141
|
+
datasets[1].y_prob(model).copy(),
|
142
|
+
num_bins=num_bins,
|
143
|
+
mode=mode,
|
144
|
+
)
|
145
|
+
|
146
|
+
fig = go.Figure(
|
147
|
+
data=[
|
148
|
+
go.Bar(
|
149
|
+
x=list(range(len(psi_results))),
|
150
|
+
y=[d["percent_initial"] for d in psi_results],
|
151
|
+
name="Initial",
|
152
|
+
marker=dict(color="#DE257E"),
|
153
|
+
),
|
154
|
+
go.Bar(
|
155
|
+
x=list(range(len(psi_results))),
|
156
|
+
y=[d["percent_new"] for d in psi_results],
|
157
|
+
name="New",
|
158
|
+
marker=dict(color="#E8B1F8"),
|
159
|
+
),
|
160
|
+
go.Scatter(
|
161
|
+
x=list(range(len(psi_results))),
|
162
|
+
y=[d["psi"] for d in psi_results],
|
163
|
+
name="PSI",
|
164
|
+
yaxis="y2",
|
165
|
+
line=dict(color="#257EDE"),
|
166
|
+
),
|
167
|
+
],
|
168
|
+
layout=go.Layout(
|
219
169
|
title="Population Stability Index (PSI) Plot",
|
220
170
|
xaxis=dict(title="Bin"),
|
221
171
|
yaxis=dict(title="Population Ratio"),
|
@@ -229,23 +179,31 @@ class PopulationStabilityIndex(Metric):
|
|
229
179
|
], # Adjust as needed
|
230
180
|
),
|
231
181
|
barmode="group",
|
232
|
-
)
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
239
|
-
)
|
240
|
-
|
241
|
-
|
242
|
-
total_psi = {
|
243
|
-
key: sum(d.get(key, 0) for d in psi_results)
|
244
|
-
for key in psi_results[0].keys()
|
245
|
-
if isinstance(psi_results[0][key], (int, float))
|
246
|
-
}
|
182
|
+
),
|
183
|
+
)
|
184
|
+
|
185
|
+
# sum up the PSI values to get the total values
|
186
|
+
total_psi = {
|
187
|
+
key: sum(d.get(key, 0) for d in psi_results)
|
188
|
+
for key in psi_results[0].keys()
|
189
|
+
if isinstance(psi_results[0][key], (int, float))
|
190
|
+
}
|
191
|
+
psi_results.append(total_psi)
|
247
192
|
|
248
|
-
|
249
|
-
psi_results.append(total_psi)
|
193
|
+
table_title = f"Population Stability Index for {datasets[0].input_id} and {datasets[1].input_id} Datasets"
|
250
194
|
|
251
|
-
|
195
|
+
return {
|
196
|
+
table_title: [
|
197
|
+
{
|
198
|
+
"Bin": (
|
199
|
+
i if i < (len(psi_results) - 1) else "Total"
|
200
|
+
), # The last bin is the "Total" bin
|
201
|
+
"Count Initial": values["initial"],
|
202
|
+
"Percent Initial (%)": values["percent_initial"] * 100,
|
203
|
+
"Count New": values["new"],
|
204
|
+
"Percent New (%)": values["percent_new"] * 100,
|
205
|
+
"PSI": values["psi"],
|
206
|
+
}
|
207
|
+
for i, values in enumerate(psi_results)
|
208
|
+
],
|
209
|
+
}, fig
|
@@ -2,19 +2,19 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from dataclasses import dataclass
|
6
|
-
|
7
5
|
import numpy as np
|
8
6
|
import plotly.graph_objects as go
|
9
7
|
from sklearn.metrics import precision_recall_curve
|
10
8
|
|
9
|
+
from validmind import tags, tasks
|
11
10
|
from validmind.errors import SkipTestError
|
12
11
|
from validmind.models import FoundationModel
|
13
|
-
from validmind.vm_models import
|
12
|
+
from validmind.vm_models import VMDataset, VMModel
|
14
13
|
|
15
14
|
|
16
|
-
@
|
17
|
-
|
15
|
+
@tags("sklearn", "binary_classification", "model_performance", "visualization")
|
16
|
+
@tasks("classification", "text_classification")
|
17
|
+
def PrecisionRecallCurve(model: VMModel, dataset: VMDataset):
|
18
18
|
"""
|
19
19
|
Evaluates the precision-recall trade-off for binary classification models and visualizes the Precision-Recall curve.
|
20
20
|
|
@@ -55,59 +55,30 @@ class PrecisionRecallCurve(Metric):
|
|
55
55
|
- It may not fully represent the overall accuracy of the model if the cost of false positives and false negatives
|
56
56
|
are extremely different, or if the dataset is heavily imbalanced.
|
57
57
|
"""
|
58
|
+
if isinstance(model, FoundationModel):
|
59
|
+
raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
|
58
60
|
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
"binary_classification",
|
65
|
-
"multiclass_classification",
|
66
|
-
"model_performance",
|
67
|
-
"visualization",
|
68
|
-
]
|
69
|
-
|
70
|
-
def run(self):
|
71
|
-
if isinstance(self.inputs.model, FoundationModel):
|
72
|
-
raise SkipTestError("Skipping PrecisionRecallCurve for Foundation models")
|
73
|
-
|
74
|
-
y_true = self.inputs.dataset.y
|
75
|
-
y_pred = self.inputs.dataset.y_prob(self.inputs.model)
|
76
|
-
|
77
|
-
# PR curve is only supported for binary classification
|
78
|
-
if len(np.unique(y_true)) > 2:
|
79
|
-
raise SkipTestError(
|
80
|
-
"Precision Recall Curve is only supported for binary classification models"
|
81
|
-
)
|
61
|
+
y_true = dataset.y
|
62
|
+
if len(np.unique(y_true)) > 2:
|
63
|
+
raise SkipTestError(
|
64
|
+
"Precision Recall Curve is only supported for binary classification models"
|
65
|
+
)
|
82
66
|
|
83
|
-
|
67
|
+
precision, recall, _ = precision_recall_curve(y_true, dataset.y_prob(model))
|
84
68
|
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
69
|
+
return go.Figure(
|
70
|
+
data=[
|
71
|
+
go.Scatter(
|
72
|
+
x=recall,
|
73
|
+
y=precision,
|
74
|
+
mode="lines",
|
75
|
+
name="Precision-Recall Curve",
|
76
|
+
line=dict(color="#DE257E"),
|
77
|
+
)
|
78
|
+
],
|
79
|
+
layout=go.Layout(
|
93
80
|
title="Precision-Recall Curve",
|
94
81
|
xaxis=dict(title="Recall"),
|
95
82
|
yaxis=dict(title="Precision"),
|
96
|
-
)
|
97
|
-
|
98
|
-
fig = go.Figure(data=[trace], layout=layout)
|
99
|
-
|
100
|
-
return self.cache_results(
|
101
|
-
metric_value={
|
102
|
-
"precision": precision,
|
103
|
-
"recall": recall,
|
104
|
-
"thresholds": pr_thresholds,
|
105
|
-
},
|
106
|
-
figures=[
|
107
|
-
Figure(
|
108
|
-
for_object=self,
|
109
|
-
key="pr_curve",
|
110
|
-
figure=fig,
|
111
|
-
)
|
112
|
-
],
|
113
|
-
)
|
83
|
+
),
|
84
|
+
)
|