project-llm-trainer 0.5.16__py3-none-any.whl → 0.5.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +21 -17
- llm_trainer/grpo_trainer.py +11 -11
- llm_trainer/parallel.py +1 -1
- llm_trainer/trainer.py +11 -5
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.5.17.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.5.17.dist-info}/RECORD +15 -15
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.5.17.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.16.dist-info → project_llm_trainer-0.5.17.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -12,7 +12,10 @@ from .dataset import DPODataset
|
|
|
12
12
|
from .loss import DPOLoss
|
|
13
13
|
from .tools import TrainerTools
|
|
14
14
|
from .utils import get_dpo_collate_fn
|
|
15
|
-
from .partition_utils import
|
|
15
|
+
from .partition_utils import (
|
|
16
|
+
sync_model_params,
|
|
17
|
+
unwrap_model_for_generation
|
|
18
|
+
)
|
|
16
19
|
|
|
17
20
|
from .checkpoint import (
|
|
18
21
|
save_checkpoint,
|
|
@@ -35,28 +38,28 @@ class DPOTrainer(Trainer):
|
|
|
35
38
|
eval_image_tags=eval_image_tags
|
|
36
39
|
)
|
|
37
40
|
|
|
38
|
-
self.
|
|
41
|
+
self.ref_model = self._init_ref_model()
|
|
39
42
|
|
|
40
|
-
def
|
|
41
|
-
|
|
43
|
+
def _init_ref_model(self):
|
|
44
|
+
ref_model = self._new_model(self.train_config)
|
|
42
45
|
|
|
43
|
-
|
|
44
|
-
model=
|
|
46
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
47
|
+
model=ref_model,
|
|
45
48
|
optimizer=None,
|
|
46
|
-
kwargs=self.
|
|
49
|
+
kwargs=self._init_ref_model_args(),
|
|
47
50
|
save_instance=False
|
|
48
51
|
)
|
|
49
52
|
|
|
50
|
-
|
|
51
|
-
for param in
|
|
53
|
+
ref_model.eval()
|
|
54
|
+
for param in ref_model.parameters():
|
|
52
55
|
param.requires_grad = False
|
|
53
56
|
|
|
54
57
|
sync_model_params(
|
|
55
58
|
_from=self.train_model,
|
|
56
|
-
_to=
|
|
59
|
+
_to=ref_model
|
|
57
60
|
)
|
|
58
61
|
|
|
59
|
-
return
|
|
62
|
+
return ref_model
|
|
60
63
|
|
|
61
64
|
def _init_loss(self):
|
|
62
65
|
criterion = DPOLoss(
|
|
@@ -203,17 +206,18 @@ class DPOTrainer(Trainer):
|
|
|
203
206
|
|
|
204
207
|
with self.ctx:
|
|
205
208
|
policy_outputs = self.train_model(concat_inputs, attention_mask=concat_mask)
|
|
206
|
-
with torch.inference_mode():
|
|
207
|
-
ref_outputs = self.reference_model(concat_inputs, attention_mask=concat_mask)
|
|
208
|
-
|
|
209
209
|
policy_probs = self._logprobs(policy_outputs['logits'], concat_labels, concat_mask)
|
|
210
|
-
|
|
210
|
+
aux_loss = policy_outputs.get('aux_loss')
|
|
211
|
+
|
|
212
|
+
with torch.no_grad():
|
|
213
|
+
ref_outputs = self.ref_model(concat_inputs, attention_mask=concat_mask)
|
|
214
|
+
ref_probs = self._logprobs(ref_outputs['logits'], concat_labels, concat_mask)
|
|
211
215
|
|
|
212
216
|
# calc loss
|
|
213
217
|
loss = self.criterion(policy_probs, ref_probs)
|
|
214
218
|
|
|
215
|
-
if aux_loss_coef and
|
|
216
|
-
loss += aux_loss_coef *
|
|
219
|
+
if aux_loss_coef and aux_loss:
|
|
220
|
+
loss += aux_loss_coef *aux_loss
|
|
217
221
|
|
|
218
222
|
if gradient_accumulation_steps > 1:
|
|
219
223
|
loss = loss / gradient_accumulation_steps
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -42,27 +42,27 @@ class GRPOTrainer(Trainer):
|
|
|
42
42
|
)
|
|
43
43
|
|
|
44
44
|
self.reward_func = reward_func
|
|
45
|
-
self.
|
|
45
|
+
self.ref_model = self._init_ref_model()
|
|
46
46
|
|
|
47
47
|
# 默认使用torch提供的pad_sequence
|
|
48
48
|
# 如果pad_sequence不支持padding_side参数,则将改参数置为False,使用反转的方式
|
|
49
49
|
self._use_origin_pad_sequence = True
|
|
50
50
|
|
|
51
|
-
def
|
|
52
|
-
|
|
51
|
+
def _init_ref_model(self):
|
|
52
|
+
ref_model = self._new_model(self.train_config)
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
model=
|
|
54
|
+
ref_model, _ = TrainerTools().parallel.process(
|
|
55
|
+
model=ref_model,
|
|
56
56
|
optimizer=None,
|
|
57
|
-
kwargs=self.
|
|
57
|
+
kwargs=self._init_ref_model_args(),
|
|
58
58
|
save_instance=False
|
|
59
59
|
)
|
|
60
60
|
|
|
61
|
-
|
|
62
|
-
for param in
|
|
61
|
+
ref_model.eval()
|
|
62
|
+
for param in ref_model.parameters():
|
|
63
63
|
param.requires_grad = False
|
|
64
64
|
|
|
65
|
-
return
|
|
65
|
+
return ref_model
|
|
66
66
|
|
|
67
67
|
def _init_loss(self):
|
|
68
68
|
criterion = GRPOLoss(
|
|
@@ -225,7 +225,7 @@ class GRPOTrainer(Trainer):
|
|
|
225
225
|
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
226
226
|
|
|
227
227
|
# Compute ref_log_probs from the reference model, which remains static.
|
|
228
|
-
ref_log_probs, _ = self._compute_log_probabilities(self.
|
|
228
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
|
|
229
229
|
|
|
230
230
|
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
231
231
|
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
@@ -290,7 +290,7 @@ class GRPOTrainer(Trainer):
|
|
|
290
290
|
for epoch in range(self.train_config.n_epochs):
|
|
291
291
|
sync_model_params(
|
|
292
292
|
_from=self.train_model,
|
|
293
|
-
_to=self.
|
|
293
|
+
_to=self.ref_model,
|
|
294
294
|
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
295
295
|
)
|
|
296
296
|
|
llm_trainer/parallel.py
CHANGED
llm_trainer/trainer.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
|
1
|
-
import time
|
|
2
1
|
from contextlib import nullcontext
|
|
3
2
|
from typing import Optional, Tuple, List, Dict, Any
|
|
3
|
+
import copy
|
|
4
4
|
|
|
5
5
|
import torch
|
|
6
6
|
import torch.distributed as dist
|
|
@@ -65,6 +65,7 @@ class Trainer:
|
|
|
65
65
|
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
66
66
|
|
|
67
67
|
parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim = self._convert_train_args()
|
|
68
|
+
self.parallel_kwargs = parallel_kwargs
|
|
68
69
|
self.data_loader_kwargs: dict[str, Any] = data_loader_kwargs
|
|
69
70
|
self.sampler_kwargs: dict[str, Any] = sampler_kwargs
|
|
70
71
|
|
|
@@ -323,8 +324,8 @@ class Trainer:
|
|
|
323
324
|
|
|
324
325
|
return parallel_kwargs, data_loader_kwargs, sampler_kwargs, use_ds_optim
|
|
325
326
|
|
|
326
|
-
def
|
|
327
|
-
parallel_kwargs
|
|
327
|
+
def _init_ref_model_args(self) -> dict:
|
|
328
|
+
parallel_kwargs = copy.deepcopy(self.parallel_kwargs)
|
|
328
329
|
|
|
329
330
|
if parallel_kwargs and isinstance(TrainerTools().parallel, DsParallel):
|
|
330
331
|
# reference to https://github.com/huggingface/trl/blob/main/trl/models/utils.py:prepare_deepspeed
|
|
@@ -346,8 +347,13 @@ class Trainer:
|
|
|
346
347
|
# }
|
|
347
348
|
# )
|
|
348
349
|
|
|
349
|
-
|
|
350
|
-
|
|
350
|
+
parallel_kwargs.pop('activation_checkpointing', None)
|
|
351
|
+
parallel_kwargs.pop('gradient_clipping', None)
|
|
352
|
+
|
|
353
|
+
# ref_model暂时先使用stage 0, 解决训练卡住问题
|
|
354
|
+
parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
355
|
+
# if parallel_kwargs.get("zero_optimization", {}).get("stage", 0) != 3:
|
|
356
|
+
# parallel_kwargs["zero_optimization"] = {"stage": 0}
|
|
351
357
|
|
|
352
358
|
return parallel_kwargs
|
|
353
359
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=pNJaXvk-g0lGkZoRhbODNH34hTiz8EdP4Z12ws4W0t8,12309
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=tuzcSi1uBzUPVKojEheJ3-Tx8-g99mf6LYYxC5nsNiw,16040
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
-
llm_trainer/parallel.py,sha256=
|
|
11
|
+
llm_trainer/parallel.py,sha256=G9X0FddIJwd9j-5XOknB4AlBe4G2W6fUCaQH6ycC2Fo,4490
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=Q821nlLDKRZVpaRoiZ7DiJplpAJRRLtvR_33FbClGA0,26729
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.17.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.17.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.17.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.17.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.17.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.17.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.17.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.17.dist-info/METADATA,sha256=BVzwe45PQXSE-f5-BCZulqWCK3PIpKzxv9z__moTEJY,196
|
|
31
|
+
project_llm_trainer-0.5.17.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.17.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.17.dist-info/RECORD,,
|
{project_llm_trainer-0.5.16.data → project_llm_trainer-0.5.17.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|