project-llm-trainer 0.5.15__py3-none-any.whl → 0.5.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +1 -1
- llm_trainer/dpo_trainer.py +8 -3
- llm_trainer/ds_checkpoint.py +1 -1
- llm_trainer/grpo_trainer.py +3 -1
- llm_trainer/parallel.py +5 -6
- llm_trainer/trainer.py +9 -4
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.16.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.16.dist-info}/RECORD +17 -17
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.16.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.15.dist-info → project_llm_trainer-0.5.16.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -170,14 +170,19 @@ class DPOTrainer(Trainer):
|
|
|
170
170
|
skipping_train = True
|
|
171
171
|
continue
|
|
172
172
|
|
|
173
|
-
skipping_train = False
|
|
174
|
-
|
|
175
173
|
# 是否需要更新梯度
|
|
176
|
-
if
|
|
174
|
+
if skipping_train:
|
|
175
|
+
need_update_grad = False
|
|
176
|
+
elif gradient_accumulation_steps > 1:
|
|
177
177
|
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
178
178
|
else:
|
|
179
179
|
need_update_grad = True
|
|
180
180
|
|
|
181
|
+
# 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
|
|
182
|
+
if skipping_train:
|
|
183
|
+
TrainerTools().parallel.wait('skip train')
|
|
184
|
+
skipping_train = False
|
|
185
|
+
|
|
181
186
|
try:
|
|
182
187
|
chosen_inputs: torch.Tensor = batch_data['chosen_inputs'].to(TrainerTools().parallel.device)
|
|
183
188
|
chosen_labels: torch.Tensor = batch_data['chosen_labels'].to(TrainerTools().parallel.device)
|
llm_trainer/ds_checkpoint.py
CHANGED
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -317,7 +317,9 @@ class GRPOTrainer(Trainer):
|
|
|
317
317
|
skipping_train = True
|
|
318
318
|
continue
|
|
319
319
|
|
|
320
|
-
skipping_train
|
|
320
|
+
if skipping_train:
|
|
321
|
+
TrainerTools().parallel.wait('skip train')
|
|
322
|
+
skipping_train = False
|
|
321
323
|
|
|
322
324
|
# start generate
|
|
323
325
|
if TrainerTools().parallel.is_main_process:
|
llm_trainer/parallel.py
CHANGED
|
@@ -139,9 +139,8 @@ class Parallel(ABC):
|
|
|
139
139
|
return dist.get_world_size()
|
|
140
140
|
return 1
|
|
141
141
|
|
|
142
|
-
def wait(self):
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
log(f'continue at {self.device}')
|
|
142
|
+
def wait(self, msg=None):
|
|
143
|
+
msg = f' for {msg}' if msg else None
|
|
144
|
+
log(f'wait at {self.device}{msg}')
|
|
145
|
+
dist.barrier()
|
|
146
|
+
log(f'continue at {self.device}{msg}')
|
llm_trainer/trainer.py
CHANGED
|
@@ -449,7 +449,7 @@ class Trainer:
|
|
|
449
449
|
)
|
|
450
450
|
generate_model.train()
|
|
451
451
|
|
|
452
|
-
TrainerTools().parallel.wait()
|
|
452
|
+
TrainerTools().parallel.wait('eval')
|
|
453
453
|
|
|
454
454
|
def _on_batch_end(self, tag: str):
|
|
455
455
|
self._eval(f'sign:batch/{tag}')
|
|
@@ -500,14 +500,19 @@ class Trainer:
|
|
|
500
500
|
skipping_train = True
|
|
501
501
|
continue
|
|
502
502
|
|
|
503
|
-
skipping_train = False
|
|
504
|
-
|
|
505
503
|
# 是否需要更新梯度
|
|
506
|
-
if
|
|
504
|
+
if skipping_train:
|
|
505
|
+
need_update_grad = False
|
|
506
|
+
elif gradient_accumulation_steps > 1:
|
|
507
507
|
need_update_grad = (batch + 1) % gradient_accumulation_steps == 0 or batch == batch_count_per_file - 1
|
|
508
508
|
else:
|
|
509
509
|
need_update_grad = True
|
|
510
510
|
|
|
511
|
+
# 要放在need_update_grad赋值下面,解决在继续训练时未知原因的卡死现象
|
|
512
|
+
if skipping_train:
|
|
513
|
+
TrainerTools().parallel.wait('skip train')
|
|
514
|
+
skipping_train = False
|
|
515
|
+
|
|
511
516
|
inputs = batch_data['inputs']
|
|
512
517
|
labels = batch_data['labels']
|
|
513
518
|
|
|
@@ -1,14 +1,14 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=gz31pZbbQvRTYrBhxV-MFaBAIFeqpe7rM6nFsjwT9lY,4328
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256
|
|
5
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=--ItH-rkkq24Da3M_Kf0VxpQ3t-k0fpZrzFGqkYsjks,12304
|
|
5
|
+
llm_trainer/ds_checkpoint.py,sha256=Wzy7PvVVWR794-BW4uragWFTAkkgDvjvkF-qMdyB4fc,2141
|
|
6
6
|
llm_trainer/eval.py,sha256=ZyUfSo2Q8P-lrCdPEnGkoo5pGubd0AabREK5eMISRII,1109
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=wrZoG2g7CsOyG4sb3px9vURHQFV6_9j5kQmpFc5A8yg,15335
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=g_ivzQop2SkvhlKAEWb0zUnIvNuHTfsOoIG6y29oTCw,16106
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
|
-
llm_trainer/parallel.py,sha256=
|
|
11
|
+
llm_trainer/parallel.py,sha256=j1L4n-JmDkDZblURrNKpEAWEqqGIAXAN9PT_fSS_OnE,4492
|
|
12
12
|
llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1345
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=992wy0YhBG2WvxwdLEPL4_-JUl4NkwMPT-jj_BIHo6A,7347
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=YqWhD9jXbrUdm3KEjEHLyg_qHiXCy5R7PK-arCXxJ6M,26399
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.16.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.16.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.16.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.16.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.16.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.16.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.16.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.16.dist-info/METADATA,sha256=h0TMNrZMUU875tVasbuqt69EuPPMbo_nv6tHQLKeNbQ,196
|
|
31
|
+
project_llm_trainer-0.5.16.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.16.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.16.dist-info/RECORD,,
|
{project_llm_trainer-0.5.15.data → project_llm_trainer-0.5.16.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|