project-llm-trainer 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +19 -28
- llm_trainer/dpo_trainer.py +8 -5
- llm_trainer/ds_checkpoint.py +2 -11
- llm_trainer/grpo_trainer.py +4 -2
- llm_trainer/trainer.py +6 -3
- {project_llm_trainer-0.5.3.dist-info → project_llm_trainer-0.5.5.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.3.dist-info → project_llm_trainer-0.5.5.dist-info}/RECORD +16 -16
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.3.dist-info → project_llm_trainer-0.5.5.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.3.dist-info → project_llm_trainer-0.5.5.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -14,17 +14,14 @@ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
|
14
14
|
|
|
15
15
|
def save_checkpoint(
|
|
16
16
|
model: nn.Module,
|
|
17
|
-
optimizer: Optional[Optimizer] = None
|
|
18
|
-
suffix: Optional[str] = None
|
|
17
|
+
optimizer: Optional[Optimizer] = None
|
|
19
18
|
):
|
|
20
19
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
21
20
|
from .ds_checkpoint import save_ds_checkpoint
|
|
22
|
-
save_ds_checkpoint(model
|
|
21
|
+
save_ds_checkpoint(model)
|
|
23
22
|
else:
|
|
24
23
|
if TrainerTools().parallel.is_main_process:
|
|
25
24
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
26
|
-
if suffix:
|
|
27
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
28
25
|
|
|
29
26
|
raw_model = model if not isinstance(model, DDP) else model.module
|
|
30
27
|
ckpt = {'model_state_dict': raw_model.state_dict()}
|
|
@@ -37,28 +34,26 @@ def save_checkpoint(
|
|
|
37
34
|
|
|
38
35
|
def save_best_checkpoint(
|
|
39
36
|
current_loss: float,
|
|
40
|
-
last_best_checkpoint_loss: float
|
|
41
|
-
suffix: Optional[str] = None
|
|
37
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
42
38
|
) -> bool:
|
|
43
|
-
need_replace = current_loss <= last_best_checkpoint_loss
|
|
39
|
+
need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
|
|
44
40
|
if need_replace and TrainerTools().parallel.is_main_process:
|
|
45
41
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
46
|
-
|
|
47
|
-
if suffix:
|
|
48
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
42
|
+
checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
49
43
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
44
|
+
if checkpoint_dir.endswith('/'):
|
|
45
|
+
best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
|
|
46
|
+
else:
|
|
47
|
+
best_checkpoint_dir = f'{checkpoint_dir}_best'
|
|
53
48
|
|
|
54
|
-
if os.path.exists(
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
if not os.path.exists(best_checkpoint_dir):
|
|
50
|
+
os.makedirs(best_checkpoint_dir)
|
|
51
|
+
|
|
52
|
+
if os.path.exists(checkpoint_dir):
|
|
53
|
+
shutil.rmtree(best_checkpoint_dir)
|
|
54
|
+
shutil.copytree(checkpoint_dir, best_checkpoint_dir)
|
|
57
55
|
else:
|
|
58
56
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
59
|
-
if suffix:
|
|
60
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
61
|
-
|
|
62
57
|
best_checkpoint_name = f'{checkpoint_name}_best'
|
|
63
58
|
|
|
64
59
|
if os.path.exists(checkpoint_name):
|
|
@@ -75,16 +70,13 @@ def load_checkpoint(
|
|
|
75
70
|
model: nn.Module,
|
|
76
71
|
optimizer: Optional[Optimizer] = None,
|
|
77
72
|
device: Optional[Union[torch.device, str]] = None,
|
|
78
|
-
load_module_only: bool = False
|
|
79
|
-
suffix: Optional[str] = None
|
|
73
|
+
load_module_only: bool = False
|
|
80
74
|
):
|
|
81
75
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
82
76
|
from .ds_checkpoint import load_ds_checkpoint
|
|
83
|
-
load_ds_checkpoint(model, load_module_only=load_module_only
|
|
77
|
+
load_ds_checkpoint(model, load_module_only=load_module_only)
|
|
84
78
|
else:
|
|
85
79
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
86
|
-
if suffix:
|
|
87
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
88
80
|
|
|
89
81
|
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
90
82
|
raw_model = model.module if isinstance(model, DDP) else model
|
|
@@ -96,14 +88,13 @@ def load_checkpoint(
|
|
|
96
88
|
|
|
97
89
|
def load_checkpoint_for_eval(
|
|
98
90
|
model: nn.Module,
|
|
99
|
-
device: Optional[Union[torch.device, str]] = None
|
|
100
|
-
suffix: Optional[str] = None
|
|
91
|
+
device: Optional[Union[torch.device, str]] = None
|
|
101
92
|
):
|
|
102
93
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
103
94
|
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
104
95
|
load_ds_checkpoint_for_eval(model)
|
|
105
96
|
else:
|
|
106
|
-
load_checkpoint(model, None, device
|
|
97
|
+
load_checkpoint(model, None, device)
|
|
107
98
|
|
|
108
99
|
|
|
109
100
|
def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -137,11 +137,12 @@ class DPOTrainer(Trainer):
|
|
|
137
137
|
# 梯度累积步数
|
|
138
138
|
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
139
139
|
global_steps = 0
|
|
140
|
-
loss_accumulation = 0.0
|
|
141
140
|
skipping_train = False
|
|
142
141
|
|
|
142
|
+
loss_accumulation = 0.0
|
|
143
|
+
batches_accumulated = 0
|
|
143
144
|
current_loss: float = 0.0
|
|
144
|
-
last_best_checkpoint_loss: float =
|
|
145
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
145
146
|
|
|
146
147
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
147
148
|
|
|
@@ -214,14 +215,15 @@ class DPOTrainer(Trainer):
|
|
|
214
215
|
|
|
215
216
|
loss_accumulation += loss.detach().item()
|
|
216
217
|
self._backward_loss(loss)
|
|
218
|
+
batches_accumulated += 1
|
|
217
219
|
|
|
218
220
|
if need_update_grad:
|
|
219
|
-
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
221
|
+
loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
|
|
220
222
|
|
|
221
223
|
if TrainerTools().parallel.parallel_train:
|
|
222
224
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
223
225
|
|
|
224
|
-
|
|
226
|
+
current_loss = loss_tensor.item()
|
|
225
227
|
|
|
226
228
|
# ds模式已经集成gradient_clipping
|
|
227
229
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
@@ -235,10 +237,11 @@ class DPOTrainer(Trainer):
|
|
|
235
237
|
epoch_tag=f'epoch: {epoch}',
|
|
236
238
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
237
239
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
238
|
-
loss=
|
|
240
|
+
loss=current_loss
|
|
239
241
|
)
|
|
240
242
|
# reset to default
|
|
241
243
|
loss_accumulation = 0.0
|
|
244
|
+
batches_accumulated = 0
|
|
242
245
|
except Exception as e:
|
|
243
246
|
self._on_exception(e, epoch, batch)
|
|
244
247
|
finally:
|
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Optional
|
|
3
2
|
from glob import glob
|
|
4
3
|
import shutil
|
|
5
4
|
from torch import nn
|
|
@@ -17,14 +16,9 @@ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器
|
|
|
17
16
|
convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
|
|
18
17
|
"""
|
|
19
18
|
|
|
20
|
-
def save_ds_checkpoint(
|
|
21
|
-
model: nn.Module,
|
|
22
|
-
suffix: Optional[str] = None
|
|
23
|
-
):
|
|
19
|
+
def save_ds_checkpoint(model: nn.Module):
|
|
24
20
|
assert isinstance(model, DeepSpeedEngine)
|
|
25
21
|
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
26
|
-
if suffix:
|
|
27
|
-
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
28
22
|
|
|
29
23
|
try:
|
|
30
24
|
# 包括model、optimizer等状态
|
|
@@ -44,13 +38,10 @@ def save_ds_checkpoint(
|
|
|
44
38
|
|
|
45
39
|
def load_ds_checkpoint(
|
|
46
40
|
model: nn.Module,
|
|
47
|
-
load_module_only: bool = False
|
|
48
|
-
suffix: Optional[str] = None
|
|
41
|
+
load_module_only: bool = False
|
|
49
42
|
):
|
|
50
43
|
assert isinstance(model, DeepSpeedEngine)
|
|
51
44
|
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
52
|
-
if suffix:
|
|
53
|
-
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
54
45
|
|
|
55
46
|
# 包括model、optimizer等状态
|
|
56
47
|
if os.path.exists(ckpt_dir):
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -283,7 +283,7 @@ class GRPOTrainer(Trainer):
|
|
|
283
283
|
skipping_train = False
|
|
284
284
|
|
|
285
285
|
current_loss: float = 0.0
|
|
286
|
-
last_best_checkpoint_loss: float =
|
|
286
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
287
287
|
|
|
288
288
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
289
289
|
|
|
@@ -345,6 +345,8 @@ class GRPOTrainer(Trainer):
|
|
|
345
345
|
if TrainerTools().parallel.parallel_train:
|
|
346
346
|
dist.all_reduce(loss, dist.ReduceOp.AVG)
|
|
347
347
|
|
|
348
|
+
current_loss = loss.detach().item()
|
|
349
|
+
|
|
348
350
|
# ds模式已经集成gradient_clipping
|
|
349
351
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
350
352
|
# clip grad
|
|
@@ -357,7 +359,7 @@ class GRPOTrainer(Trainer):
|
|
|
357
359
|
epoch_tag=f'epoch: {epoch}',
|
|
358
360
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
359
361
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
360
|
-
loss=
|
|
362
|
+
loss=current_loss
|
|
361
363
|
)
|
|
362
364
|
except Exception as e:
|
|
363
365
|
self._on_exception(e, epoch, batch)
|
llm_trainer/trainer.py
CHANGED
|
@@ -465,11 +465,12 @@ class Trainer:
|
|
|
465
465
|
# 梯度累积步数
|
|
466
466
|
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
467
467
|
global_steps = 0
|
|
468
|
-
loss_accumulation = 0.0
|
|
469
468
|
skipping_train = False
|
|
470
469
|
|
|
470
|
+
loss_accumulation = 0.0
|
|
471
|
+
batches_accumulated = 0
|
|
471
472
|
current_loss: float = 0.0
|
|
472
|
-
last_best_checkpoint_loss: float =
|
|
473
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
473
474
|
|
|
474
475
|
for epoch in range(self.train_config.n_epochs):
|
|
475
476
|
self.train_model.train()
|
|
@@ -536,9 +537,10 @@ class Trainer:
|
|
|
536
537
|
|
|
537
538
|
loss_accumulation += loss.detach().item()
|
|
538
539
|
self._backward_loss(loss)
|
|
540
|
+
batches_accumulated += 1
|
|
539
541
|
|
|
540
542
|
if need_update_grad:
|
|
541
|
-
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
543
|
+
loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
|
|
542
544
|
|
|
543
545
|
if TrainerTools().parallel.parallel_train:
|
|
544
546
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
@@ -561,6 +563,7 @@ class Trainer:
|
|
|
561
563
|
)
|
|
562
564
|
# reset to default
|
|
563
565
|
loss_accumulation = 0.0
|
|
566
|
+
batches_accumulated = 0
|
|
564
567
|
except Exception as e:
|
|
565
568
|
self._on_exception(e, epoch, batch)
|
|
566
569
|
finally:
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
11
|
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.5.dist-info/METADATA,sha256=ajxfapuo4Q2xfdJ3kjZoCzs7Q5ynGp6BssXRFOIbF7Y,195
|
|
31
|
+
project_llm_trainer-0.5.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.5.dist-info/RECORD,,
|
{project_llm_trainer-0.5.3.data → project_llm_trainer-0.5.5.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|