autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -1,106 +0,0 @@
1
- """Copyright 2023.
2
-
3
- Author: Lukas Schweizer <schweizer.lukas@web.de>
4
- """
5
-
6
- # Copyright (c) Prior Labs GmbH 2025.
7
- # Licensed under the Apache License, Version 2.0
8
-
9
- from __future__ import annotations
10
-
11
- import numpy as np
12
- import pandas as pd
13
- import torch
14
- # Type checking imports
15
- from typing import TYPE_CHECKING
16
-
17
- if TYPE_CHECKING:
18
- from numpy.typing import NDArray
19
-
20
-
21
- def preprocess_data(
22
- data,
23
- nan_values=True,
24
- one_hot_encoding=False,
25
- normalization=True,
26
- categorical_indices=None,
27
- ):
28
- """This method preprocesses data regarding missing values, categorical features
29
- and data normalization (for the kNN Model)
30
- :param data: Data to preprocess
31
- :param nan_values: Preprocesses nan values if True
32
- :param one_hot_encoding: Whether use OHE for categoricals
33
- :param normalization: Normalizes data if True
34
- :param categorical_indices: Categorical columns of data
35
- :return: Preprocessed version of the data.
36
- """
37
- data = data.numpy() if torch.is_tensor(data) else data
38
- data = data.astype(np.float32)
39
- data = pd.DataFrame(data).reset_index().drop("index", axis=1)
40
-
41
- if categorical_indices is None:
42
- categorical_indices = []
43
- preprocessed_data = data
44
- # NaN values (replace NaN with zeros)
45
- if nan_values:
46
- preprocessed_data = preprocessed_data.fillna(0)
47
- # Categorical Features (One Hot Encoding)
48
- if one_hot_encoding:
49
- # Setting dtypes of categorical data to 'category'
50
- for idx in categorical_indices:
51
- preprocessed_data[preprocessed_data.columns[idx]] = preprocessed_data[
52
- preprocessed_data.columns[idx]
53
- ].astype("category")
54
- categorical_columns = list(
55
- preprocessed_data.select_dtypes(include=["category"]).columns,
56
- )
57
- preprocessed_data = pd.get_dummies(
58
- preprocessed_data,
59
- columns=categorical_columns,
60
- )
61
- # Data normalization from R -> [0, 1]
62
- if normalization:
63
- if one_hot_encoding:
64
- numerical_columns = list(
65
- preprocessed_data.select_dtypes(exclude=["category"]).columns,
66
- )
67
- preprocessed_data[numerical_columns] = preprocessed_data[
68
- numerical_columns
69
- ].apply(
70
- lambda x: (x - x.min()) / (x.max() - x.min())
71
- if x.max() != x.min()
72
- else x,
73
- )
74
- else:
75
- preprocessed_data = preprocessed_data.apply(
76
- lambda x: (x - x.min()) / (x.max() - x.min())
77
- if x.max() != x.min()
78
- else x,
79
- )
80
- return preprocessed_data
81
-
82
- def softmax(logits: NDArray) -> NDArray:
83
- """Apply softmax function to convert logits to probabilities.
84
-
85
- Args:
86
- logits: Input logits array of shape (n_samples, n_classes) or (n_classes,)
87
-
88
- Returns:
89
- Probabilities where values sum to 1 across the last dimension
90
- """
91
- # Handle both 2D and 1D inputs
92
- if logits.ndim == 1:
93
- logits = logits.reshape(1, -1)
94
-
95
- # Apply exponential to each logit with numerical stability
96
- logits_max = np.max(logits, axis=1, keepdims=True)
97
- exp_logits = np.exp(logits - logits_max) # Subtract max for numerical stability
98
-
99
- # Sum across classes and normalize
100
- sum_exp_logits = np.sum(exp_logits, axis=1, keepdims=True)
101
- probs = exp_logits / sum_exp_logits
102
-
103
- # Return in the same shape as input
104
- if logits.ndim == 1:
105
- return probs.reshape(-1)
106
- return probs
@@ -1,466 +0,0 @@
1
- """
2
- Code Adapted from TabArena: https://github.com/autogluon/tabrepo/blob/main/tabrepo/benchmark/models/ag/tabpfnv2/tabpfnv2_model.py
3
- """
4
-
5
- from __future__ import annotations
6
-
7
- import logging
8
- import warnings
9
- from pathlib import Path
10
- from typing import TYPE_CHECKING, Any
11
-
12
- import numpy as np
13
- import scipy
14
- from sklearn.preprocessing import PowerTransformer
15
- from typing_extensions import Self
16
-
17
- from autogluon.common.utils.resource_utils import ResourceManager
18
- from autogluon.tabular.models.abstract.abstract_torch_model import AbstractTorchModel
19
- from autogluon.features.generators import LabelEncoderFeatureGenerator
20
- from autogluon.tabular import __version__
21
-
22
- if TYPE_CHECKING:
23
- import pandas as pd
24
-
25
- logger = logging.getLogger(__name__)
26
-
27
- _HAS_LOGGED_TABPFN_LICENSE: bool = False
28
-
29
-
30
- # TODO: merge into TabPFnv2 codebase
31
- class FixedSafePowerTransformer(PowerTransformer):
32
- """Fixed version of safe power."""
33
-
34
- def __init__(
35
- self,
36
- variance_threshold: float = 1e-3,
37
- large_value_threshold: float = 100,
38
- method="yeo-johnson",
39
- standardize=True,
40
- copy=True,
41
- ):
42
- super().__init__(method=method, standardize=standardize, copy=copy)
43
- self.variance_threshold = variance_threshold
44
- self.large_value_threshold = large_value_threshold
45
-
46
- self.revert_indices_ = None
47
-
48
- def _find_features_to_revert_because_of_failure(
49
- self,
50
- transformed_X: np.ndarray,
51
- ) -> None:
52
- # Calculate the variance for each feature in the transformed data
53
- variances = np.nanvar(transformed_X, axis=0)
54
-
55
- # Identify features where the variance is not close to 1
56
- mask = np.abs(variances - 1) > self.variance_threshold
57
- non_unit_variance_indices = np.where(mask)[0]
58
-
59
- # Identify features with values greater than the large_value_threshold
60
- large_value_indices = np.any(transformed_X > self.large_value_threshold, axis=0)
61
- large_value_indices = np.nonzero(large_value_indices)[0]
62
-
63
- # Identify features to revert based on either condition
64
- self.revert_indices_ = np.unique(
65
- np.concatenate([non_unit_variance_indices, large_value_indices]),
66
- )
67
-
68
- def _yeo_johnson_optimize(self, x: np.ndarray) -> float:
69
- try:
70
- with warnings.catch_warnings():
71
- warnings.filterwarnings(
72
- "ignore",
73
- message=r"overflow encountered",
74
- category=RuntimeWarning,
75
- )
76
- return super()._yeo_johnson_optimize(x) # type: ignore
77
- except scipy.optimize._optimize.BracketError:
78
- return np.nan
79
-
80
- def _yeo_johnson_transform(self, x: np.ndarray, lmbda: float) -> np.ndarray:
81
- if np.isnan(lmbda):
82
- return x
83
-
84
- return super()._yeo_johnson_transform(x, lmbda) # type: ignore
85
-
86
- def _revert_failed_features(
87
- self,
88
- transformed_X: np.ndarray,
89
- original_X: np.ndarray,
90
- ) -> np.ndarray:
91
- # Replace these features with the original features
92
- if self.revert_indices_ and (self.revert_indices_) > 0:
93
- transformed_X[:, self.revert_indices_] = original_X[:, self.revert_indices_]
94
-
95
- return transformed_X
96
-
97
- def fit(self, X: np.ndarray, y: Any | None = None) -> FixedSafePowerTransformer:
98
- super().fit(X, y)
99
-
100
- # Check and revert features as necessary
101
- self._find_features_to_revert_because_of_failure(super().transform(X)) # type: ignore
102
- return self
103
-
104
- def transform(self, X: np.ndarray) -> np.ndarray:
105
- transformed_X = super().transform(X)
106
- return self._revert_failed_features(transformed_X, X) # type: ignore
107
-
108
-
109
- # FIXME: Need to take this logic into v6 for loading on CPU
110
- class TabPFNV2Model(AbstractTorchModel):
111
- """
112
- TabPFNv2 is a tabular foundation model pre-trained purely on synthetic data that achieves
113
- state-of-the-art results with in-context learning on small datasets with <=10000 samples and <=500 features.
114
- TabPFNv2 is developed and maintained by PriorLabs: https://priorlabs.ai/
115
-
116
- TabPFNv2 is the top performing method for small datasets on TabArena-v0.1: https://tabarena.ai
117
-
118
- Paper: Accurate predictions on small data with a tabular foundation model
119
- Authors: Noah Hollmann, Samuel Müller, Lennart Purucker, Arjun Krishnakumar, Max Körfer, Shi Bin Hoo, Robin Tibor Schirrmeister & Frank Hutter
120
- Codebase: https://github.com/PriorLabs/TabPFN
121
- License: https://github.com/PriorLabs/TabPFN/blob/main/LICENSE
122
-
123
- .. versionadded:: 1.4.0
124
- """
125
- ag_key = "TABPFNV2"
126
- ag_name = "TabPFNv2"
127
- ag_priority = 105
128
- seed_name = "random_state"
129
-
130
- def __init__(self, **kwargs):
131
- super().__init__(**kwargs)
132
- self._cached_model = False
133
- self._feature_generator = None
134
- self._cat_features = None
135
- self._cat_indices = None
136
-
137
- def _preprocess(self, X: pd.DataFrame, is_train=False, **kwargs) -> pd.DataFrame:
138
- X = super()._preprocess(X, **kwargs)
139
-
140
- if is_train:
141
- self._cat_indices = []
142
-
143
- # X will be the training data.
144
- self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
145
- self._feature_generator.fit(X=X)
146
-
147
- # This converts categorical features to numeric via stateful label encoding.
148
- if self._feature_generator.features_in:
149
- X = X.copy()
150
- X[self._feature_generator.features_in] = self._feature_generator.transform(
151
- X=X
152
- )
153
-
154
- if is_train:
155
- # Detect/set cat features and indices
156
- if self._cat_features is None:
157
- self._cat_features = self._feature_generator.features_in[:]
158
- self._cat_indices = [X.columns.get_loc(col) for col in self._cat_features]
159
-
160
- return X
161
-
162
- def _get_model_cls(self):
163
- from tabpfn import TabPFNClassifier, TabPFNRegressor
164
- is_classification = self.problem_type in ["binary", "multiclass"]
165
- model_base = TabPFNClassifier if is_classification else TabPFNRegressor
166
- return model_base
167
-
168
- # FIXME: Crashes during model download if bagging with parallel fit.
169
- # Consider adopting same download logic as TabPFNMix which doesn't crash during model download.
170
- # FIXME: Maybe support child_oof somehow with using only one model and being smart about inference time?
171
- def _fit(
172
- self,
173
- X: pd.DataFrame,
174
- y: pd.Series,
175
- num_cpus: int = 1,
176
- num_gpus: int = 0,
177
- verbosity: int = 2,
178
- **kwargs,
179
- ):
180
- try:
181
- from tabpfn.model import preprocessing
182
- except ImportError as err:
183
- logger.log(
184
- 40,
185
- f"\tFailed to import tabpfn! To use the TabPFNv2 model, "
186
- f"do: `pip install autogluon.tabular[tabpfn]=={__version__}`.",
187
- )
188
- raise err
189
-
190
- preprocessing.SafePowerTransformer = FixedSafePowerTransformer
191
-
192
- is_classification = self.problem_type in ["binary", "multiclass"]
193
-
194
- model_base = self._get_model_cls()
195
-
196
- from tabpfn.model.loading import resolve_model_path
197
- from torch.cuda import is_available
198
-
199
- device = "cuda" if num_gpus != 0 else "cpu"
200
- if (device == "cuda") and (not is_available()):
201
- # FIXME: warn instead and switch to CPU.
202
- raise AssertionError(
203
- "Fit specified to use GPU, but CUDA is not available on this machine. "
204
- "Please switch to CPU usage instead.",
205
- )
206
-
207
- if verbosity >= 2:
208
- # logs "Built with PriorLabs-TabPFN"
209
- self._log_license(device=device)
210
-
211
- X = self.preprocess(X, is_train=True)
212
-
213
- hps = self._get_model_params()
214
- hps["device"] = device
215
- hps["n_jobs"] = num_cpus
216
- hps["categorical_features_indices"] = self._cat_indices
217
-
218
- _, model_dir, _, _ = resolve_model_path(
219
- model_path=None,
220
- which="classifier" if is_classification else "regressor",
221
- )
222
- if is_classification:
223
- if "classification_model_path" in hps:
224
- hps["model_path"] = model_dir / hps.pop("classification_model_path")
225
- if "regression_model_path" in hps:
226
- del hps["regression_model_path"]
227
- else:
228
- if "regression_model_path" in hps:
229
- hps["model_path"] = model_dir / hps.pop("regression_model_path")
230
- if "classification_model_path" in hps:
231
- del hps["classification_model_path"]
232
-
233
- # Resolve inference_config
234
- inference_config = {
235
- _k: v
236
- for k, v in hps.items()
237
- if k.startswith("inference_config/") and (_k := k.split("/")[-1])
238
- }
239
- if inference_config:
240
- hps["inference_config"] = inference_config
241
- for k in list(hps.keys()):
242
- if k.startswith("inference_config/"):
243
- del hps[k]
244
-
245
- # TODO: remove power from search space and TabPFNv2 codebase
246
- # Power transform can fail. To avoid this, make all power be safepower instead.
247
- if "PREPROCESS_TRANSFORMS" in inference_config:
248
- safe_config = []
249
- for preprocessing_dict in inference_config["PREPROCESS_TRANSFORMS"]:
250
- if preprocessing_dict["name"] == "power":
251
- preprocessing_dict["name"] = "safepower"
252
- safe_config.append(preprocessing_dict)
253
- inference_config["PREPROCESS_TRANSFORMS"] = safe_config
254
- if "REGRESSION_Y_PREPROCESS_TRANSFORMS" in inference_config:
255
- safe_config = []
256
- for preprocessing_name in inference_config[
257
- "REGRESSION_Y_PREPROCESS_TRANSFORMS"
258
- ]:
259
- if preprocessing_name == "power":
260
- preprocessing_name = "safepower"
261
- safe_config.append(preprocessing_name)
262
- inference_config["REGRESSION_Y_PREPROCESS_TRANSFORMS"] = safe_config
263
-
264
- # Resolve model_type
265
- n_ensemble_repeats = hps.pop("n_ensemble_repeats", None)
266
- model_is_rf_pfn = hps.pop("model_type", "no") == "dt_pfn"
267
- if model_is_rf_pfn:
268
- from .rfpfn import RandomForestTabPFNClassifier, RandomForestTabPFNRegressor
269
-
270
- hps["n_estimators"] = 1
271
- rf_model_base = (
272
- RandomForestTabPFNClassifier
273
- if is_classification
274
- else RandomForestTabPFNRegressor
275
- )
276
- self.model = rf_model_base(
277
- tabpfn=model_base(**hps),
278
- categorical_features=self._cat_indices,
279
- n_estimators=n_ensemble_repeats,
280
- )
281
- else:
282
- if n_ensemble_repeats is not None:
283
- hps["n_estimators"] = n_ensemble_repeats
284
- self.model = model_base(**hps)
285
-
286
- self.model = self.model.fit(
287
- X=X,
288
- y=y,
289
- )
290
-
291
- def get_device(self) -> str:
292
- return self.model.device_.type
293
-
294
- def _set_device(self, device: str):
295
- pass # TODO: Unknown how to properly set device for TabPFN after loading. Refer to `_set_device_tabpfn`.
296
-
297
- # FIXME: This is not comprehensive. Need model authors to add an official API set_device
298
- def _set_device_tabpfn(self, device: str):
299
- import torch
300
- # Move all torch components to the target device
301
- device = self.to_torch_device(device)
302
- self.model.device_ = device
303
- if hasattr(self.model.executor_, "model") and self.model.executor_.model is not None:
304
- self.model.executor_.model.to(self.model.device_)
305
- if hasattr(self.model.executor_, "models"):
306
- self.model.executor_.models = [m.to(self.model.device_) for m in self.model.executor_.models]
307
-
308
- # Restore other potential torch objects from fitted_attrs
309
- for key, value in vars(self.model).items():
310
- if key.endswith("_") and hasattr(value, "to"):
311
- setattr(self.model, key, value.to(self.model.device_))
312
-
313
- def model_weights_path(self, path: str | None = None) -> Path:
314
- if path is None:
315
- path = self.path
316
- return Path(path) / "config.tabpfn_fit"
317
-
318
- def save(self, path: str = None, verbose=True) -> str:
319
- _model = self.model
320
- is_fit = self.is_fit()
321
- if is_fit:
322
- self._save_model_artifact(path=path)
323
- self._cached_model = True
324
- self.model = None
325
- path = super().save(path=path, verbose=verbose)
326
- if is_fit:
327
- self.model = _model
328
- return path
329
-
330
- # TODO: It is required to do this because it is unknown how to otherwise save TabPFN in CPU-only mode.
331
- # Even though we would generally prefer to save it in the pkl for better insurance
332
- # that the model will work in future (self-contained)
333
- def _save_model_artifact(self, path: str | None = None):
334
- # save with CPU device so it can be loaded on a CPU only machine
335
- device_og = self.device
336
- self._set_device_tabpfn(device="cpu")
337
- self.model.save_fit_state(path=self.model_weights_path(path=path))
338
- self._set_device_tabpfn(device=device_og)
339
-
340
- @classmethod
341
- def load(cls, path: str, reset_paths=True, verbose=True) -> Self:
342
- model = super().load(path=path, reset_paths=reset_paths, verbose=verbose)
343
- if model._cached_model:
344
- model._load_model_artifact()
345
- model._cached_model = False
346
- return model
347
-
348
- def _load_model_artifact(self):
349
- model_cls = self._get_model_cls()
350
- device = self.suggest_device_infer()
351
- self.model = model_cls.load_from_fit_state(path=self.model_weights_path(), device=device)
352
- self.device = device
353
-
354
- def _log_license(self, device: str):
355
- global _HAS_LOGGED_TABPFN_LICENSE
356
- if not _HAS_LOGGED_TABPFN_LICENSE:
357
- logger.log(20, "\tBuilt with PriorLabs-TabPFN") # Aligning with TabPFNv2 license requirements
358
- if device == "cpu":
359
- logger.log(
360
- 20,
361
- "\tRunning TabPFNv2 on CPU. This can be very slow. "
362
- "It is recommended to run TabPFNv2 on a GPU."
363
- )
364
- _HAS_LOGGED_TABPFN_LICENSE = True # Avoid repeated logging
365
-
366
- def _get_default_resources(self) -> tuple[int, int]:
367
- # Use only physical cores for better performance based on benchmarks
368
- num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
369
-
370
- num_gpus = min(1, ResourceManager.get_gpu_count_torch(cuda_only=True))
371
-
372
- return num_cpus, num_gpus
373
-
374
- def _set_default_params(self):
375
- default_params = {
376
- "ignore_pretraining_limits": True, # to ignore warnings and size limits
377
- }
378
- for param, val in default_params.items():
379
- self._set_default_param_value(param, val)
380
-
381
- @classmethod
382
- def supported_problem_types(cls) -> list[str] | None:
383
- return ["binary", "multiclass", "regression"]
384
-
385
- def _get_default_auxiliary_params(self) -> dict:
386
- default_auxiliary_params = super()._get_default_auxiliary_params()
387
- default_auxiliary_params.update(
388
- {
389
- "max_rows": 10000,
390
- "max_features": 500,
391
- "max_classes": 10,
392
- "max_batch_size": 10000, # TabPFN seems to cryptically error if predicting on 100,000 samples.
393
- }
394
- )
395
- return default_auxiliary_params
396
-
397
- @classmethod
398
- def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
399
- """Set fold_fitting_strategy to sequential_local,
400
- as parallel folding crashes if model weights aren't pre-downloaded.
401
- """
402
- default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
403
- extra_ag_args_ensemble = {
404
- # FIXME: Find a work-around to avoid crash if parallel and weights are not downloaded
405
- "fold_fitting_strategy": "sequential_local",
406
- "refit_folds": True, # Better to refit the model for faster inference and similar quality as the bag.
407
- }
408
- default_ag_args_ensemble.update(extra_ag_args_ensemble)
409
- return default_ag_args_ensemble
410
-
411
- def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
412
- hyperparameters = self._get_model_params()
413
- return self.estimate_memory_usage_static(
414
- X=X,
415
- problem_type=self.problem_type,
416
- num_classes=self.num_classes,
417
- hyperparameters=hyperparameters,
418
- **kwargs,
419
- )
420
-
421
- @classmethod
422
- def _estimate_memory_usage_static(
423
- cls,
424
- *,
425
- X: pd.DataFrame,
426
- hyperparameters: dict | None = None,
427
- **kwargs,
428
- ) -> int:
429
- """Heuristic memory estimate based on TabPFN's memory estimate logic in:
430
- https://github.com/PriorLabs/TabPFN/blob/57a2efd3ebdb3886245e4d097cefa73a5261a969/src/tabpfn/model/memory.py#L147.
431
-
432
- This is based on GPU memory usage, but hopefully with overheads it also approximates CPU memory usage.
433
- """
434
- # features_per_group = 2 # Based on TabPFNv2 default (unused)
435
- n_layers = 12 # Based on TabPFNv2 default
436
- embedding_size = 192 # Based on TabPFNv2 default
437
- dtype_byte_size = 2 # Based on TabPFNv2 default
438
-
439
- model_mem = 14489108 # Based on TabPFNv2 default
440
-
441
- n_samples, n_features = X.shape[0], X.shape[1]
442
- n_feature_groups = n_features + 1 # TODO: Unsure how to calculate this
443
-
444
- X_mem = n_samples * n_feature_groups * dtype_byte_size
445
- activation_mem = (
446
- n_samples * n_feature_groups * embedding_size * n_layers * dtype_byte_size
447
- )
448
-
449
- baseline_overhead_mem_est = 1e9 # 1 GB generic overhead
450
-
451
- # Add some buffer to each term + 1 GB overhead to be safe
452
- return int(
453
- model_mem + 4 * X_mem + 2 * activation_mem + baseline_overhead_mem_est
454
- )
455
-
456
- @classmethod
457
- def _class_tags(cls):
458
- return {
459
- "can_estimate_memory_usage_static": True,
460
- "can_set_device": True,
461
- "set_device_on_save_to": None,
462
- "set_device_on_load": False,
463
- }
464
-
465
- def _more_tags(self) -> dict:
466
- return {"can_refit_full": True}