project-llm-trainer 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -38,29 +38,32 @@ def save_best_checkpoint(
38
38
  ) -> bool:
39
39
  need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
40
40
  if need_replace and TrainerTools().parallel.is_main_process:
41
- if isinstance(TrainerTools().parallel, DsParallel):
42
- checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
41
+ try:
42
+ if isinstance(TrainerTools().parallel, DsParallel):
43
+ checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
43
44
 
44
- if checkpoint_dir.endswith('/'):
45
- best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
46
- else:
47
- best_checkpoint_dir = f'{checkpoint_dir}_best'
45
+ if checkpoint_dir.endswith('/'):
46
+ best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
47
+ else:
48
+ best_checkpoint_dir = f'{checkpoint_dir}_best'
48
49
 
49
- if not os.path.exists(best_checkpoint_dir):
50
- os.makedirs(best_checkpoint_dir)
50
+ if not os.path.exists(best_checkpoint_dir):
51
+ os.makedirs(best_checkpoint_dir)
51
52
 
52
- if os.path.exists(checkpoint_dir):
53
- shutil.rmtree(best_checkpoint_dir)
54
- shutil.copytree(checkpoint_dir, best_checkpoint_dir)
55
- else:
56
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
57
- best_checkpoint_name = f'{checkpoint_name}_best'
53
+ if os.path.exists(checkpoint_dir):
54
+ shutil.rmtree(best_checkpoint_dir)
55
+ shutil.copytree(checkpoint_dir, best_checkpoint_dir)
56
+ else:
57
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
58
+ best_checkpoint_name = f'{checkpoint_name}_best'
58
59
 
59
- if os.path.exists(checkpoint_name):
60
- if os.path.exists(best_checkpoint_name):
61
- os.remove(best_checkpoint_name)
60
+ if os.path.exists(checkpoint_name):
61
+ if os.path.exists(best_checkpoint_name):
62
+ os.remove(best_checkpoint_name)
62
63
 
63
- shutil.copy2(checkpoint_name, best_checkpoint_name)
64
+ shutil.copy2(checkpoint_name, best_checkpoint_name)
65
+ except:
66
+ pass
64
67
 
65
68
  TrainerTools().parallel.wait()
66
69
  return need_replace
@@ -137,9 +137,10 @@ class DPOTrainer(Trainer):
137
137
  # 梯度累积步数
138
138
  gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
139
139
  global_steps = 0
140
- loss_accumulation = 0.0
141
140
  skipping_train = False
142
141
 
142
+ loss_accumulation = 0.0
143
+ batches_accumulated = 0
143
144
  current_loss: float = 0.0
144
145
  last_best_checkpoint_loss: Optional[float] = None
145
146
 
@@ -214,14 +215,15 @@ class DPOTrainer(Trainer):
214
215
 
215
216
  loss_accumulation += loss.detach().item()
216
217
  self._backward_loss(loss)
218
+ batches_accumulated += 1
217
219
 
218
220
  if need_update_grad:
219
- loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
221
+ loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
220
222
 
221
223
  if TrainerTools().parallel.parallel_train:
222
224
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
223
225
 
224
- final_log_loss = loss_tensor.item()
226
+ current_loss = loss_tensor.item()
225
227
 
226
228
  # ds模式已经集成gradient_clipping
227
229
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
@@ -235,10 +237,11 @@ class DPOTrainer(Trainer):
235
237
  epoch_tag=f'epoch: {epoch}',
236
238
  file_tag=f'file: {file_idx + 1}/{file_count}',
237
239
  batch_tag=f'batch: {batch}/{batch_count_per_file}',
238
- loss=final_log_loss
240
+ loss=current_loss
239
241
  )
240
242
  # reset to default
241
243
  loss_accumulation = 0.0
244
+ batches_accumulated = 0
242
245
  except Exception as e:
243
246
  self._on_exception(e, epoch, batch)
244
247
  finally:
@@ -345,6 +345,8 @@ class GRPOTrainer(Trainer):
345
345
  if TrainerTools().parallel.parallel_train:
346
346
  dist.all_reduce(loss, dist.ReduceOp.AVG)
347
347
 
348
+ current_loss = loss.detach().item()
349
+
348
350
  # ds模式已经集成gradient_clipping
349
351
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
350
352
  # clip grad
@@ -357,7 +359,7 @@ class GRPOTrainer(Trainer):
357
359
  epoch_tag=f'epoch: {epoch}',
358
360
  file_tag=f'file: {file_idx + 1}/{file_count}',
359
361
  batch_tag=f'batch: {batch}/{batch_count_per_file}',
360
- loss=loss.detach().item()
362
+ loss=current_loss
361
363
  )
362
364
  except Exception as e:
363
365
  self._on_exception(e, epoch, batch)
llm_trainer/trainer.py CHANGED
@@ -465,9 +465,10 @@ class Trainer:
465
465
  # 梯度累积步数
466
466
  gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
467
467
  global_steps = 0
468
- loss_accumulation = 0.0
469
468
  skipping_train = False
470
469
 
470
+ loss_accumulation = 0.0
471
+ batches_accumulated = 0
471
472
  current_loss: float = 0.0
472
473
  last_best_checkpoint_loss: Optional[float] = None
473
474
 
@@ -536,9 +537,10 @@ class Trainer:
536
537
 
537
538
  loss_accumulation += loss.detach().item()
538
539
  self._backward_loss(loss)
540
+ batches_accumulated += 1
539
541
 
540
542
  if need_update_grad:
541
- loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
543
+ loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
542
544
 
543
545
  if TrainerTools().parallel.parallel_train:
544
546
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
@@ -561,6 +563,7 @@ class Trainer:
561
563
  )
562
564
  # reset to default
563
565
  loss_accumulation = 0.0
566
+ batches_accumulated = 0
564
567
  except Exception as e:
565
568
  self._on_exception(e, epoch, batch)
566
569
  finally:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.4
3
+ Version: 0.5.6
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
2
+ llm_trainer/checkpoint.py,sha256=GHRnPpvG0lz8mg2qv0itHr1rXLlj-itOqZWCtOV1IRU,4411
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=3hCjX06W6nt-8lio0YyGzXdI_saa5QlypXehQTtacO4,11871
4
+ llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
5
  llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
6
6
  llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
7
  llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
8
- llm_trainer/grpo_trainer.py,sha256=VltU9AngecJ5PEtpn_ToXfiLIc6VwwgOVwlLh-X2Je8,16001
8
+ llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
11
  llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
21
- llm_trainer/trainer.py,sha256=VHnuL8rZgxj2ewBVxmbN0jwVEPRVhbK2V2DYCtm_FxI,25886
21
+ llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.4.dist-info/METADATA,sha256=lZWPvJQFiqZTl9b1FUSFl1Fl6berO7DKGZ-A_7ZkidE,195
31
- project_llm_trainer-0.5.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.4.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.6.dist-info/METADATA,sha256=5JaiIS6GqDsm9o_6Zz-vkec_0rmENISg32Qld1Gn-u8,195
31
+ project_llm_trainer-0.5.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.6.dist-info/RECORD,,