autogluon.tabular 1.3.2b20250723__py3-none-any.whl → 1.4.0b20250725__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.tabular might be problematic. Click here for more details.
- autogluon/tabular/configs/hyperparameter_configs.py +2 -265
- autogluon/tabular/configs/presets_configs.py +51 -23
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +309 -0
- autogluon/tabular/models/automm/automm_model.py +2 -0
- autogluon/tabular/models/automm/ft_transformer.py +4 -1
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +18 -6
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +8 -4
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +5 -1
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +3 -0
- autogluon/tabular/models/mitra/mitra_model.py +85 -21
- autogluon/tabular/models/mitra/sklearn_interface.py +15 -13
- autogluon/tabular/models/realmlp/realmlp_model.py +13 -6
- autogluon/tabular/models/tabicl/tabicl_model.py +17 -8
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +3 -0
- autogluon/tabular/models/tabm/tabm_model.py +14 -6
- autogluon/tabular/models/tabm/tabm_reference.py +2 -0
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +29 -12
- autogluon/tabular/predictor/predictor.py +45 -5
- autogluon/tabular/trainer/abstract_trainer.py +2 -0
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/METADATA +40 -18
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/RECORD +31 -30
- /autogluon.tabular-1.3.2b20250723-py3.9-nspkg.pth → /autogluon.tabular-1.4.0b20250725-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/zip-safe +0 -0
|
@@ -0,0 +1,309 @@
|
|
|
1
|
+
# optimized for <=10000 samples and <=500 features, with a GPU present
|
|
2
|
+
hyperparameter_portfolio_zeroshot_2025_small = {
|
|
3
|
+
"TABPFNV2": [
|
|
4
|
+
{
|
|
5
|
+
"ag_args": {'name_suffix': '_r143', 'priority': -1},
|
|
6
|
+
"average_before_softmax": False,
|
|
7
|
+
"classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
|
|
8
|
+
"inference_config/FINGERPRINT_FEATURE": False,
|
|
9
|
+
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
10
|
+
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
11
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
|
|
12
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
|
|
13
|
+
"inference_config/SUBSAMPLE_SAMPLES": 0.99,
|
|
14
|
+
"model_type": 'single',
|
|
15
|
+
"n_ensemble_repeats": 4,
|
|
16
|
+
"regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
|
|
17
|
+
"softmax_temperature": 0.75,
|
|
18
|
+
},
|
|
19
|
+
{
|
|
20
|
+
"ag_args": {'name_suffix': '_r94', 'priority': -3},
|
|
21
|
+
"average_before_softmax": True,
|
|
22
|
+
"classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
|
|
23
|
+
"inference_config/FINGERPRINT_FEATURE": True,
|
|
24
|
+
"inference_config/OUTLIER_REMOVAL_STD": None,
|
|
25
|
+
"inference_config/POLYNOMIAL_FEATURES": 'no',
|
|
26
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
|
|
27
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
|
|
28
|
+
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
29
|
+
"model_type": 'single',
|
|
30
|
+
"n_ensemble_repeats": 4,
|
|
31
|
+
"regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
|
|
32
|
+
"softmax_temperature": 0.9,
|
|
33
|
+
},
|
|
34
|
+
{
|
|
35
|
+
"ag_args": {'name_suffix': '_r181', 'priority': -4},
|
|
36
|
+
"average_before_softmax": False,
|
|
37
|
+
"classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
|
|
38
|
+
"inference_config/FINGERPRINT_FEATURE": False,
|
|
39
|
+
"inference_config/OUTLIER_REMOVAL_STD": 9.0,
|
|
40
|
+
"inference_config/POLYNOMIAL_FEATURES": 50,
|
|
41
|
+
"inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
|
|
42
|
+
"inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
|
|
43
|
+
"inference_config/SUBSAMPLE_SAMPLES": None,
|
|
44
|
+
"model_type": 'single',
|
|
45
|
+
"n_ensemble_repeats": 4,
|
|
46
|
+
"regression_model_path": 'tabpfn-v2-regressor.ckpt',
|
|
47
|
+
"softmax_temperature": 0.95,
|
|
48
|
+
},
|
|
49
|
+
],
|
|
50
|
+
"GBM": [
|
|
51
|
+
{
|
|
52
|
+
"ag_args": {'name_suffix': '_r33', 'priority': -2},
|
|
53
|
+
"bagging_fraction": 0.9625293420216,
|
|
54
|
+
"bagging_freq": 1,
|
|
55
|
+
"cat_l2": 0.1236875455555,
|
|
56
|
+
"cat_smooth": 68.8584757332856,
|
|
57
|
+
"extra_trees": False,
|
|
58
|
+
"feature_fraction": 0.6189215809382,
|
|
59
|
+
"lambda_l1": 0.1641757352921,
|
|
60
|
+
"lambda_l2": 0.6937755557881,
|
|
61
|
+
"learning_rate": 0.0154031028561,
|
|
62
|
+
"max_cat_to_onehot": 17,
|
|
63
|
+
"min_data_in_leaf": 1,
|
|
64
|
+
"min_data_per_group": 30,
|
|
65
|
+
"num_leaves": 68,
|
|
66
|
+
},
|
|
67
|
+
{
|
|
68
|
+
"ag_args": {'name_suffix': '_r21', 'priority': -16},
|
|
69
|
+
"bagging_fraction": 0.7218730663234,
|
|
70
|
+
"bagging_freq": 1,
|
|
71
|
+
"cat_l2": 0.0296205152578,
|
|
72
|
+
"cat_smooth": 0.0010255271303,
|
|
73
|
+
"extra_trees": False,
|
|
74
|
+
"feature_fraction": 0.4557131604374,
|
|
75
|
+
"lambda_l1": 0.5219704038237,
|
|
76
|
+
"lambda_l2": 0.1070959487853,
|
|
77
|
+
"learning_rate": 0.0055891584996,
|
|
78
|
+
"max_cat_to_onehot": 71,
|
|
79
|
+
"min_data_in_leaf": 50,
|
|
80
|
+
"min_data_per_group": 10,
|
|
81
|
+
"num_leaves": 30,
|
|
82
|
+
},
|
|
83
|
+
{
|
|
84
|
+
"ag_args": {'name_suffix': '_r11', 'priority': -19},
|
|
85
|
+
"bagging_fraction": 0.775784726514,
|
|
86
|
+
"bagging_freq": 1,
|
|
87
|
+
"cat_l2": 0.3888471449178,
|
|
88
|
+
"cat_smooth": 0.0057144748021,
|
|
89
|
+
"extra_trees": True,
|
|
90
|
+
"feature_fraction": 0.7732354787904,
|
|
91
|
+
"lambda_l1": 0.2211002452568,
|
|
92
|
+
"lambda_l2": 1.1318405980187,
|
|
93
|
+
"learning_rate": 0.0090151778542,
|
|
94
|
+
"max_cat_to_onehot": 15,
|
|
95
|
+
"min_data_in_leaf": 4,
|
|
96
|
+
"min_data_per_group": 15,
|
|
97
|
+
"num_leaves": 2,
|
|
98
|
+
},
|
|
99
|
+
],
|
|
100
|
+
"CAT": [
|
|
101
|
+
{
|
|
102
|
+
"ag_args": {'priority': -5},
|
|
103
|
+
},
|
|
104
|
+
{
|
|
105
|
+
"ag_args": {'name_suffix': '_r51', 'priority': -10},
|
|
106
|
+
"boosting_type": 'Plain',
|
|
107
|
+
"bootstrap_type": 'Bernoulli',
|
|
108
|
+
"colsample_bylevel": 0.8771035272558,
|
|
109
|
+
"depth": 7,
|
|
110
|
+
"grow_policy": 'SymmetricTree',
|
|
111
|
+
"l2_leaf_reg": 2.0107286863021,
|
|
112
|
+
"leaf_estimation_iterations": 2,
|
|
113
|
+
"learning_rate": 0.0058424016622,
|
|
114
|
+
"max_bin": 254,
|
|
115
|
+
"max_ctr_complexity": 4,
|
|
116
|
+
"model_size_reg": 0.1307400355809,
|
|
117
|
+
"one_hot_max_size": 23,
|
|
118
|
+
"subsample": 0.809527841437,
|
|
119
|
+
},
|
|
120
|
+
{
|
|
121
|
+
"ag_args": {'name_suffix': '_r10', 'priority': -12},
|
|
122
|
+
"boosting_type": 'Plain',
|
|
123
|
+
"bootstrap_type": 'Bernoulli',
|
|
124
|
+
"colsample_bylevel": 0.8994502668431,
|
|
125
|
+
"depth": 6,
|
|
126
|
+
"grow_policy": 'Depthwise',
|
|
127
|
+
"l2_leaf_reg": 1.8187025215896,
|
|
128
|
+
"leaf_estimation_iterations": 7,
|
|
129
|
+
"learning_rate": 0.005177304142,
|
|
130
|
+
"max_bin": 254,
|
|
131
|
+
"max_ctr_complexity": 4,
|
|
132
|
+
"model_size_reg": 0.5247386875068,
|
|
133
|
+
"one_hot_max_size": 53,
|
|
134
|
+
"subsample": 0.8705228845742,
|
|
135
|
+
},
|
|
136
|
+
{
|
|
137
|
+
"ag_args": {'name_suffix': '_r24', 'priority': -15},
|
|
138
|
+
"boosting_type": 'Plain',
|
|
139
|
+
"bootstrap_type": 'Bernoulli',
|
|
140
|
+
"colsample_bylevel": 0.8597809376276,
|
|
141
|
+
"depth": 8,
|
|
142
|
+
"grow_policy": 'Depthwise',
|
|
143
|
+
"l2_leaf_reg": 0.3628261923976,
|
|
144
|
+
"leaf_estimation_iterations": 5,
|
|
145
|
+
"learning_rate": 0.016851077771,
|
|
146
|
+
"max_bin": 254,
|
|
147
|
+
"max_ctr_complexity": 4,
|
|
148
|
+
"model_size_reg": 0.1253820547902,
|
|
149
|
+
"one_hot_max_size": 20,
|
|
150
|
+
"subsample": 0.8120271122061,
|
|
151
|
+
},
|
|
152
|
+
{
|
|
153
|
+
"ag_args": {'name_suffix': '_r91', 'priority': -17},
|
|
154
|
+
"boosting_type": 'Plain',
|
|
155
|
+
"bootstrap_type": 'Bernoulli',
|
|
156
|
+
"colsample_bylevel": 0.8959275863514,
|
|
157
|
+
"depth": 4,
|
|
158
|
+
"grow_policy": 'SymmetricTree',
|
|
159
|
+
"l2_leaf_reg": 0.0026915894253,
|
|
160
|
+
"leaf_estimation_iterations": 12,
|
|
161
|
+
"learning_rate": 0.0475233791203,
|
|
162
|
+
"max_bin": 254,
|
|
163
|
+
"max_ctr_complexity": 5,
|
|
164
|
+
"model_size_reg": 0.1633175256924,
|
|
165
|
+
"one_hot_max_size": 11,
|
|
166
|
+
"subsample": 0.798554178926,
|
|
167
|
+
},
|
|
168
|
+
],
|
|
169
|
+
"TABM": [
|
|
170
|
+
{
|
|
171
|
+
"ag_args": {'name_suffix': '_r184', 'priority': -6},
|
|
172
|
+
"amp": False,
|
|
173
|
+
"arch_type": 'tabm-mini',
|
|
174
|
+
"batch_size": 'auto',
|
|
175
|
+
"d_block": 864,
|
|
176
|
+
"d_embedding": 24,
|
|
177
|
+
"dropout": 0.0,
|
|
178
|
+
"gradient_clipping_norm": 1.0,
|
|
179
|
+
"lr": 0.0019256819924656217,
|
|
180
|
+
"n_blocks": 3,
|
|
181
|
+
"num_emb_n_bins": 3,
|
|
182
|
+
"num_emb_type": 'pwl',
|
|
183
|
+
"patience": 16,
|
|
184
|
+
"share_training_batches": False,
|
|
185
|
+
"tabm_k": 32,
|
|
186
|
+
"weight_decay": 0.0,
|
|
187
|
+
},
|
|
188
|
+
{
|
|
189
|
+
"ag_args": {'name_suffix': '_r69', 'priority': -7},
|
|
190
|
+
"amp": False,
|
|
191
|
+
"arch_type": 'tabm-mini',
|
|
192
|
+
"batch_size": 'auto',
|
|
193
|
+
"d_block": 848,
|
|
194
|
+
"d_embedding": 28,
|
|
195
|
+
"dropout": 0.40215621636031007,
|
|
196
|
+
"gradient_clipping_norm": 1.0,
|
|
197
|
+
"lr": 0.0010413640454559532,
|
|
198
|
+
"n_blocks": 3,
|
|
199
|
+
"num_emb_n_bins": 18,
|
|
200
|
+
"num_emb_type": 'pwl',
|
|
201
|
+
"patience": 16,
|
|
202
|
+
"share_training_batches": False,
|
|
203
|
+
"tabm_k": 32,
|
|
204
|
+
"weight_decay": 0.0,
|
|
205
|
+
},
|
|
206
|
+
{
|
|
207
|
+
"ag_args": {'name_suffix': '_r52', 'priority': -11},
|
|
208
|
+
"amp": False,
|
|
209
|
+
"arch_type": 'tabm-mini',
|
|
210
|
+
"batch_size": 'auto',
|
|
211
|
+
"d_block": 1024,
|
|
212
|
+
"d_embedding": 32,
|
|
213
|
+
"dropout": 0.0,
|
|
214
|
+
"gradient_clipping_norm": 1.0,
|
|
215
|
+
"lr": 0.0006297851297842611,
|
|
216
|
+
"n_blocks": 4,
|
|
217
|
+
"num_emb_n_bins": 22,
|
|
218
|
+
"num_emb_type": 'pwl',
|
|
219
|
+
"patience": 16,
|
|
220
|
+
"share_training_batches": False,
|
|
221
|
+
"tabm_k": 32,
|
|
222
|
+
"weight_decay": 0.06900108498839816,
|
|
223
|
+
},
|
|
224
|
+
{
|
|
225
|
+
"ag_args": {'priority': -13},
|
|
226
|
+
},
|
|
227
|
+
{
|
|
228
|
+
"ag_args": {'name_suffix': '_r191', 'priority': -14},
|
|
229
|
+
"amp": False,
|
|
230
|
+
"arch_type": 'tabm-mini',
|
|
231
|
+
"batch_size": 'auto',
|
|
232
|
+
"d_block": 864,
|
|
233
|
+
"d_embedding": 8,
|
|
234
|
+
"dropout": 0.45321529282058803,
|
|
235
|
+
"gradient_clipping_norm": 1.0,
|
|
236
|
+
"lr": 0.0003781238075322413,
|
|
237
|
+
"n_blocks": 4,
|
|
238
|
+
"num_emb_n_bins": 27,
|
|
239
|
+
"num_emb_type": 'pwl',
|
|
240
|
+
"patience": 16,
|
|
241
|
+
"share_training_batches": False,
|
|
242
|
+
"tabm_k": 32,
|
|
243
|
+
"weight_decay": 0.01766851962579851,
|
|
244
|
+
},
|
|
245
|
+
{
|
|
246
|
+
"ag_args": {'name_suffix': '_r49', 'priority': -20},
|
|
247
|
+
"amp": False,
|
|
248
|
+
"arch_type": 'tabm-mini',
|
|
249
|
+
"batch_size": 'auto',
|
|
250
|
+
"d_block": 640,
|
|
251
|
+
"d_embedding": 28,
|
|
252
|
+
"dropout": 0.15296207419190627,
|
|
253
|
+
"gradient_clipping_norm": 1.0,
|
|
254
|
+
"lr": 0.002277678490593717,
|
|
255
|
+
"n_blocks": 3,
|
|
256
|
+
"num_emb_n_bins": 48,
|
|
257
|
+
"num_emb_type": 'pwl',
|
|
258
|
+
"patience": 16,
|
|
259
|
+
"share_training_batches": False,
|
|
260
|
+
"tabm_k": 32,
|
|
261
|
+
"weight_decay": 0.0578159148243893,
|
|
262
|
+
},
|
|
263
|
+
],
|
|
264
|
+
"TABICL": [
|
|
265
|
+
{
|
|
266
|
+
"ag_args": {'priority': -8},
|
|
267
|
+
},
|
|
268
|
+
],
|
|
269
|
+
"XGB": [
|
|
270
|
+
{
|
|
271
|
+
"ag_args": {'name_suffix': '_r171', 'priority': -9},
|
|
272
|
+
"colsample_bylevel": 0.9213705632288,
|
|
273
|
+
"colsample_bynode": 0.6443385965381,
|
|
274
|
+
"enable_categorical": True,
|
|
275
|
+
"grow_policy": 'lossguide',
|
|
276
|
+
"learning_rate": 0.0068171645251,
|
|
277
|
+
"max_cat_to_onehot": 8,
|
|
278
|
+
"max_depth": 6,
|
|
279
|
+
"max_leaves": 10,
|
|
280
|
+
"min_child_weight": 0.0507304250576,
|
|
281
|
+
"reg_alpha": 4.2446346389037,
|
|
282
|
+
"reg_lambda": 1.4800570021253,
|
|
283
|
+
"subsample": 0.9656290596647,
|
|
284
|
+
},
|
|
285
|
+
{
|
|
286
|
+
"ag_args": {'name_suffix': '_r40', 'priority': -18},
|
|
287
|
+
"colsample_bylevel": 0.6377491713202,
|
|
288
|
+
"colsample_bynode": 0.9237625621103,
|
|
289
|
+
"enable_categorical": True,
|
|
290
|
+
"grow_policy": 'lossguide',
|
|
291
|
+
"learning_rate": 0.0112462621131,
|
|
292
|
+
"max_cat_to_onehot": 33,
|
|
293
|
+
"max_depth": 10,
|
|
294
|
+
"max_leaves": 35,
|
|
295
|
+
"min_child_weight": 0.1403464856034,
|
|
296
|
+
"reg_alpha": 3.4960653958503,
|
|
297
|
+
"reg_lambda": 1.3062320805235,
|
|
298
|
+
"subsample": 0.6948898835178,
|
|
299
|
+
},
|
|
300
|
+
],
|
|
301
|
+
"MITRA": [
|
|
302
|
+
{
|
|
303
|
+
"n_estimators": 1,
|
|
304
|
+
"fine_tune": True,
|
|
305
|
+
"fine_tune_steps": 50,
|
|
306
|
+
"ag_args": {'priority': -21},
|
|
307
|
+
},
|
|
308
|
+
],
|
|
309
|
+
}
|
|
@@ -17,7 +17,8 @@ class FTTransformerModel(MultiModalPredictorModel):
|
|
|
17
17
|
ag_name = "FTTransformer"
|
|
18
18
|
|
|
19
19
|
def __init__(self, **kwargs):
|
|
20
|
-
"""
|
|
20
|
+
"""
|
|
21
|
+
FT-Transformer model.
|
|
21
22
|
|
|
22
23
|
The features can be a mix of
|
|
23
24
|
- categorical column
|
|
@@ -48,6 +49,8 @@ class FTTransformerModel(MultiModalPredictorModel):
|
|
|
48
49
|
Names of the features.
|
|
49
50
|
feature_metadata
|
|
50
51
|
The feature metadata.
|
|
52
|
+
|
|
53
|
+
.. versionadded:: 0.6.0
|
|
51
54
|
"""
|
|
52
55
|
super().__init__(**kwargs)
|
|
53
56
|
|
|
@@ -24,10 +24,16 @@ class TrainerFinetune(BaseEstimator):
|
|
|
24
24
|
cfg: ConfigRun,
|
|
25
25
|
model: torch.nn.Module,
|
|
26
26
|
n_classes: int,
|
|
27
|
-
device: str
|
|
28
|
-
|
|
27
|
+
device: str,
|
|
28
|
+
rng: np.random.RandomState = None,
|
|
29
|
+
verbose: bool = True,
|
|
30
|
+
):
|
|
29
31
|
|
|
30
32
|
self.cfg = cfg
|
|
33
|
+
if rng is None:
|
|
34
|
+
rng = np.random.RandomState(self.cfg.seed)
|
|
35
|
+
self.rng = rng
|
|
36
|
+
self.verbose = verbose
|
|
31
37
|
self.device = device
|
|
32
38
|
self.model = model.to(self.device, non_blocking=True)
|
|
33
39
|
self.n_classes = n_classes
|
|
@@ -81,13 +87,15 @@ class TrainerFinetune(BaseEstimator):
|
|
|
81
87
|
y = y_train_transformed,
|
|
82
88
|
task = self.cfg.task,
|
|
83
89
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
84
|
-
max_samples_query = self.cfg.hyperparams['max_samples_query']
|
|
90
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
91
|
+
rng=self.rng,
|
|
85
92
|
)
|
|
86
93
|
|
|
87
94
|
self.checkpoint.reset(self.model)
|
|
88
95
|
|
|
89
96
|
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
|
90
|
-
self.
|
|
97
|
+
if self.verbose:
|
|
98
|
+
self.log_start_metrics(metrics_valid)
|
|
91
99
|
self.checkpoint(self.model, metrics_valid.loss)
|
|
92
100
|
|
|
93
101
|
start_time = time.time()
|
|
@@ -154,13 +162,15 @@ class TrainerFinetune(BaseEstimator):
|
|
|
154
162
|
metrics_train = prediction_metrics_tracker.get_metrics()
|
|
155
163
|
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
|
156
164
|
|
|
157
|
-
self.
|
|
165
|
+
if self.verbose:
|
|
166
|
+
self.log_metrics(epoch, metrics_train, metrics_valid)
|
|
158
167
|
|
|
159
168
|
self.checkpoint(self.model, metrics_valid.loss)
|
|
160
169
|
|
|
161
170
|
self.early_stopping(metrics_valid.metrics[self.metric])
|
|
162
171
|
if self.early_stopping.we_should_stop():
|
|
163
|
-
|
|
172
|
+
if self.verbose:
|
|
173
|
+
logger.info("Early stopping")
|
|
164
174
|
break
|
|
165
175
|
|
|
166
176
|
if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
|
|
@@ -192,6 +202,7 @@ class TrainerFinetune(BaseEstimator):
|
|
|
192
202
|
y_query = y_query,
|
|
193
203
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
194
204
|
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
205
|
+
rng=self.rng,
|
|
195
206
|
)
|
|
196
207
|
|
|
197
208
|
loader = self.make_loader(dataset, training=False)
|
|
@@ -246,6 +257,7 @@ class TrainerFinetune(BaseEstimator):
|
|
|
246
257
|
y_query = None,
|
|
247
258
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
|
248
259
|
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
|
260
|
+
rng=self.rng,
|
|
249
261
|
)
|
|
250
262
|
|
|
251
263
|
loader = self.make_loader(dataset, training=False)
|
|
@@ -26,13 +26,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
|
|
|
26
26
|
x_query: np.ndarray,
|
|
27
27
|
y_query: Optional[np.ndarray],
|
|
28
28
|
max_samples_support: int,
|
|
29
|
-
max_samples_query: int
|
|
29
|
+
max_samples_query: int,
|
|
30
|
+
rng: np.random.RandomState,
|
|
30
31
|
):
|
|
31
32
|
"""
|
|
32
33
|
:param: max_features: number of features the tab pfn model has been trained on
|
|
33
34
|
"""
|
|
34
35
|
|
|
35
36
|
self.cfg = cfg
|
|
37
|
+
self.rng = rng
|
|
36
38
|
|
|
37
39
|
self.x_support = x_support
|
|
38
40
|
self.y_support = y_support
|
|
@@ -59,7 +61,7 @@ class DatasetFinetune(torch.utils.data.Dataset):
|
|
|
59
61
|
|
|
60
62
|
def __getitem__(self, idx):
|
|
61
63
|
|
|
62
|
-
support_indices =
|
|
64
|
+
support_indices = self.rng.choice(
|
|
63
65
|
self.n_samples_support,
|
|
64
66
|
size=self.support_size,
|
|
65
67
|
replace=False
|
|
@@ -101,7 +103,8 @@ def DatasetFinetuneGenerator(
|
|
|
101
103
|
y: np.ndarray,
|
|
102
104
|
task: Task,
|
|
103
105
|
max_samples_support: int,
|
|
104
|
-
max_samples_query: int
|
|
106
|
+
max_samples_query: int,
|
|
107
|
+
rng: np.random.RandomState,
|
|
105
108
|
):
|
|
106
109
|
"""
|
|
107
110
|
The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
|
|
@@ -112,7 +115,7 @@ def DatasetFinetuneGenerator(
|
|
|
112
115
|
|
|
113
116
|
while True:
|
|
114
117
|
|
|
115
|
-
x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=
|
|
118
|
+
x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=rng)
|
|
116
119
|
n_samples_support = x_support.shape[0]
|
|
117
120
|
n_samples_query = x_query.shape[0]
|
|
118
121
|
|
|
@@ -127,6 +130,7 @@ def DatasetFinetuneGenerator(
|
|
|
127
130
|
y_query=y_query[:query_size],
|
|
128
131
|
max_samples_support=max_samples_support,
|
|
129
132
|
max_samples_query=max_samples_query,
|
|
133
|
+
rng=rng,
|
|
130
134
|
)
|
|
131
135
|
|
|
132
136
|
yield dataset_finetune
|
|
@@ -1,3 +1,5 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
1
3
|
import numpy as np
|
|
2
4
|
from sklearn.model_selection import StratifiedKFold, train_test_split
|
|
3
5
|
|
|
@@ -19,9 +21,11 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> t
|
|
|
19
21
|
|
|
20
22
|
|
|
21
23
|
def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
|
|
24
|
+
if isinstance(seed, int):
|
|
25
|
+
seed = np.random.RandomState(seed)
|
|
22
26
|
|
|
23
27
|
# Stratify doesn't shuffle the data, so we shuffle it first
|
|
24
|
-
permutation =
|
|
28
|
+
permutation = seed.permutation(len(y))
|
|
25
29
|
x, y = x[permutation], y[permutation]
|
|
26
30
|
|
|
27
31
|
min_samples_per_class = np.min(np.bincount(y))
|
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
import json
|
|
2
|
+
import logging
|
|
2
3
|
import os
|
|
3
4
|
from typing import Optional, Union
|
|
4
5
|
|
|
@@ -29,6 +30,8 @@ from ..._internal.models.embedding import (
|
|
|
29
30
|
Tab2DQuantileEmbeddingX,
|
|
30
31
|
)
|
|
31
32
|
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
32
35
|
|
|
33
36
|
class Tab2D(BaseModel):
|
|
34
37
|
|