autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/config_helper.py +1 -1
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/pipeline_presets.py +130 -0
- autogluon/tabular/configs/presets_configs.py +51 -26
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
- autogluon/tabular/models/__init__.py +6 -1
- autogluon/tabular/models/_utils/rapids_utils.py +1 -1
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +15 -9
- autogluon/tabular/models/catboost/catboost_utils.py +17 -3
- autogluon/tabular/models/ebm/__init__.py +0 -0
- autogluon/tabular/models/ebm/ebm_model.py +259 -0
- autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
- autogluon/tabular/models/knn/knn_model.py +7 -3
- autogluon/tabular/models/lgb/lgb_model.py +60 -21
- autogluon/tabular/models/lr/lr_model.py +6 -1
- autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
- autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +380 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
- autogluon/tabular/models/rf/rf_model.py +11 -6
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
- autogluon/tabular/models/tabm/tabm_model.py +356 -0
- autogluon/tabular/models/tabm/tabm_reference.py +631 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
- autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
- autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
- autogluon/tabular/predictor/predictor.py +147 -84
- autogluon/tabular/registry/_ag_model_registry.py +12 -2
- autogluon/tabular/testing/fit_helper.py +57 -27
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +3 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,356 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Code Adapted from TabArena: https://github.com/autogluon/tabrepo/blob/main/tabrepo/benchmark/models/ag/tabm/tabm_model.py
|
|
3
|
+
Note: This is a custom implementation of TabM based on TabArena. Because the AutoGluon 1.4 release occurred at nearly
|
|
4
|
+
the same time as TabM became available on PyPi, we chose to use TabArena's implementation
|
|
5
|
+
for the AutoGluon 1.4 release as it has already been benchmarked.
|
|
6
|
+
|
|
7
|
+
Partially adapted from pytabkit's TabM implementation.
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import time
|
|
14
|
+
|
|
15
|
+
import pandas as pd
|
|
16
|
+
|
|
17
|
+
from autogluon.common.utils.resource_utils import ResourceManager
|
|
18
|
+
from autogluon.core.models import AbstractModel
|
|
19
|
+
from autogluon.tabular import __version__
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TabMModel(AbstractModel):
|
|
25
|
+
"""
|
|
26
|
+
TabM is an efficient ensemble of MLPs that is trained simultaneously with mostly shared parameters.
|
|
27
|
+
|
|
28
|
+
TabM is one of the top performing methods overall on TabArena-v0.1: https://tabarena.ai
|
|
29
|
+
|
|
30
|
+
Paper: TabM: Advancing Tabular Deep Learning with Parameter-Efficient Ensembling
|
|
31
|
+
Authors: Yury Gorishniy, Akim Kotelnikov, Artem Babenko
|
|
32
|
+
Codebase: https://github.com/yandex-research/tabm
|
|
33
|
+
License: Apache-2.0
|
|
34
|
+
|
|
35
|
+
Partially adapted from pytabkit's TabM implementation.
|
|
36
|
+
|
|
37
|
+
.. versionadded:: 1.4.0
|
|
38
|
+
"""
|
|
39
|
+
ag_key = "TABM"
|
|
40
|
+
ag_name = "TabM"
|
|
41
|
+
ag_priority = 85
|
|
42
|
+
seed_name = "random_state"
|
|
43
|
+
|
|
44
|
+
def __init__(self, **kwargs):
|
|
45
|
+
super().__init__(**kwargs)
|
|
46
|
+
self._imputer = None
|
|
47
|
+
self._features_to_impute = None
|
|
48
|
+
self._features_to_keep = None
|
|
49
|
+
self._indicator_columns = None
|
|
50
|
+
self._features_bool = None
|
|
51
|
+
self._bool_to_cat = None
|
|
52
|
+
self.device = None
|
|
53
|
+
|
|
54
|
+
def _fit(
|
|
55
|
+
self,
|
|
56
|
+
X: pd.DataFrame,
|
|
57
|
+
y: pd.Series,
|
|
58
|
+
X_val: pd.DataFrame = None,
|
|
59
|
+
y_val: pd.Series = None,
|
|
60
|
+
time_limit: float | None = None,
|
|
61
|
+
num_cpus: int = 1,
|
|
62
|
+
num_gpus: float = 0,
|
|
63
|
+
**kwargs,
|
|
64
|
+
):
|
|
65
|
+
start_time = time.time()
|
|
66
|
+
|
|
67
|
+
try:
|
|
68
|
+
# imports various dependencies such as torch
|
|
69
|
+
from torch.cuda import is_available
|
|
70
|
+
|
|
71
|
+
from ._tabm_internal import TabMImplementation
|
|
72
|
+
except ImportError as err:
|
|
73
|
+
logger.log(
|
|
74
|
+
40,
|
|
75
|
+
f"\tFailed to import tabm! To use the TabM model, "
|
|
76
|
+
f"do: `pip install autogluon.tabular[tabm]=={__version__}`.",
|
|
77
|
+
)
|
|
78
|
+
raise err
|
|
79
|
+
|
|
80
|
+
device = "cpu" if num_gpus == 0 else "cuda"
|
|
81
|
+
if (device == "cuda") and (not is_available()):
|
|
82
|
+
# FIXME: warn instead and switch to CPU.
|
|
83
|
+
raise AssertionError(
|
|
84
|
+
"Fit specified to use GPU, but CUDA is not available on this machine. "
|
|
85
|
+
"Please switch to CPU usage instead.",
|
|
86
|
+
)
|
|
87
|
+
|
|
88
|
+
if X_val is None:
|
|
89
|
+
from autogluon.core.utils import generate_train_test_split
|
|
90
|
+
|
|
91
|
+
X_train, X_val, y_train, y_val = generate_train_test_split(
|
|
92
|
+
X=X,
|
|
93
|
+
y=y,
|
|
94
|
+
problem_type=self.problem_type,
|
|
95
|
+
test_size=0.2,
|
|
96
|
+
random_state=0,
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
hyp = self._get_model_params()
|
|
100
|
+
bool_to_cat = hyp.pop("bool_to_cat", True)
|
|
101
|
+
|
|
102
|
+
X = self.preprocess(X, is_train=True, bool_to_cat=bool_to_cat)
|
|
103
|
+
if X_val is not None:
|
|
104
|
+
X_val = self.preprocess(X_val)
|
|
105
|
+
|
|
106
|
+
self.model = TabMImplementation(
|
|
107
|
+
n_threads=num_cpus,
|
|
108
|
+
device=device,
|
|
109
|
+
problem_type=self.problem_type,
|
|
110
|
+
early_stopping_metric=self.stopping_metric,
|
|
111
|
+
**hyp,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
self.model.fit(
|
|
115
|
+
X_train=X,
|
|
116
|
+
y_train=y,
|
|
117
|
+
X_val=X_val,
|
|
118
|
+
y_val=y_val,
|
|
119
|
+
cat_col_names=X.select_dtypes(include="category").columns.tolist(),
|
|
120
|
+
time_to_fit_in_seconds=time_limit - (time.time() - start_time) if time_limit is not None else None,
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
# FIXME: bool_to_cat is a hack: Maybe move to abstract model?
|
|
124
|
+
def _preprocess(
|
|
125
|
+
self,
|
|
126
|
+
X: pd.DataFrame,
|
|
127
|
+
is_train: bool = False,
|
|
128
|
+
bool_to_cat: bool = False,
|
|
129
|
+
**kwargs,
|
|
130
|
+
) -> pd.DataFrame:
|
|
131
|
+
"""Imputes missing values via the mean and adds indicator columns for numerical features.
|
|
132
|
+
Converts indicator columns to categorical features to avoid them being treated as numerical by RealMLP.
|
|
133
|
+
"""
|
|
134
|
+
X = super()._preprocess(X, **kwargs)
|
|
135
|
+
|
|
136
|
+
if is_train:
|
|
137
|
+
self._bool_to_cat = bool_to_cat
|
|
138
|
+
self._features_bool = self._feature_metadata.get_features(required_special_types=["bool"])
|
|
139
|
+
if self._bool_to_cat and self._features_bool:
|
|
140
|
+
# FIXME: Use CategoryFeatureGenerator? Or tell the model which is category
|
|
141
|
+
X = X.copy(deep=True)
|
|
142
|
+
X[self._features_bool] = X[self._features_bool].astype("category")
|
|
143
|
+
|
|
144
|
+
return X
|
|
145
|
+
|
|
146
|
+
def save(self, path: str = None, verbose=True) -> str:
|
|
147
|
+
"""
|
|
148
|
+
Need to set device to CPU to be able to load on a non-GPU environment
|
|
149
|
+
"""
|
|
150
|
+
import torch
|
|
151
|
+
|
|
152
|
+
# Save on CPU to ensure the model can be loaded without GPU
|
|
153
|
+
if self.model is not None:
|
|
154
|
+
self.device = self.model.device_
|
|
155
|
+
device_cpu = torch.device("cpu")
|
|
156
|
+
self.model.model_ = self.model.model_.to(device_cpu)
|
|
157
|
+
self.model.device_ = device_cpu
|
|
158
|
+
path = super().save(path=path, verbose=verbose)
|
|
159
|
+
# Put the model back to the device after the save
|
|
160
|
+
if self.model is not None:
|
|
161
|
+
self.model.model_.to(self.device)
|
|
162
|
+
self.model.device_ = self.device
|
|
163
|
+
|
|
164
|
+
return path
|
|
165
|
+
|
|
166
|
+
@classmethod
|
|
167
|
+
def load(cls, path: str, reset_paths=True, verbose=True):
|
|
168
|
+
"""
|
|
169
|
+
Loads the model from disk to memory.
|
|
170
|
+
The loaded model will be on the same device it was trained on (cuda/mps);
|
|
171
|
+
if the device is not available (trained on GPU, deployed on CPU), then `cpu` will be used.
|
|
172
|
+
|
|
173
|
+
Parameters
|
|
174
|
+
----------
|
|
175
|
+
path : str
|
|
176
|
+
Path to the saved model, minus the file name.
|
|
177
|
+
This should generally be a directory path ending with a '/' character (or appropriate path separator value depending on OS).
|
|
178
|
+
The model file is typically located in os.path.join(path, cls.model_file_name).
|
|
179
|
+
reset_paths : bool, default True
|
|
180
|
+
Whether to reset the self.path value of the loaded model to be equal to path.
|
|
181
|
+
It is highly recommended to keep this value as True unless accessing the original self.path value is important.
|
|
182
|
+
If False, the actual valid path and self.path may differ, leading to strange behaviour and potential exceptions if the model needs to load any other files at a later time.
|
|
183
|
+
verbose : bool, default True
|
|
184
|
+
Whether to log the location of the loaded file.
|
|
185
|
+
|
|
186
|
+
Returns
|
|
187
|
+
-------
|
|
188
|
+
model : cls
|
|
189
|
+
Loaded model object.
|
|
190
|
+
"""
|
|
191
|
+
import torch
|
|
192
|
+
|
|
193
|
+
model: TabMModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
|
|
194
|
+
|
|
195
|
+
# Put the model on the same device it was trained on (GPU/MPS) if it is available; otherwise use CPU
|
|
196
|
+
if model.model is not None:
|
|
197
|
+
original_device_type = model.device.type
|
|
198
|
+
if "cuda" in original_device_type:
|
|
199
|
+
# cuda: nvidia GPU
|
|
200
|
+
device = torch.device(original_device_type if torch.cuda.is_available() else "cpu")
|
|
201
|
+
elif "mps" in original_device_type:
|
|
202
|
+
# mps: Apple Silicon
|
|
203
|
+
device = torch.device(original_device_type if torch.backends.mps.is_available() else "cpu")
|
|
204
|
+
else:
|
|
205
|
+
device = torch.device(original_device_type)
|
|
206
|
+
|
|
207
|
+
if verbose and (original_device_type != device.type):
|
|
208
|
+
logger.log(15, f"Model is trained on {original_device_type}, but the device is not available - loading on {device.type}")
|
|
209
|
+
|
|
210
|
+
model.set_device(device=device)
|
|
211
|
+
|
|
212
|
+
return model
|
|
213
|
+
|
|
214
|
+
def set_device(self, device):
|
|
215
|
+
self.device = device
|
|
216
|
+
if self.model is not None:
|
|
217
|
+
self.model.device_ = device
|
|
218
|
+
if self.model.model_ is not None:
|
|
219
|
+
self.model.model_ = self.model.model_.to(device)
|
|
220
|
+
|
|
221
|
+
@classmethod
|
|
222
|
+
def supported_problem_types(cls) -> list[str] | None:
|
|
223
|
+
return ["binary", "multiclass", "regression"]
|
|
224
|
+
|
|
225
|
+
def _get_default_stopping_metric(self):
|
|
226
|
+
return self.eval_metric
|
|
227
|
+
|
|
228
|
+
def _get_default_resources(self) -> tuple[int, int]:
|
|
229
|
+
# Use only physical cores for better performance based on benchmarks
|
|
230
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
|
231
|
+
|
|
232
|
+
num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
|
|
233
|
+
return num_cpus, num_gpus
|
|
234
|
+
|
|
235
|
+
def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
|
|
236
|
+
hyperparameters = self._get_model_params()
|
|
237
|
+
return self.estimate_memory_usage_static(
|
|
238
|
+
X=X,
|
|
239
|
+
problem_type=self.problem_type,
|
|
240
|
+
num_classes=self.num_classes,
|
|
241
|
+
hyperparameters=hyperparameters,
|
|
242
|
+
**kwargs,
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
@classmethod
|
|
246
|
+
def _estimate_memory_usage_static(
|
|
247
|
+
cls,
|
|
248
|
+
*,
|
|
249
|
+
X: pd.DataFrame,
|
|
250
|
+
hyperparameters: dict = None,
|
|
251
|
+
num_classes: int | None = 1,
|
|
252
|
+
**kwargs,
|
|
253
|
+
) -> int:
|
|
254
|
+
"""
|
|
255
|
+
Heuristic memory estimate that correlates strongly with RealMLP
|
|
256
|
+
"""
|
|
257
|
+
if num_classes is None:
|
|
258
|
+
num_classes = 1
|
|
259
|
+
if hyperparameters is None:
|
|
260
|
+
hyperparameters = {}
|
|
261
|
+
|
|
262
|
+
cat_sizes = []
|
|
263
|
+
for col in X.select_dtypes(include=["category", "object"]):
|
|
264
|
+
if isinstance(X[col], pd.CategoricalDtype):
|
|
265
|
+
# Use .cat.codes for category dtype
|
|
266
|
+
unique_codes = X[col].cat.codes.unique()
|
|
267
|
+
else:
|
|
268
|
+
# For object dtype, treat unique strings as codes
|
|
269
|
+
unique_codes = X[col].astype("category").cat.codes.unique()
|
|
270
|
+
cat_sizes.append(len(unique_codes))
|
|
271
|
+
|
|
272
|
+
n_numerical = len(X.select_dtypes(include=["number"]).columns)
|
|
273
|
+
|
|
274
|
+
# TODO: This estimates very high memory usage,
|
|
275
|
+
# we probably need to adjust batch size automatically to compensate
|
|
276
|
+
mem_estimate_bytes = cls._estimate_tabm_ram(
|
|
277
|
+
hyperparameters=hyperparameters,
|
|
278
|
+
n_numerical=n_numerical,
|
|
279
|
+
cat_sizes=cat_sizes,
|
|
280
|
+
n_classes=num_classes,
|
|
281
|
+
n_samples=len(X),
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
return mem_estimate_bytes
|
|
285
|
+
|
|
286
|
+
@classmethod
|
|
287
|
+
def _estimate_tabm_ram(
|
|
288
|
+
cls,
|
|
289
|
+
hyperparameters: dict,
|
|
290
|
+
n_numerical: int,
|
|
291
|
+
cat_sizes: list[int],
|
|
292
|
+
n_classes: int,
|
|
293
|
+
n_samples: int,
|
|
294
|
+
) -> int:
|
|
295
|
+
num_emb_n_bins = hyperparameters.get("num_emb_n_bins", 48)
|
|
296
|
+
d_embedding = hyperparameters.get("d_embedding", 16)
|
|
297
|
+
d_block = hyperparameters.get("d_block", 512)
|
|
298
|
+
# not completely sure if this is hidden blocks or all blocks, taking the safe option below
|
|
299
|
+
n_blocks = hyperparameters.get("n_blocks", "auto")
|
|
300
|
+
if isinstance(n_blocks, str) and n_blocks == "auto":
|
|
301
|
+
n_blocks = 3
|
|
302
|
+
batch_size = hyperparameters.get("batch_size", "auto")
|
|
303
|
+
if isinstance(batch_size, str) and batch_size == "auto":
|
|
304
|
+
batch_size = cls.get_tabm_auto_batch_size(n_samples=n_samples)
|
|
305
|
+
tabm_k = hyperparameters.get("tabm_k", 32)
|
|
306
|
+
predict_batch_size = hyperparameters.get("eval_batch_size", 1024)
|
|
307
|
+
|
|
308
|
+
# not completely sure
|
|
309
|
+
n_params_num_emb = n_numerical * (num_emb_n_bins + 1) * d_embedding
|
|
310
|
+
n_params_mlp = (n_numerical + sum(cat_sizes)) * d_embedding * (d_block + tabm_k) \
|
|
311
|
+
+ (n_blocks - 1) * d_block ** 2 \
|
|
312
|
+
+ n_blocks * d_block + d_block * (1 + max(1, n_classes))
|
|
313
|
+
# 4 bytes per float, up to 5 copies of parameters (1 standard, 1 .grad, 2 adam, 1 best_epoch)
|
|
314
|
+
mem_params = 4 * 5 * (n_params_num_emb + n_params_mlp)
|
|
315
|
+
|
|
316
|
+
# compute number of floats in forward pass (per batch element)
|
|
317
|
+
# todo: numerical embedding layer (not sure if this is entirely correct)
|
|
318
|
+
n_floats_forward = n_numerical * (num_emb_n_bins + d_embedding)
|
|
319
|
+
# before and after scale
|
|
320
|
+
n_floats_forward += 2 * (sum(cat_sizes) + n_numerical * d_embedding)
|
|
321
|
+
# 2 for pre-act, post-act
|
|
322
|
+
n_floats_forward += n_blocks * 2 * d_block + 2 * max(1, n_classes)
|
|
323
|
+
# 2 for forward and backward, 4 bytes per float
|
|
324
|
+
mem_forward_backward = 4 * max(batch_size * 2, predict_batch_size) * n_floats_forward * tabm_k
|
|
325
|
+
# * 8 is pessimistic for the long tensors in the forward pass, 4 would probably suffice
|
|
326
|
+
|
|
327
|
+
mem_ds = n_samples * (4 * n_numerical + 8 * len(cat_sizes))
|
|
328
|
+
|
|
329
|
+
# some safety constants and offsets (the 5 is probably excessive)
|
|
330
|
+
mem_total = 5 * mem_ds + 1.2 * mem_forward_backward + 1.2 * mem_params + 0.3 * (1024 ** 3)
|
|
331
|
+
|
|
332
|
+
return mem_total
|
|
333
|
+
|
|
334
|
+
@classmethod
|
|
335
|
+
def get_tabm_auto_batch_size(cls, n_samples: int) -> int:
|
|
336
|
+
# by Yury Gorishniy, inferred from the choices in the TabM paper.
|
|
337
|
+
if n_samples < 2_800:
|
|
338
|
+
return 32
|
|
339
|
+
if n_samples < 4_500:
|
|
340
|
+
return 64
|
|
341
|
+
if n_samples < 6_400:
|
|
342
|
+
return 128
|
|
343
|
+
if n_samples < 32_000:
|
|
344
|
+
return 256
|
|
345
|
+
if n_samples < 108_000:
|
|
346
|
+
return 512
|
|
347
|
+
return 1024
|
|
348
|
+
|
|
349
|
+
@classmethod
|
|
350
|
+
def _class_tags(cls):
|
|
351
|
+
return {"can_estimate_memory_usage_static": True}
|
|
352
|
+
|
|
353
|
+
def _more_tags(self) -> dict:
|
|
354
|
+
# TODO: Need to add train params support, track best epoch
|
|
355
|
+
# How to force stopping at a specific epoch?
|
|
356
|
+
return {"can_refit_full": False}
|