autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -0,0 +1,310 @@
1
+ # optimized for <=10000 samples and <=500 features, with a GPU present
2
+ hyperparameter_portfolio_zeroshot_2025_small = {
3
+ "TABPFNV2": [
4
+ {
5
+ "ag_args": {'name_suffix': '_r143', 'priority': -1},
6
+ "average_before_softmax": False,
7
+ "classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
8
+ "inference_config/FINGERPRINT_FEATURE": False,
9
+ "inference_config/OUTLIER_REMOVAL_STD": None,
10
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
11
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
12
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
13
+ "inference_config/SUBSAMPLE_SAMPLES": 0.99,
14
+ "model_type": 'single',
15
+ "n_ensemble_repeats": 4,
16
+ "regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
17
+ "softmax_temperature": 0.75,
18
+ },
19
+ {
20
+ "ag_args": {'name_suffix': '_r94', 'priority': -3},
21
+ "average_before_softmax": True,
22
+ "classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
23
+ "inference_config/FINGERPRINT_FEATURE": True,
24
+ "inference_config/OUTLIER_REMOVAL_STD": None,
25
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
26
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
27
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
28
+ "inference_config/SUBSAMPLE_SAMPLES": None,
29
+ "model_type": 'single',
30
+ "n_ensemble_repeats": 4,
31
+ "regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
32
+ "softmax_temperature": 0.9,
33
+ },
34
+ {
35
+ "ag_args": {'name_suffix': '_r181', 'priority': -4},
36
+ "average_before_softmax": False,
37
+ "classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
38
+ "inference_config/FINGERPRINT_FEATURE": False,
39
+ "inference_config/OUTLIER_REMOVAL_STD": 9.0,
40
+ "inference_config/POLYNOMIAL_FEATURES": 50,
41
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
42
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
43
+ "inference_config/SUBSAMPLE_SAMPLES": None,
44
+ "model_type": 'single',
45
+ "n_ensemble_repeats": 4,
46
+ "regression_model_path": 'tabpfn-v2-regressor.ckpt',
47
+ "softmax_temperature": 0.95,
48
+ },
49
+ ],
50
+ "GBM": [
51
+ {
52
+ "ag_args": {'name_suffix': '_r33', 'priority': -2},
53
+ "bagging_fraction": 0.9625293420216,
54
+ "bagging_freq": 1,
55
+ "cat_l2": 0.1236875455555,
56
+ "cat_smooth": 68.8584757332856,
57
+ "extra_trees": False,
58
+ "feature_fraction": 0.6189215809382,
59
+ "lambda_l1": 0.1641757352921,
60
+ "lambda_l2": 0.6937755557881,
61
+ "learning_rate": 0.0154031028561,
62
+ "max_cat_to_onehot": 17,
63
+ "min_data_in_leaf": 1,
64
+ "min_data_per_group": 30,
65
+ "num_leaves": 68,
66
+ },
67
+ {
68
+ "ag_args": {'name_suffix': '_r21', 'priority': -16},
69
+ "bagging_fraction": 0.7218730663234,
70
+ "bagging_freq": 1,
71
+ "cat_l2": 0.0296205152578,
72
+ "cat_smooth": 0.0010255271303,
73
+ "extra_trees": False,
74
+ "feature_fraction": 0.4557131604374,
75
+ "lambda_l1": 0.5219704038237,
76
+ "lambda_l2": 0.1070959487853,
77
+ "learning_rate": 0.0055891584996,
78
+ "max_cat_to_onehot": 71,
79
+ "min_data_in_leaf": 50,
80
+ "min_data_per_group": 10,
81
+ "num_leaves": 30,
82
+ },
83
+ {
84
+ "ag_args": {'name_suffix': '_r11', 'priority': -19},
85
+ "bagging_fraction": 0.775784726514,
86
+ "bagging_freq": 1,
87
+ "cat_l2": 0.3888471449178,
88
+ "cat_smooth": 0.0057144748021,
89
+ "extra_trees": True,
90
+ "feature_fraction": 0.7732354787904,
91
+ "lambda_l1": 0.2211002452568,
92
+ "lambda_l2": 1.1318405980187,
93
+ "learning_rate": 0.0090151778542,
94
+ "max_cat_to_onehot": 15,
95
+ "min_data_in_leaf": 4,
96
+ "min_data_per_group": 15,
97
+ "num_leaves": 2,
98
+ },
99
+ ],
100
+ "CAT": [
101
+ {
102
+ "ag_args": {'priority': -5},
103
+ },
104
+ {
105
+ "ag_args": {'name_suffix': '_r51', 'priority': -10},
106
+ "boosting_type": 'Plain',
107
+ "bootstrap_type": 'Bernoulli',
108
+ "colsample_bylevel": 0.8771035272558,
109
+ "depth": 7,
110
+ "grow_policy": 'SymmetricTree',
111
+ "l2_leaf_reg": 2.0107286863021,
112
+ "leaf_estimation_iterations": 2,
113
+ "learning_rate": 0.0058424016622,
114
+ "max_bin": 254,
115
+ "max_ctr_complexity": 4,
116
+ "model_size_reg": 0.1307400355809,
117
+ "one_hot_max_size": 23,
118
+ "subsample": 0.809527841437,
119
+ },
120
+ {
121
+ "ag_args": {'name_suffix': '_r10', 'priority': -12},
122
+ "boosting_type": 'Plain',
123
+ "bootstrap_type": 'Bernoulli',
124
+ "colsample_bylevel": 0.8994502668431,
125
+ "depth": 6,
126
+ "grow_policy": 'Depthwise',
127
+ "l2_leaf_reg": 1.8187025215896,
128
+ "leaf_estimation_iterations": 7,
129
+ "learning_rate": 0.005177304142,
130
+ "max_bin": 254,
131
+ "max_ctr_complexity": 4,
132
+ "model_size_reg": 0.5247386875068,
133
+ "one_hot_max_size": 53,
134
+ "subsample": 0.8705228845742,
135
+ },
136
+ {
137
+ "ag_args": {'name_suffix': '_r24', 'priority': -15},
138
+ "boosting_type": 'Plain',
139
+ "bootstrap_type": 'Bernoulli',
140
+ "colsample_bylevel": 0.8597809376276,
141
+ "depth": 8,
142
+ "grow_policy": 'Depthwise',
143
+ "l2_leaf_reg": 0.3628261923976,
144
+ "leaf_estimation_iterations": 5,
145
+ "learning_rate": 0.016851077771,
146
+ "max_bin": 254,
147
+ "max_ctr_complexity": 4,
148
+ "model_size_reg": 0.1253820547902,
149
+ "one_hot_max_size": 20,
150
+ "subsample": 0.8120271122061,
151
+ },
152
+ {
153
+ "ag_args": {'name_suffix': '_r91', 'priority': -17},
154
+ "boosting_type": 'Plain',
155
+ "bootstrap_type": 'Bernoulli',
156
+ "colsample_bylevel": 0.8959275863514,
157
+ "depth": 4,
158
+ "grow_policy": 'SymmetricTree',
159
+ "l2_leaf_reg": 0.0026915894253,
160
+ "leaf_estimation_iterations": 12,
161
+ "learning_rate": 0.0475233791203,
162
+ "max_bin": 254,
163
+ "max_ctr_complexity": 5,
164
+ "model_size_reg": 0.1633175256924,
165
+ "one_hot_max_size": 11,
166
+ "subsample": 0.798554178926,
167
+ },
168
+ ],
169
+ "TABM": [
170
+ {
171
+ "ag_args": {'name_suffix': '_r184', 'priority': -6},
172
+ "amp": False,
173
+ "arch_type": 'tabm-mini',
174
+ "batch_size": 'auto',
175
+ "d_block": 864,
176
+ "d_embedding": 24,
177
+ "dropout": 0.0,
178
+ "gradient_clipping_norm": 1.0,
179
+ "lr": 0.0019256819924656217,
180
+ "n_blocks": 3,
181
+ "num_emb_n_bins": 3,
182
+ "num_emb_type": 'pwl',
183
+ "patience": 16,
184
+ "share_training_batches": False,
185
+ "tabm_k": 32,
186
+ "weight_decay": 0.0,
187
+ },
188
+ {
189
+ "ag_args": {'name_suffix': '_r69', 'priority': -7},
190
+ "amp": False,
191
+ "arch_type": 'tabm-mini',
192
+ "batch_size": 'auto',
193
+ "d_block": 848,
194
+ "d_embedding": 28,
195
+ "dropout": 0.40215621636031007,
196
+ "gradient_clipping_norm": 1.0,
197
+ "lr": 0.0010413640454559532,
198
+ "n_blocks": 3,
199
+ "num_emb_n_bins": 18,
200
+ "num_emb_type": 'pwl',
201
+ "patience": 16,
202
+ "share_training_batches": False,
203
+ "tabm_k": 32,
204
+ "weight_decay": 0.0,
205
+ },
206
+ {
207
+ "ag_args": {'name_suffix': '_r52', 'priority': -11},
208
+ "amp": False,
209
+ "arch_type": 'tabm-mini',
210
+ "batch_size": 'auto',
211
+ "d_block": 1024,
212
+ "d_embedding": 32,
213
+ "dropout": 0.0,
214
+ "gradient_clipping_norm": 1.0,
215
+ "lr": 0.0006297851297842611,
216
+ "n_blocks": 4,
217
+ "num_emb_n_bins": 22,
218
+ "num_emb_type": 'pwl',
219
+ "patience": 16,
220
+ "share_training_batches": False,
221
+ "tabm_k": 32,
222
+ "weight_decay": 0.06900108498839816,
223
+ },
224
+ {
225
+ "ag_args": {'priority': -13},
226
+ },
227
+ {
228
+ "ag_args": {'name_suffix': '_r191', 'priority': -14},
229
+ "amp": False,
230
+ "arch_type": 'tabm-mini',
231
+ "batch_size": 'auto',
232
+ "d_block": 864,
233
+ "d_embedding": 8,
234
+ "dropout": 0.45321529282058803,
235
+ "gradient_clipping_norm": 1.0,
236
+ "lr": 0.0003781238075322413,
237
+ "n_blocks": 4,
238
+ "num_emb_n_bins": 27,
239
+ "num_emb_type": 'pwl',
240
+ "patience": 16,
241
+ "share_training_batches": False,
242
+ "tabm_k": 32,
243
+ "weight_decay": 0.01766851962579851,
244
+ },
245
+ {
246
+ "ag_args": {'name_suffix': '_r49', 'priority': -20},
247
+ "amp": False,
248
+ "arch_type": 'tabm-mini',
249
+ "batch_size": 'auto',
250
+ "d_block": 640,
251
+ "d_embedding": 28,
252
+ "dropout": 0.15296207419190627,
253
+ "gradient_clipping_norm": 1.0,
254
+ "lr": 0.002277678490593717,
255
+ "n_blocks": 3,
256
+ "num_emb_n_bins": 48,
257
+ "num_emb_type": 'pwl',
258
+ "patience": 16,
259
+ "share_training_batches": False,
260
+ "tabm_k": 32,
261
+ "weight_decay": 0.0578159148243893,
262
+ },
263
+ ],
264
+ "TABICL": [
265
+ {
266
+ "ag_args": {'priority': -8},
267
+ },
268
+ ],
269
+ "XGB": [
270
+ {
271
+ "ag_args": {'name_suffix': '_r171', 'priority': -9},
272
+ "colsample_bylevel": 0.9213705632288,
273
+ "colsample_bynode": 0.6443385965381,
274
+ "enable_categorical": True,
275
+ "grow_policy": 'lossguide',
276
+ "learning_rate": 0.0068171645251,
277
+ "max_cat_to_onehot": 8,
278
+ "max_depth": 6,
279
+ "max_leaves": 10,
280
+ "min_child_weight": 0.0507304250576,
281
+ "reg_alpha": 4.2446346389037,
282
+ "reg_lambda": 1.4800570021253,
283
+ "subsample": 0.9656290596647,
284
+ },
285
+ {
286
+ "ag_args": {'name_suffix': '_r40', 'priority': -18},
287
+ "colsample_bylevel": 0.6377491713202,
288
+ "colsample_bynode": 0.9237625621103,
289
+ "enable_categorical": True,
290
+ "grow_policy": 'lossguide',
291
+ "learning_rate": 0.0112462621131,
292
+ "max_cat_to_onehot": 33,
293
+ "max_depth": 10,
294
+ "max_leaves": 35,
295
+ "min_child_weight": 0.1403464856034,
296
+ "reg_alpha": 3.4960653958503,
297
+ "reg_lambda": 1.3062320805235,
298
+ "subsample": 0.6948898835178,
299
+ },
300
+ ],
301
+ "MITRA": [
302
+ {
303
+ "n_estimators": 1,
304
+ "fine_tune": True,
305
+ "fine_tune_steps": 50,
306
+ "ag.num_gpus": 1,
307
+ "ag_args": {'priority': -21},
308
+ },
309
+ ],
310
+ }
@@ -3,6 +3,7 @@ from autogluon.core.models.abstract.abstract_model import AbstractModel
3
3
  from .automm.automm_model import MultiModalPredictorModel
4
4
  from .automm.ft_transformer import FTTransformerModel
5
5
  from .catboost.catboost_model import CatBoostModel
6
+ from .ebm.ebm_model import EBMModel
6
7
  from .fastainn.tabular_nn_fastai import NNFastAiTabularModel
7
8
  from .fasttext.fasttext_model import FastTextModel
8
9
  from .image_prediction.image_predictor import ImagePredictorModel
@@ -17,9 +18,13 @@ from .imodels.imodels_models import (
17
18
  from .knn.knn_model import KNNModel
18
19
  from .lgb.lgb_model import LGBModel
19
20
  from .lr.lr_model import LinearModel
21
+ from .realmlp.realmlp_model import RealMLPModel
20
22
  from .rf.rf_model import RFModel
21
- from .tabpfn.tabpfn_model import TabPFNModel
23
+ from .tabicl.tabicl_model import TabICLModel
24
+ from .tabm.tabm_model import TabMModel
25
+ from .tabpfnv2.tabpfnv2_model import TabPFNV2Model
22
26
  from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
27
+ from .mitra.mitra_model import MitraModel
23
28
  from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
24
29
  from .text_prediction.text_prediction_v1_model import TextPredictorModel
25
30
  from .xgboost.xgboost_model import XGBoostModel
@@ -10,7 +10,7 @@ class RapidsModelMixin:
10
10
  @classmethod
11
11
  def _get_default_ag_args_ensemble(cls, **kwargs) -> dict:
12
12
  default_ag_args_ensemble = super()._get_default_ag_args_ensemble(**kwargs)
13
- extra_ag_args_ensemble = {"use_child_oof": False}
13
+ extra_ag_args_ensemble = {"use_child_oof": False, "fold_fitting_strategy": "sequential_local"}
14
14
  default_ag_args_ensemble.update(extra_ag_args_ensemble)
15
15
  return default_ag_args_ensemble
16
16
 
@@ -65,6 +65,8 @@ class MultiModalPredictorModel(AbstractModel):
65
65
  Names of the features.
66
66
  feature_metadata
67
67
  The feature metadata.
68
+
69
+ .. versionadded:: 0.3.0
68
70
  """
69
71
  super().__init__(**kwargs)
70
72
  self._label_column_name = None
@@ -17,7 +17,8 @@ class FTTransformerModel(MultiModalPredictorModel):
17
17
  ag_name = "FTTransformer"
18
18
 
19
19
  def __init__(self, **kwargs):
20
- """Wrapper of autogluon.multimodal.MultiModalPredictor.
20
+ """
21
+ FT-Transformer model.
21
22
 
22
23
  The features can be a mix of
23
24
  - categorical column
@@ -48,6 +49,8 @@ class FTTransformerModel(MultiModalPredictorModel):
48
49
  Names of the features.
49
50
  feature_metadata
50
51
  The feature metadata.
52
+
53
+ .. versionadded:: 0.6.0
51
54
  """
52
55
  super().__init__(**kwargs)
53
56
 
@@ -170,14 +170,15 @@ class EarlyStoppingCallback:
170
170
 
171
171
  self.eval_metric_name = eval_metric_name
172
172
  self.is_max_optimal = is_max_optimal
173
- self.is_quantile = self.eval_metric_name.startswith(CATBOOST_QUANTILE_PREFIX)
173
+ self.is_quantile = CATBOOST_QUANTILE_PREFIX in self.eval_metric_name
174
174
 
175
175
  def after_iteration(self, info):
176
176
  is_best_iter = False
177
177
  if self.is_quantile:
178
178
  # FIXME: CatBoost adds extra ',' in the metric name if quantile levels are not balanced
179
179
  # e.g., 'MultiQuantile:alpha=0.1,0.25,0.5,0.95' becomes 'MultiQuantile:alpha=0.1,,0.25,0.5,0.95'
180
- eval_metric_name = [k for k in info.metrics[self.compare_key] if k.startswith(CATBOOST_QUANTILE_PREFIX)][0]
180
+ # `'Quantile:' in k` catches both multiquantile (MultiQuantile:) and single-quantile mode (Quantile:)
181
+ eval_metric_name = [k for k in info.metrics[self.compare_key] if CATBOOST_QUANTILE_PREFIX in k][0]
181
182
  else:
182
183
  eval_metric_name = self.eval_metric_name
183
184
  cur_score = info.metrics[self.compare_key][eval_metric_name][-1]
@@ -13,13 +13,13 @@ from autogluon.common.features.types import R_BOOL, R_CATEGORY, R_FLOAT, R_INT
13
13
  from autogluon.common.utils.pandas_utils import get_approximate_df_mem_usage
14
14
  from autogluon.common.utils.resource_utils import ResourceManager
15
15
  from autogluon.common.utils.try_import import try_import_catboost
16
- from autogluon.core.constants import MULTICLASS, PROBLEM_TYPES_CLASSIFICATION, QUANTILE, SOFTCLASS
16
+ from autogluon.core.constants import MULTICLASS, PROBLEM_TYPES_CLASSIFICATION, REGRESSION, QUANTILE, SOFTCLASS
17
17
  from autogluon.core.models import AbstractModel
18
18
  from autogluon.core.models._utils import get_early_stopping_rounds
19
19
  from autogluon.core.utils.exceptions import TimeLimitExceeded
20
20
 
21
21
  from .callbacks import EarlyStoppingCallback, MemoryCheckCallback, TimeCheckCallback
22
- from .catboost_utils import get_catboost_metric_from_ag_metric
22
+ from .catboost_utils import get_catboost_metric_from_ag_metric, CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION
23
23
  from .hyperparameters.parameters import get_param_baseline
24
24
  from .hyperparameters.searchspaces import get_default_searchspace
25
25
 
@@ -39,6 +39,7 @@ class CatBoostModel(AbstractModel):
39
39
  ag_priority_by_problem_type = MappingProxyType({
40
40
  SOFTCLASS: 60
41
41
  })
42
+ seed_name = "random_seed"
42
43
 
43
44
  def __init__(self, **kwargs):
44
45
  super().__init__(**kwargs)
@@ -48,7 +49,6 @@ class CatBoostModel(AbstractModel):
48
49
  default_params = get_param_baseline(problem_type=self.problem_type)
49
50
  for param, val in default_params.items():
50
51
  self._set_default_param_value(param, val)
51
- self._set_default_param_value("random_seed", 0) # Remove randomness for reproducibility
52
52
  # Set 'allow_writing_files' to True in order to keep log files created by catboost during training (these will be saved in the directory where AutoGluon stores this model)
53
53
  self._set_default_param_value("allow_writing_files", False) # Disables creation of catboost logging files during training by default
54
54
  if self.problem_type != SOFTCLASS: # TODO: remove this after catboost 0.24
@@ -126,16 +126,20 @@ class CatBoostModel(AbstractModel):
126
126
 
127
127
  ag_params = self._get_ag_params()
128
128
  params = self._get_model_params()
129
+
129
130
  params["thread_count"] = num_cpus
130
131
  if self.problem_type == SOFTCLASS:
131
132
  # FIXME: This is extremely slow due to unoptimized metric / objective sent to CatBoost
132
133
  from .catboost_softclass_utils import SoftclassCustomMetric, SoftclassObjective
133
134
 
134
- params["loss_function"] = SoftclassObjective.SoftLogLossObjective()
135
+ params.setdefault("loss_function", SoftclassObjective.SoftLogLossObjective())
135
136
  params["eval_metric"] = SoftclassCustomMetric.SoftLogLossMetric()
136
- elif self.problem_type == QUANTILE:
137
- # FIXME: Unless specified, CatBoost defaults to loss_function='MultiQuantile' and raises an exception
138
- params["loss_function"] = params["eval_metric"]
137
+ elif self.problem_type in [REGRESSION, QUANTILE]:
138
+ # Choose appropriate loss_function that is as close as possible to the eval_metric
139
+ params.setdefault(
140
+ "loss_function",
141
+ CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION.get(params["eval_metric"], params["eval_metric"])
142
+ )
139
143
 
140
144
  model_type = CatBoostClassifier if self.problem_type in PROBLEM_TYPES_CLASSIFICATION else CatBoostRegressor
141
145
  num_rows_train = len(X)
@@ -307,6 +311,8 @@ class CatBoostModel(AbstractModel):
307
311
  max_memory_iters = math.floor(available_mem * max_memory_proportion / mem_usage_per_iter)
308
312
 
309
313
  final_iters = min(default_iters, min(max_memory_iters, estimated_iters_in_time))
314
+ if final_iters < 1:
315
+ raise TimeLimitExceeded
310
316
  return final_iters
311
317
 
312
318
  def _predict_proba(self, X, **kwargs):
@@ -350,8 +356,8 @@ class CatBoostModel(AbstractModel):
350
356
  return minimum_resources
351
357
 
352
358
  def _get_default_resources(self):
353
- # logical=False is faster in training
354
- num_cpus = ResourceManager.get_cpu_count_psutil(logical=False)
359
+ # only_physical_cores=True is faster in training
360
+ num_cpus = ResourceManager.get_cpu_count(only_physical_cores=True)
355
361
  num_gpus = 0
356
362
  return num_cpus, num_gpus
357
363
 
@@ -5,7 +5,14 @@ from autogluon.core.constants import BINARY, MULTICLASS, QUANTILE, REGRESSION, S
5
5
  logger = logging.getLogger(__name__)
6
6
 
7
7
 
8
- CATBOOST_QUANTILE_PREFIX = "MultiQuantile:"
8
+ CATBOOST_QUANTILE_PREFIX = "Quantile:"
9
+ # Mapping from non-optimizable eval_metric to optimizable loss_function.
10
+ # See https://catboost.ai/docs/en/concepts/loss-functions-regression#usage-information
11
+ CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION = {
12
+ "MedianAbsoluteError": "MAE",
13
+ "SMAPE": "MAPE",
14
+ "R2": "RMSE",
15
+ }
9
16
 
10
17
 
11
18
  # TODO: Add weight support?
@@ -65,7 +72,10 @@ def get_catboost_metric_from_ag_metric(metric, problem_type, quantile_levels=Non
65
72
  mean_squared_error="RMSE",
66
73
  root_mean_squared_error="RMSE",
67
74
  mean_absolute_error="MAE",
75
+ mean_absolute_percentage_error="MAPE",
76
+ # Non-optimizable metrics, see CATBOOST_EVAL_METRIC_TO_LOSS_FUNCTION
68
77
  median_absolute_error="MedianAbsoluteError",
78
+ symmetric_mean_absolute_percentage_error="SMAPE",
69
79
  r2="R2",
70
80
  )
71
81
  metric_class = metric_map.get(metric.name, "RMSE")
@@ -74,8 +84,12 @@ def get_catboost_metric_from_ag_metric(metric, problem_type, quantile_levels=Non
74
84
  raise AssertionError(f"quantile_levels must be provided for problem_type = {problem_type}")
75
85
  if not all(0 < q < 1 for q in quantile_levels):
76
86
  raise AssertionError(f"quantile_levels must fulfill 0 < q < 1, provided quantile_levels: {quantile_levels}")
77
- quantile_string = ",".join(str(q) for q in quantile_levels)
78
- metric_class = f"{CATBOOST_QUANTILE_PREFIX}alpha={quantile_string}"
87
+ # Loss function MultiQuantile: can only be used if len(quantile_levels) >= 2, otherwise we must use Quantile:
88
+ if len(quantile_levels) == 1:
89
+ metric_class = f"{CATBOOST_QUANTILE_PREFIX}alpha={quantile_levels[0]}"
90
+ else:
91
+ quantile_string = ",".join(str(q) for q in quantile_levels)
92
+ metric_class = f"Multi{CATBOOST_QUANTILE_PREFIX}alpha={quantile_string}"
79
93
  else:
80
94
  raise AssertionError(f"CatBoost does not support {problem_type} problem type.")
81
95
 
File without changes