autogluon.tabular 1.4.1b20251212__py3-none-any.whl → 1.5.0b20251220__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (43) hide show
  1. autogluon/tabular/configs/hyperparameter_configs.py +4 -0
  2. autogluon/tabular/configs/presets_configs.py +39 -2
  3. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +2 -44
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +2 -0
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +2 -0
  6. autogluon/tabular/learner/default_learner.py +1 -0
  7. autogluon/tabular/models/__init__.py +3 -1
  8. autogluon/tabular/models/abstract/__init__.py +0 -0
  9. autogluon/tabular/models/abstract/abstract_torch_model.py +148 -0
  10. autogluon/tabular/models/catboost/catboost_model.py +1 -1
  11. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +5 -1
  12. autogluon/tabular/models/lgb/lgb_model.py +58 -8
  13. autogluon/tabular/models/lgb/lgb_utils.py +2 -2
  14. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +14 -1
  15. autogluon/tabular/models/mitra/mitra_model.py +53 -22
  16. autogluon/tabular/models/realmlp/realmlp_model.py +8 -2
  17. autogluon/tabular/models/tabdpt/__init__.py +0 -0
  18. autogluon/tabular/models/tabdpt/tabdpt_model.py +253 -0
  19. autogluon/tabular/models/tabicl/tabicl_model.py +15 -2
  20. autogluon/tabular/models/tabm/tabm_model.py +23 -79
  21. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +451 -0
  22. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +86 -8
  23. autogluon/tabular/models/tabprep/__init__.py +0 -0
  24. autogluon/tabular/models/tabprep/prep_lgb_model.py +21 -0
  25. autogluon/tabular/models/tabprep/prep_mixin.py +220 -0
  26. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +1 -1
  27. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +12 -4
  28. autogluon/tabular/models/xgboost/xgboost_model.py +2 -0
  29. autogluon/tabular/predictor/predictor.py +47 -18
  30. autogluon/tabular/registry/_ag_model_registry.py +8 -2
  31. autogluon/tabular/testing/fit_helper.py +33 -0
  32. autogluon/tabular/trainer/abstract_trainer.py +45 -9
  33. autogluon/tabular/trainer/auto_trainer.py +5 -0
  34. autogluon/tabular/version.py +1 -1
  35. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/METADATA +38 -37
  36. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/RECORD +43 -33
  37. /autogluon.tabular-1.4.1b20251212-py3.11-nspkg.pth → /autogluon.tabular-1.5.0b20251220-py3.11-nspkg.pth +0 -0
  38. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/WHEEL +0 -0
  39. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/licenses/LICENSE +0 -0
  40. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/licenses/NOTICE +0 -0
  41. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/namespace_packages.txt +0 -0
  42. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/top_level.txt +0 -0
  43. {autogluon_tabular-1.4.1b20251212.dist-info → autogluon_tabular-1.5.0b20251220.dist-info}/zip-safe +0 -0
@@ -0,0 +1,451 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ import os
5
+ from pathlib import Path
6
+ from typing import TYPE_CHECKING
7
+
8
+ from autogluon.common.utils.resource_utils import ResourceManager
9
+ from autogluon.features.generators import LabelEncoderFeatureGenerator
10
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
11
+
12
+ if TYPE_CHECKING:
13
+ import numpy as np
14
+ import pandas as pd
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+ _HAS_LOGGED_TABPFN_LICENSE: bool = False
19
+ _HAS_LOGGED_TABPFN_NONCOMMERICAL: bool = False
20
+ _HAS_LOGGED_TABPFN_CPU_WARNING: bool = False
21
+
22
+
23
+ class TabPFNModel(AbstractTorchModel):
24
+ """TabPFN-2.5 is a tabular foundation model that is developed and maintained by PriorLabs: https://priorlabs.ai/.
25
+
26
+ This class is an abstract template for various TabPFN versions as subclasses.
27
+
28
+ Paper: Accurate predictions on small data with a tabular foundation model
29
+ Authors: Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin Hoo, Robin Tibor Schirrmeister & Frank Hutter
30
+ Codebase: https://github.com/PriorLabs/TabPFN
31
+ License: https://github.com/PriorLabs/TabPFN/blob/main/LICENSE
32
+
33
+ .. versionadded:: 1.5.0
34
+ """
35
+
36
+ ag_key = "NOTSET"
37
+ ag_name = "NOTSET"
38
+ ag_priority = 40
39
+ seed_name = "random_state"
40
+
41
+ custom_model_dir: str | None = None
42
+ default_classification_model: str | None = "NOTSET"
43
+ default_regression_model: str | None = "NOTSET"
44
+
45
+ def __init__(self, **kwargs):
46
+ super().__init__(**kwargs)
47
+ self._feature_generator = None
48
+ self._cat_features = None
49
+ self._cat_indices = None
50
+
51
+ def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> pd.DataFrame:
52
+ X = super()._preprocess(X, **kwargs)
53
+
54
+ if is_train:
55
+ self._cat_indices = []
56
+
57
+ # X will be the training data.
58
+ self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
59
+ self._feature_generator.fit(X=X)
60
+
61
+ # This converts categorical features to numeric via stateful label encoding.
62
+ if self._feature_generator.features_in:
63
+ X = X.copy()
64
+ X[self._feature_generator.features_in] = self._feature_generator.transform(
65
+ X=X
66
+ )
67
+
68
+ if is_train:
69
+ # Detect/set cat features and indices
70
+ if self._cat_features is None:
71
+ self._cat_features = self._feature_generator.features_in[:]
72
+ self._cat_indices = [
73
+ X.columns.get_loc(col) for col in self._cat_features
74
+ ]
75
+
76
+ return X
77
+
78
+ def _fit(
79
+ self,
80
+ X: pd.DataFrame,
81
+ y: pd.Series,
82
+ num_cpus: int = 1,
83
+ num_gpus: int = 0,
84
+ time_limit: float | None = None,
85
+ verbosity: int = 2,
86
+ **kwargs,
87
+ ):
88
+ if not self.params_aux.get("model_telemetry", False):
89
+ self.disable_tabpfn_telemetry()
90
+
91
+ from tabpfn import TabPFNClassifier, TabPFNRegressor
92
+ from tabpfn.model.loading import resolve_model_path
93
+ from torch.cuda import is_available
94
+
95
+ is_classification = self.problem_type in ["binary", "multiclass"]
96
+
97
+ model_base = TabPFNClassifier if is_classification else TabPFNRegressor
98
+
99
+ device = "cuda" if num_gpus != 0 else "cpu"
100
+ if (device == "cuda") and (not is_available()):
101
+ raise AssertionError(
102
+ "Fit specified to use GPU, but CUDA is not available on this machine. "
103
+ "Please switch to CPU usage instead.",
104
+ )
105
+
106
+ if verbosity >= 2:
107
+ # logs "Built with PriorLabs-TabPFN"
108
+ self._log_license(device=device)
109
+ self._log_cpu_warning(device=device)
110
+
111
+ X = self.preprocess(X, is_train=True)
112
+
113
+ hps = self._get_model_params()
114
+ hps["device"] = device
115
+ hps["n_jobs"] = num_cpus # FIXME: remove this, it doesn't do anything, use n_preprocessing_jobs??
116
+ hps["categorical_features_indices"] = self._cat_indices
117
+
118
+ # Resolve preprocessing
119
+ if "preprocessing/scaling" in hps:
120
+ hps["inference_config/PREPROCESS_TRANSFORMS"] = [
121
+ {
122
+ "name": scaler,
123
+ "global_transformer_name": hps.pop("preprocessing/global", None),
124
+ "categorical_name": hps.pop(
125
+ "preprocessing/categoricals", "numeric"
126
+ ),
127
+ "append_original": hps.pop("preprocessing/append_original", True),
128
+ }
129
+ for scaler in hps["preprocessing/scaling"]
130
+ ]
131
+ for k in [
132
+ "preprocessing/scaling",
133
+ "preprocessing/categoricals",
134
+ "preprocessing/append_original",
135
+ "preprocessing/global",
136
+ ]:
137
+ hps.pop(k, None)
138
+
139
+ # Remove task specific HPs
140
+ if is_classification:
141
+ hps.pop("inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS", None)
142
+ else:
143
+ hps.pop("balance_probabilities", None)
144
+
145
+ # Resolve model_path
146
+ if self.custom_model_dir is not None:
147
+ model_dir = Path(self.custom_model_dir)
148
+ else:
149
+ _, model_dir, _, _ = resolve_model_path(
150
+ model_path=None,
151
+ which="classifier" if is_classification else "regressor",
152
+ )
153
+ model_dir = model_dir[0]
154
+ clf_path, reg_path = hps.pop(
155
+ "zip_model_path",
156
+ [self.default_classification_model, self.default_regression_model],
157
+ )
158
+ model_path = clf_path if is_classification else reg_path
159
+ if model_path is not None:
160
+ hps["model_path"] = model_dir / model_path
161
+
162
+ # Resolve inference_config
163
+ inference_config = {
164
+ _k: v
165
+ for k, v in hps.items()
166
+ if k.startswith("inference_config/") and (_k := k.split("/")[-1])
167
+ }
168
+ if inference_config:
169
+ hps["inference_config"] = inference_config
170
+ for k in list(hps.keys()):
171
+ if k.startswith("inference_config/"):
172
+ del hps[k]
173
+
174
+ # Model and fit
175
+ self.model = model_base(**hps)
176
+ self.model = self.model.fit(
177
+ X=X,
178
+ y=y,
179
+ )
180
+
181
+ def _predict_proba(self, X, **kwargs) -> np.ndarray:
182
+ if not self.params_aux.get("model_telemetry", False):
183
+ self.disable_tabpfn_telemetry()
184
+ return super()._predict_proba(X=X, kwargs=kwargs)
185
+
186
+ def _get_default_resources(self) -> tuple[int, int]:
187
+ # Use only physical cores for better performance based on benchmarks
188
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
189
+
190
+ num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
191
+
192
+ return num_cpus, num_gpus
193
+
194
+ def get_minimum_resources(
195
+ self, is_gpu_available: bool = False
196
+ ) -> dict[str, int | float]:
197
+ return {
198
+ "num_cpus": 1,
199
+ "num_gpus": 1 if is_gpu_available else 0,
200
+ }
201
+
202
+ def _set_default_params(self):
203
+ default_params = {
204
+ "ignore_pretraining_limits": True, # to ignore warnings and size limits
205
+ }
206
+ for param, val in default_params.items():
207
+ self._set_default_param_value(param, val)
208
+
209
+ def get_device(self) -> str:
210
+ return self.model.devices_[0].type
211
+
212
+ def _set_device(self, device: str):
213
+ self.model.to(device)
214
+
215
+ @classmethod
216
+ def supported_problem_types(cls) -> list[str] | None:
217
+ return ["binary", "multiclass", "regression"]
218
+
219
+ def _get_default_auxiliary_params(self) -> dict:
220
+ default_auxiliary_params = super()._get_default_auxiliary_params()
221
+ default_auxiliary_params.update(
222
+ {
223
+ "max_rows": 100_000,
224
+ "max_features": 2000,
225
+ "max_classes": 10,
226
+ "model_telemetry": False,
227
+ }
228
+ )
229
+ return default_auxiliary_params
230
+
231
+ @classmethod
232
+ def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
233
+ """Set fold_fitting_strategy to sequential_local,
234
+ as parallel folding crashes if model weights aren't pre-downloaded.
235
+ """
236
+ default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
237
+ extra_ag_args_ensemble = {
238
+ # FIXME: Find a work-around to avoid crash if parallel and weights are not downloaded
239
+ "fold_fitting_strategy": "sequential_local",
240
+ "refit_folds": True, # Better to refit the model for faster inference and similar quality as the bag.
241
+ }
242
+ default_ag_args_ensemble.update(extra_ag_args_ensemble)
243
+ return default_ag_args_ensemble
244
+
245
+ def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
246
+ hyperparameters = self._get_model_params()
247
+ return self.estimate_memory_usage_static(
248
+ X=X,
249
+ problem_type=self.problem_type,
250
+ num_classes=self.num_classes,
251
+ hyperparameters=hyperparameters,
252
+ **kwargs,
253
+ )
254
+
255
+ @classmethod
256
+ def disable_tabpfn_telemetry(cls):
257
+ os.environ["TABPFN_DISABLE_TELEMETRY"] = "1"
258
+
259
+ @classmethod
260
+ def _estimate_memory_usage_static(
261
+ cls,
262
+ *,
263
+ X: pd.DataFrame,
264
+ hyperparameters: dict | None = None,
265
+ **kwargs,
266
+ ) -> int:
267
+ """Heuristic memory estimate based on TabPFN's memory estimate logic in:
268
+ https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
269
+
270
+ This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
271
+ """
272
+ # TODO: update, this is not correct anymore, consider using internal TabPFN functions directly.
273
+ features_per_group = 3 # Based on TabPFNv2 default (unused)
274
+ n_layers = 12 # Based on TabPFNv2 default
275
+ embedding_size = 192 # Based on TabPFNv2 default
276
+ dtype_byte_size = 2 # Based on TabPFNv2 default
277
+
278
+ model_mem = 14489108 # Based on TabPFNv2 default
279
+
280
+ n_samples, n_features = X.shape[0], min(X.shape[1], 2000)
281
+ n_feature_groups = (
282
+ n_features
283
+ ) / features_per_group + 1 # TODO: Unsure how to calculate this
284
+
285
+ X_mem = n_samples * n_feature_groups * dtype_byte_size
286
+ activation_mem = (
287
+ n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
288
+ )
289
+
290
+ baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
291
+
292
+ # Add some buffer to each term + 1 GB overhead to be safe
293
+ return int(
294
+ model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
295
+ )
296
+
297
+ @classmethod
298
+ def _class_tags(cls):
299
+ return {"can_estimate_memory_usage_static": True}
300
+
301
+ def _more_tags(self) -> dict:
302
+ return {"can_refit_full": True}
303
+
304
+ @staticmethod
305
+ def extra_checkpoints_for_tuning(problem_type: str) -> list[str]:
306
+ raise NotImplementedError("This method must be implemented in the subclass.")
307
+
308
+ def _log_license(self, device: str):
309
+ pass
310
+
311
+ def _log_cpu_warning(self, device: str):
312
+ global _HAS_LOGGED_TABPFN_CPU_WARNING
313
+ if not _HAS_LOGGED_TABPFN_CPU_WARNING:
314
+ if device == "cpu":
315
+ logger.log(
316
+ 20,
317
+ "\tRunning TabPFN on CPU. This can be very slow. "
318
+ "It is recommended to run TabPFN on a GPU."
319
+ )
320
+ _HAS_LOGGED_TABPFN_CPU_WARNING = True
321
+
322
+ class RealTabPFNv25Model(TabPFNModel):
323
+ """RealTabPFN-v2.5 version: https://priorlabs.ai/technical-reports/tabpfn-2-5-model-report.
324
+
325
+ We name this model RealTabPFN-v2.5 as its default checkpoints were trained on
326
+ real-world datasets, following the naming conventions of Prior Labs.
327
+ The extra checkpoints include models trained on only synthetic datasets as well.
328
+
329
+ .. versionadded:: 1.5.0
330
+ """
331
+
332
+ ag_key = "REALTABPFN-V2.5"
333
+ ag_name = "RealTabPFN-v2.5"
334
+
335
+ default_classification_model: str | None = (
336
+ "tabpfn-v2.5-classifier-v2.5_default.ckpt"
337
+ )
338
+ default_regression_model: str | None = "tabpfn-v2.5-regressor-v2.5_default.ckpt"
339
+
340
+ @staticmethod
341
+ def extra_checkpoints_for_tuning(problem_type: str) -> list[str]:
342
+ """The list of checkpoints to use for hyperparameter tuning."""
343
+ if problem_type == "classification":
344
+ return [
345
+ "tabpfn-v2.5-classifier-v2.5_default-2.ckpt",
346
+ "tabpfn-v2.5-classifier-v2.5_large-features-L.ckpt",
347
+ "tabpfn-v2.5-classifier-v2.5_large-features-XL.ckpt",
348
+ "tabpfn-v2.5-classifier-v2.5_large-samples.ckpt",
349
+ "tabpfn-v2.5-classifier-v2.5_real-large-features.ckpt",
350
+ "tabpfn-v2.5-classifier-v2.5_real-large-samples-and-features.ckpt",
351
+ "tabpfn-v2.5-classifier-v2.5_real.ckpt",
352
+ "tabpfn-v2.5-classifier-v2.5_variant.ckpt",
353
+ ]
354
+
355
+ return [
356
+ "tabpfn-v2.5-regressor-v2.5_low-skew.ckpt",
357
+ "tabpfn-v2.5-regressor-v2.5_quantiles.ckpt",
358
+ "tabpfn-v2.5-regressor-v2.5_real-variant.ckpt",
359
+ "tabpfn-v2.5-regressor-v2.5_real.ckpt",
360
+ "tabpfn-v2.5-regressor-v2.5_small-samples.ckpt",
361
+ "tabpfn-v2.5-regressor-v2.5_variant.ckpt",
362
+ ]
363
+
364
+ def _log_license(self, device: str):
365
+ global _HAS_LOGGED_TABPFN_NONCOMMERICAL
366
+ if not _HAS_LOGGED_TABPFN_NONCOMMERICAL:
367
+ logger.log(
368
+ 30,
369
+ "\tWarning: TabPFN-2.5 is a NONCOMMERCIAL model. "
370
+ "Usage of this artifact (including through AutoGluon) is not permitted "
371
+ "for commercial tasks unless granted explicit permission "
372
+ "by the model authors (PriorLabs)."
373
+ ) # Aligning with TabPFNv25 license
374
+ _HAS_LOGGED_TABPFN_NONCOMMERICAL = True # Avoid repeated logging
375
+
376
+
377
+ class RealTabPFNv2Model(TabPFNModel):
378
+ """RealTabPFN-v2 version
379
+
380
+ We name this model RealTabPFN-v2 as its default checkpoints were trained on
381
+ real-world datasets, following the naming conventions of Prior Labs.
382
+ The extra checkpoints include models trained on only synthetic datasets as well.
383
+
384
+ .. versionadded:: 1.5.0
385
+ """
386
+
387
+ ag_key = "REALTABPFN-V2"
388
+ ag_name = "RealTabPFN-v2"
389
+
390
+ # TODO: Verify if this is the same as the "default" ckpt
391
+ default_classification_model: str | None = (
392
+ "tabpfn-v2-classifier-finetuned-zk73skhh.ckpt"
393
+ )
394
+ default_regression_model: str | None = "tabpfn-v2-regressor-v2_default.ckpt"
395
+
396
+ def _get_default_auxiliary_params(self) -> dict:
397
+ default_auxiliary_params = super()._get_default_auxiliary_params()
398
+ default_auxiliary_params.update(
399
+ {
400
+ "max_rows": 10_000,
401
+ "max_features": 500,
402
+ "max_classes": 10,
403
+ "max_batch_size": 10000, # TabPFN seems to cryptically error if predicting on 100,000 samples.
404
+ }
405
+ )
406
+ return default_auxiliary_params
407
+
408
+ def _log_license(self, device: str):
409
+ global _HAS_LOGGED_TABPFN_LICENSE
410
+ if not _HAS_LOGGED_TABPFN_LICENSE:
411
+ logger.log(20, "\tBuilt with PriorLabs-TabPFN") # Aligning with TabPFNv2 license requirements
412
+ _HAS_LOGGED_TABPFN_LICENSE = True # Avoid repeated logging
413
+
414
+ # FIXME: Avoid code dupe. This one has 500 features max, 2.5 has 2000.
415
+ @classmethod
416
+ def _estimate_memory_usage_static(
417
+ cls,
418
+ *,
419
+ X: pd.DataFrame,
420
+ hyperparameters: dict | None = None,
421
+ **kwargs,
422
+ ) -> int:
423
+ """Heuristic memory estimate based on TabPFN's memory estimate logic in:
424
+ https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
425
+
426
+ This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
427
+ """
428
+ # TODO: update, this is not correct anymore, consider using internal TabPFN functions directly.
429
+ features_per_group = 3 # Based on TabPFNv2 default (unused)
430
+ n_layers = 12 # Based on TabPFNv2 default
431
+ embedding_size = 192 # Based on TabPFNv2 default
432
+ dtype_byte_size = 2 # Based on TabPFNv2 default
433
+
434
+ model_mem = 14489108 # Based on TabPFNv2 default
435
+
436
+ n_samples, n_features = X.shape[0], min(X.shape[1], 500)
437
+ n_feature_groups = (
438
+ n_features
439
+ ) / features_per_group + 1 # TODO: Unsure how to calculate this
440
+
441
+ X_mem = n_samples * n_feature_groups * dtype_byte_size
442
+ activation_mem = (
443
+ n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
444
+ )
445
+
446
+ baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
447
+
448
+ # Add some buffer to each term + 1 GB overhead to be safe
449
+ return int(
450
+ model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
451
+ )
@@ -6,14 +6,16 @@ from __future__ import annotations
6
6
 
7
7
  import logging
8
8
  import warnings
9
+ from pathlib import Path
9
10
  from typing import TYPE_CHECKING, Any
10
11
 
11
12
  import numpy as np
12
13
  import scipy
13
14
  from sklearn.preprocessing import PowerTransformer
15
+ from typing_extensions import Self
14
16
 
15
17
  from autogluon.common.utils.resource_utils import ResourceManager
16
- from autogluon.core.models import AbstractModel
18
+ from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
17
19
  from autogluon.features.generators import LabelEncoderFeatureGenerator
18
20
  from autogluon.tabular import __version__
19
21
 
@@ -104,7 +106,8 @@ class FixedSafePowerTransformer(PowerTransformer):
104
106
  return self._revert_failed_features(transformed_X, X) # type: ignore
105
107
 
106
108
 
107
- class TabPFNV2Model(AbstractModel):
109
+ # FIXME: Need to take this logic into v6 for loading on CPU
110
+ class TabPFNV2Model(AbstractTorchModel):
108
111
  """
109
112
  TabPFNv2 is a tabular foundation model pre-trained purely on synthetic data that achieves
110
113
  state-of-the-art results with in-context learning on small datasets with <=10000 samples and <=500 features.
@@ -126,6 +129,7 @@ class TabPFNV2Model(AbstractModel):
126
129
 
127
130
  def __init__(self, **kwargs):
128
131
  super().__init__(**kwargs)
132
+ self._cached_model = False
129
133
  self._feature_generator = None
130
134
  self._cat_features = None
131
135
  self._cat_indices = None
@@ -155,6 +159,12 @@ class TabPFNV2Model(AbstractModel):
155
159
 
156
160
  return X
157
161
 
162
+ def _get_model_cls(self):
163
+ from tabpfn import TabPFNClassifier, TabPFNRegressor
164
+ is_classification = self.problem_type in ["binary", "multiclass"]
165
+ model_base = TabPFNClassifier if is_classification else TabPFNRegressor
166
+ return model_base
167
+
158
168
  # FIXME: Crashes during model download if bagging with parallel fit.
159
169
  # Consider adopting same download logic as TabPFNMix which doesn't crash during model download.
160
170
  # FIXME: Maybe support child_oof somehow with using only one model and being smart about inference time?
@@ -179,13 +189,12 @@ class TabPFNV2Model(AbstractModel):
179
189
 
180
190
  preprocessing.SafePowerTransformer = FixedSafePowerTransformer
181
191
 
182
- from tabpfn import TabPFNClassifier, TabPFNRegressor
183
- from tabpfn.model.loading import resolve_model_path
184
- from torch.cuda import is_available
185
-
186
192
  is_classification = self.problem_type in ["binary", "multiclass"]
187
193
 
188
- model_base = TabPFNClassifier if is_classification else TabPFNRegressor
194
+ model_base = self._get_model_cls()
195
+
196
+ from tabpfn.model.loading import resolve_model_path
197
+ from torch.cuda import is_available
189
198
 
190
199
  device = "cuda" if num_gpus != 0 else "cpu"
191
200
  if (device == "cuda") and (not is_available()):
@@ -279,6 +288,69 @@ class TabPFNV2Model(AbstractModel):
279
288
  y=y,
280
289
  )
281
290
 
291
+ def get_device(self) -> str:
292
+ return self.model.device_.type
293
+
294
+ def _set_device(self, device: str):
295
+ pass # TODO: Unknown how to properly set device for TabPFN after loading. Refer to `_set_device_tabpfn`.
296
+
297
+ # FIXME: This is not comprehensive. Need model authors to add an official API set_device
298
+ def _set_device_tabpfn(self, device: str):
299
+ import torch
300
+ # Move all torch components to the target device
301
+ device = self.to_torch_device(device)
302
+ self.model.device_ = device
303
+ if hasattr(self.model.executor_, "model") and self.model.executor_.model is not None:
304
+ self.model.executor_.model.to(self.model.device_)
305
+ if hasattr(self.model.executor_, "models"):
306
+ self.model.executor_.models = [m.to(self.model.device_) for m in self.model.executor_.models]
307
+
308
+ # Restore other potential torch objects from fitted_attrs
309
+ for key, value in vars(self.model).items():
310
+ if key.endswith("_") and hasattr(value, "to"):
311
+ setattr(self.model, key, value.to(self.model.device_))
312
+
313
+ def model_weights_path(self, path: str | None = None) -> Path:
314
+ if path is None:
315
+ path = self.path
316
+ return Path(path) / "config.tabpfn_fit"
317
+
318
+ def save(self, path: str = None, verbose=True) -> str:
319
+ _model = self.model
320
+ is_fit = self.is_fit()
321
+ if is_fit:
322
+ self._save_model_artifact(path=path)
323
+ self._cached_model = True
324
+ self.model = None
325
+ path = super().save(path=path, verbose=verbose)
326
+ if is_fit:
327
+ self.model = _model
328
+ return path
329
+
330
+ # TODO: It is required to do this because it is unknown how to otherwise save TabPFN in CPU-only mode.
331
+ # Even though we would generally prefer to save it in the pkl for better insurance
332
+ # that the model will work in future (self-contained)
333
+ def _save_model_artifact(self, path: str | None = None):
334
+ # save with CPU device so it can be loaded on a CPU only machine
335
+ device_og = self.device
336
+ self._set_device_tabpfn(device="cpu")
337
+ self.model.save_fit_state(path=self.model_weights_path(path=path))
338
+ self._set_device_tabpfn(device=device_og)
339
+
340
+ @classmethod
341
+ def load(cls, path: str, reset_paths=True, verbose=True) -> Self:
342
+ model = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
343
+ if model._cached_model:
344
+ model._load_model_artifact()
345
+ model._cached_model = False
346
+ return model
347
+
348
+ def _load_model_artifact(self):
349
+ model_cls = self._get_model_cls()
350
+ device = self.suggest_device_infer()
351
+ self.model = model_cls.load_from_fit_state(path=self.model_weights_path(), device=device)
352
+ self.device = device
353
+
282
354
  def _log_license(self, device: str):
283
355
  global _HAS_LOGGED_TABPFN_LICENSE
284
356
  if not _HAS_LOGGED_TABPFN_LICENSE:
@@ -317,6 +389,7 @@ class TabPFNV2Model(AbstractModel):
317
389
  "max_rows": 10000,
318
390
  "max_features": 500,
319
391
  "max_classes": 10,
392
+ "max_batch_size": 10000, # TabPFN seems to cryptically error if predicting on 100,000 samples.
320
393
  }
321
394
  )
322
395
  return default_auxiliary_params
@@ -382,7 +455,12 @@ class TabPFNV2Model(AbstractModel):
382
455
 
383
456
  @classmethod
384
457
  def _class_tags(cls):
385
- return {"can_estimate_memory_usage_static": True}
458
+ return {
459
+ "can_estimate_memory_usage_static": True,
460
+ "can_set_device": True,
461
+ "set_device_on_save_to": None,
462
+ "set_device_on_load": False,
463
+ }
386
464
 
387
465
  def _more_tags(self) -> dict:
388
466
  return {"can_refit_full": True}
File without changes
@@ -0,0 +1,21 @@
1
+ from __future__ import annotations
2
+
3
+ from ..lgb.lgb_model import LGBModel
4
+ from .prep_mixin import ModelAgnosticPrepMixin
5
+
6
+
7
+ class PrepLGBModel(ModelAgnosticPrepMixin, LGBModel):
8
+ ag_key = "GBM_PREP"
9
+ ag_name = "LightGBMPrep"
10
+
11
+ @classmethod
12
+ def _estimate_memory_usage_static(cls, **kwargs) -> int:
13
+ memory_usage = super()._estimate_memory_usage_static(**kwargs)
14
+ # FIXME: 1.5 runs OOM on kddcup09_appetency fold 2 repeat 0 prep_LightGBM_r49_BAG_L1
15
+ return memory_usage * 2.0 # FIXME: For some reason this underestimates mem usage without this
16
+
17
+ @classmethod
18
+ def _estimate_memory_usage_static_lite(cls, **kwargs) -> int:
19
+ memory_usage = super()._estimate_memory_usage_static_lite(**kwargs)
20
+ # FIXME: 1.5 runs OOM on kddcup09_appetency fold 2 repeat 0 prep_LightGBM_r49_BAG_L1
21
+ return memory_usage * 2.0 # FIXME: For some reason this underestimates mem usage without this