project-llm-trainer 0.5.3__py3-none-any.whl → 0.5.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -14,17 +14,14 @@ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
14
14
 
15
15
  def save_checkpoint(
16
16
  model: nn.Module,
17
- optimizer: Optional[Optimizer] = None,
18
- suffix: Optional[str] = None
17
+ optimizer: Optional[Optimizer] = None
19
18
  ):
20
19
  if isinstance(TrainerTools().parallel, DsParallel):
21
20
  from .ds_checkpoint import save_ds_checkpoint
22
- save_ds_checkpoint(model, suffix)
21
+ save_ds_checkpoint(model)
23
22
  else:
24
23
  if TrainerTools().parallel.is_main_process:
25
24
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
26
- if suffix:
27
- checkpoint_name = f"{checkpoint_name}_{suffix}"
28
25
 
29
26
  raw_model = model if not isinstance(model, DDP) else model.module
30
27
  ckpt = {'model_state_dict': raw_model.state_dict()}
@@ -37,28 +34,26 @@ def save_checkpoint(
37
34
 
38
35
  def save_best_checkpoint(
39
36
  current_loss: float,
40
- last_best_checkpoint_loss: float,
41
- suffix: Optional[str] = None
37
+ last_best_checkpoint_loss: Optional[float] = None
42
38
  ) -> bool:
43
- need_replace = current_loss <= last_best_checkpoint_loss
39
+ need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
44
40
  if need_replace and TrainerTools().parallel.is_main_process:
45
41
  if isinstance(TrainerTools().parallel, DsParallel):
46
- checkpoint_name = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
47
- if suffix:
48
- checkpoint_name = f"{checkpoint_name}_{suffix}"
42
+ checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
49
43
 
50
- best_checkpoint_name = f'{checkpoint_name}_best'
51
- if not os.path.exists(best_checkpoint_name):
52
- os.makedirs(best_checkpoint_name)
44
+ if checkpoint_dir.endswith('/'):
45
+ best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
46
+ else:
47
+ best_checkpoint_dir = f'{checkpoint_dir}_best'
53
48
 
54
- if os.path.exists(checkpoint_name):
55
- shutil.rmtree(best_checkpoint_name)
56
- shutil.copytree(checkpoint_name, best_checkpoint_name)
49
+ if not os.path.exists(best_checkpoint_dir):
50
+ os.makedirs(best_checkpoint_dir)
51
+
52
+ if os.path.exists(checkpoint_dir):
53
+ shutil.rmtree(best_checkpoint_dir)
54
+ shutil.copytree(checkpoint_dir, best_checkpoint_dir)
57
55
  else:
58
56
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
59
- if suffix:
60
- checkpoint_name = f"{checkpoint_name}_{suffix}"
61
-
62
57
  best_checkpoint_name = f'{checkpoint_name}_best'
63
58
 
64
59
  if os.path.exists(checkpoint_name):
@@ -75,16 +70,13 @@ def load_checkpoint(
75
70
  model: nn.Module,
76
71
  optimizer: Optional[Optimizer] = None,
77
72
  device: Optional[Union[torch.device, str]] = None,
78
- load_module_only: bool = False,
79
- suffix: Optional[str] = None
73
+ load_module_only: bool = False
80
74
  ):
81
75
  if isinstance(TrainerTools().parallel, DsParallel):
82
76
  from .ds_checkpoint import load_ds_checkpoint
83
- load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
77
+ load_ds_checkpoint(model, load_module_only=load_module_only)
84
78
  else:
85
79
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
86
- if suffix:
87
- checkpoint_name = f"{checkpoint_name}_{suffix}"
88
80
 
89
81
  state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
90
82
  raw_model = model.module if isinstance(model, DDP) else model
@@ -96,14 +88,13 @@ def load_checkpoint(
96
88
 
97
89
  def load_checkpoint_for_eval(
98
90
  model: nn.Module,
99
- device: Optional[Union[torch.device, str]] = None,
100
- suffix: Optional[str] = None
91
+ device: Optional[Union[torch.device, str]] = None
101
92
  ):
102
93
  if isinstance(TrainerTools().parallel, DsParallel):
103
94
  from .ds_checkpoint import load_ds_checkpoint_for_eval
104
95
  load_ds_checkpoint_for_eval(model)
105
96
  else:
106
- load_checkpoint(model, None, device, suffix=suffix)
97
+ load_checkpoint(model, None, device)
107
98
 
108
99
 
109
100
  def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
@@ -137,11 +137,12 @@ class DPOTrainer(Trainer):
137
137
  # 梯度累积步数
138
138
  gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
139
139
  global_steps = 0
140
- loss_accumulation = 0.0
141
140
  skipping_train = False
142
141
 
142
+ loss_accumulation = 0.0
143
+ batches_accumulated = 0
143
144
  current_loss: float = 0.0
144
- last_best_checkpoint_loss: float = 0.0
145
+ last_best_checkpoint_loss: Optional[float] = None
145
146
 
146
147
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
147
148
 
@@ -214,14 +215,15 @@ class DPOTrainer(Trainer):
214
215
 
215
216
  loss_accumulation += loss.detach().item()
216
217
  self._backward_loss(loss)
218
+ batches_accumulated += 1
217
219
 
218
220
  if need_update_grad:
219
- loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
221
+ loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
220
222
 
221
223
  if TrainerTools().parallel.parallel_train:
222
224
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
223
225
 
224
- final_log_loss = loss_tensor.item()
226
+ current_loss = loss_tensor.item()
225
227
 
226
228
  # ds模式已经集成gradient_clipping
227
229
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
@@ -235,10 +237,11 @@ class DPOTrainer(Trainer):
235
237
  epoch_tag=f'epoch: {epoch}',
236
238
  file_tag=f'file: {file_idx + 1}/{file_count}',
237
239
  batch_tag=f'batch: {batch}/{batch_count_per_file}',
238
- loss=final_log_loss
240
+ loss=current_loss
239
241
  )
240
242
  # reset to default
241
243
  loss_accumulation = 0.0
244
+ batches_accumulated = 0
242
245
  except Exception as e:
243
246
  self._on_exception(e, epoch, batch)
244
247
  finally:
@@ -1,5 +1,4 @@
1
1
  import os
2
- from typing import Optional
3
2
  from glob import glob
4
3
  import shutil
5
4
  from torch import nn
@@ -17,14 +16,9 @@ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器
17
16
  convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
18
17
  """
19
18
 
20
- def save_ds_checkpoint(
21
- model: nn.Module,
22
- suffix: Optional[str] = None
23
- ):
19
+ def save_ds_checkpoint(model: nn.Module):
24
20
  assert isinstance(model, DeepSpeedEngine)
25
21
  ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
26
- if suffix:
27
- ckpt_dir = f"{ckpt_dir}_{suffix}"
28
22
 
29
23
  try:
30
24
  # 包括model、optimizer等状态
@@ -44,13 +38,10 @@ def save_ds_checkpoint(
44
38
 
45
39
  def load_ds_checkpoint(
46
40
  model: nn.Module,
47
- load_module_only: bool = False,
48
- suffix: Optional[str] = None
41
+ load_module_only: bool = False
49
42
  ):
50
43
  assert isinstance(model, DeepSpeedEngine)
51
44
  ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
52
- if suffix:
53
- ckpt_dir = f"{ckpt_dir}_{suffix}"
54
45
 
55
46
  # 包括model、optimizer等状态
56
47
  if os.path.exists(ckpt_dir):
@@ -283,7 +283,7 @@ class GRPOTrainer(Trainer):
283
283
  skipping_train = False
284
284
 
285
285
  current_loss: float = 0.0
286
- last_best_checkpoint_loss: float = 0.0
286
+ last_best_checkpoint_loss: Optional[float] = None
287
287
 
288
288
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
289
289
 
@@ -345,6 +345,8 @@ class GRPOTrainer(Trainer):
345
345
  if TrainerTools().parallel.parallel_train:
346
346
  dist.all_reduce(loss, dist.ReduceOp.AVG)
347
347
 
348
+ current_loss = loss.detach().item()
349
+
348
350
  # ds模式已经集成gradient_clipping
349
351
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
350
352
  # clip grad
@@ -357,7 +359,7 @@ class GRPOTrainer(Trainer):
357
359
  epoch_tag=f'epoch: {epoch}',
358
360
  file_tag=f'file: {file_idx + 1}/{file_count}',
359
361
  batch_tag=f'batch: {batch}/{batch_count_per_file}',
360
- loss=loss.detach().item()
362
+ loss=current_loss
361
363
  )
362
364
  except Exception as e:
363
365
  self._on_exception(e, epoch, batch)
llm_trainer/trainer.py CHANGED
@@ -465,11 +465,12 @@ class Trainer:
465
465
  # 梯度累积步数
466
466
  gradient_accumulation_steps = self.train_config.gradient_accumulation_steps
467
467
  global_steps = 0
468
- loss_accumulation = 0.0
469
468
  skipping_train = False
470
469
 
470
+ loss_accumulation = 0.0
471
+ batches_accumulated = 0
471
472
  current_loss: float = 0.0
472
- last_best_checkpoint_loss: float = 0.0
473
+ last_best_checkpoint_loss: Optional[float] = None
473
474
 
474
475
  for epoch in range(self.train_config.n_epochs):
475
476
  self.train_model.train()
@@ -536,9 +537,10 @@ class Trainer:
536
537
 
537
538
  loss_accumulation += loss.detach().item()
538
539
  self._backward_loss(loss)
540
+ batches_accumulated += 1
539
541
 
540
542
  if need_update_grad:
541
- loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
543
+ loss_tensor = torch.tensor(loss_accumulation * gradient_accumulation_steps / batches_accumulated, device=TrainerTools().parallel.device)
542
544
 
543
545
  if TrainerTools().parallel.parallel_train:
544
546
  dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
@@ -561,6 +563,7 @@ class Trainer:
561
563
  )
562
564
  # reset to default
563
565
  loss_accumulation = 0.0
566
+ batches_accumulated = 0
564
567
  except Exception as e:
565
568
  self._on_exception(e, epoch, batch)
566
569
  finally:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.3
3
+ Version: 0.5.5
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=67q1zDYvcbS2zce1PVt3nmsPzqniHu0f2pI-cyyCkng,4647
2
+ llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=xfYXlLA5TbqPKCUbk5_V79TreEh-dnLMaN72a3-Tdzg,11860
5
- llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
4
+ llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
+ llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
6
6
  llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
7
  llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
8
- llm_trainer/grpo_trainer.py,sha256=vTNi3n6R4NbwFh_s8LYN1TWEJm8AW2F5NVJlT5MHxKk,15990
8
+ llm_trainer/grpo_trainer.py,sha256=sCYjvksdm9f7TpN23KXuCmua_8VFTZEfVEcflL89P_I,16058
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
11
  llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
21
- llm_trainer/trainer.py,sha256=g8YUP0FmBP3MGwewyoyOW35p9CY98rS62pzjnOMiWvE,25875
21
+ llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.3.dist-info/METADATA,sha256=iDB3C1trVLQsnwsRxeFm7Oi2YpNevuX3XO2WZFlL7wg,195
31
- project_llm_trainer-0.5.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.3.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.5.dist-info/METADATA,sha256=ajxfapuo4Q2xfdJ3kjZoCzs7Q5ynGp6BssXRFOIbF7Y,195
31
+ project_llm_trainer-0.5.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.5.dist-info/RECORD,,