autogluon.tabular 1.3.2b20250722__py3-none-any.whl → 1.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/presets_configs.py +51 -23
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  6. autogluon/tabular/models/automm/automm_model.py +2 -0
  7. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  8. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +18 -6
  9. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +8 -4
  10. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +5 -1
  11. autogluon/tabular/models/mitra/_internal/models/tab2d.py +3 -0
  12. autogluon/tabular/models/mitra/mitra_model.py +74 -21
  13. autogluon/tabular/models/mitra/sklearn_interface.py +15 -13
  14. autogluon/tabular/models/realmlp/realmlp_model.py +13 -6
  15. autogluon/tabular/models/tabicl/tabicl_model.py +17 -8
  16. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +3 -0
  17. autogluon/tabular/models/tabm/tabm_model.py +14 -6
  18. autogluon/tabular/models/tabm/tabm_reference.py +2 -0
  19. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
  20. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +29 -12
  21. autogluon/tabular/predictor/predictor.py +79 -26
  22. autogluon/tabular/trainer/abstract_trainer.py +2 -0
  23. autogluon/tabular/version.py +1 -1
  24. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/METADATA +42 -20
  25. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/RECORD +32 -31
  26. /autogluon.tabular-1.3.2b20250722-py3.9-nspkg.pth → /autogluon.tabular-1.4.0-py3.9-nspkg.pth +0 -0
  27. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/LICENSE +0 -0
  28. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/NOTICE +0 -0
  29. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/WHEEL +0 -0
  30. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/namespace_packages.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/top_level.txt +0 -0
  32. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0.dist-info}/zip-safe +0 -0
@@ -0,0 +1,310 @@
1
+ # optimized for <=10000 samples and <=500 features, with a GPU present
2
+ hyperparameter_portfolio_zeroshot_2025_small = {
3
+ "TABPFNV2": [
4
+ {
5
+ "ag_args": {'name_suffix': '_r143', 'priority': -1},
6
+ "average_before_softmax": False,
7
+ "classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
8
+ "inference_config/FINGERPRINT_FEATURE": False,
9
+ "inference_config/OUTLIER_REMOVAL_STD": None,
10
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
11
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
12
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
13
+ "inference_config/SUBSAMPLE_SAMPLES": 0.99,
14
+ "model_type": 'single',
15
+ "n_ensemble_repeats": 4,
16
+ "regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
17
+ "softmax_temperature": 0.75,
18
+ },
19
+ {
20
+ "ag_args": {'name_suffix': '_r94', 'priority': -3},
21
+ "average_before_softmax": True,
22
+ "classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
23
+ "inference_config/FINGERPRINT_FEATURE": True,
24
+ "inference_config/OUTLIER_REMOVAL_STD": None,
25
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
26
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
27
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
28
+ "inference_config/SUBSAMPLE_SAMPLES": None,
29
+ "model_type": 'single',
30
+ "n_ensemble_repeats": 4,
31
+ "regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
32
+ "softmax_temperature": 0.9,
33
+ },
34
+ {
35
+ "ag_args": {'name_suffix': '_r181', 'priority': -4},
36
+ "average_before_softmax": False,
37
+ "classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
38
+ "inference_config/FINGERPRINT_FEATURE": False,
39
+ "inference_config/OUTLIER_REMOVAL_STD": 9.0,
40
+ "inference_config/POLYNOMIAL_FEATURES": 50,
41
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
42
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
43
+ "inference_config/SUBSAMPLE_SAMPLES": None,
44
+ "model_type": 'single',
45
+ "n_ensemble_repeats": 4,
46
+ "regression_model_path": 'tabpfn-v2-regressor.ckpt',
47
+ "softmax_temperature": 0.95,
48
+ },
49
+ ],
50
+ "GBM": [
51
+ {
52
+ "ag_args": {'name_suffix': '_r33', 'priority': -2},
53
+ "bagging_fraction": 0.9625293420216,
54
+ "bagging_freq": 1,
55
+ "cat_l2": 0.1236875455555,
56
+ "cat_smooth": 68.8584757332856,
57
+ "extra_trees": False,
58
+ "feature_fraction": 0.6189215809382,
59
+ "lambda_l1": 0.1641757352921,
60
+ "lambda_l2": 0.6937755557881,
61
+ "learning_rate": 0.0154031028561,
62
+ "max_cat_to_onehot": 17,
63
+ "min_data_in_leaf": 1,
64
+ "min_data_per_group": 30,
65
+ "num_leaves": 68,
66
+ },
67
+ {
68
+ "ag_args": {'name_suffix': '_r21', 'priority': -16},
69
+ "bagging_fraction": 0.7218730663234,
70
+ "bagging_freq": 1,
71
+ "cat_l2": 0.0296205152578,
72
+ "cat_smooth": 0.0010255271303,
73
+ "extra_trees": False,
74
+ "feature_fraction": 0.4557131604374,
75
+ "lambda_l1": 0.5219704038237,
76
+ "lambda_l2": 0.1070959487853,
77
+ "learning_rate": 0.0055891584996,
78
+ "max_cat_to_onehot": 71,
79
+ "min_data_in_leaf": 50,
80
+ "min_data_per_group": 10,
81
+ "num_leaves": 30,
82
+ },
83
+ {
84
+ "ag_args": {'name_suffix': '_r11', 'priority': -19},
85
+ "bagging_fraction": 0.775784726514,
86
+ "bagging_freq": 1,
87
+ "cat_l2": 0.3888471449178,
88
+ "cat_smooth": 0.0057144748021,
89
+ "extra_trees": True,
90
+ "feature_fraction": 0.7732354787904,
91
+ "lambda_l1": 0.2211002452568,
92
+ "lambda_l2": 1.1318405980187,
93
+ "learning_rate": 0.0090151778542,
94
+ "max_cat_to_onehot": 15,
95
+ "min_data_in_leaf": 4,
96
+ "min_data_per_group": 15,
97
+ "num_leaves": 2,
98
+ },
99
+ ],
100
+ "CAT": [
101
+ {
102
+ "ag_args": {'priority': -5},
103
+ },
104
+ {
105
+ "ag_args": {'name_suffix': '_r51', 'priority': -10},
106
+ "boosting_type": 'Plain',
107
+ "bootstrap_type": 'Bernoulli',
108
+ "colsample_bylevel": 0.8771035272558,
109
+ "depth": 7,
110
+ "grow_policy": 'SymmetricTree',
111
+ "l2_leaf_reg": 2.0107286863021,
112
+ "leaf_estimation_iterations": 2,
113
+ "learning_rate": 0.0058424016622,
114
+ "max_bin": 254,
115
+ "max_ctr_complexity": 4,
116
+ "model_size_reg": 0.1307400355809,
117
+ "one_hot_max_size": 23,
118
+ "subsample": 0.809527841437,
119
+ },
120
+ {
121
+ "ag_args": {'name_suffix': '_r10', 'priority': -12},
122
+ "boosting_type": 'Plain',
123
+ "bootstrap_type": 'Bernoulli',
124
+ "colsample_bylevel": 0.8994502668431,
125
+ "depth": 6,
126
+ "grow_policy": 'Depthwise',
127
+ "l2_leaf_reg": 1.8187025215896,
128
+ "leaf_estimation_iterations": 7,
129
+ "learning_rate": 0.005177304142,
130
+ "max_bin": 254,
131
+ "max_ctr_complexity": 4,
132
+ "model_size_reg": 0.5247386875068,
133
+ "one_hot_max_size": 53,
134
+ "subsample": 0.8705228845742,
135
+ },
136
+ {
137
+ "ag_args": {'name_suffix': '_r24', 'priority': -15},
138
+ "boosting_type": 'Plain',
139
+ "bootstrap_type": 'Bernoulli',
140
+ "colsample_bylevel": 0.8597809376276,
141
+ "depth": 8,
142
+ "grow_policy": 'Depthwise',
143
+ "l2_leaf_reg": 0.3628261923976,
144
+ "leaf_estimation_iterations": 5,
145
+ "learning_rate": 0.016851077771,
146
+ "max_bin": 254,
147
+ "max_ctr_complexity": 4,
148
+ "model_size_reg": 0.1253820547902,
149
+ "one_hot_max_size": 20,
150
+ "subsample": 0.8120271122061,
151
+ },
152
+ {
153
+ "ag_args": {'name_suffix': '_r91', 'priority': -17},
154
+ "boosting_type": 'Plain',
155
+ "bootstrap_type": 'Bernoulli',
156
+ "colsample_bylevel": 0.8959275863514,
157
+ "depth": 4,
158
+ "grow_policy": 'SymmetricTree',
159
+ "l2_leaf_reg": 0.0026915894253,
160
+ "leaf_estimation_iterations": 12,
161
+ "learning_rate": 0.0475233791203,
162
+ "max_bin": 254,
163
+ "max_ctr_complexity": 5,
164
+ "model_size_reg": 0.1633175256924,
165
+ "one_hot_max_size": 11,
166
+ "subsample": 0.798554178926,
167
+ },
168
+ ],
169
+ "TABM": [
170
+ {
171
+ "ag_args": {'name_suffix': '_r184', 'priority': -6},
172
+ "amp": False,
173
+ "arch_type": 'tabm-mini',
174
+ "batch_size": 'auto',
175
+ "d_block": 864,
176
+ "d_embedding": 24,
177
+ "dropout": 0.0,
178
+ "gradient_clipping_norm": 1.0,
179
+ "lr": 0.0019256819924656217,
180
+ "n_blocks": 3,
181
+ "num_emb_n_bins": 3,
182
+ "num_emb_type": 'pwl',
183
+ "patience": 16,
184
+ "share_training_batches": False,
185
+ "tabm_k": 32,
186
+ "weight_decay": 0.0,
187
+ },
188
+ {
189
+ "ag_args": {'name_suffix': '_r69', 'priority': -7},
190
+ "amp": False,
191
+ "arch_type": 'tabm-mini',
192
+ "batch_size": 'auto',
193
+ "d_block": 848,
194
+ "d_embedding": 28,
195
+ "dropout": 0.40215621636031007,
196
+ "gradient_clipping_norm": 1.0,
197
+ "lr": 0.0010413640454559532,
198
+ "n_blocks": 3,
199
+ "num_emb_n_bins": 18,
200
+ "num_emb_type": 'pwl',
201
+ "patience": 16,
202
+ "share_training_batches": False,
203
+ "tabm_k": 32,
204
+ "weight_decay": 0.0,
205
+ },
206
+ {
207
+ "ag_args": {'name_suffix': '_r52', 'priority': -11},
208
+ "amp": False,
209
+ "arch_type": 'tabm-mini',
210
+ "batch_size": 'auto',
211
+ "d_block": 1024,
212
+ "d_embedding": 32,
213
+ "dropout": 0.0,
214
+ "gradient_clipping_norm": 1.0,
215
+ "lr": 0.0006297851297842611,
216
+ "n_blocks": 4,
217
+ "num_emb_n_bins": 22,
218
+ "num_emb_type": 'pwl',
219
+ "patience": 16,
220
+ "share_training_batches": False,
221
+ "tabm_k": 32,
222
+ "weight_decay": 0.06900108498839816,
223
+ },
224
+ {
225
+ "ag_args": {'priority': -13},
226
+ },
227
+ {
228
+ "ag_args": {'name_suffix': '_r191', 'priority': -14},
229
+ "amp": False,
230
+ "arch_type": 'tabm-mini',
231
+ "batch_size": 'auto',
232
+ "d_block": 864,
233
+ "d_embedding": 8,
234
+ "dropout": 0.45321529282058803,
235
+ "gradient_clipping_norm": 1.0,
236
+ "lr": 0.0003781238075322413,
237
+ "n_blocks": 4,
238
+ "num_emb_n_bins": 27,
239
+ "num_emb_type": 'pwl',
240
+ "patience": 16,
241
+ "share_training_batches": False,
242
+ "tabm_k": 32,
243
+ "weight_decay": 0.01766851962579851,
244
+ },
245
+ {
246
+ "ag_args": {'name_suffix': '_r49', 'priority': -20},
247
+ "amp": False,
248
+ "arch_type": 'tabm-mini',
249
+ "batch_size": 'auto',
250
+ "d_block": 640,
251
+ "d_embedding": 28,
252
+ "dropout": 0.15296207419190627,
253
+ "gradient_clipping_norm": 1.0,
254
+ "lr": 0.002277678490593717,
255
+ "n_blocks": 3,
256
+ "num_emb_n_bins": 48,
257
+ "num_emb_type": 'pwl',
258
+ "patience": 16,
259
+ "share_training_batches": False,
260
+ "tabm_k": 32,
261
+ "weight_decay": 0.0578159148243893,
262
+ },
263
+ ],
264
+ "TABICL": [
265
+ {
266
+ "ag_args": {'priority': -8},
267
+ },
268
+ ],
269
+ "XGB": [
270
+ {
271
+ "ag_args": {'name_suffix': '_r171', 'priority': -9},
272
+ "colsample_bylevel": 0.9213705632288,
273
+ "colsample_bynode": 0.6443385965381,
274
+ "enable_categorical": True,
275
+ "grow_policy": 'lossguide',
276
+ "learning_rate": 0.0068171645251,
277
+ "max_cat_to_onehot": 8,
278
+ "max_depth": 6,
279
+ "max_leaves": 10,
280
+ "min_child_weight": 0.0507304250576,
281
+ "reg_alpha": 4.2446346389037,
282
+ "reg_lambda": 1.4800570021253,
283
+ "subsample": 0.9656290596647,
284
+ },
285
+ {
286
+ "ag_args": {'name_suffix': '_r40', 'priority': -18},
287
+ "colsample_bylevel": 0.6377491713202,
288
+ "colsample_bynode": 0.9237625621103,
289
+ "enable_categorical": True,
290
+ "grow_policy": 'lossguide',
291
+ "learning_rate": 0.0112462621131,
292
+ "max_cat_to_onehot": 33,
293
+ "max_depth": 10,
294
+ "max_leaves": 35,
295
+ "min_child_weight": 0.1403464856034,
296
+ "reg_alpha": 3.4960653958503,
297
+ "reg_lambda": 1.3062320805235,
298
+ "subsample": 0.6948898835178,
299
+ },
300
+ ],
301
+ "MITRA": [
302
+ {
303
+ "n_estimators": 1,
304
+ "fine_tune": True,
305
+ "fine_tune_steps": 50,
306
+ "ag.num_gpus": 1,
307
+ "ag_args": {'priority': -21},
308
+ },
309
+ ],
310
+ }
@@ -65,6 +65,8 @@ class MultiModalPredictorModel(AbstractModel):
65
65
  Names of the features.
66
66
  feature_metadata
67
67
  The feature metadata.
68
+
69
+ .. versionadded:: 0.3.0
68
70
  """
69
71
  super().__init__(**kwargs)
70
72
  self._label_column_name = None
@@ -17,7 +17,8 @@ class FTTransformerModel(MultiModalPredictorModel):
17
17
  ag_name = "FTTransformer"
18
18
 
19
19
  def __init__(self, **kwargs):
20
- """Wrapper of autogluon.multimodal.MultiModalPredictor.
20
+ """
21
+ FT-Transformer model.
21
22
 
22
23
  The features can be a mix of
23
24
  - categorical column
@@ -48,6 +49,8 @@ class FTTransformerModel(MultiModalPredictorModel):
48
49
  Names of the features.
49
50
  feature_metadata
50
51
  The feature metadata.
52
+
53
+ .. versionadded:: 0.6.0
51
54
  """
52
55
  super().__init__(**kwargs)
53
56
 
@@ -24,10 +24,16 @@ class TrainerFinetune(BaseEstimator):
24
24
  cfg: ConfigRun,
25
25
  model: torch.nn.Module,
26
26
  n_classes: int,
27
- device: str
28
- ) -> None:
27
+ device: str,
28
+ rng: np.random.RandomState = None,
29
+ verbose: bool = True,
30
+ ):
29
31
 
30
32
  self.cfg = cfg
33
+ if rng is None:
34
+ rng = np.random.RandomState(self.cfg.seed)
35
+ self.rng = rng
36
+ self.verbose = verbose
31
37
  self.device = device
32
38
  self.model = model.to(self.device, non_blocking=True)
33
39
  self.n_classes = n_classes
@@ -81,13 +87,15 @@ class TrainerFinetune(BaseEstimator):
81
87
  y = y_train_transformed,
82
88
  task = self.cfg.task,
83
89
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
84
- max_samples_query = self.cfg.hyperparams['max_samples_query']
90
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
91
+ rng=self.rng,
85
92
  )
86
93
 
87
94
  self.checkpoint.reset(self.model)
88
95
 
89
96
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
90
- self.log_start_metrics(metrics_valid)
97
+ if self.verbose:
98
+ self.log_start_metrics(metrics_valid)
91
99
  self.checkpoint(self.model, metrics_valid.loss)
92
100
 
93
101
  start_time = time.time()
@@ -154,13 +162,15 @@ class TrainerFinetune(BaseEstimator):
154
162
  metrics_train = prediction_metrics_tracker.get_metrics()
155
163
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
156
164
 
157
- self.log_metrics(epoch, metrics_train, metrics_valid)
165
+ if self.verbose:
166
+ self.log_metrics(epoch, metrics_train, metrics_valid)
158
167
 
159
168
  self.checkpoint(self.model, metrics_valid.loss)
160
169
 
161
170
  self.early_stopping(metrics_valid.metrics[self.metric])
162
171
  if self.early_stopping.we_should_stop():
163
- logger.info("Early stopping")
172
+ if self.verbose:
173
+ logger.info("Early stopping")
164
174
  break
165
175
 
166
176
  if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
@@ -192,6 +202,7 @@ class TrainerFinetune(BaseEstimator):
192
202
  y_query = y_query,
193
203
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
194
204
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
205
+ rng=self.rng,
195
206
  )
196
207
 
197
208
  loader = self.make_loader(dataset, training=False)
@@ -246,6 +257,7 @@ class TrainerFinetune(BaseEstimator):
246
257
  y_query = None,
247
258
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
248
259
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
260
+ rng=self.rng,
249
261
  )
250
262
 
251
263
  loader = self.make_loader(dataset, training=False)
@@ -26,13 +26,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
26
26
  x_query: np.ndarray,
27
27
  y_query: Optional[np.ndarray],
28
28
  max_samples_support: int,
29
- max_samples_query: int
29
+ max_samples_query: int,
30
+ rng: np.random.RandomState,
30
31
  ):
31
32
  """
32
33
  :param: max_features: number of features the tab pfn model has been trained on
33
34
  """
34
35
 
35
36
  self.cfg = cfg
37
+ self.rng = rng
36
38
 
37
39
  self.x_support = x_support
38
40
  self.y_support = y_support
@@ -59,7 +61,7 @@ class DatasetFinetune(torch.utils.data.Dataset):
59
61
 
60
62
  def __getitem__(self, idx):
61
63
 
62
- support_indices = np.random.choice(
64
+ support_indices = self.rng.choice(
63
65
  self.n_samples_support,
64
66
  size=self.support_size,
65
67
  replace=False
@@ -101,7 +103,8 @@ def DatasetFinetuneGenerator(
101
103
  y: np.ndarray,
102
104
  task: Task,
103
105
  max_samples_support: int,
104
- max_samples_query: int
106
+ max_samples_query: int,
107
+ rng: np.random.RandomState,
105
108
  ):
106
109
  """
107
110
  The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
@@ -112,7 +115,7 @@ def DatasetFinetuneGenerator(
112
115
 
113
116
  while True:
114
117
 
115
- x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=cfg.seed)
118
+ x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=rng)
116
119
  n_samples_support = x_support.shape[0]
117
120
  n_samples_query = x_query.shape[0]
118
121
 
@@ -127,6 +130,7 @@ def DatasetFinetuneGenerator(
127
130
  y_query=y_query[:query_size],
128
131
  max_samples_support=max_samples_support,
129
132
  max_samples_query=max_samples_query,
133
+ rng=rng,
130
134
  )
131
135
 
132
136
  yield dataset_finetune
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  from sklearn.model_selection import StratifiedKFold, train_test_split
3
5
 
@@ -19,9 +21,11 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> t
19
21
 
20
22
 
21
23
  def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
24
+ if isinstance(seed, int):
25
+ seed = np.random.RandomState(seed)
22
26
 
23
27
  # Stratify doesn't shuffle the data, so we shuffle it first
24
- permutation = np.random.permutation(len(y))
28
+ permutation = seed.permutation(len(y))
25
29
  x, y = x[permutation], y[permutation]
26
30
 
27
31
  min_samples_per_class = np.min(np.bincount(y))
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
@@ -29,6 +30,8 @@ from ..._internal.models.embedding import (
29
30
  Tab2DQuantileEmbeddingX,
30
31
  )
31
32
 
33
+ logger = logging.getLogger(__name__)
34
+
32
35
 
33
36
  class Tab2D(BaseModel):
34
37
 
@@ -1,49 +1,56 @@
1
- # TODO: To ensure deterministic operations we need to set torch.use_deterministic_algorithms(True)
2
- # and os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'. The CUBLAS environment variable configures
3
- # the workspace size for certain CUBLAS operations to ensure reproducibility when using CUDA >= 10.2.
4
- # Both settings are required to ensure deterministic behavior in operations such as matrix multiplications.
5
- import os
6
-
7
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
1
+ from __future__ import annotations
8
2
 
3
+ import logging
9
4
  import os
10
5
  from typing import List, Optional
11
6
 
12
7
  import pandas as pd
13
- import torch
14
- import logging
15
8
 
16
9
  from autogluon.common.utils.resource_utils import ResourceManager
17
10
  from autogluon.core.models import AbstractModel
11
+ from autogluon.features.generators import LabelEncoderFeatureGenerator
12
+ from autogluon.tabular import __version__
18
13
 
19
14
  logger = logging.getLogger(__name__)
20
15
 
21
16
 
22
- # TODO: Needs memory usage estimate method
23
17
  class MitraModel(AbstractModel):
18
+ """
19
+ Mitra is a tabular foundation model pre-trained purely on synthetic data with the goal
20
+ of optimizing fine-tuning performance over in-context learning performance.
21
+ Mitra was developed by the AutoGluon team @ AWS AI.
22
+
23
+ Mitra's default hyperparameters outperforms all methods for small datasets on TabArena-v0.1 (excluding ensembling): https://tabarena.ai
24
+
25
+ Authors: Xiyuan Zhang, Danielle C. Maddix, Junming Yin, Nick Erickson, Abdul Fatir Ansari, Boran Han, Shuai Zhang, Leman Akoglu, Christos Faloutsos, Michael W. Mahoney, Cuixiong Hu, Huzefa Rangwala, George Karypis, Bernie Wang
26
+ Blog Post: https://www.amazon.science/blog/mitra-mixed-synthetic-priors-for-enhancing-tabular-foundation-models
27
+ License: Apache-2.0
28
+
29
+ .. versionadded:: 1.4.0
30
+ """
24
31
  ag_key = "MITRA"
25
32
  ag_name = "Mitra"
26
33
  weights_file_name = "model.pt"
27
34
  ag_priority = 55
28
35
 
29
- def __init__(self, problem_type=None, **kwargs):
36
+ def __init__(self, **kwargs):
30
37
  super().__init__(**kwargs)
31
- self.problem_type = problem_type
32
38
  self._weights_saved = False
39
+ self._feature_generator = None
33
40
 
34
41
  @staticmethod
35
42
  def _get_default_device():
36
43
  """Get the best available device for the current system."""
37
44
  if ResourceManager.get_gpu_count_torch(cuda_only=True) > 0:
38
- logger.info("Using CUDA GPU")
45
+ logger.log(15, "Using CUDA GPU")
39
46
  return "cuda"
40
47
  else:
41
48
  return "cpu"
42
49
 
43
50
  def get_model_cls(self):
44
- from .sklearn_interface import MitraClassifier
45
-
46
51
  if self.problem_type in ["binary", "multiclass"]:
52
+ from .sklearn_interface import MitraClassifier
53
+
47
54
  model_cls = MitraClassifier
48
55
  elif self.problem_type == "regression":
49
56
  from .sklearn_interface import MitraRegressor
@@ -53,6 +60,23 @@ class MitraModel(AbstractModel):
53
60
  raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
54
61
  return model_cls
55
62
 
63
+ def _preprocess(self, X: pd.DataFrame, is_train: bool = False, **kwargs) -> pd.DataFrame:
64
+ X = super()._preprocess(X, **kwargs)
65
+
66
+ if is_train:
67
+ # X will be the training data.
68
+ self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
69
+ self._feature_generator.fit(X=X)
70
+
71
+ # This converts categorical features to numeric via stateful label encoding.
72
+ if self._feature_generator.features_in:
73
+ X = X.copy()
74
+ X[self._feature_generator.features_in] = self._feature_generator.transform(
75
+ X=X
76
+ )
77
+
78
+ return X
79
+
56
80
  def _fit(
57
81
  self,
58
82
  X: pd.DataFrame,
@@ -61,11 +85,25 @@ class MitraModel(AbstractModel):
61
85
  y_val: pd.Series = None,
62
86
  time_limit: float = None,
63
87
  num_cpus: int = 1,
88
+ num_gpus: float = 0,
89
+ verbosity: int = 2,
64
90
  **kwargs,
65
91
  ):
66
92
  # TODO: Reset the number of threads based on the specified num_cpus
67
93
  need_to_reset_torch_threads = False
68
94
  torch_threads_og = None
95
+
96
+ try:
97
+ model_cls = self.get_model_cls()
98
+ import torch
99
+ except ImportError as err:
100
+ logger.log(
101
+ 40,
102
+ f"\tFailed to import Mitra! To use the Mitra model, "
103
+ f"do: `pip install autogluon.tabular[mitra]=={__version__}`.",
104
+ )
105
+ raise err
106
+
69
107
  if num_cpus is not None and isinstance(num_cpus, (int, float)):
70
108
  torch_threads_og = torch.get_num_threads()
71
109
  if torch_threads_og != num_cpus:
@@ -73,9 +111,21 @@ class MitraModel(AbstractModel):
73
111
  torch.set_num_threads(num_cpus)
74
112
  need_to_reset_torch_threads = True
75
113
 
76
- model_cls = self.get_model_cls()
77
-
78
114
  hyp = self._get_model_params()
115
+
116
+ if hyp.get("device", None) is None:
117
+ if num_gpus == 0:
118
+ hyp["device"] = "cpu"
119
+ else:
120
+ hyp["device"] = self._get_default_device()
121
+
122
+ if hyp["device"] == "cpu" and hyp.get("fine_tune", True):
123
+ logger.log(
124
+ 30,
125
+ f"\tWarning: Attempting to fine-tune Mitra on CPU. This will be very slow. "
126
+ f"We strongly recommend using a GPU instance to fine-tune Mitra."
127
+ )
128
+
79
129
  if "state_dict_classification" in hyp:
80
130
  state_dict_classification = hyp.pop("state_dict_classification")
81
131
  if self.problem_type in ["binary", "multiclass"]:
@@ -85,11 +135,14 @@ class MitraModel(AbstractModel):
85
135
  if self.problem_type in ["regression"]:
86
136
  hyp["state_dict"] = state_dict_regression
87
137
 
138
+ if "verbose" not in hyp:
139
+ hyp["verbose"] = verbosity >= 3
140
+
88
141
  self.model = model_cls(
89
142
  **hyp,
90
143
  )
91
144
 
92
- X = self.preprocess(X)
145
+ X = self.preprocess(X, is_train=True)
93
146
  if X_val is not None:
94
147
  X_val = self.preprocess(X_val)
95
148
 
@@ -106,7 +159,6 @@ class MitraModel(AbstractModel):
106
159
 
107
160
  def _set_default_params(self):
108
161
  default_params = {
109
- "device": self._get_default_device(),
110
162
  "n_estimators": 1,
111
163
  }
112
164
  for param, val in default_params.items():
@@ -196,12 +248,13 @@ class MitraModel(AbstractModel):
196
248
  X: pd.DataFrame,
197
249
  **kwargs,
198
250
  ) -> int:
199
- return max(
251
+ # Multiply by 0.9 as currently this is overly safe
252
+ return int(0.9 * max(
200
253
  cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
201
254
  cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
202
255
  cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
203
256
  cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
204
- )
257
+ ))
205
258
 
206
259
  @classmethod
207
260
  def _estimate_memory_usage_static_cpu_icl(