validmind 2.4.9__py3-none-any.whl → 2.4.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
validmind/__version__.py CHANGED
@@ -1 +1 @@
1
- __version__ = "2.4.9"
1
+ __version__ = "2.4.13"
validmind/api_client.py CHANGED
@@ -34,6 +34,7 @@ _api_secret = os.getenv("VM_API_SECRET")
34
34
  _api_host = os.getenv("VM_API_HOST")
35
35
  _model_cuid = os.getenv("VM_API_MODEL")
36
36
  _run_cuid = os.getenv("VM_RUN_CUID")
37
+ _monitoring = False
37
38
 
38
39
  __api_session: aiohttp.ClientSession = None
39
40
 
@@ -57,6 +58,7 @@ def get_api_config() -> Dict[str, Optional[str]]:
57
58
  "VM_API_HOST": _api_host,
58
59
  "VM_API_MODEL": _model_cuid,
59
60
  "VM_RUN_CUID": _run_cuid,
61
+ "X-MONITORING": _monitoring,
60
62
  }
61
63
 
62
64
 
@@ -73,6 +75,7 @@ def get_api_headers() -> Dict[str, str]:
73
75
  "X-API-KEY": _api_key,
74
76
  "X-API-SECRET": _api_secret,
75
77
  "X-PROJECT-CUID": _model_cuid,
78
+ "X-MONITORING": _monitoring,
76
79
  }
77
80
 
78
81
 
@@ -82,6 +85,7 @@ def init(
82
85
  api_secret: Optional[str] = None,
83
86
  api_host: Optional[str] = None,
84
87
  model: Optional[str] = None,
88
+ monitoring=False,
85
89
  ):
86
90
  """
87
91
  Initializes the API client instances and calls the /ping endpoint to ensure
@@ -96,11 +100,12 @@ def init(
96
100
  api_key (str, optional): The API key. Defaults to None.
97
101
  api_secret (str, optional): The API secret. Defaults to None.
98
102
  api_host (str, optional): The API host. Defaults to None.
103
+ monitoring (str, optional): The ongoing monitoring flag. Defaults to False.
99
104
 
100
105
  Raises:
101
106
  ValueError: If the API key and secret are not provided
102
107
  """
103
- global _api_key, _api_secret, _api_host, _run_cuid, _model_cuid
108
+ global _api_key, _api_secret, _api_host, _run_cuid, _model_cuid, _monitoring
104
109
 
105
110
  if api_key == "...":
106
111
  # special case to detect when running a notebook with the standard init snippet
@@ -125,6 +130,8 @@ def init(
125
130
 
126
131
  _run_cuid = os.getenv("VM_RUN_CUID", None)
127
132
 
133
+ _monitoring = monitoring
134
+
128
135
  try:
129
136
  __ping()
130
137
  except Exception as e:
@@ -145,6 +152,7 @@ def _get_session() -> aiohttp.ClientSession:
145
152
  "X-API-KEY": _api_key,
146
153
  "X-API-SECRET": _api_secret,
147
154
  "X-PROJECT-CUID": _model_cuid,
155
+ "X-MONITORING": str(_monitoring),
148
156
  }
149
157
  )
150
158
 
@@ -159,6 +167,7 @@ def __ping() -> Dict[str, Any]:
159
167
  "X-API-KEY": _api_key,
160
168
  "X-API-SECRET": _api_secret,
161
169
  "X-PROJECT-CUID": _model_cuid,
170
+ "X-MONITORING": str(_monitoring),
162
171
  },
163
172
  )
164
173
  if r.status_code != 200:
validmind/client.py CHANGED
@@ -48,7 +48,6 @@ def init_dataset(
48
48
  index_name: str = None,
49
49
  date_time_index: bool = False,
50
50
  columns: list = None,
51
- options: dict = None,
52
51
  text_column: str = None,
53
52
  target_column: str = None,
54
53
  feature_columns: list = None,
@@ -72,7 +71,6 @@ def init_dataset(
72
71
  Args:
73
72
  dataset : dataset from various python libraries
74
73
  model (VMModel): ValidMind model object
75
- options (dict): A dictionary of options for the dataset
76
74
  targets (vm.vm.DatasetTargets): A list of target variables
77
75
  target_column (str): The name of the target column in the dataset
78
76
  feature_columns (list): A list of names of feature columns in the dataset
@@ -135,7 +133,8 @@ def init_dataset(
135
133
  model=model,
136
134
  index=index,
137
135
  index_name=index_name,
138
- columns=columns,
136
+ # if no columns are passed, use the index
137
+ columns=columns or [i for i in range(dataset.shape[1])],
139
138
  target_column=target_column,
140
139
  feature_columns=feature_columns,
141
140
  text_column=text_column,
@@ -6,6 +6,8 @@
6
6
  Central class to register inputs
7
7
  """
8
8
 
9
+ from validmind.vm_models.input import VMInput
10
+
9
11
  from .errors import InvalidInputError
10
12
 
11
13
 
@@ -14,6 +16,12 @@ class InputRegistry:
14
16
  self.registry = {}
15
17
 
16
18
  def add(self, key, obj):
19
+ if not isinstance(obj, VMInput):
20
+ raise InvalidInputError(
21
+ f"Input object must be an instance of VMInput. "
22
+ f"Got {type(obj)} instead."
23
+ )
24
+
17
25
  self.registry[key] = obj
18
26
 
19
27
  def get(self, key):
@@ -134,6 +134,7 @@ class DatasetDescription(Metric):
134
134
  )
135
135
  else:
136
136
  vm_dataset_variables[column] = {"id": column, "type": str(type)}
137
+
137
138
  return list(vm_dataset_variables.values())
138
139
 
139
140
  def describe_dataset_field(self, df, field):
validmind/tests/run.py CHANGED
@@ -83,32 +83,47 @@ def _combine_summaries(summaries: List[Dict[str, Any]]):
83
83
  )
84
84
 
85
85
 
86
- def _update_plotly_titles(figures, input_groups, title_template):
87
- current_title = figures[0].figure.layout.title.text
86
+ def _get_input_id(v):
87
+ if isinstance(v, str):
88
+ return v # If v is a string, return it as is.
89
+ elif isinstance(v, list) and all(hasattr(item, "input_id") for item in v):
90
+ # If v is a list and all items have an input_id attribute, join their input_id values.
91
+ return ", ".join(item.input_id for item in v)
92
+ elif hasattr(v, "input_id"):
93
+ return v.input_id # If v has an input_id attribute, return it.
94
+ return str(v) # Otherwise, return the string representation of v.
95
+
96
+
97
+ def _update_plotly_titles(figures, input_group, title_template):
98
+ for figure in figures:
99
+
100
+ current_title = figure.figure.layout.title.text
101
+
102
+ input_description = " and ".join(
103
+ f"{key}: {_get_input_id(value)}" for key, value in input_group.items()
104
+ )
88
105
 
89
- for i, figure in enumerate(figures):
90
106
  figure.figure.layout.title.text = title_template.format(
91
107
  current_title=f"{current_title} " if current_title else "",
92
- input_description=" and ".join(
93
- f"{k}: {v if isinstance(v, str) else ', '.join(item.input_id for item in v) if isinstance(v, list) and all(hasattr(item, 'input_id') for item in v) else v.input_id}"
94
- for k, v in input_groups[i].items()
95
- ),
108
+ input_description=input_description,
96
109
  )
97
110
 
98
111
 
99
- def _update_matplotlib_titles(figures, input_groups, title_template):
100
- current_title = (
101
- figures[0].figure._suptitle.get_text() if figures[0].figure._suptitle else ""
102
- )
112
+ def _update_matplotlib_titles(figures, input_group, title_template):
113
+ for figure in figures:
114
+
115
+ current_title = (
116
+ figure.figure._suptitle.get_text() if figure.figure._suptitle else ""
117
+ )
118
+
119
+ input_description = " and ".join(
120
+ f"{key}: {_get_input_id(value)}" for key, value in input_group.items()
121
+ )
103
122
 
104
- for i, figure in enumerate(figures):
105
123
  figure.figure.suptitle(
106
124
  title_template.format(
107
125
  current_title=f"{current_title} " if current_title else "",
108
- input_description=" and ".join(
109
- f"{k}: {v if isinstance(v, str) else ', '.join(item.input_id for item in v) if isinstance(v, list) and all(hasattr(item, 'input_id') for item in v) else v.input_id}"
110
- for k, v in input_groups[i].items()
111
- ),
126
+ input_description=input_description,
112
127
  )
113
128
  )
114
129
 
@@ -120,11 +135,12 @@ def _combine_figures(figure_lists: List[List[Any]], input_groups: List[Dict[str,
120
135
 
121
136
  title_template = "{current_title}({input_description})"
122
137
 
123
- for figures in list(zip(*figure_lists)):
138
+ for idx, figures in enumerate(figure_lists):
139
+ input_group = input_groups[idx]
124
140
  if is_plotly_figure(figures[0].figure):
125
- _update_plotly_titles(figures, input_groups, title_template)
141
+ _update_plotly_titles(figures, input_group, title_template)
126
142
  elif is_matplotlib_figure(figures[0].figure):
127
- _update_matplotlib_titles(figures, input_groups, title_template)
143
+ _update_matplotlib_titles(figures, input_group, title_template)
128
144
  else:
129
145
  logger.warning("Cannot properly annotate png figures")
130
146
 
@@ -80,7 +80,7 @@ def _serialize_dataset(dataset, model):
80
80
  and pre-computed prediction columns, addressing potential hash collisions.
81
81
  """
82
82
  return _fast_hash(
83
- dataset.df[
83
+ dataset._df[
84
84
  [
85
85
  *dataset.feature_columns,
86
86
  dataset.target_column,
@@ -8,6 +8,7 @@ Models entrypoint
8
8
 
9
9
  from .dataset.dataset import VMDataset
10
10
  from .figure import Figure
11
+ from .input import VMInput
11
12
  from .model import R_MODEL_TYPES, ModelAttributes, VMModel
12
13
  from .test.metric import Metric
13
14
  from .test.metric_result import MetricResult
@@ -20,6 +21,7 @@ from .test_suite.runner import TestSuiteRunner
20
21
  from .test_suite.test_suite import TestSuite
21
22
 
22
23
  __all__ = [
24
+ "VMInput",
23
25
  "VMDataset",
24
26
  "VMModel",
25
27
  "Figure",
@@ -7,6 +7,7 @@ Dataset class wrapper
7
7
  """
8
8
 
9
9
  import warnings
10
+ from copy import deepcopy
10
11
 
11
12
  import numpy as np
12
13
  import pandas as pd
@@ -14,6 +15,7 @@ import polars as pl
14
15
 
15
16
  from validmind.logging import get_logger
16
17
  from validmind.models import FunctionModel, PipelineModel
18
+ from validmind.vm_models.input import VMInput
17
19
  from validmind.vm_models.model import VMModel
18
20
 
19
21
  from .utils import ExtraColumns, as_df, compute_predictions, convert_index_to_datetime
@@ -21,7 +23,7 @@ from .utils import ExtraColumns, as_df, compute_predictions, convert_index_to_da
21
23
  logger = get_logger(__name__)
22
24
 
23
25
 
24
- class VMDataset:
26
+ class VMDataset(VMInput):
25
27
  """Base class for VM datasets
26
28
 
27
29
  Child classes should be used to support new dataset types (tensor, polars etc)
@@ -60,7 +62,6 @@ class VMDataset:
60
62
  text_column: str = None,
61
63
  extra_columns: dict = None,
62
64
  target_class_labels: dict = None,
63
- options: dict = None,
64
65
  ):
65
66
  """
66
67
  Initializes a VMDataset instance.
@@ -77,7 +78,6 @@ class VMDataset:
77
78
  feature_columns (str, optional): The feature column names of the dataset. Defaults to None.
78
79
  text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
79
80
  target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
80
- options (Dict, optional): Additional options for the dataset. Defaults to None.
81
81
  """
82
82
  # initialize input_id
83
83
  self.input_id = input_id
@@ -92,16 +92,14 @@ class VMDataset:
92
92
  raise ValueError("Expected Numpy array for attribute raw_dataset")
93
93
  self.index = index
94
94
 
95
- self.df = pd.DataFrame(self._raw_dataset, columns=columns).infer_objects()
95
+ self._df = pd.DataFrame(self._raw_dataset, columns=columns).infer_objects()
96
96
  # set index to dataframe
97
97
  if index is not None:
98
- self.df.set_index(pd.Index(index), inplace=True)
99
- self.df.index.name = index_name
98
+ self._df.set_index(pd.Index(index), inplace=True)
99
+ self._df.index.name = index_name
100
100
  # attempt to convert index to datatime
101
101
  if date_time_index:
102
- self.df = convert_index_to_datetime(self.df)
103
-
104
- self.options = options
102
+ self._df = convert_index_to_datetime(self._df)
105
103
 
106
104
  self.columns = columns or []
107
105
  self.column_aliases = {}
@@ -128,12 +126,12 @@ class VMDataset:
128
126
  self.feature_columns = [col for col in self.columns if col not in excluded]
129
127
 
130
128
  self.feature_columns_numeric = (
131
- self.df[self.feature_columns]
129
+ self._df[self.feature_columns]
132
130
  .select_dtypes(include=[np.number])
133
131
  .columns.tolist()
134
132
  )
135
133
  self.feature_columns_categorical = (
136
- self.df[self.feature_columns]
134
+ self._df[self.feature_columns]
137
135
  .select_dtypes(include=[object, pd.Categorical])
138
136
  .columns.tolist()
139
137
  )
@@ -142,19 +140,19 @@ class VMDataset:
142
140
  column_values = np.array(column_values)
143
141
 
144
142
  if column_values.ndim == 1:
145
- if len(column_values) != len(self.df):
143
+ if len(column_values) != len(self._df):
146
144
  raise ValueError(
147
145
  "Length of values doesn't match number of rows in the DataFrame."
148
146
  )
149
147
  self.columns.append(column_name)
150
- self.df[column_name] = column_values
148
+ self._df[column_name] = column_values
151
149
  elif column_values.ndim == 2:
152
- if column_values.shape[0] != len(self.df):
150
+ if column_values.shape[0] != len(self._df):
153
151
  raise ValueError(
154
152
  "Number of rows in values doesn't match number of rows in the DataFrame."
155
153
  )
156
154
  self.columns.append(column_name)
157
- self.df[column_name] = column_values.tolist()
155
+ self._df[column_name] = column_values.tolist()
158
156
 
159
157
  else:
160
158
  raise ValueError("Only 1D and 2D arrays are supported for column_values.")
@@ -199,6 +197,56 @@ class VMDataset:
199
197
  "Cannot use precomputed probabilities without precomputed predictions"
200
198
  )
201
199
 
200
+ def with_options(self, **kwargs) -> "VMDataset":
201
+ """Support options provided when passing an input to run_test or run_test_suite
202
+
203
+ Example:
204
+ ```python
205
+ # to only use a certain subset of columns in the dataset:
206
+ run_test(
207
+ "validmind.SomeTestID",
208
+ inputs={
209
+ "dataset": {
210
+ "input_id": "my_dataset_id",
211
+ "columns": ["col1", "col2"],
212
+ }
213
+ }
214
+ )
215
+
216
+ # behind the scenes, this retrieves the dataset object (VMDataset) from the registry
217
+ # and then calls the `with_options()` method and passes `{"columns": ...}`
218
+ ```
219
+
220
+ Args:
221
+ **kwargs: Options:
222
+ - columns: Filter columns in the dataset
223
+
224
+ Returns:
225
+ VMDataset: A new instance of the dataset with only the specified columns
226
+ """
227
+ if "columns" in kwargs:
228
+ # filter columns (create a temp copy of self with only specified columns)
229
+ # TODO: need a more robust mechanism for this as we expand on this feature
230
+ columns = kwargs.pop("columns")
231
+
232
+ new = deepcopy(self)
233
+
234
+ new._set_feature_columns(
235
+ [col for col in new.feature_columns if col in columns]
236
+ )
237
+ new.text_column = new.text_column if new.text_column in columns else None
238
+ new.target_column = (
239
+ new.target_column if new.target_column in columns else None
240
+ )
241
+ new.extra_columns.extras = new.extra_columns.extras.intersection(columns)
242
+
243
+ return new
244
+
245
+ if kwargs:
246
+ raise NotImplementedError(
247
+ f"Options {kwargs} are not supported for this input"
248
+ )
249
+
202
250
  def assign_predictions(
203
251
  self,
204
252
  model: VMModel,
@@ -243,10 +291,10 @@ class VMDataset:
243
291
 
244
292
  # if the user passes a column name, we assume it has precomputed predictions
245
293
  if prediction_column:
246
- prediction_values = self.df[prediction_column].values
294
+ prediction_values = self._df[prediction_column].values
247
295
 
248
296
  if probability_column:
249
- probability_values = self.df[probability_column].values
297
+ probability_values = self._df[probability_column].values
250
298
 
251
299
  if prediction_values is None:
252
300
  X = self.df if isinstance(model, (FunctionModel, PipelineModel)) else self.x
@@ -320,6 +368,33 @@ class VMDataset:
320
368
  f"Extra column {column_name} with {len(column_values)} values added to the dataset"
321
369
  )
322
370
 
371
+ @property
372
+ def df(self) -> pd.DataFrame:
373
+ """
374
+ Returns the dataset as a pandas DataFrame.
375
+
376
+ Returns:
377
+ pd.DataFrame: The dataset as a pandas DataFrame.
378
+ """
379
+ # only include feature, text and target columns
380
+ # don't include internal pred and prob columns
381
+ columns = self.feature_columns.copy()
382
+
383
+ # text column can also be a feature column so don't add it twice
384
+ if self.text_column and self.text_column not in columns:
385
+ columns.append(self.text_column)
386
+
387
+ if self.extra_columns.extras:
388
+ # add user-defined extra columns
389
+ columns.extend(self.extra_columns.extras)
390
+
391
+ if self.target_column:
392
+ # shouldn't be a feature column but add this to be safe
393
+ assert self.target_column not in columns
394
+ columns.append(self.target_column)
395
+
396
+ return as_df(self._df[columns])
397
+
323
398
  @property
324
399
  def x(self) -> np.ndarray:
325
400
  """
@@ -328,7 +403,7 @@ class VMDataset:
328
403
  Returns:
329
404
  np.ndarray: The input features.
330
405
  """
331
- return self.df[self.feature_columns].to_numpy()
406
+ return self._df[self.feature_columns].to_numpy()
332
407
 
333
408
  @property
334
409
  def y(self) -> np.ndarray:
@@ -338,7 +413,7 @@ class VMDataset:
338
413
  Returns:
339
414
  np.ndarray: The target variables.
340
415
  """
341
- return self.df[self.target_column].to_numpy()
416
+ return self._df[self.target_column].to_numpy()
342
417
 
343
418
  def y_pred(self, model) -> np.ndarray:
344
419
  """Returns the predictions for a given model.
@@ -352,7 +427,7 @@ class VMDataset:
352
427
  Returns:
353
428
  np.ndarray: The predictions for the model
354
429
  """
355
- return np.stack(self.df[self.prediction_column(model)].values)
430
+ return np.stack(self._df[self.prediction_column(model)].values)
356
431
 
357
432
  def y_prob(self, model) -> np.ndarray:
358
433
  """Returns the probabilities for a given model.
@@ -363,23 +438,23 @@ class VMDataset:
363
438
  Returns:
364
439
  np.ndarray: The probability variables.
365
440
  """
366
- return self.df[self.probability_column(model)].values
441
+ return self._df[self.probability_column(model)].values
367
442
 
368
443
  def x_df(self):
369
444
  """Returns a dataframe containing only the feature columns"""
370
- return as_df(self.df[self.feature_columns])
445
+ return as_df(self._df[self.feature_columns])
371
446
 
372
447
  def y_df(self) -> pd.DataFrame:
373
448
  """Returns a dataframe containing the target column"""
374
- return as_df(self.df[self.target_column])
449
+ return as_df(self._df[self.target_column])
375
450
 
376
451
  def y_pred_df(self, model) -> pd.DataFrame:
377
452
  """Returns a dataframe containing the predictions for a given model"""
378
- return as_df(self.df[self.prediction_column(model)])
453
+ return as_df(self._df[self.prediction_column(model)])
379
454
 
380
455
  def y_prob_df(self, model) -> pd.DataFrame:
381
456
  """Returns a dataframe containing the probabilities for a given model"""
382
- return as_df(self.df[self.probability_column(model)])
457
+ return as_df(self._df[self.probability_column(model)])
383
458
 
384
459
  def target_classes(self):
385
460
  """Returns the target class labels or unique values of the target column."""
@@ -417,7 +492,6 @@ class DataFrameDataset(VMDataset):
417
492
  feature_columns: list = None,
418
493
  text_column: str = None,
419
494
  target_class_labels: dict = None,
420
- options: dict = None,
421
495
  date_time_index: bool = False,
422
496
  ):
423
497
  """
@@ -432,7 +506,6 @@ class DataFrameDataset(VMDataset):
432
506
  feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
433
507
  text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
434
508
  target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
435
- options (dict, optional): Additional options for the dataset. Defaults to None.
436
509
  date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
437
510
  """
438
511
  index = None
@@ -451,7 +524,6 @@ class DataFrameDataset(VMDataset):
451
524
  feature_columns=feature_columns,
452
525
  text_column=text_column,
453
526
  target_class_labels=target_class_labels,
454
- options=options,
455
527
  date_time_index=date_time_index,
456
528
  )
457
529
 
@@ -471,7 +543,6 @@ class PolarsDataset(VMDataset):
471
543
  feature_columns: list = None,
472
544
  text_column: str = None,
473
545
  target_class_labels: dict = None,
474
- options: dict = None,
475
546
  date_time_index: bool = False,
476
547
  ):
477
548
  """
@@ -486,7 +557,6 @@ class PolarsDataset(VMDataset):
486
557
  feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
487
558
  text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
488
559
  target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
489
- options (dict, optional): Additional options for the dataset. Defaults to None.
490
560
  date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
491
561
  """
492
562
  super().__init__(
@@ -501,7 +571,6 @@ class PolarsDataset(VMDataset):
501
571
  feature_columns=feature_columns,
502
572
  text_column=text_column,
503
573
  target_class_labels=target_class_labels,
504
- options=options,
505
574
  date_time_index=date_time_index,
506
575
  )
507
576
 
@@ -524,7 +593,6 @@ class TorchDataset(VMDataset):
524
593
  feature_columns: list = None,
525
594
  text_column: str = None,
526
595
  target_class_labels: dict = None,
527
- options: dict = None,
528
596
  ):
529
597
  """
530
598
  Initializes a TorchDataset instance.
@@ -582,5 +650,4 @@ class TorchDataset(VMDataset):
582
650
  text_column=text_column,
583
651
  extra_columns=extra_columns,
584
652
  target_class_labels=target_class_labels,
585
- options=options,
586
653
  )
@@ -16,6 +16,8 @@ logger = get_logger(__name__)
16
16
 
17
17
  @dataclass
18
18
  class ExtraColumns:
19
+ # TODO: this now holds internal (pred, prob and group_by) cols as well as
20
+ # user-defined extra columns. These should probably be separated.
19
21
  """Extra columns for the dataset."""
20
22
 
21
23
  extras: Set[str] = field(default_factory=set)
@@ -0,0 +1,31 @@
1
+ # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
+ # See the LICENSE file in the root of this repository for details.
3
+ # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
+
5
+ """Base class for ValidMind Input types"""
6
+
7
+ from abc import ABC
8
+
9
+
10
+ class VMInput(ABC):
11
+ """
12
+ Base class for ValidMind Input types
13
+ """
14
+
15
+ def with_options(self, **kwargs) -> "VMInput":
16
+ """
17
+ Allows for setting options on the input object that are passed by the user
18
+ when using the input to run a test or set of tests
19
+
20
+ To allow options, just override this method in the subclass (see VMDataset)
21
+ and ensure that it returns a new instance of the input with the specified options
22
+ set.
23
+
24
+ Args:
25
+ **kwargs: Arbitrary keyword arguments that will be passed to the input object
26
+
27
+ Returns:
28
+ VMInput: A new instance of the input with the specified options set
29
+ """
30
+ if kwargs:
31
+ raise NotImplementedError("This type of input does not support options")
@@ -7,11 +7,13 @@ Model class wrapper module
7
7
  """
8
8
  import importlib
9
9
  import inspect
10
- from abc import ABC, abstractmethod
10
+ from abc import abstractmethod
11
11
  from dataclasses import dataclass
12
12
 
13
13
  from validmind.errors import MissingOrInvalidModelPredictFnError
14
14
 
15
+ from .input import VMInput
16
+
15
17
  SUPPORTED_LIBRARIES = {
16
18
  "catboost": "CatBoostModel",
17
19
  "xgboost": "XGBoostModel",
@@ -77,7 +79,7 @@ class ModelAttributes:
77
79
  )
78
80
 
79
81
 
80
- class VMModel(ABC):
82
+ class VMModel(VMInput):
81
83
  """
82
84
  An base class that wraps a trained model instance and its associated data.
83
85
 
@@ -78,13 +78,12 @@ class TestInput:
78
78
  ... (any): Any other arbitrary inputs that can be used by tests
79
79
  """
80
80
 
81
- # TODO: we need to look into adding metadata for test inputs and logging that
82
-
83
81
  def __init__(self, inputs):
84
82
  """Initialize with either a dictionary of inputs"""
85
83
  for key, value in inputs.items():
86
84
  # 1) retrieve input object from input registry if an input_id string is provided
87
85
  # 2) check the input_id type if a list of inputs (mix of strings and objects) is provided
86
+ # 3) if its a dict, it should contain the `input_id` key as well as other options
88
87
  if isinstance(value, str):
89
88
  value = input_registry.get(key=value)
90
89
  elif isinstance(value, list) or isinstance(value, tuple):
@@ -92,6 +91,14 @@ class TestInput:
92
91
  input_registry.get(key=v) if isinstance(v, str) else v
93
92
  for v in value
94
93
  ]
94
+ elif isinstance(value, dict):
95
+ assert "input_id" in value, (
96
+ "Input dictionary must contain an 'input_id' key "
97
+ "to retrieve the input object from the input registry."
98
+ )
99
+ value = input_registry.get(key=value.get("input_id")).with_options(
100
+ **{k: v for k, v in value.items() if k != "input_id"}
101
+ )
95
102
 
96
103
  setattr(self, key, value)
97
104
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: validmind
3
- Version: 2.4.9
3
+ Version: 2.4.13
4
4
  Summary: ValidMind Developer Framework
5
5
  License: Commercial License
6
6
  Author: Andres Rodriguez
@@ -1,9 +1,9 @@
1
1
  validmind/__init__.py,sha256=UfmzPwUCdUWbWq3zPqqmq4jw0_kfl3hX4U72p_seE4I,3700
2
- validmind/__version__.py,sha256=jexP_bn42GgaPCRDRYZqRI6UFG4hFgjug7OJbg8uE0Y,22
2
+ validmind/__version__.py,sha256=tz33jXQt_yV5A4oaguyO5xgSigC1jcfj0c7SavSDPo0,23
3
3
  validmind/ai/test_descriptions.py,sha256=Q1Ftus4x5eiVLKWJu7hqPLukBQZzhy-dARqq_6_JWtk,9464
4
4
  validmind/ai/utils.py,sha256=TEXII_S5CpkpczzSyHwTlqLcPMLnPBJWEBR6QFMKh1U,3421
5
- validmind/api_client.py,sha256=xr9VNqCmA_WFf8rVm-0M0pmzVyLAPFOnfEe4dAog1LA,17144
6
- validmind/client.py,sha256=UnsEwWK_s3nuktr6i2U3haLjjlWRGR6H431jsZpKEDA,18649
5
+ validmind/api_client.py,sha256=JZIJWuYtvl-VEVi_AK4c839Fn7cGa40J2d4_4FUZcno,17483
6
+ validmind/client.py,sha256=guXu_9um4caPpepbAsfKgjLc63Ygx07Lgp8wZJD3p6Y,18653
7
7
  validmind/client_config.py,sha256=58L6s6-9vFWC9vkSs_98CjV1YWmlksdhblJtPQxQsAk,1611
8
8
  validmind/datasets/__init__.py,sha256=oYfcvW7BAyUgpghBOnTeGbQF6tpFAWg38rRirdLr8m8,262
9
9
  validmind/datasets/classification/__init__.py,sha256=HlTOBLyb6IorRYmAhP3AIyX-l-NyemyDjV8BBOdrCrY,1787
@@ -60,7 +60,7 @@ validmind/datasets/regression/models/fred_loan_rates_model_5.pkl,sha256=FkNLHq9x
60
60
  validmind/errors.py,sha256=qy7Gp6Uom5J6WmLw-CpE5zaTN96SiN7kJjDGBaJdoxY,8023
61
61
  validmind/html_templates/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
62
62
  validmind/html_templates/content_blocks.py,sha256=LTsv2Hr_drUUZVLEfY2JcT4z0M-45RGYy2sFInt1VKY,3998
63
- validmind/input_registry.py,sha256=zexO3x-vncaoWvQ6VfkvgDLn6x72e2BNel_jCbrVHSE,793
63
+ validmind/input_registry.py,sha256=8C_mrhgLT72hwbt_lo3ZwXb5NCyIcSuCQI1HdJ3bK2A,1042
64
64
  validmind/logging.py,sha256=J1Y1dYCH1dtkoYCHoXMOQH_B7EO4fJytWRDrDqZZz8U,5204
65
65
  validmind/models/__init__.py,sha256=lraTbNwoKckXNP3Dbyj-euI78UTkZ_w5wpUOb8l5nWs,729
66
66
  validmind/models/foundation.py,sha256=ZdVmwwRVbjgqMyfjguyf9Lka_KcgJnDD7ho8zv0gQok,1842
@@ -101,7 +101,7 @@ validmind/tests/data_validation/BivariateScatterPlots.py,sha256=9QcMcbc3yiZl8LbV
101
101
  validmind/tests/data_validation/ChiSquaredFeaturesTable.py,sha256=5lPRnNbHjxEdMGtp6fhg4cYy7FLSHbHxUtpymU8GRO0,5977
102
102
  validmind/tests/data_validation/ClassImbalance.py,sha256=nRNHDtjCAgLQfWbCWP-zX5CFZSjJno8NWLXZBPaG_yA,6882
103
103
  validmind/tests/data_validation/DFGLSArch.py,sha256=D6_kR4AkvctjK-MRUJCc9cELwdmTT_085QIGagliSsA,5365
104
- validmind/tests/data_validation/DatasetDescription.py,sha256=ftB6RrEc4Sfc2qGtOmX7Ogf8yJxGV9QJY29OzjdQCyE,11365
104
+ validmind/tests/data_validation/DatasetDescription.py,sha256=-V8pO260iRs1QzJfJFAP_YwayBQcCwYD_X51NZoLBXA,11366
105
105
  validmind/tests/data_validation/DatasetSplit.py,sha256=4BCeshqxvNSwmRwXw37uIJ3xy8JnU7ZJIWPyzzTdLJ0,5068
106
106
  validmind/tests/data_validation/DescriptiveStatistics.py,sha256=5u1qx-VGD8aVBFEi_4ffCWZfJ17YXbT5PO1Na52KoNo,6333
107
107
  validmind/tests/data_validation/Duplicates.py,sha256=oO1LPMyclfJno1_AAITpKB-fZryU-705BvYCm5StClw,5592
@@ -266,10 +266,10 @@ validmind/tests/prompt_validation/Robustness.py,sha256=fBdkYnO9yoBazz4wD-l62tT8D
266
266
  validmind/tests/prompt_validation/Specificity.py,sha256=h3gKRTTi2rfnGWmGC1YnSt2s_VbZU4KX0iY7LciZ3PU,6068
267
267
  validmind/tests/prompt_validation/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
268
268
  validmind/tests/prompt_validation/ai_powered_test.py,sha256=7TTeIR5GotQosm7oVT8Y3KnwPB3XkVT1Fzhckpr-SgE,1963
269
- validmind/tests/run.py,sha256=24E5pRg6p0dUHoK6shB9KeKHWTOEJa5HymT6tD8Ozl4,15574
269
+ validmind/tests/run.py,sha256=WuLV8iY2xN7bRPu5px75-rgRKeh_XYPtbdLhqG8Dugo,15874
270
270
  validmind/tests/test_providers.py,sha256=47xe5eb5ufvj1jmhdRsbSvDQTXSDpFDFNeXg3xtXwhw,5320
271
271
  validmind/tests/utils.py,sha256=kNrxfUYbj4DwmkZtpp_1rG4GMUGxYEhvqnYR_A7qAKM,471
272
- validmind/unit_metrics/__init__.py,sha256=AlFnWA9pmzVf8xysPxYpQ3kBTQ81-YVxRTJpgC0Q41w,7344
272
+ validmind/unit_metrics/__init__.py,sha256=mFk52eU7bOQKTpruKSrPyzjmxFUpIi5RZuwIE5BVFHU,7345
273
273
  validmind/unit_metrics/classification/sklearn/Accuracy.py,sha256=2Ra_OpKceY01h1dAFCqRFAwe--K2oVbCUiYjM5AH_nQ,480
274
274
  validmind/unit_metrics/classification/sklearn/F1.py,sha256=Uiq5sPyNpALhApTkmLUhh76mF91bLCABB5OVHOlbmGo,437
275
275
  validmind/unit_metrics/classification/sklearn/Precision.py,sha256=8zO5VDZhfT8R2VFYiV-CzsZwhsTwVAKca4nhD-qALLw,458
@@ -288,12 +288,13 @@ validmind/unit_metrics/regression/sklearn/MeanSquaredError.py,sha256=7UQnDTTO7yR
288
288
  validmind/unit_metrics/regression/sklearn/RSquaredScore.py,sha256=h9U5ndtnJfNNtKPZIo5n3KRp-m4akQcEo0t1iSwjVzY,420
289
289
  validmind/unit_metrics/regression/sklearn/RootMeanSquaredError.py,sha256=_5IQIU9jNfmTE4NLJvaRWXbudRGV2PS7nYF5e4fkSMY,556
290
290
  validmind/utils.py,sha256=MQDsW7YuwEJ50tA01n3xb8D_Ihmji_Mn22AlMnJJQT8,15819
291
- validmind/vm_models/__init__.py,sha256=lmWCD2u4tW6_AH39UnJ24sCcMUcsHbUttz7SaZfrh3s,1168
291
+ validmind/vm_models/__init__.py,sha256=V5DH-E1Rkvl-HQEkilppVCHBag9MQXkzyoORLW3LSGQ,1210
292
292
  validmind/vm_models/dataset/__init__.py,sha256=U4CxZjdoc0dd9u2AqBl5PJh1UVbzXWNrmundmjLF-qE,346
293
- validmind/vm_models/dataset/dataset.py,sha256=X0vKp1NuL0k5WLILAAlmnPs_WHR8Ji6ovlfWLjJs3qk,23305
294
- validmind/vm_models/dataset/utils.py,sha256=ygT6hUw0KklKCboo7tqLxh_hf-dEiaccVyCpR9DCPF8,5177
293
+ validmind/vm_models/dataset/dataset.py,sha256=0N1yz4pqvop5FJQdbf2QodLINsNJpI7DQ-ImH-usU6U,25579
294
+ validmind/vm_models/dataset/utils.py,sha256=DRFCg93YE7sTRrWAGt1RIyvzPjINagMk6zUw7z692d0,5325
295
295
  validmind/vm_models/figure.py,sha256=iSrvPcCG5sQrMkX1Fh6c5utRzaroh3bc6IlnGDOK_Eg,6651
296
- validmind/vm_models/model.py,sha256=b-UL73EWOpj-X5aQbHQ3HLkONHCH9hYwUlKxVwPC6gI,6088
296
+ validmind/vm_models/input.py,sha256=qLdqz_bktr4v0YcPha2vFdDvmkC-btT1pH9zBIkt1OY,1046
297
+ validmind/vm_models/model.py,sha256=P-zKbh0TrU_4ZK-bA0l83h6K6nfU6v0lIpC4mfCl6Fw,6115
297
298
  validmind/vm_models/test/metric.py,sha256=DvXMju36JzxArXNWimq3SSrSUoIHkyvDbuhbgBOKxkk,3357
298
299
  validmind/vm_models/test/metric_result.py,sha256=Bak4GDrMlNq5NtgP5exwlPsKZgz3tWgtC6jZqtHjvqM,1987
299
300
  validmind/vm_models/test/output_template.py,sha256=njqCAMyLxwadkCWhACVskyL9-psTgmUysaeeirTVAX4,1500
@@ -302,13 +303,13 @@ validmind/vm_models/test/result_wrapper.py,sha256=Zb2IVjB3UTIMxTjmv9xZ1kaIIAd_dU
302
303
  validmind/vm_models/test/test.py,sha256=2Wbte09E4l7fUXwfQije0LQbPeSuh2Wpbyt4ddwyVks,3419
303
304
  validmind/vm_models/test/threshold_test.py,sha256=xSadM5t9Z-XZjkxu7LKmeljy2bdwTwXrUh-mkdePdLM,3740
304
305
  validmind/vm_models/test/threshold_test_result.py,sha256=EXP-g_e3NsnpkvNgYew030qVUoY6ZTHyuuFUXaq-BuM,1954
305
- validmind/vm_models/test_context.py,sha256=AN7-atBgOcD04MLVitCFJYooxF6_iNmvI2H4nkv32iw,9035
306
+ validmind/vm_models/test_context.py,sha256=SGqoF_OeFC7Fj1jg5CPO1LOpfB7mA1FPwm61SYP8f2o,9475
306
307
  validmind/vm_models/test_suite/runner.py,sha256=aewxadRfoOPH48jes2Gtb3Ju_FWFfVM_9ARIAJHD4wA,6982
307
308
  validmind/vm_models/test_suite/summary.py,sha256=GQRNe2ZvvqjQN0yKmaN7ohAUjRFQIN4YYUYxfOuWN6M,4682
308
309
  validmind/vm_models/test_suite/test.py,sha256=_GfbK36l98SjzgVcucmp0OKBJKqMW3neO7SqJ3EWeps,5049
309
310
  validmind/vm_models/test_suite/test_suite.py,sha256=Cns2wL54v0T5Mv5_HJb3kMeaa4rtycdqT8KxK9_rWEU,6279
310
- validmind-2.4.9.dist-info/LICENSE,sha256=XonPUfwjvrC5Ombl3y-ko0Wubb1xdG_7nzvIbkZRKHw,35772
311
- validmind-2.4.9.dist-info/METADATA,sha256=pZA_sXTyykGQNAwB6yPyMQIjFgvekoh_jdNhqf0eaXY,4250
312
- validmind-2.4.9.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
313
- validmind-2.4.9.dist-info/entry_points.txt,sha256=HuW7YyOv9u_OEWpViQXtv0nfoI67uieJHawKWA4Hv9A,76
314
- validmind-2.4.9.dist-info/RECORD,,
311
+ validmind-2.4.13.dist-info/LICENSE,sha256=XonPUfwjvrC5Ombl3y-ko0Wubb1xdG_7nzvIbkZRKHw,35772
312
+ validmind-2.4.13.dist-info/METADATA,sha256=2GruFcfXawHmaIiQzC9nVgpIu58-m17hWgQSsf3BhNk,4251
313
+ validmind-2.4.13.dist-info/WHEEL,sha256=sP946D7jFCHeNz5Iq4fL4Lu-PrWrFsgfLXbbkciIZwg,88
314
+ validmind-2.4.13.dist-info/entry_points.txt,sha256=HuW7YyOv9u_OEWpViQXtv0nfoI67uieJHawKWA4Hv9A,76
315
+ validmind-2.4.13.dist-info/RECORD,,