autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +214 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -0
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
- /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -0,0 +1,134 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Self
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import scipy
|
6
|
+
import torch
|
7
|
+
from loguru import logger
|
8
|
+
from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
|
9
|
+
|
10
|
+
from ..._internal.data.preprocessor import Preprocessor
|
11
|
+
from ..._internal.config.enums import MetricName, Task
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class PredictionMetrics():
|
16
|
+
task: Task
|
17
|
+
loss: float
|
18
|
+
score: float
|
19
|
+
metrics: dict[MetricName, float]
|
20
|
+
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> Self:
|
24
|
+
|
25
|
+
loss, score, metrics = compute_metrics(y_pred, y_true, task)
|
26
|
+
|
27
|
+
return cls(task=task, loss=loss, score=score, metrics=metrics)
|
28
|
+
|
29
|
+
|
30
|
+
def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
|
31
|
+
|
32
|
+
match task:
|
33
|
+
case Task.CLASSIFICATION:
|
34
|
+
return compute_classification_metrics(y_pred, y_true)
|
35
|
+
case Task.REGRESSION:
|
36
|
+
return compute_regression_metrics(y_pred, y_true)
|
37
|
+
|
38
|
+
|
39
|
+
def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
40
|
+
# predictions are assumed to be log-probabilities
|
41
|
+
|
42
|
+
y_pred_class = np.argmax(y_pred, axis=1)
|
43
|
+
y_pred_proba = scipy.special.softmax(y_pred, axis=1)
|
44
|
+
y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
|
45
|
+
labels = np.arange(y_pred_proba.shape[1])
|
46
|
+
|
47
|
+
metrics = {
|
48
|
+
MetricName.ACCURACY: (y_true == y_pred_class).mean(),
|
49
|
+
MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
|
50
|
+
MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
|
51
|
+
MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
|
52
|
+
}
|
53
|
+
|
54
|
+
loss = metrics[MetricName.LOG_LOSS]
|
55
|
+
score = metrics[MetricName.ACCURACY]
|
56
|
+
|
57
|
+
return loss, score, metrics
|
58
|
+
|
59
|
+
|
60
|
+
def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
|
61
|
+
"""
|
62
|
+
The roc_auc_score multi_class is not supported for binary classification
|
63
|
+
"""
|
64
|
+
|
65
|
+
if np.unique(y_true).shape[0] == 1:
|
66
|
+
# AUC is not defined if there is only one class
|
67
|
+
return float('nan')
|
68
|
+
|
69
|
+
try:
|
70
|
+
if y_pred_proba.shape[1] == 2:
|
71
|
+
return roc_auc_score(y_true, y_pred_proba[:, 1])
|
72
|
+
else:
|
73
|
+
return roc_auc_score(y_true, y_pred_proba, multi_class=multi_class, average=average, labels=labels)
|
74
|
+
except ValueError as e:
|
75
|
+
logger.error(f"Error computing roc_auc_score: {e}")
|
76
|
+
logger.error(f"Returning {-1}")
|
77
|
+
return -1
|
78
|
+
|
79
|
+
|
80
|
+
def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
81
|
+
|
82
|
+
metrics = {
|
83
|
+
MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
|
84
|
+
MetricName.MSE: mean_squared_error(y_true, y_pred),
|
85
|
+
MetricName.MAE: np.abs(y_true - y_pred).mean(),
|
86
|
+
MetricName.R2: r2_score(y_true, y_pred)
|
87
|
+
}
|
88
|
+
|
89
|
+
loss = metrics[MetricName.MSE]
|
90
|
+
score = metrics[MetricName.R2]
|
91
|
+
|
92
|
+
return loss, score, metrics
|
93
|
+
|
94
|
+
|
95
|
+
class PredictionMetricsTracker():
|
96
|
+
"""
|
97
|
+
Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
|
98
|
+
Uses torch.Tensor for predictions and true values.
|
99
|
+
"""
|
100
|
+
|
101
|
+
def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
|
102
|
+
|
103
|
+
self.task = task
|
104
|
+
self.preprocessor = preprocessor
|
105
|
+
self.reset()
|
106
|
+
|
107
|
+
|
108
|
+
def reset(self) -> None:
|
109
|
+
|
110
|
+
self.ys_pred: list[np.ndarray] = []
|
111
|
+
self.ys_true: list[np.ndarray] = []
|
112
|
+
|
113
|
+
|
114
|
+
def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
|
115
|
+
|
116
|
+
y_pred_np = y_pred.detach().cpu().numpy()[0]
|
117
|
+
y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
|
118
|
+
|
119
|
+
y_true_np = y_true.detach().cpu().numpy()[0]
|
120
|
+
if train:
|
121
|
+
y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
|
122
|
+
else:
|
123
|
+
y_true_ori = y_true_np
|
124
|
+
|
125
|
+
self.ys_pred.append(y_pred_ori)
|
126
|
+
self.ys_true.append(y_true_ori)
|
127
|
+
|
128
|
+
|
129
|
+
def get_metrics(self) -> PredictionMetrics:
|
130
|
+
|
131
|
+
y_pred = np.concatenate(self.ys_pred, axis=0)
|
132
|
+
y_true = np.concatenate(self.ys_true, axis=0)
|
133
|
+
|
134
|
+
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
@@ -0,0 +1,367 @@
|
|
1
|
+
import time
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from loguru import logger
|
5
|
+
from sklearn.base import BaseEstimator
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
from ..._internal.config.config_run import ConfigRun
|
9
|
+
from ..._internal.core.callbacks import Checkpoint, EarlyStopping
|
10
|
+
from ..._internal.data.collator import CollatorWithPadding
|
11
|
+
from ..._internal.config.enums import MetricName, ModelName, Task, LossName
|
12
|
+
from ..._internal.core.get_loss import get_loss
|
13
|
+
from ..._internal.core.get_optimizer import get_optimizer, GradScaler
|
14
|
+
from ..._internal.core.get_scheduler import get_scheduler
|
15
|
+
from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
|
16
|
+
from ..._internal.data.preprocessor import Preprocessor
|
17
|
+
from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
|
18
|
+
|
19
|
+
|
20
|
+
class TrainerFinetune(BaseEstimator):
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
cfg: ConfigRun,
|
25
|
+
model: torch.nn.Module,
|
26
|
+
n_classes: int,
|
27
|
+
device: str
|
28
|
+
) -> None:
|
29
|
+
|
30
|
+
self.cfg = cfg
|
31
|
+
self.device = device
|
32
|
+
self.model = model.to(self.device, non_blocking=True)
|
33
|
+
self.n_classes = n_classes
|
34
|
+
|
35
|
+
self.loss = get_loss(self.cfg)
|
36
|
+
self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
|
37
|
+
self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
|
38
|
+
self.scaler = GradScaler(
|
39
|
+
enabled=self.cfg.hyperparams['grad_scaler_enabled'],
|
40
|
+
scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
|
41
|
+
scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
|
42
|
+
growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
|
43
|
+
device=self.device
|
44
|
+
)
|
45
|
+
|
46
|
+
self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
|
47
|
+
self.checkpoint = Checkpoint()
|
48
|
+
self.preprocessor = Preprocessor(
|
49
|
+
dim_embedding=self.cfg.hyperparams['dim_embedding'],
|
50
|
+
n_classes=self.n_classes,
|
51
|
+
dim_output=self.cfg.hyperparams['dim_output'],
|
52
|
+
use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
|
53
|
+
use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
|
54
|
+
use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
|
55
|
+
shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
|
56
|
+
shuffle_features=self.cfg.hyperparams['shuffle_features'],
|
57
|
+
random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
|
58
|
+
random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
|
59
|
+
task=self.cfg.task
|
60
|
+
)
|
61
|
+
|
62
|
+
self.checkpoint.reset(self.model)
|
63
|
+
|
64
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
65
|
+
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
|
66
|
+
self.bin_width = self.bins[1] - self.bins[0]
|
67
|
+
|
68
|
+
self.metric = self.cfg.hyperparams['metric']
|
69
|
+
|
70
|
+
|
71
|
+
def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
|
72
|
+
|
73
|
+
self.preprocessor.fit(x_train, y_train)
|
74
|
+
|
75
|
+
x_train_transformed = self.preprocessor.transform_X(x_train)
|
76
|
+
y_train_transformed = self.preprocessor.transform_y(y_train)
|
77
|
+
|
78
|
+
dataset_train_generator = DatasetFinetuneGenerator(
|
79
|
+
self.cfg,
|
80
|
+
x = x_train_transformed,
|
81
|
+
y = y_train_transformed,
|
82
|
+
task = self.cfg.task,
|
83
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
84
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query']
|
85
|
+
)
|
86
|
+
|
87
|
+
self.checkpoint.reset(self.model)
|
88
|
+
|
89
|
+
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
90
|
+
self.log_start_metrics(metrics_valid)
|
91
|
+
self.checkpoint(self.model, metrics_valid.loss)
|
92
|
+
|
93
|
+
start_time = time.time()
|
94
|
+
|
95
|
+
for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
|
96
|
+
|
97
|
+
dataset_train = next(dataset_train_generator)
|
98
|
+
loader_train = self.make_loader(dataset_train, training=True)
|
99
|
+
self.model.train()
|
100
|
+
|
101
|
+
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
102
|
+
|
103
|
+
for batch in loader_train:
|
104
|
+
|
105
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
106
|
+
|
107
|
+
x_support = batch['x_support'].to(self.device, non_blocking=True)
|
108
|
+
y_support = batch['y_support'].to(self.device, non_blocking=True)
|
109
|
+
x_query = batch['x_query'].to(self.device, non_blocking=True)
|
110
|
+
y_query = batch['y_query'].to(self.device, non_blocking=True)
|
111
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
112
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
113
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
114
|
+
|
115
|
+
# Convert numerical y_support to bin ids
|
116
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
117
|
+
y_support = torch.bucketize(y_support, self.bins) - 1
|
118
|
+
y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
119
|
+
y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
|
120
|
+
y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
121
|
+
|
122
|
+
match self.cfg.model_name:
|
123
|
+
case ModelName.TABPFN:
|
124
|
+
y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
|
125
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
126
|
+
y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
|
127
|
+
|
128
|
+
# Convert numerical y_query to bin ids
|
129
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
130
|
+
loss = self.loss(y_hat, y_query_bin_ids)
|
131
|
+
elif self.cfg.task == Task.CLASSIFICATION:
|
132
|
+
# for b in range(y_support.shape[0]):
|
133
|
+
# unique_classes = len(torch.unique(torch.cat((y_support[b], y_query[b]))))
|
134
|
+
# y_hat[b, :, unique_classes:] = 0
|
135
|
+
loss = self.loss(y_hat, y_query)
|
136
|
+
else:
|
137
|
+
loss = self.loss(y_hat, y_query)
|
138
|
+
|
139
|
+
self.optimizer.zero_grad()
|
140
|
+
self.scaler.scale(loss).backward()
|
141
|
+
self.scaler.step(self.optimizer)
|
142
|
+
self.scaler.update()
|
143
|
+
|
144
|
+
# Convert bin id predictions to numerical values
|
145
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
146
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
147
|
+
y_hat = self.bins[y_hat] + self.bin_width / 2
|
148
|
+
|
149
|
+
y_hat = y_hat.float()
|
150
|
+
if self.cfg.task == Task.REGRESSION:
|
151
|
+
prediction_metrics_tracker.update(y_hat, y_query, train=True)
|
152
|
+
else:
|
153
|
+
prediction_metrics_tracker.update(y_hat, y_query, train=False)
|
154
|
+
|
155
|
+
metrics_train = prediction_metrics_tracker.get_metrics()
|
156
|
+
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
157
|
+
|
158
|
+
self.log_metrics(epoch, metrics_train, metrics_valid)
|
159
|
+
|
160
|
+
self.checkpoint(self.model, metrics_valid.loss)
|
161
|
+
|
162
|
+
self.early_stopping(metrics_valid.metrics[self.metric])
|
163
|
+
if self.early_stopping.we_should_stop():
|
164
|
+
logger.info("Early stopping")
|
165
|
+
break
|
166
|
+
|
167
|
+
if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
|
168
|
+
logger.info("Time limit reached")
|
169
|
+
break
|
170
|
+
|
171
|
+
if epoch < self.cfg.hyperparams['warmup_steps']:
|
172
|
+
self.scheduler_warmup.step()
|
173
|
+
else:
|
174
|
+
self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
|
175
|
+
|
176
|
+
self.checkpoint.set_to_best(self.model)
|
177
|
+
|
178
|
+
|
179
|
+
def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
|
180
|
+
|
181
|
+
self.model.eval()
|
182
|
+
|
183
|
+
x_support_transformed = self.preprocessor.transform_X(x_support)
|
184
|
+
x_query_transformed = self.preprocessor.transform_X(x_query)
|
185
|
+
y_support_transformed = self.preprocessor.transform_y(y_support)
|
186
|
+
# y_query_transformed = self.preprocessor.transform_y(y_query)
|
187
|
+
|
188
|
+
dataset = DatasetFinetune(
|
189
|
+
self.cfg,
|
190
|
+
x_support = x_support_transformed,
|
191
|
+
y_support = y_support_transformed,
|
192
|
+
x_query = x_query_transformed,
|
193
|
+
y_query = y_query,
|
194
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
195
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
196
|
+
)
|
197
|
+
|
198
|
+
loader = self.make_loader(dataset, training=False)
|
199
|
+
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
200
|
+
|
201
|
+
with torch.no_grad():
|
202
|
+
for batch in loader:
|
203
|
+
|
204
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
205
|
+
|
206
|
+
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
207
|
+
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
208
|
+
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
209
|
+
y_q = batch['y_query'].to(self.device, non_blocking=True)
|
210
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
211
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
212
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
213
|
+
|
214
|
+
# Convert numerical y_support to bin ids
|
215
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
216
|
+
y_s = torch.bucketize(y_s, self.bins) - 1
|
217
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
218
|
+
|
219
|
+
match self.cfg.model_name:
|
220
|
+
case ModelName.TABPFN:
|
221
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
222
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
223
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
224
|
+
|
225
|
+
# Convert bin id predictions to numerical values
|
226
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
227
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
228
|
+
y_hat = self.bins[y_hat] + self.bin_width / 2
|
229
|
+
|
230
|
+
y_hat = y_hat.float()
|
231
|
+
prediction_metrics_tracker.update(y_hat, y_q, train=False)
|
232
|
+
|
233
|
+
metrics_eval = prediction_metrics_tracker.get_metrics()
|
234
|
+
return metrics_eval
|
235
|
+
|
236
|
+
|
237
|
+
def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
|
238
|
+
|
239
|
+
x_support_transformed = self.preprocessor.transform_X(x_support)
|
240
|
+
x_query_transformed = self.preprocessor.transform_X(x_query)
|
241
|
+
y_support_transformed = self.preprocessor.transform_y(y_support)
|
242
|
+
|
243
|
+
dataset = DatasetFinetune(
|
244
|
+
self.cfg,
|
245
|
+
x_support = x_support_transformed,
|
246
|
+
y_support = y_support_transformed,
|
247
|
+
x_query = x_query_transformed,
|
248
|
+
y_query = None,
|
249
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
250
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
251
|
+
)
|
252
|
+
|
253
|
+
loader = self.make_loader(dataset, training=False)
|
254
|
+
self.model.eval()
|
255
|
+
|
256
|
+
y_pred_list = []
|
257
|
+
|
258
|
+
with torch.no_grad():
|
259
|
+
for batch in loader:
|
260
|
+
|
261
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
262
|
+
|
263
|
+
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
264
|
+
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
265
|
+
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
266
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
267
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
268
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
269
|
+
|
270
|
+
# Convert numerical y_support to bin ids
|
271
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
272
|
+
y_s = torch.bucketize(y_s, self.bins) - 1
|
273
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
274
|
+
|
275
|
+
match self.cfg.model_name:
|
276
|
+
case ModelName.TABPFN:
|
277
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
278
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
279
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
280
|
+
|
281
|
+
y_hat = y_hat[0].float().cpu().numpy()
|
282
|
+
|
283
|
+
# Convert bin id predictions to numerical values
|
284
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
285
|
+
y_hat = np.argmax(y_hat, axis=-1)
|
286
|
+
y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
|
287
|
+
|
288
|
+
y_hat = self.preprocessor.inverse_transform_y(y_hat)
|
289
|
+
y_pred_list.append(y_hat)
|
290
|
+
|
291
|
+
y_pred = np.concatenate(y_pred_list, axis=0)
|
292
|
+
|
293
|
+
return y_pred
|
294
|
+
|
295
|
+
|
296
|
+
def load_params(self, path):
|
297
|
+
self.model.load_state_dict(torch.load(path))
|
298
|
+
|
299
|
+
|
300
|
+
def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
|
301
|
+
|
302
|
+
match self.cfg.model_name:
|
303
|
+
case ModelName.TABPFN:
|
304
|
+
pad_to_max_features = True
|
305
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
306
|
+
pad_to_max_features = False
|
307
|
+
case _:
|
308
|
+
raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
|
309
|
+
|
310
|
+
return torch.utils.data.DataLoader(
|
311
|
+
dataset,
|
312
|
+
batch_size=1,
|
313
|
+
shuffle=training,
|
314
|
+
pin_memory=True,
|
315
|
+
num_workers=0,
|
316
|
+
drop_last=False,
|
317
|
+
collate_fn=CollatorWithPadding(
|
318
|
+
max_features=self.cfg.hyperparams['dim_embedding'],
|
319
|
+
pad_to_max_features=pad_to_max_features
|
320
|
+
),
|
321
|
+
)
|
322
|
+
|
323
|
+
|
324
|
+
def log_start_metrics(self, metrics_valid: PredictionMetrics):
|
325
|
+
|
326
|
+
match self.cfg.task:
|
327
|
+
case Task.REGRESSION:
|
328
|
+
logger.info((
|
329
|
+
f"Epoch 000 "
|
330
|
+
f"| Train MSE: -.---- "
|
331
|
+
f"| Train MAE: -.---- "
|
332
|
+
f"| Train r2: -.---- "
|
333
|
+
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
334
|
+
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
335
|
+
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
336
|
+
))
|
337
|
+
case Task.CLASSIFICATION:
|
338
|
+
logger.info((
|
339
|
+
f"Epoch 000 "
|
340
|
+
f"| Train CE: -.---- "
|
341
|
+
f"| Train acc: -.---- "
|
342
|
+
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
343
|
+
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
344
|
+
))
|
345
|
+
|
346
|
+
|
347
|
+
def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
|
348
|
+
|
349
|
+
match self.cfg.task:
|
350
|
+
case Task.REGRESSION:
|
351
|
+
logger.info((
|
352
|
+
f"Epoch {epoch:03d} "
|
353
|
+
f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
|
354
|
+
f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
|
355
|
+
f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
|
356
|
+
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
357
|
+
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
358
|
+
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
359
|
+
))
|
360
|
+
case Task.CLASSIFICATION:
|
361
|
+
logger.info((
|
362
|
+
f"Epoch {epoch:03d} "
|
363
|
+
f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
|
364
|
+
f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
|
365
|
+
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
366
|
+
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
367
|
+
))
|
@@ -0,0 +1,46 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
class CollatorWithPadding():
|
5
|
+
|
6
|
+
def __init__(
|
7
|
+
self,
|
8
|
+
max_features: int,
|
9
|
+
pad_to_max_features: bool,
|
10
|
+
) -> None:
|
11
|
+
|
12
|
+
self.max_features = max_features
|
13
|
+
self.pad_to_max_features = pad_to_max_features
|
14
|
+
|
15
|
+
|
16
|
+
def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
|
17
|
+
|
18
|
+
max_support_samples = max(dataset['x_support'].shape[0] for dataset in batch)
|
19
|
+
max_query_samples = max(dataset['x_query'].shape[0] for dataset in batch)
|
20
|
+
max_features = max(dataset['x_support'].shape[1] for dataset in batch)
|
21
|
+
|
22
|
+
if self.pad_to_max_features:
|
23
|
+
max_features = self.max_features
|
24
|
+
|
25
|
+
batch_size = len(batch)
|
26
|
+
|
27
|
+
tensor_dict = {
|
28
|
+
'x_support': torch.zeros((batch_size, max_support_samples, max_features), dtype=batch[0]['x_support'].dtype),
|
29
|
+
'y_support': torch.full((batch_size, max_support_samples), fill_value=-100, dtype=batch[0]['y_support'].dtype),
|
30
|
+
'x_query': torch.zeros((batch_size, max_query_samples, max_features), dtype=batch[0]['x_query'].dtype),
|
31
|
+
'y_query': torch.full((batch_size, max_query_samples), fill_value=-100, dtype=batch[0]['y_query'].dtype),
|
32
|
+
'padding_features': torch.ones((batch_size, max_features), dtype=torch.bool),
|
33
|
+
'padding_obs_support': torch.ones((batch_size, max_support_samples), dtype=torch.bool),
|
34
|
+
'padding_obs_query': torch.ones((batch_size, max_query_samples), dtype=torch.bool),
|
35
|
+
}
|
36
|
+
|
37
|
+
for i, dataset in enumerate(batch):
|
38
|
+
tensor_dict['x_support'][i, :dataset['x_support'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_support']
|
39
|
+
tensor_dict['y_support'][i, :dataset['y_support'].shape[0]] = dataset['y_support']
|
40
|
+
tensor_dict['x_query'][i, :dataset['x_query'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_query']
|
41
|
+
tensor_dict['y_query'][i, :dataset['y_query'].shape[0]] = dataset['y_query']
|
42
|
+
tensor_dict['padding_features'][i, :dataset['x_support'].shape[1]] = False
|
43
|
+
tensor_dict['padding_obs_support'][i, :dataset['x_support'].shape[0]] = False
|
44
|
+
tensor_dict['padding_obs_query'][i, :dataset['x_query'].shape[0]] = False
|
45
|
+
|
46
|
+
return tensor_dict
|
@@ -0,0 +1,132 @@
|
|
1
|
+
from typing import Optional
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch
|
5
|
+
|
6
|
+
from ..._internal.config.config_run import ConfigRun
|
7
|
+
from ..._internal.data.dataset_split import make_dataset_split
|
8
|
+
from ..._internal.config.enums import Task
|
9
|
+
|
10
|
+
|
11
|
+
class DatasetFinetune(torch.utils.data.Dataset):
|
12
|
+
"""
|
13
|
+
The main goal of this class is to generate a dataset for fine-tuning.
|
14
|
+
The input data are the full (x_support, y_support, x_query, y_query)
|
15
|
+
But these arrays are too large to be pushed through the model at once.
|
16
|
+
So here we split query the data into chunks if the query data is too large.
|
17
|
+
If the support data is too large, we randomly sample from it.
|
18
|
+
Furthermore, we transition from numpy to tensors.
|
19
|
+
"""
|
20
|
+
|
21
|
+
def __init__(
|
22
|
+
self,
|
23
|
+
cfg: ConfigRun,
|
24
|
+
x_support: np.ndarray,
|
25
|
+
y_support: np.ndarray,
|
26
|
+
x_query: np.ndarray,
|
27
|
+
y_query: Optional[np.ndarray],
|
28
|
+
max_samples_support: int,
|
29
|
+
max_samples_query: int
|
30
|
+
):
|
31
|
+
"""
|
32
|
+
:param: max_features: number of features the tab pfn model has been trained on
|
33
|
+
"""
|
34
|
+
|
35
|
+
self.cfg = cfg
|
36
|
+
|
37
|
+
self.x_support = x_support
|
38
|
+
self.y_support = y_support
|
39
|
+
self.x_query = x_query
|
40
|
+
self.y_query = y_query
|
41
|
+
|
42
|
+
if self.y_query is None:
|
43
|
+
self.y_query = np.zeros((self.x_query.shape[0],)) - 1
|
44
|
+
|
45
|
+
self.max_samples_support = max_samples_support
|
46
|
+
self.max_samples_query = max_samples_query
|
47
|
+
|
48
|
+
self.x_queries = self.split_in_chunks(self.x_query, max_samples_query)
|
49
|
+
self.y_queries = self.split_in_chunks(self.y_query, max_samples_query)
|
50
|
+
|
51
|
+
self.n_samples_support = self.x_support.shape[0]
|
52
|
+
|
53
|
+
# We push the whole training data through the model, unless it is too large
|
54
|
+
self.support_size = min(self.max_samples_support, self.n_samples_support)
|
55
|
+
|
56
|
+
|
57
|
+
def __len__(self):
|
58
|
+
return len(self.x_queries)
|
59
|
+
|
60
|
+
def __getitem__(self, idx):
|
61
|
+
|
62
|
+
support_indices = np.random.choice(
|
63
|
+
self.n_samples_support,
|
64
|
+
size=self.support_size,
|
65
|
+
replace=False
|
66
|
+
)
|
67
|
+
|
68
|
+
x_support = self.x_support[support_indices]
|
69
|
+
y_support = self.y_support[support_indices]
|
70
|
+
|
71
|
+
x_support_tensor = torch.as_tensor(x_support)
|
72
|
+
y_support_tensor = torch.as_tensor(y_support)
|
73
|
+
x_query_tensor = torch.as_tensor(self.x_queries[idx])
|
74
|
+
y_query_tensor = torch.as_tensor(self.y_queries[idx])
|
75
|
+
|
76
|
+
return {
|
77
|
+
'x_support': x_support_tensor,
|
78
|
+
'y_support': y_support_tensor,
|
79
|
+
'x_query': x_query_tensor,
|
80
|
+
'y_query': y_query_tensor,
|
81
|
+
}
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
def split_in_chunks(self, x: np.ndarray, batch_size: int) -> list[np.ndarray]:
|
86
|
+
"""
|
87
|
+
Splits the data into chunks of size batch_size
|
88
|
+
"""
|
89
|
+
|
90
|
+
n_chunks = int(np.ceil(x.shape[0] / batch_size))
|
91
|
+
x_chunks = []
|
92
|
+
|
93
|
+
for i in range(n_chunks):
|
94
|
+
x_chunks.append(x[i * batch_size: (i + 1) * batch_size])
|
95
|
+
|
96
|
+
return x_chunks
|
97
|
+
|
98
|
+
def DatasetFinetuneGenerator(
|
99
|
+
cfg: ConfigRun,
|
100
|
+
x: np.ndarray,
|
101
|
+
y: np.ndarray,
|
102
|
+
task: Task,
|
103
|
+
max_samples_support: int,
|
104
|
+
max_samples_query: int
|
105
|
+
):
|
106
|
+
"""
|
107
|
+
The dataset fine-tune generator is a generator that yields a dataset for fine-tuning.
|
108
|
+
The idea is to split the training dataset into a support and query set.
|
109
|
+
Every single iteration, the generator yields a different support and query set split.
|
110
|
+
The dataset made always has exactly one batch.
|
111
|
+
"""
|
112
|
+
|
113
|
+
while True:
|
114
|
+
|
115
|
+
x_support, x_query, y_support, y_query = make_dataset_split(x=x, y=y, task=task, seed=cfg.seed)
|
116
|
+
n_samples_support = x_support.shape[0]
|
117
|
+
n_samples_query = x_query.shape[0]
|
118
|
+
|
119
|
+
support_size = min(max_samples_support, n_samples_support)
|
120
|
+
query_size = min(max_samples_query, n_samples_query)
|
121
|
+
|
122
|
+
dataset_finetune = DatasetFinetune(
|
123
|
+
cfg=cfg,
|
124
|
+
x_support=x_support[:support_size],
|
125
|
+
y_support=y_support[:support_size],
|
126
|
+
x_query=x_query[:query_size],
|
127
|
+
y_query=y_query[:query_size],
|
128
|
+
max_samples_support=max_samples_support,
|
129
|
+
max_samples_query=max_samples_query,
|
130
|
+
)
|
131
|
+
|
132
|
+
yield dataset_finetune
|