autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260116__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/METADATA +26 -26
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260116-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260116.dist-info}/zip-safe +0 -0
@@ -6,12 +6,15 @@ import time
6
6
  import einops
7
7
  import numpy as np
8
8
  import torch
9
- from sklearn.base import BaseEstimator
10
-
11
9
  from numpy.random import Generator
10
+ from sklearn.base import BaseEstimator
12
11
 
13
12
  from autogluon.core.metrics import Scorer
14
13
 
14
+ from ..config.config_run import ConfigRun
15
+ from ..data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
16
+ from ..data.preprocessor import Preprocessor
17
+ from ..results.prediction_metrics import PredictionMetrics
15
18
  from .callbacks import Checkpoint, EarlyStopping, TrackOutput
16
19
  from .collator import CollatorWithPadding
17
20
  from .enums import Task
@@ -19,16 +22,11 @@ from .get_loss import get_loss
19
22
  from .get_optimizer import get_optimizer
20
23
  from .get_scheduler import get_scheduler
21
24
  from .y_transformer import create_y_transformer
22
- from ..config.config_run import ConfigRun
23
- from ..data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
24
- from ..data.preprocessor import Preprocessor
25
- from ..results.prediction_metrics import PredictionMetrics
26
25
 
27
26
  logger = logging.getLogger(__name__)
28
27
 
29
28
 
30
29
  class TrainerFinetune(BaseEstimator):
31
-
32
30
  def __init__(
33
31
  self,
34
32
  cfg: ConfigRun,
@@ -42,18 +40,18 @@ class TrainerFinetune(BaseEstimator):
42
40
  self.model = model
43
41
  self.model.to(self.cfg.device)
44
42
  self.n_classes = n_classes
45
-
43
+
46
44
  self.loss = get_loss(self.cfg.task)
47
45
  self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
48
46
  self.scheduler = get_scheduler(self.cfg.hyperparams, self.optimizer)
49
47
  self.use_best_epoch = use_best_epoch
50
48
  self.compute_train_metrics = compute_train_metrics
51
49
 
52
- self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
53
- self.preprocessor = Preprocessor(
54
- use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
55
- use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
56
- max_features=self.cfg.hyperparams['n_features'],
50
+ self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams["early_stopping_patience"])
51
+ self.preprocessor = Preprocessor(
52
+ use_quantile_transformer=self.cfg.hyperparams["use_quantile_transformer"],
53
+ use_feature_count_scaling=self.cfg.hyperparams["use_feature_count_scaling"],
54
+ max_features=self.cfg.hyperparams["n_features"],
57
55
  task=self.cfg.task,
58
56
  )
59
57
 
@@ -70,7 +68,14 @@ class TrainerFinetune(BaseEstimator):
70
68
  self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
71
69
  self.scheduler = get_scheduler(self.cfg.hyperparams, self.optimizer)
72
70
 
73
- def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray = None, y_val: np.ndarray = None, time_limit: float = None):
71
+ def train(
72
+ self,
73
+ x_train: np.ndarray,
74
+ y_train: np.ndarray,
75
+ x_val: np.ndarray = None,
76
+ y_val: np.ndarray = None,
77
+ time_limit: float = None,
78
+ ):
74
79
  time_start = time.time()
75
80
  if self.optimizer is None:
76
81
  self.reset_optimizer()
@@ -86,27 +91,27 @@ class TrainerFinetune(BaseEstimator):
86
91
  if use_val:
87
92
  x_val = self.preprocessor.transform(x_val)
88
93
  self.y_transformer = create_y_transformer(y_train, self.cfg.task)
89
-
94
+
90
95
  dataset_train_generator = DatasetFinetuneGenerator(
91
96
  self.cfg,
92
- x = x_train,
93
- y = self.y_transformer.transform(y_train),
94
- task = self.cfg.task,
95
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
96
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
97
- split = 0.8,
97
+ x=x_train,
98
+ y=self.y_transformer.transform(y_train),
99
+ task=self.cfg.task,
100
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
101
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
102
+ split=0.8,
98
103
  random_state=rng,
99
104
  )
100
105
 
101
106
  if use_val:
102
107
  dataset_valid = DatasetFinetune(
103
108
  self.cfg,
104
- x_support = x_train,
105
- y_support = self.y_transformer.transform(y_train),
106
- x_query = x_val,
107
- y_query = y_val,
108
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
109
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
109
+ x_support=x_train,
110
+ y_support=self.y_transformer.transform(y_train),
111
+ x_query=x_val,
112
+ y_query=y_val,
113
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
114
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
110
115
  )
111
116
  loader_valid = self.make_loader(dataset_valid, training=False)
112
117
  else:
@@ -115,7 +120,7 @@ class TrainerFinetune(BaseEstimator):
115
120
  if use_val and self.use_best_epoch:
116
121
  checkpoint.reset()
117
122
 
118
- max_epochs = self.cfg.hyperparams['max_epochs']
123
+ max_epochs = self.cfg.hyperparams["max_epochs"]
119
124
 
120
125
  epoch = 0
121
126
  if max_epochs != 0 and use_val:
@@ -135,15 +140,15 @@ class TrainerFinetune(BaseEstimator):
135
140
  time_cur = time.time()
136
141
  time_elapsed = time_cur - time_start
137
142
  time_left = time_limit - time_elapsed
138
- if time_left < (time_elapsed*3+3):
143
+ if time_left < (time_elapsed * 3 + 3):
139
144
  # Fine-tuning an epoch will take longer than this, so triple the time required
140
145
  logger.log(15, "Early stopping due to running out of time...")
141
146
  max_epochs = 0
142
147
 
143
- for epoch in range(1, max_epochs+1):
144
- dataset_train = next(dataset_train_generator)
148
+ for epoch in range(1, max_epochs + 1):
149
+ dataset_train = next(dataset_train_generator)
145
150
  loader_train = self.make_loader(dataset_train, training=True)
146
-
151
+
147
152
  metrics_train = self.train_epoch(loader_train, return_metrics=self.compute_train_metrics)
148
153
  if use_val:
149
154
  metrics_valid = self.test_epoch(loader_valid, y_val)
@@ -165,7 +170,9 @@ class TrainerFinetune(BaseEstimator):
165
170
  if self.early_stopping.we_should_stop():
166
171
  logger.info("Early stopping")
167
172
  break
168
- self.scheduler.step(metrics_valid.loss) # TODO: Make scheduler work properly during refit with no val data, to mimic scheduler in OG fit
173
+ self.scheduler.step(
174
+ metrics_valid.loss
175
+ ) # TODO: Make scheduler work properly during refit with no val data, to mimic scheduler in OG fit
169
176
 
170
177
  if time_limit is not None:
171
178
  time_cur = time.time()
@@ -173,7 +180,7 @@ class TrainerFinetune(BaseEstimator):
173
180
 
174
181
  time_per_epoch = time_elapsed / epoch
175
182
  time_left = time_limit - time_elapsed
176
- if time_left < (time_per_epoch+3):
183
+ if time_left < (time_per_epoch + 3):
177
184
  logger.log(15, "Early stopping due to running out of time...")
178
185
  break
179
186
 
@@ -189,7 +196,9 @@ class TrainerFinetune(BaseEstimator):
189
196
  self.optimizer = None
190
197
  self.scheduler = None
191
198
 
192
- def train_epoch(self, dataloader: torch.utils.data.DataLoader, return_metrics: bool = False) -> PredictionMetrics | None:
199
+ def train_epoch(
200
+ self, dataloader: torch.utils.data.DataLoader, return_metrics: bool = False
201
+ ) -> PredictionMetrics | None:
193
202
  """
194
203
 
195
204
  Parameters
@@ -214,20 +223,25 @@ class TrainerFinetune(BaseEstimator):
214
223
  for batch in dataloader:
215
224
  self.optimizer.zero_grad()
216
225
 
217
- x_support = batch['x_support'].to(self.cfg.device)
218
- y_support = batch['y_support'].to(self.cfg.device)
219
- x_query = batch['x_query'].to(self.cfg.device)
220
- y_query = batch['y_query'].to(self.cfg.device)
226
+ x_support = batch["x_support"].to(self.cfg.device)
227
+ y_support = batch["y_support"].to(self.cfg.device)
228
+ x_query = batch["x_query"].to(self.cfg.device)
229
+ y_query = batch["y_query"].to(self.cfg.device)
221
230
 
222
231
  if self.cfg.task == Task.REGRESSION:
223
- x_support, y_support, x_query, y_query = x_support.float(), y_support.float(), x_query.float(), y_query.float()
232
+ x_support, y_support, x_query, y_query = (
233
+ x_support.float(),
234
+ y_support.float(),
235
+ x_query.float(),
236
+ y_query.float(),
237
+ )
224
238
 
225
239
  y_hat = self.model(x_support, y_support, x_query)
226
240
 
227
241
  if self.cfg.task == Task.REGRESSION:
228
242
  y_hat = y_hat[0, :, 0]
229
243
  else:
230
- y_hat = y_hat[0, :, :self.n_classes]
244
+ y_hat = y_hat[0, :, : self.n_classes]
231
245
 
232
246
  y_query = y_query[0, :]
233
247
 
@@ -241,7 +255,9 @@ class TrainerFinetune(BaseEstimator):
241
255
  if return_metrics:
242
256
  y_true, y_pred = output_tracker.get()
243
257
  y_pred = self.y_transformer.inverse_transform(y_pred)
244
- prediction_metrics = PredictionMetrics.from_prediction(y_pred, y_true, self.cfg.task, metric=self.stopping_metric)
258
+ prediction_metrics = PredictionMetrics.from_prediction(
259
+ y_pred, y_true, self.cfg.task, metric=self.stopping_metric
260
+ )
245
261
  return prediction_metrics
246
262
  else:
247
263
  return None
@@ -251,13 +267,16 @@ class TrainerFinetune(BaseEstimator):
251
267
  y_hat = self.predict_epoch(dataloader)
252
268
  y_hat_finish = self.y_transformer.inverse_transform(y_hat)
253
269
 
254
- prediction_metrics = PredictionMetrics.from_prediction(y_hat_finish, y_test, self.cfg.task, metric=self.stopping_metric)
270
+ prediction_metrics = PredictionMetrics.from_prediction(
271
+ y_hat_finish, y_test, self.cfg.task, metric=self.stopping_metric
272
+ )
255
273
  return prediction_metrics
256
274
 
257
275
  def _get_memory_size(self) -> int:
258
276
  import gc
259
- import sys
260
277
  import pickle
278
+ import sys
279
+
261
280
  gc.collect() # Try to avoid OOM error
262
281
  return sys.getsizeof(pickle.dumps(self, protocol=4))
263
282
 
@@ -270,24 +289,24 @@ class TrainerFinetune(BaseEstimator):
270
289
  x_query = self.preprocessor.transform(x_query)
271
290
 
272
291
  dataset = DatasetFinetune(
273
- self.cfg,
274
- x_support = x_support,
275
- y_support = self.y_transformer.transform(y_support),
276
- x_query = x_query,
277
- y_query = None,
278
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
279
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
292
+ self.cfg,
293
+ x_support=x_support,
294
+ y_support=self.y_transformer.transform(y_support),
295
+ x_query=x_query,
296
+ y_query=None,
297
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
298
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
280
299
  )
281
300
 
282
301
  loader = self.make_loader(dataset, training=False)
283
302
 
284
303
  y_hat_ensembles = []
285
304
 
286
- for _ in range(self.cfg.hyperparams['n_ensembles']):
305
+ for _ in range(self.cfg.hyperparams["n_ensembles"]):
287
306
  y_hat = self.predict_epoch(loader)
288
307
  y_hat_ensembles.append(y_hat)
289
308
 
290
- y_hat_ensembled = sum(y_hat_ensembles) / self.cfg.hyperparams['n_ensembles']
309
+ y_hat_ensembled = sum(y_hat_ensembles) / self.cfg.hyperparams["n_ensembles"]
291
310
  y_hat_finish = self.y_transformer.inverse_transform(y_hat_ensembled)
292
311
 
293
312
  return y_hat_finish
@@ -304,20 +323,19 @@ class TrainerFinetune(BaseEstimator):
304
323
 
305
324
  with torch.no_grad():
306
325
  for batch in dataloader:
307
-
308
- x_support = batch['x_support'].to(self.cfg.device)
309
- y_support = batch['y_support'].to(self.cfg.device)
310
- x_query = batch['x_query'].to(self.cfg.device)
326
+ x_support = batch["x_support"].to(self.cfg.device)
327
+ y_support = batch["y_support"].to(self.cfg.device)
328
+ x_query = batch["x_query"].to(self.cfg.device)
311
329
 
312
330
  if self.cfg.task == Task.REGRESSION:
313
331
  y_support = y_support.float()
314
332
 
315
333
  y_hat = self.model(x_support, y_support, x_query)
316
-
334
+
317
335
  if self.cfg.task == Task.REGRESSION:
318
336
  y_hat = y_hat[0, :, 0]
319
337
  else:
320
- y_hat = y_hat[0, :, :self.n_classes]
338
+ y_hat = y_hat[0, :, : self.n_classes]
321
339
 
322
340
  y_hat_list.append(einops.asnumpy(y_hat))
323
341
 
@@ -325,7 +343,6 @@ class TrainerFinetune(BaseEstimator):
325
343
  return y_hat
326
344
 
327
345
  def make_loader(self, dataset, training):
328
-
329
346
  return torch.utils.data.DataLoader(
330
347
  dataset,
331
348
  batch_size=1,
@@ -333,7 +350,5 @@ class TrainerFinetune(BaseEstimator):
333
350
  pin_memory=True,
334
351
  num_workers=0,
335
352
  drop_last=False,
336
- collate_fn=CollatorWithPadding(
337
- pad_to_n_support_samples=None
338
- )
339
- )
353
+ collate_fn=CollatorWithPadding(pad_to_n_support_samples=None),
354
+ )
@@ -1,5 +1,4 @@
1
1
  import numpy as np
2
-
3
2
  from sklearn.base import BaseEstimator, TransformerMixin
4
3
  from sklearn.pipeline import FunctionTransformer
5
4
  from sklearn.preprocessing import QuantileTransformer
@@ -22,16 +21,15 @@ def create_y_transformer(y_train: np.ndarray, task: Task) -> TransformerMixin:
22
21
 
23
22
 
24
23
  class QuantileTransformer1D(BaseEstimator, TransformerMixin):
25
-
26
24
  def __init__(self, output_distribution="normal") -> None:
27
25
  self.quantile_transformer = QuantileTransformer(output_distribution=output_distribution)
28
26
 
29
27
  def fit(self, x: np.ndarray):
30
28
  self.quantile_transformer.fit(x[:, None])
31
29
  return self
32
-
30
+
33
31
  def transform(self, x: np.ndarray):
34
32
  return self.quantile_transformer.transform(x[:, None])[:, 0]
35
-
33
+
36
34
  def inverse_transform(self, x: np.ndarray):
37
- return self.quantile_transformer.inverse_transform(x[:, None])[:, 0]
35
+ return self.quantile_transformer.inverse_transform(x[:, None])[:, 0]
@@ -19,24 +19,24 @@ class DatasetFinetune(torch.utils.data.Dataset):
19
19
  """
20
20
 
21
21
  def __init__(
22
- self,
22
+ self,
23
23
  cfg: ConfigRun,
24
- x_support: np.ndarray,
25
- y_support: np.ndarray,
26
- x_query: np.ndarray,
24
+ x_support: np.ndarray,
25
+ y_support: np.ndarray,
26
+ x_query: np.ndarray,
27
27
  y_query: Optional[np.ndarray],
28
28
  max_samples_support: int,
29
- max_samples_query: int
29
+ max_samples_query: int,
30
30
  ):
31
31
  """
32
32
  :param: max_features: number of features the tab pfn model has been trained on
33
33
  """
34
34
 
35
35
  self.cfg = cfg
36
-
36
+
37
37
  self.x_support = x_support
38
38
  self.y_support = y_support
39
- self.x_query = x_query
39
+ self.x_query = x_query
40
40
  self.y_query = y_query
41
41
 
42
42
  if self.y_query is None:
@@ -53,17 +53,11 @@ class DatasetFinetune(torch.utils.data.Dataset):
53
53
  # We push the whole training data through the model, unless it is too large
54
54
  self.support_size = min(self.max_samples_support, self.n_samples_support)
55
55
 
56
-
57
56
  def __len__(self):
58
57
  return len(self.x_queries)
59
58
 
60
59
  def __getitem__(self, idx):
61
-
62
- support_indices = np.random.choice(
63
- self.n_samples_support,
64
- size=self.support_size,
65
- replace=False
66
- )
60
+ support_indices = np.random.choice(self.n_samples_support, size=self.support_size, replace=False)
67
61
 
68
62
  x_support = self.x_support[support_indices]
69
63
  y_support = self.y_support[support_indices]
@@ -74,13 +68,11 @@ class DatasetFinetune(torch.utils.data.Dataset):
74
68
  y_query_tensor = torch.as_tensor(self.y_queries[idx])
75
69
 
76
70
  return {
77
- 'x_support': x_support_tensor,
78
- 'y_support': y_support_tensor,
79
- 'x_query': x_query_tensor,
80
- 'y_query': y_query_tensor,
71
+ "x_support": x_support_tensor,
72
+ "y_support": y_support_tensor,
73
+ "x_query": x_query_tensor,
74
+ "y_query": y_query_tensor,
81
75
  }
82
-
83
-
84
76
 
85
77
  def split_in_chunks(self, x: np.ndarray, batch_size: int) -> list[np.ndarray]:
86
78
  """
@@ -91,19 +83,15 @@ class DatasetFinetune(torch.utils.data.Dataset):
91
83
  x_chunks = []
92
84
 
93
85
  for i in range(n_chunks):
94
- x_chunks.append(x[i * batch_size: (i + 1) * batch_size])
86
+ x_chunks.append(x[i * batch_size : (i + 1) * batch_size])
95
87
 
96
88
  return x_chunks
97
-
98
-
99
-
100
-
101
89
 
102
90
 
103
91
  def DatasetFinetuneGenerator(
104
92
  cfg: ConfigRun,
105
- x: np.ndarray,
106
- y: np.ndarray,
93
+ x: np.ndarray,
94
+ y: np.ndarray,
107
95
  task: Task,
108
96
  max_samples_support: int,
109
97
  max_samples_query: int,
@@ -116,9 +104,8 @@ def DatasetFinetuneGenerator(
116
104
  Every single iteration, the generator yields a different support and query set split.
117
105
  The dataset made always has exactly one batch.
118
106
  """
119
-
120
- while True:
121
107
 
108
+ while True:
122
109
  x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, random_state=random_state)
123
110
  n_samples_support = x_support.shape[0]
124
111
  n_samples_query = x_query.shape[0]
@@ -136,4 +123,4 @@ def DatasetFinetuneGenerator(
136
123
  max_samples_query=max_samples_query,
137
124
  )
138
125
 
139
- yield dataset_finetune
126
+ yield dataset_finetune
@@ -14,26 +14,24 @@ class Preprocessor(TransformerMixin, BaseEstimator):
14
14
  """
15
15
  This class is used to preprocess the data before it is pushed through the model.
16
16
  The preprocessor assures that the data has the right shape and is normalized,
17
- This way the model always gets the same input distribution,
17
+ This way the model always gets the same input distribution,
18
18
  no matter whether the input data is synthetic or real.
19
19
 
20
20
  """
21
21
 
22
22
  def __init__(
23
- self,
24
- max_features: int,
25
- use_quantile_transformer: bool,
26
- use_feature_count_scaling: bool,
27
- task: Task,
28
- ):
29
-
23
+ self,
24
+ max_features: int,
25
+ use_quantile_transformer: bool,
26
+ use_feature_count_scaling: bool,
27
+ task: Task,
28
+ ):
30
29
  self.max_features = max_features
31
30
  self.use_quantile_transformer = use_quantile_transformer
32
31
  self.use_feature_count_scaling = use_feature_count_scaling
33
32
  self.task = task
34
-
35
- def fit(self, X: np.ndarray, y: np.ndarray):
36
33
 
34
+ def fit(self, X: np.ndarray, y: np.ndarray):
37
35
  self.compute_pre_nan_mean(X)
38
36
  X = self.impute_nan_features_with_mean(X)
39
37
 
@@ -46,26 +44,24 @@ class Preprocessor(TransformerMixin, BaseEstimator):
46
44
  if self.use_quantile_transformer:
47
45
  n_obs, n_features = X.shape
48
46
  n_quantiles = min(n_obs, 1000)
49
- self.quantile_transformer = QuantileTransformer(n_quantiles=n_quantiles, output_distribution='normal')
47
+ self.quantile_transformer = QuantileTransformer(n_quantiles=n_quantiles, output_distribution="normal")
50
48
  X = self.quantile_transformer.fit_transform(X)
51
-
49
+
52
50
  self.mean, self.std = self.calc_mean_std(X)
53
51
  X = self.normalize_by_mean_std(X, self.mean, self.std)
54
52
 
55
53
  assert np.isnan(X).sum() == 0, "There are NaNs in the data after preprocessing"
56
54
 
57
55
  return self
58
-
59
56
 
60
57
  def transform(self, X: np.ndarray):
61
-
62
58
  X = self.cutoff_singular_features(X, self.singular_features)
63
59
  X = self.impute_nan_features_with_mean(X)
64
60
  X = self.select_features(X)
65
61
 
66
62
  if self.use_quantile_transformer:
67
63
  X = self.quantile_transformer.transform(X)
68
-
64
+
69
65
  X = self.normalize_by_mean_std(X, self.mean, self.std)
70
66
 
71
67
  if self.use_feature_count_scaling:
@@ -76,16 +72,15 @@ class Preprocessor(TransformerMixin, BaseEstimator):
76
72
  assert np.isnan(X).sum() == 0, "There are NaNs in the data after preprocessing"
77
73
 
78
74
  return X
79
-
80
75
 
81
76
  def determine_which_features_are_singular(self, x: np.ndarray) -> None:
82
-
83
- self.singular_features = np.array([ len(np.unique(x_col)) for x_col in x.T ]) == 1
77
+ self.singular_features = np.array([len(np.unique(x_col)) for x_col in x.T]) == 1
84
78
 
85
79
  def determine_which_features_to_select(self, x: np.ndarray, y: np.ndarray) -> None:
86
-
87
80
  if x.shape[1] > self.max_features:
88
- logger.info(f"A maximum of {self.max_features} features are allowed, but the dataset has {x.shape[1]} features. A subset of {self.max_features} are selected using SelectKBest")
81
+ logger.info(
82
+ f"A maximum of {self.max_features} features are allowed, but the dataset has {x.shape[1]} features. A subset of {self.max_features} are selected using SelectKBest"
83
+ )
89
84
 
90
85
  if self.task == Task.CLASSIFICATION:
91
86
  self.select_k_best = SelectKBest(k=self.max_features, score_func=f_classif)
@@ -99,30 +94,23 @@ class Preprocessor(TransformerMixin, BaseEstimator):
99
94
  """
100
95
  self.pre_nan_mean = np.nanmean(x, axis=0)
101
96
 
102
-
103
97
  def impute_nan_features_with_mean(self, x: np.ndarray) -> np.ndarray:
104
-
105
98
  inds = np.where(np.isnan(x))
106
99
  x[inds] = np.take(self.pre_nan_mean, inds[1])
107
100
  return x
108
101
 
109
-
110
102
  def select_features(self, x: np.ndarray) -> np.ndarray:
111
-
112
103
  if x.shape[1] > self.max_features:
113
104
  x = self.select_k_best.transform(x)
114
105
 
115
106
  return x
116
-
117
107
 
118
108
  def cutoff_singular_features(self, x: np.ndarray, singular_features: np.ndarray) -> np.ndarray:
119
-
120
109
  if singular_features.any():
121
110
  x = x[:, ~singular_features]
122
111
 
123
112
  return x
124
113
 
125
-
126
114
  def calc_mean_std(self, x: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
127
115
  """
128
116
  Calculates the mean and std of the training data
@@ -130,7 +118,6 @@ class Preprocessor(TransformerMixin, BaseEstimator):
130
118
  mean = x.mean(axis=0)
131
119
  std = x.std(axis=0)
132
120
  return mean, std
133
-
134
121
 
135
122
  def normalize_by_mean_std(self, x: np.ndarray, mean: np.ndarray, std: np.ndarray) -> np.ndarray:
136
123
  """
@@ -140,7 +127,6 @@ class Preprocessor(TransformerMixin, BaseEstimator):
140
127
  x = (x - mean) / std
141
128
  return x
142
129
 
143
-
144
130
  def normalize_by_feature_count(self, x: np.ndarray, max_features) -> np.ndarray:
145
131
  """
146
132
  An interesting way of normalization by the tabPFN paper
@@ -149,8 +135,6 @@ class Preprocessor(TransformerMixin, BaseEstimator):
149
135
  x = x * max_features / x.shape[1]
150
136
  return x
151
137
 
152
-
153
-
154
138
  def extend_feature_dim_to_max_features(self, x: np.ndarray, max_features) -> np.ndarray:
155
139
  """
156
140
  Increases the number of features to the number of features the model has been trained on
@@ -158,7 +142,3 @@ class Preprocessor(TransformerMixin, BaseEstimator):
158
142
  added_zeros = np.zeros((x.shape[0], max_features - x.shape[1]), dtype=np.float32)
159
143
  x = np.concatenate([x, added_zeros], axis=1)
160
144
  return x
161
-
162
-
163
-
164
-