project-llm-trainer 0.5.11__py3-none-any.whl → 0.5.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -271,7 +271,5 @@ class DPOTrainer(Trainer):
271
271
  TrainerTools().parallel.on_epoch_end(epoch)
272
272
  self._on_epoch_end(tag=f'epoch:{epoch}')
273
273
 
274
- # 等待checkpoint保存完成
275
- time.sleep(10)
276
274
  TrainerTools().parallel.destroy()
277
275
 
@@ -2,6 +2,7 @@ import os
2
2
  from glob import glob
3
3
  import shutil
4
4
  from torch import nn
5
+ from .tools import TrainerTools
5
6
 
6
7
  try:
7
8
  import deepspeed
@@ -23,17 +24,20 @@ def save_ds_checkpoint(model: nn.Module):
23
24
  try:
24
25
  # 包括model、optimizer等状态
25
26
  model.save_checkpoint(save_dir=ckpt_dir)
26
- except:
27
- return
28
-
29
- # 删除历史checkpoint
30
- ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
31
- if len(ckpt_paths) > 2:
32
- # 按修改时间排序,找到最旧的目录
33
- oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
34
- try:
35
- shutil.rmtree(oldest_ckpt)
36
- except: ...
27
+ except: ...
28
+
29
+ TrainerTools().parallel.wait()
30
+
31
+ # 只在main rank上执行
32
+ if TrainerTools().parallel.is_main_process:
33
+ # 删除历史checkpoint
34
+ ckpt_paths = glob(os.path.join(ckpt_dir, "global_*"))
35
+ if len(ckpt_paths) > 2:
36
+ # 按修改时间排序,找到最旧的目录
37
+ oldest_ckpt = sorted(ckpt_paths, key=os.path.getmtime)[0]
38
+ try:
39
+ shutil.rmtree(oldest_ckpt)
40
+ except: ...
37
41
 
38
42
 
39
43
  def load_ds_checkpoint(
@@ -389,6 +389,4 @@ class GRPOTrainer(Trainer):
389
389
  TrainerTools().parallel.on_epoch_end(epoch)
390
390
  self._on_epoch_end(tag=f'epoch:{epoch}')
391
391
 
392
- # 等待checkpoint保存完成
393
- time.sleep(10)
394
392
  TrainerTools().parallel.destroy()
@@ -136,8 +136,8 @@ class DPOConfig:
136
136
  @dataclass(kw_only=True)
137
137
  class GRPOConfig:
138
138
  grpo_steps: int = 1
139
- clip_eps: float = 0.2
140
- kl_weight: float = 0.01
139
+ clip_eps: float = 0.1
140
+ kl_weight: float = 0.04
141
141
  group_size: int = 12
142
142
  mixup_alpha: float = 1.0
143
143
  gen_max_new_tokens: Optional[int] = None
llm_trainer/trainer.py CHANGED
@@ -597,6 +597,4 @@ class Trainer:
597
597
  TrainerTools().parallel.on_epoch_end(epoch)
598
598
  self._on_epoch_end(tag=f'epoch:{epoch}')
599
599
 
600
- # 等待checkpoint保存完成
601
- time.sleep(10)
602
600
  TrainerTools().parallel.destroy()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.11
3
+ Version: 0.5.13
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=Wh5CwceIajTgJ9i_mH3I1R9N2nOLFqVFmlEMkTiGcD4,4306
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
- llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
4
+ llm_trainer/dpo_trainer.py,sha256=8LYRxviJKcB-rN_XprVsWr5YshU8KolggMm7irjbXvI,11990
5
+ llm_trainer/ds_checkpoint.py,sha256=-QjtY8JPJMa1IGjMHLJQKrbWRwVqDQSa7OXOKYamwDo,2115
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
- llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
8
+ llm_trainer/grpo_trainer.py,sha256=PVTlKOEJpI0AMlh7Siw_MHpLm9CAZepCAMjjSZF6eRU,15996
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
11
  llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
@@ -17,17 +17,17 @@ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
- llm_trainer/train_configs.py,sha256=guV8xkG5TSGvYwFvsQV_mA8mDHLLVhL5L0xo_WMsMME,7347
21
- llm_trainer/trainer.py,sha256=U26dZc22nByfTZUzKeEiqqYVexBzgw0ep7N0Z2zIcWI,26141
20
+ llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
21
+ llm_trainer/trainer.py,sha256=FF75J-BRUp34No2TvQIgomvNozWYzVhDeOfaBgQLV9g,26079
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.11.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.11.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.11.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.11.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.11.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.11.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.11.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.11.dist-info/METADATA,sha256=RAhT8VLTlO4Oyr9ocDxBvhulYed1JXWX_5GWWkiJ7go,196
31
- project_llm_trainer-0.5.11.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.11.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.11.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.13.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.13.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.13.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.13.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.13.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.13.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.13.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.13.dist-info/METADATA,sha256=ROfaBOhsQK5yC6HMBa_4Tblg90TVadyKbjcNU-s3Imk,196
31
+ project_llm_trainer-0.5.13.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.13.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.13.dist-info/RECORD,,