project-llm-trainer 0.5.4__py3-none-any.whl → 0.5.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +21 -18
- llm_trainer/dpo_trainer.py +7 -4
- llm_trainer/grpo_trainer.py +3 -1
- llm_trainer/trainer.py +5 -2
- {project_llm_trainer-0.5.4.dist-info → project_llm_trainer-0.5.6.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.4.dist-info → project_llm_trainer-0.5.6.dist-info}/RECORD +15 -15
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.4.dist-info → project_llm_trainer-0.5.6.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.4.dist-info → project_llm_trainer-0.5.6.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -38,29 +38,32 @@ def save_best_checkpoint(
|
|
|
38
38
|
) -> bool:
|
|
39
39
|
need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
|
|
40
40
|
if need_replace and TrainerTools().parallel.is_main_process:
|
|
41
|
-
|
|
42
|
-
|
|
41
|
+
try:
|
|
42
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
43
|
+
checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
43
44
|
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
45
|
+
if checkpoint_dir.endswith('/'):
|
|
46
|
+
best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
|
|
47
|
+
else:
|
|
48
|
+
best_checkpoint_dir = f'{checkpoint_dir}_best'
|
|
48
49
|
|
|
49
|
-
|
|
50
|
-
|
|
50
|
+
if not os.path.exists(best_checkpoint_dir):
|
|
51
|
+
os.makedirs(best_checkpoint_dir)
|
|
51
52
|
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
53
|
+
if os.path.exists(checkpoint_dir):
|
|
54
|
+
shutil.rmtree(best_checkpoint_dir)
|
|
55
|
+
shutil.copytree(checkpoint_dir, best_checkpoint_dir)
|
|
56
|
+
else:
|
|
57
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
58
|
+
best_checkpoint_name = f'{checkpoint_name}_best'
|
|
58
59
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
if os.path.exists(checkpoint_name):
|
|
61
|
+
if os.path.exists(best_checkpoint_name):
|
|
62
|
+
os.remove(best_checkpoint_name)
|
|
62
63
|
|
|
63
|
-
|
|
64
|
+
shutil.copy2(checkpoint_name, best_checkpoint_name)
|
|
65
|
+
except:
|
|
66
|
+
pass
|
|
64
67
|
|
|
65
68
|
TrainerTools().parallel.wait()
|
|
66
69
|
return need_replace
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -137,9 +137,10 @@ class DPOTrainer(Trainer):
|
|
|
137
137
|
# 梯度累积步数
|
|
138
138
|
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
139
139
|
global_steps = 0
|
|
140
|
-
loss_accumulation = 0.0
|
|
141
140
|
skipping_train = False
|
|
142
141
|
|
|
142
|
+
loss_accumulation = 0.0
|
|
143
|
+
batches_accumulated = 0
|
|
143
144
|
current_loss: float = 0.0
|
|
144
145
|
last_best_checkpoint_loss: Optional[float] = None
|
|
145
146
|
|
|
@@ -214,14 +215,15 @@ class DPOTrainer(Trainer):
|
|
|
214
215
|
|
|
215
216
|
loss_accumulation += loss.detach().item()
|
|
216
217
|
self._backward_loss(loss)
|
|
218
|
+
batches_accumulated += 1
|
|
217
219
|
|
|
218
220
|
if need_update_grad:
|
|
219
|
-
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
221
|
+
loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
|
|
220
222
|
|
|
221
223
|
if TrainerTools().parallel.parallel_train:
|
|
222
224
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
223
225
|
|
|
224
|
-
|
|
226
|
+
current_loss = loss_tensor.item()
|
|
225
227
|
|
|
226
228
|
# ds模式已经集成gradient_clipping
|
|
227
229
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
@@ -235,10 +237,11 @@ class DPOTrainer(Trainer):
|
|
|
235
237
|
epoch_tag=f'epoch: {epoch}',
|
|
236
238
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
237
239
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
238
|
-
loss=
|
|
240
|
+
loss=current_loss
|
|
239
241
|
)
|
|
240
242
|
# reset to default
|
|
241
243
|
loss_accumulation = 0.0
|
|
244
|
+
batches_accumulated = 0
|
|
242
245
|
except Exception as e:
|
|
243
246
|
self._on_exception(e, epoch, batch)
|
|
244
247
|
finally:
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -345,6 +345,8 @@ class GRPOTrainer(Trainer):
|
|
|
345
345
|
if TrainerTools().parallel.parallel_train:
|
|
346
346
|
dist.all_reduce(loss, dist.ReduceOp.AVG)
|
|
347
347
|
|
|
348
|
+
current_loss = loss.detach().item()
|
|
349
|
+
|
|
348
350
|
# ds模式已经集成gradient_clipping
|
|
349
351
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
350
352
|
# clip grad
|
|
@@ -357,7 +359,7 @@ class GRPOTrainer(Trainer):
|
|
|
357
359
|
epoch_tag=f'epoch: {epoch}',
|
|
358
360
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
359
361
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
360
|
-
loss=
|
|
362
|
+
loss=current_loss
|
|
361
363
|
)
|
|
362
364
|
except Exception as e:
|
|
363
365
|
self._on_exception(e, epoch, batch)
|
llm_trainer/trainer.py
CHANGED
|
@@ -465,9 +465,10 @@ class Trainer:
|
|
|
465
465
|
# 梯度累积步数
|
|
466
466
|
gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
|
|
467
467
|
global_steps = 0
|
|
468
|
-
loss_accumulation = 0.0
|
|
469
468
|
skipping_train = False
|
|
470
469
|
|
|
470
|
+
loss_accumulation = 0.0
|
|
471
|
+
batches_accumulated = 0
|
|
471
472
|
current_loss: float = 0.0
|
|
472
473
|
last_best_checkpoint_loss: Optional[float] = None
|
|
473
474
|
|
|
@@ -536,9 +537,10 @@ class Trainer:
|
|
|
536
537
|
|
|
537
538
|
loss_accumulation += loss.detach().item()
|
|
538
539
|
self._backward_loss(loss)
|
|
540
|
+
batches_accumulated += 1
|
|
539
541
|
|
|
540
542
|
if need_update_grad:
|
|
541
|
-
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
543
|
+
loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
|
|
542
544
|
|
|
543
545
|
if TrainerTools().parallel.parallel_train:
|
|
544
546
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
@@ -561,6 +563,7 @@ class Trainer:
|
|
|
561
563
|
)
|
|
562
564
|
# reset to default
|
|
563
565
|
loss_accumulation = 0.0
|
|
566
|
+
batches_accumulated = 0
|
|
564
567
|
except Exception as e:
|
|
565
568
|
self._on_exception(e, epoch, batch)
|
|
566
569
|
finally:
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=GHRnPpvG0lz8mg2qv0itHr1rXLlj-itOqZWCtOV1IRU,4411
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
7
|
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
11
|
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
|
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.6.dist-info/METADATA,sha256=5JaiIS6GqDsm9o_6Zz-vkec_0rmENISg32Qld1Gn-u8,195
|
|
31
|
+
project_llm_trainer-0.5.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.6.dist-info/RECORD,,
|
{project_llm_trainer-0.5.4.data → project_llm_trainer-0.5.6.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|