project-llm-trainer 0.5.0__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/generate_utils.py +7 -68
- llm_trainer/tokenizer.py +10 -10
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.1.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.1.dist-info}/RECORD +13 -13
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.1.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.5.0.dist-info → project_llm_trainer-0.5.1.dist-info}/top_level.txt +0 -0
llm_trainer/generate_utils.py
CHANGED
|
@@ -107,8 +107,7 @@ def _generate(
|
|
|
107
107
|
pixel_values: Optional[torch.Tensor] = None,
|
|
108
108
|
tokens_per_image: int = -1,
|
|
109
109
|
suppress_tokens: Optional[List[int]] = None,
|
|
110
|
-
device: Union[str, torch.device, int]
|
|
111
|
-
reasoning_budget: Optional[int] = None
|
|
110
|
+
device: Union[str, torch.device, int]
|
|
112
111
|
):
|
|
113
112
|
"""
|
|
114
113
|
:param model:
|
|
@@ -142,27 +141,6 @@ def _generate(
|
|
|
142
141
|
kv_cache: Optional[KVCache] = None
|
|
143
142
|
generate_tokens = tokens.clone()
|
|
144
143
|
|
|
145
|
-
reasoning_start = TrainerTools().tokenizer.reasoning_start
|
|
146
|
-
reasoning_end = TrainerTools().tokenizer.reasoning_end
|
|
147
|
-
|
|
148
|
-
# --- 状态初始化 ---
|
|
149
|
-
in_reasoning_block = False
|
|
150
|
-
reasoning_step_count = 0
|
|
151
|
-
# “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
|
|
152
|
-
suppress_reasoning_start_next = False
|
|
153
|
-
|
|
154
|
-
if reasoning_budget is not None:
|
|
155
|
-
prompt_tokens = tokens[0]
|
|
156
|
-
start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
|
|
157
|
-
end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
|
|
158
|
-
|
|
159
|
-
last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
|
|
160
|
-
last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
|
|
161
|
-
|
|
162
|
-
if last_start_idx > last_end_idx:
|
|
163
|
-
in_reasoning_block = True
|
|
164
|
-
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
165
|
-
|
|
166
144
|
with torch.inference_mode():
|
|
167
145
|
for _ in range(max_new_tokens):
|
|
168
146
|
# 是否需要截取??
|
|
@@ -182,23 +160,6 @@ def _generate(
|
|
|
182
160
|
# (batch, vocab_size)
|
|
183
161
|
logits = logits[:, -1, :]
|
|
184
162
|
|
|
185
|
-
# --- 推理预算逻辑 ---
|
|
186
|
-
force_end_reasoning_token = False
|
|
187
|
-
if reasoning_budget is not None:
|
|
188
|
-
# 检查是否需要在此步抑制 <reasoning>
|
|
189
|
-
should_suppress_this_step = suppress_reasoning_start_next
|
|
190
|
-
suppress_reasoning_start_next = False # 立即重置标志位
|
|
191
|
-
|
|
192
|
-
# 修改: 检查是否超出预算
|
|
193
|
-
if in_reasoning_block and reasoning_step_count >= reasoning_budget:
|
|
194
|
-
force_end_reasoning_token = True
|
|
195
|
-
# 设置标志位,在下一步抑制 <reasoning>
|
|
196
|
-
suppress_reasoning_start_next = True
|
|
197
|
-
|
|
198
|
-
# 如果上一轮设置了抑制标志,则在此轮执行抑制
|
|
199
|
-
if should_suppress_this_step:
|
|
200
|
-
logits[:, reasoning_start] = -float("inf")
|
|
201
|
-
|
|
202
163
|
# 抑制特殊token输出
|
|
203
164
|
if suppress_tokens and len(suppress_tokens) != 0:
|
|
204
165
|
logits = _suppress_warper(logits, suppress_tokens)
|
|
@@ -214,10 +175,6 @@ def _generate(
|
|
|
214
175
|
if p and 0 < p <= 1:
|
|
215
176
|
logits = _top_p_warper(logits, p)
|
|
216
177
|
|
|
217
|
-
if force_end_reasoning_token:
|
|
218
|
-
logits[:] = -float("inf")
|
|
219
|
-
logits[:, reasoning_end] = 0.0
|
|
220
|
-
|
|
221
178
|
if multinomial:
|
|
222
179
|
prob = logits.softmax(dim=-1)
|
|
223
180
|
# 返回下标
|
|
@@ -226,18 +183,6 @@ def _generate(
|
|
|
226
183
|
# 返回下标
|
|
227
184
|
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
228
185
|
|
|
229
|
-
if reasoning_budget is not None:
|
|
230
|
-
current_token_id = next_token.item()
|
|
231
|
-
if not in_reasoning_block and current_token_id == reasoning_start:
|
|
232
|
-
in_reasoning_block = True
|
|
233
|
-
reasoning_step_count = 0
|
|
234
|
-
elif in_reasoning_block:
|
|
235
|
-
if current_token_id == reasoning_end:
|
|
236
|
-
in_reasoning_block = False
|
|
237
|
-
reasoning_step_count = 0
|
|
238
|
-
else:
|
|
239
|
-
reasoning_step_count += 1
|
|
240
|
-
|
|
241
186
|
# token, is_full_result
|
|
242
187
|
yield next_token, False
|
|
243
188
|
|
|
@@ -266,8 +211,7 @@ def _streaming_generate(
|
|
|
266
211
|
pixel_values: Optional[torch.Tensor] = None,
|
|
267
212
|
tokens_per_image: int = -1,
|
|
268
213
|
suppress_tokens: Optional[List[int]] = None,
|
|
269
|
-
device: Union[str, torch.device, int] = None
|
|
270
|
-
reasoning_budget: Optional[int] = None
|
|
214
|
+
device: Union[str, torch.device, int] = None
|
|
271
215
|
):
|
|
272
216
|
device = TrainerTools().parallel.device if not device else device
|
|
273
217
|
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
@@ -283,8 +227,7 @@ def _streaming_generate(
|
|
|
283
227
|
pixel_values=pixel_values,
|
|
284
228
|
tokens_per_image=tokens_per_image,
|
|
285
229
|
suppress_tokens=suppress_tokens,
|
|
286
|
-
device=device
|
|
287
|
-
reasoning_budget=reasoning_budget
|
|
230
|
+
device=device
|
|
288
231
|
)
|
|
289
232
|
|
|
290
233
|
for (token, is_full_result) in generate_text_iterator:
|
|
@@ -303,8 +246,7 @@ def streaming_generate(
|
|
|
303
246
|
pixel_values: Optional[torch.Tensor] = None,
|
|
304
247
|
tokens_per_image: int = -1,
|
|
305
248
|
suppress_tokens: Optional[List[int]] = None,
|
|
306
|
-
device: Union[str, torch.device, int] = None
|
|
307
|
-
reasoning_budget: Optional[int] = None
|
|
249
|
+
device: Union[str, torch.device, int] = None
|
|
308
250
|
):
|
|
309
251
|
text_iterator = _streaming_generate(
|
|
310
252
|
model=model,
|
|
@@ -317,8 +259,7 @@ def streaming_generate(
|
|
|
317
259
|
pixel_values=pixel_values,
|
|
318
260
|
tokens_per_image=tokens_per_image,
|
|
319
261
|
suppress_tokens=suppress_tokens,
|
|
320
|
-
device=device
|
|
321
|
-
reasoning_budget=reasoning_budget
|
|
262
|
+
device=device
|
|
322
263
|
)
|
|
323
264
|
|
|
324
265
|
for (token, is_full_result) in text_iterator:
|
|
@@ -338,8 +279,7 @@ def generate(
|
|
|
338
279
|
pixel_values: Optional[torch.Tensor] = None,
|
|
339
280
|
tokens_per_image: int = -1,
|
|
340
281
|
suppress_tokens: Optional[List[int]] = None,
|
|
341
|
-
device: Union[str, torch.device, int] = None
|
|
342
|
-
reasoning_budget: Optional[int] = None
|
|
282
|
+
device: Union[str, torch.device, int] = None
|
|
343
283
|
):
|
|
344
284
|
text_iterator = _streaming_generate(
|
|
345
285
|
model=model,
|
|
@@ -352,8 +292,7 @@ def generate(
|
|
|
352
292
|
suppress_tokens=suppress_tokens,
|
|
353
293
|
pixel_values=pixel_values,
|
|
354
294
|
tokens_per_image=tokens_per_image,
|
|
355
|
-
device=device
|
|
356
|
-
reasoning_budget=reasoning_budget
|
|
295
|
+
device=device
|
|
357
296
|
)
|
|
358
297
|
|
|
359
298
|
for (token, is_full_result) in text_iterator:
|
llm_trainer/tokenizer.py
CHANGED
|
@@ -26,8 +26,8 @@ class Tokenizer:
|
|
|
26
26
|
self.text_user = '<user>'
|
|
27
27
|
self.text_assistant = '<assistant>'
|
|
28
28
|
|
|
29
|
-
self.
|
|
30
|
-
self.
|
|
29
|
+
self.text_think_start = '<think>'
|
|
30
|
+
self.text_think_end = '</think>'
|
|
31
31
|
|
|
32
32
|
self.text_answer_start = '<answer>'
|
|
33
33
|
self.text_answer_end = '</answer>'
|
|
@@ -47,8 +47,8 @@ class Tokenizer:
|
|
|
47
47
|
additional_special_tokens = [
|
|
48
48
|
AddedToken(self.text_user, lstrip=False, rstrip=False),
|
|
49
49
|
AddedToken(self.text_assistant, lstrip=False, rstrip=False),
|
|
50
|
-
AddedToken(self.
|
|
51
|
-
AddedToken(self.
|
|
50
|
+
AddedToken(self.text_think_start, lstrip=False, rstrip=False),
|
|
51
|
+
AddedToken(self.text_think_end, lstrip=False, rstrip=False),
|
|
52
52
|
AddedToken(self.text_answer_start, lstrip=False, rstrip=False),
|
|
53
53
|
AddedToken(self.text_answer_end, lstrip=False, rstrip=False),
|
|
54
54
|
AddedToken(self.text_system, lstrip=False, rstrip=False),
|
|
@@ -69,8 +69,8 @@ class Tokenizer:
|
|
|
69
69
|
self.user = self.tokenizer.convert_tokens_to_ids(self.text_user)
|
|
70
70
|
self.assistant = self.tokenizer.convert_tokens_to_ids(self.text_assistant)
|
|
71
71
|
|
|
72
|
-
self.
|
|
73
|
-
self.
|
|
72
|
+
self.think_start = self.tokenizer.convert_tokens_to_ids(self.text_think_start)
|
|
73
|
+
self.think_end = self.tokenizer.convert_tokens_to_ids(self.text_think_end)
|
|
74
74
|
|
|
75
75
|
self.answer_start = self.tokenizer.convert_tokens_to_ids(self.text_answer_start)
|
|
76
76
|
self.answer_end = self.tokenizer.convert_tokens_to_ids(self.text_answer_end)
|
|
@@ -140,9 +140,9 @@ class Tokenizer:
|
|
|
140
140
|
{"role":"user", "content":"hello?"},
|
|
141
141
|
{"role":"assistant", "content":"hello"},
|
|
142
142
|
{"role":"user", "content":"hello hello?"},
|
|
143
|
-
{"role":"assistant", "
|
|
143
|
+
{"role":"assistant", "think":"thinking", "content":"hello hello"},
|
|
144
144
|
]
|
|
145
|
-
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><
|
|
145
|
+
<system>{system_prompt}</s><user>hello?</s><assistant>hello</s><user>hello hello?</s><assistant><think>thinking</think><answer>hello hello</answer></s>
|
|
146
146
|
"""
|
|
147
147
|
|
|
148
148
|
chat_template = ''
|
|
@@ -154,8 +154,8 @@ class Tokenizer:
|
|
|
154
154
|
if add_answer_tag_for_assistant and role == 'assistant':
|
|
155
155
|
content = f"{self.text_answer_start}{content}{self.text_answer_end}"
|
|
156
156
|
|
|
157
|
-
if '
|
|
158
|
-
content = f"{self.
|
|
157
|
+
if 'think' in conversation:
|
|
158
|
+
content = f"{self.text_think_start}{conversation['think']}{self.text_think_end}{content}"
|
|
159
159
|
|
|
160
160
|
chat_template = f"{chat_template}{support_roles[role]}{content}{self.text_end}"
|
|
161
161
|
|
|
@@ -4,7 +4,7 @@ llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
|
4
4
|
llm_trainer/dpo_trainer.py,sha256=wMREatLt0I8Ajdm_sI2U8Zj-IN1L6txP9s_tH1oI3-s,11431
|
|
5
5
|
llm_trainer/ds_checkpoint.py,sha256=wz48HoLBBt8QGO1tXfvJwrXoiGtPG_gjwHfEqARllso,2175
|
|
6
6
|
llm_trainer/eval.py,sha256=fjASCILU3fSPJxo9cP3rIXEEnkc5ZlUyHqXlZtUiHrw,888
|
|
7
|
-
llm_trainer/generate_utils.py,sha256=
|
|
7
|
+
llm_trainer/generate_utils.py,sha256=CbJ3mfAD6DkQ0GUHcJQ1AK02m-ocwmd-BPXEpiwvNNQ,14933
|
|
8
8
|
llm_trainer/grpo_trainer.py,sha256=qiC3KwxYPSB9UKqyk4eSRvORP3b6GM-2ozqI8u3QvI0,15568
|
|
9
9
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
10
10
|
llm_trainer/loss.py,sha256=NZCQeUXnLSj__mmDflE8g89KgE0emAJXIab0IERCLno,6023
|
|
@@ -15,19 +15,19 @@ llm_trainer/parallel_none.py,sha256=TG6Pm829Dg-yQu-97O-EHV3FCARBlNcP47KkGFAs16E,
|
|
|
15
15
|
llm_trainer/partition_utils.py,sha256=xzv8kwlbKp3dai2pBwX89gN5ymeHk1bGbTkGru5H-UM,5167
|
|
16
16
|
llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
17
17
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
18
|
-
llm_trainer/tokenizer.py,sha256=
|
|
18
|
+
llm_trainer/tokenizer.py,sha256=SSpgXtb0e1NtQqRW0gCq09TTZi47umggy-Fh5EMHKJg,6708
|
|
19
19
|
llm_trainer/tools.py,sha256=yF17lp6oOfLe2XJeKDQ1juZcbv-6vFamJSLwEeArduA,2975
|
|
20
20
|
llm_trainer/train_configs.py,sha256=m57W71SI5VCCU9aJ_nJkB-3AJrSGiNXmV28rdpuYmLg,7332
|
|
21
21
|
llm_trainer/trainer.py,sha256=zTJVyY1cAjJdTkyXCOy2ZPVP18SOMLdWhD54Mz2JRe4,25314
|
|
22
22
|
llm_trainer/utils.py,sha256=LWNhyQ0NDEZ9mZtk2Ryvh6EulvHIaUGIflugSpqmeFI,6791
|
|
23
|
-
project_llm_trainer-0.5.
|
|
24
|
-
project_llm_trainer-0.5.
|
|
25
|
-
project_llm_trainer-0.5.
|
|
26
|
-
project_llm_trainer-0.5.
|
|
27
|
-
project_llm_trainer-0.5.
|
|
28
|
-
project_llm_trainer-0.5.
|
|
29
|
-
project_llm_trainer-0.5.
|
|
30
|
-
project_llm_trainer-0.5.
|
|
31
|
-
project_llm_trainer-0.5.
|
|
32
|
-
project_llm_trainer-0.5.
|
|
33
|
-
project_llm_trainer-0.5.
|
|
23
|
+
project_llm_trainer-0.5.1.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
24
|
+
project_llm_trainer-0.5.1.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
25
|
+
project_llm_trainer-0.5.1.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
26
|
+
project_llm_trainer-0.5.1.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
27
|
+
project_llm_trainer-0.5.1.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
28
|
+
project_llm_trainer-0.5.1.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
29
|
+
project_llm_trainer-0.5.1.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
30
|
+
project_llm_trainer-0.5.1.dist-info/METADATA,sha256=x-Bobn0EH7wyKznJydUeVLK9sdIrkBmDYDbEpyG4pKc,195
|
|
31
|
+
project_llm_trainer-0.5.1.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
32
|
+
project_llm_trainer-0.5.1.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
33
|
+
project_llm_trainer-0.5.1.dist-info/RECORD,,
|
{project_llm_trainer-0.5.0.data → project_llm_trainer-0.5.1.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|