project-llm-trainer 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/generate_utils.py +1 -0
- llm_trainer/grpo_trainer.py +13 -20
- llm_trainer/trainer.py +1 -1
- {project_llm_trainer-0.4.1.dist-info → project_llm_trainer-0.4.3.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.1.dist-info → project_llm_trainer-0.4.3.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.1.dist-info → project_llm_trainer-0.4.3.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.1.dist-info → project_llm_trainer-0.4.3.dist-info}/top_level.txt +0 -0
llm_trainer/generate_utils.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from typing import Union, Optional, List
|
|
2
2
|
from contextlib import nullcontext
|
|
3
3
|
import torch
|
|
4
|
+
import torch.distributed as dist
|
|
4
5
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
5
6
|
from llm_model import VlmModel, KVCache
|
|
6
7
|
from .tools import TrainerTools
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -14,6 +14,7 @@ from .dataset import GRPORolloutDataset
|
|
|
14
14
|
from .loss import GRPOLoss
|
|
15
15
|
from .tools import TrainerTools
|
|
16
16
|
from .generate_utils import batch_generate
|
|
17
|
+
from .log import log
|
|
17
18
|
|
|
18
19
|
from .checkpoint import (
|
|
19
20
|
save_checkpoint,
|
|
@@ -46,12 +47,9 @@ class GRPOTrainer(Trainer):
|
|
|
46
47
|
|
|
47
48
|
def _init_reference_model(self):
|
|
48
49
|
reference_model = self._new_model(self.train_config)
|
|
49
|
-
|
|
50
|
-
device = 'cpu' # TrainerTools().parallel.device
|
|
51
|
-
reference_model.to(device)
|
|
52
|
-
# load_checkpoint_for_eval(model=reference_model, device=device)
|
|
53
|
-
|
|
50
|
+
reference_model.to('cpu')
|
|
54
51
|
reference_model.eval()
|
|
52
|
+
|
|
55
53
|
for param in reference_model.parameters():
|
|
56
54
|
param.requires_grad = False
|
|
57
55
|
|
|
@@ -59,17 +57,6 @@ class GRPOTrainer(Trainer):
|
|
|
59
57
|
|
|
60
58
|
def _init_generate_model(self):
|
|
61
59
|
return copy.deepcopy(self.reference_model)
|
|
62
|
-
# generate_model = self._new_model(self.train_config)
|
|
63
|
-
#
|
|
64
|
-
# device = 'cpu' #TrainerTools().parallel.device
|
|
65
|
-
# generate_model.to(device)
|
|
66
|
-
# # load_checkpoint_for_eval(model=generate_model, device=device)
|
|
67
|
-
#
|
|
68
|
-
# generate_model.eval()
|
|
69
|
-
# for param in generate_model.parameters():
|
|
70
|
-
# param.requires_grad = False
|
|
71
|
-
#
|
|
72
|
-
# return generate_model
|
|
73
60
|
|
|
74
61
|
def _init_loss(self):
|
|
75
62
|
criterion = GRPOLoss(
|
|
@@ -194,7 +181,6 @@ class GRPOTrainer(Trainer):
|
|
|
194
181
|
|
|
195
182
|
# [batch*group_size, max_prompt_len+max_gen_len]
|
|
196
183
|
outputs: torch.Tensor = batch_generate(
|
|
197
|
-
# model=self.train_model,
|
|
198
184
|
model=self.generate_model,
|
|
199
185
|
tokens=prompt_ids,
|
|
200
186
|
pad_token_id=pad_token_id,
|
|
@@ -325,10 +311,14 @@ class GRPOTrainer(Trainer):
|
|
|
325
311
|
self.generate_model.to(device)
|
|
326
312
|
self.reference_model.to(device)
|
|
327
313
|
|
|
328
|
-
|
|
329
|
-
|
|
314
|
+
if TrainerTools().parallel.is_main_process:
|
|
315
|
+
log(f'start generate for batch {batch}/{batch_count_per_file}')
|
|
316
|
+
|
|
330
317
|
# 生成数据
|
|
331
|
-
|
|
318
|
+
with torch.no_grad():
|
|
319
|
+
# 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
|
|
320
|
+
copy_model_params(_from=self.train_model, _to=self.generate_model)
|
|
321
|
+
rollout_data = self._generate_rollout_data(batch_data)
|
|
332
322
|
|
|
333
323
|
# 卸载到cpu上,等待下次使用时再to gpu
|
|
334
324
|
self.generate_model.to('cpu')
|
|
@@ -337,6 +327,9 @@ class GRPOTrainer(Trainer):
|
|
|
337
327
|
# end generate
|
|
338
328
|
|
|
339
329
|
try:
|
|
330
|
+
if TrainerTools().parallel.is_main_process:
|
|
331
|
+
log(f'start train for batch {batch}/{batch_count_per_file}')
|
|
332
|
+
|
|
340
333
|
for grpo_step in range(self.train_config.grpo_config.grpo_steps):
|
|
341
334
|
with self.ctx:
|
|
342
335
|
loss, aux_loss = self._maximize_grpo_objective(rollout_data)
|
llm_trainer/trainer.py
CHANGED
|
@@ -136,7 +136,7 @@ class Trainer:
|
|
|
136
136
|
# freeze llm model for vlm training
|
|
137
137
|
if self.train_config.freeze_llm_model:
|
|
138
138
|
for name, param in model.named_parameters():
|
|
139
|
-
if not any(sub_module in name for sub_module in ['
|
|
139
|
+
if not any(sub_module in name for sub_module in ['multi_modal_projector']):
|
|
140
140
|
param.requires_grad = False
|
|
141
141
|
|
|
142
142
|
model.embed_tokens.eval()
|
|
@@ -6,8 +6,8 @@ llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13
|
|
|
6
6
|
llm_trainer/ds_checkpoint.py,sha256=nchGocJE2oJnQ_KNN1kw-BkOAEIyTtO8SJt41cuN_xM,4232
|
|
7
7
|
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
8
|
llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
|
|
9
|
-
llm_trainer/generate_utils.py,sha256=
|
|
10
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
9
|
+
llm_trainer/generate_utils.py,sha256=RpAIjN0fvyTkMk9b9x7YE6c5GiiE3x5YGyPaa4R_BjA,15191
|
|
10
|
+
llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
|
|
11
11
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
12
12
|
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
13
13
|
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
@@ -20,16 +20,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
20
20
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
21
|
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
22
|
llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
|
|
23
|
-
llm_trainer/trainer.py,sha256=
|
|
23
|
+
llm_trainer/trainer.py,sha256=aoZYL5U4Z5axXBMM_DHgzIzJ89YbU9xUQ56jppcT65c,25339
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.
|
|
26
|
-
project_llm_trainer-0.4.
|
|
27
|
-
project_llm_trainer-0.4.
|
|
28
|
-
project_llm_trainer-0.4.
|
|
29
|
-
project_llm_trainer-0.4.
|
|
30
|
-
project_llm_trainer-0.4.
|
|
31
|
-
project_llm_trainer-0.4.
|
|
32
|
-
project_llm_trainer-0.4.
|
|
33
|
-
project_llm_trainer-0.4.
|
|
34
|
-
project_llm_trainer-0.4.
|
|
35
|
-
project_llm_trainer-0.4.
|
|
25
|
+
project_llm_trainer-0.4.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.3.dist-info/METADATA,sha256=kmmc6L6SE9iBvNHutWpeb0TocGX5vixhvHHLS4ltqec,195
|
|
33
|
+
project_llm_trainer-0.4.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.3.dist-info/RECORD,,
|
{project_llm_trainer-0.4.1.data → project_llm_trainer-0.4.3.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|