autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/config_helper.py +1 -1
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/pipeline_presets.py +130 -0
- autogluon/tabular/configs/presets_configs.py +51 -26
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
- autogluon/tabular/models/__init__.py +6 -1
- autogluon/tabular/models/_utils/rapids_utils.py +1 -1
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/catboost/callbacks.py +3 -2
- autogluon/tabular/models/catboost/catboost_model.py +15 -9
- autogluon/tabular/models/catboost/catboost_utils.py +17 -3
- autogluon/tabular/models/ebm/__init__.py +0 -0
- autogluon/tabular/models/ebm/ebm_model.py +259 -0
- autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
- autogluon/tabular/models/knn/knn_model.py +7 -3
- autogluon/tabular/models/lgb/lgb_model.py +60 -21
- autogluon/tabular/models/lr/lr_model.py +6 -1
- autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
- autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +380 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
- autogluon/tabular/models/realmlp/__init__.py +0 -0
- autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
- autogluon/tabular/models/rf/rf_model.py +11 -6
- autogluon/tabular/models/tabicl/__init__.py +0 -0
- autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
- autogluon/tabular/models/tabm/__init__.py +0 -0
- autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
- autogluon/tabular/models/tabm/tabm_model.py +356 -0
- autogluon/tabular/models/tabm/tabm_reference.py +631 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
- autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
- autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
- autogluon/tabular/predictor/predictor.py +147 -84
- autogluon/tabular/registry/_ag_model_registry.py +12 -2
- autogluon/tabular/testing/fit_helper.py +57 -27
- autogluon/tabular/testing/generate_datasets.py +7 -0
- autogluon/tabular/trainer/abstract_trainer.py +3 -1
- autogluon/tabular/trainer/model_presets/presets.py +10 -1
- autogluon/tabular/version.py +1 -1
- autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,380 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import logging
|
|
4
|
+
import os
|
|
5
|
+
from typing import List, Optional
|
|
6
|
+
|
|
7
|
+
import pandas as pd
|
|
8
|
+
|
|
9
|
+
from autogluon.common.utils.resource_utils import ResourceManager
|
|
10
|
+
from autogluon.core.models import AbstractModel
|
|
11
|
+
from autogluon.features.generators import LabelEncoderFeatureGenerator
|
|
12
|
+
from autogluon.tabular import __version__
|
|
13
|
+
|
|
14
|
+
logger = logging.getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class MitraModel(AbstractModel):
|
|
18
|
+
"""
|
|
19
|
+
Mitra is a tabular foundation model pre-trained purely on synthetic data with the goal
|
|
20
|
+
of optimizing fine-tuning performance over in-context learning performance.
|
|
21
|
+
Mitra was developed by the AutoGluon team @ AWS AI.
|
|
22
|
+
|
|
23
|
+
Mitra's default hyperparameters outperforms all methods for small datasets on TabArena-v0.1 (excluding ensembling): https://tabarena.ai
|
|
24
|
+
|
|
25
|
+
Authors: Xiyuan Zhang, Danielle C. Maddix, Junming Yin, Nick Erickson, Abdul Fatir Ansari, Boran Han, Shuai Zhang, Leman Akoglu, Christos Faloutsos, Michael W. Mahoney, Cuixiong Hu, Huzefa Rangwala, George Karypis, Bernie Wang
|
|
26
|
+
Blog Post: https://www.amazon.science/blog/mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models
|
|
27
|
+
License: Apache-2.0
|
|
28
|
+
|
|
29
|
+
.. versionadded:: 1.4.0
|
|
30
|
+
"""
|
|
31
|
+
ag_key = "MITRA"
|
|
32
|
+
ag_name = "Mitra"
|
|
33
|
+
weights_file_name = "model.pt"
|
|
34
|
+
ag_priority = 55
|
|
35
|
+
seed_name = "seed"
|
|
36
|
+
|
|
37
|
+
def __init__(self, **kwargs):
|
|
38
|
+
super().__init__(**kwargs)
|
|
39
|
+
self._weights_saved = False
|
|
40
|
+
self._feature_generator = None
|
|
41
|
+
|
|
42
|
+
@staticmethod
|
|
43
|
+
def _get_default_device():
|
|
44
|
+
"""Get the best available device for the current system."""
|
|
45
|
+
if ResourceManager.get_gpu_count_torch(cuda_only=True) > 0:
|
|
46
|
+
logger.log(15, "Using CUDA GPU")
|
|
47
|
+
return "cuda"
|
|
48
|
+
else:
|
|
49
|
+
return "cpu"
|
|
50
|
+
|
|
51
|
+
def get_model_cls(self):
|
|
52
|
+
if self.problem_type in ["binary", "multiclass"]:
|
|
53
|
+
from .sklearn_interface import MitraClassifier
|
|
54
|
+
|
|
55
|
+
model_cls = MitraClassifier
|
|
56
|
+
elif self.problem_type == "regression":
|
|
57
|
+
from .sklearn_interface import MitraRegressor
|
|
58
|
+
|
|
59
|
+
model_cls = MitraRegressor
|
|
60
|
+
else:
|
|
61
|
+
raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
|
|
62
|
+
return model_cls
|
|
63
|
+
|
|
64
|
+
def _preprocess(self, X: pd.DataFrame, is_train: bool = False, **kwargs) -> pd.DataFrame:
|
|
65
|
+
X = super()._preprocess(X, **kwargs)
|
|
66
|
+
|
|
67
|
+
if is_train:
|
|
68
|
+
# X will be the training data.
|
|
69
|
+
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
|
|
70
|
+
self._feature_generator.fit(X=X)
|
|
71
|
+
|
|
72
|
+
# This converts categorical features to numeric via stateful label encoding.
|
|
73
|
+
if self._feature_generator.features_in:
|
|
74
|
+
X = X.copy()
|
|
75
|
+
X[self._feature_generator.features_in] = self._feature_generator.transform(
|
|
76
|
+
X=X
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
return X
|
|
80
|
+
|
|
81
|
+
def _fit(
|
|
82
|
+
self,
|
|
83
|
+
X: pd.DataFrame,
|
|
84
|
+
y: pd.Series,
|
|
85
|
+
X_val: pd.DataFrame = None,
|
|
86
|
+
y_val: pd.Series = None,
|
|
87
|
+
time_limit: float = None,
|
|
88
|
+
num_cpus: int = 1,
|
|
89
|
+
num_gpus: float = 0,
|
|
90
|
+
verbosity: int = 2,
|
|
91
|
+
**kwargs,
|
|
92
|
+
):
|
|
93
|
+
# TODO: Reset the number of threads based on the specified num_cpus
|
|
94
|
+
need_to_reset_torch_threads = False
|
|
95
|
+
torch_threads_og = None
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
model_cls = self.get_model_cls()
|
|
99
|
+
import torch
|
|
100
|
+
except ImportError as err:
|
|
101
|
+
logger.log(
|
|
102
|
+
40,
|
|
103
|
+
f"\tFailed to import Mitra! To use the Mitra model, "
|
|
104
|
+
f"do: `pip install autogluon.tabular[mitra]=={__version__}`.",
|
|
105
|
+
)
|
|
106
|
+
raise err
|
|
107
|
+
|
|
108
|
+
if num_cpus is not None and isinstance(num_cpus, (int, float)):
|
|
109
|
+
torch_threads_og = torch.get_num_threads()
|
|
110
|
+
if torch_threads_og != num_cpus:
|
|
111
|
+
# reset torch threads back to original value after fit
|
|
112
|
+
torch.set_num_threads(num_cpus)
|
|
113
|
+
need_to_reset_torch_threads = True
|
|
114
|
+
|
|
115
|
+
hyp = self._get_model_params()
|
|
116
|
+
|
|
117
|
+
hf_cls_model = hyp.pop("hf_cls_model", None)
|
|
118
|
+
hf_reg_model = hyp.pop("hf_reg_model", None)
|
|
119
|
+
if self.problem_type in ["binary", "multiclass"]:
|
|
120
|
+
hf_model = hf_cls_model
|
|
121
|
+
elif self.problem_type == "regression":
|
|
122
|
+
hf_model = hf_reg_model
|
|
123
|
+
else:
|
|
124
|
+
raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
|
|
125
|
+
if hf_model is None:
|
|
126
|
+
hf_model = hyp.pop("hf_general_model", None)
|
|
127
|
+
if hf_model is None:
|
|
128
|
+
hf_model = hyp.pop("hf_model", None)
|
|
129
|
+
if hf_model is not None:
|
|
130
|
+
logger.log(30, f"\tCustom hf_model specified: {hf_model}")
|
|
131
|
+
hyp["hf_model"] = hf_model
|
|
132
|
+
|
|
133
|
+
if hyp.get("device", None) is None:
|
|
134
|
+
if num_gpus == 0:
|
|
135
|
+
hyp["device"] = "cpu"
|
|
136
|
+
else:
|
|
137
|
+
hyp["device"] = self._get_default_device()
|
|
138
|
+
|
|
139
|
+
if hyp["device"] == "cpu" and hyp.get("fine_tune", True):
|
|
140
|
+
logger.log(
|
|
141
|
+
30,
|
|
142
|
+
f"\tWarning: Attempting to fine-tune Mitra on CPU. This will be very slow. "
|
|
143
|
+
f"We strongly recommend using a GPU instance to fine-tune Mitra."
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
if "state_dict_classification" in hyp:
|
|
147
|
+
state_dict_classification = hyp.pop("state_dict_classification")
|
|
148
|
+
if self.problem_type in ["binary", "multiclass"]:
|
|
149
|
+
hyp["state_dict"] = state_dict_classification
|
|
150
|
+
if "state_dict_regression" in hyp:
|
|
151
|
+
state_dict_regression = hyp.pop("state_dict_regression")
|
|
152
|
+
if self.problem_type in ["regression"]:
|
|
153
|
+
hyp["state_dict"] = state_dict_regression
|
|
154
|
+
|
|
155
|
+
if "verbose" not in hyp:
|
|
156
|
+
hyp["verbose"] = verbosity >= 3
|
|
157
|
+
|
|
158
|
+
self.model = model_cls(**hyp)
|
|
159
|
+
|
|
160
|
+
X = self.preprocess(X, is_train=True)
|
|
161
|
+
if X_val is not None:
|
|
162
|
+
X_val = self.preprocess(X_val)
|
|
163
|
+
|
|
164
|
+
self.model = self.model.fit(
|
|
165
|
+
X=X,
|
|
166
|
+
y=y,
|
|
167
|
+
X_val=X_val,
|
|
168
|
+
y_val=y_val,
|
|
169
|
+
time_limit=time_limit,
|
|
170
|
+
)
|
|
171
|
+
|
|
172
|
+
if need_to_reset_torch_threads:
|
|
173
|
+
torch.set_num_threads(torch_threads_og)
|
|
174
|
+
|
|
175
|
+
def _set_default_params(self):
|
|
176
|
+
default_params = {
|
|
177
|
+
"n_estimators": 1,
|
|
178
|
+
}
|
|
179
|
+
for param, val in default_params.items():
|
|
180
|
+
self._set_default_param_value(param, val)
|
|
181
|
+
|
|
182
|
+
def _get_default_auxiliary_params(self) -> dict:
|
|
183
|
+
default_auxiliary_params = super()._get_default_auxiliary_params()
|
|
184
|
+
default_auxiliary_params.update(
|
|
185
|
+
{
|
|
186
|
+
"max_rows": 10000,
|
|
187
|
+
"max_features": 500,
|
|
188
|
+
"max_classes": 10,
|
|
189
|
+
}
|
|
190
|
+
)
|
|
191
|
+
return default_auxiliary_params
|
|
192
|
+
|
|
193
|
+
@property
|
|
194
|
+
def weights_path(self) -> str:
|
|
195
|
+
return os.path.join(self.path, self.weights_file_name)
|
|
196
|
+
|
|
197
|
+
def save(self, path: str = None, verbose=True) -> str:
|
|
198
|
+
_model_weights_list = None
|
|
199
|
+
if self.model is not None:
|
|
200
|
+
_model_weights_list = []
|
|
201
|
+
for i in range(len(self.model.trainers)):
|
|
202
|
+
_model_weights_list.append(self.model.trainers[i].model)
|
|
203
|
+
self.model.trainers[i].checkpoint = None
|
|
204
|
+
self.model.trainers[i].model = None
|
|
205
|
+
self.model.trainers[i].optimizer = None
|
|
206
|
+
self.model.trainers[i].scheduler_warmup = None
|
|
207
|
+
self.model.trainers[i].scheduler_reduce_on_plateau = None
|
|
208
|
+
self._weights_saved = True
|
|
209
|
+
path = super().save(path=path, verbose=verbose)
|
|
210
|
+
if _model_weights_list is not None:
|
|
211
|
+
import torch
|
|
212
|
+
|
|
213
|
+
os.makedirs(self.path, exist_ok=True)
|
|
214
|
+
torch.save(_model_weights_list, self.weights_path)
|
|
215
|
+
for i in range(len(self.model.trainers)):
|
|
216
|
+
self.model.trainers[i].model = _model_weights_list[i]
|
|
217
|
+
return path
|
|
218
|
+
|
|
219
|
+
@classmethod
|
|
220
|
+
def load(cls, path: str, reset_paths=False, verbose=True):
|
|
221
|
+
model: MitraModel = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
|
|
222
|
+
|
|
223
|
+
if model._weights_saved:
|
|
224
|
+
import torch
|
|
225
|
+
|
|
226
|
+
model_weights_list = torch.load(model.weights_path, weights_only=False) # nosec B614
|
|
227
|
+
for i in range(len(model.model.trainers)):
|
|
228
|
+
model.model.trainers[i].model = model_weights_list[i]
|
|
229
|
+
model._weights_saved = False
|
|
230
|
+
return model
|
|
231
|
+
|
|
232
|
+
@classmethod
|
|
233
|
+
def download_weights(cls, repo_id: str):
|
|
234
|
+
"""
|
|
235
|
+
Download weights for Mitra from HuggingFace from `repo_id`.
|
|
236
|
+
Requires an internet connection.
|
|
237
|
+
"""
|
|
238
|
+
from huggingface_hub import hf_hub_download
|
|
239
|
+
hf_hub_download(repo_id=repo_id, filename="config.json")
|
|
240
|
+
hf_hub_download(repo_id=repo_id, filename="model.safetensors")
|
|
241
|
+
|
|
242
|
+
@classmethod
|
|
243
|
+
def download_default_weights(cls):
|
|
244
|
+
"""
|
|
245
|
+
Download default weights for Mitra from HuggingFace.
|
|
246
|
+
Includes both classifier and regressor weights.
|
|
247
|
+
|
|
248
|
+
This is useful to call when building a docker image to avoid having to download Mitra weights for each instance.
|
|
249
|
+
This is also useful for benchmarking as a first sanity check
|
|
250
|
+
to avoid HuggingFace potentially blocking the download.
|
|
251
|
+
|
|
252
|
+
Requires an internet connection.
|
|
253
|
+
"""
|
|
254
|
+
cls.download_weights(repo_id="autogluon/mitra-classifier")
|
|
255
|
+
cls.download_weights(repo_id="autogluon/mitra-regressor")
|
|
256
|
+
|
|
257
|
+
@classmethod
|
|
258
|
+
def supported_problem_types(cls) -> Optional[List[str]]:
|
|
259
|
+
return ["binary", "multiclass", "regression"]
|
|
260
|
+
|
|
261
|
+
@classmethod
|
|
262
|
+
def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
|
|
263
|
+
default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
|
|
264
|
+
# FIXME: Test if it works with parallel, need to enable n_cpus support
|
|
265
|
+
extra_ag_args_ensemble = {
|
|
266
|
+
"fold_fitting_strategy": "sequential_local", # FIXME: Comment out after debugging for large speedup
|
|
267
|
+
}
|
|
268
|
+
default_ag_args_ensemble.update(extra_ag_args_ensemble)
|
|
269
|
+
return default_ag_args_ensemble
|
|
270
|
+
|
|
271
|
+
def _get_default_resources(self) -> tuple[int, int]:
|
|
272
|
+
# Use only physical cores for better performance based on benchmarks
|
|
273
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
|
274
|
+
|
|
275
|
+
num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
|
|
276
|
+
|
|
277
|
+
return num_cpus, num_gpus
|
|
278
|
+
|
|
279
|
+
def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
|
|
280
|
+
return self.estimate_memory_usage_static(
|
|
281
|
+
X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs
|
|
282
|
+
)
|
|
283
|
+
|
|
284
|
+
@classmethod
|
|
285
|
+
def _estimate_memory_usage_static(
|
|
286
|
+
cls,
|
|
287
|
+
*,
|
|
288
|
+
X: pd.DataFrame,
|
|
289
|
+
**kwargs,
|
|
290
|
+
) -> int:
|
|
291
|
+
# Multiply by 0.9 as currently this is overly safe
|
|
292
|
+
return int(0.9 * max(
|
|
293
|
+
cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
|
|
294
|
+
cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
|
|
295
|
+
cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
|
|
296
|
+
cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
|
|
297
|
+
))
|
|
298
|
+
|
|
299
|
+
@classmethod
|
|
300
|
+
def _estimate_memory_usage_static_cpu_icl(
|
|
301
|
+
cls,
|
|
302
|
+
*,
|
|
303
|
+
X: pd.DataFrame,
|
|
304
|
+
**kwargs,
|
|
305
|
+
) -> int:
|
|
306
|
+
rows, features = X.shape[0], X.shape[1]
|
|
307
|
+
|
|
308
|
+
# For very small datasets, use a more conservative estimate
|
|
309
|
+
if rows * features < 100: # Small dataset threshold
|
|
310
|
+
# Use a simpler linear formula for small datasets
|
|
311
|
+
cpu_memory_kb = 1.3 * (100 * rows * features + 1000000) # 1GB base + linear scaling
|
|
312
|
+
else:
|
|
313
|
+
# Original formula for larger datasets
|
|
314
|
+
cpu_memory_kb = 1.3 * (
|
|
315
|
+
0.001748 * (rows**2) * features + 0.001206 * rows * (features**2) + 10.3482 * rows * features + 6409698
|
|
316
|
+
)
|
|
317
|
+
return int(cpu_memory_kb * 1e3)
|
|
318
|
+
|
|
319
|
+
@classmethod
|
|
320
|
+
def _estimate_memory_usage_static_cpu_ft_icl(
|
|
321
|
+
cls,
|
|
322
|
+
*,
|
|
323
|
+
X: pd.DataFrame,
|
|
324
|
+
**kwargs,
|
|
325
|
+
) -> int:
|
|
326
|
+
rows, features = X.shape[0], X.shape[1]
|
|
327
|
+
|
|
328
|
+
# For very small datasets, use a more conservative estimate
|
|
329
|
+
if rows * features < 100: # Small dataset threshold
|
|
330
|
+
# Use a simpler linear formula for small datasets
|
|
331
|
+
cpu_memory_kb = 1.3 * (200 * rows * features + 2000000) # 2GB base + linear scaling
|
|
332
|
+
else:
|
|
333
|
+
# Original formula for larger datasets
|
|
334
|
+
cpu_memory_kb = 1.3 * (
|
|
335
|
+
0.001 * (rows**2) * features + 0.004541 * rows * (features**2) + 46.2974 * rows * features + 5605681
|
|
336
|
+
)
|
|
337
|
+
return int(cpu_memory_kb * 1e3)
|
|
338
|
+
|
|
339
|
+
@classmethod
|
|
340
|
+
def _estimate_memory_usage_static_gpu_cpu(
|
|
341
|
+
cls,
|
|
342
|
+
*,
|
|
343
|
+
X: pd.DataFrame,
|
|
344
|
+
**kwargs,
|
|
345
|
+
) -> int:
|
|
346
|
+
rows, features = X.shape[0], X.shape[1]
|
|
347
|
+
|
|
348
|
+
# For very small datasets, use a more conservative estimate
|
|
349
|
+
if rows * features < 100: # Small dataset threshold
|
|
350
|
+
return int(2.5 * 1e9) # 2.5GB for small datasets
|
|
351
|
+
else:
|
|
352
|
+
return int(5 * 1e9) # 5GB for larger datasets
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def _estimate_memory_usage_static_gpu_gpu(
|
|
356
|
+
cls,
|
|
357
|
+
*,
|
|
358
|
+
X: pd.DataFrame,
|
|
359
|
+
**kwargs,
|
|
360
|
+
) -> int:
|
|
361
|
+
rows, features = X.shape[0], X.shape[1]
|
|
362
|
+
|
|
363
|
+
# For very small datasets, use a more conservative estimate
|
|
364
|
+
if rows * features < 100: # Small dataset threshold
|
|
365
|
+
# Use a simpler linear formula for small datasets
|
|
366
|
+
gpu_memory_mb = 1.3 * (10 * rows * features + 2000) # 2GB base + linear scaling
|
|
367
|
+
else:
|
|
368
|
+
# Original formula for larger datasets
|
|
369
|
+
gpu_memory_mb = 1.3 * (0.05676 * rows * features + 3901)
|
|
370
|
+
return int(gpu_memory_mb * 1e6)
|
|
371
|
+
|
|
372
|
+
@classmethod
|
|
373
|
+
def _class_tags(cls) -> dict:
|
|
374
|
+
return {
|
|
375
|
+
"can_estimate_memory_usage_static": True,
|
|
376
|
+
}
|
|
377
|
+
|
|
378
|
+
def _more_tags(self) -> dict:
|
|
379
|
+
tags = {"can_refit_full": True}
|
|
380
|
+
return tags
|