project-llm-trainer 0.4.6__py3-none-any.whl → 0.4.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -69,7 +69,7 @@ def load_ds_checkpoint_for_eval(model: nn.Module):
69
69
 
70
70
  def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
71
71
  """
72
- 可以在任意rank上调用,然后只有rank0有值
72
+ 需要在所有rank上调用,然后只有rank0有值
73
73
  """
74
74
 
75
75
  if model.zero_optimization_stage() != 3:
@@ -99,7 +99,6 @@ def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
99
99
  # if TrainerTools().parallel.is_main_process:
100
100
  # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
101
101
  # else:
102
- # print("22222222")
103
102
  # if TrainerTools().parallel.is_main_process:
104
103
  # state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
105
104
  #
@@ -1,7 +1,6 @@
1
1
  from typing import Union, Optional, List
2
2
  from contextlib import nullcontext
3
3
  import torch
4
- import torch.distributed as dist
5
4
  from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
6
5
  from llm_model import VlmModel, KVCache
7
6
  from .tools import TrainerTools
@@ -109,7 +108,8 @@ def _generate(
109
108
  pixel_values: Optional[torch.Tensor] = None,
110
109
  tokens_per_image: int = -1,
111
110
  suppress_tokens: Optional[List[int]] = None,
112
- device: Union[str, torch.device, int]
111
+ device: Union[str, torch.device, int],
112
+ reasoning_budget: Optional[int] = None
113
113
  ):
114
114
  """
115
115
  :param model:
@@ -144,6 +144,27 @@ def _generate(
144
144
  kv_cache: Optional[KVCache] = None
145
145
  generate_tokens = tokens.clone()
146
146
 
147
+ reasoning_start = TrainerTools().tokenizer.reasoning_start
148
+ reasoning_end = TrainerTools().tokenizer.reasoning_end
149
+
150
+ # --- 状态初始化 ---
151
+ in_reasoning_block = False
152
+ reasoning_step_count = 0
153
+ # “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
154
+ suppress_reasoning_start_next = False
155
+
156
+ if reasoning_budget is not None:
157
+ prompt_tokens = tokens[0]
158
+ start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
159
+ end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
160
+
161
+ last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
162
+ last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
163
+
164
+ if last_start_idx > last_end_idx:
165
+ in_reasoning_block = True
166
+ reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
167
+
147
168
  model.eval()
148
169
  with torch.inference_mode():
149
170
  for _ in range(max_new_tokens):
@@ -163,6 +184,24 @@ def _generate(
163
184
 
164
185
  # (batch, vocab_size)
165
186
  logits = logits[:, -1, :]
187
+
188
+ # --- 推理预算逻辑 ---
189
+ force_end_reasoning_token = False
190
+ if reasoning_budget is not None:
191
+ # 检查是否需要在此步抑制 <reasoning>
192
+ should_suppress_this_step = suppress_reasoning_start_next
193
+ suppress_reasoning_start_next = False # 立即重置标志位
194
+
195
+ # 修改: 检查是否超出预算
196
+ if in_reasoning_block and reasoning_step_count >= reasoning_budget:
197
+ force_end_reasoning_token = True
198
+ # 设置标志位,在下一步抑制 <reasoning>
199
+ suppress_reasoning_start_next = True
200
+
201
+ # 如果上一轮设置了抑制标志,则在此轮执行抑制
202
+ if should_suppress_this_step:
203
+ logits[:, reasoning_start] = -float("inf")
204
+
166
205
  # 抑制特殊token输出
167
206
  if suppress_tokens and len(suppress_tokens) != 0:
168
207
  logits = _suppress_warper(logits, suppress_tokens)
@@ -175,9 +214,13 @@ def _generate(
175
214
  if k and k != 0:
176
215
  logits = _top_k_warper(logits, k, device)
177
216
 
178
- if p and p < 1:
217
+ if p and 0 < p <= 1:
179
218
  logits = _top_p_warper(logits, p)
180
219
 
220
+ if force_end_reasoning_token:
221
+ logits[:] = -float("inf")
222
+ logits[:, reasoning_end] = 0.0
223
+
181
224
  if multinomial:
182
225
  prob = logits.softmax(dim=-1)
183
226
  # 返回下标
@@ -186,6 +229,18 @@ def _generate(
186
229
  # 返回下标
187
230
  next_token = logits.argmax(dim=-1, keepdim=True)
188
231
 
232
+ if reasoning_budget is not None:
233
+ current_token_id = next_token.item()
234
+ if not in_reasoning_block and current_token_id == reasoning_start:
235
+ in_reasoning_block = True
236
+ reasoning_step_count = 0
237
+ elif in_reasoning_block:
238
+ if current_token_id == reasoning_end:
239
+ in_reasoning_block = False
240
+ reasoning_step_count = 0
241
+ else:
242
+ reasoning_step_count += 1
243
+
189
244
  # token, is_full_result
190
245
  yield next_token, False
191
246
 
@@ -210,11 +265,12 @@ def _streaming_generate(
210
265
  max_new_tokens: int,
211
266
  temperature: Optional[float] = 1.0,
212
267
  k: Optional[int] = None,
213
- p: Optional[float] = 1.0,
268
+ p: Optional[float] = None,
214
269
  pixel_values: Optional[torch.Tensor] = None,
215
270
  tokens_per_image: int = -1,
216
271
  suppress_tokens: Optional[List[int]] = None,
217
272
  device: Union[str, torch.device, int] = None,
273
+ reasoning_budget: Optional[int] = None
218
274
  ):
219
275
  device = TrainerTools().parallel.device if not device else device
220
276
  encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
@@ -230,7 +286,8 @@ def _streaming_generate(
230
286
  pixel_values=pixel_values,
231
287
  tokens_per_image=tokens_per_image,
232
288
  suppress_tokens=suppress_tokens,
233
- device=device
289
+ device=device,
290
+ reasoning_budget=reasoning_budget
234
291
  )
235
292
 
236
293
  for (token, is_full_result) in generate_text_iterator:
@@ -245,11 +302,12 @@ def streaming_generate(
245
302
  max_new_tokens: int,
246
303
  temperature: Optional[float] = 1.0,
247
304
  k: Optional[int] = None,
248
- p: Optional[float] = 1.0,
305
+ p: Optional[float] = None,
249
306
  pixel_values: Optional[torch.Tensor] = None,
250
307
  tokens_per_image: int = -1,
251
308
  suppress_tokens: Optional[List[int]] = None,
252
309
  device: Union[str, torch.device, int] = None,
310
+ reasoning_budget: Optional[int] = None
253
311
  ):
254
312
  text_iterator = _streaming_generate(
255
313
  model=model,
@@ -262,7 +320,8 @@ def streaming_generate(
262
320
  pixel_values=pixel_values,
263
321
  tokens_per_image=tokens_per_image,
264
322
  suppress_tokens=suppress_tokens,
265
- device=device
323
+ device=device,
324
+ reasoning_budget=reasoning_budget
266
325
  )
267
326
 
268
327
  for (token, is_full_result) in text_iterator:
@@ -278,11 +337,12 @@ def generate(
278
337
  max_new_tokens: int,
279
338
  temperature: Optional[float] = 1.0,
280
339
  k: Optional[int] = None,
281
- p: Optional[float] = 1.0,
340
+ p: Optional[float] = None,
282
341
  pixel_values: Optional[torch.Tensor] = None,
283
342
  tokens_per_image: int = -1,
284
343
  suppress_tokens: Optional[List[int]] = None,
285
344
  device: Union[str, torch.device, int] = None,
345
+ reasoning_budget: Optional[int] = None
286
346
  ):
287
347
  text_iterator = _streaming_generate(
288
348
  model=model,
@@ -295,7 +355,8 @@ def generate(
295
355
  suppress_tokens=suppress_tokens,
296
356
  pixel_values=pixel_values,
297
357
  tokens_per_image=tokens_per_image,
298
- device=device
358
+ device=device,
359
+ reasoning_budget=reasoning_budget
299
360
  )
300
361
 
301
362
  for (token, is_full_result) in text_iterator:
@@ -382,7 +443,7 @@ def batch_generate(
382
443
  if k and k != 0:
383
444
  logits = _top_k_warper(logits, k, device)
384
445
 
385
- if p and p < 1:
446
+ if p and 0 < p <= 1:
386
447
  logits = _top_p_warper(logits, p)
387
448
 
388
449
  prob = logits.softmax(dim=-1)
@@ -22,11 +22,11 @@ class DsActivationCheckpointingConfig:
22
22
  self,
23
23
  *,
24
24
  partition_activations: bool = True,
25
- cpu_checkpointing: bool = True,
25
+ cpu_checkpointing: bool = False,
26
26
  contiguous_memory_optimization: bool = True,
27
27
  number_checkpoints: Optional[int] = None,
28
- synchronize_checkpoint_boundary: bool = True,
29
- profile: bool = True
28
+ synchronize_checkpoint_boundary: bool = False,
29
+ profile: bool = False
30
30
  ):
31
31
  self.partition_activations =partition_activations
32
32
  self.cpu_checkpointing = cpu_checkpointing
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.4.6
3
+ Version: 0.4.9
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -3,10 +3,10 @@ llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,574
3
3
  llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
5
5
  llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13406
6
- llm_trainer/ds_checkpoint.py,sha256=7o9oxHUqPJNQESuZz83vHUmV83AkUh19mV9nY6qg4PE,4957
6
+ llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
7
7
  llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
8
8
  llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
9
- llm_trainer/generate_utils.py,sha256=_3TWAt-W8ZIzDZrLHEBR2iiZ3bn4V34WuVvKgCuDtyI,15193
9
+ llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
10
10
  llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
11
11
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
12
12
  llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
19
19
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
20
20
  llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
21
21
  llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
22
- llm_trainer/train_configs.py,sha256=arnet3tIzgVnwshod08F1jE7r4I7e-SIgMy55IagPnE,15971
22
+ llm_trainer/train_configs.py,sha256=gzTXMLUuQexRvqyKIZQ1U6ESa0DELD7hPpYZdrDcyxg,15974
23
23
  llm_trainer/trainer.py,sha256=Zy1oesBfsFlDedZ4hn3gcAkTrpi5fr76bFFQikfAkak,25351
24
24
  llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
25
- project_llm_trainer-0.4.6.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
- project_llm_trainer-0.4.6.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
- project_llm_trainer-0.4.6.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
- project_llm_trainer-0.4.6.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
- project_llm_trainer-0.4.6.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
- project_llm_trainer-0.4.6.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
- project_llm_trainer-0.4.6.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
- project_llm_trainer-0.4.6.dist-info/METADATA,sha256=PcLGKG5luK4XZFXbVRjiBrXNq9EzRvWCgZ6K1cxMzlo,195
33
- project_llm_trainer-0.4.6.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
- project_llm_trainer-0.4.6.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
- project_llm_trainer-0.4.6.dist-info/RECORD,,
25
+ project_llm_trainer-0.4.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
26
+ project_llm_trainer-0.4.9.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
27
+ project_llm_trainer-0.4.9.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
28
+ project_llm_trainer-0.4.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
29
+ project_llm_trainer-0.4.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
30
+ project_llm_trainer-0.4.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
31
+ project_llm_trainer-0.4.9.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
32
+ project_llm_trainer-0.4.9.dist-info/METADATA,sha256=Xn5uE_6i2vFPTFf2O2iJT59czHzHxTAUeb6ZDCaH0-A,195
33
+ project_llm_trainer-0.4.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
34
+ project_llm_trainer-0.4.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
35
+ project_llm_trainer-0.4.9.dist-info/RECORD,,