project-llm-trainer 0.4.7__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/generate_utils.py +66 -5
- llm_trainer/train_configs.py +3 -3
- llm_trainer/trainer.py +7 -3
- {project_llm_trainer-0.4.7.dist-info → project_llm_trainer-0.4.10.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.7.dist-info → project_llm_trainer-0.4.10.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.7.dist-info → project_llm_trainer-0.4.10.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.7.dist-info → project_llm_trainer-0.4.10.dist-info}/top_level.txt +0 -0
llm_trainer/generate_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Union, Optional, List
|
|
2
2
|
from contextlib import nullcontext
|
|
3
3
|
import torch
|
|
4
|
-
import torch.distributed as dist
|
|
5
4
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
6
5
|
from llm_model import VlmModel, KVCache
|
|
7
6
|
from .tools import TrainerTools
|
|
@@ -109,7 +108,8 @@ def _generate(
|
|
|
109
108
|
pixel_values: Optional[torch.Tensor] = None,
|
|
110
109
|
tokens_per_image: int = -1,
|
|
111
110
|
suppress_tokens: Optional[List[int]] = None,
|
|
112
|
-
device: Union[str, torch.device, int]
|
|
111
|
+
device: Union[str, torch.device, int],
|
|
112
|
+
reasoning_budget: Optional[int] = None
|
|
113
113
|
):
|
|
114
114
|
"""
|
|
115
115
|
:param model:
|
|
@@ -144,6 +144,27 @@ def _generate(
|
|
|
144
144
|
kv_cache: Optional[KVCache] = None
|
|
145
145
|
generate_tokens = tokens.clone()
|
|
146
146
|
|
|
147
|
+
reasoning_start = TrainerTools().tokenizer.reasoning_start
|
|
148
|
+
reasoning_end = TrainerTools().tokenizer.reasoning_end
|
|
149
|
+
|
|
150
|
+
# --- 状态初始化 ---
|
|
151
|
+
in_reasoning_block = False
|
|
152
|
+
reasoning_step_count = 0
|
|
153
|
+
# “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
|
|
154
|
+
suppress_reasoning_start_next = False
|
|
155
|
+
|
|
156
|
+
if reasoning_budget is not None:
|
|
157
|
+
prompt_tokens = tokens[0]
|
|
158
|
+
start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
|
|
159
|
+
end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
|
|
160
|
+
|
|
161
|
+
last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
|
|
162
|
+
last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
|
|
163
|
+
|
|
164
|
+
if last_start_idx > last_end_idx:
|
|
165
|
+
in_reasoning_block = True
|
|
166
|
+
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
167
|
+
|
|
147
168
|
model.eval()
|
|
148
169
|
with torch.inference_mode():
|
|
149
170
|
for _ in range(max_new_tokens):
|
|
@@ -163,6 +184,24 @@ def _generate(
|
|
|
163
184
|
|
|
164
185
|
# (batch, vocab_size)
|
|
165
186
|
logits = logits[:, -1, :]
|
|
187
|
+
|
|
188
|
+
# --- 推理预算逻辑 ---
|
|
189
|
+
force_end_reasoning_token = False
|
|
190
|
+
if reasoning_budget is not None:
|
|
191
|
+
# 检查是否需要在此步抑制 <reasoning>
|
|
192
|
+
should_suppress_this_step = suppress_reasoning_start_next
|
|
193
|
+
suppress_reasoning_start_next = False # 立即重置标志位
|
|
194
|
+
|
|
195
|
+
# 修改: 检查是否超出预算
|
|
196
|
+
if in_reasoning_block and reasoning_step_count >= reasoning_budget:
|
|
197
|
+
force_end_reasoning_token = True
|
|
198
|
+
# 设置标志位,在下一步抑制 <reasoning>
|
|
199
|
+
suppress_reasoning_start_next = True
|
|
200
|
+
|
|
201
|
+
# 如果上一轮设置了抑制标志,则在此轮执行抑制
|
|
202
|
+
if should_suppress_this_step:
|
|
203
|
+
logits[:, reasoning_start] = -float("inf")
|
|
204
|
+
|
|
166
205
|
# 抑制特殊token输出
|
|
167
206
|
if suppress_tokens and len(suppress_tokens) != 0:
|
|
168
207
|
logits = _suppress_warper(logits, suppress_tokens)
|
|
@@ -178,6 +217,10 @@ def _generate(
|
|
|
178
217
|
if p and 0 < p <= 1:
|
|
179
218
|
logits = _top_p_warper(logits, p)
|
|
180
219
|
|
|
220
|
+
if force_end_reasoning_token:
|
|
221
|
+
logits[:] = -float("inf")
|
|
222
|
+
logits[:, reasoning_end] = 0.0
|
|
223
|
+
|
|
181
224
|
if multinomial:
|
|
182
225
|
prob = logits.softmax(dim=-1)
|
|
183
226
|
# 返回下标
|
|
@@ -186,6 +229,18 @@ def _generate(
|
|
|
186
229
|
# 返回下标
|
|
187
230
|
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
188
231
|
|
|
232
|
+
if reasoning_budget is not None:
|
|
233
|
+
current_token_id = next_token.item()
|
|
234
|
+
if not in_reasoning_block and current_token_id == reasoning_start:
|
|
235
|
+
in_reasoning_block = True
|
|
236
|
+
reasoning_step_count = 0
|
|
237
|
+
elif in_reasoning_block:
|
|
238
|
+
if current_token_id == reasoning_end:
|
|
239
|
+
in_reasoning_block = False
|
|
240
|
+
reasoning_step_count = 0
|
|
241
|
+
else:
|
|
242
|
+
reasoning_step_count += 1
|
|
243
|
+
|
|
189
244
|
# token, is_full_result
|
|
190
245
|
yield next_token, False
|
|
191
246
|
|
|
@@ -215,6 +270,7 @@ def _streaming_generate(
|
|
|
215
270
|
tokens_per_image: int = -1,
|
|
216
271
|
suppress_tokens: Optional[List[int]] = None,
|
|
217
272
|
device: Union[str, torch.device, int] = None,
|
|
273
|
+
reasoning_budget: Optional[int] = None
|
|
218
274
|
):
|
|
219
275
|
device = TrainerTools().parallel.device if not device else device
|
|
220
276
|
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
@@ -230,7 +286,8 @@ def _streaming_generate(
|
|
|
230
286
|
pixel_values=pixel_values,
|
|
231
287
|
tokens_per_image=tokens_per_image,
|
|
232
288
|
suppress_tokens=suppress_tokens,
|
|
233
|
-
device=device
|
|
289
|
+
device=device,
|
|
290
|
+
reasoning_budget=reasoning_budget
|
|
234
291
|
)
|
|
235
292
|
|
|
236
293
|
for (token, is_full_result) in generate_text_iterator:
|
|
@@ -250,6 +307,7 @@ def streaming_generate(
|
|
|
250
307
|
tokens_per_image: int = -1,
|
|
251
308
|
suppress_tokens: Optional[List[int]] = None,
|
|
252
309
|
device: Union[str, torch.device, int] = None,
|
|
310
|
+
reasoning_budget: Optional[int] = None
|
|
253
311
|
):
|
|
254
312
|
text_iterator = _streaming_generate(
|
|
255
313
|
model=model,
|
|
@@ -262,7 +320,8 @@ def streaming_generate(
|
|
|
262
320
|
pixel_values=pixel_values,
|
|
263
321
|
tokens_per_image=tokens_per_image,
|
|
264
322
|
suppress_tokens=suppress_tokens,
|
|
265
|
-
device=device
|
|
323
|
+
device=device,
|
|
324
|
+
reasoning_budget=reasoning_budget
|
|
266
325
|
)
|
|
267
326
|
|
|
268
327
|
for (token, is_full_result) in text_iterator:
|
|
@@ -283,6 +342,7 @@ def generate(
|
|
|
283
342
|
tokens_per_image: int = -1,
|
|
284
343
|
suppress_tokens: Optional[List[int]] = None,
|
|
285
344
|
device: Union[str, torch.device, int] = None,
|
|
345
|
+
reasoning_budget: Optional[int] = None
|
|
286
346
|
):
|
|
287
347
|
text_iterator = _streaming_generate(
|
|
288
348
|
model=model,
|
|
@@ -295,7 +355,8 @@ def generate(
|
|
|
295
355
|
suppress_tokens=suppress_tokens,
|
|
296
356
|
pixel_values=pixel_values,
|
|
297
357
|
tokens_per_image=tokens_per_image,
|
|
298
|
-
device=device
|
|
358
|
+
device=device,
|
|
359
|
+
reasoning_budget=reasoning_budget
|
|
299
360
|
)
|
|
300
361
|
|
|
301
362
|
for (token, is_full_result) in text_iterator:
|
llm_trainer/train_configs.py
CHANGED
|
@@ -22,11 +22,11 @@ class DsActivationCheckpointingConfig:
|
|
|
22
22
|
self,
|
|
23
23
|
*,
|
|
24
24
|
partition_activations: bool = True,
|
|
25
|
-
cpu_checkpointing: bool =
|
|
25
|
+
cpu_checkpointing: bool = False,
|
|
26
26
|
contiguous_memory_optimization: bool = True,
|
|
27
27
|
number_checkpoints: Optional[int] = None,
|
|
28
|
-
synchronize_checkpoint_boundary: bool =
|
|
29
|
-
profile: bool =
|
|
28
|
+
synchronize_checkpoint_boundary: bool = False,
|
|
29
|
+
profile: bool = False
|
|
30
30
|
):
|
|
31
31
|
self.partition_activations =partition_activations
|
|
32
32
|
self.cpu_checkpointing = cpu_checkpointing
|
llm_trainer/trainer.py
CHANGED
|
@@ -540,13 +540,17 @@ class Trainer:
|
|
|
540
540
|
if gradient_accumulation_steps > 1:
|
|
541
541
|
loss = loss / gradient_accumulation_steps
|
|
542
542
|
|
|
543
|
-
loss_accumulation += loss.detach()
|
|
543
|
+
loss_accumulation += loss.detach().item()
|
|
544
544
|
self._backward_loss(loss)
|
|
545
545
|
|
|
546
546
|
if need_update_grad:
|
|
547
|
+
loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
|
|
548
|
+
|
|
547
549
|
# todo check all_reduce??
|
|
548
550
|
if TrainerTools().parallel.parallel_train:
|
|
549
|
-
dist.all_reduce(
|
|
551
|
+
dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
|
|
552
|
+
|
|
553
|
+
final_log_loss = loss_tensor.item()
|
|
550
554
|
|
|
551
555
|
# ds模式已经集成gradient_clipping
|
|
552
556
|
if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
|
|
@@ -560,7 +564,7 @@ class Trainer:
|
|
|
560
564
|
epoch_tag=f'epoch: {epoch}',
|
|
561
565
|
file_tag=f'file: {file_idx + 1}/{file_count}',
|
|
562
566
|
batch_tag=f'batch: {batch}/{batch_count_per_file}',
|
|
563
|
-
loss=
|
|
567
|
+
loss=final_log_loss
|
|
564
568
|
)
|
|
565
569
|
# reset to default
|
|
566
570
|
loss_accumulation = 0.0
|
|
@@ -6,7 +6,7 @@ llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13
|
|
|
6
6
|
llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
|
|
7
7
|
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
8
|
llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
|
|
9
|
-
llm_trainer/generate_utils.py,sha256=
|
|
9
|
+
llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
|
|
10
10
|
llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
|
|
11
11
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
12
12
|
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
|
19
19
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
20
20
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
21
|
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
|
-
llm_trainer/train_configs.py,sha256=
|
|
23
|
-
llm_trainer/trainer.py,sha256=
|
|
22
|
+
llm_trainer/train_configs.py,sha256=gzTXMLUuQexRvqyKIZQ1U6ESa0DELD7hPpYZdrDcyxg,15974
|
|
23
|
+
llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.
|
|
26
|
-
project_llm_trainer-0.4.
|
|
27
|
-
project_llm_trainer-0.4.
|
|
28
|
-
project_llm_trainer-0.4.
|
|
29
|
-
project_llm_trainer-0.4.
|
|
30
|
-
project_llm_trainer-0.4.
|
|
31
|
-
project_llm_trainer-0.4.
|
|
32
|
-
project_llm_trainer-0.4.
|
|
33
|
-
project_llm_trainer-0.4.
|
|
34
|
-
project_llm_trainer-0.4.
|
|
35
|
-
project_llm_trainer-0.4.
|
|
25
|
+
project_llm_trainer-0.4.10.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.10.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.10.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.10.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.10.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.10.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.10.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.10.dist-info/METADATA,sha256=zrHUkQPm7Zox2CSeYN5HBqedZebXuZAQgZVj0O24U6I,196
|
|
33
|
+
project_llm_trainer-0.4.10.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.10.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.10.dist-info/RECORD,,
|
{project_llm_trainer-0.4.7.data → project_llm_trainer-0.4.10.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|