autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -0,0 +1,494 @@
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import time
5
+ from pathlib import Path
6
+ import contextlib
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+ from sklearn.base import BaseEstimator, ClassifierMixin, RegressorMixin
12
+
13
+ from ._internal.config.config_run import ConfigRun
14
+ from ._internal.config.enums import ModelName
15
+ from ._internal.core.trainer_finetune import TrainerFinetune
16
+ from ._internal.data.dataset_split import make_stratified_dataset_split
17
+ from ._internal.models.tab2d import Tab2D
18
+ from ._internal.utils.set_seed import set_seed
19
+
20
+ # Hyperparameter search space
21
+ DEFAULT_FINE_TUNE = True # [True, False]
22
+ DEFAULT_FINE_TUNE_STEPS = 50 # [50, 60, 70, 80, 90, 100]
23
+ DEFAULT_CLS_METRIC = 'log_loss' # ['log_loss', 'accuracy', 'auc']
24
+ DEFAULT_REG_METRIC = 'mse' # ['mse', 'mae', 'rmse', 'r2']
25
+ SHUFFLE_CLASSES = False # [True, False]
26
+ SHUFFLE_FEATURES = False # [True, False]
27
+ USE_RANDOM_TRANSFORMS = False # [True, False]
28
+ RANDOM_MIRROR_REGRESSION = True # [True, False]
29
+ RANDOM_MIRROR_X = True # [True, False]
30
+ LR = 0.0001 # [0.00001, 0.000025, 0.00005, 0.000075, 0.0001, 0.00025, 0.0005, 0.00075, 0.001]
31
+ PATIENCE = 40 # [30, 35, 40, 45, 50]
32
+ WARMUP_STEPS = 1000 # [500, 750, 1000, 1250, 1500]
33
+ DEFAULT_CLS_MODEL = 'autogluon/mitra-classifier'
34
+ DEFAULT_REG_MODEL = 'autogluon/mitra-regressor'
35
+
36
+ # Constants
37
+ SEED = 0
38
+ DEFAULT_MODEL_TYPE = "Tab2D"
39
+
40
+ def _get_default_device():
41
+ """Get the best available device for the current system."""
42
+ if torch.cuda.is_available():
43
+ return "cuda"
44
+ elif hasattr(torch.backends, 'mps') and torch.backends.mps.is_available():
45
+ return "mps" # Apple silicon
46
+ else:
47
+ return "cpu"
48
+
49
+ DEFAULT_DEVICE = _get_default_device()
50
+ DEFAULT_ENSEMBLE = 1
51
+ DEFAULT_DIM = 512
52
+ DEFAULT_LAYERS = 12
53
+ DEFAULT_HEADS = 4
54
+ DEFAULT_CLASSES = 10
55
+ DEFAULT_VALIDATION_SPLIT = 0.2
56
+ USE_HF = True # Use Hugging Face pretrained models if available
57
+
58
+ class MitraBase(BaseEstimator):
59
+ """Base class for Mitra models with common functionality."""
60
+
61
+ def __init__(self,
62
+ model_type=DEFAULT_MODEL_TYPE,
63
+ n_estimators=DEFAULT_ENSEMBLE,
64
+ device=DEFAULT_DEVICE,
65
+ fine_tune=DEFAULT_FINE_TUNE,
66
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
67
+ metric=DEFAULT_CLS_METRIC,
68
+ state_dict=None,
69
+ hf_model=None,
70
+ patience=PATIENCE,
71
+ lr=LR,
72
+ warmup_steps=WARMUP_STEPS,
73
+ shuffle_classes=SHUFFLE_CLASSES,
74
+ shuffle_features=SHUFFLE_FEATURES,
75
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
76
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
77
+ random_mirror_x=RANDOM_MIRROR_X,
78
+ seed=SEED,
79
+ verbose=True,
80
+ ):
81
+ """
82
+ Initialize the base Mitra model.
83
+
84
+ Parameters
85
+ ----------
86
+ model_type : str, default="Tab2D"
87
+ The type of model to use. Options: "Tab2D", "Tab2D_COL_ROW"
88
+ n_estimators : int, default=1
89
+ Number of models in the ensemble
90
+ device : str, default="cuda"
91
+ Device to run the model on
92
+ fine_tune_steps: int, default=0
93
+ Number of epochs to train for
94
+ state_dict : str, optional
95
+ Path to the pretrained weights
96
+ """
97
+ self.model_type = model_type
98
+ self.n_estimators = n_estimators
99
+ self.device = device
100
+ self.fine_tune = fine_tune
101
+ self.fine_tune_steps = fine_tune_steps
102
+ self.metric = metric
103
+ self.state_dict = state_dict
104
+ self.hf_model = hf_model
105
+ self.patience = patience
106
+ self.lr = lr
107
+ self.warmup_steps = warmup_steps
108
+ self.shuffle_classes = shuffle_classes
109
+ self.shuffle_features = shuffle_features
110
+ self.use_random_transforms = use_random_transforms
111
+ self.random_mirror_regression = random_mirror_regression
112
+ self.random_mirror_x = random_mirror_x
113
+ self.trainers = []
114
+ self.train_time = 0
115
+ self.seed = seed
116
+ self.verbose = verbose
117
+
118
+ # FIXME: set_seed was removed in v1.4 as quality and speed reduction was observed when setting seed.
119
+ # This should be investigated and fixed for v1.5
120
+ # set_seed(self.seed)
121
+
122
+ def _create_config(self, task, dim_output, time_limit=None):
123
+ cfg = ConfigRun(
124
+ device=self.device,
125
+ model_name=ModelName.TAB2D,
126
+ seed=self.seed,
127
+ hyperparams={
128
+ 'dim_embedding': None,
129
+ 'early_stopping_data_split': 'VALID',
130
+ 'early_stopping_max_samples': 2048,
131
+ 'early_stopping_patience': self.patience,
132
+ 'grad_scaler_enabled': False,
133
+ 'grad_scaler_growth_interval': 1000,
134
+ 'grad_scaler_scale_init': 65536.0,
135
+ 'grad_scaler_scale_min': 65536.0,
136
+ 'label_smoothing': 0.0,
137
+ 'lr_scheduler': False,
138
+ 'lr_scheduler_patience': 25,
139
+ 'max_epochs': self.fine_tune_steps if self.fine_tune else 0,
140
+ 'max_samples_query': 1024,
141
+ 'max_samples_support': 8192,
142
+ 'optimizer': 'adamw',
143
+ 'lr': self.lr,
144
+ 'weight_decay': 0.1,
145
+ 'warmup_steps': self.warmup_steps,
146
+ 'path_to_weights': self.state_dict,
147
+ 'precision': 'bfloat16',
148
+ 'random_mirror_regression': self.random_mirror_regression,
149
+ 'random_mirror_x': self.random_mirror_x,
150
+ 'shuffle_classes': self.shuffle_classes,
151
+ 'shuffle_features': self.shuffle_features,
152
+ 'use_random_transforms': self.use_random_transforms,
153
+ 'use_feature_count_scaling': False,
154
+ 'use_pretrained_weights': False,
155
+ 'use_quantile_transformer': False,
156
+ 'budget': time_limit,
157
+ 'metric': self.metric,
158
+ },
159
+ )
160
+
161
+ cfg.task = task
162
+ cfg.hyperparams.update({
163
+ 'n_ensembles': self.n_estimators,
164
+ 'dim': DEFAULT_DIM,
165
+ 'dim_output': dim_output,
166
+ 'n_layers': DEFAULT_LAYERS,
167
+ 'n_heads': DEFAULT_HEADS,
168
+ 'regression_loss': 'mse',
169
+ })
170
+
171
+ return cfg, Tab2D
172
+
173
+
174
+ def _split_data(self, X, y):
175
+ """Split data into training and validation sets."""
176
+ if hasattr(self, 'task') and self.task == 'classification':
177
+ return make_stratified_dataset_split(X, y, seed=self.seed)
178
+ else:
179
+ # For regression, use random split
180
+ val_indices = np.random.choice(range(len(X)), int(DEFAULT_VALIDATION_SPLIT * len(X)), replace=False).tolist()
181
+ train_indices = [i for i in range(len(X)) if i not in val_indices]
182
+ return X[train_indices], X[val_indices], y[train_indices], y[val_indices]
183
+
184
+ def _train_ensemble(self, X_train, y_train, X_valid, y_valid, task, dim_output, n_classes=0, time_limit=None):
185
+ """Train the ensemble of models."""
186
+
187
+ cfg, Tab2D = self._create_config(task, dim_output, time_limit)
188
+ rng = np.random.RandomState(cfg.seed)
189
+
190
+ success = False
191
+ while not (success and cfg.hyperparams["max_samples_support"] > 0 and cfg.hyperparams["max_samples_query"] > 0):
192
+ try:
193
+ self.trainers.clear()
194
+
195
+ self.train_time = 0
196
+ for _ in range(self.n_estimators):
197
+ if USE_HF:
198
+ assert self.hf_model is not None, f"hf_model must not be None."
199
+ model = Tab2D.from_pretrained(self.hf_model, device=self.device)
200
+ else:
201
+ model = Tab2D(
202
+ dim=cfg.hyperparams['dim'],
203
+ dim_output=dim_output,
204
+ n_layers=cfg.hyperparams['n_layers'],
205
+ n_heads=cfg.hyperparams['n_heads'],
206
+ task=task.upper(),
207
+ use_pretrained_weights=True,
208
+ path_to_weights=Path(self.state_dict),
209
+ device=self.device,
210
+ )
211
+ trainer = TrainerFinetune(cfg, model, n_classes=n_classes, device=self.device, rng=rng, verbose=self.verbose)
212
+
213
+ start_time = time.time()
214
+ trainer.train(X_train, y_train, X_valid, y_valid)
215
+ end_time = time.time()
216
+
217
+ self.trainers.append(trainer)
218
+ self.train_time += end_time - start_time
219
+
220
+ success = True
221
+
222
+ except torch.cuda.OutOfMemoryError:
223
+ if cfg.hyperparams["max_samples_support"] >= 2048:
224
+ cfg.hyperparams["max_samples_support"] = int(
225
+ cfg.hyperparams["max_samples_support"] // 2
226
+ )
227
+ print(f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
228
+ f"to {cfg.hyperparams['max_samples_support']} due to OOM error.")
229
+ else:
230
+ cfg.hyperparams["max_samples_support"] = int(
231
+ cfg.hyperparams["max_samples_support"] // 2
232
+ )
233
+ print(f"Reducing max_samples_support from {cfg.hyperparams['max_samples_support'] * 2}"
234
+ f"to {cfg.hyperparams['max_samples_support']} due to OOM error.")
235
+ cfg.hyperparams["max_samples_query"] = int(
236
+ cfg.hyperparams["max_samples_query"] // 2
237
+ )
238
+ print(f"Reducing max_samples_query from {cfg.hyperparams['max_samples_query'] * 2}"
239
+ f"to {cfg.hyperparams['max_samples_query']} due to OOM error.")
240
+
241
+ if not success:
242
+ raise RuntimeError(
243
+ "Failed to train Mitra model after multiple attempts due to out of memory error."
244
+ )
245
+
246
+ return self
247
+
248
+
249
+ class MitraClassifier(MitraBase, ClassifierMixin):
250
+ """Classifier implementation of Mitra model."""
251
+
252
+ def __init__(self,
253
+ model_type=DEFAULT_MODEL_TYPE,
254
+ n_estimators=DEFAULT_ENSEMBLE,
255
+ device=DEFAULT_DEVICE,
256
+ fine_tune=DEFAULT_FINE_TUNE,
257
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
258
+ metric=DEFAULT_CLS_METRIC,
259
+ state_dict=None,
260
+ hf_model=DEFAULT_CLS_MODEL,
261
+ patience=PATIENCE,
262
+ lr=LR,
263
+ warmup_steps=WARMUP_STEPS,
264
+ shuffle_classes=SHUFFLE_CLASSES,
265
+ shuffle_features=SHUFFLE_FEATURES,
266
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
267
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
268
+ random_mirror_x=RANDOM_MIRROR_X,
269
+ seed=SEED,
270
+ verbose=True,
271
+ ):
272
+ """Initialize the classifier."""
273
+ super().__init__(
274
+ model_type,
275
+ n_estimators,
276
+ device,
277
+ fine_tune,
278
+ fine_tune_steps,
279
+ metric,
280
+ state_dict,
281
+ hf_model=hf_model,
282
+ patience=patience,
283
+ lr=lr,
284
+ warmup_steps=warmup_steps,
285
+ shuffle_classes=shuffle_classes,
286
+ shuffle_features=shuffle_features,
287
+ use_random_transforms=use_random_transforms,
288
+ random_mirror_regression=random_mirror_regression,
289
+ random_mirror_x=random_mirror_x,
290
+ seed=seed,
291
+ verbose=verbose,
292
+ )
293
+ self.task = 'classification'
294
+
295
+ def fit(self, X, y, X_val = None, y_val = None, time_limit = None):
296
+ """
297
+ Fit the ensemble of models.
298
+
299
+ Parameters
300
+ ----------
301
+ X : array-like of shape (n_samples, n_features)
302
+ Training data
303
+ y : array-like of shape (n_samples,)
304
+ Target values
305
+
306
+ Returns
307
+ -------
308
+ self : object
309
+ Returns self
310
+ """
311
+
312
+ with mitra_deterministic_context():
313
+
314
+ if isinstance(X, pd.DataFrame):
315
+ X = X.values
316
+ if isinstance(y, pd.Series):
317
+ y = y.values
318
+
319
+ self.X, self.y = X, y
320
+
321
+ if X_val is not None and y_val is not None:
322
+ if isinstance(X_val, pd.DataFrame):
323
+ X_val = X_val.values
324
+ if isinstance(y_val, pd.Series):
325
+ y_val = y_val.values
326
+ X_train, X_valid, y_train, y_valid = X, X_val, y, y_val
327
+ else:
328
+ X_train, X_valid, y_train, y_valid = self._split_data(X, y)
329
+
330
+ return self._train_ensemble(X_train, y_train, X_valid, y_valid, self.task, DEFAULT_CLASSES, n_classes=DEFAULT_CLASSES, time_limit=time_limit)
331
+
332
+ def predict(self, X):
333
+ """
334
+ Predict class labels for samples in X.
335
+
336
+ Parameters
337
+ ----------
338
+ X : array-like of shape (n_samples, n_features)
339
+ The input samples
340
+
341
+ Returns
342
+ -------
343
+ y : ndarray of shape (n_samples,)
344
+ The predicted classes
345
+ """
346
+
347
+ if isinstance(X, pd.DataFrame):
348
+ X = X.values
349
+
350
+ return self.predict_proba(X).argmax(axis=1)
351
+
352
+ def predict_proba(self, X):
353
+ """
354
+ Predict class probabilities for samples in X.
355
+
356
+ Parameters
357
+ ----------
358
+ X : array-like of shape (n_samples, n_features)
359
+ The input samples
360
+
361
+ Returns
362
+ -------
363
+ p : ndarray of shape (n_samples, n_classes)
364
+ The class probabilities of the input samples
365
+ """
366
+
367
+ with mitra_deterministic_context():
368
+
369
+ if isinstance(X, pd.DataFrame):
370
+ X = X.values
371
+
372
+ preds = []
373
+ for trainer in self.trainers:
374
+ logits = trainer.predict(self.X, self.y, X)[...,:len(np.unique(self.y))] # Remove extra classes
375
+ preds.append(np.exp(logits) / np.exp(logits).sum(axis=1, keepdims=True)) # Softmax
376
+ preds = sum(preds) / len(preds) # Averaging ensemble predictions
377
+
378
+ return preds
379
+
380
+
381
+ class MitraRegressor(MitraBase, RegressorMixin):
382
+ """Regressor implementation of Mitra model."""
383
+
384
+ def __init__(self,
385
+ model_type=DEFAULT_MODEL_TYPE,
386
+ n_estimators=DEFAULT_ENSEMBLE,
387
+ device=DEFAULT_DEVICE,
388
+ fine_tune=DEFAULT_FINE_TUNE,
389
+ fine_tune_steps=DEFAULT_FINE_TUNE_STEPS,
390
+ metric=DEFAULT_REG_METRIC,
391
+ state_dict=None,
392
+ hf_model=DEFAULT_REG_MODEL,
393
+ patience=PATIENCE,
394
+ lr=LR,
395
+ warmup_steps=WARMUP_STEPS,
396
+ shuffle_classes=SHUFFLE_CLASSES,
397
+ shuffle_features=SHUFFLE_FEATURES,
398
+ use_random_transforms=USE_RANDOM_TRANSFORMS,
399
+ random_mirror_regression=RANDOM_MIRROR_REGRESSION,
400
+ random_mirror_x=RANDOM_MIRROR_X,
401
+ seed=SEED,
402
+ verbose=True,
403
+ ):
404
+ """Initialize the regressor."""
405
+ super().__init__(
406
+ model_type,
407
+ n_estimators,
408
+ device,
409
+ fine_tune,
410
+ fine_tune_steps,
411
+ metric,
412
+ state_dict,
413
+ hf_model=hf_model,
414
+ patience=patience,
415
+ lr=lr,
416
+ warmup_steps=warmup_steps,
417
+ shuffle_classes=shuffle_classes,
418
+ shuffle_features=shuffle_features,
419
+ use_random_transforms=use_random_transforms,
420
+ random_mirror_regression=random_mirror_regression,
421
+ random_mirror_x=random_mirror_x,
422
+ seed=seed,
423
+ verbose=verbose,
424
+ )
425
+ self.task = 'regression'
426
+
427
+ def fit(self, X, y, X_val = None, y_val = None, time_limit = None):
428
+ """
429
+ Fit the ensemble of models.
430
+
431
+ Parameters
432
+ ----------
433
+ X : array-like of shape (n_samples, n_features)
434
+ Training data
435
+ y : array-like of shape (n_samples,)
436
+ Target values
437
+
438
+ Returns
439
+ -------
440
+ self : object
441
+ Returns self
442
+ """
443
+
444
+ with mitra_deterministic_context():
445
+
446
+ if isinstance(X, pd.DataFrame):
447
+ X = X.values
448
+ if isinstance(y, pd.Series):
449
+ y = y.values
450
+
451
+ self.X, self.y = X, y
452
+
453
+ if X_val is not None and y_val is not None:
454
+ if isinstance(X_val, pd.DataFrame):
455
+ X_val = X_val.values
456
+ if isinstance(y_val, pd.Series):
457
+ y_val = y_val.values
458
+ X_train, X_valid, y_train, y_valid = X, X_val, y, y_val
459
+ else:
460
+ X_train, X_valid, y_train, y_valid = self._split_data(X, y)
461
+
462
+ return self._train_ensemble(X_train, y_train, X_valid, y_valid, self.task, 1, time_limit=time_limit)
463
+
464
+ def predict(self, X):
465
+ """
466
+ Predict regression target for samples in X.
467
+
468
+ Parameters
469
+ ----------
470
+ X : array-like of shape (n_samples, n_features)
471
+ The input samples
472
+
473
+ Returns
474
+ -------
475
+ y : ndarray of shape (n_samples,)
476
+ The predicted values
477
+ """
478
+
479
+ with mitra_deterministic_context():
480
+
481
+ if isinstance(X, pd.DataFrame):
482
+ X = X.values
483
+
484
+ preds = []
485
+ for trainer in self.trainers:
486
+ preds.append(trainer.predict(self.X, self.y, X))
487
+
488
+ return sum(preds) / len(preds) # Averaging ensemble predictions
489
+
490
+
491
+ @contextlib.contextmanager
492
+ def mitra_deterministic_context():
493
+ """Context manager to set deterministic settings only for Mitra operations."""
494
+ yield
File without changes