project-llm-trainer 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -69,19 +69,29 @@ class DPOTrainer(Trainer):
69
69
  zero_optimization = {'stage': 0}
70
70
  parallel_kwargs['zero_optimization'] = zero_optimization
71
71
 
72
- if self.train_config.ds_config.fp16_config:
72
+
73
+ if (self.train_config.ds_config.bf16_config is not None
74
+ and self.train_config.ds_config.bf16_config.enabled):
75
+ bf16_config = self.train_config.ds_config.bf16_config
76
+ bf16 = {
77
+ 'enabled': bf16_config.enabled
78
+ }
79
+ parallel_kwargs['bf16'] = bf16
80
+ elif self.train_config.ds_config.fp16_config:
73
81
  fb16_config = self.train_config.ds_config.fp16_config
74
- fp16 = { 'enabled': fb16_config.enabled }
82
+ fp16 = {
83
+ 'enabled': fb16_config.enabled,
84
+ 'loss_scale': fb16_config.loss_scale,
85
+ 'loss_scale_window': fb16_config.loss_scale_window,
86
+ 'initial_scale_power': fb16_config.initial_scale_power,
87
+ 'hysteresis': fb16_config.hysteresis,
88
+ 'min_loss_scale': fb16_config.min_loss_scale
89
+ }
75
90
 
76
91
  if fb16_config.fp16_opt_level is not None:
77
92
  fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
78
93
 
79
94
  parallel_kwargs['fp16'] = fp16
80
-
81
- if self.train_config.ds_config.bf16_config:
82
- bf16_config = self.train_config.ds_config.bf16_config
83
- bf16 = { 'enabled': bf16_config.enabled }
84
- parallel_kwargs['bf16'] = bf16
85
95
  elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
86
96
  parallel_kwargs = {
87
97
  'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
@@ -257,7 +267,7 @@ class DPOTrainer(Trainer):
257
267
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
258
268
  # clip grad
259
269
  self.scalar.unscale_(self.optimizer)
260
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
270
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
261
271
 
262
272
  self._step()
263
273
 
@@ -355,7 +355,7 @@ class GRPOTrainer(Trainer):
355
355
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
356
356
  # clip grad
357
357
  self.scalar.unscale_(self.optimizer)
358
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
358
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
359
359
 
360
360
  self._step()
361
361
 
@@ -422,7 +422,8 @@ class TrainConfig:
422
422
  kd_config: Optional[KDConfig] = None,
423
423
  pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None,
424
424
  init_state_dict: Optional[Mapping[str, Any]] = None,
425
- eval_config: EvalConfig = EvalConfig()
425
+ eval_config: EvalConfig = EvalConfig(),
426
+ freeze_llm_model: bool = False
426
427
  ):
427
428
  self.n_epochs = n_epochs
428
429
  self.batch_size = batch_size
@@ -443,5 +444,6 @@ class TrainConfig:
443
444
  self.pixel_values_provider = pixel_values_provider
444
445
  self.init_state_dict = init_state_dict
445
446
  self.eval_config = eval_config
447
+ self.freeze_llm_model = freeze_llm_model
446
448
 
447
449
 
llm_trainer/trainer.py CHANGED
@@ -116,6 +116,10 @@ class Trainer:
116
116
  else:
117
117
  return LlmModel(train_config.model_config)
118
118
 
119
+ def _get_trainable_params(self, model):
120
+ freeze_llm_model = self.train_config.freeze_llm_model
121
+ return model.parameters() if not freeze_llm_model else filter(lambda p: p.requires_grad, model.parameters())
122
+
119
123
  def _init_train_model_and_optim(
120
124
  self,
121
125
  initial_lr: float,
@@ -128,10 +132,24 @@ class Trainer:
128
132
  model.load_state_dict(self.train_config.init_state_dict, strict=False)
129
133
  self.train_config.init_state_dict = None
130
134
 
135
+ # freeze llm model for vlm training
136
+ if self.train_config.freeze_llm_model:
137
+ for name, param in model.named_parameters():
138
+ if not any(sub_module in name for sub_module in ['vision_tower', 'multi_modal_projector']):
139
+ param.requires_grad = False
140
+
141
+ model.embed_tokens.eval()
142
+ model.layers.eval()
143
+ model.head_norm.eval()
144
+ model.lm_head.eval()
145
+
131
146
  if TrainerTools().parallel.is_main_process:
132
147
  total_params = sum(p.numel() for p in model.parameters())
133
148
  log(f"Total number of parameters: {total_params:,}")
134
149
 
150
+ trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
151
+ log(f"Trainable number of parameters: {trainable_params:,}")
152
+
135
153
  total_size_bytes = total_params * 4
136
154
  total_size_mb = total_size_bytes / (1024 * 1024)
137
155
  log(f"Total size of the model: {total_size_mb:.2f} MB")
@@ -139,13 +157,13 @@ class Trainer:
139
157
  if use_ds_optim:
140
158
  import deepspeed
141
159
  origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
142
- model.parameters(),
160
+ self._get_trainable_params(model),
143
161
  lr=initial_lr,
144
162
  weight_decay=self.train_config.lr_config.weight_decay
145
163
  )
146
164
  else:
147
165
  origin_optim = torch.optim.AdamW(
148
- model.parameters(),
166
+ self._get_trainable_params(model),
149
167
  lr=initial_lr,
150
168
  weight_decay=self.train_config.lr_config.weight_decay
151
169
  )
@@ -529,7 +547,7 @@ class Trainer:
529
547
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
530
548
  # clip grad
531
549
  self.scalar.unscale_(self.optimizer)
532
- torch.nn.utils.clip_grad_norm_(self.train_model.parameters(), 1.0)
550
+ torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
533
551
 
534
552
  self._step()
535
553
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.3.4
3
+ Version: 0.3.6
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -2,11 +2,11 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
2
  llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
- llm_trainer/dpo_trainer.py,sha256=7Bf6snWcu2fT8QRDI1CSzmrc7Cog6JauIeK2KoW_f8I,13135
5
+ llm_trainer/dpo_trainer.py,sha256=rEhoVN4gPweX5NYKZaEH7jgWav4w6OQ2x-QRocahYjg,13640
6
6
  llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
7
7
  llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
8
8
  llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
9
- llm_trainer/grpo_trainer.py,sha256=M6vp6QjxhBQVaw3e_3BJ4earuezQNKQ3JeZfQLBaSLQ,16370
9
+ llm_trainer/grpo_trainer.py,sha256=1oH0argbpITlzAEkGKW8F9kZPr67bcb95FGOVpP8XTM,16385
10
10
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
11
11
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
12
12
  llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
@@ -18,17 +18,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
18
18
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
19
19
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
20
20
  llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
21
- llm_trainer/train_configs.py,sha256=cadfo8RgxNUR-L3ZLyjiRXTQvhjUl4A1qENaq-ol8h4,15878
22
- llm_trainer/trainer.py,sha256=5DgDzg0TReZrXsIaM6A4DzeJnzePNybGdfoVSDybQ2U,24308
21
+ llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
22
+ llm_trainer/trainer.py,sha256=2cO-MwWJsgPbTisOp_HVIdA0SVodFZx3M8lafarnLdw,25188
23
23
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
24
- project_llm_trainer-0.3.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
- project_llm_trainer-0.3.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
- project_llm_trainer-0.3.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
- project_llm_trainer-0.3.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
- project_llm_trainer-0.3.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
- project_llm_trainer-0.3.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
- project_llm_trainer-0.3.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
- project_llm_trainer-0.3.4.dist-info/METADATA,sha256=Y8XjOGdQb7VxN5QKHyKICkkOzjGcXJuI6hPziULJNfc,195
32
- project_llm_trainer-0.3.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
- project_llm_trainer-0.3.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
- project_llm_trainer-0.3.4.dist-info/RECORD,,
24
+ project_llm_trainer-0.3.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
25
+ project_llm_trainer-0.3.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
26
+ project_llm_trainer-0.3.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
27
+ project_llm_trainer-0.3.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
28
+ project_llm_trainer-0.3.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
29
+ project_llm_trainer-0.3.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
30
+ project_llm_trainer-0.3.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
31
+ project_llm_trainer-0.3.6.dist-info/METADATA,sha256=1ClKvVThd4g8uToJQevDXmjAI8gbVYzDfYImWXHFRqI,195
32
+ project_llm_trainer-0.3.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
33
+ project_llm_trainer-0.3.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
34
+ project_llm_trainer-0.3.6.dist-info/RECORD,,