project-llm-trainer 0.5.15__py3-none-any.whl → 0.5.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -65,7 +65,7 @@ def save_best_checkpoint(
65
65
  except:
66
66
  pass
67
67
 
68
- TrainerTools().parallel.wait()
68
+ TrainerTools().parallel.wait('save best checkpoint')
69
69
  return need_replace
70
70
 
71
71
 
@@ -12,7 +12,10 @@ from .dataset import DPODataset
12
12
  from .loss import DPOLoss
13
13
  from .tools import TrainerTools
14
14
  from .utils import get_dpo_collate_fn
15
- from .partition_utils import sync_model_params
15
+ from .partition_utils import (
16
+ sync_model_params,
17
+ unwrap_model_for_generation
18
+ )
16
19
 
17
20
  from .checkpoint import (
18
21
  save_checkpoint,
@@ -35,28 +38,28 @@ class DPOTrainer(Trainer):
35
38
  eval_image_tags=eval_image_tags
36
39
  )
37
40
 
38
- self.reference_model = self._init_reference_model()
41
+ self.ref_model = self._init_ref_model()
39
42
 
40
- def _init_reference_model(self):
41
- reference_model = self._new_model(self.train_config)
43
+ def _init_ref_model(self):
44
+ ref_model = self._new_model(self.train_config)
42
45
 
43
- reference_model, _ = TrainerTools().parallel.process(
44
- model=reference_model,
46
+ ref_model, _ = TrainerTools().parallel.process(
47
+ model=ref_model,
45
48
  optimizer=None,
46
- kwargs=self._init_reference_args(),
49
+ kwargs=self._init_ref_model_args(),
47
50
  save_instance=False
48
51
  )
49
52
 
50
- reference_model.eval()
51
- for param in reference_model.parameters():
53
+ ref_model.eval()
54
+ for param in ref_model.parameters():
52
55
  param.requires_grad = False
53
56
 
54
57
  sync_model_params(
55
58
  _from=self.train_model,
56
- _to=reference_model
59
+ _to=ref_model
57
60
  )
58
61
 
59
- return reference_model
62
+ return ref_model
60
63
 
61
64
  def _init_loss(self):
62
65
  criterion = DPOLoss(
@@ -170,14 +173,19 @@ class DPOTrainer(Trainer):
170
173
  skipping_train = True
171
174
  continue
172
175
 
173
- skipping_train = False
174
-
175
176
  # 是否需要更新梯度
176
- if gradient_accumulation_steps > 1:
177
+ if skipping_train:
178
+ need_update_grad = False
179
+ elif gradient_accumulation_steps > 1:
177
180
  need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
178
181
  else:
179
182
  need_update_grad = True
180
183
 
184
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
185
+ if skipping_train:
186
+ TrainerTools().parallel.wait('skip train')
187
+ skipping_train = False
188
+
181
189
  try:
182
190
  chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
183
191
  chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
@@ -198,17 +206,18 @@ class DPOTrainer(Trainer):
198
206
 
199
207
  with self.ctx:
200
208
  policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
201
- with torch.inference_mode():
202
- ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
203
-
204
209
  policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
205
- ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
210
+ aux_loss = policy_outputs.get('aux_loss')
211
+
212
+ with torch.no_grad():
213
+ ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
214
+ ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
206
215
 
207
216
  # calc loss
208
217
  loss = self.criterion(policy_probs, ref_probs)
209
218
 
210
- if aux_loss_coef and policy_outputs['aux_loss']:
211
- loss += aux_loss_coef * policy_outputs['aux_loss']
219
+ if aux_loss_coef and aux_loss:
220
+ loss += aux_loss_coef *aux_loss
212
221
 
213
222
  if gradient_accumulation_steps > 1:
214
223
  loss = loss / gradient_accumulation_steps
@@ -37,7 +37,7 @@ def save_ds_checkpoint(model: nn.Module):
37
37
  shutil.rmtree(oldest_ckpt)
38
38
  except: ...
39
39
 
40
- TrainerTools().parallel.wait()
40
+ TrainerTools().parallel.wait('remove old ds checkpoint')
41
41
 
42
42
 
43
43
  def load_ds_checkpoint(
@@ -42,27 +42,27 @@ class GRPOTrainer(Trainer):
42
42
  )
43
43
 
44
44
  self.reward_func = reward_func
45
- self.reference_model = self._init_reference_model()
45
+ self.ref_model = self._init_ref_model()
46
46
 
47
47
  # 默认使用torch提供的pad_sequence
48
48
  # 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
49
49
  self._use_origin_pad_sequence = True
50
50
 
51
- def _init_reference_model(self):
52
- reference_model = self._new_model(self.train_config)
51
+ def _init_ref_model(self):
52
+ ref_model = self._new_model(self.train_config)
53
53
 
54
- reference_model, _ = TrainerTools().parallel.process(
55
- model=reference_model,
54
+ ref_model, _ = TrainerTools().parallel.process(
55
+ model=ref_model,
56
56
  optimizer=None,
57
- kwargs=self._init_reference_args(),
57
+ kwargs=self._init_ref_model_args(),
58
58
  save_instance=False
59
59
  )
60
60
 
61
- reference_model.eval()
62
- for param in reference_model.parameters():
61
+ ref_model.eval()
62
+ for param in ref_model.parameters():
63
63
  param.requires_grad = False
64
64
 
65
- return reference_model
65
+ return ref_model
66
66
 
67
67
  def _init_loss(self):
68
68
  criterion = GRPOLoss(
@@ -225,7 +225,7 @@ class GRPOTrainer(Trainer):
225
225
  old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
226
226
 
227
227
  # Compute ref_log_probs from the reference model, which remains static.
228
- ref_log_probs, _ = self._compute_log_probabilities(self.reference_model, input_ids, attention_mask, logits_to_keep)
228
+ ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
229
229
 
230
230
  repeated_prompts = [p for p in prompts for _ in range(group_size)]
231
231
  repeated_answers = [a for a in answers for _ in range(group_size)]
@@ -290,7 +290,7 @@ class GRPOTrainer(Trainer):
290
290
  for epoch in range(self.train_config.n_epochs):
291
291
  sync_model_params(
292
292
  _from=self.train_model,
293
- _to=self.reference_model,
293
+ _to=self.ref_model,
294
294
  mixup_alpha=self.train_config.grpo_config.mixup_alpha
295
295
  )
296
296
 
@@ -317,7 +317,9 @@ class GRPOTrainer(Trainer):
317
317
  skipping_train = True
318
318
  continue
319
319
 
320
- skipping_train = False
320
+ if skipping_train:
321
+ TrainerTools().parallel.wait('skip train')
322
+ skipping_train = False
321
323
 
322
324
  # start generate
323
325
  if TrainerTools().parallel.is_main_process:
llm_trainer/parallel.py CHANGED
@@ -139,9 +139,8 @@ class Parallel(ABC):
139
139
  return dist.get_world_size()
140
140
  return 1
141
141
 
142
- def wait(self):
143
- try:
144
- log(f'wait at {self.device}')
145
- dist.barrier()
146
- except: ...
147
- log(f'continue at {self.device}')
142
+ def wait(self, msg=None):
143
+ msg = f' for {msg}' if msg else ''
144
+ log(f'wait at {self.device}{msg}')
145
+ dist.barrier()
146
+ log(f'continue at {self.device}{msg}')
llm_trainer/trainer.py CHANGED
@@ -1,6 +1,6 @@
1
- import time
2
1
  from contextlib import nullcontext
3
2
  from typing import Optional, Tuple, List, Dict, Any
3
+ import copy
4
4
 
5
5
  import torch
6
6
  import torch.distributed as dist
@@ -65,6 +65,7 @@ class Trainer:
65
65
  assert len(self.eval_prompts) == len(self.eval_image_tags)
66
66
 
67
67
  parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
68
+ self.parallel_kwargs = parallel_kwargs
68
69
  self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
69
70
  self.sampler_kwargs: dict[str, Any] = sampler_kwargs
70
71
 
@@ -323,8 +324,8 @@ class Trainer:
323
324
 
324
325
  return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
325
326
 
326
- def _init_reference_args(self) -> dict:
327
- parallel_kwargs, _, _, _ = self._convert_train_args()
327
+ def _init_ref_model_args(self) -> dict:
328
+ parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
328
329
 
329
330
  if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
330
331
  # reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
@@ -346,8 +347,13 @@ class Trainer:
346
347
  # }
347
348
  # )
348
349
 
349
- if parallel_kwargs['zero_optimization']['stage'] != 3:
350
- parallel_kwargs['zero_optimization']['stage'] = 0
350
+ parallel_kwargs.pop('activation_checkpointing', None)
351
+ parallel_kwargs.pop('gradient_clipping', None)
352
+
353
+ # ref_model暂时先使用stage 0, 解决训练卡住问题
354
+ parallel_kwargs["zero_optimization"] = {"stage": 0}
355
+ # if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
356
+ # parallel_kwargs["zero_optimization"] = {"stage": 0}
351
357
 
352
358
  return parallel_kwargs
353
359
 
@@ -449,7 +455,7 @@ class Trainer:
449
455
  )
450
456
  generate_model.train()
451
457
 
452
- TrainerTools().parallel.wait()
458
+ TrainerTools().parallel.wait('eval')
453
459
 
454
460
  def _on_batch_end(self, tag: str):
455
461
  self._eval(f'sign:batch/{tag}')
@@ -500,14 +506,19 @@ class Trainer:
500
506
  skipping_train = True
501
507
  continue
502
508
 
503
- skipping_train = False
504
-
505
509
  # 是否需要更新梯度
506
- if gradient_accumulation_steps > 1:
510
+ if skipping_train:
511
+ need_update_grad = False
512
+ elif gradient_accumulation_steps > 1:
507
513
  need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
508
514
  else:
509
515
  need_update_grad = True
510
516
 
517
+ # 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
518
+ if skipping_train:
519
+ TrainerTools().parallel.wait('skip train')
520
+ skipping_train = False
521
+
511
522
  inputs = batch_data['inputs']
512
523
  labels = batch_data['labels']
513
524
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.15
3
+ Version: 0.5.17
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,14 +1,14 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=Wh5CwceIajTgJ9i_mH3I1R9N2nOLFqVFmlEMkTiGcD4,4306
2
+ llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=8LYRxviJKcB-rN_XprVsWr5YshU8KolggMm7irjbXvI,11990
5
- llm_trainer/ds_checkpoint.py,sha256=kM7--wZyo4WIg4C2xk3bwad-m3V8ICfNLF3aFKtvzSA,2115
4
+ llm_trainer/dpo_trainer.py,sha256=pNJaXvk-g0lGkZoRhbODNH34hTiz8EdP4Z12ws4W0t8,12309
5
+ llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
6
6
  llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
7
7
  llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
8
- llm_trainer/grpo_trainer.py,sha256=PVTlKOEJpI0AMlh7Siw_MHpLm9CAZepCAMjjSZF6eRU,15996
8
+ llm_trainer/grpo_trainer.py,sha256=tuzcSi1uBzUPVKojEheJ3-Tx8-g99mf6LYYxC5nsNiw,16040
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
- llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
11
+ llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
12
12
  llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
21
- llm_trainer/trainer.py,sha256=FF75J-BRUp34No2TvQIgomvNozWYzVhDeOfaBgQLV9g,26079
21
+ llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.15.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.15.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.15.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.15.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.15.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.15.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.15.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.15.dist-info/METADATA,sha256=7ObRAx3PO5Dn55rgnJRS-bXGp-NU-SHgoPKVdTUTGCc,196
31
- project_llm_trainer-0.5.15.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.15.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.15.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.17.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.17.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.17.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.17.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.17.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.17.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.17.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.17.dist-info/METADATA,sha256=BVzwe45PQXSE-f5-BCZulqWCK3PIpKzxv9z__moTEJY,196
31
+ project_llm_trainer-0.5.17.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.17.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.17.dist-info/RECORD,,