autogluon.tabular 1.5.1b20260105__py3-none-any.whl → 1.5.1b20260117__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of autogluon.tabular might be problematic. Click here for more details.
- autogluon/tabular/__init__.py +1 -0
- autogluon/tabular/configs/config_helper.py +18 -6
- autogluon/tabular/configs/feature_generator_presets.py +3 -1
- autogluon/tabular/configs/hyperparameter_configs.py +42 -9
- autogluon/tabular/configs/presets_configs.py +38 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2023.py +84 -14
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_2025.py +48 -48
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_cpu_2025_12_18.py +774 -1
- autogluon/tabular/configs/zeroshot/zeroshot_portfolio_gpu_2025_12_18.py +421 -1
- autogluon/tabular/experimental/_scikit_mixin.py +6 -2
- autogluon/tabular/experimental/_tabular_classifier.py +3 -1
- autogluon/tabular/experimental/_tabular_regressor.py +3 -1
- autogluon/tabular/experimental/plot_leaderboard.py +73 -19
- autogluon/tabular/learner/abstract_learner.py +160 -42
- autogluon/tabular/learner/default_learner.py +78 -22
- autogluon/tabular/models/__init__.py +2 -2
- autogluon/tabular/models/_utils/rapids_utils.py +3 -1
- autogluon/tabular/models/abstract/abstract_torch_model.py +2 -0
- autogluon/tabular/models/automm/automm_model.py +12 -3
- autogluon/tabular/models/automm/ft_transformer.py +5 -1
- autogluon/tabular/models/catboost/callbacks.py +2 -2
- autogluon/tabular/models/catboost/catboost_model.py +93 -29
- autogluon/tabular/models/catboost/catboost_softclass_utils.py +4 -1
- autogluon/tabular/models/catboost/catboost_utils.py +3 -1
- autogluon/tabular/models/ebm/ebm_model.py +8 -13
- autogluon/tabular/models/ebm/hyperparameters/parameters.py +1 -0
- autogluon/tabular/models/ebm/hyperparameters/searchspaces.py +1 -0
- autogluon/tabular/models/fastainn/callbacks.py +20 -3
- autogluon/tabular/models/fastainn/hyperparameters/searchspaces.py +11 -1
- autogluon/tabular/models/fastainn/quantile_helpers.py +10 -2
- autogluon/tabular/models/fastainn/tabular_nn_fastai.py +65 -18
- autogluon/tabular/models/fasttext/fasttext_model.py +3 -1
- autogluon/tabular/models/image_prediction/image_predictor.py +7 -2
- autogluon/tabular/models/knn/knn_model.py +41 -8
- autogluon/tabular/models/lgb/callbacks.py +32 -9
- autogluon/tabular/models/lgb/hyperparameters/searchspaces.py +3 -1
- autogluon/tabular/models/lgb/lgb_model.py +150 -34
- autogluon/tabular/models/lgb/lgb_utils.py +12 -4
- autogluon/tabular/models/lr/hyperparameters/searchspaces.py +5 -1
- autogluon/tabular/models/lr/lr_model.py +40 -10
- autogluon/tabular/models/lr/lr_rapids_model.py +22 -13
- autogluon/tabular/models/mitra/_internal/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +36 -40
- autogluon/tabular/models/mitra/_internal/config/config_run.py +2 -14
- autogluon/tabular/models/mitra/_internal/config/enums.py +27 -26
- autogluon/tabular/models/mitra/_internal/core/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +14 -21
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +10 -12
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +17 -32
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +12 -27
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +16 -21
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +130 -111
- autogluon/tabular/models/mitra/_internal/data/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/data/collator.py +30 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +18 -26
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +10 -7
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +70 -100
- autogluon/tabular/models/mitra/_internal/models/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/models/base.py +7 -10
- autogluon/tabular/models/mitra/_internal/models/embedding.py +46 -56
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +140 -120
- autogluon/tabular/models/mitra/_internal/utils/__init__.py +1 -1
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +3 -1
- autogluon/tabular/models/mitra/mitra_model.py +16 -11
- autogluon/tabular/models/mitra/sklearn_interface.py +178 -162
- autogluon/tabular/models/realmlp/realmlp_model.py +28 -15
- autogluon/tabular/models/rf/compilers/onnx.py +1 -1
- autogluon/tabular/models/rf/rf_model.py +45 -12
- autogluon/tabular/models/rf/rf_quantile.py +4 -2
- autogluon/tabular/models/tabdpt/tabdpt_model.py +8 -17
- autogluon/tabular/models/tabicl/tabicl_model.py +8 -1
- autogluon/tabular/models/tabm/_tabm_internal.py +6 -4
- autogluon/tabular/models/tabm/rtdl_num_embeddings.py +80 -127
- autogluon/tabular/models/tabm/tabm_model.py +8 -4
- autogluon/tabular/models/tabm/tabm_reference.py +53 -85
- autogluon/tabular/models/tabpfnmix/_internal/core/callbacks.py +7 -16
- autogluon/tabular/models/tabpfnmix/_internal/core/collator.py +16 -24
- autogluon/tabular/models/tabpfnmix/_internal/core/dataset_split.py +5 -7
- autogluon/tabular/models/tabpfnmix/_internal/core/enums.py +0 -2
- autogluon/tabular/models/tabpfnmix/_internal/core/get_loss.py +0 -1
- autogluon/tabular/models/tabpfnmix/_internal/core/get_optimizer.py +7 -18
- autogluon/tabular/models/tabpfnmix/_internal/core/get_scheduler.py +3 -14
- autogluon/tabular/models/tabpfnmix/_internal/core/trainer_finetune.py +79 -64
- autogluon/tabular/models/tabpfnmix/_internal/core/y_transformer.py +3 -5
- autogluon/tabular/models/tabpfnmix/_internal/data/dataset_finetune.py +17 -30
- autogluon/tabular/models/tabpfnmix/_internal/data/preprocessor.py +15 -35
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/embedding.py +21 -38
- autogluon/tabular/models/tabpfnmix/_internal/models/foundation/foundation_transformer.py +33 -51
- autogluon/tabular/models/tabpfnmix/_internal/results/prediction_metrics.py +4 -4
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_classifier.py +32 -12
- autogluon/tabular/models/tabpfnmix/_internal/tabpfnmix_regressor.py +32 -13
- autogluon/tabular/models/tabpfnmix/tabpfnmix_model.py +55 -19
- autogluon/tabular/models/tabpfnv2/tabpfnv2_5_model.py +21 -48
- autogluon/tabular/models/tabprep/prep_mixin.py +34 -26
- autogluon/tabular/models/tabular_nn/compilers/onnx.py +36 -8
- autogluon/tabular/models/tabular_nn/torch/tabular_nn_torch.py +130 -36
- autogluon/tabular/models/tabular_nn/torch/tabular_torch_dataset.py +8 -4
- autogluon/tabular/models/tabular_nn/torch/torch_network_modules.py +26 -5
- autogluon/tabular/models/tabular_nn/utils/categorical_encoders.py +41 -24
- autogluon/tabular/models/tabular_nn/utils/data_preprocessor.py +33 -8
- autogluon/tabular/models/tabular_nn/utils/nn_architecture_utils.py +21 -6
- autogluon/tabular/models/xgboost/callbacks.py +9 -3
- autogluon/tabular/models/xgboost/xgboost_model.py +59 -11
- autogluon/tabular/models/xt/xt_model.py +1 -0
- autogluon/tabular/predictor/interpretable_predictor.py +3 -1
- autogluon/tabular/predictor/predictor.py +409 -128
- autogluon/tabular/registry/__init__.py +1 -1
- autogluon/tabular/registry/_ag_model_registry.py +4 -5
- autogluon/tabular/registry/_model_registry.py +1 -0
- autogluon/tabular/testing/fit_helper.py +55 -15
- autogluon/tabular/testing/generate_datasets.py +1 -1
- autogluon/tabular/testing/model_fit_helper.py +10 -4
- autogluon/tabular/trainer/abstract_trainer.py +644 -230
- autogluon/tabular/trainer/auto_trainer.py +19 -8
- autogluon/tabular/trainer/model_presets/presets.py +33 -9
- autogluon/tabular/trainer/model_presets/presets_distill.py +16 -2
- autogluon/tabular/version.py +1 -1
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/METADATA +27 -27
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/RECORD +127 -135
- autogluon/tabular/models/tabpfnv2/rfpfn/__init__.py +0 -20
- autogluon/tabular/models/tabpfnv2/rfpfn/configs.py +0 -40
- autogluon/tabular/models/tabpfnv2/rfpfn/scoring_utils.py +0 -201
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_decision_tree_tabpfn.py +0 -1464
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_based_random_forest_tabpfn.py +0 -747
- autogluon/tabular/models/tabpfnv2/rfpfn/sklearn_compat.py +0 -863
- autogluon/tabular/models/tabpfnv2/rfpfn/utils.py +0 -106
- autogluon/tabular/models/tabpfnv2/tabpfnv2_model.py +0 -466
- /autogluon.tabular-1.5.1b20260105-py3.11-nspkg.pth → /autogluon.tabular-1.5.1b20260117-py3.11-nspkg.pth +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/WHEEL +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/LICENSE +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/licenses/NOTICE +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/namespace_packages.txt +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/top_level.txt +0 -0
- {autogluon_tabular-1.5.1b20260105.dist-info → autogluon_tabular-1.5.1b20260117.dist-info}/zip-safe +0 -0
|
@@ -1,18 +1,19 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
+
import os
|
|
3
4
|
from dataclasses import dataclass
|
|
4
5
|
from pathlib import Path
|
|
5
6
|
from typing import Optional
|
|
6
|
-
import yaml
|
|
7
|
-
import os
|
|
8
7
|
|
|
9
8
|
import torch
|
|
9
|
+
import yaml
|
|
10
10
|
from omegaconf import DictConfig, OmegaConf
|
|
11
11
|
|
|
12
|
-
from ..._internal.config.enums import GeneratorName,
|
|
12
|
+
from ..._internal.config.enums import GeneratorName, LossName, ModelName, Task
|
|
13
|
+
|
|
13
14
|
|
|
14
15
|
@dataclass
|
|
15
|
-
class ConfigData
|
|
16
|
+
class ConfigData:
|
|
16
17
|
generator: GeneratorName
|
|
17
18
|
min_samples_support: int
|
|
18
19
|
max_samples_support: int
|
|
@@ -26,37 +27,38 @@ class ConfigData():
|
|
|
26
27
|
task: Task
|
|
27
28
|
|
|
28
29
|
def __post_init__(self):
|
|
29
|
-
|
|
30
30
|
assert self.min_samples_support <= self.max_samples_support
|
|
31
31
|
assert self.min_features <= self.max_features
|
|
32
32
|
|
|
33
|
+
|
|
33
34
|
@dataclass
|
|
34
|
-
class ConfigModel
|
|
35
|
+
class ConfigModel:
|
|
35
36
|
name: ModelName
|
|
36
37
|
hyperparams: dict
|
|
37
38
|
|
|
38
39
|
|
|
39
40
|
@dataclass
|
|
40
|
-
class ConfigPreprocessing
|
|
41
|
+
class ConfigPreprocessing:
|
|
41
42
|
use_quantile_transformer: bool
|
|
42
43
|
use_feature_count_scaling: bool
|
|
43
44
|
|
|
45
|
+
|
|
44
46
|
@dataclass
|
|
45
|
-
class ConfigGradScaler
|
|
47
|
+
class ConfigGradScaler:
|
|
46
48
|
enabled: bool
|
|
47
49
|
scale_init: float
|
|
48
50
|
scale_min: float
|
|
49
51
|
growth_interval: int
|
|
50
52
|
|
|
51
|
-
|
|
52
53
|
def __post_init__(self):
|
|
53
54
|
assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
|
|
54
55
|
assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
|
|
55
56
|
assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
|
|
56
57
|
assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
|
|
57
58
|
|
|
59
|
+
|
|
58
60
|
@dataclass
|
|
59
|
-
class ConfigOptim
|
|
61
|
+
class ConfigOptim:
|
|
60
62
|
steps: int
|
|
61
63
|
log_every_n_steps: int
|
|
62
64
|
eval_every_n_steps: int
|
|
@@ -80,42 +82,35 @@ class ConfigOptim():
|
|
|
80
82
|
|
|
81
83
|
@classmethod
|
|
82
84
|
def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
|
|
83
|
-
|
|
84
85
|
grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
|
|
85
|
-
cfg_dict: dict = OmegaConf.to_container(cfg_hydra)
|
|
86
|
+
cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
|
|
86
87
|
del cfg_dict["grad_scaler"]
|
|
87
88
|
|
|
88
89
|
regression_loss = LossName[cfg_dict["regression_loss"]]
|
|
89
90
|
del cfg_dict["regression_loss"]
|
|
90
91
|
|
|
91
|
-
return cls(
|
|
92
|
-
grad_scaler=grad_scaler,
|
|
93
|
-
regression_loss=regression_loss,
|
|
94
|
-
**cfg_dict
|
|
95
|
-
)
|
|
92
|
+
return cls(grad_scaler=grad_scaler, regression_loss=regression_loss, **cfg_dict)
|
|
96
93
|
|
|
97
94
|
def __post_init__(self):
|
|
98
95
|
assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
|
|
99
96
|
|
|
100
|
-
class ConfigSaveLoadMixin(yaml.YAMLObject):
|
|
101
97
|
|
|
98
|
+
class ConfigSaveLoadMixin(yaml.YAMLObject):
|
|
102
99
|
def save(self, path: Path) -> None:
|
|
103
|
-
|
|
104
100
|
path.parent.mkdir(parents=True, exist_ok=True)
|
|
105
101
|
|
|
106
|
-
with open(path,
|
|
102
|
+
with open(path, "w") as f:
|
|
107
103
|
yaml.dump(self, f, default_flow_style=False)
|
|
108
104
|
|
|
109
|
-
|
|
110
105
|
@classmethod
|
|
111
106
|
def load(cls, path: Path) -> Self:
|
|
112
|
-
|
|
113
|
-
with open(path, 'r') as f:
|
|
107
|
+
with open(path, "r") as f:
|
|
114
108
|
# It's unsafe, but not unsafer than the pickle module
|
|
115
109
|
config = yaml.unsafe_load(f)
|
|
116
110
|
|
|
117
111
|
return config
|
|
118
112
|
|
|
113
|
+
|
|
119
114
|
@dataclass
|
|
120
115
|
class ConfigPretrain(ConfigSaveLoadMixin):
|
|
121
116
|
run_name: str
|
|
@@ -141,8 +136,9 @@ class ConfigPretrain(ConfigSaveLoadMixin):
|
|
|
141
136
|
|
|
142
137
|
@classmethod
|
|
143
138
|
def from_hydra(cls, cfg_hydra: DictConfig):
|
|
144
|
-
|
|
145
|
-
|
|
139
|
+
assert not os.path.exists(cfg_hydra.output_dir), (
|
|
140
|
+
f"Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder."
|
|
141
|
+
)
|
|
146
142
|
|
|
147
143
|
output_dir = Path(cfg_hydra.output_dir)
|
|
148
144
|
|
|
@@ -160,11 +156,11 @@ class ConfigPretrain(ConfigSaveLoadMixin):
|
|
|
160
156
|
use_ddp=len(devices) > 1,
|
|
161
157
|
seed=cfg_hydra.seed,
|
|
162
158
|
workers_per_gpu=cfg_hydra.workers_per_gpu,
|
|
163
|
-
model
|
|
164
|
-
name
|
|
165
|
-
hyperparams
|
|
159
|
+
model=ConfigModel(
|
|
160
|
+
name=ModelName[cfg_hydra.model.name],
|
|
161
|
+
hyperparams=OmegaConf.to_container(cfg_hydra.model.hyperparams),
|
|
166
162
|
),
|
|
167
|
-
data
|
|
163
|
+
data=ConfigData(
|
|
168
164
|
generator=GeneratorName(cfg_hydra.data.generator),
|
|
169
165
|
min_samples_support=cfg_hydra.data.min_samples_support,
|
|
170
166
|
max_samples_support=cfg_hydra.data.max_samples_support,
|
|
@@ -175,16 +171,16 @@ class ConfigPretrain(ConfigSaveLoadMixin):
|
|
|
175
171
|
task=Task[cfg_hydra.data.task],
|
|
176
172
|
sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
|
|
177
173
|
sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
|
|
178
|
-
generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams),
|
|
174
|
+
generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
|
|
179
175
|
),
|
|
180
|
-
optim
|
|
181
|
-
preprocessing
|
|
182
|
-
load_from_file
|
|
183
|
-
load_path_x
|
|
184
|
-
load_path_y
|
|
185
|
-
save_file
|
|
186
|
-
save_file_only
|
|
187
|
-
save_path_x
|
|
188
|
-
save_path_y
|
|
189
|
-
number_of_runs
|
|
176
|
+
optim=ConfigOptim.from_hydra(cfg_hydra.optim),
|
|
177
|
+
preprocessing=ConfigPreprocessing(**cfg_hydra.preprocessing),
|
|
178
|
+
load_from_file=cfg_hydra.load_from_file,
|
|
179
|
+
load_path_x=cfg_hydra.load_path_x,
|
|
180
|
+
load_path_y=cfg_hydra.load_path_y,
|
|
181
|
+
save_file=cfg_hydra.save_file,
|
|
182
|
+
save_file_only=cfg_hydra.save_file_only,
|
|
183
|
+
save_path_x=cfg_hydra.save_path_x,
|
|
184
|
+
save_path_y=cfg_hydra.save_path_y,
|
|
185
|
+
number_of_runs=cfg_hydra.number_of_runs,
|
|
190
186
|
)
|
|
@@ -16,17 +16,5 @@ class ConfigRun(ConfigSaveLoadMixin):
|
|
|
16
16
|
hyperparams: dict
|
|
17
17
|
|
|
18
18
|
@classmethod
|
|
19
|
-
def create(
|
|
20
|
-
cls,
|
|
21
|
-
device: torch.device,
|
|
22
|
-
seed: int,
|
|
23
|
-
model_name: ModelName,
|
|
24
|
-
hyperparams: dict
|
|
25
|
-
) -> "ConfigRun":
|
|
26
|
-
|
|
27
|
-
return cls(
|
|
28
|
-
device=device,
|
|
29
|
-
seed=seed,
|
|
30
|
-
model_name=model_name,
|
|
31
|
-
hyperparams=hyperparams
|
|
32
|
-
)
|
|
19
|
+
def create(cls, device: torch.device, seed: int, model_name: ModelName, hyperparams: dict) -> "ConfigRun":
|
|
20
|
+
return cls(device=device, seed=seed, model_name=model_name, hyperparams=hyperparams)
|
|
@@ -5,10 +5,12 @@ try:
|
|
|
5
5
|
except ImportError:
|
|
6
6
|
# StrEnum is not available in Python < 3.11, so we create a compatible version
|
|
7
7
|
from enum import Enum
|
|
8
|
+
|
|
8
9
|
class StrEnum(str, Enum):
|
|
9
10
|
"""
|
|
10
11
|
Enum where members are also (and must be) strings
|
|
11
12
|
"""
|
|
13
|
+
|
|
12
14
|
def __new__(cls, value):
|
|
13
15
|
if not isinstance(value, str):
|
|
14
16
|
raise TypeError(f"{value!r} is not a string")
|
|
@@ -53,7 +55,7 @@ class Phase(StrEnum):
|
|
|
53
55
|
|
|
54
56
|
|
|
55
57
|
class ModelName(StrEnum):
|
|
56
|
-
PLACEHOLDER = "_placeholder_"
|
|
58
|
+
PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
|
|
57
59
|
FT_TRANSFORMER = "FT-Transformer"
|
|
58
60
|
TABPFN = "TabPFN"
|
|
59
61
|
FOUNDATION = "Foundation"
|
|
@@ -87,10 +89,10 @@ class ModelName(StrEnum):
|
|
|
87
89
|
|
|
88
90
|
|
|
89
91
|
class ModelClass(StrEnum):
|
|
90
|
-
BASE =
|
|
91
|
-
GBDT =
|
|
92
|
-
NN =
|
|
93
|
-
ICLT =
|
|
92
|
+
BASE = "base"
|
|
93
|
+
GBDT = "GBDT"
|
|
94
|
+
NN = "NN"
|
|
95
|
+
ICLT = "ICLT"
|
|
94
96
|
|
|
95
97
|
|
|
96
98
|
class DownstreamTask(StrEnum):
|
|
@@ -98,7 +100,6 @@ class DownstreamTask(StrEnum):
|
|
|
98
100
|
FINETUNE = "finetune"
|
|
99
101
|
|
|
100
102
|
|
|
101
|
-
|
|
102
103
|
class BenchmarkName(StrEnum):
|
|
103
104
|
DEBUG_CLASSIFICATION = "debug_classification"
|
|
104
105
|
DEBUG_REGRESSION = "debug_regression"
|
|
@@ -124,25 +125,25 @@ class BenchmarkOrigin(StrEnum):
|
|
|
124
125
|
|
|
125
126
|
|
|
126
127
|
class GeneratorName(StrEnum):
|
|
127
|
-
TABPFN =
|
|
128
|
-
TREE =
|
|
129
|
-
RANDOMFOREST =
|
|
130
|
-
NEIGHBOR =
|
|
131
|
-
MIX =
|
|
132
|
-
PERLIN =
|
|
133
|
-
MIX_7 =
|
|
134
|
-
MIX_6 =
|
|
135
|
-
MIX_5 =
|
|
136
|
-
MIX_5_GP =
|
|
137
|
-
MIX_4 =
|
|
138
|
-
MIX_4_AG =
|
|
139
|
-
LR =
|
|
140
|
-
POLY =
|
|
141
|
-
SAMPLE_RF =
|
|
142
|
-
SAMPLE_GP =
|
|
143
|
-
TABREPO =
|
|
144
|
-
MIX_4_TABREPO =
|
|
145
|
-
MIX_4_TABPFNV2 =
|
|
128
|
+
TABPFN = "tabpfn"
|
|
129
|
+
TREE = "tree"
|
|
130
|
+
RANDOMFOREST = "randomforest"
|
|
131
|
+
NEIGHBOR = "neighbor"
|
|
132
|
+
MIX = "mix"
|
|
133
|
+
PERLIN = "perlin"
|
|
134
|
+
MIX_7 = "mix_7"
|
|
135
|
+
MIX_6 = "mix_6"
|
|
136
|
+
MIX_5 = "mix_5"
|
|
137
|
+
MIX_5_GP = "mix_5_gp"
|
|
138
|
+
MIX_4 = "mix_4"
|
|
139
|
+
MIX_4_AG = "mix_4_ag"
|
|
140
|
+
LR = "lr"
|
|
141
|
+
POLY = "poly"
|
|
142
|
+
SAMPLE_RF = "sample_rf"
|
|
143
|
+
SAMPLE_GP = "sample_gp"
|
|
144
|
+
TABREPO = "tabrepo"
|
|
145
|
+
MIX_4_TABREPO = "mix_4_tabrepo"
|
|
146
|
+
MIX_4_TABPFNV2 = "mix_4_tabpfnv2"
|
|
146
147
|
|
|
147
148
|
|
|
148
149
|
class MetricName(StrEnum):
|
|
@@ -159,4 +160,4 @@ class MetricName(StrEnum):
|
|
|
159
160
|
class LossName(StrEnum):
|
|
160
161
|
CROSS_ENTROPY = "cross_entropy"
|
|
161
162
|
MSE = "mse"
|
|
162
|
-
MAE = "mae"
|
|
163
|
+
MAE = "mae"
|
|
@@ -1 +1 @@
|
|
|
1
|
-
# Core modules for MitraModel
|
|
1
|
+
# Core modules for MitraModel
|
|
@@ -2,10 +2,8 @@ import numpy as np
|
|
|
2
2
|
import torch
|
|
3
3
|
|
|
4
4
|
|
|
5
|
-
class EarlyStopping
|
|
6
|
-
|
|
7
|
-
def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
|
|
8
|
-
|
|
5
|
+
class EarlyStopping:
|
|
6
|
+
def __init__(self, patience=10, delta=0.0001, metric="log_loss"):
|
|
9
7
|
self.patience = patience
|
|
10
8
|
self.counter = 0
|
|
11
9
|
self.best_score = None
|
|
@@ -13,9 +11,7 @@ class EarlyStopping():
|
|
|
13
11
|
self.delta = delta
|
|
14
12
|
self.metric = metric
|
|
15
13
|
|
|
16
|
-
|
|
17
14
|
def __call__(self, val_loss):
|
|
18
|
-
|
|
19
15
|
# smaller is better for these metrics
|
|
20
16
|
if self.metric in ["log_loss", "mse", "mae", "rmse"]:
|
|
21
17
|
score = -val_loss
|
|
@@ -23,7 +19,9 @@ class EarlyStopping():
|
|
|
23
19
|
elif self.metric in ["accuracy", "roc_auc", "r2"]:
|
|
24
20
|
score = val_loss
|
|
25
21
|
else:
|
|
26
|
-
raise ValueError(
|
|
22
|
+
raise ValueError(
|
|
23
|
+
f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2."
|
|
24
|
+
)
|
|
27
25
|
|
|
28
26
|
if self.best_score is None:
|
|
29
27
|
self.best_score = score
|
|
@@ -39,39 +37,34 @@ class EarlyStopping():
|
|
|
39
37
|
return self.early_stop
|
|
40
38
|
|
|
41
39
|
|
|
42
|
-
class Checkpoint
|
|
43
|
-
|
|
40
|
+
class Checkpoint:
|
|
44
41
|
def __init__(self):
|
|
45
42
|
self.curr_best_loss = np.inf
|
|
46
43
|
self.best_model: dict
|
|
47
|
-
|
|
44
|
+
|
|
48
45
|
def reset(self, net: torch.nn.Module):
|
|
49
46
|
self.curr_best_loss = np.inf
|
|
50
47
|
self.best_model = net.state_dict()
|
|
51
48
|
for key in self.best_model:
|
|
52
|
-
self.best_model[key] = self.best_model[key].to(
|
|
53
|
-
|
|
49
|
+
self.best_model[key] = self.best_model[key].to("cpu")
|
|
54
50
|
|
|
55
51
|
def __call__(self, net: torch.nn.Module, loss: float):
|
|
56
|
-
|
|
57
52
|
if loss < self.curr_best_loss:
|
|
58
53
|
self.curr_best_loss = loss
|
|
59
54
|
self.best_model = net.state_dict()
|
|
60
55
|
for key in self.best_model:
|
|
61
|
-
self.best_model[key] = self.best_model[key].to(
|
|
62
|
-
|
|
56
|
+
self.best_model[key] = self.best_model[key].to("cpu")
|
|
63
57
|
|
|
64
58
|
def set_to_best(self, net):
|
|
65
59
|
net.load_state_dict(self.best_model)
|
|
66
60
|
|
|
67
61
|
|
|
68
|
-
class EpochStatistics
|
|
69
|
-
|
|
62
|
+
class EpochStatistics:
|
|
70
63
|
def __init__(self) -> None:
|
|
71
64
|
self.n = 0
|
|
72
65
|
self.loss = 0
|
|
73
66
|
self.score = 0
|
|
74
|
-
|
|
67
|
+
|
|
75
68
|
def update(self, loss, score, n):
|
|
76
69
|
self.n += n
|
|
77
70
|
self.loss += loss * n
|
|
@@ -79,9 +72,9 @@ class EpochStatistics():
|
|
|
79
72
|
|
|
80
73
|
def get(self):
|
|
81
74
|
return self.loss / self.n, self.score / self.n
|
|
82
|
-
|
|
83
|
-
class TrackOutput():
|
|
84
75
|
|
|
76
|
+
|
|
77
|
+
class TrackOutput:
|
|
85
78
|
def __init__(self) -> None:
|
|
86
79
|
self.y_true: list[np.ndarray] = []
|
|
87
80
|
self.y_pred: list[np.ndarray] = []
|
|
@@ -91,4 +84,4 @@ class TrackOutput():
|
|
|
91
84
|
self.y_pred.append(y_pred)
|
|
92
85
|
|
|
93
86
|
def get(self):
|
|
94
|
-
return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
|
|
87
|
+
return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
|
|
@@ -7,13 +7,11 @@ from ..._internal.config.enums import LossName, Task
|
|
|
7
7
|
|
|
8
8
|
|
|
9
9
|
class CrossEntropyLossExtraBatch(torch.nn.Module):
|
|
10
|
-
|
|
11
10
|
def __init__(self, label_smoothing: float):
|
|
12
11
|
super().__init__()
|
|
13
12
|
|
|
14
13
|
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
|
15
14
|
|
|
16
|
-
|
|
17
15
|
def forward(self, input, target):
|
|
18
16
|
"""
|
|
19
17
|
Input has shape (batch_size, num_samples, num_classes)
|
|
@@ -22,26 +20,26 @@ class CrossEntropyLossExtraBatch(torch.nn.Module):
|
|
|
22
20
|
Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
|
|
23
21
|
"""
|
|
24
22
|
|
|
25
|
-
input = einops.rearrange(input,
|
|
26
|
-
target = einops.rearrange(target,
|
|
23
|
+
input = einops.rearrange(input, "b s c -> (b s) c")
|
|
24
|
+
target = einops.rearrange(target, "b s -> (b s)")
|
|
27
25
|
|
|
28
26
|
return self.loss(input, target)
|
|
29
27
|
|
|
30
|
-
def get_loss(cfg: ConfigRun):
|
|
31
28
|
|
|
32
|
-
|
|
29
|
+
def get_loss(cfg: ConfigRun):
|
|
30
|
+
if cfg.task == Task.REGRESSION and cfg.hyperparams["regression_loss"] == LossName.MSE:
|
|
33
31
|
return torch.nn.MSELoss()
|
|
34
|
-
elif cfg.task == Task.REGRESSION and cfg.hyperparams[
|
|
32
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams["regression_loss"] == LossName.MAE:
|
|
35
33
|
return torch.nn.L1Loss()
|
|
36
|
-
elif cfg.task == Task.REGRESSION and cfg.hyperparams[
|
|
37
|
-
return CrossEntropyLossExtraBatch(cfg.hyperparams[
|
|
34
|
+
elif cfg.task == Task.REGRESSION and cfg.hyperparams["regression_loss"] == LossName.CROSS_ENTROPY:
|
|
35
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams["label_smoothing"])
|
|
38
36
|
elif cfg.task == Task.CLASSIFICATION:
|
|
39
|
-
return CrossEntropyLossExtraBatch(cfg.hyperparams[
|
|
37
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams["label_smoothing"])
|
|
40
38
|
else:
|
|
41
39
|
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
|
42
40
|
|
|
43
|
-
def get_loss_pretrain(cfg: ConfigPretrain):
|
|
44
41
|
|
|
42
|
+
def get_loss_pretrain(cfg: ConfigPretrain):
|
|
45
43
|
if cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MSE:
|
|
46
44
|
return torch.nn.MSELoss()
|
|
47
45
|
elif cfg.data.task == Task.REGRESSION and cfg.optim.regression_loss == LossName.MAE:
|
|
@@ -51,4 +49,4 @@ def get_loss_pretrain(cfg: ConfigPretrain):
|
|
|
51
49
|
elif cfg.data.task == Task.CLASSIFICATION:
|
|
52
50
|
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
|
53
51
|
else:
|
|
54
|
-
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|
|
52
|
+
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|
|
@@ -5,44 +5,32 @@ from ..._internal.config.config_pretrain import ConfigPretrain
|
|
|
5
5
|
|
|
6
6
|
|
|
7
7
|
def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
|
|
8
|
-
|
|
9
8
|
optimizer: torch.optim.Optimizer
|
|
10
9
|
|
|
11
|
-
if hyperparams[
|
|
10
|
+
if hyperparams["optimizer"] == "adam":
|
|
12
11
|
optimizer = Adam(
|
|
13
|
-
model.parameters(),
|
|
14
|
-
lr=hyperparams['lr'],
|
|
15
|
-
betas=(0.9, 0.999),
|
|
16
|
-
weight_decay=hyperparams['weight_decay']
|
|
12
|
+
model.parameters(), lr=hyperparams["lr"], betas=(0.9, 0.999), weight_decay=hyperparams["weight_decay"]
|
|
17
13
|
)
|
|
18
|
-
elif hyperparams[
|
|
14
|
+
elif hyperparams["optimizer"] == "adamw":
|
|
19
15
|
optimizer = AdamW(
|
|
20
|
-
model.parameters(),
|
|
21
|
-
lr=hyperparams['lr'],
|
|
22
|
-
betas=(0.9, 0.999),
|
|
23
|
-
weight_decay=hyperparams['weight_decay']
|
|
24
|
-
)
|
|
25
|
-
elif hyperparams['optimizer'] == "sgd":
|
|
26
|
-
optimizer = SGD(
|
|
27
|
-
model.parameters(),
|
|
28
|
-
lr=hyperparams['lr'],
|
|
29
|
-
weight_decay=hyperparams['weight_decay']
|
|
16
|
+
model.parameters(), lr=hyperparams["lr"], betas=(0.9, 0.999), weight_decay=hyperparams["weight_decay"]
|
|
30
17
|
)
|
|
18
|
+
elif hyperparams["optimizer"] == "sgd":
|
|
19
|
+
optimizer = SGD(model.parameters(), lr=hyperparams["lr"], weight_decay=hyperparams["weight_decay"])
|
|
31
20
|
else:
|
|
32
21
|
raise ValueError("Optimizer not recognized")
|
|
33
|
-
|
|
22
|
+
|
|
34
23
|
return optimizer
|
|
35
24
|
|
|
36
25
|
|
|
37
26
|
def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
|
|
38
|
-
|
|
39
27
|
parameters = [(name, param) for name, param in model.named_parameters()]
|
|
40
28
|
|
|
41
29
|
parameters_with_weight_decay = []
|
|
42
30
|
parameters_without_weight_decay = []
|
|
43
31
|
|
|
44
32
|
for name, param in parameters:
|
|
45
|
-
if name.endswith("bias") or
|
|
33
|
+
if name.endswith("bias") or "norm" in name or "embedding" in name:
|
|
46
34
|
parameters_without_weight_decay.append(param)
|
|
47
35
|
else:
|
|
48
36
|
parameters_with_weight_decay.append(param)
|
|
@@ -53,26 +41,25 @@ def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch
|
|
|
53
41
|
]
|
|
54
42
|
|
|
55
43
|
optimizer = torch.optim.AdamW(
|
|
56
|
-
optimizer_parameters,
|
|
44
|
+
optimizer_parameters,
|
|
57
45
|
lr=cfg.optim.lr,
|
|
58
46
|
betas=(cfg.optim.beta1, cfg.optim.beta2),
|
|
59
|
-
weight_decay=cfg.optim.weight_decay
|
|
47
|
+
weight_decay=cfg.optim.weight_decay,
|
|
60
48
|
)
|
|
61
|
-
|
|
49
|
+
|
|
62
50
|
return optimizer
|
|
63
51
|
|
|
64
52
|
|
|
65
53
|
class GradScaler(torch.amp.GradScaler):
|
|
66
|
-
|
|
67
54
|
def __init__(
|
|
68
|
-
self,
|
|
55
|
+
self,
|
|
69
56
|
enabled: bool = True,
|
|
70
|
-
scale_init: float = 2
|
|
71
|
-
scale_min: float = 1
|
|
57
|
+
scale_init: float = 2.0**16,
|
|
58
|
+
scale_min: float = 1.0,
|
|
72
59
|
growth_interval: int = 2000,
|
|
73
|
-
device: str =
|
|
60
|
+
device: str = "cuda",
|
|
74
61
|
):
|
|
75
|
-
super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval)
|
|
62
|
+
super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
|
|
76
63
|
self._enabled = enabled
|
|
77
64
|
self.scale_min = scale_min
|
|
78
65
|
self.device = device
|
|
@@ -81,9 +68,7 @@ class GradScaler(torch.amp.GradScaler):
|
|
|
81
68
|
# We write scale=1 to log if the scaler is disabled
|
|
82
69
|
self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
|
|
83
70
|
|
|
84
|
-
|
|
85
71
|
def update(self):
|
|
86
|
-
|
|
87
72
|
if not self._enabled:
|
|
88
73
|
return
|
|
89
74
|
|
|
@@ -105,4 +90,4 @@ def move_optimizer_to(optim, device):
|
|
|
105
90
|
if isinstance(subparam, torch.Tensor):
|
|
106
91
|
subparam.data = subparam.data.to(device)
|
|
107
92
|
if subparam._grad is not None:
|
|
108
|
-
subparam._grad.data = subparam._grad.data.to(device)
|
|
93
|
+
subparam._grad.data = subparam._grad.data.to(device)
|
|
@@ -1,14 +1,15 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
from torch.optim.lr_scheduler import
|
|
2
|
+
from torch.optim.lr_scheduler import LinearLR, ReduceLROnPlateau
|
|
3
3
|
from transformers import get_constant_schedule_with_warmup
|
|
4
4
|
from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
|
|
5
5
|
|
|
6
6
|
from ..._internal.config.config_pretrain import ConfigPretrain
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
def get_scheduler(
|
|
10
|
-
|
|
11
|
-
|
|
9
|
+
def get_scheduler(
|
|
10
|
+
hyperparams: dict, optimizer: torch.optim.Optimizer
|
|
11
|
+
) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
|
|
12
|
+
warmup_steps = hyperparams["warmup_steps"]
|
|
12
13
|
|
|
13
14
|
# if warmup_steps > 0:
|
|
14
15
|
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
|
@@ -18,7 +19,7 @@ def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[
|
|
|
18
19
|
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
|
19
20
|
# optimizer, lambda step: 1.0
|
|
20
21
|
# )
|
|
21
|
-
|
|
22
|
+
|
|
22
23
|
if warmup_steps > 0:
|
|
23
24
|
scheduler_warmup = LinearLR(
|
|
24
25
|
optimizer,
|
|
@@ -29,39 +30,23 @@ def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[
|
|
|
29
30
|
else:
|
|
30
31
|
scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
|
|
31
32
|
|
|
32
|
-
if hyperparams[
|
|
33
|
+
if hyperparams["lr_scheduler"]:
|
|
33
34
|
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
|
34
|
-
optimizer,
|
|
35
|
-
patience=hyperparams['lr_scheduler_patience'],
|
|
36
|
-
min_lr=0,
|
|
37
|
-
factor=0.2
|
|
35
|
+
optimizer, patience=hyperparams["lr_scheduler_patience"], min_lr=0, factor=0.2
|
|
38
36
|
)
|
|
39
37
|
else:
|
|
40
38
|
# With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
|
|
41
|
-
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
|
42
|
-
optimizer,
|
|
43
|
-
patience=1000000000,
|
|
44
|
-
min_lr=0,
|
|
45
|
-
factor=0.2
|
|
46
|
-
)
|
|
39
|
+
scheduler_reduce_on_plateau = ReduceLROnPlateau(optimizer, patience=1000000000, min_lr=0, factor=0.2)
|
|
47
40
|
|
|
48
41
|
return scheduler_warmup, scheduler_reduce_on_plateau
|
|
49
42
|
|
|
50
43
|
|
|
51
44
|
def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
|
|
52
|
-
|
|
53
|
-
|
|
54
45
|
if cfg.optim.cosine_scheduler:
|
|
55
46
|
schedule = get_cosine_with_min_lr_schedule_with_warmup(
|
|
56
|
-
optimizer,
|
|
57
|
-
num_warmup_steps=cfg.optim.warmup_steps,
|
|
58
|
-
num_training_steps=cfg.optim.steps,
|
|
59
|
-
min_lr_rate=0.1
|
|
47
|
+
optimizer, num_warmup_steps=cfg.optim.warmup_steps, num_training_steps=cfg.optim.steps, min_lr_rate=0.1
|
|
60
48
|
)
|
|
61
49
|
else:
|
|
62
|
-
schedule = get_constant_schedule_with_warmup(
|
|
63
|
-
optimizer,
|
|
64
|
-
num_warmup_steps=cfg.optim.warmup_steps
|
|
65
|
-
)
|
|
50
|
+
schedule = get_constant_schedule_with_warmup(optimizer, num_warmup_steps=cfg.optim.warmup_steps)
|
|
66
51
|
|
|
67
|
-
return schedule
|
|
52
|
+
return schedule
|