project-llm-trainer 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -107,8 +107,7 @@ def _generate(
107
107
  pixel_values: Optional[torch.Tensor] = None,
108
108
  tokens_per_image: int = -1,
109
109
  suppress_tokens: Optional[List[int]] = None,
110
- device: Union[str, torch.device, int],
111
- reasoning_budget: Optional[int] = None
110
+ device: Union[str, torch.device, int]
112
111
  ):
113
112
  """
114
113
  :param model:
@@ -142,27 +141,6 @@ def _generate(
142
141
  kv_cache: Optional[KVCache] = None
143
142
  generate_tokens = tokens.clone()
144
143
 
145
- reasoning_start = TrainerTools().tokenizer.reasoning_start
146
- reasoning_end = TrainerTools().tokenizer.reasoning_end
147
-
148
- # --- 状态初始化 ---
149
- in_reasoning_block = False
150
- reasoning_step_count = 0
151
- # “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
152
- suppress_reasoning_start_next = False
153
-
154
- if reasoning_budget is not None:
155
- prompt_tokens = tokens[0]
156
- start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
157
- end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
158
-
159
- last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
160
- last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
161
-
162
- if last_start_idx > last_end_idx:
163
- in_reasoning_block = True
164
- reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
165
-
166
144
  with torch.inference_mode():
167
145
  for _ in range(max_new_tokens):
168
146
  # 是否需要截取??
@@ -182,23 +160,6 @@ def _generate(
182
160
  # (batch, vocab_size)
183
161
  logits = logits[:, -1, :]
184
162
 
185
- # --- 推理预算逻辑 ---
186
- force_end_reasoning_token = False
187
- if reasoning_budget is not None:
188
- # 检查是否需要在此步抑制 <reasoning>
189
- should_suppress_this_step = suppress_reasoning_start_next
190
- suppress_reasoning_start_next = False # 立即重置标志位
191
-
192
- # 修改: 检查是否超出预算
193
- if in_reasoning_block and reasoning_step_count >= reasoning_budget:
194
- force_end_reasoning_token = True
195
- # 设置标志位,在下一步抑制 <reasoning>
196
- suppress_reasoning_start_next = True
197
-
198
- # 如果上一轮设置了抑制标志,则在此轮执行抑制
199
- if should_suppress_this_step:
200
- logits[:, reasoning_start] = -float("inf")
201
-
202
163
  # 抑制特殊token输出
203
164
  if suppress_tokens and len(suppress_tokens) != 0:
204
165
  logits = _suppress_warper(logits, suppress_tokens)
@@ -214,10 +175,6 @@ def _generate(
214
175
  if p and 0 < p <= 1:
215
176
  logits = _top_p_warper(logits, p)
216
177
 
217
- if force_end_reasoning_token:
218
- logits[:] = -float("inf")
219
- logits[:, reasoning_end] = 0.0
220
-
221
178
  if multinomial:
222
179
  prob = logits.softmax(dim=-1)
223
180
  # 返回下标
@@ -226,18 +183,6 @@ def _generate(
226
183
  # 返回下标
227
184
  next_token = logits.argmax(dim=-1, keepdim=True)
228
185
 
229
- if reasoning_budget is not None:
230
- current_token_id = next_token.item()
231
- if not in_reasoning_block and current_token_id == reasoning_start:
232
- in_reasoning_block = True
233
- reasoning_step_count = 0
234
- elif in_reasoning_block:
235
- if current_token_id == reasoning_end:
236
- in_reasoning_block = False
237
- reasoning_step_count = 0
238
- else:
239
- reasoning_step_count += 1
240
-
241
186
  # token, is_full_result
242
187
  yield next_token, False
243
188
 
@@ -266,8 +211,7 @@ def _streaming_generate(
266
211
  pixel_values: Optional[torch.Tensor] = None,
267
212
  tokens_per_image: int = -1,
268
213
  suppress_tokens: Optional[List[int]] = None,
269
- device: Union[str, torch.device, int] = None,
270
- reasoning_budget: Optional[int] = None
214
+ device: Union[str, torch.device, int] = None
271
215
  ):
272
216
  device = TrainerTools().parallel.device if not device else device
273
217
  encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
@@ -283,8 +227,7 @@ def _streaming_generate(
283
227
  pixel_values=pixel_values,
284
228
  tokens_per_image=tokens_per_image,
285
229
  suppress_tokens=suppress_tokens,
286
- device=device,
287
- reasoning_budget=reasoning_budget
230
+ device=device
288
231
  )
289
232
 
290
233
  for (token, is_full_result) in generate_text_iterator:
@@ -303,8 +246,7 @@ def streaming_generate(
303
246
  pixel_values: Optional[torch.Tensor] = None,
304
247
  tokens_per_image: int = -1,
305
248
  suppress_tokens: Optional[List[int]] = None,
306
- device: Union[str, torch.device, int] = None,
307
- reasoning_budget: Optional[int] = None
249
+ device: Union[str, torch.device, int] = None
308
250
  ):
309
251
  text_iterator = _streaming_generate(
310
252
  model=model,
@@ -317,8 +259,7 @@ def streaming_generate(
317
259
  pixel_values=pixel_values,
318
260
  tokens_per_image=tokens_per_image,
319
261
  suppress_tokens=suppress_tokens,
320
- device=device,
321
- reasoning_budget=reasoning_budget
262
+ device=device
322
263
  )
323
264
 
324
265
  for (token, is_full_result) in text_iterator:
@@ -338,8 +279,7 @@ def generate(
338
279
  pixel_values: Optional[torch.Tensor] = None,
339
280
  tokens_per_image: int = -1,
340
281
  suppress_tokens: Optional[List[int]] = None,
341
- device: Union[str, torch.device, int] = None,
342
- reasoning_budget: Optional[int] = None
282
+ device: Union[str, torch.device, int] = None
343
283
  ):
344
284
  text_iterator = _streaming_generate(
345
285
  model=model,
@@ -352,8 +292,7 @@ def generate(
352
292
  suppress_tokens=suppress_tokens,
353
293
  pixel_values=pixel_values,
354
294
  tokens_per_image=tokens_per_image,
355
- device=device,
356
- reasoning_budget=reasoning_budget
295
+ device=device
357
296
  )
358
297
 
359
298
  for (token, is_full_result) in text_iterator:
llm_trainer/tokenizer.py CHANGED
@@ -26,8 +26,8 @@ class Tokenizer:
26
26
  self.text_user = '<user>'
27
27
  self.text_assistant = '<assistant>'
28
28
 
29
- self.text_reasoning_start = '<reasoning>'
30
- self.text_reasoning_end = '</reasoning>'
29
+ self.text_think_start = '<think>'
30
+ self.text_think_end = '</think>'
31
31
 
32
32
  self.text_answer_start = '<answer>'
33
33
  self.text_answer_end = '</answer>'
@@ -47,8 +47,8 @@ class Tokenizer:
47
47
  additional_special_tokens = [
48
48
  AddedToken(self.text_user, lstrip=False, rstrip=False),
49
49
  AddedToken(self.text_assistant, lstrip=False, rstrip=False),
50
- AddedToken(self.text_reasoning_start, lstrip=False, rstrip=False),
51
- AddedToken(self.text_reasoning_end, lstrip=False, rstrip=False),
50
+ AddedToken(self.text_think_start, lstrip=False, rstrip=False),
51
+ AddedToken(self.text_think_end, lstrip=False, rstrip=False),
52
52
  AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
53
53
  AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
54
54
  AddedToken(self.text_system, lstrip=False, rstrip=False),
@@ -69,8 +69,8 @@ class Tokenizer:
69
69
  self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
70
70
  self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
71
71
 
72
- self.reasoning_start = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_start)
73
- self.reasoning_end = self.tokenizer.convert_tokens_to_ids(self.text_reasoning_end)
72
+ self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
73
+ self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
74
74
 
75
75
  self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
76
76
  self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
@@ -140,9 +140,9 @@ class Tokenizer:
140
140
  {"role":"user", "content":"hello?"},
141
141
  {"role":"assistant", "content":"hello"},
142
142
  {"role":"user", "content":"hello hello?"},
143
- {"role":"assistant", "reasoning":"thinking", "content":"hello hello"},
143
+ {"role":"assistant", "think":"thinking", "content":"hello hello"},
144
144
  ]
145
- <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><reasoning>thinking</reasoning><answer>hello hello</answer></s>
145
+ <system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
146
146
  """
147
147
 
148
148
  chat_template = ''
@@ -154,8 +154,8 @@ class Tokenizer:
154
154
  if add_answer_tag_for_assistant and role == 'assistant':
155
155
  content = f"{self.text_answer_start}{content}{self.text_answer_end}"
156
156
 
157
- if 'reasoning' in conversation:
158
- content = f"{self.text_reasoning_start}{conversation['reasoning']}{self.text_reasoning_end}{content}"
157
+ if 'think' in conversation:
158
+ content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
159
159
 
160
160
  chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
161
161
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: project_llm_trainer
3
- Version: 0.5.0
3
+ Version: 0.5.1
4
4
  Summary: LLM and VLM trainer
5
5
  Author: qibin
6
6
  Author-email: qibin0506@gmail.com
@@ -4,7 +4,7 @@ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
4
4
  llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
5
5
  llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
6
6
  llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
7
- llm_trainer/generate_utils.py,sha256=2MoEGEpoTzx7khO3dPcC2akFLyjtbFFpdJtuB_QQ3OY,17708
7
+ llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
8
8
  llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
9
9
  llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
10
10
  llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
@@ -15,19 +15,19 @@ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,
15
15
  llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
16
16
  llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
17
17
  llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
18
- llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
18
+ llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
19
19
  llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
20
20
  llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
21
21
  llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
22
22
  llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
23
- project_llm_trainer-0.5.0.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
- project_llm_trainer-0.5.0.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
- project_llm_trainer-0.5.0.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
- project_llm_trainer-0.5.0.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
- project_llm_trainer-0.5.0.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
- project_llm_trainer-0.5.0.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
- project_llm_trainer-0.5.0.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
- project_llm_trainer-0.5.0.dist-info/METADATA,sha256=YDj-N4VL8O_AqNanwfU6Yt38J97p3RgtUSzmwl0Y-GM,195
31
- project_llm_trainer-0.5.0.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
- project_llm_trainer-0.5.0.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
- project_llm_trainer-0.5.0.dist-info/RECORD,,
23
+ project_llm_trainer-0.5.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
24
+ project_llm_trainer-0.5.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
25
+ project_llm_trainer-0.5.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
26
+ project_llm_trainer-0.5.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
27
+ project_llm_trainer-0.5.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
28
+ project_llm_trainer-0.5.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
29
+ project_llm_trainer-0.5.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
30
+ project_llm_trainer-0.5.1.dist-info/METADATA,sha256=x-Bobn0EH7wyKznJydUeVLK9sdIrkBmDYDbEpyG4pKc,195
31
+ project_llm_trainer-0.5.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
32
+ project_llm_trainer-0.5.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
33
+ project_llm_trainer-0.5.1.dist-info/RECORD,,