validmind 2.0.7__py3-none-any.whl → 2.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. validmind/__init__.py +3 -3
  2. validmind/__version__.py +1 -1
  3. validmind/ai.py +7 -11
  4. validmind/api_client.py +29 -27
  5. validmind/client.py +10 -3
  6. validmind/datasets/credit_risk/__init__.py +11 -0
  7. validmind/datasets/credit_risk/datasets/lending_club_loan_data_2007_2014_clean.csv.gz +0 -0
  8. validmind/datasets/credit_risk/lending_club.py +394 -0
  9. validmind/logging.py +9 -2
  10. validmind/template.py +2 -2
  11. validmind/test_suites/__init__.py +4 -2
  12. validmind/tests/__init__.py +97 -50
  13. validmind/tests/data_validation/FeatureTargetCorrelationPlot.py +3 -1
  14. validmind/tests/data_validation/PiTCreditScoresHistogram.py +1 -1
  15. validmind/tests/data_validation/ScatterPlot.py +8 -2
  16. validmind/tests/decorator.py +138 -14
  17. validmind/tests/model_validation/BertScore.py +1 -1
  18. validmind/tests/model_validation/BertScoreAggregate.py +1 -1
  19. validmind/tests/model_validation/BleuScore.py +1 -1
  20. validmind/tests/model_validation/ClusterSizeDistribution.py +1 -1
  21. validmind/tests/model_validation/ContextualRecall.py +1 -1
  22. validmind/tests/model_validation/FeaturesAUC.py +110 -0
  23. validmind/tests/model_validation/MeteorScore.py +1 -1
  24. validmind/tests/model_validation/RegardHistogram.py +1 -1
  25. validmind/tests/model_validation/RegardScore.py +1 -1
  26. validmind/tests/model_validation/RegressionResidualsPlot.py +127 -0
  27. validmind/tests/model_validation/RougeMetrics.py +1 -1
  28. validmind/tests/model_validation/RougeMetricsAggregate.py +1 -1
  29. validmind/tests/model_validation/SelfCheckNLIScore.py +1 -1
  30. validmind/tests/model_validation/TokenDisparity.py +1 -1
  31. validmind/tests/model_validation/ToxicityHistogram.py +1 -1
  32. validmind/tests/model_validation/ToxicityScore.py +1 -1
  33. validmind/tests/model_validation/embeddings/ClusterDistribution.py +1 -1
  34. validmind/tests/model_validation/embeddings/CosineSimilarityDistribution.py +1 -3
  35. validmind/tests/model_validation/embeddings/DescriptiveAnalytics.py +1 -1
  36. validmind/tests/model_validation/embeddings/EmbeddingsVisualization2D.py +1 -1
  37. validmind/tests/model_validation/sklearn/ClassifierPerformance.py +15 -18
  38. validmind/tests/model_validation/sklearn/ClusterCosineSimilarity.py +1 -1
  39. validmind/tests/model_validation/sklearn/ClusterPerformance.py +2 -2
  40. validmind/tests/model_validation/sklearn/ConfusionMatrix.py +21 -3
  41. validmind/tests/model_validation/sklearn/MinimumAccuracy.py +1 -1
  42. validmind/tests/model_validation/sklearn/MinimumF1Score.py +1 -1
  43. validmind/tests/model_validation/sklearn/MinimumROCAUCScore.py +1 -1
  44. validmind/tests/model_validation/sklearn/ModelsPerformanceComparison.py +5 -4
  45. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +2 -2
  46. validmind/tests/model_validation/sklearn/ROCCurve.py +6 -12
  47. validmind/tests/model_validation/sklearn/RegressionErrors.py +2 -2
  48. validmind/tests/model_validation/sklearn/RegressionModelsPerformanceComparison.py +6 -4
  49. validmind/tests/model_validation/sklearn/RegressionR2Square.py +2 -2
  50. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +27 -3
  51. validmind/tests/model_validation/sklearn/SilhouettePlot.py +1 -1
  52. validmind/tests/model_validation/sklearn/TrainingTestDegradation.py +2 -2
  53. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +2 -2
  54. validmind/tests/model_validation/statsmodels/CumulativePredictionProbabilities.py +140 -0
  55. validmind/tests/model_validation/statsmodels/GINITable.py +22 -45
  56. validmind/tests/model_validation/statsmodels/{LogisticRegPredictionHistogram.py → PredictionProbabilitiesHistogram.py} +67 -92
  57. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +2 -2
  58. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +2 -2
  59. validmind/tests/model_validation/statsmodels/RegressionModelInsampleComparison.py +1 -1
  60. validmind/tests/model_validation/statsmodels/RegressionModelOutsampleComparison.py +1 -1
  61. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +1 -1
  62. validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
  63. validmind/tests/model_validation/statsmodels/RegressionPermutationFeatureImportance.py +128 -0
  64. validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +70 -103
  65. validmind/tests/test_providers.py +14 -124
  66. validmind/unit_metrics/__init__.py +76 -69
  67. validmind/unit_metrics/classification/sklearn/Accuracy.py +14 -0
  68. validmind/unit_metrics/classification/sklearn/F1.py +13 -0
  69. validmind/unit_metrics/classification/sklearn/Precision.py +13 -0
  70. validmind/unit_metrics/classification/sklearn/ROC_AUC.py +13 -0
  71. validmind/unit_metrics/classification/sklearn/Recall.py +13 -0
  72. validmind/unit_metrics/composite.py +24 -71
  73. validmind/unit_metrics/regression/GiniCoefficient.py +20 -26
  74. validmind/unit_metrics/regression/HuberLoss.py +12 -16
  75. validmind/unit_metrics/regression/KolmogorovSmirnovStatistic.py +18 -24
  76. validmind/unit_metrics/regression/MeanAbsolutePercentageError.py +7 -13
  77. validmind/unit_metrics/regression/MeanBiasDeviation.py +5 -14
  78. validmind/unit_metrics/regression/QuantileLoss.py +6 -16
  79. validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +12 -18
  80. validmind/unit_metrics/regression/sklearn/MeanAbsoluteError.py +6 -15
  81. validmind/unit_metrics/regression/sklearn/MeanSquaredError.py +5 -14
  82. validmind/unit_metrics/regression/sklearn/RSquaredScore.py +6 -15
  83. validmind/unit_metrics/regression/sklearn/RootMeanSquaredError.py +11 -14
  84. validmind/utils.py +18 -45
  85. validmind/vm_models/__init__.py +0 -2
  86. validmind/vm_models/dataset.py +255 -16
  87. validmind/vm_models/test/metric.py +1 -2
  88. validmind/vm_models/test/result_wrapper.py +12 -13
  89. validmind/vm_models/test/test.py +2 -1
  90. validmind/vm_models/test/threshold_test.py +1 -2
  91. validmind/vm_models/test_suite/summary.py +3 -3
  92. validmind/vm_models/test_suite/test_suite.py +2 -1
  93. {validmind-2.0.7.dist-info → validmind-2.1.0.dist-info}/METADATA +10 -6
  94. {validmind-2.0.7.dist-info → validmind-2.1.0.dist-info}/RECORD +97 -96
  95. validmind/tests/__types__.py +0 -62
  96. validmind/tests/model_validation/statsmodels/LogRegressionConfusionMatrix.py +0 -128
  97. validmind/tests/model_validation/statsmodels/LogisticRegCumulativeProb.py +0 -172
  98. validmind/tests/model_validation/statsmodels/ScorecardBucketHistogram.py +0 -181
  99. validmind/tests/model_validation/statsmodels/ScorecardProbabilitiesHistogram.py +0 -175
  100. validmind/unit_metrics/sklearn/classification/Accuracy.py +0 -22
  101. validmind/unit_metrics/sklearn/classification/F1.py +0 -24
  102. validmind/unit_metrics/sklearn/classification/Precision.py +0 -24
  103. validmind/unit_metrics/sklearn/classification/ROC_AUC.py +0 -22
  104. validmind/unit_metrics/sklearn/classification/Recall.py +0 -22
  105. validmind/vm_models/test/unit_metric.py +0 -88
  106. {validmind-2.0.7.dist-info → validmind-2.1.0.dist-info}/LICENSE +0 -0
  107. {validmind-2.0.7.dist-info → validmind-2.1.0.dist-info}/WHEEL +0 -0
  108. {validmind-2.0.7.dist-info → validmind-2.1.0.dist-info}/entry_points.txt +0 -0
@@ -5,24 +5,26 @@
5
5
  """All Tests for ValidMind"""
6
6
 
7
7
  import importlib
8
+ import inspect
8
9
  import sys
9
10
  from pathlib import Path
10
11
  from pprint import pformat
11
12
  from typing import Dict
12
13
 
14
+ import mistune
13
15
  import pandas as pd
14
16
  from IPython.display import display
15
17
  from ipywidgets import HTML
16
- from markdown import markdown
17
18
 
18
19
  from ..errors import LoadTestError
19
20
  from ..html_templates.content_blocks import test_content_block_html
20
21
  from ..logging import get_logger
22
+ from ..unit_metrics import run_metric
21
23
  from ..unit_metrics.composite import load_composite_metric
22
- from ..utils import clean_docstring, format_dataframe, fuzzy_match, test_id_to_name
24
+ from ..utils import format_dataframe, fuzzy_match, test_id_to_name
23
25
  from ..vm_models import TestContext, TestInput
24
- from .__types__ import ExternalTestProvider
25
- from .test_providers import GithubTestProvider, LocalTestProvider
26
+ from .decorator import metric, tags, tasks
27
+ from .test_providers import LocalTestProvider, TestProvider
26
28
 
27
29
  logger = get_logger(__name__)
28
30
 
@@ -35,23 +37,28 @@ __all__ = [
35
37
  "load_test",
36
38
  "describe_test",
37
39
  "register_test_provider",
38
- "GithubTestProvider",
39
40
  "LoadTestError",
40
41
  "LocalTestProvider",
42
+ # Decorators for functional metrics
43
+ "metric",
44
+ "tags",
45
+ "tasks",
41
46
  ]
42
47
 
43
48
  __tests = None
44
49
  __test_classes = None
45
50
 
46
- __test_providers: Dict[str, ExternalTestProvider] = {}
51
+ __test_providers: Dict[str, TestProvider] = {}
47
52
  __custom_tests: Dict[str, object] = {}
48
53
 
49
54
 
50
55
  def _test_description(test_class, truncate=True):
51
- if truncate and len(test_class.__doc__.split("\n")) > 5:
52
- return test_class.__doc__.strip().split("\n")[0] + "..."
56
+ description = inspect.getdoc(test_class).strip()
53
57
 
54
- return test_class.__doc__
58
+ if truncate and len(description.split("\n")) > 5:
59
+ return description.strip().split("\n")[0] + "..."
60
+
61
+ return description
55
62
 
56
63
 
57
64
  def _load_tests(test_ids):
@@ -251,61 +258,83 @@ def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
251
258
  return tests
252
259
 
253
260
 
254
- def load_test(test_id, reload=False): # noqa: C901
255
- # Extract the test ID extension from the actual test ID when loading
256
- # the test class. This enables us to generate multiple results for
257
- # the same tests within the document. For instance, consider the
258
- # test ID "validmind.data_validation.ClassImbalance:data_id_1,"
259
- # where the test ID extension is "data_id_1".
261
+ def _load_validmind_test(test_id, reload=False):
260
262
  parts = test_id.split(":")[0].split(".")
261
263
 
264
+ test_module = ".".join(parts[1:-1])
265
+ test_class = parts[-1]
266
+
267
+ error = None
268
+ test = None
269
+
270
+ try:
271
+ full_path = f"validmind.tests.{test_module}.{test_class}"
272
+
273
+ if reload and full_path in sys.modules:
274
+ module = importlib.reload(sys.modules[full_path])
275
+ else:
276
+ module = importlib.import_module(full_path)
277
+
278
+ test = getattr(module, test_class)
279
+ except ModuleNotFoundError as e:
280
+ error = f"Unable to load test {test_id}. {e}"
281
+ except AttributeError:
282
+ error = f"Unable to load test {test_id}. Test not in module: {test_class}"
283
+
284
+ return error, test
285
+
286
+
287
+ def load_test(test_id: str, reload=False):
288
+ """Load a test by test ID
289
+
290
+ Test IDs are in the format `namespace.path_to_module.TestClassOrFuncName[:result_id]`.
291
+ The result ID is optional and is used to distinguish between multiple results from the
292
+ running the same test.
293
+
294
+ Args:
295
+ test_id (str): The test ID in the format `namespace.path_to_module.TestName[:result_id]`
296
+ reload (bool, optional): Whether to reload the test module. Defaults to False.
297
+ """
298
+ # TODO: we should use a dedicated class for test IDs to handle this consistently
299
+ test_id, result_id = test_id.split(":", 1) if ":" in test_id else (test_id, None)
300
+
262
301
  error = None
263
- namespace = parts[0]
302
+ namespace = test_id.split(".", 1)[0]
264
303
 
265
- if test_id.split(":")[0] in __custom_tests:
266
- test = __custom_tests[test_id.split(":")[0]]
304
+ # TODO: lets implement an extensible loading system instead of this ugly if/else
305
+ if test_id in __custom_tests:
306
+ test = __custom_tests[test_id]
267
307
 
268
308
  elif test_id.startswith("validmind.composite_metric"):
269
- test = load_composite_metric(test_id)
309
+ error, test = load_composite_metric(test_id)
270
310
 
271
311
  elif namespace == "validmind":
272
- test_module = ".".join(parts[1:-1])
273
- test_class = parts[-1]
274
-
275
- try:
276
- full_path = f"validmind.tests.{test_module}.{test_class}"
277
-
278
- if reload and full_path in sys.modules:
279
- module = importlib.reload(sys.modules[full_path])
280
- else:
281
- module = importlib.import_module(full_path)
282
-
283
- test = getattr(module, test_class)
284
- except ModuleNotFoundError as e:
285
- error = f"Unable to load test {test_id}. {e}"
286
- except AttributeError:
287
- error = f"Unable to load test {test_id}. Class not in module: {test_class}"
288
-
289
- elif namespace != "validmind" and namespace not in __test_providers:
290
- error = (
291
- f"Unable to load test {test_id}. "
292
- f"No Test Provider found for the namespace: {namespace}."
293
- )
312
+ error, test = _load_validmind_test(test_id, reload=reload)
294
313
 
295
314
  elif namespace in __test_providers:
296
315
  try:
297
316
  test = __test_providers[namespace].load_test(test_id.split(".", 1)[1])
298
317
  except Exception as e:
299
318
  error = (
300
- f"Unable to load test {test_id} from test provider: "
319
+ f"Unable to load test {test_id} from test provider: "
301
320
  f"{__test_providers[namespace]}\n Got Exception: {e}"
302
321
  )
303
322
 
323
+ else:
324
+ error = f"Unable to load test {test_id}. No test provider found."
325
+
304
326
  if error:
305
327
  logger.error(error)
306
328
  raise LoadTestError(error)
307
329
 
308
- test.test_id = test_id
330
+ if inspect.isfunction(test):
331
+ # if its a function, we decorate it and then load the class
332
+ # TODO: simplify this as we move towards all functional metrics
333
+ # "_" is used here so it doesn't conflict with other test ids
334
+ metric("_")(test)
335
+ test = __custom_tests["_"]
336
+
337
+ test.test_id = f"{test_id}:{result_id}" if result_id else test_id
309
338
 
310
339
  return test
311
340
 
@@ -330,7 +359,7 @@ def describe_test(test_id: str = None, raw: bool = False):
330
359
  "Test Type": test.test_type,
331
360
  "Required Inputs": test.required_inputs,
332
361
  "Params": test.default_params or {},
333
- "Description": clean_docstring(test.__doc__),
362
+ "Description": inspect.getdoc(test).strip() or "",
334
363
  }
335
364
 
336
365
  if raw:
@@ -340,7 +369,7 @@ def describe_test(test_id: str = None, raw: bool = False):
340
369
  HTML(
341
370
  test_content_block_html.format(
342
371
  title=f'{details["Name"]}',
343
- description=markdown(details["Description"]),
372
+ description=mistune.html(details["Description"].strip()),
344
373
  required_inputs=", ".join(details["Required Inputs"] or ["None"]),
345
374
  params_table="\n".join(
346
375
  [
@@ -361,6 +390,7 @@ def run_test(
361
390
  params: dict = None,
362
391
  inputs=None,
363
392
  output_template=None,
393
+ show=True,
364
394
  **kwargs,
365
395
  ):
366
396
  """Run a test by test ID
@@ -375,6 +405,7 @@ def run_test(
375
405
  params (dict, optional): A dictionary of params to override the default params
376
406
  inputs: A dictionary of test inputs to pass to the Test
377
407
  output_template (str, optional): A template to use for customizing the output
408
+ show (bool, optional): Whether to display the results. Defaults to True.
378
409
  **kwargs: Any extra arguments will be passed in via the TestInput object. i.e.:
379
410
  - dataset: A validmind Dataset object or a Pandas DataFrame
380
411
  - model: A model to use for the test
@@ -389,9 +420,23 @@ def run_test(
389
420
  if (unit_metrics and not name) or (name and not unit_metrics):
390
421
  raise ValueError("`name` and `unit_metrics` must be provided together")
391
422
 
423
+ if test_id and test_id.startswith("validmind.unit_metrics"):
424
+ # TODO: as we move towards a more unified approach to metrics
425
+ # we will want to make everything functional and remove the
426
+ # separation between unit metrics and "normal" metrics
427
+ return run_metric(test_id, inputs=inputs, params=params, show=show)
428
+
392
429
  if unit_metrics:
393
- TestClass = load_composite_metric(unit_metrics=unit_metrics, metric_name=name)
394
- test_id = f"validmind.composite_metric.{name}"
430
+ metric_id_name = "".join(word[0].upper() + word[1:] for word in name.split())
431
+ test_id = f"validmind.composite_metric.{metric_id_name}"
432
+
433
+ error, TestClass = load_composite_metric(
434
+ unit_metrics=unit_metrics, metric_name=metric_id_name
435
+ )
436
+
437
+ if error:
438
+ raise LoadTestError(error)
439
+
395
440
  else:
396
441
  TestClass = load_test(test_id, reload=True)
397
442
 
@@ -404,17 +449,19 @@ def run_test(
404
449
  )
405
450
 
406
451
  test.run()
407
- test.result.show()
452
+
453
+ if show:
454
+ test.result.show()
408
455
 
409
456
  return test.result
410
457
 
411
458
 
412
- def register_test_provider(namespace: str, test_provider: ExternalTestProvider) -> None:
459
+ def register_test_provider(namespace: str, test_provider: TestProvider) -> None:
413
460
  """Register an external test provider
414
461
 
415
462
  Args:
416
463
  namespace (str): The namespace of the test provider
417
- test_provider (ExternalTestProvider): The test provider
464
+ test_provider (TestProvider): The test provider
418
465
  """
419
466
  __test_providers[namespace] = test_provider
420
467
 
@@ -74,7 +74,9 @@ class FeatureTargetCorrelationPlot(Metric):
74
74
 
75
75
  def visualize_feature_target_correlation(self, df, target_column, fig_height):
76
76
  # Compute correlations with the target variable
77
- correlations = df.corr(numeric_only=True)[target_column].drop(target_column)
77
+ correlations = (
78
+ df.corr(numeric_only=True)[target_column].drop(target_column).to_frame()
79
+ )
78
80
  correlations = correlations.loc[:, ~correlations.columns.duplicated()]
79
81
 
80
82
  correlations = correlations.sort_values(by=target_column, ascending=True)
@@ -113,7 +113,7 @@ class PiTCreditScoresHistogram(Metric):
113
113
  )
114
114
  predicted_default_column = (
115
115
  self.params.get("predicted_default_column")
116
- or self.inputs.dataset.y_pred(self.inputs.model.input_id),
116
+ or self.inputs.dataset.y_pred(self.inputs.model),
117
117
  )
118
118
  scores_column = self.params["scores_column"]
119
119
  point_in_time_column = self.params["point_in_time_column"]
@@ -65,8 +65,14 @@ class ScatterPlot(Metric):
65
65
  if not set(columns).issubset(set(df.columns)):
66
66
  raise ValueError("Provided 'columns' must exist in the dataset")
67
67
 
68
- sns.pairplot(data=df, diag_kind="kde")
69
-
68
+ g = sns.pairplot(data=df, diag_kind="kde")
69
+ for ax in g.axes.flatten():
70
+ # rotate x axis labels
71
+ ax.set_xlabel(ax.get_xlabel(), rotation=45)
72
+ # rotate y axis labels
73
+ ax.set_ylabel(ax.get_ylabel(), rotation=45)
74
+ # set y labels alignment
75
+ ax.yaxis.get_label().set_horizontalalignment("right")
70
76
  # Get the current figure
71
77
  fig = plt.gcf()
72
78
 
@@ -4,13 +4,17 @@
4
4
 
5
5
  """Decorators for creating and registering metrics with the ValidMind framework."""
6
6
 
7
+ # TODO: as we move entirely to a functional approach a lot of this logic
8
+ # should be moved into the __init__ to replace the old class-based stuff
9
+
7
10
  import inspect
11
+ import os
8
12
  from uuid import uuid4
9
13
 
10
14
  import pandas as pd
11
15
 
16
+ from validmind.errors import MissingRequiredTestInputError
12
17
  from validmind.logging import get_logger
13
- from validmind.utils import clean_docstring
14
18
  from validmind.vm_models import (
15
19
  Metric,
16
20
  MetricResult,
@@ -26,8 +30,6 @@ from validmind.vm_models.figure import (
26
30
  )
27
31
  from validmind.vm_models.test.result_wrapper import MetricResultWrapper
28
32
 
29
- from . import _register_custom_test
30
-
31
33
  logger = get_logger(__name__)
32
34
 
33
35
 
@@ -53,7 +55,7 @@ def _inspect_signature(test_func: callable):
53
55
  return inputs, params
54
56
 
55
57
 
56
- def _build_result(results, test_id, description, output_template):
58
+ def _build_result(results, test_id, description, output_template, inputs): # noqa: C901
57
59
  ref_id = str(uuid4())
58
60
  figure_metadata = {
59
61
  "_type": "metric",
@@ -65,7 +67,17 @@ def _build_result(results, test_id, description, output_template):
65
67
  figures = []
66
68
 
67
69
  def process_item(item):
68
- if is_matplotlib_figure(item) or is_plotly_figure(item) or is_png_image(item):
70
+ # TOOD: build out a more robust/extensible system for this
71
+ # TODO: custom type handlers would be really cool
72
+
73
+ # unit metrics (scalar values) - show in a simple table for now
74
+ if isinstance(item, int) or isinstance(item, float) or isinstance(item, str):
75
+ tables.append(ResultTable(data=[{test_id.split(".")[-1]: item}]))
76
+
77
+ # plots
78
+ elif isinstance(item, Figure):
79
+ figures.append(item)
80
+ elif is_matplotlib_figure(item) or is_plotly_figure(item) or is_png_image(item):
69
81
  figures.append(
70
82
  Figure(
71
83
  key=f"{test_id}:{len(figures) + 1}",
@@ -73,18 +85,24 @@ def _build_result(results, test_id, description, output_template):
73
85
  metadata=figure_metadata,
74
86
  )
75
87
  )
76
- elif isinstance(item, list):
77
- tables.append(ResultTable(data=item))
78
- elif isinstance(item, pd.DataFrame):
88
+
89
+ # tables
90
+ elif isinstance(item, list) or isinstance(item, pd.DataFrame):
79
91
  tables.append(ResultTable(data=item))
80
92
  elif isinstance(item, dict):
81
93
  for table_name, table in item.items():
94
+ if not isinstance(table, list) and not isinstance(table, pd.DataFrame):
95
+ raise ValueError(
96
+ f"Invalid table format: {table_name} must be a list or DataFrame"
97
+ )
98
+
82
99
  tables.append(
83
100
  ResultTable(
84
101
  data=table,
85
102
  metadata=ResultTableMetadata(title=table_name),
86
103
  )
87
104
  )
105
+
88
106
  else:
89
107
  raise ValueError(f"Invalid return type: {type(item)}")
90
108
 
@@ -107,17 +125,23 @@ def _build_result(results, test_id, description, output_template):
107
125
  result_metadata=[
108
126
  {
109
127
  "content_id": f"metric_description:{test_id}",
110
- "text": clean_docstring(description),
128
+ "text": description,
111
129
  }
112
130
  ],
113
- inputs=[],
131
+ inputs=inputs,
114
132
  output_template=output_template,
115
133
  )
116
134
 
117
135
 
118
- def get_run_method(func, inputs, params):
136
+ def _get_run_method(func, inputs, params):
119
137
  def run(self: Metric):
120
- input_kwargs = {k: getattr(self.inputs, k) for k in inputs.keys()}
138
+ input_kwargs = {}
139
+ for k in inputs.keys():
140
+ try:
141
+ input_kwargs[k] = getattr(self.inputs, k)
142
+ except AttributeError:
143
+ raise MissingRequiredTestInputError(f"Missing required input: {k}.")
144
+
121
145
  param_kwargs = {
122
146
  k: self.params.get(k, params[k]["default"]) for k in params.keys()
123
147
  }
@@ -127,8 +151,9 @@ def get_run_method(func, inputs, params):
127
151
  self.result = _build_result(
128
152
  results=raw_results,
129
153
  test_id=self.test_id,
130
- description=self.__doc__,
154
+ description=inspect.getdoc(self),
131
155
  output_template=self.output_template,
156
+ inputs=list(inputs.keys()),
132
157
  )
133
158
 
134
159
  return self.result
@@ -136,6 +161,65 @@ def get_run_method(func, inputs, params):
136
161
  return run
137
162
 
138
163
 
164
+ def _get_save_func(func, test_id):
165
+ def save(root_folder=".", imports=None):
166
+ parts = test_id.split(".")
167
+
168
+ if len(parts) > 1:
169
+ path = os.path.join(root_folder, *parts[1:-1])
170
+ test_name = parts[-1]
171
+ new_test_id = f"<test_provider_namespace>.{'.'.join(parts[1:])}"
172
+ else:
173
+ path = root_folder
174
+ test_name = parts[0]
175
+ new_test_id = f"<test_provider_namespace>.{test_name}"
176
+
177
+ if not os.path.exists(path):
178
+ os.makedirs(path, exist_ok=True)
179
+
180
+ full_path = os.path.join(path, f"{test_name}.py")
181
+
182
+ source = inspect.getsource(func)
183
+ # remove decorator line
184
+ source = source.split("\n", 1)[1]
185
+ if imports:
186
+ imports = "\n".join(imports)
187
+ source = f"{imports}\n\n\n{source}"
188
+ # add comment to the top of the file
189
+ source = f"""
190
+ # Saved from {func.__module__}.{func.__name__}
191
+ # Original Test ID: {test_id}
192
+ # New Test ID: {new_test_id}
193
+
194
+ {source}
195
+ """
196
+
197
+ # ensure that the function name matches the test name
198
+ source = source.replace(f"def {func.__name__}", f"def {test_name}")
199
+
200
+ # use black to format the code
201
+ try:
202
+ import black
203
+
204
+ source = black.format_str(source, mode=black.FileMode())
205
+ except ImportError:
206
+ # ignore if not available
207
+ pass
208
+
209
+ with open(full_path, "w") as file:
210
+ file.writelines(source)
211
+
212
+ logger.info(
213
+ f"Saved to {os.path.abspath(full_path)}!"
214
+ "Be sure to add any necessary imports to the top of the file."
215
+ )
216
+ logger.info(
217
+ f"This metric can be run with the ID: {new_test_id}",
218
+ )
219
+
220
+ return save
221
+
222
+
139
223
  def metric(func_or_id):
140
224
  """Decorator for creating and registering metrics with the ValidMind framework.
141
225
 
@@ -151,6 +235,7 @@ def metric(func_or_id):
151
235
 
152
236
  - Table: Either a list of dictionaries or a pandas DataFrame
153
237
  - Plot: Either a matplotlib figure or a plotly figure
238
+ - Scalar: A single number or string
154
239
 
155
240
  The function may also include a docstring. This docstring will be used and logged
156
241
  as the metric's description.
@@ -163,27 +248,66 @@ def metric(func_or_id):
163
248
  The decorated function.
164
249
  """
165
250
 
251
+ from . import _register_custom_test
252
+
166
253
  def decorator(func):
167
254
  test_id = func_or_id or f"validmind.custom_metrics.{func.__name__}"
168
255
 
169
256
  inputs, params = _inspect_signature(func)
170
257
  description = inspect.getdoc(func)
258
+ tasks = getattr(func, "__tasks__", [])
259
+ tags = getattr(func, "__tags__", [])
171
260
 
172
261
  metric_class = type(
173
262
  func.__name__,
174
263
  (Metric,),
175
264
  {
176
- "run": get_run_method(func, inputs, params),
265
+ "run": _get_run_method(func, inputs, params),
177
266
  "required_inputs": list(inputs.keys()),
178
267
  "default_parameters": params,
179
268
  "__doc__": description,
269
+ "metadata": {
270
+ "task_types": tasks,
271
+ "tags": tags,
272
+ },
180
273
  },
181
274
  )
182
275
  _register_custom_test(test_id, metric_class)
183
276
 
277
+ # special function to allow the function to be saved to a file
278
+ func.save = _get_save_func(func, test_id)
279
+
184
280
  return func
185
281
 
186
282
  if callable(func_or_id):
187
283
  return decorator(func_or_id)
188
284
 
189
285
  return decorator
286
+
287
+
288
+ def tasks(*tasks):
289
+ """Decorator for specifying the task types that a metric is designed for.
290
+
291
+ Args:
292
+ *tasks: The task types that the metric is designed for.
293
+ """
294
+
295
+ def decorator(func):
296
+ func.__tasks__ = list(tasks)
297
+ return func
298
+
299
+ return decorator
300
+
301
+
302
+ def tags(*tags):
303
+ """Decorator for specifying tags for a metric.
304
+
305
+ Args:
306
+ *tags: The tags to apply to the metric.
307
+ """
308
+
309
+ def decorator(func):
310
+ func.__tags__ = list(tags)
311
+ return func
312
+
313
+ return decorator
@@ -57,7 +57,7 @@ class BertScore(Metric):
57
57
 
58
58
  def run(self):
59
59
  y_true = list(itertools.chain.from_iterable(self.inputs.dataset.y))
60
- y_pred = self.inputs.dataset.y_pred(self.inputs.model.input_id)
60
+ y_pred = self.inputs.dataset.y_pred(self.inputs.model)
61
61
 
62
62
  # Load the bert evaluation metric
63
63
  bert = evaluate.load("bertscore")
@@ -50,7 +50,7 @@ class BertScoreAggregate(Metric):
50
50
 
51
51
  def run(self):
52
52
  y_true = list(itertools.chain.from_iterable(self.inputs.dataset.y))
53
- y_pred = self.inputs.dataset.y_pred(self.inputs.model.input_id)
53
+ y_pred = self.inputs.dataset.y_pred(self.inputs.model)
54
54
 
55
55
  bert = evaluate.load("bertscore")
56
56
  bert_s = bert.compute(predictions=y_pred, references=y_true, lang="en")
@@ -55,7 +55,7 @@ class BleuScore(Metric):
55
55
 
56
56
  # Compute the BLEU score
57
57
  bleu = bleu.compute(
58
- predictions=self.inputs.dataset.y_pred(self.inputs.model.input_id),
58
+ predictions=self.inputs.dataset.y_pred(self.inputs.model),
59
59
  references=self.inputs.dataset.y,
60
60
  )
61
61
  return self.cache_results(metric_value={"blue_score_metric": bleu})
@@ -61,7 +61,7 @@ class ClusterSizeDistribution(Metric):
61
61
 
62
62
  def run(self):
63
63
  y_true_train = self.inputs.dataset.y
64
- y_pred_train = self.inputs.dataset.y_pred(self.inputs.model.input_id)
64
+ y_pred_train = self.inputs.dataset.y_pred(self.inputs.model)
65
65
  y_true_train = y_true_train.astype(y_pred_train.dtype)
66
66
  df = pd.DataFrame(
67
67
  {"Actual": y_true_train.ravel(), "Prediction": y_pred_train.ravel()}
@@ -66,7 +66,7 @@ class ContextualRecall(Metric):
66
66
 
67
67
  def run(self):
68
68
  y_true = list(itertools.chain.from_iterable(self.inputs.dataset.y))
69
- y_pred = self.inputs.dataset.y_pred(self.inputs.model.input_id)
69
+ y_pred = self.inputs.dataset.y_pred(self.inputs.model)
70
70
 
71
71
  score_list = []
72
72
  for y_t, y_p in zip(y_true, y_pred):