project-llm-trainer 0.5.15__py3-none-any.whl → 0.5.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +1 -1
- llm_trainer/dpo_trainer.py +29 -20
- llm_trainer/ds_checkpoint.py +1 -1
- llm_trainer/grpo_trainer.py +14 -12
- llm_trainer/parallel.py +5 -6
- llm_trainer/trainer.py +20 -9
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.17.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.17.dist-info}/RECORD +17 -17
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.17.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.17.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,7 +12,10 @@ from .dataset import DPODataset
|
|
|
12
12
|
from .loss import DPOLoss
|
|
13
13
|
from .tools import TrainerTools
|
|
14
14
|
from .utils import get_dpo_collate_fn
|
|
15
|
-
from .partition_utils import
|
|
15
|
+
from .partition_utils import (
|
|
16
|
+
sync_model_params,
|
|
17
|
+
unwrap_model_for_generation
|
|
18
|
+
)
|
|
16
19
|
|
|
17
20
|
from .checkpoint import (
|
|
18
21
|
save_checkpoint,
|
|
@@ -35,28 +38,28 @@ class DPOTrainer(Trainer):
|
|
|
35
38
|
eval_image_tags=eval_image_tags
|
|
36
39
|
)
|
|
37
40
|
|
|
38
|
-
self.
|
|
41
|
+
self.ref_model = self._init_ref_model()
|
|
39
42
|
|
|
40
|
-
def
|
|
41
|
-
|
|
43
|
+
def _init_ref_model(self):
|
|
44
|
+
ref_model = self._new_model(self.train_config)
|
|
42
45
|
|
|
43
|
-
|
|
44
|
-
model=
|
|
46
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
47
|
+
model=ref_model,
|
|
45
48
|
optimizer=None,
|
|
46
|
-
kwargs=self.
|
|
49
|
+
kwargs=self._init_ref_model_args(),
|
|
47
50
|
save_instance=False
|
|
48
51
|
)
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
for param in
|
|
53
|
+
ref_model.eval()
|
|
54
|
+
for param in ref_model.parameters():
|
|
52
55
|
param.requires_grad = False
|
|
53
56
|
|
|
54
57
|
sync_model_params(
|
|
55
58
|
_from=self.train_model,
|
|
56
|
-
_to=
|
|
59
|
+
_to=ref_model
|
|
57
60
|
)
|
|
58
61
|
|
|
59
|
-
return
|
|
62
|
+
return ref_model
|
|
60
63
|
|
|
61
64
|
def _init_loss(self):
|
|
62
65
|
criterion = DPOLoss(
|
|
@@ -170,14 +173,19 @@ class DPOTrainer(Trainer):
|
|
|
170
173
|
skipping_train = True
|
|
171
174
|
continue
|
|
172
175
|
|
|
173
|
-
skipping_train = False
|
|
174
|
-
|
|
175
176
|
# 是否需要更新梯度
|
|
176
|
-
if
|
|
177
|
+
if skipping_train:
|
|
178
|
+
need_update_grad = False
|
|
179
|
+
elif gradient_accumulation_steps > 1:
|
|
177
180
|
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
178
181
|
else:
|
|
179
182
|
need_update_grad = True
|
|
180
183
|
|
|
184
|
+
# 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
|
|
185
|
+
if skipping_train:
|
|
186
|
+
TrainerTools().parallel.wait('skip train')
|
|
187
|
+
skipping_train = False
|
|
188
|
+
|
|
181
189
|
try:
|
|
182
190
|
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
183
191
|
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
|
@@ -198,17 +206,18 @@ class DPOTrainer(Trainer):
|
|
|
198
206
|
|
|
199
207
|
with self.ctx:
|
|
200
208
|
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
201
|
-
with torch.inference_mode():
|
|
202
|
-
ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
|
|
203
|
-
|
|
204
209
|
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
205
|
-
|
|
210
|
+
aux_loss = policy_outputs.get('aux_loss')
|
|
211
|
+
|
|
212
|
+
with torch.no_grad():
|
|
213
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
|
|
214
|
+
ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
|
|
206
215
|
|
|
207
216
|
# calc loss
|
|
208
217
|
loss = self.criterion(policy_probs, ref_probs)
|
|
209
218
|
|
|
210
|
-
if aux_loss_coef and
|
|
211
|
-
loss += aux_loss_coef *
|
|
219
|
+
if aux_loss_coef and aux_loss:
|
|
220
|
+
loss += aux_loss_coef *aux_loss
|
|
212
221
|
|
|
213
222
|
if gradient_accumulation_steps > 1:
|
|
214
223
|
loss = loss / gradient_accumulation_steps
|
llm_trainer/ds_checkpoint.py
CHANGED
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -42,27 +42,27 @@ class GRPOTrainer(Trainer):
|
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
self.reward_func = reward_func
|
|
45
|
-
self.
|
|
45
|
+
self.ref_model = self._init_ref_model()
|
|
46
46
|
|
|
47
47
|
# 默认使用torch提供的pad_sequence
|
|
48
48
|
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
49
49
|
self._use_origin_pad_sequence = True
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
|
|
51
|
+
def _init_ref_model(self):
|
|
52
|
+
ref_model = self._new_model(self.train_config)
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
model=
|
|
54
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
55
|
+
model=ref_model,
|
|
56
56
|
optimizer=None,
|
|
57
|
-
kwargs=self.
|
|
57
|
+
kwargs=self._init_ref_model_args(),
|
|
58
58
|
save_instance=False
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
-
|
|
62
|
-
for param in
|
|
61
|
+
ref_model.eval()
|
|
62
|
+
for param in ref_model.parameters():
|
|
63
63
|
param.requires_grad = False
|
|
64
64
|
|
|
65
|
-
return
|
|
65
|
+
return ref_model
|
|
66
66
|
|
|
67
67
|
def _init_loss(self):
|
|
68
68
|
criterion = GRPOLoss(
|
|
@@ -225,7 +225,7 @@ class GRPOTrainer(Trainer):
|
|
|
225
225
|
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
226
226
|
|
|
227
227
|
# Compute ref_log_probs from the reference model, which remains static.
|
|
228
|
-
ref_log_probs, _ = self._compute_log_probabilities(self.
|
|
228
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
|
|
229
229
|
|
|
230
230
|
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
231
231
|
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
@@ -290,7 +290,7 @@ class GRPOTrainer(Trainer):
|
|
|
290
290
|
for epoch in range(self.train_config.n_epochs):
|
|
291
291
|
sync_model_params(
|
|
292
292
|
_from=self.train_model,
|
|
293
|
-
_to=self.
|
|
293
|
+
_to=self.ref_model,
|
|
294
294
|
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
295
295
|
)
|
|
296
296
|
|
|
@@ -317,7 +317,9 @@ class GRPOTrainer(Trainer):
|
|
|
317
317
|
skipping_train = True
|
|
318
318
|
continue
|
|
319
319
|
|
|
320
|
-
skipping_train
|
|
320
|
+
if skipping_train:
|
|
321
|
+
TrainerTools().parallel.wait('skip train')
|
|
322
|
+
skipping_train = False
|
|
321
323
|
|
|
322
324
|
# start generate
|
|
323
325
|
if TrainerTools().parallel.is_main_process:
|
llm_trainer/parallel.py
CHANGED
|
@@ -139,9 +139,8 @@ class Parallel(ABC):
|
|
|
139
139
|
return dist.get_world_size()
|
|
140
140
|
return 1
|
|
141
141
|
|
|
142
|
-
def wait(self):
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
log(f'continue at {self.device}')
|
|
142
|
+
def wait(self, msg=None):
|
|
143
|
+
msg = f' for {msg}' if msg else ''
|
|
144
|
+
log(f'wait at {self.device}{msg}')
|
|
145
|
+
dist.barrier()
|
|
146
|
+
log(f'continue at {self.device}{msg}')
|
llm_trainer/trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import time
|
|
2
1
|
from contextlib import nullcontext
|
|
3
2
|
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
import copy
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.distributed as dist
|
|
@@ -65,6 +65,7 @@ class Trainer:
|
|
|
65
65
|
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
66
66
|
|
|
67
67
|
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
|
|
68
|
+
self.parallel_kwargs = parallel_kwargs
|
|
68
69
|
self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
|
|
69
70
|
self.sampler_kwargs: dict[str, Any] = sampler_kwargs
|
|
70
71
|
|
|
@@ -323,8 +324,8 @@ class Trainer:
|
|
|
323
324
|
|
|
324
325
|
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
325
326
|
|
|
326
|
-
def
|
|
327
|
-
parallel_kwargs
|
|
327
|
+
def _init_ref_model_args(self) -> dict:
|
|
328
|
+
parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
|
|
328
329
|
|
|
329
330
|
if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
|
|
330
331
|
# reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
|
|
@@ -346,8 +347,13 @@ class Trainer:
|
|
|
346
347
|
# }
|
|
347
348
|
# )
|
|
348
349
|
|
|
349
|
-
|
|
350
|
-
|
|
350
|
+
parallel_kwargs.pop('activation_checkpointing', None)
|
|
351
|
+
parallel_kwargs.pop('gradient_clipping', None)
|
|
352
|
+
|
|
353
|
+
# ref_model暂时先使用stage 0, 解决训练卡住问题
|
|
354
|
+
parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
355
|
+
# if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
|
|
356
|
+
# parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
351
357
|
|
|
352
358
|
return parallel_kwargs
|
|
353
359
|
|
|
@@ -449,7 +455,7 @@ class Trainer:
|
|
|
449
455
|
)
|
|
450
456
|
generate_model.train()
|
|
451
457
|
|
|
452
|
-
TrainerTools().parallel.wait()
|
|
458
|
+
TrainerTools().parallel.wait('eval')
|
|
453
459
|
|
|
454
460
|
def _on_batch_end(self, tag: str):
|
|
455
461
|
self._eval(f'sign:batch/{tag}')
|
|
@@ -500,14 +506,19 @@ class Trainer:
|
|
|
500
506
|
skipping_train = True
|
|
501
507
|
continue
|
|
502
508
|
|
|
503
|
-
skipping_train = False
|
|
504
|
-
|
|
505
509
|
# 是否需要更新梯度
|
|
506
|
-
if
|
|
510
|
+
if skipping_train:
|
|
511
|
+
need_update_grad = False
|
|
512
|
+
elif gradient_accumulation_steps > 1:
|
|
507
513
|
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
508
514
|
else:
|
|
509
515
|
need_update_grad = True
|
|
510
516
|
|
|
517
|
+
# 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
|
|
518
|
+
if skipping_train:
|
|
519
|
+
TrainerTools().parallel.wait('skip train')
|
|
520
|
+
skipping_train = False
|
|
521
|
+
|
|
511
522
|
inputs = batch_data['inputs']
|
|
512
523
|
labels = batch_data['labels']
|
|
513
524
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=pNJaXvk-g0lGkZoRhbODNH34hTiz8EdP4Z12ws4W0t8,12309
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=tuzcSi1uBzUPVKojEheJ3-Tx8-g99mf6LYYxC5nsNiw,16040
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
-
llm_trainer/parallel.py,sha256=
|
|
11
|
+
llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.17.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.17.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.17.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.17.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.17.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.17.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.17.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.17.dist-info/METADATA,sha256=BVzwe45PQXSE-f5-BCZulqWCK3PIpKzxv9z__moTEJY,196
|
|
31
|
+
project_llm_trainer-0.5.17.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.17.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.17.dist-info/RECORD,,
|
{project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.17.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|