autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260117__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of autogluon.tabular might be problematic. Click here for more details.

Files changed (135) hide show
  1. autogluon/tabular/__init__.py +1 -0
  2. autogluon/tabular/configs/config_helper.py +18 -6
  3. autogluon/tabular/configs/feature_generator_presets.py +3 -1
  4. autogluon/tabular/configs/hyperparameter_configs.py +42 -9
  5. autogluon/tabular/configs/presets_configs.py +38 -14
  6. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
  7. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
  8. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
  9. autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
  10. autogluon/tabular/experimental/_scikit_mixin.py +6 -2
  11. autogluon/tabular/experimental/_tabular_classifier.py +3 -1
  12. autogluon/tabular/experimental/_tabular_regressor.py +3 -1
  13. autogluon/tabular/experimental/plot_leaderboard.py +73 -19
  14. autogluon/tabular/learner/abstract_learner.py +160 -42
  15. autogluon/tabular/learner/default_learner.py +78 -22
  16. autogluon/tabular/models/__init__.py +2 -2
  17. autogluon/tabular/models/_utils/rapids_utils.py +3 -1
  18. autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
  19. autogluon/tabular/models/automm/automm_model.py +12 -3
  20. autogluon/tabular/models/automm/ft_transformer.py +5 -1
  21. autogluon/tabular/models/catboost/callbacks.py +2 -2
  22. autogluon/tabular/models/catboost/catboost_model.py +93 -29
  23. autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
  24. autogluon/tabular/models/catboost/catboost_utils.py +3 -1
  25. autogluon/tabular/models/ebm/ebm_model.py +8 -13
  26. autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
  27. autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
  28. autogluon/tabular/models/fastainn/callbacks.py +20 -3
  29. autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
  30. autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
  31. autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
  32. autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
  33. autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
  34. autogluon/tabular/models/knn/knn_model.py +41 -8
  35. autogluon/tabular/models/lgb/callbacks.py +32 -9
  36. autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
  37. autogluon/tabular/models/lgb/lgb_model.py +150 -34
  38. autogluon/tabular/models/lgb/lgb_utils.py +12 -4
  39. autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
  40. autogluon/tabular/models/lr/lr_model.py +40 -10
  41. autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
  42. autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
  43. autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
  44. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
  45. autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
  46. autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
  47. autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
  48. autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
  49. autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
  50. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
  51. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
  52. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
  53. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
  54. autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
  55. autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
  56. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
  57. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
  58. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
  59. autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
  60. autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
  61. autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
  62. autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
  63. autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
  64. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
  65. autogluon/tabular/models/mitra/mitra_model.py +16 -11
  66. autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
  67. autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
  68. autogluon/tabular/models/rf/compilers/onnx.py +1 -1
  69. autogluon/tabular/models/rf/rf_model.py +45 -12
  70. autogluon/tabular/models/rf/rf_quantile.py +4 -2
  71. autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
  72. autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
  73. autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
  74. autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
  75. autogluon/tabular/models/tabm/tabm_model.py +8 -4
  76. autogluon/tabular/models/tabm/tabm_reference.py +53 -85
  77. autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
  78. autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
  79. autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
  80. autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
  81. autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
  82. autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
  83. autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
  84. autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
  85. autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
  86. autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
  87. autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
  88. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
  89. autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
  90. autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
  91. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
  92. autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
  93. autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
  94. autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
  95. autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
  96. autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
  97. autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
  98. autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
  99. autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
  100. autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
  101. autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
  102. autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
  103. autogluon/tabular/models/xgboost/callbacks.py +9 -3
  104. autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
  105. autogluon/tabular/models/xt/xt_model.py +1 -0
  106. autogluon/tabular/predictor/interpretable_predictor.py +3 -1
  107. autogluon/tabular/predictor/predictor.py +409 -128
  108. autogluon/tabular/registry/__init__.py +1 -1
  109. autogluon/tabular/registry/_ag_model_registry.py +4 -5
  110. autogluon/tabular/registry/_model_registry.py +1 -0
  111. autogluon/tabular/testing/fit_helper.py +55 -15
  112. autogluon/tabular/testing/generate_datasets.py +1 -1
  113. autogluon/tabular/testing/model_fit_helper.py +10 -4
  114. autogluon/tabular/trainer/abstract_trainer.py +644 -230
  115. autogluon/tabular/trainer/auto_trainer.py +19 -8
  116. autogluon/tabular/trainer/model_presets/presets.py +33 -9
  117. autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
  118. autogluon/tabular/version.py +1 -1
  119. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/METADATA +27 -27
  120. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/RECORD +127 -135
  121. autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
  122. autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
  123. autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
  124. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
  125. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
  126. autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
  127. autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
  128. autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
  129. /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260117-py3.11-nspkg.pth +0 -0
  130. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/WHEEL +0 -0
  131. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/LICENSE +0 -0
  132. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/NOTICE +0 -0
  133. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/namespace_packages.txt +0 -0
  134. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/top_level.txt +0 -0
  135. {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/zip-safe +0 -0
@@ -11,23 +11,20 @@ from ..._internal.data.preprocessor import Preprocessor
11
11
 
12
12
 
13
13
  @dataclass
14
- class PredictionMetrics():
14
+ class PredictionMetrics:
15
15
  task: Task
16
16
  loss: float
17
17
  score: float
18
18
  metrics: dict[MetricName, float]
19
19
 
20
-
21
20
  @classmethod
22
21
  def from_prediction(cls, y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> "PredictionMetrics":
23
-
24
22
  loss, score, metrics = compute_metrics(y_pred, y_true, task)
25
23
 
26
24
  return cls(task=task, loss=loss, score=score, metrics=metrics)
27
25
 
28
26
 
29
27
  def compute_metrics(y_pred: np.ndarray, y_true: np.ndarray, task: Task) -> tuple[float, float, dict]:
30
-
31
28
  if task == Task.CLASSIFICATION:
32
29
  return compute_classification_metrics(y_pred, y_true)
33
30
  elif task == Task.REGRESSION:
@@ -39,14 +36,20 @@ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tu
39
36
 
40
37
  y_pred_class = np.argmax(y_pred, axis=1)
41
38
  y_pred_proba = scipy.special.softmax(y_pred, axis=1)
42
- y_pred_proba = y_pred_proba / y_pred_proba.sum(axis=1, keepdims=True) # softmax not completely numerically stable, so a small correction is needed
39
+ y_pred_proba = y_pred_proba / y_pred_proba.sum(
40
+ axis=1, keepdims=True
41
+ ) # softmax not completely numerically stable, so a small correction is needed
43
42
  labels = np.arange(y_pred_proba.shape[1])
44
43
 
45
44
  metrics = {
46
45
  MetricName.ACCURACY: (y_true == y_pred_class).mean(),
47
46
  MetricName.F1: f1_score(y_true, y_pred_class, average="weighted"),
48
- MetricName.AUC: roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=labels),
49
- MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(torch.from_numpy(y_pred), torch.from_numpy(y_true)).item()
47
+ MetricName.AUC: roc_auc_score_multiclass(
48
+ y_true, y_pred_proba, multi_class="ovo", average="macro", labels=labels
49
+ ),
50
+ MetricName.LOG_LOSS: torch.nn.functional.cross_entropy(
51
+ torch.from_numpy(y_pred), torch.from_numpy(y_true)
52
+ ).item(),
50
53
  }
51
54
 
52
55
  loss = metrics[MetricName.LOG_LOSS]
@@ -55,14 +58,14 @@ def compute_classification_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tu
55
58
  return loss, score, metrics
56
59
 
57
60
 
58
- def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='macro', labels=None) -> float:
59
- """
61
+ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class="ovo", average="macro", labels=None) -> float:
62
+ """
60
63
  The roc_auc_score multi_class is not supported for binary classification
61
64
  """
62
65
 
63
66
  if np.unique(y_true).shape[0] == 1:
64
67
  # AUC is not defined if there is only one class
65
- return float('nan')
68
+ return float("nan")
66
69
 
67
70
  try:
68
71
  if y_pred_proba.shape[1] == 2:
@@ -76,12 +79,11 @@ def roc_auc_score_multiclass(y_true, y_pred_proba, multi_class='ovo', average='m
76
79
 
77
80
 
78
81
  def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[float, float, dict]:
79
-
80
82
  metrics = {
81
83
  MetricName.RMSE: root_mean_squared_error(y_true, y_pred),
82
84
  MetricName.MSE: mean_squared_error(y_true, y_pred),
83
85
  MetricName.MAE: np.abs(y_true - y_pred).mean(),
84
- MetricName.R2: r2_score(y_true, y_pred)
86
+ MetricName.R2: r2_score(y_true, y_pred),
85
87
  }
86
88
 
87
89
  loss = metrics[MetricName.MSE]
@@ -90,27 +92,22 @@ def compute_regression_metrics(y_pred: np.ndarray, y_true: np.ndarray) -> tuple[
90
92
  return loss, score, metrics
91
93
 
92
94
 
93
- class PredictionMetricsTracker():
95
+ class PredictionMetricsTracker:
94
96
  """
95
97
  Prediction metrics tracker that accumulates predictions and true values to compute metrics at the end.
96
98
  Uses torch.Tensor for predictions and true values.
97
99
  """
98
100
 
99
101
  def __init__(self, task: Task, preprocessor: Preprocessor) -> None:
100
-
101
102
  self.task = task
102
103
  self.preprocessor = preprocessor
103
104
  self.reset()
104
105
 
105
-
106
106
  def reset(self) -> None:
107
-
108
107
  self.ys_pred: list[np.ndarray] = []
109
108
  self.ys_true: list[np.ndarray] = []
110
109
 
111
-
112
110
  def update(self, y_pred: torch.Tensor, y_true: torch.Tensor, train: bool) -> None:
113
-
114
111
  y_pred_np = y_pred.detach().cpu().numpy()[0]
115
112
  y_pred_ori = self.preprocessor.inverse_transform_y(y_pred_np)
116
113
 
@@ -123,10 +120,8 @@ class PredictionMetricsTracker():
123
120
  self.ys_pred.append(y_pred_ori)
124
121
  self.ys_true.append(y_true_ori)
125
122
 
126
-
127
123
  def get_metrics(self) -> PredictionMetrics:
128
-
129
124
  y_pred = np.concatenate(self.ys_pred, axis=0)
130
125
  y_true = np.concatenate(self.ys_true, axis=0)
131
126
 
132
- return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
127
+ return PredictionMetrics.from_prediction(y_pred, y_true, self.task)
@@ -18,17 +18,15 @@ from ..._internal.data.preprocessor import Preprocessor
18
18
 
19
19
 
20
20
  class TrainerFinetune(BaseEstimator):
21
-
22
21
  def __init__(
23
- self,
24
- cfg: ConfigRun,
25
- model: torch.nn.Module,
26
- n_classes: int,
27
- device: str,
28
- rng: np.random.RandomState = None,
29
- verbose: bool = True,
22
+ self,
23
+ cfg: ConfigRun,
24
+ model: torch.nn.Module,
25
+ n_classes: int,
26
+ device: str,
27
+ rng: np.random.RandomState = None,
28
+ verbose: bool = True,
30
29
  ):
31
-
32
30
  self.cfg = cfg
33
31
  if rng is None:
34
32
  rng = np.random.RandomState(self.cfg.seed)
@@ -42,36 +40,36 @@ class TrainerFinetune(BaseEstimator):
42
40
  self.optimizer = get_optimizer(self.cfg.hyperparams, self.model)
43
41
  self.scheduler_warmup, self.scheduler_reduce_on_plateau = get_scheduler(self.cfg.hyperparams, self.optimizer)
44
42
  self.scaler = GradScaler(
45
- enabled=self.cfg.hyperparams['grad_scaler_enabled'],
46
- scale_init=self.cfg.hyperparams['grad_scaler_scale_init'],
47
- scale_min=self.cfg.hyperparams['grad_scaler_scale_min'],
48
- growth_interval=self.cfg.hyperparams['grad_scaler_growth_interval'],
49
- device=self.device
43
+ enabled=self.cfg.hyperparams["grad_scaler_enabled"],
44
+ scale_init=self.cfg.hyperparams["grad_scaler_scale_init"],
45
+ scale_min=self.cfg.hyperparams["grad_scaler_scale_min"],
46
+ growth_interval=self.cfg.hyperparams["grad_scaler_growth_interval"],
47
+ device=self.device,
50
48
  )
51
49
 
52
- self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams['early_stopping_patience'])
50
+ self.early_stopping = EarlyStopping(patience=self.cfg.hyperparams["early_stopping_patience"])
53
51
  self.checkpoint = Checkpoint()
54
52
  self.preprocessor = Preprocessor(
55
- dim_embedding=self.cfg.hyperparams['dim_embedding'],
53
+ dim_embedding=self.cfg.hyperparams["dim_embedding"],
56
54
  n_classes=self.n_classes,
57
- dim_output=self.cfg.hyperparams['dim_output'],
58
- use_quantile_transformer=self.cfg.hyperparams['use_quantile_transformer'],
59
- use_feature_count_scaling=self.cfg.hyperparams['use_feature_count_scaling'],
60
- use_random_transforms=self.cfg.hyperparams['use_random_transforms'],
61
- shuffle_classes=self.cfg.hyperparams['shuffle_classes'],
62
- shuffle_features=self.cfg.hyperparams['shuffle_features'],
63
- random_mirror_x=self.cfg.hyperparams['random_mirror_x'],
64
- random_mirror_regression=self.cfg.hyperparams['random_mirror_regression'],
65
- task=self.cfg.task
55
+ dim_output=self.cfg.hyperparams["dim_output"],
56
+ use_quantile_transformer=self.cfg.hyperparams["use_quantile_transformer"],
57
+ use_feature_count_scaling=self.cfg.hyperparams["use_feature_count_scaling"],
58
+ use_random_transforms=self.cfg.hyperparams["use_random_transforms"],
59
+ shuffle_classes=self.cfg.hyperparams["shuffle_classes"],
60
+ shuffle_features=self.cfg.hyperparams["shuffle_features"],
61
+ random_mirror_x=self.cfg.hyperparams["random_mirror_x"],
62
+ random_mirror_regression=self.cfg.hyperparams["random_mirror_regression"],
63
+ task=self.cfg.task,
66
64
  )
67
65
 
68
66
  self.checkpoint.reset(self.model)
69
67
 
70
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
71
- self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams['dim_output']+1, device=cfg.device)
68
+ if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY:
69
+ self.bins = torch.linspace(-0.5, 1.5, self.cfg.hyperparams["dim_output"] + 1, device=cfg.device)
72
70
  self.bin_width = self.bins[1] - self.bins[0]
73
71
 
74
- self.metric = self.cfg.hyperparams['metric']
72
+ self.metric = self.cfg.hyperparams["metric"]
75
73
 
76
74
  def set_device(self, device: str):
77
75
  self.device = device
@@ -89,7 +87,6 @@ class TrainerFinetune(BaseEstimator):
89
87
  self.metric = None
90
88
 
91
89
  def train(self, x_train: np.ndarray, y_train: np.ndarray, x_val: np.ndarray, y_val: np.ndarray):
92
-
93
90
  self.preprocessor.fit(x_train, y_train)
94
91
 
95
92
  x_train_transformed = self.preprocessor.transform_X(x_train)
@@ -97,11 +94,11 @@ class TrainerFinetune(BaseEstimator):
97
94
 
98
95
  dataset_train_generator = DatasetFinetuneGenerator(
99
96
  self.cfg,
100
- x = x_train_transformed,
101
- y = y_train_transformed,
102
- task = self.cfg.task,
103
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
104
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
97
+ x=x_train_transformed,
98
+ y=y_train_transformed,
99
+ task=self.cfg.task,
100
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
101
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
105
102
  rng=self.rng,
106
103
  )
107
104
 
@@ -114,8 +111,7 @@ class TrainerFinetune(BaseEstimator):
114
111
 
115
112
  start_time = time.time()
116
113
 
117
- for epoch in range(1, self.cfg.hyperparams['max_epochs']+1):
118
-
114
+ for epoch in range(1, self.cfg.hyperparams["max_epochs"] + 1):
119
115
  dataset_train = next(dataset_train_generator)
120
116
  loader_train = self.make_loader(dataset_train, training=True)
121
117
  self.model.train()
@@ -123,31 +119,39 @@ class TrainerFinetune(BaseEstimator):
123
119
  prediction_metrics_tracker = PredictionMetricsTracker(task=self.cfg.task, preprocessor=self.preprocessor)
124
120
 
125
121
  for batch in loader_train:
126
-
127
- with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
128
-
129
- x_support = batch['x_support'].to(self.device, non_blocking=True)
130
- y_support = batch['y_support'].to(self.device, non_blocking=True)
131
- x_query = batch['x_query'].to(self.device, non_blocking=True)
132
- y_query = batch['y_query'].to(self.device, non_blocking=True)
133
- padding_features = batch['padding_features'].to(self.device, non_blocking=True)
134
- padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
135
- padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
122
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
123
+ x_support = batch["x_support"].to(self.device, non_blocking=True)
124
+ y_support = batch["y_support"].to(self.device, non_blocking=True)
125
+ x_query = batch["x_query"].to(self.device, non_blocking=True)
126
+ y_query = batch["y_query"].to(self.device, non_blocking=True)
127
+ padding_features = batch["padding_features"].to(self.device, non_blocking=True)
128
+ padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
129
+ padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
136
130
 
137
131
  # Convert numerical y_support to bin ids
138
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
132
+ if (
133
+ self.cfg.task == Task.REGRESSION
134
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
135
+ ):
139
136
  y_support = torch.bucketize(y_support, self.bins) - 1
140
- y_support = torch.clamp(y_support, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
137
+ y_support = torch.clamp(y_support, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
141
138
  y_query_bin_ids = torch.bucketize(y_query, self.bins) - 1
142
- y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
139
+ y_query_bin_ids = torch.clamp(y_query_bin_ids, 0, self.cfg.hyperparams["dim_output"] - 1).to(
140
+ torch.int64
141
+ )
143
142
 
144
143
  if self.cfg.model_name == ModelName.TABPFN:
145
144
  y_hat = self.model(x_support, y_support, x_query, task=self.cfg.task).squeeze(-1)
146
145
  elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
147
- y_hat = self.model(x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query)
146
+ y_hat = self.model(
147
+ x_support, y_support, x_query, padding_features, padding_obs_support, padding_obs_query
148
+ )
148
149
 
149
150
  # Convert numerical y_query to bin ids
150
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
151
+ if (
152
+ self.cfg.task == Task.REGRESSION
153
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
154
+ ):
151
155
  loss = self.loss(y_hat, y_query_bin_ids)
152
156
  elif self.cfg.task == Task.CLASSIFICATION:
153
157
  # for b in range(y_support.shape[0]):
@@ -163,7 +167,10 @@ class TrainerFinetune(BaseEstimator):
163
167
  self.scaler.update()
164
168
 
165
169
  # Convert bin id predictions to numerical values
166
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
170
+ if (
171
+ self.cfg.task == Task.REGRESSION
172
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
173
+ ):
167
174
  y_hat = torch.argmax(y_hat, dim=-1)
168
175
  y_hat = self.bins[y_hat] + self.bin_width / 2
169
176
 
@@ -187,19 +194,24 @@ class TrainerFinetune(BaseEstimator):
187
194
  logger.info("Early stopping")
188
195
  break
189
196
 
190
- if self.cfg.hyperparams["budget"] is not None and self.cfg.hyperparams["budget"] > 0 and time.time() - start_time > self.cfg.hyperparams["budget"]:
197
+ if (
198
+ self.cfg.hyperparams["budget"] is not None
199
+ and self.cfg.hyperparams["budget"] > 0
200
+ and time.time() - start_time > self.cfg.hyperparams["budget"]
201
+ ):
191
202
  logger.info("Time limit reached")
192
203
  break
193
204
 
194
- if epoch < self.cfg.hyperparams['warmup_steps']:
205
+ if epoch < self.cfg.hyperparams["warmup_steps"]:
195
206
  self.scheduler_warmup.step()
196
207
  else:
197
208
  self.scheduler_reduce_on_plateau.step(metrics_valid.loss)
198
209
 
199
210
  self.checkpoint.set_to_best(self.model)
200
211
 
201
- def evaluate(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray) -> PredictionMetrics:
202
-
212
+ def evaluate(
213
+ self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray, y_query: np.ndarray
214
+ ) -> PredictionMetrics:
203
215
  self.model.eval()
204
216
 
205
217
  x_support_transformed = self.preprocessor.transform_X(x_support)
@@ -209,12 +221,12 @@ class TrainerFinetune(BaseEstimator):
209
221
 
210
222
  dataset = DatasetFinetune(
211
223
  self.cfg,
212
- x_support = x_support_transformed,
213
- y_support = y_support_transformed,
214
- x_query = x_query_transformed,
215
- y_query = y_query,
216
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
217
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
224
+ x_support=x_support_transformed,
225
+ y_support=y_support_transformed,
226
+ x_query=x_query_transformed,
227
+ y_query=y_query,
228
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
229
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
218
230
  rng=self.rng,
219
231
  )
220
232
 
@@ -223,21 +235,22 @@ class TrainerFinetune(BaseEstimator):
223
235
 
224
236
  with torch.no_grad():
225
237
  for batch in loader:
226
-
227
- with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
228
-
229
- x_s = batch['x_support'].to(self.device, non_blocking=True)
230
- y_s = batch['y_support'].to(self.device, non_blocking=True)
231
- x_q = batch['x_query'].to(self.device, non_blocking=True)
232
- y_q = batch['y_query'].to(self.device, non_blocking=True)
233
- padding_features = batch['padding_features'].to(self.device, non_blocking=True)
234
- padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
235
- padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
238
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
239
+ x_s = batch["x_support"].to(self.device, non_blocking=True)
240
+ y_s = batch["y_support"].to(self.device, non_blocking=True)
241
+ x_q = batch["x_query"].to(self.device, non_blocking=True)
242
+ y_q = batch["y_query"].to(self.device, non_blocking=True)
243
+ padding_features = batch["padding_features"].to(self.device, non_blocking=True)
244
+ padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
245
+ padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
236
246
 
237
247
  # Convert numerical y_support to bin ids
238
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
248
+ if (
249
+ self.cfg.task == Task.REGRESSION
250
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
251
+ ):
239
252
  y_s = torch.bucketize(y_s, self.bins) - 1
240
- y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
253
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
241
254
 
242
255
  if self.cfg.model_name == ModelName.TABPFN:
243
256
  y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
@@ -245,7 +258,10 @@ class TrainerFinetune(BaseEstimator):
245
258
  y_hat = self.model(x_s, y_s, x_q, padding_features, padding_obs_support, padding_obs_query)
246
259
 
247
260
  # Convert bin id predictions to numerical values
248
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
261
+ if (
262
+ self.cfg.task == Task.REGRESSION
263
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
264
+ ):
249
265
  y_hat = torch.argmax(y_hat, dim=-1)
250
266
  y_hat = self.bins[y_hat] + self.bin_width / 2
251
267
 
@@ -255,21 +271,19 @@ class TrainerFinetune(BaseEstimator):
255
271
  metrics_eval = prediction_metrics_tracker.get_metrics()
256
272
  return metrics_eval
257
273
 
258
-
259
274
  def predict(self, x_support: np.ndarray, y_support: np.ndarray, x_query: np.ndarray) -> np.ndarray:
260
-
261
275
  x_support_transformed = self.preprocessor.transform_X(x_support)
262
276
  x_query_transformed = self.preprocessor.transform_X(x_query)
263
277
  y_support_transformed = self.preprocessor.transform_y(y_support)
264
278
 
265
279
  dataset = DatasetFinetune(
266
280
  self.cfg,
267
- x_support = x_support_transformed,
268
- y_support = y_support_transformed,
269
- x_query = x_query_transformed,
270
- y_query = None,
271
- max_samples_support = self.cfg.hyperparams['max_samples_support'],
272
- max_samples_query = self.cfg.hyperparams['max_samples_query'],
281
+ x_support=x_support_transformed,
282
+ y_support=y_support_transformed,
283
+ x_query=x_query_transformed,
284
+ y_query=None,
285
+ max_samples_support=self.cfg.hyperparams["max_samples_support"],
286
+ max_samples_query=self.cfg.hyperparams["max_samples_query"],
273
287
  rng=self.rng,
274
288
  )
275
289
 
@@ -280,20 +294,21 @@ class TrainerFinetune(BaseEstimator):
280
294
 
281
295
  with torch.no_grad():
282
296
  for batch in loader:
283
-
284
- with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams['precision'])):
285
-
286
- x_s = batch['x_support'].to(self.device, non_blocking=True)
287
- y_s = batch['y_support'].to(self.device, non_blocking=True)
288
- x_q = batch['x_query'].to(self.device, non_blocking=True)
289
- padding_features = batch['padding_features'].to(self.device, non_blocking=True)
290
- padding_obs_support = batch['padding_obs_support'].to(self.device, non_blocking=True)
291
- padding_obs_query = batch['padding_obs_query'].to(self.device, non_blocking=True)
297
+ with torch.autocast(device_type=self.device, dtype=getattr(torch, self.cfg.hyperparams["precision"])):
298
+ x_s = batch["x_support"].to(self.device, non_blocking=True)
299
+ y_s = batch["y_support"].to(self.device, non_blocking=True)
300
+ x_q = batch["x_query"].to(self.device, non_blocking=True)
301
+ padding_features = batch["padding_features"].to(self.device, non_blocking=True)
302
+ padding_obs_support = batch["padding_obs_support"].to(self.device, non_blocking=True)
303
+ padding_obs_query = batch["padding_obs_query"].to(self.device, non_blocking=True)
292
304
 
293
305
  # Convert numerical y_support to bin ids
294
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
306
+ if (
307
+ self.cfg.task == Task.REGRESSION
308
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
309
+ ):
295
310
  y_s = torch.bucketize(y_s, self.bins) - 1
296
- y_s = torch.clamp(y_s, 0, self.cfg.hyperparams['dim_output']-1).to(torch.int64)
311
+ y_s = torch.clamp(y_s, 0, self.cfg.hyperparams["dim_output"] - 1).to(torch.int64)
297
312
 
298
313
  if self.cfg.model_name == ModelName.TABPFN:
299
314
  y_hat = self.model(x_s, y_s, x_q, task=self.cfg.task).squeeze(-1)
@@ -303,7 +318,10 @@ class TrainerFinetune(BaseEstimator):
303
318
  y_hat = y_hat[0].float().cpu().numpy()
304
319
 
305
320
  # Convert bin id predictions to numerical values
306
- if self.cfg.task == Task.REGRESSION and self.cfg.hyperparams['regression_loss'] == LossName.CROSS_ENTROPY:
321
+ if (
322
+ self.cfg.task == Task.REGRESSION
323
+ and self.cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY
324
+ ):
307
325
  y_hat = np.argmax(y_hat, axis=-1)
308
326
  y_hat = (self.bins[y_hat] + self.bin_width / 2).cpu().numpy()
309
327
 
@@ -314,13 +332,10 @@ class TrainerFinetune(BaseEstimator):
314
332
 
315
333
  return y_pred
316
334
 
317
-
318
335
  def load_params(self, path):
319
336
  self.model.load_state_dict(torch.load(path))
320
337
 
321
-
322
338
  def make_loader(self, dataset: torch.utils.data.Dataset, training: bool) -> torch.utils.data.DataLoader:
323
-
324
339
  if self.cfg.model_name == ModelName.TABPFN:
325
340
  pad_to_max_features = True
326
341
  elif self.cfg.model_name in [ModelName.TAB2D, ModelName.TAB2D_COL_ROW, ModelName.TAB2D_SDPA]:
@@ -336,16 +351,14 @@ class TrainerFinetune(BaseEstimator):
336
351
  num_workers=0,
337
352
  drop_last=False,
338
353
  collate_fn=CollatorWithPadding(
339
- max_features=self.cfg.hyperparams['dim_embedding'],
340
- pad_to_max_features=pad_to_max_features
354
+ max_features=self.cfg.hyperparams["dim_embedding"], pad_to_max_features=pad_to_max_features
341
355
  ),
342
356
  )
343
357
 
344
-
345
358
  def log_start_metrics(self, metrics_valid: PredictionMetrics):
346
-
347
359
  if self.cfg.task == Task.REGRESSION:
348
- logger.info((
360
+ logger.info(
361
+ (
349
362
  f"Epoch 000 "
350
363
  f"| Train MSE: -.---- "
351
364
  f"| Train MAE: -.---- "
@@ -353,21 +366,24 @@ class TrainerFinetune(BaseEstimator):
353
366
  f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
354
367
  f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
355
368
  f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
356
- ))
369
+ )
370
+ )
357
371
 
358
372
  elif self.cfg.task == Task.CLASSIFICATION:
359
- logger.info((
373
+ logger.info(
374
+ (
360
375
  f"Epoch 000 "
361
376
  f"| Train CE: -.---- "
362
377
  f"| Train acc: -.---- "
363
378
  f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
364
379
  f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
365
- ))
380
+ )
381
+ )
366
382
 
367
383
  def log_metrics(self, epoch: int, metrics_train: PredictionMetrics, metrics_valid: PredictionMetrics):
368
-
369
384
  if self.cfg.task == Task.REGRESSION:
370
- logger.info((
385
+ logger.info(
386
+ (
371
387
  f"Epoch {epoch:03d} "
372
388
  f"| Train MSE: {metrics_train.metrics[MetricName.MSE]:.4f} "
373
389
  f"| Train MAE: {metrics_train.metrics[MetricName.MAE]:.4f} "
@@ -375,12 +391,15 @@ class TrainerFinetune(BaseEstimator):
375
391
  f"| Val MSE: {metrics_valid.metrics[MetricName.MSE]:.4f} "
376
392
  f"| Val MAE: {metrics_valid.metrics[MetricName.MAE]:.4f} "
377
393
  f"| Val r2: {metrics_valid.metrics[MetricName.R2]:.4f}"
378
- ))
394
+ )
395
+ )
379
396
  elif self.cfg.task == Task.CLASSIFICATION:
380
- logger.info((
397
+ logger.info(
398
+ (
381
399
  f"Epoch {epoch:03d} "
382
400
  f"| Train CE: {metrics_train.metrics[MetricName.LOG_LOSS]:.4f} "
383
401
  f"| Train acc: {metrics_train.metrics[MetricName.ACCURACY]:.4f} "
384
402
  f"| Val CE: {metrics_valid.metrics[MetricName.LOG_LOSS]:.4f} "
385
403
  f"| Val acc: {metrics_valid.metrics[MetricName.ACCURACY]:.4f}"
386
- ))
404
+ )
405
+ )
@@ -1 +1 @@
1
- # Data processing modules for MitraModel
1
+ # Data processing modules for MitraModel
@@ -1,23 +1,19 @@
1
1
  import torch
2
2
 
3
3
 
4
- class CollatorWithPadding():
5
-
4
+ class CollatorWithPadding:
6
5
  def __init__(
7
- self,
8
- max_features: int,
9
- pad_to_max_features: bool,
10
- ) -> None:
11
-
6
+ self,
7
+ max_features: int,
8
+ pad_to_max_features: bool,
9
+ ) -> None:
12
10
  self.max_features = max_features
13
11
  self.pad_to_max_features = pad_to_max_features
14
12
 
15
-
16
13
  def __call__(self, batch: list[dict[str, torch.Tensor]]) -> dict[str, torch.Tensor]:
17
-
18
- max_support_samples = max(dataset['x_support'].shape[0] for dataset in batch)
19
- max_query_samples = max(dataset['x_query'].shape[0] for dataset in batch)
20
- max_features = max(dataset['x_support'].shape[1] for dataset in batch)
14
+ max_support_samples = max(dataset["x_support"].shape[0] for dataset in batch)
15
+ max_query_samples = max(dataset["x_query"].shape[0] for dataset in batch)
16
+ max_features = max(dataset["x_support"].shape[1] for dataset in batch)
21
17
 
22
18
  if self.pad_to_max_features:
23
19
  max_features = self.max_features
@@ -25,22 +21,30 @@ class CollatorWithPadding():
25
21
  batch_size = len(batch)
26
22
 
27
23
  tensor_dict = {
28
- 'x_support': torch.zeros((batch_size, max_support_samples, max_features), dtype=batch[0]['x_support'].dtype),
29
- 'y_support': torch.full((batch_size, max_support_samples), fill_value=-100, dtype=batch[0]['y_support'].dtype),
30
- 'x_query': torch.zeros((batch_size, max_query_samples, max_features), dtype=batch[0]['x_query'].dtype),
31
- 'y_query': torch.full((batch_size, max_query_samples), fill_value=-100, dtype=batch[0]['y_query'].dtype),
32
- 'padding_features': torch.ones((batch_size, max_features), dtype=torch.bool),
33
- 'padding_obs_support': torch.ones((batch_size, max_support_samples), dtype=torch.bool),
34
- 'padding_obs_query': torch.ones((batch_size, max_query_samples), dtype=torch.bool),
24
+ "x_support": torch.zeros(
25
+ (batch_size, max_support_samples, max_features), dtype=batch[0]["x_support"].dtype
26
+ ),
27
+ "y_support": torch.full(
28
+ (batch_size, max_support_samples), fill_value=-100, dtype=batch[0]["y_support"].dtype
29
+ ),
30
+ "x_query": torch.zeros((batch_size, max_query_samples, max_features), dtype=batch[0]["x_query"].dtype),
31
+ "y_query": torch.full((batch_size, max_query_samples), fill_value=-100, dtype=batch[0]["y_query"].dtype),
32
+ "padding_features": torch.ones((batch_size, max_features), dtype=torch.bool),
33
+ "padding_obs_support": torch.ones((batch_size, max_support_samples), dtype=torch.bool),
34
+ "padding_obs_query": torch.ones((batch_size, max_query_samples), dtype=torch.bool),
35
35
  }
36
36
 
37
37
  for i, dataset in enumerate(batch):
38
- tensor_dict['x_support'][i, :dataset['x_support'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_support']
39
- tensor_dict['y_support'][i, :dataset['y_support'].shape[0]] = dataset['y_support']
40
- tensor_dict['x_query'][i, :dataset['x_query'].shape[0], :dataset['x_support'].shape[1]] = dataset['x_query']
41
- tensor_dict['y_query'][i, :dataset['y_query'].shape[0]] = dataset['y_query']
42
- tensor_dict['padding_features'][i, :dataset['x_support'].shape[1]] = False
43
- tensor_dict['padding_obs_support'][i, :dataset['x_support'].shape[0]] = False
44
- tensor_dict['padding_obs_query'][i, :dataset['x_query'].shape[0]] = False
38
+ tensor_dict["x_support"][i, : dataset["x_support"].shape[0], : dataset["x_support"].shape[1]] = dataset[
39
+ "x_support"
40
+ ]
41
+ tensor_dict["y_support"][i, : dataset["y_support"].shape[0]] = dataset["y_support"]
42
+ tensor_dict["x_query"][i, : dataset["x_query"].shape[0], : dataset["x_support"].shape[1]] = dataset[
43
+ "x_query"
44
+ ]
45
+ tensor_dict["y_query"][i, : dataset["y_query"].shape[0]] = dataset["y_query"]
46
+ tensor_dict["padding_features"][i, : dataset["x_support"].shape[1]] = False
47
+ tensor_dict["padding_obs_support"][i, : dataset["x_support"].shape[0]] = False
48
+ tensor_dict["padding_obs_query"][i, : dataset["x_query"].shape[0]] = False
45
49
 
46
50
  return tensor_dict