autogluon.tabular 1.3.2b20250715__py3-none-any.whl → 1.3.2b20250716__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (31) hide show
  1. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_run.py +3 -3
  4. autogluon/tabular/models/mitra/_internal/config/enums.py +20 -3
  5. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  6. autogluon/tabular/models/mitra/_internal/core/get_loss.py +22 -23
  7. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +11 -13
  8. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +69 -75
  9. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  10. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +57 -57
  11. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  12. autogluon/tabular/models/mitra/_internal/models/tab2d.py +23 -26
  13. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  14. autogluon/tabular/models/mitra/mitra_model.py +64 -24
  15. autogluon/tabular/models/mitra/sklearn_interface.py +52 -42
  16. autogluon/tabular/models/realmlp/realmlp_model.py +11 -3
  17. autogluon/tabular/models/tabicl/tabicl_model.py +4 -1
  18. autogluon/tabular/models/tabm/_tabm_internal.py +4 -3
  19. autogluon/tabular/models/tabm/tabm_model.py +7 -3
  20. autogluon/tabular/models/tabm/tabm_reference.py +21 -19
  21. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +10 -9
  22. autogluon/tabular/version.py +1 -1
  23. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/METADATA +11 -11
  24. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/RECORD +31 -25
  25. /autogluon.tabular-1.3.2b20250715-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250716-py3.9-nspkg.pth +0 -0
  26. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/LICENSE +0 -0
  27. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/NOTICE +0 -0
  28. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/WHEEL +0 -0
  29. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/namespace_packages.txt +0 -0
  30. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/top_level.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250715.dist-info → autogluon.tabular-1.3.2b20250716.dist-info}/zip-safe +0 -0
@@ -0,0 +1 @@
1
+ # Internal modules for MitraModel
@@ -0,0 +1 @@
1
+ # Configuration modules for MitraModel
@@ -1,13 +1,13 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  from dataclasses import dataclass
4
- from typing import Self
5
4
 
6
5
  import torch
7
6
 
8
7
  from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
9
8
  from ..._internal.config.enums import ModelName
10
9
 
10
+
11
11
  @dataclass
12
12
  class ConfigRun(ConfigSaveLoadMixin):
13
13
  device: torch.device
@@ -22,11 +22,11 @@ class ConfigRun(ConfigSaveLoadMixin):
22
22
  seed: int,
23
23
  model_name: ModelName,
24
24
  hyperparams: dict
25
- ) -> Self:
25
+ ) -> "ConfigRun":
26
26
 
27
27
  return cls(
28
28
  device=device,
29
29
  seed=seed,
30
30
  model_name=model_name,
31
31
  hyperparams=hyperparams
32
- )
32
+ )
@@ -1,4 +1,21 @@
1
- from enum import IntEnum, StrEnum
1
+ from enum import IntEnum
2
+
3
+ try:
4
+ from enum import StrEnum
5
+ except ImportError:
6
+ # StrEnum is not available in Python < 3.11, so we create a compatible version
7
+ from enum import Enum
8
+ class StrEnum(str, Enum):
9
+ """
10
+ Enum where members are also (and must be) strings
11
+ """
12
+ def __new__(cls, value):
13
+ if not isinstance(value, str):
14
+ raise TypeError(f"{value!r} is not a string")
15
+ return super().__new__(cls, value)
16
+
17
+ def __str__(self):
18
+ return self.value
2
19
 
3
20
 
4
21
  class Task(StrEnum):
@@ -95,7 +112,7 @@ class BenchmarkName(StrEnum):
95
112
  NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
96
113
  CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
97
114
  NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
98
-
115
+
99
116
  TABZILLA_HARD = "tabzilla_hard"
100
117
  TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
101
118
  TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
@@ -142,4 +159,4 @@ class MetricName(StrEnum):
142
159
  class LossName(StrEnum):
143
160
  CROSS_ENTROPY = "cross_entropy"
144
161
  MSE = "mse"
145
- MAE = "mae"
162
+ MAE = "mae"
@@ -0,0 +1 @@
1
+ # Core modules for MitraModel
@@ -1,10 +1,11 @@
1
- import torch
2
1
  import einops
2
+ import torch
3
3
 
4
4
  from ..._internal.config.config_pretrain import ConfigPretrain
5
5
  from ..._internal.config.config_run import ConfigRun
6
6
  from ..._internal.config.enums import LossName, Task
7
7
 
8
+
8
9
  class CrossEntropyLossExtraBatch(torch.nn.Module):
9
10
 
10
11
  def __init__(self, label_smoothing: float):
@@ -28,28 +29,26 @@ class CrossEntropyLossExtraBatch(torch.nn.Module):
28
29
 
29
30
  def get_loss(cfg: ConfigRun):
30
31
 
31
- match (cfg.task, cfg.hyperparams['regression_loss']):
32
- case (Task.REGRESSION, LossName.MSE):
33
- return torch.nn.MSELoss()
34
- case (Task.REGRESSION, LossName.MAE):
35
- return torch.nn.L1Loss()
36
- case (Task.REGRESSION, LossName.CROSS_ENTROPY):
37
- return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
38
- case (Task.CLASSIFICATION, _):
39
- return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
40
- case (_, _):
41
- raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
32
+ if cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MSE:
33
+ return torch.nn.MSELoss()
34
+ elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.MAE:
35
+ return torch.nn.L1Loss()
36
+ elif cfg.task == Task.REGRESSION and cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
37
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
38
+ elif cfg.task == Task.CLASSIFICATION:
39
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
40
+ else:
41
+ raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
42
42
 
43
43
  def get_loss_pretrain(cfg: ConfigPretrain):
44
44
 
45
- match (cfg.data.task, cfg.optim.regression_loss):
46
- case (Task.REGRESSION, LossName.MSE):
47
- return torch.nn.MSELoss()
48
- case (Task.REGRESSION, LossName.MAE):
49
- return torch.nn.L1Loss()
50
- case (Task.REGRESSION, LossName.CROSS_ENTROPY):
51
- return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
52
- case (Task.CLASSIFICATION, _):
53
- return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
54
- case (_, _):
55
- raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
45
+ if cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MSE:
46
+ return torch.nn.MSELoss()
47
+ elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MAE:
48
+ return torch.nn.L1Loss()
49
+ elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.CROSS_ENTROPY:
50
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
51
+ elif cfg.data.task == Task.CLASSIFICATION:
52
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
53
+ else:
54
+ raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
@@ -1,14 +1,13 @@
1
1
  from dataclasses import dataclass
2
- from typing import Self
3
2
 
4
3
  import numpy as np
5
- import scipy
4
+ import scipy.special
6
5
  import torch
7
6
  from loguru import logger
8
7
  from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
9
8
 
10
- from ..._internal.data.preprocessor import Preprocessor
11
9
  from ..._internal.config.enums import MetricName, Task
10
+ from ..._internal.data.preprocessor import Preprocessor
12
11
 
13
12
 
14
13
  @dataclass
@@ -20,21 +19,20 @@ class PredictionMetrics():
20
19
 
21
20
 
22
21
  @classmethod
23
- def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> Self:
22
+ def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> "PredictionMetrics":
24
23
 
25
24
  loss, score, metrics = compute_metrics(y_pred, y_true, task)
26
25
 
27
26
  return cls(task=task, loss=loss, score=score, metrics=metrics)
28
27
 
29
-
28
+
30
29
  def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
31
30
 
32
- match task:
33
- case Task.CLASSIFICATION:
34
- return compute_classification_metrics(y_pred, y_true)
35
- case Task.REGRESSION:
36
- return compute_regression_metrics(y_pred, y_true)
37
-
31
+ if task == Task.CLASSIFICATION:
32
+ return compute_classification_metrics(y_pred, y_true)
33
+ elif task == Task.REGRESSION:
34
+ return compute_regression_metrics(y_pred, y_true)
35
+
38
36
 
39
37
  def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
40
38
  # predictions are assumed to be log-probabilities
@@ -121,7 +119,7 @@ class PredictionMetricsTracker():
121
119
  y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
122
120
  else:
123
121
  y_true_ori = y_true_np
124
-
122
+
125
123
  self.ys_pred.append(y_pred_ori)
126
124
  self.ys_true.append(y_true_ori)
127
125
 
@@ -131,4 +129,4 @@ class PredictionMetricsTracker():
131
129
  y_pred = np.concatenate(self.ys_pred, axis=0)
132
130
  y_true = np.concatenate(self.ys_true, axis=0)
133
131
 
134
- return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
132
+ return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
@@ -1,26 +1,26 @@
1
1
  import time
2
+
2
3
  import numpy as np
3
4
  import torch
4
5
  from loguru import logger
5
6
  from sklearn.base import BaseEstimator
6
- import torch.nn.functional as F
7
7
 
8
8
  from ..._internal.config.config_run import ConfigRun
9
+ from ..._internal.config.enums import LossName, MetricName, ModelName, Task
9
10
  from ..._internal.core.callbacks import Checkpoint, EarlyStopping
10
- from ..._internal.data.collator import CollatorWithPadding
11
- from ..._internal.config.enums import MetricName, ModelName, Task, LossName
12
11
  from ..._internal.core.get_loss import get_loss
13
- from ..._internal.core.get_optimizer import get_optimizer, GradScaler
12
+ from ..._internal.core.get_optimizer import GradScaler, get_optimizer
14
13
  from ..._internal.core.get_scheduler import get_scheduler
14
+ from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
15
+ from ..._internal.data.collator import CollatorWithPadding
15
16
  from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
16
17
  from ..._internal.data.preprocessor import Preprocessor
17
- from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
18
18
 
19
19
 
20
20
  class TrainerFinetune(BaseEstimator):
21
21
 
22
22
  def __init__(
23
- self,
23
+ self,
24
24
  cfg: ConfigRun,
25
25
  model: torch.nn.Module,
26
26
  n_classes: int,
@@ -31,7 +31,7 @@ class TrainerFinetune(BaseEstimator):
31
31
  self.device = device
32
32
  self.model = model.to(self.device, non_blocking=True)
33
33
  self.n_classes = n_classes
34
-
34
+
35
35
  self.loss = get_loss(self.cfg)
36
36
  self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
37
37
  self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
@@ -51,7 +51,7 @@ class TrainerFinetune(BaseEstimator):
51
51
  dim_output=self.cfg.hyperparams['dim_output'],
52
52
  use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
53
53
  use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
54
- use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
54
+ use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
55
55
  shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
56
56
  shuffle_features=self.cfg.hyperparams['shuffle_features'],
57
57
  random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
@@ -62,19 +62,19 @@ class TrainerFinetune(BaseEstimator):
62
62
  self.checkpoint.reset(self.model)
63
63
 
64
64
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
65
- self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
65
+ self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
66
66
  self.bin_width = self.bins[1] - self.bins[0]
67
-
67
+
68
68
  self.metric = self.cfg.hyperparams['metric']
69
69
 
70
70
 
71
71
  def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
72
72
 
73
- self.preprocessor.fit(x_train, y_train)
73
+ self.preprocessor.fit(x_train, y_train)
74
74
 
75
- x_train_transformed = self.preprocessor.transform_X(x_train)
75
+ x_train_transformed = self.preprocessor.transform_X(x_train)
76
76
  y_train_transformed = self.preprocessor.transform_y(y_train)
77
-
77
+
78
78
  dataset_train_generator = DatasetFinetuneGenerator(
79
79
  self.cfg,
80
80
  x = x_train_transformed,
@@ -94,16 +94,16 @@ class TrainerFinetune(BaseEstimator):
94
94
 
95
95
  for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
96
96
 
97
- dataset_train = next(dataset_train_generator)
97
+ dataset_train = next(dataset_train_generator)
98
98
  loader_train = self.make_loader(dataset_train, training=True)
99
99
  self.model.train()
100
-
100
+
101
101
  prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
102
102
 
103
103
  for batch in loader_train:
104
-
104
+
105
105
  with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
106
-
106
+
107
107
  x_support = batch['x_support'].to(self.device, non_blocking=True)
108
108
  y_support = batch['y_support'].to(self.device, non_blocking=True)
109
109
  x_query = batch['x_query'].to(self.device, non_blocking=True)
@@ -118,13 +118,12 @@ class TrainerFinetune(BaseEstimator):
118
118
  y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
119
119
  y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
120
120
  y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
121
-
122
- match self.cfg.model_name:
123
- case ModelName.TABPFN:
124
- y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
125
- case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
126
- y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
127
-
121
+
122
+ if self.cfg.model_name == ModelName.TABPFN:
123
+ y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
124
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
125
+ y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
126
+
128
127
  # Convert numerical y_query to bin ids
129
128
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
130
129
  loss = self.loss(y_hat, y_query_bin_ids)
@@ -143,7 +142,7 @@ class TrainerFinetune(BaseEstimator):
143
142
 
144
143
  # Convert bin id predictions to numerical values
145
144
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
146
- y_hat = torch.argmax(y_hat, dim=-1)
145
+ y_hat = torch.argmax(y_hat, dim=-1)
147
146
  y_hat = self.bins[y_hat] + self.bin_width / 2
148
147
 
149
148
  y_hat = y_hat.float()
@@ -153,12 +152,12 @@ class TrainerFinetune(BaseEstimator):
153
152
  prediction_metrics_tracker.update(y_hat, y_query, train=False)
154
153
 
155
154
  metrics_train = prediction_metrics_tracker.get_metrics()
156
- metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
155
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
157
156
 
158
157
  self.log_metrics(epoch, metrics_train, metrics_valid)
159
158
 
160
159
  self.checkpoint(self.model, metrics_valid.loss)
161
-
160
+
162
161
  self.early_stopping(metrics_valid.metrics[self.metric])
163
162
  if self.early_stopping.we_should_stop():
164
163
  logger.info("Early stopping")
@@ -175,9 +174,9 @@ class TrainerFinetune(BaseEstimator):
175
174
 
176
175
  self.checkpoint.set_to_best(self.model)
177
176
 
178
-
177
+
179
178
  def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
180
-
179
+
181
180
  self.model.eval()
182
181
 
183
182
  x_support_transformed = self.preprocessor.transform_X(x_support)
@@ -186,9 +185,9 @@ class TrainerFinetune(BaseEstimator):
186
185
  # y_query_transformed = self.preprocessor.transform_y(y_query)
187
186
 
188
187
  dataset = DatasetFinetune(
189
- self.cfg,
190
- x_support = x_support_transformed,
191
- y_support = y_support_transformed,
188
+ self.cfg,
189
+ x_support = x_support_transformed,
190
+ y_support = y_support_transformed,
192
191
  x_query = x_query_transformed,
193
192
  y_query = y_query,
194
193
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
@@ -202,7 +201,7 @@ class TrainerFinetune(BaseEstimator):
202
201
  for batch in loader:
203
202
 
204
203
  with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
205
-
204
+
206
205
  x_s = batch['x_support'].to(self.device, non_blocking=True)
207
206
  y_s = batch['y_support'].to(self.device, non_blocking=True)
208
207
  x_q = batch['x_query'].to(self.device, non_blocking=True)
@@ -215,16 +214,15 @@ class TrainerFinetune(BaseEstimator):
215
214
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
216
215
  y_s = torch.bucketize(y_s, self.bins) - 1
217
216
  y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
218
-
219
- match self.cfg.model_name:
220
- case ModelName.TABPFN:
221
- y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
222
- case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
223
- y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
217
+
218
+ if self.cfg.model_name == ModelName.TABPFN:
219
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
220
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
221
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
224
222
 
225
223
  # Convert bin id predictions to numerical values
226
224
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
227
- y_hat = torch.argmax(y_hat, dim=-1)
225
+ y_hat = torch.argmax(y_hat, dim=-1)
228
226
  y_hat = self.bins[y_hat] + self.bin_width / 2
229
227
 
230
228
  y_hat = y_hat.float()
@@ -232,7 +230,7 @@ class TrainerFinetune(BaseEstimator):
232
230
 
233
231
  metrics_eval = prediction_metrics_tracker.get_metrics()
234
232
  return metrics_eval
235
-
233
+
236
234
 
237
235
  def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
238
236
 
@@ -241,9 +239,9 @@ class TrainerFinetune(BaseEstimator):
241
239
  y_support_transformed = self.preprocessor.transform_y(y_support)
242
240
 
243
241
  dataset = DatasetFinetune(
244
- self.cfg,
245
- x_support = x_support_transformed,
246
- y_support = y_support_transformed,
242
+ self.cfg,
243
+ x_support = x_support_transformed,
244
+ y_support = y_support_transformed,
247
245
  x_query = x_query_transformed,
248
246
  y_query = None,
249
247
  max_samples_support = self.cfg.hyperparams['max_samples_support'],
@@ -259,7 +257,7 @@ class TrainerFinetune(BaseEstimator):
259
257
  for batch in loader:
260
258
 
261
259
  with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
262
-
260
+
263
261
  x_s = batch['x_support'].to(self.device, non_blocking=True)
264
262
  y_s = batch['y_support'].to(self.device, non_blocking=True)
265
263
  x_q = batch['x_query'].to(self.device, non_blocking=True)
@@ -271,18 +269,17 @@ class TrainerFinetune(BaseEstimator):
271
269
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
272
270
  y_s = torch.bucketize(y_s, self.bins) - 1
273
271
  y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
274
-
275
- match self.cfg.model_name:
276
- case ModelName.TABPFN:
277
- y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
278
- case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
279
- y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
280
-
272
+
273
+ if self.cfg.model_name == ModelName.TABPFN:
274
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
275
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
276
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
277
+
281
278
  y_hat = y_hat[0].float().cpu().numpy()
282
279
 
283
280
  # Convert bin id predictions to numerical values
284
281
  if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
285
- y_hat = np.argmax(y_hat, axis=-1)
282
+ y_hat = np.argmax(y_hat, axis=-1)
286
283
  y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
287
284
 
288
285
  y_hat = self.preprocessor.inverse_transform_y(y_hat)
@@ -291,21 +288,20 @@ class TrainerFinetune(BaseEstimator):
291
288
  y_pred = np.concatenate(y_pred_list, axis=0)
292
289
 
293
290
  return y_pred
294
-
291
+
295
292
 
296
293
  def load_params(self, path):
297
294
  self.model.load_state_dict(torch.load(path))
298
-
295
+
299
296
 
300
297
  def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
301
298
 
302
- match self.cfg.model_name:
303
- case ModelName.TABPFN:
304
- pad_to_max_features = True
305
- case ModelName.TAB2D | ModelName.TAB2D_COL_ROW | ModelName.TAB2D_SDPA:
306
- pad_to_max_features = False
307
- case _:
308
- raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
299
+ if self.cfg.model_name == ModelName.TABPFN:
300
+ pad_to_max_features = True
301
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
302
+ pad_to_max_features = False
303
+ else:
304
+ raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
309
305
 
310
306
  return torch.utils.data.DataLoader(
311
307
  dataset,
@@ -319,13 +315,12 @@ class TrainerFinetune(BaseEstimator):
319
315
  pad_to_max_features=pad_to_max_features
320
316
  ),
321
317
  )
322
-
318
+
323
319
 
324
320
  def log_start_metrics(self, metrics_valid: PredictionMetrics):
325
321
 
326
- match self.cfg.task:
327
- case Task.REGRESSION:
328
- logger.info((
322
+ if self.cfg.task == Task.REGRESSION:
323
+ logger.info((
329
324
  f"Epoch 000 "
330
325
  f"| Train MSE: -.---- "
331
326
  f"| Train MAE: -.---- "
@@ -334,21 +329,20 @@ class TrainerFinetune(BaseEstimator):
334
329
  f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
335
330
  f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
336
331
  ))
337
- case Task.CLASSIFICATION:
338
- logger.info((
332
+
333
+ elif self.cfg.task == Task.CLASSIFICATION:
334
+ logger.info((
339
335
  f"Epoch 000 "
340
336
  f"| Train CE: -.---- "
341
337
  f"| Train acc: -.---- "
342
338
  f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
343
339
  f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
344
340
  ))
345
-
346
341
 
347
342
  def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
348
343
 
349
- match self.cfg.task:
350
- case Task.REGRESSION:
351
- logger.info((
344
+ if self.cfg.task == Task.REGRESSION:
345
+ logger.info((
352
346
  f"Epoch {epoch:03d} "
353
347
  f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
354
348
  f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
@@ -357,11 +351,11 @@ class TrainerFinetune(BaseEstimator):
357
351
  f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
358
352
  f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
359
353
  ))
360
- case Task.CLASSIFICATION:
361
- logger.info((
354
+ elif self.cfg.task == Task.CLASSIFICATION:
355
+ logger.info((
362
356
  f"Epoch {epoch:03d} "
363
357
  f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
364
358
  f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
365
359
  f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
366
360
  f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
367
- ))
361
+ ))
@@ -0,0 +1 @@
1
+ # Data processing modules for MitraModel