project-llm-trainer 0.5.3__py3-none-any.whl → 0.5.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -14,17 +14,14 @@ DEFAULT_CHECKPOINT_NAME = "checkpoint.pth"
14
14
 
15
15
  def save_checkpoint(
16
16
  model: nn.Module,
17
- optimizer: Optional[Optimizer] = None,
18
- suffix: Optional[str] = None
17
+ optimizer: Optional[Optimizer] = None
19
18
  ):
20
19
  if isinstance(TrainerTools().parallel, DsParallel):
21
20
  from .ds_checkpoint import save_ds_checkpoint
22
- save_ds_checkpoint(model, suffix)
21
+ save_ds_checkpoint(model)
23
22
  else:
24
23
  if TrainerTools().parallel.is_main_process:
25
24
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
26
- if suffix:
27
- checkpoint_name = f"{checkpoint_name}_{suffix}"
28
25
 
29
26
  raw_model = model if not isinstance(model, DDP) else model.module
30
27
  ckpt = {'model_state_dict': raw_model.state_dict()}
@@ -37,28 +34,26 @@ def save_checkpoint(
37
34
 
38
35
  def save_best_checkpoint(
39
36
  current_loss: float,
40
- last_best_checkpoint_loss: float,
41
- suffix: Optional[str] = None
37
+ last_best_checkpoint_loss: Optional[float] = None
42
38
  ) -> bool:
43
- need_replace = current_loss <= last_best_checkpoint_loss
39
+ need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
44
40
  if need_replace and TrainerTools().parallel.is_main_process:
45
41
  if isinstance(TrainerTools().parallel, DsParallel):
46
- checkpoint_name = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
47
- if suffix:
48
- checkpoint_name = f"{checkpoint_name}_{suffix}"
42
+ checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
49
43
 
50
- best_checkpoint_name = f'{checkpoint_name}_best'
51
- if not os.path.exists(best_checkpoint_name):
52
- os.makedirs(best_checkpoint_name)
44
+ if checkpoint_dir.endswith('/'):
45
+ best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
46
+ else:
47
+ best_checkpoint_dir = f'{checkpoint_dir}_best'
53
48
 
54
- if os.path.exists(checkpoint_name):
55
- shutil.rmtree(best_checkpoint_name)
56
- shutil.copytree(checkpoint_name, best_checkpoint_name)
49
+ if not os.path.exists(best_checkpoint_dir):
50
+ os.makedirs(best_checkpoint_dir)
51
+
52
+ if os.path.exists(checkpoint_dir):
53
+ shutil.rmtree(best_checkpoint_dir)
54
+ shutil.copytree(checkpoint_dir, best_checkpoint_dir)
57
55
  else:
58
56
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
59
- if suffix:
60
- checkpoint_name = f"{checkpoint_name}_{suffix}"
61
-
62
57
  best_checkpoint_name = f'{checkpoint_name}_best'
63
58
 
64
59
  if os.path.exists(checkpoint_name):
@@ -75,16 +70,13 @@ def load_checkpoint(
75
70
  model: nn.Module,
76
71
  optimizer: Optional[Optimizer] = None,
77
72
  device: Optional[Union[torch.device, str]] = None,
78
- load_module_only: bool = False,
79
- suffix: Optional[str] = None
73
+ load_module_only: bool = False
80
74
  ):
81
75
  if isinstance(TrainerTools().parallel, DsParallel):
82
76
  from .ds_checkpoint import load_ds_checkpoint
83
- load_ds_checkpoint(model, load_module_only=load_module_only, suffix=suffix)
77
+ load_ds_checkpoint(model, load_module_only=load_module_only)
84
78
  else:
85
79
  checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
86
- if suffix:
87
- checkpoint_name = f"{checkpoint_name}_{suffix}"
88
80
 
89
81
  state_dict = torch.load(checkpoint_name, weights_only=True, map_location=device)
90
82
  raw_model = model.module if isinstance(model, DDP) else model
@@ -96,14 +88,13 @@ def load_checkpoint(
96
88
 
97
89
  def load_checkpoint_for_eval(
98
90
  model: nn.Module,
99
- device: Optional[Union[torch.device, str]] = None,
100
- suffix: Optional[str] = None
91
+ device: Optional[Union[torch.device, str]] = None
101
92
  ):
102
93
  if isinstance(TrainerTools().parallel, DsParallel):
103
94
  from .ds_checkpoint import load_ds_checkpoint_for_eval
104
95
  load_ds_checkpoint_for_eval(model)
105
96
  else:
106
- load_checkpoint(model, None, device, suffix=suffix)
97
+ load_checkpoint(model, None, device)
107
98
 
108
99
 
109
100
  def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
@@ -141,7 +141,7 @@ class DPOTrainer(Trainer):
141
141
  skipping_train = False
142
142
 
143
143
  current_loss: float = 0.0
144
- last_best_checkpoint_loss: float = 0.0
144
+ last_best_checkpoint_loss: Optional[float] = None
145
145
 
146
146
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
147
147
 
@@ -1,5 +1,4 @@
1
1
  import os
2
- from typing import Optional
3
2
  from glob import glob
4
3
  import shutil
5
4
  from torch import nn
@@ -17,14 +16,9 @@ load_state_dict_from_zero_checkpoint 从 ZeRO 检查点加载模型和优化器
17
16
  convert_zero_checkpoint_to_fp32_state_dict 将 ZeRO 检查点转换为独立的 FP32 状态字典文件 否 是 创建可移植的 FP32 权重文件,用于部署、分享等
18
17
  """
19
18
 
20
- def save_ds_checkpoint(
21
- model: nn.Module,
22
- suffix: Optional[str] = None
23
- ):
19
+ def save_ds_checkpoint(model: nn.Module):
24
20
  assert isinstance(model, DeepSpeedEngine)
25
21
  ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
26
- if suffix:
27
- ckpt_dir = f"{ckpt_dir}_{suffix}"
28
22
 
29
23
  try:
30
24
  # 包括model、optimizer等状态
@@ -44,13 +38,10 @@ def save_ds_checkpoint(
44
38
 
45
39
  def load_ds_checkpoint(
46
40
  model: nn.Module,
47
- load_module_only: bool = False,
48
- suffix: Optional[str] = None
41
+ load_module_only: bool = False
49
42
  ):
50
43
  assert isinstance(model, DeepSpeedEngine)
51
44
  ckpt_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
52
- if suffix:
53
- ckpt_dir = f"{ckpt_dir}_{suffix}"
54
45
 
55
46
  # 包括model、optimizer等状态
56
47
  if os.path.exists(ckpt_dir):
@@ -283,7 +283,7 @@ class GRPOTrainer(Trainer):
283
283
  skipping_train = False
284
284
 
285
285
  current_loss: float = 0.0
286
- last_best_checkpoint_loss: float = 0.0
286
+ last_best_checkpoint_loss: Optional[float] = None
287
287
 
288
288
  aux_loss_coef = self.train_config.loss_config.aux_loss_coef
289
289
 
llm_trainer/trainer.py CHANGED
@@ -469,7 +469,7 @@ class Trainer:
469
469
  skipping_train = False
470
470
 
471
471
  current_loss: float = 0.0
472
- last_best_checkpoint_loss: float = 0.0
472
+ last_best_checkpoint_loss: Optional[float] = None
473
473
 
474
474
  for epoch in range(self.train_config.n_epochs):
475
475
  self.train_model.train()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.3
3
+ Version: 0.5.4
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,11 +1,11 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=67q1zDYvcbS2zce1PVt3nmsPzqniHu0f2pI-cyyCkng,4647
2
+ llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
- llm_trainer/dpo_trainer.py,sha256=xfYXlLA5TbqPKCUbk5_V79TreEh-dnLMaN72a3-Tdzg,11860
5
- llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
4
+ llm_trainer/dpo_trainer.py,sha256=3hCjX06W6nt-8lio0YyGzXdI_saa5QlypXehQTtacO4,11871
5
+ llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
6
6
  llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
7
  llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
8
- llm_trainer/grpo_trainer.py,sha256=vTNi3n6R4NbwFh_s8LYN1TWEJm8AW2F5NVJlT5MHxKk,15990
8
+ llm_trainer/grpo_trainer.py,sha256=VltU9AngecJ5PEtpn_ToXfiLIc6VwwgOVwlLh-X2Je8,16001
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
11
11
  llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
@@ -18,16 +18,16 @@ llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,17
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
21
- llm_trainer/trainer.py,sha256=g8YUP0FmBP3MGwewyoyOW35p9CY98rS62pzjnOMiWvE,25875
21
+ llm_trainer/trainer.py,sha256=VHnuL8rZgxj2ewBVxmbN0jwVEPRVhbK2V2DYCtm_FxI,25886
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.3.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.3.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.3.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.3.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.3.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.3.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.3.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.3.dist-info/METADATA,sha256=iDB3C1trVLQsnwsRxeFm7Oi2YpNevuX3XO2WZFlL7wg,195
31
- project_llm_trainer-0.5.3.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.3.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.3.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.4.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.4.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.4.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.4.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.4.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.4.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.4.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.4.dist-info/METADATA,sha256=lZWPvJQFiqZTl9b1FUSFl1Fl6berO7DKGZ-A_7ZkidE,195
31
+ project_llm_trainer-0.5.4.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.4.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.4.dist-info/RECORD,,