project-llm-trainer 0.4.1__py3-none-any.whl → 0.4.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
1
  from typing import Union, Optional, List
2
2
  from contextlib import nullcontext
3
3
  import torch
4
+ import torch.distributed as dist
4
5
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
5
6
  from llm_model import VlmModel, KVCache
6
7
  from .tools import TrainerTools
@@ -14,6 +14,7 @@ from .dataset import GRPORolloutDataset
14
14
  from .loss import GRPOLoss
15
15
  from .tools import TrainerTools
16
16
  from .generate_utils import batch_generate
17
+ from .log import log
17
18
 
18
19
  from .checkpoint import (
19
20
  save_checkpoint,
@@ -46,12 +47,9 @@ class GRPOTrainer(Trainer):
46
47
 
47
48
  def _init_reference_model(self):
48
49
  reference_model = self._new_model(self.train_config)
49
-
50
- device = 'cpu' # TrainerTools().parallel.device
51
- reference_model.to(device)
52
- # load_checkpoint_for_eval(model=reference_model, device=device)
53
-
50
+ reference_model.to('cpu')
54
51
  reference_model.eval()
52
+
55
53
  for param in reference_model.parameters():
56
54
  param.requires_grad = False
57
55
 
@@ -59,17 +57,6 @@ class GRPOTrainer(Trainer):
59
57
 
60
58
  def _init_generate_model(self):
61
59
  return copy.deepcopy(self.reference_model)
62
- # generate_model = self._new_model(self.train_config)
63
- #
64
- # device = 'cpu' #TrainerTools().parallel.device
65
- # generate_model.to(device)
66
- # # load_checkpoint_for_eval(model=generate_model, device=device)
67
- #
68
- # generate_model.eval()
69
- # for param in generate_model.parameters():
70
- # param.requires_grad = False
71
- #
72
- # return generate_model
73
60
 
74
61
  def _init_loss(self):
75
62
  criterion = GRPOLoss(
@@ -194,7 +181,6 @@ class GRPOTrainer(Trainer):
194
181
 
195
182
  # [batch*group_size, max_prompt_len+max_gen_len]
196
183
  outputs: torch.Tensor = batch_generate(
197
- # model=self.train_model,
198
184
  model=self.generate_model,
199
185
  tokens=prompt_ids,
200
186
  pad_token_id=pad_token_id,
@@ -325,10 +311,14 @@ class GRPOTrainer(Trainer):
325
311
  self.generate_model.to(device)
326
312
  self.reference_model.to(device)
327
313
 
328
- # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
329
- copy_model_params(_from=self.train_model, _to=self.generate_model)
314
+ if TrainerTools().parallel.is_main_process:
315
+ log(f'start generate for batch {batch}/{batch_count_per_file}')
316
+
330
317
  # 生成数据
331
- rollout_data = self._generate_rollout_data(batch_data)
318
+ with torch.no_grad():
319
+ # 保存了train_model checkpoint后,这里保证生成模型使用的参数是最新
320
+ copy_model_params(_from=self.train_model, _to=self.generate_model)
321
+ rollout_data = self._generate_rollout_data(batch_data)
332
322
 
333
323
  # 卸载到cpu上,等待下次使用时再to gpu
334
324
  self.generate_model.to('cpu')
@@ -337,6 +327,9 @@ class GRPOTrainer(Trainer):
337
327
  # end generate
338
328
 
339
329
  try:
330
+ if TrainerTools().parallel.is_main_process:
331
+ log(f'start train for batch {batch}/{batch_count_per_file}')
332
+
340
333
  for grpo_step in range(self.train_config.grpo_config.grpo_steps):
341
334
  with self.ctx:
342
335
  loss, aux_loss = self._maximize_grpo_objective(rollout_data)
llm_trainer/trainer.py CHANGED
@@ -136,7 +136,7 @@ class Trainer:
136
136
  # freeze llm model for vlm training
137
137
  if self.train_config.freeze_llm_model:
138
138
  for name, param in model.named_parameters():
139
- if not any(sub_module in name for sub_module in ['vision_tower', 'multi_modal_projector']):
139
+ if not any(sub_module in name for sub_module in ['multi_modal_projector']):
140
140
  param.requires_grad = False
141
141
 
142
142
  model.embed_tokens.eval()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.1
3
+ Version: 0.4.3
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -6,8 +6,8 @@ llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13
6
6
  llm_trainer/ds_checkpoint.py,sha256=nchGocJE2oJnQ_KNN1kw-BkOAEIyTtO8SJt41cuN_xM,4232
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
8
  llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
9
- llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
10
- llm_trainer/grpo_trainer.py,sha256=fqLT48ORSCece_e8dpyt8J7EarDuTnGoJ_eHk7Oy-1k,16177
9
+ llm_trainer/generate_utils.py,sha256=RpAIjN0fvyTkMk9b9x7YE6c5GiiE3x5YGyPaa4R_BjA,15191
10
+ llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
11
11
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
12
12
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
13
13
  llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
@@ -20,16 +20,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
22
  llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
23
- llm_trainer/trainer.py,sha256=hOn-z8kOd67RTuaaNMmdQjlw7N5LIZRHjSt5frpA1xI,25355
23
+ llm_trainer/trainer.py,sha256=aoZYL5U4Z5axXBMM_DHgzIzJ89YbU9xUQ56jppcT65c,25339
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.1.dist-info/METADATA,sha256=9z1AB745r7BzQHNc3j-3N2nOdB9ZRUYsxcM42QoSb1o,195
33
- project_llm_trainer-0.4.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.1.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.3.dist-info/METADATA,sha256=kmmc6L6SE9iBvNHutWpeb0TocGX5vixhvHHLS4ltqec,195
33
+ project_llm_trainer-0.4.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.3.dist-info/RECORD,,