autogluon.tabular 1.3.2b20250713__py3-none-any.whl → 1.3.2b20250714__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (32) hide show
  1. autogluon/tabular/models/__init__.py +1 -0
  2. autogluon/tabular/models/mitra/__init__.py +0 -0
  3. autogluon/tabular/models/mitra/_internal/config/config_pretrain.py +190 -0
  4. autogluon/tabular/models/mitra/_internal/config/config_run.py +32 -0
  5. autogluon/tabular/models/mitra/_internal/config/enums.py +145 -0
  6. autogluon/tabular/models/mitra/_internal/core/callbacks.py +94 -0
  7. autogluon/tabular/models/mitra/_internal/core/get_loss.py +55 -0
  8. autogluon/tabular/models/mitra/_internal/core/get_optimizer.py +108 -0
  9. autogluon/tabular/models/mitra/_internal/core/get_scheduler.py +67 -0
  10. autogluon/tabular/models/mitra/_internal/core/prediction_metrics.py +134 -0
  11. autogluon/tabular/models/mitra/_internal/core/trainer_finetune.py +367 -0
  12. autogluon/tabular/models/mitra/_internal/data/collator.py +46 -0
  13. autogluon/tabular/models/mitra/_internal/data/dataset_finetune.py +132 -0
  14. autogluon/tabular/models/mitra/_internal/data/dataset_split.py +53 -0
  15. autogluon/tabular/models/mitra/_internal/data/preprocessor.py +420 -0
  16. autogluon/tabular/models/mitra/_internal/models/base.py +21 -0
  17. autogluon/tabular/models/mitra/_internal/models/embedding.py +182 -0
  18. autogluon/tabular/models/mitra/_internal/models/tab2d.py +667 -0
  19. autogluon/tabular/models/mitra/_internal/utils/set_seed.py +15 -0
  20. autogluon/tabular/models/mitra/mitra_model.py +214 -0
  21. autogluon/tabular/models/mitra/sklearn_interface.py +462 -0
  22. autogluon/tabular/registry/_ag_model_registry.py +2 -0
  23. autogluon/tabular/version.py +1 -1
  24. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/METADATA +19 -10
  25. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/RECORD +32 -12
  26. /autogluon.tabular-1.3.2b20250713-py3.9-nspkg.pth → /autogluon.tabular-1.3.2b20250714-py3.9-nspkg.pth +0 -0
  27. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/LICENSE +0 -0
  28. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/NOTICE +0 -0
  29. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/WHEEL +0 -0
  30. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/namespace_packages.txt +0 -0
  31. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/top_level.txt +0 -0
  32. {autogluon.tabular-1.3.2b20250713.dist-info → autogluon.tabular-1.3.2b20250714.dist-info}/zip-safe +0 -0
@@ -23,6 +23,7 @@ from .tabicl.tabicl_model import TabICLModel
23
23
  from .tabm.tabm_model import TabMModel
24
24
  from .tabpfnv2.tabpfnv2_model import TabPFNV2Model
25
25
  from .tabpfnmix.tabpfnmix_model import TabPFNMixModel
26
+ from .mitra.mitra_model import MitraModel
26
27
  from .tabular_nn.torch.tabular_nn_torch import TabularNeuralNetTorchModel
27
28
  from .text_prediction.text_prediction_v1_model import TextPredictorModel
28
29
  from .xgboost.xgboost_model import XGBoostModel
File without changes
@@ -0,0 +1,190 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from pathlib import Path
5
+ from typing import Optional
6
+ import yaml
7
+ import os
8
+
9
+ import torch
10
+ from omegaconf import DictConfig, OmegaConf
11
+
12
+ from ..._internal.config.enums import GeneratorName, ModelName, LossName, Task
13
+
14
+ @dataclass
15
+ class ConfigData():
16
+ generator: GeneratorName
17
+ min_samples_support: int
18
+ max_samples_support: int
19
+ n_samples_query: int
20
+ min_features: int
21
+ max_features: int
22
+ max_classes: int
23
+ sample_multinomial_categorical: bool
24
+ sample_multinomial_label: bool
25
+ generator_hyperparams: dict
26
+ task: Task
27
+
28
+ def __post_init__(self):
29
+
30
+ assert self.min_samples_support <= self.max_samples_support
31
+ assert self.min_features <= self.max_features
32
+
33
+ @dataclass
34
+ class ConfigModel():
35
+ name: ModelName
36
+ hyperparams: dict
37
+
38
+
39
+ @dataclass
40
+ class ConfigPreprocessing():
41
+ use_quantile_transformer: bool
42
+ use_feature_count_scaling: bool
43
+
44
+ @dataclass
45
+ class ConfigGradScaler():
46
+ enabled: bool
47
+ scale_init: float
48
+ scale_min: float
49
+ growth_interval: int
50
+
51
+
52
+ def __post_init__(self):
53
+ assert self.scale_init >= self.scale_min, "Scale init must be greater than scale min"
54
+ assert self.scale_min >= 1, "Scale min lower than 1 makes no sense for mixed precision training"
55
+ assert type(self.scale_init) == float, "Scale init must be a float, otherwise gradscaler will return an error"
56
+ assert type(self.scale_min) == float, "Scale min must be a float, otherwise gradscaler will return an error"
57
+
58
+ @dataclass
59
+ class ConfigOptim():
60
+ steps: int
61
+ log_every_n_steps: int
62
+ eval_every_n_steps: int
63
+ batch_size: int
64
+ gradient_accumulation_steps: int
65
+ lr: float
66
+ weight_decay: float
67
+ beta1: float
68
+ beta2: float
69
+ warmup_steps: int
70
+ cosine_scheduler: bool
71
+ max_grad_norm: float
72
+ label_smoothing: float
73
+ regression_loss: LossName
74
+ use_pretrained_weights: bool
75
+ path_to_weights: str
76
+ resume_states: bool
77
+ path_to_states: str
78
+ precision: str
79
+ grad_scaler: ConfigGradScaler
80
+
81
+ @classmethod
82
+ def from_hydra(cls, cfg_hydra: DictConfig) -> Self:
83
+
84
+ grad_scaler = ConfigGradScaler(**cfg_hydra.grad_scaler)
85
+ cfg_dict: dict = OmegaConf.to_container(cfg_hydra) # type: ignore
86
+ del cfg_dict["grad_scaler"]
87
+
88
+ regression_loss = LossName[cfg_dict["regression_loss"]]
89
+ del cfg_dict["regression_loss"]
90
+
91
+ return cls(
92
+ grad_scaler=grad_scaler,
93
+ regression_loss=regression_loss,
94
+ **cfg_dict
95
+ )
96
+
97
+ def __post_init__(self):
98
+ assert hasattr(torch, self.precision), f"Precision {self.precision} not supported by torch"
99
+
100
+ class ConfigSaveLoadMixin(yaml.YAMLObject):
101
+
102
+ def save(self, path: Path) -> None:
103
+
104
+ path.parent.mkdir(parents=True, exist_ok=True)
105
+
106
+ with open(path, 'w') as f:
107
+ yaml.dump(self, f, default_flow_style=False)
108
+
109
+
110
+ @classmethod
111
+ def load(cls, path: Path) -> Self:
112
+
113
+ with open(path, 'r') as f:
114
+ # It's unsafe, but not unsafer than the pickle module
115
+ config = yaml.unsafe_load(f)
116
+
117
+ return config
118
+
119
+ @dataclass
120
+ class ConfigPretrain(ConfigSaveLoadMixin):
121
+ run_name: str
122
+ output_dir: Path
123
+ seed: int
124
+ devices: list[torch.device]
125
+ device: torch.device
126
+ max_cpus_per_device: Optional[int]
127
+ use_ddp: bool
128
+ workers_per_gpu: int
129
+ model: ConfigModel
130
+ data: ConfigData
131
+ optim: ConfigOptim
132
+ preprocessing: ConfigPreprocessing
133
+ load_from_file: bool
134
+ load_path_x: str
135
+ load_path_y: str
136
+ save_file: bool
137
+ save_file_only: bool
138
+ save_path_x: str
139
+ save_path_y: str
140
+ number_of_runs: int
141
+
142
+ @classmethod
143
+ def from_hydra(cls, cfg_hydra: DictConfig):
144
+
145
+ assert not os.path.exists(cfg_hydra.output_dir), f'Output directory {cfg_hydra.output_dir} already exists! Please change to a new folder.'
146
+
147
+ output_dir = Path(cfg_hydra.output_dir)
148
+
149
+ devices = [torch.device(device) for device in cfg_hydra.devices]
150
+
151
+ # Initialize device to cpu, DDP will overwrite this
152
+ device = torch.device("cpu")
153
+
154
+ return cls(
155
+ run_name=cfg_hydra.run_name,
156
+ output_dir=output_dir,
157
+ devices=devices,
158
+ device=device,
159
+ max_cpus_per_device=cfg_hydra.max_cpus_per_device,
160
+ use_ddp=len(devices) > 1,
161
+ seed=cfg_hydra.seed,
162
+ workers_per_gpu=cfg_hydra.workers_per_gpu,
163
+ model = ConfigModel(
164
+ name = ModelName[cfg_hydra.model.name],
165
+ hyperparams = OmegaConf.to_container(cfg_hydra.model.hyperparams),
166
+ ),
167
+ data = ConfigData(
168
+ generator=GeneratorName(cfg_hydra.data.generator),
169
+ min_samples_support=cfg_hydra.data.min_samples_support,
170
+ max_samples_support=cfg_hydra.data.max_samples_support,
171
+ n_samples_query=cfg_hydra.data.n_samples_query,
172
+ min_features=cfg_hydra.data.min_features,
173
+ max_features=cfg_hydra.data.max_features,
174
+ max_classes=cfg_hydra.data.max_classes,
175
+ task=Task[cfg_hydra.data.task],
176
+ sample_multinomial_categorical=cfg_hydra.data.sample_multinomial_categorical,
177
+ sample_multinomial_label=cfg_hydra.data.sample_multinomial_label,
178
+ generator_hyperparams=OmegaConf.to_container(cfg_hydra.data.generator_hyperparams), # type: ignore
179
+ ),
180
+ optim = ConfigOptim.from_hydra(cfg_hydra.optim),
181
+ preprocessing = ConfigPreprocessing(**cfg_hydra.preprocessing),
182
+ load_from_file = cfg_hydra.load_from_file,
183
+ load_path_x = cfg_hydra.load_path_x,
184
+ load_path_y = cfg_hydra.load_path_y,
185
+ save_file = cfg_hydra.save_file,
186
+ save_file_only = cfg_hydra.save_file_only,
187
+ save_path_x = cfg_hydra.save_path_x,
188
+ save_path_y = cfg_hydra.save_path_y,
189
+ number_of_runs = cfg_hydra.number_of_runs,
190
+ )
@@ -0,0 +1,32 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import Self
5
+
6
+ import torch
7
+
8
+ from ..._internal.config.config_pretrain import ConfigSaveLoadMixin
9
+ from ..._internal.config.enums import ModelName
10
+
11
+ @dataclass
12
+ class ConfigRun(ConfigSaveLoadMixin):
13
+ device: torch.device
14
+ seed: int
15
+ model_name: ModelName
16
+ hyperparams: dict
17
+
18
+ @classmethod
19
+ def create(
20
+ cls,
21
+ device: torch.device,
22
+ seed: int,
23
+ model_name: ModelName,
24
+ hyperparams: dict
25
+ ) -> Self:
26
+
27
+ return cls(
28
+ device=device,
29
+ seed=seed,
30
+ model_name=model_name,
31
+ hyperparams=hyperparams
32
+ )
@@ -0,0 +1,145 @@
1
+ from enum import IntEnum, StrEnum
2
+
3
+
4
+ class Task(StrEnum):
5
+ CLASSIFICATION = "classification"
6
+ REGRESSION = "regression"
7
+
8
+
9
+ class FeatureType(StrEnum):
10
+ NUMERICAL = "numerical"
11
+ CATEGORICAL = "categorical"
12
+ MIXED = "mixed"
13
+
14
+
15
+ class SearchType(StrEnum):
16
+ DEFAULT = "default"
17
+ RANDOM = "random"
18
+
19
+
20
+ class DatasetSize(IntEnum):
21
+ SMALL = 1000
22
+ MEDIUM = 10000
23
+ LARGE = 50000
24
+
25
+
26
+ class DataSplit(StrEnum):
27
+ TRAIN = "train"
28
+ VALID = "valid"
29
+ TEST = "test"
30
+
31
+
32
+ class Phase(StrEnum):
33
+ TRAINING = "training"
34
+ VALIDATION = "validation"
35
+ TESTING = "testing"
36
+
37
+
38
+ class ModelName(StrEnum):
39
+ PLACEHOLDER = "_placeholder_" # This is a placeholder for the current running model
40
+ FT_TRANSFORMER = "FT-Transformer"
41
+ TABPFN = "TabPFN"
42
+ FOUNDATION = "Foundation"
43
+ FOUNDATION_FLASH = "FoundationFlash"
44
+ TAB2D = "Tab2D"
45
+ TAB2D_COL_ROW = "Tab2D_COL_ROW"
46
+ TAB2D_SDPA = "Tab2D_SDPA"
47
+ SAINT = "SAINT"
48
+ MLP = "MLP"
49
+ MLP_RTDL = "MLP-rtdl"
50
+ RESNET = "Resnet"
51
+ RANDOM_FOREST = "RandomForest"
52
+ XGBOOST = "XGBoost"
53
+ CATBOOST = "CatBoost"
54
+ LIGHTGBM = "LightGBM"
55
+ GRADIENT_BOOSTING_TREE = "GradientBoostingTree"
56
+ HIST_GRADIENT_BOOSTING_TREE = "HistGradientBoostingTree"
57
+ LOGISTIC_REGRESSION = "LogisticRegression"
58
+ LINEAR_REGRESSION = "LinearRegression"
59
+ DECISION_TREE = "DecisionTree"
60
+ KNN = "KNN"
61
+ STG = "STG"
62
+ SVM = "SVM"
63
+ TABNET = "TabNet"
64
+ TABTRANSFORMER = "TabTransformer"
65
+ DEEPFM = "DeepFM"
66
+ VIME = "VIME"
67
+ DANET = "DANet"
68
+ NODE = "NODE"
69
+ AUTOGLUON = "AutoGluon"
70
+
71
+
72
+ class ModelClass(StrEnum):
73
+ BASE = 'base'
74
+ GBDT = 'GBDT'
75
+ NN = 'NN'
76
+ ICLT = 'ICLT'
77
+
78
+
79
+ class DownstreamTask(StrEnum):
80
+ ZEROSHOT = "zeroshot"
81
+ FINETUNE = "finetune"
82
+
83
+
84
+
85
+ class BenchmarkName(StrEnum):
86
+ DEBUG_CLASSIFICATION = "debug_classification"
87
+ DEBUG_REGRESSION = "debug_regression"
88
+ DEBUG_TABZILLA = "debug_tabzilla"
89
+
90
+ CATEGORICAL_CLASSIFICATION = "categorical_classification"
91
+ NUMERICAL_CLASSIFICATION = "numerical_classification"
92
+ CATEGORICAL_REGRESSION = "categorical_regression"
93
+ NUMERICAL_REGRESSION = "numerical_regression"
94
+ CATEGORICAL_CLASSIFICATION_LARGE = "categorical_classification_large"
95
+ NUMERICAL_CLASSIFICATION_LARGE = "numerical_classification_large"
96
+ CATEGORICAL_REGRESSION_LARGE = "categorical_regression_large"
97
+ NUMERICAL_REGRESSION_LARGE = "numerical_regression_large"
98
+
99
+ TABZILLA_HARD = "tabzilla_hard"
100
+ TABZILLA_HARD_MAX_TEN_CLASSES = "tabzilla_hard_max_ten_classes"
101
+ TABZILLA_HAS_COMPLETED_RUNS = "tabzilla_has_completed_runs"
102
+
103
+
104
+ class BenchmarkOrigin(StrEnum):
105
+ TABZILLA = "tabzilla"
106
+ WHYTREES = "whytrees"
107
+
108
+
109
+ class GeneratorName(StrEnum):
110
+ TABPFN = 'tabpfn'
111
+ TREE = 'tree'
112
+ RANDOMFOREST = 'randomforest'
113
+ NEIGHBOR = 'neighbor'
114
+ MIX = 'mix'
115
+ PERLIN = 'perlin'
116
+ MIX_7 = 'mix_7'
117
+ MIX_6 = 'mix_6'
118
+ MIX_5 = 'mix_5'
119
+ MIX_5_GP = 'mix_5_gp'
120
+ MIX_4 = 'mix_4'
121
+ MIX_4_AG = 'mix_4_ag'
122
+ LR = 'lr'
123
+ POLY = 'poly'
124
+ SAMPLE_RF = 'sample_rf'
125
+ SAMPLE_GP = 'sample_gp'
126
+ TABREPO = 'tabrepo'
127
+ MIX_4_TABREPO = 'mix_4_tabrepo'
128
+ MIX_4_TABPFNV2 = 'mix_4_tabpfnv2'
129
+
130
+
131
+ class MetricName(StrEnum):
132
+ ACCURACY = "accuracy"
133
+ F1 = "f1"
134
+ AUC = "auc"
135
+ MSE = "mse"
136
+ MAE = "mae"
137
+ R2 = "r2"
138
+ LOG_LOSS = "log_loss"
139
+ RMSE = "rmse"
140
+
141
+
142
+ class LossName(StrEnum):
143
+ CROSS_ENTROPY = "cross_entropy"
144
+ MSE = "mse"
145
+ MAE = "mae"
@@ -0,0 +1,94 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+
5
+ class EarlyStopping():
6
+
7
+ def __init__(self, patience=10, delta=0.0001, metric='log_loss'):
8
+
9
+ self.patience = patience
10
+ self.counter = 0
11
+ self.best_score = None
12
+ self.early_stop = False
13
+ self.delta = delta
14
+ self.metric = metric
15
+
16
+
17
+ def __call__(self, val_loss):
18
+
19
+ # smaller is better for these metrics
20
+ if self.metric in ["log_loss", "mse", "mae", "rmse"]:
21
+ score = -val_loss
22
+ # larger is better for these metrics
23
+ elif self.metric in ["accuracy", "roc_auc", "r2"]:
24
+ score = val_loss
25
+ else:
26
+ raise ValueError(f"Unsupported metric: {self.metric}. Supported metrics are: log_loss, mse, mae, rmse, accuracy, roc_auc, r2.")
27
+
28
+ if self.best_score is None:
29
+ self.best_score = score
30
+ elif score < self.best_score + self.delta:
31
+ self.counter += 1
32
+ if self.counter >= self.patience:
33
+ self.early_stop = True
34
+ else:
35
+ self.best_score = score
36
+ self.counter = 0
37
+
38
+ def we_should_stop(self):
39
+ return self.early_stop
40
+
41
+
42
+ class Checkpoint():
43
+
44
+ def __init__(self):
45
+ self.curr_best_loss = np.inf
46
+ self.best_model: dict
47
+
48
+ def reset(self, net: torch.nn.Module):
49
+ self.curr_best_loss = np.inf
50
+ self.best_model = net.state_dict()
51
+ for key in self.best_model:
52
+ self.best_model[key] = self.best_model[key].to('cpu')
53
+
54
+
55
+ def __call__(self, net: torch.nn.Module, loss: float):
56
+
57
+ if loss < self.curr_best_loss:
58
+ self.curr_best_loss = loss
59
+ self.best_model = net.state_dict()
60
+ for key in self.best_model:
61
+ self.best_model[key] = self.best_model[key].to('cpu')
62
+
63
+
64
+ def set_to_best(self, net):
65
+ net.load_state_dict(self.best_model)
66
+
67
+
68
+ class EpochStatistics():
69
+
70
+ def __init__(self) -> None:
71
+ self.n = 0
72
+ self.loss = 0
73
+ self.score = 0
74
+
75
+ def update(self, loss, score, n):
76
+ self.n += n
77
+ self.loss += loss * n
78
+ self.score += score * n
79
+
80
+ def get(self):
81
+ return self.loss / self.n, self.score / self.n
82
+
83
+ class TrackOutput():
84
+
85
+ def __init__(self) -> None:
86
+ self.y_true: list[np.ndarray] = []
87
+ self.y_pred: list[np.ndarray] = []
88
+
89
+ def update(self, y_true: np.ndarray, y_pred: np.ndarray):
90
+ self.y_true.append(y_true)
91
+ self.y_pred.append(y_pred)
92
+
93
+ def get(self):
94
+ return np.concatenate(self.y_true, axis=0), np.concatenate(self.y_pred, axis=0)
@@ -0,0 +1,55 @@
1
+ import torch
2
+ import einops
3
+
4
+ from ..._internal.config.config_pretrain import ConfigPretrain
5
+ from ..._internal.config.config_run import ConfigRun
6
+ from ..._internal.config.enums import LossName, Task
7
+
8
+ class CrossEntropyLossExtraBatch(torch.nn.Module):
9
+
10
+ def __init__(self, label_smoothing: float):
11
+ super().__init__()
12
+
13
+ self.loss = torch.nn.CrossEntropyLoss(label_smoothing=label_smoothing)
14
+
15
+
16
+ def forward(self, input, target):
17
+ """
18
+ Input has shape (batch_size, num_samples, num_classes)
19
+ Target has shape (batch_size, num_samples)
20
+
21
+ Compared to the original CrossEntropyLoss, accepts (batch_size, num_samples) as batch
22
+ """
23
+
24
+ input = einops.rearrange(input, 'b s c -> (b s) c')
25
+ target = einops.rearrange(target, 'b s -> (b s)')
26
+
27
+ return self.loss(input, target)
28
+
29
+ def get_loss(cfg: ConfigRun):
30
+
31
+ match (cfg.task, cfg.hyperparams['regression_loss']):
32
+ case (Task.REGRESSION, LossName.MSE):
33
+ return torch.nn.MSELoss()
34
+ case (Task.REGRESSION, LossName.MAE):
35
+ return torch.nn.L1Loss()
36
+ case (Task.REGRESSION, LossName.CROSS_ENTROPY):
37
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
38
+ case (Task.CLASSIFICATION, _):
39
+ return CrossEntropyLossExtraBatch(cfg.hyperparams['label_smoothing'])
40
+ case (_, _):
41
+ raise ValueError(f"Unsupported task {cfg.task} and (regression) loss {cfg.hyperparams['regression_loss']}")
42
+
43
+ def get_loss_pretrain(cfg: ConfigPretrain):
44
+
45
+ match (cfg.data.task, cfg.optim.regression_loss):
46
+ case (Task.REGRESSION, LossName.MSE):
47
+ return torch.nn.MSELoss()
48
+ case (Task.REGRESSION, LossName.MAE):
49
+ return torch.nn.L1Loss()
50
+ case (Task.REGRESSION, LossName.CROSS_ENTROPY):
51
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
52
+ case (Task.CLASSIFICATION, _):
53
+ return CrossEntropyLossExtraBatch(cfg.optim.label_smoothing)
54
+ case (_, _):
55
+ raise ValueError(f"Unsupported task {cfg.data.task} and (regression) loss {cfg.optim.regression_loss}")
@@ -0,0 +1,108 @@
1
+ import torch
2
+ from torch.optim import SGD, Adam, AdamW
3
+
4
+ from ..._internal.config.config_pretrain import ConfigPretrain
5
+
6
+
7
+ def get_optimizer(hyperparams: dict, model: torch.nn.Module) -> torch.optim.Optimizer:
8
+
9
+ optimizer: torch.optim.Optimizer
10
+
11
+ if hyperparams['optimizer'] == "adam":
12
+ optimizer = Adam(
13
+ model.parameters(),
14
+ lr=hyperparams['lr'],
15
+ betas=(0.9, 0.999),
16
+ weight_decay=hyperparams['weight_decay']
17
+ )
18
+ elif hyperparams['optimizer'] == "adamw":
19
+ optimizer = AdamW(
20
+ model.parameters(),
21
+ lr=hyperparams['lr'],
22
+ betas=(0.9, 0.999),
23
+ weight_decay=hyperparams['weight_decay']
24
+ )
25
+ elif hyperparams['optimizer'] == "sgd":
26
+ optimizer = SGD(
27
+ model.parameters(),
28
+ lr=hyperparams['lr'],
29
+ weight_decay=hyperparams['weight_decay']
30
+ )
31
+ else:
32
+ raise ValueError("Optimizer not recognized")
33
+
34
+ return optimizer
35
+
36
+
37
+ def get_optimizer_pretrain(cfg: ConfigPretrain, model: torch.nn.Module) -> torch.optim.Optimizer:
38
+
39
+ parameters = [(name, param) for name, param in model.named_parameters()]
40
+
41
+ parameters_with_weight_decay = []
42
+ parameters_without_weight_decay = []
43
+
44
+ for name, param in parameters:
45
+ if name.endswith("bias") or 'norm' in name or 'embedding' in name:
46
+ parameters_without_weight_decay.append(param)
47
+ else:
48
+ parameters_with_weight_decay.append(param)
49
+
50
+ optimizer_parameters = [
51
+ {"params": parameters_with_weight_decay, "weight_decay": cfg.optim.weight_decay},
52
+ {"params": parameters_without_weight_decay, "weight_decay": 0.0},
53
+ ]
54
+
55
+ optimizer = torch.optim.AdamW(
56
+ optimizer_parameters,
57
+ lr=cfg.optim.lr,
58
+ betas=(cfg.optim.beta1, cfg.optim.beta2),
59
+ weight_decay=cfg.optim.weight_decay
60
+ )
61
+
62
+ return optimizer
63
+
64
+
65
+ class GradScaler(torch.amp.GradScaler):
66
+
67
+ def __init__(
68
+ self,
69
+ enabled: bool = True,
70
+ scale_init: float = 2.**16,
71
+ scale_min: float = 1.,
72
+ growth_interval: int = 2000,
73
+ device: str = 'cuda'
74
+ ):
75
+ super().__init__(enabled=enabled, device="cpu", init_scale=scale_init, growth_interval=growth_interval) # type: ignore
76
+ self._enabled = enabled
77
+ self.scale_min = scale_min
78
+ self.device = device
79
+
80
+ if not self._enabled:
81
+ # We write scale=1 to log if the scaler is disabled
82
+ self._scale = torch.tensor((1,), dtype=torch.float32, device=self.device)
83
+
84
+
85
+ def update(self):
86
+
87
+ if not self._enabled:
88
+ return
89
+
90
+ super().update()
91
+
92
+ if self._scale < self.scale_min:
93
+ super().update(self.scale_min)
94
+
95
+
96
+ def move_optimizer_to(optim, device):
97
+ for param in optim.state.values():
98
+ # Not sure there are any global tensors in the state dict
99
+ if isinstance(param, torch.Tensor):
100
+ param.data = param.data.to(device)
101
+ if param._grad is not None:
102
+ param._grad.data = param._grad.data.to(device)
103
+ elif isinstance(param, dict):
104
+ for subparam in param.values():
105
+ if isinstance(subparam, torch.Tensor):
106
+ subparam.data = subparam.data.to(device)
107
+ if subparam._grad is not None:
108
+ subparam._grad.data = subparam._grad.data.to(device)
@@ -0,0 +1,67 @@
1
+ import torch
2
+ from torch.optim.lr_scheduler import ReduceLROnPlateau, LinearLR
3
+ from transformers import get_constant_schedule_with_warmup
4
+ from transformers.optimization import get_cosine_with_min_lr_schedule_with_warmup
5
+
6
+ from ..._internal.config.config_pretrain import ConfigPretrain
7
+
8
+
9
+ def get_scheduler(hyperparams: dict, optimizer: torch.optim.Optimizer) -> tuple[torch.optim.lr_scheduler.LambdaLR, ReduceLROnPlateau]:
10
+
11
+ warmup_steps = hyperparams['warmup_steps']
12
+
13
+ # if warmup_steps > 0:
14
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
15
+ # optimizer, lambda step: min((step + 1) / warmup_steps, 1.0)
16
+ # )
17
+ # else:
18
+ # scheduler_warmup = torch.optim.lr_scheduler.LambdaLR(
19
+ # optimizer, lambda step: 1.0
20
+ # )
21
+
22
+ if warmup_steps > 0:
23
+ scheduler_warmup = LinearLR(
24
+ optimizer,
25
+ start_factor=1.0 / warmup_steps,
26
+ end_factor=1.0,
27
+ total_iters=warmup_steps,
28
+ )
29
+ else:
30
+ scheduler_warmup = torch.optim.lr_scheduler.ConstantLR(optimizer, factor=1.0, total_iters=1)
31
+
32
+ if hyperparams['lr_scheduler']:
33
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
34
+ optimizer,
35
+ patience=hyperparams['lr_scheduler_patience'],
36
+ min_lr=0,
37
+ factor=0.2
38
+ )
39
+ else:
40
+ # With ReduceLROnPlateau, the scheduler accepts a metric to monitor, so our dummy metric must also be a ReduceLRonPlateau scheduler
41
+ scheduler_reduce_on_plateau = ReduceLROnPlateau(
42
+ optimizer,
43
+ patience=1000000000,
44
+ min_lr=0,
45
+ factor=0.2
46
+ )
47
+
48
+ return scheduler_warmup, scheduler_reduce_on_plateau
49
+
50
+
51
+ def get_scheduler_pretrain(cfg: ConfigPretrain, optimizer: torch.optim.Optimizer):
52
+
53
+
54
+ if cfg.optim.cosine_scheduler:
55
+ schedule = get_cosine_with_min_lr_schedule_with_warmup(
56
+ optimizer,
57
+ num_warmup_steps=cfg.optim.warmup_steps,
58
+ num_training_steps=cfg.optim.steps,
59
+ min_lr_rate=0.1
60
+ )
61
+ else:
62
+ schedule = get_constant_schedule_with_warmup(
63
+ optimizer,
64
+ num_warmup_steps=cfg.optim.warmup_steps
65
+ )
66
+
67
+ return schedule