project-llm-trainer 0.9.1__py3-none-any.whl → 0.9.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/train_configs.py +3 -2
- llm_trainer/trainer.py +21 -3
- {project_llm_trainer-0.9.1.dist-info → project_llm_trainer-0.9.2.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.9.1.dist-info → project_llm_trainer-0.9.2.dist-info}/RECORD +13 -13
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.9.1.dist-info → project_llm_trainer-0.9.2.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.9.1.dist-info → project_llm_trainer-0.9.2.dist-info}/top_level.txt +0 -0
llm_trainer/train_configs.py
CHANGED
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Optional, Union, Callable, List, Mapping, Any
|
|
1
|
+
from typing import Optional, Union, Callable, List, Mapping, Any, Tuple
|
|
2
2
|
from dataclasses import dataclass, field
|
|
3
3
|
|
|
4
4
|
import torch
|
|
@@ -111,7 +111,8 @@ class OptimConfig:
|
|
|
111
111
|
optim_type: str = 'adam' # or 'lion'
|
|
112
112
|
enable_lr_scheduler: bool = False
|
|
113
113
|
initial_lr: float
|
|
114
|
-
weight_decay: float =
|
|
114
|
+
weight_decay: Optional[float] = None
|
|
115
|
+
betas: Optional[Tuple[float, float]] = None
|
|
115
116
|
warmup_iters: Optional[int] = None
|
|
116
117
|
max_lr: Optional[float] = None
|
|
117
118
|
min_lr: Optional[float] = None
|
llm_trainer/trainer.py
CHANGED
|
@@ -156,14 +156,15 @@ class Trainer:
|
|
|
156
156
|
|
|
157
157
|
model, optim = TrainerTools().parallel.process(
|
|
158
158
|
model=model,
|
|
159
|
-
optimizer=self.
|
|
159
|
+
optimizer=self._config_optim(model, initial_lr),
|
|
160
160
|
kwargs=self.parallel_kwargs
|
|
161
161
|
)
|
|
162
162
|
|
|
163
163
|
return model, optim
|
|
164
164
|
|
|
165
|
-
def
|
|
165
|
+
def _config_optim(self, model, initial_lr):
|
|
166
166
|
optimizer = None
|
|
167
|
+
use_lion_optim = self.train_config.optim_config.optim_type == 'lion'
|
|
167
168
|
|
|
168
169
|
if isinstance(TrainerTools().parallel, DsParallel) and self.parallel_kwargs:
|
|
169
170
|
import deepspeed
|
|
@@ -175,6 +176,7 @@ class Trainer:
|
|
|
175
176
|
optimizer = deepspeed.ops.lion.DeepSpeedCPULion
|
|
176
177
|
else:
|
|
177
178
|
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
|
|
179
|
+
use_lion_optim = False
|
|
178
180
|
log('When set offload_optimizer, lion optim is unsupported, so set optim to adam!!!!!')
|
|
179
181
|
else:
|
|
180
182
|
optimizer = deepspeed.ops.adam.DeepSpeedCPUAdam
|
|
@@ -195,10 +197,26 @@ class Trainer:
|
|
|
195
197
|
else:
|
|
196
198
|
optimizer = torch.optim.AdamW
|
|
197
199
|
|
|
200
|
+
betas = self.train_config.optim_config.betas
|
|
201
|
+
weight_decay = self.train_config.optim_config.weight_decay
|
|
202
|
+
|
|
203
|
+
if betas is None:
|
|
204
|
+
if use_lion_optim:
|
|
205
|
+
betas = (0.95, 0.98)
|
|
206
|
+
else:
|
|
207
|
+
betas = (0.9, 0.999)
|
|
208
|
+
|
|
209
|
+
if weight_decay is None:
|
|
210
|
+
if use_lion_optim:
|
|
211
|
+
weight_decay = 0.015
|
|
212
|
+
else:
|
|
213
|
+
weight_decay = 0.01
|
|
214
|
+
|
|
198
215
|
return optimizer(
|
|
199
216
|
self._get_trainable_params(model),
|
|
200
217
|
lr=initial_lr,
|
|
201
|
-
|
|
218
|
+
betas=betas,
|
|
219
|
+
weight_decay=weight_decay
|
|
202
220
|
)
|
|
203
221
|
|
|
204
222
|
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=rSOGZx53jMgOuJdztfxQASYJ62uD0dVaih4IAnSwGBc,1787
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=0-xQCMz1xiPTDAZiYsVsiECSoZ_1eIvW9XsZOoFfakQ,7250
|
|
19
19
|
llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
20
|
+
llm_trainer/train_configs.py,sha256=afXUZ7M_Uoj0B3c2Nwf5xE-Lv7QAZZHTdW8LBw-QeWE,7704
|
|
21
|
+
llm_trainer/trainer.py,sha256=bVghqvQY4bvYAZFPgyh2ywX8WanqAC525Lkg8bNv4FQ,29721
|
|
22
22
|
llm_trainer/utils.py,sha256=xC5plG-8-_Al5yIF5xIU5lroOcBBk98TEhtUJrazZPE,12305
|
|
23
|
-
project_llm_trainer-0.9.
|
|
24
|
-
project_llm_trainer-0.9.
|
|
25
|
-
project_llm_trainer-0.9.
|
|
26
|
-
project_llm_trainer-0.9.
|
|
27
|
-
project_llm_trainer-0.9.
|
|
28
|
-
project_llm_trainer-0.9.
|
|
29
|
-
project_llm_trainer-0.9.
|
|
30
|
-
project_llm_trainer-0.9.
|
|
31
|
-
project_llm_trainer-0.9.
|
|
32
|
-
project_llm_trainer-0.9.
|
|
33
|
-
project_llm_trainer-0.9.
|
|
23
|
+
project_llm_trainer-0.9.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.9.2.data/scripts/ddp_train,sha256=eZSud6KYQAoKLsYB5QB-FI2zq5AZm6Apq1azKdupV3o,477
|
|
25
|
+
project_llm_trainer-0.9.2.data/scripts/ds_train,sha256=41q4rOxwbvZDUY0FDdAIpG13PEaUWBpthhvFvww8uOc,388
|
|
26
|
+
project_llm_trainer-0.9.2.data/scripts/plot_loss,sha256=O9ooioAJ-79-X06LosgqF8XOqQe-beRxYm3LsLunmoU,908
|
|
27
|
+
project_llm_trainer-0.9.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.9.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.9.2.data/scripts/smart_train,sha256=N8dp2n7k6bghGczedBVwOdtf1O66oM_cNPh9QmZt0bM,914
|
|
30
|
+
project_llm_trainer-0.9.2.dist-info/METADATA,sha256=hoIO4KbvNU5xaZdzuNljSZcZSb_Iozl_Skp4miE3U6Y,195
|
|
31
|
+
project_llm_trainer-0.9.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.9.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.9.2.dist-info/RECORD,,
|
{project_llm_trainer-0.9.1.data → project_llm_trainer-0.9.2.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|