gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,513 @@
|
|
|
1
|
+
"""
|
|
2
|
+
推测解码引擎
|
|
3
|
+
|
|
4
|
+
实现 EAGLE-3 风格的推测解码,提升单请求解码速度 2-3x
|
|
5
|
+
|
|
6
|
+
核心思想:
|
|
7
|
+
1. 使用轻量级 Draft 模型预测多个候选 token
|
|
8
|
+
2. 目标模型并行验证候选序列
|
|
9
|
+
3. 接受最长的正确前缀
|
|
10
|
+
|
|
11
|
+
参考:
|
|
12
|
+
- EAGLE-3: https://arxiv.org/abs/2503.01840
|
|
13
|
+
- Medusa: https://arxiv.org/abs/2401.10774
|
|
14
|
+
- SpecInfer: https://arxiv.org/abs/2305.09781
|
|
15
|
+
"""
|
|
16
|
+
import logging
|
|
17
|
+
import time
|
|
18
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
19
|
+
from dataclasses import dataclass
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
import torch.nn as nn
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
@dataclass
|
|
28
|
+
class SpeculativeConfig:
|
|
29
|
+
"""推测解码配置"""
|
|
30
|
+
# Draft 模型配置
|
|
31
|
+
draft_model_id: Optional[str] = None # 独立的 draft 模型
|
|
32
|
+
use_self_draft: bool = True # 使用自身作为 draft(EAGLE 风格)
|
|
33
|
+
draft_head_hidden_size: int = 1024 # Draft head 隐藏层大小
|
|
34
|
+
|
|
35
|
+
# 推测参数
|
|
36
|
+
num_speculative_tokens: int = 5 # 每步推测的 token 数
|
|
37
|
+
tree_width: int = 3 # 树宽度(每个位置的候选数)
|
|
38
|
+
tree_depth: int = 5 # 树深度
|
|
39
|
+
|
|
40
|
+
# 验证参数
|
|
41
|
+
temperature: float = 0.0 # 验证时的温度
|
|
42
|
+
top_p: float = 1.0 # 验证时的 top_p
|
|
43
|
+
|
|
44
|
+
# 性能参数
|
|
45
|
+
min_accept_rate: float = 0.3 # 最小接受率阈值
|
|
46
|
+
adaptive_depth: bool = True # 自适应调整推测深度
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
@dataclass
|
|
50
|
+
class SpeculativeOutput:
|
|
51
|
+
"""推测解码输出"""
|
|
52
|
+
tokens: List[int] # 接受的 tokens
|
|
53
|
+
accept_rate: float # 接受率
|
|
54
|
+
draft_tokens: int # 生成的 draft tokens
|
|
55
|
+
accepted_tokens: int # 接受的 tokens
|
|
56
|
+
latency_ms: float # 延迟
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class DraftHead(nn.Module):
|
|
60
|
+
"""
|
|
61
|
+
EAGLE 风格的 Draft Head
|
|
62
|
+
|
|
63
|
+
在 feature level 进行自回归预测,而不是 token level
|
|
64
|
+
"""
|
|
65
|
+
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
hidden_size: int,
|
|
69
|
+
vocab_size: int,
|
|
70
|
+
num_layers: int = 2,
|
|
71
|
+
hidden_dim: int = 1024,
|
|
72
|
+
):
|
|
73
|
+
super().__init__()
|
|
74
|
+
|
|
75
|
+
self.hidden_size = hidden_size
|
|
76
|
+
self.vocab_size = vocab_size
|
|
77
|
+
|
|
78
|
+
# Feature 预测网络
|
|
79
|
+
layers = []
|
|
80
|
+
input_dim = hidden_size * 2 # 当前 hidden + token embedding
|
|
81
|
+
for i in range(num_layers):
|
|
82
|
+
output_dim = hidden_dim if i < num_layers - 1 else hidden_size
|
|
83
|
+
layers.extend([
|
|
84
|
+
nn.Linear(input_dim, output_dim),
|
|
85
|
+
nn.SiLU() if i < num_layers - 1 else nn.Identity(),
|
|
86
|
+
])
|
|
87
|
+
input_dim = output_dim
|
|
88
|
+
|
|
89
|
+
self.feature_predictor = nn.Sequential(*layers)
|
|
90
|
+
|
|
91
|
+
# Token embedding(共享目标模型的 embedding)
|
|
92
|
+
self.token_embedding = None
|
|
93
|
+
|
|
94
|
+
def set_token_embedding(self, embedding: nn.Embedding) -> None:
|
|
95
|
+
"""设置 token embedding(通常共享目标模型的 embedding)"""
|
|
96
|
+
self.token_embedding = embedding
|
|
97
|
+
|
|
98
|
+
def forward(
|
|
99
|
+
self,
|
|
100
|
+
hidden_states: torch.Tensor,
|
|
101
|
+
token_ids: torch.Tensor,
|
|
102
|
+
) -> torch.Tensor:
|
|
103
|
+
"""
|
|
104
|
+
预测下一步的 hidden states
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
hidden_states: [batch, seq_len, hidden_size]
|
|
108
|
+
token_ids: [batch, seq_len]
|
|
109
|
+
|
|
110
|
+
Returns:
|
|
111
|
+
predicted_hidden: [batch, seq_len, hidden_size]
|
|
112
|
+
"""
|
|
113
|
+
if self.token_embedding is None:
|
|
114
|
+
raise RuntimeError("Token embedding not set")
|
|
115
|
+
|
|
116
|
+
# 获取 token embeddings
|
|
117
|
+
token_embeds = self.token_embedding(token_ids) # [batch, seq_len, hidden_size]
|
|
118
|
+
|
|
119
|
+
# 拼接 hidden states 和 token embeddings
|
|
120
|
+
combined = torch.cat([hidden_states, token_embeds], dim=-1)
|
|
121
|
+
|
|
122
|
+
# 预测下一步的 hidden states
|
|
123
|
+
predicted_hidden = self.feature_predictor(combined)
|
|
124
|
+
|
|
125
|
+
return predicted_hidden
|
|
126
|
+
|
|
127
|
+
|
|
128
|
+
class TreeDraftBuffer:
|
|
129
|
+
"""
|
|
130
|
+
Token 树缓冲区
|
|
131
|
+
|
|
132
|
+
管理推测解码中的候选 token 树
|
|
133
|
+
"""
|
|
134
|
+
|
|
135
|
+
def __init__(
|
|
136
|
+
self,
|
|
137
|
+
tree_width: int = 3,
|
|
138
|
+
tree_depth: int = 5,
|
|
139
|
+
device: str = "cuda",
|
|
140
|
+
):
|
|
141
|
+
self.tree_width = tree_width
|
|
142
|
+
self.tree_depth = tree_depth
|
|
143
|
+
self.device = device
|
|
144
|
+
|
|
145
|
+
# 树节点:(token_id, log_prob, parent_idx)
|
|
146
|
+
self.nodes: List[Tuple[int, float, int]] = []
|
|
147
|
+
|
|
148
|
+
# 层级索引
|
|
149
|
+
self.layer_offsets: List[int] = []
|
|
150
|
+
|
|
151
|
+
def reset(self) -> None:
|
|
152
|
+
"""重置缓冲区"""
|
|
153
|
+
self.nodes.clear()
|
|
154
|
+
self.layer_offsets.clear()
|
|
155
|
+
|
|
156
|
+
def add_candidates(
|
|
157
|
+
self,
|
|
158
|
+
token_ids: torch.Tensor,
|
|
159
|
+
log_probs: torch.Tensor,
|
|
160
|
+
parent_indices: torch.Tensor,
|
|
161
|
+
) -> None:
|
|
162
|
+
"""
|
|
163
|
+
添加候选 tokens
|
|
164
|
+
|
|
165
|
+
Args:
|
|
166
|
+
token_ids: [num_candidates]
|
|
167
|
+
log_probs: [num_candidates]
|
|
168
|
+
parent_indices: [num_candidates] 每个候选的父节点索引
|
|
169
|
+
"""
|
|
170
|
+
self.layer_offsets.append(len(self.nodes))
|
|
171
|
+
|
|
172
|
+
for tid, lp, pid in zip(
|
|
173
|
+
token_ids.cpu().tolist(),
|
|
174
|
+
log_probs.cpu().tolist(),
|
|
175
|
+
parent_indices.cpu().tolist()
|
|
176
|
+
):
|
|
177
|
+
self.nodes.append((tid, lp, pid))
|
|
178
|
+
|
|
179
|
+
def get_tree_tokens(self) -> torch.Tensor:
|
|
180
|
+
"""获取树中所有 tokens(用于验证)"""
|
|
181
|
+
tokens = [node[0] for node in self.nodes]
|
|
182
|
+
return torch.tensor(tokens, device=self.device)
|
|
183
|
+
|
|
184
|
+
def get_tree_attention_mask(self, seq_len: int) -> torch.Tensor:
|
|
185
|
+
"""
|
|
186
|
+
生成树形 attention mask
|
|
187
|
+
|
|
188
|
+
Returns:
|
|
189
|
+
mask: [num_nodes, seq_len + num_nodes]
|
|
190
|
+
"""
|
|
191
|
+
num_nodes = len(self.nodes)
|
|
192
|
+
total_len = seq_len + num_nodes
|
|
193
|
+
|
|
194
|
+
# 初始化 mask
|
|
195
|
+
mask = torch.zeros(num_nodes, total_len, device=self.device)
|
|
196
|
+
|
|
197
|
+
# 每个节点可以看到:
|
|
198
|
+
# 1. 所有前缀 tokens
|
|
199
|
+
# 2. 自己的祖先节点
|
|
200
|
+
for i, (_, _, parent_idx) in enumerate(self.nodes):
|
|
201
|
+
# 可以看到所有前缀
|
|
202
|
+
mask[i, :seq_len] = 1
|
|
203
|
+
|
|
204
|
+
# 可以看到自己和祖先
|
|
205
|
+
current = i
|
|
206
|
+
while current >= 0:
|
|
207
|
+
mask[i, seq_len + current] = 1
|
|
208
|
+
if current < len(self.nodes):
|
|
209
|
+
current = self.nodes[current][2]
|
|
210
|
+
else:
|
|
211
|
+
break
|
|
212
|
+
|
|
213
|
+
return mask
|
|
214
|
+
|
|
215
|
+
def trace_accepted_path(
|
|
216
|
+
self,
|
|
217
|
+
accepted_mask: torch.Tensor
|
|
218
|
+
) -> List[int]:
|
|
219
|
+
"""
|
|
220
|
+
追踪被接受的路径
|
|
221
|
+
|
|
222
|
+
Args:
|
|
223
|
+
accepted_mask: [num_nodes] 布尔掩码
|
|
224
|
+
|
|
225
|
+
Returns:
|
|
226
|
+
accepted_tokens: 被接受的 token 列表
|
|
227
|
+
"""
|
|
228
|
+
accepted = accepted_mask.cpu().tolist()
|
|
229
|
+
|
|
230
|
+
# 找到最长的被接受路径
|
|
231
|
+
best_path = []
|
|
232
|
+
|
|
233
|
+
for i in range(len(self.nodes) - 1, -1, -1):
|
|
234
|
+
if accepted[i]:
|
|
235
|
+
path = []
|
|
236
|
+
current = i
|
|
237
|
+
while current >= 0 and accepted[current]:
|
|
238
|
+
path.append(self.nodes[current][0])
|
|
239
|
+
current = self.nodes[current][2]
|
|
240
|
+
path.reverse()
|
|
241
|
+
|
|
242
|
+
if len(path) > len(best_path):
|
|
243
|
+
best_path = path
|
|
244
|
+
|
|
245
|
+
return best_path
|
|
246
|
+
|
|
247
|
+
|
|
248
|
+
class SpeculativeDecoder:
|
|
249
|
+
"""
|
|
250
|
+
推测解码器
|
|
251
|
+
|
|
252
|
+
实现完整的推测解码流程
|
|
253
|
+
"""
|
|
254
|
+
|
|
255
|
+
def __init__(
|
|
256
|
+
self,
|
|
257
|
+
target_model: nn.Module,
|
|
258
|
+
config: SpeculativeConfig,
|
|
259
|
+
device: str = "cuda",
|
|
260
|
+
):
|
|
261
|
+
self.target = target_model
|
|
262
|
+
self.config = config
|
|
263
|
+
self.device = device
|
|
264
|
+
|
|
265
|
+
# Draft 组件
|
|
266
|
+
self.draft_head: Optional[DraftHead] = None
|
|
267
|
+
self.draft_model: Optional[nn.Module] = None
|
|
268
|
+
|
|
269
|
+
# 树缓冲区
|
|
270
|
+
self.tree_buffer = TreeDraftBuffer(
|
|
271
|
+
tree_width=config.tree_width,
|
|
272
|
+
tree_depth=config.tree_depth,
|
|
273
|
+
device=device,
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
# 统计
|
|
277
|
+
self._stats = {
|
|
278
|
+
"total_steps": 0,
|
|
279
|
+
"total_draft_tokens": 0,
|
|
280
|
+
"total_accepted_tokens": 0,
|
|
281
|
+
"avg_accept_rate": 0.0,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
# 自适应深度
|
|
285
|
+
self._current_depth = config.tree_depth
|
|
286
|
+
|
|
287
|
+
def setup_draft_head(
|
|
288
|
+
self,
|
|
289
|
+
hidden_size: int,
|
|
290
|
+
vocab_size: int,
|
|
291
|
+
) -> None:
|
|
292
|
+
"""设置 Draft Head"""
|
|
293
|
+
self.draft_head = DraftHead(
|
|
294
|
+
hidden_size=hidden_size,
|
|
295
|
+
vocab_size=vocab_size,
|
|
296
|
+
hidden_dim=self.config.draft_head_hidden_size,
|
|
297
|
+
).to(self.device)
|
|
298
|
+
|
|
299
|
+
# 共享 embedding
|
|
300
|
+
if hasattr(self.target, "model") and hasattr(self.target.model, "embed_tokens"):
|
|
301
|
+
self.draft_head.set_token_embedding(self.target.model.embed_tokens)
|
|
302
|
+
elif hasattr(self.target, "transformer") and hasattr(self.target.transformer, "wte"):
|
|
303
|
+
self.draft_head.set_token_embedding(self.target.transformer.wte)
|
|
304
|
+
|
|
305
|
+
async def decode_step(
|
|
306
|
+
self,
|
|
307
|
+
hidden_states: torch.Tensor,
|
|
308
|
+
past_key_values: Optional[List[Tuple[torch.Tensor, torch.Tensor]]] = None,
|
|
309
|
+
input_ids: Optional[torch.Tensor] = None,
|
|
310
|
+
) -> Tuple[List[int], int]:
|
|
311
|
+
"""
|
|
312
|
+
执行一步推测解码
|
|
313
|
+
|
|
314
|
+
Args:
|
|
315
|
+
hidden_states: 当前隐藏状态
|
|
316
|
+
past_key_values: KV-Cache
|
|
317
|
+
input_ids: 当前输入 tokens
|
|
318
|
+
|
|
319
|
+
Returns:
|
|
320
|
+
(accepted_tokens, num_accepted)
|
|
321
|
+
"""
|
|
322
|
+
start_time = time.time()
|
|
323
|
+
|
|
324
|
+
# 1. Draft 阶段:生成候选 token 树
|
|
325
|
+
self.tree_buffer.reset()
|
|
326
|
+
draft_tokens = await self._generate_draft_tree(
|
|
327
|
+
hidden_states,
|
|
328
|
+
input_ids,
|
|
329
|
+
past_key_values,
|
|
330
|
+
)
|
|
331
|
+
|
|
332
|
+
# 2. Verify 阶段:目标模型验证
|
|
333
|
+
accepted_mask = await self._verify_candidates(
|
|
334
|
+
draft_tokens,
|
|
335
|
+
hidden_states,
|
|
336
|
+
past_key_values,
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
# 3. Accept 阶段:确定接受的 tokens
|
|
340
|
+
accepted_tokens = self.tree_buffer.trace_accepted_path(accepted_mask)
|
|
341
|
+
|
|
342
|
+
# 更新统计
|
|
343
|
+
num_draft = len(self.tree_buffer.nodes)
|
|
344
|
+
num_accepted = len(accepted_tokens)
|
|
345
|
+
accept_rate = num_accepted / max(1, num_draft)
|
|
346
|
+
|
|
347
|
+
self._stats["total_steps"] += 1
|
|
348
|
+
self._stats["total_draft_tokens"] += num_draft
|
|
349
|
+
self._stats["total_accepted_tokens"] += num_accepted
|
|
350
|
+
self._stats["avg_accept_rate"] = (
|
|
351
|
+
self._stats["total_accepted_tokens"] /
|
|
352
|
+
max(1, self._stats["total_draft_tokens"])
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
# 自适应调整深度
|
|
356
|
+
if self.config.adaptive_depth:
|
|
357
|
+
self._adapt_depth(accept_rate)
|
|
358
|
+
|
|
359
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
360
|
+
logger.debug(
|
|
361
|
+
f"Speculative step: {num_accepted}/{num_draft} accepted "
|
|
362
|
+
f"({accept_rate:.1%}) in {latency_ms:.1f}ms"
|
|
363
|
+
)
|
|
364
|
+
|
|
365
|
+
return accepted_tokens, num_accepted
|
|
366
|
+
|
|
367
|
+
async def _generate_draft_tree(
|
|
368
|
+
self,
|
|
369
|
+
hidden_states: torch.Tensor,
|
|
370
|
+
input_ids: torch.Tensor,
|
|
371
|
+
past_key_values,
|
|
372
|
+
) -> torch.Tensor:
|
|
373
|
+
"""生成 draft token 树"""
|
|
374
|
+
if self.draft_head is None:
|
|
375
|
+
raise RuntimeError("Draft head not initialized")
|
|
376
|
+
|
|
377
|
+
current_hidden = hidden_states
|
|
378
|
+
current_tokens = input_ids[:, -1:]
|
|
379
|
+
|
|
380
|
+
for depth in range(self._current_depth):
|
|
381
|
+
# 预测下一步的 hidden states
|
|
382
|
+
predicted_hidden = self.draft_head(current_hidden, current_tokens)
|
|
383
|
+
|
|
384
|
+
# 使用目标模型的 LM head 获取 logits
|
|
385
|
+
if hasattr(self.target, "lm_head"):
|
|
386
|
+
logits = self.target.lm_head(predicted_hidden)
|
|
387
|
+
else:
|
|
388
|
+
logits = predicted_hidden @ self.target.model.embed_tokens.weight.T
|
|
389
|
+
|
|
390
|
+
# 获取 top-k candidates
|
|
391
|
+
log_probs = torch.log_softmax(logits[:, -1], dim=-1)
|
|
392
|
+
top_log_probs, top_indices = torch.topk(
|
|
393
|
+
log_probs,
|
|
394
|
+
k=min(self.config.tree_width, logits.size(-1)),
|
|
395
|
+
dim=-1
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# 添加到树
|
|
399
|
+
parent_indices = torch.zeros(top_indices.size(-1), device=self.device, dtype=torch.long)
|
|
400
|
+
if depth > 0:
|
|
401
|
+
# 连接到上一层的最佳节点
|
|
402
|
+
parent_indices = torch.full_like(
|
|
403
|
+
top_indices.squeeze(),
|
|
404
|
+
len(self.tree_buffer.nodes) - 1
|
|
405
|
+
)
|
|
406
|
+
|
|
407
|
+
self.tree_buffer.add_candidates(
|
|
408
|
+
top_indices.squeeze(),
|
|
409
|
+
top_log_probs.squeeze(),
|
|
410
|
+
parent_indices,
|
|
411
|
+
)
|
|
412
|
+
|
|
413
|
+
# 更新状态(使用最佳候选)
|
|
414
|
+
current_tokens = top_indices[:, :1]
|
|
415
|
+
current_hidden = predicted_hidden
|
|
416
|
+
|
|
417
|
+
return self.tree_buffer.get_tree_tokens()
|
|
418
|
+
|
|
419
|
+
async def _verify_candidates(
|
|
420
|
+
self,
|
|
421
|
+
draft_tokens: torch.Tensor,
|
|
422
|
+
hidden_states: torch.Tensor,
|
|
423
|
+
past_key_values,
|
|
424
|
+
) -> torch.Tensor:
|
|
425
|
+
"""验证候选 tokens"""
|
|
426
|
+
# 获取树形 attention mask
|
|
427
|
+
seq_len = hidden_states.size(1)
|
|
428
|
+
tree_mask = self.tree_buffer.get_tree_attention_mask(seq_len)
|
|
429
|
+
|
|
430
|
+
# 准备输入
|
|
431
|
+
verify_input = draft_tokens.unsqueeze(0)
|
|
432
|
+
|
|
433
|
+
# 目标模型前向传播
|
|
434
|
+
with torch.no_grad():
|
|
435
|
+
outputs = self.target(
|
|
436
|
+
input_ids=verify_input,
|
|
437
|
+
past_key_values=past_key_values,
|
|
438
|
+
use_cache=True,
|
|
439
|
+
)
|
|
440
|
+
|
|
441
|
+
# 获取 logits 并验证
|
|
442
|
+
logits = outputs.logits # [1, num_tokens, vocab_size]
|
|
443
|
+
|
|
444
|
+
# 采样验证
|
|
445
|
+
if self.config.temperature > 0:
|
|
446
|
+
probs = torch.softmax(logits / self.config.temperature, dim=-1)
|
|
447
|
+
sampled_tokens = torch.multinomial(probs.squeeze(), num_samples=1).squeeze()
|
|
448
|
+
else:
|
|
449
|
+
sampled_tokens = logits.argmax(dim=-1).squeeze()
|
|
450
|
+
|
|
451
|
+
# 检查匹配
|
|
452
|
+
accepted_mask = (sampled_tokens == draft_tokens)
|
|
453
|
+
|
|
454
|
+
return accepted_mask
|
|
455
|
+
|
|
456
|
+
def _adapt_depth(self, accept_rate: float) -> None:
|
|
457
|
+
"""自适应调整推测深度"""
|
|
458
|
+
if accept_rate < self.config.min_accept_rate:
|
|
459
|
+
# 接受率过低,减少深度
|
|
460
|
+
self._current_depth = max(1, self._current_depth - 1)
|
|
461
|
+
elif accept_rate > 0.7 and self._current_depth < self.config.tree_depth:
|
|
462
|
+
# 接受率高,增加深度
|
|
463
|
+
self._current_depth = min(self.config.tree_depth, self._current_depth + 1)
|
|
464
|
+
|
|
465
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
466
|
+
"""获取统计信息"""
|
|
467
|
+
return {
|
|
468
|
+
**self._stats,
|
|
469
|
+
"current_depth": self._current_depth,
|
|
470
|
+
"speedup_estimate": max(1.0, self._stats["avg_accept_rate"] * self._current_depth),
|
|
471
|
+
}
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
class MedusaHead(nn.Module):
|
|
475
|
+
"""
|
|
476
|
+
Medusa 风格的多头预测
|
|
477
|
+
|
|
478
|
+
使用多个独立的预测头,每个头预测不同位置的 token
|
|
479
|
+
"""
|
|
480
|
+
|
|
481
|
+
def __init__(
|
|
482
|
+
self,
|
|
483
|
+
hidden_size: int,
|
|
484
|
+
vocab_size: int,
|
|
485
|
+
num_heads: int = 4,
|
|
486
|
+
hidden_dim: int = 1024,
|
|
487
|
+
):
|
|
488
|
+
super().__init__()
|
|
489
|
+
|
|
490
|
+
self.num_heads = num_heads
|
|
491
|
+
self.heads = nn.ModuleList([
|
|
492
|
+
nn.Sequential(
|
|
493
|
+
nn.Linear(hidden_size, hidden_dim),
|
|
494
|
+
nn.SiLU(),
|
|
495
|
+
nn.Linear(hidden_dim, vocab_size),
|
|
496
|
+
)
|
|
497
|
+
for _ in range(num_heads)
|
|
498
|
+
])
|
|
499
|
+
|
|
500
|
+
def forward(
|
|
501
|
+
self,
|
|
502
|
+
hidden_states: torch.Tensor,
|
|
503
|
+
) -> List[torch.Tensor]:
|
|
504
|
+
"""
|
|
505
|
+
预测多个位置的 logits
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
hidden_states: [batch, seq_len, hidden_size]
|
|
509
|
+
|
|
510
|
+
Returns:
|
|
511
|
+
logits_list: List of [batch, seq_len, vocab_size]
|
|
512
|
+
"""
|
|
513
|
+
return [head(hidden_states) for head in self.heads]
|
|
@@ -0,0 +1,139 @@
|
|
|
1
|
+
"""视觉理解推理引擎 - GLM-4V"""
|
|
2
|
+
from typing import Dict, Any
|
|
3
|
+
import torch
|
|
4
|
+
import base64
|
|
5
|
+
import io
|
|
6
|
+
import logging
|
|
7
|
+
from PIL import Image
|
|
8
|
+
|
|
9
|
+
from .base import BaseEngine
|
|
10
|
+
|
|
11
|
+
logger = logging.getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class VisionEngine(BaseEngine):
|
|
15
|
+
"""视觉理解推理引擎 - 支持图像识别、图像问答"""
|
|
16
|
+
|
|
17
|
+
def load_model(self) -> None:
|
|
18
|
+
"""加载视觉语言模型"""
|
|
19
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
20
|
+
|
|
21
|
+
model_id = self.config.get("model_id", "THUDM/glm-4v-9b")
|
|
22
|
+
logger.info(f"Loading vision model: {model_id}")
|
|
23
|
+
|
|
24
|
+
# 加载tokenizer
|
|
25
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
26
|
+
model_id,
|
|
27
|
+
trust_remote_code=True
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
# 加载模型
|
|
31
|
+
load_kwargs = {
|
|
32
|
+
"torch_dtype": torch.bfloat16,
|
|
33
|
+
"trust_remote_code": True,
|
|
34
|
+
"low_cpu_mem_usage": True,
|
|
35
|
+
}
|
|
36
|
+
|
|
37
|
+
if self.config.get("enable_cpu_offload", True):
|
|
38
|
+
load_kwargs["device_map"] = "auto"
|
|
39
|
+
else:
|
|
40
|
+
load_kwargs["device_map"] = {"": self.device}
|
|
41
|
+
|
|
42
|
+
self.model = AutoModelForCausalLM.from_pretrained(model_id, **load_kwargs)
|
|
43
|
+
self.model.eval()
|
|
44
|
+
|
|
45
|
+
self.loaded = True
|
|
46
|
+
logger.info("Vision model loaded successfully")
|
|
47
|
+
|
|
48
|
+
def inference(self, params: Dict[str, Any]) -> Dict[str, Any]:
|
|
49
|
+
"""
|
|
50
|
+
执行视觉理解推理
|
|
51
|
+
|
|
52
|
+
支持的任务:
|
|
53
|
+
- image_qa: 图像问答
|
|
54
|
+
- image_caption: 图像描述
|
|
55
|
+
- ocr: 文字识别
|
|
56
|
+
"""
|
|
57
|
+
task = params.get("task", "image_qa")
|
|
58
|
+
image_data = params.get("image_base64") or params.get("image")
|
|
59
|
+
question = params.get("question", "请描述这张图片的内容")
|
|
60
|
+
max_tokens = params.get("max_tokens", 1024)
|
|
61
|
+
|
|
62
|
+
# 解码图像
|
|
63
|
+
if isinstance(image_data, str):
|
|
64
|
+
# Base64 编码的图像
|
|
65
|
+
image_bytes = base64.b64decode(image_data)
|
|
66
|
+
image = Image.open(io.BytesIO(image_bytes)).convert("RGB")
|
|
67
|
+
else:
|
|
68
|
+
raise ValueError("image_base64 is required")
|
|
69
|
+
|
|
70
|
+
# 根据任务类型构建提示
|
|
71
|
+
if task == "image_caption":
|
|
72
|
+
prompt = "请详细描述这张图片的内容,包括场景、物体、人物、颜色等细节。"
|
|
73
|
+
elif task == "ocr":
|
|
74
|
+
prompt = "请识别并提取图片中的所有文字内容。"
|
|
75
|
+
elif task == "image_qa":
|
|
76
|
+
prompt = question
|
|
77
|
+
else:
|
|
78
|
+
prompt = question
|
|
79
|
+
|
|
80
|
+
# 构建输入
|
|
81
|
+
messages = [
|
|
82
|
+
{
|
|
83
|
+
"role": "user",
|
|
84
|
+
"content": [
|
|
85
|
+
{"type": "image", "image": image},
|
|
86
|
+
{"type": "text", "text": prompt}
|
|
87
|
+
]
|
|
88
|
+
}
|
|
89
|
+
]
|
|
90
|
+
|
|
91
|
+
# 应用聊天模板
|
|
92
|
+
inputs = self.tokenizer.apply_chat_template(
|
|
93
|
+
messages,
|
|
94
|
+
add_generation_prompt=True,
|
|
95
|
+
tokenize=True,
|
|
96
|
+
return_tensors="pt",
|
|
97
|
+
return_dict=True
|
|
98
|
+
).to(self.model.device)
|
|
99
|
+
|
|
100
|
+
input_length = inputs["input_ids"].shape[1]
|
|
101
|
+
|
|
102
|
+
# 生成响应
|
|
103
|
+
with torch.no_grad():
|
|
104
|
+
outputs = self.model.generate(
|
|
105
|
+
**inputs,
|
|
106
|
+
max_new_tokens=max_tokens,
|
|
107
|
+
do_sample=True,
|
|
108
|
+
temperature=0.7,
|
|
109
|
+
top_p=0.9,
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
# 解码响应
|
|
113
|
+
response = self.tokenizer.decode(
|
|
114
|
+
outputs[0][input_length:],
|
|
115
|
+
skip_special_tokens=True
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
output_length = outputs.shape[1] - input_length
|
|
119
|
+
|
|
120
|
+
return {
|
|
121
|
+
"response": response,
|
|
122
|
+
"task": task,
|
|
123
|
+
"usage": {
|
|
124
|
+
"prompt_tokens": input_length,
|
|
125
|
+
"completion_tokens": output_length,
|
|
126
|
+
"total_tokens": input_length + output_length
|
|
127
|
+
}
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
def unload_model(self) -> None:
|
|
131
|
+
"""卸载模型"""
|
|
132
|
+
if self.model:
|
|
133
|
+
del self.model
|
|
134
|
+
self.model = None
|
|
135
|
+
if hasattr(self, "tokenizer"):
|
|
136
|
+
del self.tokenizer
|
|
137
|
+
torch.cuda.empty_cache()
|
|
138
|
+
self.loaded = False
|
|
139
|
+
logger.info("Vision model unloaded")
|