project-llm-trainer 0.4.7__py3-none-any.whl → 0.4.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -1,7 +1,6 @@
1
1
  from typing import Union, Optional, List
2
2
  from contextlib import nullcontext
3
3
  import torch
4
- import torch.distributed as dist
5
4
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
5
  from llm_model import VlmModel, KVCache
7
6
  from .tools import TrainerTools
@@ -109,7 +108,8 @@ def _generate(
109
108
  pixel_values: Optional[torch.Tensor] = None,
110
109
  tokens_per_image: int = -1,
111
110
  suppress_tokens: Optional[List[int]] = None,
112
- device: Union[str, torch.device, int]
111
+ device: Union[str, torch.device, int],
112
+ reasoning_budget: Optional[int] = None
113
113
  ):
114
114
  """
115
115
  :param model:
@@ -144,6 +144,27 @@ def _generate(
144
144
  kv_cache: Optional[KVCache] = None
145
145
  generate_tokens = tokens.clone()
146
146
 
147
+ reasoning_start = TrainerTools().tokenizer.reasoning_start
148
+ reasoning_end = TrainerTools().tokenizer.reasoning_end
149
+
150
+ # --- 状态初始化 ---
151
+ in_reasoning_block = False
152
+ reasoning_step_count = 0
153
+ # “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
154
+ suppress_reasoning_start_next = False
155
+
156
+ if reasoning_budget is not None:
157
+ prompt_tokens = tokens[0]
158
+ start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
159
+ end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
160
+
161
+ last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
162
+ last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
163
+
164
+ if last_start_idx > last_end_idx:
165
+ in_reasoning_block = True
166
+ reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
167
+
147
168
  model.eval()
148
169
  with torch.inference_mode():
149
170
  for _ in range(max_new_tokens):
@@ -163,6 +184,24 @@ def _generate(
163
184
 
164
185
  # (batch, vocab_size)
165
186
  logits = logits[:, -1, :]
187
+
188
+ # --- 推理预算逻辑 ---
189
+ force_end_reasoning_token = False
190
+ if reasoning_budget is not None:
191
+ # 检查是否需要在此步抑制 <reasoning>
192
+ should_suppress_this_step = suppress_reasoning_start_next
193
+ suppress_reasoning_start_next = False # 立即重置标志位
194
+
195
+ # 修改: 检查是否超出预算
196
+ if in_reasoning_block and reasoning_step_count >= reasoning_budget:
197
+ force_end_reasoning_token = True
198
+ # 设置标志位,在下一步抑制 <reasoning>
199
+ suppress_reasoning_start_next = True
200
+
201
+ # 如果上一轮设置了抑制标志,则在此轮执行抑制
202
+ if should_suppress_this_step:
203
+ logits[:, reasoning_start] = -float("inf")
204
+
166
205
  # 抑制特殊token输出
167
206
  if suppress_tokens and len(suppress_tokens) != 0:
168
207
  logits = _suppress_warper(logits, suppress_tokens)
@@ -178,6 +217,10 @@ def _generate(
178
217
  if p and 0 < p <= 1:
179
218
  logits = _top_p_warper(logits, p)
180
219
 
220
+ if force_end_reasoning_token:
221
+ logits[:] = -float("inf")
222
+ logits[:, reasoning_end] = 0.0
223
+
181
224
  if multinomial:
182
225
  prob = logits.softmax(dim=-1)
183
226
  # 返回下标
@@ -186,6 +229,18 @@ def _generate(
186
229
  # 返回下标
187
230
  next_token = logits.argmax(dim=-1, keepdim=True)
188
231
 
232
+ if reasoning_budget is not None:
233
+ current_token_id = next_token.item()
234
+ if not in_reasoning_block and current_token_id == reasoning_start:
235
+ in_reasoning_block = True
236
+ reasoning_step_count = 0
237
+ elif in_reasoning_block:
238
+ if current_token_id == reasoning_end:
239
+ in_reasoning_block = False
240
+ reasoning_step_count = 0
241
+ else:
242
+ reasoning_step_count += 1
243
+
189
244
  # token, is_full_result
190
245
  yield next_token, False
191
246
 
@@ -215,6 +270,7 @@ def _streaming_generate(
215
270
  tokens_per_image: int = -1,
216
271
  suppress_tokens: Optional[List[int]] = None,
217
272
  device: Union[str, torch.device, int] = None,
273
+ reasoning_budget: Optional[int] = None
218
274
  ):
219
275
  device = TrainerTools().parallel.device if not device else device
220
276
  encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
@@ -230,7 +286,8 @@ def _streaming_generate(
230
286
  pixel_values=pixel_values,
231
287
  tokens_per_image=tokens_per_image,
232
288
  suppress_tokens=suppress_tokens,
233
- device=device
289
+ device=device,
290
+ reasoning_budget=reasoning_budget
234
291
  )
235
292
 
236
293
  for (token, is_full_result) in generate_text_iterator:
@@ -250,6 +307,7 @@ def streaming_generate(
250
307
  tokens_per_image: int = -1,
251
308
  suppress_tokens: Optional[List[int]] = None,
252
309
  device: Union[str, torch.device, int] = None,
310
+ reasoning_budget: Optional[int] = None
253
311
  ):
254
312
  text_iterator = _streaming_generate(
255
313
  model=model,
@@ -262,7 +320,8 @@ def streaming_generate(
262
320
  pixel_values=pixel_values,
263
321
  tokens_per_image=tokens_per_image,
264
322
  suppress_tokens=suppress_tokens,
265
- device=device
323
+ device=device,
324
+ reasoning_budget=reasoning_budget
266
325
  )
267
326
 
268
327
  for (token, is_full_result) in text_iterator:
@@ -283,6 +342,7 @@ def generate(
283
342
  tokens_per_image: int = -1,
284
343
  suppress_tokens: Optional[List[int]] = None,
285
344
  device: Union[str, torch.device, int] = None,
345
+ reasoning_budget: Optional[int] = None
286
346
  ):
287
347
  text_iterator = _streaming_generate(
288
348
  model=model,
@@ -295,7 +355,8 @@ def generate(
295
355
  suppress_tokens=suppress_tokens,
296
356
  pixel_values=pixel_values,
297
357
  tokens_per_image=tokens_per_image,
298
- device=device
358
+ device=device,
359
+ reasoning_budget=reasoning_budget
299
360
  )
300
361
 
301
362
  for (token, is_full_result) in text_iterator:
@@ -22,11 +22,11 @@ class DsActivationCheckpointingConfig:
22
22
  self,
23
23
  *,
24
24
  partition_activations: bool = True,
25
- cpu_checkpointing: bool = True,
25
+ cpu_checkpointing: bool = False,
26
26
  contiguous_memory_optimization: bool = True,
27
27
  number_checkpoints: Optional[int] = None,
28
- synchronize_checkpoint_boundary: bool = True,
29
- profile: bool = True
28
+ synchronize_checkpoint_boundary: bool = False,
29
+ profile: bool = False
30
30
  ):
31
31
  self.partition_activations =partition_activations
32
32
  self.cpu_checkpointing = cpu_checkpointing
llm_trainer/trainer.py CHANGED
@@ -540,13 +540,17 @@ class Trainer:
540
540
  if gradient_accumulation_steps > 1:
541
541
  loss = loss / gradient_accumulation_steps
542
542
 
543
- loss_accumulation += loss.detach()
543
+ loss_accumulation += loss.detach().item()
544
544
  self._backward_loss(loss)
545
545
 
546
546
  if need_update_grad:
547
+ loss_tensor = torch.tensor(loss_accumulation, device=TrainerTools().parallel.device)
548
+
547
549
  # todo check all_reduce??
548
550
  if TrainerTools().parallel.parallel_train:
549
- dist.all_reduce(loss_accumulation, dist.ReduceOp.AVG)
551
+ dist.all_reduce(loss_tensor, dist.ReduceOp.AVG)
552
+
553
+ final_log_loss = loss_tensor.item()
550
554
 
551
555
  # ds模式已经集成gradient_clipping
552
556
  if not isinstance(TrainerTools().parallel, DsParallel) and self.lr_scheduler.can_clip_grad():
@@ -560,7 +564,7 @@ class Trainer:
560
564
  epoch_tag=f'epoch: {epoch}',
561
565
  file_tag=f'file: {file_idx + 1}/{file_count}',
562
566
  batch_tag=f'batch: {batch}/{batch_count_per_file}',
563
- loss=loss_accumulation.item()
567
+ loss=final_log_loss
564
568
  )
565
569
  # reset to default
566
570
  loss_accumulation = 0.0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.7
3
+ Version: 0.4.10
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -6,7 +6,7 @@ llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13
6
6
  llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
8
  llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
9
- llm_trainer/generate_utils.py,sha256=BmjpCrus_jvJ3SM2KS1bQNzJWAFnpJ9mI28iBWXZpvo,15206
9
+ llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
10
10
  llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
11
11
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
12
12
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
19
19
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
- llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
23
- llm_trainer/trainer.py,sha256=Zy1oesBfsFlDedZ4hn3gcAkTrpi5fr76bFFQikfAkak,25351
22
+ llm_trainer/train_configs.py,sha256=gzTXMLUuQexRvqyKIZQ1U6ESa0DELD7hPpYZdrDcyxg,15974
23
+ llm_trainer/trainer.py,sha256=pUtJVRosn54j1hn76CFAptJcAsrDo59H6p8NMkg2zt4,25521
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.7.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.7.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.7.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.7.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.7.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.7.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.7.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.7.dist-info/METADATA,sha256=u4_cQkQaH9QKqG_XcWiXzGHD5rnrzqHjvJWQvgVnkZQ,195
33
- project_llm_trainer-0.4.7.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.7.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.7.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.10.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.10.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.10.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.10.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.10.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.10.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.10.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.10.dist-info/METADATA,sha256=zrHUkQPm7Zox2CSeYN5HBqedZebXuZAQgZVj0O24U6I,196
33
+ project_llm_trainer-0.4.10.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.10.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.10.dist-info/RECORD,,