project-llm-trainer 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -104,17 +104,14 @@ def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
104
104
  # 暂时只保存主进程的
105
105
  if TrainerTools().parallel.is_main_process:
106
106
  steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
107
- ckpt = {'global_steps': global_steps, 'lr_steps': lr_scheduler.cur_steps}
107
+ ckpt = {'global_steps': global_steps}
108
+ ckpt.update(lr_scheduler.get_ckpt_dict())
108
109
  torch.save(ckpt, steps_checkpoint_name)
109
110
 
110
111
 
111
- def load_steps(
112
- default_global_steps: int = 0,
113
- default_lr_steps: int = 0
114
- ) -> Tuple[Optional[int], Optional[int]]:
112
+ def load_steps() -> Optional[dict]:
115
113
  steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
116
114
  if os.path.exists(steps_checkpoint_name):
117
- ckpt = torch.load(steps_checkpoint_name, weights_only=True)
118
- return ckpt['global_steps'], ckpt['lr_steps']
115
+ return torch.load(steps_checkpoint_name, weights_only=True)
119
116
 
120
- return default_global_steps, default_lr_steps
117
+ return None
llm_trainer/scheduler.py CHANGED
@@ -15,15 +15,18 @@ class LRScheduler(ABC):
15
15
  @abstractmethod
16
16
  def cur_lr(self): ...
17
17
 
18
- @abstractmethod
19
- def update_steps(self, steps): ...
20
-
21
18
  @abstractmethod
22
19
  def step(self): ...
23
20
 
24
21
  @abstractmethod
25
22
  def can_clip_grad(self): ...
26
23
 
24
+ @abstractmethod
25
+ def get_ckpt_dict(self) -> dict: ...
26
+
27
+ @abstractmethod
28
+ def restore_ckpt_dict(self, ckpt: dict): ...
29
+
27
30
 
28
31
  class WarmupCosineAnnealingLRScheduler(LRScheduler):
29
32
  def __init__(
@@ -72,11 +75,6 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
72
75
  def cur_lr(self):
73
76
  return self._current_lr
74
77
 
75
- def update_steps(self, steps):
76
- log(f'update step to {steps}')
77
- self._steps = steps
78
- self._update_lr()
79
-
80
78
  def step(self):
81
79
  self._steps += 1
82
80
  self._update_lr()
@@ -122,6 +120,33 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
122
120
  if self.need_log:
123
121
  log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
124
122
 
123
+ def get_ckpt_dict(self) -> dict:
124
+ return {
125
+ 'cur_lr': self._current_lr,
126
+ 'lr_steps': self.cur_steps,
127
+ 'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
128
+ 't_cur': self.T_cur,
129
+ 'cycle': self.cycle,
130
+ }
131
+
132
+ def restore_ckpt_dict(self, ckpt: dict):
133
+ if ckpt['cur_lr']:
134
+ self._current_lr = ckpt['cur_lr']
135
+
136
+ if ckpt['lr_steps']:
137
+ self._steps = ckpt['lr_steps']
138
+
139
+ if ckpt['cosine_annealing_base_lr']:
140
+ self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
141
+
142
+ if ckpt['t_cur']:
143
+ self.T_cur = ckpt['t_cur']
144
+
145
+ if ckpt['cycle']:
146
+ self.cycle = ckpt['cycle']
147
+
148
+ self._update_lr()
149
+
125
150
 
126
151
  class NoneLRScheduler(LRScheduler):
127
152
  def __init__(self, initial_lr):
@@ -135,9 +160,14 @@ class NoneLRScheduler(LRScheduler):
135
160
  def cur_lr(self):
136
161
  return self._current_lr
137
162
 
138
- def update_steps(self, steps): ...
139
-
140
163
  def step(self): ...
141
164
 
142
165
  def can_clip_grad(self):
143
- return True
166
+ return True
167
+
168
+ def get_ckpt_dict(self) -> dict:
169
+ return {'cur_lr': self._current_lr}
170
+
171
+ def restore_ckpt_dict(self, ckpt: dict):
172
+ if ckpt['cur_lr']:
173
+ self._current_lr = ckpt['cur_lr']
llm_trainer/trainer.py CHANGED
@@ -59,6 +59,7 @@ class Trainer:
59
59
  self.eval_prompts = eval_prompts
60
60
  self.eval_image_tags = eval_image_tags
61
61
  self.eval_idx = -1
62
+ self.last_global_steps = 0
62
63
 
63
64
  if self.eval_image_tags:
64
65
  assert len(self.eval_prompts) == len(self.eval_image_tags)
@@ -92,12 +93,15 @@ class Trainer:
92
93
  device=TrainerTools().parallel.device
93
94
  )
94
95
 
95
- last_global_steps, last_lr_steps = load_steps(0, -1)
96
- self.last_global_steps = last_global_steps
97
- log(f'last_global_steps={last_global_steps}, last_lr_steps={last_lr_steps}')
96
+ steps_dict = load_steps()
97
+ if steps_dict:
98
+ self.last_global_steps = steps_dict['global_steps']
99
+ if not self.last_global_steps:
100
+ self.last_global_steps = 0
98
101
 
99
- if last_lr_steps != -1:
100
- self.lr_scheduler.update_steps(last_lr_steps)
102
+ self.lr_scheduler.restore_ckpt_dict(steps_dict)
103
+
104
+ log(f'restore steps_dict = {steps_dict}')
101
105
 
102
106
  if isinstance(train_config.model_config, VLMConfig):
103
107
  self.pixel_values_provider = train_config.pixel_values_provider
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.6
3
+ Version: 0.5.8
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,5 +1,5 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=GHRnPpvG0lz8mg2qv0itHr1rXLlj-itOqZWCtOV1IRU,4411
2
+ llm_trainer/checkpoint.py,sha256=RoRlIB-Qtvl3MyY3g0FbEBHLpFRLnEMZLpEncXOLToQ,4242
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
5
  llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
@@ -13,21 +13,21 @@ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
15
  llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
- llm_trainer/scheduler.py,sha256=lyC9TFuF_y8EXYq9d-WAqN4CSaq_w9kSKeh_BOo3EpI,4039
16
+ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
21
- llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
21
+ llm_trainer/trainer.py,sha256=U26dZc22nByfTZUzKeEiqqYVexBzgw0ep7N0Z2zIcWI,26141
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.6.dist-info/METADATA,sha256=5JaiIS6GqDsm9o_6Zz-vkec_0rmENISg32Qld1Gn-u8,195
31
- project_llm_trainer-0.5.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.6.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.8.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.8.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.8.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.8.dist-info/METADATA,sha256=54q4Nl2EWMYwShSqS8cLLqZ0iJVraAvJCGz8QPiVMiE,195
31
+ project_llm_trainer-0.5.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.8.dist-info/RECORD,,