project-llm-trainer 0.4.6__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/ds_checkpoint.py +1 -2
- llm_trainer/generate_utils.py +71 -10
- llm_trainer/train_configs.py +3 -3
- {project_llm_trainer-0.4.6.dist-info → project_llm_trainer-0.4.9.dist-info}/METADATA +1 -1
- {project_llm_trainer-0.4.6.dist-info → project_llm_trainer-0.4.9.dist-info}/RECORD +14 -14
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/calc_intermediate_size +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/ddp_train +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/ds_train +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/plot_loss +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/plot_lr +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/py_train +0 -0
- {project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/smart_train +0 -0
- {project_llm_trainer-0.4.6.dist-info → project_llm_trainer-0.4.9.dist-info}/WHEEL +0 -0
- {project_llm_trainer-0.4.6.dist-info → project_llm_trainer-0.4.9.dist-info}/top_level.txt +0 -0
llm_trainer/ds_checkpoint.py
CHANGED
|
@@ -69,7 +69,7 @@ def load_ds_checkpoint_for_eval(model: nn.Module):
|
|
|
69
69
|
|
|
70
70
|
def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
71
71
|
"""
|
|
72
|
-
|
|
72
|
+
需要在所有rank上调用,然后只有rank0有值
|
|
73
73
|
"""
|
|
74
74
|
|
|
75
75
|
if model.zero_optimization_stage() != 3:
|
|
@@ -99,7 +99,6 @@ def _get_ds_full_state_dict_on_rank0(model: DeepSpeedEngine) -> Optional[dict]:
|
|
|
99
99
|
# if TrainerTools().parallel.is_main_process:
|
|
100
100
|
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
101
101
|
# else:
|
|
102
|
-
# print("22222222")
|
|
103
102
|
# if TrainerTools().parallel.is_main_process:
|
|
104
103
|
# state_dict_on_rank_0[param_name] = param.data.to(torch.float32).cpu().clone()
|
|
105
104
|
#
|
llm_trainer/generate_utils.py
CHANGED
|
@@ -1,7 +1,6 @@
|
|
|
1
1
|
from typing import Union, Optional, List
|
|
2
2
|
from contextlib import nullcontext
|
|
3
3
|
import torch
|
|
4
|
-
import torch.distributed as dist
|
|
5
4
|
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
6
5
|
from llm_model import VlmModel, KVCache
|
|
7
6
|
from .tools import TrainerTools
|
|
@@ -109,7 +108,8 @@ def _generate(
|
|
|
109
108
|
pixel_values: Optional[torch.Tensor] = None,
|
|
110
109
|
tokens_per_image: int = -1,
|
|
111
110
|
suppress_tokens: Optional[List[int]] = None,
|
|
112
|
-
device: Union[str, torch.device, int]
|
|
111
|
+
device: Union[str, torch.device, int],
|
|
112
|
+
reasoning_budget: Optional[int] = None
|
|
113
113
|
):
|
|
114
114
|
"""
|
|
115
115
|
:param model:
|
|
@@ -144,6 +144,27 @@ def _generate(
|
|
|
144
144
|
kv_cache: Optional[KVCache] = None
|
|
145
145
|
generate_tokens = tokens.clone()
|
|
146
146
|
|
|
147
|
+
reasoning_start = TrainerTools().tokenizer.reasoning_start
|
|
148
|
+
reasoning_end = TrainerTools().tokenizer.reasoning_end
|
|
149
|
+
|
|
150
|
+
# --- 状态初始化 ---
|
|
151
|
+
in_reasoning_block = False
|
|
152
|
+
reasoning_step_count = 0
|
|
153
|
+
# “冷静期”标志位。当强制结束思考后,在下一步抑制<reasoning>的生成。
|
|
154
|
+
suppress_reasoning_start_next = False
|
|
155
|
+
|
|
156
|
+
if reasoning_budget is not None:
|
|
157
|
+
prompt_tokens = tokens[0]
|
|
158
|
+
start_indices = (prompt_tokens == reasoning_start).nonzero(as_tuple=True)[0]
|
|
159
|
+
end_indices = (prompt_tokens == reasoning_end).nonzero(as_tuple=True)[0]
|
|
160
|
+
|
|
161
|
+
last_start_idx = start_indices[-1].item() if len(start_indices) > 0 else -1
|
|
162
|
+
last_end_idx = end_indices[-1].item() if len(end_indices) > 0 else -1
|
|
163
|
+
|
|
164
|
+
if last_start_idx > last_end_idx:
|
|
165
|
+
in_reasoning_block = True
|
|
166
|
+
reasoning_step_count = len(prompt_tokens) - 1 - last_start_idx
|
|
167
|
+
|
|
147
168
|
model.eval()
|
|
148
169
|
with torch.inference_mode():
|
|
149
170
|
for _ in range(max_new_tokens):
|
|
@@ -163,6 +184,24 @@ def _generate(
|
|
|
163
184
|
|
|
164
185
|
# (batch, vocab_size)
|
|
165
186
|
logits = logits[:, -1, :]
|
|
187
|
+
|
|
188
|
+
# --- 推理预算逻辑 ---
|
|
189
|
+
force_end_reasoning_token = False
|
|
190
|
+
if reasoning_budget is not None:
|
|
191
|
+
# 检查是否需要在此步抑制 <reasoning>
|
|
192
|
+
should_suppress_this_step = suppress_reasoning_start_next
|
|
193
|
+
suppress_reasoning_start_next = False # 立即重置标志位
|
|
194
|
+
|
|
195
|
+
# 修改: 检查是否超出预算
|
|
196
|
+
if in_reasoning_block and reasoning_step_count >= reasoning_budget:
|
|
197
|
+
force_end_reasoning_token = True
|
|
198
|
+
# 设置标志位,在下一步抑制 <reasoning>
|
|
199
|
+
suppress_reasoning_start_next = True
|
|
200
|
+
|
|
201
|
+
# 如果上一轮设置了抑制标志,则在此轮执行抑制
|
|
202
|
+
if should_suppress_this_step:
|
|
203
|
+
logits[:, reasoning_start] = -float("inf")
|
|
204
|
+
|
|
166
205
|
# 抑制特殊token输出
|
|
167
206
|
if suppress_tokens and len(suppress_tokens) != 0:
|
|
168
207
|
logits = _suppress_warper(logits, suppress_tokens)
|
|
@@ -175,9 +214,13 @@ def _generate(
|
|
|
175
214
|
if k and k != 0:
|
|
176
215
|
logits = _top_k_warper(logits, k, device)
|
|
177
216
|
|
|
178
|
-
if p and
|
|
217
|
+
if p and 0 < p <= 1:
|
|
179
218
|
logits = _top_p_warper(logits, p)
|
|
180
219
|
|
|
220
|
+
if force_end_reasoning_token:
|
|
221
|
+
logits[:] = -float("inf")
|
|
222
|
+
logits[:, reasoning_end] = 0.0
|
|
223
|
+
|
|
181
224
|
if multinomial:
|
|
182
225
|
prob = logits.softmax(dim=-1)
|
|
183
226
|
# 返回下标
|
|
@@ -186,6 +229,18 @@ def _generate(
|
|
|
186
229
|
# 返回下标
|
|
187
230
|
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
188
231
|
|
|
232
|
+
if reasoning_budget is not None:
|
|
233
|
+
current_token_id = next_token.item()
|
|
234
|
+
if not in_reasoning_block and current_token_id == reasoning_start:
|
|
235
|
+
in_reasoning_block = True
|
|
236
|
+
reasoning_step_count = 0
|
|
237
|
+
elif in_reasoning_block:
|
|
238
|
+
if current_token_id == reasoning_end:
|
|
239
|
+
in_reasoning_block = False
|
|
240
|
+
reasoning_step_count = 0
|
|
241
|
+
else:
|
|
242
|
+
reasoning_step_count += 1
|
|
243
|
+
|
|
189
244
|
# token, is_full_result
|
|
190
245
|
yield next_token, False
|
|
191
246
|
|
|
@@ -210,11 +265,12 @@ def _streaming_generate(
|
|
|
210
265
|
max_new_tokens: int,
|
|
211
266
|
temperature: Optional[float] = 1.0,
|
|
212
267
|
k: Optional[int] = None,
|
|
213
|
-
p: Optional[float] =
|
|
268
|
+
p: Optional[float] = None,
|
|
214
269
|
pixel_values: Optional[torch.Tensor] = None,
|
|
215
270
|
tokens_per_image: int = -1,
|
|
216
271
|
suppress_tokens: Optional[List[int]] = None,
|
|
217
272
|
device: Union[str, torch.device, int] = None,
|
|
273
|
+
reasoning_budget: Optional[int] = None
|
|
218
274
|
):
|
|
219
275
|
device = TrainerTools().parallel.device if not device else device
|
|
220
276
|
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
@@ -230,7 +286,8 @@ def _streaming_generate(
|
|
|
230
286
|
pixel_values=pixel_values,
|
|
231
287
|
tokens_per_image=tokens_per_image,
|
|
232
288
|
suppress_tokens=suppress_tokens,
|
|
233
|
-
device=device
|
|
289
|
+
device=device,
|
|
290
|
+
reasoning_budget=reasoning_budget
|
|
234
291
|
)
|
|
235
292
|
|
|
236
293
|
for (token, is_full_result) in generate_text_iterator:
|
|
@@ -245,11 +302,12 @@ def streaming_generate(
|
|
|
245
302
|
max_new_tokens: int,
|
|
246
303
|
temperature: Optional[float] = 1.0,
|
|
247
304
|
k: Optional[int] = None,
|
|
248
|
-
p: Optional[float] =
|
|
305
|
+
p: Optional[float] = None,
|
|
249
306
|
pixel_values: Optional[torch.Tensor] = None,
|
|
250
307
|
tokens_per_image: int = -1,
|
|
251
308
|
suppress_tokens: Optional[List[int]] = None,
|
|
252
309
|
device: Union[str, torch.device, int] = None,
|
|
310
|
+
reasoning_budget: Optional[int] = None
|
|
253
311
|
):
|
|
254
312
|
text_iterator = _streaming_generate(
|
|
255
313
|
model=model,
|
|
@@ -262,7 +320,8 @@ def streaming_generate(
|
|
|
262
320
|
pixel_values=pixel_values,
|
|
263
321
|
tokens_per_image=tokens_per_image,
|
|
264
322
|
suppress_tokens=suppress_tokens,
|
|
265
|
-
device=device
|
|
323
|
+
device=device,
|
|
324
|
+
reasoning_budget=reasoning_budget
|
|
266
325
|
)
|
|
267
326
|
|
|
268
327
|
for (token, is_full_result) in text_iterator:
|
|
@@ -278,11 +337,12 @@ def generate(
|
|
|
278
337
|
max_new_tokens: int,
|
|
279
338
|
temperature: Optional[float] = 1.0,
|
|
280
339
|
k: Optional[int] = None,
|
|
281
|
-
p: Optional[float] =
|
|
340
|
+
p: Optional[float] = None,
|
|
282
341
|
pixel_values: Optional[torch.Tensor] = None,
|
|
283
342
|
tokens_per_image: int = -1,
|
|
284
343
|
suppress_tokens: Optional[List[int]] = None,
|
|
285
344
|
device: Union[str, torch.device, int] = None,
|
|
345
|
+
reasoning_budget: Optional[int] = None
|
|
286
346
|
):
|
|
287
347
|
text_iterator = _streaming_generate(
|
|
288
348
|
model=model,
|
|
@@ -295,7 +355,8 @@ def generate(
|
|
|
295
355
|
suppress_tokens=suppress_tokens,
|
|
296
356
|
pixel_values=pixel_values,
|
|
297
357
|
tokens_per_image=tokens_per_image,
|
|
298
|
-
device=device
|
|
358
|
+
device=device,
|
|
359
|
+
reasoning_budget=reasoning_budget
|
|
299
360
|
)
|
|
300
361
|
|
|
301
362
|
for (token, is_full_result) in text_iterator:
|
|
@@ -382,7 +443,7 @@ def batch_generate(
|
|
|
382
443
|
if k and k != 0:
|
|
383
444
|
logits = _top_k_warper(logits, k, device)
|
|
384
445
|
|
|
385
|
-
if p and
|
|
446
|
+
if p and 0 < p <= 1:
|
|
386
447
|
logits = _top_p_warper(logits, p)
|
|
387
448
|
|
|
388
449
|
prob = logits.softmax(dim=-1)
|
llm_trainer/train_configs.py
CHANGED
|
@@ -22,11 +22,11 @@ class DsActivationCheckpointingConfig:
|
|
|
22
22
|
self,
|
|
23
23
|
*,
|
|
24
24
|
partition_activations: bool = True,
|
|
25
|
-
cpu_checkpointing: bool =
|
|
25
|
+
cpu_checkpointing: bool = False,
|
|
26
26
|
contiguous_memory_optimization: bool = True,
|
|
27
27
|
number_checkpoints: Optional[int] = None,
|
|
28
|
-
synchronize_checkpoint_boundary: bool =
|
|
29
|
-
profile: bool =
|
|
28
|
+
synchronize_checkpoint_boundary: bool = False,
|
|
29
|
+
profile: bool = False
|
|
30
30
|
):
|
|
31
31
|
self.partition_activations =partition_activations
|
|
32
32
|
self.cpu_checkpointing = cpu_checkpointing
|
|
@@ -3,10 +3,10 @@ llm_trainer/checkpoint.py,sha256=yZcExxneN2yzvWxRiK-pstMWs35LV7GiOfqcLq-S6vc,574
|
|
|
3
3
|
llm_trainer/dataset.py,sha256=4QlOo0SFB5816BUYegQjgobUqTUMQvdmZMM_OEAMSjE,4347
|
|
4
4
|
llm_trainer/dcp.py,sha256=PkD97DyrOtoTKn4FJsfL3VqAy4dxufgjdzJEz8-Cnoc,3635
|
|
5
5
|
llm_trainer/dpo_trainer.py,sha256=rC_I5ipesSlP3gFK_SG2GB8NbgJAMu4K7KLxkAS-aRY,13406
|
|
6
|
-
llm_trainer/ds_checkpoint.py,sha256=
|
|
6
|
+
llm_trainer/ds_checkpoint.py,sha256=x_tjgJR47P8gVwV4qAnTUCGwx7eVq2Epw0vOVV7fkYo,4925
|
|
7
7
|
llm_trainer/eval.py,sha256=NDm8PbXLch7xT81xPYPRCNrcrB_Xj5GDJSCxyVwUOp4,1524
|
|
8
8
|
llm_trainer/fsdp_checkpoint.py,sha256=lqZFzHyWyfzuCq_81kQNtJd2qaiMeY1N5BCEMnrJTBw,3192
|
|
9
|
-
llm_trainer/generate_utils.py,sha256=
|
|
9
|
+
llm_trainer/generate_utils.py,sha256=tSbA_tLqSq5qJGHSOlPv5T3iRDZkbFg5ZvDAgJ_i_SE,17946
|
|
10
10
|
llm_trainer/grpo_trainer.py,sha256=bZPrxhyPQLAnFzWhI7hhA6fpuKVNwj7nOm9k0ku9aK4,15977
|
|
11
11
|
llm_trainer/log.py,sha256=LxqTGRNZUGMTSQCePRpk-rYyxSnSIbT4kOdP8Fbzr0M,462
|
|
12
12
|
llm_trainer/loss.py,sha256=Yv3fsaVuZ5AhnGPJOr5vEMb_tM2urR6mCb4DBbrHHI8,6030
|
|
@@ -19,17 +19,17 @@ llm_trainer/scheduler.py,sha256=Xz8HhwoRMjRe41sf_NHhpZfkTlEs0I2MYusvMY6hCVw,3531
|
|
|
19
19
|
llm_trainer/sft_trainer.py,sha256=gxQA7T1o1QGUsHp2CX1Qb_fO5LppBJuNbc0H4ixCYUA,1783
|
|
20
20
|
llm_trainer/tokenizer.py,sha256=A7TYYUbtPf75kjCvWP7yBui4xZBObMk2aPem62YpwpY,6776
|
|
21
21
|
llm_trainer/tools.py,sha256=O45-20wRmh-nyTfU-U-XtjbKAoe7boEIsUvWT_NaKx4,3041
|
|
22
|
-
llm_trainer/train_configs.py,sha256=
|
|
22
|
+
llm_trainer/train_configs.py,sha256=gzTXMLUuQexRvqyKIZQ1U6ESa0DELD7hPpYZdrDcyxg,15974
|
|
23
23
|
llm_trainer/trainer.py,sha256=Zy1oesBfsFlDedZ4hn3gcAkTrpi5fr76bFFQikfAkak,25351
|
|
24
24
|
llm_trainer/utils.py,sha256=-ivhMF0d999va13S1wt2uBvtVw8Nvr3uBzhaUFKL04Q,6826
|
|
25
|
-
project_llm_trainer-0.4.
|
|
26
|
-
project_llm_trainer-0.4.
|
|
27
|
-
project_llm_trainer-0.4.
|
|
28
|
-
project_llm_trainer-0.4.
|
|
29
|
-
project_llm_trainer-0.4.
|
|
30
|
-
project_llm_trainer-0.4.
|
|
31
|
-
project_llm_trainer-0.4.
|
|
32
|
-
project_llm_trainer-0.4.
|
|
33
|
-
project_llm_trainer-0.4.
|
|
34
|
-
project_llm_trainer-0.4.
|
|
35
|
-
project_llm_trainer-0.4.
|
|
25
|
+
project_llm_trainer-0.4.9.data/scripts/calc_intermediate_size,sha256=AggpgNHokJiJMbEtVdOnolqr_4bH3i1UYuZNEAzC2Gc,460
|
|
26
|
+
project_llm_trainer-0.4.9.data/scripts/ddp_train,sha256=x81AasaN2-9TwARFFF1l7iV1LmfMQ0bLw0i_CGbOwSw,299
|
|
27
|
+
project_llm_trainer-0.4.9.data/scripts/ds_train,sha256=qL3qc3TcedBCw98UZUjW07ONcErRawLE1HymW2AmscA,265
|
|
28
|
+
project_llm_trainer-0.4.9.data/scripts/plot_loss,sha256=MzFcdJESlVr1srj4Td6-AxPGUKkfB_QEcJwm0Bd-5fU,910
|
|
29
|
+
project_llm_trainer-0.4.9.data/scripts/plot_lr,sha256=w_7XR_x3KYYyboeOVAeu_I4fveLFI-C0wBmRrNlmWUI,894
|
|
30
|
+
project_llm_trainer-0.4.9.data/scripts/py_train,sha256=tOp9TquORQeU8XN5H7OVIk5O0Ypwi34p_GENxTwgwdk,265
|
|
31
|
+
project_llm_trainer-0.4.9.data/scripts/smart_train,sha256=Pmt4Q0to4Hoz82iB9uFPZuz7uahNUbfE7FR1940EBy8,716
|
|
32
|
+
project_llm_trainer-0.4.9.dist-info/METADATA,sha256=Xn5uE_6i2vFPTFf2O2iJT59czHzHxTAUeb6ZDCaH0-A,195
|
|
33
|
+
project_llm_trainer-0.4.9.dist-info/WHEEL,sha256=Nw36Djuh_5VDukK0H78QzOX-_FQEo6V37m3nkm96gtU,91
|
|
34
|
+
project_llm_trainer-0.4.9.dist-info/top_level.txt,sha256=LtRFg28i0QIG7iBCD2t095oSco99LCtkijibS9cMGik,12
|
|
35
|
+
project_llm_trainer-0.4.9.dist-info/RECORD,,
|
{project_llm_trainer-0.4.6.data → project_llm_trainer-0.4.9.data}/scripts/calc_intermediate_size
RENAMED
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|