project-llm-trainer 0.7.4__py3-none-any.whl → 0.7.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/grpo_trainer.py +16 -8
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/smart_train +20 -2
- {project_llm_trainer-0.7.4.dist-info → project_llm_trainer-0.7.6.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.7.4.dist-info → project_llm_trainer-0.7.6.dist-info}/RECORD +12 -12
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.7.4.dist-info → project_llm_trainer-0.7.6.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.7.4.dist-info → project_llm_trainer-0.7.6.dist-info}/top_level.txt +0 -0
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -50,6 +50,10 @@ class GRPOTrainer(Trainer):
|
|
|
50
50
|
self._use_origin_pad_sequence = True
|
|
51
51
|
|
|
52
52
|
def _init_ref_model(self):
|
|
53
|
+
# beta == 0,不需要ref_model
|
|
54
|
+
if self.train_config.grpo_config.loss_beta == 0.0:
|
|
55
|
+
return None
|
|
56
|
+
|
|
53
57
|
ref_model = self._new_model(self.train_config)
|
|
54
58
|
|
|
55
59
|
ref_model, _ = TrainerTools().parallel.process(
|
|
@@ -230,8 +234,11 @@ class GRPOTrainer(Trainer):
|
|
|
230
234
|
# Compute old_log_probs from the current model, with gradients disabled.
|
|
231
235
|
old_log_probs, _ = self._compute_log_probabilities(generate_model, input_ids, attention_mask, logits_to_keep)
|
|
232
236
|
|
|
233
|
-
|
|
234
|
-
|
|
237
|
+
if self.ref_model:
|
|
238
|
+
# Compute ref_log_probs from the reference model, which remains static.
|
|
239
|
+
ref_log_probs, _ = self._compute_log_probabilities(self.ref_model, input_ids, attention_mask, logits_to_keep)
|
|
240
|
+
else:
|
|
241
|
+
ref_log_probs = None
|
|
235
242
|
|
|
236
243
|
repeated_prompts = [p for p in prompts for _ in range(group_size)]
|
|
237
244
|
repeated_answers = [a for a in answers for _ in range(group_size)]
|
|
@@ -294,11 +301,12 @@ class GRPOTrainer(Trainer):
|
|
|
294
301
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
295
302
|
|
|
296
303
|
for epoch in range(self.train_config.n_epochs):
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
304
|
+
if self.ref_model:
|
|
305
|
+
sync_model_params(
|
|
306
|
+
_from=self.train_model,
|
|
307
|
+
_to=self.ref_model,
|
|
308
|
+
mixup_alpha=self.train_config.grpo_config.mixup_alpha
|
|
309
|
+
)
|
|
302
310
|
|
|
303
311
|
file_count = len(self.train_config.file_dataset)
|
|
304
312
|
|
|
@@ -366,7 +374,7 @@ class GRPOTrainer(Trainer):
|
|
|
366
374
|
self._log_loss(
|
|
367
375
|
epoch_tag=f'epoch: {epoch}',
|
|
368
376
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
369
|
-
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
377
|
+
batch_tag=f'batch: {batch}/{batch_count_per_file}, grpo_step={grpo_step}',
|
|
370
378
|
loss=current_loss
|
|
371
379
|
)
|
|
372
380
|
except Exception as e:
|
|
@@ -2,9 +2,24 @@
|
|
|
2
2
|
|
|
3
3
|
if __name__ == '__main__':
|
|
4
4
|
import os, sys, torch
|
|
5
|
+
|
|
5
6
|
arguments = sys.argv[1:]
|
|
7
|
+
# file name
|
|
6
8
|
run_file_name = arguments[0]
|
|
7
9
|
|
|
10
|
+
# cuda_visible_devive
|
|
11
|
+
if len(arguments) > 1:
|
|
12
|
+
# 0,1,2,3
|
|
13
|
+
cuda_visible_devive = arguments[1]
|
|
14
|
+
else:
|
|
15
|
+
cuda_visible_devive = None
|
|
16
|
+
|
|
17
|
+
# cuda location
|
|
18
|
+
if len(arguments) > 2:
|
|
19
|
+
cuda_loc = arguments[2]
|
|
20
|
+
else:
|
|
21
|
+
cuda_loc = 'localhost'
|
|
22
|
+
|
|
8
23
|
try:
|
|
9
24
|
import deepspeed
|
|
10
25
|
parallel_type = 'ds'
|
|
@@ -18,11 +33,14 @@ if __name__ == '__main__':
|
|
|
18
33
|
os.environ['PARALLEL_TYPE'] = parallel_type
|
|
19
34
|
|
|
20
35
|
if parallel_type == 'ds':
|
|
21
|
-
|
|
36
|
+
cuda_ctrl = f' --include {cuda_loc}:{cuda_visible_devive}' if cuda_visible_devive else ''
|
|
37
|
+
command = f'deepspeed{cuda_ctrl} {run_file_name}'
|
|
22
38
|
elif parallel_type == 'ddp':
|
|
39
|
+
if cuda_visible_devive:
|
|
40
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = cuda_visible_devive
|
|
23
41
|
command = f'torchrun --standalone --nproc_per_node=gpu {run_file_name}'
|
|
24
42
|
else:
|
|
25
43
|
command = f'python3 {run_file_name}'
|
|
26
44
|
|
|
27
|
-
print(f'
|
|
45
|
+
print(f'run command {command}')
|
|
28
46
|
os.system(command)
|
|
@@ -5,7 +5,7 @@ llm_trainer/dpo_trainer.py,sha256=RMfbTsl3eav4yTJ2PK59mi6a0ECVOg8WwYVsHvMbNUE,12
|
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=X2IWgpgi0yOtogph7n6DEwvK_0Ceb7juu1WMutv3HSk,2270
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=8K3YFbp7IF_lCkmkzjHhqTW26EBFb2AilQmarVcfMvs,15001
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=MXnP8Kc9CQJw0CB3uMbHxIYwvpuujai4hgbbpUut_K4,16808
|
|
9
9
|
llm_trainer/log.py,sha256=XwychwKF6gvFPhthCIZCAEUZ0G3DY3fiQrOHqPWsxz0,463
|
|
10
10
|
llm_trainer/loss.py,sha256=glf4IeDWHvA2cJo-QKLRL8P6OxK4QjRJGrYJWOZiTPQ,6929
|
|
11
11
|
llm_trainer/parallel.py,sha256=yjStV21DJ26yM8-0O6GTMxdFAcyShY5GsQWSZmbI7HU,4543
|
|
@@ -20,14 +20,14 @@ llm_trainer/tools.py,sha256=5op5qrjjkK-Lr9oes5VxIVnOVYOYGoAdlIJq9mPUf64,2637
|
|
|
20
20
|
llm_trainer/train_configs.py,sha256=N3ykM1uaLHcSNRC8ErYIxp9VYhSP7voJyAP-2D4ZJe0,7574
|
|
21
21
|
llm_trainer/trainer.py,sha256=jS31zEXIIj9BoPTPlmaGYq61x72HGCjKfS2u3_gOkDk,27924
|
|
22
22
|
llm_trainer/utils.py,sha256=xcdzpvPvXRKqsOK2yB7PZ9GmOvZMDFcglDPUZY2hJTY,11484
|
|
23
|
-
project_llm_trainer-0.7.
|
|
24
|
-
project_llm_trainer-0.7.
|
|
25
|
-
project_llm_trainer-0.7.
|
|
26
|
-
project_llm_trainer-0.7.
|
|
27
|
-
project_llm_trainer-0.7.
|
|
28
|
-
project_llm_trainer-0.7.
|
|
29
|
-
project_llm_trainer-0.7.
|
|
30
|
-
project_llm_trainer-0.7.
|
|
31
|
-
project_llm_trainer-0.7.
|
|
32
|
-
project_llm_trainer-0.7.
|
|
33
|
-
project_llm_trainer-0.7.
|
|
23
|
+
project_llm_trainer-0.7.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.7.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.7.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.7.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.7.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.7.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.7.6.data/scripts/smart_train,sha256=3oLIDuuqb4U4TU1lXy9V8lw_0gIf7i8tGsxlQ_s6bro,1220
|
|
30
|
+
project_llm_trainer-0.7.6.dist-info/METADATA,sha256=t52f6ahI8WvTnTguykneF91x-ChSZ84sE9PaBjvqb1g,195
|
|
31
|
+
project_llm_trainer-0.7.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.7.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.7.6.dist-info/RECORD,,
|
{project_llm_trainer-0.7.4.data → project_llm_trainer-0.7.6.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|