autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- autogluon/tabular/models/__init__.py +1 -0
- autogluon/tabular/models/mitra/__init__.py +0 -0
- autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
- autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
- autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
- autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
- autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
- autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
- autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
- autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
- autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
- autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
- autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
- autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
- autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
- autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
- autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
- autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
- autogluon/tabular/models/mitra/mitra_model.py +214 -0
- autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
- autogluon/tabular/registry/_ag_model_registry.py +2 -0
- autogluon/tabular/version.py +1 -1
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
- /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
- {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -23,6 +23,7 @@ from .tabicl.tabicl_model import TabICLModel
|
|
23
23
|
from .tabm.tabm_model import TabMModel
|
24
24
|
from .tabpfnv2.tabpfnv2_model import TabPFNV2Model
|
25
25
|
from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
|
26
|
+
from .mitra.mitra_model import MitraModel
|
26
27
|
from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
|
27
28
|
from .text_prediction.text_prediction_v1_model import TextPredictorModel
|
28
29
|
from .xgboost.xgboost_model import XGBoostModel
|
File without changes
|
@@ -0,0 +1,190 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from pathlib import Path
|
5
|
+
from typing import Optional
|
6
|
+
import yaml
|
7
|
+
import os
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from omegaconf import DictConfig, OmegaConf
|
11
|
+
|
12
|
+
from ..._internal.config.enums import GeneratorName, ModelName, LossName, Task
|
13
|
+
|
14
|
+
@dataclass
|
15
|
+
class ConfigData():
|
16
|
+
generator: GeneratorName
|
17
|
+
min_samples_support: int
|
18
|
+
max_samples_support: int
|
19
|
+
n_samples_query: int
|
20
|
+
min_features: int
|
21
|
+
max_features: int
|
22
|
+
max_classes: int
|
23
|
+
sample_multinomial_categorical: bool
|
24
|
+
sample_multinomial_label: bool
|
25
|
+
generator_hyperparams: dict
|
26
|
+
task: Task
|
27
|
+
|
28
|
+
def __post_init__(self):
|
29
|
+
|
30
|
+
assert self.min_samples_support <= self.max_samples_support
|
31
|
+
assert self.min_features <= self.max_features
|
32
|
+
|
33
|
+
@dataclass
|
34
|
+
class ConfigModel():
|
35
|
+
name: ModelName
|
36
|
+
hyperparams: dict
|
37
|
+
|
38
|
+
|
39
|
+
@dataclass
|
40
|
+
class ConfigPreprocessing():
|
41
|
+
use_quantile_transformer: bool
|
42
|
+
use_feature_count_scaling: bool
|
43
|
+
|
44
|
+
@dataclass
|
45
|
+
class ConfigGradScaler():
|
46
|
+
enabled: bool
|
47
|
+
scale_init: float
|
48
|
+
scale_min: float
|
49
|
+
growth_interval: int
|
50
|
+
|
51
|
+
|
52
|
+
def __post_init__(self):
|
53
|
+
assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
|
54
|
+
assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
|
55
|
+
assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
|
56
|
+
assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
|
57
|
+
|
58
|
+
@dataclass
|
59
|
+
class ConfigOptim():
|
60
|
+
steps: int
|
61
|
+
log_every_n_steps: int
|
62
|
+
eval_every_n_steps: int
|
63
|
+
batch_size: int
|
64
|
+
gradient_accumulation_steps: int
|
65
|
+
lr: float
|
66
|
+
weight_decay: float
|
67
|
+
beta1: float
|
68
|
+
beta2: float
|
69
|
+
warmup_steps: int
|
70
|
+
cosine_scheduler: bool
|
71
|
+
max_grad_norm: float
|
72
|
+
label_smoothing: float
|
73
|
+
regression_loss: LossName
|
74
|
+
use_pretrained_weights: bool
|
75
|
+
path_to_weights: str
|
76
|
+
resume_states: bool
|
77
|
+
path_to_states: str
|
78
|
+
precision: str
|
79
|
+
grad_scaler: ConfigGradScaler
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
|
83
|
+
|
84
|
+
grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
|
85
|
+
cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
|
86
|
+
del cfg_dict["grad_scaler"]
|
87
|
+
|
88
|
+
regression_loss = LossName[cfg_dict["regression_loss"]]
|
89
|
+
del cfg_dict["regression_loss"]
|
90
|
+
|
91
|
+
return cls(
|
92
|
+
grad_scaler=grad_scaler,
|
93
|
+
regression_loss=regression_loss,
|
94
|
+
**cfg_dict
|
95
|
+
)
|
96
|
+
|
97
|
+
def __post_init__(self):
|
98
|
+
assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
|
99
|
+
|
100
|
+
class ConfigSaveLoadMixin(yaml.YAMLObject):
|
101
|
+
|
102
|
+
def save(self, path: Path) -> None:
|
103
|
+
|
104
|
+
path.parent.mkdir(parents=True, exist_ok=True)
|
105
|
+
|
106
|
+
with open(path, 'w') as f:
|
107
|
+
yaml.dump(self, f, default_flow_style=False)
|
108
|
+
|
109
|
+
|
110
|
+
@classmethod
|
111
|
+
def load(cls, path: Path) -> Self:
|
112
|
+
|
113
|
+
with open(path, 'r') as f:
|
114
|
+
# It's unsafe, but not unsafer than the pickle module
|
115
|
+
config = yaml.unsafe_load(f)
|
116
|
+
|
117
|
+
return config
|
118
|
+
|
119
|
+
@dataclass
|
120
|
+
class ConfigPretrain(ConfigSaveLoadMixin):
|
121
|
+
run_name: str
|
122
|
+
output_dir: Path
|
123
|
+
seed: int
|
124
|
+
devices: list[torch.device]
|
125
|
+
device: torch.device
|
126
|
+
max_cpus_per_device: Optional[int]
|
127
|
+
use_ddp: bool
|
128
|
+
workers_per_gpu: int
|
129
|
+
model: ConfigModel
|
130
|
+
data: ConfigData
|
131
|
+
optim: ConfigOptim
|
132
|
+
preprocessing: ConfigPreprocessing
|
133
|
+
load_from_file: bool
|
134
|
+
load_path_x: str
|
135
|
+
load_path_y: str
|
136
|
+
save_file: bool
|
137
|
+
save_file_only: bool
|
138
|
+
save_path_x: str
|
139
|
+
save_path_y: str
|
140
|
+
number_of_runs: int
|
141
|
+
|
142
|
+
@classmethod
|
143
|
+
def from_hydra(cls, cfg_hydra: DictConfig):
|
144
|
+
|
145
|
+
assert not os.path.exists(cfg_hydra.output_dir), f'Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder.'
|
146
|
+
|
147
|
+
output_dir = Path(cfg_hydra.output_dir)
|
148
|
+
|
149
|
+
devices = [torch.device(device) for device in cfg_hydra.devices]
|
150
|
+
|
151
|
+
# Initialize device to cpu, DDP will overwrite this
|
152
|
+
device = torch.device("cpu")
|
153
|
+
|
154
|
+
return cls(
|
155
|
+
run_name=cfg_hydra.run_name,
|
156
|
+
output_dir=output_dir,
|
157
|
+
devices=devices,
|
158
|
+
device=device,
|
159
|
+
max_cpus_per_device=cfg_hydra.max_cpus_per_device,
|
160
|
+
use_ddp=len(devices) > 1,
|
161
|
+
seed=cfg_hydra.seed,
|
162
|
+
workers_per_gpu=cfg_hydra.workers_per_gpu,
|
163
|
+
model = ConfigModel(
|
164
|
+
name = ModelName[cfg_hydra.model.name],
|
165
|
+
hyperparams = OmegaConf.to_container(cfg_hydra.model.hyperparams),
|
166
|
+
),
|
167
|
+
data = ConfigData(
|
168
|
+
generator=GeneratorName(cfg_hydra.data.generator),
|
169
|
+
min_samples_support=cfg_hydra.data.min_samples_support,
|
170
|
+
max_samples_support=cfg_hydra.data.max_samples_support,
|
171
|
+
n_samples_query=cfg_hydra.data.n_samples_query,
|
172
|
+
min_features=cfg_hydra.data.min_features,
|
173
|
+
max_features=cfg_hydra.data.max_features,
|
174
|
+
max_classes=cfg_hydra.data.max_classes,
|
175
|
+
task=Task[cfg_hydra.data.task],
|
176
|
+
sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
|
177
|
+
sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
|
178
|
+
generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
|
179
|
+
),
|
180
|
+
optim = ConfigOptim.from_hydra(cfg_hydra.optim),
|
181
|
+
preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
|
182
|
+
load_from_file = cfg_hydra.load_from_file,
|
183
|
+
load_path_x = cfg_hydra.load_path_x,
|
184
|
+
load_path_y = cfg_hydra.load_path_y,
|
185
|
+
save_file = cfg_hydra.save_file,
|
186
|
+
save_file_only = cfg_hydra.save_file_only,
|
187
|
+
save_path_x = cfg_hydra.save_path_x,
|
188
|
+
save_path_y = cfg_hydra.save_path_y,
|
189
|
+
number_of_runs = cfg_hydra.number_of_runs,
|
190
|
+
)
|
@@ -0,0 +1,32 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Self
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
|
9
|
+
from ..._internal.config.enums import ModelName
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class ConfigRun(ConfigSaveLoadMixin):
|
13
|
+
device: torch.device
|
14
|
+
seed: int
|
15
|
+
model_name: ModelName
|
16
|
+
hyperparams: dict
|
17
|
+
|
18
|
+
@classmethod
|
19
|
+
def create(
|
20
|
+
cls,
|
21
|
+
device: torch.device,
|
22
|
+
seed: int,
|
23
|
+
model_name: ModelName,
|
24
|
+
hyperparams: dict
|
25
|
+
) -> Self:
|
26
|
+
|
27
|
+
return cls(
|
28
|
+
device=device,
|
29
|
+
seed=seed,
|
30
|
+
model_name=model_name,
|
31
|
+
hyperparams=hyperparams
|
32
|
+
)
|
@@ -0,0 +1,145 @@
|
|
1
|
+
from enum import IntEnum, StrEnum
|
2
|
+
|
3
|
+
|
4
|
+
class Task(StrEnum):
|
5
|
+
CLASSIFICATION = "classification"
|
6
|
+
REGRESSION = "regression"
|
7
|
+
|
8
|
+
|
9
|
+
class FeatureType(StrEnum):
|
10
|
+
NUMERICAL = "numerical"
|
11
|
+
CATEGORICAL = "categorical"
|
12
|
+
MIXED = "mixed"
|
13
|
+
|
14
|
+
|
15
|
+
class SearchType(StrEnum):
|
16
|
+
DEFAULT = "default"
|
17
|
+
RANDOM = "random"
|
18
|
+
|
19
|
+
|
20
|
+
class DatasetSize(IntEnum):
|
21
|
+
SMALL = 1000
|
22
|
+
MEDIUM = 10000
|
23
|
+
LARGE = 50000
|
24
|
+
|
25
|
+
|
26
|
+
class DataSplit(StrEnum):
|
27
|
+
TRAIN = "train"
|
28
|
+
VALID = "valid"
|
29
|
+
TEST = "test"
|
30
|
+
|
31
|
+
|
32
|
+
class Phase(StrEnum):
|
33
|
+
TRAINING = "training"
|
34
|
+
VALIDATION = "validation"
|
35
|
+
TESTING = "testing"
|
36
|
+
|
37
|
+
|
38
|
+
class ModelName(StrEnum):
|
39
|
+
PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
|
40
|
+
FT_TRANSFORMER = "FT-Transformer"
|
41
|
+
TABPFN = "TabPFN"
|
42
|
+
FOUNDATION = "Foundation"
|
43
|
+
FOUNDATION_FLASH = "FoundationFlash"
|
44
|
+
TAB2D = "Tab2D"
|
45
|
+
TAB2D_COL_ROW = "Tab2D_COL_ROW"
|
46
|
+
TAB2D_SDPA = "Tab2D_SDPA"
|
47
|
+
SAINT = "SAINT"
|
48
|
+
MLP = "MLP"
|
49
|
+
MLP_RTDL = "MLP-rtdl"
|
50
|
+
RESNET = "Resnet"
|
51
|
+
RANDOM_FOREST = "RandomForest"
|
52
|
+
XGBOOST = "XGBoost"
|
53
|
+
CATBOOST = "CatBoost"
|
54
|
+
LIGHTGBM = "LightGBM"
|
55
|
+
GRADIENT_BOOSTING_TREE = "GradientBoostingTree"
|
56
|
+
HIST_GRADIENT_BOOSTING_TREE = "HistGradientBoostingTree"
|
57
|
+
LOGISTIC_REGRESSION = "LogisticRegression"
|
58
|
+
LINEAR_REGRESSION = "LinearRegression"
|
59
|
+
DECISION_TREE = "DecisionTree"
|
60
|
+
KNN = "KNN"
|
61
|
+
STG = "STG"
|
62
|
+
SVM = "SVM"
|
63
|
+
TABNET = "TabNet"
|
64
|
+
TABTRANSFORMER = "TabTransformer"
|
65
|
+
DEEPFM = "DeepFM"
|
66
|
+
VIME = "VIME"
|
67
|
+
DANET = "DANet"
|
68
|
+
NODE = "NODE"
|
69
|
+
AUTOGLUON = "AutoGluon"
|
70
|
+
|
71
|
+
|
72
|
+
class ModelClass(StrEnum):
|
73
|
+
BASE = 'base'
|
74
|
+
GBDT = 'GBDT'
|
75
|
+
NN = 'NN'
|
76
|
+
ICLT = 'ICLT'
|
77
|
+
|
78
|
+
|
79
|
+
class DownstreamTask(StrEnum):
|
80
|
+
ZEROSHOT = "zeroshot"
|
81
|
+
FINETUNE = "finetune"
|
82
|
+
|
83
|
+
|
84
|
+
|
85
|
+
class BenchmarkName(StrEnum):
|
86
|
+
DEBUG_CLASSIFICATION = "debug_classification"
|
87
|
+
DEBUG_REGRESSION = "debug_regression"
|
88
|
+
DEBUG_TABZILLA = "debug_tabzilla"
|
89
|
+
|
90
|
+
CATEGORICAL_CLASSIFICATION = "categorical_classification"
|
91
|
+
NUMERICAL_CLASSIFICATION = "numerical_classification"
|
92
|
+
CATEGORICAL_REGRESSION = "categorical_regression"
|
93
|
+
NUMERICAL_REGRESSION = "numerical_regression"
|
94
|
+
CATEGORICAL_CLASSIFICATION_LARGE = "categorical_classification_large"
|
95
|
+
NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
|
96
|
+
CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
|
97
|
+
NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
|
98
|
+
|
99
|
+
TABZILLA_HARD = "tabzilla_hard"
|
100
|
+
TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
|
101
|
+
TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
|
102
|
+
|
103
|
+
|
104
|
+
class BenchmarkOrigin(StrEnum):
|
105
|
+
TABZILLA = "tabzilla"
|
106
|
+
WHYTREES = "whytrees"
|
107
|
+
|
108
|
+
|
109
|
+
class GeneratorName(StrEnum):
|
110
|
+
TABPFN = 'tabpfn'
|
111
|
+
TREE = 'tree'
|
112
|
+
RANDOMFOREST = 'randomforest'
|
113
|
+
NEIGHBOR = 'neighbor'
|
114
|
+
MIX = 'mix'
|
115
|
+
PERLIN = 'perlin'
|
116
|
+
MIX_7 = 'mix_7'
|
117
|
+
MIX_6 = 'mix_6'
|
118
|
+
MIX_5 = 'mix_5'
|
119
|
+
MIX_5_GP = 'mix_5_gp'
|
120
|
+
MIX_4 = 'mix_4'
|
121
|
+
MIX_4_AG = 'mix_4_ag'
|
122
|
+
LR = 'lr'
|
123
|
+
POLY = 'poly'
|
124
|
+
SAMPLE_RF = 'sample_rf'
|
125
|
+
SAMPLE_GP = 'sample_gp'
|
126
|
+
TABREPO = 'tabrepo'
|
127
|
+
MIX_4_TABREPO = 'mix_4_tabrepo'
|
128
|
+
MIX_4_TABPFNV2 = 'mix_4_tabpfnv2'
|
129
|
+
|
130
|
+
|
131
|
+
class MetricName(StrEnum):
|
132
|
+
ACCURACY = "accuracy"
|
133
|
+
F1 = "f1"
|
134
|
+
AUC = "auc"
|
135
|
+
MSE = "mse"
|
136
|
+
MAE = "mae"
|
137
|
+
R2 = "r2"
|
138
|
+
LOG_LOSS = "log_loss"
|
139
|
+
RMSE = "rmse"
|
140
|
+
|
141
|
+
|
142
|
+
class LossName(StrEnum):
|
143
|
+
CROSS_ENTROPY = "cross_entropy"
|
144
|
+
MSE = "mse"
|
145
|
+
MAE = "mae"
|
@@ -0,0 +1,94 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
|
4
|
+
|
5
|
+
class EarlyStopping():
|
6
|
+
|
7
|
+
def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
|
8
|
+
|
9
|
+
self.patience = patience
|
10
|
+
self.counter = 0
|
11
|
+
self.best_score = None
|
12
|
+
self.early_stop = False
|
13
|
+
self.delta = delta
|
14
|
+
self.metric = metric
|
15
|
+
|
16
|
+
|
17
|
+
def __call__(self, val_loss):
|
18
|
+
|
19
|
+
# smaller is better for these metrics
|
20
|
+
if self.metric in ["log_loss", "mse", "mae", "rmse"]:
|
21
|
+
score = -val_loss
|
22
|
+
# larger is better for these metrics
|
23
|
+
elif self.metric in ["accuracy", "roc_auc", "r2"]:
|
24
|
+
score = val_loss
|
25
|
+
else:
|
26
|
+
raise ValueError(f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2.")
|
27
|
+
|
28
|
+
if self.best_score is None:
|
29
|
+
self.best_score = score
|
30
|
+
elif score < self.best_score + self.delta:
|
31
|
+
self.counter += 1
|
32
|
+
if self.counter >= self.patience:
|
33
|
+
self.early_stop = True
|
34
|
+
else:
|
35
|
+
self.best_score = score
|
36
|
+
self.counter = 0
|
37
|
+
|
38
|
+
def we_should_stop(self):
|
39
|
+
return self.early_stop
|
40
|
+
|
41
|
+
|
42
|
+
class Checkpoint():
|
43
|
+
|
44
|
+
def __init__(self):
|
45
|
+
self.curr_best_loss = np.inf
|
46
|
+
self.best_model: dict
|
47
|
+
|
48
|
+
def reset(self, net: torch.nn.Module):
|
49
|
+
self.curr_best_loss = np.inf
|
50
|
+
self.best_model = net.state_dict()
|
51
|
+
for key in self.best_model:
|
52
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
53
|
+
|
54
|
+
|
55
|
+
def __call__(self, net: torch.nn.Module, loss: float):
|
56
|
+
|
57
|
+
if loss < self.curr_best_loss:
|
58
|
+
self.curr_best_loss = loss
|
59
|
+
self.best_model = net.state_dict()
|
60
|
+
for key in self.best_model:
|
61
|
+
self.best_model[key] = self.best_model[key].to('cpu')
|
62
|
+
|
63
|
+
|
64
|
+
def set_to_best(self, net):
|
65
|
+
net.load_state_dict(self.best_model)
|
66
|
+
|
67
|
+
|
68
|
+
class EpochStatistics():
|
69
|
+
|
70
|
+
def __init__(self) -> None:
|
71
|
+
self.n = 0
|
72
|
+
self.loss = 0
|
73
|
+
self.score = 0
|
74
|
+
|
75
|
+
def update(self, loss, score, n):
|
76
|
+
self.n += n
|
77
|
+
self.loss += loss * n
|
78
|
+
self.score += score * n
|
79
|
+
|
80
|
+
def get(self):
|
81
|
+
return self.loss / self.n, self.score / self.n
|
82
|
+
|
83
|
+
class TrackOutput():
|
84
|
+
|
85
|
+
def __init__(self) -> None:
|
86
|
+
self.y_true: list[np.ndarray] = []
|
87
|
+
self.y_pred: list[np.ndarray] = []
|
88
|
+
|
89
|
+
def update(self, y_true: np.ndarray, y_pred: np.ndarray):
|
90
|
+
self.y_true.append(y_true)
|
91
|
+
self.y_pred.append(y_pred)
|
92
|
+
|
93
|
+
def get(self):
|
94
|
+
return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
|
@@ -0,0 +1,55 @@
|
|
1
|
+
import torch
|
2
|
+
import einops
|
3
|
+
|
4
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
5
|
+
from ..._internal.config.config_run import ConfigRun
|
6
|
+
from ..._internal.config.enums import LossName, Task
|
7
|
+
|
8
|
+
class CrossEntropyLossExtraBatch(torch.nn.Module):
|
9
|
+
|
10
|
+
def __init__(self, label_smoothing: float):
|
11
|
+
super().__init__()
|
12
|
+
|
13
|
+
self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
|
14
|
+
|
15
|
+
|
16
|
+
def forward(self, input, target):
|
17
|
+
"""
|
18
|
+
Input has shape (batch_size, num_samples, num_classes)
|
19
|
+
Target has shape (batch_size, num_samples)
|
20
|
+
|
21
|
+
Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
|
22
|
+
"""
|
23
|
+
|
24
|
+
input = einops.rearrange(input, 'b s c -> (b s) c')
|
25
|
+
target = einops.rearrange(target, 'b s -> (b s)')
|
26
|
+
|
27
|
+
return self.loss(input, target)
|
28
|
+
|
29
|
+
def get_loss(cfg: ConfigRun):
|
30
|
+
|
31
|
+
match (cfg.task, cfg.hyperparams['regression_loss']):
|
32
|
+
case (Task.REGRESSION, LossName.MSE):
|
33
|
+
return torch.nn.MSELoss()
|
34
|
+
case (Task.REGRESSION, LossName.MAE):
|
35
|
+
return torch.nn.L1Loss()
|
36
|
+
case (Task.REGRESSION, LossName.CROSS_ENTROPY):
|
37
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
38
|
+
case (Task.CLASSIFICATION, _):
|
39
|
+
return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
|
40
|
+
case (_, _):
|
41
|
+
raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
|
42
|
+
|
43
|
+
def get_loss_pretrain(cfg: ConfigPretrain):
|
44
|
+
|
45
|
+
match (cfg.data.task, cfg.optim.regression_loss):
|
46
|
+
case (Task.REGRESSION, LossName.MSE):
|
47
|
+
return torch.nn.MSELoss()
|
48
|
+
case (Task.REGRESSION, LossName.MAE):
|
49
|
+
return torch.nn.L1Loss()
|
50
|
+
case (Task.REGRESSION, LossName.CROSS_ENTROPY):
|
51
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
52
|
+
case (Task.CLASSIFICATION, _):
|
53
|
+
return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
|
54
|
+
case (_, _):
|
55
|
+
raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
|
@@ -0,0 +1,108 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.optim import SGD, Adam, AdamW
|
3
|
+
|
4
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
5
|
+
|
6
|
+
|
7
|
+
def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
|
8
|
+
|
9
|
+
optimizer: torch.optim.Optimizer
|
10
|
+
|
11
|
+
if hyperparams['optimizer'] == "adam":
|
12
|
+
optimizer = Adam(
|
13
|
+
model.parameters(),
|
14
|
+
lr=hyperparams['lr'],
|
15
|
+
betas=(0.9, 0.999),
|
16
|
+
weight_decay=hyperparams['weight_decay']
|
17
|
+
)
|
18
|
+
elif hyperparams['optimizer'] == "adamw":
|
19
|
+
optimizer = AdamW(
|
20
|
+
model.parameters(),
|
21
|
+
lr=hyperparams['lr'],
|
22
|
+
betas=(0.9, 0.999),
|
23
|
+
weight_decay=hyperparams['weight_decay']
|
24
|
+
)
|
25
|
+
elif hyperparams['optimizer'] == "sgd":
|
26
|
+
optimizer = SGD(
|
27
|
+
model.parameters(),
|
28
|
+
lr=hyperparams['lr'],
|
29
|
+
weight_decay=hyperparams['weight_decay']
|
30
|
+
)
|
31
|
+
else:
|
32
|
+
raise ValueError("Optimizer not recognized")
|
33
|
+
|
34
|
+
return optimizer
|
35
|
+
|
36
|
+
|
37
|
+
def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
|
38
|
+
|
39
|
+
parameters = [(name, param) for name, param in model.named_parameters()]
|
40
|
+
|
41
|
+
parameters_with_weight_decay = []
|
42
|
+
parameters_without_weight_decay = []
|
43
|
+
|
44
|
+
for name, param in parameters:
|
45
|
+
if name.endswith("bias") or 'norm' in name or 'embedding' in name:
|
46
|
+
parameters_without_weight_decay.append(param)
|
47
|
+
else:
|
48
|
+
parameters_with_weight_decay.append(param)
|
49
|
+
|
50
|
+
optimizer_parameters = [
|
51
|
+
{"params": parameters_with_weight_decay, "weight_decay": cfg.optim.weight_decay},
|
52
|
+
{"params": parameters_without_weight_decay, "weight_decay": 0.0},
|
53
|
+
]
|
54
|
+
|
55
|
+
optimizer = torch.optim.AdamW(
|
56
|
+
optimizer_parameters,
|
57
|
+
lr=cfg.optim.lr,
|
58
|
+
betas=(cfg.optim.beta1, cfg.optim.beta2),
|
59
|
+
weight_decay=cfg.optim.weight_decay
|
60
|
+
)
|
61
|
+
|
62
|
+
return optimizer
|
63
|
+
|
64
|
+
|
65
|
+
class GradScaler(torch.amp.GradScaler):
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
enabled: bool = True,
|
70
|
+
scale_init: float = 2.**16,
|
71
|
+
scale_min: float = 1.,
|
72
|
+
growth_interval: int = 2000,
|
73
|
+
device: str = 'cuda'
|
74
|
+
):
|
75
|
+
super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
|
76
|
+
self._enabled = enabled
|
77
|
+
self.scale_min = scale_min
|
78
|
+
self.device = device
|
79
|
+
|
80
|
+
if not self._enabled:
|
81
|
+
# We write scale=1 to log if the scaler is disabled
|
82
|
+
self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
|
83
|
+
|
84
|
+
|
85
|
+
def update(self):
|
86
|
+
|
87
|
+
if not self._enabled:
|
88
|
+
return
|
89
|
+
|
90
|
+
super().update()
|
91
|
+
|
92
|
+
if self._scale < self.scale_min:
|
93
|
+
super().update(self.scale_min)
|
94
|
+
|
95
|
+
|
96
|
+
def move_optimizer_to(optim, device):
|
97
|
+
for param in optim.state.values():
|
98
|
+
# Not sure there are any global tensors in the state dict
|
99
|
+
if isinstance(param, torch.Tensor):
|
100
|
+
param.data = param.data.to(device)
|
101
|
+
if param._grad is not None:
|
102
|
+
param._grad.data = param._grad.data.to(device)
|
103
|
+
elif isinstance(param, dict):
|
104
|
+
for subparam in param.values():
|
105
|
+
if isinstance(subparam, torch.Tensor):
|
106
|
+
subparam.data = subparam.data.to(device)
|
107
|
+
if subparam._grad is not None:
|
108
|
+
subparam._grad.data = subparam._grad.data.to(device)
|
@@ -0,0 +1,67 @@
|
|
1
|
+
import torch
|
2
|
+
from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
|
3
|
+
from transformers import get_constant_schedule_with_warmup
|
4
|
+
from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
|
5
|
+
|
6
|
+
from ..._internal.config.config_pretrain import ConfigPretrain
|
7
|
+
|
8
|
+
|
9
|
+
def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
|
10
|
+
|
11
|
+
warmup_steps = hyperparams['warmup_steps']
|
12
|
+
|
13
|
+
# if warmup_steps > 0:
|
14
|
+
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
15
|
+
# optimizer, lambda step: min((step + 1) / warmup_steps, 1.0)
|
16
|
+
# )
|
17
|
+
# else:
|
18
|
+
# scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
|
19
|
+
# optimizer, lambda step: 1.0
|
20
|
+
# )
|
21
|
+
|
22
|
+
if warmup_steps > 0:
|
23
|
+
scheduler_warmup = LinearLR(
|
24
|
+
optimizer,
|
25
|
+
start_factor=1.0 / warmup_steps,
|
26
|
+
end_factor=1.0,
|
27
|
+
total_iters=warmup_steps,
|
28
|
+
)
|
29
|
+
else:
|
30
|
+
scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
|
31
|
+
|
32
|
+
if hyperparams['lr_scheduler']:
|
33
|
+
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
34
|
+
optimizer,
|
35
|
+
patience=hyperparams['lr_scheduler_patience'],
|
36
|
+
min_lr=0,
|
37
|
+
factor=0.2
|
38
|
+
)
|
39
|
+
else:
|
40
|
+
# With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
|
41
|
+
scheduler_reduce_on_plateau = ReduceLROnPlateau(
|
42
|
+
optimizer,
|
43
|
+
patience=1000000000,
|
44
|
+
min_lr=0,
|
45
|
+
factor=0.2
|
46
|
+
)
|
47
|
+
|
48
|
+
return scheduler_warmup, scheduler_reduce_on_plateau
|
49
|
+
|
50
|
+
|
51
|
+
def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
|
52
|
+
|
53
|
+
|
54
|
+
if cfg.optim.cosine_scheduler:
|
55
|
+
schedule = get_cosine_with_min_lr_schedule_with_warmup(
|
56
|
+
optimizer,
|
57
|
+
num_warmup_steps=cfg.optim.warmup_steps,
|
58
|
+
num_training_steps=cfg.optim.steps,
|
59
|
+
min_lr_rate=0.1
|
60
|
+
)
|
61
|
+
else:
|
62
|
+
schedule = get_constant_schedule_with_warmup(
|
63
|
+
optimizer,
|
64
|
+
num_warmup_steps=cfg.optim.warmup_steps
|
65
|
+
)
|
66
|
+
|
67
|
+
return schedule
|