autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250715__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. autogluon/tabular/models/__init__.py +1 -0
  2. autogluon/tabular/models/catboost/catboost_model.py +9 -6
  3. autogluon/tabular/models/catboost/catboost_utils.py +10 -0
  4. autogluon/tabular/models/lgb/lgb_model.py +2 -1
  5. autogluon/tabular/models/mitra/__init__.py +0 -0
  6. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  7. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  8. autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
  9. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  10. autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
  11. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  12. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  13. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
  14. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
  15. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  16. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
  17. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
  18. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  19. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  20. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  21. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  22. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  23. autogluon/tabular/models/mitra/mitra_model.py +214 -0
  24. autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
  25. autogluon/tabular/registry/_ag_model_registry.py +2 -0
  26. autogluon/tabular/testing/fit_helper.py +2 -2
  27. autogluon/tabular/version.py +1 -1
  28. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/METADATA +21 -12
  29. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/RECORD +36 -16
  30. /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth +0 -0
  31. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/LICENSE +0 -0
  32. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/NOTICE +0 -0
  33. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/WHEEL +0 -0
  34. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/namespace_packages.txt +0 -0
  35. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/top_level.txt +0 -0
  36. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/zip-safe +0 -0
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from torch.optim import SGD, Adam, AdamW
3
+
4
+ from ..._internal.config.config_pretrain import ConfigPretrain
5
+
6
+
7
+ def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
8
+
9
+ optimizer: torch.optim.Optimizer
10
+
11
+ if hyperparams['optimizer'] == "adam":
12
+ optimizer = Adam(
13
+ model.parameters(),
14
+ lr=hyperparams['lr'],
15
+ betas=(0.9, 0.999),
16
+ weight_decay=hyperparams['weight_decay']
17
+ )
18
+ elif hyperparams['optimizer'] == "adamw":
19
+ optimizer = AdamW(
20
+ model.parameters(),
21
+ lr=hyperparams['lr'],
22
+ betas=(0.9, 0.999),
23
+ weight_decay=hyperparams['weight_decay']
24
+ )
25
+ elif hyperparams['optimizer'] == "sgd":
26
+ optimizer = SGD(
27
+ model.parameters(),
28
+ lr=hyperparams['lr'],
29
+ weight_decay=hyperparams['weight_decay']
30
+ )
31
+ else:
32
+ raise ValueError("Optimizer not recognized")
33
+
34
+ return optimizer
35
+
36
+
37
+ def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
38
+
39
+ parameters = [(name, param) for name, param in model.named_parameters()]
40
+
41
+ parameters_with_weight_decay = []
42
+ parameters_without_weight_decay = []
43
+
44
+ for name, param in parameters:
45
+ if name.endswith("bias") or 'norm' in name or 'embedding' in name:
46
+ parameters_without_weight_decay.append(param)
47
+ else:
48
+ parameters_with_weight_decay.append(param)
49
+
50
+ optimizer_parameters = [
51
+ {"params": parameters_with_weight_decay, "weight_decay": cfg.optim.weight_decay},
52
+ {"params": parameters_without_weight_decay, "weight_decay": 0.0},
53
+ ]
54
+
55
+ optimizer = torch.optim.AdamW(
56
+ optimizer_parameters,
57
+ lr=cfg.optim.lr,
58
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
59
+ weight_decay=cfg.optim.weight_decay
60
+ )
61
+
62
+ return optimizer
63
+
64
+
65
+ class GradScaler(torch.amp.GradScaler):
66
+
67
+ def __init__(
68
+ self,
69
+ enabled: bool = True,
70
+ scale_init: float = 2.**16,
71
+ scale_min: float = 1.,
72
+ growth_interval: int = 2000,
73
+ device: str = 'cuda'
74
+ ):
75
+ super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
76
+ self._enabled = enabled
77
+ self.scale_min = scale_min
78
+ self.device = device
79
+
80
+ if not self._enabled:
81
+ # We write scale=1 to log if the scaler is disabled
82
+ self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
83
+
84
+
85
+ def update(self):
86
+
87
+ if not self._enabled:
88
+ return
89
+
90
+ super().update()
91
+
92
+ if self._scale < self.scale_min:
93
+ super().update(self.scale_min)
94
+
95
+
96
+ def move_optimizer_to(optim, device):
97
+ for param in optim.state.values():
98
+ # Not sure there are any global tensors in the state dict
99
+ if isinstance(param, torch.Tensor):
100
+ param.data = param.data.to(device)
101
+ if param._grad is not None:
102
+ param._grad.data = param._grad.data.to(device)
103
+ elif isinstance(param, dict):
104
+ for subparam in param.values():
105
+ if isinstance(subparam, torch.Tensor):
106
+ subparam.data = subparam.data.to(device)
107
+ if subparam._grad is not None:
108
+ subparam._grad.data = subparam._grad.data.to(device)
@@ -0,0 +1,67 @@
1
+ import torch
2
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
3
+ from transformers import get_constant_schedule_with_warmup
4
+ from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
5
+
6
+ from ..._internal.config.config_pretrain import ConfigPretrain
7
+
8
+
9
+ def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
10
+
11
+ warmup_steps = hyperparams['warmup_steps']
12
+
13
+ # if warmup_steps > 0:
14
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
15
+ # optimizer, lambda step: min((step + 1) / warmup_steps, 1.0)
16
+ # )
17
+ # else:
18
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
19
+ # optimizer, lambda step: 1.0
20
+ # )
21
+
22
+ if warmup_steps > 0:
23
+ scheduler_warmup = LinearLR(
24
+ optimizer,
25
+ start_factor=1.0 / warmup_steps,
26
+ end_factor=1.0,
27
+ total_iters=warmup_steps,
28
+ )
29
+ else:
30
+ scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
31
+
32
+ if hyperparams['lr_scheduler']:
33
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
34
+ optimizer,
35
+ patience=hyperparams['lr_scheduler_patience'],
36
+ min_lr=0,
37
+ factor=0.2
38
+ )
39
+ else:
40
+ # With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
41
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
42
+ optimizer,
43
+ patience=1000000000,
44
+ min_lr=0,
45
+ factor=0.2
46
+ )
47
+
48
+ return scheduler_warmup, scheduler_reduce_on_plateau
49
+
50
+
51
+ def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
52
+
53
+
54
+ if cfg.optim.cosine_scheduler:
55
+ schedule = get_cosine_with_min_lr_schedule_with_warmup(
56
+ optimizer,
57
+ num_warmup_steps=cfg.optim.warmup_steps,
58
+ num_training_steps=cfg.optim.steps,
59
+ min_lr_rate=0.1
60
+ )
61
+ else:
62
+ schedule = get_constant_schedule_with_warmup(
63
+ optimizer,
64
+ num_warmup_steps=cfg.optim.warmup_steps
65
+ )
66
+
67
+ return schedule
@@ -0,0 +1,134 @@
1
+ from dataclasses import dataclass
2
+ from typing import Self
3
+
4
+ import numpy as np
5
+ import scipy
6
+ import torch
7
+ from loguru import logger
8
+ from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
9
+
10
+ from ..._internal.data.preprocessor import Preprocessor
11
+ from ..._internal.config.enums import MetricName, Task
12
+
13
+
14
+ @dataclass
15
+ class PredictionMetrics():
16
+ task: Task
17
+ loss: float
18
+ score: float
19
+ metrics: dict[MetricName, float]
20
+
21
+
22
+ @classmethod
23
+ def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> Self:
24
+
25
+ loss, score, metrics = compute_metrics(y_pred, y_true, task)
26
+
27
+ return cls(task=task, loss=loss, score=score, metrics=metrics)
28
+
29
+
30
+ def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
31
+
32
+ match task:
33
+ case Task.CLASSIFICATION:
34
+ return compute_classification_metrics(y_pred, y_true)
35
+ case Task.REGRESSION:
36
+ return compute_regression_metrics(y_pred, y_true)
37
+
38
+
39
+ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
40
+ # predictions are assumed to be log-probabilities
41
+
42
+ y_pred_class = np.argmax(y_pred, axis=1)
43
+ y_pred_proba = scipy.special.softmax(y_pred, axis=1)
44
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
45
+ labels = np.arange(y_pred_proba.shape[1])
46
+
47
+ metrics = {
48
+ MetricName.ACCURACY: (y_true == y_pred_class).mean(),
49
+ MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
50
+ MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
51
+ MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
52
+ }
53
+
54
+ loss = metrics[MetricName.LOG_LOSS]
55
+ score = metrics[MetricName.ACCURACY]
56
+
57
+ return loss, score, metrics
58
+
59
+
60
+ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
61
+ """
62
+ The roc_auc_score multi_class is not supported for binary classification
63
+ """
64
+
65
+ if np.unique(y_true).shape[0] == 1:
66
+ # AUC is not defined if there is only one class
67
+ return float('nan')
68
+
69
+ try:
70
+ if y_pred_proba.shape[1] == 2:
71
+ return roc_auc_score(y_true, y_pred_proba[:, 1])
72
+ else:
73
+ return roc_auc_score(y_true, y_pred_proba, multi_class=multi_class, average=average, labels=labels)
74
+ except ValueError as e:
75
+ logger.error(f"Error computing roc_auc_score: {e}")
76
+ logger.error(f"Returning {-1}")
77
+ return -1
78
+
79
+
80
+ def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
81
+
82
+ metrics = {
83
+ MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
84
+ MetricName.MSE: mean_squared_error(y_true, y_pred),
85
+ MetricName.MAE: np.abs(y_true - y_pred).mean(),
86
+ MetricName.R2: r2_score(y_true, y_pred)
87
+ }
88
+
89
+ loss = metrics[MetricName.MSE]
90
+ score = metrics[MetricName.R2]
91
+
92
+ return loss, score, metrics
93
+
94
+
95
+ class PredictionMetricsTracker():
96
+ """
97
+ Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
98
+ Uses torch.Tensor for predictions and true values.
99
+ """
100
+
101
+ def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
102
+
103
+ self.task = task
104
+ self.preprocessor = preprocessor
105
+ self.reset()
106
+
107
+
108
+ def reset(self) -> None:
109
+
110
+ self.ys_pred: list[np.ndarray] = []
111
+ self.ys_true: list[np.ndarray] = []
112
+
113
+
114
+ def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
115
+
116
+ y_pred_np = y_pred.detach().cpu().numpy()[0]
117
+ y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
118
+
119
+ y_true_np = y_true.detach().cpu().numpy()[0]
120
+ if train:
121
+ y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
122
+ else:
123
+ y_true_ori = y_true_np
124
+
125
+ self.ys_pred.append(y_pred_ori)
126
+ self.ys_true.append(y_true_ori)
127
+
128
+
129
+ def get_metrics(self) -> PredictionMetrics:
130
+
131
+ y_pred = np.concatenate(self.ys_pred, axis=0)
132
+ y_true = np.concatenate(self.ys_true, axis=0)
133
+
134
+ return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
@@ -0,0 +1,367 @@
1
+ import time
2
+ import numpy as np
3
+ import torch
4
+ from loguru import logger
5
+ from sklearn.base import BaseEstimator
6
+ import torch.nn.functional as F
7
+
8
+ from ..._internal.config.config_run import ConfigRun
9
+ from ..._internal.core.callbacks import Checkpoint, EarlyStopping
10
+ from ..._internal.data.collator import CollatorWithPadding
11
+ from ..._internal.config.enums import MetricName, ModelName, Task, LossName
12
+ from ..._internal.core.get_loss import get_loss
13
+ from ..._internal.core.get_optimizer import get_optimizer, GradScaler
14
+ from ..._internal.core.get_scheduler import get_scheduler
15
+ from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
16
+ from ..._internal.data.preprocessor import Preprocessor
17
+ from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
18
+
19
+
20
+ class TrainerFinetune(BaseEstimator):
21
+
22
+ def __init__(
23
+ self,
24
+ cfg: ConfigRun,
25
+ model: torch.nn.Module,
26
+ n_classes: int,
27
+ device: str
28
+ ) -> None:
29
+
30
+ self.cfg = cfg
31
+ self.device = device
32
+ self.model = model.to(self.device, non_blocking=True)
33
+ self.n_classes = n_classes
34
+
35
+ self.loss = get_loss(self.cfg)
36
+ self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
37
+ self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
38
+ self.scaler = GradScaler(
39
+ enabled=self.cfg.hyperparams['grad_scaler_enabled'],
40
+ scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
41
+ scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
42
+ growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
43
+ device=self.device
44
+ )
45
+
46
+ self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
47
+ self.checkpoint = Checkpoint()
48
+ self.preprocessor = Preprocessor(
49
+ dim_embedding=self.cfg.hyperparams['dim_embedding'],
50
+ n_classes=self.n_classes,
51
+ dim_output=self.cfg.hyperparams['dim_output'],
52
+ use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
53
+ use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
54
+ use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
55
+ shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
56
+ shuffle_features=self.cfg.hyperparams['shuffle_features'],
57
+ random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
58
+ random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
59
+ task=self.cfg.task
60
+ )
61
+
62
+ self.checkpoint.reset(self.model)
63
+
64
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
65
+ self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
66
+ self.bin_width = self.bins[1] - self.bins[0]
67
+
68
+ self.metric = self.cfg.hyperparams['metric']
69
+
70
+
71
+ def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
72
+
73
+ self.preprocessor.fit(x_train, y_train)
74
+
75
+ x_train_transformed = self.preprocessor.transform_X(x_train)
76
+ y_train_transformed = self.preprocessor.transform_y(y_train)
77
+
78
+ dataset_train_generator = DatasetFinetuneGenerator(
79
+ self.cfg,
80
+ x = x_train_transformed,
81
+ y = y_train_transformed,
82
+ task = self.cfg.task,
83
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
84
+ max_samples_query = self.cfg.hyperparams['max_samples_query']
85
+ )
86
+
87
+ self.checkpoint.reset(self.model)
88
+
89
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
90
+ self.log_start_metrics(metrics_valid)
91
+ self.checkpoint(self.model, metrics_valid.loss)
92
+
93
+ start_time = time.time()
94
+
95
+ for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
96
+
97
+ dataset_train = next(dataset_train_generator)
98
+ loader_train = self.make_loader(dataset_train, training=True)
99
+ self.model.train()
100
+
101
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
102
+
103
+ for batch in loader_train:
104
+
105
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
106
+
107
+ x_support = batch['x_support'].to(self.device, non_blocking=True)
108
+ y_support = batch['y_support'].to(self.device, non_blocking=True)
109
+ x_query = batch['x_query'].to(self.device, non_blocking=True)
110
+ y_query = batch['y_query'].to(self.device, non_blocking=True)
111
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
112
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
113
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
114
+
115
+ # Convert numerical y_support to bin ids
116
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
117
+ y_support = torch.bucketize(y_support, self.bins) - 1
118
+ y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
119
+ y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
120
+ y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
121
+
122
+ match self.cfg.model_name:
123
+ case ModelName.TABPFN:
124
+ y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
125
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
126
+ y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
127
+
128
+ # Convert numerical y_query to bin ids
129
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
130
+ loss = self.loss(y_hat, y_query_bin_ids)
131
+ elif self.cfg.task == Task.CLASSIFICATION:
132
+ # for b in range(y_support.shape[0]):
133
+ # unique_classes = len(torch.unique(torch.cat((y_support[b], y_query[b]))))
134
+ # y_hat[b, :, unique_classes:] = 0
135
+ loss = self.loss(y_hat, y_query)
136
+ else:
137
+ loss = self.loss(y_hat, y_query)
138
+
139
+ self.optimizer.zero_grad()
140
+ self.scaler.scale(loss).backward()
141
+ self.scaler.step(self.optimizer)
142
+ self.scaler.update()
143
+
144
+ # Convert bin id predictions to numerical values
145
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
146
+ y_hat = torch.argmax(y_hat, dim=-1)
147
+ y_hat = self.bins[y_hat] + self.bin_width / 2
148
+
149
+ y_hat = y_hat.float()
150
+ if self.cfg.task == Task.REGRESSION:
151
+ prediction_metrics_tracker.update(y_hat, y_query, train=True)
152
+ else:
153
+ prediction_metrics_tracker.update(y_hat, y_query, train=False)
154
+
155
+ metrics_train = prediction_metrics_tracker.get_metrics()
156
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
157
+
158
+ self.log_metrics(epoch, metrics_train, metrics_valid)
159
+
160
+ self.checkpoint(self.model, metrics_valid.loss)
161
+
162
+ self.early_stopping(metrics_valid.metrics[self.metric])
163
+ if self.early_stopping.we_should_stop():
164
+ logger.info("Early stopping")
165
+ break
166
+
167
+ if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
168
+ logger.info("Time limit reached")
169
+ break
170
+
171
+ if epoch < self.cfg.hyperparams['warmup_steps']:
172
+ self.scheduler_warmup.step()
173
+ else:
174
+ self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
175
+
176
+ self.checkpoint.set_to_best(self.model)
177
+
178
+
179
+ def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
180
+
181
+ self.model.eval()
182
+
183
+ x_support_transformed = self.preprocessor.transform_X(x_support)
184
+ x_query_transformed = self.preprocessor.transform_X(x_query)
185
+ y_support_transformed = self.preprocessor.transform_y(y_support)
186
+ # y_query_transformed = self.preprocessor.transform_y(y_query)
187
+
188
+ dataset = DatasetFinetune(
189
+ self.cfg,
190
+ x_support = x_support_transformed,
191
+ y_support = y_support_transformed,
192
+ x_query = x_query_transformed,
193
+ y_query = y_query,
194
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
195
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
196
+ )
197
+
198
+ loader = self.make_loader(dataset, training=False)
199
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
200
+
201
+ with torch.no_grad():
202
+ for batch in loader:
203
+
204
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
205
+
206
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
207
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
208
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
209
+ y_q = batch['y_query'].to(self.device, non_blocking=True)
210
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
211
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
212
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
213
+
214
+ # Convert numerical y_support to bin ids
215
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
216
+ y_s = torch.bucketize(y_s, self.bins) - 1
217
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
218
+
219
+ match self.cfg.model_name:
220
+ case ModelName.TABPFN:
221
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
222
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
223
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
224
+
225
+ # Convert bin id predictions to numerical values
226
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
227
+ y_hat = torch.argmax(y_hat, dim=-1)
228
+ y_hat = self.bins[y_hat] + self.bin_width / 2
229
+
230
+ y_hat = y_hat.float()
231
+ prediction_metrics_tracker.update(y_hat, y_q, train=False)
232
+
233
+ metrics_eval = prediction_metrics_tracker.get_metrics()
234
+ return metrics_eval
235
+
236
+
237
+ def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
238
+
239
+ x_support_transformed = self.preprocessor.transform_X(x_support)
240
+ x_query_transformed = self.preprocessor.transform_X(x_query)
241
+ y_support_transformed = self.preprocessor.transform_y(y_support)
242
+
243
+ dataset = DatasetFinetune(
244
+ self.cfg,
245
+ x_support = x_support_transformed,
246
+ y_support = y_support_transformed,
247
+ x_query = x_query_transformed,
248
+ y_query = None,
249
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
250
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
251
+ )
252
+
253
+ loader = self.make_loader(dataset, training=False)
254
+ self.model.eval()
255
+
256
+ y_pred_list = []
257
+
258
+ with torch.no_grad():
259
+ for batch in loader:
260
+
261
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
262
+
263
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
264
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
265
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
266
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
267
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
268
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
269
+
270
+ # Convert numerical y_support to bin ids
271
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
272
+ y_s = torch.bucketize(y_s, self.bins) - 1
273
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
274
+
275
+ match self.cfg.model_name:
276
+ case ModelName.TABPFN:
277
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
278
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
279
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
280
+
281
+ y_hat = y_hat[0].float().cpu().numpy()
282
+
283
+ # Convert bin id predictions to numerical values
284
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
285
+ y_hat = np.argmax(y_hat, axis=-1)
286
+ y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
287
+
288
+ y_hat = self.preprocessor.inverse_transform_y(y_hat)
289
+ y_pred_list.append(y_hat)
290
+
291
+ y_pred = np.concatenate(y_pred_list, axis=0)
292
+
293
+ return y_pred
294
+
295
+
296
+ def load_params(self, path):
297
+ self.model.load_state_dict(torch.load(path))
298
+
299
+
300
+ def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
301
+
302
+ match self.cfg.model_name:
303
+ case ModelName.TABPFN:
304
+ pad_to_max_features = True
305
+ case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
306
+ pad_to_max_features = False
307
+ case _:
308
+ raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
309
+
310
+ return torch.utils.data.DataLoader(
311
+ dataset,
312
+ batch_size=1,
313
+ shuffle=training,
314
+ pin_memory=True,
315
+ num_workers=0,
316
+ drop_last=False,
317
+ collate_fn=CollatorWithPadding(
318
+ max_features=self.cfg.hyperparams['dim_embedding'],
319
+ pad_to_max_features=pad_to_max_features
320
+ ),
321
+ )
322
+
323
+
324
+ def log_start_metrics(self, metrics_valid: PredictionMetrics):
325
+
326
+ match self.cfg.task:
327
+ case Task.REGRESSION:
328
+ logger.info((
329
+ f"Epoch 000 "
330
+ f"| Train MSE: -.---- "
331
+ f"| Train MAE: -.---- "
332
+ f"| Train r2: -.---- "
333
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
334
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
335
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
336
+ ))
337
+ case Task.CLASSIFICATION:
338
+ logger.info((
339
+ f"Epoch 000 "
340
+ f"| Train CE: -.---- "
341
+ f"| Train acc: -.---- "
342
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
343
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
344
+ ))
345
+
346
+
347
+ def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
348
+
349
+ match self.cfg.task:
350
+ case Task.REGRESSION:
351
+ logger.info((
352
+ f"Epoch {epoch:03d} "
353
+ f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
354
+ f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
355
+ f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
356
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
357
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
358
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
359
+ ))
360
+ case Task.CLASSIFICATION:
361
+ logger.info((
362
+ f"Epoch {epoch:03d} "
363
+ f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
364
+ f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
365
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
366
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
367
+ ))