project-llm-trainer 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -12,7 +12,10 @@ from .dataset import DPODataset
12
12
  from .loss import DPOLoss
13
13
  from .tools import TrainerTools
14
14
  from .utils import get_dpo_collate_fn
15
- from .partition_utils import sync_model_params
15
+ from .partition_utils import (
16
+ sync_model_params,
17
+ unwrap_model_for_generation
18
+ )
16
19
 
17
20
  from .checkpoint import (
18
21
  save_checkpoint,
@@ -35,28 +38,28 @@ class DPOTrainer(Trainer):
35
38
  eval_image_tags=eval_image_tags
36
39
  )
37
40
 
38
- self.reference_model = self._init_reference_model()
41
+ self.ref_model = self._init_ref_model()
39
42
 
40
- def _init_reference_model(self):
41
- reference_model = self._new_model(self.train_config)
43
+ def _init_ref_model(self):
44
+ ref_model = self._new_model(self.train_config)
42
45
 
43
- reference_model, _ = TrainerTools().parallel.process(
44
- model=reference_model,
46
+ ref_model, _ = TrainerTools().parallel.process(
47
+ model=ref_model,
45
48
  optimizer=None,
46
- kwargs=self._init_reference_args(),
49
+ kwargs=self._init_ref_model_args(),
47
50
  save_instance=False
48
51
  )
49
52
 
50
- reference_model.eval()
51
- for param in reference_model.parameters():
53
+ ref_model.eval()
54
+ for param in ref_model.parameters():
52
55
  param.requires_grad = False
53
56
 
54
57
  sync_model_params(
55
58
  _from=self.train_model,
56
- _to=reference_model
59
+ _to=ref_model
57
60
  )
58
61
 
59
- return reference_model
62
+ return ref_model
60
63
 
61
64
  def _init_loss(self):
62
65
  criterion = DPOLoss(
@@ -203,17 +206,18 @@ class DPOTrainer(Trainer):
203
206
 
204
207
  with self.ctx:
205
208
  policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
206
- with torch.inference_mode():
207
- ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
208
-
209
209
  policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
210
- ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
210
+ aux_loss = policy_outputs.get('aux_loss')
211
+
212
+ with torch.no_grad():
213
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
214
+ ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
211
215
 
212
216
  # calc loss
213
217
  loss = self.criterion(policy_probs, ref_probs)
214
218
 
215
- if aux_loss_coef and policy_outputs['aux_loss']:
216
- loss += aux_loss_coef * policy_outputs['aux_loss']
219
+ if aux_loss_coef and aux_loss:
220
+ loss += aux_loss_coef *aux_loss
217
221
 
218
222
  if gradient_accumulation_steps > 1:
219
223
  loss = loss / gradient_accumulation_steps
@@ -42,27 +42,27 @@ class GRPOTrainer(Trainer):
42
42
  )
43
43
 
44
44
  self.reward_func = reward_func
45
- self.reference_model = self._init_reference_model()
45
+ self.ref_model = self._init_ref_model()
46
46
 
47
47
  # 默认使用torch提供的pad_sequence
48
48
  # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
49
49
  self._use_origin_pad_sequence = True
50
50
 
51
- def _init_reference_model(self):
52
- reference_model = self._new_model(self.train_config)
51
+ def _init_ref_model(self):
52
+ ref_model = self._new_model(self.train_config)
53
53
 
54
- reference_model, _ = TrainerTools().parallel.process(
55
- model=reference_model,
54
+ ref_model, _ = TrainerTools().parallel.process(
55
+ model=ref_model,
56
56
  optimizer=None,
57
- kwargs=self._init_reference_args(),
57
+ kwargs=self._init_ref_model_args(),
58
58
  save_instance=False
59
59
  )
60
60
 
61
- reference_model.eval()
62
- for param in reference_model.parameters():
61
+ ref_model.eval()
62
+ for param in ref_model.parameters():
63
63
  param.requires_grad = False
64
64
 
65
- return reference_model
65
+ return ref_model
66
66
 
67
67
  def _init_loss(self):
68
68
  criterion = GRPOLoss(
@@ -225,7 +225,7 @@ class GRPOTrainer(Trainer):
225
225
  old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
226
226
 
227
227
  # Compute ref_log_probs from the reference model, which remains static.
228
- ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
228
+ ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
229
229
 
230
230
  repeated_prompts = [p for p in prompts for _ in range(group_size)]
231
231
  repeated_answers = [a for a in answers for _ in range(group_size)]
@@ -290,7 +290,7 @@ class GRPOTrainer(Trainer):
290
290
  for epoch in range(self.train_config.n_epochs):
291
291
  sync_model_params(
292
292
  _from=self.train_model,
293
- _to=self.reference_model,
293
+ _to=self.ref_model,
294
294
  mixup_alpha=self.train_config.grpo_config.mixup_alpha
295
295
  )
296
296
 
llm_trainer/parallel.py CHANGED
@@ -140,7 +140,7 @@ class Parallel(ABC):
140
140
  return 1
141
141
 
142
142
  def wait(self, msg=None):
143
- msg = f' for {msg}' if msg else None
143
+ msg = f' for {msg}' if msg else ''
144
144
  log(f'wait at {self.device}{msg}')
145
145
  dist.barrier()
146
146
  log(f'continue at {self.device}{msg}')
llm_trainer/trainer.py CHANGED
@@ -1,6 +1,6 @@
1
- import time
2
1
  from contextlib import nullcontext
3
2
  from typing import Optional, Tuple, List, Dict, Any
3
+ import copy
4
4
 
5
5
  import torch
6
6
  import torch.distributed as dist
@@ -65,6 +65,7 @@ class Trainer:
65
65
  assert len(self.eval_prompts) == len(self.eval_image_tags)
66
66
 
67
67
  parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
68
+ self.parallel_kwargs = parallel_kwargs
68
69
  self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
69
70
  self.sampler_kwargs: dict[str, Any] = sampler_kwargs
70
71
 
@@ -323,8 +324,8 @@ class Trainer:
323
324
 
324
325
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
325
326
 
326
- def _init_reference_args(self) -> dict:
327
- parallel_kwargs, _, _, _ = self._convert_train_args()
327
+ def _init_ref_model_args(self) -> dict:
328
+ parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
328
329
 
329
330
  if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
330
331
  # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
@@ -346,8 +347,13 @@ class Trainer:
346
347
  # }
347
348
  # )
348
349
 
349
- if parallel_kwargs['zero_optimization']['stage'] != 3:
350
- parallel_kwargs['zero_optimization']['stage'] = 0
350
+ parallel_kwargs.pop('activation_checkpointing', None)
351
+ parallel_kwargs.pop('gradient_clipping', None)
352
+
353
+ # ref_model暂时先使用stage 0, 解决训练卡住问题
354
+ parallel_kwargs["zero_optimization"] = {"stage": 0}
355
+ # if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
356
+ # parallel_kwargs["zero_optimization"] = {"stage": 0}
351
357
 
352
358
  return parallel_kwargs
353
359
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.16
3
+ Version: 0.5.17
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,14 +1,14 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=--ItH-rkkq24Da3M_Kf0VxpQ3t-k0fpZrzFGqkYsjks,12304
4
+ llm_trainer/dpo_trainer.py,sha256=pNJaXvk-g0lGkZoRhbODNH34hTiz8EdP4Z12ws4W0t8,12309
5
5
  llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
- llm_trainer/grpo_trainer.py,sha256=g_ivzQop2SkvhlKAEWb0zUnIvNuHTfsOoIG6y29oTCw,16106
8
+ llm_trainer/grpo_trainer.py,sha256=tuzcSi1uBzUPVKojEheJ3-Tx8-g99mf6LYYxC5nsNiw,16040
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
- llm_trainer/parallel.py,sha256=j1L4n-JmDkDZblURrNKpEAWEqqGIAXAN9PT_fSS_OnE,4492
11
+ llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
21
- llm_trainer/trainer.py,sha256=YqWhD9jXbrUdm3KEjEHLyg_qHiXCy5R7PK-arCXxJ6M,26399
21
+ llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.16.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.16.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.16.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.16.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.16.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.16.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.16.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.16.dist-info/METADATA,sha256=h0TMNrZMUU875tVasbuqt69EuPPMbo_nv6tHQLKeNbQ,196
31
- project_llm_trainer-0.5.16.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.16.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.16.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.17.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.17.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.17.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.17.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.17.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.17.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.17.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.17.dist-info/METADATA,sha256=BVzwe45PQXSE-f5-BCZulqWCK3PIpKzxv9z__moTEJY,196
31
+ project_llm_trainer-0.5.17.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.17.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.17.dist-info/RECORD,,