autogluon.tabular 1.3.2b20250711__py3-none-any.whl → 1.3.2b20250712__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -1
- autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +376 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/METADATA +13 -15
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/RECORD +21 -14
- autogluon/tabular/models/tabpfn/__init__.py +0 -1
- autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
- /autogluon.tabular-1.3.2b20250711-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/zip-safe +0 -0
@@ -0,0 +1,106 @@
|
|
1
|
+
"""Copyright 2023.
|
2
|
+
|
3
|
+
Author: Lukas Schweizer <schweizer.lukas@web.de>
|
4
|
+
"""
|
5
|
+
|
6
|
+
# Copyright (c) Prior Labs GmbH 2025.
|
7
|
+
# Licensed under the Apache License, Version 2.0
|
8
|
+
|
9
|
+
from __future__ import annotations
|
10
|
+
|
11
|
+
import numpy as np
|
12
|
+
import pandas as pd
|
13
|
+
import torch
|
14
|
+
# Type checking imports
|
15
|
+
from typing import TYPE_CHECKING
|
16
|
+
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from numpy.typing import NDArray
|
19
|
+
|
20
|
+
|
21
|
+
def preprocess_data(
|
22
|
+
data,
|
23
|
+
nan_values=True,
|
24
|
+
one_hot_encoding=False,
|
25
|
+
normalization=True,
|
26
|
+
categorical_indices=None,
|
27
|
+
):
|
28
|
+
"""This method preprocesses data regarding missing values, categorical features
|
29
|
+
and data normalization (for the kNN Model)
|
30
|
+
:param data: Data to preprocess
|
31
|
+
:param nan_values: Preprocesses nan values if True
|
32
|
+
:param one_hot_encoding: Whether use OHE for categoricals
|
33
|
+
:param normalization: Normalizes data if True
|
34
|
+
:param categorical_indices: Categorical columns of data
|
35
|
+
:return: Preprocessed version of the data.
|
36
|
+
"""
|
37
|
+
data = data.numpy() if torch.is_tensor(data) else data
|
38
|
+
data = data.astype(np.float32)
|
39
|
+
data = pd.DataFrame(data).reset_index().drop("index", axis=1)
|
40
|
+
|
41
|
+
if categorical_indices is None:
|
42
|
+
categorical_indices = []
|
43
|
+
preprocessed_data = data
|
44
|
+
# NaN values (replace NaN with zeros)
|
45
|
+
if nan_values:
|
46
|
+
preprocessed_data = preprocessed_data.fillna(0)
|
47
|
+
# Categorical Features (One Hot Encoding)
|
48
|
+
if one_hot_encoding:
|
49
|
+
# Setting dtypes of categorical data to 'category'
|
50
|
+
for idx in categorical_indices:
|
51
|
+
preprocessed_data[preprocessed_data.columns[idx]] = preprocessed_data[
|
52
|
+
preprocessed_data.columns[idx]
|
53
|
+
].astype("category")
|
54
|
+
categorical_columns = list(
|
55
|
+
preprocessed_data.select_dtypes(include=["category"]).columns,
|
56
|
+
)
|
57
|
+
preprocessed_data = pd.get_dummies(
|
58
|
+
preprocessed_data,
|
59
|
+
columns=categorical_columns,
|
60
|
+
)
|
61
|
+
# Data normalization from R -> [0, 1]
|
62
|
+
if normalization:
|
63
|
+
if one_hot_encoding:
|
64
|
+
numerical_columns = list(
|
65
|
+
preprocessed_data.select_dtypes(exclude=["category"]).columns,
|
66
|
+
)
|
67
|
+
preprocessed_data[numerical_columns] = preprocessed_data[
|
68
|
+
numerical_columns
|
69
|
+
].apply(
|
70
|
+
lambda x: (x - x.min()) / (x.max() - x.min())
|
71
|
+
if x.max() != x.min()
|
72
|
+
else x,
|
73
|
+
)
|
74
|
+
else:
|
75
|
+
preprocessed_data = preprocessed_data.apply(
|
76
|
+
lambda x: (x - x.min()) / (x.max() - x.min())
|
77
|
+
if x.max() != x.min()
|
78
|
+
else x,
|
79
|
+
)
|
80
|
+
return preprocessed_data
|
81
|
+
|
82
|
+
def softmax(logits: NDArray) -> NDArray:
|
83
|
+
"""Apply softmax function to convert logits to probabilities.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
logits: Input logits array of shape (n_samples, n_classes) or (n_classes,)
|
87
|
+
|
88
|
+
Returns:
|
89
|
+
Probabilities where values sum to 1 across the last dimension
|
90
|
+
"""
|
91
|
+
# Handle both 2D and 1D inputs
|
92
|
+
if logits.ndim == 1:
|
93
|
+
logits = logits.reshape(1, -1)
|
94
|
+
|
95
|
+
# Apply exponential to each logit with numerical stability
|
96
|
+
logits_max = np.max(logits, axis=1, keepdims=True)
|
97
|
+
exp_logits = np.exp(logits - logits_max) # Subtract max for numerical stability
|
98
|
+
|
99
|
+
# Sum across classes and normalize
|
100
|
+
sum_exp_logits = np.sum(exp_logits, axis=1, keepdims=True)
|
101
|
+
probs = exp_logits / sum_exp_logits
|
102
|
+
|
103
|
+
# Return in the same shape as input
|
104
|
+
if logits.ndim == 1:
|
105
|
+
return probs.reshape(-1)
|
106
|
+
return probs
|
@@ -0,0 +1,376 @@
|
|
1
|
+
"""
|
2
|
+
Code Adapted from TabArena: https://github.com/autogluon/tabrepo/blob/main/tabrepo/benchmark/models/ag/tabpfnv2/tabpfnv2_model.py
|
3
|
+
|
4
|
+
Model: TabPFNv2
|
5
|
+
Paper: Accurate predictions on small data with a tabular foundation model
|
6
|
+
Authors: Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin Hoo, Robin Tibor Schirrmeister & Frank Hutter
|
7
|
+
Codebase: https://github.com/PriorLabs/TabPFN
|
8
|
+
License: https://github.com/PriorLabs/TabPFN/blob/main/LICENSE
|
9
|
+
"""
|
10
|
+
|
11
|
+
from __future__ import annotations
|
12
|
+
|
13
|
+
import logging
|
14
|
+
import warnings
|
15
|
+
from typing import TYPE_CHECKING, Any
|
16
|
+
|
17
|
+
import numpy as np
|
18
|
+
import scipy
|
19
|
+
from autogluon.common.utils.resource_utils import ResourceManager
|
20
|
+
from autogluon.core.models import AbstractModel
|
21
|
+
from autogluon.features.generators import LabelEncoderFeatureGenerator
|
22
|
+
from autogluon.tabular import __version__
|
23
|
+
from sklearn.preprocessing import PowerTransformer
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
import pandas as pd
|
27
|
+
|
28
|
+
logger = logging.getLogger(__name__)
|
29
|
+
|
30
|
+
_HAS_LOGGED_TABPFN_LICENSE: bool = False
|
31
|
+
|
32
|
+
|
33
|
+
# TODO: merge into TabPFnv2 codebase
|
34
|
+
class FixedSafePowerTransformer(PowerTransformer):
|
35
|
+
"""Fixed version of safe power."""
|
36
|
+
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
variance_threshold: float = 1e-3,
|
40
|
+
large_value_threshold: float = 100,
|
41
|
+
method="yeo-johnson",
|
42
|
+
standardize=True,
|
43
|
+
copy=True,
|
44
|
+
):
|
45
|
+
super().__init__(method=method, standardize=standardize, copy=copy)
|
46
|
+
self.variance_threshold = variance_threshold
|
47
|
+
self.large_value_threshold = large_value_threshold
|
48
|
+
|
49
|
+
self.revert_indices_ = None
|
50
|
+
|
51
|
+
def _find_features_to_revert_because_of_failure(
|
52
|
+
self,
|
53
|
+
transformed_X: np.ndarray,
|
54
|
+
) -> None:
|
55
|
+
# Calculate the variance for each feature in the transformed data
|
56
|
+
variances = np.nanvar(transformed_X, axis=0)
|
57
|
+
|
58
|
+
# Identify features where the variance is not close to 1
|
59
|
+
mask = np.abs(variances - 1) > self.variance_threshold
|
60
|
+
non_unit_variance_indices = np.where(mask)[0]
|
61
|
+
|
62
|
+
# Identify features with values greater than the large_value_threshold
|
63
|
+
large_value_indices = np.any(transformed_X > self.large_value_threshold, axis=0)
|
64
|
+
large_value_indices = np.nonzero(large_value_indices)[0]
|
65
|
+
|
66
|
+
# Identify features to revert based on either condition
|
67
|
+
self.revert_indices_ = np.unique(
|
68
|
+
np.concatenate([non_unit_variance_indices, large_value_indices]),
|
69
|
+
)
|
70
|
+
|
71
|
+
def _yeo_johnson_optimize(self, x: np.ndarray) -> float:
|
72
|
+
try:
|
73
|
+
with warnings.catch_warnings():
|
74
|
+
warnings.filterwarnings(
|
75
|
+
"ignore",
|
76
|
+
message=r"overflow encountered",
|
77
|
+
category=RuntimeWarning,
|
78
|
+
)
|
79
|
+
return super()._yeo_johnson_optimize(x) # type: ignore
|
80
|
+
except scipy.optimize._optimize.BracketError:
|
81
|
+
return np.nan
|
82
|
+
|
83
|
+
def _yeo_johnson_transform(self, x: np.ndarray, lmbda: float) -> np.ndarray:
|
84
|
+
if np.isnan(lmbda):
|
85
|
+
return x
|
86
|
+
|
87
|
+
return super()._yeo_johnson_transform(x, lmbda) # type: ignore
|
88
|
+
|
89
|
+
def _revert_failed_features(
|
90
|
+
self,
|
91
|
+
transformed_X: np.ndarray,
|
92
|
+
original_X: np.ndarray,
|
93
|
+
) -> np.ndarray:
|
94
|
+
# Replace these features with the original features
|
95
|
+
if self.revert_indices_ and (self.revert_indices_) > 0:
|
96
|
+
transformed_X[:, self.revert_indices_] = original_X[:, self.revert_indices_]
|
97
|
+
|
98
|
+
return transformed_X
|
99
|
+
|
100
|
+
def fit(self, X: np.ndarray, y: Any | None = None) -> FixedSafePowerTransformer:
|
101
|
+
super().fit(X, y)
|
102
|
+
|
103
|
+
# Check and revert features as necessary
|
104
|
+
self._find_features_to_revert_because_of_failure(super().transform(X)) # type: ignore
|
105
|
+
return self
|
106
|
+
|
107
|
+
def transform(self, X: np.ndarray) -> np.ndarray:
|
108
|
+
transformed_X = super().transform(X)
|
109
|
+
return self._revert_failed_features(transformed_X, X) # type: ignore
|
110
|
+
|
111
|
+
|
112
|
+
class TabPFNV2Model(AbstractModel):
|
113
|
+
ag_key = "TABPFNV2"
|
114
|
+
ag_name = "TabPFNv2"
|
115
|
+
ag_priority = 105
|
116
|
+
|
117
|
+
def __init__(self, **kwargs):
|
118
|
+
super().__init__(**kwargs)
|
119
|
+
self._feature_generator = None
|
120
|
+
self._cat_features = None
|
121
|
+
|
122
|
+
def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> pd.DataFrame:
|
123
|
+
X = super()._preprocess(X, **kwargs)
|
124
|
+
self._cat_indices = []
|
125
|
+
|
126
|
+
if is_train:
|
127
|
+
# X will be the training data.
|
128
|
+
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
|
129
|
+
self._feature_generator.fit(X=X)
|
130
|
+
|
131
|
+
# This converts categorical features to numeric via stateful label encoding.
|
132
|
+
if self._feature_generator.features_in:
|
133
|
+
X = X.copy()
|
134
|
+
X[self._feature_generator.features_in] = self._feature_generator.transform(
|
135
|
+
X=X
|
136
|
+
)
|
137
|
+
|
138
|
+
# Detect/set cat features and indices
|
139
|
+
if self._cat_features is None:
|
140
|
+
self._cat_features = self._feature_generator.features_in[:]
|
141
|
+
self._cat_indices = [X.columns.get_loc(col) for col in self._cat_features]
|
142
|
+
|
143
|
+
return X
|
144
|
+
|
145
|
+
# FIXME: Crashes during model download if bagging with parallel fit.
|
146
|
+
# Consider adopting same download logic as TabPFNMix which doesn't crash during model download.
|
147
|
+
# FIXME: Maybe support child_oof somehow with using only one model and being smart about inference time?
|
148
|
+
def _fit(
|
149
|
+
self,
|
150
|
+
X: pd.DataFrame,
|
151
|
+
y: pd.Series,
|
152
|
+
num_cpus: int = 1,
|
153
|
+
num_gpus: int = 0,
|
154
|
+
verbosity: int = 2,
|
155
|
+
**kwargs,
|
156
|
+
):
|
157
|
+
try:
|
158
|
+
from tabpfn.model import preprocessing
|
159
|
+
except ImportError as err:
|
160
|
+
logger.log(
|
161
|
+
40,
|
162
|
+
f"\tFailed to import tabpfn! To use the TabPFNv2 model, "
|
163
|
+
f"do: `pip install autogluon.tabular[tabpfn]=={__version__}`.",
|
164
|
+
)
|
165
|
+
raise err
|
166
|
+
|
167
|
+
preprocessing.SafePowerTransformer = FixedSafePowerTransformer
|
168
|
+
|
169
|
+
from tabpfn import TabPFNClassifier, TabPFNRegressor
|
170
|
+
from tabpfn.model.loading import resolve_model_path
|
171
|
+
from torch.cuda import is_available
|
172
|
+
|
173
|
+
is_classification = self.problem_type in ["binary", "multiclass"]
|
174
|
+
|
175
|
+
model_base = TabPFNClassifier if is_classification else TabPFNRegressor
|
176
|
+
|
177
|
+
device = "cuda" if num_gpus != 0 else "cpu"
|
178
|
+
if (device == "cuda") and (not is_available()):
|
179
|
+
# FIXME: warn instead and switch to CPU.
|
180
|
+
raise AssertionError(
|
181
|
+
"Fit specified to use GPU, but CUDA is not available on this machine. "
|
182
|
+
"Please switch to CPU usage instead.",
|
183
|
+
)
|
184
|
+
|
185
|
+
if verbosity >= 2:
|
186
|
+
# logs "Built with PriorLabs-TabPFN"
|
187
|
+
self._log_license(device=device)
|
188
|
+
|
189
|
+
X = self.preprocess(X, is_train=True)
|
190
|
+
|
191
|
+
hps = self._get_model_params()
|
192
|
+
hps["device"] = device
|
193
|
+
hps["n_jobs"] = num_cpus
|
194
|
+
hps["categorical_features_indices"] = self._cat_indices
|
195
|
+
|
196
|
+
_, model_dir, _, _ = resolve_model_path(
|
197
|
+
model_path=None,
|
198
|
+
which="classifier" if is_classification else "regressor",
|
199
|
+
)
|
200
|
+
if is_classification:
|
201
|
+
if "classification_model_path" in hps:
|
202
|
+
hps["model_path"] = model_dir / hps.pop("classification_model_path")
|
203
|
+
if "regression_model_path" in hps:
|
204
|
+
del hps["regression_model_path"]
|
205
|
+
else:
|
206
|
+
if "regression_model_path" in hps:
|
207
|
+
hps["model_path"] = model_dir / hps.pop("regression_model_path")
|
208
|
+
if "classification_model_path" in hps:
|
209
|
+
del hps["classification_model_path"]
|
210
|
+
|
211
|
+
# Resolve inference_config
|
212
|
+
inference_config = {
|
213
|
+
_k: v
|
214
|
+
for k, v in hps.items()
|
215
|
+
if k.startswith("inference_config/") and (_k := k.split("/")[-1])
|
216
|
+
}
|
217
|
+
if inference_config:
|
218
|
+
hps["inference_config"] = inference_config
|
219
|
+
for k in list(hps.keys()):
|
220
|
+
if k.startswith("inference_config/"):
|
221
|
+
del hps[k]
|
222
|
+
|
223
|
+
# TODO: remove power from search space and TabPFNv2 codebase
|
224
|
+
# Power transform can fail. To avoid this, make all power be safepower instead.
|
225
|
+
if "PREPROCESS_TRANSFORMS" in inference_config:
|
226
|
+
safe_config = []
|
227
|
+
for preprocessing_dict in inference_config["PREPROCESS_TRANSFORMS"]:
|
228
|
+
if preprocessing_dict["name"] == "power":
|
229
|
+
preprocessing_dict["name"] = "safepower"
|
230
|
+
safe_config.append(preprocessing_dict)
|
231
|
+
inference_config["PREPROCESS_TRANSFORMS"] = safe_config
|
232
|
+
if "REGRESSION_Y_PREPROCESS_TRANSFORMS" in inference_config:
|
233
|
+
safe_config = []
|
234
|
+
for preprocessing_name in inference_config[
|
235
|
+
"REGRESSION_Y_PREPROCESS_TRANSFORMS"
|
236
|
+
]:
|
237
|
+
if preprocessing_name == "power":
|
238
|
+
preprocessing_name = "safepower"
|
239
|
+
safe_config.append(preprocessing_name)
|
240
|
+
inference_config["REGRESSION_Y_PREPROCESS_TRANSFORMS"] = safe_config
|
241
|
+
|
242
|
+
# Resolve model_type
|
243
|
+
n_ensemble_repeats = hps.pop("n_ensemble_repeats", None)
|
244
|
+
model_is_rf_pfn = hps.pop("model_type", "no") == "dt_pfn"
|
245
|
+
if model_is_rf_pfn:
|
246
|
+
from .rfpfn import (
|
247
|
+
RandomForestTabPFNClassifier,
|
248
|
+
RandomForestTabPFNRegressor,
|
249
|
+
)
|
250
|
+
|
251
|
+
hps["n_estimators"] = 1
|
252
|
+
rf_model_base = (
|
253
|
+
RandomForestTabPFNClassifier
|
254
|
+
if is_classification
|
255
|
+
else RandomForestTabPFNRegressor
|
256
|
+
)
|
257
|
+
self.model = rf_model_base(
|
258
|
+
tabpfn=model_base(**hps),
|
259
|
+
categorical_features=self._cat_indices,
|
260
|
+
n_estimators=n_ensemble_repeats,
|
261
|
+
)
|
262
|
+
else:
|
263
|
+
if n_ensemble_repeats is not None:
|
264
|
+
hps["n_estimators"] = n_ensemble_repeats
|
265
|
+
self.model = model_base(**hps)
|
266
|
+
|
267
|
+
self.model = self.model.fit(
|
268
|
+
X=X,
|
269
|
+
y=y,
|
270
|
+
)
|
271
|
+
|
272
|
+
def _log_license(self, device: str):
|
273
|
+
global _HAS_LOGGED_TABPFN_LICENSE
|
274
|
+
if not _HAS_LOGGED_TABPFN_LICENSE:
|
275
|
+
logger.log(20, f"\tBuilt with PriorLabs-TabPFN") # Aligning with TabPFNv2 license requirements
|
276
|
+
if device == "cpu":
|
277
|
+
logger.log(
|
278
|
+
20,
|
279
|
+
f"\tRunning TabPFNv2 on CPU. This can be very slow. "
|
280
|
+
f"It is recommended to run TabPFNv2 on a GPU."
|
281
|
+
)
|
282
|
+
_HAS_LOGGED_TABPFN_LICENSE = True # Avoid repeated logging
|
283
|
+
|
284
|
+
def _get_default_resources(self) -> tuple[int, int]:
|
285
|
+
num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
|
286
|
+
num_gpus = min(ResourceManager.get_gpu_count_torch(), 1)
|
287
|
+
return num_cpus, num_gpus
|
288
|
+
|
289
|
+
def _set_default_params(self):
|
290
|
+
default_params = {
|
291
|
+
"random_state": 42,
|
292
|
+
"ignore_pretraining_limits": True, # to ignore warnings and size limits
|
293
|
+
}
|
294
|
+
for param, val in default_params.items():
|
295
|
+
self._set_default_param_value(param, val)
|
296
|
+
|
297
|
+
@classmethod
|
298
|
+
def supported_problem_types(cls) -> list[str] | None:
|
299
|
+
return ["binary", "multiclass", "regression"]
|
300
|
+
|
301
|
+
def _get_default_auxiliary_params(self) -> dict:
|
302
|
+
default_auxiliary_params = super()._get_default_auxiliary_params()
|
303
|
+
default_auxiliary_params.update(
|
304
|
+
{
|
305
|
+
"max_rows": 10000,
|
306
|
+
"max_features": 500,
|
307
|
+
"max_classes": 10,
|
308
|
+
}
|
309
|
+
)
|
310
|
+
return default_auxiliary_params
|
311
|
+
|
312
|
+
@classmethod
|
313
|
+
def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
|
314
|
+
"""Set fold_fitting_strategy to sequential_local,
|
315
|
+
as parallel folding crashes if model weights aren't pre-downloaded.
|
316
|
+
"""
|
317
|
+
default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
|
318
|
+
extra_ag_args_ensemble = {
|
319
|
+
# FIXME: Find a work-around to avoid crash if parallel and weights are not downloaded
|
320
|
+
"fold_fitting_strategy": "sequential_local",
|
321
|
+
"refit_folds": True, # Better to refit the model for faster inference and similar quality as the bag.
|
322
|
+
}
|
323
|
+
default_ag_args_ensemble.update(extra_ag_args_ensemble)
|
324
|
+
return default_ag_args_ensemble
|
325
|
+
|
326
|
+
def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
|
327
|
+
hyperparameters = self._get_model_params()
|
328
|
+
return self.estimate_memory_usage_static(
|
329
|
+
X=X,
|
330
|
+
problem_type=self.problem_type,
|
331
|
+
num_classes=self.num_classes,
|
332
|
+
hyperparameters=hyperparameters,
|
333
|
+
**kwargs,
|
334
|
+
)
|
335
|
+
|
336
|
+
@classmethod
|
337
|
+
def _estimate_memory_usage_static(
|
338
|
+
cls,
|
339
|
+
*,
|
340
|
+
X: pd.DataFrame,
|
341
|
+
hyperparameters: dict | None = None,
|
342
|
+
**kwargs,
|
343
|
+
) -> int:
|
344
|
+
"""Heuristic memory estimate based on TabPFN's memory estimate logic in:
|
345
|
+
https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
|
346
|
+
|
347
|
+
This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
|
348
|
+
"""
|
349
|
+
# features_per_group = 2 # Based on TabPFNv2 default (unused)
|
350
|
+
n_layers = 12 # Based on TabPFNv2 default
|
351
|
+
embedding_size = 192 # Based on TabPFNv2 default
|
352
|
+
dtype_byte_size = 2 # Based on TabPFNv2 default
|
353
|
+
|
354
|
+
model_mem = 14489108 # Based on TabPFNv2 default
|
355
|
+
|
356
|
+
n_samples, n_features = X.shape[0], X.shape[1]
|
357
|
+
n_feature_groups = n_features + 1 # TODO: Unsure how to calculate this
|
358
|
+
|
359
|
+
X_mem = n_samples * n_feature_groups * dtype_byte_size
|
360
|
+
activation_mem = (
|
361
|
+
n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
|
362
|
+
)
|
363
|
+
|
364
|
+
baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
|
365
|
+
|
366
|
+
# Add some buffer to each term + 1 GB overhead to be safe
|
367
|
+
return int(
|
368
|
+
model_mem + 4 * X_mem + 1.5 * activation_mem + baseline_overhead_mem_est
|
369
|
+
)
|
370
|
+
|
371
|
+
@classmethod
|
372
|
+
def _class_tags(cls):
|
373
|
+
return {"can_estimate_memory_usage_static": True}
|
374
|
+
|
375
|
+
def _more_tags(self) -> dict:
|
376
|
+
return {"can_refit_full": True}
|
@@ -25,7 +25,7 @@ from ..models import (
|
|
25
25
|
TabICLModel,
|
26
26
|
TabMModel,
|
27
27
|
TabPFNMixModel,
|
28
|
-
|
28
|
+
TabPFNV2Model,
|
29
29
|
TabularNeuralNetTorchModel,
|
30
30
|
TextPredictorModel,
|
31
31
|
XGBoostModel,
|
@@ -51,8 +51,8 @@ REGISTERED_MODEL_CLS_LST = [
|
|
51
51
|
FTTransformerModel,
|
52
52
|
TabICLModel,
|
53
53
|
TabMModel,
|
54
|
-
TabPFNModel,
|
55
54
|
TabPFNMixModel,
|
55
|
+
TabPFNV2Model,
|
56
56
|
FastTextModel,
|
57
57
|
GreedyWeightedEnsembleModel,
|
58
58
|
SimpleWeightedEnsembleModel,
|
autogluon/tabular/version.py
CHANGED
{autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/METADATA
RENAMED
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.1
|
2
2
|
Name: autogluon.tabular
|
3
|
-
Version: 1.3.
|
3
|
+
Version: 1.3.2b20250712
|
4
4
|
Summary: Fast and Accurate ML in 3 Lines of Code
|
5
5
|
Home-page: https://github.com/autogluon/autogluon
|
6
6
|
Author: AutoGluon Community
|
@@ -41,20 +41,20 @@ Requires-Dist: scipy<1.17,>=1.5.4
|
|
41
41
|
Requires-Dist: pandas<2.4.0,>=2.0.0
|
42
42
|
Requires-Dist: scikit-learn<1.8.0,>=1.4.0
|
43
43
|
Requires-Dist: networkx<4,>=3.0
|
44
|
-
Requires-Dist: autogluon.core==1.3.
|
45
|
-
Requires-Dist: autogluon.features==1.3.
|
44
|
+
Requires-Dist: autogluon.core==1.3.2b20250712
|
45
|
+
Requires-Dist: autogluon.features==1.3.2b20250712
|
46
46
|
Provides-Extra: all
|
47
|
-
Requires-Dist: lightgbm<4.7,>=4.0; extra == "all"
|
48
|
-
Requires-Dist: autogluon.core[all]==1.3.2b20250711; extra == "all"
|
49
47
|
Requires-Dist: spacy<3.9; extra == "all"
|
50
|
-
Requires-Dist: xgboost<3.1,>=2.0; extra == "all"
|
51
|
-
Requires-Dist: fastai<2.9,>=2.3.1; extra == "all"
|
52
48
|
Requires-Dist: catboost<1.3,>=1.2; extra == "all"
|
53
|
-
Requires-Dist: huggingface-hub[torch]; extra == "all"
|
54
|
-
Requires-Dist: numpy<2.3.0,>=1.25; extra == "all"
|
55
49
|
Requires-Dist: einops<0.9,>=0.7; extra == "all"
|
56
|
-
Requires-Dist:
|
50
|
+
Requires-Dist: lightgbm<4.7,>=4.0; extra == "all"
|
51
|
+
Requires-Dist: numpy<2.3.0,>=1.25; extra == "all"
|
52
|
+
Requires-Dist: autogluon.core[all]==1.3.2b20250712; extra == "all"
|
53
|
+
Requires-Dist: fastai<2.9,>=2.3.1; extra == "all"
|
57
54
|
Requires-Dist: pytabkit<1.6,>=1.5; extra == "all"
|
55
|
+
Requires-Dist: xgboost<3.1,>=2.0; extra == "all"
|
56
|
+
Requires-Dist: torch<2.8,>=2.2; extra == "all"
|
57
|
+
Requires-Dist: huggingface-hub[torch]; extra == "all"
|
58
58
|
Provides-Extra: catboost
|
59
59
|
Requires-Dist: numpy<2.3.0,>=1.25; extra == "catboost"
|
60
60
|
Requires-Dist: catboost<1.3,>=1.2; extra == "catboost"
|
@@ -67,7 +67,7 @@ Requires-Dist: imodels<2.1.0,>=1.3.10; extra == "imodels"
|
|
67
67
|
Provides-Extra: lightgbm
|
68
68
|
Requires-Dist: lightgbm<4.7,>=4.0; extra == "lightgbm"
|
69
69
|
Provides-Extra: ray
|
70
|
-
Requires-Dist: autogluon.core[all]==1.3.
|
70
|
+
Requires-Dist: autogluon.core[all]==1.3.2b20250712; extra == "ray"
|
71
71
|
Provides-Extra: realmlp
|
72
72
|
Requires-Dist: pytabkit<1.6,>=1.5; extra == "realmlp"
|
73
73
|
Provides-Extra: skex
|
@@ -83,15 +83,13 @@ Requires-Dist: tabicl<0.2,>=0.1.3; extra == "tabicl"
|
|
83
83
|
Provides-Extra: tabm
|
84
84
|
Requires-Dist: torch<2.8,>=2.2; extra == "tabm"
|
85
85
|
Provides-Extra: tabpfn
|
86
|
-
Requires-Dist: tabpfn<2.
|
86
|
+
Requires-Dist: tabpfn<2.2,>=2.0.9; extra == "tabpfn"
|
87
87
|
Provides-Extra: tabpfnmix
|
88
88
|
Requires-Dist: torch<2.8,>=2.2; extra == "tabpfnmix"
|
89
89
|
Requires-Dist: huggingface-hub[torch]; extra == "tabpfnmix"
|
90
90
|
Requires-Dist: einops<0.9,>=0.7; extra == "tabpfnmix"
|
91
91
|
Provides-Extra: tests
|
92
|
-
Requires-Dist:
|
93
|
-
Requires-Dist: huggingface-hub[torch]; extra == "tests"
|
94
|
-
Requires-Dist: einops<0.9,>=0.7; extra == "tests"
|
92
|
+
Requires-Dist: tabpfn<2.2,>=2.0.9; extra == "tests"
|
95
93
|
Requires-Dist: imodels<2.1.0,>=1.3.10; extra == "tests"
|
96
94
|
Requires-Dist: skl2onnx<1.18.0,>=1.15.0; extra == "tests"
|
97
95
|
Requires-Dist: onnxruntime<1.20.0,>=1.17.0; extra == "tests"
|
{autogluon.tabular-1.3.2b20250711.dist-info → autogluon.tabular-1.3.2b20250712.dist-info}/RECORD
RENAMED
@@ -1,6 +1,6 @@
|
|
1
|
-
autogluon.tabular-1.3.
|
1
|
+
autogluon.tabular-1.3.2b20250712-py3.9-nspkg.pth,sha256=cQGwpuGPqg1GXscIwt-7PmME1OnSpD-7ixkikJ31WAY,554
|
2
2
|
autogluon/tabular/__init__.py,sha256=2OXpJCvENRHubBTYNIPpHX93WWuFZzsJBtTZbNVHVas,400
|
3
|
-
autogluon/tabular/version.py,sha256=
|
3
|
+
autogluon/tabular/version.py,sha256=C6OW_vajErF7r9El7B0X_XkhCzzEn70hhuGbhroLKSU,91
|
4
4
|
autogluon/tabular/configs/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
5
5
|
autogluon/tabular/configs/config_helper.py,sha256=JsdVGmpcYL88GPKBznPtqJ1sGaByOSvLn7KWU-HyVoQ,21085
|
6
6
|
autogluon/tabular/configs/feature_generator_presets.py,sha256=EV5Ym8VW15q92MwOUpTi7wZFS2QooM51fLg3RdUsn-M,1223
|
@@ -16,7 +16,7 @@ autogluon/tabular/experimental/plot_leaderboard.py,sha256=BN_kB-zmOZNUYWyI7z9pF6
|
|
16
16
|
autogluon/tabular/learner/__init__.py,sha256=Hhmk5WpKQHohVmI-veOaKMelKJpIdzeXrmw_DPn3DTU,63
|
17
17
|
autogluon/tabular/learner/abstract_learner.py,sha256=0kf0huvg0nphe-lrdKtNTzdIFr14jzJPsfZDRBkKo3g,55253
|
18
18
|
autogluon/tabular/learner/default_learner.py,sha256=hjdKbcFtIQxQ3-k1LiGOo-w5sLxIIQAyFLs3-R35aw0,24781
|
19
|
-
autogluon/tabular/models/__init__.py,sha256=
|
19
|
+
autogluon/tabular/models/__init__.py,sha256=x6hmZ0RhFIznnO1UFrHrcu0wRFTV8sGZN_SwoXdW1u8,1174
|
20
20
|
autogluon/tabular/models/_utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
21
21
|
autogluon/tabular/models/_utils/rapids_utils.py,sha256=9A2Y10Owva6zhcLkBVQ_T4tOAMDp1idSMzDWhl_QyBI,1083
|
22
22
|
autogluon/tabular/models/_utils/torch_utils.py,sha256=dxs_KMMAOmNkRNjYf_hrzqaHIfkqn1xoKRKqCFbQ1Rk,537
|
@@ -83,8 +83,6 @@ autogluon/tabular/models/tabm/_tabm_internal.py,sha256=LbIohrZYnXiKbD1ZnXWDJQMBL
|
|
83
83
|
autogluon/tabular/models/tabm/rtdl_num_embeddings.py,sha256=omDKJT0MjniUPUnk8tSU-brE8dXIjw27BHFbYc2bswQ,30119
|
84
84
|
autogluon/tabular/models/tabm/tabm_model.py,sha256=43I8429yTq5U2IDp6ATZB27lyewAW20VzdbPxS-01sA,10115
|
85
85
|
autogluon/tabular/models/tabm/tabm_reference.py,sha256=h9FXzyeu6b4vXg9nnM3L2I8dYbcE39USr9C4uMnt4Ek,21788
|
86
|
-
autogluon/tabular/models/tabpfn/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
87
|
-
autogluon/tabular/models/tabpfn/tabpfn_model.py,sha256=PEYMuIh5TFLIDy3hcjfz1DcvDu77rbwRq0pKWyuUR04,6787
|
88
86
|
autogluon/tabular/models/tabpfnmix/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
89
87
|
autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py,sha256=7cLjAfstq6Xb-l2DxBdwtSAIanSJN2sMfKPtijDQwXo,16193
|
90
88
|
autogluon/tabular/models/tabpfnmix/_internal/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
@@ -111,6 +109,15 @@ autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py,sha2
|
|
111
109
|
autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py,sha256=bhNpGIA5BKqIVX-kDW4bZLgsOB_A8iNsnpgoyyBLR98,5383
|
112
110
|
autogluon/tabular/models/tabpfnmix/_internal/results/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
113
111
|
autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py,sha256=1tRPHyViSSLJ7BkQJi6wai-PwXJ56od86Dy1WWKWZq4,1743
|
112
|
+
autogluon/tabular/models/tabpfnv2/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
113
|
+
autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py,sha256=DC4t-woswX3mWYRWF6zMfpjEtPZXJ9qgr-mKxWVFs3w,14254
|
114
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py,sha256=yE5XAhGxKEFV0JcelZ_JTQZIWGlVEVUQ9a-lxcH_Esc,585
|
115
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/configs.py,sha256=lzBY9kKOeBZACVrtRDPHF4ATs9g1rxyNnIs2CMjE20c,1175
|
116
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py,sha256=uvHsfvnnMdg4tP3_7zAilktkw7nr65LaqfVKXabXAow,6785
|
117
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py,sha256=-KQNm_HYWem6HWUsdbnIX4lKe-eW0PQAXZUny2kqego,55582
|
118
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py,sha256=FRJSelTtDaKnpsKKHphjy2rJrFX302miSdHZ0YqHxCQ,28045
|
119
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py,sha256=jv2ZHsGwcO4Inhxtol_tig3NoXZQR649dhmW_Kv69QY,29607
|
120
|
+
autogluon/tabular/models/tabpfnv2/rfpfn/utils.py,sha256=vjMQsNaZZcW1BBf0hduSCtrNCtSd467xfkhsbHspUog,3489
|
114
121
|
autogluon/tabular/models/tabular_nn/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
115
122
|
autogluon/tabular/models/tabular_nn/compilers/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
116
123
|
autogluon/tabular/models/tabular_nn/compilers/native.py,sha256=W8d8cqBj7U-KVhfGK3hdtGj8JJm3lXr_SecU0615Gbs,1330
|
@@ -141,7 +148,7 @@ autogluon/tabular/predictor/__init__.py,sha256=zCMgjxQlWpDWnr1l1xjBCiK3rWC3N3RoD
|
|
141
148
|
autogluon/tabular/predictor/interpretable_predictor.py,sha256=5UeKgnMFsfY65tiO3kxfHBPr03lyswLrgdtjPhI0Y7Q,6934
|
142
149
|
autogluon/tabular/predictor/predictor.py,sha256=cjszntXs6k5BZMOaLGaMiC1e2sGkCsnXrH9rVI972-0,356548
|
143
150
|
autogluon/tabular/registry/__init__.py,sha256=vZpzX4Xve7bfA9crt5LxjgQv9PPfxbi1E1U6Im0Y_xU,93
|
144
|
-
autogluon/tabular/registry/_ag_model_registry.py,sha256=
|
151
|
+
autogluon/tabular/registry/_ag_model_registry.py,sha256=6Rro0BBN3yb34Ysi2hffuJDdP9eV6el2HjQ-a48N2-E,1518
|
145
152
|
autogluon/tabular/registry/_model_registry.py,sha256=Rl8Q7BLzaif4hxNxJF20xGE02vrWwh2ZuUaTmA-UJnE,6824
|
146
153
|
autogluon/tabular/testing/__init__.py,sha256=XrEGLmMdmRT6QHNR13M9wna57LO4O3Q4tt27Ca8omAc,79
|
147
154
|
autogluon/tabular/testing/fit_helper.py,sha256=dzyzIBD9s7Ekb_inoAE6sep3bW9QKeYqO4WcDzAhAwg,19818
|
@@ -155,11 +162,11 @@ autogluon/tabular/trainer/model_presets/presets.py,sha256=hoWADaOG576Q_XLV1nY_ju
|
|
155
162
|
autogluon/tabular/trainer/model_presets/presets_distill.py,sha256=MnFC2GJc6RmDBNAGbsO2XMfo3PjR8cUrZoilWW8gTYQ,3295
|
156
163
|
autogluon/tabular/tuning/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
157
164
|
autogluon/tabular/tuning/feature_pruner.py,sha256=9iNku8gVbYEkjuKlyITPJDicsNkoraaQOlINQq9iZlQ,6877
|
158
|
-
autogluon.tabular-1.3.
|
159
|
-
autogluon.tabular-1.3.
|
160
|
-
autogluon.tabular-1.3.
|
161
|
-
autogluon.tabular-1.3.
|
162
|
-
autogluon.tabular-1.3.
|
163
|
-
autogluon.tabular-1.3.
|
164
|
-
autogluon.tabular-1.3.
|
165
|
-
autogluon.tabular-1.3.
|
165
|
+
autogluon.tabular-1.3.2b20250712.dist-info/LICENSE,sha256=CeipvOyAZxBGUsFoaFqwkx54aPnIKEtm9a5u2uXxEws,10142
|
166
|
+
autogluon.tabular-1.3.2b20250712.dist-info/METADATA,sha256=6abZh-VsVqsBp5JZrRellfiGEvoKRi5z2c3LGRUfJDI,14290
|
167
|
+
autogluon.tabular-1.3.2b20250712.dist-info/NOTICE,sha256=7nPQuj8Kp-uXsU0S5so3-2dNU5EctS5hDXvvzzehd7E,114
|
168
|
+
autogluon.tabular-1.3.2b20250712.dist-info/WHEEL,sha256=tZoeGjtWxWRfdplE7E3d45VPlLNQnvbKiYnx7gwAy8A,92
|
169
|
+
autogluon.tabular-1.3.2b20250712.dist-info/namespace_packages.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
|
170
|
+
autogluon.tabular-1.3.2b20250712.dist-info/top_level.txt,sha256=giERA4R78OkJf2ijn5slgjURlhRPzfLr7waIcGkzYAo,10
|
171
|
+
autogluon.tabular-1.3.2b20250712.dist-info/zip-safe,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
172
|
+
autogluon.tabular-1.3.2b20250712.dist-info/RECORD,,
|
@@ -1 +0,0 @@
|
|
1
|
-
|