validmind 2.3.5__py3-none-any.whl → 2.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +8 -1
- validmind/ai/utils.py +2 -1
- validmind/client.py +1 -0
- validmind/template.py +2 -0
- validmind/tests/__init__.py +14 -468
- validmind/tests/_store.py +102 -0
- validmind/tests/data_validation/ACFandPACFPlot.py +7 -9
- validmind/tests/data_validation/ADF.py +8 -10
- validmind/tests/data_validation/ANOVAOneWayTable.py +8 -10
- validmind/tests/data_validation/AutoAR.py +2 -4
- validmind/tests/data_validation/AutoMA.py +2 -4
- validmind/tests/data_validation/AutoSeasonality.py +8 -10
- validmind/tests/data_validation/AutoStationarity.py +8 -10
- validmind/tests/data_validation/BivariateFeaturesBarPlots.py +8 -10
- validmind/tests/data_validation/BivariateHistograms.py +8 -10
- validmind/tests/data_validation/BivariateScatterPlots.py +8 -10
- validmind/tests/data_validation/ChiSquaredFeaturesTable.py +8 -10
- validmind/tests/data_validation/ClassImbalance.py +2 -4
- validmind/tests/data_validation/DFGLSArch.py +2 -4
- validmind/tests/data_validation/DatasetDescription.py +7 -9
- validmind/tests/data_validation/DatasetSplit.py +8 -9
- validmind/tests/data_validation/DescriptiveStatistics.py +2 -4
- validmind/tests/data_validation/Duplicates.py +2 -4
- validmind/tests/data_validation/EngleGrangerCoint.py +2 -4
- validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +2 -4
- validmind/tests/data_validation/HeatmapFeatureCorrelations.py +2 -4
- validmind/tests/data_validation/HighCardinality.py +2 -4
- validmind/tests/data_validation/HighPearsonCorrelation.py +2 -4
- validmind/tests/data_validation/IQROutliersBarPlot.py +2 -4
- validmind/tests/data_validation/IQROutliersTable.py +2 -4
- validmind/tests/data_validation/IsolationForestOutliers.py +2 -4
- validmind/tests/data_validation/KPSS.py +8 -10
- validmind/tests/data_validation/LaggedCorrelationHeatmap.py +2 -4
- validmind/tests/data_validation/MissingValues.py +2 -4
- validmind/tests/data_validation/MissingValuesBarPlot.py +2 -4
- validmind/tests/data_validation/MissingValuesRisk.py +2 -4
- validmind/tests/data_validation/PearsonCorrelationMatrix.py +2 -4
- validmind/tests/data_validation/PhillipsPerronArch.py +7 -9
- validmind/tests/data_validation/RollingStatsPlot.py +2 -4
- validmind/tests/data_validation/ScatterPlot.py +2 -4
- validmind/tests/data_validation/SeasonalDecompose.py +2 -4
- validmind/tests/data_validation/Skewness.py +2 -4
- validmind/tests/data_validation/SpreadPlot.py +2 -4
- validmind/tests/data_validation/TabularCategoricalBarPlots.py +2 -4
- validmind/tests/data_validation/TabularDateTimeHistograms.py +2 -4
- validmind/tests/data_validation/TabularDescriptionTables.py +2 -4
- validmind/tests/data_validation/TabularNumericalHistograms.py +2 -4
- validmind/tests/data_validation/TargetRateBarPlots.py +2 -4
- validmind/tests/data_validation/TimeSeriesFrequency.py +2 -4
- validmind/tests/data_validation/TimeSeriesLinePlot.py +2 -4
- validmind/tests/data_validation/TimeSeriesMissingValues.py +2 -4
- validmind/tests/data_validation/TimeSeriesOutliers.py +2 -4
- validmind/tests/data_validation/TooManyZeroValues.py +2 -4
- validmind/tests/data_validation/UniqueRows.py +2 -4
- validmind/tests/data_validation/WOEBinPlots.py +2 -4
- validmind/tests/data_validation/WOEBinTable.py +2 -4
- validmind/tests/data_validation/ZivotAndrewsArch.py +2 -4
- validmind/tests/data_validation/nlp/CommonWords.py +2 -4
- validmind/tests/data_validation/nlp/Hashtags.py +2 -4
- validmind/tests/data_validation/nlp/Mentions.py +2 -4
- validmind/tests/data_validation/nlp/Punctuations.py +2 -4
- validmind/tests/data_validation/nlp/StopWords.py +2 -4
- validmind/tests/data_validation/nlp/TextDescription.py +2 -4
- validmind/tests/decorator.py +10 -8
- validmind/tests/load.py +264 -0
- validmind/tests/metadata.py +59 -0
- validmind/tests/model_validation/ClusterSizeDistribution.py +5 -7
- validmind/tests/model_validation/FeaturesAUC.py +6 -8
- validmind/tests/model_validation/ModelMetadata.py +8 -9
- validmind/tests/model_validation/RegressionResidualsPlot.py +2 -6
- validmind/tests/model_validation/embeddings/ClusterDistribution.py +2 -4
- validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +2 -4
- validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +2 -4
- validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +2 -4
- validmind/tests/model_validation/embeddings/StabilityAnalysis.py +2 -4
- validmind/tests/model_validation/sklearn/AdjustedMutualInformation.py +5 -7
- validmind/tests/model_validation/sklearn/AdjustedRandIndex.py +5 -7
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +7 -9
- validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +5 -7
- validmind/tests/model_validation/sklearn/ClusterPerformance.py +5 -7
- validmind/tests/model_validation/sklearn/ClusterPerformanceMetrics.py +2 -7
- validmind/tests/model_validation/sklearn/CompletenessScore.py +5 -7
- validmind/tests/model_validation/sklearn/ConfusionMatrix.py +19 -10
- validmind/tests/model_validation/sklearn/FowlkesMallowsScore.py +5 -7
- validmind/tests/model_validation/sklearn/HomogeneityScore.py +5 -7
- validmind/tests/model_validation/sklearn/HyperParametersTuning.py +2 -7
- validmind/tests/model_validation/sklearn/KMeansClustersOptimization.py +4 -7
- validmind/tests/model_validation/sklearn/MinimumAccuracy.py +7 -9
- validmind/tests/model_validation/sklearn/MinimumF1Score.py +7 -9
- validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +7 -9
- validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +8 -10
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +7 -9
- validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +8 -10
- validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +7 -9
- validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +8 -10
- validmind/tests/model_validation/sklearn/ROCCurve.py +10 -11
- validmind/tests/model_validation/sklearn/RegressionErrors.py +5 -7
- validmind/tests/model_validation/sklearn/RegressionModelsPerformanceComparison.py +5 -7
- validmind/tests/model_validation/sklearn/RegressionR2Square.py +5 -7
- validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +10 -14
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +8 -10
- validmind/tests/model_validation/sklearn/SilhouettePlot.py +5 -7
- validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +8 -10
- validmind/tests/model_validation/sklearn/VMeasure.py +5 -7
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +8 -10
- validmind/tests/model_validation/statsmodels/AutoARIMA.py +2 -4
- validmind/tests/model_validation/statsmodels/BoxPierce.py +2 -4
- validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +3 -4
- validmind/tests/model_validation/statsmodels/DurbinWatsonTest.py +2 -4
- validmind/tests/model_validation/statsmodels/GINITable.py +2 -4
- validmind/tests/model_validation/statsmodels/JarqueBera.py +7 -9
- validmind/tests/model_validation/statsmodels/KolmogorovSmirnov.py +7 -9
- validmind/tests/model_validation/statsmodels/LJungBox.py +2 -4
- validmind/tests/model_validation/statsmodels/Lilliefors.py +7 -9
- validmind/tests/model_validation/statsmodels/PredictionProbabilitiesHistogram.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionCoeffsPlot.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionFeatureSignificance.py +7 -9
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionModelSensitivityPlot.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionModelsCoeffs.py +2 -4
- validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +6 -8
- validmind/tests/model_validation/statsmodels/RunsTest.py +2 -4
- validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +3 -4
- validmind/tests/model_validation/statsmodels/ShapiroWilk.py +2 -4
- validmind/tests/prompt_validation/Bias.py +2 -4
- validmind/tests/prompt_validation/Clarity.py +2 -4
- validmind/tests/prompt_validation/Conciseness.py +2 -4
- validmind/tests/prompt_validation/Delimitation.py +2 -4
- validmind/tests/prompt_validation/NegativeInstruction.py +2 -4
- validmind/tests/prompt_validation/Robustness.py +2 -4
- validmind/tests/prompt_validation/Specificity.py +2 -4
- validmind/tests/run.py +394 -0
- validmind/tests/test_providers.py +12 -0
- validmind/tests/utils.py +16 -0
- validmind/unit_metrics/__init__.py +12 -4
- validmind/unit_metrics/composite.py +3 -0
- validmind/vm_models/test/metric.py +8 -5
- validmind/vm_models/test/result_wrapper.py +2 -1
- validmind/vm_models/test/test.py +14 -11
- validmind/vm_models/test/threshold_test.py +1 -0
- validmind/vm_models/test_suite/runner.py +1 -0
- {validmind-2.3.5.dist-info → validmind-2.4.1.dist-info}/METADATA +1 -1
- {validmind-2.3.5.dist-info → validmind-2.4.1.dist-info}/RECORD +149 -144
- {validmind-2.3.5.dist-info → validmind-2.4.1.dist-info}/LICENSE +0 -0
- {validmind-2.3.5.dist-info → validmind-2.4.1.dist-info}/WHEEL +0 -0
- {validmind-2.3.5.dist-info → validmind-2.4.1.dist-info}/entry_points.txt +0 -0
@@ -75,10 +75,8 @@ class Bias(ThresholdTest):
|
|
75
75
|
name = "bias"
|
76
76
|
required_inputs = ["model.prompt"]
|
77
77
|
default_params = {"min_threshold": 7}
|
78
|
-
|
79
|
-
|
80
|
-
"tags": ["llm", "few_shot"],
|
81
|
-
}
|
78
|
+
tasks = ["text_classification", "text_summarization"]
|
79
|
+
tags = ["llm", "few_shot"]
|
82
80
|
|
83
81
|
system_prompt = """
|
84
82
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different best practices. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
@@ -64,10 +64,8 @@ class Clarity(ThresholdTest):
|
|
64
64
|
name = "clarity"
|
65
65
|
required_inputs = ["model.prompt"]
|
66
66
|
default_params = {"min_threshold": 7}
|
67
|
-
|
68
|
-
|
69
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
70
|
-
}
|
67
|
+
tasks = ["text_classification", "text_summarization"]
|
68
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
71
69
|
|
72
70
|
system_prompt = """
|
73
71
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
@@ -64,10 +64,8 @@ class Conciseness(ThresholdTest):
|
|
64
64
|
name = "conciseness"
|
65
65
|
required_inputs = ["model.prompt"]
|
66
66
|
default_params = {"min_threshold": 7}
|
67
|
-
|
68
|
-
|
69
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
70
|
-
}
|
67
|
+
tasks = ["text_classification", "text_summarization"]
|
68
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
71
69
|
|
72
70
|
system_prompt = """
|
73
71
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
@@ -66,10 +66,8 @@ class Delimitation(ThresholdTest):
|
|
66
66
|
name = "delimitation"
|
67
67
|
required_inputs = ["model.prompt"]
|
68
68
|
default_params = {"min_threshold": 7}
|
69
|
-
|
70
|
-
|
71
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
72
|
-
}
|
69
|
+
tasks = ["text_classification", "text_summarization"]
|
70
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
73
71
|
|
74
72
|
system_prompt = """
|
75
73
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
@@ -70,10 +70,8 @@ class NegativeInstruction(ThresholdTest):
|
|
70
70
|
name = "negative_instruction"
|
71
71
|
required_inputs = ["model.prompt"]
|
72
72
|
default_params = {"min_threshold": 7}
|
73
|
-
|
74
|
-
|
75
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
76
|
-
}
|
73
|
+
tasks = ["text_classification", "text_summarization"]
|
74
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
77
75
|
|
78
76
|
system_prompt = """
|
79
77
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
@@ -60,10 +60,8 @@ class Robustness(ThresholdTest):
|
|
60
60
|
name = "robustness"
|
61
61
|
required_inputs = ["model"]
|
62
62
|
default_params = {"num_tests": 10}
|
63
|
-
|
64
|
-
|
65
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
66
|
-
}
|
63
|
+
tasks = ["text_classification", "text_summarization"]
|
64
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
67
65
|
|
68
66
|
system_prompt = '''
|
69
67
|
You are a prompt evaluation researcher AI who is tasked with testing the robustness of LLM prompts.
|
@@ -66,10 +66,8 @@ class Specificity(ThresholdTest):
|
|
66
66
|
name = "specificity"
|
67
67
|
required_inputs = ["model.prompt"]
|
68
68
|
default_params = {"min_threshold": 7}
|
69
|
-
|
70
|
-
|
71
|
-
"tags": ["llm", "zero_shot", "few_shot"],
|
72
|
-
}
|
69
|
+
tasks = ["text_classification", "text_summarization"]
|
70
|
+
tags = ["llm", "zero_shot", "few_shot"]
|
73
71
|
|
74
72
|
system_prompt = """
|
75
73
|
You are a prompt evaluation AI. You are aware of all prompt engineering best practices and can score prompts based on how well they satisfy different metrics. You analyse the prompts step-by-step based on provided documentation and provide a score and an explanation for how you produced that score.
|
validmind/tests/run.py
ADDED
@@ -0,0 +1,394 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
from itertools import product
|
6
|
+
from typing import Any, Dict, List, Union
|
7
|
+
from uuid import uuid4
|
8
|
+
|
9
|
+
import pandas as pd
|
10
|
+
|
11
|
+
from validmind.ai.test_descriptions import get_description_metadata
|
12
|
+
from validmind.errors import LoadTestError
|
13
|
+
from validmind.logging import get_logger
|
14
|
+
from validmind.unit_metrics import run_metric
|
15
|
+
from validmind.unit_metrics.composite import load_composite_metric
|
16
|
+
from validmind.vm_models import (
|
17
|
+
MetricResult,
|
18
|
+
ResultSummary,
|
19
|
+
ResultTable,
|
20
|
+
TestContext,
|
21
|
+
TestInput,
|
22
|
+
ThresholdTestResults,
|
23
|
+
)
|
24
|
+
from validmind.vm_models.figure import is_matplotlib_figure, is_plotly_figure
|
25
|
+
from validmind.vm_models.test.result_wrapper import (
|
26
|
+
MetricResultWrapper,
|
27
|
+
ThresholdTestResultWrapper,
|
28
|
+
)
|
29
|
+
|
30
|
+
from .__types__ import TestID
|
31
|
+
from .load import load_test
|
32
|
+
|
33
|
+
logger = get_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
def _cartesian_product(input_grid: Dict[str, List[Any]]):
|
37
|
+
"""Get all possible combinations for a set of inputs"""
|
38
|
+
return [dict(zip(input_grid, values)) for values in product(*input_grid.values())]
|
39
|
+
|
40
|
+
|
41
|
+
def _combine_summaries(summaries: List[Dict[str, Any]]):
|
42
|
+
"""Combine the summaries from multiple results
|
43
|
+
|
44
|
+
Args:
|
45
|
+
summaries (List[Dict[str, Any]]): A list of dictionaries where each dictionary
|
46
|
+
has two keys: "inputs" and "summary". The "inputs" key should contain the
|
47
|
+
inputs used for the test and the "summary" key should contain the actual
|
48
|
+
summary object.
|
49
|
+
|
50
|
+
Constraint: The summaries must all have the same structure meaning that each has
|
51
|
+
the same number of tables in the same order with the same columns etc. This
|
52
|
+
should always be the case for comparison tests since its the same test run
|
53
|
+
multiple times with different inputs.
|
54
|
+
"""
|
55
|
+
if not summaries[0]["summary"]:
|
56
|
+
return None
|
57
|
+
|
58
|
+
def combine_tables(table_index):
|
59
|
+
combined_df = pd.DataFrame()
|
60
|
+
|
61
|
+
for summary_obj in summaries:
|
62
|
+
serialized = summary_obj["summary"].results[table_index].serialize()
|
63
|
+
summary_df = pd.DataFrame(serialized["data"])
|
64
|
+
summary_df = pd.concat(
|
65
|
+
[
|
66
|
+
pd.DataFrame(summary_obj["inputs"], index=summary_df.index),
|
67
|
+
summary_df,
|
68
|
+
],
|
69
|
+
axis=1,
|
70
|
+
)
|
71
|
+
combined_df = pd.concat([combined_df, summary_df], ignore_index=True)
|
72
|
+
|
73
|
+
return ResultTable(
|
74
|
+
data=combined_df.to_dict(orient="records"),
|
75
|
+
metadata=summaries[0]["summary"].results[table_index].metadata,
|
76
|
+
)
|
77
|
+
|
78
|
+
return ResultSummary(
|
79
|
+
results=[
|
80
|
+
combine_tables(table_index)
|
81
|
+
for table_index in range(len(summaries[0]["summary"].results))
|
82
|
+
]
|
83
|
+
)
|
84
|
+
|
85
|
+
|
86
|
+
def _update_plotly_titles(figures, input_groups, title_template):
|
87
|
+
current_title = figures[0].figure.layout.title.text
|
88
|
+
|
89
|
+
for i, figure in enumerate(figures):
|
90
|
+
figure.figure.layout.title.text = title_template.format(
|
91
|
+
current_title=f"{current_title} " if current_title else "",
|
92
|
+
input_description=", ".join(
|
93
|
+
f"{k}={v if isinstance(v, str) else v.input_id}"
|
94
|
+
for k, v in input_groups[i].items()
|
95
|
+
),
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def _update_matplotlib_titles(figures, input_groups, title_template):
|
100
|
+
current_title = figures[0].figure.get_title()
|
101
|
+
|
102
|
+
for i, figure in enumerate(figures):
|
103
|
+
figure.figure.suptitle(
|
104
|
+
title_template.format(
|
105
|
+
current_title=f"{current_title} " if current_title else "",
|
106
|
+
input_description=" and ".join(
|
107
|
+
f"{k}: {v if isinstance(v, str) else v.input_id}"
|
108
|
+
for k, v in input_groups[i].items()
|
109
|
+
),
|
110
|
+
)
|
111
|
+
)
|
112
|
+
|
113
|
+
|
114
|
+
def _combine_figures(figure_lists: List[List[Any]], input_groups: List[Dict[str, Any]]):
|
115
|
+
"""Combine the figures from multiple results"""
|
116
|
+
if not figure_lists[0]:
|
117
|
+
return None
|
118
|
+
|
119
|
+
title_template = "{current_title}({input_description})"
|
120
|
+
|
121
|
+
for i, figures in enumerate(list(zip(*figure_lists))):
|
122
|
+
if is_plotly_figure(figures[0].figure):
|
123
|
+
_update_plotly_titles(figures, input_groups, title_template)
|
124
|
+
elif is_matplotlib_figure(figures[0].figure):
|
125
|
+
_update_matplotlib_titles(figures, input_groups, title_template)
|
126
|
+
else:
|
127
|
+
logger.warning("Cannot properly annotate png figures")
|
128
|
+
|
129
|
+
return [figure for figures in figure_lists for figure in figures]
|
130
|
+
|
131
|
+
|
132
|
+
def metric_comparison(
|
133
|
+
results: List[MetricResultWrapper],
|
134
|
+
test_id: TestID,
|
135
|
+
input_groups: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
|
136
|
+
output_template: str = None,
|
137
|
+
generate_description: bool = True,
|
138
|
+
):
|
139
|
+
"""Build a comparison result for multiple metric results"""
|
140
|
+
ref_id = str(uuid4())
|
141
|
+
|
142
|
+
input_group_strings = [
|
143
|
+
{k: v if isinstance(v, str) else v.input_id for k, v in group.items()}
|
144
|
+
for group in input_groups
|
145
|
+
]
|
146
|
+
|
147
|
+
merged_summary = _combine_summaries(
|
148
|
+
[
|
149
|
+
{"inputs": input_group_strings[i], "summary": result.metric.summary}
|
150
|
+
for i, result in enumerate(results)
|
151
|
+
]
|
152
|
+
)
|
153
|
+
merged_figures = _combine_figures(
|
154
|
+
[result.figures for result in results], input_groups
|
155
|
+
)
|
156
|
+
|
157
|
+
# Patch figure metadata so they are connected to the comparison result
|
158
|
+
if merged_figures and len(merged_figures):
|
159
|
+
for i, figure in enumerate(merged_figures):
|
160
|
+
figure.key = f"{figure.key}-{i}"
|
161
|
+
figure.metadata["_name"] = test_id
|
162
|
+
figure.metadata["_ref_id"] = ref_id
|
163
|
+
|
164
|
+
return MetricResultWrapper(
|
165
|
+
result_id=test_id,
|
166
|
+
result_metadata=[
|
167
|
+
get_description_metadata(
|
168
|
+
test_id=test_id,
|
169
|
+
default_description=f"Comparison test result for {test_id}",
|
170
|
+
summary=merged_summary.serialize() if merged_summary else None,
|
171
|
+
figures=merged_figures,
|
172
|
+
should_generate=generate_description,
|
173
|
+
),
|
174
|
+
],
|
175
|
+
inputs=[
|
176
|
+
input if isinstance(input, str) else input.input_id
|
177
|
+
for group in input_groups
|
178
|
+
for input in group.values()
|
179
|
+
],
|
180
|
+
output_template=output_template,
|
181
|
+
metric=MetricResult(
|
182
|
+
key=test_id,
|
183
|
+
ref_id=ref_id,
|
184
|
+
value=[],
|
185
|
+
summary=merged_summary,
|
186
|
+
),
|
187
|
+
figures=merged_figures,
|
188
|
+
)
|
189
|
+
|
190
|
+
|
191
|
+
def threshold_test_comparison(
|
192
|
+
results: List[ThresholdTestResultWrapper],
|
193
|
+
test_id: TestID,
|
194
|
+
input_groups: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
|
195
|
+
output_template: str = None,
|
196
|
+
generate_description: bool = True,
|
197
|
+
):
|
198
|
+
"""Build a comparison result for multiple threshold test results"""
|
199
|
+
ref_id = str(uuid4())
|
200
|
+
|
201
|
+
input_group_strings = [
|
202
|
+
{k: v if isinstance(v, str) else v.input_id for k, v in group.items()}
|
203
|
+
for group in input_groups
|
204
|
+
]
|
205
|
+
|
206
|
+
merged_summary = _combine_summaries(
|
207
|
+
[
|
208
|
+
{"inputs": input_group_strings[i], "summary": result.test_results.summary}
|
209
|
+
for i, result in enumerate(results)
|
210
|
+
]
|
211
|
+
)
|
212
|
+
merged_figures = _combine_figures(
|
213
|
+
[result.figures for result in results], input_groups
|
214
|
+
)
|
215
|
+
|
216
|
+
# Patch figure metadata so they are connected to the comparison result
|
217
|
+
if merged_figures and len(merged_figures):
|
218
|
+
for i, figure in enumerate(merged_figures):
|
219
|
+
figure.key = f"{figure.key}-{i}"
|
220
|
+
figure.metadata["_name"] = test_id
|
221
|
+
figure.metadata["_ref_id"] = ref_id
|
222
|
+
|
223
|
+
return ThresholdTestResultWrapper(
|
224
|
+
result_id=test_id,
|
225
|
+
result_metadata=[
|
226
|
+
get_description_metadata(
|
227
|
+
test_id=test_id,
|
228
|
+
default_description=f"Comparison test result for {test_id}",
|
229
|
+
summary=merged_summary.serialize() if merged_summary else None,
|
230
|
+
figures=merged_figures,
|
231
|
+
prefix="test_description",
|
232
|
+
should_generate=generate_description,
|
233
|
+
)
|
234
|
+
],
|
235
|
+
inputs=[
|
236
|
+
input if isinstance(input, str) else input.input_id
|
237
|
+
for group in input_groups
|
238
|
+
for input in group.values()
|
239
|
+
],
|
240
|
+
output_template=output_template,
|
241
|
+
test_results=ThresholdTestResults(
|
242
|
+
test_name=test_id,
|
243
|
+
ref_id=ref_id,
|
244
|
+
# TODO: when we have param_grid support, this will need to be updated
|
245
|
+
params=results[0].test_results.params,
|
246
|
+
passed=all(result.test_results.passed for result in results),
|
247
|
+
results=[],
|
248
|
+
summary=merged_summary,
|
249
|
+
),
|
250
|
+
figures=merged_figures,
|
251
|
+
)
|
252
|
+
|
253
|
+
|
254
|
+
def run_comparison_test(
|
255
|
+
test_id: TestID,
|
256
|
+
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
|
257
|
+
params: Dict[str, Any] = None,
|
258
|
+
show: bool = True,
|
259
|
+
output_template: str = None,
|
260
|
+
generate_description: bool = True,
|
261
|
+
):
|
262
|
+
"""Run a comparison test"""
|
263
|
+
if isinstance(input_grid, dict):
|
264
|
+
input_groups = _cartesian_product(input_grid)
|
265
|
+
else:
|
266
|
+
input_groups = input_grid
|
267
|
+
|
268
|
+
results = [
|
269
|
+
run_test(
|
270
|
+
test_id,
|
271
|
+
inputs=inputs,
|
272
|
+
show=False,
|
273
|
+
params=params,
|
274
|
+
__generate_description=False,
|
275
|
+
)
|
276
|
+
for inputs in input_groups
|
277
|
+
]
|
278
|
+
|
279
|
+
if isinstance(results[0], MetricResultWrapper):
|
280
|
+
func = metric_comparison
|
281
|
+
else:
|
282
|
+
func = threshold_test_comparison
|
283
|
+
|
284
|
+
result = func(results, test_id, input_groups, output_template, generate_description)
|
285
|
+
|
286
|
+
if show:
|
287
|
+
result.show()
|
288
|
+
|
289
|
+
return result
|
290
|
+
|
291
|
+
|
292
|
+
def run_test(
|
293
|
+
test_id: TestID = None,
|
294
|
+
params: Dict[str, Any] = None,
|
295
|
+
inputs: Dict[str, Any] = None,
|
296
|
+
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
|
297
|
+
name: str = None,
|
298
|
+
unit_metrics: List[TestID] = None,
|
299
|
+
output_template: str = None,
|
300
|
+
show: bool = True,
|
301
|
+
__generate_description: bool = True,
|
302
|
+
**kwargs,
|
303
|
+
) -> Union[MetricResultWrapper, ThresholdTestResultWrapper]:
|
304
|
+
"""Run a test by test ID
|
305
|
+
|
306
|
+
Args:
|
307
|
+
test_id (TestID, optional): The test ID to run. Not required if `unit_metrics` is provided.
|
308
|
+
params (dict, optional): A dictionary of parameters to pass into the test. Params
|
309
|
+
are used to customize the test behavior and are specific to each test. See the
|
310
|
+
test details for more information on the available parameters. Defaults to None.
|
311
|
+
inputs (Dict[str, Any], optional): A dictionary of test inputs to pass into the
|
312
|
+
test. Inputs are either models or datasets that have been initialized using
|
313
|
+
vm.init_model() or vm.init_dataset(). Defaults to None.
|
314
|
+
input_grid (Union[Dict[str, List[Any]], List[Dict[str, Any]]], optional): To run
|
315
|
+
a comparison test, provide either a dictionary of inputs where the keys are
|
316
|
+
the input names and the values are lists of different inputs, or a list of
|
317
|
+
dictionaries where each dictionary is a set of inputs to run the test with.
|
318
|
+
This will run the test multiple times with different sets of inputs and then
|
319
|
+
combine the results into a single output. When passing a dictionary, the grid
|
320
|
+
will be created by taking the Cartesian product of the input lists. Its simply
|
321
|
+
a more convenient way of forming the input grid as opposed to passing a list of
|
322
|
+
all possible combinations. Defaults to None.
|
323
|
+
name (str, optional): The name of the test (used to create a composite metric
|
324
|
+
out of multiple unit metrics) - required when running multiple unit metrics
|
325
|
+
unit_metrics (list, optional): A list of unit metric IDs to run as a composite
|
326
|
+
metric - required when running multiple unit metrics
|
327
|
+
output_template (str, optional): A jinja2 html template to customize the output
|
328
|
+
of the test. Defaults to None.
|
329
|
+
show (bool, optional): Whether to display the results. Defaults to True.
|
330
|
+
**kwargs: Keyword inputs to pass into the test (same as `inputs` but as keyword
|
331
|
+
args instead of a dictionary):
|
332
|
+
- dataset: A validmind Dataset object or a Pandas DataFrame
|
333
|
+
- model: A model to use for the test
|
334
|
+
- models: A list of models to use for the test
|
335
|
+
- dataset: A validmind Dataset object or a Pandas DataFrame
|
336
|
+
"""
|
337
|
+
if not test_id and not name and not unit_metrics:
|
338
|
+
raise ValueError(
|
339
|
+
"`test_id` or `name` and `unit_metrics` must be provided to run a test"
|
340
|
+
)
|
341
|
+
|
342
|
+
if (unit_metrics and not name) or (name and not unit_metrics):
|
343
|
+
raise ValueError("`name` and `unit_metrics` must be provided together")
|
344
|
+
|
345
|
+
if (input_grid and kwargs) or (input_grid and inputs):
|
346
|
+
raise ValueError(
|
347
|
+
"When providing an `input_grid`, you cannot also provide `inputs` or `kwargs`"
|
348
|
+
)
|
349
|
+
|
350
|
+
if input_grid:
|
351
|
+
return run_comparison_test(
|
352
|
+
test_id,
|
353
|
+
input_grid,
|
354
|
+
params=params,
|
355
|
+
output_template=output_template,
|
356
|
+
show=show,
|
357
|
+
generate_description=__generate_description,
|
358
|
+
)
|
359
|
+
|
360
|
+
if test_id and test_id.startswith("validmind.unit_metrics"):
|
361
|
+
# TODO: as we move towards a more unified approach to metrics
|
362
|
+
# we will want to make everything functional and remove the
|
363
|
+
# separation between unit metrics and "normal" metrics
|
364
|
+
return run_metric(test_id, inputs=inputs, params=params, show=show)
|
365
|
+
|
366
|
+
if unit_metrics:
|
367
|
+
metric_id_name = "".join(word[0].upper() + word[1:] for word in name.split())
|
368
|
+
test_id = f"validmind.composite_test.{metric_id_name}"
|
369
|
+
|
370
|
+
error, TestClass = load_composite_metric(
|
371
|
+
unit_metrics=unit_metrics, metric_name=metric_id_name
|
372
|
+
)
|
373
|
+
|
374
|
+
if error:
|
375
|
+
raise LoadTestError(error)
|
376
|
+
|
377
|
+
else:
|
378
|
+
TestClass = load_test(test_id, reload=True)
|
379
|
+
|
380
|
+
test = TestClass(
|
381
|
+
test_id=test_id,
|
382
|
+
context=TestContext(),
|
383
|
+
inputs=TestInput({**kwargs, **(inputs or {})}),
|
384
|
+
output_template=output_template,
|
385
|
+
params=params,
|
386
|
+
generate_description=__generate_description,
|
387
|
+
)
|
388
|
+
|
389
|
+
test.run()
|
390
|
+
|
391
|
+
if show:
|
392
|
+
test.result.show()
|
393
|
+
|
394
|
+
return test.result
|
@@ -9,6 +9,8 @@ from typing import Protocol
|
|
9
9
|
|
10
10
|
from validmind.logging import get_logger
|
11
11
|
|
12
|
+
from ._store import test_provider_store
|
13
|
+
|
12
14
|
logger = get_logger(__name__)
|
13
15
|
|
14
16
|
|
@@ -145,3 +147,13 @@ class LocalTestProvider:
|
|
145
147
|
raise LocalTestProviderLoadTestError(
|
146
148
|
f"Failed to find the test class in the module. Error: {str(e)}"
|
147
149
|
)
|
150
|
+
|
151
|
+
|
152
|
+
def register_test_provider(namespace: str, test_provider: "TestProvider") -> None:
|
153
|
+
"""Register an external test provider
|
154
|
+
|
155
|
+
Args:
|
156
|
+
namespace (str): The namespace of the test provider
|
157
|
+
test_provider (TestProvider): The test provider
|
158
|
+
"""
|
159
|
+
test_provider_store.register_test_provider(namespace, test_provider)
|
validmind/tests/utils.py
ADDED
@@ -0,0 +1,16 @@
|
|
1
|
+
# Copyright © 2023-2024 ValidMind Inc. All rights reserved.
|
2
|
+
# See the LICENSE file in the root of this repository for details.
|
3
|
+
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
|
+
|
5
|
+
"""Test Module Utils"""
|
6
|
+
|
7
|
+
import inspect
|
8
|
+
|
9
|
+
|
10
|
+
def test_description(test_class, truncate=True):
|
11
|
+
description = inspect.getdoc(test_class).strip()
|
12
|
+
|
13
|
+
if truncate and len(description.split("\n")) > 5:
|
14
|
+
return description.strip().split("\n")[0] + "..."
|
15
|
+
|
16
|
+
return description
|
@@ -6,8 +6,9 @@ import hashlib
|
|
6
6
|
import json
|
7
7
|
from importlib import import_module
|
8
8
|
|
9
|
-
from
|
10
|
-
from
|
9
|
+
from validmind.input_registry import input_registry
|
10
|
+
from validmind.tests.decorator import _build_result, _inspect_signature
|
11
|
+
from validmind.utils import get_model_info, test_id_to_name
|
11
12
|
|
12
13
|
unit_metric_results_cache = {}
|
13
14
|
|
@@ -157,7 +158,10 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
|
|
157
158
|
show (bool): Whether to display the results
|
158
159
|
value_only (bool): Whether to return only the value
|
159
160
|
"""
|
160
|
-
inputs =
|
161
|
+
inputs = {
|
162
|
+
k: input_registry.get(v) if isinstance(v, str) else v
|
163
|
+
for k, v in (inputs or {}).items()
|
164
|
+
}
|
161
165
|
params = params or {}
|
162
166
|
|
163
167
|
cache_key = get_metric_cache_key(metric_id, params, inputs)
|
@@ -168,7 +172,11 @@ def run_metric(metric_id, inputs=None, params=None, show=True, value_only=False)
|
|
168
172
|
|
169
173
|
result = metric(
|
170
174
|
**{k: v for k, v in inputs.items() if k in _inputs.keys()},
|
171
|
-
**{
|
175
|
+
**{
|
176
|
+
k: v
|
177
|
+
for k, v in params.items()
|
178
|
+
if k in _params.keys() or "kwargs" in _params.keys()
|
179
|
+
},
|
172
180
|
)
|
173
181
|
unit_metric_results_cache[cache_key] = (
|
174
182
|
result,
|
@@ -42,6 +42,7 @@ class CompositeMetric(Metric):
|
|
42
42
|
params=self.params,
|
43
43
|
output_template=self.output_template,
|
44
44
|
show=False,
|
45
|
+
generate_description=self.generate_description,
|
45
46
|
)
|
46
47
|
|
47
48
|
return self.result
|
@@ -109,6 +110,7 @@ def run_metrics(
|
|
109
110
|
params: dict = None,
|
110
111
|
test_id: str = None,
|
111
112
|
show: bool = True,
|
113
|
+
generate_description: bool = True,
|
112
114
|
) -> MetricResultWrapper:
|
113
115
|
"""Run a composite metric
|
114
116
|
|
@@ -209,6 +211,7 @@ def run_metrics(
|
|
209
211
|
test_id=test_id,
|
210
212
|
default_description=description,
|
211
213
|
summary=result_summary.serialize(),
|
214
|
+
should_generate=generate_description,
|
212
215
|
),
|
213
216
|
{
|
214
217
|
"content_id": f"composite_metric_def:{test_id}:unit_metrics",
|
@@ -78,11 +78,14 @@ class Metric(Test):
|
|
78
78
|
self.result = MetricResultWrapper(
|
79
79
|
result_id=self.test_id,
|
80
80
|
result_metadata=[
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
|
81
|
+
(
|
82
|
+
get_description_metadata(
|
83
|
+
test_id=self.test_id,
|
84
|
+
default_description=self.description(),
|
85
|
+
summary=metric.serialize()["summary"],
|
86
|
+
figures=figures,
|
87
|
+
should_generate=self.generate_description,
|
88
|
+
)
|
86
89
|
)
|
87
90
|
],
|
88
91
|
metric=metric,
|
@@ -344,7 +344,8 @@ class MetricResultWrapper(ResultWrapper):
|
|
344
344
|
"""Check if the metric summary has columns from input datasets"""
|
345
345
|
dataset_columns = set()
|
346
346
|
|
347
|
-
for
|
347
|
+
for input in self.inputs:
|
348
|
+
input_id = input if isinstance(input, str) else input.input_id
|
348
349
|
input_obj = input_registry.get(input_id)
|
349
350
|
if isinstance(input_obj, VMDataset):
|
350
351
|
dataset_columns.update(input_obj.columns)
|