autogluon.tabular 1.3.2b20250610__py3-none-any.whl → 1.4.1b20251214__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. autogluon/tabular/configs/config_helper.py +1 -1
  2. autogluon/tabular/configs/hyperparameter_configs.py +2 -265
  3. autogluon/tabular/configs/pipeline_presets.py +130 -0
  4. autogluon/tabular/configs/presets_configs.py +51 -26
  5. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +0 -1
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +310 -0
  7. autogluon/tabular/models/__init__.py +6 -1
  8. autogluon/tabular/models/_utils/rapids_utils.py +1 -1
  9. autogluon/tabular/models/automm/automm_model.py +2 -0
  10. autogluon/tabular/models/automm/ft_transformer.py +4 -1
  11. autogluon/tabular/models/catboost/callbacks.py +3 -2
  12. autogluon/tabular/models/catboost/catboost_model.py +15 -9
  13. autogluon/tabular/models/catboost/catboost_utils.py +17 -3
  14. autogluon/tabular/models/ebm/__init__.py +0 -0
  15. autogluon/tabular/models/ebm/ebm_model.py +259 -0
  16. autogluon/tabular/models/ebm/hyperparameters/__init__.py +0 -0
  17. autogluon/tabular/models/ebm/hyperparameters/parameters.py +39 -0
  18. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +72 -0
  19. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +7 -5
  20. autogluon/tabular/models/knn/knn_model.py +7 -3
  21. autogluon/tabular/models/lgb/lgb_model.py +60 -21
  22. autogluon/tabular/models/lr/lr_model.py +6 -1
  23. autogluon/tabular/models/lr/lr_preprocessing_utils.py +6 -7
  24. autogluon/tabular/models/lr/lr_rapids_model.py +45 -5
  25. autogluon/tabular/models/mitra/__init__.py +0 -0
  26. autogluon/tabular/models/mitra/_internal/__init__.py +1 -0
  27. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -0
  28. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  29. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  30. autogluon/tabular/models/mitra/_internal/config/enums.py +162 -0
  31. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -0
  32. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  33. autogluon/tabular/models/mitra/_internal/core/get_loss.py +54 -0
  34. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  35. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  36. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +132 -0
  37. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +373 -0
  38. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -0
  39. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  40. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +136 -0
  41. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +57 -0
  42. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  43. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -0
  44. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  45. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  46. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  47. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -0
  48. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  49. autogluon/tabular/models/mitra/mitra_model.py +380 -0
  50. autogluon/tabular/models/mitra/sklearn_interface.py +494 -0
  51. autogluon/tabular/models/realmlp/__init__.py +0 -0
  52. autogluon/tabular/models/realmlp/realmlp_model.py +360 -0
  53. autogluon/tabular/models/rf/rf_model.py +11 -6
  54. autogluon/tabular/models/tabicl/__init__.py +0 -0
  55. autogluon/tabular/models/tabicl/tabicl_model.py +179 -0
  56. autogluon/tabular/models/tabm/__init__.py +0 -0
  57. autogluon/tabular/models/tabm/_tabm_internal.py +545 -0
  58. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +810 -0
  59. autogluon/tabular/models/tabm/tabm_model.py +356 -0
  60. autogluon/tabular/models/tabm/tabm_reference.py +631 -0
  61. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +13 -7
  62. autogluon/tabular/models/tabpfnv2/__init__.py +0 -0
  63. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +20 -0
  64. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +40 -0
  65. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +201 -0
  66. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +1464 -0
  67. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +747 -0
  68. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +863 -0
  69. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +106 -0
  70. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +388 -0
  71. autogluon/tabular/models/tabular_nn/hyperparameters/parameters.py +1 -3
  72. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +5 -5
  73. autogluon/tabular/models/xgboost/xgboost_model.py +10 -3
  74. autogluon/tabular/predictor/predictor.py +147 -84
  75. autogluon/tabular/registry/_ag_model_registry.py +12 -2
  76. autogluon/tabular/testing/fit_helper.py +57 -27
  77. autogluon/tabular/testing/generate_datasets.py +7 -0
  78. autogluon/tabular/trainer/abstract_trainer.py +3 -1
  79. autogluon/tabular/trainer/model_presets/presets.py +10 -1
  80. autogluon/tabular/version.py +1 -1
  81. autogluon.tabular-1.4.1b20251214-py3.11-nspkg.pth +1 -0
  82. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/METADATA +112 -57
  83. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/RECORD +89 -40
  84. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/WHEEL +1 -1
  85. autogluon/tabular/models/tabpfn/__init__.py +0 -1
  86. autogluon/tabular/models/tabpfn/tabpfn_model.py +0 -153
  87. autogluon.tabular-1.3.2b20250610-py3.9-nspkg.pth +0 -1
  88. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/LICENSE +0 -0
  89. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info/licenses}/NOTICE +0 -0
  90. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/namespace_packages.txt +0 -0
  91. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/top_level.txt +0 -0
  92. {autogluon.tabular-1.3.2b20250610.dist-info → autogluon_tabular-1.4.1b20251214.dist-info}/zip-safe +0 -0
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from torch.optim import SGD, Adam, AdamW
3
+
4
+ from ..._internal.config.config_pretrain import ConfigPretrain
5
+
6
+
7
+ def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
8
+
9
+ optimizer: torch.optim.Optimizer
10
+
11
+ if hyperparams['optimizer'] == "adam":
12
+ optimizer = Adam(
13
+ model.parameters(),
14
+ lr=hyperparams['lr'],
15
+ betas=(0.9, 0.999),
16
+ weight_decay=hyperparams['weight_decay']
17
+ )
18
+ elif hyperparams['optimizer'] == "adamw":
19
+ optimizer = AdamW(
20
+ model.parameters(),
21
+ lr=hyperparams['lr'],
22
+ betas=(0.9, 0.999),
23
+ weight_decay=hyperparams['weight_decay']
24
+ )
25
+ elif hyperparams['optimizer'] == "sgd":
26
+ optimizer = SGD(
27
+ model.parameters(),
28
+ lr=hyperparams['lr'],
29
+ weight_decay=hyperparams['weight_decay']
30
+ )
31
+ else:
32
+ raise ValueError("Optimizer not recognized")
33
+
34
+ return optimizer
35
+
36
+
37
+ def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
38
+
39
+ parameters = [(name, param) for name, param in model.named_parameters()]
40
+
41
+ parameters_with_weight_decay = []
42
+ parameters_without_weight_decay = []
43
+
44
+ for name, param in parameters:
45
+ if name.endswith("bias") or 'norm' in name or 'embedding' in name:
46
+ parameters_without_weight_decay.append(param)
47
+ else:
48
+ parameters_with_weight_decay.append(param)
49
+
50
+ optimizer_parameters = [
51
+ {"params": parameters_with_weight_decay, "weight_decay": cfg.optim.weight_decay},
52
+ {"params": parameters_without_weight_decay, "weight_decay": 0.0},
53
+ ]
54
+
55
+ optimizer = torch.optim.AdamW(
56
+ optimizer_parameters,
57
+ lr=cfg.optim.lr,
58
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
59
+ weight_decay=cfg.optim.weight_decay
60
+ )
61
+
62
+ return optimizer
63
+
64
+
65
+ class GradScaler(torch.amp.GradScaler):
66
+
67
+ def __init__(
68
+ self,
69
+ enabled: bool = True,
70
+ scale_init: float = 2.**16,
71
+ scale_min: float = 1.,
72
+ growth_interval: int = 2000,
73
+ device: str = 'cuda'
74
+ ):
75
+ super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
76
+ self._enabled = enabled
77
+ self.scale_min = scale_min
78
+ self.device = device
79
+
80
+ if not self._enabled:
81
+ # We write scale=1 to log if the scaler is disabled
82
+ self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
83
+
84
+
85
+ def update(self):
86
+
87
+ if not self._enabled:
88
+ return
89
+
90
+ super().update()
91
+
92
+ if self._scale < self.scale_min:
93
+ super().update(self.scale_min)
94
+
95
+
96
+ def move_optimizer_to(optim, device):
97
+ for param in optim.state.values():
98
+ # Not sure there are any global tensors in the state dict
99
+ if isinstance(param, torch.Tensor):
100
+ param.data = param.data.to(device)
101
+ if param._grad is not None:
102
+ param._grad.data = param._grad.data.to(device)
103
+ elif isinstance(param, dict):
104
+ for subparam in param.values():
105
+ if isinstance(subparam, torch.Tensor):
106
+ subparam.data = subparam.data.to(device)
107
+ if subparam._grad is not None:
108
+ subparam._grad.data = subparam._grad.data.to(device)
@@ -0,0 +1,67 @@
1
+ import torch
2
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
3
+ from transformers import get_constant_schedule_with_warmup
4
+ from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
5
+
6
+ from ..._internal.config.config_pretrain import ConfigPretrain
7
+
8
+
9
+ def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
10
+
11
+ warmup_steps = hyperparams['warmup_steps']
12
+
13
+ # if warmup_steps > 0:
14
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
15
+ # optimizer, lambda step: min((step + 1) / warmup_steps, 1.0)
16
+ # )
17
+ # else:
18
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
19
+ # optimizer, lambda step: 1.0
20
+ # )
21
+
22
+ if warmup_steps > 0:
23
+ scheduler_warmup = LinearLR(
24
+ optimizer,
25
+ start_factor=1.0 / warmup_steps,
26
+ end_factor=1.0,
27
+ total_iters=warmup_steps,
28
+ )
29
+ else:
30
+ scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
31
+
32
+ if hyperparams['lr_scheduler']:
33
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
34
+ optimizer,
35
+ patience=hyperparams['lr_scheduler_patience'],
36
+ min_lr=0,
37
+ factor=0.2
38
+ )
39
+ else:
40
+ # With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
41
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
42
+ optimizer,
43
+ patience=1000000000,
44
+ min_lr=0,
45
+ factor=0.2
46
+ )
47
+
48
+ return scheduler_warmup, scheduler_reduce_on_plateau
49
+
50
+
51
+ def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
52
+
53
+
54
+ if cfg.optim.cosine_scheduler:
55
+ schedule = get_cosine_with_min_lr_schedule_with_warmup(
56
+ optimizer,
57
+ num_warmup_steps=cfg.optim.warmup_steps,
58
+ num_training_steps=cfg.optim.steps,
59
+ min_lr_rate=0.1
60
+ )
61
+ else:
62
+ schedule = get_constant_schedule_with_warmup(
63
+ optimizer,
64
+ num_warmup_steps=cfg.optim.warmup_steps
65
+ )
66
+
67
+ return schedule
@@ -0,0 +1,132 @@
1
+ from dataclasses import dataclass
2
+
3
+ import numpy as np
4
+ import scipy.special
5
+ import torch
6
+ from loguru import logger
7
+ from sklearn.metrics import f1_score, mean_squared_error, r2_score, roc_auc_score, root_mean_squared_error
8
+
9
+ from ..._internal.config.enums import MetricName, Task
10
+ from ..._internal.data.preprocessor import Preprocessor
11
+
12
+
13
+ @dataclass
14
+ class PredictionMetrics():
15
+ task: Task
16
+ loss: float
17
+ score: float
18
+ metrics: dict[MetricName, float]
19
+
20
+
21
+ @classmethod
22
+ def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> "PredictionMetrics":
23
+
24
+ loss, score, metrics = compute_metrics(y_pred, y_true, task)
25
+
26
+ return cls(task=task, loss=loss, score=score, metrics=metrics)
27
+
28
+
29
+ def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
30
+
31
+ if task == Task.CLASSIFICATION:
32
+ return compute_classification_metrics(y_pred, y_true)
33
+ elif task == Task.REGRESSION:
34
+ return compute_regression_metrics(y_pred, y_true)
35
+
36
+
37
+ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
38
+ # predictions are assumed to be log-probabilities
39
+
40
+ y_pred_class = np.argmax(y_pred, axis=1)
41
+ y_pred_proba = scipy.special.softmax(y_pred, axis=1)
42
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
43
+ labels = np.arange(y_pred_proba.shape[1])
44
+
45
+ metrics = {
46
+ MetricName.ACCURACY: (y_true == y_pred_class).mean(),
47
+ MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
48
+ MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
49
+ MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
50
+ }
51
+
52
+ loss = metrics[MetricName.LOG_LOSS]
53
+ score = metrics[MetricName.ACCURACY]
54
+
55
+ return loss, score, metrics
56
+
57
+
58
+ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
59
+ """
60
+ The roc_auc_score multi_class is not supported for binary classification
61
+ """
62
+
63
+ if np.unique(y_true).shape[0] == 1:
64
+ # AUC is not defined if there is only one class
65
+ return float('nan')
66
+
67
+ try:
68
+ if y_pred_proba.shape[1] == 2:
69
+ return roc_auc_score(y_true, y_pred_proba[:, 1])
70
+ else:
71
+ return roc_auc_score(y_true, y_pred_proba, multi_class=multi_class, average=average, labels=labels)
72
+ except ValueError as e:
73
+ logger.error(f"Error computing roc_auc_score: {e}")
74
+ logger.error(f"Returning {-1}")
75
+ return -1
76
+
77
+
78
+ def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
79
+
80
+ metrics = {
81
+ MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
82
+ MetricName.MSE: mean_squared_error(y_true, y_pred),
83
+ MetricName.MAE: np.abs(y_true - y_pred).mean(),
84
+ MetricName.R2: r2_score(y_true, y_pred)
85
+ }
86
+
87
+ loss = metrics[MetricName.MSE]
88
+ score = metrics[MetricName.R2]
89
+
90
+ return loss, score, metrics
91
+
92
+
93
+ class PredictionMetricsTracker():
94
+ """
95
+ Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
96
+ Uses torch.Tensor for predictions and true values.
97
+ """
98
+
99
+ def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
100
+
101
+ self.task = task
102
+ self.preprocessor = preprocessor
103
+ self.reset()
104
+
105
+
106
+ def reset(self) -> None:
107
+
108
+ self.ys_pred: list[np.ndarray] = []
109
+ self.ys_true: list[np.ndarray] = []
110
+
111
+
112
+ def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
113
+
114
+ y_pred_np = y_pred.detach().cpu().numpy()[0]
115
+ y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
116
+
117
+ y_true_np = y_true.detach().cpu().numpy()[0]
118
+ if train:
119
+ y_true_ori = self.preprocessor.inverse_transform_y(y_true_np)
120
+ else:
121
+ y_true_ori = y_true_np
122
+
123
+ self.ys_pred.append(y_pred_ori)
124
+ self.ys_true.append(y_true_ori)
125
+
126
+
127
+ def get_metrics(self) -> PredictionMetrics:
128
+
129
+ y_pred = np.concatenate(self.ys_pred, axis=0)
130
+ y_true = np.concatenate(self.ys_true, axis=0)
131
+
132
+ return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
@@ -0,0 +1,373 @@
1
+ import time
2
+
3
+ import numpy as np
4
+ import torch
5
+ from loguru import logger
6
+ from sklearn.base import BaseEstimator
7
+
8
+ from ..._internal.config.config_run import ConfigRun
9
+ from ..._internal.config.enums import LossName, MetricName, ModelName, Task
10
+ from ..._internal.core.callbacks import Checkpoint, EarlyStopping
11
+ from ..._internal.core.get_loss import get_loss
12
+ from ..._internal.core.get_optimizer import GradScaler, get_optimizer
13
+ from ..._internal.core.get_scheduler import get_scheduler
14
+ from ..._internal.core.prediction_metrics import PredictionMetrics, PredictionMetricsTracker
15
+ from ..._internal.data.collator import CollatorWithPadding
16
+ from ..._internal.data.dataset_finetune import DatasetFinetune, DatasetFinetuneGenerator
17
+ from ..._internal.data.preprocessor import Preprocessor
18
+
19
+
20
+ class TrainerFinetune(BaseEstimator):
21
+
22
+ def __init__(
23
+ self,
24
+ cfg: ConfigRun,
25
+ model: torch.nn.Module,
26
+ n_classes: int,
27
+ device: str,
28
+ rng: np.random.RandomState = None,
29
+ verbose: bool = True,
30
+ ):
31
+
32
+ self.cfg = cfg
33
+ if rng is None:
34
+ rng = np.random.RandomState(self.cfg.seed)
35
+ self.rng = rng
36
+ self.verbose = verbose
37
+ self.device = device
38
+ self.model = model.to(self.device, non_blocking=True)
39
+ self.n_classes = n_classes
40
+
41
+ self.loss = get_loss(self.cfg)
42
+ self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
43
+ self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
44
+ self.scaler = GradScaler(
45
+ enabled=self.cfg.hyperparams['grad_scaler_enabled'],
46
+ scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
47
+ scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
48
+ growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
49
+ device=self.device
50
+ )
51
+
52
+ self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
53
+ self.checkpoint = Checkpoint()
54
+ self.preprocessor = Preprocessor(
55
+ dim_embedding=self.cfg.hyperparams['dim_embedding'],
56
+ n_classes=self.n_classes,
57
+ dim_output=self.cfg.hyperparams['dim_output'],
58
+ use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
59
+ use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
60
+ use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
61
+ shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
62
+ shuffle_features=self.cfg.hyperparams['shuffle_features'],
63
+ random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
64
+ random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
65
+ task=self.cfg.task
66
+ )
67
+
68
+ self.checkpoint.reset(self.model)
69
+
70
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
71
+ self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
72
+ self.bin_width = self.bins[1] - self.bins[0]
73
+
74
+ self.metric = self.cfg.hyperparams['metric']
75
+
76
+
77
+ def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
78
+
79
+ self.preprocessor.fit(x_train, y_train)
80
+
81
+ x_train_transformed = self.preprocessor.transform_X(x_train)
82
+ y_train_transformed = self.preprocessor.transform_y(y_train)
83
+
84
+ dataset_train_generator = DatasetFinetuneGenerator(
85
+ self.cfg,
86
+ x = x_train_transformed,
87
+ y = y_train_transformed,
88
+ task = self.cfg.task,
89
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
90
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
91
+ rng=self.rng,
92
+ )
93
+
94
+ self.checkpoint.reset(self.model)
95
+
96
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
97
+ if self.verbose:
98
+ self.log_start_metrics(metrics_valid)
99
+ self.checkpoint(self.model, metrics_valid.loss)
100
+
101
+ start_time = time.time()
102
+
103
+ for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
104
+
105
+ dataset_train = next(dataset_train_generator)
106
+ loader_train = self.make_loader(dataset_train, training=True)
107
+ self.model.train()
108
+
109
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
110
+
111
+ for batch in loader_train:
112
+
113
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
114
+
115
+ x_support = batch['x_support'].to(self.device, non_blocking=True)
116
+ y_support = batch['y_support'].to(self.device, non_blocking=True)
117
+ x_query = batch['x_query'].to(self.device, non_blocking=True)
118
+ y_query = batch['y_query'].to(self.device, non_blocking=True)
119
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
120
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
121
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
122
+
123
+ # Convert numerical y_support to bin ids
124
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
125
+ y_support = torch.bucketize(y_support, self.bins) - 1
126
+ y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
127
+ y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
128
+ y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
129
+
130
+ if self.cfg.model_name == ModelName.TABPFN:
131
+ y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
132
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
133
+ y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
134
+
135
+ # Convert numerical y_query to bin ids
136
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
137
+ loss = self.loss(y_hat, y_query_bin_ids)
138
+ elif self.cfg.task == Task.CLASSIFICATION:
139
+ # for b in range(y_support.shape[0]):
140
+ # unique_classes = len(torch.unique(torch.cat((y_support[b], y_query[b]))))
141
+ # y_hat[b, :, unique_classes:] = 0
142
+ loss = self.loss(y_hat, y_query)
143
+ else:
144
+ loss = self.loss(y_hat, y_query)
145
+
146
+ self.optimizer.zero_grad()
147
+ self.scaler.scale(loss).backward()
148
+ self.scaler.step(self.optimizer)
149
+ self.scaler.update()
150
+
151
+ # Convert bin id predictions to numerical values
152
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
153
+ y_hat = torch.argmax(y_hat, dim=-1)
154
+ y_hat = self.bins[y_hat] + self.bin_width / 2
155
+
156
+ y_hat = y_hat.float()
157
+ if self.cfg.task == Task.REGRESSION:
158
+ prediction_metrics_tracker.update(y_hat, y_query, train=True)
159
+ else:
160
+ prediction_metrics_tracker.update(y_hat, y_query, train=False)
161
+
162
+ metrics_train = prediction_metrics_tracker.get_metrics()
163
+ metrics_valid = self.evaluate(x_train, y_train, x_val, y_val)
164
+
165
+ if self.verbose:
166
+ self.log_metrics(epoch, metrics_train, metrics_valid)
167
+
168
+ self.checkpoint(self.model, metrics_valid.loss)
169
+
170
+ self.early_stopping(metrics_valid.metrics[self.metric])
171
+ if self.early_stopping.we_should_stop():
172
+ if self.verbose:
173
+ logger.info("Early stopping")
174
+ break
175
+
176
+ if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
177
+ logger.info("Time limit reached")
178
+ break
179
+
180
+ if epoch < self.cfg.hyperparams['warmup_steps']:
181
+ self.scheduler_warmup.step()
182
+ else:
183
+ self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
184
+
185
+ self.checkpoint.set_to_best(self.model)
186
+
187
+
188
+ def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
189
+
190
+ self.model.eval()
191
+
192
+ x_support_transformed = self.preprocessor.transform_X(x_support)
193
+ x_query_transformed = self.preprocessor.transform_X(x_query)
194
+ y_support_transformed = self.preprocessor.transform_y(y_support)
195
+ # y_query_transformed = self.preprocessor.transform_y(y_query)
196
+
197
+ dataset = DatasetFinetune(
198
+ self.cfg,
199
+ x_support = x_support_transformed,
200
+ y_support = y_support_transformed,
201
+ x_query = x_query_transformed,
202
+ y_query = y_query,
203
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
204
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
205
+ rng=self.rng,
206
+ )
207
+
208
+ loader = self.make_loader(dataset, training=False)
209
+ prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
210
+
211
+ with torch.no_grad():
212
+ for batch in loader:
213
+
214
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
215
+
216
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
217
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
218
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
219
+ y_q = batch['y_query'].to(self.device, non_blocking=True)
220
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
221
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
222
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
223
+
224
+ # Convert numerical y_support to bin ids
225
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
226
+ y_s = torch.bucketize(y_s, self.bins) - 1
227
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
228
+
229
+ if self.cfg.model_name == ModelName.TABPFN:
230
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
231
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
232
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
233
+
234
+ # Convert bin id predictions to numerical values
235
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
236
+ y_hat = torch.argmax(y_hat, dim=-1)
237
+ y_hat = self.bins[y_hat] + self.bin_width / 2
238
+
239
+ y_hat = y_hat.float()
240
+ prediction_metrics_tracker.update(y_hat, y_q, train=False)
241
+
242
+ metrics_eval = prediction_metrics_tracker.get_metrics()
243
+ return metrics_eval
244
+
245
+
246
+ def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
247
+
248
+ x_support_transformed = self.preprocessor.transform_X(x_support)
249
+ x_query_transformed = self.preprocessor.transform_X(x_query)
250
+ y_support_transformed = self.preprocessor.transform_y(y_support)
251
+
252
+ dataset = DatasetFinetune(
253
+ self.cfg,
254
+ x_support = x_support_transformed,
255
+ y_support = y_support_transformed,
256
+ x_query = x_query_transformed,
257
+ y_query = None,
258
+ max_samples_support = self.cfg.hyperparams['max_samples_support'],
259
+ max_samples_query = self.cfg.hyperparams['max_samples_query'],
260
+ rng=self.rng,
261
+ )
262
+
263
+ loader = self.make_loader(dataset, training=False)
264
+ self.model.eval()
265
+
266
+ y_pred_list = []
267
+
268
+ with torch.no_grad():
269
+ for batch in loader:
270
+
271
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
272
+
273
+ x_s = batch['x_support'].to(self.device, non_blocking=True)
274
+ y_s = batch['y_support'].to(self.device, non_blocking=True)
275
+ x_q = batch['x_query'].to(self.device, non_blocking=True)
276
+ padding_features = batch['padding_features'].to(self.device, non_blocking=True)
277
+ padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
278
+ padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
279
+
280
+ # Convert numerical y_support to bin ids
281
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
282
+ y_s = torch.bucketize(y_s, self.bins) - 1
283
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
284
+
285
+ if self.cfg.model_name == ModelName.TABPFN:
286
+ y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
287
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
288
+ y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
289
+
290
+ y_hat = y_hat[0].float().cpu().numpy()
291
+
292
+ # Convert bin id predictions to numerical values
293
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
294
+ y_hat = np.argmax(y_hat, axis=-1)
295
+ y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
296
+
297
+ y_hat = self.preprocessor.inverse_transform_y(y_hat)
298
+ y_pred_list.append(y_hat)
299
+
300
+ y_pred = np.concatenate(y_pred_list, axis=0)
301
+
302
+ return y_pred
303
+
304
+
305
+ def load_params(self, path):
306
+ self.model.load_state_dict(torch.load(path))
307
+
308
+
309
+ def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
310
+
311
+ if self.cfg.model_name == ModelName.TABPFN:
312
+ pad_to_max_features = True
313
+ elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
314
+ pad_to_max_features = False
315
+ else:
316
+ raise NotImplementedError(f"Model {self.cfg.model_name} not implemented")
317
+
318
+ return torch.utils.data.DataLoader(
319
+ dataset,
320
+ batch_size=1,
321
+ shuffle=training,
322
+ pin_memory=True,
323
+ num_workers=0,
324
+ drop_last=False,
325
+ collate_fn=CollatorWithPadding(
326
+ max_features=self.cfg.hyperparams['dim_embedding'],
327
+ pad_to_max_features=pad_to_max_features
328
+ ),
329
+ )
330
+
331
+
332
+ def log_start_metrics(self, metrics_valid: PredictionMetrics):
333
+
334
+ if self.cfg.task == Task.REGRESSION:
335
+ logger.info((
336
+ f"Epoch 000 "
337
+ f"| Train MSE: -.---- "
338
+ f"| Train MAE: -.---- "
339
+ f"| Train r2: -.---- "
340
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
341
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
342
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
343
+ ))
344
+
345
+ elif self.cfg.task == Task.CLASSIFICATION:
346
+ logger.info((
347
+ f"Epoch 000 "
348
+ f"| Train CE: -.---- "
349
+ f"| Train acc: -.---- "
350
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
351
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
352
+ ))
353
+
354
+ def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
355
+
356
+ if self.cfg.task == Task.REGRESSION:
357
+ logger.info((
358
+ f"Epoch {epoch:03d} "
359
+ f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
360
+ f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
361
+ f"| Train r2: {metrics_train.metrics[MetricName.R2]:.4f} "
362
+ f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
363
+ f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
364
+ f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
365
+ ))
366
+ elif self.cfg.task == Task.CLASSIFICATION:
367
+ logger.info((
368
+ f"Epoch {epoch:03d} "
369
+ f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
370
+ f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
371
+ f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
372
+ f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
373
+ ))
@@ -0,0 +1 @@
1
+ # Data processing modules for MitraModel