validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. validmind/__init__.py +6 -5
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +13 -9
  4. validmind/ai/utils.py +2 -2
  5. validmind/api_client.py +75 -32
  6. validmind/client.py +111 -100
  7. validmind/client_config.py +3 -3
  8. validmind/datasets/classification/__init__.py +7 -3
  9. validmind/datasets/credit_risk/lending_club.py +28 -16
  10. validmind/datasets/nlp/cnn_dailymail.py +10 -4
  11. validmind/datasets/regression/__init__.py +22 -5
  12. validmind/errors.py +17 -7
  13. validmind/input_registry.py +1 -1
  14. validmind/logging.py +44 -35
  15. validmind/models/foundation.py +2 -2
  16. validmind/models/function.py +10 -3
  17. validmind/template.py +33 -24
  18. validmind/test_suites/__init__.py +2 -2
  19. validmind/tests/_store.py +13 -4
  20. validmind/tests/comparison.py +65 -33
  21. validmind/tests/data_validation/ClassImbalance.py +3 -1
  22. validmind/tests/data_validation/DatasetDescription.py +2 -23
  23. validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
  24. validmind/tests/data_validation/Skewness.py +7 -6
  25. validmind/tests/decorator.py +14 -11
  26. validmind/tests/load.py +38 -24
  27. validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
  28. validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
  29. validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
  30. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
  31. validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
  32. validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
  33. validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
  34. validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
  35. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
  36. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
  37. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
  38. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
  39. validmind/tests/output.py +66 -11
  40. validmind/tests/run.py +28 -14
  41. validmind/tests/test_providers.py +28 -35
  42. validmind/tests/utils.py +17 -4
  43. validmind/unit_metrics/__init__.py +1 -1
  44. validmind/utils.py +295 -31
  45. validmind/vm_models/dataset/dataset.py +83 -43
  46. validmind/vm_models/dataset/utils.py +5 -3
  47. validmind/vm_models/figure.py +6 -6
  48. validmind/vm_models/input.py +6 -5
  49. validmind/vm_models/model.py +5 -5
  50. validmind/vm_models/result/result.py +122 -43
  51. validmind/vm_models/result/utils.py +5 -5
  52. validmind/vm_models/test_suite/__init__.py +5 -0
  53. validmind/vm_models/test_suite/runner.py +5 -5
  54. validmind/vm_models/test_suite/summary.py +20 -2
  55. validmind/vm_models/test_suite/test.py +6 -6
  56. validmind/vm_models/test_suite/test_suite.py +10 -10
  57. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
  58. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
  59. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
  60. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
  61. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
@@ -47,7 +47,7 @@ def _compute_metrics(
47
47
  None: The computed metrics are appended to the `results` dictionary in-place.
48
48
  """
49
49
  results["Slice"].append(str(region))
50
- results["Shape"].append(df_region.shape[0])
50
+ results["Number of Records"].append(df_region.shape[0])
51
51
  results["Feature"].append(feature_column)
52
52
 
53
53
  # Check if df_region is an empty dataframe and if so, append 0 to all metrics
@@ -222,7 +222,7 @@ def WeakspotsDiagnosis(
222
222
  thresholds = thresholds or DEFAULT_THRESHOLDS
223
223
  thresholds = {k.title(): v for k, v in thresholds.items()}
224
224
 
225
- results_headers = ["Slice", "Shape", "Feature"]
225
+ results_headers = ["Slice", "Number of Records", "Feature"]
226
226
  results_headers.extend(metrics.keys())
227
227
 
228
228
  figures = []
@@ -236,19 +236,20 @@ def WeakspotsDiagnosis(
236
236
  feature_columns
237
237
  + [datasets[1].target_column, datasets[1].prediction_column(model)]
238
238
  ]
239
-
239
+ results_1 = pd.DataFrame()
240
+ results_2 = pd.DataFrame()
240
241
  for feature in feature_columns:
241
242
  bins = 10
242
243
  if feature in datasets[0].feature_columns_categorical:
243
244
  bins = len(df_1[feature].unique())
244
245
  df_1["bin"] = pd.cut(df_1[feature], bins=bins)
245
246
 
246
- results_1 = {k: [] for k in results_headers}
247
- results_2 = {k: [] for k in results_headers}
247
+ r1 = {k: [] for k in results_headers}
248
+ r2 = {k: [] for k in results_headers}
248
249
 
249
250
  for region, df_region in df_1.groupby("bin"):
250
251
  _compute_metrics(
251
- results=results_1,
252
+ results=r1,
252
253
  metrics=metrics,
253
254
  region=region,
254
255
  df_region=df_region,
@@ -260,7 +261,7 @@ def WeakspotsDiagnosis(
260
261
  (df_2[feature] > region.left) & (df_2[feature] <= region.right)
261
262
  ]
262
263
  _compute_metrics(
263
- results=results_2,
264
+ results=r2,
264
265
  metrics=metrics,
265
266
  region=region,
266
267
  df_region=df_2_region,
@@ -271,8 +272,8 @@ def WeakspotsDiagnosis(
271
272
 
272
273
  for metric in metrics.keys():
273
274
  fig, df = _plot_weak_spots(
274
- results_1=results_1,
275
- results_2=results_2,
275
+ results_1=r1,
276
+ results_2=r2,
276
277
  feature_column=feature,
277
278
  metric=metric,
278
279
  threshold=thresholds[metric],
@@ -284,6 +285,8 @@ def WeakspotsDiagnosis(
284
285
  # rely on visual assessment for this test for now.
285
286
  if not df[df[list(thresholds.keys())].lt(thresholds).any(axis=1)].empty:
286
287
  passed = False
288
+ results_1 = pd.concat([results_1, pd.DataFrame(r1)])
289
+ results_2 = pd.concat([results_2, pd.DataFrame(r2)])
287
290
 
288
291
  return (
289
292
  pd.concat(
@@ -291,7 +294,9 @@ def WeakspotsDiagnosis(
291
294
  pd.DataFrame(results_1).assign(Dataset=datasets[0].input_id),
292
295
  pd.DataFrame(results_2).assign(Dataset=datasets[1].input_id),
293
296
  ]
294
- ).sort_values(["Feature", "Dataset"]),
297
+ )
298
+ .reset_index(drop=True)
299
+ .sort_values(["Feature", "Dataset"]),
295
300
  *figures,
296
301
  passed,
297
302
  )
validmind/tests/output.py CHANGED
@@ -9,6 +9,7 @@ from uuid import uuid4
9
9
  import numpy as np
10
10
  import pandas as pd
11
11
 
12
+ from validmind.utils import is_html, md_to_html
12
13
  from validmind.vm_models.figure import (
13
14
  Figure,
14
15
  is_matplotlib_figure,
@@ -77,30 +78,72 @@ class FigureOutputHandler(OutputHandler):
77
78
 
78
79
  class TableOutputHandler(OutputHandler):
79
80
  def can_handle(self, item: Any) -> bool:
80
- return isinstance(item, (list, pd.DataFrame, dict, ResultTable))
81
+ return isinstance(item, (list, pd.DataFrame, dict, ResultTable, tuple))
82
+
83
+ def _convert_simple_type(self, data: Any) -> pd.DataFrame:
84
+ """Convert a simple data type to a DataFrame."""
85
+ if isinstance(data, dict):
86
+ return pd.DataFrame([data])
87
+ elif data is None:
88
+ return pd.DataFrame()
89
+ else:
90
+ raise ValueError(f"Cannot convert {type(data)} to DataFrame")
91
+
92
+ def _convert_list(self, data_list: List) -> pd.DataFrame:
93
+ """Convert a list to a DataFrame."""
94
+ if not data_list:
95
+ return pd.DataFrame()
96
+
97
+ try:
98
+ return pd.DataFrame(data_list)
99
+ except Exception as e:
100
+ # If conversion fails, try to handle common cases
101
+ if all(
102
+ isinstance(item, (int, float, str, bool, type(None)))
103
+ for item in data_list
104
+ ):
105
+ return pd.DataFrame({"Values": data_list})
106
+ else:
107
+ raise ValueError(f"Could not convert list to DataFrame: {e}")
108
+
109
+ def _convert_to_dataframe(self, table_data: Any) -> pd.DataFrame:
110
+ """Convert various data types to a pandas DataFrame."""
111
+ # Handle special cases by type
112
+ if isinstance(table_data, pd.DataFrame):
113
+ return table_data
114
+ elif isinstance(table_data, (dict, str, type(None))):
115
+ return self._convert_simple_type(table_data)
116
+ elif isinstance(table_data, tuple):
117
+ return self._convert_list(list(table_data))
118
+ elif isinstance(table_data, list):
119
+ return self._convert_list(table_data)
120
+ else:
121
+ # If we reach here, we don't know how to handle this type
122
+ raise ValueError(
123
+ f"Invalid table format: must be a list of dictionaries or a DataFrame, got {type(table_data)}"
124
+ )
81
125
 
82
126
  def process(
83
127
  self,
84
- item: Union[List[Dict[str, Any]], pd.DataFrame, Dict[str, Any], ResultTable],
128
+ item: Union[
129
+ List[Dict[str, Any]], pd.DataFrame, Dict[str, Any], ResultTable, str, tuple
130
+ ],
85
131
  result: TestResult,
86
132
  ) -> None:
133
+ # Convert to a dictionary of tables if not already
87
134
  tables = item if isinstance(item, dict) else {"": item}
88
135
 
89
136
  for table_name, table_data in tables.items():
90
- # if already a ResultTable, add it directly
137
+ # If already a ResultTable, add it directly
91
138
  if isinstance(table_data, ResultTable):
92
139
  result.add_table(table_data)
93
140
  continue
94
141
 
95
- if not isinstance(table_data, (list, pd.DataFrame)):
96
- raise ValueError(
97
- "Invalid table format: must be a list of dictionaries or a DataFrame"
98
- )
99
-
100
- if isinstance(table_data, list):
101
- table_data = pd.DataFrame(table_data)
142
+ # Convert the data to a DataFrame using our helper method
143
+ df = self._convert_to_dataframe(table_data)
102
144
 
103
- result.add_table(ResultTable(data=table_data, title=table_name or None))
145
+ # Add the resulting DataFrame as a table to the resul
146
+ result.add_table(ResultTable(data=df, title=table_name or None))
104
147
 
105
148
 
106
149
  class RawDataOutputHandler(OutputHandler):
@@ -111,6 +154,17 @@ class RawDataOutputHandler(OutputHandler):
111
154
  result.raw_data = item
112
155
 
113
156
 
157
+ class StringOutputHandler(OutputHandler):
158
+ def can_handle(self, item: Any) -> bool:
159
+ return isinstance(item, str)
160
+
161
+ def process(self, item: Any, result: TestResult) -> None:
162
+ if not is_html(item):
163
+ item = md_to_html(item, mathml=True)
164
+
165
+ result.description = item
166
+
167
+
114
168
  def process_output(item: Any, result: TestResult) -> None:
115
169
  """Process a single test output item and update the TestResult."""
116
170
  handlers = [
@@ -119,6 +173,7 @@ def process_output(item: Any, result: TestResult) -> None:
119
173
  FigureOutputHandler(),
120
174
  TableOutputHandler(),
121
175
  RawDataOutputHandler(),
176
+ StringOutputHandler(),
122
177
  ]
123
178
 
124
179
  for handler in handlers:
validmind/tests/run.py CHANGED
@@ -76,7 +76,7 @@ def _get_run_metadata(**metadata: Dict[str, Any]) -> Dict[str, Any]:
76
76
 
77
77
  def _get_test_kwargs(
78
78
  test_func: callable, inputs: Dict[str, Any], params: Dict[str, Any]
79
- ):
79
+ ) -> Tuple[Dict[str, Any], Dict[str, Any]]:
80
80
  """Insepect function signature to build kwargs to pass the inputs and params
81
81
  that the test function expects
82
82
 
@@ -93,7 +93,7 @@ def _get_test_kwargs(
93
93
  params (dict): Test parameters e.g. {"param1": 1, "param2": 2}
94
94
 
95
95
  Returns:
96
- tuple: Tuple of input and param kwargs
96
+ Tuple[Dict[str, Any], Dict[str, Any]]: Tuple of input and param kwargs
97
97
  """
98
98
  input_kwargs = {} # map function inputs (`dataset` etc) to actual objects
99
99
 
@@ -222,6 +222,7 @@ def _run_comparison_test(
222
222
  params: Union[Dict[str, Any], None],
223
223
  param_grid: Union[Dict[str, List[Any]], List[Dict[str, Any]], None],
224
224
  title: Optional[str] = None,
225
+ show_params: bool = True,
225
226
  ):
226
227
  """Run a comparison test i.e. a test that compares multiple outputs of a test across
227
228
  different input and/or param combinations"""
@@ -242,6 +243,7 @@ def _run_comparison_test(
242
243
  show=False,
243
244
  generate_description=False,
244
245
  title=title,
246
+ show_params=show_params,
245
247
  )
246
248
  for config in run_test_configs
247
249
  ]
@@ -253,7 +255,9 @@ def _run_comparison_test(
253
255
  else:
254
256
  test_doc = describe_test(test_id, raw=True)["Description"]
255
257
 
256
- combined_outputs, combined_inputs, combined_params = combine_results(results)
258
+ combined_outputs, combined_inputs, combined_params = combine_results(
259
+ results, show_params
260
+ )
257
261
 
258
262
  return build_test_result(
259
263
  outputs=combined_outputs,
@@ -265,7 +269,12 @@ def _run_comparison_test(
265
269
  )
266
270
 
267
271
 
268
- def _run_test(test_id: TestID, inputs: Dict[str, Any], params: Dict[str, Any]):
272
+ def _run_test(
273
+ test_id: TestID,
274
+ inputs: Dict[str, Any],
275
+ params: Dict[str, Any],
276
+ title: Optional[str] = None,
277
+ ):
269
278
  """Run a standard test and return a TestResult object"""
270
279
  test_func = load_test(test_id)
271
280
  input_kwargs, param_kwargs = _get_test_kwargs(
@@ -282,6 +291,7 @@ def _run_test(test_id: TestID, inputs: Dict[str, Any], params: Dict[str, Any]):
282
291
  test_doc=getdoc(test_func),
283
292
  inputs=input_kwargs,
284
293
  params=param_kwargs,
294
+ title=title,
285
295
  )
286
296
 
287
297
 
@@ -297,6 +307,7 @@ def run_test( # noqa: C901
297
307
  generate_description: bool = True,
298
308
  title: Optional[str] = None,
299
309
  post_process_fn: Union[Callable[[TestResult], None], None] = None,
310
+ show_params: bool = True,
300
311
  **kwargs,
301
312
  ) -> TestResult:
302
313
  """Run a ValidMind or custom test
@@ -321,6 +332,7 @@ def run_test( # noqa: C901
321
332
  generate_description (bool, optional): Whether to generate a description. Defaults to True.
322
333
  title (str, optional): Custom title for the test result
323
334
  post_process_fn (Callable[[TestResult], None], optional): Function to post-process the test result
335
+ show_params (bool, optional): Whether to include parameter values in figure titles for comparison tests. Defaults to True.
324
336
 
325
337
  Returns:
326
338
  TestResult: A TestResult object containing the test results
@@ -358,6 +370,7 @@ def run_test( # noqa: C901
358
370
  input_grid=input_grid,
359
371
  params=params,
360
372
  param_grid=param_grid,
373
+ show_params=show_params,
361
374
  )
362
375
 
363
376
  elif unit_metrics:
@@ -375,7 +388,7 @@ def run_test( # noqa: C901
375
388
  )
376
389
 
377
390
  else:
378
- result = _run_test(test_id, inputs, params)
391
+ result = _run_test(test_id, inputs, params, title)
379
392
 
380
393
  end_time = time.perf_counter()
381
394
  result.metadata = _get_run_metadata(duration_seconds=end_time - start_time)
@@ -383,15 +396,16 @@ def run_test( # noqa: C901
383
396
  if post_process_fn:
384
397
  result = post_process_fn(result)
385
398
 
386
- result.description = get_result_description(
387
- test_id=test_id,
388
- test_description=result.doc,
389
- tables=result.tables,
390
- figures=result.figures,
391
- metric=result.metric,
392
- should_generate=generate_description,
393
- title=title,
394
- )
399
+ if not result.description:
400
+ result.description = get_result_description(
401
+ test_id=test_id,
402
+ test_description=result.doc,
403
+ tables=result.tables,
404
+ figures=result.figures,
405
+ metric=result.metric,
406
+ should_generate=generate_description,
407
+ title=title,
408
+ )
395
409
 
396
410
  if show:
397
411
  result.show()
@@ -7,7 +7,7 @@ import os
7
7
  import re
8
8
  import sys
9
9
  from pathlib import Path
10
- from typing import List, Protocol
10
+ from typing import Any, Callable, List, Protocol
11
11
 
12
12
  from validmind.logging import get_logger
13
13
 
@@ -95,45 +95,38 @@ class LocalTestProvider:
95
95
  """
96
96
  self.root_folder = os.path.abspath(root_folder)
97
97
 
98
- def list_tests(self):
98
+ def list_tests(self) -> List[str]:
99
99
  """List all tests in the given namespace
100
100
 
101
101
  Returns:
102
102
  list: A list of test IDs
103
103
  """
104
- test_ids = []
105
-
104
+ test_files = []
106
105
  for root, _, files in os.walk(self.root_folder):
107
- for filename in files:
108
- if not filename.endswith(".py") or filename.startswith("__"):
109
- continue
110
-
111
- path = Path(root) / filename
112
- if not _is_test_file(path):
106
+ for file in files:
107
+ if not file.endswith(".py"):
113
108
  continue
114
109
 
115
- rel_path = path.relative_to(self.root_folder)
116
-
117
- test_id_parts = [p.stem for p in rel_path.parents if p.stem][::-1]
118
- test_id_parts.append(path.stem)
119
- test_ids.append(".".join(test_id_parts))
110
+ path = Path(os.path.join(root, file))
111
+ if _is_test_file(path):
112
+ rel_path = os.path.relpath(path, self.root_folder)
113
+ test_id = os.path.splitext(rel_path)[0].replace(os.sep, ".")
114
+ test_files.append(test_id)
120
115
 
121
- return sorted(test_ids)
116
+ return test_files
122
117
 
123
- def load_test(self, test_id: str):
124
- """
125
- Load the test identified by the given test_id.
118
+ def load_test(self, test_id: str) -> Callable[..., Any]:
119
+ """Load the test function identified by the given test_id
126
120
 
127
121
  Args:
128
- test_id (str): The identifier of the test. This corresponds to the relative
129
- path of the python file from the root folder, with slashes replaced by dots
122
+ test_id (str): The test ID (does not contain the namespace under which
123
+ the test is registered)
130
124
 
131
125
  Returns:
132
- The test class that matches the last part of the test_id.
126
+ callable: The test function
133
127
 
134
128
  Raises:
135
- LocalTestProviderLoadModuleError: If the test module cannot be imported
136
- LocalTestProviderLoadTestError: If the test class cannot be found in the module
129
+ FileNotFoundError: If the test is not found
137
130
  """
138
131
  # Convert test_id to file path
139
132
  file_path = os.path.join(self.root_folder, f"{test_id.replace('.', '/')}.py")
@@ -162,28 +155,28 @@ class LocalTestProvider:
162
155
 
163
156
 
164
157
  class ValidMindTestProvider:
165
- """Test provider for ValidMind tests"""
158
+ """Provider for built-in ValidMind tests"""
166
159
 
167
- def __init__(self):
160
+ def __init__(self) -> None:
168
161
  # two subproviders: unit_metrics and normal tests
169
- self.metrics_provider = LocalTestProvider(
162
+ self.unit_metrics_provider = LocalTestProvider(
170
163
  os.path.join(os.path.dirname(__file__), "..", "unit_metrics")
171
164
  )
172
- self.tests_provider = LocalTestProvider(os.path.dirname(__file__))
165
+ self.test_provider = LocalTestProvider(os.path.dirname(__file__))
173
166
 
174
167
  def list_tests(self) -> List[str]:
175
- """List all tests in the ValidMind test provider"""
168
+ """List all tests in the given namespace"""
176
169
  metric_ids = [
177
- f"unit_metrics.{test}" for test in self.metrics_provider.list_tests()
170
+ f"unit_metrics.{test}" for test in self.unit_metrics_provider.list_tests()
178
171
  ]
179
- test_ids = self.tests_provider.list_tests()
172
+ test_ids = self.test_provider.list_tests()
180
173
 
181
174
  return metric_ids + test_ids
182
175
 
183
- def load_test(self, test_id: str) -> callable:
184
- """Load a ValidMind test or unit metric"""
176
+ def load_test(self, test_id: str) -> Callable[..., Any]:
177
+ """Load the test function identified by the given test_id"""
185
178
  return (
186
- self.metrics_provider.load_test(test_id.replace("unit_metrics.", ""))
179
+ self.unit_metrics_provider.load_test(test_id.replace("unit_metrics.", ""))
187
180
  if test_id.startswith("unit_metrics.")
188
- else self.tests_provider.load_test(test_id)
181
+ else self.test_provider.load_test(test_id)
189
182
  )
validmind/tests/utils.py CHANGED
@@ -5,6 +5,7 @@
5
5
  """Test Module Utils"""
6
6
 
7
7
  import inspect
8
+ from typing import Any, Optional, Tuple, Type, Union
8
9
 
9
10
  import numpy as np
10
11
  import pandas as pd
@@ -14,7 +15,7 @@ from validmind.logging import get_logger
14
15
  logger = get_logger(__name__)
15
16
 
16
17
 
17
- def test_description(test_class, truncate=True):
18
+ def test_description(test_class: Type[Any], truncate: bool = True) -> str:
18
19
  description = inspect.getdoc(test_class).strip()
19
20
 
20
21
  if truncate and len(description.split("\n")) > 5:
@@ -23,7 +24,11 @@ def test_description(test_class, truncate=True):
23
24
  return description
24
25
 
25
26
 
26
- def remove_nan_pairs(y_true, y_pred, dataset_id=None):
27
+ def remove_nan_pairs(
28
+ y_true: Union[np.ndarray, list],
29
+ y_pred: Union[np.ndarray, list],
30
+ dataset_id: Optional[str] = None,
31
+ ) -> Tuple[np.ndarray, np.ndarray]:
27
32
  """
28
33
  Remove pairs where either true or predicted values are NaN/None.
29
34
  Args:
@@ -52,7 +57,11 @@ def remove_nan_pairs(y_true, y_pred, dataset_id=None):
52
57
  return y_true, y_pred
53
58
 
54
59
 
55
- def ensure_equal_lengths(y_true, y_pred, dataset_id=None):
60
+ def ensure_equal_lengths(
61
+ y_true: Union[np.ndarray, list],
62
+ y_pred: Union[np.ndarray, list],
63
+ dataset_id: Optional[str] = None,
64
+ ) -> Tuple[np.ndarray, np.ndarray]:
56
65
  """
57
66
  Check if true and predicted values have matching lengths, log warning if they don't,
58
67
  and truncate to the shorter length if necessary. Also removes any NaN/None values.
@@ -82,7 +91,11 @@ def ensure_equal_lengths(y_true, y_pred, dataset_id=None):
82
91
  return y_true, y_pred
83
92
 
84
93
 
85
- def validate_prediction(y_true, y_pred, dataset_id=None):
94
+ def validate_prediction(
95
+ y_true: Union[np.ndarray, list],
96
+ y_pred: Union[np.ndarray, list],
97
+ dataset_id: Optional[str] = None,
98
+ ) -> Tuple[np.ndarray, np.ndarray]:
86
99
  """
87
100
  Comprehensive validation of true and predicted value pairs.
88
101
  Handles NaN/None values and length mismatches.
@@ -10,7 +10,7 @@ from validmind.tests.run import run_test
10
10
  def list_metrics(**kwargs):
11
11
  """List all metrics"""
12
12
  vm_provider = test_provider_store.get_test_provider("validmind")
13
- vm_metrics_provider = vm_provider.metrics_provider
13
+ vm_metrics_provider = vm_provider.unit_metrics_provider
14
14
 
15
15
  prefix = "validmind.unit_metrics."
16
16