autogluon.tabular 1.3.2b20250715__py3-none-any.whl → 1.3.2b20250716__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +3 -3
- autogluon/tabular/models/mitra/_internal/config/enums.py +20 -3
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +22 -23
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +11 -13
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +69 -75
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +57 -57
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +23 -26
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
- autogluon/tabular/models/mitra/mitra_model.py +64 -24
- autogluon/tabular/models/mitra/sklearn_interface.py +52 -42
- autogluon/tabular/models/realmlp/realmlp_model.py +11 -3
- autogluon/tabular/models/tabicl/tabicl_model.py +4 -1
- autogluon/tabular/models/tabm/_tabm_internal.py +4 -3
- autogluon/tabular/models/tabm/tabm_model.py +7 -3
- autogluon/tabular/models/tabm/tabm_reference.py +21 -19
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +10 -9
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/METADATA +11 -11
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/RECORD +31 -25
- /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250716-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/zip-safe +0 -0
@@ -0,0 +1 @@
|
|
1
|
+
# Internal modules for MitraModel
|
@@ -0,0 +1 @@
|
|
1
|
+
# Configuration modules for MitraModel
|
@@ -1,13 +1,13 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import Self
|
5
4
|
|
6
5
|
import torch
|
7
6
|
|
8
7
|
from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
|
9
8
|
from ..._internal.config.enums import ModelName
|
10
9
|
|
10
|
+
|
11
11
|
@dataclass
|
12
12
|
class ConfigRun(ConfigSaveLoadMixin):
|
13
13
|
device: torch.device
|
@@ -22,11 +22,11 @@ class ConfigRun(ConfigSaveLoadMixin):
|
|
22
22
|
seed: int,
|
23
23
|
model_name: ModelName,
|
24
24
|
hyperparams: dict
|
25
|
-
) ->
|
25
|
+
) -> "ConfigRun":
|
26
26
|
|
27
27
|
return cls(
|
28
28
|
device=device,
|
29
29
|
seed=seed,
|
30
30
|
model_name=model_name,
|
31
31
|
hyperparams=hyperparams
|
32
|
-
|
32
|
+
)
|
@@ -1,4 +1,21 @@
|
|
1
|
-
from enum import IntEnum
|
1
|
+
from enum import IntEnum
|
2
|
+
|
3
|
+
try:
|
4
|
+
from enum import StrEnum
|
5
|
+
except ImportError:
|
6
|
+
# StrEnum is not available in Python < 3.11, so we create a compatible version
|
7
|
+
from enum import Enum
|
8
|
+
class StrEnum(str, Enum):
|
9
|
+
"""
|
10
|
+
Enum where members are also (and must be) strings
|
11
|
+
"""
|
12
|
+
def __new__(cls, value):
|
13
|
+
if not isinstance(value, str):
|
14
|
+
raise TypeError(f"{value!r} is not a string")
|
15
|
+
return super().__new__(cls, value)
|
16
|
+
|
17
|
+
def __str__(self):
|
18
|
+
return self.value
|
2
19
|
|
3
20
|
|
4
21
|
class Task(StrEnum):
|
@@ -95,7 +112,7 @@ class BenchmarkName(StrEnum):
|
|
95
112
|
NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
|
96
113
|
CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
|
97
114
|
NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
|
98
|
-
|
115
|
+
|
99
116
|
TABZILLA_HARD = "tabzilla_hard"
|
100
117
|
TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
|
101
118
|
TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
|
@@ -142,4 +159,4 @@ class MetricName(StrEnum):
|
|
142
159
|
class LossName(StrEnum):
|
143
160
|
CROSS_ENTROPY = "cross_entropy"
|
144
161
|
MSE = "mse"
|
145
|
-
MAE = "mae"
|
162
|
+
MAE = "mae"
|
@@ -0,0 +1 @@
|
|
1
|
+
# Core modules for MitraModel
|
@@ -1,10 +1,11 @@
|
|
1
|
-
import torch
|
2
1
|
import einops
|
2
|
+
import torch
|
3
3
|
|
4
4
|
from ..._internal.config.config_pretrain import ConfigPretrain
|
5
5
|
from ..._internal.config.config_run import ConfigRun
|
6
6
|
from ..._internal.config.enums import LossName, Task
|
7
7
|
|
8
|
+
|
8
9
|
class CrossEntropyLossExtraBatch(torch.nn.Module):
|
9
10
|
|
10
11
|
def __init__(self, label_smoothing: float):
|
@@ -28,28 +29,26 @@ class CrossEntropyLossExtraBatch(torch.nn.Module):
|
|
28
29
|
|
29
30
|
def get_loss(cfg: ConfigRun):
|
30
31
|
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
32
|
+
if cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MSE:
|
33
|
+
return torch.nn.MSELoss()
|
34
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MAE:
|
35
|
+
return torch.nn.L1Loss()
|
36
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
37
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
38
|
+
elif cfg.task == Task.CLASSIFICATION:
|
39
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
40
|
+
else:
|
41
|
+
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
42
42
|
|
43
43
|
def get_loss_pretrain(cfg: ConfigPretrain):
|
44
44
|
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|
45
|
+
if cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MSE:
|
46
|
+
return torch.nn.MSELoss()
|
47
|
+
elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MAE:
|
48
|
+
return torch.nn.L1Loss()
|
49
|
+
elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.CROSS_ENTROPY:
|
50
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
51
|
+
elif cfg.data.task == Task.CLASSIFICATION:
|
52
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
53
|
+
else:
|
54
|
+
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|
@@ -1,14 +1,13 @@
|
|
1
1
|
from dataclasses import dataclass
|
2
|
-
from typing import Self
|
3
2
|
|
4
3
|
import numpy as np
|
5
|
-
import scipy
|
4
|
+
import scipy.special
|
6
5
|
import torch
|
7
6
|
from loguru import logger
|
8
7
|
from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
|
9
8
|
|
10
|
-
from ..._internal.data.preprocessor import Preprocessor
|
11
9
|
from ..._internal.config.enums import MetricName, Task
|
10
|
+
from ..._internal.data.preprocessor import Preprocessor
|
12
11
|
|
13
12
|
|
14
13
|
@dataclass
|
@@ -20,21 +19,20 @@ class PredictionMetrics():
|
|
20
19
|
|
21
20
|
|
22
21
|
@classmethod
|
23
|
-
def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) ->
|
22
|
+
def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> "PredictionMetrics":
|
24
23
|
|
25
24
|
loss, score, metrics = compute_metrics(y_pred, y_true, task)
|
26
25
|
|
27
26
|
return cls(task=task, loss=loss, score=score, metrics=metrics)
|
28
27
|
|
29
|
-
|
28
|
+
|
30
29
|
def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
|
31
30
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
31
|
+
if task == Task.CLASSIFICATION:
|
32
|
+
return compute_classification_metrics(y_pred, y_true)
|
33
|
+
elif task == Task.REGRESSION:
|
34
|
+
return compute_regression_metrics(y_pred, y_true)
|
35
|
+
|
38
36
|
|
39
37
|
def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
|
40
38
|
# predictions are assumed to be log-probabilities
|
@@ -121,7 +119,7 @@ class PredictionMetricsTracker():
|
|
121
119
|
y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
|
122
120
|
else:
|
123
121
|
y_true_ori = y_true_np
|
124
|
-
|
122
|
+
|
125
123
|
self.ys_pred.append(y_pred_ori)
|
126
124
|
self.ys_true.append(y_true_ori)
|
127
125
|
|
@@ -131,4 +129,4 @@ class PredictionMetricsTracker():
|
|
131
129
|
y_pred = np.concatenate(self.ys_pred, axis=0)
|
132
130
|
y_true = np.concatenate(self.ys_true, axis=0)
|
133
131
|
|
134
|
-
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
132
|
+
return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
|
@@ -1,26 +1,26 @@
|
|
1
1
|
import time
|
2
|
+
|
2
3
|
import numpy as np
|
3
4
|
import torch
|
4
5
|
from loguru import logger
|
5
6
|
from sklearn.base import BaseEstimator
|
6
|
-
import torch.nn.functional as F
|
7
7
|
|
8
8
|
from ..._internal.config.config_run import ConfigRun
|
9
|
+
from ..._internal.config.enums import LossName, MetricName, ModelName, Task
|
9
10
|
from ..._internal.core.callbacks import Checkpoint, EarlyStopping
|
10
|
-
from ..._internal.data.collator import CollatorWithPadding
|
11
|
-
from ..._internal.config.enums import MetricName, ModelName, Task, LossName
|
12
11
|
from ..._internal.core.get_loss import get_loss
|
13
|
-
from ..._internal.core.get_optimizer import
|
12
|
+
from ..._internal.core.get_optimizer import GradScaler, get_optimizer
|
14
13
|
from ..._internal.core.get_scheduler import get_scheduler
|
14
|
+
from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
|
15
|
+
from ..._internal.data.collator import CollatorWithPadding
|
15
16
|
from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
|
16
17
|
from ..._internal.data.preprocessor import Preprocessor
|
17
|
-
from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
|
18
18
|
|
19
19
|
|
20
20
|
class TrainerFinetune(BaseEstimator):
|
21
21
|
|
22
22
|
def __init__(
|
23
|
-
self,
|
23
|
+
self,
|
24
24
|
cfg: ConfigRun,
|
25
25
|
model: torch.nn.Module,
|
26
26
|
n_classes: int,
|
@@ -31,7 +31,7 @@ class TrainerFinetune(BaseEstimator):
|
|
31
31
|
self.device = device
|
32
32
|
self.model = model.to(self.device, non_blocking=True)
|
33
33
|
self.n_classes = n_classes
|
34
|
-
|
34
|
+
|
35
35
|
self.loss = get_loss(self.cfg)
|
36
36
|
self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
|
37
37
|
self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
|
@@ -51,7 +51,7 @@ class TrainerFinetune(BaseEstimator):
|
|
51
51
|
dim_output=self.cfg.hyperparams['dim_output'],
|
52
52
|
use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
|
53
53
|
use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
|
54
|
-
use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
|
54
|
+
use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
|
55
55
|
shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
|
56
56
|
shuffle_features=self.cfg.hyperparams['shuffle_features'],
|
57
57
|
random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
|
@@ -62,19 +62,19 @@ class TrainerFinetune(BaseEstimator):
|
|
62
62
|
self.checkpoint.reset(self.model)
|
63
63
|
|
64
64
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
65
|
-
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
|
65
|
+
self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
|
66
66
|
self.bin_width = self.bins[1] - self.bins[0]
|
67
|
-
|
67
|
+
|
68
68
|
self.metric = self.cfg.hyperparams['metric']
|
69
69
|
|
70
70
|
|
71
71
|
def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
|
72
72
|
|
73
|
-
self.preprocessor.fit(x_train, y_train)
|
73
|
+
self.preprocessor.fit(x_train, y_train)
|
74
74
|
|
75
|
-
x_train_transformed = self.preprocessor.transform_X(x_train)
|
75
|
+
x_train_transformed = self.preprocessor.transform_X(x_train)
|
76
76
|
y_train_transformed = self.preprocessor.transform_y(y_train)
|
77
|
-
|
77
|
+
|
78
78
|
dataset_train_generator = DatasetFinetuneGenerator(
|
79
79
|
self.cfg,
|
80
80
|
x = x_train_transformed,
|
@@ -94,16 +94,16 @@ class TrainerFinetune(BaseEstimator):
|
|
94
94
|
|
95
95
|
for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
|
96
96
|
|
97
|
-
dataset_train = next(dataset_train_generator)
|
97
|
+
dataset_train = next(dataset_train_generator)
|
98
98
|
loader_train = self.make_loader(dataset_train, training=True)
|
99
99
|
self.model.train()
|
100
|
-
|
100
|
+
|
101
101
|
prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
|
102
102
|
|
103
103
|
for batch in loader_train:
|
104
|
-
|
104
|
+
|
105
105
|
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
106
|
-
|
106
|
+
|
107
107
|
x_support = batch['x_support'].to(self.device, non_blocking=True)
|
108
108
|
y_support = batch['y_support'].to(self.device, non_blocking=True)
|
109
109
|
x_query = batch['x_query'].to(self.device, non_blocking=True)
|
@@ -118,13 +118,12 @@ class TrainerFinetune(BaseEstimator):
|
|
118
118
|
y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
119
119
|
y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
|
120
120
|
y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
121
|
-
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
121
|
+
|
122
|
+
if self.cfg.model_name == ModelName.TABPFN:
|
123
|
+
y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
|
124
|
+
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
125
|
+
y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
|
126
|
+
|
128
127
|
# Convert numerical y_query to bin ids
|
129
128
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
130
129
|
loss = self.loss(y_hat, y_query_bin_ids)
|
@@ -143,7 +142,7 @@ class TrainerFinetune(BaseEstimator):
|
|
143
142
|
|
144
143
|
# Convert bin id predictions to numerical values
|
145
144
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
146
|
-
y_hat = torch.argmax(y_hat, dim=-1)
|
145
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
147
146
|
y_hat = self.bins[y_hat] + self.bin_width / 2
|
148
147
|
|
149
148
|
y_hat = y_hat.float()
|
@@ -153,12 +152,12 @@ class TrainerFinetune(BaseEstimator):
|
|
153
152
|
prediction_metrics_tracker.update(y_hat, y_query, train=False)
|
154
153
|
|
155
154
|
metrics_train = prediction_metrics_tracker.get_metrics()
|
156
|
-
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
155
|
+
metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
|
157
156
|
|
158
157
|
self.log_metrics(epoch, metrics_train, metrics_valid)
|
159
158
|
|
160
159
|
self.checkpoint(self.model, metrics_valid.loss)
|
161
|
-
|
160
|
+
|
162
161
|
self.early_stopping(metrics_valid.metrics[self.metric])
|
163
162
|
if self.early_stopping.we_should_stop():
|
164
163
|
logger.info("Early stopping")
|
@@ -175,9 +174,9 @@ class TrainerFinetune(BaseEstimator):
|
|
175
174
|
|
176
175
|
self.checkpoint.set_to_best(self.model)
|
177
176
|
|
178
|
-
|
177
|
+
|
179
178
|
def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
|
180
|
-
|
179
|
+
|
181
180
|
self.model.eval()
|
182
181
|
|
183
182
|
x_support_transformed = self.preprocessor.transform_X(x_support)
|
@@ -186,9 +185,9 @@ class TrainerFinetune(BaseEstimator):
|
|
186
185
|
# y_query_transformed = self.preprocessor.transform_y(y_query)
|
187
186
|
|
188
187
|
dataset = DatasetFinetune(
|
189
|
-
self.cfg,
|
190
|
-
x_support = x_support_transformed,
|
191
|
-
y_support = y_support_transformed,
|
188
|
+
self.cfg,
|
189
|
+
x_support = x_support_transformed,
|
190
|
+
y_support = y_support_transformed,
|
192
191
|
x_query = x_query_transformed,
|
193
192
|
y_query = y_query,
|
194
193
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
@@ -202,7 +201,7 @@ class TrainerFinetune(BaseEstimator):
|
|
202
201
|
for batch in loader:
|
203
202
|
|
204
203
|
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
205
|
-
|
204
|
+
|
206
205
|
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
207
206
|
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
208
207
|
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
@@ -215,16 +214,15 @@ class TrainerFinetune(BaseEstimator):
|
|
215
214
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
216
215
|
y_s = torch.bucketize(y_s, self.bins) - 1
|
217
216
|
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
217
|
+
|
218
|
+
if self.cfg.model_name == ModelName.TABPFN:
|
219
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
220
|
+
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
221
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
224
222
|
|
225
223
|
# Convert bin id predictions to numerical values
|
226
224
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
227
|
-
y_hat = torch.argmax(y_hat, dim=-1)
|
225
|
+
y_hat = torch.argmax(y_hat, dim=-1)
|
228
226
|
y_hat = self.bins[y_hat] + self.bin_width / 2
|
229
227
|
|
230
228
|
y_hat = y_hat.float()
|
@@ -232,7 +230,7 @@ class TrainerFinetune(BaseEstimator):
|
|
232
230
|
|
233
231
|
metrics_eval = prediction_metrics_tracker.get_metrics()
|
234
232
|
return metrics_eval
|
235
|
-
|
233
|
+
|
236
234
|
|
237
235
|
def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
|
238
236
|
|
@@ -241,9 +239,9 @@ class TrainerFinetune(BaseEstimator):
|
|
241
239
|
y_support_transformed = self.preprocessor.transform_y(y_support)
|
242
240
|
|
243
241
|
dataset = DatasetFinetune(
|
244
|
-
self.cfg,
|
245
|
-
x_support = x_support_transformed,
|
246
|
-
y_support = y_support_transformed,
|
242
|
+
self.cfg,
|
243
|
+
x_support = x_support_transformed,
|
244
|
+
y_support = y_support_transformed,
|
247
245
|
x_query = x_query_transformed,
|
248
246
|
y_query = None,
|
249
247
|
max_samples_support = self.cfg.hyperparams['max_samples_support'],
|
@@ -259,7 +257,7 @@ class TrainerFinetune(BaseEstimator):
|
|
259
257
|
for batch in loader:
|
260
258
|
|
261
259
|
with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
|
262
|
-
|
260
|
+
|
263
261
|
x_s = batch['x_support'].to(self.device, non_blocking=True)
|
264
262
|
y_s = batch['y_support'].to(self.device, non_blocking=True)
|
265
263
|
x_q = batch['x_query'].to(self.device, non_blocking=True)
|
@@ -271,18 +269,17 @@ class TrainerFinetune(BaseEstimator):
|
|
271
269
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
272
270
|
y_s = torch.bucketize(y_s, self.bins) - 1
|
273
271
|
y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
|
274
|
-
|
275
|
-
|
276
|
-
|
277
|
-
|
278
|
-
|
279
|
-
|
280
|
-
|
272
|
+
|
273
|
+
if self.cfg.model_name == ModelName.TABPFN:
|
274
|
+
y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
|
275
|
+
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
276
|
+
y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
|
277
|
+
|
281
278
|
y_hat = y_hat[0].float().cpu().numpy()
|
282
279
|
|
283
280
|
# Convert bin id predictions to numerical values
|
284
281
|
if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
|
285
|
-
y_hat = np.argmax(y_hat, axis=-1)
|
282
|
+
y_hat = np.argmax(y_hat, axis=-1)
|
286
283
|
y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
|
287
284
|
|
288
285
|
y_hat = self.preprocessor.inverse_transform_y(y_hat)
|
@@ -291,21 +288,20 @@ class TrainerFinetune(BaseEstimator):
|
|
291
288
|
y_pred = np.concatenate(y_pred_list, axis=0)
|
292
289
|
|
293
290
|
return y_pred
|
294
|
-
|
291
|
+
|
295
292
|
|
296
293
|
def load_params(self, path):
|
297
294
|
self.model.load_state_dict(torch.load(path))
|
298
|
-
|
295
|
+
|
299
296
|
|
300
297
|
def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
|
301
298
|
|
302
|
-
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
|
299
|
+
if self.cfg.model_name == ModelName.TABPFN:
|
300
|
+
pad_to_max_features = True
|
301
|
+
elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
|
302
|
+
pad_to_max_features = False
|
303
|
+
else:
|
304
|
+
raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
|
309
305
|
|
310
306
|
return torch.utils.data.DataLoader(
|
311
307
|
dataset,
|
@@ -319,13 +315,12 @@ class TrainerFinetune(BaseEstimator):
|
|
319
315
|
pad_to_max_features=pad_to_max_features
|
320
316
|
),
|
321
317
|
)
|
322
|
-
|
318
|
+
|
323
319
|
|
324
320
|
def log_start_metrics(self, metrics_valid: PredictionMetrics):
|
325
321
|
|
326
|
-
|
327
|
-
|
328
|
-
logger.info((
|
322
|
+
if self.cfg.task == Task.REGRESSION:
|
323
|
+
logger.info((
|
329
324
|
f"Epoch 000 "
|
330
325
|
f"| Train MSE: -.---- "
|
331
326
|
f"| Train MAE: -.---- "
|
@@ -334,21 +329,20 @@ class TrainerFinetune(BaseEstimator):
|
|
334
329
|
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
335
330
|
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
336
331
|
))
|
337
|
-
|
338
|
-
|
332
|
+
|
333
|
+
elif self.cfg.task == Task.CLASSIFICATION:
|
334
|
+
logger.info((
|
339
335
|
f"Epoch 000 "
|
340
336
|
f"| Train CE: -.---- "
|
341
337
|
f"| Train acc: -.---- "
|
342
338
|
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
343
339
|
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
344
340
|
))
|
345
|
-
|
346
341
|
|
347
342
|
def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
|
348
343
|
|
349
|
-
|
350
|
-
|
351
|
-
logger.info((
|
344
|
+
if self.cfg.task == Task.REGRESSION:
|
345
|
+
logger.info((
|
352
346
|
f"Epoch {epoch:03d} "
|
353
347
|
f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
|
354
348
|
f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
|
@@ -357,11 +351,11 @@ class TrainerFinetune(BaseEstimator):
|
|
357
351
|
f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
|
358
352
|
f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
|
359
353
|
))
|
360
|
-
|
361
|
-
|
354
|
+
elif self.cfg.task == Task.CLASSIFICATION:
|
355
|
+
logger.info((
|
362
356
|
f"Epoch {epoch:03d} "
|
363
357
|
f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
|
364
358
|
f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
|
365
359
|
f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
|
366
360
|
f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
|
367
|
-
))
|
361
|
+
))
|
@@ -0,0 +1 @@
|
|
1
|
+
# Data processing modules for MitraModel
|