autogluon.tabular 1.3.2b20250722__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/configs/config_helper.py +1 -1
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/presets_configs.py +51 -23
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +18 -6
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +8 -4
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +5 -1
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +3 -0
- autogluon/tabular/models/mitra/mitra_model.py +74 -21
- autogluon/tabular/models/mitra/sklearn_interface.py +15 -13
- autogluon/tabular/models/realmlp/realmlp_model.py +13 -6
- autogluon/tabular/models/tabicl/tabicl_model.py +17 -8
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +3 -0
- autogluon/tabular/models/tabm/tabm_model.py +14 -6
- autogluon/tabular/models/tabm/tabm_reference.py +2 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +29 -12
- autogluon/tabular/predictor/predictor.py +79 -26
- autogluon/tabular/trainer/abstract_trainer.py +2 -0
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/METADATA +42 -20
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/RECORD +32 -31
- /autogluon.tabular-1.3.2b20250722-py3.9-nspkg.pth → /autogluon.tabular-1.4.0-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,310 @@
|
|
|
1
|
+
# optimized for <=10000 samples and <=500 features, with a GPU present
|
|
2
|
+
hyperparameter_portfolio_zeroshot_2025_small = {
|
|
3
|
+
"TABPFNV2": [
|
|
4
|
+
{
|
|
5
|
+
"ag_args": {'name_suffix': '_r143', 'priority': -1},
|
|
6
|
+
"average_before_softmax": False,
|
|
7
|
+
"classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
|
|
8
|
+
"inference_config/FINGERPRINT_FEATURE": False,
|
|
9
|
+
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
10
|
+
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
11
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
|
|
12
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
|
|
13
|
+
"inference_config/SUBSAMPLE_SAMPLES": 0.99,
|
|
14
|
+
"model_type": 'single',
|
|
15
|
+
"n_ensemble_repeats": 4,
|
|
16
|
+
"regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
|
|
17
|
+
"softmax_temperature": 0.75,
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"ag_args": {'name_suffix': '_r94', 'priority': -3},
|
|
21
|
+
"average_before_softmax": True,
|
|
22
|
+
"classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
|
|
23
|
+
"inference_config/FINGERPRINT_FEATURE": True,
|
|
24
|
+
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
25
|
+
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
26
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
|
|
27
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
|
|
28
|
+
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
29
|
+
"model_type": 'single',
|
|
30
|
+
"n_ensemble_repeats": 4,
|
|
31
|
+
"regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
|
|
32
|
+
"softmax_temperature": 0.9,
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
"ag_args": {'name_suffix': '_r181', 'priority': -4},
|
|
36
|
+
"average_before_softmax": False,
|
|
37
|
+
"classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
|
|
38
|
+
"inference_config/FINGERPRINT_FEATURE": False,
|
|
39
|
+
"inference_config/OUTLIER_REMOVAL_STD": 9.0,
|
|
40
|
+
"inference_config/POLYNOMIAL_FEATURES": 50,
|
|
41
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
|
|
42
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
|
|
43
|
+
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
44
|
+
"model_type": 'single',
|
|
45
|
+
"n_ensemble_repeats": 4,
|
|
46
|
+
"regression_model_path": 'tabpfn-v2-regressor.ckpt',
|
|
47
|
+
"softmax_temperature": 0.95,
|
|
48
|
+
},
|
|
49
|
+
],
|
|
50
|
+
"GBM": [
|
|
51
|
+
{
|
|
52
|
+
"ag_args": {'name_suffix': '_r33', 'priority': -2},
|
|
53
|
+
"bagging_fraction": 0.9625293420216,
|
|
54
|
+
"bagging_freq": 1,
|
|
55
|
+
"cat_l2": 0.1236875455555,
|
|
56
|
+
"cat_smooth": 68.8584757332856,
|
|
57
|
+
"extra_trees": False,
|
|
58
|
+
"feature_fraction": 0.6189215809382,
|
|
59
|
+
"lambda_l1": 0.1641757352921,
|
|
60
|
+
"lambda_l2": 0.6937755557881,
|
|
61
|
+
"learning_rate": 0.0154031028561,
|
|
62
|
+
"max_cat_to_onehot": 17,
|
|
63
|
+
"min_data_in_leaf": 1,
|
|
64
|
+
"min_data_per_group": 30,
|
|
65
|
+
"num_leaves": 68,
|
|
66
|
+
},
|
|
67
|
+
{
|
|
68
|
+
"ag_args": {'name_suffix': '_r21', 'priority': -16},
|
|
69
|
+
"bagging_fraction": 0.7218730663234,
|
|
70
|
+
"bagging_freq": 1,
|
|
71
|
+
"cat_l2": 0.0296205152578,
|
|
72
|
+
"cat_smooth": 0.0010255271303,
|
|
73
|
+
"extra_trees": False,
|
|
74
|
+
"feature_fraction": 0.4557131604374,
|
|
75
|
+
"lambda_l1": 0.5219704038237,
|
|
76
|
+
"lambda_l2": 0.1070959487853,
|
|
77
|
+
"learning_rate": 0.0055891584996,
|
|
78
|
+
"max_cat_to_onehot": 71,
|
|
79
|
+
"min_data_in_leaf": 50,
|
|
80
|
+
"min_data_per_group": 10,
|
|
81
|
+
"num_leaves": 30,
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"ag_args": {'name_suffix': '_r11', 'priority': -19},
|
|
85
|
+
"bagging_fraction": 0.775784726514,
|
|
86
|
+
"bagging_freq": 1,
|
|
87
|
+
"cat_l2": 0.3888471449178,
|
|
88
|
+
"cat_smooth": 0.0057144748021,
|
|
89
|
+
"extra_trees": True,
|
|
90
|
+
"feature_fraction": 0.7732354787904,
|
|
91
|
+
"lambda_l1": 0.2211002452568,
|
|
92
|
+
"lambda_l2": 1.1318405980187,
|
|
93
|
+
"learning_rate": 0.0090151778542,
|
|
94
|
+
"max_cat_to_onehot": 15,
|
|
95
|
+
"min_data_in_leaf": 4,
|
|
96
|
+
"min_data_per_group": 15,
|
|
97
|
+
"num_leaves": 2,
|
|
98
|
+
},
|
|
99
|
+
],
|
|
100
|
+
"CAT": [
|
|
101
|
+
{
|
|
102
|
+
"ag_args": {'priority': -5},
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
"ag_args": {'name_suffix': '_r51', 'priority': -10},
|
|
106
|
+
"boosting_type": 'Plain',
|
|
107
|
+
"bootstrap_type": 'Bernoulli',
|
|
108
|
+
"colsample_bylevel": 0.8771035272558,
|
|
109
|
+
"depth": 7,
|
|
110
|
+
"grow_policy": 'SymmetricTree',
|
|
111
|
+
"l2_leaf_reg": 2.0107286863021,
|
|
112
|
+
"leaf_estimation_iterations": 2,
|
|
113
|
+
"learning_rate": 0.0058424016622,
|
|
114
|
+
"max_bin": 254,
|
|
115
|
+
"max_ctr_complexity": 4,
|
|
116
|
+
"model_size_reg": 0.1307400355809,
|
|
117
|
+
"one_hot_max_size": 23,
|
|
118
|
+
"subsample": 0.809527841437,
|
|
119
|
+
},
|
|
120
|
+
{
|
|
121
|
+
"ag_args": {'name_suffix': '_r10', 'priority': -12},
|
|
122
|
+
"boosting_type": 'Plain',
|
|
123
|
+
"bootstrap_type": 'Bernoulli',
|
|
124
|
+
"colsample_bylevel": 0.8994502668431,
|
|
125
|
+
"depth": 6,
|
|
126
|
+
"grow_policy": 'Depthwise',
|
|
127
|
+
"l2_leaf_reg": 1.8187025215896,
|
|
128
|
+
"leaf_estimation_iterations": 7,
|
|
129
|
+
"learning_rate": 0.005177304142,
|
|
130
|
+
"max_bin": 254,
|
|
131
|
+
"max_ctr_complexity": 4,
|
|
132
|
+
"model_size_reg": 0.5247386875068,
|
|
133
|
+
"one_hot_max_size": 53,
|
|
134
|
+
"subsample": 0.8705228845742,
|
|
135
|
+
},
|
|
136
|
+
{
|
|
137
|
+
"ag_args": {'name_suffix': '_r24', 'priority': -15},
|
|
138
|
+
"boosting_type": 'Plain',
|
|
139
|
+
"bootstrap_type": 'Bernoulli',
|
|
140
|
+
"colsample_bylevel": 0.8597809376276,
|
|
141
|
+
"depth": 8,
|
|
142
|
+
"grow_policy": 'Depthwise',
|
|
143
|
+
"l2_leaf_reg": 0.3628261923976,
|
|
144
|
+
"leaf_estimation_iterations": 5,
|
|
145
|
+
"learning_rate": 0.016851077771,
|
|
146
|
+
"max_bin": 254,
|
|
147
|
+
"max_ctr_complexity": 4,
|
|
148
|
+
"model_size_reg": 0.1253820547902,
|
|
149
|
+
"one_hot_max_size": 20,
|
|
150
|
+
"subsample": 0.8120271122061,
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"ag_args": {'name_suffix': '_r91', 'priority': -17},
|
|
154
|
+
"boosting_type": 'Plain',
|
|
155
|
+
"bootstrap_type": 'Bernoulli',
|
|
156
|
+
"colsample_bylevel": 0.8959275863514,
|
|
157
|
+
"depth": 4,
|
|
158
|
+
"grow_policy": 'SymmetricTree',
|
|
159
|
+
"l2_leaf_reg": 0.0026915894253,
|
|
160
|
+
"leaf_estimation_iterations": 12,
|
|
161
|
+
"learning_rate": 0.0475233791203,
|
|
162
|
+
"max_bin": 254,
|
|
163
|
+
"max_ctr_complexity": 5,
|
|
164
|
+
"model_size_reg": 0.1633175256924,
|
|
165
|
+
"one_hot_max_size": 11,
|
|
166
|
+
"subsample": 0.798554178926,
|
|
167
|
+
},
|
|
168
|
+
],
|
|
169
|
+
"TABM": [
|
|
170
|
+
{
|
|
171
|
+
"ag_args": {'name_suffix': '_r184', 'priority': -6},
|
|
172
|
+
"amp": False,
|
|
173
|
+
"arch_type": 'tabm-mini',
|
|
174
|
+
"batch_size": 'auto',
|
|
175
|
+
"d_block": 864,
|
|
176
|
+
"d_embedding": 24,
|
|
177
|
+
"dropout": 0.0,
|
|
178
|
+
"gradient_clipping_norm": 1.0,
|
|
179
|
+
"lr": 0.0019256819924656217,
|
|
180
|
+
"n_blocks": 3,
|
|
181
|
+
"num_emb_n_bins": 3,
|
|
182
|
+
"num_emb_type": 'pwl',
|
|
183
|
+
"patience": 16,
|
|
184
|
+
"share_training_batches": False,
|
|
185
|
+
"tabm_k": 32,
|
|
186
|
+
"weight_decay": 0.0,
|
|
187
|
+
},
|
|
188
|
+
{
|
|
189
|
+
"ag_args": {'name_suffix': '_r69', 'priority': -7},
|
|
190
|
+
"amp": False,
|
|
191
|
+
"arch_type": 'tabm-mini',
|
|
192
|
+
"batch_size": 'auto',
|
|
193
|
+
"d_block": 848,
|
|
194
|
+
"d_embedding": 28,
|
|
195
|
+
"dropout": 0.40215621636031007,
|
|
196
|
+
"gradient_clipping_norm": 1.0,
|
|
197
|
+
"lr": 0.0010413640454559532,
|
|
198
|
+
"n_blocks": 3,
|
|
199
|
+
"num_emb_n_bins": 18,
|
|
200
|
+
"num_emb_type": 'pwl',
|
|
201
|
+
"patience": 16,
|
|
202
|
+
"share_training_batches": False,
|
|
203
|
+
"tabm_k": 32,
|
|
204
|
+
"weight_decay": 0.0,
|
|
205
|
+
},
|
|
206
|
+
{
|
|
207
|
+
"ag_args": {'name_suffix': '_r52', 'priority': -11},
|
|
208
|
+
"amp": False,
|
|
209
|
+
"arch_type": 'tabm-mini',
|
|
210
|
+
"batch_size": 'auto',
|
|
211
|
+
"d_block": 1024,
|
|
212
|
+
"d_embedding": 32,
|
|
213
|
+
"dropout": 0.0,
|
|
214
|
+
"gradient_clipping_norm": 1.0,
|
|
215
|
+
"lr": 0.0006297851297842611,
|
|
216
|
+
"n_blocks": 4,
|
|
217
|
+
"num_emb_n_bins": 22,
|
|
218
|
+
"num_emb_type": 'pwl',
|
|
219
|
+
"patience": 16,
|
|
220
|
+
"share_training_batches": False,
|
|
221
|
+
"tabm_k": 32,
|
|
222
|
+
"weight_decay": 0.06900108498839816,
|
|
223
|
+
},
|
|
224
|
+
{
|
|
225
|
+
"ag_args": {'priority': -13},
|
|
226
|
+
},
|
|
227
|
+
{
|
|
228
|
+
"ag_args": {'name_suffix': '_r191', 'priority': -14},
|
|
229
|
+
"amp": False,
|
|
230
|
+
"arch_type": 'tabm-mini',
|
|
231
|
+
"batch_size": 'auto',
|
|
232
|
+
"d_block": 864,
|
|
233
|
+
"d_embedding": 8,
|
|
234
|
+
"dropout": 0.45321529282058803,
|
|
235
|
+
"gradient_clipping_norm": 1.0,
|
|
236
|
+
"lr": 0.0003781238075322413,
|
|
237
|
+
"n_blocks": 4,
|
|
238
|
+
"num_emb_n_bins": 27,
|
|
239
|
+
"num_emb_type": 'pwl',
|
|
240
|
+
"patience": 16,
|
|
241
|
+
"share_training_batches": False,
|
|
242
|
+
"tabm_k": 32,
|
|
243
|
+
"weight_decay": 0.01766851962579851,
|
|
244
|
+
},
|
|
245
|
+
{
|
|
246
|
+
"ag_args": {'name_suffix': '_r49', 'priority': -20},
|
|
247
|
+
"amp": False,
|
|
248
|
+
"arch_type": 'tabm-mini',
|
|
249
|
+
"batch_size": 'auto',
|
|
250
|
+
"d_block": 640,
|
|
251
|
+
"d_embedding": 28,
|
|
252
|
+
"dropout": 0.15296207419190627,
|
|
253
|
+
"gradient_clipping_norm": 1.0,
|
|
254
|
+
"lr": 0.002277678490593717,
|
|
255
|
+
"n_blocks": 3,
|
|
256
|
+
"num_emb_n_bins": 48,
|
|
257
|
+
"num_emb_type": 'pwl',
|
|
258
|
+
"patience": 16,
|
|
259
|
+
"share_training_batches": False,
|
|
260
|
+
"tabm_k": 32,
|
|
261
|
+
"weight_decay": 0.0578159148243893,
|
|
262
|
+
},
|
|
263
|
+
],
|
|
264
|
+
"TABICL": [
|
|
265
|
+
{
|
|
266
|
+
"ag_args": {'priority': -8},
|
|
267
|
+
},
|
|
268
|
+
],
|
|
269
|
+
"XGB": [
|
|
270
|
+
{
|
|
271
|
+
"ag_args": {'name_suffix': '_r171', 'priority': -9},
|
|
272
|
+
"colsample_bylevel": 0.9213705632288,
|
|
273
|
+
"colsample_bynode": 0.6443385965381,
|
|
274
|
+
"enable_categorical": True,
|
|
275
|
+
"grow_policy": 'lossguide',
|
|
276
|
+
"learning_rate": 0.0068171645251,
|
|
277
|
+
"max_cat_to_onehot": 8,
|
|
278
|
+
"max_depth": 6,
|
|
279
|
+
"max_leaves": 10,
|
|
280
|
+
"min_child_weight": 0.0507304250576,
|
|
281
|
+
"reg_alpha": 4.2446346389037,
|
|
282
|
+
"reg_lambda": 1.4800570021253,
|
|
283
|
+
"subsample": 0.9656290596647,
|
|
284
|
+
},
|
|
285
|
+
{
|
|
286
|
+
"ag_args": {'name_suffix': '_r40', 'priority': -18},
|
|
287
|
+
"colsample_bylevel": 0.6377491713202,
|
|
288
|
+
"colsample_bynode": 0.9237625621103,
|
|
289
|
+
"enable_categorical": True,
|
|
290
|
+
"grow_policy": 'lossguide',
|
|
291
|
+
"learning_rate": 0.0112462621131,
|
|
292
|
+
"max_cat_to_onehot": 33,
|
|
293
|
+
"max_depth": 10,
|
|
294
|
+
"max_leaves": 35,
|
|
295
|
+
"min_child_weight": 0.1403464856034,
|
|
296
|
+
"reg_alpha": 3.4960653958503,
|
|
297
|
+
"reg_lambda": 1.3062320805235,
|
|
298
|
+
"subsample": 0.6948898835178,
|
|
299
|
+
},
|
|
300
|
+
],
|
|
301
|
+
"MITRA": [
|
|
302
|
+
{
|
|
303
|
+
"n_estimators": 1,
|
|
304
|
+
"fine_tune": True,
|
|
305
|
+
"fine_tune_steps": 50,
|
|
306
|
+
"ag.num_gpus": 1,
|
|
307
|
+
"ag_args": {'priority': -21},
|
|
308
|
+
},
|
|
309
|
+
],
|
|
310
|
+
}
|
|
@@ -17,7 +17,8 @@ class FTTransformerModel(MultiModalPredictorModel):
|
|
|
17
17
|
ag_name = "FTTransformer"
|
|
18
18
|
|
|
19
19
|
def __init__(self, **kwargs):
|
|
20
|
-
"""
|
|
20
|
+
"""
|
|
21
|
+
FT-Transformer model.
|
|
21
22
|
|
|
22
23
|
The features can be a mix of
|
|
23
24
|
- categorical column
|
|
@@ -48,6 +49,8 @@ class FTTransformerModel(MultiModalPredictorModel):
|
|
|
48
49
|
Names of the features.
|
|
49
50
|
feature_metadata
|
|
50
51
|
The feature metadata.
|
|
52
|
+
|
|
53
|
+
.. versionadded:: 0.6.0
|
|
51
54
|
"""
|
|
52
55
|
super().__init__(**kwargs)
|
|
53
56
|
|
|
@@ -24,10 +24,16 @@ class TrainerFinetune(BaseEstimator):
|
|
|
24
24
|
cfg: ConfigRun,
|
|
25
25
|
model: torch.nn.Module,
|
|
26
26
|
n_classes: int,
|
|
27
|
-
device: str
|
|
28
|
-
|
|
27
|
+
device: str,
|
|
28
|
+
rng: np.random.RandomState = None,
|
|
29
|
+
verbose: bool = True,
|
|
30
|
+
):
|
|
29
31
|
|
|
30
32
|
self.cfg = cfg
|
|
33
|
+
if rng is None:
|
|
34
|
+
rng = np.random.RandomState(self.cfg.seed)
|
|
35
|
+
self.rng = rng
|
|
36
|
+
self.verbose = verbose
|
|
31
37
|
self.device = device
|
|
32
38
|
self.model = model.to(self.device, non_blocking=True)
|
|
33
39
|
self.n_classes = n_classes
|
|
@@ -81,13 +87,15 @@ class TrainerFinetune(BaseEstimator):
|
|
|
81
87
|
y = y_train_transformed,
|
|
82
88
|
task = self.cfg.task,
|
|
83
89
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
84
|
-
max_samples_query = self.cfg.hyperparams['max_samples_query']
|
|
90
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
91
|
+
rng=self.rng,
|
|
85
92
|
)
|
|
86
93
|
|
|
87
94
|
self.checkpoint.reset(self.model)
|
|
88
95
|
|
|
89
96
|
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
|
90
|
-
self.
|
|
97
|
+
if self.verbose:
|
|
98
|
+
self.log_start_metrics(metrics_valid)
|
|
91
99
|
self.checkpoint(self.model, metrics_valid.loss)
|
|
92
100
|
|
|
93
101
|
start_time = time.time()
|
|
@@ -154,13 +162,15 @@ class TrainerFinetune(BaseEstimator):
|
|
|
154
162
|
metrics_train = prediction_metrics_tracker.get_metrics()
|
|
155
163
|
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
|
156
164
|
|
|
157
|
-
self.
|
|
165
|
+
if self.verbose:
|
|
166
|
+
self.log_metrics(epoch, metrics_train, metrics_valid)
|
|
158
167
|
|
|
159
168
|
self.checkpoint(self.model, metrics_valid.loss)
|
|
160
169
|
|
|
161
170
|
self.early_stopping(metrics_valid.metrics[self.metric])
|
|
162
171
|
if self.early_stopping.we_should_stop():
|
|
163
|
-
|
|
172
|
+
if self.verbose:
|
|
173
|
+
logger.info("Early stopping")
|
|
164
174
|
break
|
|
165
175
|
|
|
166
176
|
if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
|
|
@@ -192,6 +202,7 @@ class TrainerFinetune(BaseEstimator):
|
|
|
192
202
|
y_query = y_query,
|
|
193
203
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
194
204
|
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
205
|
+
rng=self.rng,
|
|
195
206
|
)
|
|
196
207
|
|
|
197
208
|
loader = self.make_loader(dataset, training=False)
|
|
@@ -246,6 +257,7 @@ class TrainerFinetune(BaseEstimator):
|
|
|
246
257
|
y_query = None,
|
|
247
258
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
248
259
|
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
260
|
+
rng=self.rng,
|
|
249
261
|
)
|
|
250
262
|
|
|
251
263
|
loader = self.make_loader(dataset, training=False)
|
|
@@ -26,13 +26,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
|
|
|
26
26
|
x_query: np.ndarray,
|
|
27
27
|
y_query: Optional[np.ndarray],
|
|
28
28
|
max_samples_support: int,
|
|
29
|
-
max_samples_query: int
|
|
29
|
+
max_samples_query: int,
|
|
30
|
+
rng: np.random.RandomState,
|
|
30
31
|
):
|
|
31
32
|
"""
|
|
32
33
|
:param: max_features: number of features the tab pfn model has been trained on
|
|
33
34
|
"""
|
|
34
35
|
|
|
35
36
|
self.cfg = cfg
|
|
37
|
+
self.rng = rng
|
|
36
38
|
|
|
37
39
|
self.x_support = x_support
|
|
38
40
|
self.y_support = y_support
|
|
@@ -59,7 +61,7 @@ class DatasetFinetune(torch.utils.data.Dataset):
|
|
|
59
61
|
|
|
60
62
|
def __getitem__(self, idx):
|
|
61
63
|
|
|
62
|
-
support_indices =
|
|
64
|
+
support_indices = self.rng.choice(
|
|
63
65
|
self.n_samples_support,
|
|
64
66
|
size=self.support_size,
|
|
65
67
|
replace=False
|
|
@@ -101,7 +103,8 @@ def DatasetFinetuneGenerator(
|
|
|
101
103
|
y: np.ndarray,
|
|
102
104
|
task: Task,
|
|
103
105
|
max_samples_support: int,
|
|
104
|
-
max_samples_query: int
|
|
106
|
+
max_samples_query: int,
|
|
107
|
+
rng: np.random.RandomState,
|
|
105
108
|
):
|
|
106
109
|
"""
|
|
107
110
|
The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
|
|
@@ -112,7 +115,7 @@ def DatasetFinetuneGenerator(
|
|
|
112
115
|
|
|
113
116
|
while True:
|
|
114
117
|
|
|
115
|
-
x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=
|
|
118
|
+
x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=rng)
|
|
116
119
|
n_samples_support = x_support.shape[0]
|
|
117
120
|
n_samples_query = x_query.shape[0]
|
|
118
121
|
|
|
@@ -127,6 +130,7 @@ def DatasetFinetuneGenerator(
|
|
|
127
130
|
y_query=y_query[:query_size],
|
|
128
131
|
max_samples_support=max_samples_support,
|
|
129
132
|
max_samples_query=max_samples_query,
|
|
133
|
+
rng=rng,
|
|
130
134
|
)
|
|
131
135
|
|
|
132
136
|
yield dataset_finetune
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from sklearn.model_selection import StratifiedKFold, train_test_split
|
|
3
5
|
|
|
@@ -19,9 +21,11 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> t
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
|
|
24
|
+
if isinstance(seed, int):
|
|
25
|
+
seed = np.random.RandomState(seed)
|
|
22
26
|
|
|
23
27
|
# Stratify doesn't shuffle the data, so we shuffle it first
|
|
24
|
-
permutation =
|
|
28
|
+
permutation = seed.permutation(len(y))
|
|
25
29
|
x, y = x[permutation], y[permutation]
|
|
26
30
|
|
|
27
31
|
min_samples_per_class = np.min(np.bincount(y))
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
@@ -29,6 +30,8 @@ from ..._internal.models.embedding import (
|
|
|
29
30
|
Tab2DQuantileEmbeddingX,
|
|
30
31
|
)
|
|
31
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
32
35
|
|
|
33
36
|
class Tab2D(BaseModel):
|
|
34
37
|
|
|
@@ -1,49 +1,56 @@
|
|
|
1
|
-
|
|
2
|
-
# and os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'. The CUBLAS environment variable configures
|
|
3
|
-
# the workspace size for certain CUBLAS operations to ensure reproducibility when using CUDA >= 10.2.
|
|
4
|
-
# Both settings are required to ensure deterministic behavior in operations such as matrix multiplications.
|
|
5
|
-
import os
|
|
6
|
-
|
|
7
|
-
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
|
|
1
|
+
from __future__ import annotations
|
|
8
2
|
|
|
3
|
+
import logging
|
|
9
4
|
import os
|
|
10
5
|
from typing import List, Optional
|
|
11
6
|
|
|
12
7
|
import pandas as pd
|
|
13
|
-
import torch
|
|
14
|
-
import logging
|
|
15
8
|
|
|
16
9
|
from autogluon.common.utils.resource_utils import ResourceManager
|
|
17
10
|
from autogluon.core.models import AbstractModel
|
|
11
|
+
from autogluon.features.generators import LabelEncoderFeatureGenerator
|
|
12
|
+
from autogluon.tabular import __version__
|
|
18
13
|
|
|
19
14
|
logger = logging.getLogger(__name__)
|
|
20
15
|
|
|
21
16
|
|
|
22
|
-
# TODO: Needs memory usage estimate method
|
|
23
17
|
class MitraModel(AbstractModel):
|
|
18
|
+
"""
|
|
19
|
+
Mitra is a tabular foundation model pre-trained purely on synthetic data with the goal
|
|
20
|
+
of optimizing fine-tuning performance over in-context learning performance.
|
|
21
|
+
Mitra was developed by the AutoGluon team @ AWS AI.
|
|
22
|
+
|
|
23
|
+
Mitra's default hyperparameters outperforms all methods for small datasets on TabArena-v0.1 (excluding ensembling): https://tabarena.ai
|
|
24
|
+
|
|
25
|
+
Authors: Xiyuan Zhang, Danielle C. Maddix, Junming Yin, Nick Erickson, Abdul Fatir Ansari, Boran Han, Shuai Zhang, Leman Akoglu, Christos Faloutsos, Michael W. Mahoney, Cuixiong Hu, Huzefa Rangwala, George Karypis, Bernie Wang
|
|
26
|
+
Blog Post: https://www.amazon.science/blog/mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models
|
|
27
|
+
License: Apache-2.0
|
|
28
|
+
|
|
29
|
+
.. versionadded:: 1.4.0
|
|
30
|
+
"""
|
|
24
31
|
ag_key = "MITRA"
|
|
25
32
|
ag_name = "Mitra"
|
|
26
33
|
weights_file_name = "model.pt"
|
|
27
34
|
ag_priority = 55
|
|
28
35
|
|
|
29
|
-
def __init__(self,
|
|
36
|
+
def __init__(self, **kwargs):
|
|
30
37
|
super().__init__(**kwargs)
|
|
31
|
-
self.problem_type = problem_type
|
|
32
38
|
self._weights_saved = False
|
|
39
|
+
self._feature_generator = None
|
|
33
40
|
|
|
34
41
|
@staticmethod
|
|
35
42
|
def _get_default_device():
|
|
36
43
|
"""Get the best available device for the current system."""
|
|
37
44
|
if ResourceManager.get_gpu_count_torch(cuda_only=True) > 0:
|
|
38
|
-
logger.
|
|
45
|
+
logger.log(15, "Using CUDA GPU")
|
|
39
46
|
return "cuda"
|
|
40
47
|
else:
|
|
41
48
|
return "cpu"
|
|
42
49
|
|
|
43
50
|
def get_model_cls(self):
|
|
44
|
-
from .sklearn_interface import MitraClassifier
|
|
45
|
-
|
|
46
51
|
if self.problem_type in ["binary", "multiclass"]:
|
|
52
|
+
from .sklearn_interface import MitraClassifier
|
|
53
|
+
|
|
47
54
|
model_cls = MitraClassifier
|
|
48
55
|
elif self.problem_type == "regression":
|
|
49
56
|
from .sklearn_interface import MitraRegressor
|
|
@@ -53,6 +60,23 @@ class MitraModel(AbstractModel):
|
|
|
53
60
|
raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
|
|
54
61
|
return model_cls
|
|
55
62
|
|
|
63
|
+
def _preprocess(self, X: pd.DataFrame, is_train: bool = False, **kwargs) -> pd.DataFrame:
|
|
64
|
+
X = super()._preprocess(X, **kwargs)
|
|
65
|
+
|
|
66
|
+
if is_train:
|
|
67
|
+
# X will be the training data.
|
|
68
|
+
self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
|
|
69
|
+
self._feature_generator.fit(X=X)
|
|
70
|
+
|
|
71
|
+
# This converts categorical features to numeric via stateful label encoding.
|
|
72
|
+
if self._feature_generator.features_in:
|
|
73
|
+
X = X.copy()
|
|
74
|
+
X[self._feature_generator.features_in] = self._feature_generator.transform(
|
|
75
|
+
X=X
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
return X
|
|
79
|
+
|
|
56
80
|
def _fit(
|
|
57
81
|
self,
|
|
58
82
|
X: pd.DataFrame,
|
|
@@ -61,11 +85,25 @@ class MitraModel(AbstractModel):
|
|
|
61
85
|
y_val: pd.Series = None,
|
|
62
86
|
time_limit: float = None,
|
|
63
87
|
num_cpus: int = 1,
|
|
88
|
+
num_gpus: float = 0,
|
|
89
|
+
verbosity: int = 2,
|
|
64
90
|
**kwargs,
|
|
65
91
|
):
|
|
66
92
|
# TODO: Reset the number of threads based on the specified num_cpus
|
|
67
93
|
need_to_reset_torch_threads = False
|
|
68
94
|
torch_threads_og = None
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
model_cls = self.get_model_cls()
|
|
98
|
+
import torch
|
|
99
|
+
except ImportError as err:
|
|
100
|
+
logger.log(
|
|
101
|
+
40,
|
|
102
|
+
f"\tFailed to import Mitra! To use the Mitra model, "
|
|
103
|
+
f"do: `pip install autogluon.tabular[mitra]=={__version__}`.",
|
|
104
|
+
)
|
|
105
|
+
raise err
|
|
106
|
+
|
|
69
107
|
if num_cpus is not None and isinstance(num_cpus, (int, float)):
|
|
70
108
|
torch_threads_og = torch.get_num_threads()
|
|
71
109
|
if torch_threads_og != num_cpus:
|
|
@@ -73,9 +111,21 @@ class MitraModel(AbstractModel):
|
|
|
73
111
|
torch.set_num_threads(num_cpus)
|
|
74
112
|
need_to_reset_torch_threads = True
|
|
75
113
|
|
|
76
|
-
model_cls = self.get_model_cls()
|
|
77
|
-
|
|
78
114
|
hyp = self._get_model_params()
|
|
115
|
+
|
|
116
|
+
if hyp.get("device", None) is None:
|
|
117
|
+
if num_gpus == 0:
|
|
118
|
+
hyp["device"] = "cpu"
|
|
119
|
+
else:
|
|
120
|
+
hyp["device"] = self._get_default_device()
|
|
121
|
+
|
|
122
|
+
if hyp["device"] == "cpu" and hyp.get("fine_tune", True):
|
|
123
|
+
logger.log(
|
|
124
|
+
30,
|
|
125
|
+
f"\tWarning: Attempting to fine-tune Mitra on CPU. This will be very slow. "
|
|
126
|
+
f"We strongly recommend using a GPU instance to fine-tune Mitra."
|
|
127
|
+
)
|
|
128
|
+
|
|
79
129
|
if "state_dict_classification" in hyp:
|
|
80
130
|
state_dict_classification = hyp.pop("state_dict_classification")
|
|
81
131
|
if self.problem_type in ["binary", "multiclass"]:
|
|
@@ -85,11 +135,14 @@ class MitraModel(AbstractModel):
|
|
|
85
135
|
if self.problem_type in ["regression"]:
|
|
86
136
|
hyp["state_dict"] = state_dict_regression
|
|
87
137
|
|
|
138
|
+
if "verbose" not in hyp:
|
|
139
|
+
hyp["verbose"] = verbosity >= 3
|
|
140
|
+
|
|
88
141
|
self.model = model_cls(
|
|
89
142
|
**hyp,
|
|
90
143
|
)
|
|
91
144
|
|
|
92
|
-
X = self.preprocess(X)
|
|
145
|
+
X = self.preprocess(X, is_train=True)
|
|
93
146
|
if X_val is not None:
|
|
94
147
|
X_val = self.preprocess(X_val)
|
|
95
148
|
|
|
@@ -106,7 +159,6 @@ class MitraModel(AbstractModel):
|
|
|
106
159
|
|
|
107
160
|
def _set_default_params(self):
|
|
108
161
|
default_params = {
|
|
109
|
-
"device": self._get_default_device(),
|
|
110
162
|
"n_estimators": 1,
|
|
111
163
|
}
|
|
112
164
|
for param, val in default_params.items():
|
|
@@ -196,12 +248,13 @@ class MitraModel(AbstractModel):
|
|
|
196
248
|
X: pd.DataFrame,
|
|
197
249
|
**kwargs,
|
|
198
250
|
) -> int:
|
|
199
|
-
|
|
251
|
+
# Multiply by 0.9 as currently this is overly safe
|
|
252
|
+
return int(0.9 * max(
|
|
200
253
|
cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
|
|
201
254
|
cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
|
|
202
255
|
cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
|
|
203
256
|
cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
|
|
204
|
-
)
|
|
257
|
+
))
|
|
205
258
|
|
|
206
259
|
@classmethod
|
|
207
260
|
def _estimate_memory_usage_static_cpu_icl(
|