project-llm-trainer 0.13.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,463 @@
1
+ from typing import Union, Optional, List
2
+ import torch
3
+ from llm_model import VlmModel, KVCache
4
+ from .tools import TrainerTools
5
+ from .utils import (
6
+ autocast,
7
+ batch_repeat_image_tok,
8
+ calc_position_ids
9
+ )
10
+
11
+
12
+ def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
13
+ """
14
+ 抑制特殊token输出
15
+ :param logits:
16
+ :param suppress_tokens:
17
+ :return:
18
+ """
19
+ suppress_tokens = torch.tensor(suppress_tokens, device=logits.device)
20
+ vocab_tensor = torch.arange(logits.shape[-1], device=logits.device)
21
+ suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
22
+ logits = torch.where(suppress_token_mask, -float("inf"), logits)
23
+
24
+ return logits
25
+
26
+
27
+ def _temperature_warper(logits: torch.Tensor, temperature: float) -> torch.Tensor:
28
+ """
29
+ 应用temperature
30
+ :param logits:
31
+ :param temperature:
32
+ :return:
33
+ """
34
+ logits = logits / temperature
35
+ return logits
36
+
37
+
38
+ def _top_k_warper(logits: torch.Tensor, k: int, device: Union[str, torch.device, int] = None) -> torch.Tensor:
39
+ """
40
+ top k采样
41
+ :param logits:
42
+ :param k:
43
+ :param device:
44
+ :return:
45
+ """
46
+ # [batch, k]
47
+ topk_logits, _ = torch.topk(logits, k=k)
48
+ # []
49
+ min_val: torch.Tensor = topk_logits[:, -1]
50
+ logits = torch.where(logits < min_val.unsqueeze(-1), torch.tensor(-torch.inf).to(device), logits)
51
+ return logits
52
+
53
+
54
+ def _top_p_warper(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
55
+ """
56
+ top p 核采样
57
+ :param logits:
58
+ :param p:
59
+ :param min_tokens_to_keep:
60
+ :return:
61
+ """
62
+ # 正序排列 eg: [0.1, 0.2, 0.3]
63
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
64
+ # cumsum求和, 每一个元素的值都是与之前元素的求和
65
+ # 例如:torch.cumsum(torch.tensor([[0.1, 0.2, 0.3]]), dim=-1) 结果是: [0.1, 0.3, 0.6]
66
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
67
+ # 删除累积概率<=1-p的部分, 因为cumulative_probs是正序排列的,所以要用1-p
68
+ # 例如:
69
+ # 假设 p=0.9,并且经过排序和计算后,我们有以下的 cumulative_probs
70
+ # cumulative_probs = [0.1, 0.3, 0.7, 0.92, 0.98]
71
+ # 那么 (1 - p) 就是 0.1
72
+ # 执行 cumulative_probs <= (1 - p) 后,得到的 sorted_indices_to_remove 就是
73
+ # sorted_indices_to_remove = [True, False, False, False, False]
74
+ # 这意味着,累积概率小于等于 0.1 的词(也就是第一个词)应该被移除
75
+ # 为什么是 (1 - p)?
76
+ # 这里使用 (1 - p) 的原因是为了方便后续的处理。在实际的代码中,
77
+ # 通常会将 sorted_indices_to_remove 向右移动一位,并将第一个元素设置为 False。
78
+ # 这样做是为了保留至少一个词,即使第一个词的概率非常小。
79
+ # 通过使用 (1 - p),我们可以直接使用 cumulative_probs 进行比较,而不需要额外的步骤来处理第一个词
80
+ sorted_indices_to_remove = cumulative_probs <= (1 - p)
81
+ # 保证至少有min_tokens_to_keep个token保留
82
+ # 例如:
83
+ # sorted_indices_to_remove=[True, True, True],min_tokens_to_keep=1时
84
+ # 该操作后sorted_indices_to_remove=[True, True, False]
85
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
86
+ # 下面步骤是将排序后确定要删除的元素index映射回非排序的元素index
87
+ # scatter函数根据 index 中提供的索引,将 src 中的值复制到 tensor 中。
88
+ # 举例说明, 假设我们有一个 batch,词汇表大小为 5,并且有以下数据
89
+ # sorted_indices = [[2, 0, 4, 1, 3]] (排序后的索引)
90
+ # sorted_indices_to_remove = [[False, True, False, True, False]] (排序后的移除掩码)
91
+ # 执行sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 后,得到的 indices_to_remove 将是
92
+ # indices_to_remove = [[True, True, False, False, False]]
93
+ indices_to_remove = sorted_indices_to_remove.scatter(1, index=sorted_indices, src=sorted_indices_to_remove)
94
+
95
+ # 将需要移除的元素的值设置为-inf
96
+ scores_processed = logits.masked_fill_(indices_to_remove, -float("Inf"))
97
+
98
+ return scores_processed
99
+
100
+
101
+ def _generate(
102
+ model: torch.nn.Module,
103
+ *,
104
+ tokens: torch.Tensor,
105
+ max_new_tokens: int,
106
+ temperature: Optional[float],
107
+ k: Optional[int],
108
+ p: Optional[float],
109
+ pixel_values: Optional[torch.Tensor] = None,
110
+ tokens_per_image: int = -1,
111
+ suppress_tokens: Optional[List[int]] = None,
112
+ device: Union[str, torch.device, int]
113
+ ):
114
+ """
115
+ :param model:
116
+ :param tokens:
117
+ :param max_new_tokens:
118
+ :param temperature: 设置None不不生效temperature
119
+ :param k: top k参数,设置为None或者0不生效topk
120
+ :param p: top p参数,设置为None不生效top p
121
+ :param suppress_tokens: 要抑制的tokens
122
+ :param device:
123
+
124
+ 如果内容质量底,需要减小temperature、k、p
125
+ 如果temperature很大但内容单一,需要增大k、p
126
+ """
127
+ use_kv_cache = True
128
+
129
+ # 确保输入维度是 [Batch, Seq]
130
+ if tokens.dim() == 1:
131
+ tokens = tokens.unsqueeze(0)
132
+
133
+ if isinstance(model, VlmModel):
134
+ tokens, _ = batch_repeat_image_tok(tokens, tokens_per_image)
135
+
136
+ attention_mask = torch.ones_like(tokens, device=device, dtype=torch.long)
137
+
138
+ kv_cache: Optional[KVCache] = None
139
+ if use_kv_cache:
140
+ # Prompt Length + Max Generation Length
141
+ total_capacity = tokens.shape[1] + max_new_tokens
142
+ kv_cache = KVCache(max_capacity=total_capacity)
143
+
144
+ if pixel_values is not None:
145
+ pixel_values = pixel_values.to(device)
146
+
147
+ generate_tokens = tokens.clone()
148
+
149
+ with torch.inference_mode():
150
+ for _ in range(max_new_tokens):
151
+ t = tokens
152
+ with autocast(TrainerTools().parallel.device_type):
153
+ result = model(
154
+ t,
155
+ attention_mask=attention_mask,
156
+ past_key_values=kv_cache,
157
+ use_cache=use_kv_cache,
158
+ pixel_values=pixel_values
159
+ )
160
+
161
+ # logits (batch, seq_len, vocab_size)
162
+ logits = result['logits']
163
+
164
+ # (batch, vocab_size)
165
+ logits = logits[:, -1, :]
166
+
167
+ # 抑制特殊token输出
168
+ if suppress_tokens and len(suppress_tokens) != 0:
169
+ logits = _suppress_warper(logits, suppress_tokens)
170
+
171
+ multinomial = False
172
+ if temperature and temperature > 0:
173
+ multinomial = True
174
+ logits = _temperature_warper(logits, temperature)
175
+
176
+ if k and k != 0:
177
+ logits = _top_k_warper(logits, k, device)
178
+
179
+ if p and 0 < p <= 1:
180
+ logits = _top_p_warper(logits, p)
181
+
182
+ if multinomial:
183
+ prob = logits.softmax(dim=-1)
184
+ # 返回下标
185
+ next_token = torch.multinomial(prob, num_samples=1)
186
+ else:
187
+ # 返回下标
188
+ next_token = logits.argmax(dim=-1, keepdim=True)
189
+
190
+ # token, is_full_result
191
+ yield next_token, False
192
+
193
+ if use_kv_cache:
194
+ tokens = next_token
195
+ generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
196
+ else:
197
+ tokens = torch.cat((tokens, next_token), dim=-1)
198
+
199
+ # 更新 mask:追加 1
200
+ new_mask_bit = torch.ones((tokens.shape[0], 1), device=device, dtype=attention_mask.dtype)
201
+ attention_mask = torch.cat((attention_mask, new_mask_bit), dim=-1)
202
+
203
+ if next_token.item() == TrainerTools().tokenizer.end:
204
+ break
205
+
206
+ # token, is_full_result
207
+ yield tokens if not use_kv_cache else generate_tokens, True
208
+
209
+
210
+ def _streaming_generate(
211
+ model: torch.nn.Module,
212
+ *,
213
+ prompt: Union[str, torch.Tensor],
214
+ max_new_tokens: int,
215
+ temperature: Optional[float] = 1.0,
216
+ k: Optional[int] = None,
217
+ p: Optional[float] = None,
218
+ pixel_values: Optional[torch.Tensor] = None,
219
+ tokens_per_image: int = -1,
220
+ suppress_tokens: Optional[List[int]] = None,
221
+ device: Union[str, torch.device, int] = None
222
+ ):
223
+ device = TrainerTools().parallel.device if not device else device
224
+
225
+ if isinstance(prompt, torch.Tensor):
226
+ encoded_tokens = prompt.to(device)
227
+ else:
228
+ encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
229
+
230
+ generate_text_iterator = _generate(
231
+ model=model,
232
+ tokens=encoded_tokens,
233
+ max_new_tokens=max_new_tokens,
234
+ temperature=temperature,
235
+ k=k,
236
+ p=p,
237
+ pixel_values=pixel_values,
238
+ tokens_per_image=tokens_per_image,
239
+ suppress_tokens=suppress_tokens,
240
+ device=device
241
+ )
242
+
243
+ for (token, is_full_result) in generate_text_iterator:
244
+ yield token, is_full_result
245
+
246
+
247
+ def streaming_generate(
248
+ model: torch.nn.Module,
249
+ *,
250
+ prompt: Union[str, torch.Tensor],
251
+ max_new_tokens: int,
252
+ temperature: Optional[float] = 1.0,
253
+ k: Optional[int] = None,
254
+ p: Optional[float] = None,
255
+ pixel_values: Optional[torch.Tensor] = None,
256
+ tokens_per_image: int = -1,
257
+ suppress_tokens: Optional[List[int]] = None,
258
+ device: Union[str, torch.device, int] = None,
259
+ return_token: bool = False
260
+ ):
261
+ text_iterator = _streaming_generate(
262
+ model=model,
263
+ prompt=prompt,
264
+ max_new_tokens=max_new_tokens,
265
+ temperature=temperature,
266
+ k=k,
267
+ p=p,
268
+ pixel_values=pixel_values,
269
+ tokens_per_image=tokens_per_image,
270
+ suppress_tokens=suppress_tokens,
271
+ device=device
272
+ )
273
+
274
+ for (token, is_full_result) in text_iterator:
275
+ if not is_full_result:
276
+ if return_token:
277
+ yield token.squeeze(0)
278
+ else:
279
+ yield TrainerTools().tokenizer.decode(token.squeeze(0))
280
+
281
+
282
+ def generate(
283
+ model: torch.nn.Module,
284
+ *,
285
+ prompt: Union[str, torch.Tensor],
286
+ max_new_tokens: int,
287
+ temperature: Optional[float] = 1.0,
288
+ k: Optional[int] = None,
289
+ p: Optional[float] = None,
290
+ pixel_values: Optional[torch.Tensor] = None,
291
+ tokens_per_image: int = -1,
292
+ suppress_tokens: Optional[List[int]] = None,
293
+ device: Union[str, torch.device, int] = None,
294
+ return_token: bool = False
295
+ ):
296
+ text_iterator = _streaming_generate(
297
+ model=model,
298
+ prompt=prompt,
299
+ max_new_tokens=max_new_tokens,
300
+ temperature=temperature,
301
+ k=k,
302
+ p=p,
303
+ suppress_tokens=suppress_tokens,
304
+ pixel_values=pixel_values,
305
+ tokens_per_image=tokens_per_image,
306
+ device=device
307
+ )
308
+
309
+ for (token, is_full_result) in text_iterator:
310
+ if is_full_result:
311
+ if return_token:
312
+ return token.squeeze(0)
313
+ else:
314
+ return TrainerTools().tokenizer.decode(token.squeeze(0))
315
+
316
+ return None
317
+
318
+
319
+ def batch_generate(
320
+ model: torch.nn.Module,
321
+ *,
322
+ tokens: torch.Tensor,
323
+ attention_mask: torch.Tensor,
324
+ max_new_tokens: int,
325
+ temperature: Optional[float] = None,
326
+ k: Optional[int] = None,
327
+ p: Optional[float] = None,
328
+ pixel_values: Optional[torch.Tensor] = None,
329
+ tokens_per_image: int = -1,
330
+ suppress_tokens: Optional[List[int]] = None,
331
+ device: Union[str, torch.device, int],
332
+ return_logits: bool = True
333
+ ):
334
+ use_kv_cache = True
335
+ end_token = TrainerTools().tokenizer.end
336
+ pad_token_id = TrainerTools().tokenizer.pad
337
+
338
+ if isinstance(model, VlmModel):
339
+ tokens, attention_mask = batch_repeat_image_tok(tokens, tokens_per_image, attention_mask)
340
+
341
+ if pixel_values is not None:
342
+ pixel_values = pixel_values.to(device)
343
+
344
+ orig_tokens = tokens.clone()
345
+ full_attention_mask = attention_mask.clone()
346
+
347
+ # 初始化 position_ids,处理 left padding
348
+ position_ids = calc_position_ids(full_attention_mask)
349
+
350
+ kv_cache: Optional[KVCache] = None
351
+ batch_size = tokens.shape[0]
352
+
353
+ if use_kv_cache:
354
+ # Prompt Length + Max Generation Length
355
+ total_capacity = tokens.shape[1] + max_new_tokens
356
+ kv_cache = KVCache(max_capacity=total_capacity)
357
+
358
+ # 预分配最大长度,避免循环中 cat 造成内存碎片
359
+ generated_tokens_buffer = torch.full(
360
+ (batch_size, max_new_tokens),
361
+ pad_token_id,
362
+ dtype=torch.long,
363
+ device=device
364
+ )
365
+
366
+ done = torch.zeros(batch_size, dtype=torch.bool, device=device)
367
+ current_tokens = tokens
368
+
369
+ padded_logits = None
370
+ actual_gen_len = 0
371
+
372
+ pad_token_tensor = torch.tensor(pad_token_id, device=device, dtype=torch.long)
373
+
374
+ with torch.inference_mode():
375
+ for i in range(max_new_tokens):
376
+ if done.all():
377
+ break
378
+
379
+ actual_gen_len = i + 1
380
+
381
+ if current_tokens.dtype != torch.long:
382
+ current_tokens = current_tokens.long()
383
+
384
+ if kv_cache is None:
385
+ current_position_ids = position_ids
386
+ else:
387
+ # 下一个位置ID基于当前mask序列的最后一个有效位置
388
+ # 如果kv_cache有效,当前token是上一步生成的,位置是前一个位置+1
389
+ # 注意:第一次迭代(Prefill)kv_cache 内部虽空,但我们传入了完整的 tokens
390
+ # prefill 阶段不需要单独处理 position_ids,因为我们直接传入了全量 position_ids
391
+ if i == 0:
392
+ current_position_ids = position_ids
393
+ else:
394
+ current_position_ids = position_ids[:, -1:] + 1
395
+ position_ids = torch.cat((position_ids, current_position_ids), dim=-1)
396
+
397
+ with autocast(TrainerTools().parallel.device_type):
398
+ result = model(
399
+ current_tokens,
400
+ attention_mask=full_attention_mask,
401
+ position_ids=current_position_ids,
402
+ past_key_values=kv_cache,
403
+ use_cache=use_kv_cache,
404
+ pixel_values=pixel_values
405
+ )
406
+ logits = result['logits']
407
+
408
+ logits = logits[:, -1, :]
409
+
410
+ if return_logits:
411
+ if padded_logits is None:
412
+ vocab_size = logits.shape[-1]
413
+ padded_logits = torch.zeros(
414
+ (batch_size, max_new_tokens, vocab_size),
415
+ dtype=logits.dtype,
416
+ device=device
417
+ )
418
+ padded_logits[:, i, :] = logits
419
+
420
+ if suppress_tokens:
421
+ logits = _suppress_warper(logits, suppress_tokens)
422
+
423
+ multinomial = False
424
+ if temperature and temperature > 0:
425
+ multinomial = True
426
+ logits = _temperature_warper(logits, temperature)
427
+ if k and k != 0:
428
+ logits = _top_k_warper(logits, k, device)
429
+ if p and 0 < p <= 1:
430
+ logits = _top_p_warper(logits, p)
431
+
432
+ if multinomial:
433
+ prob = logits.softmax(dim=-1)
434
+ next_token_active = torch.multinomial(prob, num_samples=1)
435
+ else:
436
+ next_token_active = logits.argmax(dim=-1, keepdim=True)
437
+
438
+ next_token = torch.where(
439
+ done.unsqueeze(1),
440
+ pad_token_tensor,
441
+ next_token_active
442
+ )
443
+
444
+ generated_tokens_buffer[:, i] = next_token.squeeze(-1)
445
+
446
+ new_done = (next_token.squeeze(-1) == end_token)
447
+ done = done | new_done
448
+
449
+ current_tokens = next_token
450
+
451
+ new_mask = (~done).long().to(full_attention_mask.dtype)
452
+ full_attention_mask = torch.cat((full_attention_mask, new_mask.unsqueeze(-1)), dim=-1)
453
+
454
+ final_generated_tokens = generated_tokens_buffer[:, :actual_gen_len]
455
+
456
+ if padded_logits is not None:
457
+ final_padded_logits = padded_logits[:, :actual_gen_len, :]
458
+ else:
459
+ final_padded_logits = None
460
+
461
+ final_full_sequences = torch.cat((orig_tokens, final_generated_tokens), dim=1)
462
+
463
+ return final_full_sequences, final_padded_logits