project-llm-trainer 0.4.13__py3-none-any.whl → 0.4.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +2 -2
- llm_trainer/ds_checkpoint.py +2 -2
- llm_trainer/fsdp_checkpoint.py +2 -2
- {project_llm_trainer-0.4.13.dist-info → project_llm_trainer-0.4.14.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.13.dist-info → project_llm_trainer-0.4.14.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.13.dist-info → project_llm_trainer-0.4.14.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.13.dist-info → project_llm_trainer-0.4.14.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -129,10 +129,10 @@ def copy_model_params(
|
|
|
129
129
|
|
|
130
130
|
if isinstance(TrainerTools().parallel, DsParallel):
|
|
131
131
|
from .ds_checkpoint import get_ds_model_params
|
|
132
|
-
state_dict = get_ds_model_params(_from)
|
|
132
|
+
state_dict = get_ds_model_params(_from, only_rank0=_to is None)
|
|
133
133
|
elif isinstance(TrainerTools().parallel, FsdpParallel):
|
|
134
134
|
from .fsdp_checkpoint import get_fsdp_model_params
|
|
135
|
-
state_dict = get_fsdp_model_params(_from)
|
|
135
|
+
state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
|
|
136
136
|
elif isinstance(_from, DDP):
|
|
137
137
|
state_dict = _from.module.state_dict()
|
|
138
138
|
else:
|
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -105,7 +105,7 @@ def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
|
105
105
|
# return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
|
|
106
106
|
|
|
107
107
|
|
|
108
|
-
def get_ds_model_params(model: nn.Module):
|
|
108
|
+
def get_ds_model_params(model: nn.Module, only_rank0=False):
|
|
109
109
|
"""
|
|
110
110
|
从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
|
|
111
111
|
兼容 ZeRO Stages 0, 1, 2, 3。
|
|
@@ -117,7 +117,7 @@ def get_ds_model_params(model: nn.Module):
|
|
|
117
117
|
|
|
118
118
|
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
119
119
|
# 我们需要将其广播给所有进程。
|
|
120
|
-
if TrainerTools().parallel.world_size > 1:
|
|
120
|
+
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
121
121
|
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
122
122
|
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
123
123
|
# 执行广播,这个操作是阻塞的,会同步所有进程
|
llm_trainer/fsdp_checkpoint.py
CHANGED
|
@@ -66,7 +66,7 @@ def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
|
|
|
66
66
|
return None
|
|
67
67
|
|
|
68
68
|
|
|
69
|
-
def get_fsdp_model_params(model: nn.Module):
|
|
69
|
+
def get_fsdp_model_params(model: nn.Module, only_rank0=False):
|
|
70
70
|
"""
|
|
71
71
|
从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
|
|
72
72
|
这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
|
|
@@ -76,7 +76,7 @@ def get_fsdp_model_params(model: nn.Module):
|
|
|
76
76
|
|
|
77
77
|
# 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
|
|
78
78
|
# 我们需要将其广播给所有进程。
|
|
79
|
-
if TrainerTools().parallel.world_size > 1:
|
|
79
|
+
if not only_rank0 and TrainerTools().parallel.world_size > 1:
|
|
80
80
|
# 准备一个列表,rank 0 有数据,其他 rank 是占位符
|
|
81
81
|
object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
|
|
82
82
|
# 执行广播,这个操作是阻塞的,会同步所有进程
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=ItDzuXVikk-0gWSw-IS7SrODEdlJEb5nZs15dBFkPdk,5793
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
5
|
llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
|
|
6
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
6
|
+
llm_trainer/ds_checkpoint.py,sha256=fprJlbSgtyKmmpytyMOZBs3pcjZA13SeWao0llnLpNQ,4962
|
|
7
7
|
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
|
-
llm_trainer/fsdp_checkpoint.py,sha256=
|
|
8
|
+
llm_trainer/fsdp_checkpoint.py,sha256=dAHIGHfuvTA6OC0jV9Ls-oD4ZR9CPGa31mjtoh-2dZE,3229
|
|
9
9
|
llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
|
|
10
10
|
llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
|
|
11
11
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
@@ -22,14 +22,14 @@ llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
|
22
22
|
llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
|
|
23
23
|
llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.
|
|
26
|
-
project_llm_trainer-0.4.
|
|
27
|
-
project_llm_trainer-0.4.
|
|
28
|
-
project_llm_trainer-0.4.
|
|
29
|
-
project_llm_trainer-0.4.
|
|
30
|
-
project_llm_trainer-0.4.
|
|
31
|
-
project_llm_trainer-0.4.
|
|
32
|
-
project_llm_trainer-0.4.
|
|
33
|
-
project_llm_trainer-0.4.
|
|
34
|
-
project_llm_trainer-0.4.
|
|
35
|
-
project_llm_trainer-0.4.
|
|
25
|
+
project_llm_trainer-0.4.14.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.14.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.14.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.14.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.14.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.14.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.14.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.14.dist-info/METADATA,sha256=VMEWVv8pBqFUAhIAiH4_S4ECUHln31gchHLhTtUAM1o,196
|
|
33
|
+
project_llm_trainer-0.4.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.14.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.14.dist-info/RECORD,,
|
{project_llm_trainer-0.4.13.data → project_llm_trainer-0.4.14.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|