project-llm-trainer 0.5.14__py3-none-any.whl → 0.5.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -65,7 +65,7 @@ def save_best_checkpoint(
65
65
  except:
66
66
  pass
67
67
 
68
- TrainerTools().parallel.wait()
68
+ TrainerTools().parallel.wait('save best checkpoint')
69
69
  return need_replace
70
70
 
71
71
 
@@ -170,14 +170,19 @@ class DPOTrainer(Trainer):
170
170
  skipping_train = True
171
171
  continue
172
172
 
173
- skipping_train = False
174
-
175
173
  # 是否需要更新梯度
176
- if gradient_accumulation_steps > 1:
174
+ if skipping_train:
175
+ need_update_grad = False
176
+ elif gradient_accumulation_steps > 1:
177
177
  need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
178
178
  else:
179
179
  need_update_grad = True
180
180
 
181
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
182
+ if skipping_train:
183
+ TrainerTools().parallel.wait('skip train')
184
+ skipping_train = False
185
+
181
186
  try:
182
187
  chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
183
188
  chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
@@ -26,8 +26,6 @@ def save_ds_checkpoint(model: nn.Module):
26
26
  model.save_checkpoint(save_dir=ckpt_dir)
27
27
  except: ...
28
28
 
29
- TrainerTools().parallel.wait()
30
-
31
29
  # 只在main rank上执行
32
30
  if TrainerTools().parallel.is_main_process:
33
31
  # 删除历史checkpoint
@@ -39,7 +37,7 @@ def save_ds_checkpoint(model: nn.Module):
39
37
  shutil.rmtree(oldest_ckpt)
40
38
  except: ...
41
39
 
42
- TrainerTools().parallel.wait()
40
+ TrainerTools().parallel.wait('remove old ds checkpoint')
43
41
 
44
42
 
45
43
  def load_ds_checkpoint(
@@ -317,7 +317,9 @@ class GRPOTrainer(Trainer):
317
317
  skipping_train = True
318
318
  continue
319
319
 
320
- skipping_train = False
320
+ if skipping_train:
321
+ TrainerTools().parallel.wait('skip train')
322
+ skipping_train = False
321
323
 
322
324
  # start generate
323
325
  if TrainerTools().parallel.is_main_process:
llm_trainer/parallel.py CHANGED
@@ -139,9 +139,8 @@ class Parallel(ABC):
139
139
  return dist.get_world_size()
140
140
  return 1
141
141
 
142
- def wait(self):
143
- try:
144
- log(f'wait at {self.device}')
145
- dist.barrier()
146
- except: ...
147
- log(f'continue at {self.device}')
142
+ def wait(self, msg=None):
143
+ msg = f' for {msg}' if msg else None
144
+ log(f'wait at {self.device}{msg}')
145
+ dist.barrier()
146
+ log(f'continue at {self.device}{msg}')
llm_trainer/trainer.py CHANGED
@@ -449,7 +449,7 @@ class Trainer:
449
449
  )
450
450
  generate_model.train()
451
451
 
452
- TrainerTools().parallel.wait()
452
+ TrainerTools().parallel.wait('eval')
453
453
 
454
454
  def _on_batch_end(self, tag: str):
455
455
  self._eval(f'sign:batch/{tag}')
@@ -500,14 +500,19 @@ class Trainer:
500
500
  skipping_train = True
501
501
  continue
502
502
 
503
- skipping_train = False
504
-
505
503
  # 是否需要更新梯度
506
- if gradient_accumulation_steps > 1:
504
+ if skipping_train:
505
+ need_update_grad = False
506
+ elif gradient_accumulation_steps > 1:
507
507
  need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
508
508
  else:
509
509
  need_update_grad = True
510
510
 
511
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
512
+ if skipping_train:
513
+ TrainerTools().parallel.wait('skip train')
514
+ skipping_train = False
515
+
511
516
  inputs = batch_data['inputs']
512
517
  labels = batch_data['labels']
513
518
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.14
3
+ Version: 0.5.16
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,14 +1,14 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=Wh5CwceIajTgJ9i_mH3I1R9N2nOLFqVFmlEMkTiGcD4,4306
2
+ llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=8LYRxviJKcB-rN_XprVsWr5YshU8KolggMm7irjbXvI,11990
5
- llm_trainer/ds_checkpoint.py,sha256=U7f79uWLysiyLnbJET_G4RwWh53Z0C9HCcXmNKq8UvM,2151
4
+ llm_trainer/dpo_trainer.py,sha256=--ItH-rkkq24Da3M_Kf0VxpQ3t-k0fpZrzFGqkYsjks,12304
5
+ llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
- llm_trainer/grpo_trainer.py,sha256=PVTlKOEJpI0AMlh7Siw_MHpLm9CAZepCAMjjSZF6eRU,15996
8
+ llm_trainer/grpo_trainer.py,sha256=g_ivzQop2SkvhlKAEWb0zUnIvNuHTfsOoIG6y29oTCw,16106
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
- llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
11
+ llm_trainer/parallel.py,sha256=j1L4n-JmDkDZblURrNKpEAWEqqGIAXAN9PT_fSS_OnE,4492
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
21
- llm_trainer/trainer.py,sha256=FF75J-BRUp34No2TvQIgomvNozWYzVhDeOfaBgQLV9g,26079
21
+ llm_trainer/trainer.py,sha256=YqWhD9jXbrUdm3KEjEHLyg_qHiXCy5R7PK-arCXxJ6M,26399
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.14.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.14.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.14.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.14.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.14.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.14.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.14.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.14.dist-info/METADATA,sha256=639zcd1nZ1iEJDswOjEL8eudA39c6RhqPVi-H-xttWE,196
31
- project_llm_trainer-0.5.14.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.14.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.14.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.16.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.16.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.16.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.16.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.16.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.16.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.16.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.16.dist-info/METADATA,sha256=h0TMNrZMUU875tVasbuqt69EuPPMbo_nv6tHQLKeNbQ,196
31
+ project_llm_trainer-0.5.16.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.16.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.16.dist-info/RECORD,,