autogluon.tabular 1.4.1b20251014__py3-none-any.whl → 1.5.0b20251222__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/hyperparameter_configs.py +4 -0
- autogluon/tabular/configs/presets_configs.py +39 -2
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +2 -44
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +2 -0
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +2 -0
- autogluon/tabular/learner/default_learner.py +1 -0
- autogluon/tabular/models/__init__.py +3 -1
- autogluon/tabular/models/abstract/__init__.py +0 -0
- autogluon/tabular/models/abstract/abstract_torch_model.py +148 -0
- autogluon/tabular/models/catboost/catboost_model.py +2 -5
- autogluon/tabular/models/ebm/ebm_model.py +2 -6
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +9 -3
- autogluon/tabular/models/lgb/lgb_model.py +60 -17
- autogluon/tabular/models/lgb/lgb_utils.py +2 -2
- autogluon/tabular/models/lr/lr_model.py +2 -4
- autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +14 -1
- autogluon/tabular/models/mitra/mitra_model.py +55 -29
- autogluon/tabular/models/realmlp/realmlp_model.py +8 -5
- autogluon/tabular/models/rf/rf_model.py +6 -8
- autogluon/tabular/models/tabdpt/__init__.py +0 -0
- autogluon/tabular/models/tabdpt/tabdpt_model.py +253 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +15 -5
- autogluon/tabular/models/tabm/tabm_model.py +25 -8
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +7 -5
- autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +451 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +87 -12
- autogluon/tabular/models/tabprep/__init__.py +0 -0
- autogluon/tabular/models/tabprep/prep_lgb_model.py +21 -0
- autogluon/tabular/models/tabprep/prep_mixin.py +220 -0
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +3 -6
- autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +12 -4
- autogluon/tabular/models/xgboost/xgboost_model.py +3 -4
- autogluon/tabular/predictor/predictor.py +50 -20
- autogluon/tabular/registry/_ag_model_registry.py +8 -2
- autogluon/tabular/testing/fit_helper.py +61 -0
- autogluon/tabular/trainer/abstract_trainer.py +45 -9
- autogluon/tabular/trainer/auto_trainer.py +5 -0
- autogluon/tabular/version.py +1 -1
- autogluon.tabular-1.5.0b20251222-py3.11-nspkg.pth +1 -0
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/METADATA +97 -87
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/RECORD +48 -38
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/WHEEL +1 -1
- autogluon.tabular-1.4.1b20251014-py3.9-nspkg.pth +0 -1
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info/licenses}/LICENSE +0 -0
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info/licenses}/NOTICE +0 -0
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.4.1b20251014.dist-info → autogluon_tabular-1.5.0b20251222.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,253 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
|
|
5
|
+
from autogluon.common.utils.resource_utils import ResourceManager
|
|
6
|
+
from autogluon.core.constants import BINARY, MULTICLASS, REGRESSION
|
|
7
|
+
from autogluon.features.generators import LabelEncoderFeatureGenerator
|
|
8
|
+
from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
|
|
9
|
+
|
|
10
|
+
if TYPE_CHECKING:
|
|
11
|
+
import numpy as np
|
|
12
|
+
import pandas as pd
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
# FIXME: Nick:
|
|
16
|
+
# TODO: batch_size is linear to memory usage
|
|
17
|
+
# 512 default
|
|
18
|
+
# should be less for very large datasets
|
|
19
|
+
# 128 batch_size on Bioresponse -> 12 GB VRAM
|
|
20
|
+
# Train Data Rows: 2500
|
|
21
|
+
# Train Data Columns: 1776
|
|
22
|
+
# Problem Type: binary
|
|
23
|
+
# FIXME: Just set context_size = infinity, everything is way faster, memory usage is way lower, etc.
|
|
24
|
+
# Train Data Rows: 100000
|
|
25
|
+
# Train Data Columns: 10
|
|
26
|
+
# binary
|
|
27
|
+
# only takes 6.7 GB during inference with batch_size = 512
|
|
28
|
+
# FIXME: Make it work when loading on CPU?
|
|
29
|
+
# FIXME: Can we run 8 in parallel to speed up?
|
|
30
|
+
# TODO: clip_sigma == 1 is terrible, clip_sigma == 16 maybe very good? What about higher values?
|
|
31
|
+
# clip_sigma >= 16 is roughly all equivalent
|
|
32
|
+
# FIXME: TabDPT stores self.X_test for no reason
|
|
33
|
+
# FIXME: TabDPT creates faiss_knn even if it is never used. Better if `context_size=None` means it is never created.
|
|
34
|
+
# TODO: unit test
|
|
35
|
+
# TODO: memory estimate
|
|
36
|
+
class TabDPTModel(AbstractTorchModel):
|
|
37
|
+
ag_key = "TABDPT"
|
|
38
|
+
ag_name = "TabDPT"
|
|
39
|
+
seed_name = "seed"
|
|
40
|
+
ag_priority = 50
|
|
41
|
+
default_random_seed = 0
|
|
42
|
+
|
|
43
|
+
def __init__(self, **kwargs):
|
|
44
|
+
super().__init__(**kwargs)
|
|
45
|
+
self._feature_generator = None
|
|
46
|
+
self._predict_hps = None
|
|
47
|
+
self._use_flash_og = None
|
|
48
|
+
|
|
49
|
+
def _fit(
|
|
50
|
+
self,
|
|
51
|
+
X: pd.DataFrame,
|
|
52
|
+
y: pd.Series,
|
|
53
|
+
num_cpus: int = 1,
|
|
54
|
+
num_gpus: int = 0,
|
|
55
|
+
**kwargs,
|
|
56
|
+
):
|
|
57
|
+
from torch.cuda import is_available
|
|
58
|
+
|
|
59
|
+
device = "cuda" if num_gpus != 0 else "cpu"
|
|
60
|
+
if (device == "cuda") and (not is_available()):
|
|
61
|
+
raise AssertionError(
|
|
62
|
+
"Fit specified to use GPU, but CUDA is not available on this machine. "
|
|
63
|
+
"Please switch to CPU usage instead.",
|
|
64
|
+
)
|
|
65
|
+
from tabdpt import TabDPTClassifier, TabDPTRegressor
|
|
66
|
+
|
|
67
|
+
model_cls = (
|
|
68
|
+
TabDPTClassifier
|
|
69
|
+
if self.problem_type in [BINARY, MULTICLASS]
|
|
70
|
+
else TabDPTRegressor
|
|
71
|
+
)
|
|
72
|
+
fit_params, self._predict_hps = self._get_tabdpt_params(num_gpus=num_gpus)
|
|
73
|
+
|
|
74
|
+
X = self.preprocess(X)
|
|
75
|
+
y = y.to_numpy()
|
|
76
|
+
self.model = model_cls(
|
|
77
|
+
device=device,
|
|
78
|
+
**fit_params,
|
|
79
|
+
)
|
|
80
|
+
self.model.fit(X=X, y=y)
|
|
81
|
+
|
|
82
|
+
def _get_tabdpt_params(self, num_gpus: float) -> tuple[dict, dict]:
|
|
83
|
+
model_params = self._get_model_params()
|
|
84
|
+
|
|
85
|
+
valid_predict_params = (self.seed_name, "context_size", "permute_classes", "temperature", "n_ensembles")
|
|
86
|
+
|
|
87
|
+
predict_params = {}
|
|
88
|
+
for hp in valid_predict_params:
|
|
89
|
+
if hp in model_params:
|
|
90
|
+
predict_params[hp] = model_params.pop(hp)
|
|
91
|
+
predict_params.setdefault(self.seed_name, self.default_random_seed)
|
|
92
|
+
predict_params.setdefault("context_size", None)
|
|
93
|
+
|
|
94
|
+
supported_predict_params = (
|
|
95
|
+
(self.seed_name, "context_size", "n_ensembles", "permute_classes", "temperature")
|
|
96
|
+
if self.problem_type in [BINARY, MULTICLASS]
|
|
97
|
+
else (self.seed_name, "context_size", "n_ensembles")
|
|
98
|
+
)
|
|
99
|
+
predict_params = {key: val for key, val in predict_params.items() if key in supported_predict_params}
|
|
100
|
+
|
|
101
|
+
fit_params = model_params
|
|
102
|
+
|
|
103
|
+
fit_params.setdefault("verbose", False)
|
|
104
|
+
fit_params.setdefault("compile", False)
|
|
105
|
+
if fit_params.get("use_flash", True):
|
|
106
|
+
fit_params["use_flash"] = self._use_flash(num_gpus=num_gpus)
|
|
107
|
+
return fit_params, predict_params
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def _use_flash(num_gpus: float) -> bool:
|
|
111
|
+
"""Detect if torch's native flash attention is available on the current machine."""
|
|
112
|
+
if num_gpus == 0:
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
import torch
|
|
116
|
+
|
|
117
|
+
if not torch.cuda.is_available():
|
|
118
|
+
return False
|
|
119
|
+
|
|
120
|
+
device = torch.device("cuda:0")
|
|
121
|
+
capability = torch.cuda.get_device_capability(device)
|
|
122
|
+
|
|
123
|
+
return capability != (7, 5)
|
|
124
|
+
|
|
125
|
+
def _post_fit(self, **kwargs):
|
|
126
|
+
super()._post_fit(**kwargs)
|
|
127
|
+
self._use_flash_og = self.model.use_flash
|
|
128
|
+
return self
|
|
129
|
+
|
|
130
|
+
def get_device(self) -> str:
|
|
131
|
+
return self.model.device
|
|
132
|
+
|
|
133
|
+
def _set_device(self, device: str):
|
|
134
|
+
self.model.to(device)
|
|
135
|
+
if device == "cpu":
|
|
136
|
+
self.model.use_flash = False
|
|
137
|
+
self.model.model.use_flash = False
|
|
138
|
+
else:
|
|
139
|
+
self.model.use_flash = self._use_flash_og
|
|
140
|
+
self.model.model.use_flash = self._use_flash_og
|
|
141
|
+
|
|
142
|
+
def _get_default_resources(self) -> tuple[int, int]:
|
|
143
|
+
# Use only physical cores for better performance based on benchmarks
|
|
144
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
|
145
|
+
|
|
146
|
+
num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
|
|
147
|
+
|
|
148
|
+
return num_cpus, num_gpus
|
|
149
|
+
|
|
150
|
+
def get_minimum_resources(
|
|
151
|
+
self, is_gpu_available: bool = False
|
|
152
|
+
) -> dict[str, int | float]:
|
|
153
|
+
return {
|
|
154
|
+
"num_cpus": 1,
|
|
155
|
+
"num_gpus": 0.5 if is_gpu_available else 0,
|
|
156
|
+
}
|
|
157
|
+
|
|
158
|
+
def _predict_proba(self, X, **kwargs) -> np.ndarray:
|
|
159
|
+
X = self.preprocess(X, **kwargs)
|
|
160
|
+
|
|
161
|
+
if self.problem_type in [REGRESSION]:
|
|
162
|
+
y_pred = self.model.predict(X, **self._predict_hps)
|
|
163
|
+
return y_pred
|
|
164
|
+
|
|
165
|
+
y_pred_proba = self.model.ensemble_predict_proba(X, **self._predict_hps)
|
|
166
|
+
return self._convert_proba_to_unified_form(y_pred_proba)
|
|
167
|
+
|
|
168
|
+
def _preprocess(self, X: pd.DataFrame, **kwargs) -> pd.DataFrame:
|
|
169
|
+
"""TabDPT requires numpy array as input."""
|
|
170
|
+
X = super()._preprocess(X, **kwargs)
|
|
171
|
+
if self._feature_generator is None:
|
|
172
|
+
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
|
|
173
|
+
self._feature_generator.fit(X=X)
|
|
174
|
+
if self._feature_generator.features_in:
|
|
175
|
+
X = X.copy()
|
|
176
|
+
X[self._feature_generator.features_in] = self._feature_generator.transform(
|
|
177
|
+
X=X
|
|
178
|
+
)
|
|
179
|
+
return X.to_numpy()
|
|
180
|
+
|
|
181
|
+
@classmethod
|
|
182
|
+
def supported_problem_types(cls) -> list[str] | None:
|
|
183
|
+
return ["binary", "multiclass", "regression"]
|
|
184
|
+
|
|
185
|
+
@classmethod
|
|
186
|
+
def _class_tags(cls):
|
|
187
|
+
return {"can_estimate_memory_usage_static": True}
|
|
188
|
+
|
|
189
|
+
def _more_tags(self) -> dict:
|
|
190
|
+
return {"can_refit_full": True}
|
|
191
|
+
|
|
192
|
+
def _get_default_auxiliary_params(self) -> dict:
|
|
193
|
+
default_auxiliary_params = super()._get_default_auxiliary_params()
|
|
194
|
+
default_auxiliary_params.update(
|
|
195
|
+
{
|
|
196
|
+
"max_rows": 100000, # TODO: Test >100k rows
|
|
197
|
+
"max_features": 2500, # TODO: Test >2500 features
|
|
198
|
+
"max_classes": 10, # TODO: Test >10 classes
|
|
199
|
+
}
|
|
200
|
+
)
|
|
201
|
+
return default_auxiliary_params
|
|
202
|
+
|
|
203
|
+
@classmethod
|
|
204
|
+
def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
|
|
205
|
+
default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
|
|
206
|
+
extra_ag_args_ensemble = {
|
|
207
|
+
"refit_folds": True,
|
|
208
|
+
}
|
|
209
|
+
default_ag_args_ensemble.update(extra_ag_args_ensemble)
|
|
210
|
+
return default_ag_args_ensemble
|
|
211
|
+
|
|
212
|
+
# FIXME: This is copied from TabPFN, but TabDPT is not the same
|
|
213
|
+
@classmethod
|
|
214
|
+
def _estimate_memory_usage_static(
|
|
215
|
+
cls,
|
|
216
|
+
*,
|
|
217
|
+
X: pd.DataFrame,
|
|
218
|
+
hyperparameters: dict | None = None,
|
|
219
|
+
**kwargs,
|
|
220
|
+
) -> int:
|
|
221
|
+
"""Heuristic memory estimate based on TabPFN's memory estimate logic in:
|
|
222
|
+
https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
|
|
223
|
+
|
|
224
|
+
This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
|
|
225
|
+
"""
|
|
226
|
+
# TODO: update, this is not correct anymore, consider using internal TabPFN functions directly.
|
|
227
|
+
features_per_group = 3 # Based on TabPFNv2 default (unused)
|
|
228
|
+
n_layers = 12 # Based on TabPFNv2 default
|
|
229
|
+
embedding_size = 192 # Based on TabPFNv2 default
|
|
230
|
+
dtype_byte_size = 2 # Based on TabPFNv2 default
|
|
231
|
+
|
|
232
|
+
model_mem = 14489108 # Based on TabPFNv2 default
|
|
233
|
+
|
|
234
|
+
n_samples, n_features = X.shape[0], min(X.shape[1], 500)
|
|
235
|
+
n_feature_groups = (
|
|
236
|
+
n_features
|
|
237
|
+
) / features_per_group + 1 # TODO: Unsure how to calculate this
|
|
238
|
+
|
|
239
|
+
X_mem = n_samples * n_feature_groups * dtype_byte_size
|
|
240
|
+
activation_mem = (
|
|
241
|
+
n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
|
|
245
|
+
|
|
246
|
+
# Add some buffer to each term + 1 GB overhead to be safe
|
|
247
|
+
memory_estimate = model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
|
|
248
|
+
|
|
249
|
+
# TabDPT memory estimation is very inaccurate because it is using TabPFN memory estimate. Double it to be safe.
|
|
250
|
+
memory_estimate = memory_estimate * 2
|
|
251
|
+
|
|
252
|
+
# Note: This memory estimate is way off if `context_size` is not None
|
|
253
|
+
return int(memory_estimate)
|
|
@@ -10,14 +10,14 @@ import pandas as pd
|
|
|
10
10
|
|
|
11
11
|
from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
|
|
12
12
|
from autogluon.common.utils.resource_utils import ResourceManager
|
|
13
|
-
from autogluon.core.models import AbstractModel
|
|
14
13
|
from autogluon.tabular import __version__
|
|
14
|
+
from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
|
|
15
15
|
|
|
16
16
|
logger = logging.getLogger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
19
|
# TODO: Verify if crashes when weights are not yet downloaded and fit in parallel
|
|
20
|
-
class TabICLModel(
|
|
20
|
+
class TabICLModel(AbstractTorchModel):
|
|
21
21
|
"""
|
|
22
22
|
TabICL is a foundation model for tabular data using in-context learning
|
|
23
23
|
that is scalable to larger datasets than TabPFNv2. It is pretrained purely on synthetic data.
|
|
@@ -35,6 +35,7 @@ class TabICLModel(AbstractModel):
|
|
|
35
35
|
ag_key = "TABICL"
|
|
36
36
|
ag_name = "TabICL"
|
|
37
37
|
ag_priority = 65
|
|
38
|
+
seed_name = "random_state"
|
|
38
39
|
|
|
39
40
|
def get_model_cls(self):
|
|
40
41
|
from tabicl import TabICLClassifier
|
|
@@ -89,7 +90,6 @@ class TabICLModel(AbstractModel):
|
|
|
89
90
|
**hyp,
|
|
90
91
|
device=device,
|
|
91
92
|
n_jobs=num_cpus,
|
|
92
|
-
random_state=self.random_seed,
|
|
93
93
|
)
|
|
94
94
|
X = self.preprocess(X)
|
|
95
95
|
self.model = self.model.fit(
|
|
@@ -97,8 +97,18 @@ class TabICLModel(AbstractModel):
|
|
|
97
97
|
y=y,
|
|
98
98
|
)
|
|
99
99
|
|
|
100
|
-
def
|
|
101
|
-
return
|
|
100
|
+
def get_device(self) -> str:
|
|
101
|
+
return self.model.device_.type
|
|
102
|
+
|
|
103
|
+
# TODO: Better to have an official TabICL method for this
|
|
104
|
+
def _set_device(self, device: str):
|
|
105
|
+
device = self.to_torch_device(device)
|
|
106
|
+
self.model.device_ = device
|
|
107
|
+
self.model.device = self.model.device_.type
|
|
108
|
+
self.model.model_ = self.model.model_.to(self.model.device_)
|
|
109
|
+
self.model.inference_config_.COL_CONFIG.device = self.model.device_
|
|
110
|
+
self.model.inference_config_.ROW_CONFIG.device = self.model.device_
|
|
111
|
+
self.model.inference_config_.ICL_CONFIG.device = self.model.device_
|
|
102
112
|
|
|
103
113
|
def _get_default_auxiliary_params(self) -> dict:
|
|
104
114
|
default_auxiliary_params = super()._get_default_auxiliary_params()
|
|
@@ -15,13 +15,13 @@ import time
|
|
|
15
15
|
import pandas as pd
|
|
16
16
|
|
|
17
17
|
from autogluon.common.utils.resource_utils import ResourceManager
|
|
18
|
-
from autogluon.core.models import AbstractModel
|
|
19
18
|
from autogluon.tabular import __version__
|
|
19
|
+
from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
|
|
20
20
|
|
|
21
21
|
logger = logging.getLogger(__name__)
|
|
22
22
|
|
|
23
23
|
|
|
24
|
-
class TabMModel(
|
|
24
|
+
class TabMModel(AbstractTorchModel):
|
|
25
25
|
"""
|
|
26
26
|
TabM is an efficient ensemble of MLPs that is trained simultaneously with mostly shared parameters.
|
|
27
27
|
|
|
@@ -39,6 +39,7 @@ class TabMModel(AbstractModel):
|
|
|
39
39
|
ag_key = "TABM"
|
|
40
40
|
ag_name = "TabM"
|
|
41
41
|
ag_priority = 85
|
|
42
|
+
seed_name = "random_state"
|
|
42
43
|
|
|
43
44
|
def __init__(self, **kwargs):
|
|
44
45
|
super().__init__(**kwargs)
|
|
@@ -86,7 +87,7 @@ class TabMModel(AbstractModel):
|
|
|
86
87
|
if X_val is None:
|
|
87
88
|
from autogluon.core.utils import generate_train_test_split
|
|
88
89
|
|
|
89
|
-
|
|
90
|
+
X, X_val, y, y_val = generate_train_test_split(
|
|
90
91
|
X=X,
|
|
91
92
|
y=y,
|
|
92
93
|
problem_type=self.problem_type,
|
|
@@ -97,7 +98,7 @@ class TabMModel(AbstractModel):
|
|
|
97
98
|
hyp = self._get_model_params()
|
|
98
99
|
bool_to_cat = hyp.pop("bool_to_cat", True)
|
|
99
100
|
|
|
100
|
-
X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
|
|
101
|
+
X = self.preprocess(X, y=y, is_train=True, bool_to_cat=bool_to_cat)
|
|
101
102
|
if X_val is not None:
|
|
102
103
|
X_val = self.preprocess(X_val)
|
|
103
104
|
|
|
@@ -106,7 +107,6 @@ class TabMModel(AbstractModel):
|
|
|
106
107
|
device=device,
|
|
107
108
|
problem_type=self.problem_type,
|
|
108
109
|
early_stopping_metric=self.stopping_metric,
|
|
109
|
-
random_state=self.random_seed,
|
|
110
110
|
**hyp,
|
|
111
111
|
)
|
|
112
112
|
|
|
@@ -142,8 +142,13 @@ class TabMModel(AbstractModel):
|
|
|
142
142
|
|
|
143
143
|
return X
|
|
144
144
|
|
|
145
|
-
def
|
|
146
|
-
return
|
|
145
|
+
def get_device(self) -> str:
|
|
146
|
+
return self.model.device_.type
|
|
147
|
+
|
|
148
|
+
def _set_device(self, device: str):
|
|
149
|
+
device = self.to_torch_device(device)
|
|
150
|
+
self.model.device_ = device
|
|
151
|
+
self.model.model_ = self.model.model_.to(device)
|
|
147
152
|
|
|
148
153
|
@classmethod
|
|
149
154
|
def supported_problem_types(cls) -> list[str] | None:
|
|
@@ -258,6 +263,15 @@ class TabMModel(AbstractModel):
|
|
|
258
263
|
|
|
259
264
|
return mem_total
|
|
260
265
|
|
|
266
|
+
def _get_default_auxiliary_params(self) -> dict:
|
|
267
|
+
default_auxiliary_params = super()._get_default_auxiliary_params()
|
|
268
|
+
default_auxiliary_params.update(
|
|
269
|
+
{
|
|
270
|
+
"max_batch_size": 16384, # avoid excessive VRAM usage
|
|
271
|
+
}
|
|
272
|
+
)
|
|
273
|
+
return default_auxiliary_params
|
|
274
|
+
|
|
261
275
|
@classmethod
|
|
262
276
|
def get_tabm_auto_batch_size(cls, n_samples: int) -> int:
|
|
263
277
|
# by Yury Gorishniy, inferred from the choices in the TabM paper.
|
|
@@ -275,7 +289,10 @@ class TabMModel(AbstractModel):
|
|
|
275
289
|
|
|
276
290
|
@classmethod
|
|
277
291
|
def _class_tags(cls):
|
|
278
|
-
return {
|
|
292
|
+
return {
|
|
293
|
+
"can_estimate_memory_usage_static": True,
|
|
294
|
+
"reset_torch_threads": True,
|
|
295
|
+
}
|
|
279
296
|
|
|
280
297
|
def _more_tags(self) -> dict:
|
|
281
298
|
# TODO: Need to add train params support, track best epoch
|
|
@@ -42,6 +42,7 @@ class TabPFNMixModel(AbstractModel):
|
|
|
42
42
|
ag_key = "TABPFNMIX"
|
|
43
43
|
ag_name = "TabPFNMix"
|
|
44
44
|
ag_priority = 45
|
|
45
|
+
seed_name = "random_state"
|
|
45
46
|
|
|
46
47
|
weights_file_name = "model.pt"
|
|
47
48
|
|
|
@@ -123,6 +124,7 @@ class TabPFNMixModel(AbstractModel):
|
|
|
123
124
|
raise AssertionError(f"Max allowed classes for the model is {max_classes}, " f"but found {self.num_classes} classes.")
|
|
124
125
|
|
|
125
126
|
params = self._get_model_params()
|
|
127
|
+
random_state = params.pop(self.seed_name, self.default_random_seed)
|
|
126
128
|
sample_rows = ag_params.get("sample_rows", None)
|
|
127
129
|
sample_rows_val = ag_params.get("sample_rows_val", None)
|
|
128
130
|
max_rows = ag_params.get("max_rows", None)
|
|
@@ -133,11 +135,11 @@ class TabPFNMixModel(AbstractModel):
|
|
|
133
135
|
|
|
134
136
|
# TODO: Make sample_rows generic
|
|
135
137
|
if sample_rows is not None and isinstance(sample_rows, int) and len(X) > sample_rows:
|
|
136
|
-
X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows)
|
|
138
|
+
X, y = self._subsample_data(X=X, y=y, num_rows=sample_rows, random_state=random_state)
|
|
137
139
|
|
|
138
140
|
# TODO: Make sample_rows generic
|
|
139
141
|
if X_val is not None and y_val is not None and sample_rows_val is not None and isinstance(sample_rows_val, int) and len(X_val) > sample_rows_val:
|
|
140
|
-
X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val)
|
|
142
|
+
X_val, y_val = self._subsample_data(X=X_val, y=y_val, num_rows=sample_rows_val, random_state=random_state)
|
|
141
143
|
|
|
142
144
|
from ._internal.core.enums import Task
|
|
143
145
|
if self.problem_type in [REGRESSION, QUANTILE]:
|
|
@@ -178,7 +180,7 @@ class TabPFNMixModel(AbstractModel):
|
|
|
178
180
|
elif weights_path is not None:
|
|
179
181
|
logger.log(15, f'\tLoading pre-trained weights from file... (weights_path="{weights_path}")')
|
|
180
182
|
|
|
181
|
-
cfg = ConfigRun(hyperparams=params, task=task, device=device, seed=
|
|
183
|
+
cfg = ConfigRun(hyperparams=params, task=task, device=device, seed=random_state)
|
|
182
184
|
|
|
183
185
|
if cfg.hyperparams["max_epochs"] == 0 and cfg.hyperparams["n_ensembles"] != 1:
|
|
184
186
|
logger.log(
|
|
@@ -242,14 +244,14 @@ class TabPFNMixModel(AbstractModel):
|
|
|
242
244
|
return self
|
|
243
245
|
|
|
244
246
|
# TODO: Make this generic by creating a generic `preprocess_train` and putting this logic prior to `_preprocess`.
|
|
245
|
-
def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int) -> (pd.DataFrame, pd.Series):
|
|
247
|
+
def _subsample_data(self, X: pd.DataFrame, y: pd.Series, num_rows: int, random_state: int | None = 0) -> (pd.DataFrame, pd.Series):
|
|
246
248
|
num_rows_to_drop = len(X) - num_rows
|
|
247
249
|
X, _, y, _ = generate_train_test_split(
|
|
248
250
|
X=X,
|
|
249
251
|
y=y,
|
|
250
252
|
problem_type=self.problem_type,
|
|
251
253
|
test_size=num_rows_to_drop,
|
|
252
|
-
random_state=
|
|
254
|
+
random_state=random_state,
|
|
253
255
|
min_cls_count_train=1,
|
|
254
256
|
)
|
|
255
257
|
return X, y
|