autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250715__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -0
- autogluon/tabular/models/catboost/catboost_model.py +9 -6
- autogluon/tabular/models/catboost/catboost_utils.py +10 -0
- autogluon/tabular/models/lgb/lgb_model.py +2 -1
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +214 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -0
- autogluon/tabular/testing/fit_helper.py +2 -2
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/METADATA +21 -12
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/RECORD +36 -16
- /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250715.dist-info}/zip-safe +0 -0
@@ -0,0 +1,108 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.optim import SGD, Adam, AdamW
|
3
|
+
|
4
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
5
|
+
|
6
|
+
|
7
|
+
def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
|
8
|
+
|
9
|
+
optimizer: torch.optim.Optimizer
|
10
|
+
|
11
|
+
if hyperparams['optimizer'] == "adam":
|
12
|
+
optimizer = Adam(
|
13
|
+
model.parameters(),
|
14
|
+
lr=hyperparams['lr'],
|
15
|
+
betas=(0.9, 0.999),
|
16
|
+
weight_decay=hyperparams['weight_decay']
|
17
|
+
)
|
18
|
+
elif hyperparams['optimizer'] == "adamw":
|
19
|
+
optimizer = AdamW(
|
20
|
+
model.parameters(),
|
21
|
+
lr=hyperparams['lr'],
|
22
|
+
betas=(0.9, 0.999),
|
23
|
+
weight_decay=hyperparams['weight_decay']
|
24
|
+
)
|
25
|
+
elif hyperparams['optimizer'] == "sgd":
|
26
|
+
optimizer = SGD(
|
27
|
+
model.parameters(),
|
28
|
+
lr=hyperparams['lr'],
|
29
|
+
weight_decay=hyperparams['weight_decay']
|
30
|
+
)
|
31
|
+
else:
|
32
|
+
raise ValueError("Optimizer not recognized")
|
33
|
+
|
34
|
+
return optimizer
|
35
|
+
|
36
|
+
|
37
|
+
def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
|
38
|
+
|
39
|
+
parameters = [(name, param) for name, param in model.named_parameters()]
|
40
|
+
|
41
|
+
parameters_with_weight_decay = []
|
42
|
+
parameters_without_weight_decay = []
|
43
|
+
|
44
|
+
for name, param in parameters:
|
45
|
+
if name.endswith("bias") or 'norm' in name or 'embedding' in name:
|
46
|
+
parameters_without_weight_decay.append(param)
|
47
|
+
else:
|
48
|
+
parameters_with_weight_decay.append(param)
|
49
|
+
|
50
|
+
optimizer_parameters = [
|
51
|
+
{"params": parameters_with_weight_decay, "weight_decay": cfg.optim.weight_decay},
|
52
|
+
{"params": parameters_without_weight_decay, "weight_decay": 0.0},
|
53
|
+
]
|
54
|
+
|
55
|
+
optimizer = torch.optim.AdamW(
|
56
|
+
optimizer_parameters,
|
57
|
+
lr=cfg.optim.lr,
|
58
|
+
betas=(cfg.optim.beta1, cfg.optim.beta2),
|
59
|
+
weight_decay=cfg.optim.weight_decay
|
60
|
+
)
|
61
|
+
|
62
|
+
return optimizer
|
63
|
+
|
64
|
+
|
65
|
+
class GradScaler(torch.amp.GradScaler):
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
enabled: bool = True,
|
70
|
+
scale_init: float = 2.**16,
|
71
|
+
scale_min: float = 1.,
|
72
|
+
growth_interval: int = 2000,
|
73
|
+
device: str = 'cuda'
|
74
|
+
):
|
75
|
+
super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
|
76
|
+
self._enabled = enabled
|
77
|
+
self.scale_min = scale_min
|
78
|
+
self.device = device
|
79
|
+
|
80
|
+
if not self._enabled:
|
81
|
+
# We write scale=1 to log if the scaler is disabled
|
82
|
+
self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
|
83
|
+
|
84
|
+
|
85
|
+
def update(self):
|
86
|
+
|
87
|
+
if not self._enabled:
|
88
|
+
return
|
89
|
+
|
90
|
+
super().update()
|
91
|
+
|
92
|
+
if self._scale < self.scale_min:
|
93
|
+
super().update(self.scale_min)
|
94
|
+
|
95
|
+
|
96
|
+
def move_optimizer_to(optim, device):
|
97
|
+
for param in optim.state.values():
|
98
|
+
# Not sure there are any global tensors in the state dict
|
99
|
+
if isinstance(param, torch.Tensor):
|
100
|
+
param.data = param.data.to(device)
|
101
|
+
if param._grad is not None:
|
102
|
+
param._grad.data = param._grad.data.to(device)
|
103
|
+
elif isinstance(param, dict):
|
104
|
+
for subparam in param.values():
|
105
|
+
if isinstance(subparam, torch.Tensor):
|
106
|
+
subparam.data = subparam.data.to(device)
|
107
|
+
if subparam._grad is not None:
|
108
|
+
subparam._grad.data = subparam._grad.data.to(device)
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
|
3
|
+
from transformers import get_constant_schedule_with_warmup
|
4
|
+
from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
|
5
|
+
|
6
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
7
|
+
|
8
|
+
|
9
|
+
def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
|
10
|
+
|
11
|
+
warmup_steps = hyperparams['warmup_steps']
|
12
|
+
|
13
|
+
# if warmup_steps > 0:
|
14
|
+
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
15
|
+
# optimizer, lambda step: min((step + 1) / warmup_steps, 1.0)
|
16
|
+
# )
|
17
|
+
# else:
|
18
|
+
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
19
|
+
# optimizer, lambda step: 1.0
|
20
|
+
# )
|
21
|
+
|
22
|
+
if warmup_steps > 0:
|
23
|
+
scheduler_warmup = LinearLR(
|
24
|
+
optimizer,
|
25
|
+
start_factor=1.0 / warmup_steps,
|
26
|
+
end_factor=1.0,
|
27
|
+
total_iters=warmup_steps,
|
28
|
+
)
|
29
|
+
else:
|
30
|
+
scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
|
31
|
+
|
32
|
+
if hyperparams['lr_scheduler']:
|
33
|
+
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
34
|
+
optimizer,
|
35
|
+
patience=hyperparams['lr_scheduler_patience'],
|
36
|
+
min_lr=0,
|
37
|
+
factor=0.2
|
38
|
+
)
|
39
|
+
else:
|
40
|
+
# With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
|
41
|
+
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
42
|
+
optimizer,
|
43
|
+
patience=1000000000,
|
44
|
+
min_lr=0,
|
45
|
+
factor=0.2
|
46
|
+
)
|
47
|
+
|
48
|
+
return scheduler_warmup, scheduler_reduce_on_plateau
|
49
|
+
|
50
|
+
|
51
|
+
def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
|
52
|
+
|
53
|
+
|
54
|
+
if cfg.optim.cosine_scheduler:
|
55
|
+
schedule = get_cosine_with_min_lr_schedule_with_warmup(
|
56
|
+
optimizer,
|
57
|
+
num_warmup_steps=cfg.optim.warmup_steps,
|
58
|
+
num_training_steps=cfg.optim.steps,
|
59
|
+
min_lr_rate=0.1
|
60
|
+
)
|
61
|
+
else:
|
62
|
+
schedule = get_constant_schedule_with_warmup(
|
63
|
+
optimizer,
|
64
|
+
num_warmup_steps=cfg.optim.warmup_steps
|
65
|
+
)
|
66
|
+
|
67
|
+
return schedule
|
@@ -0,0 +1,134 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import Self
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import scipy
|
6
|
+
import torch
|
7
|
+
from loguru import logger
|
8
|
+
from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
|
9
|
+
|
10
|
+
from ..._internal.data.preprocessor import Preprocessor
|
11
|
+
from ..._internal.config.enums import MetricName, Task
|
12
|
+
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class PredictionMetrics():
|
16
|
+
task: Task
|
17
|
+
loss: float
|
18
|
+
score: float
|
19
|
+
metrics: dict[MetricName, float]
|
20
|
+
|
21
|
+
|
22
|
+
@classmethod
|
23
|
+
def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> Self:
|
24
|
+
|
25
|
+
loss, score, metrics = compute_metrics(y_pred, y_true, task)
|
26
|
+
|
27
|
+
return cls(task=task, loss=loss, score=score, metrics=metrics)
|
28
|
+
|
29
|
+
|
30
|
+
def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
|
31
|
+
|
32
|
+
match task:
|
33
|
+
case Task.CLASSIFICATION:
|
34
|
+
return compute_classification_metrics(y_pred, y_true)
|
35
|
+
case Task.REGRESSION:
|
36
|
+
return compute_regression_metrics(y_pred, y_true)
|
37
|
+
|
38
|
+
|
39
|
+
def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
40
|
+
# predictions are assumed to be log-probabilities
|
41
|
+
|
42
|
+
y_pred_class = np.argmax(y_pred, axis=1)
|
43
|
+
y_pred_proba = scipy.special.softmax(y_pred, axis=1)
|
44
|
+
y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
|
45
|
+
labels = np.arange(y_pred_proba.shape[1])
|
46
|
+
|
47
|
+
metrics = {
|
48
|
+
MetricName.ACCURACY: (y_true == y_pred_class).mean(),
|
49
|
+
MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
|
50
|
+
MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
|
51
|
+
MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
|
52
|
+
}
|
53
|
+
|
54
|
+
loss = metrics[MetricName.LOG_LOSS]
|
55
|
+
score = metrics[MetricName.ACCURACY]
|
56
|
+
|
57
|
+
return loss, score, metrics
|
58
|
+
|
59
|
+
|
60
|
+
def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
|
61
|
+
"""
|
62
|
+
The roc_auc_score multi_class is not supported for binary classification
|
63
|
+
"""
|
64
|
+
|
65
|
+
if np.unique(y_true).shape[0] == 1:
|
66
|
+
# AUC is not defined if there is only one class
|
67
|
+
return float('nan')
|
68
|
+
|
69
|
+
try:
|
70
|
+
if y_pred_proba.shape[1] == 2:
|
71
|
+
return roc_auc_score(y_true, y_pred_proba[:, 1])
|
72
|
+
else:
|
73
|
+
return roc_auc_score(y_true, y_pred_proba, multi_class=multi_class, average=average, labels=labels)
|
74
|
+
except ValueError as e:
|
75
|
+
logger.error(f"Error computing roc_auc_score: {e}")
|
76
|
+
logger.error(f"Returning {-1}")
|
77
|
+
return -1
|
78
|
+
|
79
|
+
|
80
|
+
def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
81
|
+
|
82
|
+
metrics = {
|
83
|
+
MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
|
84
|
+
MetricName.MSE: mean_squared_error(y_true, y_pred),
|
85
|
+
MetricName.MAE: np.abs(y_true - y_pred).mean(),
|
86
|
+
MetricName.R2: r2_score(y_true, y_pred)
|
87
|
+
}
|
88
|
+
|
89
|
+
loss = metrics[MetricName.MSE]
|
90
|
+
score = metrics[MetricName.R2]
|
91
|
+
|
92
|
+
return loss, score, metrics
|
93
|
+
|
94
|
+
|
95
|
+
class PredictionMetricsTracker():
|
96
|
+
"""
|
97
|
+
Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
|
98
|
+
Uses torch.Tensor for predictions and true values.
|
99
|
+
"""
|
100
|
+
|
101
|
+
def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
|
102
|
+
|
103
|
+
self.task = task
|
104
|
+
self.preprocessor = preprocessor
|
105
|
+
self.reset()
|
106
|
+
|
107
|
+
|
108
|
+
def reset(self) -> None:
|
109
|
+
|
110
|
+
self.ys_pred: list[np.ndarray] = []
|
111
|
+
self.ys_true: list[np.ndarray] = []
|
112
|
+
|
113
|
+
|
114
|
+
def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
|
115
|
+
|
116
|
+
y_pred_np = y_pred.detach().cpu().numpy()[0]
|
117
|
+
y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
|
118
|
+
|
119
|
+
y_true_np = y_true.detach().cpu().numpy()[0]
|
120
|
+
if train:
|
121
|
+
y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
|
122
|
+
else:
|
123
|
+
y_true_ori = y_true_np
|
124
|
+
|
125
|
+
self.ys_pred.append(y_pred_ori)
|
126
|
+
self.ys_true.append(y_true_ori)
|
127
|
+
|
128
|
+
|
129
|
+
def get_metrics(self) -> PredictionMetrics:
|
130
|
+
|
131
|
+
y_pred = np.concatenate(self.ys_pred, axis=0)
|
132
|
+
y_true = np.concatenate(self.ys_true, axis=0)
|
133
|
+
|
134
|
+
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
@@ -0,0 +1,367 @@
|
|
1
|
+
import time
|
2
|
+
import numpy as np
|
3
|
+
import torch
|
4
|
+
from loguru import logger
|
5
|
+
from sklearn.base import BaseEstimator
|
6
|
+
import torch.nn.functional as F
|
7
|
+
|
8
|
+
from ..._internal.config.config_run import ConfigRun
|
9
|
+
from ..._internal.core.callbacks import Checkpoint, EarlyStopping
|
10
|
+
from ..._internal.data.collator import CollatorWithPadding
|
11
|
+
from ..._internal.config.enums import MetricName, ModelName, Task, LossName
|
12
|
+
from ..._internal.core.get_loss import get_loss
|
13
|
+
from ..._internal.core.get_optimizer import get_optimizer, GradScaler
|
14
|
+
from ..._internal.core.get_scheduler import get_scheduler
|
15
|
+
from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
|
16
|
+
from ..._internal.data.preprocessor import Preprocessor
|
17
|
+
from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
|
18
|
+
|
19
|
+
|
20
|
+
class TrainerFinetune(BaseEstimator):
|
21
|
+
|
22
|
+
def __init__(
|
23
|
+
self,
|
24
|
+
cfg: ConfigRun,
|
25
|
+
model: torch.nn.Module,
|
26
|
+
n_classes: int,
|
27
|
+
device: str
|
28
|
+
) -> None:
|
29
|
+
|
30
|
+
self.cfg = cfg
|
31
|
+
self.device = device
|
32
|
+
self.model = model.to(self.device, non_blocking=True)
|
33
|
+
self.n_classes = n_classes
|
34
|
+
|
35
|
+
self.loss = get_loss(self.cfg)
|
36
|
+
self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
|
37
|
+
self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
|
38
|
+
self.scaler = GradScaler(
|
39
|
+
enabled=self.cfg.hyperparams['grad_scaler_enabled'],
|
40
|
+
scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
|
41
|
+
scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
|
42
|
+
growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
|
43
|
+
device=self.device
|
44
|
+
)
|
45
|
+
|
46
|
+
self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
|
47
|
+
self.checkpoint = Checkpoint()
|
48
|
+
self.preprocessor = Preprocessor(
|
49
|
+
dim_embedding=self.cfg.hyperparams['dim_embedding'],
|
50
|
+
n_classes=self.n_classes,
|
51
|
+
dim_output=self.cfg.hyperparams['dim_output'],
|
52
|
+
use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
|
53
|
+
use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
|
54
|
+
use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
|
55
|
+
shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
|
56
|
+
shuffle_features=self.cfg.hyperparams['shuffle_features'],
|
57
|
+
random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
|
58
|
+
random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
|
59
|
+
task=self.cfg.task
|
60
|
+
)
|
61
|
+
|
62
|
+
self.checkpoint.reset(self.model)
|
63
|
+
|
64
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
65
|
+
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
|
66
|
+
self.bin_width = self.bins[1] - self.bins[0]
|
67
|
+
|
68
|
+
self.metric = self.cfg.hyperparams['metric']
|
69
|
+
|
70
|
+
|
71
|
+
def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
|
72
|
+
|
73
|
+
self.preprocessor.fit(x_train, y_train)
|
74
|
+
|
75
|
+
x_train_transformed = self.preprocessor.transform_X(x_train)
|
76
|
+
y_train_transformed = self.preprocessor.transform_y(y_train)
|
77
|
+
|
78
|
+
dataset_train_generator = DatasetFinetuneGenerator(
|
79
|
+
self.cfg,
|
80
|
+
x = x_train_transformed,
|
81
|
+
y = y_train_transformed,
|
82
|
+
task = self.cfg.task,
|
83
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
84
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query']
|
85
|
+
)
|
86
|
+
|
87
|
+
self.checkpoint.reset(self.model)
|
88
|
+
|
89
|
+
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
90
|
+
self.log_start_metrics(metrics_valid)
|
91
|
+
self.checkpoint(self.model, metrics_valid.loss)
|
92
|
+
|
93
|
+
start_time = time.time()
|
94
|
+
|
95
|
+
for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
|
96
|
+
|
97
|
+
dataset_train = next(dataset_train_generator)
|
98
|
+
loader_train = self.make_loader(dataset_train, training=True)
|
99
|
+
self.model.train()
|
100
|
+
|
101
|
+
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
102
|
+
|
103
|
+
for batch in loader_train:
|
104
|
+
|
105
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
106
|
+
|
107
|
+
x_support = batch['x_support'].to(self.device, non_blocking=True)
|
108
|
+
y_support = batch['y_support'].to(self.device, non_blocking=True)
|
109
|
+
x_query = batch['x_query'].to(self.device, non_blocking=True)
|
110
|
+
y_query = batch['y_query'].to(self.device, non_blocking=True)
|
111
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
112
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
113
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
114
|
+
|
115
|
+
# Convert numerical y_support to bin ids
|
116
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
117
|
+
y_support = torch.bucketize(y_support, self.bins) - 1
|
118
|
+
y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
119
|
+
y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
|
120
|
+
y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
121
|
+
|
122
|
+
match self.cfg.model_name:
|
123
|
+
case ModelName.TABPFN:
|
124
|
+
y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
|
125
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
126
|
+
y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
|
127
|
+
|
128
|
+
# Convert numerical y_query to bin ids
|
129
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
130
|
+
loss = self.loss(y_hat, y_query_bin_ids)
|
131
|
+
elif self.cfg.task == Task.CLASSIFICATION:
|
132
|
+
# for b in range(y_support.shape[0]):
|
133
|
+
# unique_classes = len(torch.unique(torch.cat((y_support[b], y_query[b]))))
|
134
|
+
# y_hat[b, :, unique_classes:] = 0
|
135
|
+
loss = self.loss(y_hat, y_query)
|
136
|
+
else:
|
137
|
+
loss = self.loss(y_hat, y_query)
|
138
|
+
|
139
|
+
self.optimizer.zero_grad()
|
140
|
+
self.scaler.scale(loss).backward()
|
141
|
+
self.scaler.step(self.optimizer)
|
142
|
+
self.scaler.update()
|
143
|
+
|
144
|
+
# Convert bin id predictions to numerical values
|
145
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
146
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
147
|
+
y_hat = self.bins[y_hat] + self.bin_width / 2
|
148
|
+
|
149
|
+
y_hat = y_hat.float()
|
150
|
+
if self.cfg.task == Task.REGRESSION:
|
151
|
+
prediction_metrics_tracker.update(y_hat, y_query, train=True)
|
152
|
+
else:
|
153
|
+
prediction_metrics_tracker.update(y_hat, y_query, train=False)
|
154
|
+
|
155
|
+
metrics_train = prediction_metrics_tracker.get_metrics()
|
156
|
+
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
157
|
+
|
158
|
+
self.log_metrics(epoch, metrics_train, metrics_valid)
|
159
|
+
|
160
|
+
self.checkpoint(self.model, metrics_valid.loss)
|
161
|
+
|
162
|
+
self.early_stopping(metrics_valid.metrics[self.metric])
|
163
|
+
if self.early_stopping.we_should_stop():
|
164
|
+
logger.info("Early stopping")
|
165
|
+
break
|
166
|
+
|
167
|
+
if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
|
168
|
+
logger.info("Time limit reached")
|
169
|
+
break
|
170
|
+
|
171
|
+
if epoch < self.cfg.hyperparams['warmup_steps']:
|
172
|
+
self.scheduler_warmup.step()
|
173
|
+
else:
|
174
|
+
self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
|
175
|
+
|
176
|
+
self.checkpoint.set_to_best(self.model)
|
177
|
+
|
178
|
+
|
179
|
+
def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
|
180
|
+
|
181
|
+
self.model.eval()
|
182
|
+
|
183
|
+
x_support_transformed = self.preprocessor.transform_X(x_support)
|
184
|
+
x_query_transformed = self.preprocessor.transform_X(x_query)
|
185
|
+
y_support_transformed = self.preprocessor.transform_y(y_support)
|
186
|
+
# y_query_transformed = self.preprocessor.transform_y(y_query)
|
187
|
+
|
188
|
+
dataset = DatasetFinetune(
|
189
|
+
self.cfg,
|
190
|
+
x_support = x_support_transformed,
|
191
|
+
y_support = y_support_transformed,
|
192
|
+
x_query = x_query_transformed,
|
193
|
+
y_query = y_query,
|
194
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
195
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
196
|
+
)
|
197
|
+
|
198
|
+
loader = self.make_loader(dataset, training=False)
|
199
|
+
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
200
|
+
|
201
|
+
with torch.no_grad():
|
202
|
+
for batch in loader:
|
203
|
+
|
204
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
205
|
+
|
206
|
+
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
207
|
+
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
208
|
+
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
209
|
+
y_q = batch['y_query'].to(self.device, non_blocking=True)
|
210
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
211
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
212
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
213
|
+
|
214
|
+
# Convert numerical y_support to bin ids
|
215
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
216
|
+
y_s = torch.bucketize(y_s, self.bins) - 1
|
217
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
218
|
+
|
219
|
+
match self.cfg.model_name:
|
220
|
+
case ModelName.TABPFN:
|
221
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
222
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
223
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
224
|
+
|
225
|
+
# Convert bin id predictions to numerical values
|
226
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
227
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
228
|
+
y_hat = self.bins[y_hat] + self.bin_width / 2
|
229
|
+
|
230
|
+
y_hat = y_hat.float()
|
231
|
+
prediction_metrics_tracker.update(y_hat, y_q, train=False)
|
232
|
+
|
233
|
+
metrics_eval = prediction_metrics_tracker.get_metrics()
|
234
|
+
return metrics_eval
|
235
|
+
|
236
|
+
|
237
|
+
def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
|
238
|
+
|
239
|
+
x_support_transformed = self.preprocessor.transform_X(x_support)
|
240
|
+
x_query_transformed = self.preprocessor.transform_X(x_query)
|
241
|
+
y_support_transformed = self.preprocessor.transform_y(y_support)
|
242
|
+
|
243
|
+
dataset = DatasetFinetune(
|
244
|
+
self.cfg,
|
245
|
+
x_support = x_support_transformed,
|
246
|
+
y_support = y_support_transformed,
|
247
|
+
x_query = x_query_transformed,
|
248
|
+
y_query = None,
|
249
|
+
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
250
|
+
max_samples_query = self.cfg.hyperparams['max_samples_query'],
|
251
|
+
)
|
252
|
+
|
253
|
+
loader = self.make_loader(dataset, training=False)
|
254
|
+
self.model.eval()
|
255
|
+
|
256
|
+
y_pred_list = []
|
257
|
+
|
258
|
+
with torch.no_grad():
|
259
|
+
for batch in loader:
|
260
|
+
|
261
|
+
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
262
|
+
|
263
|
+
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
264
|
+
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
265
|
+
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
266
|
+
padding_features = batch['padding_features'].to(self.device, non_blocking=True)
|
267
|
+
padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
|
268
|
+
padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
|
269
|
+
|
270
|
+
# Convert numerical y_support to bin ids
|
271
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
272
|
+
y_s = torch.bucketize(y_s, self.bins) - 1
|
273
|
+
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
274
|
+
|
275
|
+
match self.cfg.model_name:
|
276
|
+
case ModelName.TABPFN:
|
277
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
278
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
279
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
280
|
+
|
281
|
+
y_hat = y_hat[0].float().cpu().numpy()
|
282
|
+
|
283
|
+
# Convert bin id predictions to numerical values
|
284
|
+
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
285
|
+
y_hat = np.argmax(y_hat, axis=-1)
|
286
|
+
y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
|
287
|
+
|
288
|
+
y_hat = self.preprocessor.inverse_transform_y(y_hat)
|
289
|
+
y_pred_list.append(y_hat)
|
290
|
+
|
291
|
+
y_pred = np.concatenate(y_pred_list, axis=0)
|
292
|
+
|
293
|
+
return y_pred
|
294
|
+
|
295
|
+
|
296
|
+
def load_params(self, path):
|
297
|
+
self.model.load_state_dict(torch.load(path))
|
298
|
+
|
299
|
+
|
300
|
+
def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
|
301
|
+
|
302
|
+
match self.cfg.model_name:
|
303
|
+
case ModelName.TABPFN:
|
304
|
+
pad_to_max_features = True
|
305
|
+
case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
|
306
|
+
pad_to_max_features = False
|
307
|
+
case _:
|
308
|
+
raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
|
309
|
+
|
310
|
+
return torch.utils.data.DataLoader(
|
311
|
+
dataset,
|
312
|
+
batch_size=1,
|
313
|
+
shuffle=training,
|
314
|
+
pin_memory=True,
|
315
|
+
num_workers=0,
|
316
|
+
drop_last=False,
|
317
|
+
collate_fn=CollatorWithPadding(
|
318
|
+
max_features=self.cfg.hyperparams['dim_embedding'],
|
319
|
+
pad_to_max_features=pad_to_max_features
|
320
|
+
),
|
321
|
+
)
|
322
|
+
|
323
|
+
|
324
|
+
def log_start_metrics(self, metrics_valid: PredictionMetrics):
|
325
|
+
|
326
|
+
match self.cfg.task:
|
327
|
+
case Task.REGRESSION:
|
328
|
+
logger.info((
|
329
|
+
f"Epoch 000 "
|
330
|
+
f"| Train MSE: -.---- "
|
331
|
+
f"| Train MAE: -.---- "
|
332
|
+
f"| Train r2: -.---- "
|
333
|
+
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
334
|
+
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
335
|
+
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
336
|
+
))
|
337
|
+
case Task.CLASSIFICATION:
|
338
|
+
logger.info((
|
339
|
+
f"Epoch 000 "
|
340
|
+
f"| Train CE: -.---- "
|
341
|
+
f"| Train acc: -.---- "
|
342
|
+
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
343
|
+
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
344
|
+
))
|
345
|
+
|
346
|
+
|
347
|
+
def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
|
348
|
+
|
349
|
+
match self.cfg.task:
|
350
|
+
case Task.REGRESSION:
|
351
|
+
logger.info((
|
352
|
+
f"Epoch {epoch:03d} "
|
353
|
+
f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
|
354
|
+
f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
|
355
|
+
f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
|
356
|
+
f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
|
357
|
+
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
358
|
+
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
359
|
+
))
|
360
|
+
case Task.CLASSIFICATION:
|
361
|
+
logger.info((
|
362
|
+
f"Epoch {epoch:03d} "
|
363
|
+
f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
|
364
|
+
f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
|
365
|
+
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
366
|
+
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
367
|
+
))
|