project-llm-trainer 0.5.0__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/checkpoint.py +37 -0
- llm_trainer/dpo_trainer.py +12 -1
- llm_trainer/generate_utils.py +7 -68
- llm_trainer/grpo_trainer.py +13 -1
- llm_trainer/scheduler.py +12 -7
- llm_trainer/tokenizer.py +10 -10
- llm_trainer/train_configs.py +4 -4
- llm_trainer/trainer.py +20 -9
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.2.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.2.dist-info}/RECORD +19 -19
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.2.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.2.dist-info}/top_level.txt +0 -0
llm_trainer/checkpoint.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import os
|
|
2
2
|
from typing import Optional, Union, Tuple
|
|
3
|
+
import shutil
|
|
3
4
|
import torch
|
|
4
5
|
from torch import nn
|
|
5
6
|
from torch.optim import Optimizer
|
|
@@ -34,6 +35,42 @@ def save_checkpoint(
|
|
|
34
35
|
torch.save(ckpt, checkpoint_name)
|
|
35
36
|
|
|
36
37
|
|
|
38
|
+
def save_best_checkpoint(
|
|
39
|
+
current_loss: float,
|
|
40
|
+
last_best_checkpoint_loss: float,
|
|
41
|
+
suffix: Optional[str] = None
|
|
42
|
+
) -> bool:
|
|
43
|
+
need_replace = current_loss <= last_best_checkpoint_loss
|
|
44
|
+
if need_replace and TrainerTools().parallel.is_main_process:
|
|
45
|
+
if isinstance(TrainerTools().parallel, DsParallel):
|
|
46
|
+
checkpoint_name = os.environ.get('DIST_CHECKPOINT_DIR', 'checkpoint')
|
|
47
|
+
if suffix:
|
|
48
|
+
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
49
|
+
|
|
50
|
+
best_checkpoint_name = f'best_{checkpoint_name}'
|
|
51
|
+
if not os.path.exists(best_checkpoint_name):
|
|
52
|
+
os.makedirs(best_checkpoint_name)
|
|
53
|
+
|
|
54
|
+
if os.path.exists(checkpoint_name):
|
|
55
|
+
shutil.rmtree(best_checkpoint_name)
|
|
56
|
+
shutil.copytree(checkpoint_name, best_checkpoint_name)
|
|
57
|
+
else:
|
|
58
|
+
checkpoint_name = os.environ.get('CHECKPOINT_NAME', DEFAULT_CHECKPOINT_NAME)
|
|
59
|
+
if suffix:
|
|
60
|
+
checkpoint_name = f"{checkpoint_name}_{suffix}"
|
|
61
|
+
|
|
62
|
+
best_checkpoint_name = f'best_{checkpoint_name}'
|
|
63
|
+
|
|
64
|
+
if os.path.exists(checkpoint_name):
|
|
65
|
+
if os.path.exists(best_checkpoint_name):
|
|
66
|
+
os.remove(best_checkpoint_name)
|
|
67
|
+
|
|
68
|
+
shutil.copy2(checkpoint_name, best_checkpoint_name)
|
|
69
|
+
|
|
70
|
+
TrainerTools().parallel.wait()
|
|
71
|
+
return need_replace
|
|
72
|
+
|
|
73
|
+
|
|
37
74
|
def load_checkpoint(
|
|
38
75
|
model: nn.Module,
|
|
39
76
|
optimizer: Optional[Optimizer] = None,
|
llm_trainer/dpo_trainer.py
CHANGED
|
@@ -16,6 +16,7 @@ from .partition_utils import sync_model_params
|
|
|
16
16
|
|
|
17
17
|
from .checkpoint import (
|
|
18
18
|
save_checkpoint,
|
|
19
|
+
save_best_checkpoint,
|
|
19
20
|
save_steps,
|
|
20
21
|
)
|
|
21
22
|
|
|
@@ -139,6 +140,9 @@ class DPOTrainer(Trainer):
|
|
|
139
140
|
loss_accumulation = 0.0
|
|
140
141
|
skipping_train = False
|
|
141
142
|
|
|
143
|
+
current_loss: float = 0.0
|
|
144
|
+
last_best_checkpoint_loss: float = 0.0
|
|
145
|
+
|
|
142
146
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
143
147
|
|
|
144
148
|
for epoch in range(self.train_config.n_epochs):
|
|
@@ -243,6 +247,9 @@ class DPOTrainer(Trainer):
|
|
|
243
247
|
|
|
244
248
|
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
245
249
|
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
250
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
251
|
+
last_best_checkpoint_loss = current_loss
|
|
252
|
+
|
|
246
253
|
last_ckpt_batch = batch
|
|
247
254
|
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
248
255
|
|
|
@@ -252,8 +259,12 @@ class DPOTrainer(Trainer):
|
|
|
252
259
|
|
|
253
260
|
# end epoch
|
|
254
261
|
if not skipping_train:
|
|
255
|
-
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
256
262
|
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
263
|
+
|
|
264
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
265
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
266
|
+
last_best_checkpoint_loss = current_loss
|
|
267
|
+
|
|
257
268
|
TrainerTools().parallel.on_epoch_end(epoch)
|
|
258
269
|
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
259
270
|
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -107,8 +107,7 @@ def _generate(
|
|
|
107
107
|
pixel_values: Optional[torch.Tensor] = None,
|
|
108
108
|
tokens_per_image: int = -1,
|
|
109
109
|
suppress_tokens: Optional[List[int]] = None,
|
|
110
|
-
device: Union[str, torch.device, int]
|
|
111
|
-
reasoning_budget: Optional[int] = None
|
|
110
|
+
device: Union[str, torch.device, int]
|
|
112
111
|
):
|
|
113
112
|
"""
|
|
114
113
|
:param model:
|
|
@@ -142,27 +141,6 @@ def _generate(
|
|
|
142
141
|
kv_cache: Optional[KVCache] = None
|
|
143
142
|
generate_tokens = tokens.clone()
|
|
144
143
|
|
|
145
|
-
reasoning_start = TrainerTools().tokenizer.reasoning_start
|
|
146
|
-
reasoning_end = TrainerTools().tokenizer.reasoning_end
|
|
147
|
-
|
|
148
|
-
# --- 状态初始化 ---
|
|
149
|
-
in_reasoning_block = False
|
|
150
|
-
reasoning_step_count = 0
|
|
151
|
-
# “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
|
|
152
|
-
suppress_reasoning_start_next = False
|
|
153
|
-
|
|
154
|
-
if reasoning_budget is not None:
|
|
155
|
-
prompt_tokens = tokens[0]
|
|
156
|
-
start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
|
|
157
|
-
end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
|
|
158
|
-
|
|
159
|
-
last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
|
|
160
|
-
last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
|
|
161
|
-
|
|
162
|
-
if last_start_idx > last_end_idx:
|
|
163
|
-
in_reasoning_block = True
|
|
164
|
-
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
165
|
-
|
|
166
144
|
with torch.inference_mode():
|
|
167
145
|
for _ in range(max_new_tokens):
|
|
168
146
|
# 是否需要截取??
|
|
@@ -182,23 +160,6 @@ def _generate(
|
|
|
182
160
|
# (batch, vocab_size)
|
|
183
161
|
logits = logits[:, -1, :]
|
|
184
162
|
|
|
185
|
-
# --- 推理预算逻辑 ---
|
|
186
|
-
force_end_reasoning_token = False
|
|
187
|
-
if reasoning_budget is not None:
|
|
188
|
-
# 检查是否需要在此步抑制 <reasoning>
|
|
189
|
-
should_suppress_this_step = suppress_reasoning_start_next
|
|
190
|
-
suppress_reasoning_start_next = False # 立即重置标志位
|
|
191
|
-
|
|
192
|
-
# 修改: 检查是否超出预算
|
|
193
|
-
if in_reasoning_block and reasoning_step_count >= reasoning_budget:
|
|
194
|
-
force_end_reasoning_token = True
|
|
195
|
-
# 设置标志位,在下一步抑制 <reasoning>
|
|
196
|
-
suppress_reasoning_start_next = True
|
|
197
|
-
|
|
198
|
-
# 如果上一轮设置了抑制标志,则在此轮执行抑制
|
|
199
|
-
if should_suppress_this_step:
|
|
200
|
-
logits[:, reasoning_start] = -float("inf")
|
|
201
|
-
|
|
202
163
|
# 抑制特殊token输出
|
|
203
164
|
if suppress_tokens and len(suppress_tokens) != 0:
|
|
204
165
|
logits = _suppress_warper(logits, suppress_tokens)
|
|
@@ -214,10 +175,6 @@ def _generate(
|
|
|
214
175
|
if p and 0 < p <= 1:
|
|
215
176
|
logits = _top_p_warper(logits, p)
|
|
216
177
|
|
|
217
|
-
if force_end_reasoning_token:
|
|
218
|
-
logits[:] = -float("inf")
|
|
219
|
-
logits[:, reasoning_end] = 0.0
|
|
220
|
-
|
|
221
178
|
if multinomial:
|
|
222
179
|
prob = logits.softmax(dim=-1)
|
|
223
180
|
# 返回下标
|
|
@@ -226,18 +183,6 @@ def _generate(
|
|
|
226
183
|
# 返回下标
|
|
227
184
|
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
228
185
|
|
|
229
|
-
if reasoning_budget is not None:
|
|
230
|
-
current_token_id = next_token.item()
|
|
231
|
-
if not in_reasoning_block and current_token_id == reasoning_start:
|
|
232
|
-
in_reasoning_block = True
|
|
233
|
-
reasoning_step_count = 0
|
|
234
|
-
elif in_reasoning_block:
|
|
235
|
-
if current_token_id == reasoning_end:
|
|
236
|
-
in_reasoning_block = False
|
|
237
|
-
reasoning_step_count = 0
|
|
238
|
-
else:
|
|
239
|
-
reasoning_step_count += 1
|
|
240
|
-
|
|
241
186
|
# token, is_full_result
|
|
242
187
|
yield next_token, False
|
|
243
188
|
|
|
@@ -266,8 +211,7 @@ def _streaming_generate(
|
|
|
266
211
|
pixel_values: Optional[torch.Tensor] = None,
|
|
267
212
|
tokens_per_image: int = -1,
|
|
268
213
|
suppress_tokens: Optional[List[int]] = None,
|
|
269
|
-
device: Union[str, torch.device, int] = None
|
|
270
|
-
reasoning_budget: Optional[int] = None
|
|
214
|
+
device: Union[str, torch.device, int] = None
|
|
271
215
|
):
|
|
272
216
|
device = TrainerTools().parallel.device if not device else device
|
|
273
217
|
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
@@ -283,8 +227,7 @@ def _streaming_generate(
|
|
|
283
227
|
pixel_values=pixel_values,
|
|
284
228
|
tokens_per_image=tokens_per_image,
|
|
285
229
|
suppress_tokens=suppress_tokens,
|
|
286
|
-
device=device
|
|
287
|
-
reasoning_budget=reasoning_budget
|
|
230
|
+
device=device
|
|
288
231
|
)
|
|
289
232
|
|
|
290
233
|
for (token, is_full_result) in generate_text_iterator:
|
|
@@ -303,8 +246,7 @@ def streaming_generate(
|
|
|
303
246
|
pixel_values: Optional[torch.Tensor] = None,
|
|
304
247
|
tokens_per_image: int = -1,
|
|
305
248
|
suppress_tokens: Optional[List[int]] = None,
|
|
306
|
-
device: Union[str, torch.device, int] = None
|
|
307
|
-
reasoning_budget: Optional[int] = None
|
|
249
|
+
device: Union[str, torch.device, int] = None
|
|
308
250
|
):
|
|
309
251
|
text_iterator = _streaming_generate(
|
|
310
252
|
model=model,
|
|
@@ -317,8 +259,7 @@ def streaming_generate(
|
|
|
317
259
|
pixel_values=pixel_values,
|
|
318
260
|
tokens_per_image=tokens_per_image,
|
|
319
261
|
suppress_tokens=suppress_tokens,
|
|
320
|
-
device=device
|
|
321
|
-
reasoning_budget=reasoning_budget
|
|
262
|
+
device=device
|
|
322
263
|
)
|
|
323
264
|
|
|
324
265
|
for (token, is_full_result) in text_iterator:
|
|
@@ -338,8 +279,7 @@ def generate(
|
|
|
338
279
|
pixel_values: Optional[torch.Tensor] = None,
|
|
339
280
|
tokens_per_image: int = -1,
|
|
340
281
|
suppress_tokens: Optional[List[int]] = None,
|
|
341
|
-
device: Union[str, torch.device, int] = None
|
|
342
|
-
reasoning_budget: Optional[int] = None
|
|
282
|
+
device: Union[str, torch.device, int] = None
|
|
343
283
|
):
|
|
344
284
|
text_iterator = _streaming_generate(
|
|
345
285
|
model=model,
|
|
@@ -352,8 +292,7 @@ def generate(
|
|
|
352
292
|
suppress_tokens=suppress_tokens,
|
|
353
293
|
pixel_values=pixel_values,
|
|
354
294
|
tokens_per_image=tokens_per_image,
|
|
355
|
-
device=device
|
|
356
|
-
reasoning_budget=reasoning_budget
|
|
295
|
+
device=device
|
|
357
296
|
)
|
|
358
297
|
|
|
359
298
|
for (token, is_full_result) in text_iterator:
|
llm_trainer/grpo_trainer.py
CHANGED
|
@@ -22,6 +22,7 @@ from .partition_utils import (
|
|
|
22
22
|
|
|
23
23
|
from .checkpoint import (
|
|
24
24
|
save_checkpoint,
|
|
25
|
+
save_best_checkpoint,
|
|
25
26
|
save_steps,
|
|
26
27
|
)
|
|
27
28
|
|
|
@@ -280,6 +281,10 @@ class GRPOTrainer(Trainer):
|
|
|
280
281
|
def train(self):
|
|
281
282
|
global_steps = 0
|
|
282
283
|
skipping_train = False
|
|
284
|
+
|
|
285
|
+
current_loss: float = 0.0
|
|
286
|
+
last_best_checkpoint_loss: float = 0.0
|
|
287
|
+
|
|
283
288
|
aux_loss_coef = self.train_config.loss_config.aux_loss_coef
|
|
284
289
|
|
|
285
290
|
for epoch in range(self.train_config.n_epochs):
|
|
@@ -361,6 +366,9 @@ class GRPOTrainer(Trainer):
|
|
|
361
366
|
|
|
362
367
|
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
363
368
|
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
369
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
370
|
+
last_best_checkpoint_loss = current_loss
|
|
371
|
+
|
|
364
372
|
last_ckpt_batch = batch
|
|
365
373
|
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
366
374
|
|
|
@@ -370,8 +378,12 @@ class GRPOTrainer(Trainer):
|
|
|
370
378
|
|
|
371
379
|
# end epoch
|
|
372
380
|
if not skipping_train:
|
|
373
|
-
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
374
381
|
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
382
|
+
|
|
383
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
384
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
385
|
+
last_best_checkpoint_loss = current_loss
|
|
386
|
+
|
|
375
387
|
TrainerTools().parallel.on_epoch_end(epoch)
|
|
376
388
|
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
377
389
|
|
llm_trainer/scheduler.py
CHANGED
|
@@ -30,12 +30,12 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
30
30
|
self,
|
|
31
31
|
*,
|
|
32
32
|
optimizer: torch.optim.Optimizer,
|
|
33
|
+
warmup_iters: int,
|
|
33
34
|
initial_lr: float,
|
|
34
35
|
min_lr: float,
|
|
35
36
|
max_lr: float,
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
period_mul: int = 1, # 周期长度的倍数
|
|
37
|
+
cosine_annealing_period: int, # 每个周期的步数
|
|
38
|
+
cosine_annealing_period_mul: int = 0, # 周期长度的倍数
|
|
39
39
|
need_log: bool = False
|
|
40
40
|
):
|
|
41
41
|
super().__init__()
|
|
@@ -46,8 +46,8 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
46
46
|
self._max_lr = max_lr
|
|
47
47
|
self._warmup_iters = warmup_iters
|
|
48
48
|
|
|
49
|
-
self.
|
|
50
|
-
self.
|
|
49
|
+
self._cosine_annealing_period = cosine_annealing_period
|
|
50
|
+
self._cosine_annealing_period_mul = cosine_annealing_period_mul
|
|
51
51
|
|
|
52
52
|
self.T_cur = 0 # 当前周期内已走过的步数
|
|
53
53
|
self.cycle = 0 # 当前周期编号
|
|
@@ -85,7 +85,12 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
85
85
|
return self._steps > self._warmup_iters
|
|
86
86
|
|
|
87
87
|
def _update_lr(self):
|
|
88
|
-
|
|
88
|
+
# 如果period_mul是0,则认为没有周期,超过余弦退火总步数,则一直保持最小lr
|
|
89
|
+
if self._cosine_annealing_period_mul == 0 and self._steps >= self._cosine_annealing_period + self._warmup_iters:
|
|
90
|
+
lr = self._min_lr
|
|
91
|
+
for param_group in self._optimizer.param_groups:
|
|
92
|
+
param_group['lr'] = lr
|
|
93
|
+
elif self._steps <= self._warmup_iters:
|
|
89
94
|
# Warmup: adjust learning rate linearly
|
|
90
95
|
# (max_lr - initial_lr) / warmup_iters
|
|
91
96
|
lr = self._initial_lr + self._steps * self._lr_increment
|
|
@@ -97,7 +102,7 @@ class WarmupCosineAnnealingLRScheduler(LRScheduler):
|
|
|
97
102
|
|
|
98
103
|
"""每步更新学习率"""
|
|
99
104
|
# 计算当前周期的最大步数
|
|
100
|
-
T_max = self.
|
|
105
|
+
T_max = self._cosine_annealing_period * (max(self._cosine_annealing_period_mul, 1) ** self.cycle)
|
|
101
106
|
|
|
102
107
|
# 更新周期状态
|
|
103
108
|
self.T_cur += 1
|
llm_trainer/tokenizer.py
CHANGED
|
@@ -26,8 +26,8 @@ class Tokenizer:
|
|
|
26
26
|
self.text_user = '<user>'
|
|
27
27
|
self.text_assistant = '<assistant>'
|
|
28
28
|
|
|
29
|
-
self.
|
|
30
|
-
self.
|
|
29
|
+
self.text_think_start = '<think>'
|
|
30
|
+
self.text_think_end = '</think>'
|
|
31
31
|
|
|
32
32
|
self.text_answer_start = '<answer>'
|
|
33
33
|
self.text_answer_end = '</answer>'
|
|
@@ -47,8 +47,8 @@ class Tokenizer:
|
|
|
47
47
|
additional_special_tokens = [
|
|
48
48
|
AddedToken(self.text_user, lstrip=False, rstrip=False),
|
|
49
49
|
AddedToken(self.text_assistant, lstrip=False, rstrip=False),
|
|
50
|
-
AddedToken(self.
|
|
51
|
-
AddedToken(self.
|
|
50
|
+
AddedToken(self.text_think_start, lstrip=False, rstrip=False),
|
|
51
|
+
AddedToken(self.text_think_end, lstrip=False, rstrip=False),
|
|
52
52
|
AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
|
|
53
53
|
AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
|
|
54
54
|
AddedToken(self.text_system, lstrip=False, rstrip=False),
|
|
@@ -69,8 +69,8 @@ class Tokenizer:
|
|
|
69
69
|
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
70
70
|
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
71
71
|
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
72
|
+
self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
|
|
73
|
+
self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
|
|
74
74
|
|
|
75
75
|
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
76
76
|
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
@@ -140,9 +140,9 @@ class Tokenizer:
|
|
|
140
140
|
{"role":"user", "content":"hello?"},
|
|
141
141
|
{"role":"assistant", "content":"hello"},
|
|
142
142
|
{"role":"user", "content":"hello hello?"},
|
|
143
|
-
{"role":"assistant", "
|
|
143
|
+
{"role":"assistant", "think":"thinking", "content":"hello hello"},
|
|
144
144
|
]
|
|
145
|
-
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><
|
|
145
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
|
|
146
146
|
"""
|
|
147
147
|
|
|
148
148
|
chat_template = ''
|
|
@@ -154,8 +154,8 @@ class Tokenizer:
|
|
|
154
154
|
if add_answer_tag_for_assistant and role == 'assistant':
|
|
155
155
|
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
156
156
|
|
|
157
|
-
if '
|
|
158
|
-
content = f"{self.
|
|
157
|
+
if 'think' in conversation:
|
|
158
|
+
content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
|
|
159
159
|
|
|
160
160
|
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
161
161
|
|
llm_trainer/train_configs.py
CHANGED
|
@@ -109,13 +109,13 @@ class DataLoaderConfig:
|
|
|
109
109
|
@dataclass(kw_only=True)
|
|
110
110
|
class LrConfig:
|
|
111
111
|
enable_lr_scheduler: bool = False
|
|
112
|
-
initial_lr:
|
|
112
|
+
initial_lr: float
|
|
113
113
|
weight_decay: float = 0.1
|
|
114
|
+
warmup_iters: Optional[int] = None
|
|
114
115
|
max_lr: Optional[float] = None
|
|
115
116
|
min_lr: Optional[float] = None
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
warmup_iters: Optional[int] = None
|
|
117
|
+
cosine_annealing_period: Optional[int] = None
|
|
118
|
+
cosine_annealing_period_mul: int = 0
|
|
119
119
|
|
|
120
120
|
|
|
121
121
|
@dataclass(kw_only=True)
|
llm_trainer/trainer.py
CHANGED
|
@@ -30,6 +30,7 @@ from .scheduler import (
|
|
|
30
30
|
from .checkpoint import (
|
|
31
31
|
load_checkpoint,
|
|
32
32
|
save_checkpoint,
|
|
33
|
+
save_best_checkpoint,
|
|
33
34
|
load_steps,
|
|
34
35
|
save_steps,
|
|
35
36
|
)
|
|
@@ -172,20 +173,20 @@ class Trainer:
|
|
|
172
173
|
|
|
173
174
|
def _init_lr_scheduler(self, initial_lr: float) -> LRScheduler:
|
|
174
175
|
if self.train_config.lr_config.enable_lr_scheduler:
|
|
176
|
+
warmup_iters = self.train_config.lr_config.warmup_iters
|
|
175
177
|
min_lr = self.train_config.lr_config.min_lr
|
|
176
178
|
max_lr = self.train_config.lr_config.max_lr
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
period_mul = self.train_config.lr_config.period_mul
|
|
179
|
+
cosine_annealing_period = self.train_config.lr_config.cosine_annealing_period
|
|
180
|
+
cosine_annealing_period_mul = self.train_config.lr_config.cosine_annealing_period_mul
|
|
180
181
|
|
|
181
182
|
return WarmupCosineAnnealingLRScheduler(
|
|
182
183
|
optimizer=self.optimizer,
|
|
184
|
+
warmup_iters=warmup_iters,
|
|
183
185
|
initial_lr=initial_lr,
|
|
184
186
|
min_lr=min_lr,
|
|
185
187
|
max_lr=max_lr,
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
period_mul=period_mul,
|
|
188
|
+
cosine_annealing_period=cosine_annealing_period,
|
|
189
|
+
cosine_annealing_period_mul=cosine_annealing_period_mul,
|
|
189
190
|
need_log=TrainerTools().parallel.is_main_process
|
|
190
191
|
)
|
|
191
192
|
|
|
@@ -467,6 +468,9 @@ class Trainer:
|
|
|
467
468
|
loss_accumulation = 0.0
|
|
468
469
|
skipping_train = False
|
|
469
470
|
|
|
471
|
+
current_loss: float = 0.0
|
|
472
|
+
last_best_checkpoint_loss: float = 0.0
|
|
473
|
+
|
|
470
474
|
for epoch in range(self.train_config.n_epochs):
|
|
471
475
|
self.train_model.train()
|
|
472
476
|
file_count = len(self.train_config.file_dataset)
|
|
@@ -539,7 +543,7 @@ class Trainer:
|
|
|
539
543
|
if TrainerTools().parallel.parallel_train:
|
|
540
544
|
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
541
545
|
|
|
542
|
-
|
|
546
|
+
current_loss = loss_tensor.item()
|
|
543
547
|
|
|
544
548
|
# ds模式已经集成gradient_clipping
|
|
545
549
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
@@ -553,7 +557,7 @@ class Trainer:
|
|
|
553
557
|
epoch_tag=f'epoch: {epoch}',
|
|
554
558
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
555
559
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
556
|
-
loss=
|
|
560
|
+
loss=current_loss
|
|
557
561
|
)
|
|
558
562
|
# reset to default
|
|
559
563
|
loss_accumulation = 0.0
|
|
@@ -565,6 +569,9 @@ class Trainer:
|
|
|
565
569
|
|
|
566
570
|
if (batch - last_ckpt_batch) >= self.train_config.eval_batch_interval:
|
|
567
571
|
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
572
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
573
|
+
last_best_checkpoint_loss = current_loss
|
|
574
|
+
|
|
568
575
|
last_ckpt_batch = batch
|
|
569
576
|
self._on_batch_end(tag=f'epoch:{epoch}/batch:{batch}')
|
|
570
577
|
|
|
@@ -574,8 +581,12 @@ class Trainer:
|
|
|
574
581
|
|
|
575
582
|
# end epoch
|
|
576
583
|
if not skipping_train:
|
|
577
|
-
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
578
584
|
save_steps(global_steps=global_steps, lr_scheduler=self.lr_scheduler)
|
|
585
|
+
|
|
586
|
+
save_checkpoint(model=self.train_model, optimizer=self.optimizer)
|
|
587
|
+
if save_best_checkpoint(current_loss, last_best_checkpoint_loss):
|
|
588
|
+
last_best_checkpoint_loss = current_loss
|
|
589
|
+
|
|
579
590
|
TrainerTools().parallel.on_epoch_end(epoch)
|
|
580
591
|
self._on_epoch_end(tag=f'epoch:{epoch}')
|
|
581
592
|
|
|
@@ -1,11 +1,11 @@
|
|
|
1
1
|
llm_trainer/__init__.py,sha256=HWgtTEVeQSnZmEyYQm2K6eFEG4X2QAoigMlB5Z2tcXE,260
|
|
2
|
-
llm_trainer/checkpoint.py,sha256=
|
|
2
|
+
llm_trainer/checkpoint.py,sha256=wC4GdIY2HAnxGHzUND5Yq-J_ynhDPT7A_2sXlWRElBc,4647
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
|
-
llm_trainer/dpo_trainer.py,sha256=
|
|
4
|
+
llm_trainer/dpo_trainer.py,sha256=xfYXlLA5TbqPKCUbk5_V79TreEh-dnLMaN72a3-Tdzg,11860
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=
|
|
8
|
-
llm_trainer/grpo_trainer.py,sha256=
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
|
+
llm_trainer/grpo_trainer.py,sha256=vTNi3n6R4NbwFh_s8LYN1TWEJm8AW2F5NVJlT5MHxKk,15990
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
11
11
|
llm_trainer/parallel.py,sha256=DQu8GqEFxD99HQ6hKuIxxyKi-05dMO33eMhImYlPuOI,4468
|
|
@@ -13,21 +13,21 @@ llm_trainer/parallel_ddp.py,sha256=Pob9vUlBZnkL4oP1Re11kFob7nufMSE96pn7m7fuOEM,1
|
|
|
13
13
|
llm_trainer/parallel_ds.py,sha256=oy8RRxHud3rACWubFlJqqd0pjPEQhKeAPGPQUSdJX2c,1145
|
|
14
14
|
llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,676
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
|
-
llm_trainer/scheduler.py,sha256=
|
|
16
|
+
llm_trainer/scheduler.py,sha256=lyC9TFuF_y8EXYq9d-WAqN4CSaq_w9kSKeh_BOo3EpI,4039
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
|
-
llm_trainer/train_configs.py,sha256=
|
|
21
|
-
llm_trainer/trainer.py,sha256=
|
|
20
|
+
llm_trainer/train_configs.py,sha256=c6bgivkkWRYcPD3NzI5uRItAUhZiIBgKVMuMgVFRnFo,7336
|
|
21
|
+
llm_trainer/trainer.py,sha256=g8YUP0FmBP3MGwewyoyOW35p9CY98rS62pzjnOMiWvE,25875
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.2.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.2.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.2.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.2.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.2.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.2.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.2.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.2.dist-info/METADATA,sha256=AP3lS957E984t8klMW_Z6VmrIU7-sBtrzAszA6V-KcQ,195
|
|
31
|
+
project_llm_trainer-0.5.2.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.2.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.2.dist-info/RECORD,,
|
{project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.2.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|