project-llm-trainer 0.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +6 -0
- llm_trainer/checkpoint.py +161 -0
- llm_trainer/dataset.py +140 -0
- llm_trainer/dcp.py +93 -0
- llm_trainer/dpo_trainer.py +300 -0
- llm_trainer/ds_checkpoint.py +61 -0
- llm_trainer/eval.py +86 -0
- llm_trainer/generate_utils.py +424 -0
- llm_trainer/grpo_trainer.py +393 -0
- llm_trainer/log.py +16 -0
- llm_trainer/loss.py +171 -0
- llm_trainer/parallel.py +146 -0
- llm_trainer/parallel_ddp.py +39 -0
- llm_trainer/parallel_ds.py +45 -0
- llm_trainer/parallel_fsdp.py +115 -0
- llm_trainer/parallel_none.py +28 -0
- llm_trainer/scheduler.py +138 -0
- llm_trainer/sft_trainer.py +39 -0
- llm_trainer/tokenizer.py +166 -0
- llm_trainer/tools.py +102 -0
- llm_trainer/train_configs.py +445 -0
- llm_trainer/trainer.py +569 -0
- llm_trainer/utils.py +262 -0
- project_llm_trainer-0.3.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.3.data/scripts/ddp_train +12 -0
- project_llm_trainer-0.3.data/scripts/ds_train +12 -0
- project_llm_trainer-0.3.data/scripts/plot_loss +39 -0
- project_llm_trainer-0.3.data/scripts/plot_lr +41 -0
- project_llm_trainer-0.3.data/scripts/py_train +12 -0
- project_llm_trainer-0.3.data/scripts/smart_train +28 -0
- project_llm_trainer-0.3.dist-info/METADATA +9 -0
- project_llm_trainer-0.3.dist-info/RECORD +34 -0
- project_llm_trainer-0.3.dist-info/WHEEL +5 -0
- project_llm_trainer-0.3.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,424 @@
|
|
|
1
|
+
from typing import Union, Optional, List
|
|
2
|
+
from contextlib import nullcontext
|
|
3
|
+
import torch
|
|
4
|
+
from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
|
|
5
|
+
from llm_model import VlmModel, KVCache
|
|
6
|
+
from .tools import TrainerTools
|
|
7
|
+
from .utils import batch_repeat_image_tok
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
|
|
11
|
+
"""
|
|
12
|
+
抑制特殊token输出
|
|
13
|
+
:param logits:
|
|
14
|
+
:param suppress_tokens:
|
|
15
|
+
:return:
|
|
16
|
+
"""
|
|
17
|
+
suppress_tokens = torch.tensor(suppress_tokens, device=logits.device)
|
|
18
|
+
vocab_tensor = torch.arange(logits.shape[-1], device=logits.device)
|
|
19
|
+
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
|
|
20
|
+
logits = torch.where(suppress_token_mask, -float("inf"), logits)
|
|
21
|
+
|
|
22
|
+
return logits
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _temperature_warper(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
|
26
|
+
"""
|
|
27
|
+
应用temperature
|
|
28
|
+
:param logits:
|
|
29
|
+
:param temperature:
|
|
30
|
+
:return:
|
|
31
|
+
"""
|
|
32
|
+
logits = logits / temperature
|
|
33
|
+
return logits
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def _top_k_warper(logits: torch.Tensor, k: int, device: Union[str, torch.device, int] = None) -> torch.Tensor:
|
|
37
|
+
"""
|
|
38
|
+
top k采样
|
|
39
|
+
:param logits:
|
|
40
|
+
:param k:
|
|
41
|
+
:param device:
|
|
42
|
+
:return:
|
|
43
|
+
"""
|
|
44
|
+
# [batch, k]
|
|
45
|
+
topk_logits, _ = torch.topk(logits, k=k)
|
|
46
|
+
# []
|
|
47
|
+
min_val: torch.Tensor = topk_logits[:, -1]
|
|
48
|
+
logits = torch.where(logits < min_val.unsqueeze(-1), torch.tensor(-torch.inf).to(device), logits)
|
|
49
|
+
return logits
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def _top_p_warper(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
|
|
53
|
+
"""
|
|
54
|
+
top p 核采样
|
|
55
|
+
:param logits:
|
|
56
|
+
:param p:
|
|
57
|
+
:param min_tokens_to_keep:
|
|
58
|
+
:return:
|
|
59
|
+
"""
|
|
60
|
+
# 正序排列 eg: [0.1, 0.2, 0.3]
|
|
61
|
+
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
|
|
62
|
+
# cumsum求和, 每一个元素的值都是与之前元素的求和
|
|
63
|
+
# 例如:torch.cumsum(torch.tensor([[0.1, 0.2, 0.3]]), dim=-1) 结果是: [0.1, 0.3, 0.6]
|
|
64
|
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
65
|
+
# 删除累积概率<=1-p的部分, 因为cumulative_probs是正序排列的,所以要用1-p
|
|
66
|
+
# 例如:
|
|
67
|
+
# 假设 p=0.9,并且经过排序和计算后,我们有以下的 cumulative_probs
|
|
68
|
+
# cumulative_probs = [0.1, 0.3, 0.7, 0.92, 0.98]
|
|
69
|
+
# 那么 (1 - p) 就是 0.1
|
|
70
|
+
# 执行 cumulative_probs <= (1 - p) 后,得到的 sorted_indices_to_remove 就是
|
|
71
|
+
# sorted_indices_to_remove = [True, False, False, False, False]
|
|
72
|
+
# 这意味着,累积概率小于等于 0.1 的词(也就是第一个词)应该被移除
|
|
73
|
+
# 为什么是 (1 - p)?
|
|
74
|
+
# 这里使用 (1 - p) 的原因是为了方便后续的处理。在实际的代码中,
|
|
75
|
+
# 通常会将 sorted_indices_to_remove 向右移动一位,并将第一个元素设置为 False。
|
|
76
|
+
# 这样做是为了保留至少一个词,即使第一个词的概率非常小。
|
|
77
|
+
# 通过使用 (1 - p),我们可以直接使用 cumulative_probs 进行比较,而不需要额外的步骤来处理第一个词
|
|
78
|
+
sorted_indices_to_remove = cumulative_probs <= (1 - p)
|
|
79
|
+
# 保证至少有min_tokens_to_keep个token保留
|
|
80
|
+
# 例如:
|
|
81
|
+
# sorted_indices_to_remove=[True, True, True],min_tokens_to_keep=1时
|
|
82
|
+
# 该操作后sorted_indices_to_remove=[True, True, False]
|
|
83
|
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
|
84
|
+
# 下面步骤是将排序后确定要删除的元素index映射回非排序的元素index
|
|
85
|
+
# scatter函数根据 index 中提供的索引,将 src 中的值复制到 tensor 中。
|
|
86
|
+
# 举例说明, 假设我们有一个 batch,词汇表大小为 5,并且有以下数据
|
|
87
|
+
# sorted_indices = [[2, 0, 4, 1, 3]] (排序后的索引)
|
|
88
|
+
# sorted_indices_to_remove = [[False, True, False, True, False]] (排序后的移除掩码)
|
|
89
|
+
# 执行sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 后,得到的 indices_to_remove 将是
|
|
90
|
+
# indices_to_remove = [[True, True, False, False, False]]
|
|
91
|
+
indices_to_remove = sorted_indices_to_remove.scatter(1, index=sorted_indices, src=sorted_indices_to_remove)
|
|
92
|
+
|
|
93
|
+
# 将需要移除的元素的值设置为-inf
|
|
94
|
+
scores_processed = logits.masked_fill_(indices_to_remove, -float("Inf"))
|
|
95
|
+
|
|
96
|
+
return scores_processed
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
def _generate(
|
|
100
|
+
model: torch.nn.Module,
|
|
101
|
+
*,
|
|
102
|
+
tokens: torch.Tensor,
|
|
103
|
+
max_position_embeddings: int,
|
|
104
|
+
max_new_tokens: int,
|
|
105
|
+
temperature: Optional[float],
|
|
106
|
+
k: Optional[int],
|
|
107
|
+
p: Optional[float],
|
|
108
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
109
|
+
tokens_per_image: int = -1,
|
|
110
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
111
|
+
device: Union[str, torch.device, int]
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
:param model:
|
|
115
|
+
:param tokens:
|
|
116
|
+
:param max_position_embeddings:
|
|
117
|
+
:param max_new_tokens:
|
|
118
|
+
:param temperature: 设置None不不生效temperature
|
|
119
|
+
:param k: top k参数,设置为None或者0不生效topk
|
|
120
|
+
:param p: top p参数,设置为None不生效top p
|
|
121
|
+
:param suppress_tokens: 要抑制的tokens
|
|
122
|
+
:param device:
|
|
123
|
+
|
|
124
|
+
如果内容质量底,需要减小temperature、k、p
|
|
125
|
+
如果temperature很大但内容单一,需要增大k、p
|
|
126
|
+
"""
|
|
127
|
+
use_kv_cache = True
|
|
128
|
+
|
|
129
|
+
ctx = torch.autocast(
|
|
130
|
+
device_type=device,
|
|
131
|
+
dtype=TrainerTools().dtype,
|
|
132
|
+
enabled=True,
|
|
133
|
+
# fsdp模式,需要将cache_enabled设置为false
|
|
134
|
+
cache_enabled=False if isinstance(model, FSDP) else None
|
|
135
|
+
) if TrainerTools().use_amp else nullcontext()
|
|
136
|
+
|
|
137
|
+
if isinstance(model, VlmModel):
|
|
138
|
+
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
139
|
+
|
|
140
|
+
if pixel_values is not None:
|
|
141
|
+
pixel_values = pixel_values.to(device)
|
|
142
|
+
|
|
143
|
+
kv_cache: Optional[KVCache] = None
|
|
144
|
+
generate_tokens = tokens.clone()
|
|
145
|
+
|
|
146
|
+
model.eval()
|
|
147
|
+
with torch.inference_mode():
|
|
148
|
+
for _ in range(max_new_tokens):
|
|
149
|
+
# 是否需要截取??
|
|
150
|
+
t = tokens[:, -max_position_embeddings:]
|
|
151
|
+
with ctx:
|
|
152
|
+
result = model(
|
|
153
|
+
t,
|
|
154
|
+
past_key_values=kv_cache,
|
|
155
|
+
use_cache=use_kv_cache,
|
|
156
|
+
pixel_values=pixel_values
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
# logits (batch, seq_len, vocab_size)
|
|
160
|
+
logits = result['logits']
|
|
161
|
+
kv_cache = result['past_key_values']
|
|
162
|
+
|
|
163
|
+
# (batch, vocab_size)
|
|
164
|
+
logits = logits[:, -1, :]
|
|
165
|
+
# 抑制特殊token输出
|
|
166
|
+
if suppress_tokens and len(suppress_tokens) != 0:
|
|
167
|
+
logits = _suppress_warper(logits, suppress_tokens)
|
|
168
|
+
|
|
169
|
+
multinomial = False
|
|
170
|
+
if temperature and temperature > 0:
|
|
171
|
+
multinomial = True
|
|
172
|
+
logits = _temperature_warper(logits, temperature)
|
|
173
|
+
|
|
174
|
+
if k and k != 0:
|
|
175
|
+
logits = _top_k_warper(logits, k, device)
|
|
176
|
+
|
|
177
|
+
if p and p < 1:
|
|
178
|
+
logits = _top_p_warper(logits, p)
|
|
179
|
+
|
|
180
|
+
if multinomial:
|
|
181
|
+
prob = logits.softmax(dim=-1)
|
|
182
|
+
# 返回下标
|
|
183
|
+
next_token = torch.multinomial(prob, num_samples=1)
|
|
184
|
+
else:
|
|
185
|
+
# 返回下标
|
|
186
|
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
187
|
+
|
|
188
|
+
# token, is_full_result
|
|
189
|
+
yield next_token, False
|
|
190
|
+
|
|
191
|
+
if use_kv_cache:
|
|
192
|
+
tokens = next_token
|
|
193
|
+
generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
|
|
194
|
+
else:
|
|
195
|
+
tokens = torch.cat((tokens, next_token), dim=-1)
|
|
196
|
+
|
|
197
|
+
if next_token.item() == TrainerTools().tokenizer.end:
|
|
198
|
+
break
|
|
199
|
+
|
|
200
|
+
# token, is_full_result
|
|
201
|
+
yield tokens if not use_kv_cache else generate_tokens, True
|
|
202
|
+
|
|
203
|
+
|
|
204
|
+
def _streaming_generate(
|
|
205
|
+
model: torch.nn.Module,
|
|
206
|
+
*,
|
|
207
|
+
prompt: str,
|
|
208
|
+
max_position_embeddings: int,
|
|
209
|
+
max_new_tokens: int,
|
|
210
|
+
temperature: Optional[float] = 1.0,
|
|
211
|
+
k: Optional[int] = None,
|
|
212
|
+
p: Optional[float] = 1.0,
|
|
213
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
214
|
+
tokens_per_image: int = -1,
|
|
215
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
216
|
+
device: Union[str, torch.device, int] = None,
|
|
217
|
+
):
|
|
218
|
+
device = TrainerTools().parallel.device if not device else device
|
|
219
|
+
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
220
|
+
|
|
221
|
+
generate_text_iterator = _generate(
|
|
222
|
+
model=model,
|
|
223
|
+
tokens=encoded_tokens,
|
|
224
|
+
max_position_embeddings=max_position_embeddings,
|
|
225
|
+
max_new_tokens=max_new_tokens,
|
|
226
|
+
temperature=temperature,
|
|
227
|
+
k=k,
|
|
228
|
+
p=p,
|
|
229
|
+
pixel_values=pixel_values,
|
|
230
|
+
tokens_per_image=tokens_per_image,
|
|
231
|
+
suppress_tokens=suppress_tokens,
|
|
232
|
+
device=device
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
for (token, is_full_result) in generate_text_iterator:
|
|
236
|
+
yield token, is_full_result
|
|
237
|
+
|
|
238
|
+
|
|
239
|
+
def streaming_generate(
|
|
240
|
+
model: torch.nn.Module,
|
|
241
|
+
*,
|
|
242
|
+
prompt: str,
|
|
243
|
+
max_position_embeddings: int,
|
|
244
|
+
max_new_tokens: int,
|
|
245
|
+
temperature: Optional[float] = 1.0,
|
|
246
|
+
k: Optional[int] = 50,
|
|
247
|
+
p: Optional[float] = 1.0,
|
|
248
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
249
|
+
tokens_per_image: int = -1,
|
|
250
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
251
|
+
device: Union[str, torch.device, int] = None,
|
|
252
|
+
):
|
|
253
|
+
text_iterator = _streaming_generate(
|
|
254
|
+
model=model,
|
|
255
|
+
prompt=prompt,
|
|
256
|
+
max_position_embeddings=max_position_embeddings,
|
|
257
|
+
max_new_tokens=max_new_tokens,
|
|
258
|
+
temperature=temperature,
|
|
259
|
+
k=k,
|
|
260
|
+
p=p,
|
|
261
|
+
pixel_values=pixel_values,
|
|
262
|
+
tokens_per_image=tokens_per_image,
|
|
263
|
+
suppress_tokens=suppress_tokens,
|
|
264
|
+
device=device
|
|
265
|
+
)
|
|
266
|
+
|
|
267
|
+
for (token, is_full_result) in text_iterator:
|
|
268
|
+
if not is_full_result:
|
|
269
|
+
yield TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
270
|
+
|
|
271
|
+
|
|
272
|
+
def generate(
|
|
273
|
+
model: torch.nn.Module,
|
|
274
|
+
*,
|
|
275
|
+
prompt: str,
|
|
276
|
+
max_position_embeddings: int,
|
|
277
|
+
max_new_tokens: int,
|
|
278
|
+
temperature: Optional[float] = 1.0,
|
|
279
|
+
k: Optional[int] = None,
|
|
280
|
+
p: Optional[float] = 1.0,
|
|
281
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
282
|
+
tokens_per_image: int = -1,
|
|
283
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
284
|
+
device: Union[str, torch.device, int] = None,
|
|
285
|
+
):
|
|
286
|
+
text_iterator = _streaming_generate(
|
|
287
|
+
model=model,
|
|
288
|
+
prompt=prompt,
|
|
289
|
+
max_position_embeddings=max_position_embeddings,
|
|
290
|
+
max_new_tokens=max_new_tokens,
|
|
291
|
+
temperature=temperature,
|
|
292
|
+
k=k,
|
|
293
|
+
p=p,
|
|
294
|
+
suppress_tokens=suppress_tokens,
|
|
295
|
+
pixel_values=pixel_values,
|
|
296
|
+
tokens_per_image=tokens_per_image,
|
|
297
|
+
device=device
|
|
298
|
+
)
|
|
299
|
+
|
|
300
|
+
for (token, is_full_result) in text_iterator:
|
|
301
|
+
if is_full_result:
|
|
302
|
+
return TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
def batch_generate(
|
|
306
|
+
model: torch.nn.Module,
|
|
307
|
+
*,
|
|
308
|
+
tokens: torch.Tensor,
|
|
309
|
+
pad_token_id: torch.Tensor,
|
|
310
|
+
attention_mask: torch.Tensor,
|
|
311
|
+
max_position_embeddings: int,
|
|
312
|
+
max_new_tokens: int,
|
|
313
|
+
temperature: Optional[float],
|
|
314
|
+
k: Optional[int],
|
|
315
|
+
p: Optional[float],
|
|
316
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
317
|
+
tokens_per_image: int = -1,
|
|
318
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
319
|
+
device: Union[str, torch.device, int]
|
|
320
|
+
):
|
|
321
|
+
use_kv_cache = True
|
|
322
|
+
|
|
323
|
+
ctx = torch.autocast(
|
|
324
|
+
device_type=device,
|
|
325
|
+
dtype=TrainerTools().dtype,
|
|
326
|
+
enabled=True,
|
|
327
|
+
cache_enabled=False if isinstance(model, FSDP) else None
|
|
328
|
+
) if TrainerTools().use_amp else nullcontext()
|
|
329
|
+
|
|
330
|
+
if isinstance(model, VlmModel):
|
|
331
|
+
tokens = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
332
|
+
|
|
333
|
+
if pixel_values is not None:
|
|
334
|
+
pixel_values = pixel_values.to(device)
|
|
335
|
+
|
|
336
|
+
kv_cache: Optional[KVCache] = None
|
|
337
|
+
generate_tokens = tokens.clone()
|
|
338
|
+
batch_size = tokens.shape[0]
|
|
339
|
+
|
|
340
|
+
# 初始化完成标记
|
|
341
|
+
end_token = TrainerTools().tokenizer.end
|
|
342
|
+
done = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
343
|
+
|
|
344
|
+
model.eval()
|
|
345
|
+
with torch.inference_mode():
|
|
346
|
+
for _ in range(max_new_tokens):
|
|
347
|
+
# 只处理未完成的样本
|
|
348
|
+
if done.all():
|
|
349
|
+
break
|
|
350
|
+
|
|
351
|
+
t = tokens #tokens[:, -max_position_embeddings:]
|
|
352
|
+
with ctx:
|
|
353
|
+
result = model(
|
|
354
|
+
t,
|
|
355
|
+
attention_mask=attention_mask,
|
|
356
|
+
past_key_values=kv_cache,
|
|
357
|
+
use_cache=use_kv_cache,
|
|
358
|
+
pixel_values=pixel_values
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
logits = result['logits']
|
|
362
|
+
kv_cache = result['past_key_values']
|
|
363
|
+
|
|
364
|
+
# 处理logits
|
|
365
|
+
logits = logits[:, -1, :] # (batch, vocab_size)
|
|
366
|
+
|
|
367
|
+
if done.any():
|
|
368
|
+
# 强制构造 one-hot 分布,确保 pad_token_id 概率为 1
|
|
369
|
+
logits_done = torch.full_like(logits[done], -torch.finfo(logits.dtype).max)
|
|
370
|
+
logits_done[:, pad_token_id] = 0 # pad_token_id 的 logit 设为 0
|
|
371
|
+
logits[done] = logits_done
|
|
372
|
+
|
|
373
|
+
if suppress_tokens and suppress_tokens:
|
|
374
|
+
logits = _suppress_warper(logits, suppress_tokens)
|
|
375
|
+
|
|
376
|
+
multinomial = False
|
|
377
|
+
if temperature and temperature > 0:
|
|
378
|
+
multinomial = True
|
|
379
|
+
logits = _temperature_warper(logits, temperature)
|
|
380
|
+
|
|
381
|
+
if k and k != 0:
|
|
382
|
+
logits = _top_k_warper(logits, k, device)
|
|
383
|
+
|
|
384
|
+
if p and p < 1:
|
|
385
|
+
logits = _top_p_warper(logits, p)
|
|
386
|
+
|
|
387
|
+
prob = logits.softmax(dim=-1)
|
|
388
|
+
|
|
389
|
+
# 检查并修正无效概率分布(sum <=0)[[1]][[4]]
|
|
390
|
+
prob_sum = prob.sum(dim=-1, keepdim=True)
|
|
391
|
+
invalid_mask = (prob_sum <= 0)
|
|
392
|
+
if invalid_mask.any():
|
|
393
|
+
# 为无效样本生成均匀分布,确保概率总和>0
|
|
394
|
+
uniform_prob = torch.ones_like(prob) / prob.size(-1)
|
|
395
|
+
prob = torch.where(invalid_mask.unsqueeze(-1), uniform_prob, prob)
|
|
396
|
+
|
|
397
|
+
# 抑制已完成样本(确保 pad_token_id 概率为 1)
|
|
398
|
+
prob[done] = 0
|
|
399
|
+
prob[done, pad_token_id] = 1
|
|
400
|
+
|
|
401
|
+
# 数值稳定性处理
|
|
402
|
+
if torch.isnan(prob).any() or torch.isinf(prob).any():
|
|
403
|
+
prob = torch.nan_to_num(prob, nan=0.0, posinf=1.0, neginf=0.0)
|
|
404
|
+
|
|
405
|
+
if multinomial:
|
|
406
|
+
next_token = torch.multinomial(prob, num_samples=1)
|
|
407
|
+
else:
|
|
408
|
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
409
|
+
|
|
410
|
+
# 更新完成标记
|
|
411
|
+
done = done | (next_token.squeeze(-1) == end_token)
|
|
412
|
+
|
|
413
|
+
# 拼接生成结果
|
|
414
|
+
if use_kv_cache:
|
|
415
|
+
tokens = next_token
|
|
416
|
+
generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
|
|
417
|
+
else:
|
|
418
|
+
tokens = torch.cat((tokens, next_token), dim=-1)
|
|
419
|
+
|
|
420
|
+
new_mask = torch.ones_like(next_token, dtype=torch.bool)
|
|
421
|
+
attention_mask = torch.cat((attention_mask, new_mask), dim=-1)
|
|
422
|
+
|
|
423
|
+
# 返回完整结果
|
|
424
|
+
return tokens if not use_kv_cache else generate_tokens
|