project-llm-trainer 0.12.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,450 @@
1
+ from typing import Union, Optional, List
2
+ import torch
3
+ from llm_model import VlmModel, KVCache
4
+ from .tools import TrainerTools
5
+ from .utils import (
6
+ autocast,
7
+ batch_repeat_image_tok,
8
+ calc_position_ids
9
+ )
10
+
11
+
12
+ def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
13
+ """
14
+ 抑制特殊token输出
15
+ :param logits:
16
+ :param suppress_tokens:
17
+ :return:
18
+ """
19
+ suppress_tokens = torch.tensor(suppress_tokens, device=logits.device)
20
+ vocab_tensor = torch.arange(logits.shape[-1], device=logits.device)
21
+ suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
22
+ logits = torch.where(suppress_token_mask, -float("inf"), logits)
23
+
24
+ return logits
25
+
26
+
27
+ def _temperature_warper(logits: torch.Tensor, temperature: float) -> torch.Tensor:
28
+ """
29
+ 应用temperature
30
+ :param logits:
31
+ :param temperature:
32
+ :return:
33
+ """
34
+ logits = logits / temperature
35
+ return logits
36
+
37
+
38
+ def _top_k_warper(logits: torch.Tensor, k: int, device: Union[str, torch.device, int] = None) -> torch.Tensor:
39
+ """
40
+ top k采样
41
+ :param logits:
42
+ :param k:
43
+ :param device:
44
+ :return:
45
+ """
46
+ # [batch, k]
47
+ topk_logits, _ = torch.topk(logits, k=k)
48
+ # []
49
+ min_val: torch.Tensor = topk_logits[:, -1]
50
+ logits = torch.where(logits < min_val.unsqueeze(-1), torch.tensor(-torch.inf).to(device), logits)
51
+ return logits
52
+
53
+
54
+ def _top_p_warper(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
55
+ """
56
+ top p 核采样
57
+ :param logits:
58
+ :param p:
59
+ :param min_tokens_to_keep:
60
+ :return:
61
+ """
62
+ # 正序排列 eg: [0.1, 0.2, 0.3]
63
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
64
+ # cumsum求和, 每一个元素的值都是与之前元素的求和
65
+ # 例如:torch.cumsum(torch.tensor([[0.1, 0.2, 0.3]]), dim=-1) 结果是: [0.1, 0.3, 0.6]
66
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
67
+ # 删除累积概率<=1-p的部分, 因为cumulative_probs是正序排列的,所以要用1-p
68
+ # 例如:
69
+ # 假设 p=0.9,并且经过排序和计算后,我们有以下的 cumulative_probs
70
+ # cumulative_probs = [0.1, 0.3, 0.7, 0.92, 0.98]
71
+ # 那么 (1 - p) 就是 0.1
72
+ # 执行 cumulative_probs <= (1 - p) 后,得到的 sorted_indices_to_remove 就是
73
+ # sorted_indices_to_remove = [True, False, False, False, False]
74
+ # 这意味着,累积概率小于等于 0.1 的词(也就是第一个词)应该被移除
75
+ # 为什么是 (1 - p)?
76
+ # 这里使用 (1 - p) 的原因是为了方便后续的处理。在实际的代码中,
77
+ # 通常会将 sorted_indices_to_remove 向右移动一位,并将第一个元素设置为 False。
78
+ # 这样做是为了保留至少一个词,即使第一个词的概率非常小。
79
+ # 通过使用 (1 - p),我们可以直接使用 cumulative_probs 进行比较,而不需要额外的步骤来处理第一个词
80
+ sorted_indices_to_remove = cumulative_probs <= (1 - p)
81
+ # 保证至少有min_tokens_to_keep个token保留
82
+ # 例如:
83
+ # sorted_indices_to_remove=[True, True, True],min_tokens_to_keep=1时
84
+ # 该操作后sorted_indices_to_remove=[True, True, False]
85
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
86
+ # 下面步骤是将排序后确定要删除的元素index映射回非排序的元素index
87
+ # scatter函数根据 index 中提供的索引,将 src 中的值复制到 tensor 中。
88
+ # 举例说明, 假设我们有一个 batch,词汇表大小为 5,并且有以下数据
89
+ # sorted_indices = [[2, 0, 4, 1, 3]] (排序后的索引)
90
+ # sorted_indices_to_remove = [[False, True, False, True, False]] (排序后的移除掩码)
91
+ # 执行sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 后,得到的 indices_to_remove 将是
92
+ # indices_to_remove = [[True, True, False, False, False]]
93
+ indices_to_remove = sorted_indices_to_remove.scatter(1, index=sorted_indices, src=sorted_indices_to_remove)
94
+
95
+ # 将需要移除的元素的值设置为-inf
96
+ scores_processed = logits.masked_fill_(indices_to_remove, -float("Inf"))
97
+
98
+ return scores_processed
99
+
100
+
101
+ def _generate(
102
+ model: torch.nn.Module,
103
+ *,
104
+ tokens: torch.Tensor,
105
+ max_new_tokens: int,
106
+ temperature: Optional[float],
107
+ k: Optional[int],
108
+ p: Optional[float],
109
+ pixel_values: Optional[torch.Tensor] = None,
110
+ tokens_per_image: int = -1,
111
+ suppress_tokens: Optional[List[int]] = None,
112
+ device: Union[str, torch.device, int]
113
+ ):
114
+ """
115
+ :param model:
116
+ :param tokens:
117
+ :param max_new_tokens:
118
+ :param temperature: 设置None不不生效temperature
119
+ :param k: top k参数,设置为None或者0不生效topk
120
+ :param p: top p参数,设置为None不生效top p
121
+ :param suppress_tokens: 要抑制的tokens
122
+ :param device:
123
+
124
+ 如果内容质量底,需要减小temperature、k、p
125
+ 如果temperature很大但内容单一,需要增大k、p
126
+ """
127
+ use_kv_cache = True
128
+
129
+ # 确保输入维度是 [Batch, Seq]
130
+ if tokens.dim() == 1:
131
+ tokens = tokens.unsqueeze(0)
132
+
133
+ attention_mask = torch.ones_like(tokens, device=device, dtype=torch.long)
134
+
135
+ if isinstance(model, VlmModel):
136
+ tokens = batch_repeat_image_tok(tokens, tokens_per_image)
137
+
138
+ if pixel_values is not None:
139
+ pixel_values = pixel_values.to(device)
140
+
141
+ kv_cache: Optional[KVCache] = None
142
+ generate_tokens = tokens.clone()
143
+
144
+ with torch.inference_mode():
145
+ for _ in range(max_new_tokens):
146
+ t = tokens
147
+ with autocast(TrainerTools().parallel.device_type):
148
+ result = model(
149
+ t,
150
+ attention_mask=attention_mask,
151
+ past_key_values=kv_cache,
152
+ use_cache=use_kv_cache,
153
+ pixel_values=pixel_values
154
+ )
155
+
156
+ # logits (batch, seq_len, vocab_size)
157
+ logits = result['logits']
158
+ kv_cache = result['past_key_values']
159
+
160
+ # (batch, vocab_size)
161
+ logits = logits[:, -1, :]
162
+
163
+ # 抑制特殊token输出
164
+ if suppress_tokens and len(suppress_tokens) != 0:
165
+ logits = _suppress_warper(logits, suppress_tokens)
166
+
167
+ multinomial = False
168
+ if temperature and temperature > 0:
169
+ multinomial = True
170
+ logits = _temperature_warper(logits, temperature)
171
+
172
+ if k and k != 0:
173
+ logits = _top_k_warper(logits, k, device)
174
+
175
+ if p and 0 < p <= 1:
176
+ logits = _top_p_warper(logits, p)
177
+
178
+ if multinomial:
179
+ prob = logits.softmax(dim=-1)
180
+ # 返回下标
181
+ next_token = torch.multinomial(prob, num_samples=1)
182
+ else:
183
+ # 返回下标
184
+ next_token = logits.argmax(dim=-1, keepdim=True)
185
+
186
+ # token, is_full_result
187
+ yield next_token, False
188
+
189
+ if use_kv_cache:
190
+ tokens = next_token
191
+ generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
192
+ else:
193
+ tokens = torch.cat((tokens, next_token), dim=-1)
194
+
195
+ # [关键修复] 更新 mask:追加 1,让 Position ID 继续增长
196
+ new_mask_bit = torch.ones((tokens.shape[0], 1), device=device, dtype=attention_mask.dtype)
197
+ attention_mask = torch.cat((attention_mask, new_mask_bit), dim=-1)
198
+
199
+ if next_token.item() == TrainerTools().tokenizer.end:
200
+ break
201
+
202
+ # token, is_full_result
203
+ yield tokens if not use_kv_cache else generate_tokens, True
204
+
205
+
206
+ def _streaming_generate(
207
+ model: torch.nn.Module,
208
+ *,
209
+ prompt: Union[str, torch.Tensor],
210
+ max_new_tokens: int,
211
+ temperature: Optional[float] = 1.0,
212
+ k: Optional[int] = None,
213
+ p: Optional[float] = None,
214
+ pixel_values: Optional[torch.Tensor] = None,
215
+ tokens_per_image: int = -1,
216
+ suppress_tokens: Optional[List[int]] = None,
217
+ device: Union[str, torch.device, int] = None
218
+ ):
219
+ device = TrainerTools().parallel.device if not device else device
220
+
221
+ if isinstance(prompt, torch.Tensor):
222
+ encoded_tokens = prompt.to(device)
223
+ else:
224
+ encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
225
+
226
+ generate_text_iterator = _generate(
227
+ model=model,
228
+ tokens=encoded_tokens,
229
+ max_new_tokens=max_new_tokens,
230
+ temperature=temperature,
231
+ k=k,
232
+ p=p,
233
+ pixel_values=pixel_values,
234
+ tokens_per_image=tokens_per_image,
235
+ suppress_tokens=suppress_tokens,
236
+ device=device
237
+ )
238
+
239
+ for (token, is_full_result) in generate_text_iterator:
240
+ yield token, is_full_result
241
+
242
+
243
+ def streaming_generate(
244
+ model: torch.nn.Module,
245
+ *,
246
+ prompt: Union[str, torch.Tensor],
247
+ max_new_tokens: int,
248
+ temperature: Optional[float] = 1.0,
249
+ k: Optional[int] = None,
250
+ p: Optional[float] = None,
251
+ pixel_values: Optional[torch.Tensor] = None,
252
+ tokens_per_image: int = -1,
253
+ suppress_tokens: Optional[List[int]] = None,
254
+ device: Union[str, torch.device, int] = None,
255
+ return_token: bool = False
256
+ ):
257
+ text_iterator = _streaming_generate(
258
+ model=model,
259
+ prompt=prompt,
260
+ max_new_tokens=max_new_tokens,
261
+ temperature=temperature,
262
+ k=k,
263
+ p=p,
264
+ pixel_values=pixel_values,
265
+ tokens_per_image=tokens_per_image,
266
+ suppress_tokens=suppress_tokens,
267
+ device=device
268
+ )
269
+
270
+ for (token, is_full_result) in text_iterator:
271
+ if not is_full_result:
272
+ if return_token:
273
+ yield token.squeeze(0)
274
+ else:
275
+ yield TrainerTools().tokenizer.decode(token.squeeze(0))
276
+
277
+
278
+ def generate(
279
+ model: torch.nn.Module,
280
+ *,
281
+ prompt: Union[str, torch.Tensor],
282
+ max_new_tokens: int,
283
+ temperature: Optional[float] = 1.0,
284
+ k: Optional[int] = None,
285
+ p: Optional[float] = None,
286
+ pixel_values: Optional[torch.Tensor] = None,
287
+ tokens_per_image: int = -1,
288
+ suppress_tokens: Optional[List[int]] = None,
289
+ device: Union[str, torch.device, int] = None,
290
+ return_token: bool = False
291
+ ):
292
+ text_iterator = _streaming_generate(
293
+ model=model,
294
+ prompt=prompt,
295
+ max_new_tokens=max_new_tokens,
296
+ temperature=temperature,
297
+ k=k,
298
+ p=p,
299
+ suppress_tokens=suppress_tokens,
300
+ pixel_values=pixel_values,
301
+ tokens_per_image=tokens_per_image,
302
+ device=device
303
+ )
304
+
305
+ for (token, is_full_result) in text_iterator:
306
+ if is_full_result:
307
+ if return_token:
308
+ return token.squeeze(0)
309
+ else:
310
+ return TrainerTools().tokenizer.decode(token.squeeze(0))
311
+
312
+ return None
313
+
314
+
315
+ def batch_generate(
316
+ model: torch.nn.Module,
317
+ *,
318
+ tokens: torch.Tensor,
319
+ attention_mask: torch.Tensor,
320
+ max_new_tokens: int,
321
+ temperature: Optional[float] = None,
322
+ k: Optional[int] = None,
323
+ p: Optional[float] = None,
324
+ pixel_values: Optional[torch.Tensor] = None,
325
+ tokens_per_image: int = -1,
326
+ suppress_tokens: Optional[List[int]] = None,
327
+ device: Union[str, torch.device, int],
328
+ return_logits: bool = True
329
+ ):
330
+ use_kv_cache = True
331
+ end_token = TrainerTools().tokenizer.end
332
+ pad_token_id = TrainerTools().tokenizer.pad
333
+
334
+ if isinstance(model, VlmModel):
335
+ tokens = batch_repeat_image_tok(tokens, tokens_per_image)
336
+
337
+ if pixel_values is not None:
338
+ pixel_values = pixel_values.to(device)
339
+
340
+ orig_tokens = tokens.clone()
341
+ full_attention_mask = attention_mask.clone()
342
+
343
+ # 初始化 position_ids,处理 left padding
344
+ position_ids = calc_position_ids(full_attention_mask)
345
+
346
+ kv_cache: Optional[KVCache] = None
347
+ batch_size = tokens.shape[0]
348
+
349
+ # 预分配最大长度,避免循环中 cat 造成内存碎片
350
+ generated_tokens_buffer = torch.full(
351
+ (batch_size, max_new_tokens),
352
+ pad_token_id,
353
+ dtype=torch.long,
354
+ device=device
355
+ )
356
+
357
+ done = torch.zeros(batch_size, dtype=torch.bool, device=device)
358
+ current_tokens = tokens
359
+
360
+ padded_logits = None
361
+ actual_gen_len = 0
362
+
363
+ pad_token_tensor = torch.tensor(pad_token_id, device=device, dtype=torch.long)
364
+
365
+ with torch.inference_mode():
366
+ for i in range(max_new_tokens):
367
+ if done.all():
368
+ break
369
+
370
+ actual_gen_len = i + 1
371
+
372
+ if current_tokens.dtype != torch.long:
373
+ current_tokens = current_tokens.long()
374
+
375
+ if kv_cache is None:
376
+ current_position_ids = position_ids
377
+ else:
378
+ # 下一个位置ID基于当前mask序列的最后一个有效位置
379
+ # 如果kv_cache有效,当前token是上一步生成的,位置是前一个位置+1
380
+ current_position_ids = position_ids[:, -1:] + 1
381
+ position_ids = torch.cat((position_ids, current_position_ids), dim=-1)
382
+
383
+ with autocast(TrainerTools().parallel.device_type):
384
+ result = model(
385
+ current_tokens,
386
+ attention_mask=full_attention_mask,
387
+ position_ids=current_position_ids,
388
+ past_key_values=kv_cache,
389
+ use_cache=use_kv_cache,
390
+ pixel_values=pixel_values
391
+ )
392
+ logits = result['logits']
393
+ kv_cache = result['past_key_values']
394
+
395
+ logits = logits[:, -1, :]
396
+
397
+ if return_logits:
398
+ if padded_logits is None:
399
+ vocab_size = logits.shape[-1]
400
+ padded_logits = torch.zeros(
401
+ (batch_size, max_new_tokens, vocab_size),
402
+ dtype=logits.dtype,
403
+ device=device
404
+ )
405
+ padded_logits[:, i, :] = logits
406
+
407
+ if suppress_tokens:
408
+ logits = _suppress_warper(logits, suppress_tokens)
409
+
410
+ multinomial = False
411
+ if temperature and temperature > 0:
412
+ multinomial = True
413
+ logits = _temperature_warper(logits, temperature)
414
+ if k and k != 0:
415
+ logits = _top_k_warper(logits, k, device)
416
+ if p and 0 < p <= 1:
417
+ logits = _top_p_warper(logits, p)
418
+
419
+ if multinomial:
420
+ prob = logits.softmax(dim=-1)
421
+ next_token_active = torch.multinomial(prob, num_samples=1)
422
+ else:
423
+ next_token_active = logits.argmax(dim=-1, keepdim=True)
424
+
425
+ next_token = torch.where(
426
+ done.unsqueeze(1),
427
+ pad_token_tensor,
428
+ next_token_active
429
+ )
430
+
431
+ generated_tokens_buffer[:, i] = next_token.squeeze(-1)
432
+
433
+ new_done = (next_token.squeeze(-1) == end_token)
434
+ done = done | new_done
435
+
436
+ current_tokens = next_token
437
+
438
+ new_mask = (~done).long().to(full_attention_mask.dtype)
439
+ full_attention_mask = torch.cat((full_attention_mask, new_mask.unsqueeze(-1)), dim=-1)
440
+
441
+ final_generated_tokens = generated_tokens_buffer[:, :actual_gen_len]
442
+
443
+ if padded_logits is not None:
444
+ final_padded_logits = padded_logits[:, :actual_gen_len, :]
445
+ else:
446
+ final_padded_logits = None
447
+
448
+ final_full_sequences = torch.cat((orig_tokens, final_generated_tokens), dim=1)
449
+
450
+ return final_full_sequences, final_padded_logits