validmind 2.5.24__py3-none-any.whl → 2.6.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (198) hide show
  1. validmind/__init__.py +8 -17
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +66 -85
  4. validmind/ai/test_result_description/context.py +2 -2
  5. validmind/ai/utils.py +26 -1
  6. validmind/api_client.py +43 -79
  7. validmind/client.py +5 -7
  8. validmind/client_config.py +1 -1
  9. validmind/datasets/__init__.py +1 -1
  10. validmind/datasets/classification/customer_churn.py +7 -5
  11. validmind/datasets/nlp/__init__.py +2 -2
  12. validmind/errors.py +6 -10
  13. validmind/html_templates/content_blocks.py +18 -16
  14. validmind/logging.py +21 -16
  15. validmind/tests/__init__.py +28 -5
  16. validmind/tests/__types__.py +186 -170
  17. validmind/tests/_store.py +7 -21
  18. validmind/tests/comparison.py +362 -0
  19. validmind/tests/data_validation/ACFandPACFPlot.py +44 -73
  20. validmind/tests/data_validation/ADF.py +49 -83
  21. validmind/tests/data_validation/AutoAR.py +59 -96
  22. validmind/tests/data_validation/AutoMA.py +59 -96
  23. validmind/tests/data_validation/AutoStationarity.py +66 -114
  24. validmind/tests/data_validation/ClassImbalance.py +48 -117
  25. validmind/tests/data_validation/DatasetDescription.py +180 -209
  26. validmind/tests/data_validation/DatasetSplit.py +50 -75
  27. validmind/tests/data_validation/DescriptiveStatistics.py +59 -85
  28. validmind/tests/data_validation/{DFGLSArch.py → DickeyFullerGLS.py} +44 -76
  29. validmind/tests/data_validation/Duplicates.py +21 -90
  30. validmind/tests/data_validation/EngleGrangerCoint.py +53 -75
  31. validmind/tests/data_validation/HighCardinality.py +32 -80
  32. validmind/tests/data_validation/HighPearsonCorrelation.py +29 -97
  33. validmind/tests/data_validation/IQROutliersBarPlot.py +63 -94
  34. validmind/tests/data_validation/IQROutliersTable.py +40 -80
  35. validmind/tests/data_validation/IsolationForestOutliers.py +41 -63
  36. validmind/tests/data_validation/KPSS.py +33 -81
  37. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +47 -95
  38. validmind/tests/data_validation/MissingValues.py +17 -58
  39. validmind/tests/data_validation/MissingValuesBarPlot.py +61 -87
  40. validmind/tests/data_validation/PhillipsPerronArch.py +56 -79
  41. validmind/tests/data_validation/RollingStatsPlot.py +50 -81
  42. validmind/tests/data_validation/SeasonalDecompose.py +102 -184
  43. validmind/tests/data_validation/Skewness.py +27 -64
  44. validmind/tests/data_validation/SpreadPlot.py +34 -57
  45. validmind/tests/data_validation/TabularCategoricalBarPlots.py +46 -65
  46. validmind/tests/data_validation/TabularDateTimeHistograms.py +23 -45
  47. validmind/tests/data_validation/TabularNumericalHistograms.py +27 -46
  48. validmind/tests/data_validation/TargetRateBarPlots.py +54 -93
  49. validmind/tests/data_validation/TimeSeriesFrequency.py +48 -133
  50. validmind/tests/data_validation/TimeSeriesHistogram.py +24 -3
  51. validmind/tests/data_validation/TimeSeriesLinePlot.py +29 -47
  52. validmind/tests/data_validation/TimeSeriesMissingValues.py +59 -135
  53. validmind/tests/data_validation/TimeSeriesOutliers.py +54 -171
  54. validmind/tests/data_validation/TooManyZeroValues.py +21 -70
  55. validmind/tests/data_validation/UniqueRows.py +23 -62
  56. validmind/tests/data_validation/WOEBinPlots.py +83 -109
  57. validmind/tests/data_validation/WOEBinTable.py +28 -69
  58. validmind/tests/data_validation/ZivotAndrewsArch.py +33 -75
  59. validmind/tests/data_validation/nlp/CommonWords.py +49 -57
  60. validmind/tests/data_validation/nlp/Hashtags.py +27 -49
  61. validmind/tests/data_validation/nlp/LanguageDetection.py +7 -13
  62. validmind/tests/data_validation/nlp/Mentions.py +32 -63
  63. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +89 -14
  64. validmind/tests/data_validation/nlp/Punctuations.py +63 -47
  65. validmind/tests/data_validation/nlp/Sentiment.py +4 -0
  66. validmind/tests/data_validation/nlp/StopWords.py +62 -91
  67. validmind/tests/data_validation/nlp/TextDescription.py +116 -159
  68. validmind/tests/data_validation/nlp/Toxicity.py +12 -4
  69. validmind/tests/decorator.py +33 -242
  70. validmind/tests/load.py +212 -153
  71. validmind/tests/model_validation/BertScore.py +13 -7
  72. validmind/tests/model_validation/BleuScore.py +4 -0
  73. validmind/tests/model_validation/ClusterSizeDistribution.py +24 -47
  74. validmind/tests/model_validation/ContextualRecall.py +3 -0
  75. validmind/tests/model_validation/FeaturesAUC.py +43 -74
  76. validmind/tests/model_validation/MeteorScore.py +3 -0
  77. validmind/tests/model_validation/RegardScore.py +5 -1
  78. validmind/tests/model_validation/RegressionResidualsPlot.py +54 -75
  79. validmind/tests/model_validation/embeddings/ClusterDistribution.py +10 -33
  80. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +11 -29
  81. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +19 -31
  82. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +40 -49
  83. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +29 -15
  84. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +25 -11
  85. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +28 -13
  86. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +67 -38
  87. validmind/tests/model_validation/embeddings/utils.py +53 -0
  88. validmind/tests/model_validation/ragas/AnswerCorrectness.py +37 -32
  89. validmind/tests/model_validation/ragas/{AspectCritique.py → AspectCritic.py} +33 -27
  90. validmind/tests/model_validation/ragas/ContextEntityRecall.py +44 -41
  91. validmind/tests/model_validation/ragas/ContextPrecision.py +40 -35
  92. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +133 -0
  93. validmind/tests/model_validation/ragas/ContextRecall.py +40 -35
  94. validmind/tests/model_validation/ragas/Faithfulness.py +42 -30
  95. validmind/tests/model_validation/ragas/NoiseSensitivity.py +59 -35
  96. validmind/tests/model_validation/ragas/{AnswerRelevance.py → ResponseRelevancy.py} +52 -41
  97. validmind/tests/model_validation/ragas/{AnswerSimilarity.py → SemanticSimilarity.py} +39 -34
  98. validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +13 -16
  99. validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +13 -16
  100. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +51 -89
  101. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +31 -61
  102. validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +118 -83
  103. validmind/tests/model_validation/sklearn/CompletenessScore.py +13 -16
  104. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +62 -94
  105. validmind/tests/model_validation/sklearn/FeatureImportance.py +7 -8
  106. validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +12 -15
  107. validmind/tests/model_validation/sklearn/HomogeneityScore.py +12 -15
  108. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +23 -53
  109. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +60 -74
  110. validmind/tests/model_validation/sklearn/MinimumAccuracy.py +16 -84
  111. validmind/tests/model_validation/sklearn/MinimumF1Score.py +22 -72
  112. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +29 -78
  113. validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +52 -82
  114. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +51 -145
  115. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +60 -78
  116. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +130 -172
  117. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +26 -55
  118. validmind/tests/model_validation/sklearn/ROCCurve.py +43 -77
  119. validmind/tests/model_validation/sklearn/RegressionPerformance.py +41 -94
  120. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +47 -136
  121. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +164 -208
  122. validmind/tests/model_validation/sklearn/SilhouettePlot.py +54 -99
  123. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +50 -124
  124. validmind/tests/model_validation/sklearn/VMeasure.py +12 -15
  125. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +225 -281
  126. validmind/tests/model_validation/statsmodels/AutoARIMA.py +40 -45
  127. validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +22 -47
  128. validmind/tests/model_validation/statsmodels/Lilliefors.py +17 -28
  129. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +37 -81
  130. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +37 -105
  131. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +62 -166
  132. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +57 -119
  133. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +20 -57
  134. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +47 -80
  135. validmind/tests/ongoing_monitoring/PredictionCorrelation.py +2 -0
  136. validmind/tests/ongoing_monitoring/TargetPredictionDistributionPlot.py +4 -2
  137. validmind/tests/output.py +120 -0
  138. validmind/tests/prompt_validation/Bias.py +55 -98
  139. validmind/tests/prompt_validation/Clarity.py +56 -99
  140. validmind/tests/prompt_validation/Conciseness.py +63 -101
  141. validmind/tests/prompt_validation/Delimitation.py +48 -89
  142. validmind/tests/prompt_validation/NegativeInstruction.py +62 -96
  143. validmind/tests/prompt_validation/Robustness.py +80 -121
  144. validmind/tests/prompt_validation/Specificity.py +61 -95
  145. validmind/tests/prompt_validation/ai_powered_test.py +2 -2
  146. validmind/tests/run.py +314 -496
  147. validmind/tests/test_providers.py +109 -79
  148. validmind/tests/utils.py +91 -0
  149. validmind/unit_metrics/__init__.py +16 -155
  150. validmind/unit_metrics/classification/F1.py +1 -0
  151. validmind/unit_metrics/classification/Precision.py +1 -0
  152. validmind/unit_metrics/classification/ROC_AUC.py +1 -0
  153. validmind/unit_metrics/classification/Recall.py +1 -0
  154. validmind/unit_metrics/regression/AdjustedRSquaredScore.py +1 -0
  155. validmind/unit_metrics/regression/GiniCoefficient.py +1 -0
  156. validmind/unit_metrics/regression/HuberLoss.py +1 -0
  157. validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +1 -0
  158. validmind/unit_metrics/regression/MeanAbsoluteError.py +1 -0
  159. validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +1 -0
  160. validmind/unit_metrics/regression/MeanBiasDeviation.py +1 -0
  161. validmind/unit_metrics/regression/MeanSquaredError.py +1 -0
  162. validmind/unit_metrics/regression/QuantileLoss.py +1 -0
  163. validmind/unit_metrics/regression/RSquaredScore.py +2 -1
  164. validmind/unit_metrics/regression/RootMeanSquaredError.py +1 -0
  165. validmind/utils.py +66 -17
  166. validmind/vm_models/__init__.py +2 -17
  167. validmind/vm_models/dataset/dataset.py +31 -4
  168. validmind/vm_models/figure.py +7 -37
  169. validmind/vm_models/model.py +3 -0
  170. validmind/vm_models/result/__init__.py +7 -0
  171. validmind/vm_models/result/result.jinja +21 -0
  172. validmind/vm_models/result/result.py +337 -0
  173. validmind/vm_models/result/utils.py +160 -0
  174. validmind/vm_models/test_suite/runner.py +16 -54
  175. validmind/vm_models/test_suite/summary.py +3 -3
  176. validmind/vm_models/test_suite/test.py +43 -77
  177. validmind/vm_models/test_suite/test_suite.py +8 -40
  178. validmind-2.6.7.dist-info/METADATA +137 -0
  179. {validmind-2.5.24.dist-info → validmind-2.6.7.dist-info}/RECORD +182 -189
  180. validmind/tests/data_validation/AutoSeasonality.py +0 -190
  181. validmind/tests/metadata.py +0 -59
  182. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +0 -176
  183. validmind/tests/model_validation/ragas/ContextUtilization.py +0 -161
  184. validmind/tests/model_validation/sklearn/ClusterPerformance.py +0 -80
  185. validmind/unit_metrics/composite.py +0 -238
  186. validmind/vm_models/test/metric.py +0 -98
  187. validmind/vm_models/test/metric_result.py +0 -61
  188. validmind/vm_models/test/output_template.py +0 -55
  189. validmind/vm_models/test/result_summary.py +0 -76
  190. validmind/vm_models/test/result_wrapper.py +0 -488
  191. validmind/vm_models/test/test.py +0 -103
  192. validmind/vm_models/test/threshold_test.py +0 -106
  193. validmind/vm_models/test/threshold_test_result.py +0 -75
  194. validmind/vm_models/test_context.py +0 -259
  195. validmind-2.5.24.dist-info/METADATA +0 -118
  196. {validmind-2.5.24.dist-info → validmind-2.6.7.dist-info}/LICENSE +0 -0
  197. {validmind-2.5.24.dist-info → validmind-2.6.7.dist-info}/WHEEL +0 -0
  198. {validmind-2.5.24.dist-info → validmind-2.6.7.dist-info}/entry_points.txt +0 -0
@@ -2,16 +2,16 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
7
5
  from statsmodels.stats.diagnostic import kstest_normal
8
6
 
7
+ from validmind import tags, tasks
9
8
  from validmind.errors import InvalidTestParametersError
10
- from validmind.vm_models import Metric, ResultSummary, ResultTable, ResultTableMetadata
9
+ from validmind.vm_models import VMDataset, VMModel
11
10
 
12
11
 
13
- @dataclass
14
- class KolmogorovSmirnov(Metric):
12
+ @tags("tabular_data", "data_distribution", "statistical_test", "statsmodels")
13
+ @tasks("classification", "regression")
14
+ def KolmogorovSmirnov(model: VMModel, dataset: VMDataset, dist: str = "norm"):
15
15
  """
16
16
  Assesses whether each feature in the dataset aligns with a normal distribution using the Kolmogorov-Smirnov test.
17
17
 
@@ -47,48 +47,23 @@ class KolmogorovSmirnov(Metric):
47
47
  - Less effective for multivariate distributions, as it is designed for univariate distributions.
48
48
  - Does not identify specific types of non-normality, such as skewness or kurtosis, which could impact model fitting.
49
49
  """
50
+ if dist not in ["norm", "exp"]:
51
+ raise InvalidTestParametersError(
52
+ "'dist' parameter must be either 'norm' or 'exp'"
53
+ )
50
54
 
51
- name = "kolmogorov_smirnov"
52
- required_inputs = ["dataset"]
53
- default_params = {"dist": "norm"}
54
- tasks = ["classification", "regression"]
55
- tags = [
56
- "tabular_data",
57
- "data_distribution",
58
- "statistical_test",
59
- "statsmodels",
60
- ]
55
+ df = dataset.df[dataset.feature_columns_numeric]
61
56
 
62
- def summary(self, metric_value):
63
- results_table = metric_value["metrics_summary"]
64
-
65
- results_table = [
66
- {"Column": k, "stat": result["stat"], "pvalue": result["pvalue"]}
67
- for k, result in results_table.items()
68
- ]
69
-
70
- return ResultSummary(
71
- results=[
72
- ResultTable(
73
- data=results_table,
74
- metadata=ResultTableMetadata(title="KS Test results"),
75
- )
76
- ]
77
- )
57
+ ks_values = {}
58
+ for col in df.columns:
59
+ ks_stat, p_value = kstest_normal(df[col].values, dist)
60
+ ks_values[col] = {"stat": ks_stat, "pvalue": p_value}
78
61
 
79
- def run(self):
80
- """
81
- Calculates KS for each of the dataset features
82
- """
83
- data_distribution = self.params["dist"]
84
- if data_distribution not in ["norm" or "exp"]:
85
- InvalidTestParametersError("Dist parameter must be either 'norm' or 'exp'")
86
-
87
- x_train = self.inputs.dataset.df[self.inputs.dataset.feature_columns_numeric]
88
- ks_values = {}
89
- for col in x_train.columns:
90
- ks_stat, p_value = kstest_normal(x_train[col].values, data_distribution)
91
- ks_values[col] = {"stat": ks_stat, "pvalue": p_value}
92
-
93
- print(ks_values)
94
- return self.cache_results({"metrics_summary": ks_values})
62
+ return [
63
+ {
64
+ "Column": k,
65
+ "Statistic": result["stat"],
66
+ "P-Value": result["pvalue"],
67
+ }
68
+ for k, result in ks_values.items()
69
+ ]
@@ -2,15 +2,15 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
7
5
  from statsmodels.stats.diagnostic import lilliefors
8
6
 
9
- from validmind.vm_models import Metric
7
+ from validmind import tags, tasks
8
+ from validmind.vm_models import VMDataset, VMModel
10
9
 
11
10
 
12
- @dataclass
13
- class Lilliefors(Metric):
11
+ @tags("tabular_data", "data_distribution", "statistical_test", "statsmodels")
12
+ @tasks("classification", "regression")
13
+ def Lilliefors(model: VMModel, dataset: VMDataset):
14
14
  """
15
15
  Assesses the normality of feature distributions in an ML model's training dataset using the Lilliefors test.
16
16
 
@@ -56,29 +56,18 @@ class Lilliefors(Metric):
56
56
  - Like any other statistical test, Lilliefors test may also produce false positives or negatives. Hence, banking
57
57
  solely on this test, without considering other characteristics of the data, may give rise to risks.
58
58
  """
59
+ df = dataset.df[dataset.feature_columns_numeric]
60
+
61
+ table = []
59
62
 
60
- name = "lilliefors_test"
61
- required_inputs = ["dataset"]
62
- tasks = ["classification", "regression"]
63
- tags = [
64
- "tabular_data",
65
- "data_distribution",
66
- "statistical_test",
67
- "statsmodels",
68
- ]
69
-
70
- def run(self):
71
- """
72
- Calculates Lilliefors test for each of the dataset features
73
- """
74
- x_train = self.inputs.dataset.df[self.inputs.dataset.feature_columns_numeric]
75
-
76
- lilliefors_values = {}
77
- for col in x_train.columns:
78
- l_stat, p_value = lilliefors(x_train[col].values)
79
- lilliefors_values[col] = {
80
- "stat": l_stat,
81
- "pvalue": p_value,
63
+ for col in df.columns:
64
+ l_stat, p_value = lilliefors(df[col].values)
65
+ table.append(
66
+ {
67
+ "Column": col,
68
+ "Statistic": l_stat,
69
+ "P-Value": p_value,
82
70
  }
71
+ )
83
72
 
84
- return self.cache_results(lilliefors_values)
73
+ return table
@@ -2,36 +2,37 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
6
-
7
5
  import matplotlib.pyplot as plt
8
6
  import seaborn as sns
9
7
 
8
+ from validmind import tags, tasks
10
9
  from validmind.errors import SkipTestError
11
10
  from validmind.logging import get_logger
12
- from validmind.vm_models import Figure, Metric
11
+ from validmind.vm_models import VMModel
13
12
 
14
13
  logger = get_logger(__name__)
15
14
 
16
15
 
17
- @dataclass
18
- class RegressionFeatureSignificance(Metric):
16
+ @tags("statistical_test", "model_interpretation", "visualization", "feature_importance")
17
+ @tasks("regression")
18
+ def RegressionFeatureSignificance(
19
+ model: VMModel, fontsize: int = 10, p_threshold: float = 0.05
20
+ ):
19
21
  """
20
- Assesses and visualizes the statistical significance of features in a set of regression models.
22
+ Assesses and visualizes the statistical significance of features in a regression model.
21
23
 
22
24
  ### Purpose
23
25
 
24
26
  The Regression Feature Significance metric assesses the significance of each feature in a given set of regression
25
- models. It creates a visualization displaying p-values for every feature of each model, assisting model developers
26
- in understanding which features are most influential in their models.
27
+ model. It creates a visualization displaying p-values for every feature of the model, assisting model developers
28
+ in understanding which features are most influential in their model.
27
29
 
28
30
  ### Test Mechanism
29
31
 
30
- The test mechanism involves going through each fitted regression model in a given list, extracting the model
31
- coefficients and p-values for each feature, and then plotting these values. The x-axis on the plot contains the
32
- p-values while the y-axis denotes the coefficients of each feature. A vertical red line is drawn at the threshold
33
- for p-value significance, which is 0.05 by default. Any features with p-values to the left of this line are
34
- considered statistically significant at the chosen level.
32
+ The test mechanism involves extracting the model's coefficients and p-values for each feature, and then plotting these
33
+ values. The x-axis on the plot contains the p-values while the y-axis denotes the coefficients of each feature. A
34
+ vertical red line is drawn at the threshold for p-value significance, which is 0.05 by default. Any features with
35
+ p-values to the left of this line are considered statistically significant at the chosen level.
35
36
 
36
37
  ### Signs of High Risk
37
38
 
@@ -45,7 +46,6 @@ class RegressionFeatureSignificance(Metric):
45
46
  - Helps identify the features that significantly contribute to a model's prediction, providing insights into the
46
47
  feature importance.
47
48
  - Provides tangible, easy-to-understand visualizations to interpret the feature significance.
48
- - Facilitates comparison of feature importance across multiple models.
49
49
 
50
50
  ### Limitations
51
51
 
@@ -57,81 +57,37 @@ class RegressionFeatureSignificance(Metric):
57
57
  - P-value thresholds are somewhat arbitrary and do not always indicate practical significance, only statistical
58
58
  significance.
59
59
  """
60
+ if model.library != "statsmodels":
61
+ raise SkipTestError("Only statsmodels are supported for this metric")
60
62
 
61
- name = "regression_feature_significance"
62
- required_inputs = ["model"]
63
-
64
- default_params = {"fontsize": 10, "p_threshold": 0.05}
65
- tasks = ["regression"]
66
- tags = [
67
- "statistical_test",
68
- "model_interpretation",
69
- "visualization",
70
- "feature_importance",
71
- ]
72
-
73
- def run(self):
74
- fontsize = self.params["fontsize"]
75
- p_threshold = self.params["p_threshold"]
76
-
77
- # Check models list is not empty
78
- if not self.inputs.model:
79
- raise ValueError("Model must be provided in the models parameter")
80
-
81
- figures = self._plot_pvalues(self.inputs.model, fontsize, p_threshold)
82
-
83
- return self.cache_results(figures=figures)
84
-
85
- def _plot_pvalues(self, model_list, fontsize, p_threshold):
86
- # Initialize a list to store figures
87
- figures = []
88
-
89
- for i, model in enumerate(model_list):
90
-
91
- if model.library != "statsmodels":
92
- raise SkipTestError("Only statsmodels are supported for this metric")
63
+ coefficients = model.model.params
64
+ pvalues = model.model.pvalues
93
65
 
94
- # Get the coefficients and p-values from the model
95
- coefficients = model.model.params
96
- pvalues = model.model.pvalues
66
+ # Sort the variables by p-value in ascending order
67
+ sorted_idx = pvalues.argsort()
68
+ coefficients = coefficients.iloc[sorted_idx]
69
+ pvalues = pvalues.iloc[sorted_idx]
97
70
 
98
- # Sort the variables by p-value in ascending order
99
- sorted_idx = pvalues.argsort()
100
- coefficients = coefficients.iloc[sorted_idx]
101
- pvalues = pvalues.iloc[sorted_idx]
71
+ fig, ax = plt.subplots()
102
72
 
103
- # Increase the height of the figure
104
- fig, ax = plt.subplots()
73
+ sns.barplot(x=pvalues, y=coefficients.index, ax=ax, color="skyblue")
105
74
 
106
- # Create a horizontal bar plot with wider bars using Seaborn
107
- sns.barplot(x=pvalues, y=coefficients.index, ax=ax, color="skyblue")
75
+ # Add a threshold line at p-value = p_threshold
76
+ threshold_line = ax.axvline(x=p_threshold, color="red", linestyle="--")
108
77
 
109
- # Add a threshold line at p-value = p_threshold
110
- threshold_line = ax.axvline(x=p_threshold, color="red", linestyle="--")
78
+ # Set labels and title
79
+ ax.set_xlabel("P-value")
80
+ ax.set_ylabel(None)
81
+ ax.set_title(f"Feature Significance for {model.input_id}")
111
82
 
112
- # Set labels and title
113
- ax.set_xlabel("P-value")
114
- ax.set_ylabel(None)
115
- ax.set_title(f"Feature Significance for Model {i + 1}")
83
+ plt.tight_layout()
116
84
 
117
- # Adjust the layout to prevent overlapping of variable names
118
- plt.tight_layout()
85
+ ax.set_yticklabels(ax.get_yticklabels(), fontsize=fontsize)
119
86
 
120
- # Set the fontsize of y-axis tick labels
121
- ax.set_yticklabels(ax.get_yticklabels(), fontsize=fontsize)
87
+ # Add a legend for the threshold line
88
+ legend_label = f"p_threshold {p_threshold}"
89
+ ax.legend([threshold_line], [legend_label])
122
90
 
123
- # Add a legend for the threshold line
124
- legend_label = f"p_threshold {p_threshold}"
125
- ax.legend([threshold_line], [legend_label])
91
+ plt.close()
126
92
 
127
- # Add to the figures list
128
- figures.append(
129
- Figure(
130
- for_object=self,
131
- key=f"{self.key}:{i}",
132
- figure=fig,
133
- metadata={"model": str(model.model)},
134
- )
135
- )
136
- plt.close("all")
137
- return figures
93
+ return fig
@@ -2,38 +2,43 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from dataclasses import dataclass
5
+ from typing import Union
6
6
 
7
7
  import matplotlib.pyplot as plt
8
8
  import pandas as pd
9
9
 
10
+ from validmind import tags, tasks
10
11
  from validmind.logging import get_logger
11
- from validmind.vm_models import Figure, Metric
12
+ from validmind.vm_models import VMDataset, VMModel
12
13
 
13
14
  logger = get_logger(__name__)
14
15
 
15
16
 
16
- @dataclass
17
- class RegressionModelForecastPlot(Metric):
17
+ @tags("time_series_data", "forecasting", "visualization")
18
+ @tasks("regression")
19
+ def RegressionModelForecastPlot(
20
+ model: VMModel,
21
+ dataset: VMDataset,
22
+ start_date: Union[str, None] = None,
23
+ end_date: Union[str, None] = None,
24
+ ):
18
25
  """
19
- Generates plots to visually compare the forecasted outcomes of one or more regression models against actual
20
- observed values over a specified date range.
26
+ Generates plots to visually compare the forecasted outcomes of a regression model against actual observed values over
27
+ a specified date range.
21
28
 
22
29
  ### Purpose
23
30
 
24
- The "regression_forecast_plot" is intended to visually depict the performance of one or more regression models by
25
- comparing the model's forecasted outcomes against actual observed values within a specified date range. This metric
26
- is especially useful in time-series models or any model where the outcome changes over time, allowing direct
27
- comparison of predicted vs actual values.
31
+ This metric is useful for time-series models or any model where the outcome changes over time, allowing direct
32
+ comparison of predicted vs actual values. It can help identify overfitting or underfitting situations as well as
33
+ general model performance.
28
34
 
29
35
  ### Test Mechanism
30
36
 
31
- This test generates a plot for each fitted model in the list. The x-axis represents the date ranging from the
32
- specified "start_date" to the "end_date", while the y-axis shows the value of the outcome variable. Two lines are
33
- plotted: one representing the forecasted values and the other representing the observed values. The "start_date"
34
- and "end_date" can be parameters of this test; if these parameters are not provided, they are set to the minimum
35
- and maximum date available in the dataset. The test verifies that the provided date range is within the limits of
36
- the available data.
37
+ This test generates a plot with the x-axis representing the date ranging from the specified "start_date" to the
38
+ "end_date", while the y-axis shows the value of the outcome variable. Two lines are plotted: one representing the
39
+ forecasted values and the other representing the observed values. The "start_date" and "end_date" can be parameters
40
+ of this test; if these parameters are not provided, they are set to the minimum and maximum date available in the
41
+ dataset.
37
42
 
38
43
  ### Signs of High Risk
39
44
 
@@ -58,101 +63,28 @@ class RegressionModelForecastPlot(Metric):
58
63
  - Inapplicability: Limited to cases where the order of data points (time-series) matters, it might not be of much
59
64
  use in problems that are not related to time series prediction.
60
65
  """
66
+ index = dataset.df.index
61
67
 
62
- name = "regression_forecast_plot"
63
- required_inputs = ["models", "datasets"]
64
- default_params = {"start_date": None, "end_date": None}
65
- tasks = ["regression"]
66
- tags = ["forecasting", "visualization"]
67
-
68
- def run(self):
69
- start_date = self.params["start_date"]
70
- end_date = self.params["end_date"]
71
-
72
- # Check models list is not empty
73
- if not self.inputs.models:
74
- raise ValueError("List of models must be provided in the models parameter")
75
- all_models = []
76
- for model in self.inputs.models:
77
- all_models.append(model)
78
-
79
- figures = self._plot_forecast(
80
- all_models, self.inputs.datasets, start_date, end_date
81
- )
82
-
83
- return self.cache_results(figures=figures)
84
-
85
- def _plot_forecast(self, model_list, datasets, start_date=None, end_date=None):
86
- # Convert start_date and end_date to pandas Timestamp for comparison
87
- start_date = pd.Timestamp(start_date)
88
- end_date = pd.Timestamp(end_date)
89
-
90
- # Initialize a list to store figures
91
- figures = []
92
-
93
- for i, fitted_model in enumerate(model_list):
94
- feature_columns = datasets[0].feature_columns
95
-
96
- train_ds = datasets[0]
97
- test_ds = datasets[1]
98
-
99
- y_pred = train_ds.y_pred(fitted_model)
100
- y_pred_test = test_ds.y_pred(fitted_model)
68
+ start_date = index.min() if start_date is None else pd.Timestamp(start_date)
69
+ end_date = index.max() if end_date is None else pd.Timestamp(end_date)
101
70
 
102
- # Check that start_date and end_date are within the data range
103
- all_dates = pd.concat([pd.Series(train_ds.index), pd.Series(test_ds.index)])
104
-
105
- # If start_date or end_date are None, set them to the min/max of all_dates
106
- if start_date is None:
107
- start_date = all_dates.min()
108
- else:
109
- start_date = pd.Timestamp(start_date)
110
-
111
- if end_date is None:
112
- end_date = all_dates.max()
113
- else:
114
- end_date = pd.Timestamp(end_date)
115
-
116
- # If start_date or end_date are None, set them to the min/max of all_dates
117
- if start_date is None:
118
- start_date = all_dates.min()
119
- else:
120
- start_date = pd.Timestamp(start_date)
121
-
122
- if end_date is None:
123
- end_date = all_dates.max()
124
- else:
125
- end_date = pd.Timestamp(end_date)
126
-
127
- if start_date < all_dates.min() or end_date > all_dates.max():
128
- raise ValueError(
129
- "start_date and end_date must be within the range of dates in the data"
130
- )
71
+ if start_date < index.min() or end_date > index.max():
72
+ raise ValueError(
73
+ "start_date and end_date must be within the range of dates in the data"
74
+ )
131
75
 
132
- fig, ax = plt.subplots()
133
- ax.plot(train_ds.index, train_ds.y, label="Train Forecast")
134
- ax.plot(test_ds.index, test_ds.y, label="Test Forecast")
135
- ax.plot(train_ds.index, y_pred, label="Train Dataset", color="grey")
136
- ax.plot(test_ds.index, y_pred_test, label="Test Dataset", color="black")
76
+ fig, ax = plt.subplots()
137
77
 
138
- plt.title(f"Forecast vs Observed for features {feature_columns}")
78
+ ax.plot(index, dataset.y, label="Observed")
79
+ ax.plot(index, dataset.y_pred(model), label="Forecast", color="grey")
139
80
 
140
- # Set the x-axis limits to zoom in/out
141
- plt.xlim(start_date, end_date)
81
+ plt.title("Forecast vs Observed")
142
82
 
143
- plt.legend()
144
- # TODO: define a proper key for each plot
145
- logger.info(f"Plotting forecast vs observed for model {fitted_model.model}")
83
+ # Set the x-axis limits to zoom in/out
84
+ plt.xlim(start_date, end_date)
146
85
 
147
- plt.close("all")
86
+ plt.legend()
148
87
 
149
- figures.append(
150
- Figure(
151
- for_object=self,
152
- key=f"{self.key}:{i}",
153
- figure=fig,
154
- metadata={"model": str(feature_columns)},
155
- )
156
- )
88
+ plt.close()
157
89
 
158
- return figures
90
+ return fig