project-llm-trainer 0.5.5__py3-none-any.whl → 0.5.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

llm_trainer/checkpoint.py CHANGED
@@ -38,29 +38,32 @@ def save_best_checkpoint(
38
38
  ) -> bool:
39
39
  need_replace = not last_best_checkpoint_loss or current_loss <= last_best_checkpoint_loss
40
40
  if need_replace and TrainerTools().parallel.is_main_process:
41
- if isinstance(TrainerTools().parallel, DsParallel):
42
- checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
41
+ try:
42
+ if isinstance(TrainerTools().parallel, DsParallel):
43
+ checkpoint_dir = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
43
44
 
44
- if checkpoint_dir.endswith('/'):
45
- best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
46
- else:
47
- best_checkpoint_dir = f'{checkpoint_dir}_best'
45
+ if checkpoint_dir.endswith('/'):
46
+ best_checkpoint_dir = f'{checkpoint_dir[:-1]}_best'
47
+ else:
48
+ best_checkpoint_dir = f'{checkpoint_dir}_best'
48
49
 
49
- if not os.path.exists(best_checkpoint_dir):
50
- os.makedirs(best_checkpoint_dir)
50
+ if not os.path.exists(best_checkpoint_dir):
51
+ os.makedirs(best_checkpoint_dir)
51
52
 
52
- if os.path.exists(checkpoint_dir):
53
- shutil.rmtree(best_checkpoint_dir)
54
- shutil.copytree(checkpoint_dir, best_checkpoint_dir)
55
- else:
56
- checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
57
- best_checkpoint_name = f'{checkpoint_name}_best'
53
+ if os.path.exists(checkpoint_dir):
54
+ shutil.rmtree(best_checkpoint_dir)
55
+ shutil.copytree(checkpoint_dir, best_checkpoint_dir)
56
+ else:
57
+ checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
58
+ best_checkpoint_name = f'{checkpoint_name}_best'
58
59
 
59
- if os.path.exists(checkpoint_name):
60
- if os.path.exists(best_checkpoint_name):
61
- os.remove(best_checkpoint_name)
60
+ if os.path.exists(checkpoint_name):
61
+ if os.path.exists(best_checkpoint_name):
62
+ os.remove(best_checkpoint_name)
62
63
 
63
- shutil.copy2(checkpoint_name, best_checkpoint_name)
64
+ shutil.copy2(checkpoint_name, best_checkpoint_name)
65
+ except:
66
+ pass
64
67
 
65
68
  TrainerTools().parallel.wait()
66
69
  return need_replace
@@ -101,17 +104,14 @@ def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
101
104
  # 暂时只保存主进程的
102
105
  if TrainerTools().parallel.is_main_process:
103
106
  steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
104
- ckpt = {'global_steps': global_steps, 'lr_steps': lr_scheduler.cur_steps}
107
+ ckpt = {'global_steps': global_steps}
108
+ ckpt.update(lr_scheduler.get_ckpt_dict())
105
109
  torch.save(ckpt, steps_checkpoint_name)
106
110
 
107
111
 
108
- def load_steps(
109
- default_global_steps: int = 0,
110
- default_lr_steps: int = 0
111
- ) -> Tuple[Optional[int], Optional[int]]:
112
+ def load_steps() -> Optional[dict]:
112
113
  steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
113
114
  if os.path.exists(steps_checkpoint_name):
114
- ckpt = torch.load(steps_checkpoint_name, weights_only=True)
115
- return ckpt['global_steps'], ckpt['lr_steps']
115
+ return torch.load(steps_checkpoint_name, weights_only=True)
116
116
 
117
- return default_global_steps, default_lr_steps
117
+ return None
llm_trainer/scheduler.py CHANGED
@@ -15,15 +15,18 @@ class LRScheduler(ABC):
15
15
  @abstractmethod
16
16
  def cur_lr(self): ...
17
17
 
18
- @abstractmethod
19
- def update_steps(self, steps): ...
20
-
21
18
  @abstractmethod
22
19
  def step(self): ...
23
20
 
24
21
  @abstractmethod
25
22
  def can_clip_grad(self): ...
26
23
 
24
+ @abstractmethod
25
+ def get_ckpt_dict(self) -> dict: ...
26
+
27
+ @abstractmethod
28
+ def restore_ckpt_dict(self, ckpt: dict): ...
29
+
27
30
 
28
31
  class WarmupCosineAnnealingLRScheduler(LRScheduler):
29
32
  def __init__(
@@ -72,11 +75,6 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
72
75
  def cur_lr(self):
73
76
  return self._current_lr
74
77
 
75
- def update_steps(self, steps):
76
- log(f'update step to {steps}')
77
- self._steps = steps
78
- self._update_lr()
79
-
80
78
  def step(self):
81
79
  self._steps += 1
82
80
  self._update_lr()
@@ -122,6 +120,33 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
122
120
  if self.need_log:
123
121
  log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
124
122
 
123
+ def get_ckpt_dict(self) -> dict:
124
+ return {
125
+ 'cur_lr': self._current_lr,
126
+ 'lr_steps': self.cur_steps,
127
+ 'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
128
+ 't_cur': self.T_cur,
129
+ 'cycle': self.cycle,
130
+ }
131
+
132
+ def restore_ckpt_dict(self, ckpt: dict):
133
+ if ckpt['cur_lr']:
134
+ self._current_lr = ckpt['cur_lr']
135
+
136
+ if ckpt['lr_steps']:
137
+ self._steps = ckpt['lr_steps']
138
+
139
+ if ckpt['cosine_annealing_base_lr']:
140
+ self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
141
+
142
+ if ckpt['t_cur']:
143
+ self.T_cur = ckpt['t_cur']
144
+
145
+ if ckpt['cycle']:
146
+ self.cycle = ckpt['cycle']
147
+
148
+ self._update_lr()
149
+
125
150
 
126
151
  class NoneLRScheduler(LRScheduler):
127
152
  def __init__(self, initial_lr):
@@ -135,9 +160,14 @@ class NoneLRScheduler(LRScheduler):
135
160
  def cur_lr(self):
136
161
  return self._current_lr
137
162
 
138
- def update_steps(self, steps): ...
139
-
140
163
  def step(self): ...
141
164
 
142
165
  def can_clip_grad(self):
143
- return True
166
+ return True
167
+
168
+ def get_ckpt_dict(self) -> dict:
169
+ return {'cur_lr': self._current_lr}
170
+
171
+ def restore_ckpt_dict(self, ckpt: dict):
172
+ if ckpt['cur_lr']:
173
+ self._current_lr = ckpt['cur_lr']
llm_trainer/trainer.py CHANGED
@@ -92,12 +92,15 @@ class Trainer:
92
92
  device=TrainerTools().parallel.device
93
93
  )
94
94
 
95
- last_global_steps, last_lr_steps = load_steps(0, -1)
96
- self.last_global_steps = last_global_steps
97
- log(f'last_global_steps={last_global_steps}, last_lr_steps={last_lr_steps}')
95
+ steps_dict = load_steps()
96
+ if steps_dict:
97
+ self.last_global_steps = steps_dict['global_steps']
98
+ if not self.last_global_steps:
99
+ self.last_global_steps = 0
98
100
 
99
- if last_lr_steps != -1:
100
- self.lr_scheduler.update_steps(last_lr_steps)
101
+ self.lr_scheduler.restore_ckpt_dict(steps_dict)
102
+
103
+ log(f'restore steps_dict = {steps_dict}')
101
104
 
102
105
  if isinstance(train_config.model_config, VLMConfig):
103
106
  self.pixel_values_provider = train_config.pixel_values_provider
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.5
3
+ Version: 0.5.7
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -1,5 +1,5 @@
1
1
  llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
2
- llm_trainer/checkpoint.py,sha256=UVjOaDsiSIzRJ5VJZib6iXrdKv2A7K_gtJw3a9wNyoM,4293
2
+ llm_trainer/checkpoint.py,sha256=RoRlIB-Qtvl3MyY3g0FbEBHLpFRLnEMZLpEncXOLToQ,4242
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
5
5
  llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
@@ -13,21 +13,21 @@ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1
13
13
  llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
14
14
  llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
15
15
  llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
- llm_trainer/scheduler.py,sha256=lyC9TFuF_y8EXYq9d-WAqN4CSaq_w9kSKeh_BOo3EpI,4039
16
+ llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
17
17
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
18
  llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
21
- llm_trainer/trainer.py,sha256=YW59dJWTyQy77cLDGzBHhfinGyfkvmWCkl1SR9hM6a8,26071
21
+ llm_trainer/trainer.py,sha256=sqN5cXsFAH9xe8-px6tAgcUe5nw6iZU5PEjT9mgEusE,26106
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.5.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.5.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.5.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.5.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.5.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.5.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.5.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.5.dist-info/METADATA,sha256=ajxfapuo4Q2xfdJ3kjZoCzs7Q5ynGp6BssXRFOIbF7Y,195
31
- project_llm_trainer-0.5.5.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.5.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.5.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.7.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.7.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.7.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.7.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.7.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.7.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.7.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.7.dist-info/METADATA,sha256=3yxEJlE4psbIzjpHGKnrpz04BT1n03vUR0xlnqu0-V0,195
31
+ project_llm_trainer-0.5.7.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.7.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.7.dist-info/RECORD,,