project-llm-trainer 0.5.6__py3-none-any.whl → 0.5.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +5 -8
- llm_trainer/scheduler.py +41 -11
- llm_trainer/trainer.py +9 -5
- {project_llm_trainer-0.5.6.dist-info → project_llm_trainer-0.5.8.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.6.dist-info → project_llm_trainer-0.5.8.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.6.dist-info → project_llm_trainer-0.5.8.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.6.dist-info → project_llm_trainer-0.5.8.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -104,17 +104,14 @@ def save_steps(global_steps: int, lr_scheduler: Optional[LRScheduler] = None):
|
|
|
104
104
|
# 暂时只保存主进程的
|
|
105
105
|
if TrainerTools().parallel.is_main_process:
|
|
106
106
|
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
107
|
-
ckpt = {'global_steps': global_steps
|
|
107
|
+
ckpt = {'global_steps': global_steps}
|
|
108
|
+
ckpt.update(lr_scheduler.get_ckpt_dict())
|
|
108
109
|
torch.save(ckpt, steps_checkpoint_name)
|
|
109
110
|
|
|
110
111
|
|
|
111
|
-
def load_steps(
|
|
112
|
-
default_global_steps: int = 0,
|
|
113
|
-
default_lr_steps: int = 0
|
|
114
|
-
) -> Tuple[Optional[int], Optional[int]]:
|
|
112
|
+
def load_steps() -> Optional[dict]:
|
|
115
113
|
steps_checkpoint_name = f"{os.environ.get('LOG_DIR', './')}steps.pt"
|
|
116
114
|
if os.path.exists(steps_checkpoint_name):
|
|
117
|
-
|
|
118
|
-
return ckpt['global_steps'], ckpt['lr_steps']
|
|
115
|
+
return torch.load(steps_checkpoint_name, weights_only=True)
|
|
119
116
|
|
|
120
|
-
return
|
|
117
|
+
return None
|
llm_trainer/scheduler.py
CHANGED
|
@@ -15,15 +15,18 @@ class LRScheduler(ABC):
|
|
|
15
15
|
@abstractmethod
|
|
16
16
|
def cur_lr(self): ...
|
|
17
17
|
|
|
18
|
-
@abstractmethod
|
|
19
|
-
def update_steps(self, steps): ...
|
|
20
|
-
|
|
21
18
|
@abstractmethod
|
|
22
19
|
def step(self): ...
|
|
23
20
|
|
|
24
21
|
@abstractmethod
|
|
25
22
|
def can_clip_grad(self): ...
|
|
26
23
|
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def get_ckpt_dict(self) -> dict: ...
|
|
26
|
+
|
|
27
|
+
@abstractmethod
|
|
28
|
+
def restore_ckpt_dict(self, ckpt: dict): ...
|
|
29
|
+
|
|
27
30
|
|
|
28
31
|
class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
29
32
|
def __init__(
|
|
@@ -72,11 +75,6 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
72
75
|
def cur_lr(self):
|
|
73
76
|
return self._current_lr
|
|
74
77
|
|
|
75
|
-
def update_steps(self, steps):
|
|
76
|
-
log(f'update step to {steps}')
|
|
77
|
-
self._steps = steps
|
|
78
|
-
self._update_lr()
|
|
79
|
-
|
|
80
78
|
def step(self):
|
|
81
79
|
self._steps += 1
|
|
82
80
|
self._update_lr()
|
|
@@ -122,6 +120,33 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
122
120
|
if self.need_log:
|
|
123
121
|
log(f"step={self.cur_steps},lr={lr}\n", f'{get_log_dir()}lr.txt')
|
|
124
122
|
|
|
123
|
+
def get_ckpt_dict(self) -> dict:
|
|
124
|
+
return {
|
|
125
|
+
'cur_lr': self._current_lr,
|
|
126
|
+
'lr_steps': self.cur_steps,
|
|
127
|
+
'cosine_annealing_base_lr': self._cosine_annealing_base_lr,
|
|
128
|
+
't_cur': self.T_cur,
|
|
129
|
+
'cycle': self.cycle,
|
|
130
|
+
}
|
|
131
|
+
|
|
132
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
133
|
+
if ckpt['cur_lr']:
|
|
134
|
+
self._current_lr = ckpt['cur_lr']
|
|
135
|
+
|
|
136
|
+
if ckpt['lr_steps']:
|
|
137
|
+
self._steps = ckpt['lr_steps']
|
|
138
|
+
|
|
139
|
+
if ckpt['cosine_annealing_base_lr']:
|
|
140
|
+
self._cosine_annealing_base_lr = ckpt['cosine_annealing_base_lr']
|
|
141
|
+
|
|
142
|
+
if ckpt['t_cur']:
|
|
143
|
+
self.T_cur = ckpt['t_cur']
|
|
144
|
+
|
|
145
|
+
if ckpt['cycle']:
|
|
146
|
+
self.cycle = ckpt['cycle']
|
|
147
|
+
|
|
148
|
+
self._update_lr()
|
|
149
|
+
|
|
125
150
|
|
|
126
151
|
class NoneLRScheduler(LRScheduler):
|
|
127
152
|
def __init__(self, initial_lr):
|
|
@@ -135,9 +160,14 @@ class NoneLRScheduler(LRScheduler):
|
|
|
135
160
|
def cur_lr(self):
|
|
136
161
|
return self._current_lr
|
|
137
162
|
|
|
138
|
-
def update_steps(self, steps): ...
|
|
139
|
-
|
|
140
163
|
def step(self): ...
|
|
141
164
|
|
|
142
165
|
def can_clip_grad(self):
|
|
143
|
-
return True
|
|
166
|
+
return True
|
|
167
|
+
|
|
168
|
+
def get_ckpt_dict(self) -> dict:
|
|
169
|
+
return {'cur_lr': self._current_lr}
|
|
170
|
+
|
|
171
|
+
def restore_ckpt_dict(self, ckpt: dict):
|
|
172
|
+
if ckpt['cur_lr']:
|
|
173
|
+
self._current_lr = ckpt['cur_lr']
|
llm_trainer/trainer.py
CHANGED
|
@@ -59,6 +59,7 @@ class Trainer:
|
|
|
59
59
|
self.eval_prompts = eval_prompts
|
|
60
60
|
self.eval_image_tags = eval_image_tags
|
|
61
61
|
self.eval_idx = -1
|
|
62
|
+
self.last_global_steps = 0
|
|
62
63
|
|
|
63
64
|
if self.eval_image_tags:
|
|
64
65
|
assert len(self.eval_prompts) == len(self.eval_image_tags)
|
|
@@ -92,12 +93,15 @@ class Trainer:
|
|
|
92
93
|
device=TrainerTools().parallel.device
|
|
93
94
|
)
|
|
94
95
|
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
|
|
96
|
+
steps_dict = load_steps()
|
|
97
|
+
if steps_dict:
|
|
98
|
+
self.last_global_steps = steps_dict['global_steps']
|
|
99
|
+
if not self.last_global_steps:
|
|
100
|
+
self.last_global_steps = 0
|
|
98
101
|
|
|
99
|
-
|
|
100
|
-
|
|
102
|
+
self.lr_scheduler.restore_ckpt_dict(steps_dict)
|
|
103
|
+
|
|
104
|
+
log(f'restore steps_dict = {steps_dict}')
|
|
101
105
|
|
|
102
106
|
if isinstance(train_config.model_config, VLMConfig):
|
|
103
107
|
self.pixel_values_provider = train_config.pixel_values_provider
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=RoRlIB-Qtvl3MyY3g0FbEBHLpFRLnEMZLpEncXOLToQ,4242
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dpo_trainer.py,sha256=1A_4QP2_xqM_YeqdXy-0RaMvEL80gim-pgnPQyHww9U,12052
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=D092fkS1Up4QmpV9YCpqbSzfX_caCAeX-UiOrhOE1I8,1947
|
|
@@ -13,21 +13,21 @@ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1
|
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
-
llm_trainer/scheduler.py,sha256=
|
|
16
|
+
llm_trainer/scheduler.py,sha256=LAI_0VxClsIQkix0bRoduRD4vPfVuIZDhZgTAT_KK8k,4901
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
18
|
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
21
|
+
llm_trainer/trainer.py,sha256=U26dZc22nByfTZUzKeEiqqYVexBzgw0ep7N0Z2zIcWI,26141
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.8.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.8.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.8.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.8.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.8.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.8.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.8.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.8.dist-info/METADATA,sha256=54q4Nl2EWMYwShSqS8cLLqZ0iJVraAvJCGz8QPiVMiE,195
|
|
31
|
+
project_llm_trainer-0.5.8.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.8.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.8.dist-info/RECORD,,
|
{project_llm_trainer-0.5.6.data → project_llm_trainer-0.5.8.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|