autogluon.tabular 1.3.2b20250723__py3-none-any.whl → 1.4.0b20250725__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (31) hide show
  1. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  2. autogluon/tabular/configs/presets_configs.py +51 -23
  3. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  4. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +309 -0
  5. autogluon/tabular/models/automm/automm_model.py +2 -0
  6. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  7. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +18 -6
  8. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +8 -4
  9. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +5 -1
  10. autogluon/tabular/models/mitra/_internal/models/tab2d.py +3 -0
  11. autogluon/tabular/models/mitra/mitra_model.py +85 -21
  12. autogluon/tabular/models/mitra/sklearn_interface.py +15 -13
  13. autogluon/tabular/models/realmlp/realmlp_model.py +13 -6
  14. autogluon/tabular/models/tabicl/tabicl_model.py +17 -8
  15. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +3 -0
  16. autogluon/tabular/models/tabm/tabm_model.py +14 -6
  17. autogluon/tabular/models/tabm/tabm_reference.py +2 -0
  18. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +4 -0
  19. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +29 -12
  20. autogluon/tabular/predictor/predictor.py +45 -5
  21. autogluon/tabular/trainer/abstract_trainer.py +2 -0
  22. autogluon/tabular/version.py +1 -1
  23. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/METADATA +40 -18
  24. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/RECORD +31 -30
  25. /autogluon.tabular-1.3.2b20250723-py3.9-nspkg.pth → /autogluon.tabular-1.4.0b20250725-py3.9-nspkg.pth +0 -0
  26. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/LICENSE +0 -0
  27. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/NOTICE +0 -0
  28. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/WHEEL +0 -0
  29. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/namespace_packages.txt +0 -0
  30. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/top_level.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250723.dist-info → autogluon.tabular-1.4.0b20250725.dist-info}/zip-safe +0 -0
@@ -0,0 +1,309 @@
1
+ # optimized for <=10000 samples and <=500 features, with a GPU present
2
+ hyperparameter_portfolio_zeroshot_2025_small = {
3
+ "TABPFNV2": [
4
+ {
5
+ "ag_args": {'name_suffix': '_r143', 'priority': -1},
6
+ "average_before_softmax": False,
7
+ "classification_model_path": 'tabpfn-v2-classifier-od3j1g5m.ckpt',
8
+ "inference_config/FINGERPRINT_FEATURE": False,
9
+ "inference_config/OUTLIER_REMOVAL_STD": None,
10
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
11
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'safepower', 'subsample_features': -1}, {'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': -1}],
12
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None, 'power'],
13
+ "inference_config/SUBSAMPLE_SAMPLES": 0.99,
14
+ "model_type": 'single',
15
+ "n_ensemble_repeats": 4,
16
+ "regression_model_path": 'tabpfn-v2-regressor-wyl4o83o.ckpt',
17
+ "softmax_temperature": 0.75,
18
+ },
19
+ {
20
+ "ag_args": {'name_suffix': '_r94', 'priority': -3},
21
+ "average_before_softmax": True,
22
+ "classification_model_path": 'tabpfn-v2-classifier-vutqq28w.ckpt',
23
+ "inference_config/FINGERPRINT_FEATURE": True,
24
+ "inference_config/OUTLIER_REMOVAL_STD": None,
25
+ "inference_config/POLYNOMIAL_FEATURES": 'no',
26
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'ordinal_very_common_categories_shuffled', 'global_transformer_name': None, 'name': 'quantile_uni', 'subsample_features': 0.99}],
27
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": [None],
28
+ "inference_config/SUBSAMPLE_SAMPLES": None,
29
+ "model_type": 'single',
30
+ "n_ensemble_repeats": 4,
31
+ "regression_model_path": 'tabpfn-v2-regressor-5wof9ojf.ckpt',
32
+ "softmax_temperature": 0.9,
33
+ },
34
+ {
35
+ "ag_args": {'name_suffix': '_r181', 'priority': -4},
36
+ "average_before_softmax": False,
37
+ "classification_model_path": 'tabpfn-v2-classifier-llderlii.ckpt',
38
+ "inference_config/FINGERPRINT_FEATURE": False,
39
+ "inference_config/OUTLIER_REMOVAL_STD": 9.0,
40
+ "inference_config/POLYNOMIAL_FEATURES": 50,
41
+ "inference_config/PREPROCESS_TRANSFORMS": [{'append_original': True, 'categorical_name': 'onehot', 'global_transformer_name': 'svd', 'name': 'quantile_uni_coarse', 'subsample_features': 0.99}],
42
+ "inference_config/REGRESSION_Y_PREPROCESS_TRANSFORMS": ['power'],
43
+ "inference_config/SUBSAMPLE_SAMPLES": None,
44
+ "model_type": 'single',
45
+ "n_ensemble_repeats": 4,
46
+ "regression_model_path": 'tabpfn-v2-regressor.ckpt',
47
+ "softmax_temperature": 0.95,
48
+ },
49
+ ],
50
+ "GBM": [
51
+ {
52
+ "ag_args": {'name_suffix': '_r33', 'priority': -2},
53
+ "bagging_fraction": 0.9625293420216,
54
+ "bagging_freq": 1,
55
+ "cat_l2": 0.1236875455555,
56
+ "cat_smooth": 68.8584757332856,
57
+ "extra_trees": False,
58
+ "feature_fraction": 0.6189215809382,
59
+ "lambda_l1": 0.1641757352921,
60
+ "lambda_l2": 0.6937755557881,
61
+ "learning_rate": 0.0154031028561,
62
+ "max_cat_to_onehot": 17,
63
+ "min_data_in_leaf": 1,
64
+ "min_data_per_group": 30,
65
+ "num_leaves": 68,
66
+ },
67
+ {
68
+ "ag_args": {'name_suffix': '_r21', 'priority': -16},
69
+ "bagging_fraction": 0.7218730663234,
70
+ "bagging_freq": 1,
71
+ "cat_l2": 0.0296205152578,
72
+ "cat_smooth": 0.0010255271303,
73
+ "extra_trees": False,
74
+ "feature_fraction": 0.4557131604374,
75
+ "lambda_l1": 0.5219704038237,
76
+ "lambda_l2": 0.1070959487853,
77
+ "learning_rate": 0.0055891584996,
78
+ "max_cat_to_onehot": 71,
79
+ "min_data_in_leaf": 50,
80
+ "min_data_per_group": 10,
81
+ "num_leaves": 30,
82
+ },
83
+ {
84
+ "ag_args": {'name_suffix': '_r11', 'priority': -19},
85
+ "bagging_fraction": 0.775784726514,
86
+ "bagging_freq": 1,
87
+ "cat_l2": 0.3888471449178,
88
+ "cat_smooth": 0.0057144748021,
89
+ "extra_trees": True,
90
+ "feature_fraction": 0.7732354787904,
91
+ "lambda_l1": 0.2211002452568,
92
+ "lambda_l2": 1.1318405980187,
93
+ "learning_rate": 0.0090151778542,
94
+ "max_cat_to_onehot": 15,
95
+ "min_data_in_leaf": 4,
96
+ "min_data_per_group": 15,
97
+ "num_leaves": 2,
98
+ },
99
+ ],
100
+ "CAT": [
101
+ {
102
+ "ag_args": {'priority': -5},
103
+ },
104
+ {
105
+ "ag_args": {'name_suffix': '_r51', 'priority': -10},
106
+ "boosting_type": 'Plain',
107
+ "bootstrap_type": 'Bernoulli',
108
+ "colsample_bylevel": 0.8771035272558,
109
+ "depth": 7,
110
+ "grow_policy": 'SymmetricTree',
111
+ "l2_leaf_reg": 2.0107286863021,
112
+ "leaf_estimation_iterations": 2,
113
+ "learning_rate": 0.0058424016622,
114
+ "max_bin": 254,
115
+ "max_ctr_complexity": 4,
116
+ "model_size_reg": 0.1307400355809,
117
+ "one_hot_max_size": 23,
118
+ "subsample": 0.809527841437,
119
+ },
120
+ {
121
+ "ag_args": {'name_suffix': '_r10', 'priority': -12},
122
+ "boosting_type": 'Plain',
123
+ "bootstrap_type": 'Bernoulli',
124
+ "colsample_bylevel": 0.8994502668431,
125
+ "depth": 6,
126
+ "grow_policy": 'Depthwise',
127
+ "l2_leaf_reg": 1.8187025215896,
128
+ "leaf_estimation_iterations": 7,
129
+ "learning_rate": 0.005177304142,
130
+ "max_bin": 254,
131
+ "max_ctr_complexity": 4,
132
+ "model_size_reg": 0.5247386875068,
133
+ "one_hot_max_size": 53,
134
+ "subsample": 0.8705228845742,
135
+ },
136
+ {
137
+ "ag_args": {'name_suffix': '_r24', 'priority': -15},
138
+ "boosting_type": 'Plain',
139
+ "bootstrap_type": 'Bernoulli',
140
+ "colsample_bylevel": 0.8597809376276,
141
+ "depth": 8,
142
+ "grow_policy": 'Depthwise',
143
+ "l2_leaf_reg": 0.3628261923976,
144
+ "leaf_estimation_iterations": 5,
145
+ "learning_rate": 0.016851077771,
146
+ "max_bin": 254,
147
+ "max_ctr_complexity": 4,
148
+ "model_size_reg": 0.1253820547902,
149
+ "one_hot_max_size": 20,
150
+ "subsample": 0.8120271122061,
151
+ },
152
+ {
153
+ "ag_args": {'name_suffix': '_r91', 'priority': -17},
154
+ "boosting_type": 'Plain',
155
+ "bootstrap_type": 'Bernoulli',
156
+ "colsample_bylevel": 0.8959275863514,
157
+ "depth": 4,
158
+ "grow_policy": 'SymmetricTree',
159
+ "l2_leaf_reg": 0.0026915894253,
160
+ "leaf_estimation_iterations": 12,
161
+ "learning_rate": 0.0475233791203,
162
+ "max_bin": 254,
163
+ "max_ctr_complexity": 5,
164
+ "model_size_reg": 0.1633175256924,
165
+ "one_hot_max_size": 11,
166
+ "subsample": 0.798554178926,
167
+ },
168
+ ],
169
+ "TABM": [
170
+ {
171
+ "ag_args": {'name_suffix': '_r184', 'priority': -6},
172
+ "amp": False,
173
+ "arch_type": 'tabm-mini',
174
+ "batch_size": 'auto',
175
+ "d_block": 864,
176
+ "d_embedding": 24,
177
+ "dropout": 0.0,
178
+ "gradient_clipping_norm": 1.0,
179
+ "lr": 0.0019256819924656217,
180
+ "n_blocks": 3,
181
+ "num_emb_n_bins": 3,
182
+ "num_emb_type": 'pwl',
183
+ "patience": 16,
184
+ "share_training_batches": False,
185
+ "tabm_k": 32,
186
+ "weight_decay": 0.0,
187
+ },
188
+ {
189
+ "ag_args": {'name_suffix': '_r69', 'priority': -7},
190
+ "amp": False,
191
+ "arch_type": 'tabm-mini',
192
+ "batch_size": 'auto',
193
+ "d_block": 848,
194
+ "d_embedding": 28,
195
+ "dropout": 0.40215621636031007,
196
+ "gradient_clipping_norm": 1.0,
197
+ "lr": 0.0010413640454559532,
198
+ "n_blocks": 3,
199
+ "num_emb_n_bins": 18,
200
+ "num_emb_type": 'pwl',
201
+ "patience": 16,
202
+ "share_training_batches": False,
203
+ "tabm_k": 32,
204
+ "weight_decay": 0.0,
205
+ },
206
+ {
207
+ "ag_args": {'name_suffix': '_r52', 'priority': -11},
208
+ "amp": False,
209
+ "arch_type": 'tabm-mini',
210
+ "batch_size": 'auto',
211
+ "d_block": 1024,
212
+ "d_embedding": 32,
213
+ "dropout": 0.0,
214
+ "gradient_clipping_norm": 1.0,
215
+ "lr": 0.0006297851297842611,
216
+ "n_blocks": 4,
217
+ "num_emb_n_bins": 22,
218
+ "num_emb_type": 'pwl',
219
+ "patience": 16,
220
+ "share_training_batches": False,
221
+ "tabm_k": 32,
222
+ "weight_decay": 0.06900108498839816,
223
+ },
224
+ {
225
+ "ag_args": {'priority': -13},
226
+ },
227
+ {
228
+ "ag_args": {'name_suffix': '_r191', 'priority': -14},
229
+ "amp": False,
230
+ "arch_type": 'tabm-mini',
231
+ "batch_size": 'auto',
232
+ "d_block": 864,
233
+ "d_embedding": 8,
234
+ "dropout": 0.45321529282058803,
235
+ "gradient_clipping_norm": 1.0,
236
+ "lr": 0.0003781238075322413,
237
+ "n_blocks": 4,
238
+ "num_emb_n_bins": 27,
239
+ "num_emb_type": 'pwl',
240
+ "patience": 16,
241
+ "share_training_batches": False,
242
+ "tabm_k": 32,
243
+ "weight_decay": 0.01766851962579851,
244
+ },
245
+ {
246
+ "ag_args": {'name_suffix': '_r49', 'priority': -20},
247
+ "amp": False,
248
+ "arch_type": 'tabm-mini',
249
+ "batch_size": 'auto',
250
+ "d_block": 640,
251
+ "d_embedding": 28,
252
+ "dropout": 0.15296207419190627,
253
+ "gradient_clipping_norm": 1.0,
254
+ "lr": 0.002277678490593717,
255
+ "n_blocks": 3,
256
+ "num_emb_n_bins": 48,
257
+ "num_emb_type": 'pwl',
258
+ "patience": 16,
259
+ "share_training_batches": False,
260
+ "tabm_k": 32,
261
+ "weight_decay": 0.0578159148243893,
262
+ },
263
+ ],
264
+ "TABICL": [
265
+ {
266
+ "ag_args": {'priority': -8},
267
+ },
268
+ ],
269
+ "XGB": [
270
+ {
271
+ "ag_args": {'name_suffix': '_r171', 'priority': -9},
272
+ "colsample_bylevel": 0.9213705632288,
273
+ "colsample_bynode": 0.6443385965381,
274
+ "enable_categorical": True,
275
+ "grow_policy": 'lossguide',
276
+ "learning_rate": 0.0068171645251,
277
+ "max_cat_to_onehot": 8,
278
+ "max_depth": 6,
279
+ "max_leaves": 10,
280
+ "min_child_weight": 0.0507304250576,
281
+ "reg_alpha": 4.2446346389037,
282
+ "reg_lambda": 1.4800570021253,
283
+ "subsample": 0.9656290596647,
284
+ },
285
+ {
286
+ "ag_args": {'name_suffix': '_r40', 'priority': -18},
287
+ "colsample_bylevel": 0.6377491713202,
288
+ "colsample_bynode": 0.9237625621103,
289
+ "enable_categorical": True,
290
+ "grow_policy": 'lossguide',
291
+ "learning_rate": 0.0112462621131,
292
+ "max_cat_to_onehot": 33,
293
+ "max_depth": 10,
294
+ "max_leaves": 35,
295
+ "min_child_weight": 0.1403464856034,
296
+ "reg_alpha": 3.4960653958503,
297
+ "reg_lambda": 1.3062320805235,
298
+ "subsample": 0.6948898835178,
299
+ },
300
+ ],
301
+ "MITRA": [
302
+ {
303
+ "n_estimators": 1,
304
+ "fine_tune": True,
305
+ "fine_tune_steps": 50,
306
+ "ag_args": {'priority': -21},
307
+ },
308
+ ],
309
+ }
@@ -65,6 +65,8 @@ class MultiModalPredictorModel(AbstractModel):
65
65
  Names of the features.
66
66
  feature_metadata
67
67
  The feature metadata.
68
+
69
+ .. versionadded:: 0.3.0
68
70
  """
69
71
  super().__init__(**kwargs)
70
72
  self._label_column_name = None
@@ -17,7 +17,8 @@ class FTTransformerModel(MultiModalPredictorModel):
17
17
  ag_name = "FTTransformer"
18
18
 
19
19
  def __init__(self, **kwargs):
20
- """Wrapper of autogluon.multimodal.MultiModalPredictor.
20
+ """
21
+ FT-Transformer model.
21
22
 
22
23
  The features can be a mix of
23
24
  - categorical column
@@ -48,6 +49,8 @@ class FTTransformerModel(MultiModalPredictorModel):
48
49
  Names of the features.
49
50
  feature_metadata
50
51
  The feature metadata.
52
+
53
+ .. versionadded:: 0.6.0
51
54
  """
52
55
  super().__init__(**kwargs)
53
56
 
@@ -24,10 +24,16 @@ class TrainerFinetune(BaseEstimator):
24
24
  cfg: ConfigRun,
25
25
  model: torch.nn.Module,
26
26
  n_classes: int,
27
- device: str
28
- ) -> None:
27
+ device: str,
28
+ rng: np.random.RandomState = None,
29
+ verbose: bool = True,
30
+ ):
29
31
 
30
32
  self.cfg = cfg
33
+ if rng is None:
34
+ rng = np.random.RandomState(self.cfg.seed)
35
+ self.rng = rng
36
+ self.verbose = verbose
31
37
  self.device = device
32
38
  self.model = model.to(self.device, non_blocking=True)
33
39
  self.n_classes = n_classes
@@ -81,13 +87,15 @@ class TrainerFinetune(BaseEstimator):
81
87
  y = y_train_transformed,
82
88
  task = self.cfg.task,
83
89
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
84
- max_samples_query = self.cfg.hyperparams['max_samples_query']
90
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
91
+ rng=self.rng,
85
92
  )
86
93
 
87
94
  self.checkpoint.reset(self.model)
88
95
 
89
96
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
90
- self.log_start_metrics(metrics_valid)
97
+ if self.verbose:
98
+ self.log_start_metrics(metrics_valid)
91
99
  self.checkpoint(self.model, metrics_valid.loss)
92
100
 
93
101
  start_time = time.time()
@@ -154,13 +162,15 @@ class TrainerFinetune(BaseEstimator):
154
162
  metrics_train = prediction_metrics_tracker.get_metrics()
155
163
  metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
156
164
 
157
- self.log_metrics(epoch, metrics_train, metrics_valid)
165
+ if self.verbose:
166
+ self.log_metrics(epoch, metrics_train, metrics_valid)
158
167
 
159
168
  self.checkpoint(self.model, metrics_valid.loss)
160
169
 
161
170
  self.early_stopping(metrics_valid.metrics[self.metric])
162
171
  if self.early_stopping.we_should_stop():
163
- logger.info("Early stopping")
172
+ if self.verbose:
173
+ logger.info("Early stopping")
164
174
  break
165
175
 
166
176
  if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
@@ -192,6 +202,7 @@ class TrainerFinetune(BaseEstimator):
192
202
  y_query = y_query,
193
203
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
194
204
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
205
+ rng=self.rng,
195
206
  )
196
207
 
197
208
  loader = self.make_loader(dataset, training=False)
@@ -246,6 +257,7 @@ class TrainerFinetune(BaseEstimator):
246
257
  y_query = None,
247
258
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
248
259
  max_samples_query = self.cfg.hyperparams['max_samples_query'],
260
+ rng=self.rng,
249
261
  )
250
262
 
251
263
  loader = self.make_loader(dataset, training=False)
@@ -26,13 +26,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
26
26
  x_query: np.ndarray,
27
27
  y_query: Optional[np.ndarray],
28
28
  max_samples_support: int,
29
- max_samples_query: int
29
+ max_samples_query: int,
30
+ rng: np.random.RandomState,
30
31
  ):
31
32
  """
32
33
  :param: max_features: number of features the tab pfn model has been trained on
33
34
  """
34
35
 
35
36
  self.cfg = cfg
37
+ self.rng = rng
36
38
 
37
39
  self.x_support = x_support
38
40
  self.y_support = y_support
@@ -59,7 +61,7 @@ class DatasetFinetune(torch.utils.data.Dataset):
59
61
 
60
62
  def __getitem__(self, idx):
61
63
 
62
- support_indices = np.random.choice(
64
+ support_indices = self.rng.choice(
63
65
  self.n_samples_support,
64
66
  size=self.support_size,
65
67
  replace=False
@@ -101,7 +103,8 @@ def DatasetFinetuneGenerator(
101
103
  y: np.ndarray,
102
104
  task: Task,
103
105
  max_samples_support: int,
104
- max_samples_query: int
106
+ max_samples_query: int,
107
+ rng: np.random.RandomState,
105
108
  ):
106
109
  """
107
110
  The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
@@ -112,7 +115,7 @@ def DatasetFinetuneGenerator(
112
115
 
113
116
  while True:
114
117
 
115
- x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=cfg.seed)
118
+ x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=rng)
116
119
  n_samples_support = x_support.shape[0]
117
120
  n_samples_query = x_query.shape[0]
118
121
 
@@ -127,6 +130,7 @@ def DatasetFinetuneGenerator(
127
130
  y_query=y_query[:query_size],
128
131
  max_samples_support=max_samples_support,
129
132
  max_samples_query=max_samples_query,
133
+ rng=rng,
130
134
  )
131
135
 
132
136
  yield dataset_finetune
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  import numpy as np
2
4
  from sklearn.model_selection import StratifiedKFold, train_test_split
3
5
 
@@ -19,9 +21,11 @@ def make_dataset_split(x: np.ndarray, y: np.ndarray, task: Task, seed: int) -> t
19
21
 
20
22
 
21
23
  def make_stratified_dataset_split(x, y, n_splits=5, seed=0):
24
+ if isinstance(seed, int):
25
+ seed = np.random.RandomState(seed)
22
26
 
23
27
  # Stratify doesn't shuffle the data, so we shuffle it first
24
- permutation = np.random.permutation(len(y))
28
+ permutation = seed.permutation(len(y))
25
29
  x, y = x[permutation], y[permutation]
26
30
 
27
31
  min_samples_per_class = np.min(np.bincount(y))
@@ -1,4 +1,5 @@
1
1
  import json
2
+ import logging
2
3
  import os
3
4
  from typing import Optional, Union
4
5
 
@@ -29,6 +30,8 @@ from ..._internal.models.embedding import (
29
30
  Tab2DQuantileEmbeddingX,
30
31
  )
31
32
 
33
+ logger = logging.getLogger(__name__)
34
+
32
35
 
33
36
  class Tab2D(BaseModel):
34
37