project-llm-trainer 0.5.2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +20 -29
- llm_trainer/dpo_trainer.py +1 -1
- llm_trainer/ds_checkpoint.py +2 -11
- llm_trainer/grpo_trainer.py +1 -1
- llm_trainer/trainer.py +1 -1
- {project_llm_trainer-0.5.2.dist-info → project_llm_trainer-0.5.4.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.2.dist-info → project_llm_trainer-0.5.4.dist-info}/RECORD +16 -16
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.2.dist-info → project_llm_trainer-0.5.4.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.2.dist-info → project_llm_trainer-0.5.4.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -14,17 +14,14 @@ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
|
|
|
14
14
|
|
|
15
15
|
def save_checkpoint(
|
|
16
16
|
model: nn.Module,
|
|
17
|
-
optimizer: Optional[Optimizer] = None
|
|
18
|
-
suffix: Optional[str] = None
|
|
17
|
+
optimizer: Optional[Optimizer] = None
|
|
19
18
|
):
|
|
20
19
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
21
20
|
from .ds_checkpoint import save_ds_checkpoint
|
|
22
|
-
save_ds_checkpoint(model
|
|
21
|
+
save_ds_checkpoint(model)
|
|
23
22
|
else:
|
|
24
23
|
if TrainerTools().parallel.is_main_process:
|
|
25
24
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
26
|
-
if suffix:
|
|
27
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
28
25
|
|
|
29
26
|
raw_model = model if not isinstance(model, DDP) else model.module
|
|
30
27
|
ckpt = {'model_state_dict': raw_model.state_dict()}
|
|
@@ -37,29 +34,27 @@ def save_checkpoint(
|
|
|
37
34
|
|
|
38
35
|
def save_best_checkpoint(
|
|
39
36
|
current_loss: float,
|
|
40
|
-
last_best_checkpoint_loss: float
|
|
41
|
-
suffix: Optional[str] = None
|
|
37
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
42
38
|
) -> bool:
|
|
43
|
-
need_replace = current_loss <= last_best_checkpoint_loss
|
|
39
|
+
need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
|
|
44
40
|
if need_replace and TrainerTools().parallel.is_main_process:
|
|
45
41
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
46
|
-
|
|
47
|
-
if suffix:
|
|
48
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
42
|
+
checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
49
43
|
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
44
|
+
if checkpoint_dir.endswith('/'):
|
|
45
|
+
best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
|
|
46
|
+
else:
|
|
47
|
+
best_checkpoint_dir = f'{checkpoint_dir}_best'
|
|
53
48
|
|
|
54
|
-
if os.path.exists(
|
|
55
|
-
|
|
56
|
-
|
|
49
|
+
if not os.path.exists(best_checkpoint_dir):
|
|
50
|
+
os.makedirs(best_checkpoint_dir)
|
|
51
|
+
|
|
52
|
+
if os.path.exists(checkpoint_dir):
|
|
53
|
+
shutil.rmtree(best_checkpoint_dir)
|
|
54
|
+
shutil.copytree(checkpoint_dir, best_checkpoint_dir)
|
|
57
55
|
else:
|
|
58
56
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
59
|
-
|
|
60
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
61
|
-
|
|
62
|
-
best_checkpoint_name = f'best_{checkpoint_name}'
|
|
57
|
+
best_checkpoint_name = f'{checkpoint_name}_best'
|
|
63
58
|
|
|
64
59
|
if os.path.exists(checkpoint_name):
|
|
65
60
|
if os.path.exists(best_checkpoint_name):
|
|
@@ -75,16 +70,13 @@ def load_checkpoint(
|
|
|
75
70
|
model: nn.Module,
|
|
76
71
|
optimizer: Optional[Optimizer] = None,
|
|
77
72
|
device: Optional[Union[torch.device, str]] = None,
|
|
78
|
-
load_module_only: bool = False
|
|
79
|
-
suffix: Optional[str] = None
|
|
73
|
+
load_module_only: bool = False
|
|
80
74
|
):
|
|
81
75
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
82
76
|
from .ds_checkpoint import load_ds_checkpoint
|
|
83
|
-
load_ds_checkpoint(model, load_module_only=load_module_only
|
|
77
|
+
load_ds_checkpoint(model, load_module_only=load_module_only)
|
|
84
78
|
else:
|
|
85
79
|
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
86
|
-
if suffix:
|
|
87
|
-
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
88
80
|
|
|
89
81
|
state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
|
|
90
82
|
raw_model = model.module if isinstance(model, DDP) else model
|
|
@@ -96,14 +88,13 @@ def load_checkpoint(
|
|
|
96
88
|
|
|
97
89
|
def load_checkpoint_for_eval(
|
|
98
90
|
model: nn.Module,
|
|
99
|
-
device: Optional[Union[torch.device, str]] = None
|
|
100
|
-
suffix: Optional[str] = None
|
|
91
|
+
device: Optional[Union[torch.device, str]] = None
|
|
101
92
|
):
|
|
102
93
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
103
94
|
from .ds_checkpoint import load_ds_checkpoint_for_eval
|
|
104
95
|
load_ds_checkpoint_for_eval(model)
|
|
105
96
|
else:
|
|
106
|
-
load_checkpoint(model, None, device
|
|
97
|
+
load_checkpoint(model, None, device)
|
|
107
98
|
|
|
108
99
|
|
|
109
100
|
def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -141,7 +141,7 @@ class DPOTrainer(Trainer):
|
|
|
141
141
|
skipping_train = False
|
|
142
142
|
|
|
143
143
|
current_loss: float = 0.0
|
|
144
|
-
last_best_checkpoint_loss: float =
|
|
144
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
145
145
|
|
|
146
146
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
147
147
|
|
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import os
|
|
2
|
-
from typing import Optional
|
|
3
2
|
from glob import glob
|
|
4
3
|
import shutil
|
|
5
4
|
from torch import nn
|
|
@@ -17,14 +16,9 @@ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器
|
|
|
17
16
|
convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
|
|
18
17
|
"""
|
|
19
18
|
|
|
20
|
-
def save_ds_checkpoint(
|
|
21
|
-
model: nn.Module,
|
|
22
|
-
suffix: Optional[str] = None
|
|
23
|
-
):
|
|
19
|
+
def save_ds_checkpoint(model: nn.Module):
|
|
24
20
|
assert isinstance(model, DeepSpeedEngine)
|
|
25
21
|
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
26
|
-
if suffix:
|
|
27
|
-
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
28
22
|
|
|
29
23
|
try:
|
|
30
24
|
# 包括model、optimizer等状态
|
|
@@ -44,13 +38,10 @@ def save_ds_checkpoint(
|
|
|
44
38
|
|
|
45
39
|
def load_ds_checkpoint(
|
|
46
40
|
model: nn.Module,
|
|
47
|
-
load_module_only: bool = False
|
|
48
|
-
suffix: Optional[str] = None
|
|
41
|
+
load_module_only: bool = False
|
|
49
42
|
):
|
|
50
43
|
assert isinstance(model, DeepSpeedEngine)
|
|
51
44
|
ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
52
|
-
if suffix:
|
|
53
|
-
ckpt_dir = f"{ckpt_dir}_{suffix}"
|
|
54
45
|
|
|
55
46
|
# 包括model、optimizer等状态
|
|
56
47
|
if os.path.exists(ckpt_dir):
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -283,7 +283,7 @@ class GRPOTrainer(Trainer):
|
|
|
283
283
|
skipping_train = False
|
|
284
284
|
|
|
285
285
|
current_loss: float = 0.0
|
|
286
|
-
last_best_checkpoint_loss: float =
|
|
286
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
287
287
|
|
|
288
288
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
289
289
|
|
llm_trainer/trainer.py
CHANGED
|
@@ -469,7 +469,7 @@ class Trainer:
|
|
|
469
469
|
skipping_train = False
|
|
470
470
|
|
|
471
471
|
current_loss: float = 0.0
|
|
472
|
-
last_best_checkpoint_loss: float =
|
|
472
|
+
last_best_checkpoint_loss: Optional[float] = None
|
|
473
473
|
|
|
474
474
|
for epoch in range(self.train_config.n_epochs):
|
|
475
475
|
self.train_model.train()
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=3hCjX06W6nt-8lio0YyGzXdI_saa5QlypXehQTtacO4,11871
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=VltU9AngecJ5PEtpn_ToXfiLIc6VwwgOVwlLh-X2Je8,16001
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
11
|
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=VHnuL8rZgxj2ewBVxmbN0jwVEPRVhbK2V2DYCtm_FxI,25886
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.4.dist-info/METADATA,sha256=lZWPvJQFiqZTl9b1FUSFl1Fl6berO7DKGZ-A_7ZkidE,195
|
|
31
|
+
project_llm_trainer-0.5.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.4.dist-info/RECORD,,
|
{project_llm_trainer-0.5.2.data → project_llm_trainer-0.5.4.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|