validmind 2.1.1__py3-none-any.whl → 2.2.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. validmind/__version__.py +1 -1
  2. validmind/ai.py +3 -3
  3. validmind/api_client.py +2 -3
  4. validmind/client.py +68 -25
  5. validmind/datasets/llm/rag/__init__.py +11 -0
  6. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_1.csv +30 -0
  7. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_2.csv +30 -0
  8. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_3.csv +53 -0
  9. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_4.csv +53 -0
  10. validmind/datasets/llm/rag/datasets/rfp_existing_questions_client_5.csv +53 -0
  11. validmind/datasets/llm/rag/rfp.py +41 -0
  12. validmind/html_templates/__init__.py +0 -0
  13. validmind/html_templates/content_blocks.py +89 -14
  14. validmind/models/__init__.py +7 -4
  15. validmind/models/foundation.py +8 -34
  16. validmind/models/function.py +51 -0
  17. validmind/models/huggingface.py +16 -46
  18. validmind/models/metadata.py +42 -0
  19. validmind/models/pipeline.py +66 -0
  20. validmind/models/pytorch.py +8 -42
  21. validmind/models/r_model.py +33 -82
  22. validmind/models/sklearn.py +39 -38
  23. validmind/template.py +8 -26
  24. validmind/tests/__init__.py +43 -20
  25. validmind/tests/data_validation/ANOVAOneWayTable.py +1 -1
  26. validmind/tests/data_validation/ChiSquaredFeaturesTable.py +1 -1
  27. validmind/tests/data_validation/DescriptiveStatistics.py +2 -4
  28. validmind/tests/data_validation/Duplicates.py +1 -1
  29. validmind/tests/data_validation/IsolationForestOutliers.py +2 -2
  30. validmind/tests/data_validation/LaggedCorrelationHeatmap.py +1 -1
  31. validmind/tests/data_validation/TargetRateBarPlots.py +1 -1
  32. validmind/tests/data_validation/nlp/LanguageDetection.py +59 -0
  33. validmind/tests/data_validation/nlp/PolarityAndSubjectivity.py +48 -0
  34. validmind/tests/data_validation/nlp/Punctuations.py +11 -12
  35. validmind/tests/data_validation/nlp/Sentiment.py +57 -0
  36. validmind/tests/data_validation/nlp/Toxicity.py +45 -0
  37. validmind/tests/decorator.py +2 -2
  38. validmind/tests/model_validation/BertScore.py +100 -98
  39. validmind/tests/model_validation/BleuScore.py +93 -64
  40. validmind/tests/model_validation/ContextualRecall.py +74 -91
  41. validmind/tests/model_validation/MeteorScore.py +86 -74
  42. validmind/tests/model_validation/RegardScore.py +103 -121
  43. validmind/tests/model_validation/RougeScore.py +118 -0
  44. validmind/tests/model_validation/TokenDisparity.py +84 -121
  45. validmind/tests/model_validation/ToxicityScore.py +109 -123
  46. validmind/tests/model_validation/embeddings/CosineSimilarityComparison.py +96 -0
  47. validmind/tests/model_validation/embeddings/CosineSimilarityHeatmap.py +71 -0
  48. validmind/tests/model_validation/embeddings/EuclideanDistanceComparison.py +92 -0
  49. validmind/tests/model_validation/embeddings/EuclideanDistanceHeatmap.py +69 -0
  50. validmind/tests/model_validation/embeddings/PCAComponentsPairwisePlots.py +78 -0
  51. validmind/tests/model_validation/embeddings/StabilityAnalysis.py +35 -23
  52. validmind/tests/model_validation/embeddings/StabilityAnalysisKeyword.py +3 -0
  53. validmind/tests/model_validation/embeddings/StabilityAnalysisRandomNoise.py +7 -1
  54. validmind/tests/model_validation/embeddings/StabilityAnalysisSynonyms.py +3 -0
  55. validmind/tests/model_validation/embeddings/StabilityAnalysisTranslation.py +3 -0
  56. validmind/tests/model_validation/embeddings/TSNEComponentsPairwisePlots.py +99 -0
  57. validmind/tests/model_validation/ragas/AnswerCorrectness.py +131 -0
  58. validmind/tests/model_validation/ragas/AnswerRelevance.py +134 -0
  59. validmind/tests/model_validation/ragas/AnswerSimilarity.py +119 -0
  60. validmind/tests/model_validation/ragas/AspectCritique.py +167 -0
  61. validmind/tests/model_validation/ragas/ContextEntityRecall.py +133 -0
  62. validmind/tests/model_validation/ragas/ContextPrecision.py +123 -0
  63. validmind/tests/model_validation/ragas/ContextRecall.py +123 -0
  64. validmind/tests/model_validation/ragas/ContextRelevancy.py +114 -0
  65. validmind/tests/model_validation/ragas/Faithfulness.py +119 -0
  66. validmind/tests/model_validation/ragas/utils.py +66 -0
  67. validmind/tests/model_validation/sklearn/OverfitDiagnosis.py +3 -7
  68. validmind/tests/model_validation/sklearn/PermutationFeatureImportance.py +8 -9
  69. validmind/tests/model_validation/sklearn/PopulationStabilityIndex.py +5 -10
  70. validmind/tests/model_validation/sklearn/PrecisionRecallCurve.py +3 -2
  71. validmind/tests/model_validation/sklearn/ROCCurve.py +2 -1
  72. validmind/tests/model_validation/sklearn/RegressionR2Square.py +1 -1
  73. validmind/tests/model_validation/sklearn/RobustnessDiagnosis.py +2 -3
  74. validmind/tests/model_validation/sklearn/SHAPGlobalImportance.py +7 -11
  75. validmind/tests/model_validation/sklearn/WeakspotsDiagnosis.py +3 -4
  76. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlot.py +1 -1
  77. validmind/tests/model_validation/statsmodels/RegressionModelForecastPlotLevels.py +1 -1
  78. validmind/tests/model_validation/statsmodels/RegressionModelInsampleComparison.py +1 -1
  79. validmind/tests/model_validation/statsmodels/RegressionModelOutsampleComparison.py +1 -1
  80. validmind/tests/model_validation/statsmodels/RegressionModelSummary.py +1 -1
  81. validmind/tests/model_validation/statsmodels/RegressionModelsCoeffs.py +1 -1
  82. validmind/tests/model_validation/statsmodels/RegressionModelsPerformance.py +1 -1
  83. validmind/tests/model_validation/statsmodels/ScorecardHistogram.py +5 -6
  84. validmind/unit_metrics/__init__.py +26 -49
  85. validmind/unit_metrics/composite.py +5 -1
  86. validmind/unit_metrics/regression/sklearn/AdjustedRSquaredScore.py +1 -1
  87. validmind/utils.py +56 -6
  88. validmind/vm_models/__init__.py +1 -1
  89. validmind/vm_models/dataset/__init__.py +7 -0
  90. validmind/vm_models/dataset/dataset.py +558 -0
  91. validmind/vm_models/dataset/utils.py +146 -0
  92. validmind/vm_models/model.py +97 -72
  93. validmind/vm_models/test/result_wrapper.py +61 -24
  94. validmind/vm_models/test_context.py +1 -1
  95. validmind/vm_models/test_suite/summary.py +3 -4
  96. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/METADATA +5 -3
  97. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/RECORD +100 -75
  98. validmind/models/catboost.py +0 -33
  99. validmind/models/statsmodels.py +0 -50
  100. validmind/models/xgboost.py +0 -30
  101. validmind/tests/model_validation/BertScoreAggregate.py +0 -90
  102. validmind/tests/model_validation/RegardHistogram.py +0 -148
  103. validmind/tests/model_validation/RougeMetrics.py +0 -147
  104. validmind/tests/model_validation/RougeMetricsAggregate.py +0 -133
  105. validmind/tests/model_validation/SelfCheckNLIScore.py +0 -112
  106. validmind/tests/model_validation/ToxicityHistogram.py +0 -136
  107. validmind/vm_models/dataset.py +0 -1303
  108. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/LICENSE +0 -0
  109. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/WHEEL +0 -0
  110. {validmind-2.1.1.dist-info → validmind-2.2.2.dist-info}/entry_points.txt +0 -0
@@ -1,1303 +0,0 @@
1
- # Copyright © 2023-2024 ValidMind Inc. All rights reserved.
2
- # See the LICENSE file in the root of this repository for details.
3
- # SPDX-License-Identifier: AGPL-3.0 AND ValidMind Commercial
4
-
5
- """
6
- Dataset class wrapper
7
- """
8
-
9
- import warnings
10
- from abc import ABC, abstractmethod
11
- from dataclasses import dataclass, field
12
-
13
- import numpy as np
14
- import pandas as pd
15
- import polars as pl
16
-
17
- from validmind.errors import MissingOrInvalidModelPredictFnError
18
- from validmind.logging import get_logger
19
- from validmind.vm_models.model import VMModel
20
-
21
- logger = get_logger(__name__)
22
-
23
-
24
- @dataclass
25
- class VMDataset(ABC):
26
- """
27
- Abstract base class for VM datasets.
28
- """
29
-
30
- input_id: str = None
31
-
32
- @property
33
- @abstractmethod
34
- def raw_dataset(self):
35
- """
36
- Returns the raw dataset.
37
- """
38
- pass
39
-
40
- @abstractmethod
41
- def assign_predictions(
42
- self,
43
- model,
44
- prediction_values: list = None,
45
- prediction_probabilities: list = None,
46
- prediction_column=None,
47
- probability_column=None,
48
- ):
49
- """
50
- Assigns predictions to the dataset for a given model or prediction values.
51
- The dataset is updated with a new column containing the predictions.
52
- """
53
- pass
54
-
55
- @abstractmethod
56
- def get_extra_column(self, column_name):
57
- """
58
- Returns the values of the specified extra column.
59
-
60
- Args:
61
- column_name (str): The name of the extra column.
62
-
63
- Returns:
64
- np.ndarray: The values of the extra column.
65
- """
66
- pass
67
-
68
- @abstractmethod
69
- def add_extra_column(self, column_name, column_values=None):
70
- """
71
- Adds an extra column to the dataset without modifying the dataset `features` and `target` columns.
72
-
73
- Args:
74
- column_name (str): The name of the extra column.
75
- column_values (np.ndarray, optional): The values of the extra column.
76
- """
77
- pass
78
-
79
- @property
80
- @abstractmethod
81
- def input_id(self) -> str:
82
- """
83
- Returns input id of dataset.
84
-
85
- Returns:
86
- str: input_id.
87
- """
88
- return self.input_id
89
-
90
- @property
91
- @abstractmethod
92
- def columns(self) -> list:
93
- """
94
- Returns the the list of columns in the dataset.
95
-
96
- Returns:
97
- List[str]: The columns list.
98
- """
99
- pass
100
-
101
- @property
102
- @abstractmethod
103
- def target_column(self) -> str:
104
- """
105
- Returns the target column name of the dataset.
106
-
107
- Returns:
108
- str: The target column name.
109
- """
110
- pass
111
-
112
- @property
113
- @abstractmethod
114
- def feature_columns(self) -> list:
115
- """
116
- Returns the feature columns of the dataset. If _feature_columns is None,
117
- it returns all columns except the target column.
118
-
119
- Returns:
120
- list: The list of feature column names.
121
- """
122
- pass
123
-
124
- @property
125
- @abstractmethod
126
- def text_column(self) -> str:
127
- """
128
- Returns the text column of the dataset.
129
-
130
- Returns:
131
- str: The text column name.
132
- """
133
- pass
134
-
135
- @property
136
- @abstractmethod
137
- def x(self) -> np.ndarray:
138
- """
139
- Returns the input features (X) of the dataset.
140
-
141
- Returns:
142
- np.ndarray: The input features.
143
- """
144
- pass
145
-
146
- @property
147
- @abstractmethod
148
- def y(self) -> np.ndarray:
149
- """
150
- Returns the target variables (y) of the dataset.
151
-
152
- Returns:
153
- np.ndarray: The target variables.
154
- """
155
- pass
156
-
157
- @abstractmethod
158
- def y_pred(self, model) -> np.ndarray:
159
- """
160
- Returns the prediction values (y_pred) of the dataset for a given model.
161
-
162
- Returns:
163
- np.ndarray: The prediction values.
164
- """
165
- pass
166
-
167
- def y_prob(self, model) -> np.ndarray:
168
- """
169
- Returns the prediction probabilities (y_prob) of the dataset for a given model.
170
-
171
- Returns:
172
- np.ndarray: The prediction probabilities.
173
- """
174
- pass
175
-
176
- @property
177
- @abstractmethod
178
- def df(self):
179
- """
180
- Returns the dataset as a pandas DataFrame.
181
-
182
- Returns:
183
- pd.DataFrame: The dataset as a DataFrame.
184
- """
185
- pass
186
-
187
- @property
188
- @abstractmethod
189
- def copy(self):
190
- """
191
- Returns a copy of the raw_dataset dataframe.
192
- """
193
- pass
194
-
195
- @abstractmethod
196
- def x_df(self):
197
- """
198
- Returns the non target and prediction columns.
199
-
200
- Returns:
201
- pd.DataFrame: The non target and prediction columns .
202
- """
203
- pass
204
-
205
- @abstractmethod
206
- def y_df(self):
207
- """
208
- Returns the target columns (y) of the dataset.
209
-
210
- Returns:
211
- pd.DataFrame: The target columns.
212
- """
213
- pass
214
-
215
- @abstractmethod
216
- def y_pred_df(self, model):
217
- """
218
- Returns the target columns (y) of the dataset.
219
-
220
- Returns:
221
- pd.DataFrame: The target columns.
222
- """
223
- pass
224
-
225
- @abstractmethod
226
- def y_prob_df(self, model):
227
- """
228
- Returns the target columns (y) of the dataset.
229
-
230
- Returns:
231
- pd.DataFrame: The target columns.
232
- """
233
- pass
234
-
235
- @abstractmethod
236
- def prediction_column(self, model) -> str:
237
- """
238
- Returns the prediction column name of the dataset.
239
-
240
- Returns:
241
- str: The prediction column name.
242
- """
243
- pass
244
-
245
- def probability_column(self, model) -> str:
246
- """
247
- Returns the probability column name of the dataset.
248
-
249
- Returns:
250
- str: The probability column name.
251
- """
252
- pass
253
-
254
- @abstractmethod
255
- def get_features_columns(self):
256
- """
257
- Returns the column names of the feature variables.
258
-
259
- Returns:
260
- List[str]: The column names of the feature variables.
261
- """
262
- pass
263
-
264
- @abstractmethod
265
- def get_numeric_features_columns(self):
266
- """
267
- Returns the column names of the numeric feature variables.
268
-
269
- Returns:
270
- List[str]: The column names of the numeric feature variables.
271
- """
272
- pass
273
-
274
- @abstractmethod
275
- def get_categorical_features_columns(self):
276
- """
277
- Returns the column names of the categorical feature variables.
278
-
279
- Returns:
280
- List[str]: The column names of the categorical feature variables.
281
- """
282
- pass
283
-
284
-
285
- @dataclass
286
- class NumpyDataset(VMDataset):
287
- """
288
- VM dataset implementation for NumPy arrays.
289
- """
290
-
291
- _input_id: str = (None,)
292
- _raw_dataset: np.ndarray = None
293
- _index: np.ndarray = None
294
- _index_name: str = None
295
- _columns: list = field(init=True, default=None)
296
- _target_column: str = field(init=True, default=None)
297
- _feature_columns: list = field(init=True, default=None)
298
- _text_column: str = field(init=True, default=None)
299
- _type: str = "generic"
300
- _target_class_labels: dict = field(init=True, default=None)
301
- _df: pd.DataFrame = field(init=True, default=None)
302
- _extra_columns: dict = field(
303
- default_factory=lambda: {
304
- "prediction_columns": {},
305
- "probability_columns": {},
306
- "group_by_column": None,
307
- }
308
- )
309
-
310
- def __init__(
311
- self,
312
- raw_dataset,
313
- input_id: str = None,
314
- model: VMModel = None,
315
- index=None,
316
- index_name=None,
317
- date_time_index=False,
318
- columns=None,
319
- target_column: str = None,
320
- feature_columns: list = None,
321
- text_column=None,
322
- extra_columns: dict = None,
323
- target_class_labels: dict = None,
324
- options: dict = None,
325
- ):
326
- """
327
- Initializes a NumpyDataset instance.
328
-
329
- Args:
330
- raw_dataset (np.ndarray): The raw dataset as a NumPy array.
331
- index (np.ndarray): The raw dataset index as a NumPy array.
332
- index_name (str): The raw dataset index name as a NumPy array.
333
- date_time_index (bool): Whether the index is a datetime index.
334
- columns (List[str], optional): The column names of the dataset. Defaults to None.
335
- target_column (str, optional): The target column name of the dataset. Defaults to None.
336
- feature_columns (str, optional): The feature column names of the dataset. Defaults to None.
337
- text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
338
- target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
339
- options (Dict, optional): Additional options for the dataset. Defaults to None.
340
- """
341
- # initialize input_id
342
- self._input_id = input_id
343
-
344
- # initialize raw dataset
345
- if not isinstance(raw_dataset, np.ndarray):
346
- raise ValueError("Expected Numpy array for attribute raw_dataset")
347
- self._raw_dataset = raw_dataset
348
-
349
- # initialize index and index name
350
- if index is not None and not isinstance(index, np.ndarray):
351
- raise ValueError("Expected Numpy array for attribute raw_dataset")
352
- self._index = index
353
- self._index_name = index_name
354
-
355
- # initialize columns and df
356
- self._columns = columns or []
357
- if not self._columns:
358
- df = pd.DataFrame(self._raw_dataset).infer_objects()
359
- self._columns = df.columns.to_list()
360
- else:
361
- df = pd.DataFrame(self._raw_dataset, columns=self._columns).infer_objects()
362
-
363
- # set index to dataframe
364
- if index is not None:
365
- df.set_index(pd.Index(index), inplace=True)
366
- df.index.name = index_name
367
-
368
- # attempt to convert index to datatime
369
- if date_time_index:
370
- df = self.__attempt_convert_index_to_datetime(df)
371
-
372
- # initialize dataframe
373
- self._df = df
374
-
375
- # initialize target column
376
- self._target_column = target_column
377
- # initialize extra columns
378
- self.__set_extra_columns(extra_columns)
379
- # initialize feature columns
380
- self.__set_feature_columns(feature_columns)
381
- # initialize text column, target class labels and options
382
- self._text_column = text_column
383
- self._target_class_labels = target_class_labels
384
- self.options = options
385
- if model:
386
- self.assign_predictions(model)
387
-
388
- def __set_extra_columns(self, extra_columns):
389
- if extra_columns is None:
390
- extra_columns = {
391
- "prediction_columns": {},
392
- "probability_columns": {},
393
- "group_by_column": None,
394
- }
395
- self._extra_columns = extra_columns
396
-
397
- def __set_feature_columns(self, feature_columns):
398
- ex_columns = []
399
-
400
- if self._extra_columns.get("prediction_columns"):
401
- ex_columns.extend(self._extra_columns["prediction_columns"].values())
402
-
403
- if self._extra_columns.get("group_by_column"):
404
- ex_columns.extend(self._extra_columns["group_by_column"])
405
-
406
- extra_columns_list = ex_columns if not feature_columns else []
407
-
408
- if not feature_columns:
409
- self._feature_columns = [
410
- col
411
- for col in self._columns
412
- if col != self._target_column and col not in extra_columns_list
413
- ]
414
- else:
415
- if not isinstance(feature_columns, list):
416
- raise ValueError("Expected list for attribute feature_columns")
417
- self._feature_columns = feature_columns
418
-
419
- def __attempt_convert_index_to_datetime(self, df):
420
- """
421
- Attempts to convert the index of the dataset to a datetime index
422
- and leaves the index unchanged if it fails.
423
- """
424
- converted_index = pd.to_datetime(df.index, errors="coerce")
425
-
426
- # The conversion was successful if there are no NaT values
427
- if not converted_index.isnull().any():
428
- df.index = converted_index
429
-
430
- return df
431
-
432
- def __model_id_in_probability_columns(self, model, probability_column):
433
- return model.input_id in self._extra_columns.get("probability_columns", {})
434
-
435
- def __model_id_in_prediction_columns(self, model, prediction_column):
436
- return model.input_id in self._extra_columns.get("prediction_columns", {})
437
-
438
- def __assign_prediction_values(self, model, pred_column, prediction_values):
439
- # Link the prediction column with the model
440
- self._extra_columns.setdefault("prediction_columns", {})[
441
- model.input_id
442
- ] = pred_column
443
-
444
- # Check if the predictions are multi-dimensional (e.g., embeddings)
445
- is_multi_dimensional = (
446
- isinstance(prediction_values, np.ndarray) and prediction_values.ndim > 1
447
- )
448
-
449
- if is_multi_dimensional:
450
- # For multi-dimensional outputs, convert to a list of lists to store in DataFrame
451
- self._df[pred_column] = list(map(list, prediction_values))
452
- else:
453
- # If not multi-dimensional or a standard numpy array, reshape for compatibility
454
- self._raw_dataset = np.hstack(
455
- (self._raw_dataset, np.array(prediction_values).reshape(-1, 1))
456
- )
457
- self._df[pred_column] = prediction_values
458
-
459
- # Update the dataset columns list
460
- if pred_column not in self._columns:
461
- self._columns.append(pred_column)
462
-
463
- def __assign_prediction_probabilities(
464
- self, model, prob_column, prediction_probabilities
465
- ):
466
- # Link the prediction column with the model
467
- self._extra_columns.setdefault("probability_columns", {})[
468
- model.input_id
469
- ] = prob_column
470
-
471
- # Check if the predictions are multi-dimensional (e.g., embeddings)
472
- is_multi_dimensional = (
473
- isinstance(prediction_probabilities, np.ndarray)
474
- and prediction_probabilities.ndim > 1
475
- )
476
-
477
- if is_multi_dimensional:
478
- # For multi-dimensional outputs, convert to a list of lists to store in DataFrame
479
- self._df[prob_column] = list(map(list, prediction_probabilities))
480
- else:
481
- # If not multi-dimensional or a standard numpy array, reshape for compatibility
482
- self._raw_dataset = np.hstack(
483
- (self._raw_dataset, np.array(prediction_probabilities).reshape(-1, 1))
484
- )
485
- self._df[prob_column] = prediction_probabilities
486
-
487
- # Update the dataset columns list
488
- if prob_column not in self._columns:
489
- self._columns.append(prob_column)
490
-
491
- def assign_predictions( # noqa: C901 - we need to simplify this method
492
- self,
493
- model,
494
- prediction_values: list = None,
495
- prediction_probabilities: list = None,
496
- prediction_column=None,
497
- probability_column=None,
498
- ):
499
- def _is_probability(output):
500
- """Check if the output from the predict method is probabilities."""
501
- # This is a simple check that assumes output is probabilities if they lie between 0 and 1
502
- if np.all((output >= 0) & (output <= 1)):
503
- # Check if there is at least one element that is neither 0 nor 1
504
- if np.any((output > 0) & (output < 1)):
505
- return True
506
- return np.all((output >= 0) & (output <= 1)) and np.any(
507
- (output > 0) & (output < 1)
508
- )
509
-
510
- # Step 1: Check for Model Presence
511
- if not model:
512
- raise ValueError(
513
- "Model must be provided to link prediction column with the dataset"
514
- )
515
-
516
- # Step 2: Prediction Column Provided
517
- if prediction_column:
518
- if prediction_column not in self.columns:
519
- raise ValueError(
520
- f"Prediction column {prediction_column} doesn't exist in the dataset"
521
- )
522
- if self.__model_id_in_prediction_columns(
523
- model=model, prediction_column=prediction_column
524
- ):
525
- raise ValueError(
526
- f"Prediction column {prediction_column} already linked to the VM model"
527
- )
528
- self._extra_columns.setdefault("prediction_columns", {})[
529
- model.input_id
530
- ] = prediction_column
531
-
532
- # Step 4: Prediction Values Provided without Specific Column
533
- elif prediction_values is not None:
534
- if len(prediction_values) != self.df.shape[0]:
535
- raise ValueError(
536
- "Length of prediction values doesn't match number of rows of the dataset"
537
- )
538
- pred_column = f"{model.input_id}_prediction"
539
- if pred_column in self.columns:
540
- warnings.warn(
541
- f"Prediction column {pred_column} already exists in the dataset, overwriting the existing predictions",
542
- UserWarning,
543
- )
544
-
545
- logger.info(
546
- f"Assigning prediction values to column '{pred_column}' and linked to model '{model.input_id}'"
547
- )
548
- self.__assign_prediction_values(model, pred_column, prediction_values)
549
-
550
- # Step 3: Probability Column Provided
551
- if probability_column:
552
- if probability_column not in self.columns:
553
- raise ValueError(
554
- f"Probability column {probability_column} doesn't exist in the dataset"
555
- )
556
- if self.__model_id_in_probability_columns(
557
- model=model, probability_column=probability_column
558
- ):
559
- raise ValueError(
560
- f"Probability column {probability_column} already linked to the VM model"
561
- )
562
- self._extra_columns.setdefault("probability_columns", {})[
563
- model.input_id
564
- ] = probability_column
565
-
566
- # Step 5: Prediction Probabilities Provided without Specific Column
567
- elif prediction_probabilities is not None:
568
- if len(prediction_probabilities) != self.df.shape[0]:
569
- raise ValueError(
570
- "Length of prediction probabilities doesn't match number of rows of the dataset"
571
- )
572
- prob_column = f"{model.input_id}_probabilities"
573
- if prob_column in self.columns:
574
- warnings.warn(
575
- f"Probability column {prob_column} already exists in the dataset, overwriting the existing probabilities",
576
- UserWarning,
577
- )
578
-
579
- logger.info(
580
- f"Assigning prediction probabilities to column '{prob_column}' and linked to model '{model.input_id}'"
581
- )
582
- self.__assign_prediction_probabilities(
583
- model, prob_column, prediction_probabilities
584
- )
585
-
586
- # Step 6: Neither Specific Column Nor Values Provided
587
- elif not self.__model_id_in_prediction_columns(
588
- model=model, prediction_column=prediction_column
589
- ):
590
-
591
- # Compute prediction values directly from the VM model
592
- pred_column = f"{model.input_id}_prediction"
593
- if pred_column in self.columns:
594
- logger.info(
595
- f"Prediction column {pred_column} already exist in the dataset. Linking the model with the {pred_column} column"
596
- )
597
- return
598
-
599
- logger.info("Running predict()... This may take a while")
600
-
601
- # If the model is a FoundationModel, we need to pass the DataFrame to
602
- # the predict method since it requires column names in order to format
603
- # the input prompt with its template variables
604
- x_only = (
605
- self.x_df() if model.model_library() == "FoundationModel" else self.x
606
- )
607
-
608
- prediction_values = np.array(model.predict(x_only))
609
-
610
- # Check if the prediction values are probabilities
611
- if _is_probability(prediction_values):
612
-
613
- threshold = 0.5
614
-
615
- logger.info(
616
- "Predict method returned probabilities instead of direct labels or regression values. "
617
- + "This implies the model is likely configured for a classification task with probability output."
618
- )
619
- prob_column = f"{model.input_id}_probabilities"
620
- logger.info(
621
- f"Assigning probabilities to column '{prob_column}' and computing class labels using a threshold of {threshold}."
622
- )
623
- self.__assign_prediction_probabilities(
624
- model, prob_column, prediction_values
625
- )
626
-
627
- # Convert probabilities to class labels based on the threshold
628
- prediction_classes = (prediction_values > threshold).astype(int)
629
- self.__assign_prediction_values(model, pred_column, prediction_classes)
630
-
631
- else:
632
-
633
- # If not assign the prediction values directly
634
- pred_column = f"{model.input_id}_prediction"
635
- self.__assign_prediction_values(model, pred_column, prediction_values)
636
-
637
- try:
638
- logger.info("Running predict_proba()... This may take a while")
639
- prediction_probabilities = np.array(model.predict_proba(x_only))
640
- prob_column = f"{model.input_id}_probabilities"
641
- self.__assign_prediction_probabilities(
642
- model, prob_column, prediction_probabilities
643
- )
644
- except MissingOrInvalidModelPredictFnError:
645
- # Log that predict_proba is not available or failed
646
- logger.warn(
647
- f"Model class '{model.__class__}' does not have a compatible predict_proba implementation."
648
- + " Please assign predictions directly with vm_dataset.assign_predictions(model, prediction_values)"
649
- )
650
-
651
- # Step 7: Prediction Column Already Linked
652
- else:
653
- logger.info(
654
- f"Prediction column {self._extra_columns['prediction_columns'][model.input_id]} already linked to the {model.input_id}"
655
- )
656
-
657
- def get_extra_column(self, column_name):
658
- """
659
- Returns the values of the specified extra column.
660
-
661
- Args:
662
- column_name (str): The name of the extra column.
663
-
664
- Returns:
665
- np.ndarray: The values of the extra column.
666
- """
667
- if column_name not in self.extra_columns:
668
- raise ValueError(f"Column {column_name} is not an extra column")
669
-
670
- return self._df[column_name]
671
-
672
- def add_extra_column(self, column_name, column_values=None):
673
- """
674
- Adds an extra column to the dataset without modifying the dataset `features` and `target` columns.
675
-
676
- Args:
677
- column_name (str): The name of the extra column.
678
- column_values (np.ndarray, optional): The values of the extra column.
679
- """
680
- if column_name in self.extra_columns:
681
- logger.info(f"Column {column_name} already registered as an extra column")
682
- return
683
-
684
- # The column name already exists in the dataset so we just assign the extra column
685
- if column_name in self.columns:
686
- self._extra_columns[column_name] = column_name
687
- logger.info(
688
- f"Column {column_name} exists in the dataset, registering as an extra column"
689
- )
690
- return
691
-
692
- if column_values is None:
693
- raise ValueError(
694
- "Column values must be provided when the column doesn't exist in the dataset"
695
- )
696
-
697
- if len(column_values) != self.df.shape[0]:
698
- raise ValueError(
699
- "Length of column values doesn't match number of rows of the dataset"
700
- )
701
-
702
- self._raw_dataset = np.hstack(
703
- (self._raw_dataset, np.array(column_values).reshape(-1, 1))
704
- )
705
- self._columns.append(column_name)
706
- self._df[column_name] = column_values
707
- self._extra_columns[column_name] = column_name
708
- logger.info(f"Column {column_name} added as an extra column")
709
-
710
- @property
711
- def raw_dataset(self) -> np.ndarray:
712
- """
713
- Returns the raw dataset.
714
-
715
- Returns:
716
- np.ndarray: The raw dataset.
717
- """
718
- return self._raw_dataset
719
-
720
- @property
721
- def input_id(self) -> str:
722
- """
723
- Returns input id of dataset.
724
-
725
- Returns:
726
- str: input_id.
727
- """
728
- return self._input_id
729
-
730
- @property
731
- def index(self) -> np.ndarray:
732
- """
733
- Returns index of the dataset.
734
-
735
- Returns:
736
- np.ndarray: The dataset index.
737
- """
738
- return self._index
739
-
740
- @property
741
- def index_name(self) -> str:
742
- """
743
- Returns index name of the dataset.
744
-
745
- Returns:
746
- str: The dataset index name.
747
- """
748
- return self._df.index.name
749
-
750
- @property
751
- def columns(self) -> list:
752
- """
753
- Returns the the list of columns in the dataset.
754
-
755
- Returns:
756
- List[str]: The columns list.
757
- """
758
- return self._columns
759
-
760
- @property
761
- def target_column(self) -> str:
762
- """
763
- Returns the target column name of the dataset.
764
-
765
- Returns:
766
- str: The target column name.
767
- """
768
- return self._target_column
769
-
770
- @property
771
- def extra_columns(self) -> list:
772
- """
773
- Returns the list of extra columns of the dataset.
774
-
775
- Returns:
776
- str: The extra columns list.
777
- """
778
- return self._extra_columns
779
-
780
- @property
781
- def group_by_column(self) -> str:
782
- """
783
- Returns the group by column name of the dataset.
784
-
785
- Returns:
786
- str: The group by column name.
787
- """
788
- return self._extra_columns["group_by_column"]
789
-
790
- @property
791
- def feature_columns(self) -> list:
792
- """
793
- Returns the feature columns of the dataset. If _feature_columns is None,
794
- it returns all columns except the target column.
795
-
796
- Returns:
797
- list: The list of feature column names.
798
- """
799
- return self._feature_columns or []
800
-
801
- @property
802
- def text_column(self) -> str:
803
- """
804
- Returns the text column of the dataset.
805
-
806
- Returns:
807
- str: The text column name.
808
- """
809
- return self._text_column
810
-
811
- @property
812
- def x(self) -> np.ndarray:
813
- """
814
- Returns the input features (X) of the dataset.
815
-
816
- Returns:
817
- np.ndarray: The input features.
818
- """
819
- return self.raw_dataset[
820
- :,
821
- [
822
- self.columns.index(name)
823
- for name in self.columns
824
- if name in self.feature_columns
825
- ],
826
- ]
827
-
828
- @property
829
- def y(self) -> np.ndarray:
830
- """
831
- Returns the target variables (y) of the dataset.
832
-
833
- Returns:
834
- np.ndarray: The target variables.
835
- """
836
- return self.raw_dataset[
837
- :,
838
- [
839
- self.columns.index(name)
840
- for name in self.columns
841
- if name == self.target_column
842
- ],
843
- ]
844
-
845
- def y_pred(self, model) -> np.ndarray:
846
- """
847
- Returns the prediction variables for a given model, accommodating
848
- both scalar predictions and multi-dimensional outputs such as embeddings.
849
-
850
- Args:
851
- model (VMModel): The model whose predictions are sought.
852
-
853
- Returns:
854
- np.ndarray: The prediction variables, either as a flattened array for
855
- scalar predictions or as an array of arrays for multi-dimensional outputs.
856
- """
857
- pred_column = self.prediction_column(model)
858
-
859
- # First, attempt to retrieve the prediction data from the DataFrame
860
- if hasattr(self, "_df") and pred_column in self._df.columns:
861
- predictions = self._df[pred_column].to_numpy()
862
-
863
- # Check if the predictions are stored as objects (e.g., lists for embeddings)
864
- if self._df[pred_column].dtype == object:
865
- # Attempt to convert lists to a numpy array
866
- try:
867
- predictions = np.stack(predictions)
868
- except ValueError as e:
869
- # Handling cases where predictions cannot be directly stacked
870
- raise ValueError(f"Error stacking prediction arrays: {e}")
871
- else:
872
- # Fallback to using the raw numpy dataset if DataFrame is not available or suitable
873
- try:
874
- predictions = self.raw_dataset[
875
- :, self.columns.index(pred_column)
876
- ].flatten()
877
- except IndexError as e:
878
- raise ValueError(
879
- f"Prediction column '{pred_column}' not found in raw dataset: {e}"
880
- )
881
-
882
- return predictions
883
-
884
- def y_prob(self, model) -> np.ndarray:
885
- """
886
- Returns the prediction variables for a given model, accommodating
887
- both scalar predictions and multi-dimensional outputs such as embeddings.
888
-
889
- Args:
890
- model (str): The ID of the model whose predictions are sought.
891
-
892
- Returns:
893
- np.ndarray: The prediction variables, either as a flattened array for
894
- scalar predictions or as an array of arrays for multi-dimensional outputs.
895
- """
896
- prob_column = self.probability_column(model)
897
-
898
- # First, attempt to retrieve the prediction data from the DataFrame
899
- if hasattr(self, "_df") and prob_column in self._df.columns:
900
- probabilities = self._df[prob_column].to_numpy()
901
-
902
- # Check if the predictions are stored as objects (e.g., lists for embeddings)
903
- if self._df[prob_column].dtype == object:
904
- # Attempt to convert lists to a numpy array
905
- try:
906
- probabilities = np.stack(probabilities)
907
- except ValueError as e:
908
- # Handling cases where predictions cannot be directly stacked
909
- raise ValueError(f"Error stacking prediction arrays: {e}")
910
- else:
911
- # Fallback to using the raw numpy dataset if DataFrame is not available or suitable
912
- try:
913
- probabilities = self.raw_dataset[
914
- :, self.columns.index(prob_column)
915
- ].flatten()
916
- except IndexError as e:
917
- raise ValueError(
918
- f"Prediction column '{prob_column}' not found in raw dataset: {e}"
919
- )
920
-
921
- return probabilities
922
-
923
- @property
924
- def type(self) -> str:
925
- """
926
- Returns the type of the dataset.
927
-
928
- Returns:
929
- str: The dataset type.
930
- """
931
- return self._type
932
-
933
- @property
934
- def df(self):
935
- """
936
- Returns the dataset as a pandas DataFrame.
937
-
938
- Returns:
939
- pd.DataFrame: The dataset as a DataFrame.
940
- """
941
- return self._df
942
-
943
- @property
944
- def copy(self):
945
- """
946
- Returns a copy of the raw_dataset dataframe.
947
- """
948
- return self._df.copy()
949
-
950
- def x_df(self):
951
- """
952
- Returns the non target and prediction columns.
953
-
954
- Returns:
955
- pd.DataFrame: The non target and prediction columns .
956
- """
957
- return self._df[[name for name in self.columns if name in self.feature_columns]]
958
-
959
- def y_df(self):
960
- """
961
- Returns the target columns (y) of the dataset.
962
-
963
- Returns:
964
- pd.DataFrame: The target columns.
965
- """
966
- return self._df[self.target_column]
967
-
968
- def y_pred_df(self, model):
969
- """
970
- Returns the target columns (y) of the dataset.
971
-
972
- Returns:
973
- pd.DataFrame: The target columns.
974
- """
975
- return self._df[self.prediction_column(model)]
976
-
977
- def y_prob_df(self, model):
978
- """
979
- Returns the target columns (y) of the dataset.
980
-
981
- Returns:
982
- pd.DataFrame: The target columns.
983
- """
984
- return self._df[self.probability_column(model)]
985
-
986
- def prediction_column(self, model) -> str:
987
- """
988
- Returns the prediction column name of the dataset.
989
-
990
- Returns:
991
- str: The prediction column name.
992
- """
993
- model_id = model.input_id
994
- pred_column = self._extra_columns.get("prediction_columns", {}).get(model_id)
995
- if pred_column is None:
996
- raise ValueError(
997
- f"Prediction column is not linked with the given {model_id}"
998
- )
999
- return pred_column
1000
-
1001
- def probability_column(self, model) -> str:
1002
- """
1003
- Returns the prediction column name of the dataset.
1004
-
1005
- Returns:
1006
- str: The prediction column name.
1007
- """
1008
- model_id = model.input_id
1009
- prob_column = self._extra_columns.get("probability_columns", {}).get(model_id)
1010
- if prob_column is None:
1011
- raise ValueError(
1012
- f"Probability column is not linked with the given {model_id}"
1013
- )
1014
- return prob_column
1015
-
1016
- def serialize(self):
1017
- """
1018
- Serializes the dataset to a dictionary.
1019
-
1020
- Returns:
1021
- Dict: The serialized dataset.
1022
- """
1023
- # Dataset with no targets can be logged
1024
- dataset_dict = {}
1025
- dataset_dict["targets"] = {
1026
- "target_column": self.target_column,
1027
- "class_labels": self._target_class_labels,
1028
- }
1029
-
1030
- return dataset_dict
1031
-
1032
- def get_feature_type(self, feature_id):
1033
- """
1034
- Returns the type of the specified feature.
1035
-
1036
- Args:
1037
- feature_id (str): The ID of the feature.
1038
-
1039
- Returns:
1040
- str: The type of the feature.
1041
- """
1042
- feature = self.get_feature_by_id(feature_id)
1043
- return feature["type"]
1044
-
1045
- def get_features_columns(self):
1046
- """
1047
- Returns the column names of the feature variables.
1048
-
1049
- Returns:
1050
- List[str]: The column names of the feature variables.
1051
- """
1052
- return self.feature_columns
1053
-
1054
- def get_numeric_features_columns(self):
1055
- """
1056
- Returns the column names of the numeric feature variables.
1057
-
1058
- Returns:
1059
- List[str]: The column names of the numeric feature variables.
1060
- """
1061
- numerical_columns = (
1062
- self.df[self.feature_columns]
1063
- .select_dtypes(include=[np.number])
1064
- .columns.tolist()
1065
- )
1066
-
1067
- return [column for column in numerical_columns if column != self.target_column]
1068
-
1069
- def get_categorical_features_columns(self):
1070
- """
1071
- Returns the column names of the categorical feature variables.
1072
-
1073
- Returns:
1074
- List[str]: The column names of the categorical feature variables.
1075
- """
1076
-
1077
- # Extract categorical columns from the dataset
1078
- categorical_columns = (
1079
- self.df[self.feature_columns]
1080
- .select_dtypes(include=[object, pd.Categorical])
1081
- .columns.tolist()
1082
- )
1083
-
1084
- return [
1085
- column for column in categorical_columns if column != self.target_column
1086
- ]
1087
-
1088
- def target_classes(self):
1089
- """
1090
- Returns the unique number of target classes for the target (Y) variable.
1091
- """
1092
- return [str(i) for i in np.unique(self.y)]
1093
-
1094
- def prediction_classes(self):
1095
- """
1096
- Returns the unique number of target classes for the target (Y) variable.
1097
- """
1098
- return [str(i) for i in np.unique(self.y_pred)]
1099
-
1100
- def __str__(self):
1101
- return (
1102
- f"=================\n"
1103
- f"VMDataset object: \n"
1104
- f"=================\n"
1105
- f"Input ID: {self._input_id}\n"
1106
- f"Target Column: {self._target_column}\n"
1107
- f"Feature Columns: {self._feature_columns}\n"
1108
- f"Text Column: {self._text_column}\n"
1109
- f"Extra Columns: {self._extra_columns}\n"
1110
- f"Type: {self._type}\n"
1111
- f"Target Class Labels: {self._target_class_labels}\n"
1112
- f"Columns: {self._columns}\n"
1113
- f"Index Name: {self._index_name}\n"
1114
- f"Index: {self._index}\n"
1115
- f"=================\n"
1116
- )
1117
-
1118
-
1119
- @dataclass
1120
- class DataFrameDataset(NumpyDataset):
1121
- """
1122
- VM dataset implementation for pandas DataFrame.
1123
- """
1124
-
1125
- def __init__(
1126
- self,
1127
- raw_dataset: pd.DataFrame,
1128
- input_id: str = None,
1129
- model: VMModel = None,
1130
- target_column: str = None,
1131
- extra_columns: dict = None,
1132
- feature_columns: list = None,
1133
- text_column: str = None,
1134
- target_class_labels: dict = None,
1135
- options: dict = None,
1136
- date_time_index: bool = False,
1137
- ):
1138
- """
1139
- Initializes a DataFrameDataset instance.
1140
-
1141
- Args:
1142
- raw_dataset (pd.DataFrame): The raw dataset as a pandas DataFrame.
1143
- input_id (str, optional): Identifier for the dataset. Defaults to None.
1144
- model (VMModel, optional): Model associated with the dataset. Defaults to None.
1145
- target_column (str, optional): The target column of the dataset. Defaults to None.
1146
- extra_columns (dict, optional): Extra columns to include in the dataset. Defaults to None.
1147
- feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
1148
- text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
1149
- target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
1150
- options (dict, optional): Additional options for the dataset. Defaults to None.
1151
- date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
1152
- """
1153
- index = None
1154
- if isinstance(raw_dataset.index, pd.Index):
1155
- index = raw_dataset.index.values
1156
-
1157
- super().__init__(
1158
- raw_dataset=raw_dataset.values,
1159
- input_id=input_id,
1160
- model=model,
1161
- index_name=raw_dataset.index.name,
1162
- index=index,
1163
- columns=raw_dataset.columns.to_list(),
1164
- target_column=target_column,
1165
- extra_columns=extra_columns,
1166
- feature_columns=feature_columns,
1167
- text_column=text_column,
1168
- target_class_labels=target_class_labels,
1169
- options=options,
1170
- date_time_index=date_time_index,
1171
- )
1172
-
1173
-
1174
- @dataclass
1175
- class PolarsDataset(NumpyDataset):
1176
- """
1177
- VM dataset implementation for Polars DataFrame.
1178
- """
1179
-
1180
- def __init__(
1181
- self,
1182
- raw_dataset: pl.DataFrame,
1183
- input_id: str = None,
1184
- model: VMModel = None,
1185
- target_column: str = None,
1186
- extra_columns: dict = None,
1187
- feature_columns: list = None,
1188
- text_column: str = None,
1189
- target_class_labels: dict = None,
1190
- options: dict = None,
1191
- date_time_index: bool = False,
1192
- ):
1193
- """
1194
- Initializes a PolarsDataset instance.
1195
-
1196
- Args:
1197
- raw_dataset (pl.DataFrame): The raw dataset as a Polars DataFrame.
1198
- input_id (str, optional): Identifier for the dataset. Defaults to None.
1199
- model (VMModel, optional): Model associated with the dataset. Defaults to None.
1200
- target_column (str, optional): The target column of the dataset. Defaults to None.
1201
- extra_columns (dict, optional): Extra columns to include in the dataset. Defaults to None.
1202
- feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
1203
- text_column (str, optional): The text column name of the dataset for NLP tasks. Defaults to None.
1204
- target_class_labels (dict, optional): The class labels for the target columns. Defaults to None.
1205
- options (dict, optional): Additional options for the dataset. Defaults to None.
1206
- date_time_index (bool, optional): Whether to use date-time index. Defaults to False.
1207
- """
1208
- super().__init__(
1209
- raw_dataset=raw_dataset.to_numpy(),
1210
- input_id=input_id,
1211
- model=model,
1212
- index_name=None,
1213
- index=None,
1214
- columns=raw_dataset.columns,
1215
- target_column=target_column,
1216
- extra_columns=extra_columns,
1217
- feature_columns=feature_columns,
1218
- text_column=text_column,
1219
- target_class_labels=target_class_labels,
1220
- options=options,
1221
- date_time_index=date_time_index,
1222
- )
1223
-
1224
-
1225
- @dataclass
1226
- class TorchDataset(NumpyDataset):
1227
- """
1228
- VM dataset implementation for PyTorch Datasets.
1229
- """
1230
-
1231
- def __init__(
1232
- self,
1233
- raw_dataset,
1234
- input_id: str = None,
1235
- model: VMModel = None,
1236
- index_name=None,
1237
- index=None,
1238
- columns=None,
1239
- target_column: str = None,
1240
- extra_columns: dict = None,
1241
- feature_columns: list = None,
1242
- text_column: str = None,
1243
- target_class_labels: dict = None,
1244
- options: dict = None,
1245
- ):
1246
- """
1247
- Initializes a TorchDataset instance.
1248
-
1249
- Args:
1250
- raw_dataset (torch.utils.data.Dataset): The raw dataset as a PyTorch Dataset.
1251
- index_name (str): The raw dataset index name.
1252
- index (np.ndarray): The raw dataset index as a NumPy array.
1253
- columns (List[str]): The column names of the dataset.
1254
- target_column (str, optional): The target column of the dataset. Defaults to None.
1255
- feature_columns (list, optional): The feature columns of the dataset. Defaults to None.
1256
- text_column (str, optional): The text column name of the dataset for nlp tasks. Defaults to None.
1257
- target_class_labels (Dict, optional): The class labels for the target columns. Defaults to None.
1258
- """
1259
-
1260
- try:
1261
- import torch
1262
- except ImportError:
1263
- raise ImportError(
1264
- "PyTorch is not installed, please run `pip install validmind[pytorch]`"
1265
- )
1266
-
1267
- columns = []
1268
-
1269
- for id, tens in zip(range(0, len(raw_dataset.tensors)), raw_dataset.tensors):
1270
- if id == 0 and feature_columns is None:
1271
- n_cols = tens.shape[1]
1272
- feature_columns = [
1273
- "x" + feature_id
1274
- for feature_id in np.linspace(
1275
- 0, n_cols - 1, num=n_cols, dtype=int
1276
- ).astype(str)
1277
- ]
1278
- columns.append(feature_columns)
1279
-
1280
- elif id == 1 and target_column is None:
1281
- target_column = "y"
1282
- columns.append(target_column)
1283
-
1284
- elif id == 2 and extra_columns is None:
1285
- extra_columns.prediction_column = "y_pred"
1286
- columns.append(extra_columns.prediction_column)
1287
-
1288
- merged_tensors = torch.cat(raw_dataset.tensors, dim=1).numpy()
1289
-
1290
- super().__init__(
1291
- input_id=input_id,
1292
- raw_dataset=merged_tensors,
1293
- model=model,
1294
- index_name=index_name,
1295
- index=index,
1296
- columns=columns,
1297
- target_column=target_column,
1298
- feature_columns=feature_columns,
1299
- text_column=text_column,
1300
- extra_columns=extra_columns,
1301
- target_class_labels=target_class_labels,
1302
- options=options,
1303
- )