autogluon.tabular 1.4.1b20251212__py3-none-any.whl → 1.5.0b20251220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (43) hide show
  1. autogluon/tabular/configs/hyperparameter_configs.py +4 -0
  2. autogluon/tabular/configs/presets_configs.py +39 -2
  3. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +2 -44
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +2 -0
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +2 -0
  6. autogluon/tabular/learner/default_learner.py +1 -0
  7. autogluon/tabular/models/__init__.py +3 -1
  8. autogluon/tabular/models/abstract/__init__.py +0 -0
  9. autogluon/tabular/models/abstract/abstract_torch_model.py +148 -0
  10. autogluon/tabular/models/catboost/catboost_model.py +1 -1
  11. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +5 -1
  12. autogluon/tabular/models/lgb/lgb_model.py +58 -8
  13. autogluon/tabular/models/lgb/lgb_utils.py +2 -2
  14. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +14 -1
  15. autogluon/tabular/models/mitra/mitra_model.py +53 -22
  16. autogluon/tabular/models/realmlp/realmlp_model.py +8 -2
  17. autogluon/tabular/models/tabdpt/__init__.py +0 -0
  18. autogluon/tabular/models/tabdpt/tabdpt_model.py +253 -0
  19. autogluon/tabular/models/tabicl/tabicl_model.py +15 -2
  20. autogluon/tabular/models/tabm/tabm_model.py +23 -79
  21. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +451 -0
  22. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +86 -8
  23. autogluon/tabular/models/tabprep/__init__.py +0 -0
  24. autogluon/tabular/models/tabprep/prep_lgb_model.py +21 -0
  25. autogluon/tabular/models/tabprep/prep_mixin.py +220 -0
  26. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +1 -1
  27. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +12 -4
  28. autogluon/tabular/models/xgboost/xgboost_model.py +2 -0
  29. autogluon/tabular/predictor/predictor.py +47 -18
  30. autogluon/tabular/registry/_ag_model_registry.py +8 -2
  31. autogluon/tabular/testing/fit_helper.py +33 -0
  32. autogluon/tabular/trainer/abstract_trainer.py +45 -9
  33. autogluon/tabular/trainer/auto_trainer.py +5 -0
  34. autogluon/tabular/version.py +1 -1
  35. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/METADATA +38 -37
  36. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/RECORD +43 -33
  37. /autogluon.tabular-1.4.1b20251212-py3.11-nspkg.pth → /autogluon.tabular-1.5.0b20251220-py3.11-nspkg.pth +0 -0
  38. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/WHEEL +0 -0
  39. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/licenses/LICENSE +0 -0
  40. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/licenses/NOTICE +0 -0
  41. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/namespace_packages.txt +0 -0
  42. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/top_level.txt +0 -0
  43. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/zip-safe +0 -0
@@ -2,19 +2,21 @@ from __future__ import annotations
2
2
 
3
3
  import logging
4
4
  import os
5
+ from pathlib import Path
5
6
  from typing import List, Optional
6
7
 
7
8
  import pandas as pd
9
+ from typing_extensions import Self
8
10
 
9
11
  from autogluon.common.utils.resource_utils import ResourceManager
10
- from autogluon.core.models import AbstractModel
12
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
11
13
  from autogluon.features.generators import LabelEncoderFeatureGenerator
12
14
  from autogluon.tabular import __version__
13
15
 
14
16
  logger = logging.getLogger(__name__)
15
17
 
16
18
 
17
- class MitraModel(AbstractModel):
19
+ class MitraModel(AbstractTorchModel):
18
20
  """
19
21
  Mitra is a tabular foundation model pre-trained purely on synthetic data with the goal
20
22
  of optimizing fine-tuning performance over in-context learning performance.
@@ -161,7 +163,7 @@ class MitraModel(AbstractModel):
161
163
  if X_val is not None:
162
164
  X_val = self.preprocess(X_val)
163
165
 
164
- self.model = self.model.fit(
166
+ model = self.model.fit(
165
167
  X=X,
166
168
  y=y,
167
169
  X_val=X_val,
@@ -169,6 +171,11 @@ class MitraModel(AbstractModel):
169
171
  time_limit=time_limit,
170
172
  )
171
173
 
174
+ for i in range(len(model.trainers)):
175
+ model.trainers[i].post_fit_optimize()
176
+
177
+ self.model = model
178
+
172
179
  if need_to_reset_torch_threads:
173
180
  torch.set_num_threads(torch_threads_og)
174
181
 
@@ -190,42 +197,63 @@ class MitraModel(AbstractModel):
190
197
  )
191
198
  return default_auxiliary_params
192
199
 
193
- @property
194
- def weights_path(self) -> str:
195
- return os.path.join(self.path, self.weights_file_name)
200
+ def weights_path(self, path: str | None = None) -> str:
201
+ if path is None:
202
+ path = self.path
203
+ return str(Path(path) / self.weights_file_name)
196
204
 
197
205
  def save(self, path: str = None, verbose=True) -> str:
198
206
  _model_weights_list = None
199
207
  if self.model is not None:
208
+ self._save_model_artifact(path=path)
200
209
  _model_weights_list = []
201
210
  for i in range(len(self.model.trainers)):
202
211
  _model_weights_list.append(self.model.trainers[i].model)
203
- self.model.trainers[i].checkpoint = None
204
212
  self.model.trainers[i].model = None
205
- self.model.trainers[i].optimizer = None
206
- self.model.trainers[i].scheduler_warmup = None
207
- self.model.trainers[i].scheduler_reduce_on_plateau = None
208
- self._weights_saved = True
213
+
209
214
  path = super().save(path=path, verbose=verbose)
210
215
  if _model_weights_list is not None:
211
- import torch
212
-
213
- os.makedirs(self.path, exist_ok=True)
214
- torch.save(_model_weights_list, self.weights_path)
215
216
  for i in range(len(self.model.trainers)):
216
217
  self.model.trainers[i].model = _model_weights_list[i]
217
218
  return path
218
219
 
220
+ def _save_model_artifact(self, path: str | None):
221
+ if path is None:
222
+ path = self.path
223
+ import torch
224
+ device_og = self.device
225
+ self.set_device("cpu")
226
+
227
+ _model_weights_list = []
228
+ for i in range(len(self.model.trainers)):
229
+ _model_weights_list.append(self.model.trainers[i].model)
230
+
231
+ os.makedirs(path, exist_ok=True)
232
+ torch.save(_model_weights_list, self.weights_path(path=path))
233
+ self.set_device(device_og)
234
+ self._weights_saved = True
235
+
236
+ def _load_model_artifact(self):
237
+ import torch
238
+ device = self.suggest_device_infer()
239
+ model_weights_list = torch.load(self.weights_path(), weights_only=False) # nosec B614
240
+ for i in range(len(self.model.trainers)):
241
+ self.model.trainers[i].model = model_weights_list[i]
242
+ self.set_device(device)
243
+
244
+ def _set_device(self, device: str):
245
+ for i in range(len(self.model.trainers)):
246
+ self.model.trainers[i].set_device(device)
247
+
248
+ def get_device(self) -> str:
249
+ return self.model.trainers[0].device
250
+
219
251
  @classmethod
220
- def load(cls, path: str, reset_paths=False, verbose=True):
252
+ def load(cls, path: str, reset_paths=True, verbose=True) -> Self:
221
253
  model: MitraModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
222
254
 
223
255
  if model._weights_saved:
224
- import torch
225
-
226
- model_weights_list = torch.load(model.weights_path, weights_only=False) # nosec B614
227
- for i in range(len(model.model.trainers)):
228
- model.model.trainers[i].model = model_weights_list[i]
256
+ model._load_model_artifact()
229
257
  model._weights_saved = False
230
258
  return model
231
259
 
@@ -370,9 +398,12 @@ class MitraModel(AbstractModel):
370
398
  return int(gpu_memory_mb * 1e6)
371
399
 
372
400
  @classmethod
373
- def _class_tags(cls) -> dict:
401
+ def _class_tags(cls):
374
402
  return {
375
403
  "can_estimate_memory_usage_static": True,
404
+ "can_set_device": True,
405
+ "set_device_on_save_to": None,
406
+ "set_device_on_load": False,
376
407
  }
377
408
 
378
409
  def _more_tags(self) -> dict:
@@ -16,7 +16,7 @@ from sklearn.impute import SimpleImputer
16
16
 
17
17
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
18
18
  from autogluon.common.utils.resource_utils import ResourceManager
19
- from autogluon.core.models import AbstractModel
19
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
20
20
  from autogluon.tabular import __version__
21
21
 
22
22
  logger = logging.getLogger(__name__)
@@ -34,7 +34,7 @@ def set_logger_level(logger_name: str, level: int):
34
34
 
35
35
 
36
36
  # pip install pytabkit
37
- class RealMLPModel(AbstractModel):
37
+ class RealMLPModel(AbstractTorchModel):
38
38
  """
39
39
  RealMLP is an improved multilayer perception (MLP) model
40
40
  through a bag of tricks and better default hyperparameters.
@@ -83,6 +83,12 @@ class RealMLPModel(AbstractModel):
83
83
  model_cls = RealMLP_TD_S_Regressor
84
84
  return model_cls
85
85
 
86
+ def get_device(self) -> str:
87
+ return self.model.device
88
+
89
+ def _set_device(self, device: str):
90
+ self.model.to(device)
91
+
86
92
  def _fit(
87
93
  self,
88
94
  X: pd.DataFrame,
File without changes
@@ -0,0 +1,253 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import TYPE_CHECKING
4
+
5
+ from autogluon.common.utils.resource_utils import ResourceManager
6
+ from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
7
+ from autogluon.features.generators import LabelEncoderFeatureGenerator
8
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
9
+
10
+ if TYPE_CHECKING:
11
+ import numpy as np
12
+ import pandas as pd
13
+
14
+
15
+ # FIXME: Nick:
16
+ # TODO: batch_size is linear to memory usage
17
+ # 512 default
18
+ # should be less for very large datasets
19
+ # 128 batch_size on Bioresponse -> 12 GB VRAM
20
+ # Train Data Rows: 2500
21
+ # Train Data Columns: 1776
22
+ # Problem Type: binary
23
+ # FIXME: Just set context_size = infinity, everything is way faster, memory usage is way lower, etc.
24
+ # Train Data Rows: 100000
25
+ # Train Data Columns: 10
26
+ # binary
27
+ # only takes 6.7 GB during inference with batch_size = 512
28
+ # FIXME: Make it work when loading on CPU?
29
+ # FIXME: Can we run 8 in parallel to speed up?
30
+ # TODO: clip_sigma == 1 is terrible, clip_sigma == 16 maybe very good? What about higher values?
31
+ # clip_sigma >= 16 is roughly all equivalent
32
+ # FIXME: TabDPT stores self.X_test for no reason
33
+ # FIXME: TabDPT creates faiss_knn even if it is never used. Better if `context_size=None` means it is never created.
34
+ # TODO: unit test
35
+ # TODO: memory estimate
36
+ class TabDPTModel(AbstractTorchModel):
37
+ ag_key = "TABDPT"
38
+ ag_name = "TabDPT"
39
+ seed_name = "seed"
40
+ ag_priority = 50
41
+ default_random_seed = 0
42
+
43
+ def __init__(self, **kwargs):
44
+ super().__init__(**kwargs)
45
+ self._feature_generator = None
46
+ self._predict_hps = None
47
+ self._use_flash_og = None
48
+
49
+ def _fit(
50
+ self,
51
+ X: pd.DataFrame,
52
+ y: pd.Series,
53
+ num_cpus: int = 1,
54
+ num_gpus: int = 0,
55
+ **kwargs,
56
+ ):
57
+ from torch.cuda import is_available
58
+
59
+ device = "cuda" if num_gpus != 0 else "cpu"
60
+ if (device == "cuda") and (not is_available()):
61
+ raise AssertionError(
62
+ "Fit specified to use GPU, but CUDA is not available on this machine. "
63
+ "Please switch to CPU usage instead.",
64
+ )
65
+ from tabdpt import TabDPTClassifier, TabDPTRegressor
66
+
67
+ model_cls = (
68
+ TabDPTClassifier
69
+ if self.problem_type in [BINARY, MULTICLASS]
70
+ else TabDPTRegressor
71
+ )
72
+ fit_params, self._predict_hps = self._get_tabdpt_params(num_gpus=num_gpus)
73
+
74
+ X = self.preprocess(X)
75
+ y = y.to_numpy()
76
+ self.model = model_cls(
77
+ device=device,
78
+ **fit_params,
79
+ )
80
+ self.model.fit(X=X, y=y)
81
+
82
+ def _get_tabdpt_params(self, num_gpus: float) -> tuple[dict, dict]:
83
+ model_params = self._get_model_params()
84
+
85
+ valid_predict_params = (self.seed_name, "context_size", "permute_classes", "temperature", "n_ensembles")
86
+
87
+ predict_params = {}
88
+ for hp in valid_predict_params:
89
+ if hp in model_params:
90
+ predict_params[hp] = model_params.pop(hp)
91
+ predict_params.setdefault(self.seed_name, self.default_random_seed)
92
+ predict_params.setdefault("context_size", None)
93
+
94
+ supported_predict_params = (
95
+ (self.seed_name, "context_size", "n_ensembles", "permute_classes", "temperature")
96
+ if self.problem_type in [BINARY, MULTICLASS]
97
+ else (self.seed_name, "context_size", "n_ensembles")
98
+ )
99
+ predict_params = {key: val for key, val in predict_params.items() if key in supported_predict_params}
100
+
101
+ fit_params = model_params
102
+
103
+ fit_params.setdefault("verbose", False)
104
+ fit_params.setdefault("compile", False)
105
+ if fit_params.get("use_flash", True):
106
+ fit_params["use_flash"] = self._use_flash(num_gpus=num_gpus)
107
+ return fit_params, predict_params
108
+
109
+ @staticmethod
110
+ def _use_flash(num_gpus: float) -> bool:
111
+ """Detect if torch's native flash attention is available on the current machine."""
112
+ if num_gpus == 0:
113
+ return False
114
+
115
+ import torch
116
+
117
+ if not torch.cuda.is_available():
118
+ return False
119
+
120
+ device = torch.device("cuda:0")
121
+ capability = torch.cuda.get_device_capability(device)
122
+
123
+ return capability != (7, 5)
124
+
125
+ def _post_fit(self, **kwargs):
126
+ super()._post_fit(**kwargs)
127
+ self._use_flash_og = self.model.use_flash
128
+ return self
129
+
130
+ def get_device(self) -> str:
131
+ return self.model.device
132
+
133
+ def _set_device(self, device: str):
134
+ self.model.to(device)
135
+ if device == "cpu":
136
+ self.model.use_flash = False
137
+ self.model.model.use_flash = False
138
+ else:
139
+ self.model.use_flash = self._use_flash_og
140
+ self.model.model.use_flash = self._use_flash_og
141
+
142
+ def _get_default_resources(self) -> tuple[int, int]:
143
+ # Use only physical cores for better performance based on benchmarks
144
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
145
+
146
+ num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
147
+
148
+ return num_cpus, num_gpus
149
+
150
+ def get_minimum_resources(
151
+ self, is_gpu_available: bool = False
152
+ ) -> dict[str, int | float]:
153
+ return {
154
+ "num_cpus": 1,
155
+ "num_gpus": 0.5 if is_gpu_available else 0,
156
+ }
157
+
158
+ def _predict_proba(self, X, **kwargs) -> np.ndarray:
159
+ X = self.preprocess(X, **kwargs)
160
+
161
+ if self.problem_type in [REGRESSION]:
162
+ y_pred = self.model.predict(X, **self._predict_hps)
163
+ return y_pred
164
+
165
+ y_pred_proba = self.model.ensemble_predict_proba(X, **self._predict_hps)
166
+ return self._convert_proba_to_unified_form(y_pred_proba)
167
+
168
+ def _preprocess(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
169
+ """TabDPT requires numpy array as input."""
170
+ X = super()._preprocess(X, **kwargs)
171
+ if self._feature_generator is None:
172
+ self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
173
+ self._feature_generator.fit(X=X)
174
+ if self._feature_generator.features_in:
175
+ X = X.copy()
176
+ X[self._feature_generator.features_in] = self._feature_generator.transform(
177
+ X=X
178
+ )
179
+ return X.to_numpy()
180
+
181
+ @classmethod
182
+ def supported_problem_types(cls) -> list[str] | None:
183
+ return ["binary", "multiclass", "regression"]
184
+
185
+ @classmethod
186
+ def _class_tags(cls):
187
+ return {"can_estimate_memory_usage_static": True}
188
+
189
+ def _more_tags(self) -> dict:
190
+ return {"can_refit_full": True}
191
+
192
+ def _get_default_auxiliary_params(self) -> dict:
193
+ default_auxiliary_params = super()._get_default_auxiliary_params()
194
+ default_auxiliary_params.update(
195
+ {
196
+ "max_rows": 100000, # TODO: Test >100k rows
197
+ "max_features": 2500, # TODO: Test >2500 features
198
+ "max_classes": 10, # TODO: Test >10 classes
199
+ }
200
+ )
201
+ return default_auxiliary_params
202
+
203
+ @classmethod
204
+ def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
205
+ default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
206
+ extra_ag_args_ensemble = {
207
+ "refit_folds": True,
208
+ }
209
+ default_ag_args_ensemble.update(extra_ag_args_ensemble)
210
+ return default_ag_args_ensemble
211
+
212
+ # FIXME: This is copied from TabPFN, but TabDPT is not the same
213
+ @classmethod
214
+ def _estimate_memory_usage_static(
215
+ cls,
216
+ *,
217
+ X: pd.DataFrame,
218
+ hyperparameters: dict | None = None,
219
+ **kwargs,
220
+ ) -> int:
221
+ """Heuristic memory estimate based on TabPFN's memory estimate logic in:
222
+ https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
223
+
224
+ This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
225
+ """
226
+ # TODO: update, this is not correct anymore, consider using internal TabPFN functions directly.
227
+ features_per_group = 3 # Based on TabPFNv2 default (unused)
228
+ n_layers = 12 # Based on TabPFNv2 default
229
+ embedding_size = 192 # Based on TabPFNv2 default
230
+ dtype_byte_size = 2 # Based on TabPFNv2 default
231
+
232
+ model_mem = 14489108 # Based on TabPFNv2 default
233
+
234
+ n_samples, n_features = X.shape[0], min(X.shape[1], 500)
235
+ n_feature_groups = (
236
+ n_features
237
+ ) / features_per_group + 1 # TODO: Unsure how to calculate this
238
+
239
+ X_mem = n_samples * n_feature_groups * dtype_byte_size
240
+ activation_mem = (
241
+ n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
242
+ )
243
+
244
+ baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
245
+
246
+ # Add some buffer to each term + 1 GB overhead to be safe
247
+ memory_estimate = model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
248
+
249
+ # TabDPT memory estimation is very inaccurate because it is using TabPFN memory estimate. Double it to be safe.
250
+ memory_estimate = memory_estimate * 2
251
+
252
+ # Note: This memory estimate is way off if `context_size` is not None
253
+ return int(memory_estimate)
@@ -10,14 +10,14 @@ import pandas as pd
10
10
 
11
11
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
12
12
  from autogluon.common.utils.resource_utils import ResourceManager
13
- from autogluon.core.models import AbstractModel
14
13
  from autogluon.tabular import __version__
14
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
15
15
 
16
16
  logger = logging.getLogger(__name__)
17
17
 
18
18
 
19
19
  # TODO: Verify if crashes when weights are not yet downloaded and fit in parallel
20
- class TabICLModel(AbstractModel):
20
+ class TabICLModel(AbstractTorchModel):
21
21
  """
22
22
  TabICL is a foundation model for tabular data using in-context learning
23
23
  that is scalable to larger datasets than TabPFNv2. It is pretrained purely on synthetic data.
@@ -97,6 +97,19 @@ class TabICLModel(AbstractModel):
97
97
  y=y,
98
98
  )
99
99
 
100
+ def get_device(self) -> str:
101
+ return self.model.device_.type
102
+
103
+ # TODO: Better to have an official TabICL method for this
104
+ def _set_device(self, device: str):
105
+ device = self.to_torch_device(device)
106
+ self.model.device_ = device
107
+ self.model.device = self.model.device_.type
108
+ self.model.model_ = self.model.model_.to(self.model.device_)
109
+ self.model.inference_config_.COL_CONFIG.device = self.model.device_
110
+ self.model.inference_config_.ROW_CONFIG.device = self.model.device_
111
+ self.model.inference_config_.ICL_CONFIG.device = self.model.device_
112
+
100
113
  def _get_default_auxiliary_params(self) -> dict:
101
114
  default_auxiliary_params = super()._get_default_auxiliary_params()
102
115
  default_auxiliary_params.update(
@@ -15,13 +15,13 @@ import time
15
15
  import pandas as pd
16
16
 
17
17
  from autogluon.common.utils.resource_utils import ResourceManager
18
- from autogluon.core.models import AbstractModel
19
18
  from autogluon.tabular import __version__
19
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
20
20
 
21
21
  logger = logging.getLogger(__name__)
22
22
 
23
23
 
24
- class TabMModel(AbstractModel):
24
+ class TabMModel(AbstractTorchModel):
25
25
  """
26
26
  TabM is an efficient ensemble of MLPs that is trained simultaneously with mostly shared parameters.
27
27
 
@@ -49,7 +49,6 @@ class TabMModel(AbstractModel):
49
49
  self._indicator_columns = None
50
50
  self._features_bool = None
51
51
  self._bool_to_cat = None
52
- self.device = None
53
52
 
54
53
  def _fit(
55
54
  self,
@@ -88,7 +87,7 @@ class TabMModel(AbstractModel):
88
87
  if X_val is None:
89
88
  from autogluon.core.utils import generate_train_test_split
90
89
 
91
- X_train, X_val, y_train, y_val = generate_train_test_split(
90
+ X, X_val, y, y_val = generate_train_test_split(
92
91
  X=X,
93
92
  y=y,
94
93
  problem_type=self.problem_type,
@@ -99,7 +98,7 @@ class TabMModel(AbstractModel):
99
98
  hyp = self._get_model_params()
100
99
  bool_to_cat = hyp.pop("bool_to_cat", True)
101
100
 
102
- X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
101
+ X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat)
103
102
  if X_val is not None:
104
103
  X_val = self.preprocess(X_val)
105
104
 
@@ -143,80 +142,13 @@ class TabMModel(AbstractModel):
143
142
 
144
143
  return X
145
144
 
146
- def save(self, path: str = None, verbose=True) -> str:
147
- """
148
- Need to set device to CPU to be able to load on a non-GPU environment
149
- """
150
- import torch
151
-
152
- # Save on CPU to ensure the model can be loaded without GPU
153
- if self.model is not None:
154
- self.device = self.model.device_
155
- device_cpu = torch.device("cpu")
156
- self.model.model_ = self.model.model_.to(device_cpu)
157
- self.model.device_ = device_cpu
158
- path = super().save(path=path, verbose=verbose)
159
- # Put the model back to the device after the save
160
- if self.model is not None:
161
- self.model.model_.to(self.device)
162
- self.model.device_ = self.device
163
-
164
- return path
165
-
166
- @classmethod
167
- def load(cls, path: str, reset_paths=True, verbose=True):
168
- """
169
- Loads the model from disk to memory.
170
- The loaded model will be on the same device it was trained on (cuda/mps);
171
- if the device is not available (trained on GPU, deployed on CPU), then `cpu` will be used.
172
-
173
- Parameters
174
- ----------
175
- path : str
176
- Path to the saved model, minus the file name.
177
- This should generally be a directory path ending with a '/' character (or appropriate path separator value depending on OS).
178
- The model file is typically located in os.path.join(path, cls.model_file_name).
179
- reset_paths : bool, default True
180
- Whether to reset the self.path value of the loaded model to be equal to path.
181
- It is highly recommended to keep this value as True unless accessing the original self.path value is important.
182
- If False, the actual valid path and self.path may differ, leading to strange behaviour and potential exceptions if the model needs to load any other files at a later time.
183
- verbose : bool, default True
184
- Whether to log the location of the loaded file.
185
-
186
- Returns
187
- -------
188
- model : cls
189
- Loaded model object.
190
- """
191
- import torch
192
-
193
- model: TabMModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
194
-
195
- # Put the model on the same device it was trained on (GPU/MPS) if it is available; otherwise use CPU
196
- if model.model is not None:
197
- original_device_type = model.device.type
198
- if "cuda" in original_device_type:
199
- # cuda: nvidia GPU
200
- device = torch.device(original_device_type if torch.cuda.is_available() else "cpu")
201
- elif "mps" in original_device_type:
202
- # mps: Apple Silicon
203
- device = torch.device(original_device_type if torch.backends.mps.is_available() else "cpu")
204
- else:
205
- device = torch.device(original_device_type)
206
-
207
- if verbose and (original_device_type != device.type):
208
- logger.log(15, f"Model is trained on {original_device_type}, but the device is not available - loading on {device.type}")
209
-
210
- model.set_device(device=device)
145
+ def get_device(self) -> str:
146
+ return self.model.device_.type
211
147
 
212
- return model
213
-
214
- def set_device(self, device):
215
- self.device = device
216
- if self.model is not None:
217
- self.model.device_ = device
218
- if self.model.model_ is not None:
219
- self.model.model_ = self.model.model_.to(device)
148
+ def _set_device(self, device: str):
149
+ device = self.to_torch_device(device)
150
+ self.model.device_ = device
151
+ self.model.model_ = self.model.model_.to(device)
220
152
 
221
153
  @classmethod
222
154
  def supported_problem_types(cls) -> list[str] | None:
@@ -331,6 +263,15 @@ class TabMModel(AbstractModel):
331
263
 
332
264
  return mem_total
333
265
 
266
+ def _get_default_auxiliary_params(self) -> dict:
267
+ default_auxiliary_params = super()._get_default_auxiliary_params()
268
+ default_auxiliary_params.update(
269
+ {
270
+ "max_batch_size": 16384, # avoid excessive VRAM usage
271
+ }
272
+ )
273
+ return default_auxiliary_params
274
+
334
275
  @classmethod
335
276
  def get_tabm_auto_batch_size(cls, n_samples: int) -> int:
336
277
  # by Yury Gorishniy, inferred from the choices in the TabM paper.
@@ -348,7 +289,10 @@ class TabMModel(AbstractModel):
348
289
 
349
290
  @classmethod
350
291
  def _class_tags(cls):
351
- return {"can_estimate_memory_usage_static": True}
292
+ return {
293
+ "can_estimate_memory_usage_static": True,
294
+ "reset_torch_threads": True,
295
+ }
352
296
 
353
297
  def _more_tags(self) -> dict:
354
298
  # TODO: Need to add train params support, track best epoch