validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- validmind/__init__.py +6 -5
- validmind/__version__.py +1 -1
- validmind/ai/test_descriptions.py +13 -9
- validmind/ai/utils.py +2 -2
- validmind/api_client.py +75 -32
- validmind/client.py +111 -100
- validmind/client_config.py +3 -3
- validmind/datasets/classification/__init__.py +7 -3
- validmind/datasets/credit_risk/lending_club.py +28 -16
- validmind/datasets/nlp/cnn_dailymail.py +10 -4
- validmind/datasets/regression/__init__.py +22 -5
- validmind/errors.py +17 -7
- validmind/input_registry.py +1 -1
- validmind/logging.py +44 -35
- validmind/models/foundation.py +2 -2
- validmind/models/function.py +10 -3
- validmind/template.py +33 -24
- validmind/test_suites/__init__.py +2 -2
- validmind/tests/_store.py +13 -4
- validmind/tests/comparison.py +65 -33
- validmind/tests/data_validation/ClassImbalance.py +3 -1
- validmind/tests/data_validation/DatasetDescription.py +2 -23
- validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
- validmind/tests/data_validation/Skewness.py +7 -6
- validmind/tests/decorator.py +14 -11
- validmind/tests/load.py +38 -24
- validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
- validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
- validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
- validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
- validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
- validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
- validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
- validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
- validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
- validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
- validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
- validmind/tests/output.py +66 -11
- validmind/tests/run.py +28 -14
- validmind/tests/test_providers.py +28 -35
- validmind/tests/utils.py +17 -4
- validmind/unit_metrics/__init__.py +1 -1
- validmind/utils.py +295 -31
- validmind/vm_models/dataset/dataset.py +83 -43
- validmind/vm_models/dataset/utils.py +5 -3
- validmind/vm_models/figure.py +6 -6
- validmind/vm_models/input.py +6 -5
- validmind/vm_models/model.py +5 -5
- validmind/vm_models/result/result.py +122 -43
- validmind/vm_models/result/utils.py +5 -5
- validmind/vm_models/test_suite/__init__.py +5 -0
- validmind/vm_models/test_suite/runner.py +5 -5
- validmind/vm_models/test_suite/summary.py +20 -2
- validmind/vm_models/test_suite/test.py +6 -6
- validmind/vm_models/test_suite/test_suite.py +10 -10
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
- {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
@@ -14,7 +14,9 @@ from validmind.errors import SkipTestError
|
|
14
14
|
from validmind.vm_models import VMDataset
|
15
15
|
|
16
16
|
|
17
|
-
@tags(
|
17
|
+
@tags(
|
18
|
+
"tabular_data", "binary_classification", "multiclass_classification", "data_quality"
|
19
|
+
)
|
18
20
|
@tasks("classification")
|
19
21
|
def ClassImbalance(
|
20
22
|
dataset: VMDataset, min_percent_threshold: int = 10
|
@@ -6,12 +6,10 @@ import re
|
|
6
6
|
from collections import Counter
|
7
7
|
|
8
8
|
import numpy as np
|
9
|
-
from ydata_profiling.config import Settings
|
10
|
-
from ydata_profiling.model.typeset import ProfilingTypeSet
|
11
9
|
|
12
10
|
from validmind import RawData, tags, tasks
|
13
|
-
from validmind.errors import UnsupportedColumnTypeError
|
14
11
|
from validmind.logging import get_logger
|
12
|
+
from validmind.utils import infer_datatypes
|
15
13
|
from validmind.vm_models import VMDataset
|
16
14
|
|
17
15
|
DEFAULT_HISTOGRAM_BINS = 10
|
@@ -20,25 +18,6 @@ DEFAULT_HISTOGRAM_BIN_SIZES = [5, 10, 20, 50]
|
|
20
18
|
logger = get_logger(__name__)
|
21
19
|
|
22
20
|
|
23
|
-
def infer_datatypes(df):
|
24
|
-
column_type_mappings = {}
|
25
|
-
typeset = ProfilingTypeSet(Settings())
|
26
|
-
variable_types = typeset.infer_type(df)
|
27
|
-
|
28
|
-
for column, type in variable_types.items():
|
29
|
-
if str(type) == "Unsupported":
|
30
|
-
if df[column].isnull().all():
|
31
|
-
column_type_mappings[column] = {"id": column, "type": "Null"}
|
32
|
-
else:
|
33
|
-
raise UnsupportedColumnTypeError(
|
34
|
-
f"Unsupported type for column {column}. Please review all values in this dataset column."
|
35
|
-
)
|
36
|
-
else:
|
37
|
-
column_type_mappings[column] = {"id": column, "type": str(type)}
|
38
|
-
|
39
|
-
return list(column_type_mappings.values())
|
40
|
-
|
41
|
-
|
42
21
|
def get_numerical_histograms(df, column):
|
43
22
|
"""
|
44
23
|
Returns a collection of histograms for a numerical column, each one
|
@@ -50,7 +29,7 @@ def get_numerical_histograms(df, column):
|
|
50
29
|
# bins='sturges'. Cannot use 'auto' until we review and fix its performance
|
51
30
|
# on datasets with too many unique values
|
52
31
|
#
|
53
|
-
# 'sturges': R
|
32
|
+
# 'sturges': R's default method, only accounts for data size. Only optimal
|
54
33
|
# for gaussian data and underestimates number of bins for large non-gaussian datasets.
|
55
34
|
default_hist = np.histogram(values_cleaned, bins="sturges")
|
56
35
|
|
@@ -44,7 +44,7 @@ def get_summary_statistics_categorical(df, categorical_fields):
|
|
44
44
|
return summary_stats
|
45
45
|
|
46
46
|
|
47
|
-
@tags("tabular_data", "time_series_data")
|
47
|
+
@tags("tabular_data", "time_series_data", "data_quality")
|
48
48
|
@tasks("classification", "regression")
|
49
49
|
def DescriptiveStatistics(dataset: VMDataset):
|
50
50
|
"""
|
@@ -2,10 +2,8 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
-
from ydata_profiling.config import Settings
|
6
|
-
from ydata_profiling.model.typeset import ProfilingTypeSet
|
7
|
-
|
8
5
|
from validmind import tags, tasks
|
6
|
+
from validmind.utils import infer_datatypes
|
9
7
|
|
10
8
|
|
11
9
|
@tags("data_quality", "tabular_data")
|
@@ -49,8 +47,11 @@ def Skewness(dataset, max_threshold=1):
|
|
49
47
|
- Subjective threshold for risk grading, requiring expert input and recurrent iterations for refinement.
|
50
48
|
"""
|
51
49
|
|
52
|
-
|
53
|
-
dataset_types =
|
50
|
+
# Use the imported infer_datatypes function
|
51
|
+
dataset_types = infer_datatypes(dataset.df)
|
52
|
+
|
53
|
+
# Convert the list of dictionaries to a dictionary for easy access
|
54
|
+
dataset_types_dict = {item["id"]: item["type"] for item in dataset_types}
|
54
55
|
|
55
56
|
skewness = dataset.df.skew(numeric_only=True)
|
56
57
|
|
@@ -58,7 +59,7 @@ def Skewness(dataset, max_threshold=1):
|
|
58
59
|
passed = True
|
59
60
|
|
60
61
|
for col in skewness.index:
|
61
|
-
if
|
62
|
+
if dataset_types_dict.get(col) != "Numeric":
|
62
63
|
continue
|
63
64
|
|
64
65
|
col_skewness = skewness[col]
|
validmind/tests/decorator.py
CHANGED
@@ -7,6 +7,7 @@
|
|
7
7
|
import inspect
|
8
8
|
import os
|
9
9
|
from functools import wraps
|
10
|
+
from typing import Any, Callable, List, Optional, TypeVar, Union
|
10
11
|
|
11
12
|
from validmind.logging import get_logger
|
12
13
|
|
@@ -15,8 +16,10 @@ from .load import load_test
|
|
15
16
|
|
16
17
|
logger = get_logger(__name__)
|
17
18
|
|
19
|
+
F = TypeVar("F", bound=Callable[..., Any])
|
18
20
|
|
19
|
-
|
21
|
+
|
22
|
+
def _get_save_func(func: Callable[..., Any], test_id: str) -> Callable[..., None]:
|
20
23
|
"""Helper function to save a decorated function to a file
|
21
24
|
|
22
25
|
Useful when a custom test function has been created inline in a notebook or
|
@@ -29,7 +32,7 @@ def _get_save_func(func, test_id):
|
|
29
32
|
# remove decorator line
|
30
33
|
source = source.split("\n", 1)[1]
|
31
34
|
|
32
|
-
def save(root_folder=".", imports=None):
|
35
|
+
def save(root_folder: str = ".", imports: Optional[List[str]] = None) -> None:
|
33
36
|
parts = test_id.split(".")
|
34
37
|
|
35
38
|
if len(parts) > 1:
|
@@ -84,7 +87,7 @@ def _get_save_func(func, test_id):
|
|
84
87
|
return save
|
85
88
|
|
86
89
|
|
87
|
-
def test(func_or_id):
|
90
|
+
def test(func_or_id: Union[Callable[..., Any], str, None]) -> Callable[[F], F]:
|
88
91
|
"""Decorator for creating and registering custom tests
|
89
92
|
|
90
93
|
This decorator registers the function it wraps as a test function within ValidMind
|
@@ -109,14 +112,14 @@ def test(func_or_id):
|
|
109
112
|
as the metric's description.
|
110
113
|
|
111
114
|
Args:
|
112
|
-
|
113
|
-
|
115
|
+
func_or_id (Union[Callable[..., Any], str, None]): Either the function to decorate
|
116
|
+
or the test ID. If None, the function name is used.
|
114
117
|
|
115
118
|
Returns:
|
116
|
-
The decorated function.
|
119
|
+
Callable[[F], F]: The decorated function.
|
117
120
|
"""
|
118
121
|
|
119
|
-
def decorator(func):
|
122
|
+
def decorator(func: F) -> F:
|
120
123
|
test_id = func_or_id or f"validmind.custom_metrics.{func.__name__}"
|
121
124
|
test_func = load_test(test_id, func, reload=True)
|
122
125
|
test_store.register_test(test_id, test_func)
|
@@ -136,28 +139,28 @@ def test(func_or_id):
|
|
136
139
|
return decorator
|
137
140
|
|
138
141
|
|
139
|
-
def tasks(*tasks):
|
142
|
+
def tasks(*tasks: str) -> Callable[[F], F]:
|
140
143
|
"""Decorator for specifying the task types that a test is designed for.
|
141
144
|
|
142
145
|
Args:
|
143
146
|
*tasks: The task types that the test is designed for.
|
144
147
|
"""
|
145
148
|
|
146
|
-
def decorator(func):
|
149
|
+
def decorator(func: F) -> F:
|
147
150
|
func.__tasks__ = list(tasks)
|
148
151
|
return func
|
149
152
|
|
150
153
|
return decorator
|
151
154
|
|
152
155
|
|
153
|
-
def tags(*tags):
|
156
|
+
def tags(*tags: str) -> Callable[[F], F]:
|
154
157
|
"""Decorator for specifying tags for a test.
|
155
158
|
|
156
159
|
Args:
|
157
160
|
*tags: The tags to apply to the test.
|
158
161
|
"""
|
159
162
|
|
160
|
-
def decorator(func):
|
163
|
+
def decorator(func: F) -> F:
|
161
164
|
func.__tags__ = list(tags)
|
162
165
|
return func
|
163
166
|
|
validmind/tests/load.py
CHANGED
@@ -7,7 +7,7 @@
|
|
7
7
|
import inspect
|
8
8
|
import json
|
9
9
|
from pprint import pformat
|
10
|
-
from typing import List
|
10
|
+
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
11
11
|
from uuid import uuid4
|
12
12
|
|
13
13
|
import pandas as pd
|
@@ -32,7 +32,10 @@ INPUT_TYPE_MAP = {
|
|
32
32
|
}
|
33
33
|
|
34
34
|
|
35
|
-
def _inspect_signature(
|
35
|
+
def _inspect_signature(
|
36
|
+
test_func: Callable[..., Any],
|
37
|
+
) -> Tuple[Dict[str, Dict[str, Any]], Dict[str, Dict[str, Any]]]:
|
38
|
+
"""Inspect a test function's signature to get inputs and parameters"""
|
36
39
|
inputs = {}
|
37
40
|
params = {}
|
38
41
|
|
@@ -56,7 +59,9 @@ def _inspect_signature(test_func: callable):
|
|
56
59
|
return inputs, params
|
57
60
|
|
58
61
|
|
59
|
-
def load_test(
|
62
|
+
def load_test(
|
63
|
+
test_id: str, test_func: Optional[Callable[..., Any]] = None, reload: bool = False
|
64
|
+
) -> Callable[..., Any]:
|
60
65
|
"""Load a test by test ID
|
61
66
|
|
62
67
|
Test IDs are in the format `namespace.path_to_module.TestClassOrFuncName[:tag]`.
|
@@ -67,6 +72,8 @@ def load_test(test_id: str, test_func: callable = None, reload: bool = False):
|
|
67
72
|
test_id (str): The test ID in the format `namespace.path_to_module.TestName[:tag]`
|
68
73
|
test_func (callable, optional): The test function to load. If not provided, the
|
69
74
|
test will be loaded from the test provider. Defaults to None.
|
75
|
+
reload (bool, optional): If True, reload the test even if it's already loaded.
|
76
|
+
Defaults to False.
|
70
77
|
"""
|
71
78
|
# remove tag if present
|
72
79
|
test_id = test_id.split(":", 1)[0]
|
@@ -109,7 +116,8 @@ def load_test(test_id: str, test_func: callable = None, reload: bool = False):
|
|
109
116
|
return test_store.get_test(test_id)
|
110
117
|
|
111
118
|
|
112
|
-
def _list_test_ids():
|
119
|
+
def _list_test_ids() -> List[str]:
|
120
|
+
"""List all available test IDs"""
|
113
121
|
test_ids = []
|
114
122
|
|
115
123
|
for namespace, test_provider in test_provider_store.test_providers.items():
|
@@ -120,7 +128,7 @@ def _list_test_ids():
|
|
120
128
|
return test_ids
|
121
129
|
|
122
130
|
|
123
|
-
def _load_tests(test_ids):
|
131
|
+
def _load_tests(test_ids: List[str]) -> Dict[str, Callable[..., Any]]:
|
124
132
|
"""Load a set of tests, handling missing dependencies."""
|
125
133
|
tests = {}
|
126
134
|
|
@@ -138,12 +146,12 @@ def _load_tests(test_ids):
|
|
138
146
|
logger.debug(str(e))
|
139
147
|
|
140
148
|
if e.extra:
|
141
|
-
logger.
|
149
|
+
logger.debug(
|
142
150
|
f"Skipping `{test_id}` as it requires extra dependencies: {e.required_dependencies}."
|
143
151
|
f" Please run `pip install validmind[{e.extra}]` to view and run this test."
|
144
152
|
)
|
145
153
|
else:
|
146
|
-
logger.
|
154
|
+
logger.debug(
|
147
155
|
f"Skipping `{test_id}` as it requires missing dependencies: {e.required_dependencies}."
|
148
156
|
" Please install the missing dependencies to view and run this test."
|
149
157
|
)
|
@@ -151,7 +159,8 @@ def _load_tests(test_ids):
|
|
151
159
|
return tests
|
152
160
|
|
153
161
|
|
154
|
-
def _test_description(test_description: str, num_lines: int = 5):
|
162
|
+
def _test_description(test_description: str, num_lines: int = 5) -> str:
|
163
|
+
"""Format a test description"""
|
155
164
|
description = test_description.strip("\n").strip()
|
156
165
|
|
157
166
|
if len(description.split("\n")) > num_lines:
|
@@ -160,7 +169,10 @@ def _test_description(test_description: str, num_lines: int = 5):
|
|
160
169
|
return description
|
161
170
|
|
162
171
|
|
163
|
-
def _pretty_list_tests(
|
172
|
+
def _pretty_list_tests(
|
173
|
+
tests: Dict[str, Callable[..., Any]], truncate: bool = True
|
174
|
+
) -> None:
|
175
|
+
"""Pretty print a list of tests"""
|
164
176
|
table = [
|
165
177
|
{
|
166
178
|
"ID": test_id,
|
@@ -171,6 +183,8 @@ def _pretty_list_tests(tests, truncate=True):
|
|
171
183
|
),
|
172
184
|
"Required Inputs": list(test.inputs.keys()),
|
173
185
|
"Params": test.params,
|
186
|
+
"Tags": test.__tags__,
|
187
|
+
"Tasks": test.__tasks__,
|
174
188
|
}
|
175
189
|
for test_id, test in tests.items()
|
176
190
|
]
|
@@ -178,10 +192,8 @@ def _pretty_list_tests(tests, truncate=True):
|
|
178
192
|
return format_dataframe(pd.DataFrame(table))
|
179
193
|
|
180
194
|
|
181
|
-
def list_tags():
|
182
|
-
"""
|
183
|
-
List unique tags from all test classes.
|
184
|
-
"""
|
195
|
+
def list_tags() -> List[str]:
|
196
|
+
"""List all unique available tags"""
|
185
197
|
|
186
198
|
unique_tags = set()
|
187
199
|
|
@@ -191,7 +203,7 @@ def list_tags():
|
|
191
203
|
return list(unique_tags)
|
192
204
|
|
193
205
|
|
194
|
-
def list_tasks_and_tags(as_json=False):
|
206
|
+
def list_tasks_and_tags(as_json: bool = False) -> Union[str, Dict[str, List[str]]]:
|
195
207
|
"""
|
196
208
|
List all task types and their associated tags, with one row per task type and
|
197
209
|
all tags for a task type in one row.
|
@@ -218,11 +230,8 @@ def list_tasks_and_tags(as_json=False):
|
|
218
230
|
)
|
219
231
|
|
220
232
|
|
221
|
-
def list_tasks():
|
222
|
-
"""
|
223
|
-
List unique tasks from all test classes.
|
224
|
-
"""
|
225
|
-
|
233
|
+
def list_tasks() -> List[str]:
|
234
|
+
"""List all unique available tasks"""
|
226
235
|
unique_tasks = set()
|
227
236
|
|
228
237
|
for test in _load_tests(list_tests(pretty=False)).values():
|
@@ -231,7 +240,13 @@ def list_tasks():
|
|
231
240
|
return list(unique_tasks)
|
232
241
|
|
233
242
|
|
234
|
-
def list_tests(
|
243
|
+
def list_tests(
|
244
|
+
filter: Optional[str] = None,
|
245
|
+
task: Optional[str] = None,
|
246
|
+
tags: Optional[List[str]] = None,
|
247
|
+
pretty: bool = True,
|
248
|
+
truncate: bool = True,
|
249
|
+
) -> Union[List[str], None]:
|
235
250
|
"""List all tests in the tests directory.
|
236
251
|
|
237
252
|
Args:
|
@@ -245,9 +260,6 @@ def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
|
|
245
260
|
formatted table. Defaults to True.
|
246
261
|
truncate (bool, optional): If True, truncates the test description to the first
|
247
262
|
line. Defaults to True. (only used if pretty=True)
|
248
|
-
|
249
|
-
Returns:
|
250
|
-
list or pandas.DataFrame: A list of all tests or a formatted table.
|
251
263
|
"""
|
252
264
|
test_ids = _list_test_ids()
|
253
265
|
|
@@ -286,7 +298,9 @@ def list_tests(filter=None, task=None, tags=None, pretty=True, truncate=True):
|
|
286
298
|
return _pretty_list_tests(tests, truncate=truncate)
|
287
299
|
|
288
300
|
|
289
|
-
def describe_test(
|
301
|
+
def describe_test(
|
302
|
+
test_id: Optional[TestID] = None, raw: bool = False, show: bool = True
|
303
|
+
) -> Union[str, HTML, Dict[str, Any]]:
|
290
304
|
"""Get or show details about the test
|
291
305
|
|
292
306
|
This function can be used to see test details including the test name, description,
|
@@ -123,8 +123,10 @@ def AnswerCorrectness(
|
|
123
123
|
|
124
124
|
score_column = "answer_correctness"
|
125
125
|
|
126
|
-
fig_histogram = px.histogram(
|
127
|
-
|
126
|
+
fig_histogram = px.histogram(
|
127
|
+
x=result_df[score_column].to_list(), nbins=10, title="Answer Correctness"
|
128
|
+
)
|
129
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Answer Correctness")
|
128
130
|
|
129
131
|
return (
|
130
132
|
{
|
@@ -118,8 +118,10 @@ def ContextEntityRecall(
|
|
118
118
|
|
119
119
|
score_column = "context_entity_recall"
|
120
120
|
|
121
|
-
fig_histogram = px.histogram(
|
122
|
-
|
121
|
+
fig_histogram = px.histogram(
|
122
|
+
x=result_df[score_column].to_list(), nbins=10, title="Context Entity Recall"
|
123
|
+
)
|
124
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Context Entity Recall")
|
123
125
|
|
124
126
|
return (
|
125
127
|
{
|
@@ -114,8 +114,10 @@ def ContextPrecision(
|
|
114
114
|
|
115
115
|
score_column = "llm_context_precision_with_reference"
|
116
116
|
|
117
|
-
fig_histogram = px.histogram(
|
118
|
-
|
117
|
+
fig_histogram = px.histogram(
|
118
|
+
x=result_df[score_column].to_list(), nbins=10, title="Context Precision"
|
119
|
+
)
|
120
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Context Precision")
|
119
121
|
|
120
122
|
return (
|
121
123
|
{
|
@@ -109,8 +109,10 @@ def ContextPrecisionWithoutReference(
|
|
109
109
|
|
110
110
|
score_column = "llm_context_precision_without_reference"
|
111
111
|
|
112
|
-
fig_histogram = px.histogram(
|
113
|
-
|
112
|
+
fig_histogram = px.histogram(
|
113
|
+
x=result_df[score_column].to_list(), nbins=10, title="Context Precision"
|
114
|
+
)
|
115
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Context Precision")
|
114
116
|
|
115
117
|
return (
|
116
118
|
{
|
@@ -114,8 +114,10 @@ def ContextRecall(
|
|
114
114
|
|
115
115
|
score_column = "context_recall"
|
116
116
|
|
117
|
-
fig_histogram = px.histogram(
|
118
|
-
|
117
|
+
fig_histogram = px.histogram(
|
118
|
+
x=result_df[score_column].to_list(), nbins=10, title="Context Recall"
|
119
|
+
)
|
120
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Context Recall")
|
119
121
|
|
120
122
|
return (
|
121
123
|
{
|
@@ -119,8 +119,10 @@ def Faithfulness(
|
|
119
119
|
|
120
120
|
score_column = "faithfulness"
|
121
121
|
|
122
|
-
fig_histogram = px.histogram(
|
123
|
-
|
122
|
+
fig_histogram = px.histogram(
|
123
|
+
x=result_df[score_column].to_list(), nbins=10, title="Faithfulness"
|
124
|
+
)
|
125
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Faithfulness")
|
124
126
|
|
125
127
|
return (
|
126
128
|
{
|
@@ -133,8 +133,10 @@ def ResponseRelevancy(
|
|
133
133
|
|
134
134
|
score_column = "answer_relevancy"
|
135
135
|
|
136
|
-
fig_histogram = px.histogram(
|
137
|
-
|
136
|
+
fig_histogram = px.histogram(
|
137
|
+
x=result_df[score_column].to_list(), nbins=10, title="Response Relevancy"
|
138
|
+
)
|
139
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Response Relevancy")
|
138
140
|
|
139
141
|
return (
|
140
142
|
{
|
@@ -112,8 +112,10 @@ def SemanticSimilarity(
|
|
112
112
|
|
113
113
|
score_column = "semantic_similarity"
|
114
114
|
|
115
|
-
fig_histogram = px.histogram(
|
116
|
-
|
115
|
+
fig_histogram = px.histogram(
|
116
|
+
x=result_df[score_column].to_list(), nbins=10, title="Semantic Similarity"
|
117
|
+
)
|
118
|
+
fig_box = px.box(x=result_df[score_column].to_list(), title="Semantic Similarity")
|
117
119
|
|
118
120
|
return (
|
119
121
|
{
|
@@ -2,6 +2,8 @@
|
|
2
2
|
# See the LICENSE file in the root of this repository for details.
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
|
+
from typing import Dict, List, Optional, Union
|
6
|
+
|
5
7
|
import numpy as np
|
6
8
|
import pandas as pd
|
7
9
|
import plotly.graph_objects as go
|
@@ -12,7 +14,12 @@ from validmind import RawData, tags, tasks
|
|
12
14
|
from validmind.vm_models import VMDataset, VMModel
|
13
15
|
|
14
16
|
|
15
|
-
def find_optimal_threshold(
|
17
|
+
def find_optimal_threshold(
|
18
|
+
y_true: np.ndarray,
|
19
|
+
y_prob: np.ndarray,
|
20
|
+
method: str = "youden",
|
21
|
+
target_recall: Optional[float] = None,
|
22
|
+
) -> Dict[str, Union[str, float]]:
|
16
23
|
"""
|
17
24
|
Find the optimal classification threshold using various methods.
|
18
25
|
|
@@ -80,8 +87,11 @@ def find_optimal_threshold(y_true, y_prob, method="youden", target_recall=None):
|
|
80
87
|
@tags("model_validation", "threshold_optimization", "classification_metrics")
|
81
88
|
@tasks("classification")
|
82
89
|
def ClassifierThresholdOptimization(
|
83
|
-
dataset: VMDataset,
|
84
|
-
|
90
|
+
dataset: VMDataset,
|
91
|
+
model: VMModel,
|
92
|
+
methods: Optional[List[str]] = None,
|
93
|
+
target_recall: Optional[float] = None,
|
94
|
+
) -> Dict[str, Union[pd.DataFrame, go.Figure]]:
|
85
95
|
"""
|
86
96
|
Analyzes and visualizes different threshold optimization methods for binary classification models.
|
87
97
|
|
@@ -73,6 +73,7 @@ def _prepare_results(
|
|
73
73
|
columns={"shape": "training records", f"{metric}": f"training {metric}"},
|
74
74
|
inplace=True,
|
75
75
|
)
|
76
|
+
results["test records"] = results_test["shape"]
|
76
77
|
results[f"test {metric}"] = results_test[metric]
|
77
78
|
|
78
79
|
# Adjust gap calculation based on metric directionality
|
@@ -292,7 +293,8 @@ def OverfitDiagnosis(
|
|
292
293
|
{
|
293
294
|
"Feature": feature_column,
|
294
295
|
"Slice": row["slice"],
|
295
|
-
"Number of Records": row["training records"],
|
296
|
+
"Number of Training Records": row["training records"],
|
297
|
+
"Number of Test Records": row["test records"],
|
296
298
|
f"Training {metric.upper()}": row[f"training {metric}"],
|
297
299
|
f"Test {metric.upper()}": row[f"test {metric}"],
|
298
300
|
"Gap": row["gap"],
|
@@ -3,10 +3,12 @@
|
|
3
3
|
# SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
|
4
4
|
|
5
5
|
import warnings
|
6
|
+
from typing import Dict, List, Optional, Union
|
6
7
|
from warnings import filters as _warnings_filters
|
7
8
|
|
8
9
|
import matplotlib.pyplot as plt
|
9
10
|
import numpy as np
|
11
|
+
import pandas as pd
|
10
12
|
import shap
|
11
13
|
|
12
14
|
from validmind import RawData, tags, tasks
|
@@ -18,7 +20,10 @@ from validmind.vm_models import VMDataset, VMModel
|
|
18
20
|
logger = get_logger(__name__)
|
19
21
|
|
20
22
|
|
21
|
-
def select_shap_values(
|
23
|
+
def select_shap_values(
|
24
|
+
shap_values: Union[np.ndarray, List[np.ndarray]],
|
25
|
+
class_of_interest: Optional[int] = None,
|
26
|
+
) -> np.ndarray:
|
22
27
|
"""Selects SHAP values for binary or multiclass classification.
|
23
28
|
|
24
29
|
For regression models, returns the SHAP values directly as there are no classes.
|
@@ -41,32 +46,30 @@ def select_shap_values(shap_values, class_of_interest):
|
|
41
46
|
"""
|
42
47
|
if not isinstance(shap_values, list):
|
43
48
|
# For regression, return the SHAP values as they are
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
49
|
+
selected_values = shap_values
|
50
|
+
else:
|
51
|
+
num_classes = len(shap_values)
|
52
|
+
# Default to class 1 for binary classification where no class is specified
|
53
|
+
if num_classes == 2 and class_of_interest is None:
|
54
|
+
selected_values = shap_values[1]
|
55
|
+
# Otherwise, use the specified class_of_interest
|
56
|
+
elif class_of_interest is not None and 0 <= class_of_interest < num_classes:
|
57
|
+
selected_values = shap_values[class_of_interest]
|
58
|
+
else:
|
59
|
+
raise ValueError(
|
60
|
+
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
|
61
|
+
)
|
54
62
|
|
55
|
-
#
|
56
|
-
if (
|
57
|
-
|
58
|
-
or class_of_interest < 0
|
59
|
-
or class_of_interest >= num_classes
|
60
|
-
):
|
61
|
-
raise ValueError(
|
62
|
-
f"Invalid class_of_interest: {class_of_interest}. Must be between 0 and {num_classes - 1}."
|
63
|
-
)
|
63
|
+
# Add type conversion here to ensure proper float array
|
64
|
+
if hasattr(selected_values, "dtype"):
|
65
|
+
selected_values = np.array(selected_values, dtype=np.float64)
|
64
66
|
|
65
|
-
|
66
|
-
return shap_values[class_of_interest]
|
67
|
+
return selected_values
|
67
68
|
|
68
69
|
|
69
|
-
def generate_shap_plot(
|
70
|
+
def generate_shap_plot(
|
71
|
+
type_: str, shap_values: np.ndarray, x_test: Union[np.ndarray, pd.DataFrame]
|
72
|
+
) -> plt.Figure:
|
70
73
|
"""Plots two types of SHAP global importance (SHAP).
|
71
74
|
|
72
75
|
Args:
|
@@ -117,8 +120,8 @@ def SHAPGlobalImportance(
|
|
117
120
|
dataset: VMDataset,
|
118
121
|
kernel_explainer_samples: int = 10,
|
119
122
|
tree_or_linear_explainer_samples: int = 200,
|
120
|
-
class_of_interest: int = None,
|
121
|
-
):
|
123
|
+
class_of_interest: Optional[int] = None,
|
124
|
+
) -> Dict[str, Union[plt.Figure, Dict[str, float]]]:
|
122
125
|
"""
|
123
126
|
Evaluates and visualizes global feature importance using SHAP values for model explanation and risk identification.
|
124
127
|
|