validmind 2.5.25__py3-none-any.whl → 2.6.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +8 -17
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +66 -85
- validmind/ai/test_result_description/context.py +2 -2
- validmind/ai/utils.py +26 -1
- validmind/api_client.py +43 -79
- validmind/client.py +5 -7
- validmind/client_config.py +1 -1
- validmind/datasets/__init__.py +1 -1
- validmind/datasets/classification/customer_churn.py +7 -5
- validmind/datasets/nlp/__init__.py +2 -2
- validmind/errors.py +6 -10
- validmind/html_templates/content_blocks.py +18 -16
- validmind/logging.py +21 -16
- validmind/tests/__init__.py +28 -5
- validmind/tests/__types__.py +186 -170
- validmind/tests/_store.py +7 -21
- validmind/tests/comparison.py +362 -0
- validmind/tests/data_validation/ACFandPACFPlot.py +44 -73
- validmind/tests/data_validation/ADF.py +49 -83
- validmind/tests/data_validation/AutoAR.py +59 -96
- validmind/tests/data_validation/AutoMA.py +59 -96
- validmind/tests/data_validation/AutoStationarity.py +66 -114
- validmind/tests/data_validation/ClassImbalance.py +48 -117
- validmind/tests/data_validation/DatasetDescription.py +180 -209
- validmind/tests/data_validation/DatasetSplit.py +50 -75
- validmind/tests/data_validation/DescriptiveStatistics.py +59 -85
- validmind/tests/data_validation/{DFGLSArch.py → DickeyFullerGLS.py} +44 -76
- validmind/tests/data_validation/Duplicates.py +21 -90
- validmind/tests/data_validation/EngleGrangerCoint.py +53 -75
- validmind/tests/data_validation/HighCardinality.py +32 -80
- validmind/tests/data_validation/HighPearsonCorrelation.py +29 -97
- validmind/tests/data_validation/IQROutliersBarPlot.py +63 -94
- validmind/tests/data_validation/IQROutliersTable.py +40 -80
- validmind/tests/data_validation/IsolationForestOutliers.py +41 -63
- validmind/tests/data_validation/KPSS.py +33 -81
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +47 -95
- validmind/tests/data_validation/MissingValues.py +17 -58
- validmind/tests/data_validation/MissingValuesBarPlot.py +61 -87
- validmind/tests/data_validation/PhillipsPerronArch.py +56 -79
- validmind/tests/data_validation/RollingStatsPlot.py +50 -81
- validmind/tests/data_validation/SeasonalDecompose.py +102 -184
- validmind/tests/data_validation/Skewness.py +27 -64
- validmind/tests/data_validation/SpreadPlot.py +34 -57
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +46 -65
- validmind/tests/data_validation/TabularDateTimeHistograms.py +23 -45
- validmind/tests/data_validation/TabularNumericalHistograms.py +27 -46
- validmind/tests/data_validation/TargetRateBarPlots.py +54 -93
- validmind/tests/data_validation/TimeSeriesFrequency.py +48 -133
- validmind/tests/data_validation/TimeSeriesHistogram.py +24 -3
- validmind/tests/data_validation/TimeSeriesLinePlot.py +29 -47
- validmind/tests/data_validation/TimeSeriesMissingValues.py +59 -135
- validmind/tests/data_validation/TimeSeriesOutliers.py +54 -171
- validmind/tests/data_validation/TooManyZeroValues.py +21 -70
- validmind/tests/data_validation/UniqueRows.py +23 -62
- validmind/tests/data_validation/WOEBinPlots.py +83 -109
- validmind/tests/data_validation/WOEBinTable.py +28 -69
- validmind/tests/data_validation/ZivotAndrewsArch.py +33 -75
- validmind/tests/data_validation/nlp/CommonWords.py +49 -57
- validmind/tests/data_validation/nlp/Hashtags.py +27 -49
- validmind/tests/data_validation/nlp/LanguageDetection.py +7 -13
- validmind/tests/data_validation/nlp/Mentions.py +32 -63
- validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +89 -14
- validmind/tests/data_validation/nlp/Punctuations.py +63 -47
- validmind/tests/data_validation/nlp/Sentiment.py +4 -0
- validmind/tests/data_validation/nlp/StopWords.py +62 -91
- validmind/tests/data_validation/nlp/TextDescription.py +116 -159
- validmind/tests/data_validation/nlp/Toxicity.py +12 -4
- validmind/tests/decorator.py +33 -242
- validmind/tests/load.py +212 -153
- validmind/tests/model_validation/BertScore.py +13 -7
- validmind/tests/model_validation/BleuScore.py +4 -0
- validmind/tests/model_validation/ClusterSizeDistribution.py +24 -47
- validmind/tests/model_validation/ContextualRecall.py +3 -0
- validmind/tests/model_validation/FeaturesAUC.py +43 -74
- validmind/tests/model_validation/MeteorScore.py +3 -0
- validmind/tests/model_validation/RegardScore.py +5 -1
- validmind/tests/model_validation/RegressionResidualsPlot.py +54 -75
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +10 -33
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +11 -29
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +19 -31
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +40 -49
- validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +29 -15
- validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +25 -11
- validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +28 -13
- validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +67 -38
- validmind/tests/model_validation/embeddings/utils.py +53 -0
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +37 -32
- validmind/tests/model_validation/ragas/{AspectCritique.py → AspectCritic.py} +33 -27
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +44 -41
- validmind/tests/model_validation/ragas/ContextPrecision.py +40 -35
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +133 -0
- validmind/tests/model_validation/ragas/ContextRecall.py +40 -35
- validmind/tests/model_validation/ragas/Faithfulness.py +42 -30
- validmind/tests/model_validation/ragas/NoiseSensitivity.py +59 -35
- validmind/tests/model_validation/ragas/{AnswerRelevance.py → ResponseRelevancy.py} +52 -41
- validmind/tests/model_validation/ragas/{AnswerSimilarity.py → SemanticSimilarity.py} +39 -34
- validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +13 -16
- validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +13 -16
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +51 -89
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +31 -61
- validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +118 -83
- validmind/tests/model_validation/sklearn/CompletenessScore.py +13 -16
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +62 -94
- validmind/tests/model_validation/sklearn/FeatureImportance.py +7 -8
- validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +12 -15
- validmind/tests/model_validation/sklearn/HomogeneityScore.py +12 -15
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +23 -53
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +60 -74
- validmind/tests/model_validation/sklearn/MinimumAccuracy.py +16 -84
- validmind/tests/model_validation/sklearn/MinimumF1Score.py +22 -72
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +29 -78
- validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +52 -82
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +51 -145
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +60 -78
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +130 -172
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +26 -55
- validmind/tests/model_validation/sklearn/ROCCurve.py +43 -77
- validmind/tests/model_validation/sklearn/RegressionPerformance.py +41 -94
- validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +47 -136
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +164 -208
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +54 -99
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +50 -124
- validmind/tests/model_validation/sklearn/VMeasure.py +12 -15
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +225 -281
- validmind/tests/model_validation/statsmodels/AutoARIMA.py +40 -45
- validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +22 -47
- validmind/tests/model_validation/statsmodels/Lilliefors.py +17 -28
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +37 -81
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +37 -105
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +62 -166
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +57 -119
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +20 -57
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +47 -80
- validmind/tests/ongoing_monitoring/PredictionCorrelation.py +2 -0
- validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +4 -2
- validmind/tests/output.py +120 -0
- validmind/tests/prompt_validation/Bias.py +55 -98
- validmind/tests/prompt_validation/Clarity.py +56 -99
- validmind/tests/prompt_validation/Conciseness.py +63 -101
- validmind/tests/prompt_validation/Delimitation.py +48 -89
- validmind/tests/prompt_validation/NegativeInstruction.py +62 -96
- validmind/tests/prompt_validation/Robustness.py +80 -121
- validmind/tests/prompt_validation/Specificity.py +61 -95
- validmind/tests/prompt_validation/ai_powered_test.py +2 -2
- validmind/tests/run.py +314 -496
- validmind/tests/test_providers.py +109 -79
- validmind/tests/utils.py +91 -0
- validmind/unit_metrics/__init__.py +16 -155
- validmind/unit_metrics/classification/F1.py +1 -0
- validmind/unit_metrics/classification/Precision.py +1 -0
- validmind/unit_metrics/classification/ROC_AUC.py +1 -0
- validmind/unit_metrics/classification/Recall.py +1 -0
- validmind/unit_metrics/regression/AdjustedRSquaredScore.py +1 -0
- validmind/unit_metrics/regression/GiniCoefficient.py +1 -0
- validmind/unit_metrics/regression/HuberLoss.py +1 -0
- validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +1 -0
- validmind/unit_metrics/regression/MeanAbsoluteError.py +1 -0
- validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +1 -0
- validmind/unit_metrics/regression/MeanBiasDeviation.py +1 -0
- validmind/unit_metrics/regression/MeanSquaredError.py +1 -0
- validmind/unit_metrics/regression/QuantileLoss.py +1 -0
- validmind/unit_metrics/regression/RSquaredScore.py +2 -1
- validmind/unit_metrics/regression/RootMeanSquaredError.py +1 -0
- validmind/utils.py +66 -17
- validmind/vm_models/__init__.py +2 -17
- validmind/vm_models/dataset/dataset.py +31 -4
- validmind/vm_models/figure.py +7 -37
- validmind/vm_models/model.py +3 -0
- validmind/vm_models/result/__init__.py +7 -0
- validmind/vm_models/result/result.jinja +21 -0
- validmind/vm_models/result/result.py +337 -0
- validmind/vm_models/result/utils.py +160 -0
- validmind/vm_models/test_suite/runner.py +16 -54
- validmind/vm_models/test_suite/summary.py +3 -3
- validmind/vm_models/test_suite/test.py +43 -77
- validmind/vm_models/test_suite/test_suite.py +8 -40
- validmind-2.6.8.dist-info/METADATA +137 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/RECORD +182 -189
- validmind/tests/data_validation/AutoSeasonality.py +0 -190
- validmind/tests/metadata.py +0 -59
- validmind/tests/model_validation/embeddings/StabilityAnalysis.py +0 -176
- validmind/tests/model_validation/ragas/ContextUtilization.py +0 -161
- validmind/tests/model_validation/sklearn/ClusterPerformance.py +0 -80
- validmind/unit_metrics/composite.py +0 -238
- validmind/vm_models/test/metric.py +0 -98
- validmind/vm_models/test/metric_result.py +0 -61
- validmind/vm_models/test/output_template.py +0 -55
- validmind/vm_models/test/result_summary.py +0 -76
- validmind/vm_models/test/result_wrapper.py +0 -488
- validmind/vm_models/test/test.py +0 -103
- validmind/vm_models/test/threshold_test.py +0 -106
- validmind/vm_models/test/threshold_test_result.py +0 -75
- validmind/vm_models/test_context.py +0 -259
- validmind-2.5.25.dist-info/METADATA +0 -118
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/LICENSE +0 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/WHEEL +0 -0
- {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/entry_points.txt +0 -0
@@ -3,22 +3,122 @@
|
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
5
|
import warnings
|
6
|
-
from
|
6
|
+
from warnings import filters as _warnings_filters
|
7
7
|
|
8
8
|
import matplotlib.pyplot as plt
|
9
9
|
import numpy as np
|
10
10
|
import shap
|
11
11
|
|
12
|
+
from validmind import tags, tasks
|
12
13
|
from validmind.errors import UnsupportedModelForSHAPError
|
13
14
|
from validmind.logging import get_logger
|
14
15
|
from validmind.models import CatBoostModel, SKlearnModel, StatsModelsModel
|
15
|
-
from validmind.vm_models import
|
16
|
+
from validmind.vm_models import VMDataset, VMModel
|
16
17
|
|
17
18
|
logger = get_logger(__name__)
|
18
19
|
|
19
20
|
|
20
|
-
|
21
|
-
|
21
|
+
def select_shap_values(shap_values, class_of_interest):
|
22
|
+
"""Selects SHAP values for binary or multiclass classification.
|
23
|
+
|
24
|
+
For regression models, returns the SHAP values directly as there are no classes.
|
25
|
+
|
26
|
+
Args:
|
27
|
+
shap_values: The SHAP values returned by the SHAP explainer. For multiclass
|
28
|
+
classification, this will be a list where each element corresponds to a class.
|
29
|
+
For regression, this will be a single array of SHAP values.
|
30
|
+
class_of_interest: The class index for which to retrieve SHAP values. If None
|
31
|
+
(default), the function will assume binary classification and use class 1
|
32
|
+
by default.
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
The SHAP values for the specified class (classification) or for the regression
|
36
|
+
output.
|
37
|
+
|
38
|
+
Raises:
|
39
|
+
ValueError: If class_of_interest is specified and is out of bounds for the
|
40
|
+
number of classes.
|
41
|
+
"""
|
42
|
+
if not isinstance(shap_values, list):
|
43
|
+
# For regression, return the SHAP values as they are
|
44
|
+
# TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
|
45
|
+
# logger.info("Returning SHAP values as-is.")
|
46
|
+
return shap_values
|
47
|
+
|
48
|
+
num_classes = len(shap_values)
|
49
|
+
|
50
|
+
# Default to class 1 for binary classification where no class is specified
|
51
|
+
if num_classes == 2 and class_of_interest is None:
|
52
|
+
logger.debug("Using SHAP values for class 1 (positive class).")
|
53
|
+
return shap_values[1]
|
54
|
+
|
55
|
+
# Otherwise, use the specified class_of_interest
|
56
|
+
if (
|
57
|
+
class_of_interest is None
|
58
|
+
or class_of_interest < 0
|
59
|
+
or class_of_interest >= num_classes
|
60
|
+
):
|
61
|
+
raise ValueError(
|
62
|
+
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
|
63
|
+
)
|
64
|
+
|
65
|
+
logger.debug(f"Using SHAP values for class {class_of_interest}.")
|
66
|
+
return shap_values[class_of_interest]
|
67
|
+
|
68
|
+
|
69
|
+
def generate_shap_plot(type_, shap_values, x_test):
|
70
|
+
"""Plots two types of SHAP global importance (SHAP).
|
71
|
+
|
72
|
+
Args:
|
73
|
+
type_: The type of SHAP plot to generate. Must be "mean" or "summary".
|
74
|
+
shap_values: The SHAP values to plot.
|
75
|
+
x_test: The test data used to generate the SHAP values.
|
76
|
+
|
77
|
+
Returns:
|
78
|
+
The generated plot.
|
79
|
+
"""
|
80
|
+
ax = plt.axes()
|
81
|
+
ax.set_facecolor("white")
|
82
|
+
|
83
|
+
if type_ == "mean":
|
84
|
+
# Calculate the mean absolute SHAP value for each feature
|
85
|
+
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
86
|
+
# Find the maximum mean absolute SHAP value
|
87
|
+
max_shap_value = np.max(mean_abs_shap)
|
88
|
+
# Normalize all SHAP values based on the top feature
|
89
|
+
shap_values = shap_values / max_shap_value * 100
|
90
|
+
|
91
|
+
shap.summary_plot(shap_values, x_test, show=False, plot_type="bar")
|
92
|
+
|
93
|
+
# Customize the plot using matplotlib
|
94
|
+
plt.xlabel("Normalized SHAP Value (Percentage)", fontsize=13)
|
95
|
+
plt.ylabel("Features", fontsize=13)
|
96
|
+
plt.title("Normalized Feature Importance", fontsize=13)
|
97
|
+
else:
|
98
|
+
shap.summary_plot(shap_values, x_test, show=False)
|
99
|
+
|
100
|
+
fig = plt.gcf()
|
101
|
+
|
102
|
+
plt.close()
|
103
|
+
|
104
|
+
return fig
|
105
|
+
|
106
|
+
|
107
|
+
@tags(
|
108
|
+
"sklearn",
|
109
|
+
"binary_classification",
|
110
|
+
"multiclass_classification",
|
111
|
+
"feature_importance",
|
112
|
+
"visualization",
|
113
|
+
)
|
114
|
+
@tasks("classification", "text_classification")
|
115
|
+
def SHAPGlobalImportance(
|
116
|
+
model: VMModel,
|
117
|
+
dataset: VMDataset,
|
118
|
+
kernel_explainer_samples: int = 10,
|
119
|
+
tree_or_linear_explainer_samples: int = 200,
|
120
|
+
class_of_interest: int = None,
|
121
|
+
):
|
22
122
|
"""
|
23
123
|
Evaluates and visualizes global feature importance using SHAP values for model explanation and risk identification.
|
24
124
|
|
@@ -44,7 +144,6 @@ class SHAPGlobalImportance(Metric):
|
|
44
144
|
represents a Shapley value for a certain feature in a specific case. The vertical axis is denoted by the feature
|
45
145
|
whereas the horizontal one corresponds to the Shapley value. A color gradient indicates the value of the feature,
|
46
146
|
gradually changing from low to high. Features are systematically organized in accordance with their importance.
|
47
|
-
These plots are generated by the function `_generate_shap_plot()`.
|
48
147
|
|
49
148
|
### Signs of High Risk
|
50
149
|
|
@@ -64,213 +163,70 @@ class SHAPGlobalImportance(Metric):
|
|
64
163
|
- High-dimensional data can convolute interpretations.
|
65
164
|
- Associating importance with tangible real-world impact still involves a certain degree of subjectivity.
|
66
165
|
"""
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
"sklearn",
|
73
|
-
"binary_classification",
|
74
|
-
"multiclass_classification",
|
75
|
-
"feature_importance",
|
76
|
-
"visualization",
|
77
|
-
]
|
78
|
-
default_params = {
|
79
|
-
"kernel_explainer_samples": 10,
|
80
|
-
"tree_or_linear_explainer_samples": 200,
|
81
|
-
"class_of_interest": None,
|
82
|
-
}
|
83
|
-
|
84
|
-
def _generate_shap_plot(self, type_, shap_values, x_test):
|
85
|
-
"""
|
86
|
-
Plots two types of SHAP global importance (SHAP).
|
87
|
-
:params type: mean, summary
|
88
|
-
:params shap_values: a matrix
|
89
|
-
:params x_test:
|
90
|
-
"""
|
91
|
-
plt.close("all")
|
92
|
-
|
93
|
-
# preserve styles
|
94
|
-
# mpl.rcParams["grid.color"] = "#CCC"
|
95
|
-
ax = plt.axes()
|
96
|
-
ax.set_facecolor("white")
|
97
|
-
|
98
|
-
summary_plot_extra_args = {}
|
99
|
-
if type_ == "mean":
|
100
|
-
# Calculate the mean absolute SHAP value for each feature
|
101
|
-
mean_abs_shap = np.abs(shap_values).mean(axis=0)
|
102
|
-
|
103
|
-
# Find the maximum mean absolute SHAP value
|
104
|
-
max_shap_value = np.max(mean_abs_shap)
|
105
|
-
|
106
|
-
# Normalize all SHAP values based on the top feature
|
107
|
-
shap_values = (
|
108
|
-
shap_values / max_shap_value * 100
|
109
|
-
) # scaling factor to make the top feature 100%
|
110
|
-
summary_plot_extra_args = {"plot_type": "bar"}
|
111
|
-
|
112
|
-
shap.summary_plot(
|
113
|
-
shap_values, x_test, show=False, **summary_plot_extra_args
|
114
|
-
)
|
115
|
-
|
116
|
-
# Customize the plot using matplotlib
|
117
|
-
plt.xlabel("Normalized SHAP Value (Percentage)", fontsize=13)
|
118
|
-
plt.ylabel("Features", fontsize=13)
|
119
|
-
plt.title("Normalized Feature Importance", fontsize=13)
|
120
|
-
else:
|
121
|
-
shap.summary_plot(
|
122
|
-
shap_values, x_test, show=False, **summary_plot_extra_args
|
123
|
-
)
|
124
|
-
|
125
|
-
figure = plt.gcf()
|
126
|
-
# avoid displaying on notebooks and clears the canvas for the next plot
|
127
|
-
plt.close()
|
128
|
-
|
129
|
-
return Figure(
|
130
|
-
for_object=self,
|
131
|
-
figure=figure,
|
132
|
-
key=f"shap:{type_}",
|
133
|
-
metadata={"type": type_},
|
166
|
+
if not isinstance(model, SKlearnModel) or isinstance(
|
167
|
+
model, (CatBoostModel, StatsModelsModel)
|
168
|
+
):
|
169
|
+
raise UnsupportedModelForSHAPError(
|
170
|
+
f"Model {model.class_} is not supported for SHAP importance."
|
134
171
|
)
|
135
172
|
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
model_class
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
explainer = shap.TreeExplainer(trained_model)
|
159
|
-
elif (
|
160
|
-
model_class == "LogisticRegression"
|
161
|
-
or model_class == "XGBRegressor"
|
162
|
-
or model_class == "LinearRegression"
|
163
|
-
or model_class == "LinearSVC"
|
164
|
-
):
|
165
|
-
explainer = shap.LinearExplainer(trained_model, self.inputs.dataset.x)
|
166
|
-
elif model_class == "SVC":
|
167
|
-
# KernelExplainer is slow so we use shap.sample to speed it up
|
168
|
-
explainer = shap.KernelExplainer(
|
169
|
-
trained_model.predict,
|
170
|
-
shap.sample(
|
171
|
-
self.inputs.dataset.x,
|
172
|
-
self.params["kernel_explainer_samples"],
|
173
|
-
),
|
174
|
-
)
|
175
|
-
else:
|
176
|
-
model_class = "<ExternalModel>" if model_class is None else model_class
|
177
|
-
raise UnsupportedModelForSHAPError(
|
178
|
-
f"Model {model_class} not supported for SHAP importance."
|
179
|
-
)
|
180
|
-
|
173
|
+
model_class = model.class_
|
174
|
+
|
175
|
+
# the shap library generates a bunch of annoying warnings that we don't care about
|
176
|
+
warnings.filterwarnings("ignore", category=UserWarning)
|
177
|
+
|
178
|
+
if (
|
179
|
+
model_class == "XGBClassifier"
|
180
|
+
or model_class == "RandomForestClassifier"
|
181
|
+
or model_class == "CatBoostClassifier"
|
182
|
+
or model_class == "DecisionTreeClassifier"
|
183
|
+
or model_class == "RandomForestRegressor"
|
184
|
+
or model_class == "GradientBoostingRegressor"
|
185
|
+
):
|
186
|
+
explainer = shap.TreeExplainer(model.model)
|
187
|
+
elif (
|
188
|
+
model_class == "LogisticRegression"
|
189
|
+
or model_class == "XGBRegressor"
|
190
|
+
or model_class == "LinearRegression"
|
191
|
+
or model_class == "LinearSVC"
|
192
|
+
):
|
193
|
+
explainer = shap.LinearExplainer(model.model, dataset.x)
|
194
|
+
elif model_class == "SVC":
|
181
195
|
# KernelExplainer is slow so we use shap.sample to speed it up
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
shap_values = explainer.shap_values(shap_sample)
|
196
|
-
|
197
|
-
# Select the SHAP values for the specified class (classification) or for the regression output.
|
198
|
-
class_of_interest = self.params["class_of_interest"]
|
199
|
-
shap_values = _select_shap_values(shap_values, class_of_interest)
|
200
|
-
|
201
|
-
figures = [
|
202
|
-
self._generate_shap_plot("mean", shap_values, shap_sample),
|
203
|
-
self._generate_shap_plot("summary", shap_values, shap_sample),
|
204
|
-
]
|
205
|
-
|
206
|
-
# restore warnings
|
207
|
-
warnings.filterwarnings("default", category=UserWarning)
|
208
|
-
|
209
|
-
return self.cache_results(figures=figures)
|
210
|
-
|
211
|
-
def test(self):
|
212
|
-
"""Unit Test for SHAP Global Importance Metric"""
|
213
|
-
# Verify that the result object is not None
|
214
|
-
assert self.result is not None
|
215
|
-
|
216
|
-
# Verify that there are exactly two figures in the figures list
|
217
|
-
assert len(self.result.figures) == 2
|
218
|
-
|
219
|
-
# Verify that each figure is an instance of Figure and has the correct metadata type
|
220
|
-
for fig_num, type_ in enumerate(["mean", "summary"], start=1):
|
221
|
-
assert isinstance(self.result.figures[fig_num - 1], Figure)
|
222
|
-
assert self.result.figures[fig_num - 1].metadata["type"] == type_
|
223
|
-
|
224
|
-
|
225
|
-
def _select_shap_values(shap_values, class_of_interest=None):
|
226
|
-
"""
|
227
|
-
Selects SHAP values for binary or multiclass classification. For regression models,
|
228
|
-
returns the SHAP values directly as there are no classes.
|
196
|
+
explainer = shap.KernelExplainer(
|
197
|
+
model.model.predict,
|
198
|
+
shap.sample(
|
199
|
+
dataset.x,
|
200
|
+
kernel_explainer_samples,
|
201
|
+
),
|
202
|
+
)
|
203
|
+
else:
|
204
|
+
model_class = "<ExternalModel>" if model_class is None else model_class
|
205
|
+
raise UnsupportedModelForSHAPError(
|
206
|
+
f"Model {model_class} not supported for SHAP importance."
|
207
|
+
)
|
229
208
|
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
209
|
+
# KernelExplainer is slow so we use shap.sample to speed it up
|
210
|
+
if isinstance(explainer, shap.KernelExplainer):
|
211
|
+
shap_sample = shap.sample(
|
212
|
+
dataset.x,
|
213
|
+
kernel_explainer_samples,
|
214
|
+
)
|
215
|
+
else:
|
216
|
+
shap_sample = dataset.x_df().sample(
|
217
|
+
min(
|
218
|
+
tree_or_linear_explainer_samples,
|
219
|
+
dataset.x_df().shape[0],
|
220
|
+
)
|
221
|
+
)
|
236
222
|
|
237
|
-
|
238
|
-
|
239
|
-
will assume binary classification and use class 1 by default.
|
223
|
+
shap_values = explainer.shap_values(shap_sample)
|
224
|
+
shap_values = select_shap_values(shap_values, class_of_interest)
|
240
225
|
|
241
|
-
|
242
|
-
|
243
|
-
numpy.ndarray
|
244
|
-
The SHAP values for the specified class (classification) or for the regression output.
|
226
|
+
# restore warnings
|
227
|
+
_warnings_filters.pop(0)
|
245
228
|
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
250
|
-
"""
|
251
|
-
# Check if we are dealing with a multiclass classification
|
252
|
-
if isinstance(shap_values, list):
|
253
|
-
num_classes = len(shap_values)
|
254
|
-
|
255
|
-
# Default to class 1 for binary classification
|
256
|
-
if num_classes == 2 and class_of_interest is None:
|
257
|
-
logger.info(
|
258
|
-
"Binary classification detected: using SHAP values for class 1 (positive class)."
|
259
|
-
)
|
260
|
-
return shap_values[1]
|
261
|
-
else:
|
262
|
-
# Multiclass classification: use the specified class_of_interest
|
263
|
-
if class_of_interest is not None and 0 <= class_of_interest < num_classes:
|
264
|
-
logger.info(
|
265
|
-
f"Multiclass classification: using SHAP values for class {class_of_interest}."
|
266
|
-
)
|
267
|
-
return shap_values[class_of_interest]
|
268
|
-
else:
|
269
|
-
raise ValueError(
|
270
|
-
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
|
271
|
-
)
|
272
|
-
else:
|
273
|
-
# For regression, return the SHAP values as they are
|
274
|
-
# TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
|
275
|
-
# logger.info("Regression model detected: returning SHAP values as-is.")
|
276
|
-
return shap_values
|
229
|
+
return (
|
230
|
+
generate_shap_plot("mean", shap_values, shap_sample),
|
231
|
+
generate_shap_plot("summary", shap_values, shap_sample),
|
232
|
+
)
|
@@ -2,23 +2,17 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from dataclasses import dataclass
|
6
|
-
|
7
5
|
import matplotlib.pyplot as plt
|
8
6
|
import numpy as np
|
9
7
|
from sklearn.metrics import silhouette_samples, silhouette_score
|
10
8
|
|
11
|
-
from validmind
|
12
|
-
|
13
|
-
Metric,
|
14
|
-
ResultSummary,
|
15
|
-
ResultTable,
|
16
|
-
ResultTableMetadata,
|
17
|
-
)
|
9
|
+
from validmind import tags, tasks
|
10
|
+
from validmind.vm_models import VMDataset, VMModel
|
18
11
|
|
19
12
|
|
20
|
-
@
|
21
|
-
|
13
|
+
@tags("sklearn", "model_performance")
|
14
|
+
@tasks("clustering")
|
15
|
+
def SilhouettePlot(model: VMModel, dataset: VMDataset):
|
22
16
|
"""
|
23
17
|
Calculates and visualizes Silhouette Score, assessing the degree of data point suitability to its cluster in ML
|
24
18
|
models.
|
@@ -65,93 +59,54 @@ class SilhouettePlot(Metric):
|
|
65
59
|
assignment nuances, so potentially relevant details may be omitted.
|
66
60
|
- Computationally expensive for large datasets, as it requires pairwise distance computations.
|
67
61
|
"""
|
68
|
-
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
"
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
|
88
|
-
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
size_cluster_i = ith_cluster_silhouette_values.shape[0]
|
100
|
-
y_upper = y_lower + size_cluster_i
|
101
|
-
color = plt.cm.viridis(float(i) / num_clusters)
|
102
|
-
ax.fill_betweenx(
|
103
|
-
np.arange(y_lower, y_upper),
|
104
|
-
0,
|
105
|
-
ith_cluster_silhouette_values,
|
106
|
-
facecolor=color,
|
107
|
-
edgecolor=color,
|
108
|
-
alpha=0.7,
|
109
|
-
)
|
110
|
-
|
111
|
-
# Label the silhouette plots with their cluster numbers at the middle
|
112
|
-
ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
|
113
|
-
# Compute the new y_lower for the next plot
|
114
|
-
y_lower = y_upper + 10
|
115
|
-
|
116
|
-
ax.set_title("Silhouette Plot for Clusters")
|
117
|
-
ax.set_xlabel("Silhouette Coefficient Values")
|
118
|
-
ax.set_ylabel("Cluster Label")
|
119
|
-
|
120
|
-
# The vertical line represents the average silhouette score
|
121
|
-
ax.axvline(x=silhouette_avg, color="red", linestyle="--")
|
122
|
-
|
123
|
-
figures = [
|
124
|
-
Figure(
|
125
|
-
for_object=self,
|
126
|
-
key=self.key,
|
127
|
-
figure=fig,
|
128
|
-
)
|
129
|
-
]
|
130
|
-
# Close the figure to prevent it from displaying
|
131
|
-
plt.close(fig)
|
132
|
-
|
133
|
-
return self.cache_results(
|
134
|
-
metric_value={
|
135
|
-
"silhouette_score": {
|
136
|
-
"silhouette_score": silhouette_avg,
|
137
|
-
},
|
138
|
-
},
|
139
|
-
figures=figures,
|
62
|
+
y_pred = dataset.y_pred(model)
|
63
|
+
|
64
|
+
silhouette_avg = silhouette_score(
|
65
|
+
X=dataset.x,
|
66
|
+
labels=y_pred,
|
67
|
+
metric="euclidean",
|
68
|
+
)
|
69
|
+
|
70
|
+
# Calculate silhouette coefficients for each data point
|
71
|
+
sample_silhouette_values = silhouette_samples(dataset.x, y_pred)
|
72
|
+
# Create a silhouette plot
|
73
|
+
fig, ax = plt.subplots()
|
74
|
+
|
75
|
+
y_lower = 10
|
76
|
+
num_clusters = len(np.unique(y_pred))
|
77
|
+
for i in range(num_clusters):
|
78
|
+
# Aggregate the silhouette scores for samples belonging to cluster i
|
79
|
+
ith_cluster_silhouette_values = sample_silhouette_values[y_pred == i]
|
80
|
+
ith_cluster_silhouette_values.sort()
|
81
|
+
|
82
|
+
size_cluster_i = ith_cluster_silhouette_values.shape[0]
|
83
|
+
y_upper = y_lower + size_cluster_i
|
84
|
+
color = plt.cm.viridis(float(i) / num_clusters)
|
85
|
+
ax.fill_betweenx(
|
86
|
+
np.arange(y_lower, y_upper),
|
87
|
+
0,
|
88
|
+
ith_cluster_silhouette_values,
|
89
|
+
facecolor=color,
|
90
|
+
edgecolor=color,
|
91
|
+
alpha=0.7,
|
140
92
|
)
|
141
93
|
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
94
|
+
# Label the silhouette plots with their cluster numbers at the middle
|
95
|
+
ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
|
96
|
+
# Compute the new y_lower for the next plot
|
97
|
+
y_lower = y_upper + 10
|
98
|
+
|
99
|
+
ax.set_title("Silhouette Plot for Clusters")
|
100
|
+
ax.set_xlabel("Silhouette Coefficient Values")
|
101
|
+
ax.set_ylabel("Cluster Label")
|
102
|
+
|
103
|
+
# The vertical line represents the average silhouette score
|
104
|
+
ax.axvline(x=silhouette_avg, color="red", linestyle="--")
|
105
|
+
|
106
|
+
plt.close()
|
107
|
+
|
108
|
+
return [
|
109
|
+
{
|
110
|
+
"Silhouette Score": silhouette_avg,
|
111
|
+
},
|
112
|
+
], fig
|