project-llm-trainer 0.3.4__py3-none-any.whl → 0.3.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/dpo_trainer.py +18 -8
- llm_trainer/grpo_trainer.py +1 -1
- llm_trainer/train_configs.py +3 -1
- llm_trainer/trainer.py +21 -3
- {project_llm_trainer-0.3.4.dist-info → project_llm_trainer-0.3.6.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.3.4.dist-info → project_llm_trainer-0.3.6.dist-info}/RECORD +15 -15
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.3.4.dist-info → project_llm_trainer-0.3.6.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.3.4.dist-info → project_llm_trainer-0.3.6.dist-info}/top_level.txt +0 -0
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -69,19 +69,29 @@ class DPOTrainer(Trainer):
|
|
|
69
69
|
zero_optimization = {'stage': 0}
|
|
70
70
|
parallel_kwargs['zero_optimization'] = zero_optimization
|
|
71
71
|
|
|
72
|
-
|
|
72
|
+
|
|
73
|
+
if (self.train_config.ds_config.bf16_config is not None
|
|
74
|
+
and self.train_config.ds_config.bf16_config.enabled):
|
|
75
|
+
bf16_config = self.train_config.ds_config.bf16_config
|
|
76
|
+
bf16 = {
|
|
77
|
+
'enabled': bf16_config.enabled
|
|
78
|
+
}
|
|
79
|
+
parallel_kwargs['bf16'] = bf16
|
|
80
|
+
elif self.train_config.ds_config.fp16_config:
|
|
73
81
|
fb16_config = self.train_config.ds_config.fp16_config
|
|
74
|
-
fp16 = {
|
|
82
|
+
fp16 = {
|
|
83
|
+
'enabled': fb16_config.enabled,
|
|
84
|
+
'loss_scale': fb16_config.loss_scale,
|
|
85
|
+
'loss_scale_window': fb16_config.loss_scale_window,
|
|
86
|
+
'initial_scale_power': fb16_config.initial_scale_power,
|
|
87
|
+
'hysteresis': fb16_config.hysteresis,
|
|
88
|
+
'min_loss_scale': fb16_config.min_loss_scale
|
|
89
|
+
}
|
|
75
90
|
|
|
76
91
|
if fb16_config.fp16_opt_level is not None:
|
|
77
92
|
fp16['fp16_opt_level'] = fb16_config.fp16_opt_level
|
|
78
93
|
|
|
79
94
|
parallel_kwargs['fp16'] = fp16
|
|
80
|
-
|
|
81
|
-
if self.train_config.ds_config.bf16_config:
|
|
82
|
-
bf16_config = self.train_config.ds_config.bf16_config
|
|
83
|
-
bf16 = { 'enabled': bf16_config.enabled }
|
|
84
|
-
parallel_kwargs['bf16'] = bf16
|
|
85
95
|
elif isinstance(TrainerTools().parallel, FsdpParallel) and self.train_config.fsdp_config:
|
|
86
96
|
parallel_kwargs = {
|
|
87
97
|
'transformer_layer_cls': self.train_config.fsdp_config.transformer_layer_cls,
|
|
@@ -257,7 +267,7 @@ class DPOTrainer(Trainer):
|
|
|
257
267
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
258
268
|
# clip grad
|
|
259
269
|
self.scalar.unscale_(self.optimizer)
|
|
260
|
-
torch.nn.utils.clip_grad_norm_(self.train_model
|
|
270
|
+
torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
|
|
261
271
|
|
|
262
272
|
self._step()
|
|
263
273
|
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -355,7 +355,7 @@ class GRPOTrainer(Trainer):
|
|
|
355
355
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
356
356
|
# clip grad
|
|
357
357
|
self.scalar.unscale_(self.optimizer)
|
|
358
|
-
torch.nn.utils.clip_grad_norm_(self.train_model
|
|
358
|
+
torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
|
|
359
359
|
|
|
360
360
|
self._step()
|
|
361
361
|
|
llm_trainer/train_configs.py
CHANGED
|
@@ -422,7 +422,8 @@ class TrainConfig:
|
|
|
422
422
|
kd_config: Optional[KDConfig] = None,
|
|
423
423
|
pixel_values_provider: Optional[Callable[[list[str]], torch.Tensor]] = None,
|
|
424
424
|
init_state_dict: Optional[Mapping[str, Any]] = None,
|
|
425
|
-
eval_config: EvalConfig = EvalConfig()
|
|
425
|
+
eval_config: EvalConfig = EvalConfig(),
|
|
426
|
+
freeze_llm_model: bool = False
|
|
426
427
|
):
|
|
427
428
|
self.n_epochs = n_epochs
|
|
428
429
|
self.batch_size = batch_size
|
|
@@ -443,5 +444,6 @@ class TrainConfig:
|
|
|
443
444
|
self.pixel_values_provider = pixel_values_provider
|
|
444
445
|
self.init_state_dict = init_state_dict
|
|
445
446
|
self.eval_config = eval_config
|
|
447
|
+
self.freeze_llm_model = freeze_llm_model
|
|
446
448
|
|
|
447
449
|
|
llm_trainer/trainer.py
CHANGED
|
@@ -116,6 +116,10 @@ class Trainer:
|
|
|
116
116
|
else:
|
|
117
117
|
return LlmModel(train_config.model_config)
|
|
118
118
|
|
|
119
|
+
def _get_trainable_params(self, model):
|
|
120
|
+
freeze_llm_model = self.train_config.freeze_llm_model
|
|
121
|
+
return model.parameters() if not freeze_llm_model else filter(lambda p: p.requires_grad, model.parameters())
|
|
122
|
+
|
|
119
123
|
def _init_train_model_and_optim(
|
|
120
124
|
self,
|
|
121
125
|
initial_lr: float,
|
|
@@ -128,10 +132,24 @@ class Trainer:
|
|
|
128
132
|
model.load_state_dict(self.train_config.init_state_dict, strict=False)
|
|
129
133
|
self.train_config.init_state_dict = None
|
|
130
134
|
|
|
135
|
+
# freeze llm model for vlm training
|
|
136
|
+
if self.train_config.freeze_llm_model:
|
|
137
|
+
for name, param in model.named_parameters():
|
|
138
|
+
if not any(sub_module in name for sub_module in ['vision_tower', 'multi_modal_projector']):
|
|
139
|
+
param.requires_grad = False
|
|
140
|
+
|
|
141
|
+
model.embed_tokens.eval()
|
|
142
|
+
model.layers.eval()
|
|
143
|
+
model.head_norm.eval()
|
|
144
|
+
model.lm_head.eval()
|
|
145
|
+
|
|
131
146
|
if TrainerTools().parallel.is_main_process:
|
|
132
147
|
total_params = sum(p.numel() for p in model.parameters())
|
|
133
148
|
log(f"Total number of parameters: {total_params:,}")
|
|
134
149
|
|
|
150
|
+
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
151
|
+
log(f"Trainable number of parameters: {trainable_params:,}")
|
|
152
|
+
|
|
135
153
|
total_size_bytes = total_params * 4
|
|
136
154
|
total_size_mb = total_size_bytes / (1024 * 1024)
|
|
137
155
|
log(f"Total size of the model: {total_size_mb:.2f} MB")
|
|
@@ -139,13 +157,13 @@ class Trainer:
|
|
|
139
157
|
if use_ds_optim:
|
|
140
158
|
import deepspeed
|
|
141
159
|
origin_optim = deepspeed.ops.adam.DeepSpeedCPUAdam(
|
|
142
|
-
|
|
160
|
+
self._get_trainable_params(model),
|
|
143
161
|
lr=initial_lr,
|
|
144
162
|
weight_decay=self.train_config.lr_config.weight_decay
|
|
145
163
|
)
|
|
146
164
|
else:
|
|
147
165
|
origin_optim = torch.optim.AdamW(
|
|
148
|
-
|
|
166
|
+
self._get_trainable_params(model),
|
|
149
167
|
lr=initial_lr,
|
|
150
168
|
weight_decay=self.train_config.lr_config.weight_decay
|
|
151
169
|
)
|
|
@@ -529,7 +547,7 @@ class Trainer:
|
|
|
529
547
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
530
548
|
# clip grad
|
|
531
549
|
self.scalar.unscale_(self.optimizer)
|
|
532
|
-
torch.nn.utils.clip_grad_norm_(self.train_model
|
|
550
|
+
torch.nn.utils.clip_grad_norm_(self._get_trainable_params(self.train_model), 1.0)
|
|
533
551
|
|
|
534
552
|
self._step()
|
|
535
553
|
|
|
@@ -2,11 +2,11 @@ llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
|
2
2
|
llm_trainer/checkpoint.py,sha256=Dlkcit0o7Gx6S9QUrIrVp2pTurP9X0zVA7w7ImSuVQU,6049
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
5
|
+
llm_trainer/dpo_trainer.py,sha256=rEhoVN4gPweX5NYKZaEH7jgWav4w6OQ2x-QRocahYjg,13640
|
|
6
6
|
llm_trainer/ds_checkpoint.py,sha256=_svpzqRaa43--DKPputoXAelc6X9vPM0gNQu-hlh6NI,2153
|
|
7
7
|
llm_trainer/eval.py,sha256=sCvdYnqWWf5_nuDQN5BHb_YivXLOQW-V0ET9mPu0tPU,2389
|
|
8
8
|
llm_trainer/generate_utils.py,sha256=4iM0vyc_1C_iTL31GlS9PR4eZtYaELPRZ02KDSPZA9U,15158
|
|
9
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
9
|
+
llm_trainer/grpo_trainer.py,sha256=1oH0argbpITlzAEkGKW8F9kZPr67bcb95FGOVpP8XTM,16385
|
|
10
10
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
11
11
|
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
12
12
|
llm_trainer/parallel.py,sha256=2VJtW3Gq2c1yS_LdcrNhk7B12prFwBmFnKhvV8FS2d8,4428
|
|
@@ -18,17 +18,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
|
18
18
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
19
19
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
20
20
|
llm_trainer/tools.py,sha256=AhfjN9oln5Pyif1SgCWwgQg-Q5acTCd9xpz4L26QUjA,3039
|
|
21
|
-
llm_trainer/train_configs.py,sha256=
|
|
22
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
|
|
22
|
+
llm_trainer/trainer.py,sha256=2cO-MwWJsgPbTisOp_HVIdA0SVodFZx3M8lafarnLdw,25188
|
|
23
23
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
24
|
-
project_llm_trainer-0.3.
|
|
25
|
-
project_llm_trainer-0.3.
|
|
26
|
-
project_llm_trainer-0.3.
|
|
27
|
-
project_llm_trainer-0.3.
|
|
28
|
-
project_llm_trainer-0.3.
|
|
29
|
-
project_llm_trainer-0.3.
|
|
30
|
-
project_llm_trainer-0.3.
|
|
31
|
-
project_llm_trainer-0.3.
|
|
32
|
-
project_llm_trainer-0.3.
|
|
33
|
-
project_llm_trainer-0.3.
|
|
34
|
-
project_llm_trainer-0.3.
|
|
24
|
+
project_llm_trainer-0.3.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
25
|
+
project_llm_trainer-0.3.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
26
|
+
project_llm_trainer-0.3.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
27
|
+
project_llm_trainer-0.3.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
28
|
+
project_llm_trainer-0.3.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
29
|
+
project_llm_trainer-0.3.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
30
|
+
project_llm_trainer-0.3.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
31
|
+
project_llm_trainer-0.3.6.dist-info/METADATA,sha256=1ClKvVThd4g8uToJQevDXmjAI8gbVYzDfYImWXHFRqI,195
|
|
32
|
+
project_llm_trainer-0.3.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
33
|
+
project_llm_trainer-0.3.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
34
|
+
project_llm_trainer-0.3.6.dist-info/RECORD,,
|
{project_llm_trainer-0.3.4.data → project_llm_trainer-0.3.6.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|