validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. validmind/__init__.py +6 -5
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +13 -9
  4. validmind/ai/utils.py +2 -2
  5. validmind/api_client.py +75 -32
  6. validmind/client.py +111 -100
  7. validmind/client_config.py +3 -3
  8. validmind/datasets/classification/__init__.py +7 -3
  9. validmind/datasets/credit_risk/lending_club.py +28 -16
  10. validmind/datasets/nlp/cnn_dailymail.py +10 -4
  11. validmind/datasets/regression/__init__.py +22 -5
  12. validmind/errors.py +17 -7
  13. validmind/input_registry.py +1 -1
  14. validmind/logging.py +44 -35
  15. validmind/models/foundation.py +2 -2
  16. validmind/models/function.py +10 -3
  17. validmind/template.py +33 -24
  18. validmind/test_suites/__init__.py +2 -2
  19. validmind/tests/_store.py +13 -4
  20. validmind/tests/comparison.py +65 -33
  21. validmind/tests/data_validation/ClassImbalance.py +3 -1
  22. validmind/tests/data_validation/DatasetDescription.py +2 -23
  23. validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
  24. validmind/tests/data_validation/Skewness.py +7 -6
  25. validmind/tests/decorator.py +14 -11
  26. validmind/tests/load.py +38 -24
  27. validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
  28. validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
  29. validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
  30. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
  31. validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
  32. validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
  33. validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
  34. validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
  35. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
  36. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
  37. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
  38. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
  39. validmind/tests/output.py +66 -11
  40. validmind/tests/run.py +28 -14
  41. validmind/tests/test_providers.py +28 -35
  42. validmind/tests/utils.py +17 -4
  43. validmind/unit_metrics/__init__.py +1 -1
  44. validmind/utils.py +295 -31
  45. validmind/vm_models/dataset/dataset.py +83 -43
  46. validmind/vm_models/dataset/utils.py +5 -3
  47. validmind/vm_models/figure.py +6 -6
  48. validmind/vm_models/input.py +6 -5
  49. validmind/vm_models/model.py +5 -5
  50. validmind/vm_models/result/result.py +122 -43
  51. validmind/vm_models/result/utils.py +5 -5
  52. validmind/vm_models/test_suite/__init__.py +5 -0
  53. validmind/vm_models/test_suite/runner.py +5 -5
  54. validmind/vm_models/test_suite/summary.py +20 -2
  55. validmind/vm_models/test_suite/test.py +6 -6
  56. validmind/vm_models/test_suite/test_suite.py +10 -10
  57. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
  58. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
  59. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
  60. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
  61. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
@@ -14,7 +14,9 @@ from validmind.errors import SkipTestError
14
14
  from validmind.vm_models import VMDataset
15
15
 
16
16
 
17
- @tags("tabular_data", "binary_classification", "multiclass_classification")
17
+ @tags(
18
+ "tabular_data", "binary_classification", "multiclass_classification", "data_quality"
19
+ )
18
20
  @tasks("classification")
19
21
  def ClassImbalance(
20
22
  dataset: VMDataset, min_percent_threshold: int = 10
@@ -6,12 +6,10 @@ import re
6
6
  from collections import Counter
7
7
 
8
8
  import numpy as np
9
- from ydata_profiling.config import Settings
10
- from ydata_profiling.model.typeset import ProfilingTypeSet
11
9
 
12
10
  from validmind import RawData, tags, tasks
13
- from validmind.errors import UnsupportedColumnTypeError
14
11
  from validmind.logging import get_logger
12
+ from validmind.utils import infer_datatypes
15
13
  from validmind.vm_models import VMDataset
16
14
 
17
15
  DEFAULT_HISTOGRAM_BINS = 10
@@ -20,25 +18,6 @@ DEFAULT_HISTOGRAM_BIN_SIZES = [5, 10, 20, 50]
20
18
  logger = get_logger(__name__)
21
19
 
22
20
 
23
- def infer_datatypes(df):
24
- column_type_mappings = {}
25
- typeset = ProfilingTypeSet(Settings())
26
- variable_types = typeset.infer_type(df)
27
-
28
- for column, type in variable_types.items():
29
- if str(type) == "Unsupported":
30
- if df[column].isnull().all():
31
- column_type_mappings[column] = {"id": column, "type": "Null"}
32
- else:
33
- raise UnsupportedColumnTypeError(
34
- f"Unsupported type for column {column}. Please review all values in this dataset column."
35
- )
36
- else:
37
- column_type_mappings[column] = {"id": column, "type": str(type)}
38
-
39
- return list(column_type_mappings.values())
40
-
41
-
42
21
  def get_numerical_histograms(df, column):
43
22
  """
44
23
  Returns a collection of histograms for a numerical column, each one
@@ -50,7 +29,7 @@ def get_numerical_histograms(df, column):
50
29
  # bins='sturges'. Cannot use 'auto' until we review and fix its performance
51
30
  # on datasets with too many unique values
52
31
  #
53
- # 'sturges': Rs default method, only accounts for data size. Only optimal
32
+ # 'sturges': R's default method, only accounts for data size. Only optimal
54
33
  # for gaussian data and underestimates number of bins for large non-gaussian datasets.
55
34
  default_hist = np.histogram(values_cleaned, bins="sturges")
56
35
 
@@ -44,7 +44,7 @@ def get_summary_statistics_categorical(df, categorical_fields):
44
44
  return summary_stats
45
45
 
46
46
 
47
- @tags("tabular_data", "time_series_data")
47
+ @tags("tabular_data", "time_series_data", "data_quality")
48
48
  @tasks("classification", "regression")
49
49
  def DescriptiveStatistics(dataset: VMDataset):
50
50
  """
@@ -2,10 +2,8 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from ydata_profiling.config import Settings
6
- from ydata_profiling.model.typeset import ProfilingTypeSet
7
-
8
5
  from validmind import tags, tasks
6
+ from validmind.utils import infer_datatypes
9
7
 
10
8
 
11
9
  @tags("data_quality", "tabular_data")
@@ -49,8 +47,11 @@ def Skewness(dataset, max_threshold=1):
49
47
  - Subjective threshold for risk grading, requiring expert input and recurrent iterations for refinement.
50
48
  """
51
49
 
52
- typeset = ProfilingTypeSet(Settings())
53
- dataset_types = typeset.infer_type(dataset.df)
50
+ # Use the imported infer_datatypes function
51
+ dataset_types = infer_datatypes(dataset.df)
52
+
53
+ # Convert the list of dictionaries to a dictionary for easy access
54
+ dataset_types_dict = {item["id"]: item["type"] for item in dataset_types}
54
55
 
55
56
  skewness = dataset.df.skew(numeric_only=True)
56
57
 
@@ -58,7 +59,7 @@ def Skewness(dataset, max_threshold=1):
58
59
  passed = True
59
60
 
60
61
  for col in skewness.index:
61
- if str(dataset_types[col]) != "Numeric":
62
+ if dataset_types_dict.get(col) != "Numeric":
62
63
  continue
63
64
 
64
65
  col_skewness = skewness[col]
@@ -7,6 +7,7 @@
7
7
  import inspect
8
8
  import os
9
9
  from functools import wraps
10
+ from typing import Any, Callable, List, Optional, TypeVar, Union
10
11
 
11
12
  from validmind.logging import get_logger
12
13
 
@@ -15,8 +16,10 @@ from .load import load_test
15
16
 
16
17
  logger = get_logger(__name__)
17
18
 
19
+ F = TypeVar("F", bound=Callable[..., Any])
18
20
 
19
- def _get_save_func(func, test_id):
21
+
22
+ def _get_save_func(func: Callable[..., Any], test_id: str) -> Callable[..., None]:
20
23
  """Helper function to save a decorated function to a file
21
24
 
22
25
  Useful when a custom test function has been created inline in a notebook or
@@ -29,7 +32,7 @@ def _get_save_func(func, test_id):
29
32
  # remove decorator line
30
33
  source = source.split("\n", 1)[1]
31
34
 
32
- def save(root_folder=".", imports=None):
35
+ def save(root_folder: str = ".", imports: Optional[List[str]] = None) -> None:
33
36
  parts = test_id.split(".")
34
37
 
35
38
  if len(parts) > 1:
@@ -84,7 +87,7 @@ def _get_save_func(func, test_id):
84
87
  return save
85
88
 
86
89
 
87
- def test(func_or_id):
90
+ def test(func_or_id: Union[Callable[..., Any], str, None]) -> Callable[[F], F]:
88
91
  """Decorator for creating and registering custom tests
89
92
 
90
93
  This decorator registers the function it wraps as a test function within ValidMind
@@ -109,14 +112,14 @@ def test(func_or_id):
109
112
  as the metric's description.
110
113
 
111
114
  Args:
112
- func: The function to decorate
113
- test_id: The identifier for the metric. If not provided, the function name is used.
115
+ func_or_id (Union[Callable[..., Any], str, None]): Either the function to decorate
116
+ or the test ID. If None, the function name is used.
114
117
 
115
118
  Returns:
116
- The decorated function.
119
+ Callable[[F], F]: The decorated function.
117
120
  """
118
121
 
119
- def decorator(func):
122
+ def decorator(func: F) -> F:
120
123
  test_id = func_or_id or f"validmind.custom_metrics.{func.__name__}"
121
124
  test_func = load_test(test_id, func, reload=True)
122
125
  test_store.register_test(test_id, test_func)
@@ -136,28 +139,28 @@ def test(func_or_id):
136
139
  return decorator
137
140
 
138
141
 
139
- def tasks(*tasks):
142
+ def tasks(*tasks: str) -> Callable[[F], F]:
140
143
  """Decorator for specifying the task types that a test is designed for.
141
144
 
142
145
  Args:
143
146
  *tasks: The task types that the test is designed for.
144
147
  """
145
148
 
146
- def decorator(func):
149
+ def decorator(func: F) -> F:
147
150
  func.__tasks__ = list(tasks)
148
151
  return func
149
152
 
150
153
  return decorator
151
154
 
152
155
 
153
- def tags(*tags):
156
+ def tags(*tags: str) -> Callable[[F], F]:
154
157
  """Decorator for specifying tags for a test.
155
158
 
156
159
  Args:
157
160
  *tags: The tags to apply to the test.
158
161
  """
159
162
 
160
- def decorator(func):
163
+ def decorator(func: F) -> F:
161
164
  func.__tags__ = list(tags)
162
165
  return func
163
166
 
validmind/tests/load.py CHANGED
@@ -7,7 +7,7 @@
7
7
  import inspect
8
8
  import json
9
9
  from pprint import pformat
10
- from typing import List
10
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
11
11
  from uuid import uuid4
12
12
 
13
13
  import pandas as pd
@@ -32,7 +32,10 @@ INPUT_TYPE_MAP = {
32
32
  }
33
33
 
34
34
 
35
- def _inspect_signature(test_func: callable):
35
+ def _inspect_signature(
36
+ test_func: Callable[..., Any],
37
+ ) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
38
+ """Inspect a test function's signature to get inputs and parameters"""
36
39
  inputs = {}
37
40
  params = {}
38
41
 
@@ -56,7 +59,9 @@ def _inspect_signature(test_func: callable):
56
59
  return inputs, params
57
60
 
58
61
 
59
- def load_test(test_id: str, test_func: callable = None, reload: bool = False):
62
+ def load_test(
63
+ test_id: str, test_func: Optional[Callable[..., Any]] = None, reload: bool = False
64
+ ) -> Callable[..., Any]:
60
65
  """Load a test by test ID
61
66
 
62
67
  Test IDs are in the format `namespace.path_to_module.TestClassOrFuncName[:tag]`.
@@ -67,6 +72,8 @@ def load_test(test_id: str, test_func: callable = None, reload: bool = False):
67
72
  test_id (str): The test ID in the format `namespace.path_to_module.TestName[:tag]`
68
73
  test_func (callable, optional): The test function to load. If not provided, the
69
74
  test will be loaded from the test provider. Defaults to None.
75
+ reload (bool, optional): If True, reload the test even if it's already loaded.
76
+ Defaults to False.
70
77
  """
71
78
  # remove tag if present
72
79
  test_id = test_id.split(":", 1)[0]
@@ -109,7 +116,8 @@ def load_test(test_id: str, test_func: callable = None, reload: bool = False):
109
116
  return test_store.get_test(test_id)
110
117
 
111
118
 
112
- def _list_test_ids():
119
+ def _list_test_ids() -> List[str]:
120
+ """List all available test IDs"""
113
121
  test_ids = []
114
122
 
115
123
  for namespace, test_provider in test_provider_store.test_providers.items():
@@ -120,7 +128,7 @@ def _list_test_ids():
120
128
  return test_ids
121
129
 
122
130
 
123
- def _load_tests(test_ids):
131
+ def _load_tests(test_ids: List[str]) -> Dict[str, Callable[..., Any]]:
124
132
  """Load a set of tests, handling missing dependencies."""
125
133
  tests = {}
126
134
 
@@ -138,12 +146,12 @@ def _load_tests(test_ids):
138
146
  logger.debug(str(e))
139
147
 
140
148
  if e.extra:
141
- logger.info(
149
+ logger.debug(
142
150
  f"Skipping `{test_id}` as it requires extra dependencies: {e.required_dependencies}."
143
151
  f" Please run `pip install validmind[{e.extra}]` to view and run this test."
144
152
  )
145
153
  else:
146
- logger.info(
154
+ logger.debug(
147
155
  f"Skipping `{test_id}` as it requires missing dependencies: {e.required_dependencies}."
148
156
  " Please install the missing dependencies to view and run this test."
149
157
  )
@@ -151,7 +159,8 @@ def _load_tests(test_ids):
151
159
  return tests
152
160
 
153
161
 
154
- def _test_description(test_description: str, num_lines: int = 5):
162
+ def _test_description(test_description: str, num_lines: int = 5) -> str:
163
+ """Format a test description"""
155
164
  description = test_description.strip("\n").strip()
156
165
 
157
166
  if len(description.split("\n")) > num_lines:
@@ -160,7 +169,10 @@ def _test_description(test_description: str, num_lines: int = 5):
160
169
  return description
161
170
 
162
171
 
163
- def _pretty_list_tests(tests, truncate=True):
172
+ def _pretty_list_tests(
173
+ tests: Dict[str, Callable[..., Any]], truncate: bool = True
174
+ ) -> None:
175
+ """Pretty print a list of tests"""
164
176
  table = [
165
177
  {
166
178
  "ID": test_id,
@@ -171,6 +183,8 @@ def _pretty_list_tests(tests, truncate=True):
171
183
  ),
172
184
  "Required Inputs": list(test.inputs.keys()),
173
185
  "Params": test.params,
186
+ "Tags": test.__tags__,
187
+ "Tasks": test.__tasks__,
174
188
  }
175
189
  for test_id, test in tests.items()
176
190
  ]
@@ -178,10 +192,8 @@ def _pretty_list_tests(tests, truncate=True):
178
192
  return format_dataframe(pd.DataFrame(table))
179
193
 
180
194
 
181
- def list_tags():
182
- """
183
- List unique tags from all test classes.
184
- """
195
+ def list_tags() -> List[str]:
196
+ """List all unique available tags"""
185
197
 
186
198
  unique_tags = set()
187
199
 
@@ -191,7 +203,7 @@ def list_tags():
191
203
  return list(unique_tags)
192
204
 
193
205
 
194
- def list_tasks_and_tags(as_json=False):
206
+ def list_tasks_and_tags(as_json: bool = False) -> Union[str, Dict[str, List[str]]]:
195
207
  """
196
208
  List all task types and their associated tags, with one row per task type and
197
209
  all tags for a task type in one row.
@@ -218,11 +230,8 @@ def list_tasks_and_tags(as_json=False):
218
230
  )
219
231
 
220
232
 
221
- def list_tasks():
222
- """
223
- List unique tasks from all test classes.
224
- """
225
-
233
+ def list_tasks() -> List[str]:
234
+ """List all unique available tasks"""
226
235
  unique_tasks = set()
227
236
 
228
237
  for test in _load_tests(list_tests(pretty=False)).values():
@@ -231,7 +240,13 @@ def list_tasks():
231
240
  return list(unique_tasks)
232
241
 
233
242
 
234
- def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
243
+ def list_tests(
244
+ filter: Optional[str] = None,
245
+ task: Optional[str] = None,
246
+ tags: Optional[List[str]] = None,
247
+ pretty: bool = True,
248
+ truncate: bool = True,
249
+ ) -> Union[List[str], None]:
235
250
  """List all tests in the tests directory.
236
251
 
237
252
  Args:
@@ -245,9 +260,6 @@ def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
245
260
  formatted table. Defaults to True.
246
261
  truncate (bool, optional): If True, truncates the test description to the first
247
262
  line. Defaults to True. (only used if pretty=True)
248
-
249
- Returns:
250
- list or pandas.DataFrame: A list of all tests or a formatted table.
251
263
  """
252
264
  test_ids = _list_test_ids()
253
265
 
@@ -286,7 +298,9 @@ def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
286
298
  return _pretty_list_tests(tests, truncate=truncate)
287
299
 
288
300
 
289
- def describe_test(test_id: TestID = None, raw: bool = False, show: bool = True):
301
+ def describe_test(
302
+ test_id: Optional[TestID] = None, raw: bool = False, show: bool = True
303
+ ) -> Union[str, HTML, Dict[str, Any]]:
290
304
  """Get or show details about the test
291
305
 
292
306
  This function can be used to see test details including the test name, description,
@@ -123,8 +123,10 @@ def AnswerCorrectness(
123
123
 
124
124
  score_column = "answer_correctness"
125
125
 
126
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
127
- fig_box = px.box(x=result_df[score_column].to_list())
126
+ fig_histogram = px.histogram(
127
+ x=result_df[score_column].to_list(), nbins=10, title="Answer Correctness"
128
+ )
129
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Answer Correctness")
128
130
 
129
131
  return (
130
132
  {
@@ -118,8 +118,10 @@ def ContextEntityRecall(
118
118
 
119
119
  score_column = "context_entity_recall"
120
120
 
121
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
122
- fig_box = px.box(x=result_df[score_column].to_list())
121
+ fig_histogram = px.histogram(
122
+ x=result_df[score_column].to_list(), nbins=10, title="Context Entity Recall"
123
+ )
124
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Context Entity Recall")
123
125
 
124
126
  return (
125
127
  {
@@ -114,8 +114,10 @@ def ContextPrecision(
114
114
 
115
115
  score_column = "llm_context_precision_with_reference"
116
116
 
117
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
118
- fig_box = px.box(x=result_df[score_column].to_list())
117
+ fig_histogram = px.histogram(
118
+ x=result_df[score_column].to_list(), nbins=10, title="Context Precision"
119
+ )
120
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Context Precision")
119
121
 
120
122
  return (
121
123
  {
@@ -109,8 +109,10 @@ def ContextPrecisionWithoutReference(
109
109
 
110
110
  score_column = "llm_context_precision_without_reference"
111
111
 
112
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
113
- fig_box = px.box(x=result_df[score_column].to_list())
112
+ fig_histogram = px.histogram(
113
+ x=result_df[score_column].to_list(), nbins=10, title="Context Precision"
114
+ )
115
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Context Precision")
114
116
 
115
117
  return (
116
118
  {
@@ -114,8 +114,10 @@ def ContextRecall(
114
114
 
115
115
  score_column = "context_recall"
116
116
 
117
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
118
- fig_box = px.box(x=result_df[score_column].to_list())
117
+ fig_histogram = px.histogram(
118
+ x=result_df[score_column].to_list(), nbins=10, title="Context Recall"
119
+ )
120
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Context Recall")
119
121
 
120
122
  return (
121
123
  {
@@ -119,8 +119,10 @@ def Faithfulness(
119
119
 
120
120
  score_column = "faithfulness"
121
121
 
122
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
123
- fig_box = px.box(x=result_df[score_column].to_list())
122
+ fig_histogram = px.histogram(
123
+ x=result_df[score_column].to_list(), nbins=10, title="Faithfulness"
124
+ )
125
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Faithfulness")
124
126
 
125
127
  return (
126
128
  {
@@ -133,8 +133,10 @@ def ResponseRelevancy(
133
133
 
134
134
  score_column = "answer_relevancy"
135
135
 
136
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
137
- fig_box = px.box(x=result_df[score_column].to_list())
136
+ fig_histogram = px.histogram(
137
+ x=result_df[score_column].to_list(), nbins=10, title="Response Relevancy"
138
+ )
139
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Response Relevancy")
138
140
 
139
141
  return (
140
142
  {
@@ -112,8 +112,10 @@ def SemanticSimilarity(
112
112
 
113
113
  score_column = "semantic_similarity"
114
114
 
115
- fig_histogram = px.histogram(x=result_df[score_column].to_list(), nbins=10)
116
- fig_box = px.box(x=result_df[score_column].to_list())
115
+ fig_histogram = px.histogram(
116
+ x=result_df[score_column].to_list(), nbins=10, title="Semantic Similarity"
117
+ )
118
+ fig_box = px.box(x=result_df[score_column].to_list(), title="Semantic Similarity")
117
119
 
118
120
  return (
119
121
  {
@@ -2,6 +2,8 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
+ from typing import Dict, List, Optional, Union
6
+
5
7
  import numpy as np
6
8
  import pandas as pd
7
9
  import plotly.graph_objects as go
@@ -12,7 +14,12 @@ from validmind import RawData, tags, tasks
12
14
  from validmind.vm_models import VMDataset, VMModel
13
15
 
14
16
 
15
- def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
17
+ def find_optimal_threshold(
18
+ y_true: np.ndarray,
19
+ y_prob: np.ndarray,
20
+ method: str = "youden",
21
+ target_recall: Optional[float] = None,
22
+ ) -> Dict[str, Union[str, float]]:
16
23
  """
17
24
  Find the optimal classification threshold using various methods.
18
25
 
@@ -80,8 +87,11 @@ def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
80
87
  @tags("model_validation", "threshold_optimization", "classification_metrics")
81
88
  @tasks("classification")
82
89
  def ClassifierThresholdOptimization(
83
- dataset: VMDataset, model: VMModel, methods=None, target_recall=None
84
- ):
90
+ dataset: VMDataset,
91
+ model: VMModel,
92
+ methods: Optional[List[str]] = None,
93
+ target_recall: Optional[float] = None,
94
+ ) -> Dict[str, Union[pd.DataFrame, go.Figure]]:
85
95
  """
86
96
  Analyzes and visualizes different threshold optimization methods for binary classification models.
87
97
 
@@ -73,6 +73,7 @@ def _prepare_results(
73
73
  columns={"shape": "training records", f"{metric}": f"training {metric}"},
74
74
  inplace=True,
75
75
  )
76
+ results["test records"] = results_test["shape"]
76
77
  results[f"test {metric}"] = results_test[metric]
77
78
 
78
79
  # Adjust gap calculation based on metric directionality
@@ -292,7 +293,8 @@ def OverfitDiagnosis(
292
293
  {
293
294
  "Feature": feature_column,
294
295
  "Slice": row["slice"],
295
- "Number of Records": row["training records"],
296
+ "Number of Training Records": row["training records"],
297
+ "Number of Test Records": row["test records"],
296
298
  f"Training {metric.upper()}": row[f"training {metric}"],
297
299
  f"Test {metric.upper()}": row[f"test {metric}"],
298
300
  "Gap": row["gap"],
@@ -3,10 +3,12 @@
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
5
  import warnings
6
+ from typing import Dict, List, Optional, Union
6
7
  from warnings import filters as _warnings_filters
7
8
 
8
9
  import matplotlib.pyplot as plt
9
10
  import numpy as np
11
+ import pandas as pd
10
12
  import shap
11
13
 
12
14
  from validmind import RawData, tags, tasks
@@ -18,7 +20,10 @@ from validmind.vm_models import VMDataset, VMModel
18
20
  logger = get_logger(__name__)
19
21
 
20
22
 
21
- def select_shap_values(shap_values, class_of_interest):
23
+ def select_shap_values(
24
+ shap_values: Union[np.ndarray, List[np.ndarray]],
25
+ class_of_interest: Optional[int] = None,
26
+ ) -> np.ndarray:
22
27
  """Selects SHAP values for binary or multiclass classification.
23
28
 
24
29
  For regression models, returns the SHAP values directly as there are no classes.
@@ -41,32 +46,30 @@ def select_shap_values(shap_values, class_of_interest):
41
46
  """
42
47
  if not isinstance(shap_values, list):
43
48
  # For regression, return the SHAP values as they are
44
- # TODO: shap_values is always an array of all predictions, how is the if above supposed to work?
45
- # logger.info("Returning SHAP values as-is.")
46
- return shap_values
47
-
48
- num_classes = len(shap_values)
49
-
50
- # Default to class 1 for binary classification where no class is specified
51
- if num_classes == 2 and class_of_interest is None:
52
- logger.debug("Using SHAP values for class 1 (positive class).")
53
- return shap_values[1]
49
+ selected_values = shap_values
50
+ else:
51
+ num_classes = len(shap_values)
52
+ # Default to class 1 for binary classification where no class is specified
53
+ if num_classes == 2 and class_of_interest is None:
54
+ selected_values = shap_values[1]
55
+ # Otherwise, use the specified class_of_interest
56
+ elif class_of_interest is not None and 0 <= class_of_interest < num_classes:
57
+ selected_values = shap_values[class_of_interest]
58
+ else:
59
+ raise ValueError(
60
+ f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
61
+ )
54
62
 
55
- # Otherwise, use the specified class_of_interest
56
- if (
57
- class_of_interest is None
58
- or class_of_interest < 0
59
- or class_of_interest >= num_classes
60
- ):
61
- raise ValueError(
62
- f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
63
- )
63
+ # Add type conversion here to ensure proper float array
64
+ if hasattr(selected_values, "dtype"):
65
+ selected_values = np.array(selected_values, dtype=np.float64)
64
66
 
65
- logger.debug(f"Using SHAP values for class {class_of_interest}.")
66
- return shap_values[class_of_interest]
67
+ return selected_values
67
68
 
68
69
 
69
- def generate_shap_plot(type_, shap_values, x_test):
70
+ def generate_shap_plot(
71
+ type_: str, shap_values: np.ndarray, x_test: Union[np.ndarray, pd.DataFrame]
72
+ ) -> plt.Figure:
70
73
  """Plots two types of SHAP global importance (SHAP).
71
74
 
72
75
  Args:
@@ -117,8 +120,8 @@ def SHAPGlobalImportance(
117
120
  dataset: VMDataset,
118
121
  kernel_explainer_samples: int = 10,
119
122
  tree_or_linear_explainer_samples: int = 200,
120
- class_of_interest: int = None,
121
- ):
123
+ class_of_interest: Optional[int] = None,
124
+ ) -> Dict[str, Union[plt.Figure, Dict[str, float]]]:
122
125
  """
123
126
  Evaluates and visualizes global feature importance using SHAP values for model explanation and risk identification.
124
127