validmind 2.5.25__py3-none-any.whl → 2.6.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. validmind/__init__.py +8 -17
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +66 -85
  4. validmind/ai/test_result_description/context.py +2 -2
  5. validmind/ai/utils.py +26 -1
  6. validmind/api_client.py +43 -79
  7. validmind/client.py +5 -7
  8. validmind/client_config.py +1 -1
  9. validmind/datasets/__init__.py +1 -1
  10. validmind/datasets/classification/customer_churn.py +7 -5
  11. validmind/datasets/nlp/__init__.py +2 -2
  12. validmind/errors.py +6 -10
  13. validmind/html_templates/content_blocks.py +18 -16
  14. validmind/logging.py +21 -16
  15. validmind/tests/__init__.py +28 -5
  16. validmind/tests/__types__.py +186 -170
  17. validmind/tests/_store.py +7 -21
  18. validmind/tests/comparison.py +362 -0
  19. validmind/tests/data_validation/ACFandPACFPlot.py +44 -73
  20. validmind/tests/data_validation/ADF.py +49 -83
  21. validmind/tests/data_validation/AutoAR.py +59 -96
  22. validmind/tests/data_validation/AutoMA.py +59 -96
  23. validmind/tests/data_validation/AutoStationarity.py +66 -114
  24. validmind/tests/data_validation/ClassImbalance.py +48 -117
  25. validmind/tests/data_validation/DatasetDescription.py +180 -209
  26. validmind/tests/data_validation/DatasetSplit.py +50 -75
  27. validmind/tests/data_validation/DescriptiveStatistics.py +59 -85
  28. validmind/tests/data_validation/{DFGLSArch.py → DickeyFullerGLS.py} +44 -76
  29. validmind/tests/data_validation/Duplicates.py +21 -90
  30. validmind/tests/data_validation/EngleGrangerCoint.py +53 -75
  31. validmind/tests/data_validation/HighCardinality.py +32 -80
  32. validmind/tests/data_validation/HighPearsonCorrelation.py +29 -97
  33. validmind/tests/data_validation/IQROutliersBarPlot.py +63 -94
  34. validmind/tests/data_validation/IQROutliersTable.py +40 -80
  35. validmind/tests/data_validation/IsolationForestOutliers.py +41 -63
  36. validmind/tests/data_validation/KPSS.py +33 -81
  37. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +47 -95
  38. validmind/tests/data_validation/MissingValues.py +17 -58
  39. validmind/tests/data_validation/MissingValuesBarPlot.py +61 -87
  40. validmind/tests/data_validation/PhillipsPerronArch.py +56 -79
  41. validmind/tests/data_validation/RollingStatsPlot.py +50 -81
  42. validmind/tests/data_validation/SeasonalDecompose.py +102 -184
  43. validmind/tests/data_validation/Skewness.py +27 -64
  44. validmind/tests/data_validation/SpreadPlot.py +34 -57
  45. validmind/tests/data_validation/TabularCategoricalBarPlots.py +46 -65
  46. validmind/tests/data_validation/TabularDateTimeHistograms.py +23 -45
  47. validmind/tests/data_validation/TabularNumericalHistograms.py +27 -46
  48. validmind/tests/data_validation/TargetRateBarPlots.py +54 -93
  49. validmind/tests/data_validation/TimeSeriesFrequency.py +48 -133
  50. validmind/tests/data_validation/TimeSeriesHistogram.py +24 -3
  51. validmind/tests/data_validation/TimeSeriesLinePlot.py +29 -47
  52. validmind/tests/data_validation/TimeSeriesMissingValues.py +59 -135
  53. validmind/tests/data_validation/TimeSeriesOutliers.py +54 -171
  54. validmind/tests/data_validation/TooManyZeroValues.py +21 -70
  55. validmind/tests/data_validation/UniqueRows.py +23 -62
  56. validmind/tests/data_validation/WOEBinPlots.py +83 -109
  57. validmind/tests/data_validation/WOEBinTable.py +28 -69
  58. validmind/tests/data_validation/ZivotAndrewsArch.py +33 -75
  59. validmind/tests/data_validation/nlp/CommonWords.py +49 -57
  60. validmind/tests/data_validation/nlp/Hashtags.py +27 -49
  61. validmind/tests/data_validation/nlp/LanguageDetection.py +7 -13
  62. validmind/tests/data_validation/nlp/Mentions.py +32 -63
  63. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +89 -14
  64. validmind/tests/data_validation/nlp/Punctuations.py +63 -47
  65. validmind/tests/data_validation/nlp/Sentiment.py +4 -0
  66. validmind/tests/data_validation/nlp/StopWords.py +62 -91
  67. validmind/tests/data_validation/nlp/TextDescription.py +116 -159
  68. validmind/tests/data_validation/nlp/Toxicity.py +12 -4
  69. validmind/tests/decorator.py +33 -242
  70. validmind/tests/load.py +212 -153
  71. validmind/tests/model_validation/BertScore.py +13 -7
  72. validmind/tests/model_validation/BleuScore.py +4 -0
  73. validmind/tests/model_validation/ClusterSizeDistribution.py +24 -47
  74. validmind/tests/model_validation/ContextualRecall.py +3 -0
  75. validmind/tests/model_validation/FeaturesAUC.py +43 -74
  76. validmind/tests/model_validation/MeteorScore.py +3 -0
  77. validmind/tests/model_validation/RegardScore.py +5 -1
  78. validmind/tests/model_validation/RegressionResidualsPlot.py +54 -75
  79. validmind/tests/model_validation/embeddings/ClusterDistribution.py +10 -33
  80. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +11 -29
  81. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +19 -31
  82. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +40 -49
  83. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +29 -15
  84. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +25 -11
  85. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +28 -13
  86. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +67 -38
  87. validmind/tests/model_validation/embeddings/utils.py +53 -0
  88. validmind/tests/model_validation/ragas/AnswerCorrectness.py +37 -32
  89. validmind/tests/model_validation/ragas/{AspectCritique.py → AspectCritic.py} +33 -27
  90. validmind/tests/model_validation/ragas/ContextEntityRecall.py +44 -41
  91. validmind/tests/model_validation/ragas/ContextPrecision.py +40 -35
  92. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +133 -0
  93. validmind/tests/model_validation/ragas/ContextRecall.py +40 -35
  94. validmind/tests/model_validation/ragas/Faithfulness.py +42 -30
  95. validmind/tests/model_validation/ragas/NoiseSensitivity.py +59 -35
  96. validmind/tests/model_validation/ragas/{AnswerRelevance.py → ResponseRelevancy.py} +52 -41
  97. validmind/tests/model_validation/ragas/{AnswerSimilarity.py → SemanticSimilarity.py} +39 -34
  98. validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +13 -16
  99. validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +13 -16
  100. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +51 -89
  101. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +31 -61
  102. validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +118 -83
  103. validmind/tests/model_validation/sklearn/CompletenessScore.py +13 -16
  104. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +62 -94
  105. validmind/tests/model_validation/sklearn/FeatureImportance.py +7 -8
  106. validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +12 -15
  107. validmind/tests/model_validation/sklearn/HomogeneityScore.py +12 -15
  108. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +23 -53
  109. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +60 -74
  110. validmind/tests/model_validation/sklearn/MinimumAccuracy.py +16 -84
  111. validmind/tests/model_validation/sklearn/MinimumF1Score.py +22 -72
  112. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +29 -78
  113. validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +52 -82
  114. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +51 -145
  115. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +60 -78
  116. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +130 -172
  117. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +26 -55
  118. validmind/tests/model_validation/sklearn/ROCCurve.py +43 -77
  119. validmind/tests/model_validation/sklearn/RegressionPerformance.py +41 -94
  120. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +47 -136
  121. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +164 -208
  122. validmind/tests/model_validation/sklearn/SilhouettePlot.py +54 -99
  123. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +50 -124
  124. validmind/tests/model_validation/sklearn/VMeasure.py +12 -15
  125. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +225 -281
  126. validmind/tests/model_validation/statsmodels/AutoARIMA.py +40 -45
  127. validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +22 -47
  128. validmind/tests/model_validation/statsmodels/Lilliefors.py +17 -28
  129. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +37 -81
  130. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +37 -105
  131. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +62 -166
  132. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +57 -119
  133. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +20 -57
  134. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +47 -80
  135. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +2 -0
  136. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +4 -2
  137. validmind/tests/output.py +120 -0
  138. validmind/tests/prompt_validation/Bias.py +55 -98
  139. validmind/tests/prompt_validation/Clarity.py +56 -99
  140. validmind/tests/prompt_validation/Conciseness.py +63 -101
  141. validmind/tests/prompt_validation/Delimitation.py +48 -89
  142. validmind/tests/prompt_validation/NegativeInstruction.py +62 -96
  143. validmind/tests/prompt_validation/Robustness.py +80 -121
  144. validmind/tests/prompt_validation/Specificity.py +61 -95
  145. validmind/tests/prompt_validation/ai_powered_test.py +2 -2
  146. validmind/tests/run.py +314 -496
  147. validmind/tests/test_providers.py +109 -79
  148. validmind/tests/utils.py +91 -0
  149. validmind/unit_metrics/__init__.py +16 -155
  150. validmind/unit_metrics/classification/F1.py +1 -0
  151. validmind/unit_metrics/classification/Precision.py +1 -0
  152. validmind/unit_metrics/classification/ROC_AUC.py +1 -0
  153. validmind/unit_metrics/classification/Recall.py +1 -0
  154. validmind/unit_metrics/regression/AdjustedRSquaredScore.py +1 -0
  155. validmind/unit_metrics/regression/GiniCoefficient.py +1 -0
  156. validmind/unit_metrics/regression/HuberLoss.py +1 -0
  157. validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +1 -0
  158. validmind/unit_metrics/regression/MeanAbsoluteError.py +1 -0
  159. validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +1 -0
  160. validmind/unit_metrics/regression/MeanBiasDeviation.py +1 -0
  161. validmind/unit_metrics/regression/MeanSquaredError.py +1 -0
  162. validmind/unit_metrics/regression/QuantileLoss.py +1 -0
  163. validmind/unit_metrics/regression/RSquaredScore.py +2 -1
  164. validmind/unit_metrics/regression/RootMeanSquaredError.py +1 -0
  165. validmind/utils.py +66 -17
  166. validmind/vm_models/__init__.py +2 -17
  167. validmind/vm_models/dataset/dataset.py +31 -4
  168. validmind/vm_models/figure.py +7 -37
  169. validmind/vm_models/model.py +3 -0
  170. validmind/vm_models/result/__init__.py +7 -0
  171. validmind/vm_models/result/result.jinja +21 -0
  172. validmind/vm_models/result/result.py +337 -0
  173. validmind/vm_models/result/utils.py +160 -0
  174. validmind/vm_models/test_suite/runner.py +16 -54
  175. validmind/vm_models/test_suite/summary.py +3 -3
  176. validmind/vm_models/test_suite/test.py +43 -77
  177. validmind/vm_models/test_suite/test_suite.py +8 -40
  178. validmind-2.6.8.dist-info/METADATA +137 -0
  179. {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/RECORD +182 -189
  180. validmind/tests/data_validation/AutoSeasonality.py +0 -190
  181. validmind/tests/metadata.py +0 -59
  182. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +0 -176
  183. validmind/tests/model_validation/ragas/ContextUtilization.py +0 -161
  184. validmind/tests/model_validation/sklearn/ClusterPerformance.py +0 -80
  185. validmind/unit_metrics/composite.py +0 -238
  186. validmind/vm_models/test/metric.py +0 -98
  187. validmind/vm_models/test/metric_result.py +0 -61
  188. validmind/vm_models/test/output_template.py +0 -55
  189. validmind/vm_models/test/result_summary.py +0 -76
  190. validmind/vm_models/test/result_wrapper.py +0 -488
  191. validmind/vm_models/test/test.py +0 -103
  192. validmind/vm_models/test/threshold_test.py +0 -106
  193. validmind/vm_models/test/threshold_test_result.py +0 -75
  194. validmind/vm_models/test_context.py +0 -259
  195. validmind-2.5.25.dist-info/METADATA +0 -118
  196. {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/LICENSE +0 -0
  197. {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/WHEEL +0 -0
  198. {validmind-2.5.25.dist-info → validmind-2.6.8.dist-info}/entry_points.txt +0 -0
@@ -3,22 +3,122 @@
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
5
  import warnings
6
- from dataclasses import dataclass
6
+ from warnings import filters as _warnings_filters
7
7
 
8
8
  import matplotlib.pyplot as plt
9
9
  import numpy as np
10
10
  import shap
11
11
 
12
+ from validmind import tags, tasks
12
13
  from validmind.errors import UnsupportedModelForSHAPError
13
14
  from validmind.logging import get_logger
14
15
  from validmind.models import CatBoostModel, SKlearnModel, StatsModelsModel
15
- from validmind.vm_models import Figure, Metric
16
+ from validmind.vm_models import VMDataset, VMModel
16
17
 
17
18
  logger = get_logger(__name__)
18
19
 
19
20
 
20
- @dataclass
21
- class SHAPGlobalImportance(Metric):
21
+ def select_shap_values(shap_values, class_of_interest):
22
+ """Selects SHAP values for binary or multiclass classification.
23
+
24
+ For regression models, returns the SHAP values directly as there are no classes.
25
+
26
+ Args:
27
+ shap_values: The SHAP values returned by the SHAP explainer. For multiclass
28
+ classification, this will be a list where each element corresponds to a class.
29
+ For regression, this will be a single array of SHAP values.
30
+ class_of_interest: The class index for which to retrieve SHAP values. If None
31
+ (default), the function will assume binary classification and use class 1
32
+ by default.
33
+
34
+ Returns:
35
+ The SHAP values for the specified class (classification) or for the regression
36
+ output.
37
+
38
+ Raises:
39
+ ValueError: If class_of_interest is specified and is out of bounds for the
40
+ number of classes.
41
+ """
42
+ if not isinstance(shap_values, list):
43
+ # For regression, return the SHAP values as they are
44
+ # TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
45
+ # logger.info("Returning SHAP values as-is.")
46
+ return shap_values
47
+
48
+ num_classes = len(shap_values)
49
+
50
+ # Default to class 1 for binary classification where no class is specified
51
+ if num_classes == 2 and class_of_interest is None:
52
+ logger.debug("Using SHAP values for class 1 (positive class).")
53
+ return shap_values[1]
54
+
55
+ # Otherwise, use the specified class_of_interest
56
+ if (
57
+ class_of_interest is None
58
+ or class_of_interest < 0
59
+ or class_of_interest >= num_classes
60
+ ):
61
+ raise ValueError(
62
+ f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
63
+ )
64
+
65
+ logger.debug(f"Using SHAP values for class {class_of_interest}.")
66
+ return shap_values[class_of_interest]
67
+
68
+
69
+ def generate_shap_plot(type_, shap_values, x_test):
70
+ """Plots two types of SHAP global importance (SHAP).
71
+
72
+ Args:
73
+ type_: The type of SHAP plot to generate. Must be "mean" or "summary".
74
+ shap_values: The SHAP values to plot.
75
+ x_test: The test data used to generate the SHAP values.
76
+
77
+ Returns:
78
+ The generated plot.
79
+ """
80
+ ax = plt.axes()
81
+ ax.set_facecolor("white")
82
+
83
+ if type_ == "mean":
84
+ # Calculate the mean absolute SHAP value for each feature
85
+ mean_abs_shap = np.abs(shap_values).mean(axis=0)
86
+ # Find the maximum mean absolute SHAP value
87
+ max_shap_value = np.max(mean_abs_shap)
88
+ # Normalize all SHAP values based on the top feature
89
+ shap_values = shap_values / max_shap_value * 100
90
+
91
+ shap.summary_plot(shap_values, x_test, show=False, plot_type="bar")
92
+
93
+ # Customize the plot using matplotlib
94
+ plt.xlabel("Normalized SHAP Value (Percentage)", fontsize=13)
95
+ plt.ylabel("Features", fontsize=13)
96
+ plt.title("Normalized Feature Importance", fontsize=13)
97
+ else:
98
+ shap.summary_plot(shap_values, x_test, show=False)
99
+
100
+ fig = plt.gcf()
101
+
102
+ plt.close()
103
+
104
+ return fig
105
+
106
+
107
+ @tags(
108
+ "sklearn",
109
+ "binary_classification",
110
+ "multiclass_classification",
111
+ "feature_importance",
112
+ "visualization",
113
+ )
114
+ @tasks("classification", "text_classification")
115
+ def SHAPGlobalImportance(
116
+ model: VMModel,
117
+ dataset: VMDataset,
118
+ kernel_explainer_samples: int = 10,
119
+ tree_or_linear_explainer_samples: int = 200,
120
+ class_of_interest: int = None,
121
+ ):
22
122
  """
23
123
  Evaluates and visualizes global feature importance using SHAP values for model explanation and risk identification.
24
124
 
@@ -44,7 +144,6 @@ class SHAPGlobalImportance(Metric):
44
144
  represents a Shapley value for a certain feature in a specific case. The vertical axis is denoted by the feature
45
145
  whereas the horizontal one corresponds to the Shapley value. A color gradient indicates the value of the feature,
46
146
  gradually changing from low to high. Features are systematically organized in accordance with their importance.
47
- These plots are generated by the function `_generate_shap_plot()`.
48
147
 
49
148
  ### Signs of High Risk
50
149
 
@@ -64,213 +163,70 @@ class SHAPGlobalImportance(Metric):
64
163
  - High-dimensional data can convolute interpretations.
65
164
  - Associating importance with tangible real-world impact still involves a certain degree of subjectivity.
66
165
  """
67
-
68
- name = "shap"
69
- required_inputs = ["model", "dataset"]
70
- tasks = ["classification", "text_classification"]
71
- tags = [
72
- "sklearn",
73
- "binary_classification",
74
- "multiclass_classification",
75
- "feature_importance",
76
- "visualization",
77
- ]
78
- default_params = {
79
- "kernel_explainer_samples": 10,
80
- "tree_or_linear_explainer_samples": 200,
81
- "class_of_interest": None,
82
- }
83
-
84
- def _generate_shap_plot(self, type_, shap_values, x_test):
85
- """
86
- Plots two types of SHAP global importance (SHAP).
87
- :params type: mean, summary
88
- :params shap_values: a matrix
89
- :params x_test:
90
- """
91
- plt.close("all")
92
-
93
- # preserve styles
94
- # mpl.rcParams["grid.color"] = "#CCC"
95
- ax = plt.axes()
96
- ax.set_facecolor("white")
97
-
98
- summary_plot_extra_args = {}
99
- if type_ == "mean":
100
- # Calculate the mean absolute SHAP value for each feature
101
- mean_abs_shap = np.abs(shap_values).mean(axis=0)
102
-
103
- # Find the maximum mean absolute SHAP value
104
- max_shap_value = np.max(mean_abs_shap)
105
-
106
- # Normalize all SHAP values based on the top feature
107
- shap_values = (
108
- shap_values / max_shap_value * 100
109
- ) # scaling factor to make the top feature 100%
110
- summary_plot_extra_args = {"plot_type": "bar"}
111
-
112
- shap.summary_plot(
113
- shap_values, x_test, show=False, **summary_plot_extra_args
114
- )
115
-
116
- # Customize the plot using matplotlib
117
- plt.xlabel("Normalized SHAP Value (Percentage)", fontsize=13)
118
- plt.ylabel("Features", fontsize=13)
119
- plt.title("Normalized Feature Importance", fontsize=13)
120
- else:
121
- shap.summary_plot(
122
- shap_values, x_test, show=False, **summary_plot_extra_args
123
- )
124
-
125
- figure = plt.gcf()
126
- # avoid displaying on notebooks and clears the canvas for the next plot
127
- plt.close()
128
-
129
- return Figure(
130
- for_object=self,
131
- figure=figure,
132
- key=f"shap:{type_}",
133
- metadata={"type": type_},
166
+ if not isinstance(model, SKlearnModel) or isinstance(
167
+ model, (CatBoostModel, StatsModelsModel)
168
+ ):
169
+ raise UnsupportedModelForSHAPError(
170
+ f"Model {model.class_} is not supported for SHAP importance."
134
171
  )
135
172
 
136
- def run(self):
137
- if not isinstance(self.inputs.model, SKlearnModel) or isinstance(
138
- self.inputs.model, (CatBoostModel, StatsModelsModel)
139
- ):
140
- logger.info(f"Skiping SHAP for {self.inputs.model.library} models")
141
- return
142
-
143
- trained_model = self.inputs.model.model
144
- model_class = self.inputs.model.class_
145
-
146
- # the shap library generates a bunch of annoying warnings that we don't care about
147
- warnings.filterwarnings("ignore", category=UserWarning)
148
-
149
- # Any tree based model can go here
150
- if (
151
- model_class == "XGBClassifier"
152
- or model_class == "RandomForestClassifier"
153
- or model_class == "CatBoostClassifier"
154
- or model_class == "DecisionTreeClassifier"
155
- or model_class == "RandomForestRegressor"
156
- or model_class == "GradientBoostingRegressor"
157
- ):
158
- explainer = shap.TreeExplainer(trained_model)
159
- elif (
160
- model_class == "LogisticRegression"
161
- or model_class == "XGBRegressor"
162
- or model_class == "LinearRegression"
163
- or model_class == "LinearSVC"
164
- ):
165
- explainer = shap.LinearExplainer(trained_model, self.inputs.dataset.x)
166
- elif model_class == "SVC":
167
- # KernelExplainer is slow so we use shap.sample to speed it up
168
- explainer = shap.KernelExplainer(
169
- trained_model.predict,
170
- shap.sample(
171
- self.inputs.dataset.x,
172
- self.params["kernel_explainer_samples"],
173
- ),
174
- )
175
- else:
176
- model_class = "<ExternalModel>" if model_class is None else model_class
177
- raise UnsupportedModelForSHAPError(
178
- f"Model {model_class} not supported for SHAP importance."
179
- )
180
-
173
+ model_class = model.class_
174
+
175
+ # the shap library generates a bunch of annoying warnings that we don't care about
176
+ warnings.filterwarnings("ignore", category=UserWarning)
177
+
178
+ if (
179
+ model_class == "XGBClassifier"
180
+ or model_class == "RandomForestClassifier"
181
+ or model_class == "CatBoostClassifier"
182
+ or model_class == "DecisionTreeClassifier"
183
+ or model_class == "RandomForestRegressor"
184
+ or model_class == "GradientBoostingRegressor"
185
+ ):
186
+ explainer = shap.TreeExplainer(model.model)
187
+ elif (
188
+ model_class == "LogisticRegression"
189
+ or model_class == "XGBRegressor"
190
+ or model_class == "LinearRegression"
191
+ or model_class == "LinearSVC"
192
+ ):
193
+ explainer = shap.LinearExplainer(model.model, dataset.x)
194
+ elif model_class == "SVC":
181
195
  # KernelExplainer is slow so we use shap.sample to speed it up
182
- if isinstance(explainer, shap.KernelExplainer):
183
- shap_sample = shap.sample(
184
- self.inputs.dataset.x_df(),
185
- self.params["kernel_explainer_samples"],
186
- )
187
- else:
188
- shap_sample = self.inputs.dataset.x_df().sample(
189
- min(
190
- self.params["tree_or_linear_explainer_samples"],
191
- self.inputs.dataset.x_df().shape[0],
192
- )
193
- )
194
-
195
- shap_values = explainer.shap_values(shap_sample)
196
-
197
- # Select the SHAP values for the specified class (classification) or for the regression output.
198
- class_of_interest = self.params["class_of_interest"]
199
- shap_values = _select_shap_values(shap_values, class_of_interest)
200
-
201
- figures = [
202
- self._generate_shap_plot("mean", shap_values, shap_sample),
203
- self._generate_shap_plot("summary", shap_values, shap_sample),
204
- ]
205
-
206
- # restore warnings
207
- warnings.filterwarnings("default", category=UserWarning)
208
-
209
- return self.cache_results(figures=figures)
210
-
211
- def test(self):
212
- """Unit Test for SHAP Global Importance Metric"""
213
- # Verify that the result object is not None
214
- assert self.result is not None
215
-
216
- # Verify that there are exactly two figures in the figures list
217
- assert len(self.result.figures) == 2
218
-
219
- # Verify that each figure is an instance of Figure and has the correct metadata type
220
- for fig_num, type_ in enumerate(["mean", "summary"], start=1):
221
- assert isinstance(self.result.figures[fig_num - 1], Figure)
222
- assert self.result.figures[fig_num - 1].metadata["type"] == type_
223
-
224
-
225
- def _select_shap_values(shap_values, class_of_interest=None):
226
- """
227
- Selects SHAP values for binary or multiclass classification. For regression models,
228
- returns the SHAP values directly as there are no classes.
196
+ explainer = shap.KernelExplainer(
197
+ model.model.predict,
198
+ shap.sample(
199
+ dataset.x,
200
+ kernel_explainer_samples,
201
+ ),
202
+ )
203
+ else:
204
+ model_class = "<ExternalModel>" if model_class is None else model_class
205
+ raise UnsupportedModelForSHAPError(
206
+ f"Model {model_class} not supported for SHAP importance."
207
+ )
229
208
 
230
- Parameters:
231
- -----------
232
- shap_values : list or numpy.ndarray
233
- The SHAP values returned by the SHAP explainer. For multiclass classification,
234
- this will be a list where each element corresponds to a class. For regression,
235
- this will be a single array of SHAP values.
209
+ # KernelExplainer is slow so we use shap.sample to speed it up
210
+ if isinstance(explainer, shap.KernelExplainer):
211
+ shap_sample = shap.sample(
212
+ dataset.x,
213
+ kernel_explainer_samples,
214
+ )
215
+ else:
216
+ shap_sample = dataset.x_df().sample(
217
+ min(
218
+ tree_or_linear_explainer_samples,
219
+ dataset.x_df().shape[0],
220
+ )
221
+ )
236
222
 
237
- class_of_interest : int, optional
238
- The class index for which to retrieve SHAP values. If None (default), the function
239
- will assume binary classification and use class 1 by default.
223
+ shap_values = explainer.shap_values(shap_sample)
224
+ shap_values = select_shap_values(shap_values, class_of_interest)
240
225
 
241
- Returns:
242
- --------
243
- numpy.ndarray
244
- The SHAP values for the specified class (classification) or for the regression output.
226
+ # restore warnings
227
+ _warnings_filters.pop(0)
245
228
 
246
- Raises:
247
- -------
248
- ValueError
249
- If class_of_interest is specified and is out of bounds for the number of classes.
250
- """
251
- # Check if we are dealing with a multiclass classification
252
- if isinstance(shap_values, list):
253
- num_classes = len(shap_values)
254
-
255
- # Default to class 1 for binary classification
256
- if num_classes == 2 and class_of_interest is None:
257
- logger.info(
258
- "Binary classification detected: using SHAP values for class 1 (positive class)."
259
- )
260
- return shap_values[1]
261
- else:
262
- # Multiclass classification: use the specified class_of_interest
263
- if class_of_interest is not None and 0 <= class_of_interest < num_classes:
264
- logger.info(
265
- f"Multiclass classification: using SHAP values for class {class_of_interest}."
266
- )
267
- return shap_values[class_of_interest]
268
- else:
269
- raise ValueError(
270
- f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
271
- )
272
- else:
273
- # For regression, return the SHAP values as they are
274
- # TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
275
- # logger.info("Regression model detected: returning SHAP values as-is.")
276
- return shap_values
229
+ return (
230
+ generate_shap_plot("mean", shap_values, shap_sample),
231
+ generate_shap_plot("summary", shap_values, shap_sample),
232
+ )
@@ -2,23 +2,17 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
7
5
  import matplotlib.pyplot as plt
8
6
  import numpy as np
9
7
  from sklearn.metrics import silhouette_samples, silhouette_score
10
8
 
11
- from validmind.vm_models import (
12
- Figure,
13
- Metric,
14
- ResultSummary,
15
- ResultTable,
16
- ResultTableMetadata,
17
- )
9
+ from validmind import tags, tasks
10
+ from validmind.vm_models import VMDataset, VMModel
18
11
 
19
12
 
20
- @dataclass
21
- class SilhouettePlot(Metric):
13
+ @tags("sklearn", "model_performance")
14
+ @tasks("clustering")
15
+ def SilhouettePlot(model: VMModel, dataset: VMDataset):
22
16
  """
23
17
  Calculates and visualizes Silhouette Score, assessing the degree of data point suitability to its cluster in ML
24
18
  models.
@@ -65,93 +59,54 @@ class SilhouettePlot(Metric):
65
59
  assignment nuances, so potentially relevant details may be omitted.
66
60
  - Computationally expensive for large datasets, as it requires pairwise distance computations.
67
61
  """
68
-
69
- name = "silhouette_plot"
70
- required_inputs = ["model", "dataset"]
71
- tasks = ["clustering"]
72
- tags = [
73
- "sklearn",
74
- "model_performance",
75
- ]
76
-
77
- def run(self):
78
- y_pred_train = self.inputs.dataset.y_pred(self.inputs.model)
79
- # Calculate the silhouette score
80
- silhouette_avg = silhouette_score(
81
- self.inputs.dataset.x,
82
- y_pred_train,
83
- metric="euclidean",
84
- )
85
- num_clusters = len(np.unique(y_pred_train))
86
- # Calculate silhouette coefficients for each data point
87
- sample_silhouette_values = silhouette_samples(
88
- self.inputs.dataset.x, y_pred_train
89
- )
90
- # Create a silhouette plot
91
- fig, ax = plt.subplots()
92
- y_lower = 10
93
-
94
- for i in range(num_clusters):
95
- # Aggregate the silhouette scores for samples belonging to cluster i
96
- ith_cluster_silhouette_values = sample_silhouette_values[y_pred_train == i]
97
- ith_cluster_silhouette_values.sort()
98
-
99
- size_cluster_i = ith_cluster_silhouette_values.shape[0]
100
- y_upper = y_lower + size_cluster_i
101
- color = plt.cm.viridis(float(i) / num_clusters)
102
- ax.fill_betweenx(
103
- np.arange(y_lower, y_upper),
104
- 0,
105
- ith_cluster_silhouette_values,
106
- facecolor=color,
107
- edgecolor=color,
108
- alpha=0.7,
109
- )
110
-
111
- # Label the silhouette plots with their cluster numbers at the middle
112
- ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
113
- # Compute the new y_lower for the next plot
114
- y_lower = y_upper + 10
115
-
116
- ax.set_title("Silhouette Plot for Clusters")
117
- ax.set_xlabel("Silhouette Coefficient Values")
118
- ax.set_ylabel("Cluster Label")
119
-
120
- # The vertical line represents the average silhouette score
121
- ax.axvline(x=silhouette_avg, color="red", linestyle="--")
122
-
123
- figures = [
124
- Figure(
125
- for_object=self,
126
- key=self.key,
127
- figure=fig,
128
- )
129
- ]
130
- # Close the figure to prevent it from displaying
131
- plt.close(fig)
132
-
133
- return self.cache_results(
134
- metric_value={
135
- "silhouette_score": {
136
- "silhouette_score": silhouette_avg,
137
- },
138
- },
139
- figures=figures,
62
+ y_pred = dataset.y_pred(model)
63
+
64
+ silhouette_avg = silhouette_score(
65
+ X=dataset.x,
66
+ labels=y_pred,
67
+ metric="euclidean",
68
+ )
69
+
70
+ # Calculate silhouette coefficients for each data point
71
+ sample_silhouette_values = silhouette_samples(dataset.x, y_pred)
72
+ # Create a silhouette plot
73
+ fig, ax = plt.subplots()
74
+
75
+ y_lower = 10
76
+ num_clusters = len(np.unique(y_pred))
77
+ for i in range(num_clusters):
78
+ # Aggregate the silhouette scores for samples belonging to cluster i
79
+ ith_cluster_silhouette_values = sample_silhouette_values[y_pred == i]
80
+ ith_cluster_silhouette_values.sort()
81
+
82
+ size_cluster_i = ith_cluster_silhouette_values.shape[0]
83
+ y_upper = y_lower + size_cluster_i
84
+ color = plt.cm.viridis(float(i) / num_clusters)
85
+ ax.fill_betweenx(
86
+ np.arange(y_lower, y_upper),
87
+ 0,
88
+ ith_cluster_silhouette_values,
89
+ facecolor=color,
90
+ edgecolor=color,
91
+ alpha=0.7,
140
92
  )
141
93
 
142
- def summary(self, metric_value):
143
- """
144
- Build one table for summarizing the Silhouette score results
145
- """
146
- silhouette_score = metric_value["silhouette_score"]
147
-
148
- table = []
149
- table.append(silhouette_score)
150
- return ResultSummary(
151
- results=[
152
- ResultTable(
153
- data=table,
154
- metadata=ResultTableMetadata(title="Silhouette Score"),
155
- ),
156
- ]
157
- )
94
+ # Label the silhouette plots with their cluster numbers at the middle
95
+ ax.text(-0.05, y_lower + 0.5 * size_cluster_i, str(i))
96
+ # Compute the new y_lower for the next plot
97
+ y_lower = y_upper + 10
98
+
99
+ ax.set_title("Silhouette Plot for Clusters")
100
+ ax.set_xlabel("Silhouette Coefficient Values")
101
+ ax.set_ylabel("Cluster Label")
102
+
103
+ # The vertical line represents the average silhouette score
104
+ ax.axvline(x=silhouette_avg, color="red", linestyle="--")
105
+
106
+ plt.close()
107
+
108
+ return [
109
+ {
110
+ "Silhouette Score": silhouette_avg,
111
+ },
112
+ ], fig