validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +6 -5
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +13 -9
- validmind/ai/utils.py +2 -2
- validmind/api_client.py +75 -32
- validmind/client.py +111 -100
- validmind/client_config.py +3 -3
- validmind/datasets/classification/__init__.py +7 -3
- validmind/datasets/credit_risk/lending_club.py +28 -16
- validmind/datasets/nlp/cnn_dailymail.py +10 -4
- validmind/datasets/regression/__init__.py +22 -5
- validmind/errors.py +17 -7
- validmind/input_registry.py +1 -1
- validmind/logging.py +44 -35
- validmind/models/foundation.py +2 -2
- validmind/models/function.py +10 -3
- validmind/template.py +33 -24
- validmind/test_suites/__init__.py +2 -2
- validmind/tests/_store.py +13 -4
- validmind/tests/comparison.py +65 -33
- validmind/tests/data_validation/ClassImbalance.py +3 -1
- validmind/tests/data_validation/DatasetDescription.py +2 -23
- validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
- validmind/tests/data_validation/Skewness.py +7 -6
- validmind/tests/decorator.py +14 -11
- validmind/tests/load.py +38 -24
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
- validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
- validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
- validmind/tests/output.py +66 -11
- validmind/tests/run.py +28 -14
- validmind/tests/test_providers.py +28 -35
- validmind/tests/utils.py +17 -4
- validmind/unit_metrics/__init__.py +1 -1
- validmind/utils.py +295 -31
- validmind/vm_models/dataset/dataset.py +83 -43
- validmind/vm_models/dataset/utils.py +5 -3
- validmind/vm_models/figure.py +6 -6
- validmind/vm_models/input.py +6 -5
- validmind/vm_models/model.py +5 -5
- validmind/vm_models/result/result.py +122 -43
- validmind/vm_models/result/utils.py +5 -5
- validmind/vm_models/test_suite/__init__.py +5 -0
- validmind/vm_models/test_suite/runner.py +5 -5
- validmind/vm_models/test_suite/summary.py +20 -2
- validmind/vm_models/test_suite/test.py +6 -6
- validmind/vm_models/test_suite/test_suite.py +10 -10
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
@@ -47,7 +47,7 @@ def _compute_metrics(
|
|
47
47
|
None: The computed metrics are appended to the `results` dictionary in-place.
|
48
48
|
"""
|
49
49
|
results["Slice"].append(str(region))
|
50
|
-
results["
|
50
|
+
results["Number of Records"].append(df_region.shape[0])
|
51
51
|
results["Feature"].append(feature_column)
|
52
52
|
|
53
53
|
# Check if df_region is an empty dataframe and if so, append 0 to all metrics
|
@@ -222,7 +222,7 @@ def WeakspotsDiagnosis(
|
|
222
222
|
thresholds = thresholds or DEFAULT_THRESHOLDS
|
223
223
|
thresholds = {k.title(): v for k, v in thresholds.items()}
|
224
224
|
|
225
|
-
results_headers = ["Slice", "
|
225
|
+
results_headers = ["Slice", "Number of Records", "Feature"]
|
226
226
|
results_headers.extend(metrics.keys())
|
227
227
|
|
228
228
|
figures = []
|
@@ -236,19 +236,20 @@ def WeakspotsDiagnosis(
|
|
236
236
|
feature_columns
|
237
237
|
+ [datasets[1].target_column, datasets[1].prediction_column(model)]
|
238
238
|
]
|
239
|
-
|
239
|
+
results_1 = pd.DataFrame()
|
240
|
+
results_2 = pd.DataFrame()
|
240
241
|
for feature in feature_columns:
|
241
242
|
bins = 10
|
242
243
|
if feature in datasets[0].feature_columns_categorical:
|
243
244
|
bins = len(df_1[feature].unique())
|
244
245
|
df_1["bin"] = pd.cut(df_1[feature], bins=bins)
|
245
246
|
|
246
|
-
|
247
|
-
|
247
|
+
r1 = {k: [] for k in results_headers}
|
248
|
+
r2 = {k: [] for k in results_headers}
|
248
249
|
|
249
250
|
for region, df_region in df_1.groupby("bin"):
|
250
251
|
_compute_metrics(
|
251
|
-
results=
|
252
|
+
results=r1,
|
252
253
|
metrics=metrics,
|
253
254
|
region=region,
|
254
255
|
df_region=df_region,
|
@@ -260,7 +261,7 @@ def WeakspotsDiagnosis(
|
|
260
261
|
(df_2[feature] > region.left) & (df_2[feature] <= region.right)
|
261
262
|
]
|
262
263
|
_compute_metrics(
|
263
|
-
results=
|
264
|
+
results=r2,
|
264
265
|
metrics=metrics,
|
265
266
|
region=region,
|
266
267
|
df_region=df_2_region,
|
@@ -271,8 +272,8 @@ def WeakspotsDiagnosis(
|
|
271
272
|
|
272
273
|
for metric in metrics.keys():
|
273
274
|
fig, df = _plot_weak_spots(
|
274
|
-
results_1=
|
275
|
-
results_2=
|
275
|
+
results_1=r1,
|
276
|
+
results_2=r2,
|
276
277
|
feature_column=feature,
|
277
278
|
metric=metric,
|
278
279
|
threshold=thresholds[metric],
|
@@ -284,6 +285,8 @@ def WeakspotsDiagnosis(
|
|
284
285
|
# rely on visual assessment for this test for now.
|
285
286
|
if not df[df[list(thresholds.keys())].lt(thresholds).any(axis=1)].empty:
|
286
287
|
passed = False
|
288
|
+
results_1 = pd.concat([results_1, pd.DataFrame(r1)])
|
289
|
+
results_2 = pd.concat([results_2, pd.DataFrame(r2)])
|
287
290
|
|
288
291
|
return (
|
289
292
|
pd.concat(
|
@@ -291,7 +294,9 @@ def WeakspotsDiagnosis(
|
|
291
294
|
pd.DataFrame(results_1).assign(Dataset=datasets[0].input_id),
|
292
295
|
pd.DataFrame(results_2).assign(Dataset=datasets[1].input_id),
|
293
296
|
]
|
294
|
-
)
|
297
|
+
)
|
298
|
+
.reset_index(drop=True)
|
299
|
+
.sort_values(["Feature", "Dataset"]),
|
295
300
|
*figures,
|
296
301
|
passed,
|
297
302
|
)
|
validmind/tests/output.py
CHANGED
@@ -9,6 +9,7 @@ from uuid import uuid4
|
|
9
9
|
import numpy as np
|
10
10
|
import pandas as pd
|
11
11
|
|
12
|
+
from validmind.utils import is_html, md_to_html
|
12
13
|
from validmind.vm_models.figure import (
|
13
14
|
Figure,
|
14
15
|
is_matplotlib_figure,
|
@@ -77,30 +78,72 @@ class FigureOutputHandler(OutputHandler):
|
|
77
78
|
|
78
79
|
class TableOutputHandler(OutputHandler):
|
79
80
|
def can_handle(self, item: Any) -> bool:
|
80
|
-
return isinstance(item, (list, pd.DataFrame, dict, ResultTable))
|
81
|
+
return isinstance(item, (list, pd.DataFrame, dict, ResultTable, tuple))
|
82
|
+
|
83
|
+
def _convert_simple_type(self, data: Any) -> pd.DataFrame:
|
84
|
+
"""Convert a simple data type to a DataFrame."""
|
85
|
+
if isinstance(data, dict):
|
86
|
+
return pd.DataFrame([data])
|
87
|
+
elif data is None:
|
88
|
+
return pd.DataFrame()
|
89
|
+
else:
|
90
|
+
raise ValueError(f"Cannot convert {type(data)} to DataFrame")
|
91
|
+
|
92
|
+
def _convert_list(self, data_list: List) -> pd.DataFrame:
|
93
|
+
"""Convert a list to a DataFrame."""
|
94
|
+
if not data_list:
|
95
|
+
return pd.DataFrame()
|
96
|
+
|
97
|
+
try:
|
98
|
+
return pd.DataFrame(data_list)
|
99
|
+
except Exception as e:
|
100
|
+
# If conversion fails, try to handle common cases
|
101
|
+
if all(
|
102
|
+
isinstance(item, (int, float, str, bool, type(None)))
|
103
|
+
for item in data_list
|
104
|
+
):
|
105
|
+
return pd.DataFrame({"Values": data_list})
|
106
|
+
else:
|
107
|
+
raise ValueError(f"Could not convert list to DataFrame: {e}")
|
108
|
+
|
109
|
+
def _convert_to_dataframe(self, table_data: Any) -> pd.DataFrame:
|
110
|
+
"""Convert various data types to a pandas DataFrame."""
|
111
|
+
# Handle special cases by type
|
112
|
+
if isinstance(table_data, pd.DataFrame):
|
113
|
+
return table_data
|
114
|
+
elif isinstance(table_data, (dict, str, type(None))):
|
115
|
+
return self._convert_simple_type(table_data)
|
116
|
+
elif isinstance(table_data, tuple):
|
117
|
+
return self._convert_list(list(table_data))
|
118
|
+
elif isinstance(table_data, list):
|
119
|
+
return self._convert_list(table_data)
|
120
|
+
else:
|
121
|
+
# If we reach here, we don't know how to handle this type
|
122
|
+
raise ValueError(
|
123
|
+
f"Invalid table format: must be a list of dictionaries or a DataFrame, got {type(table_data)}"
|
124
|
+
)
|
81
125
|
|
82
126
|
def process(
|
83
127
|
self,
|
84
|
-
item: Union[
|
128
|
+
item: Union[
|
129
|
+
List[Dict[str, Any]], pd.DataFrame, Dict[str, Any], ResultTable, str, tuple
|
130
|
+
],
|
85
131
|
result: TestResult,
|
86
132
|
) -> None:
|
133
|
+
# Convert to a dictionary of tables if not already
|
87
134
|
tables = item if isinstance(item, dict) else {"": item}
|
88
135
|
|
89
136
|
for table_name, table_data in tables.items():
|
90
|
-
#
|
137
|
+
# If already a ResultTable, add it directly
|
91
138
|
if isinstance(table_data, ResultTable):
|
92
139
|
result.add_table(table_data)
|
93
140
|
continue
|
94
141
|
|
95
|
-
|
96
|
-
|
97
|
-
"Invalid table format: must be a list of dictionaries or a DataFrame"
|
98
|
-
)
|
99
|
-
|
100
|
-
if isinstance(table_data, list):
|
101
|
-
table_data = pd.DataFrame(table_data)
|
142
|
+
# Convert the data to a DataFrame using our helper method
|
143
|
+
df = self._convert_to_dataframe(table_data)
|
102
144
|
|
103
|
-
|
145
|
+
# Add the resulting DataFrame as a table to the resul
|
146
|
+
result.add_table(ResultTable(data=df, title=table_name or None))
|
104
147
|
|
105
148
|
|
106
149
|
class RawDataOutputHandler(OutputHandler):
|
@@ -111,6 +154,17 @@ class RawDataOutputHandler(OutputHandler):
|
|
111
154
|
result.raw_data = item
|
112
155
|
|
113
156
|
|
157
|
+
class StringOutputHandler(OutputHandler):
|
158
|
+
def can_handle(self, item: Any) -> bool:
|
159
|
+
return isinstance(item, str)
|
160
|
+
|
161
|
+
def process(self, item: Any, result: TestResult) -> None:
|
162
|
+
if not is_html(item):
|
163
|
+
item = md_to_html(item, mathml=True)
|
164
|
+
|
165
|
+
result.description = item
|
166
|
+
|
167
|
+
|
114
168
|
def process_output(item: Any, result: TestResult) -> None:
|
115
169
|
"""Process a single test output item and update the TestResult."""
|
116
170
|
handlers = [
|
@@ -119,6 +173,7 @@ def process_output(item: Any, result: TestResult) -> None:
|
|
119
173
|
FigureOutputHandler(),
|
120
174
|
TableOutputHandler(),
|
121
175
|
RawDataOutputHandler(),
|
176
|
+
StringOutputHandler(),
|
122
177
|
]
|
123
178
|
|
124
179
|
for handler in handlers:
|
validmind/tests/run.py
CHANGED
@@ -76,7 +76,7 @@ def _get_run_metadata(**metadata: Dict[str, Any]) -> Dict[str, Any]:
|
|
76
76
|
|
77
77
|
def _get_test_kwargs(
|
78
78
|
test_func: callable, inputs: Dict[str, Any], params: Dict[str, Any]
|
79
|
-
):
|
79
|
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
80
80
|
"""Insepect function signature to build kwargs to pass the inputs and params
|
81
81
|
that the test function expects
|
82
82
|
|
@@ -93,7 +93,7 @@ def _get_test_kwargs(
|
|
93
93
|
params (dict): Test parameters e.g. {"param1": 1, "param2": 2}
|
94
94
|
|
95
95
|
Returns:
|
96
|
-
|
96
|
+
Tuple[Dict[str, Any], Dict[str, Any]]: Tuple of input and param kwargs
|
97
97
|
"""
|
98
98
|
input_kwargs = {} # map function inputs (`dataset` etc) to actual objects
|
99
99
|
|
@@ -222,6 +222,7 @@ def _run_comparison_test(
|
|
222
222
|
params: Union[Dict[str, Any], None],
|
223
223
|
param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
|
224
224
|
title: Optional[str] = None,
|
225
|
+
show_params: bool = True,
|
225
226
|
):
|
226
227
|
"""Run a comparison test i.e. a test that compares multiple outputs of a test across
|
227
228
|
different input and/or param combinations"""
|
@@ -242,6 +243,7 @@ def _run_comparison_test(
|
|
242
243
|
show=False,
|
243
244
|
generate_description=False,
|
244
245
|
title=title,
|
246
|
+
show_params=show_params,
|
245
247
|
)
|
246
248
|
for config in run_test_configs
|
247
249
|
]
|
@@ -253,7 +255,9 @@ def _run_comparison_test(
|
|
253
255
|
else:
|
254
256
|
test_doc = describe_test(test_id, raw=True)["Description"]
|
255
257
|
|
256
|
-
combined_outputs, combined_inputs, combined_params = combine_results(
|
258
|
+
combined_outputs, combined_inputs, combined_params = combine_results(
|
259
|
+
results, show_params
|
260
|
+
)
|
257
261
|
|
258
262
|
return build_test_result(
|
259
263
|
outputs=combined_outputs,
|
@@ -265,7 +269,12 @@ def _run_comparison_test(
|
|
265
269
|
)
|
266
270
|
|
267
271
|
|
268
|
-
def _run_test(
|
272
|
+
def _run_test(
|
273
|
+
test_id: TestID,
|
274
|
+
inputs: Dict[str, Any],
|
275
|
+
params: Dict[str, Any],
|
276
|
+
title: Optional[str] = None,
|
277
|
+
):
|
269
278
|
"""Run a standard test and return a TestResult object"""
|
270
279
|
test_func = load_test(test_id)
|
271
280
|
input_kwargs, param_kwargs = _get_test_kwargs(
|
@@ -282,6 +291,7 @@ def _run_test(test_id: TestID, inputs: Dict[str, Any], params: Dict[str, Any]):
|
|
282
291
|
test_doc=getdoc(test_func),
|
283
292
|
inputs=input_kwargs,
|
284
293
|
params=param_kwargs,
|
294
|
+
title=title,
|
285
295
|
)
|
286
296
|
|
287
297
|
|
@@ -297,6 +307,7 @@ def run_test( # noqa: C901
|
|
297
307
|
generate_description: bool = True,
|
298
308
|
title: Optional[str] = None,
|
299
309
|
post_process_fn: Union[Callable[[TestResult], None], None] = None,
|
310
|
+
show_params: bool = True,
|
300
311
|
**kwargs,
|
301
312
|
) -> TestResult:
|
302
313
|
"""Run a ValidMind or custom test
|
@@ -321,6 +332,7 @@ def run_test( # noqa: C901
|
|
321
332
|
generate_description (bool, optional): Whether to generate a description. Defaults to True.
|
322
333
|
title (str, optional): Custom title for the test result
|
323
334
|
post_process_fn (Callable[[TestResult], None], optional): Function to post-process the test result
|
335
|
+
show_params (bool, optional): Whether to include parameter values in figure titles for comparison tests. Defaults to True.
|
324
336
|
|
325
337
|
Returns:
|
326
338
|
TestResult: A TestResult object containing the test results
|
@@ -358,6 +370,7 @@ def run_test( # noqa: C901
|
|
358
370
|
input_grid=input_grid,
|
359
371
|
params=params,
|
360
372
|
param_grid=param_grid,
|
373
|
+
show_params=show_params,
|
361
374
|
)
|
362
375
|
|
363
376
|
elif unit_metrics:
|
@@ -375,7 +388,7 @@ def run_test( # noqa: C901
|
|
375
388
|
)
|
376
389
|
|
377
390
|
else:
|
378
|
-
result = _run_test(test_id, inputs, params)
|
391
|
+
result = _run_test(test_id, inputs, params, title)
|
379
392
|
|
380
393
|
end_time = time.perf_counter()
|
381
394
|
result.metadata = _get_run_metadata(duration_seconds=end_time - start_time)
|
@@ -383,15 +396,16 @@ def run_test( # noqa: C901
|
|
383
396
|
if post_process_fn:
|
384
397
|
result = post_process_fn(result)
|
385
398
|
|
386
|
-
result.description
|
387
|
-
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
399
|
+
if not result.description:
|
400
|
+
result.description = get_result_description(
|
401
|
+
test_id=test_id,
|
402
|
+
test_description=result.doc,
|
403
|
+
tables=result.tables,
|
404
|
+
figures=result.figures,
|
405
|
+
metric=result.metric,
|
406
|
+
should_generate=generate_description,
|
407
|
+
title=title,
|
408
|
+
)
|
395
409
|
|
396
410
|
if show:
|
397
411
|
result.show()
|
@@ -7,7 +7,7 @@ import os
|
|
7
7
|
import re
|
8
8
|
import sys
|
9
9
|
from pathlib import Path
|
10
|
-
from typing import List, Protocol
|
10
|
+
from typing import Any, Callable, List, Protocol
|
11
11
|
|
12
12
|
from validmind.logging import get_logger
|
13
13
|
|
@@ -95,45 +95,38 @@ class LocalTestProvider:
|
|
95
95
|
"""
|
96
96
|
self.root_folder = os.path.abspath(root_folder)
|
97
97
|
|
98
|
-
def list_tests(self):
|
98
|
+
def list_tests(self) -> List[str]:
|
99
99
|
"""List all tests in the given namespace
|
100
100
|
|
101
101
|
Returns:
|
102
102
|
list: A list of test IDs
|
103
103
|
"""
|
104
|
-
|
105
|
-
|
104
|
+
test_files = []
|
106
105
|
for root, _, files in os.walk(self.root_folder):
|
107
|
-
for
|
108
|
-
if not
|
109
|
-
continue
|
110
|
-
|
111
|
-
path = Path(root) / filename
|
112
|
-
if not _is_test_file(path):
|
106
|
+
for file in files:
|
107
|
+
if not file.endswith(".py"):
|
113
108
|
continue
|
114
109
|
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
|
110
|
+
path = Path(os.path.join(root, file))
|
111
|
+
if _is_test_file(path):
|
112
|
+
rel_path = os.path.relpath(path, self.root_folder)
|
113
|
+
test_id = os.path.splitext(rel_path)[0].replace(os.sep, ".")
|
114
|
+
test_files.append(test_id)
|
120
115
|
|
121
|
-
return
|
116
|
+
return test_files
|
122
117
|
|
123
|
-
def load_test(self, test_id: str):
|
124
|
-
"""
|
125
|
-
Load the test identified by the given test_id.
|
118
|
+
def load_test(self, test_id: str) -> Callable[..., Any]:
|
119
|
+
"""Load the test function identified by the given test_id
|
126
120
|
|
127
121
|
Args:
|
128
|
-
test_id (str): The
|
129
|
-
|
122
|
+
test_id (str): The test ID (does not contain the namespace under which
|
123
|
+
the test is registered)
|
130
124
|
|
131
125
|
Returns:
|
132
|
-
The test
|
126
|
+
callable: The test function
|
133
127
|
|
134
128
|
Raises:
|
135
|
-
|
136
|
-
LocalTestProviderLoadTestError: If the test class cannot be found in the module
|
129
|
+
FileNotFoundError: If the test is not found
|
137
130
|
"""
|
138
131
|
# Convert test_id to file path
|
139
132
|
file_path = os.path.join(self.root_folder, f"{test_id.replace('.', '/')}.py")
|
@@ -162,28 +155,28 @@ class LocalTestProvider:
|
|
162
155
|
|
163
156
|
|
164
157
|
class ValidMindTestProvider:
|
165
|
-
"""
|
158
|
+
"""Provider for built-in ValidMind tests"""
|
166
159
|
|
167
|
-
def __init__(self):
|
160
|
+
def __init__(self) -> None:
|
168
161
|
# two subproviders: unit_metrics and normal tests
|
169
|
-
self.
|
162
|
+
self.unit_metrics_provider = LocalTestProvider(
|
170
163
|
os.path.join(os.path.dirname(__file__), "..", "unit_metrics")
|
171
164
|
)
|
172
|
-
self.
|
165
|
+
self.test_provider = LocalTestProvider(os.path.dirname(__file__))
|
173
166
|
|
174
167
|
def list_tests(self) -> List[str]:
|
175
|
-
"""List all tests in the
|
168
|
+
"""List all tests in the given namespace"""
|
176
169
|
metric_ids = [
|
177
|
-
f"unit_metrics.{test}" for test in self.
|
170
|
+
f"unit_metrics.{test}" for test in self.unit_metrics_provider.list_tests()
|
178
171
|
]
|
179
|
-
test_ids = self.
|
172
|
+
test_ids = self.test_provider.list_tests()
|
180
173
|
|
181
174
|
return metric_ids + test_ids
|
182
175
|
|
183
|
-
def load_test(self, test_id: str) ->
|
184
|
-
"""Load
|
176
|
+
def load_test(self, test_id: str) -> Callable[..., Any]:
|
177
|
+
"""Load the test function identified by the given test_id"""
|
185
178
|
return (
|
186
|
-
self.
|
179
|
+
self.unit_metrics_provider.load_test(test_id.replace("unit_metrics.", ""))
|
187
180
|
if test_id.startswith("unit_metrics.")
|
188
|
-
else self.
|
181
|
+
else self.test_provider.load_test(test_id)
|
189
182
|
)
|
validmind/tests/utils.py
CHANGED
@@ -5,6 +5,7 @@
|
|
5
5
|
"""Test Module Utils"""
|
6
6
|
|
7
7
|
import inspect
|
8
|
+
from typing import Any, Optional, Tuple, Type, Union
|
8
9
|
|
9
10
|
import numpy as np
|
10
11
|
import pandas as pd
|
@@ -14,7 +15,7 @@ from validmind.logging import get_logger
|
|
14
15
|
logger = get_logger(__name__)
|
15
16
|
|
16
17
|
|
17
|
-
def test_description(test_class, truncate=True):
|
18
|
+
def test_description(test_class: Type[Any], truncate: bool = True) -> str:
|
18
19
|
description = inspect.getdoc(test_class).strip()
|
19
20
|
|
20
21
|
if truncate and len(description.split("\n")) > 5:
|
@@ -23,7 +24,11 @@ def test_description(test_class, truncate=True):
|
|
23
24
|
return description
|
24
25
|
|
25
26
|
|
26
|
-
def remove_nan_pairs(
|
27
|
+
def remove_nan_pairs(
|
28
|
+
y_true: Union[np.ndarray, list],
|
29
|
+
y_pred: Union[np.ndarray, list],
|
30
|
+
dataset_id: Optional[str] = None,
|
31
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
27
32
|
"""
|
28
33
|
Remove pairs where either true or predicted values are NaN/None.
|
29
34
|
Args:
|
@@ -52,7 +57,11 @@ def remove_nan_pairs(y_true, y_pred, dataset_id=None):
|
|
52
57
|
return y_true, y_pred
|
53
58
|
|
54
59
|
|
55
|
-
def ensure_equal_lengths(
|
60
|
+
def ensure_equal_lengths(
|
61
|
+
y_true: Union[np.ndarray, list],
|
62
|
+
y_pred: Union[np.ndarray, list],
|
63
|
+
dataset_id: Optional[str] = None,
|
64
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
56
65
|
"""
|
57
66
|
Check if true and predicted values have matching lengths, log warning if they don't,
|
58
67
|
and truncate to the shorter length if necessary. Also removes any NaN/None values.
|
@@ -82,7 +91,11 @@ def ensure_equal_lengths(y_true, y_pred, dataset_id=None):
|
|
82
91
|
return y_true, y_pred
|
83
92
|
|
84
93
|
|
85
|
-
def validate_prediction(
|
94
|
+
def validate_prediction(
|
95
|
+
y_true: Union[np.ndarray, list],
|
96
|
+
y_pred: Union[np.ndarray, list],
|
97
|
+
dataset_id: Optional[str] = None,
|
98
|
+
) -> Tuple[np.ndarray, np.ndarray]:
|
86
99
|
"""
|
87
100
|
Comprehensive validation of true and predicted value pairs.
|
88
101
|
Handles NaN/None values and length mismatches.
|
@@ -10,7 +10,7 @@ from validmind.tests.run import run_test
|
|
10
10
|
def list_metrics(**kwargs):
|
11
11
|
"""List all metrics"""
|
12
12
|
vm_provider = test_provider_store.get_test_provider("validmind")
|
13
|
-
vm_metrics_provider = vm_provider.
|
13
|
+
vm_metrics_provider = vm_provider.unit_metrics_provider
|
14
14
|
|
15
15
|
prefix = "validmind.unit_metrics."
|
16
16
|
|