validmind 2.8.12__py3-none-any.whl → 2.8.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. validmind/__init__.py +6 -5
  2. validmind/__version__.py +1 -1
  3. validmind/ai/test_descriptions.py +13 -9
  4. validmind/ai/utils.py +2 -2
  5. validmind/api_client.py +75 -32
  6. validmind/client.py +111 -100
  7. validmind/client_config.py +3 -3
  8. validmind/datasets/classification/__init__.py +7 -3
  9. validmind/datasets/credit_risk/lending_club.py +28 -16
  10. validmind/datasets/nlp/cnn_dailymail.py +10 -4
  11. validmind/datasets/regression/__init__.py +22 -5
  12. validmind/errors.py +17 -7
  13. validmind/input_registry.py +1 -1
  14. validmind/logging.py +44 -35
  15. validmind/models/foundation.py +2 -2
  16. validmind/models/function.py +10 -3
  17. validmind/template.py +33 -24
  18. validmind/test_suites/__init__.py +2 -2
  19. validmind/tests/_store.py +13 -4
  20. validmind/tests/comparison.py +65 -33
  21. validmind/tests/data_validation/ClassImbalance.py +3 -1
  22. validmind/tests/data_validation/DatasetDescription.py +2 -23
  23. validmind/tests/data_validation/DescriptiveStatistics.py +1 -1
  24. validmind/tests/data_validation/Skewness.py +7 -6
  25. validmind/tests/decorator.py +14 -11
  26. validmind/tests/load.py +38 -24
  27. validmind/tests/model_validation/ragas/AnswerCorrectness.py +4 -2
  28. validmind/tests/model_validation/ragas/ContextEntityRecall.py +4 -2
  29. validmind/tests/model_validation/ragas/ContextPrecision.py +4 -2
  30. validmind/tests/model_validation/ragas/ContextPrecisionWithoutReference.py +4 -2
  31. validmind/tests/model_validation/ragas/ContextRecall.py +4 -2
  32. validmind/tests/model_validation/ragas/Faithfulness.py +4 -2
  33. validmind/tests/model_validation/ragas/ResponseRelevancy.py +4 -2
  34. validmind/tests/model_validation/ragas/SemanticSimilarity.py +4 -2
  35. validmind/tests/model_validation/sklearn/ClassifierThresholdOptimization.py +13 -3
  36. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -1
  37. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +28 -25
  38. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +15 -10
  39. validmind/tests/output.py +66 -11
  40. validmind/tests/run.py +28 -14
  41. validmind/tests/test_providers.py +28 -35
  42. validmind/tests/utils.py +17 -4
  43. validmind/unit_metrics/__init__.py +1 -1
  44. validmind/utils.py +295 -31
  45. validmind/vm_models/dataset/dataset.py +83 -43
  46. validmind/vm_models/dataset/utils.py +5 -3
  47. validmind/vm_models/figure.py +6 -6
  48. validmind/vm_models/input.py +6 -5
  49. validmind/vm_models/model.py +5 -5
  50. validmind/vm_models/result/result.py +122 -43
  51. validmind/vm_models/result/utils.py +5 -5
  52. validmind/vm_models/test_suite/__init__.py +5 -0
  53. validmind/vm_models/test_suite/runner.py +5 -5
  54. validmind/vm_models/test_suite/summary.py +20 -2
  55. validmind/vm_models/test_suite/test.py +6 -6
  56. validmind/vm_models/test_suite/test_suite.py +10 -10
  57. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/METADATA +3 -4
  58. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/RECORD +61 -60
  59. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/WHEEL +1 -1
  60. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/LICENSE +0 -0
  61. {validmind-2.8.12.dist-info → validmind-2.8.22.dist-info}/entry_points.txt +0 -0
validmind/errors.py CHANGED
@@ -15,6 +15,8 @@ from typing import Optional
15
15
 
16
16
 
17
17
  class BaseError(Exception):
18
+ """Common base class for all non-exit exceptions."""
19
+
18
20
  def __init__(self, message=""):
19
21
  self.message = message
20
22
  super().__init__(self.message)
@@ -52,7 +54,7 @@ class MissingCacheResultsArgumentsError(BaseError):
52
54
 
53
55
  class MissingOrInvalidModelPredictFnError(BaseError):
54
56
  """
55
- When the pytorch model is missing a predict function or its predict
57
+ When the PyTorch model is missing a predict function or its predict
56
58
  method does not have the expected arguments.
57
59
  """
58
60
 
@@ -71,7 +73,7 @@ class InvalidAPICredentialsError(APIRequestError):
71
73
  def description(self, *args, **kwargs):
72
74
  return (
73
75
  self.message
74
- or "Invalid API credentials. Please ensure that you have provided the correct values for api_key and api_secret."
76
+ or "Invalid API credentials. Please ensure that you have provided the correct values for API_KEY and API_SECRET."
75
77
  )
76
78
 
77
79
 
@@ -115,7 +117,7 @@ class InvalidTestResultsError(APIRequestError):
115
117
 
116
118
  class InvalidTestParametersError(BaseError):
117
119
  """
118
- When an invalid parameters for the test.
120
+ When invalid parameters are provided for the test.
119
121
  """
120
122
 
121
123
  pass
@@ -123,7 +125,15 @@ class InvalidTestParametersError(BaseError):
123
125
 
124
126
  class InvalidInputError(BaseError):
125
127
  """
126
- When an invalid input object.
128
+ When an invalid input object is provided.
129
+ """
130
+
131
+ pass
132
+
133
+
134
+ class InvalidParameterError(BaseError):
135
+ """
136
+ When an invalid parameter is provided.
127
137
  """
128
138
 
129
139
  pass
@@ -131,7 +141,7 @@ class InvalidInputError(BaseError):
131
141
 
132
142
  class InvalidTextObjectError(APIRequestError):
133
143
  """
134
- When an invalid Metadat (Text) object is sent to the API.
144
+ When an invalid Metadata (Text) object is sent to the API.
135
145
  """
136
146
 
137
147
  pass
@@ -155,7 +165,7 @@ class InvalidXGBoostTrainedModelError(BaseError):
155
165
 
156
166
  class LoadTestError(BaseError):
157
167
  """
158
- Exception raised when an error occurs while loading a test
168
+ Exception raised when an error occurs while loading a test.
159
169
  """
160
170
 
161
171
  def __init__(self, message: str, original_error: Optional[Exception] = None):
@@ -323,7 +333,7 @@ class SkipTestError(BaseError):
323
333
  def raise_api_error(error_string):
324
334
  """
325
335
  Safely try to parse JSON from the response message in case the API
326
- returns a non-JSON string or if the API returns a non-standard error
336
+ returns a non-JSON string or if the API returns a non-standard error.
327
337
  """
328
338
  try:
329
339
  json_response = json.loads(error_string)
@@ -29,7 +29,7 @@ class InputRegistry:
29
29
  if not input_obj:
30
30
  raise InvalidInputError(
31
31
  f"There's no such input with given ID '{key}'. "
32
- "Please pass valid input ID"
32
+ "Please pass valid input ID."
33
33
  )
34
34
  return input_obj
35
35
 
validmind/logging.py CHANGED
@@ -2,11 +2,12 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- """ValidMind logging module."""
5
+ """ValidMind logging module"""
6
6
 
7
7
  import logging
8
8
  import os
9
9
  import time
10
+ from typing import Any, Awaitable, Callable, Dict, Optional, TypeVar
10
11
 
11
12
  import sentry_sdk
12
13
  from sentry_sdk.utils import event_from_exception, exc_info_from_error
@@ -16,8 +17,8 @@ from .__version__ import __version__
16
17
  __dsn = "https://48f446843657444aa1e2c0d716ef864b@o1241367.ingest.sentry.io/4505239625465856"
17
18
 
18
19
 
19
- def _get_log_level():
20
- """Get the log level from the environment variable"""
20
+ def _get_log_level() -> int:
21
+ """Get the log level from the environment variable."""
21
22
  log_level_str = os.getenv("LOG_LEVEL", "INFO").upper()
22
23
 
23
24
  if log_level_str not in ["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]:
@@ -26,8 +27,10 @@ def _get_log_level():
26
27
  return logging.getLevelName(log_level_str)
27
28
 
28
29
 
29
- def get_logger(name="validmind", log_level=None):
30
- """Get a logger for the given module name"""
30
+ def get_logger(
31
+ name: str = "validmind", log_level: Optional[int] = None
32
+ ) -> logging.Logger:
33
+ """Get a logger for the given module name."""
31
34
  formatter = logging.Formatter(
32
35
  fmt="%(asctime)s - %(levelname)s(%(name)s): %(message)s"
33
36
  )
@@ -52,18 +55,21 @@ def get_logger(name="validmind", log_level=None):
52
55
  return logger
53
56
 
54
57
 
55
- def init_sentry(server_config):
56
- """Initialize Sentry SDK for sending logs back to ValidMind
58
+ def init_sentry(server_config: Dict[str, Any]) -> None:
59
+ """Initialize Sentry SDK for sending logs back to ValidMind.
57
60
 
58
- This will usually only be called by the api_client module to initialize the
59
- sentry connection after the user calls `validmind.init()`. This is because the DSN
61
+ This will usually only be called by the API client module to initialize the
62
+ Sentry connection after the user calls `validmind.init()`. This is because the DSN
60
63
  and other config options will be returned by the API.
61
64
 
62
65
  Args:
63
- config (dict): The config dictionary returned by the API
64
- - send_logs (bool): Whether to send logs to Sentry (gets removed)
65
- - dsn (str): The Sentry DSN
66
- ...: Other config options for Sentry
66
+ server_config (Dict[str, Any]): The config dictionary returned by the API.
67
+ - send_logs (bool): Whether to send logs to Sentry (gets removed).
68
+ - dsn (str): The Sentry DSN.
69
+ ...: Other config options for Sentry.
70
+
71
+ Returns:
72
+ None.
67
73
  """
68
74
  if os.getenv("VM_NO_TELEMETRY", False):
69
75
  return
@@ -88,19 +94,27 @@ def init_sentry(server_config):
88
94
  logger.debug(f"Sentry error: {str(e)}")
89
95
 
90
96
 
91
- def log_performance(name=None, logger=None, force=False):
92
- """Decorator to log the time it takes to run a function
97
+ F = TypeVar("F", bound=Callable[..., Any])
98
+ AF = TypeVar("AF", bound=Callable[..., Awaitable[Any]])
99
+
100
+
101
+ def log_performance(
102
+ name: Optional[str] = None,
103
+ logger: Optional[logging.Logger] = None,
104
+ force: bool = False,
105
+ ) -> Callable[[F], F]:
106
+ """Decorator to log the time it takes to run a function.
93
107
 
94
108
  Args:
95
109
  name (str, optional): The name of the function. Defaults to None.
96
110
  logger (logging.Logger, optional): The logger to use. Defaults to None.
97
- force (bool, optional): Whether to force logging even if env var is off
111
+ force (bool, optional): Whether to force logging even if env var is off.
98
112
 
99
113
  Returns:
100
- function: The decorated function
114
+ Callable: The decorated function.
101
115
  """
102
116
 
103
- def decorator(func):
117
+ def decorator(func: F) -> F:
104
118
  # check if log level is set to debug
105
119
  if _get_log_level() != logging.DEBUG and not force:
106
120
  return func
@@ -113,7 +127,7 @@ def log_performance(name=None, logger=None, force=False):
113
127
  if name is None:
114
128
  name = func.__name__
115
129
 
116
- def wrapped(*args, **kwargs):
130
+ def wrapped(*args: Any, **kwargs: Any) -> Any:
117
131
  time1 = time.perf_counter()
118
132
  return_val = func(*args, **kwargs)
119
133
  time2 = time.perf_counter()
@@ -127,18 +141,13 @@ def log_performance(name=None, logger=None, force=False):
127
141
  return decorator
128
142
 
129
143
 
130
- async def log_performance_async(func, name=None, logger=None, force=False):
131
- """Decorator to log the time it takes to run an async function
132
-
133
- Args:
134
- func (function): The function to decorate
135
- name (str, optional): The name of the function. Defaults to None.
136
- logger (logging.Logger, optional): The logger to use. Defaults to None.
137
- force (bool, optional): Whether to force logging even if env var is off
138
-
139
- Returns:
140
- function: The decorated function
141
- """
144
+ async def log_performance_async(
145
+ func: AF,
146
+ name: Optional[str] = None,
147
+ logger: Optional[logging.Logger] = None,
148
+ force: bool = False,
149
+ ) -> AF:
150
+ """Async version of log_performance decorator"""
142
151
  # check if log level is set to debug
143
152
  if _get_log_level() != logging.DEBUG and not force:
144
153
  return func
@@ -149,7 +158,7 @@ async def log_performance_async(func, name=None, logger=None, force=False):
149
158
  if name is None:
150
159
  name = func.__name__
151
160
 
152
- async def wrap(*args, **kwargs):
161
+ async def wrap(*args: Any, **kwargs: Any) -> Any:
153
162
  time1 = time.perf_counter()
154
163
  return_val = await func(*args, **kwargs)
155
164
  time2 = time.perf_counter()
@@ -161,11 +170,11 @@ async def log_performance_async(func, name=None, logger=None, force=False):
161
170
  return wrap
162
171
 
163
172
 
164
- def send_single_error(error: Exception):
165
- """Send a single error to Sentry
173
+ def send_single_error(error: Exception) -> None:
174
+ """Send a single error to Sentry.
166
175
 
167
176
  Args:
168
- error (Exception): The exception to send
177
+ error (Exception): The exception to send.
169
178
  """
170
179
  event, hint = event_from_exception(exc_info_from_error(error))
171
180
  client = sentry_sdk.Client(__dsn, release=f"validmind-python@{__version__}")
@@ -26,9 +26,9 @@ class FoundationModel(FunctionModel):
26
26
 
27
27
  Attributes:
28
28
  predict_fn (callable): The predict function that should take a prompt as input
29
- and return the result from the model
29
+ and return the result from the model
30
30
  prompt (Prompt): The prompt object that defines the prompt template and the
31
- variables (if any)
31
+ variables (if any)
32
32
  name (str, optional): The name of the model. Defaults to name of the predict_fn
33
33
  """
34
34
 
@@ -2,6 +2,8 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
+ from typing import Any, Dict, List
6
+
5
7
  from validmind.vm_models.model import VMModel
6
8
 
7
9
 
@@ -18,7 +20,12 @@ class Input(dict):
18
20
  def __delitem__(self, _):
19
21
  raise TypeError("Cannot delete keys from Input")
20
22
 
21
- def get_new(self):
23
+ def get_new(self) -> Dict[str, Any]:
24
+ """Get the newly added key-value pairs.
25
+
26
+ Returns:
27
+ Dict[str, Any]: Dictionary containing only the newly added key-value pairs.
28
+ """
22
29
  return {k: self[k] for k in self._new}
23
30
 
24
31
 
@@ -41,13 +48,13 @@ class FunctionModel(VMModel):
41
48
 
42
49
  self.name = self.name or self.predict_fn.__name__
43
50
 
44
- def predict(self, X):
51
+ def predict(self, X) -> List[Any]:
45
52
  """Compute predictions for the input (X)
46
53
 
47
54
  Args:
48
55
  X (pandas.DataFrame): The input features to predict on
49
56
 
50
57
  Returns:
51
- list: The predictions
58
+ List[Any]: The predictions
52
59
  """
53
60
  return [self.predict_fn(x) for x in X.to_dict(orient="records")]
validmind/template.py CHANGED
@@ -2,7 +2,9 @@
2
2
  # See the LICENSE file in the root of this repository for details.
3
3
  # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
4
 
5
- from ipywidgets import HTML, Accordion, VBox
5
+ from typing import Any, Dict, List, Optional, Type, Union
6
+
7
+ from ipywidgets import HTML, Accordion, VBox, Widget
6
8
 
7
9
  from .html_templates.content_blocks import (
8
10
  failed_content_block_html,
@@ -29,8 +31,10 @@ CONTENT_TYPE_MAP = {
29
31
 
30
32
 
31
33
  def _convert_sections_to_section_tree(
32
- sections, parent_id="_root_", start_section_id=None
33
- ):
34
+ sections: List[Dict[str, Any]],
35
+ parent_id: str = "_root_",
36
+ start_section_id: Optional[str] = None,
37
+ ) -> List[Dict[str, Any]]:
34
38
  section_tree = []
35
39
 
36
40
  for section in sections:
@@ -49,11 +53,12 @@ def _convert_sections_to_section_tree(
49
53
 
50
54
  if start_section_id and not section_tree:
51
55
  raise ValueError(f"Section {start_section_id} not found in template")
52
-
53
- return sorted(section_tree, key=lambda x: x.get("order", 0))
56
+ # sort the section tree by the order of the sections in the template (if provided)
57
+ # set the order to 9999 for the sections that do not have an order
58
+ return sorted(section_tree, key=lambda x: x.get("order", 9999))
54
59
 
55
60
 
56
- def _create_content_widget(content):
61
+ def _create_content_widget(content: Dict[str, Any]) -> Widget:
57
62
  content_type = CONTENT_TYPE_MAP[content["content_type"]]
58
63
 
59
64
  if content["content_type"] not in ["metric", "test"]:
@@ -75,7 +80,9 @@ def _create_content_widget(content):
75
80
  )
76
81
 
77
82
 
78
- def _create_sub_section_widget(sub_sections, section_number):
83
+ def _create_sub_section_widget(
84
+ sub_sections: List[Dict[str, Any]], section_number: str
85
+ ) -> Union[HTML, Accordion]:
79
86
  if not sub_sections:
80
87
  return HTML("<p>Empty Section</p>")
81
88
 
@@ -111,7 +118,7 @@ def _create_sub_section_widget(sub_sections, section_number):
111
118
  return accordion
112
119
 
113
120
 
114
- def _create_section_widget(tree):
121
+ def _create_section_widget(tree: List[Dict[str, Any]]) -> Accordion:
115
122
  widget = Accordion()
116
123
  for i, section in enumerate(tree):
117
124
  sub_widget = None
@@ -139,11 +146,11 @@ def _create_section_widget(tree):
139
146
  return widget
140
147
 
141
148
 
142
- def preview_template(template):
143
- """Preview a template in Jupyter Notebook
149
+ def preview_template(template: str) -> None:
150
+ """Preview a template in Jupyter Notebook.
144
151
 
145
152
  Args:
146
- template (dict): The template to preview
153
+ template (dict): The template to preview.
147
154
  """
148
155
  if not is_notebook():
149
156
  logger.warning("preview_template() only works in Jupyter Notebook")
@@ -154,7 +161,7 @@ def preview_template(template):
154
161
  )
155
162
 
156
163
 
157
- def _get_section_tests(section):
164
+ def _get_section_tests(section: Dict[str, Any]) -> List[str]:
158
165
  """
159
166
  Get all the tests in a section and its subsections.
160
167
 
@@ -179,15 +186,15 @@ def _get_section_tests(section):
179
186
  return tests
180
187
 
181
188
 
182
- def _create_test_suite_section(section):
189
+ def _create_test_suite_section(section: Dict[str, Any]) -> Dict[str, Any]:
183
190
  """Create a section object for a test suite that contains the tests in a section
184
- in the template
191
+ in the template.
185
192
 
186
193
  Args:
187
- section: a section of a template (in tree form)
194
+ section: A section of a template (in tree form).
188
195
 
189
196
  Returns:
190
- A TestSuite section dict
197
+ A TestSuite section dict.
191
198
  """
192
199
  if section_tests := _get_section_tests(section):
193
200
  return {
@@ -197,16 +204,18 @@ def _create_test_suite_section(section):
197
204
  }
198
205
 
199
206
 
200
- def _create_template_test_suite(template, section=None):
207
+ def _create_template_test_suite(
208
+ template: str, section: Optional[str] = None
209
+ ) -> Type[TestSuite]:
201
210
  """
202
211
  Create and run a test suite from a template.
203
212
 
204
213
  Args:
205
- template: A valid flat template
206
- section: The section of the template to run (if not provided, run all sections)
214
+ template: A valid flat template.
215
+ section: The section of the template to run. Runs all sections if not provided.
207
216
 
208
217
  Returns:
209
- A dynamically-create TestSuite Class
218
+ A dynamically-created TestSuite Class.
210
219
  """
211
220
  section_tree = _convert_sections_to_section_tree(
212
221
  sections=template["sections"],
@@ -229,17 +238,17 @@ def _create_template_test_suite(template, section=None):
229
238
  )
230
239
 
231
240
 
232
- def get_template_test_suite(template, section=None):
233
- """Get a TestSuite instance containing all tests in a template
241
+ def get_template_test_suite(template: str, section: Optional[str] = None) -> TestSuite:
242
+ """Get a TestSuite instance containing all tests in a template.
234
243
 
235
244
  This function will collect all tests used in a template into a dynamically-created
236
- TestSuite object
245
+ TestSuite object.
237
246
 
238
247
  Args:
239
248
  template: A valid flat template
240
249
  section: The section of the template to run (if not provided, run all sections)
241
250
 
242
251
  Returns:
243
- The TestSuite instance
252
+ The TestSuite instance.
244
253
  """
245
254
  return _create_template_test_suite(template, section)()
@@ -141,7 +141,7 @@ def list_suites(pretty: bool = True):
141
141
  return format_dataframe(pd.DataFrame(table))
142
142
 
143
143
 
144
- def describe_suite(test_suite_id: str, verbose=False):
144
+ def describe_suite(test_suite_id: str, verbose: bool = False) -> pd.DataFrame:
145
145
  """
146
146
  Describes a Test Suite by ID
147
147
 
@@ -150,7 +150,7 @@ def describe_suite(test_suite_id: str, verbose=False):
150
150
  verbose: If True, describe all plans and tests in the Test Suite
151
151
 
152
152
  Returns:
153
- pandas.DataFrame: A formatted table with the Test Suite description
153
+ pd.DataFrame: A formatted table with the Test Suite description
154
154
  """
155
155
  test_suite = get_by_id(test_suite_id)
156
156
 
validmind/tests/_store.py CHANGED
@@ -5,6 +5,8 @@
5
5
  """Module for storing loaded tests and test providers"""
6
6
 
7
7
 
8
+ from typing import Any, Callable, Optional
9
+
8
10
  from .test_providers import TestProvider, ValidMindTestProvider
9
11
 
10
12
 
@@ -65,19 +67,26 @@ class TestStore:
65
67
  def __init__(self):
66
68
  self.tests = {}
67
69
 
68
- def get_test(self, test_id: str):
70
+ def get_test(self, test_id: str) -> Optional[Callable[..., Any]]:
69
71
  """Get a test by test ID
70
72
 
71
73
  Args:
72
74
  test_id (str): The test ID
73
75
 
74
76
  Returns:
75
- object: The test class or function
77
+ Optional[Callable[..., Any]]: The test function if found, None otherwise
76
78
  """
77
79
  return self.tests.get(test_id)
78
80
 
79
- def register_test(self, test_id: str, test: object = None):
80
- """Register a test"""
81
+ def register_test(
82
+ self, test_id: str, test: Optional[Callable[..., Any]] = None
83
+ ) -> None:
84
+ """Register a test
85
+
86
+ Args:
87
+ test_id (str): The test ID
88
+ test (Optional[Callable[..., Any]], optional): The test function. Defaults to None.
89
+ """
81
90
  self.tests[test_id] = test
82
91
 
83
92
 
@@ -146,7 +146,9 @@ def _combine_tables(results: List[TestResult]) -> List[pd.DataFrame]:
146
146
  return [_combine_single_table(results, i) for i in range(len(results[0].tables))]
147
147
 
148
148
 
149
- def _build_input_param_string(result: TestResult, results: List[TestResult]) -> str:
149
+ def _build_input_param_string(
150
+ result: TestResult, results: List[TestResult], show_params: bool
151
+ ) -> str:
150
152
  """Build a string repr of unique inputs + params for a figure title"""
151
153
  parts = []
152
154
  unique_inputs = _get_unique_inputs(results)
@@ -162,19 +164,29 @@ def _build_input_param_string(result: TestResult, results: List[TestResult]) ->
162
164
  input_val = _get_input_key(input_obj)
163
165
  parts.append(f"{input_name}={input_val}")
164
166
 
165
- # TODO: revisit this when we can create a value/title to show for params
166
- # unique_params = _get_unique_params(results)
167
- # # if theres only one unique value for a param, don't show it
168
- # # however, if there is only one unique value for all params then show it as
169
- # # long as there is no existing inputs in the parts list
170
- # if result.params:
171
- # should_show = (
172
- # all(len(unique_params[param_name]) == 1 for param_name in unique_params)
173
- # and not parts
174
- # )
175
- # for param_name, param_value in result.params.items():
176
- # if should_show or len(unique_params[param_name]) > 1:
177
- # parts.append(f"{param_name}={param_value}")
167
+ # Handle params if show_params is enabled
168
+ if show_params and result.params:
169
+ unique_params = _get_unique_params(results)
170
+ # If there's only one unique value for a param, don't show it
171
+ # unless there is only one unique value for all params and no inputs shown
172
+ should_show = (
173
+ all(len(unique_params[param_name]) == 1 for param_name in unique_params)
174
+ and not parts
175
+ )
176
+ for param_name, param_value in result.params.items():
177
+ if should_show or len(unique_params[param_name]) > 1:
178
+ # Convert the param_value to a string representation
179
+ if isinstance(param_value, list):
180
+ # For lists, join elements with commas
181
+ str_value = ",".join(str(v) for v in param_value)
182
+ elif hasattr(param_value, "__str__"):
183
+ # Use string representation if available
184
+ str_value = str(param_value)
185
+ else:
186
+ # Default fallback
187
+ str_value = repr(param_value)
188
+
189
+ parts.append(f"{param_name}={str_value}")
178
190
 
179
191
  return ", ".join(parts)
180
192
 
@@ -207,7 +219,7 @@ def _update_figure_title(figure: Any, input_param_str: str) -> None:
207
219
  raise ValueError(f"Unsupported figure type: {type(figure)}")
208
220
 
209
221
 
210
- def _combine_figures(results: List[TestResult]) -> List[Any]:
222
+ def _combine_figures(results: List[TestResult], show_params: bool) -> List[Any]:
211
223
  """Combine figures from multiple test results (gets raw figure objects, not vm Figures)"""
212
224
  combined_figures = []
213
225
 
@@ -216,7 +228,7 @@ def _combine_figures(results: List[TestResult]) -> List[Any]:
216
228
  # update the figure object in-place with the new title
217
229
  _update_figure_title(
218
230
  figure=figure.figure,
219
- input_param_str=_build_input_param_string(result, results),
231
+ input_param_str=_build_input_param_string(result, results, show_params),
220
232
  )
221
233
  combined_figures.append(figure)
222
234
 
@@ -279,35 +291,53 @@ def get_comparison_test_configs(
279
291
  A list of test configurations.
280
292
  """
281
293
 
282
- # Convert list of dicts to dict of lists if necessary
294
+ # Convert list of dicts to dict of lists if necessary for input_grid
283
295
  def list_to_dict(grid_list):
284
296
  return {k: [d[k] for d in grid_list] for k in grid_list[0].keys()}
285
297
 
298
+ # Handle input_grid the same way as before
286
299
  if isinstance(input_grid, list):
287
300
  input_grid = list_to_dict(input_grid)
288
301
 
289
- if isinstance(param_grid, list):
290
- param_grid = list_to_dict(param_grid)
291
-
292
302
  test_configs = []
293
303
 
294
- if input_grid and param_grid:
295
- input_combinations = _cartesian_product(input_grid)
296
- param_combinations = _cartesian_product(param_grid)
297
- test_configs = [
298
- {"inputs": i, "params": p}
299
- for i, p in product(input_combinations, param_combinations)
300
- ]
304
+ # Check if param_grid is a list of dictionaries
305
+ is_param_grid_list = isinstance(param_grid, list)
306
+
307
+ # Special handling for list-based param_grid
308
+ if is_param_grid_list:
309
+ if input_grid:
310
+ # Generate all combinations of input_grid and each param dictionary
311
+ input_combinations = _cartesian_product(input_grid)
312
+ test_configs = [
313
+ {"inputs": i, "params": p}
314
+ for i in input_combinations
315
+ for p in param_grid
316
+ ]
317
+ else:
318
+ # Each dictionary in param_grid is a specific test configuration
319
+ test_configs = [{"inputs": inputs or {}, "params": p} for p in param_grid]
320
+
321
+ # Dictionary-based param_grid
322
+ elif param_grid:
323
+ if input_grid:
324
+ input_combinations = _cartesian_product(input_grid)
325
+ param_combinations = _cartesian_product(param_grid)
326
+ test_configs = [
327
+ {"inputs": i, "params": p}
328
+ for i, p in product(input_combinations, param_combinations)
329
+ ]
330
+ else:
331
+ param_combinations = _cartesian_product(param_grid)
332
+ test_configs = [
333
+ {"inputs": inputs or {}, "params": p} for p in param_combinations
334
+ ]
335
+ # Just input_grid, no param_grid
301
336
  elif input_grid:
302
337
  input_combinations = _cartesian_product(input_grid)
303
338
  test_configs = [
304
339
  {"inputs": i, "params": params or {}} for i in input_combinations
305
340
  ]
306
- elif param_grid:
307
- param_combinations = _cartesian_product(param_grid)
308
- test_configs = [
309
- {"inputs": inputs or {}, "params": p} for p in param_combinations
310
- ]
311
341
 
312
342
  return test_configs
313
343
 
@@ -333,12 +363,14 @@ def _combine_raw_data(results: List[TestResult]) -> RawData:
333
363
 
334
364
  def combine_results(
335
365
  results: List[TestResult],
366
+ show_params: bool,
336
367
  ) -> Tuple[List[Any], Dict[str, List[Any]], Dict[str, List[Any]]]:
337
368
  """
338
369
  Combine multiple test results into a single set of outputs.
339
370
 
340
371
  Args:
341
372
  results: A list of TestResult objects to combine.
373
+ show_params: Whether to show parameter values in figure titles.
342
374
 
343
375
  Returns:
344
376
  A tuple containing:
@@ -353,7 +385,7 @@ def combine_results(
353
385
  # handle tables (if any)
354
386
  combined_outputs.extend(_combine_tables(results))
355
387
  # handle figures (if any)
356
- combined_outputs.extend(_combine_figures(results))
388
+ combined_outputs.extend(_combine_figures(results, show_params))
357
389
  # handle threshold tests (i.e. tests that have pass/fail bool status)
358
390
  if results[0].passed is not None:
359
391
  combined_outputs.append(all(result.passed for result in results))