project-llm-trainer 0.13.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of project-llm-trainer might be problematic. Click here for more details.
- llm_trainer/__init__.py +13 -0
- llm_trainer/base_trainer.py +707 -0
- llm_trainer/checkpoint.py +114 -0
- llm_trainer/dataset.py +335 -0
- llm_trainer/dpo_trainer.py +311 -0
- llm_trainer/ds_checkpoint.py +72 -0
- llm_trainer/eval.py +33 -0
- llm_trainer/generate_utils.py +463 -0
- llm_trainer/grpo_trainer.py +410 -0
- llm_trainer/log.py +65 -0
- llm_trainer/loss.py +266 -0
- llm_trainer/parallel.py +220 -0
- llm_trainer/partition_utils.py +219 -0
- llm_trainer/ppo_trainer.py +686 -0
- llm_trainer/scheduler.py +220 -0
- llm_trainer/sft_trainer.py +97 -0
- llm_trainer/tokenizer.py +162 -0
- llm_trainer/tools.py +116 -0
- llm_trainer/train_configs.py +327 -0
- llm_trainer/trainer.py +34 -0
- llm_trainer/utils.py +630 -0
- project_llm_trainer-0.13.4.data/scripts/calc_intermediate_size +15 -0
- project_llm_trainer-0.13.4.data/scripts/ddp_train +21 -0
- project_llm_trainer-0.13.4.data/scripts/ds_train +17 -0
- project_llm_trainer-0.13.4.data/scripts/py_train +12 -0
- project_llm_trainer-0.13.4.data/scripts/smart_train +37 -0
- project_llm_trainer-0.13.4.data/scripts/vis_log +98 -0
- project_llm_trainer-0.13.4.data/scripts/vis_lr +46 -0
- project_llm_trainer-0.13.4.dist-info/METADATA +9 -0
- project_llm_trainer-0.13.4.dist-info/RECORD +32 -0
- project_llm_trainer-0.13.4.dist-info/WHEEL +5 -0
- project_llm_trainer-0.13.4.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,463 @@
|
|
|
1
|
+
from typing import Union, Optional, List
|
|
2
|
+
import torch
|
|
3
|
+
from llm_model import VlmModel, KVCache
|
|
4
|
+
from .tools import TrainerTools
|
|
5
|
+
from .utils import (
|
|
6
|
+
autocast,
|
|
7
|
+
batch_repeat_image_tok,
|
|
8
|
+
calc_position_ids
|
|
9
|
+
)
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
|
|
13
|
+
"""
|
|
14
|
+
抑制特殊token输出
|
|
15
|
+
:param logits:
|
|
16
|
+
:param suppress_tokens:
|
|
17
|
+
:return:
|
|
18
|
+
"""
|
|
19
|
+
suppress_tokens = torch.tensor(suppress_tokens, device=logits.device)
|
|
20
|
+
vocab_tensor = torch.arange(logits.shape[-1], device=logits.device)
|
|
21
|
+
suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
|
|
22
|
+
logits = torch.where(suppress_token_mask, -float("inf"), logits)
|
|
23
|
+
|
|
24
|
+
return logits
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def _temperature_warper(logits: torch.Tensor, temperature: float) -> torch.Tensor:
|
|
28
|
+
"""
|
|
29
|
+
应用temperature
|
|
30
|
+
:param logits:
|
|
31
|
+
:param temperature:
|
|
32
|
+
:return:
|
|
33
|
+
"""
|
|
34
|
+
logits = logits / temperature
|
|
35
|
+
return logits
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _top_k_warper(logits: torch.Tensor, k: int, device: Union[str, torch.device, int] = None) -> torch.Tensor:
|
|
39
|
+
"""
|
|
40
|
+
top k采样
|
|
41
|
+
:param logits:
|
|
42
|
+
:param k:
|
|
43
|
+
:param device:
|
|
44
|
+
:return:
|
|
45
|
+
"""
|
|
46
|
+
# [batch, k]
|
|
47
|
+
topk_logits, _ = torch.topk(logits, k=k)
|
|
48
|
+
# []
|
|
49
|
+
min_val: torch.Tensor = topk_logits[:, -1]
|
|
50
|
+
logits = torch.where(logits < min_val.unsqueeze(-1), torch.tensor(-torch.inf).to(device), logits)
|
|
51
|
+
return logits
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _top_p_warper(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
|
|
55
|
+
"""
|
|
56
|
+
top p 核采样
|
|
57
|
+
:param logits:
|
|
58
|
+
:param p:
|
|
59
|
+
:param min_tokens_to_keep:
|
|
60
|
+
:return:
|
|
61
|
+
"""
|
|
62
|
+
# 正序排列 eg: [0.1, 0.2, 0.3]
|
|
63
|
+
sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
|
|
64
|
+
# cumsum求和, 每一个元素的值都是与之前元素的求和
|
|
65
|
+
# 例如:torch.cumsum(torch.tensor([[0.1, 0.2, 0.3]]), dim=-1) 结果是: [0.1, 0.3, 0.6]
|
|
66
|
+
cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
|
|
67
|
+
# 删除累积概率<=1-p的部分, 因为cumulative_probs是正序排列的,所以要用1-p
|
|
68
|
+
# 例如:
|
|
69
|
+
# 假设 p=0.9,并且经过排序和计算后,我们有以下的 cumulative_probs
|
|
70
|
+
# cumulative_probs = [0.1, 0.3, 0.7, 0.92, 0.98]
|
|
71
|
+
# 那么 (1 - p) 就是 0.1
|
|
72
|
+
# 执行 cumulative_probs <= (1 - p) 后,得到的 sorted_indices_to_remove 就是
|
|
73
|
+
# sorted_indices_to_remove = [True, False, False, False, False]
|
|
74
|
+
# 这意味着,累积概率小于等于 0.1 的词(也就是第一个词)应该被移除
|
|
75
|
+
# 为什么是 (1 - p)?
|
|
76
|
+
# 这里使用 (1 - p) 的原因是为了方便后续的处理。在实际的代码中,
|
|
77
|
+
# 通常会将 sorted_indices_to_remove 向右移动一位,并将第一个元素设置为 False。
|
|
78
|
+
# 这样做是为了保留至少一个词,即使第一个词的概率非常小。
|
|
79
|
+
# 通过使用 (1 - p),我们可以直接使用 cumulative_probs 进行比较,而不需要额外的步骤来处理第一个词
|
|
80
|
+
sorted_indices_to_remove = cumulative_probs <= (1 - p)
|
|
81
|
+
# 保证至少有min_tokens_to_keep个token保留
|
|
82
|
+
# 例如:
|
|
83
|
+
# sorted_indices_to_remove=[True, True, True],min_tokens_to_keep=1时
|
|
84
|
+
# 该操作后sorted_indices_to_remove=[True, True, False]
|
|
85
|
+
sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
|
|
86
|
+
# 下面步骤是将排序后确定要删除的元素index映射回非排序的元素index
|
|
87
|
+
# scatter函数根据 index 中提供的索引,将 src 中的值复制到 tensor 中。
|
|
88
|
+
# 举例说明, 假设我们有一个 batch,词汇表大小为 5,并且有以下数据
|
|
89
|
+
# sorted_indices = [[2, 0, 4, 1, 3]] (排序后的索引)
|
|
90
|
+
# sorted_indices_to_remove = [[False, True, False, True, False]] (排序后的移除掩码)
|
|
91
|
+
# 执行sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 后,得到的 indices_to_remove 将是
|
|
92
|
+
# indices_to_remove = [[True, True, False, False, False]]
|
|
93
|
+
indices_to_remove = sorted_indices_to_remove.scatter(1, index=sorted_indices, src=sorted_indices_to_remove)
|
|
94
|
+
|
|
95
|
+
# 将需要移除的元素的值设置为-inf
|
|
96
|
+
scores_processed = logits.masked_fill_(indices_to_remove, -float("Inf"))
|
|
97
|
+
|
|
98
|
+
return scores_processed
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def _generate(
|
|
102
|
+
model: torch.nn.Module,
|
|
103
|
+
*,
|
|
104
|
+
tokens: torch.Tensor,
|
|
105
|
+
max_new_tokens: int,
|
|
106
|
+
temperature: Optional[float],
|
|
107
|
+
k: Optional[int],
|
|
108
|
+
p: Optional[float],
|
|
109
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
110
|
+
tokens_per_image: int = -1,
|
|
111
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
112
|
+
device: Union[str, torch.device, int]
|
|
113
|
+
):
|
|
114
|
+
"""
|
|
115
|
+
:param model:
|
|
116
|
+
:param tokens:
|
|
117
|
+
:param max_new_tokens:
|
|
118
|
+
:param temperature: 设置None不不生效temperature
|
|
119
|
+
:param k: top k参数,设置为None或者0不生效topk
|
|
120
|
+
:param p: top p参数,设置为None不生效top p
|
|
121
|
+
:param suppress_tokens: 要抑制的tokens
|
|
122
|
+
:param device:
|
|
123
|
+
|
|
124
|
+
如果内容质量底,需要减小temperature、k、p
|
|
125
|
+
如果temperature很大但内容单一,需要增大k、p
|
|
126
|
+
"""
|
|
127
|
+
use_kv_cache = True
|
|
128
|
+
|
|
129
|
+
# 确保输入维度是 [Batch, Seq]
|
|
130
|
+
if tokens.dim() == 1:
|
|
131
|
+
tokens = tokens.unsqueeze(0)
|
|
132
|
+
|
|
133
|
+
if isinstance(model, VlmModel):
|
|
134
|
+
tokens, _ = batch_repeat_image_tok(tokens, tokens_per_image)
|
|
135
|
+
|
|
136
|
+
attention_mask = torch.ones_like(tokens, device=device, dtype=torch.long)
|
|
137
|
+
|
|
138
|
+
kv_cache: Optional[KVCache] = None
|
|
139
|
+
if use_kv_cache:
|
|
140
|
+
# Prompt Length + Max Generation Length
|
|
141
|
+
total_capacity = tokens.shape[1] + max_new_tokens
|
|
142
|
+
kv_cache = KVCache(max_capacity=total_capacity)
|
|
143
|
+
|
|
144
|
+
if pixel_values is not None:
|
|
145
|
+
pixel_values = pixel_values.to(device)
|
|
146
|
+
|
|
147
|
+
generate_tokens = tokens.clone()
|
|
148
|
+
|
|
149
|
+
with torch.inference_mode():
|
|
150
|
+
for _ in range(max_new_tokens):
|
|
151
|
+
t = tokens
|
|
152
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
153
|
+
result = model(
|
|
154
|
+
t,
|
|
155
|
+
attention_mask=attention_mask,
|
|
156
|
+
past_key_values=kv_cache,
|
|
157
|
+
use_cache=use_kv_cache,
|
|
158
|
+
pixel_values=pixel_values
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
# logits (batch, seq_len, vocab_size)
|
|
162
|
+
logits = result['logits']
|
|
163
|
+
|
|
164
|
+
# (batch, vocab_size)
|
|
165
|
+
logits = logits[:, -1, :]
|
|
166
|
+
|
|
167
|
+
# 抑制特殊token输出
|
|
168
|
+
if suppress_tokens and len(suppress_tokens) != 0:
|
|
169
|
+
logits = _suppress_warper(logits, suppress_tokens)
|
|
170
|
+
|
|
171
|
+
multinomial = False
|
|
172
|
+
if temperature and temperature > 0:
|
|
173
|
+
multinomial = True
|
|
174
|
+
logits = _temperature_warper(logits, temperature)
|
|
175
|
+
|
|
176
|
+
if k and k != 0:
|
|
177
|
+
logits = _top_k_warper(logits, k, device)
|
|
178
|
+
|
|
179
|
+
if p and 0 < p <= 1:
|
|
180
|
+
logits = _top_p_warper(logits, p)
|
|
181
|
+
|
|
182
|
+
if multinomial:
|
|
183
|
+
prob = logits.softmax(dim=-1)
|
|
184
|
+
# 返回下标
|
|
185
|
+
next_token = torch.multinomial(prob, num_samples=1)
|
|
186
|
+
else:
|
|
187
|
+
# 返回下标
|
|
188
|
+
next_token = logits.argmax(dim=-1, keepdim=True)
|
|
189
|
+
|
|
190
|
+
# token, is_full_result
|
|
191
|
+
yield next_token, False
|
|
192
|
+
|
|
193
|
+
if use_kv_cache:
|
|
194
|
+
tokens = next_token
|
|
195
|
+
generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
|
|
196
|
+
else:
|
|
197
|
+
tokens = torch.cat((tokens, next_token), dim=-1)
|
|
198
|
+
|
|
199
|
+
# 更新 mask:追加 1
|
|
200
|
+
new_mask_bit = torch.ones((tokens.shape[0], 1), device=device, dtype=attention_mask.dtype)
|
|
201
|
+
attention_mask = torch.cat((attention_mask, new_mask_bit), dim=-1)
|
|
202
|
+
|
|
203
|
+
if next_token.item() == TrainerTools().tokenizer.end:
|
|
204
|
+
break
|
|
205
|
+
|
|
206
|
+
# token, is_full_result
|
|
207
|
+
yield tokens if not use_kv_cache else generate_tokens, True
|
|
208
|
+
|
|
209
|
+
|
|
210
|
+
def _streaming_generate(
|
|
211
|
+
model: torch.nn.Module,
|
|
212
|
+
*,
|
|
213
|
+
prompt: Union[str, torch.Tensor],
|
|
214
|
+
max_new_tokens: int,
|
|
215
|
+
temperature: Optional[float] = 1.0,
|
|
216
|
+
k: Optional[int] = None,
|
|
217
|
+
p: Optional[float] = None,
|
|
218
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
219
|
+
tokens_per_image: int = -1,
|
|
220
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
221
|
+
device: Union[str, torch.device, int] = None
|
|
222
|
+
):
|
|
223
|
+
device = TrainerTools().parallel.device if not device else device
|
|
224
|
+
|
|
225
|
+
if isinstance(prompt, torch.Tensor):
|
|
226
|
+
encoded_tokens = prompt.to(device)
|
|
227
|
+
else:
|
|
228
|
+
encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
|
|
229
|
+
|
|
230
|
+
generate_text_iterator = _generate(
|
|
231
|
+
model=model,
|
|
232
|
+
tokens=encoded_tokens,
|
|
233
|
+
max_new_tokens=max_new_tokens,
|
|
234
|
+
temperature=temperature,
|
|
235
|
+
k=k,
|
|
236
|
+
p=p,
|
|
237
|
+
pixel_values=pixel_values,
|
|
238
|
+
tokens_per_image=tokens_per_image,
|
|
239
|
+
suppress_tokens=suppress_tokens,
|
|
240
|
+
device=device
|
|
241
|
+
)
|
|
242
|
+
|
|
243
|
+
for (token, is_full_result) in generate_text_iterator:
|
|
244
|
+
yield token, is_full_result
|
|
245
|
+
|
|
246
|
+
|
|
247
|
+
def streaming_generate(
|
|
248
|
+
model: torch.nn.Module,
|
|
249
|
+
*,
|
|
250
|
+
prompt: Union[str, torch.Tensor],
|
|
251
|
+
max_new_tokens: int,
|
|
252
|
+
temperature: Optional[float] = 1.0,
|
|
253
|
+
k: Optional[int] = None,
|
|
254
|
+
p: Optional[float] = None,
|
|
255
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
256
|
+
tokens_per_image: int = -1,
|
|
257
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
258
|
+
device: Union[str, torch.device, int] = None,
|
|
259
|
+
return_token: bool = False
|
|
260
|
+
):
|
|
261
|
+
text_iterator = _streaming_generate(
|
|
262
|
+
model=model,
|
|
263
|
+
prompt=prompt,
|
|
264
|
+
max_new_tokens=max_new_tokens,
|
|
265
|
+
temperature=temperature,
|
|
266
|
+
k=k,
|
|
267
|
+
p=p,
|
|
268
|
+
pixel_values=pixel_values,
|
|
269
|
+
tokens_per_image=tokens_per_image,
|
|
270
|
+
suppress_tokens=suppress_tokens,
|
|
271
|
+
device=device
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
for (token, is_full_result) in text_iterator:
|
|
275
|
+
if not is_full_result:
|
|
276
|
+
if return_token:
|
|
277
|
+
yield token.squeeze(0)
|
|
278
|
+
else:
|
|
279
|
+
yield TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
280
|
+
|
|
281
|
+
|
|
282
|
+
def generate(
|
|
283
|
+
model: torch.nn.Module,
|
|
284
|
+
*,
|
|
285
|
+
prompt: Union[str, torch.Tensor],
|
|
286
|
+
max_new_tokens: int,
|
|
287
|
+
temperature: Optional[float] = 1.0,
|
|
288
|
+
k: Optional[int] = None,
|
|
289
|
+
p: Optional[float] = None,
|
|
290
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
291
|
+
tokens_per_image: int = -1,
|
|
292
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
293
|
+
device: Union[str, torch.device, int] = None,
|
|
294
|
+
return_token: bool = False
|
|
295
|
+
):
|
|
296
|
+
text_iterator = _streaming_generate(
|
|
297
|
+
model=model,
|
|
298
|
+
prompt=prompt,
|
|
299
|
+
max_new_tokens=max_new_tokens,
|
|
300
|
+
temperature=temperature,
|
|
301
|
+
k=k,
|
|
302
|
+
p=p,
|
|
303
|
+
suppress_tokens=suppress_tokens,
|
|
304
|
+
pixel_values=pixel_values,
|
|
305
|
+
tokens_per_image=tokens_per_image,
|
|
306
|
+
device=device
|
|
307
|
+
)
|
|
308
|
+
|
|
309
|
+
for (token, is_full_result) in text_iterator:
|
|
310
|
+
if is_full_result:
|
|
311
|
+
if return_token:
|
|
312
|
+
return token.squeeze(0)
|
|
313
|
+
else:
|
|
314
|
+
return TrainerTools().tokenizer.decode(token.squeeze(0))
|
|
315
|
+
|
|
316
|
+
return None
|
|
317
|
+
|
|
318
|
+
|
|
319
|
+
def batch_generate(
|
|
320
|
+
model: torch.nn.Module,
|
|
321
|
+
*,
|
|
322
|
+
tokens: torch.Tensor,
|
|
323
|
+
attention_mask: torch.Tensor,
|
|
324
|
+
max_new_tokens: int,
|
|
325
|
+
temperature: Optional[float] = None,
|
|
326
|
+
k: Optional[int] = None,
|
|
327
|
+
p: Optional[float] = None,
|
|
328
|
+
pixel_values: Optional[torch.Tensor] = None,
|
|
329
|
+
tokens_per_image: int = -1,
|
|
330
|
+
suppress_tokens: Optional[List[int]] = None,
|
|
331
|
+
device: Union[str, torch.device, int],
|
|
332
|
+
return_logits: bool = True
|
|
333
|
+
):
|
|
334
|
+
use_kv_cache = True
|
|
335
|
+
end_token = TrainerTools().tokenizer.end
|
|
336
|
+
pad_token_id = TrainerTools().tokenizer.pad
|
|
337
|
+
|
|
338
|
+
if isinstance(model, VlmModel):
|
|
339
|
+
tokens, attention_mask = batch_repeat_image_tok(tokens, tokens_per_image, attention_mask)
|
|
340
|
+
|
|
341
|
+
if pixel_values is not None:
|
|
342
|
+
pixel_values = pixel_values.to(device)
|
|
343
|
+
|
|
344
|
+
orig_tokens = tokens.clone()
|
|
345
|
+
full_attention_mask = attention_mask.clone()
|
|
346
|
+
|
|
347
|
+
# 初始化 position_ids,处理 left padding
|
|
348
|
+
position_ids = calc_position_ids(full_attention_mask)
|
|
349
|
+
|
|
350
|
+
kv_cache: Optional[KVCache] = None
|
|
351
|
+
batch_size = tokens.shape[0]
|
|
352
|
+
|
|
353
|
+
if use_kv_cache:
|
|
354
|
+
# Prompt Length + Max Generation Length
|
|
355
|
+
total_capacity = tokens.shape[1] + max_new_tokens
|
|
356
|
+
kv_cache = KVCache(max_capacity=total_capacity)
|
|
357
|
+
|
|
358
|
+
# 预分配最大长度,避免循环中 cat 造成内存碎片
|
|
359
|
+
generated_tokens_buffer = torch.full(
|
|
360
|
+
(batch_size, max_new_tokens),
|
|
361
|
+
pad_token_id,
|
|
362
|
+
dtype=torch.long,
|
|
363
|
+
device=device
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
done = torch.zeros(batch_size, dtype=torch.bool, device=device)
|
|
367
|
+
current_tokens = tokens
|
|
368
|
+
|
|
369
|
+
padded_logits = None
|
|
370
|
+
actual_gen_len = 0
|
|
371
|
+
|
|
372
|
+
pad_token_tensor = torch.tensor(pad_token_id, device=device, dtype=torch.long)
|
|
373
|
+
|
|
374
|
+
with torch.inference_mode():
|
|
375
|
+
for i in range(max_new_tokens):
|
|
376
|
+
if done.all():
|
|
377
|
+
break
|
|
378
|
+
|
|
379
|
+
actual_gen_len = i + 1
|
|
380
|
+
|
|
381
|
+
if current_tokens.dtype != torch.long:
|
|
382
|
+
current_tokens = current_tokens.long()
|
|
383
|
+
|
|
384
|
+
if kv_cache is None:
|
|
385
|
+
current_position_ids = position_ids
|
|
386
|
+
else:
|
|
387
|
+
# 下一个位置ID基于当前mask序列的最后一个有效位置
|
|
388
|
+
# 如果kv_cache有效,当前token是上一步生成的,位置是前一个位置+1
|
|
389
|
+
# 注意:第一次迭代(Prefill)kv_cache 内部虽空,但我们传入了完整的 tokens
|
|
390
|
+
# prefill 阶段不需要单独处理 position_ids,因为我们直接传入了全量 position_ids
|
|
391
|
+
if i == 0:
|
|
392
|
+
current_position_ids = position_ids
|
|
393
|
+
else:
|
|
394
|
+
current_position_ids = position_ids[:, -1:] + 1
|
|
395
|
+
position_ids = torch.cat((position_ids, current_position_ids), dim=-1)
|
|
396
|
+
|
|
397
|
+
with autocast(TrainerTools().parallel.device_type):
|
|
398
|
+
result = model(
|
|
399
|
+
current_tokens,
|
|
400
|
+
attention_mask=full_attention_mask,
|
|
401
|
+
position_ids=current_position_ids,
|
|
402
|
+
past_key_values=kv_cache,
|
|
403
|
+
use_cache=use_kv_cache,
|
|
404
|
+
pixel_values=pixel_values
|
|
405
|
+
)
|
|
406
|
+
logits = result['logits']
|
|
407
|
+
|
|
408
|
+
logits = logits[:, -1, :]
|
|
409
|
+
|
|
410
|
+
if return_logits:
|
|
411
|
+
if padded_logits is None:
|
|
412
|
+
vocab_size = logits.shape[-1]
|
|
413
|
+
padded_logits = torch.zeros(
|
|
414
|
+
(batch_size, max_new_tokens, vocab_size),
|
|
415
|
+
dtype=logits.dtype,
|
|
416
|
+
device=device
|
|
417
|
+
)
|
|
418
|
+
padded_logits[:, i, :] = logits
|
|
419
|
+
|
|
420
|
+
if suppress_tokens:
|
|
421
|
+
logits = _suppress_warper(logits, suppress_tokens)
|
|
422
|
+
|
|
423
|
+
multinomial = False
|
|
424
|
+
if temperature and temperature > 0:
|
|
425
|
+
multinomial = True
|
|
426
|
+
logits = _temperature_warper(logits, temperature)
|
|
427
|
+
if k and k != 0:
|
|
428
|
+
logits = _top_k_warper(logits, k, device)
|
|
429
|
+
if p and 0 < p <= 1:
|
|
430
|
+
logits = _top_p_warper(logits, p)
|
|
431
|
+
|
|
432
|
+
if multinomial:
|
|
433
|
+
prob = logits.softmax(dim=-1)
|
|
434
|
+
next_token_active = torch.multinomial(prob, num_samples=1)
|
|
435
|
+
else:
|
|
436
|
+
next_token_active = logits.argmax(dim=-1, keepdim=True)
|
|
437
|
+
|
|
438
|
+
next_token = torch.where(
|
|
439
|
+
done.unsqueeze(1),
|
|
440
|
+
pad_token_tensor,
|
|
441
|
+
next_token_active
|
|
442
|
+
)
|
|
443
|
+
|
|
444
|
+
generated_tokens_buffer[:, i] = next_token.squeeze(-1)
|
|
445
|
+
|
|
446
|
+
new_done = (next_token.squeeze(-1) == end_token)
|
|
447
|
+
done = done | new_done
|
|
448
|
+
|
|
449
|
+
current_tokens = next_token
|
|
450
|
+
|
|
451
|
+
new_mask = (~done).long().to(full_attention_mask.dtype)
|
|
452
|
+
full_attention_mask = torch.cat((full_attention_mask, new_mask.unsqueeze(-1)), dim=-1)
|
|
453
|
+
|
|
454
|
+
final_generated_tokens = generated_tokens_buffer[:, :actual_gen_len]
|
|
455
|
+
|
|
456
|
+
if padded_logits is not None:
|
|
457
|
+
final_padded_logits = padded_logits[:, :actual_gen_len, :]
|
|
458
|
+
else:
|
|
459
|
+
final_padded_logits = None
|
|
460
|
+
|
|
461
|
+
final_full_sequences = torch.cat((orig_tokens, final_generated_tokens), dim=1)
|
|
462
|
+
|
|
463
|
+
return final_full_sequences, final_padded_logits
|