validmind 2.5.19__py3-none-any.whl → 2.5.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +7 -46
- validmind/__version__.py +1 -1
- validmind/ai/test_result_description/context.py +2 -2
- validmind/api_client.py +131 -266
- validmind/client_config.py +1 -3
- validmind/datasets/__init__.py +1 -1
- validmind/datasets/nlp/__init__.py +1 -1
- validmind/errors.py +3 -30
- validmind/tests/load.py +4 -0
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +0 -1
- validmind/tests/model_validation/sklearn/ClassifierPerformance.py +5 -2
- validmind/tests/run.py +219 -116
- validmind/vm_models/test/result_wrapper.py +4 -4
- {validmind-2.5.19.dist-info → validmind-2.5.23.dist-info}/METADATA +12 -12
- {validmind-2.5.19.dist-info → validmind-2.5.23.dist-info}/RECORD +18 -18
- {validmind-2.5.19.dist-info → validmind-2.5.23.dist-info}/LICENSE +0 -0
- {validmind-2.5.19.dist-info → validmind-2.5.23.dist-info}/WHEEL +0 -0
- {validmind-2.5.19.dist-info → validmind-2.5.23.dist-info}/entry_points.txt +0 -0
validmind/errors.py
CHANGED
@@ -3,7 +3,7 @@
|
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
5
|
"""
|
6
|
-
This module contains all the custom errors that are used in the
|
6
|
+
This module contains all the custom errors that are used in the library.
|
7
7
|
|
8
8
|
The following base errors are defined for others:
|
9
9
|
- BaseError
|
@@ -236,14 +236,6 @@ class MissingRExtrasError(BaseError):
|
|
236
236
|
)
|
237
237
|
|
238
238
|
|
239
|
-
class MissingRunCUIDError(APIRequestError):
|
240
|
-
"""
|
241
|
-
When data is being sent to the API but the run_cuid is missing.
|
242
|
-
"""
|
243
|
-
|
244
|
-
pass
|
245
|
-
|
246
|
-
|
247
239
|
class MissingTextContentIdError(APIRequestError):
|
248
240
|
"""
|
249
241
|
When a Text object is sent to the API without a content_id.
|
@@ -260,30 +252,14 @@ class MissingTextContentsError(APIRequestError):
|
|
260
252
|
pass
|
261
253
|
|
262
254
|
|
263
|
-
class
|
255
|
+
class MissingModelIdError(BaseError):
|
264
256
|
def description(self, *args, **kwargs):
|
265
257
|
return (
|
266
258
|
self.message
|
267
|
-
or "
|
259
|
+
or "Model ID must be provided either as an environment variable or as an argument to init."
|
268
260
|
)
|
269
261
|
|
270
262
|
|
271
|
-
class StartTestRunFailedError(APIRequestError):
|
272
|
-
"""
|
273
|
-
When the API was not able to start a test run.
|
274
|
-
"""
|
275
|
-
|
276
|
-
pass
|
277
|
-
|
278
|
-
|
279
|
-
class TestRunNotFoundError(APIRequestError):
|
280
|
-
"""
|
281
|
-
When a test run is not found in the API.
|
282
|
-
"""
|
283
|
-
|
284
|
-
pass
|
285
|
-
|
286
|
-
|
287
263
|
class TestInputInvalidDatasetError(BaseError):
|
288
264
|
"""
|
289
265
|
When an invalid dataset is used in a test context.
|
@@ -369,11 +345,8 @@ def raise_api_error(error_string):
|
|
369
345
|
"missing_text": MissingTextContentsError,
|
370
346
|
"invalid_text_object": InvalidTextObjectError,
|
371
347
|
"invalid_content_id_prefix": InvalidContentIdPrefixError,
|
372
|
-
"missing_run_cuid": MissingRunCUIDError,
|
373
|
-
"test_run_not_found": TestRunNotFoundError,
|
374
348
|
"invalid_metric_results": InvalidMetricResultsError,
|
375
349
|
"invalid_test_results": InvalidTestResultsError,
|
376
|
-
"start_test_run_failed": StartTestRunFailedError,
|
377
350
|
}
|
378
351
|
|
379
352
|
error_class = error_map.get(api_code, APIRequestError)
|
validmind/tests/load.py
CHANGED
@@ -88,6 +88,10 @@ def list_tests(
|
|
88
88
|
Returns:
|
89
89
|
list or pandas.DataFrame: A list of all tests or a formatted table.
|
90
90
|
"""
|
91
|
+
# tests = {
|
92
|
+
# test_id: load_test(test_id, reload=True)
|
93
|
+
# for test_id in test_store.get_test_ids()
|
94
|
+
# }
|
91
95
|
tests = {}
|
92
96
|
for test_id in test_store.get_test_ids():
|
93
97
|
try:
|
@@ -67,6 +67,7 @@ class ClassifierPerformance(Metric):
|
|
67
67
|
"multiclass_classification",
|
68
68
|
"model_performance",
|
69
69
|
]
|
70
|
+
default_params = {"average": "macro"}
|
70
71
|
|
71
72
|
def summary(self, metric_value: dict):
|
72
73
|
"""
|
@@ -134,11 +135,13 @@ class ClassifierPerformance(Metric):
|
|
134
135
|
if len(np.unique(y_true)) > 2:
|
135
136
|
y_pred = self.inputs.dataset.y_pred(self.inputs.model)
|
136
137
|
y_true = y_true.astype(y_pred.dtype)
|
137
|
-
roc_auc = multiclass_roc_auc_score(
|
138
|
+
roc_auc = multiclass_roc_auc_score(
|
139
|
+
y_true, y_pred, average=self.params["average"]
|
140
|
+
)
|
138
141
|
else:
|
139
142
|
y_prob = self.inputs.dataset.y_prob(self.inputs.model)
|
140
143
|
y_true = y_true.astype(y_prob.dtype).flatten()
|
141
|
-
roc_auc = roc_auc_score(y_true, y_prob)
|
144
|
+
roc_auc = roc_auc_score(y_true, y_prob, average=self.params["average"])
|
142
145
|
|
143
146
|
report["roc_auc"] = roc_auc
|
144
147
|
|
validmind/tests/run.py
CHANGED
@@ -2,6 +2,7 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
+
import itertools
|
5
6
|
from itertools import product
|
6
7
|
from typing import Any, Dict, List, Union
|
7
8
|
from uuid import uuid4
|
@@ -137,7 +138,7 @@ def _combine_figures(figure_lists: List[List[Any]], input_groups: List[Dict[str,
|
|
137
138
|
title_template = "{current_title}({input_description})"
|
138
139
|
|
139
140
|
for idx, figures in enumerate(figure_lists):
|
140
|
-
input_group = input_groups[idx]
|
141
|
+
input_group = input_groups[idx]["inputs"]
|
141
142
|
if is_plotly_figure(figures[0].figure):
|
142
143
|
_update_plotly_titles(figures, input_group, title_template)
|
143
144
|
elif is_matplotlib_figure(figures[0].figure):
|
@@ -171,63 +172,55 @@ def _combine_unit_metrics(results: List[MetricResultWrapper]):
|
|
171
172
|
def metric_comparison(
|
172
173
|
results: List[MetricResultWrapper],
|
173
174
|
test_id: TestID,
|
174
|
-
|
175
|
+
input_params_groups: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
|
175
176
|
output_template: str = None,
|
176
177
|
generate_description: bool = True,
|
177
178
|
):
|
178
179
|
"""Build a comparison result for multiple metric results"""
|
179
180
|
ref_id = str(uuid4())
|
180
181
|
|
182
|
+
# Treat param_groups and input_groups as empty lists if they are None or empty
|
183
|
+
input_params_groups = input_params_groups or [{}]
|
184
|
+
|
181
185
|
input_group_strings = []
|
182
186
|
|
183
|
-
for
|
187
|
+
for input_params in input_params_groups:
|
184
188
|
new_group = {}
|
185
|
-
for
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
189
|
+
for param_k, param_v in input_params["params"].items():
|
190
|
+
new_group[param_k] = param_v
|
191
|
+
for metric_k, metric_v in input_params["inputs"].items():
|
192
|
+
# Process values in the input group
|
193
|
+
if isinstance(metric_v, str):
|
194
|
+
new_group[metric_k] = metric_v
|
195
|
+
elif hasattr(metric_v, "input_id"):
|
196
|
+
new_group[metric_k] = metric_v.input_id
|
197
|
+
elif isinstance(metric_v, list) and all(
|
198
|
+
hasattr(item, "input_id") for item in metric_v
|
199
|
+
):
|
200
|
+
new_group[metric_k] = ", ".join([item.input_id for item in metric_v])
|
192
201
|
else:
|
193
|
-
raise ValueError(f"Unsupported type for value: {
|
202
|
+
raise ValueError(f"Unsupported type for value: {metric_v}")
|
194
203
|
input_group_strings.append(new_group)
|
195
204
|
|
196
205
|
# handle unit metrics (scalar values) by adding it to the summary
|
197
206
|
_combine_unit_metrics(results)
|
198
207
|
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
# Check if the results list contains a result object with figures
|
217
|
-
if any(hasattr(result, "figures") and result.figures for result in results):
|
218
|
-
# Compute merged figures only if there is at least one result with figures
|
219
|
-
merged_figures = _combine_figures(
|
220
|
-
[result.figures for result in results],
|
221
|
-
input_groups,
|
222
|
-
)
|
223
|
-
# Patch figure metadata so they are connected to the comparison result
|
224
|
-
if merged_figures and len(merged_figures):
|
225
|
-
for i, figure in enumerate(merged_figures):
|
226
|
-
figure.key = f"{figure.key}-{i}"
|
227
|
-
figure.metadata["_name"] = test_id
|
228
|
-
figure.metadata["_ref_id"] = ref_id
|
229
|
-
else:
|
230
|
-
merged_figures = None
|
208
|
+
merged_summary = _combine_summaries(
|
209
|
+
[
|
210
|
+
{"inputs": input_group_strings[i], "summary": result.metric.summary}
|
211
|
+
for i, result in enumerate(results)
|
212
|
+
]
|
213
|
+
)
|
214
|
+
merged_figures = _combine_figures(
|
215
|
+
[result.figures for result in results], input_params_groups
|
216
|
+
)
|
217
|
+
|
218
|
+
# Patch figure metadata so they are connected to the comparison result
|
219
|
+
if merged_figures and len(merged_figures):
|
220
|
+
for i, figure in enumerate(merged_figures):
|
221
|
+
figure.key = f"{figure.key}-{i}"
|
222
|
+
figure.metadata["_name"] = test_id
|
223
|
+
figure.metadata["_ref_id"] = ref_id
|
231
224
|
|
232
225
|
return MetricResultWrapper(
|
233
226
|
result_id=test_id,
|
@@ -236,14 +229,14 @@ def metric_comparison(
|
|
236
229
|
test_id=test_id,
|
237
230
|
default_description=f"Comparison test result for {test_id}",
|
238
231
|
summary=merged_summary.serialize() if merged_summary else None,
|
239
|
-
figures=merged_figures
|
232
|
+
figures=merged_figures,
|
240
233
|
should_generate=generate_description,
|
241
234
|
),
|
242
235
|
],
|
243
236
|
inputs=[
|
244
237
|
item.input_id if hasattr(item, "input_id") else item
|
245
|
-
for group in
|
246
|
-
for input in group.values()
|
238
|
+
for group in input_params_groups
|
239
|
+
for input in group["inputs"].values()
|
247
240
|
for item in (input if isinstance(input, list) else [input])
|
248
241
|
if hasattr(item, "input_id") or isinstance(item, str)
|
249
242
|
],
|
@@ -333,39 +326,63 @@ def threshold_test_comparison(
|
|
333
326
|
|
334
327
|
def run_comparison_test(
|
335
328
|
test_id: TestID,
|
336
|
-
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]],
|
329
|
+
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
|
330
|
+
inputs: Dict[str, Any] = None,
|
337
331
|
name: str = None,
|
338
332
|
unit_metrics: List[TestID] = None,
|
333
|
+
param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
|
339
334
|
params: Dict[str, Any] = None,
|
340
335
|
show: bool = True,
|
341
336
|
output_template: str = None,
|
342
337
|
generate_description: bool = True,
|
343
338
|
):
|
344
339
|
"""Run a comparison test"""
|
345
|
-
if
|
346
|
-
|
340
|
+
if input_grid:
|
341
|
+
if isinstance(input_grid, dict):
|
342
|
+
input_groups = _cartesian_product(input_grid)
|
343
|
+
else:
|
344
|
+
input_groups = input_grid
|
347
345
|
else:
|
348
|
-
input_groups =
|
346
|
+
input_groups = list(inputs) if inputs else []
|
349
347
|
|
348
|
+
if param_grid:
|
349
|
+
if isinstance(param_grid, dict):
|
350
|
+
param_groups = _cartesian_product(param_grid)
|
351
|
+
else:
|
352
|
+
param_groups = param_grid
|
353
|
+
else:
|
354
|
+
param_groups = list(params) if inputs else []
|
355
|
+
|
356
|
+
input_groups = input_groups or [{}]
|
357
|
+
param_groups = param_groups or [{}]
|
358
|
+
# Use itertools.product to compute the Cartesian product
|
359
|
+
inputs_params_product = [
|
360
|
+
{
|
361
|
+
"inputs": item1,
|
362
|
+
"params": item2,
|
363
|
+
} # Merge dictionaries from input_groups and param_groups
|
364
|
+
for item1, item2 in itertools.product(input_groups, param_groups)
|
365
|
+
]
|
350
366
|
results = [
|
351
367
|
run_test(
|
352
368
|
test_id,
|
353
369
|
name=name,
|
354
370
|
unit_metrics=unit_metrics,
|
355
|
-
inputs=inputs,
|
371
|
+
inputs=inputs_params["inputs"],
|
356
372
|
show=False,
|
357
|
-
params=params,
|
373
|
+
params=inputs_params["params"],
|
358
374
|
__generate_description=False,
|
359
375
|
)
|
360
|
-
for
|
376
|
+
for inputs_params in (inputs_params_product or [{}])
|
361
377
|
]
|
362
|
-
|
363
378
|
if isinstance(results[0], MetricResultWrapper):
|
364
379
|
func = metric_comparison
|
365
380
|
else:
|
366
381
|
func = threshold_test_comparison
|
367
382
|
|
368
|
-
result = func(
|
383
|
+
result = func(
|
384
|
+
results, test_id, inputs_params_product, output_template, generate_description
|
385
|
+
)
|
369
386
|
|
370
387
|
if show:
|
371
388
|
result.show()
|
@@ -376,6 +393,7 @@ def run_comparison_test(
|
|
376
393
|
def run_test(
|
377
394
|
test_id: TestID = None,
|
378
395
|
params: Dict[str, Any] = None,
|
396
|
+
param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
|
379
397
|
inputs: Dict[str, Any] = None,
|
380
398
|
input_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]]] = None,
|
381
399
|
name: str = None,
|
@@ -385,83 +403,81 @@ def run_test(
|
|
385
403
|
__generate_description: bool = True,
|
386
404
|
**kwargs,
|
387
405
|
) -> Union[MetricResultWrapper, ThresholdTestResultWrapper]:
|
388
|
-
"""Run a test by test ID
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
406
|
+
"""Run a test by test ID.
|
407
|
+
test_id (TestID, optional): The test ID to run. Not required if `unit_metrics` is provided.
|
408
|
+
params (dict, optional): A dictionary of parameters to pass into the test. Params
|
409
|
+
are used to customize the test behavior and are specific to each test. See the
|
410
|
+
test details for more information on the available parameters. Defaults to None.
|
411
|
+
param_grid (Union[Dict[str, List[Any]], List[Dict[str, Any]]], optional): To run
|
412
|
+
a comparison test, provide either a dictionary of parameters where the keys are
|
413
|
+
the parameter names and the values are lists of different parameters, or a list of
|
414
|
+
dictionaries where each dictionary is a set of parameters to run the test with.
|
415
|
+
This will run the test multiple times with different sets of parameters and then
|
416
|
+
combine the results into a single output. When passing a dictionary, the grid
|
417
|
+
will be created by taking the Cartesian product of the parameter lists. Its simply
|
418
|
+
a more convenient way of forming the param grid as opposed to passing a list of
|
419
|
+
all possible combinations. Defaults to None.
|
420
|
+
inputs (Dict[str, Any], optional): A dictionary of test inputs to pass into the
|
421
|
+
test. Inputs are either models or datasets that have been initialized using
|
422
|
+
vm.init_model() or vm.init_dataset(). Defaults to None.
|
423
|
+
input_grid (Union[Dict[str, List[Any]], List[Dict[str, Any]]], optional): To run
|
424
|
+
a comparison test, provide either a dictionary of inputs where the keys are
|
425
|
+
the input names and the values are lists of different inputs, or a list of
|
426
|
+
dictionaries where each dictionary is a set of inputs to run the test with.
|
427
|
+
This will run the test multiple times with different sets of inputs and then
|
428
|
+
combine the results into a single output. When passing a dictionary, the grid
|
429
|
+
will be created by taking the Cartesian product of the input lists. Its simply
|
430
|
+
a more convenient way of forming the input grid as opposed to passing a list of
|
431
|
+
all possible combinations. Defaults to None.
|
432
|
+
name (str, optional): The name of the test (used to create a composite metric
|
433
|
+
out of multiple unit metrics) - required when running multiple unit metrics
|
434
|
+
unit_metrics (list, optional): A list of unit metric IDs to run as a composite
|
435
|
+
metric - required when running multiple unit metrics
|
436
|
+
output_template (str, optional): A jinja2 html template to customize the output
|
437
|
+
of the test. Defaults to None.
|
438
|
+
show (bool, optional): Whether to display the results. Defaults to True.
|
439
|
+
**kwargs: Keyword inputs to pass into the test (same as `inputs` but as keyword
|
440
|
+
args instead of a dictionary):
|
441
|
+
- dataset: A validmind Dataset object or a Pandas DataFrame
|
442
|
+
- model: A model to use for the test
|
443
|
+
- models: A list of models to use for the test
|
444
|
+
- dataset: A validmind Dataset object or a Pandas DataFrame
|
420
445
|
"""
|
421
|
-
if not test_id and not name and not unit_metrics:
|
422
|
-
raise ValueError(
|
423
|
-
"`test_id` or `name` and `unit_metrics` must be provided to run a test"
|
424
|
-
)
|
425
446
|
|
426
|
-
|
427
|
-
|
428
|
-
|
429
|
-
if (input_grid and kwargs) or (input_grid and inputs):
|
430
|
-
raise ValueError(
|
431
|
-
"When providing an `input_grid`, you cannot also provide `inputs` or `kwargs`"
|
432
|
-
)
|
447
|
+
# Validate input arguments with helper functions
|
448
|
+
validate_test_inputs(test_id, name, unit_metrics)
|
449
|
+
validate_grid_inputs(input_grid, kwargs, inputs, param_grid, params)
|
433
450
|
|
451
|
+
# Handle composite metric creation
|
434
452
|
if unit_metrics:
|
435
|
-
|
436
|
-
test_id = f"validmind.composite_metric.{metric_id_name}" or test_id
|
453
|
+
test_id = generate_composite_test_id(name, test_id)
|
437
454
|
|
438
|
-
if
|
439
|
-
|
455
|
+
# Run comparison tests if applicable
|
456
|
+
if input_grid or param_grid:
|
457
|
+
return run_comparison_test_with_grids(
|
440
458
|
test_id,
|
459
|
+
inputs,
|
441
460
|
input_grid,
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
447
|
-
|
461
|
+
param_grid,
|
462
|
+
name,
|
463
|
+
unit_metrics,
|
464
|
+
params,
|
465
|
+
output_template,
|
466
|
+
show,
|
467
|
+
__generate_description,
|
448
468
|
)
|
449
469
|
|
470
|
+
# Run unit metric tests
|
450
471
|
if test_id.startswith("validmind.unit_metrics"):
|
451
472
|
# TODO: as we move towards a more unified approach to metrics
|
452
473
|
# we will want to make everything functional and remove the
|
453
474
|
# separation between unit metrics and "normal" metrics
|
454
475
|
return run_metric(test_id, inputs=inputs, params=params, show=show)
|
455
476
|
|
456
|
-
|
457
|
-
|
458
|
-
unit_metrics=unit_metrics, metric_name=metric_id_name
|
459
|
-
)
|
460
|
-
if error:
|
461
|
-
raise LoadTestError(error)
|
462
|
-
else:
|
463
|
-
TestClass = load_test(test_id, reload=True)
|
477
|
+
# Load the appropriate test class
|
478
|
+
TestClass = load_test_class(test_id, unit_metrics, name)
|
464
479
|
|
480
|
+
# Create and run the test
|
465
481
|
test = TestClass(
|
466
482
|
test_id=test_id,
|
467
483
|
context=TestContext(),
|
@@ -477,3 +493,90 @@ def run_test(
|
|
477
493
|
test.result.show()
|
478
494
|
|
479
495
|
return test.result
|
496
|
+
|
497
|
+
|
498
|
+
def validate_test_inputs(test_id, name, unit_metrics):
|
499
|
+
"""Validate the main test inputs for `test_id`, `name`, and `unit_metrics`."""
|
500
|
+
if not test_id and not (name and unit_metrics):
|
501
|
+
raise ValueError(
|
502
|
+
"`test_id` or both `name` and `unit_metrics` must be provided to run a test"
|
503
|
+
)
|
504
|
+
|
505
|
+
if bool(unit_metrics) != bool(name):
|
506
|
+
raise ValueError("`name` and `unit_metrics` must be provided together")
|
507
|
+
|
508
|
+
|
509
|
+
def validate_grid_inputs(input_grid, kwargs, inputs, param_grid, params):
|
510
|
+
"""Validate the grid inputs to avoid conflicting parameters."""
|
511
|
+
if input_grid and (kwargs or inputs):
|
512
|
+
raise ValueError("Cannot provide `input_grid` along with `inputs` or `kwargs`")
|
513
|
+
|
514
|
+
if param_grid and (kwargs or params):
|
515
|
+
raise ValueError("Cannot provide `param_grid` along with `params` or `kwargs`")
|
516
|
+
|
517
|
+
|
518
|
+
def generate_composite_test_id(name, test_id):
|
519
|
+
"""Generate a composite test ID if unit metrics are provided."""
|
520
|
+
metric_id_name = "".join(word.capitalize() for word in name.split())
|
521
|
+
return f"validmind.composite_metric.{metric_id_name}" or test_id
|
522
|
+
|
523
|
+
|
524
|
+
def run_comparison_test_with_grids(
|
525
|
+
test_id,
|
526
|
+
inputs,
|
527
|
+
input_grid,
|
528
|
+
param_grid,
|
529
|
+
name,
|
530
|
+
unit_metrics,
|
531
|
+
params,
|
532
|
+
output_template,
|
533
|
+
show,
|
534
|
+
generate_description,
|
535
|
+
):
|
536
|
+
"""Run a comparison test based on the presence of input and param grids."""
|
537
|
+
if input_grid and param_grid:
|
538
|
+
return run_comparison_test(
|
539
|
+
test_id,
|
540
|
+
input_grid,
|
541
|
+
name=name,
|
542
|
+
unit_metrics=unit_metrics,
|
543
|
+
param_grid=param_grid,
|
544
|
+
output_template=output_template,
|
545
|
+
show=show,
|
546
|
+
generate_description=generate_description,
|
547
|
+
)
|
548
|
+
if input_grid:
|
549
|
+
return run_comparison_test(
|
550
|
+
test_id,
|
551
|
+
input_grid,
|
552
|
+
name=name,
|
553
|
+
unit_metrics=unit_metrics,
|
554
|
+
params=params,
|
555
|
+
output_template=output_template,
|
556
|
+
show=show,
|
557
|
+
generate_description=generate_description,
|
558
|
+
)
|
559
|
+
if param_grid:
|
560
|
+
return run_comparison_test(
|
561
|
+
test_id,
|
562
|
+
inputs=inputs,
|
563
|
+
name=name,
|
564
|
+
unit_metrics=unit_metrics,
|
565
|
+
param_grid=param_grid,
|
566
|
+
output_template=output_template,
|
567
|
+
show=show,
|
568
|
+
generate_description=generate_description,
|
569
|
+
)
|
570
|
+
|
571
|
+
|
572
|
+
def load_test_class(test_id, unit_metrics, name):
|
573
|
+
"""Load the appropriate test class based on `test_id` and unit metrics."""
|
574
|
+
if unit_metrics:
|
575
|
+
metric_id_name = "".join(word.capitalize() for word in name.split())
|
576
|
+
error, TestClass = load_composite_metric(
|
577
|
+
unit_metrics=unit_metrics, metric_name=metric_id_name
|
578
|
+
)
|
579
|
+
if error:
|
580
|
+
raise LoadTestError(error)
|
581
|
+
return TestClass
|
582
|
+
return load_test(test_id, reload=True)
|
@@ -378,8 +378,8 @@ class MetricResultWrapper(ResultWrapper):
|
|
378
378
|
self.metric.summary = self._get_filtered_summary()
|
379
379
|
|
380
380
|
tasks.append(
|
381
|
-
api_client.
|
382
|
-
|
381
|
+
api_client.log_metric_result(
|
382
|
+
metric=self.metric,
|
383
383
|
inputs=self.inputs,
|
384
384
|
output_template=self.output_template,
|
385
385
|
section_id=section_id,
|
@@ -388,7 +388,7 @@ class MetricResultWrapper(ResultWrapper):
|
|
388
388
|
)
|
389
389
|
|
390
390
|
if self.figures:
|
391
|
-
tasks.
|
391
|
+
tasks.extend([api_client.log_figure(figure) for figure in self.figures])
|
392
392
|
|
393
393
|
if hasattr(self, "result_metadata") and self.result_metadata:
|
394
394
|
description = self.result_metadata[0].get("text", "")
|
@@ -474,7 +474,7 @@ class ThresholdTestResultWrapper(ResultWrapper):
|
|
474
474
|
]
|
475
475
|
|
476
476
|
if self.figures:
|
477
|
-
tasks.
|
477
|
+
tasks.extend([api_client.log_figure(figure) for figure in self.figures])
|
478
478
|
|
479
479
|
if hasattr(self, "result_metadata") and self.result_metadata:
|
480
480
|
description = self.result_metadata[0].get("text", "")
|