project-llm-trainer 0.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of project-llm-trainer might be problematic. Click here for more details.

@@ -0,0 +1,424 @@
1
+ from typing import Union, Optional, List
2
+ from contextlib import nullcontext
3
+ import torch
4
+ from torch.distributed.fsdp import FullyShardedDataParallel as FSDP
5
+ from llm_model import VlmModel, KVCache
6
+ from .tools import TrainerTools
7
+ from .utils import batch_repeat_image_tok
8
+
9
+
10
+ def _suppress_warper(logits: torch.Tensor, suppress_tokens: List[int]) -> torch.Tensor:
11
+ """
12
+ 抑制特殊token输出
13
+ :param logits:
14
+ :param suppress_tokens:
15
+ :return:
16
+ """
17
+ suppress_tokens = torch.tensor(suppress_tokens, device=logits.device)
18
+ vocab_tensor = torch.arange(logits.shape[-1], device=logits.device)
19
+ suppress_token_mask = torch.isin(vocab_tensor, suppress_tokens)
20
+ logits = torch.where(suppress_token_mask, -float("inf"), logits)
21
+
22
+ return logits
23
+
24
+
25
+ def _temperature_warper(logits: torch.Tensor, temperature: float) -> torch.Tensor:
26
+ """
27
+ 应用temperature
28
+ :param logits:
29
+ :param temperature:
30
+ :return:
31
+ """
32
+ logits = logits / temperature
33
+ return logits
34
+
35
+
36
+ def _top_k_warper(logits: torch.Tensor, k: int, device: Union[str, torch.device, int] = None) -> torch.Tensor:
37
+ """
38
+ top k采样
39
+ :param logits:
40
+ :param k:
41
+ :param device:
42
+ :return:
43
+ """
44
+ # [batch, k]
45
+ topk_logits, _ = torch.topk(logits, k=k)
46
+ # []
47
+ min_val: torch.Tensor = topk_logits[:, -1]
48
+ logits = torch.where(logits < min_val.unsqueeze(-1), torch.tensor(-torch.inf).to(device), logits)
49
+ return logits
50
+
51
+
52
+ def _top_p_warper(logits: torch.Tensor, p: float, min_tokens_to_keep: int = 1) -> torch.Tensor:
53
+ """
54
+ top p 核采样
55
+ :param logits:
56
+ :param p:
57
+ :param min_tokens_to_keep:
58
+ :return:
59
+ """
60
+ # 正序排列 eg: [0.1, 0.2, 0.3]
61
+ sorted_logits, sorted_indices = torch.sort(logits, dim=-1, descending=False)
62
+ # cumsum求和, 每一个元素的值都是与之前元素的求和
63
+ # 例如:torch.cumsum(torch.tensor([[0.1, 0.2, 0.3]]), dim=-1) 结果是: [0.1, 0.3, 0.6]
64
+ cumulative_probs = sorted_logits.softmax(dim=-1).cumsum(dim=-1)
65
+ # 删除累积概率<=1-p的部分, 因为cumulative_probs是正序排列的,所以要用1-p
66
+ # 例如:
67
+ # 假设 p=0.9,并且经过排序和计算后,我们有以下的 cumulative_probs
68
+ # cumulative_probs = [0.1, 0.3, 0.7, 0.92, 0.98]
69
+ # 那么 (1 - p) 就是 0.1
70
+ # 执行 cumulative_probs <= (1 - p) 后,得到的 sorted_indices_to_remove 就是
71
+ # sorted_indices_to_remove = [True, False, False, False, False]
72
+ # 这意味着,累积概率小于等于 0.1 的词(也就是第一个词)应该被移除
73
+ # 为什么是 (1 - p)?
74
+ # 这里使用 (1 - p) 的原因是为了方便后续的处理。在实际的代码中,
75
+ # 通常会将 sorted_indices_to_remove 向右移动一位,并将第一个元素设置为 False。
76
+ # 这样做是为了保留至少一个词,即使第一个词的概率非常小。
77
+ # 通过使用 (1 - p),我们可以直接使用 cumulative_probs 进行比较,而不需要额外的步骤来处理第一个词
78
+ sorted_indices_to_remove = cumulative_probs <= (1 - p)
79
+ # 保证至少有min_tokens_to_keep个token保留
80
+ # 例如:
81
+ # sorted_indices_to_remove=[True, True, True],min_tokens_to_keep=1时
82
+ # 该操作后sorted_indices_to_remove=[True, True, False]
83
+ sorted_indices_to_remove[..., -min_tokens_to_keep:] = 0
84
+ # 下面步骤是将排序后确定要删除的元素index映射回非排序的元素index
85
+ # scatter函数根据 index 中提供的索引,将 src 中的值复制到 tensor 中。
86
+ # 举例说明, 假设我们有一个 batch,词汇表大小为 5,并且有以下数据
87
+ # sorted_indices = [[2, 0, 4, 1, 3]] (排序后的索引)
88
+ # sorted_indices_to_remove = [[False, True, False, True, False]] (排序后的移除掩码)
89
+ # 执行sorted_indices_to_remove.scatter(1, sorted_indices, sorted_indices_to_remove) 后,得到的 indices_to_remove 将是
90
+ # indices_to_remove = [[True, True, False, False, False]]
91
+ indices_to_remove = sorted_indices_to_remove.scatter(1, index=sorted_indices, src=sorted_indices_to_remove)
92
+
93
+ # 将需要移除的元素的值设置为-inf
94
+ scores_processed = logits.masked_fill_(indices_to_remove, -float("Inf"))
95
+
96
+ return scores_processed
97
+
98
+
99
+ def _generate(
100
+ model: torch.nn.Module,
101
+ *,
102
+ tokens: torch.Tensor,
103
+ max_position_embeddings: int,
104
+ max_new_tokens: int,
105
+ temperature: Optional[float],
106
+ k: Optional[int],
107
+ p: Optional[float],
108
+ pixel_values: Optional[torch.Tensor] = None,
109
+ tokens_per_image: int = -1,
110
+ suppress_tokens: Optional[List[int]] = None,
111
+ device: Union[str, torch.device, int]
112
+ ):
113
+ """
114
+ :param model:
115
+ :param tokens:
116
+ :param max_position_embeddings:
117
+ :param max_new_tokens:
118
+ :param temperature: 设置None不不生效temperature
119
+ :param k: top k参数,设置为None或者0不生效topk
120
+ :param p: top p参数,设置为None不生效top p
121
+ :param suppress_tokens: 要抑制的tokens
122
+ :param device:
123
+
124
+ 如果内容质量底,需要减小temperature、k、p
125
+ 如果temperature很大但内容单一,需要增大k、p
126
+ """
127
+ use_kv_cache = True
128
+
129
+ ctx = torch.autocast(
130
+ device_type=device,
131
+ dtype=TrainerTools().dtype,
132
+ enabled=True,
133
+ # fsdp模式,需要将cache_enabled设置为false
134
+ cache_enabled=False if isinstance(model, FSDP) else None
135
+ ) if TrainerTools().use_amp else nullcontext()
136
+
137
+ if isinstance(model, VlmModel):
138
+ tokens = batch_repeat_image_tok(tokens, tokens_per_image)
139
+
140
+ if pixel_values is not None:
141
+ pixel_values = pixel_values.to(device)
142
+
143
+ kv_cache: Optional[KVCache] = None
144
+ generate_tokens = tokens.clone()
145
+
146
+ model.eval()
147
+ with torch.inference_mode():
148
+ for _ in range(max_new_tokens):
149
+ # 是否需要截取??
150
+ t = tokens[:, -max_position_embeddings:]
151
+ with ctx:
152
+ result = model(
153
+ t,
154
+ past_key_values=kv_cache,
155
+ use_cache=use_kv_cache,
156
+ pixel_values=pixel_values
157
+ )
158
+
159
+ # logits (batch, seq_len, vocab_size)
160
+ logits = result['logits']
161
+ kv_cache = result['past_key_values']
162
+
163
+ # (batch, vocab_size)
164
+ logits = logits[:, -1, :]
165
+ # 抑制特殊token输出
166
+ if suppress_tokens and len(suppress_tokens) != 0:
167
+ logits = _suppress_warper(logits, suppress_tokens)
168
+
169
+ multinomial = False
170
+ if temperature and temperature > 0:
171
+ multinomial = True
172
+ logits = _temperature_warper(logits, temperature)
173
+
174
+ if k and k != 0:
175
+ logits = _top_k_warper(logits, k, device)
176
+
177
+ if p and p < 1:
178
+ logits = _top_p_warper(logits, p)
179
+
180
+ if multinomial:
181
+ prob = logits.softmax(dim=-1)
182
+ # 返回下标
183
+ next_token = torch.multinomial(prob, num_samples=1)
184
+ else:
185
+ # 返回下标
186
+ next_token = logits.argmax(dim=-1, keepdim=True)
187
+
188
+ # token, is_full_result
189
+ yield next_token, False
190
+
191
+ if use_kv_cache:
192
+ tokens = next_token
193
+ generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
194
+ else:
195
+ tokens = torch.cat((tokens, next_token), dim=-1)
196
+
197
+ if next_token.item() == TrainerTools().tokenizer.end:
198
+ break
199
+
200
+ # token, is_full_result
201
+ yield tokens if not use_kv_cache else generate_tokens, True
202
+
203
+
204
+ def _streaming_generate(
205
+ model: torch.nn.Module,
206
+ *,
207
+ prompt: str,
208
+ max_position_embeddings: int,
209
+ max_new_tokens: int,
210
+ temperature: Optional[float] = 1.0,
211
+ k: Optional[int] = None,
212
+ p: Optional[float] = 1.0,
213
+ pixel_values: Optional[torch.Tensor] = None,
214
+ tokens_per_image: int = -1,
215
+ suppress_tokens: Optional[List[int]] = None,
216
+ device: Union[str, torch.device, int] = None,
217
+ ):
218
+ device = TrainerTools().parallel.device if not device else device
219
+ encoded_tokens = TrainerTools().tokenizer.encode(prompt, unsqueeze=True, covert_tensor=True).to(device)
220
+
221
+ generate_text_iterator = _generate(
222
+ model=model,
223
+ tokens=encoded_tokens,
224
+ max_position_embeddings=max_position_embeddings,
225
+ max_new_tokens=max_new_tokens,
226
+ temperature=temperature,
227
+ k=k,
228
+ p=p,
229
+ pixel_values=pixel_values,
230
+ tokens_per_image=tokens_per_image,
231
+ suppress_tokens=suppress_tokens,
232
+ device=device
233
+ )
234
+
235
+ for (token, is_full_result) in generate_text_iterator:
236
+ yield token, is_full_result
237
+
238
+
239
+ def streaming_generate(
240
+ model: torch.nn.Module,
241
+ *,
242
+ prompt: str,
243
+ max_position_embeddings: int,
244
+ max_new_tokens: int,
245
+ temperature: Optional[float] = 1.0,
246
+ k: Optional[int] = 50,
247
+ p: Optional[float] = 1.0,
248
+ pixel_values: Optional[torch.Tensor] = None,
249
+ tokens_per_image: int = -1,
250
+ suppress_tokens: Optional[List[int]] = None,
251
+ device: Union[str, torch.device, int] = None,
252
+ ):
253
+ text_iterator = _streaming_generate(
254
+ model=model,
255
+ prompt=prompt,
256
+ max_position_embeddings=max_position_embeddings,
257
+ max_new_tokens=max_new_tokens,
258
+ temperature=temperature,
259
+ k=k,
260
+ p=p,
261
+ pixel_values=pixel_values,
262
+ tokens_per_image=tokens_per_image,
263
+ suppress_tokens=suppress_tokens,
264
+ device=device
265
+ )
266
+
267
+ for (token, is_full_result) in text_iterator:
268
+ if not is_full_result:
269
+ yield TrainerTools().tokenizer.decode(token.squeeze(0))
270
+
271
+
272
+ def generate(
273
+ model: torch.nn.Module,
274
+ *,
275
+ prompt: str,
276
+ max_position_embeddings: int,
277
+ max_new_tokens: int,
278
+ temperature: Optional[float] = 1.0,
279
+ k: Optional[int] = None,
280
+ p: Optional[float] = 1.0,
281
+ pixel_values: Optional[torch.Tensor] = None,
282
+ tokens_per_image: int = -1,
283
+ suppress_tokens: Optional[List[int]] = None,
284
+ device: Union[str, torch.device, int] = None,
285
+ ):
286
+ text_iterator = _streaming_generate(
287
+ model=model,
288
+ prompt=prompt,
289
+ max_position_embeddings=max_position_embeddings,
290
+ max_new_tokens=max_new_tokens,
291
+ temperature=temperature,
292
+ k=k,
293
+ p=p,
294
+ suppress_tokens=suppress_tokens,
295
+ pixel_values=pixel_values,
296
+ tokens_per_image=tokens_per_image,
297
+ device=device
298
+ )
299
+
300
+ for (token, is_full_result) in text_iterator:
301
+ if is_full_result:
302
+ return TrainerTools().tokenizer.decode(token.squeeze(0))
303
+
304
+
305
+ def batch_generate(
306
+ model: torch.nn.Module,
307
+ *,
308
+ tokens: torch.Tensor,
309
+ pad_token_id: torch.Tensor,
310
+ attention_mask: torch.Tensor,
311
+ max_position_embeddings: int,
312
+ max_new_tokens: int,
313
+ temperature: Optional[float],
314
+ k: Optional[int],
315
+ p: Optional[float],
316
+ pixel_values: Optional[torch.Tensor] = None,
317
+ tokens_per_image: int = -1,
318
+ suppress_tokens: Optional[List[int]] = None,
319
+ device: Union[str, torch.device, int]
320
+ ):
321
+ use_kv_cache = True
322
+
323
+ ctx = torch.autocast(
324
+ device_type=device,
325
+ dtype=TrainerTools().dtype,
326
+ enabled=True,
327
+ cache_enabled=False if isinstance(model, FSDP) else None
328
+ ) if TrainerTools().use_amp else nullcontext()
329
+
330
+ if isinstance(model, VlmModel):
331
+ tokens = batch_repeat_image_tok(tokens, tokens_per_image)
332
+
333
+ if pixel_values is not None:
334
+ pixel_values = pixel_values.to(device)
335
+
336
+ kv_cache: Optional[KVCache] = None
337
+ generate_tokens = tokens.clone()
338
+ batch_size = tokens.shape[0]
339
+
340
+ # 初始化完成标记
341
+ end_token = TrainerTools().tokenizer.end
342
+ done = torch.zeros(batch_size, dtype=torch.bool, device=device)
343
+
344
+ model.eval()
345
+ with torch.inference_mode():
346
+ for _ in range(max_new_tokens):
347
+ # 只处理未完成的样本
348
+ if done.all():
349
+ break
350
+
351
+ t = tokens #tokens[:, -max_position_embeddings:]
352
+ with ctx:
353
+ result = model(
354
+ t,
355
+ attention_mask=attention_mask,
356
+ past_key_values=kv_cache,
357
+ use_cache=use_kv_cache,
358
+ pixel_values=pixel_values
359
+ )
360
+
361
+ logits = result['logits']
362
+ kv_cache = result['past_key_values']
363
+
364
+ # 处理logits
365
+ logits = logits[:, -1, :] # (batch, vocab_size)
366
+
367
+ if done.any():
368
+ # 强制构造 one-hot 分布,确保 pad_token_id 概率为 1
369
+ logits_done = torch.full_like(logits[done], -torch.finfo(logits.dtype).max)
370
+ logits_done[:, pad_token_id] = 0 # pad_token_id 的 logit 设为 0
371
+ logits[done] = logits_done
372
+
373
+ if suppress_tokens and suppress_tokens:
374
+ logits = _suppress_warper(logits, suppress_tokens)
375
+
376
+ multinomial = False
377
+ if temperature and temperature > 0:
378
+ multinomial = True
379
+ logits = _temperature_warper(logits, temperature)
380
+
381
+ if k and k != 0:
382
+ logits = _top_k_warper(logits, k, device)
383
+
384
+ if p and p < 1:
385
+ logits = _top_p_warper(logits, p)
386
+
387
+ prob = logits.softmax(dim=-1)
388
+
389
+ # 检查并修正无效概率分布(sum <=0)[[1]][[4]]
390
+ prob_sum = prob.sum(dim=-1, keepdim=True)
391
+ invalid_mask = (prob_sum <= 0)
392
+ if invalid_mask.any():
393
+ # 为无效样本生成均匀分布,确保概率总和>0
394
+ uniform_prob = torch.ones_like(prob) / prob.size(-1)
395
+ prob = torch.where(invalid_mask.unsqueeze(-1), uniform_prob, prob)
396
+
397
+ # 抑制已完成样本(确保 pad_token_id 概率为 1)
398
+ prob[done] = 0
399
+ prob[done, pad_token_id] = 1
400
+
401
+ # 数值稳定性处理
402
+ if torch.isnan(prob).any() or torch.isinf(prob).any():
403
+ prob = torch.nan_to_num(prob, nan=0.0, posinf=1.0, neginf=0.0)
404
+
405
+ if multinomial:
406
+ next_token = torch.multinomial(prob, num_samples=1)
407
+ else:
408
+ next_token = logits.argmax(dim=-1, keepdim=True)
409
+
410
+ # 更新完成标记
411
+ done = done | (next_token.squeeze(-1) == end_token)
412
+
413
+ # 拼接生成结果
414
+ if use_kv_cache:
415
+ tokens = next_token
416
+ generate_tokens = torch.cat((generate_tokens, next_token), dim=-1)
417
+ else:
418
+ tokens = torch.cat((tokens, next_token), dim=-1)
419
+
420
+ new_mask = torch.ones_like(next_token, dtype=torch.bool)
421
+ attention_mask = torch.cat((attention_mask, new_mask), dim=-1)
422
+
423
+ # 返回完整结果
424
+ return tokens if not use_kv_cache else generate_tokens