validmind 2.3.5__py3-none-any.whl → 2.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (148) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai/test_descriptions.py +8 -1
  3. validmind/ai/utils.py +2 -1
  4. validmind/client.py +1 -0
  5. validmind/tests/__init__.py +14 -468
  6. validmind/tests/_store.py +102 -0
  7. validmind/tests/data_validation/ACFandPACFPlot.py +7 -9
  8. validmind/tests/data_validation/ADF.py +8 -10
  9. validmind/tests/data_validation/ANOVAOneWayTable.py +8 -10
  10. validmind/tests/data_validation/AutoAR.py +2 -4
  11. validmind/tests/data_validation/AutoMA.py +2 -4
  12. validmind/tests/data_validation/AutoSeasonality.py +8 -10
  13. validmind/tests/data_validation/AutoStationarity.py +8 -10
  14. validmind/tests/data_validation/BivariateFeaturesBarPlots.py +8 -10
  15. validmind/tests/data_validation/BivariateHistograms.py +8 -10
  16. validmind/tests/data_validation/BivariateScatterPlots.py +8 -10
  17. validmind/tests/data_validation/ChiSquaredFeaturesTable.py +8 -10
  18. validmind/tests/data_validation/ClassImbalance.py +2 -4
  19. validmind/tests/data_validation/DFGLSArch.py +2 -4
  20. validmind/tests/data_validation/DatasetDescription.py +7 -9
  21. validmind/tests/data_validation/DatasetSplit.py +8 -9
  22. validmind/tests/data_validation/DescriptiveStatistics.py +2 -4
  23. validmind/tests/data_validation/Duplicates.py +2 -4
  24. validmind/tests/data_validation/EngleGrangerCoint.py +2 -4
  25. validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +2 -4
  26. validmind/tests/data_validation/HeatmapFeatureCorrelations.py +2 -4
  27. validmind/tests/data_validation/HighCardinality.py +2 -4
  28. validmind/tests/data_validation/HighPearsonCorrelation.py +2 -4
  29. validmind/tests/data_validation/IQROutliersBarPlot.py +2 -4
  30. validmind/tests/data_validation/IQROutliersTable.py +2 -4
  31. validmind/tests/data_validation/IsolationForestOutliers.py +2 -4
  32. validmind/tests/data_validation/KPSS.py +8 -10
  33. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -4
  34. validmind/tests/data_validation/MissingValues.py +2 -4
  35. validmind/tests/data_validation/MissingValuesBarPlot.py +2 -4
  36. validmind/tests/data_validation/MissingValuesRisk.py +2 -4
  37. validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -4
  38. validmind/tests/data_validation/PhillipsPerronArch.py +7 -9
  39. validmind/tests/data_validation/RollingStatsPlot.py +2 -4
  40. validmind/tests/data_validation/ScatterPlot.py +2 -4
  41. validmind/tests/data_validation/SeasonalDecompose.py +2 -4
  42. validmind/tests/data_validation/Skewness.py +2 -4
  43. validmind/tests/data_validation/SpreadPlot.py +2 -4
  44. validmind/tests/data_validation/TabularCategoricalBarPlots.py +2 -4
  45. validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -4
  46. validmind/tests/data_validation/TabularDescriptionTables.py +2 -4
  47. validmind/tests/data_validation/TabularNumericalHistograms.py +2 -4
  48. validmind/tests/data_validation/TargetRateBarPlots.py +2 -4
  49. validmind/tests/data_validation/TimeSeriesFrequency.py +2 -4
  50. validmind/tests/data_validation/TimeSeriesLinePlot.py +2 -4
  51. validmind/tests/data_validation/TimeSeriesMissingValues.py +2 -4
  52. validmind/tests/data_validation/TimeSeriesOutliers.py +2 -4
  53. validmind/tests/data_validation/TooManyZeroValues.py +2 -4
  54. validmind/tests/data_validation/UniqueRows.py +2 -4
  55. validmind/tests/data_validation/WOEBinPlots.py +2 -4
  56. validmind/tests/data_validation/WOEBinTable.py +2 -4
  57. validmind/tests/data_validation/ZivotAndrewsArch.py +2 -4
  58. validmind/tests/data_validation/nlp/CommonWords.py +2 -4
  59. validmind/tests/data_validation/nlp/Hashtags.py +2 -4
  60. validmind/tests/data_validation/nlp/Mentions.py +2 -4
  61. validmind/tests/data_validation/nlp/Punctuations.py +2 -4
  62. validmind/tests/data_validation/nlp/StopWords.py +2 -4
  63. validmind/tests/data_validation/nlp/TextDescription.py +2 -4
  64. validmind/tests/decorator.py +10 -8
  65. validmind/tests/load.py +264 -0
  66. validmind/tests/metadata.py +59 -0
  67. validmind/tests/model_validation/ClusterSizeDistribution.py +5 -7
  68. validmind/tests/model_validation/FeaturesAUC.py +6 -8
  69. validmind/tests/model_validation/ModelMetadata.py +8 -9
  70. validmind/tests/model_validation/RegressionResidualsPlot.py +2 -6
  71. validmind/tests/model_validation/embeddings/ClusterDistribution.py +2 -4
  72. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +2 -4
  73. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +2 -4
  74. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -4
  75. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +2 -4
  76. validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +5 -7
  77. validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +5 -7
  78. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +7 -9
  79. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -7
  80. validmind/tests/model_validation/sklearn/ClusterPerformance.py +5 -7
  81. validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +2 -7
  82. validmind/tests/model_validation/sklearn/CompletenessScore.py +5 -7
  83. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +19 -10
  84. validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +5 -7
  85. validmind/tests/model_validation/sklearn/HomogeneityScore.py +5 -7
  86. validmind/tests/model_validation/sklearn/HyperParametersTuning.py +2 -7
  87. validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +4 -7
  88. validmind/tests/model_validation/sklearn/MinimumAccuracy.py +7 -9
  89. validmind/tests/model_validation/sklearn/MinimumF1Score.py +7 -9
  90. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +7 -9
  91. validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +8 -10
  92. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +7 -9
  93. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +8 -10
  94. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +7 -9
  95. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +8 -10
  96. validmind/tests/model_validation/sklearn/ROCCurve.py +10 -11
  97. validmind/tests/model_validation/sklearn/RegressionErrors.py +5 -7
  98. validmind/tests/model_validation/sklearn/RegressionModelsPerformanceComparison.py +5 -7
  99. validmind/tests/model_validation/sklearn/RegressionR2Square.py +5 -7
  100. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +10 -14
  101. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +8 -10
  102. validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -7
  103. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +8 -10
  104. validmind/tests/model_validation/sklearn/VMeasure.py +5 -7
  105. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +8 -10
  106. validmind/tests/model_validation/statsmodels/AutoARIMA.py +2 -4
  107. validmind/tests/model_validation/statsmodels/BoxPierce.py +2 -4
  108. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +3 -4
  109. validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +2 -4
  110. validmind/tests/model_validation/statsmodels/GINITable.py +2 -4
  111. validmind/tests/model_validation/statsmodels/JarqueBera.py +7 -9
  112. validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +7 -9
  113. validmind/tests/model_validation/statsmodels/LJungBox.py +2 -4
  114. validmind/tests/model_validation/statsmodels/Lilliefors.py +7 -9
  115. validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -4
  116. validmind/tests/model_validation/statsmodels/RegressionCoeffsPlot.py +2 -4
  117. validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +7 -9
  118. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +2 -4
  119. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +2 -4
  120. validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +2 -4
  121. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +2 -4
  122. validmind/tests/model_validation/statsmodels/RegressionModelsCoeffs.py +2 -4
  123. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +6 -8
  124. validmind/tests/model_validation/statsmodels/RunsTest.py +2 -4
  125. validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +3 -4
  126. validmind/tests/model_validation/statsmodels/ShapiroWilk.py +2 -4
  127. validmind/tests/prompt_validation/Bias.py +2 -4
  128. validmind/tests/prompt_validation/Clarity.py +2 -4
  129. validmind/tests/prompt_validation/Conciseness.py +2 -4
  130. validmind/tests/prompt_validation/Delimitation.py +2 -4
  131. validmind/tests/prompt_validation/NegativeInstruction.py +2 -4
  132. validmind/tests/prompt_validation/Robustness.py +2 -4
  133. validmind/tests/prompt_validation/Specificity.py +2 -4
  134. validmind/tests/run.py +394 -0
  135. validmind/tests/test_providers.py +12 -0
  136. validmind/tests/utils.py +16 -0
  137. validmind/unit_metrics/__init__.py +12 -4
  138. validmind/unit_metrics/composite.py +3 -0
  139. validmind/vm_models/test/metric.py +8 -5
  140. validmind/vm_models/test/result_wrapper.py +2 -1
  141. validmind/vm_models/test/test.py +14 -11
  142. validmind/vm_models/test/threshold_test.py +1 -0
  143. validmind/vm_models/test_suite/runner.py +1 -0
  144. {validmind-2.3.5.dist-info → validmind-2.4.0.dist-info}/METADATA +1 -1
  145. {validmind-2.3.5.dist-info → validmind-2.4.0.dist-info}/RECORD +148 -143
  146. {validmind-2.3.5.dist-info → validmind-2.4.0.dist-info}/LICENSE +0 -0
  147. {validmind-2.3.5.dist-info → validmind-2.4.0.dist-info}/WHEEL +0 -0
  148. {validmind-2.3.5.dist-info → validmind-2.4.0.dist-info}/entry_points.txt +0 -0
@@ -75,10 +75,8 @@ class Bias(ThresholdTest):
75
75
  name = "bias"
76
76
  required_inputs = ["model.prompt"]
77
77
  default_params = {"min_threshold": 7}
78
- metadata = {
79
- "task_types": ["text_classification", "text_summarization"],
80
- "tags": ["llm", "few_shot"],
81
- }
78
+ tasks = ["text_classification", "text_summarization"]
79
+ tags = ["llm", "few_shot"]
82
80
 
83
81
  system_prompt = """
84
82
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different best practices. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
@@ -64,10 +64,8 @@ class Clarity(ThresholdTest):
64
64
  name = "clarity"
65
65
  required_inputs = ["model.prompt"]
66
66
  default_params = {"min_threshold": 7}
67
- metadata = {
68
- "task_types": ["text_classification", "text_summarization"],
69
- "tags": ["llm", "zero_shot", "few_shot"],
70
- }
67
+ tasks = ["text_classification", "text_summarization"]
68
+ tags = ["llm", "zero_shot", "few_shot"]
71
69
 
72
70
  system_prompt = """
73
71
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
@@ -64,10 +64,8 @@ class Conciseness(ThresholdTest):
64
64
  name = "conciseness"
65
65
  required_inputs = ["model.prompt"]
66
66
  default_params = {"min_threshold": 7}
67
- metadata = {
68
- "task_types": ["text_classification", "text_summarization"],
69
- "tags": ["llm", "zero_shot", "few_shot"],
70
- }
67
+ tasks = ["text_classification", "text_summarization"]
68
+ tags = ["llm", "zero_shot", "few_shot"]
71
69
 
72
70
  system_prompt = """
73
71
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
@@ -66,10 +66,8 @@ class Delimitation(ThresholdTest):
66
66
  name = "delimitation"
67
67
  required_inputs = ["model.prompt"]
68
68
  default_params = {"min_threshold": 7}
69
- metadata = {
70
- "task_types": ["text_classification", "text_summarization"],
71
- "tags": ["llm", "zero_shot", "few_shot"],
72
- }
69
+ tasks = ["text_classification", "text_summarization"]
70
+ tags = ["llm", "zero_shot", "few_shot"]
73
71
 
74
72
  system_prompt = """
75
73
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
@@ -70,10 +70,8 @@ class NegativeInstruction(ThresholdTest):
70
70
  name = "negative_instruction"
71
71
  required_inputs = ["model.prompt"]
72
72
  default_params = {"min_threshold": 7}
73
- metadata = {
74
- "task_types": ["text_classification", "text_summarization"],
75
- "tags": ["llm", "zero_shot", "few_shot"],
76
- }
73
+ tasks = ["text_classification", "text_summarization"]
74
+ tags = ["llm", "zero_shot", "few_shot"]
77
75
 
78
76
  system_prompt = """
79
77
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
@@ -60,10 +60,8 @@ class Robustness(ThresholdTest):
60
60
  name = "robustness"
61
61
  required_inputs = ["model"]
62
62
  default_params = {"num_tests": 10}
63
- metadata = {
64
- "task_types": ["text_classification", "text_summarization"],
65
- "tags": ["llm", "zero_shot", "few_shot"],
66
- }
63
+ tasks = ["text_classification", "text_summarization"]
64
+ tags = ["llm", "zero_shot", "few_shot"]
67
65
 
68
66
  system_prompt = '''
69
67
  You are a prompt evaluation researcher AI who is tasked with testing the robustness of LLM prompts.
@@ -66,10 +66,8 @@ class Specificity(ThresholdTest):
66
66
  name = "specificity"
67
67
  required_inputs = ["model.prompt"]
68
68
  default_params = {"min_threshold": 7}
69
- metadata = {
70
- "task_types": ["text_classification", "text_summarization"],
71
- "tags": ["llm", "zero_shot", "few_shot"],
72
- }
69
+ tasks = ["text_classification", "text_summarization"]
70
+ tags = ["llm", "zero_shot", "few_shot"]
73
71
 
74
72
  system_prompt = """
75
73
  You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
validmind/tests/run.py ADDED
@@ -0,0 +1,394 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ from itertools import product
6
+ from typing import Any, Dict, List, Union
7
+ from uuid import uuid4
8
+
9
+ import pandas as pd
10
+
11
+ from validmind.ai.test_descriptions import get_description_metadata
12
+ from validmind.errors import LoadTestError
13
+ from validmind.logging import get_logger
14
+ from validmind.unit_metrics import run_metric
15
+ from validmind.unit_metrics.composite import load_composite_metric
16
+ from validmind.vm_models import (
17
+ MetricResult,
18
+ ResultSummary,
19
+ ResultTable,
20
+ TestContext,
21
+ TestInput,
22
+ ThresholdTestResults,
23
+ )
24
+ from validmind.vm_models.figure import is_matplotlib_figure, is_plotly_figure
25
+ from validmind.vm_models.test.result_wrapper import (
26
+ MetricResultWrapper,
27
+ ThresholdTestResultWrapper,
28
+ )
29
+
30
+ from .__types__ import TestID
31
+ from .load import load_test
32
+
33
+ logger = get_logger(__name__)
34
+
35
+
36
+ def _cartesian_product(input_grid: Dict[str, List[Any]]):
37
+ """Get all possible combinations for a set of inputs"""
38
+ return [dict(zip(input_grid, values)) for values in product(*input_grid.values())]
39
+
40
+
41
+ def _combine_summaries(summaries: List[Dict[str, Any]]):
42
+ """Combine the summaries from multiple results
43
+
44
+ Args:
45
+ summaries (List[Dict[str, Any]]): A list of dictionaries where each dictionary
46
+ has two keys: "inputs" and "summary". The "inputs" key should contain the
47
+ inputs used for the test and the "summary" key should contain the actual
48
+ summary object.
49
+
50
+ Constraint: The summaries must all have the same structure meaning that each has
51
+ the same number of tables in the same order with the same columns etc. This
52
+ should always be the case for comparison tests since its the same test run
53
+ multiple times with different inputs.
54
+ """
55
+ if not summaries[0]["summary"]:
56
+ return None
57
+
58
+ def combine_tables(table_index):
59
+ combined_df = pd.DataFrame()
60
+
61
+ for summary_obj in summaries:
62
+ serialized = summary_obj["summary"].results[table_index].serialize()
63
+ summary_df = pd.DataFrame(serialized["data"])
64
+ summary_df = pd.concat(
65
+ [
66
+ pd.DataFrame(summary_obj["inputs"], index=summary_df.index),
67
+ summary_df,
68
+ ],
69
+ axis=1,
70
+ )
71
+ combined_df = pd.concat([combined_df, summary_df], ignore_index=True)
72
+
73
+ return ResultTable(
74
+ data=combined_df.to_dict(orient="records"),
75
+ metadata=summaries[0]["summary"].results[table_index].metadata,
76
+ )
77
+
78
+ return ResultSummary(
79
+ results=[
80
+ combine_tables(table_index)
81
+ for table_index in range(len(summaries[0]["summary"].results))
82
+ ]
83
+ )
84
+
85
+
86
+ def _update_plotly_titles(figures, input_groups, title_template):
87
+ current_title = figures[0].figure.layout.title.text
88
+
89
+ for i, figure in enumerate(figures):
90
+ figure.figure.layout.title.text = title_template.format(
91
+ current_title=f"{current_title} " if current_title else "",
92
+ input_description=", ".join(
93
+ f"{k}={v if isinstance(v, str) else v.input_id}"
94
+ for k, v in input_groups[i].items()
95
+ ),
96
+ )
97
+
98
+
99
+ def _update_matplotlib_titles(figures, input_groups, title_template):
100
+ current_title = figures[0].figure.get_title()
101
+
102
+ for i, figure in enumerate(figures):
103
+ figure.figure.suptitle(
104
+ title_template.format(
105
+ current_title=f"{current_title} " if current_title else "",
106
+ input_description=" and ".join(
107
+ f"{k}: {v if isinstance(v, str) else v.input_id}"
108
+ for k, v in input_groups[i].items()
109
+ ),
110
+ )
111
+ )
112
+
113
+
114
+ def _combine_figures(figure_lists: List[List[Any]], input_groups: List[Dict[str, Any]]):
115
+ """Combine the figures from multiple results"""
116
+ if not figure_lists[0]:
117
+ return None
118
+
119
+ title_template = "{current_title}({input_description})"
120
+
121
+ for i, figures in enumerate(list(zip(*figure_lists))):
122
+ if is_plotly_figure(figures[0].figure):
123
+ _update_plotly_titles(figures, input_groups, title_template)
124
+ elif is_matplotlib_figure(figures[0].figure):
125
+ _update_matplotlib_titles(figures, input_groups, title_template)
126
+ else:
127
+ logger.warning("Cannot properly annotate png figures")
128
+
129
+ return [figure for figures in figure_lists for figure in figures]
130
+
131
+
132
+ def metric_comparison(
133
+ results: List[MetricResultWrapper],
134
+ test_id: TestID,
135
+ input_groups: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
136
+ output_template: str = None,
137
+ generate_description: bool = True,
138
+ ):
139
+ """Build a comparison result for multiple metric results"""
140
+ ref_id = str(uuid4())
141
+
142
+ input_group_strings = [
143
+ {k: v if isinstance(v, str) else v.input_id for k, v in group.items()}
144
+ for group in input_groups
145
+ ]
146
+
147
+ merged_summary = _combine_summaries(
148
+ [
149
+ {"inputs": input_group_strings[i], "summary": result.metric.summary}
150
+ for i, result in enumerate(results)
151
+ ]
152
+ )
153
+ merged_figures = _combine_figures(
154
+ [result.figures for result in results], input_groups
155
+ )
156
+
157
+ # Patch figure metadata so they are connected to the comparison result
158
+ if merged_figures and len(merged_figures):
159
+ for i, figure in enumerate(merged_figures):
160
+ figure.key = f"{figure.key}-{i}"
161
+ figure.metadata["_name"] = test_id
162
+ figure.metadata["_ref_id"] = ref_id
163
+
164
+ return MetricResultWrapper(
165
+ result_id=test_id,
166
+ result_metadata=[
167
+ get_description_metadata(
168
+ test_id=test_id,
169
+ default_description=f"Comparison test result for {test_id}",
170
+ summary=merged_summary.serialize() if merged_summary else None,
171
+ figures=merged_figures,
172
+ should_generate=generate_description,
173
+ ),
174
+ ],
175
+ inputs=[
176
+ input if isinstance(input, str) else input.input_id
177
+ for group in input_groups
178
+ for input in group.values()
179
+ ],
180
+ output_template=output_template,
181
+ metric=MetricResult(
182
+ key=test_id,
183
+ ref_id=ref_id,
184
+ value=[],
185
+ summary=merged_summary,
186
+ ),
187
+ figures=merged_figures,
188
+ )
189
+
190
+
191
+ def threshold_test_comparison(
192
+ results: List[ThresholdTestResultWrapper],
193
+ test_id: TestID,
194
+ input_groups: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
195
+ output_template: str = None,
196
+ generate_description: bool = True,
197
+ ):
198
+ """Build a comparison result for multiple threshold test results"""
199
+ ref_id = str(uuid4())
200
+
201
+ input_group_strings = [
202
+ {k: v if isinstance(v, str) else v.input_id for k, v in group.items()}
203
+ for group in input_groups
204
+ ]
205
+
206
+ merged_summary = _combine_summaries(
207
+ [
208
+ {"inputs": input_group_strings[i], "summary": result.test_results.summary}
209
+ for i, result in enumerate(results)
210
+ ]
211
+ )
212
+ merged_figures = _combine_figures(
213
+ [result.figures for result in results], input_groups
214
+ )
215
+
216
+ # Patch figure metadata so they are connected to the comparison result
217
+ if merged_figures and len(merged_figures):
218
+ for i, figure in enumerate(merged_figures):
219
+ figure.key = f"{figure.key}-{i}"
220
+ figure.metadata["_name"] = test_id
221
+ figure.metadata["_ref_id"] = ref_id
222
+
223
+ return ThresholdTestResultWrapper(
224
+ result_id=test_id,
225
+ result_metadata=[
226
+ get_description_metadata(
227
+ test_id=test_id,
228
+ default_description=f"Comparison test result for {test_id}",
229
+ summary=merged_summary.serialize() if merged_summary else None,
230
+ figures=merged_figures,
231
+ prefix="test_description",
232
+ should_generate=generate_description,
233
+ )
234
+ ],
235
+ inputs=[
236
+ input if isinstance(input, str) else input.input_id
237
+ for group in input_groups
238
+ for input in group.values()
239
+ ],
240
+ output_template=output_template,
241
+ test_results=ThresholdTestResults(
242
+ test_name=test_id,
243
+ ref_id=ref_id,
244
+ # TODO: when we have param_grid support, this will need to be updated
245
+ params=results[0].test_results.params,
246
+ passed=all(result.test_results.passed for result in results),
247
+ results=[],
248
+ summary=merged_summary,
249
+ ),
250
+ figures=merged_figures,
251
+ )
252
+
253
+
254
+ def run_comparison_test(
255
+ test_id: TestID,
256
+ input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
257
+ params: Dict[str, Any] = None,
258
+ show: bool = True,
259
+ output_template: str = None,
260
+ generate_description: bool = True,
261
+ ):
262
+ """Run a comparison test"""
263
+ if isinstance(input_grid, dict):
264
+ input_groups = _cartesian_product(input_grid)
265
+ else:
266
+ input_groups = input_grid
267
+
268
+ results = [
269
+ run_test(
270
+ test_id,
271
+ inputs=inputs,
272
+ show=False,
273
+ params=params,
274
+ __generate_description=False,
275
+ )
276
+ for inputs in input_groups
277
+ ]
278
+
279
+ if isinstance(results[0], MetricResultWrapper):
280
+ func = metric_comparison
281
+ else:
282
+ func = threshold_test_comparison
283
+
284
+ result = func(results, test_id, input_groups, output_template, generate_description)
285
+
286
+ if show:
287
+ result.show()
288
+
289
+ return result
290
+
291
+
292
+ def run_test(
293
+ test_id: TestID = None,
294
+ params: Dict[str, Any] = None,
295
+ inputs: Dict[str, Any] = None,
296
+ input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
297
+ name: str = None,
298
+ unit_metrics: List[TestID] = None,
299
+ output_template: str = None,
300
+ show: bool = True,
301
+ __generate_description: bool = True,
302
+ **kwargs,
303
+ ) -> Union[MetricResultWrapper, ThresholdTestResultWrapper]:
304
+ """Run a test by test ID
305
+
306
+ Args:
307
+ test_id (TestID, optional): The test ID to run. Not required if `unit_metrics` is provided.
308
+ params (dict, optional): A dictionary of parameters to pass into the test. Params
309
+ are used to customize the test behavior and are specific to each test. See the
310
+ test details for more information on the available parameters. Defaults to None.
311
+ inputs (Dict[str, Any], optional): A dictionary of test inputs to pass into the
312
+ test. Inputs are either models or datasets that have been initialized using
313
+ vm.init_model() or vm.init_dataset(). Defaults to None.
314
+ input_grid (Union[Dict[str, List[Any]], List[Dict[str, Any]]], optional): To run
315
+ a comparison test, provide either a dictionary of inputs where the keys are
316
+ the input names and the values are lists of different inputs, or a list of
317
+ dictionaries where each dictionary is a set of inputs to run the test with.
318
+ This will run the test multiple times with different sets of inputs and then
319
+ combine the results into a single output. When passing a dictionary, the grid
320
+ will be created by taking the Cartesian product of the input lists. Its simply
321
+ a more convenient way of forming the input grid as opposed to passing a list of
322
+ all possible combinations. Defaults to None.
323
+ name (str, optional): The name of the test (used to create a composite metric
324
+ out of multiple unit metrics) - required when running multiple unit metrics
325
+ unit_metrics (list, optional): A list of unit metric IDs to run as a composite
326
+ metric - required when running multiple unit metrics
327
+ output_template (str, optional): A jinja2 html template to customize the output
328
+ of the test. Defaults to None.
329
+ show (bool, optional): Whether to display the results. Defaults to True.
330
+ **kwargs: Keyword inputs to pass into the test (same as `inputs` but as keyword
331
+ args instead of a dictionary):
332
+ - dataset: A validmind Dataset object or a Pandas DataFrame
333
+ - model: A model to use for the test
334
+ - models: A list of models to use for the test
335
+ - dataset: A validmind Dataset object or a Pandas DataFrame
336
+ """
337
+ if not test_id and not name and not unit_metrics:
338
+ raise ValueError(
339
+ "`test_id` or `name` and `unit_metrics` must be provided to run a test"
340
+ )
341
+
342
+ if (unit_metrics and not name) or (name and not unit_metrics):
343
+ raise ValueError("`name` and `unit_metrics` must be provided together")
344
+
345
+ if (input_grid and kwargs) or (input_grid and inputs):
346
+ raise ValueError(
347
+ "When providing an `input_grid`, you cannot also provide `inputs` or `kwargs`"
348
+ )
349
+
350
+ if input_grid:
351
+ return run_comparison_test(
352
+ test_id,
353
+ input_grid,
354
+ params=params,
355
+ output_template=output_template,
356
+ show=show,
357
+ generate_description=__generate_description,
358
+ )
359
+
360
+ if test_id and test_id.startswith("validmind.unit_metrics"):
361
+ # TODO: as we move towards a more unified approach to metrics
362
+ # we will want to make everything functional and remove the
363
+ # separation between unit metrics and "normal" metrics
364
+ return run_metric(test_id, inputs=inputs, params=params, show=show)
365
+
366
+ if unit_metrics:
367
+ metric_id_name = "".join(word[0].upper() + word[1:] for word in name.split())
368
+ test_id = f"validmind.composite_test.{metric_id_name}"
369
+
370
+ error, TestClass = load_composite_metric(
371
+ unit_metrics=unit_metrics, metric_name=metric_id_name
372
+ )
373
+
374
+ if error:
375
+ raise LoadTestError(error)
376
+
377
+ else:
378
+ TestClass = load_test(test_id, reload=True)
379
+
380
+ test = TestClass(
381
+ test_id=test_id,
382
+ context=TestContext(),
383
+ inputs=TestInput({**kwargs, **(inputs or {})}),
384
+ output_template=output_template,
385
+ params=params,
386
+ generate_description=__generate_description,
387
+ )
388
+
389
+ test.run()
390
+
391
+ if show:
392
+ test.result.show()
393
+
394
+ return test.result
@@ -9,6 +9,8 @@ from typing import Protocol
9
9
 
10
10
  from validmind.logging import get_logger
11
11
 
12
+ from ._store import test_provider_store
13
+
12
14
  logger = get_logger(__name__)
13
15
 
14
16
 
@@ -145,3 +147,13 @@ class LocalTestProvider:
145
147
  raise LocalTestProviderLoadTestError(
146
148
  f"Failed to find the test class in the module. Error: {str(e)}"
147
149
  )
150
+
151
+
152
+ def register_test_provider(namespace: str, test_provider: "TestProvider") -> None:
153
+ """Register an external test provider
154
+
155
+ Args:
156
+ namespace (str): The namespace of the test provider
157
+ test_provider (TestProvider): The test provider
158
+ """
159
+ test_provider_store.register_test_provider(namespace, test_provider)
@@ -0,0 +1,16 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ """Test Module Utils"""
6
+
7
+ import inspect
8
+
9
+
10
+ def test_description(test_class, truncate=True):
11
+ description = inspect.getdoc(test_class).strip()
12
+
13
+ if truncate and len(description.split("\n")) > 5:
14
+ return description.strip().split("\n")[0] + "..."
15
+
16
+ return description
@@ -6,8 +6,9 @@ import hashlib
6
6
  import json
7
7
  from importlib import import_module
8
8
 
9
- from ..tests.decorator import _build_result, _inspect_signature
10
- from ..utils import get_model_info, test_id_to_name
9
+ from validmind.input_registry import input_registry
10
+ from validmind.tests.decorator import _build_result, _inspect_signature
11
+ from validmind.utils import get_model_info, test_id_to_name
11
12
 
12
13
  unit_metric_results_cache = {}
13
14
 
@@ -157,7 +158,10 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
157
158
  show (bool): Whether to display the results
158
159
  value_only (bool): Whether to return only the value
159
160
  """
160
- inputs = inputs or {}
161
+ inputs = {
162
+ k: input_registry.get(v) if isinstance(v, str) else v
163
+ for k, v in (inputs or {}).items()
164
+ }
161
165
  params = params or {}
162
166
 
163
167
  cache_key = get_metric_cache_key(metric_id, params, inputs)
@@ -168,7 +172,11 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
168
172
 
169
173
  result = metric(
170
174
  **{k: v for k, v in inputs.items() if k in _inputs.keys()},
171
- **{k: v for k, v in params.items() if k in _params.keys()},
175
+ **{
176
+ k: v
177
+ for k, v in params.items()
178
+ if k in _params.keys() or "kwargs" in _params.keys()
179
+ },
172
180
  )
173
181
  unit_metric_results_cache[cache_key] = (
174
182
  result,
@@ -42,6 +42,7 @@ class CompositeMetric(Metric):
42
42
  params=self.params,
43
43
  output_template=self.output_template,
44
44
  show=False,
45
+ generate_description=self.generate_description,
45
46
  )
46
47
 
47
48
  return self.result
@@ -109,6 +110,7 @@ def run_metrics(
109
110
  params: dict = None,
110
111
  test_id: str = None,
111
112
  show: bool = True,
113
+ generate_description: bool = True,
112
114
  ) -> MetricResultWrapper:
113
115
  """Run a composite metric
114
116
 
@@ -209,6 +211,7 @@ def run_metrics(
209
211
  test_id=test_id,
210
212
  default_description=description,
211
213
  summary=result_summary.serialize(),
214
+ should_generate=generate_description,
212
215
  ),
213
216
  {
214
217
  "content_id": f"composite_metric_def:{test_id}:unit_metrics",
@@ -78,11 +78,14 @@ class Metric(Test):
78
78
  self.result = MetricResultWrapper(
79
79
  result_id=self.test_id,
80
80
  result_metadata=[
81
- get_description_metadata(
82
- test_id=self.test_id,
83
- default_description=self.description(),
84
- summary=metric.serialize()["summary"],
85
- figures=figures,
81
+ (
82
+ get_description_metadata(
83
+ test_id=self.test_id,
84
+ default_description=self.description(),
85
+ summary=metric.serialize()["summary"],
86
+ figures=figures,
87
+ should_generate=self.generate_description,
88
+ )
86
89
  )
87
90
  ],
88
91
  metric=metric,
@@ -344,7 +344,8 @@ class MetricResultWrapper(ResultWrapper):
344
344
  """Check if the metric summary has columns from input datasets"""
345
345
  dataset_columns = set()
346
346
 
347
- for input_id in self.inputs:
347
+ for input in self.inputs:
348
+ input_id = input if isinstance(input, str) else input.input_id
348
349
  input_obj = input_registry.get(input_id)
349
350
  if isinstance(input_obj, VMDataset):
350
351
  dataset_columns.update(input_obj.columns)