project-llm-trainer 0.7.4__py3-none-any.whl → 0.7.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -50,6 +50,10 @@ class GRPOTrainer(Trainer):
50
50
  self._use_origin_pad_sequence = True
51
51
 
52
52
  def _init_ref_model(self):
53
+ # beta == 0,不需要ref_model
54
+ if self.train_config.grpo_config.loss_beta == 0.0:
55
+ return None
56
+
53
57
  ref_model = self._new_model(self.train_config)
54
58
 
55
59
  ref_model, _ = TrainerTools().parallel.process(
@@ -230,8 +234,11 @@ class GRPOTrainer(Trainer):
230
234
  # Compute old_log_probs from the current model, with gradients disabled.
231
235
  old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
232
236
 
233
- # Compute ref_log_probs from the reference model, which remains static.
234
- ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
237
+ if self.ref_model:
238
+ # Compute ref_log_probs from the reference model, which remains static.
239
+ ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
240
+ else:
241
+ ref_log_probs = None
235
242
 
236
243
  repeated_prompts = [p for p in prompts for _ in range(group_size)]
237
244
  repeated_answers = [a for a in answers for _ in range(group_size)]
@@ -294,11 +301,12 @@ class GRPOTrainer(Trainer):
294
301
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
295
302
 
296
303
  for epoch in range(self.train_config.n_epochs):
297
- sync_model_params(
298
- _from=self.train_model,
299
- _to=self.ref_model,
300
- mixup_alpha=self.train_config.grpo_config.mixup_alpha
301
- )
304
+ if self.ref_model:
305
+ sync_model_params(
306
+ _from=self.train_model,
307
+ _to=self.ref_model,
308
+ mixup_alpha=self.train_config.grpo_config.mixup_alpha
309
+ )
302
310
 
303
311
  file_count = len(self.train_config.file_dataset)
304
312
 
@@ -366,7 +374,7 @@ class GRPOTrainer(Trainer):
366
374
  self._log_loss(
367
375
  epoch_tag=f'epoch: {epoch}',
368
376
  file_tag=f'file: {file_idx + 1}/{file_count}',
369
- batch_tag=f'batch: {batch}/{batch_count_per_file}',
377
+ batch_tag=f'batch: {batch}/{batch_count_per_file}, grpo_step={grpo_step}',
370
378
  loss=current_loss
371
379
  )
372
380
  except Exception as e:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.7.4
3
+ Version: 0.7.5
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -5,7 +5,7 @@ llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12
5
5
  llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
8
- llm_trainer/grpo_trainer.py,sha256=2mMuRa7UXAgPSgav4Wp9-cs0QOPWQghv2IrW515Gn2Q,16515
8
+ llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
9
9
  llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
10
10
  llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
11
11
  llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
@@ -20,14 +20,14 @@ llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
20
20
  llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
21
21
  llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
22
22
  llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
23
- project_llm_trainer-0.7.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.7.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.7.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.7.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.7.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.7.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.7.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.7.4.dist-info/METADATA,sha256=Uq-6PqNSqhtsjYPmwAuMW07Y4SLq_caecSoCxYSbJHw,195
31
- project_llm_trainer-0.7.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.7.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.7.4.dist-info/RECORD,,
23
+ project_llm_trainer-0.7.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.7.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.7.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.7.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.7.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.7.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.7.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.7.5.dist-info/METADATA,sha256=9DcoFVuXDrhxZOVWF1Ouzk7NF6NTEnpBTkg1n6bMCYQ,195
31
+ project_llm_trainer-0.7.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.7.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.7.5.dist-info/RECORD,,