project-llm-trainer 0.4__py3-none-any.whl → 0.4.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +6 -2
- llm_trainer/ds_checkpoint.py +26 -24
- llm_trainer/eval.py +0 -9
- llm_trainer/fsdp_checkpoint.py +13 -17
- llm_trainer/trainer.py +5 -2
- {project_llm_trainer-0.4.dist-info → project_llm_trainer-0.4.1.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.dist-info → project_llm_trainer-0.4.1.dist-info}/RECORD +16 -16
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.dist-info → project_llm_trainer-0.4.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.dist-info → project_llm_trainer-0.4.1.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -121,8 +121,12 @@ def load_checkpoint_for_eval(
|
|
|
121
121
|
|
|
122
122
|
def copy_model_params(
|
|
123
123
|
_from: nn.Module,
|
|
124
|
-
_to: nn.Module
|
|
124
|
+
_to: Optional[nn.Module]
|
|
125
125
|
):
|
|
126
|
+
"""
|
|
127
|
+
必须在所有rank上调用,非rank0, _to可以设置为None
|
|
128
|
+
"""
|
|
129
|
+
|
|
126
130
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
127
131
|
from .ds_checkpoint import get_ds_model_params
|
|
128
132
|
state_dict = get_ds_model_params(_from)
|
|
@@ -134,7 +138,7 @@ def copy_model_params(
|
|
|
134
138
|
else:
|
|
135
139
|
state_dict = _from.state_dict()
|
|
136
140
|
|
|
137
|
-
if state_dict:
|
|
141
|
+
if _to and state_dict:
|
|
138
142
|
_to.load_state_dict(state_dict)
|
|
139
143
|
|
|
140
144
|
|
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -67,45 +67,47 @@ def load_ds_checkpoint_for_eval(model: nn.Module):
|
|
|
67
67
|
model.load_state_dict(state_dict)
|
|
68
68
|
|
|
69
69
|
|
|
70
|
+
def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
71
|
+
"""
|
|
72
|
+
可以在任意rank上调用,然后只有rank0有值
|
|
73
|
+
"""
|
|
74
|
+
|
|
75
|
+
if model.zero_optimization_stage() != 3:
|
|
76
|
+
if TrainerTools().parallel.is_main_process:
|
|
77
|
+
return {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
# ZeRO-3
|
|
81
|
+
state_dict_on_rank_0 = {}
|
|
82
|
+
for param_name, param in model.module.named_parameters():
|
|
83
|
+
if hasattr(param, 'ds_id'):
|
|
84
|
+
with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
|
|
85
|
+
if TrainerTools().parallel.is_main_process:
|
|
86
|
+
state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
87
|
+
else:
|
|
88
|
+
if TrainerTools().parallel.is_main_process:
|
|
89
|
+
state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
90
|
+
|
|
91
|
+
return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
|
|
92
|
+
|
|
93
|
+
|
|
70
94
|
def get_ds_model_params(model: nn.Module):
|
|
71
95
|
"""
|
|
72
96
|
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
73
97
|
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
74
|
-
|
|
98
|
+
包含了对 ZeRO-3 中分片参数的正确处理。
|
|
75
99
|
"""
|
|
76
100
|
|
|
77
101
|
assert isinstance(model, DeepSpeedEngine)
|
|
78
|
-
|
|
79
|
-
state_dict = None
|
|
80
|
-
|
|
81
|
-
if TrainerTools().parallel.is_main_process:
|
|
82
|
-
if zero_stage == 3:
|
|
83
|
-
# ZeRO-3: Rank 0 聚合参数来构建完整的 state_dict
|
|
84
|
-
state_dict = {}
|
|
85
|
-
for param in model.module.parameters():
|
|
86
|
-
# 关键检查:判断参数是否被 ZeRO-3 分片管理
|
|
87
|
-
if hasattr(param, 'ds_id'):
|
|
88
|
-
# 这是被分片的参数,使用 GatheredParameters 聚合
|
|
89
|
-
with deepspeed.zero.GatheredParameters(param, modifier_rank=0):
|
|
90
|
-
# .clone() 创建一个独立副本, .to('cpu') 移动到CPU, .to(torch.float32) 确保类型
|
|
91
|
-
state_dict[param.ds_name] = param.data.to(torch.float32).cpu().clone()
|
|
92
|
-
else:
|
|
93
|
-
# 这是未被分片的参数 (e.g., tied weights, buffers), 直接从 Rank 0 复制
|
|
94
|
-
state_dict[param.ds_name] = param.data.to(torch.float32).cpu().clone()
|
|
95
|
-
else: # zero_stage in [0, 1, 2]
|
|
96
|
-
# 在这些 stage,rank 0 已经有完整的模型。
|
|
97
|
-
# 我们从 model_engine.module 获取原始模型状态。
|
|
98
|
-
state_dict = {k: v.cpu().clone() for k, v in model.module.state_dict().items()}
|
|
102
|
+
state_dict = _get_ds_full_state_dict_on_rank0(model)
|
|
99
103
|
|
|
100
104
|
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
101
105
|
# 我们需要将其广播给所有进程。
|
|
102
106
|
if TrainerTools().parallel.world_size > 1:
|
|
103
107
|
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
104
108
|
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
105
|
-
|
|
106
109
|
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
107
110
|
dist.broadcast_object_list(object_list, src=0)
|
|
108
|
-
|
|
109
111
|
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
110
112
|
state_dict = object_list[0]
|
|
111
113
|
|
llm_trainer/eval.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from .generate_utils import generate
|
|
4
|
-
from .checkpoint import copy_model_params
|
|
5
4
|
from .log import get_log_dir
|
|
6
5
|
from .tools import TrainerTools
|
|
7
6
|
from .train_configs import EvalConfig
|
|
@@ -37,7 +36,6 @@ def _eval_task(
|
|
|
37
36
|
|
|
38
37
|
|
|
39
38
|
def submit_gen_task(
|
|
40
|
-
train_model: torch.nn.Module,
|
|
41
39
|
eval_model: torch.nn.Module,
|
|
42
40
|
eval_config: EvalConfig,
|
|
43
41
|
tag,
|
|
@@ -46,13 +44,6 @@ def submit_gen_task(
|
|
|
46
44
|
max_position_embeddings,
|
|
47
45
|
tokens_per_image
|
|
48
46
|
):
|
|
49
|
-
try:
|
|
50
|
-
copy_model_params(_from=train_model, _to=eval_model)
|
|
51
|
-
except Exception as e:
|
|
52
|
-
if isinstance(e, KeyboardInterrupt):
|
|
53
|
-
raise e
|
|
54
|
-
return
|
|
55
|
-
|
|
56
47
|
eval_model.to(TrainerTools().parallel.device)
|
|
57
48
|
_eval_task(
|
|
58
49
|
eval_model=eval_model,
|
llm_trainer/fsdp_checkpoint.py
CHANGED
|
@@ -53,6 +53,18 @@ def load_fsdp_checkpoint(
|
|
|
53
53
|
optimizer.load_state_dict(state_dict['optim_state_dict'])
|
|
54
54
|
|
|
55
55
|
|
|
56
|
+
def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
57
|
+
"""
|
|
58
|
+
可以在任意rank上调用,然后只有rank0有值
|
|
59
|
+
"""
|
|
60
|
+
|
|
61
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
62
|
+
with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
|
|
63
|
+
if TrainerTools().parallel.is_main_process:
|
|
64
|
+
return {k: v.clone() for k, v in model.state_dict().items()}
|
|
65
|
+
|
|
66
|
+
return None
|
|
67
|
+
|
|
56
68
|
|
|
57
69
|
def get_fsdp_model_params(model: nn.Module):
|
|
58
70
|
"""
|
|
@@ -60,31 +72,15 @@ def get_fsdp_model_params(model: nn.Module):
|
|
|
60
72
|
这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
|
|
61
73
|
"""
|
|
62
74
|
|
|
63
|
-
|
|
64
|
-
# writeback=False: 我们只读取参数,不写回,可以节省开销。
|
|
65
|
-
# offload_to_cpu=True: 直接将聚合后的参数卸载到 CPU,避免在 GPU 上产生大的峰值内存,
|
|
66
|
-
# 并为我们省去了 .cpu() 的步骤。这是一个非常有用的优化。
|
|
67
|
-
# rank0_only=False: 为了让 offload_to_cpu 在所有 rank 上都生效,这里通常设为 False。
|
|
68
|
-
# 我们稍后通过 get_rank() 来确保只有 rank 0 实际构建字典。
|
|
69
|
-
with FSDP.summon_full_params(model, writeback=False, offload_to_cpu=True):
|
|
70
|
-
|
|
71
|
-
state_dict = None
|
|
72
|
-
if TrainerTools().parallel.is_main_process:
|
|
73
|
-
# 在这个 with 块内部, model.state_dict() 会返回一个在 CPU 上的、完整的状态字典。
|
|
74
|
-
# 因为我们设置了 offload_to_cpu=True。
|
|
75
|
-
# 我们使用 .clone() 来确保我们得到的是一个独立的副本,
|
|
76
|
-
# 尽管 offload_to_cpu 已经帮我们处理了大部分情况。
|
|
77
|
-
state_dict = {k: v.clone() for k, v in model.state_dict().items()}
|
|
75
|
+
state_dict = _get_fsdp_full_state_dict_on_rank0(model)
|
|
78
76
|
|
|
79
77
|
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
80
78
|
# 我们需要将其广播给所有进程。
|
|
81
79
|
if TrainerTools().parallel.world_size > 1:
|
|
82
80
|
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
83
81
|
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
84
|
-
|
|
85
82
|
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
86
83
|
dist.broadcast_object_list(object_list, src=0)
|
|
87
|
-
|
|
88
84
|
# 所有进程从列表中获取广播后的 state_dict 副本
|
|
89
85
|
state_dict = object_list[0]
|
|
90
86
|
|
llm_trainer/trainer.py
CHANGED
|
@@ -31,6 +31,7 @@ from .scheduler import (
|
|
|
31
31
|
from .checkpoint import (
|
|
32
32
|
load_checkpoint,
|
|
33
33
|
save_checkpoint,
|
|
34
|
+
copy_model_params,
|
|
34
35
|
load_steps,
|
|
35
36
|
save_steps,
|
|
36
37
|
)
|
|
@@ -416,6 +417,8 @@ class Trainer:
|
|
|
416
417
|
self,
|
|
417
418
|
tag: str
|
|
418
419
|
):
|
|
420
|
+
copy_model_params(_from=self.train_model, _to=self.eval_model)
|
|
421
|
+
|
|
419
422
|
if TrainerTools().parallel.is_main_process:
|
|
420
423
|
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
421
424
|
if isinstance(self.train_model, VlmModel) and self.pixel_values_provider and eval_image_tag:
|
|
@@ -424,7 +427,6 @@ class Trainer:
|
|
|
424
427
|
eval_pixel_values = None
|
|
425
428
|
|
|
426
429
|
submit_gen_task(
|
|
427
|
-
self.train_model,
|
|
428
430
|
self.eval_model,
|
|
429
431
|
self.train_config.eval_config,
|
|
430
432
|
tag=f'sign:batch/{tag}',
|
|
@@ -439,6 +441,8 @@ class Trainer:
|
|
|
439
441
|
self,
|
|
440
442
|
tag: str
|
|
441
443
|
):
|
|
444
|
+
copy_model_params(_from=self.train_model, _to=self.eval_model)
|
|
445
|
+
|
|
442
446
|
if TrainerTools().parallel.is_main_process:
|
|
443
447
|
eval_prompt, eval_image_tag = self._get_eval_data()
|
|
444
448
|
if isinstance(self.train_model, VlmModel) and self.pixel_values_provider and eval_image_tag:
|
|
@@ -447,7 +451,6 @@ class Trainer:
|
|
|
447
451
|
eval_pixel_values = None
|
|
448
452
|
|
|
449
453
|
submit_gen_task(
|
|
450
|
-
self.train_model,
|
|
451
454
|
self.eval_model,
|
|
452
455
|
self.train_config.eval_config,
|
|
453
456
|
tag=f'sign:epoch/{tag}',
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
5
|
llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13406
|
|
6
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
7
|
-
llm_trainer/eval.py,sha256=
|
|
8
|
-
llm_trainer/fsdp_checkpoint.py,sha256=
|
|
6
|
+
llm_trainer/ds_checkpoint.py,sha256=nchGocJE2oJnQ_KNN1kw-BkOAEIyTtO8SJt41cuN_xM,4232
|
|
7
|
+
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
|
+
llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
|
|
9
9
|
llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
|
|
10
10
|
llm_trainer/grpo_trainer.py,sha256=fqLT48ORSCece_e8dpyt8J7EarDuTnGoJ_eHk7Oy-1k,16177
|
|
11
11
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
@@ -20,16 +20,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
20
20
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
21
|
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
22
|
llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
|
|
23
|
-
llm_trainer/trainer.py,sha256=
|
|
23
|
+
llm_trainer/trainer.py,sha256=hOn-z8kOd67RTuaaNMmdQjlw7N5LIZRHjSt5frpA1xI,25355
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
-
project_llm_trainer-0.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
-
project_llm_trainer-0.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
-
project_llm_trainer-0.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
-
project_llm_trainer-0.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
-
project_llm_trainer-0.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
-
project_llm_trainer-0.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
-
project_llm_trainer-0.4.dist-info/METADATA,sha256
|
|
33
|
-
project_llm_trainer-0.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
-
project_llm_trainer-0.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
-
project_llm_trainer-0.4.dist-info/RECORD,,
|
|
25
|
+
project_llm_trainer-0.4.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.1.dist-info/METADATA,sha256=9z1AB745r7BzQHNc3j-3N2nOdB9ZRUYsxcM42QoSb1o,195
|
|
33
|
+
project_llm_trainer-0.4.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.1.dist-info/RECORD,,
|
{project_llm_trainer-0.4.data → project_llm_trainer-0.4.1.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|