project-llm-trainer 0.4.13__py3-none-any.whl → 0.4.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -129,10 +129,10 @@ def copy_model_params(
129
129
 
130
130
  if isinstance(TrainerTools().parallel, DsParallel):
131
131
  from .ds_checkpoint import get_ds_model_params
132
- state_dict = get_ds_model_params(_from)
132
+ state_dict = get_ds_model_params(_from, only_rank0=_to is None)
133
133
  elif isinstance(TrainerTools().parallel, FsdpParallel):
134
134
  from .fsdp_checkpoint import get_fsdp_model_params
135
- state_dict = get_fsdp_model_params(_from)
135
+ state_dict = get_fsdp_model_params(_from, only_rank0=_to is None)
136
136
  elif isinstance(_from, DDP):
137
137
  state_dict = _from.module.state_dict()
138
138
  else:
@@ -105,7 +105,7 @@ def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
105
105
  # return state_dict_on_rank_0 if TrainerTools().parallel.is_main_process else None
106
106
 
107
107
 
108
- def get_ds_model_params(model: nn.Module):
108
+ def get_ds_model_params(model: nn.Module, only_rank0=False):
109
109
  """
110
110
  从一个正在运行的 DeepSpeedEngine 中高效地提取完整的 FP32 state_dict,
111
111
  兼容 ZeRO Stages 0, 1, 2, 3。
@@ -117,7 +117,7 @@ def get_ds_model_params(model: nn.Module):
117
117
 
118
118
  # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
119
119
  # 我们需要将其广播给所有进程。
120
- if TrainerTools().parallel.world_size > 1:
120
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
121
121
  # 准备一个列表,rank 0 有数据,其他 rank 是占位符
122
122
  object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
123
123
  # 执行广播,这个操作是阻塞的,会同步所有进程
@@ -66,7 +66,7 @@ def _get_fsdp_full_state_dict_on_rank0(model: nn.Module) -> Optional[dict]:
66
66
  return None
67
67
 
68
68
 
69
- def get_fsdp_model_params(model: nn.Module):
69
+ def get_fsdp_model_params(model: nn.Module, only_rank0=False):
70
70
  """
71
71
  从一个 FSDP 包装的模型中高效地提取完整的 FP32 state_dict。
72
72
  这个函数会聚合所有分片的参数,并确保所有 rank 都收到一个完整的副本。
@@ -76,7 +76,7 @@ def get_fsdp_model_params(model: nn.Module):
76
76
 
77
77
  # 现在,只有 rank 0 上的 state_dict 是一个有效的字典,其他 rank 上是 None。
78
78
  # 我们需要将其广播给所有进程。
79
- if TrainerTools().parallel.world_size > 1:
79
+ if not only_rank0 and TrainerTools().parallel.world_size > 1:
80
80
  # 准备一个列表,rank 0 有数据,其他 rank 是占位符
81
81
  object_list = [state_dict] if TrainerTools().parallel.is_main_process else [None]
82
82
  # 执行广播,这个操作是阻塞的,会同步所有进程
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.13
3
+ Version: 0.4.14
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,5745
2
+ llm_trainer/checkpoint.py,sha256=ItDzuXVikk-0gWSw-IS7SrODEdlJEb5nZs15dBFkPdk,5793
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
5
  llm_trainer/dpo_trainer.py,sha256=djBhvI_ixTV1nLNg84tgCpfV--pu6IRiOhO28V-aANQ,11425
6
- llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
6
+ llm_trainer/ds_checkpoint.py,sha256=fprJlbSgtyKmmpytyMOZBs3pcjZA13SeWao0llnLpNQ,4962
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
- llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
8
+ llm_trainer/fsdp_checkpoint.py,sha256=dAHIGHfuvTA6OC0jV9Ls-oD4ZR9CPGa31mjtoh-2dZE,3229
9
9
  llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
10
10
  llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
11
11
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
@@ -22,14 +22,14 @@ llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
22
  llm_trainer/train_configs.py,sha256=HKzH3nfMT1-SW4Htwa0KqYtMd6FAJcthR5IEo6di8us,8168
23
23
  llm_trainer/trainer.py,sha256=j5fDqMzvU6SYwxHsv9wX0UVX4JXS-8eP1AkHgVxKf9U,27119
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.13.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.13.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.13.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.13.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.13.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.13.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.13.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.13.dist-info/METADATA,sha256=hiW-7qgWuPizKVz4cU8mEHoqiuT6ZqNlCBb7nwVfFQ4,196
33
- project_llm_trainer-0.4.13.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.13.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.13.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.14.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.14.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.14.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.14.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.14.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.14.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.14.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.14.dist-info/METADATA,sha256=VMEWVv8pBqFUAhIAiH4_S4ECUHln31gchHLhTtUAM1o,196
33
+ project_llm_trainer-0.4.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.14.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.14.dist-info/RECORD,,