autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. autogluon/tabular/models/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/__init__.py +0 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  4. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  5. autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
  6. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  7. autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
  8. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  9. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  10. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
  11. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
  12. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  13. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
  14. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
  15. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  16. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  17. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  18. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  19. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  20. autogluon/tabular/models/mitra/mitra_model.py +214 -0
  21. autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
  22. autogluon/tabular/registry/_ag_model_registry.py +2 -0
  23. autogluon/tabular/version.py +1 -1
  24. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
  25. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
  26. /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
  27. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
  28. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
  29. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
  30. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
  32. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -0,0 +1,134 @@
1
+ from dataclasses import dataclass
2
+ from typing import Self
3
+
4
+ import numpy as np
5
+ import scipy
6
+ import torch
7
+ from loguru import logger
8
+ from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
9
+
10
+ from ..._internal.data.preprocessor import Preprocessor
11
+ from ..._internal.config.enums import MetricName, Task
12
+
13
+
14
+ @dataclass
15
+ class PredictionMetrics():
16
+ task: Task
17
+ loss: float
18
+ score: float
19
+ metrics: dict[MetricName, float]
20
+
21
+
22
+ @classmethod
23
+ def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> Self:
24
+
25
+ loss, score, metrics = compute_metrics(y_pred, y_true, task)
26
+
27
+ return cls(task=task, loss=loss, score=score, metrics=metrics)
28
+
29
+
30
+ def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
31
+
32
+ match task:
33
+ case Task.CLASSIFICATION:
34
+ return compute_classification_metrics(y_pred, y_true)
35
+ case Task.REGRESSION:
36
+ return compute_regression_metrics(y_pred, y_true)
37
+
38
+
39
+ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
40
+ # predictions are assumed to be log-probabilities
41
+
42
+ y_pred_class = np.argmax(y_pred, axis=1)
43
+ y_pred_proba = scipy.special.softmax(y_pred, axis=1)
44
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
45
+ labels = np.arange(y_pred_proba.shape[1])
46
+
47
+ metrics = {
48
+ MetricName.ACCURACY: (y_true == y_pred_class).mean(),
49
+ MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
50
+ MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
51
+ MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
52
+ }
53
+
54
+ loss = metrics[MetricName.LOG_LOSS]
55
+ score = metrics[MetricName.ACCURACY]
56
+
57
+ return loss, score, metrics
58
+
59
+
60
+ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
61
+ """
62
+ The roc_auc_score multi_class is not supported for binary classification
63
+ """
64
+
65
+ if np.unique(y_true).shape[0] == 1:
66
+ # AUC is not defined if there is only one class
67
+ return float('nan')
68
+
69
+ try:
70
+ if y_pred_proba.shape[1] == 2:
71
+ return roc_auc_score(y_true, y_pred_proba[:, 1])
72
+ else:
73
+ return roc_auc_score(y_true, y_pred_proba, multi_class=multi_class, average=average, labels=labels)
74
+ except ValueError as e:
75
+ logger.error(f"Error computing roc_auc_score: {e}")
76
+ logger.error(f"Returning {-1}")
77
+ return -1
78
+
79
+
80
+ def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
81
+
82
+ metrics = {
83
+ MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
84
+ MetricName.MSE: mean_squared_error(y_true, y_pred),
85
+ MetricName.MAE: np.abs(y_true - y_pred).mean(),
86
+ MetricName.R2: r2_score(y_true, y_pred)
87
+ }
88
+
89
+ loss = metrics[MetricName.MSE]
90
+ score = metrics[MetricName.R2]
91
+
92
+ return loss, score, metrics
93
+
94
+
95
+ class PredictionMetricsTracker():
96
+ """
97
+ Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
98
+ Uses torch.Tensor for predictions and true values.
99
+ """
100
+
101
+ def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
102
+
103
+ self.task = task
104
+ self.preprocessor = preprocessor
105
+ self.reset()
106
+
107
+
108
+ def reset(self) -> None:
109
+
110
+ self.ys_pred: list[np.ndarray] = []
111
+ self.ys_true: list[np.ndarray] = []
112
+
113
+
114
+ def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
115
+
116
+ y_pred_np = y_pred.detach().cpu().numpy()[0]
117
+ y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
118
+
119
+ y_true_np = y_true.detach().cpu().numpy()[0]
120
+ if train:
121
+ y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
122
+ else:
123
+ y_true_ori = y_true_np
124
+
125
+ self.ys_pred.append(y_pred_ori)
126
+ self.ys_true.append(y_true_ori)
127
+
128
+
129
+ def get_metrics(self) -> PredictionMetrics:
130
+
131
+ y_pred = np.concatenate(self.ys_pred, axis=0)
132
+ y_true = np.concatenate(self.ys_true, axis=0)
133
+
134
+ return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
@@ -0,0 +1,367 @@
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from loguru import logger
5
+ from sklearn.base import BaseEstimator
6
+ import torch.nn.functional as F
7
+
8
+ from ..._internal.config.config_run import ConfigRun
9
+ from ..._internal.core.callbacks import Checkpoint, EarlyStopping
10
+ from ..._internal.data.collator import CollatorWithPadding
11
+ from ..._internal.config.enums import MetricName, ModelName, Task, LossName
12
+ from ..._internal.core.get_loss import get_loss
13
+ from ..._internal.core.get_optimizer import get_optimizer, GradScaler
14
+ from ..._internal.core.get_scheduler import get_scheduler
15
+ from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
16
+ from ..._internal.data.preprocessor import Preprocessor
17
+ from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
18
+
19
+
20
+ class TrainerFinetune(BaseEstimator):
21
+
22
+ def __init__(
23
+ self,
24
+ cfg: ConfigRun,
25
+ model: torch.nn.Module,
26
+ n_classes: int,
27
+ device: str
28
+ ) -> None:
29
+
30
+ self.cfg = cfg
31
+ self.device = device
32
+ self.model = model.to(self.device, non_blocking=True)
33
+ self.n_classes = n_classes
34
+
35
+ self.loss = get_loss(self.cfg)
36
+ self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
37
+ self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
38
+ self.scaler = GradScaler(
39
+ enabled=self.cfg.hyperparams['grad_scaler_enabled'],
40
+ scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
41
+ scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
42
+ growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
43
+ device=self.device
44
+ )
45
+
46
+ self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
47
+ self.checkpoint = Checkpoint()
48
+ self.preprocessor = Preprocessor(
49
+ dim_embedding=self.cfg.hyperparams['dim_embedding'],
50
+ n_classes=self.n_classes,
51
+ dim_output=self.cfg.hyperparams['dim_output'],
52
+ use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
53
+ use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
54
+ use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
55
+ shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
56
+ shuffle_features=self.cfg.hyperparams['shuffle_features'],
57
+ random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
58
+ random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
59
+ task=self.cfg.task
60
+ )
61
+
62
+ self.checkpoint.reset(self.model)
63
+
64
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
65
+ self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
66
+ self.bin_width = self.bins[1] - self.bins[0]
67
+
68
+ self.metric = self.cfg.hyperparams['metric']
69
+
70
+
71
+ def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
72
+
73
+ self.preprocessor.fit(x_train, y_train)
74
+
75
+ x_train_transformed = self.preprocessor.transform_X(x_train)
76
+ y_train_transformed = self.preprocessor.transform_y(y_train)
77
+
78
+ dataset_train_generator = DatasetFinetuneGenerator(
79
+ self.cfg,
80
+ x = x_train_transformed,
81
+ y = y_train_transformed,
82
+ task = self.cfg.task,
83
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
84
+ max_samples_query = self.cfg.hyperparams['max_samples_query']
85
+ )
86
+
87
+ self.checkpoint.reset(self.model)
88
+
89
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
90
+ self.log_start_metrics(metrics_valid)
91
+ self.checkpoint(self.model, metrics_valid.loss)
92
+
93
+ start_time = time.time()
94
+
95
+ for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
96
+
97
+ dataset_train = next(dataset_train_generator)
98
+ loader_train = self.make_loader(dataset_train, training=True)
99
+ self.model.train()
100
+
101
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
102
+
103
+ for batch in loader_train:
104
+
105
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
106
+
107
+ x_support = batch['x_support'].to(self.device, non_blocking=True)
108
+ y_support = batch['y_support'].to(self.device, non_blocking=True)
109
+ x_query = batch['x_query'].to(self.device, non_blocking=True)
110
+ y_query = batch['y_query'].to(self.device, non_blocking=True)
111
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
112
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
113
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
114
+
115
+ # Convert numerical y_support to bin ids
116
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
117
+ y_support = torch.bucketize(y_support, self.bins) - 1
118
+ y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
119
+ y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
120
+ y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
121
+
122
+ match self.cfg.model_name:
123
+ case ModelName.TABPFN:
124
+ y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
125
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
126
+ y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
127
+
128
+ # Convert numerical y_query to bin ids
129
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
130
+ loss = self.loss(y_hat, y_query_bin_ids)
131
+ elif self.cfg.task == Task.CLASSIFICATION:
132
+ # for b in range(y_support.shape[0]):
133
+ # unique_classes = len(torch.unique(torch.cat((y_support[b], y_query[b]))))
134
+ # y_hat[b, :, unique_classes:] = 0
135
+ loss = self.loss(y_hat, y_query)
136
+ else:
137
+ loss = self.loss(y_hat, y_query)
138
+
139
+ self.optimizer.zero_grad()
140
+ self.scaler.scale(loss).backward()
141
+ self.scaler.step(self.optimizer)
142
+ self.scaler.update()
143
+
144
+ # Convert bin id predictions to numerical values
145
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
146
+ y_hat = torch.argmax(y_hat, dim=-1)
147
+ y_hat = self.bins[y_hat] + self.bin_width / 2
148
+
149
+ y_hat = y_hat.float()
150
+ if self.cfg.task == Task.REGRESSION:
151
+ prediction_metrics_tracker.update(y_hat, y_query, train=True)
152
+ else:
153
+ prediction_metrics_tracker.update(y_hat, y_query, train=False)
154
+
155
+ metrics_train = prediction_metrics_tracker.get_metrics()
156
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
157
+
158
+ self.log_metrics(epoch, metrics_train, metrics_valid)
159
+
160
+ self.checkpoint(self.model, metrics_valid.loss)
161
+
162
+ self.early_stopping(metrics_valid.metrics[self.metric])
163
+ if self.early_stopping.we_should_stop():
164
+ logger.info("Early stopping")
165
+ break
166
+
167
+ if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
168
+ logger.info("Time limit reached")
169
+ break
170
+
171
+ if epoch < self.cfg.hyperparams['warmup_steps']:
172
+ self.scheduler_warmup.step()
173
+ else:
174
+ self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
175
+
176
+ self.checkpoint.set_to_best(self.model)
177
+
178
+
179
+ def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
180
+
181
+ self.model.eval()
182
+
183
+ x_support_transformed = self.preprocessor.transform_X(x_support)
184
+ x_query_transformed = self.preprocessor.transform_X(x_query)
185
+ y_support_transformed = self.preprocessor.transform_y(y_support)
186
+ # y_query_transformed = self.preprocessor.transform_y(y_query)
187
+
188
+ dataset = DatasetFinetune(
189
+ self.cfg,
190
+ x_support = x_support_transformed,
191
+ y_support = y_support_transformed,
192
+ x_query = x_query_transformed,
193
+ y_query = y_query,
194
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
195
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
196
+ )
197
+
198
+ loader = self.make_loader(dataset, training=False)
199
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
200
+
201
+ with torch.no_grad():
202
+ for batch in loader:
203
+
204
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
205
+
206
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
207
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
208
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
209
+ y_q = batch['y_query'].to(self.device, non_blocking=True)
210
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
211
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
212
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
213
+
214
+ # Convert numerical y_support to bin ids
215
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
216
+ y_s = torch.bucketize(y_s, self.bins) - 1
217
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
218
+
219
+ match self.cfg.model_name:
220
+ case ModelName.TABPFN:
221
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
222
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
223
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
224
+
225
+ # Convert bin id predictions to numerical values
226
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
227
+ y_hat = torch.argmax(y_hat, dim=-1)
228
+ y_hat = self.bins[y_hat] + self.bin_width / 2
229
+
230
+ y_hat = y_hat.float()
231
+ prediction_metrics_tracker.update(y_hat, y_q, train=False)
232
+
233
+ metrics_eval = prediction_metrics_tracker.get_metrics()
234
+ return metrics_eval
235
+
236
+
237
+ def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
238
+
239
+ x_support_transformed = self.preprocessor.transform_X(x_support)
240
+ x_query_transformed = self.preprocessor.transform_X(x_query)
241
+ y_support_transformed = self.preprocessor.transform_y(y_support)
242
+
243
+ dataset = DatasetFinetune(
244
+ self.cfg,
245
+ x_support = x_support_transformed,
246
+ y_support = y_support_transformed,
247
+ x_query = x_query_transformed,
248
+ y_query = None,
249
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
250
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
251
+ )
252
+
253
+ loader = self.make_loader(dataset, training=False)
254
+ self.model.eval()
255
+
256
+ y_pred_list = []
257
+
258
+ with torch.no_grad():
259
+ for batch in loader:
260
+
261
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
262
+
263
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
264
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
265
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
266
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
267
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
268
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
269
+
270
+ # Convert numerical y_support to bin ids
271
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
272
+ y_s = torch.bucketize(y_s, self.bins) - 1
273
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
274
+
275
+ match self.cfg.model_name:
276
+ case ModelName.TABPFN:
277
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
278
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
279
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
280
+
281
+ y_hat = y_hat[0].float().cpu().numpy()
282
+
283
+ # Convert bin id predictions to numerical values
284
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
285
+ y_hat = np.argmax(y_hat, axis=-1)
286
+ y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
287
+
288
+ y_hat = self.preprocessor.inverse_transform_y(y_hat)
289
+ y_pred_list.append(y_hat)
290
+
291
+ y_pred = np.concatenate(y_pred_list, axis=0)
292
+
293
+ return y_pred
294
+
295
+
296
+ def load_params(self, path):
297
+ self.model.load_state_dict(torch.load(path))
298
+
299
+
300
+ def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
301
+
302
+ match self.cfg.model_name:
303
+ case ModelName.TABPFN:
304
+ pad_to_max_features = True
305
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
306
+ pad_to_max_features = False
307
+ case _:
308
+ raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
309
+
310
+ return torch.utils.data.DataLoader(
311
+ dataset,
312
+ batch_size=1,
313
+ shuffle=training,
314
+ pin_memory=True,
315
+ num_workers=0,
316
+ drop_last=False,
317
+ collate_fn=CollatorWithPadding(
318
+ max_features=self.cfg.hyperparams['dim_embedding'],
319
+ pad_to_max_features=pad_to_max_features
320
+ ),
321
+ )
322
+
323
+
324
+ def log_start_metrics(self, metrics_valid: PredictionMetrics):
325
+
326
+ match self.cfg.task:
327
+ case Task.REGRESSION:
328
+ logger.info((
329
+ f"Epoch 000 "
330
+ f"| Train MSE: -.---- "
331
+ f"| Train MAE: -.---- "
332
+ f"| Train r2: -.---- "
333
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
334
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
335
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
336
+ ))
337
+ case Task.CLASSIFICATION:
338
+ logger.info((
339
+ f"Epoch 000 "
340
+ f"| Train CE: -.---- "
341
+ f"| Train acc: -.---- "
342
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
343
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
344
+ ))
345
+
346
+
347
+ def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
348
+
349
+ match self.cfg.task:
350
+ case Task.REGRESSION:
351
+ logger.info((
352
+ f"Epoch {epoch:03d} "
353
+ f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
354
+ f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
355
+ f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
356
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
357
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
358
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
359
+ ))
360
+ case Task.CLASSIFICATION:
361
+ logger.info((
362
+ f"Epoch {epoch:03d} "
363
+ f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
364
+ f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
365
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
366
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
367
+ ))
@@ -0,0 +1,46 @@
1
+ import torch
2
+
3
+
4
+ class CollatorWithPadding():
5
+
6
+ def __init__(
7
+ self,
8
+ max_features: int,
9
+ pad_to_max_features: bool,
10
+ ) -> None:
11
+
12
+ self.max_features = max_features
13
+ self.pad_to_max_features = pad_to_max_features
14
+
15
+
16
+ def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
17
+
18
+ max_support_samples = max(dataset['x_support'].shape[0] for dataset in batch)
19
+ max_query_samples = max(dataset['x_query'].shape[0] for dataset in batch)
20
+ max_features = max(dataset['x_support'].shape[1] for dataset in batch)
21
+
22
+ if self.pad_to_max_features:
23
+ max_features = self.max_features
24
+
25
+ batch_size = len(batch)
26
+
27
+ tensor_dict = {
28
+ 'x_support': torch.zeros((batch_size, max_support_samples, max_features), dtype=batch[0]['x_support'].dtype),
29
+ 'y_support': torch.full((batch_size, max_support_samples), fill_value=-100, dtype=batch[0]['y_support'].dtype),
30
+ 'x_query': torch.zeros((batch_size, max_query_samples, max_features), dtype=batch[0]['x_query'].dtype),
31
+ 'y_query': torch.full((batch_size, max_query_samples), fill_value=-100, dtype=batch[0]['y_query'].dtype),
32
+ 'padding_features': torch.ones((batch_size, max_features), dtype=torch.bool),
33
+ 'padding_obs_support': torch.ones((batch_size, max_support_samples), dtype=torch.bool),
34
+ 'padding_obs_query': torch.ones((batch_size, max_query_samples), dtype=torch.bool),
35
+ }
36
+
37
+ for i, dataset in enumerate(batch):
38
+ tensor_dict['x_support'][i, :dataset['x_support'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_support']
39
+ tensor_dict['y_support'][i, :dataset['y_support'].shape[0]] = dataset['y_support']
40
+ tensor_dict['x_query'][i, :dataset['x_query'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_query']
41
+ tensor_dict['y_query'][i, :dataset['y_query'].shape[0]] = dataset['y_query']
42
+ tensor_dict['padding_features'][i, :dataset['x_support'].shape[1]] = False
43
+ tensor_dict['padding_obs_support'][i, :dataset['x_support'].shape[0]] = False
44
+ tensor_dict['padding_obs_query'][i, :dataset['x_query'].shape[0]] = False
45
+
46
+ return tensor_dict
@@ -0,0 +1,132 @@
1
+ from typing import Optional
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from ..._internal.config.config_run import ConfigRun
7
+ from ..._internal.data.dataset_split import make_dataset_split
8
+ from ..._internal.config.enums import Task
9
+
10
+
11
+ class DatasetFinetune(torch.utils.data.Dataset):
12
+ """
13
+ The main goal of this class is to generate a dataset for fine-tuning.
14
+ The input data are the full (x_support, y_support, x_query, y_query)
15
+ But these arrays are too large to be pushed through the model at once.
16
+ So here we split query the data into chunks if the query data is too large.
17
+ If the support data is too large, we randomly sample from it.
18
+ Furthermore, we transition from numpy to tensors.
19
+ """
20
+
21
+ def __init__(
22
+ self,
23
+ cfg: ConfigRun,
24
+ x_support: np.ndarray,
25
+ y_support: np.ndarray,
26
+ x_query: np.ndarray,
27
+ y_query: Optional[np.ndarray],
28
+ max_samples_support: int,
29
+ max_samples_query: int
30
+ ):
31
+ """
32
+ :param: max_features: number of features the tab pfn model has been trained on
33
+ """
34
+
35
+ self.cfg = cfg
36
+
37
+ self.x_support = x_support
38
+ self.y_support = y_support
39
+ self.x_query = x_query
40
+ self.y_query = y_query
41
+
42
+ if self.y_query is None:
43
+ self.y_query = np.zeros((self.x_query.shape[0],)) - 1
44
+
45
+ self.max_samples_support = max_samples_support
46
+ self.max_samples_query = max_samples_query
47
+
48
+ self.x_queries = self.split_in_chunks(self.x_query, max_samples_query)
49
+ self.y_queries = self.split_in_chunks(self.y_query, max_samples_query)
50
+
51
+ self.n_samples_support = self.x_support.shape[0]
52
+
53
+ # We push the whole training data through the model, unless it is too large
54
+ self.support_size = min(self.max_samples_support, self.n_samples_support)
55
+
56
+
57
+ def __len__(self):
58
+ return len(self.x_queries)
59
+
60
+ def __getitem__(self, idx):
61
+
62
+ support_indices = np.random.choice(
63
+ self.n_samples_support,
64
+ size=self.support_size,
65
+ replace=False
66
+ )
67
+
68
+ x_support = self.x_support[support_indices]
69
+ y_support = self.y_support[support_indices]
70
+
71
+ x_support_tensor = torch.as_tensor(x_support)
72
+ y_support_tensor = torch.as_tensor(y_support)
73
+ x_query_tensor = torch.as_tensor(self.x_queries[idx])
74
+ y_query_tensor = torch.as_tensor(self.y_queries[idx])
75
+
76
+ return {
77
+ 'x_support': x_support_tensor,
78
+ 'y_support': y_support_tensor,
79
+ 'x_query': x_query_tensor,
80
+ 'y_query': y_query_tensor,
81
+ }
82
+
83
+
84
+
85
+ def split_in_chunks(self, x: np.ndarray, batch_size: int) -> list[np.ndarray]:
86
+ """
87
+ Splits the data into chunks of size batch_size
88
+ """
89
+
90
+ n_chunks = int(np.ceil(x.shape[0] / batch_size))
91
+ x_chunks = []
92
+
93
+ for i in range(n_chunks):
94
+ x_chunks.append(x[i * batch_size: (i + 1) * batch_size])
95
+
96
+ return x_chunks
97
+
98
+ def DatasetFinetuneGenerator(
99
+ cfg: ConfigRun,
100
+ x: np.ndarray,
101
+ y: np.ndarray,
102
+ task: Task,
103
+ max_samples_support: int,
104
+ max_samples_query: int
105
+ ):
106
+ """
107
+ The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
108
+ The idea is to split the training dataset into a support and query set.
109
+ Every single iteration, the generator yields a different support and query set split.
110
+ The dataset made always has exactly one batch.
111
+ """
112
+
113
+ while True:
114
+
115
+ x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=cfg.seed)
116
+ n_samples_support = x_support.shape[0]
117
+ n_samples_query = x_query.shape[0]
118
+
119
+ support_size = min(max_samples_support, n_samples_support)
120
+ query_size = min(max_samples_query, n_samples_query)
121
+
122
+ dataset_finetune = DatasetFinetune(
123
+ cfg=cfg,
124
+ x_support=x_support[:support_size],
125
+ y_support=y_support[:support_size],
126
+ x_query=x_query[:query_size],
127
+ y_query=y_query[:query_size],
128
+ max_samples_support=max_samples_support,
129
+ max_samples_query=max_samples_query,
130
+ )
131
+
132
+ yield dataset_finetune