autogluon.tabular 1.3.2b20250722__py3-none-any.whl → 1.4.0b20250724__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (26) hide show
  1. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  2. autogluon/tabular/configs/presets_configs.py +47 -21
  3. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +309 -0
  5. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +18 -6
  6. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +8 -4
  7. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +5 -1
  8. autogluon/tabular/models/mitra/_internal/models/tab2d.py +3 -0
  9. autogluon/tabular/models/mitra/mitra_model.py +72 -21
  10. autogluon/tabular/models/mitra/sklearn_interface.py +15 -13
  11. autogluon/tabular/models/tabicl/tabicl_model.py +3 -3
  12. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +3 -0
  13. autogluon/tabular/models/tabm/tabm_reference.py +2 -0
  14. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +15 -6
  15. autogluon/tabular/predictor/predictor.py +41 -1
  16. autogluon/tabular/trainer/abstract_trainer.py +2 -0
  17. autogluon/tabular/version.py +1 -1
  18. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/METADATA +37 -15
  19. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/RECORD +26 -25
  20. /autogluon.tabular-1.3.2b20250722-py3.9-nspkg.pth → /autogluon.tabular-1.4.0b20250724-py3.9-nspkg.pth +0 -0
  21. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/LICENSE +0 -0
  22. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/NOTICE +0 -0
  23. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/WHEEL +0 -0
  24. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/namespace_packages.txt +0 -0
  25. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/top_level.txt +0 -0
  26. {autogluon.tabular-1.3.2b20250722.dist-info → autogluon.tabular-1.4.0b20250724.dist-info}/zip-safe +0 -0
@@ -0,0 +1,309 @@
1
+ # optimized for <=10000 samples and <=500 features, with a GPU present
2
+ hyperparameter_portfolio_zeroshot_2025_small = {
3
+ "TABPFNV2": [
4
+ {
5
+ "ag_args": {'name_suffix': '_r143', 'priority': -1},
6
+ "average_before_softmax": False,
7
+ "classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
8
+ "inference_config/FINGERPRINT_FEATURE": False,
9
+ "inference_config/OUTLIER_REMOVAL_STD": None,
10
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
11
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
12
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
13
+ "inference_config/SUBSAMPLE_SAMPLES": 0.99,
14
+ "model_type": 'single',
15
+ "n_ensemble_repeats": 4,
16
+ "regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
17
+ "softmax_temperature": 0.75,
18
+ },
19
+ {
20
+ "ag_args": {'name_suffix': '_r94', 'priority': -3},
21
+ "average_before_softmax": True,
22
+ "classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
23
+ "inference_config/FINGERPRINT_FEATURE": True,
24
+ "inference_config/OUTLIER_REMOVAL_STD": None,
25
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
26
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
27
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
28
+ "inference_config/SUBSAMPLE_SAMPLES": None,
29
+ "model_type": 'single',
30
+ "n_ensemble_repeats": 4,
31
+ "regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
32
+ "softmax_temperature": 0.9,
33
+ },
34
+ {
35
+ "ag_args": {'name_suffix': '_r181', 'priority': -4},
36
+ "average_before_softmax": False,
37
+ "classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
38
+ "inference_config/FINGERPRINT_FEATURE": False,
39
+ "inference_config/OUTLIER_REMOVAL_STD": 9.0,
40
+ "inference_config/POLYNOMIAL_FEATURES": 50,
41
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
42
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
43
+ "inference_config/SUBSAMPLE_SAMPLES": None,
44
+ "model_type": 'single',
45
+ "n_ensemble_repeats": 4,
46
+ "regression_model_path": 'tabpfn-v2-regressor.ckpt',
47
+ "softmax_temperature": 0.95,
48
+ },
49
+ ],
50
+ "GBM": [
51
+ {
52
+ "ag_args": {'name_suffix': '_r33', 'priority': -2},
53
+ "bagging_fraction": 0.9625293420216,
54
+ "bagging_freq": 1,
55
+ "cat_l2": 0.1236875455555,
56
+ "cat_smooth": 68.8584757332856,
57
+ "extra_trees": False,
58
+ "feature_fraction": 0.6189215809382,
59
+ "lambda_l1": 0.1641757352921,
60
+ "lambda_l2": 0.6937755557881,
61
+ "learning_rate": 0.0154031028561,
62
+ "max_cat_to_onehot": 17,
63
+ "min_data_in_leaf": 1,
64
+ "min_data_per_group": 30,
65
+ "num_leaves": 68,
66
+ },
67
+ {
68
+ "ag_args": {'name_suffix': '_r21', 'priority': -16},
69
+ "bagging_fraction": 0.7218730663234,
70
+ "bagging_freq": 1,
71
+ "cat_l2": 0.0296205152578,
72
+ "cat_smooth": 0.0010255271303,
73
+ "extra_trees": False,
74
+ "feature_fraction": 0.4557131604374,
75
+ "lambda_l1": 0.5219704038237,
76
+ "lambda_l2": 0.1070959487853,
77
+ "learning_rate": 0.0055891584996,
78
+ "max_cat_to_onehot": 71,
79
+ "min_data_in_leaf": 50,
80
+ "min_data_per_group": 10,
81
+ "num_leaves": 30,
82
+ },
83
+ {
84
+ "ag_args": {'name_suffix': '_r11', 'priority': -19},
85
+ "bagging_fraction": 0.775784726514,
86
+ "bagging_freq": 1,
87
+ "cat_l2": 0.3888471449178,
88
+ "cat_smooth": 0.0057144748021,
89
+ "extra_trees": True,
90
+ "feature_fraction": 0.7732354787904,
91
+ "lambda_l1": 0.2211002452568,
92
+ "lambda_l2": 1.1318405980187,
93
+ "learning_rate": 0.0090151778542,
94
+ "max_cat_to_onehot": 15,
95
+ "min_data_in_leaf": 4,
96
+ "min_data_per_group": 15,
97
+ "num_leaves": 2,
98
+ },
99
+ ],
100
+ "CAT": [
101
+ {
102
+ "ag_args": {'priority': -5},
103
+ },
104
+ {
105
+ "ag_args": {'name_suffix': '_r51', 'priority': -10},
106
+ "boosting_type": 'Plain',
107
+ "bootstrap_type": 'Bernoulli',
108
+ "colsample_bylevel": 0.8771035272558,
109
+ "depth": 7,
110
+ "grow_policy": 'SymmetricTree',
111
+ "l2_leaf_reg": 2.0107286863021,
112
+ "leaf_estimation_iterations": 2,
113
+ "learning_rate": 0.0058424016622,
114
+ "max_bin": 254,
115
+ "max_ctr_complexity": 4,
116
+ "model_size_reg": 0.1307400355809,
117
+ "one_hot_max_size": 23,
118
+ "subsample": 0.809527841437,
119
+ },
120
+ {
121
+ "ag_args": {'name_suffix': '_r10', 'priority': -12},
122
+ "boosting_type": 'Plain',
123
+ "bootstrap_type": 'Bernoulli',
124
+ "colsample_bylevel": 0.8994502668431,
125
+ "depth": 6,
126
+ "grow_policy": 'Depthwise',
127
+ "l2_leaf_reg": 1.8187025215896,
128
+ "leaf_estimation_iterations": 7,
129
+ "learning_rate": 0.005177304142,
130
+ "max_bin": 254,
131
+ "max_ctr_complexity": 4,
132
+ "model_size_reg": 0.5247386875068,
133
+ "one_hot_max_size": 53,
134
+ "subsample": 0.8705228845742,
135
+ },
136
+ {
137
+ "ag_args": {'name_suffix': '_r24', 'priority': -15},
138
+ "boosting_type": 'Plain',
139
+ "bootstrap_type": 'Bernoulli',
140
+ "colsample_bylevel": 0.8597809376276,
141
+ "depth": 8,
142
+ "grow_policy": 'Depthwise',
143
+ "l2_leaf_reg": 0.3628261923976,
144
+ "leaf_estimation_iterations": 5,
145
+ "learning_rate": 0.016851077771,
146
+ "max_bin": 254,
147
+ "max_ctr_complexity": 4,
148
+ "model_size_reg": 0.1253820547902,
149
+ "one_hot_max_size": 20,
150
+ "subsample": 0.8120271122061,
151
+ },
152
+ {
153
+ "ag_args": {'name_suffix': '_r91', 'priority': -17},
154
+ "boosting_type": 'Plain',
155
+ "bootstrap_type": 'Bernoulli',
156
+ "colsample_bylevel": 0.8959275863514,
157
+ "depth": 4,
158
+ "grow_policy": 'SymmetricTree',
159
+ "l2_leaf_reg": 0.0026915894253,
160
+ "leaf_estimation_iterations": 12,
161
+ "learning_rate": 0.0475233791203,
162
+ "max_bin": 254,
163
+ "max_ctr_complexity": 5,
164
+ "model_size_reg": 0.1633175256924,
165
+ "one_hot_max_size": 11,
166
+ "subsample": 0.798554178926,
167
+ },
168
+ ],
169
+ "TABM": [
170
+ {
171
+ "ag_args": {'name_suffix': '_r184', 'priority': -6},
172
+ "amp": False,
173
+ "arch_type": 'tabm-mini',
174
+ "batch_size": 'auto',
175
+ "d_block": 864,
176
+ "d_embedding": 24,
177
+ "dropout": 0.0,
178
+ "gradient_clipping_norm": 1.0,
179
+ "lr": 0.0019256819924656217,
180
+ "n_blocks": 3,
181
+ "num_emb_n_bins": 3,
182
+ "num_emb_type": 'pwl',
183
+ "patience": 16,
184
+ "share_training_batches": False,
185
+ "tabm_k": 32,
186
+ "weight_decay": 0.0,
187
+ },
188
+ {
189
+ "ag_args": {'name_suffix': '_r69', 'priority': -7},
190
+ "amp": False,
191
+ "arch_type": 'tabm-mini',
192
+ "batch_size": 'auto',
193
+ "d_block": 848,
194
+ "d_embedding": 28,
195
+ "dropout": 0.40215621636031007,
196
+ "gradient_clipping_norm": 1.0,
197
+ "lr": 0.0010413640454559532,
198
+ "n_blocks": 3,
199
+ "num_emb_n_bins": 18,
200
+ "num_emb_type": 'pwl',
201
+ "patience": 16,
202
+ "share_training_batches": False,
203
+ "tabm_k": 32,
204
+ "weight_decay": 0.0,
205
+ },
206
+ {
207
+ "ag_args": {'name_suffix': '_r52', 'priority': -11},
208
+ "amp": False,
209
+ "arch_type": 'tabm-mini',
210
+ "batch_size": 'auto',
211
+ "d_block": 1024,
212
+ "d_embedding": 32,
213
+ "dropout": 0.0,
214
+ "gradient_clipping_norm": 1.0,
215
+ "lr": 0.0006297851297842611,
216
+ "n_blocks": 4,
217
+ "num_emb_n_bins": 22,
218
+ "num_emb_type": 'pwl',
219
+ "patience": 16,
220
+ "share_training_batches": False,
221
+ "tabm_k": 32,
222
+ "weight_decay": 0.06900108498839816,
223
+ },
224
+ {
225
+ "ag_args": {'priority': -13},
226
+ },
227
+ {
228
+ "ag_args": {'name_suffix': '_r191', 'priority': -14},
229
+ "amp": False,
230
+ "arch_type": 'tabm-mini',
231
+ "batch_size": 'auto',
232
+ "d_block": 864,
233
+ "d_embedding": 8,
234
+ "dropout": 0.45321529282058803,
235
+ "gradient_clipping_norm": 1.0,
236
+ "lr": 0.0003781238075322413,
237
+ "n_blocks": 4,
238
+ "num_emb_n_bins": 27,
239
+ "num_emb_type": 'pwl',
240
+ "patience": 16,
241
+ "share_training_batches": False,
242
+ "tabm_k": 32,
243
+ "weight_decay": 0.01766851962579851,
244
+ },
245
+ {
246
+ "ag_args": {'name_suffix': '_r49', 'priority': -20},
247
+ "amp": False,
248
+ "arch_type": 'tabm-mini',
249
+ "batch_size": 'auto',
250
+ "d_block": 640,
251
+ "d_embedding": 28,
252
+ "dropout": 0.15296207419190627,
253
+ "gradient_clipping_norm": 1.0,
254
+ "lr": 0.002277678490593717,
255
+ "n_blocks": 3,
256
+ "num_emb_n_bins": 48,
257
+ "num_emb_type": 'pwl',
258
+ "patience": 16,
259
+ "share_training_batches": False,
260
+ "tabm_k": 32,
261
+ "weight_decay": 0.0578159148243893,
262
+ },
263
+ ],
264
+ "TABICL": [
265
+ {
266
+ "ag_args": {'priority': -8},
267
+ },
268
+ ],
269
+ "XGB": [
270
+ {
271
+ "ag_args": {'name_suffix': '_r171', 'priority': -9},
272
+ "colsample_bylevel": 0.9213705632288,
273
+ "colsample_bynode": 0.6443385965381,
274
+ "enable_categorical": True,
275
+ "grow_policy": 'lossguide',
276
+ "learning_rate": 0.0068171645251,
277
+ "max_cat_to_onehot": 8,
278
+ "max_depth": 6,
279
+ "max_leaves": 10,
280
+ "min_child_weight": 0.0507304250576,
281
+ "reg_alpha": 4.2446346389037,
282
+ "reg_lambda": 1.4800570021253,
283
+ "subsample": 0.9656290596647,
284
+ },
285
+ {
286
+ "ag_args": {'name_suffix': '_r40', 'priority': -18},
287
+ "colsample_bylevel": 0.6377491713202,
288
+ "colsample_bynode": 0.9237625621103,
289
+ "enable_categorical": True,
290
+ "grow_policy": 'lossguide',
291
+ "learning_rate": 0.0112462621131,
292
+ "max_cat_to_onehot": 33,
293
+ "max_depth": 10,
294
+ "max_leaves": 35,
295
+ "min_child_weight": 0.1403464856034,
296
+ "reg_alpha": 3.4960653958503,
297
+ "reg_lambda": 1.3062320805235,
298
+ "subsample": 0.6948898835178,
299
+ },
300
+ ],
301
+ "MITRA": [
302
+ {
303
+ "n_estimators": 1,
304
+ "fine_tune": True,
305
+ "fine_tune_steps": 50,
306
+ "ag_args": {'priority': -21},
307
+ },
308
+ ],
309
+ }
@@ -24,10 +24,16 @@ class TrainerFinetune(BaseEstimator):
24
24
  cfg: ConfigRun,
25
25
  model: torch.nn.Module,
26
26
  n_classes: int,
27
- device: str
28
- ) -> None:
27
+ device: str,
28
+ rng: np.random.RandomState = None,
29
+ verbose: bool = True,
30
+ ):
29
31
 
30
32
  self.cfg = cfg
33
+ if rng is None:
34
+ rng = np.random.RandomState(self.cfg.seed)
35
+ self.rng = rng
36
+ self.verbose = verbose
31
37
  self.device = device
32
38
  self.model = model.to(self.device, non_blocking=True)
33
39
  self.n_classes = n_classes
@@ -81,13 +87,15 @@ class TrainerFinetune(BaseEstimator):
81
87
  y = y_train_transformed,
82
88
  task = self.cfg.task,
83
89
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
84
- max_samples_query = self.cfg.hyperparams['max_samples_query']
90
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
91
+ rng=self.rng,
85
92
  )
86
93
 
87
94
  self.checkpoint.reset(self.model)
88
95
 
89
96
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
90
- self.log_start_metrics(metrics_valid)
97
+ if self.verbose:
98
+ self.log_start_metrics(metrics_valid)
91
99
  self.checkpoint(self.model, metrics_valid.loss)
92
100
 
93
101
  start_time = time.time()
@@ -154,13 +162,15 @@ class TrainerFinetune(BaseEstimator):
154
162
  metrics_train = prediction_metrics_tracker.get_metrics()
155
163
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
156
164
 
157
- self.log_metrics(epoch, metrics_train, metrics_valid)
165
+ if self.verbose:
166
+ self.log_metrics(epoch, metrics_train, metrics_valid)
158
167
 
159
168
  self.checkpoint(self.model, metrics_valid.loss)
160
169
 
161
170
  self.early_stopping(metrics_valid.metrics[self.metric])
162
171
  if self.early_stopping.we_should_stop():
163
- logger.info("Early stopping")
172
+ if self.verbose:
173
+ logger.info("Early stopping")
164
174
  break
165
175
 
166
176
  if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
@@ -192,6 +202,7 @@ class TrainerFinetune(BaseEstimator):
192
202
  y_query = y_query,
193
203
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
194
204
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
205
+ rng=self.rng,
195
206
  )
196
207
 
197
208
  loader = self.make_loader(dataset, training=False)
@@ -246,6 +257,7 @@ class TrainerFinetune(BaseEstimator):
246
257
  y_query = None,
247
258
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
248
259
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
260
+ rng=self.rng,
249
261
  )
250
262
 
251
263
  loader = self.make_loader(dataset, training=False)
@@ -26,13 +26,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
26
26
  x_query: np.ndarray,
27
27
  y_query: Optional[np.ndarray],
28
28
  max_samples_support: int,
29
- max_samples_query: int
29
+ max_samples_query: int,
30
+ rng: np.random.RandomState,
30
31
  ):
31
32
  """
32
33
  :param: max_features: number of features the tab pfn model has been trained on
33
34
  """
34
35
 
35
36
  self.cfg = cfg
37
+ self.rng = rng
36
38
 
37
39
  self.x_support = x_support
38
40
  self.y_support = y_support
@@ -59,7 +61,7 @@ class DatasetFinetune(torch.utils.data.Dataset):
59
61
 
60
62
  def __getitem__(self, idx):
61
63
 
62
- support_indices = np.random.choice(
64
+ support_indices = self.rng.choice(
63
65
  self.n_samples_support,
64
66
  size=self.support_size,
65
67
  replace=False
@@ -101,7 +103,8 @@ def DatasetFinetuneGenerator(
101
103
  y: np.ndarray,
102
104
  task: Task,
103
105
  max_samples_support: int,
104
- max_samples_query: int
106
+ max_samples_query: int,
107
+ rng: np.random.RandomState,
105
108
  ):
106
109
  """
107
110
  The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
@@ -112,7 +115,7 @@ def DatasetFinetuneGenerator(
112
115
 
113
116
  while True:
114
117
 
115
- x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=cfg.seed)
118
+ x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=rng)
116
119
  n_samples_support = x_support.shape[0]
117
120
  n_samples_query = x_query.shape[0]
118
121
 
@@ -127,6 +130,7 @@ def DatasetFinetuneGenerator(
127
130
  y_query=y_query[:query_size],
128
131
  max_samples_support=max_samples_support,
129
132
  max_samples_query=max_samples_query,
133
+ rng=rng,
130
134
  )
131
135
 
132
136
  yield dataset_finetune
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  from sklearn.model_selection import StratifiedKFold, train_test_split
3
5
 
@@ -19,9 +21,11 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> t
19
21
 
20
22
 
21
23
  def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
24
+ if isinstance(seed, int):
25
+ seed = np.random.RandomState(seed)
22
26
 
23
27
  # Stratify doesn't shuffle the data, so we shuffle it first
24
- permutation = np.random.permutation(len(y))
28
+ permutation = seed.permutation(len(y))
25
29
  x, y = x[permutation], y[permutation]
26
30
 
27
31
  min_samples_per_class = np.min(np.bincount(y))
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
@@ -29,6 +30,8 @@ from ..._internal.models.embedding import (
29
30
  Tab2DQuantileEmbeddingX,
30
31
  )
31
32
 
33
+ logger = logging.getLogger(__name__)
34
+
32
35
 
33
36
  class Tab2D(BaseModel):
34
37
 
@@ -1,49 +1,43 @@
1
- # TODO: To ensure deterministic operations we need to set torch.use_deterministic_algorithms(True)
2
- # and os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'. The CUBLAS environment variable configures
3
- # the workspace size for certain CUBLAS operations to ensure reproducibility when using CUDA >= 10.2.
4
- # Both settings are required to ensure deterministic behavior in operations such as matrix multiplications.
5
- import os
6
-
7
- os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8"
1
+ from __future__ import annotations
8
2
 
3
+ import logging
9
4
  import os
10
5
  from typing import List, Optional
11
6
 
12
7
  import pandas as pd
13
- import torch
14
- import logging
15
8
 
16
9
  from autogluon.common.utils.resource_utils import ResourceManager
17
10
  from autogluon.core.models import AbstractModel
11
+ from autogluon.features.generators import LabelEncoderFeatureGenerator
12
+ from autogluon.tabular import __version__
18
13
 
19
14
  logger = logging.getLogger(__name__)
20
15
 
21
16
 
22
- # TODO: Needs memory usage estimate method
23
17
  class MitraModel(AbstractModel):
24
18
  ag_key = "MITRA"
25
19
  ag_name = "Mitra"
26
20
  weights_file_name = "model.pt"
27
21
  ag_priority = 55
28
22
 
29
- def __init__(self, problem_type=None, **kwargs):
23
+ def __init__(self, **kwargs):
30
24
  super().__init__(**kwargs)
31
- self.problem_type = problem_type
32
25
  self._weights_saved = False
26
+ self._feature_generator = None
33
27
 
34
28
  @staticmethod
35
29
  def _get_default_device():
36
30
  """Get the best available device for the current system."""
37
31
  if ResourceManager.get_gpu_count_torch(cuda_only=True) > 0:
38
- logger.info("Using CUDA GPU")
32
+ logger.log(15, "Using CUDA GPU")
39
33
  return "cuda"
40
34
  else:
41
35
  return "cpu"
42
36
 
43
37
  def get_model_cls(self):
44
- from .sklearn_interface import MitraClassifier
45
-
46
38
  if self.problem_type in ["binary", "multiclass"]:
39
+ from .sklearn_interface import MitraClassifier
40
+
47
41
  model_cls = MitraClassifier
48
42
  elif self.problem_type == "regression":
49
43
  from .sklearn_interface import MitraRegressor
@@ -53,6 +47,23 @@ class MitraModel(AbstractModel):
53
47
  raise AssertionError(f"Unsupported problem_type: {self.problem_type}")
54
48
  return model_cls
55
49
 
50
+ def _preprocess(self, X: pd.DataFrame, is_train: bool = False, **kwargs) -> pd.DataFrame:
51
+ X = super()._preprocess(X, **kwargs)
52
+
53
+ if is_train:
54
+ # X will be the training data.
55
+ self._feature_generator = LabelEncoderFeatureGenerator(verbosity=0)
56
+ self._feature_generator.fit(X=X)
57
+
58
+ # This converts categorical features to numeric via stateful label encoding.
59
+ if self._feature_generator.features_in:
60
+ X = X.copy()
61
+ X[self._feature_generator.features_in] = self._feature_generator.transform(
62
+ X=X
63
+ )
64
+
65
+ return X
66
+
56
67
  def _fit(
57
68
  self,
58
69
  X: pd.DataFrame,
@@ -61,11 +72,25 @@ class MitraModel(AbstractModel):
61
72
  y_val: pd.Series = None,
62
73
  time_limit: float = None,
63
74
  num_cpus: int = 1,
75
+ num_gpus: float = 0,
76
+ verbosity: int = 2,
64
77
  **kwargs,
65
78
  ):
66
79
  # TODO: Reset the number of threads based on the specified num_cpus
67
80
  need_to_reset_torch_threads = False
68
81
  torch_threads_og = None
82
+
83
+ try:
84
+ model_cls = self.get_model_cls()
85
+ import torch
86
+ except ImportError as err:
87
+ logger.log(
88
+ 40,
89
+ f"\tFailed to import Mitra! To use the Mitra model, "
90
+ f"do: `pip install autogluon.tabular[mitra]=={__version__}`.",
91
+ )
92
+ raise err
93
+
69
94
  if num_cpus is not None and isinstance(num_cpus, (int, float)):
70
95
  torch_threads_og = torch.get_num_threads()
71
96
  if torch_threads_og != num_cpus:
@@ -73,9 +98,14 @@ class MitraModel(AbstractModel):
73
98
  torch.set_num_threads(num_cpus)
74
99
  need_to_reset_torch_threads = True
75
100
 
76
- model_cls = self.get_model_cls()
77
-
78
101
  hyp = self._get_model_params()
102
+
103
+ if hyp.get("device", None) is None:
104
+ if num_gpus == 0:
105
+ hyp["device"] = "cpu"
106
+ else:
107
+ hyp["device"] = self._get_default_device()
108
+
79
109
  if "state_dict_classification" in hyp:
80
110
  state_dict_classification = hyp.pop("state_dict_classification")
81
111
  if self.problem_type in ["binary", "multiclass"]:
@@ -85,11 +115,14 @@ class MitraModel(AbstractModel):
85
115
  if self.problem_type in ["regression"]:
86
116
  hyp["state_dict"] = state_dict_regression
87
117
 
118
+ if "verbose" not in hyp:
119
+ hyp["verbose"] = verbosity >= 3
120
+
88
121
  self.model = model_cls(
89
122
  **hyp,
90
123
  )
91
124
 
92
- X = self.preprocess(X)
125
+ X = self.preprocess(X, is_train=True)
93
126
  if X_val is not None:
94
127
  X_val = self.preprocess(X_val)
95
128
 
@@ -106,7 +139,6 @@ class MitraModel(AbstractModel):
106
139
 
107
140
  def _set_default_params(self):
108
141
  default_params = {
109
- "device": self._get_default_device(),
110
142
  "n_estimators": 1,
111
143
  }
112
144
  for param, val in default_params.items():
@@ -184,6 +216,24 @@ class MitraModel(AbstractModel):
184
216
 
185
217
  return num_cpus, num_gpus
186
218
 
219
+ def get_minimum_resources(self, is_gpu_available: bool = False) -> dict[str, int | float]:
220
+ """
221
+ Parameters
222
+ ----------
223
+ is_gpu_available : bool, default = False
224
+ Whether gpu is available in the system.
225
+ Model that can be trained both on cpu and gpu can decide the minimum resources based on this.
226
+
227
+ Returns a dictionary of minimum resource requirements to fit the model.
228
+ Subclass should consider overriding this method if it requires more resources to train.
229
+ If a resource is not part of the output dictionary, it is considered unnecessary.
230
+ Valid keys: 'num_cpus', 'num_gpus'.
231
+ """
232
+ return {
233
+ "num_cpus": 1,
234
+ "num_gpus": 0.5,
235
+ }
236
+
187
237
  def _estimate_memory_usage(self, X: pd.DataFrame, **kwargs) -> int:
188
238
  return self.estimate_memory_usage_static(
189
239
  X=X, problem_type=self.problem_type, num_classes=self.num_classes, **kwargs
@@ -196,12 +246,13 @@ class MitraModel(AbstractModel):
196
246
  X: pd.DataFrame,
197
247
  **kwargs,
198
248
  ) -> int:
199
- return max(
249
+ # Multiply by 0.9 as currently this is overly safe
250
+ return int(0.9 * max(
200
251
  cls._estimate_memory_usage_static_cpu_icl(X=X, **kwargs),
201
252
  cls._estimate_memory_usage_static_cpu_ft_icl(X=X, **kwargs),
202
253
  cls._estimate_memory_usage_static_gpu_cpu(X=X, **kwargs),
203
254
  cls._estimate_memory_usage_static_gpu_gpu(X=X, **kwargs),
204
- )
255
+ ))
205
256
 
206
257
  @classmethod
207
258
  def _estimate_memory_usage_static_cpu_icl(