gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,539 @@
|
|
|
1
|
+
"""
|
|
2
|
+
vLLM 高性能推理引擎
|
|
3
|
+
|
|
4
|
+
特点:
|
|
5
|
+
- PagedAttention 高效内存管理(PyTorch Foundation 项目)
|
|
6
|
+
- 连续批处理 (Continuous Batching)
|
|
7
|
+
- 张量并行 (Tensor Parallelism)
|
|
8
|
+
- 前缀缓存 (Prefix Caching)
|
|
9
|
+
- 分块预填充 (Chunked Prefill)
|
|
10
|
+
- 多种量化支持 (AWQ, GPTQ, INT8, FP8)
|
|
11
|
+
|
|
12
|
+
参考:https://github.com/vllm-project/vllm
|
|
13
|
+
"""
|
|
14
|
+
from typing import Dict, Any, List, Optional, AsyncIterator
|
|
15
|
+
import logging
|
|
16
|
+
import asyncio
|
|
17
|
+
import time
|
|
18
|
+
|
|
19
|
+
from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class VLLMEngine(LLMBaseEngine):
|
|
25
|
+
"""
|
|
26
|
+
基于 vLLM 的高性能 LLM 推理引擎
|
|
27
|
+
|
|
28
|
+
vLLM 核心优势:
|
|
29
|
+
1. PagedAttention - 高效显存管理,提升利用率到 85%+
|
|
30
|
+
2. 成熟的生态系统和社区支持
|
|
31
|
+
3. 支持多种量化方案 (AWQ, GPTQ, FP8, INT8)
|
|
32
|
+
4. 原生支持张量并行,适合多卡部署
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
def __init__(self, config: Dict[str, Any]):
|
|
36
|
+
super().__init__(config)
|
|
37
|
+
self.backend_type = LLMBackend.VLLM
|
|
38
|
+
self.llm = None
|
|
39
|
+
self._vllm_config = config.get("vllm", {})
|
|
40
|
+
self._default_sampling_params = None
|
|
41
|
+
|
|
42
|
+
def load_model(self) -> None:
|
|
43
|
+
"""加载模型到 vLLM"""
|
|
44
|
+
try:
|
|
45
|
+
from vllm import LLM, SamplingParams
|
|
46
|
+
except ImportError:
|
|
47
|
+
raise ImportError(
|
|
48
|
+
"vLLM not installed. Please install with:\n"
|
|
49
|
+
" pip install vllm"
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
|
|
53
|
+
logger.info(f"Loading model with vLLM: {model_id}")
|
|
54
|
+
|
|
55
|
+
# 合并配置
|
|
56
|
+
tp_size = self._vllm_config.get("tensor_parallel_size", 1)
|
|
57
|
+
gpu_util = self._vllm_config.get("gpu_memory_utilization", 0.85)
|
|
58
|
+
max_model_len = self._vllm_config.get("max_model_len", 8192)
|
|
59
|
+
max_num_seqs = self._vllm_config.get("max_num_seqs", 256)
|
|
60
|
+
enable_prefix_caching = self._vllm_config.get("enable_prefix_caching", True)
|
|
61
|
+
enable_chunked_prefill = self._vllm_config.get("enable_chunked_prefill", True)
|
|
62
|
+
|
|
63
|
+
# vLLM 配置
|
|
64
|
+
llm_config = {
|
|
65
|
+
"model": model_id,
|
|
66
|
+
"tensor_parallel_size": tp_size,
|
|
67
|
+
"gpu_memory_utilization": gpu_util,
|
|
68
|
+
"max_model_len": max_model_len,
|
|
69
|
+
"max_num_seqs": max_num_seqs,
|
|
70
|
+
"trust_remote_code": True,
|
|
71
|
+
"enforce_eager": self._vllm_config.get("enforce_eager", False),
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
# 前缀缓存
|
|
75
|
+
if enable_prefix_caching:
|
|
76
|
+
llm_config["enable_prefix_caching"] = True
|
|
77
|
+
|
|
78
|
+
# 分块预填充
|
|
79
|
+
if enable_chunked_prefill:
|
|
80
|
+
llm_config["enable_chunked_prefill"] = True
|
|
81
|
+
|
|
82
|
+
# 量化配置
|
|
83
|
+
quantization = self.config.get("quantization")
|
|
84
|
+
if quantization:
|
|
85
|
+
if quantization in ["awq", "gptq", "squeezellm", "fp8", "int8"]:
|
|
86
|
+
llm_config["quantization"] = quantization
|
|
87
|
+
logger.info(f"Using quantization: {quantization}")
|
|
88
|
+
|
|
89
|
+
# dtype 配置
|
|
90
|
+
dtype = self._vllm_config.get("dtype", "auto")
|
|
91
|
+
llm_config["dtype"] = dtype
|
|
92
|
+
|
|
93
|
+
# 创建 LLM 实例
|
|
94
|
+
logger.info(f"Creating vLLM instance with config: {llm_config}")
|
|
95
|
+
self.llm = LLM(**llm_config)
|
|
96
|
+
|
|
97
|
+
# 获取 tokenizer
|
|
98
|
+
self.tokenizer = self.llm.get_tokenizer()
|
|
99
|
+
|
|
100
|
+
# 默认采样参数
|
|
101
|
+
self._default_sampling_params = SamplingParams(
|
|
102
|
+
temperature=0.7,
|
|
103
|
+
top_p=0.9,
|
|
104
|
+
max_tokens=2048,
|
|
105
|
+
)
|
|
106
|
+
|
|
107
|
+
self.loaded = True
|
|
108
|
+
logger.info("vLLM engine loaded successfully")
|
|
109
|
+
logger.info(f" - Tensor Parallel: {tp_size}")
|
|
110
|
+
logger.info(f" - GPU Memory Utilization: {gpu_util}")
|
|
111
|
+
logger.info(f" - Prefix Caching: {enable_prefix_caching}")
|
|
112
|
+
logger.info(f" - Max Sequences: {max_num_seqs}")
|
|
113
|
+
|
|
114
|
+
async def generate_async(
|
|
115
|
+
self,
|
|
116
|
+
messages: List[Dict[str, str]],
|
|
117
|
+
config: Optional[GenerationConfig] = None
|
|
118
|
+
) -> GenerationResult:
|
|
119
|
+
"""异步生成"""
|
|
120
|
+
if config is None:
|
|
121
|
+
config = GenerationConfig()
|
|
122
|
+
|
|
123
|
+
# vLLM 的 generate 是同步的,在线程池中执行
|
|
124
|
+
loop = asyncio.get_event_loop()
|
|
125
|
+
return await loop.run_in_executor(
|
|
126
|
+
None,
|
|
127
|
+
lambda: self._generate_sync(messages, config)
|
|
128
|
+
)
|
|
129
|
+
|
|
130
|
+
def _generate_sync(
|
|
131
|
+
self,
|
|
132
|
+
messages: List[Dict[str, str]],
|
|
133
|
+
config: GenerationConfig
|
|
134
|
+
) -> GenerationResult:
|
|
135
|
+
"""同步生成"""
|
|
136
|
+
from vllm import SamplingParams
|
|
137
|
+
|
|
138
|
+
start_time = time.time()
|
|
139
|
+
|
|
140
|
+
# 格式化输入
|
|
141
|
+
prompt = self._format_messages(messages)
|
|
142
|
+
|
|
143
|
+
# 采样参数
|
|
144
|
+
sampling_params = SamplingParams(
|
|
145
|
+
temperature=config.temperature if config.temperature > 0 else 1.0,
|
|
146
|
+
top_p=config.top_p,
|
|
147
|
+
top_k=config.top_k if config.top_k > 0 else -1,
|
|
148
|
+
max_tokens=config.max_tokens,
|
|
149
|
+
use_beam_search=config.temperature == 0,
|
|
150
|
+
stop=config.stop_sequences,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
# 执行生成
|
|
154
|
+
outputs = self.llm.generate([prompt], sampling_params)
|
|
155
|
+
|
|
156
|
+
# 解析输出
|
|
157
|
+
output = outputs[0]
|
|
158
|
+
response_text = output.outputs[0].text
|
|
159
|
+
prompt_tokens = len(output.prompt_token_ids)
|
|
160
|
+
completion_tokens = len(output.outputs[0].token_ids)
|
|
161
|
+
finish_reason = output.outputs[0].finish_reason or "stop"
|
|
162
|
+
|
|
163
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
164
|
+
logger.debug(f"Generation completed in {latency_ms:.2f}ms")
|
|
165
|
+
|
|
166
|
+
return GenerationResult(
|
|
167
|
+
text=response_text,
|
|
168
|
+
prompt_tokens=prompt_tokens,
|
|
169
|
+
completion_tokens=completion_tokens,
|
|
170
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
171
|
+
finish_reason=finish_reason
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
async def batch_generate(
|
|
175
|
+
self,
|
|
176
|
+
batch_messages: List[List[Dict[str, str]]],
|
|
177
|
+
config: Optional[GenerationConfig] = None
|
|
178
|
+
) -> List[GenerationResult]:
|
|
179
|
+
"""批量生成 - 利用 vLLM 的连续批处理"""
|
|
180
|
+
if config is None:
|
|
181
|
+
config = GenerationConfig()
|
|
182
|
+
|
|
183
|
+
# 在线程池中执行批量生成
|
|
184
|
+
loop = asyncio.get_event_loop()
|
|
185
|
+
return await loop.run_in_executor(
|
|
186
|
+
None,
|
|
187
|
+
lambda: self._batch_generate_sync(batch_messages, config)
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
def _batch_generate_sync(
|
|
191
|
+
self,
|
|
192
|
+
batch_messages: List[List[Dict[str, str]]],
|
|
193
|
+
config: GenerationConfig
|
|
194
|
+
) -> List[GenerationResult]:
|
|
195
|
+
"""同步批量生成"""
|
|
196
|
+
from vllm import SamplingParams
|
|
197
|
+
|
|
198
|
+
# 格式化所有输入
|
|
199
|
+
prompts = [self._format_messages(msgs) for msgs in batch_messages]
|
|
200
|
+
|
|
201
|
+
# 采样参数
|
|
202
|
+
sampling_params = SamplingParams(
|
|
203
|
+
temperature=config.temperature if config.temperature > 0 else 1.0,
|
|
204
|
+
top_p=config.top_p,
|
|
205
|
+
top_k=config.top_k if config.top_k > 0 else -1,
|
|
206
|
+
max_tokens=config.max_tokens,
|
|
207
|
+
stop=config.stop_sequences,
|
|
208
|
+
)
|
|
209
|
+
|
|
210
|
+
# 批量生成(vLLM 自动优化批处理)
|
|
211
|
+
outputs = self.llm.generate(prompts, sampling_params)
|
|
212
|
+
|
|
213
|
+
results = []
|
|
214
|
+
for output in outputs:
|
|
215
|
+
response_text = output.outputs[0].text
|
|
216
|
+
prompt_tokens = len(output.prompt_token_ids)
|
|
217
|
+
completion_tokens = len(output.outputs[0].token_ids)
|
|
218
|
+
finish_reason = output.outputs[0].finish_reason or "stop"
|
|
219
|
+
|
|
220
|
+
results.append(GenerationResult(
|
|
221
|
+
text=response_text,
|
|
222
|
+
prompt_tokens=prompt_tokens,
|
|
223
|
+
completion_tokens=completion_tokens,
|
|
224
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
225
|
+
finish_reason=finish_reason
|
|
226
|
+
))
|
|
227
|
+
|
|
228
|
+
return results
|
|
229
|
+
|
|
230
|
+
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
231
|
+
"""格式化消息为 prompt"""
|
|
232
|
+
if self.tokenizer and hasattr(self.tokenizer, "apply_chat_template"):
|
|
233
|
+
return self.tokenizer.apply_chat_template(
|
|
234
|
+
messages,
|
|
235
|
+
tokenize=False,
|
|
236
|
+
add_generation_prompt=True
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
# Fallback: 简单拼接
|
|
240
|
+
formatted = []
|
|
241
|
+
for msg in messages:
|
|
242
|
+
role = msg.get("role", "user")
|
|
243
|
+
content = msg.get("content", "")
|
|
244
|
+
formatted.append(f"{role}: {content}")
|
|
245
|
+
|
|
246
|
+
formatted.append("assistant: ")
|
|
247
|
+
return "\n".join(formatted)
|
|
248
|
+
|
|
249
|
+
def supports_streaming(self) -> bool:
|
|
250
|
+
"""vLLM 同步模式不支持流式"""
|
|
251
|
+
return False
|
|
252
|
+
|
|
253
|
+
def supports_prefix_caching(self) -> bool:
|
|
254
|
+
"""支持前缀缓存"""
|
|
255
|
+
return self._vllm_config.get("enable_prefix_caching", True)
|
|
256
|
+
|
|
257
|
+
def supports_batch_inference(self) -> bool:
|
|
258
|
+
"""支持批量推理"""
|
|
259
|
+
return True
|
|
260
|
+
|
|
261
|
+
def unload_model(self) -> None:
|
|
262
|
+
"""卸载模型"""
|
|
263
|
+
if self.llm:
|
|
264
|
+
del self.llm
|
|
265
|
+
self.llm = None
|
|
266
|
+
|
|
267
|
+
if self.tokenizer:
|
|
268
|
+
del self.tokenizer
|
|
269
|
+
self.tokenizer = None
|
|
270
|
+
|
|
271
|
+
self._default_sampling_params = None
|
|
272
|
+
|
|
273
|
+
import torch
|
|
274
|
+
if torch.cuda.is_available():
|
|
275
|
+
torch.cuda.empty_cache()
|
|
276
|
+
|
|
277
|
+
self.loaded = False
|
|
278
|
+
logger.info("vLLM engine unloaded")
|
|
279
|
+
|
|
280
|
+
def get_status(self) -> Dict[str, Any]:
|
|
281
|
+
"""获取引擎状态"""
|
|
282
|
+
status = super().get_status()
|
|
283
|
+
status["features"] = [
|
|
284
|
+
"paged_attention",
|
|
285
|
+
"continuous_batching",
|
|
286
|
+
"tensor_parallelism",
|
|
287
|
+
"prefix_caching",
|
|
288
|
+
"chunked_prefill",
|
|
289
|
+
]
|
|
290
|
+
return status
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class VLLMAsyncEngine(LLMBaseEngine):
|
|
294
|
+
"""
|
|
295
|
+
基于 vLLM AsyncLLMEngine 的异步推理引擎
|
|
296
|
+
|
|
297
|
+
适用于需要流式输出和高并发的场景
|
|
298
|
+
"""
|
|
299
|
+
|
|
300
|
+
def __init__(self, config: Dict[str, Any]):
|
|
301
|
+
super().__init__(config)
|
|
302
|
+
self.backend_type = LLMBackend.VLLM
|
|
303
|
+
self.engine = None
|
|
304
|
+
self._vllm_config = config.get("vllm", {})
|
|
305
|
+
|
|
306
|
+
def load_model(self) -> None:
|
|
307
|
+
"""加载模型到 AsyncLLMEngine"""
|
|
308
|
+
try:
|
|
309
|
+
from vllm import AsyncLLMEngine, AsyncEngineArgs
|
|
310
|
+
except ImportError:
|
|
311
|
+
raise ImportError(
|
|
312
|
+
"vLLM not installed. Please install with:\n"
|
|
313
|
+
" pip install vllm"
|
|
314
|
+
)
|
|
315
|
+
|
|
316
|
+
model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
|
|
317
|
+
logger.info(f"Loading model with vLLM AsyncEngine: {model_id}")
|
|
318
|
+
|
|
319
|
+
# 合并配置
|
|
320
|
+
tp_size = self._vllm_config.get("tensor_parallel_size", 1)
|
|
321
|
+
gpu_util = self._vllm_config.get("gpu_memory_utilization", 0.85)
|
|
322
|
+
max_model_len = self._vllm_config.get("max_model_len", 8192)
|
|
323
|
+
enable_prefix_caching = self._vllm_config.get("enable_prefix_caching", True)
|
|
324
|
+
enable_chunked_prefill = self._vllm_config.get("enable_chunked_prefill", True)
|
|
325
|
+
|
|
326
|
+
# 引擎参数
|
|
327
|
+
engine_args = AsyncEngineArgs(
|
|
328
|
+
model=model_id,
|
|
329
|
+
tensor_parallel_size=tp_size,
|
|
330
|
+
gpu_memory_utilization=gpu_util,
|
|
331
|
+
max_model_len=max_model_len,
|
|
332
|
+
trust_remote_code=True,
|
|
333
|
+
enable_prefix_caching=enable_prefix_caching,
|
|
334
|
+
enable_chunked_prefill=enable_chunked_prefill,
|
|
335
|
+
)
|
|
336
|
+
|
|
337
|
+
# 量化配置
|
|
338
|
+
quantization = self.config.get("quantization")
|
|
339
|
+
if quantization:
|
|
340
|
+
engine_args.quantization = quantization
|
|
341
|
+
|
|
342
|
+
# 创建异步引擎
|
|
343
|
+
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
|
344
|
+
|
|
345
|
+
# 获取 tokenizer
|
|
346
|
+
from transformers import AutoTokenizer
|
|
347
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
348
|
+
model_id,
|
|
349
|
+
trust_remote_code=True
|
|
350
|
+
)
|
|
351
|
+
|
|
352
|
+
self.loaded = True
|
|
353
|
+
logger.info("vLLM AsyncEngine loaded successfully")
|
|
354
|
+
|
|
355
|
+
async def generate_async(
|
|
356
|
+
self,
|
|
357
|
+
messages: List[Dict[str, str]],
|
|
358
|
+
config: Optional[GenerationConfig] = None
|
|
359
|
+
) -> GenerationResult:
|
|
360
|
+
"""异步生成"""
|
|
361
|
+
from vllm import SamplingParams
|
|
362
|
+
import uuid
|
|
363
|
+
|
|
364
|
+
if config is None:
|
|
365
|
+
config = GenerationConfig()
|
|
366
|
+
|
|
367
|
+
start_time = time.time()
|
|
368
|
+
|
|
369
|
+
# 格式化输入
|
|
370
|
+
prompt = self._format_messages(messages)
|
|
371
|
+
request_id = str(uuid.uuid4())
|
|
372
|
+
|
|
373
|
+
# 采样参数
|
|
374
|
+
sampling_params = SamplingParams(
|
|
375
|
+
temperature=config.temperature if config.temperature > 0 else 1.0,
|
|
376
|
+
top_p=config.top_p,
|
|
377
|
+
top_k=config.top_k if config.top_k > 0 else -1,
|
|
378
|
+
max_tokens=config.max_tokens,
|
|
379
|
+
stop=config.stop_sequences,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
# 异步生成
|
|
383
|
+
results_generator = self.engine.generate(prompt, sampling_params, request_id)
|
|
384
|
+
|
|
385
|
+
final_output = None
|
|
386
|
+
async for request_output in results_generator:
|
|
387
|
+
final_output = request_output
|
|
388
|
+
|
|
389
|
+
if final_output is None:
|
|
390
|
+
raise RuntimeError("No output generated")
|
|
391
|
+
|
|
392
|
+
output = final_output.outputs[0]
|
|
393
|
+
response_text = output.text
|
|
394
|
+
prompt_tokens = len(final_output.prompt_token_ids)
|
|
395
|
+
completion_tokens = len(output.token_ids)
|
|
396
|
+
finish_reason = output.finish_reason or "stop"
|
|
397
|
+
|
|
398
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
399
|
+
logger.debug(f"Generation completed in {latency_ms:.2f}ms")
|
|
400
|
+
|
|
401
|
+
return GenerationResult(
|
|
402
|
+
text=response_text,
|
|
403
|
+
prompt_tokens=prompt_tokens,
|
|
404
|
+
completion_tokens=completion_tokens,
|
|
405
|
+
total_tokens=prompt_tokens + completion_tokens,
|
|
406
|
+
finish_reason=finish_reason
|
|
407
|
+
)
|
|
408
|
+
|
|
409
|
+
async def batch_generate(
|
|
410
|
+
self,
|
|
411
|
+
batch_messages: List[List[Dict[str, str]]],
|
|
412
|
+
config: Optional[GenerationConfig] = None
|
|
413
|
+
) -> List[GenerationResult]:
|
|
414
|
+
"""批量生成"""
|
|
415
|
+
if config is None:
|
|
416
|
+
config = GenerationConfig()
|
|
417
|
+
|
|
418
|
+
# 并发执行所有请求
|
|
419
|
+
tasks = [
|
|
420
|
+
self.generate_async(messages, config)
|
|
421
|
+
for messages in batch_messages
|
|
422
|
+
]
|
|
423
|
+
|
|
424
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
425
|
+
|
|
426
|
+
outputs = []
|
|
427
|
+
for result in results:
|
|
428
|
+
if isinstance(result, Exception):
|
|
429
|
+
logger.error(f"Batch generation error: {result}")
|
|
430
|
+
outputs.append(GenerationResult(
|
|
431
|
+
text="",
|
|
432
|
+
prompt_tokens=0,
|
|
433
|
+
completion_tokens=0,
|
|
434
|
+
total_tokens=0,
|
|
435
|
+
finish_reason="error"
|
|
436
|
+
))
|
|
437
|
+
else:
|
|
438
|
+
outputs.append(result)
|
|
439
|
+
|
|
440
|
+
return outputs
|
|
441
|
+
|
|
442
|
+
async def stream_generate(
|
|
443
|
+
self,
|
|
444
|
+
messages: List[Dict[str, str]],
|
|
445
|
+
config: Optional[GenerationConfig] = None
|
|
446
|
+
) -> AsyncIterator[str]:
|
|
447
|
+
"""流式生成"""
|
|
448
|
+
from vllm import SamplingParams
|
|
449
|
+
import uuid
|
|
450
|
+
|
|
451
|
+
if config is None:
|
|
452
|
+
config = GenerationConfig()
|
|
453
|
+
|
|
454
|
+
# 格式化输入
|
|
455
|
+
prompt = self._format_messages(messages)
|
|
456
|
+
request_id = str(uuid.uuid4())
|
|
457
|
+
|
|
458
|
+
# 采样参数
|
|
459
|
+
sampling_params = SamplingParams(
|
|
460
|
+
temperature=config.temperature if config.temperature > 0 else 1.0,
|
|
461
|
+
top_p=config.top_p,
|
|
462
|
+
top_k=config.top_k if config.top_k > 0 else -1,
|
|
463
|
+
max_tokens=config.max_tokens,
|
|
464
|
+
stop=config.stop_sequences,
|
|
465
|
+
)
|
|
466
|
+
|
|
467
|
+
# 流式生成
|
|
468
|
+
results_generator = self.engine.generate(prompt, sampling_params, request_id)
|
|
469
|
+
|
|
470
|
+
prev_text = ""
|
|
471
|
+
async for request_output in results_generator:
|
|
472
|
+
output = request_output.outputs[0]
|
|
473
|
+
new_text = output.text[len(prev_text):]
|
|
474
|
+
prev_text = output.text
|
|
475
|
+
|
|
476
|
+
if new_text:
|
|
477
|
+
yield new_text
|
|
478
|
+
|
|
479
|
+
def _format_messages(self, messages: List[Dict[str, str]]) -> str:
|
|
480
|
+
"""格式化消息为 prompt"""
|
|
481
|
+
if self.tokenizer and hasattr(self.tokenizer, "apply_chat_template"):
|
|
482
|
+
return self.tokenizer.apply_chat_template(
|
|
483
|
+
messages,
|
|
484
|
+
tokenize=False,
|
|
485
|
+
add_generation_prompt=True
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
# Fallback
|
|
489
|
+
formatted = []
|
|
490
|
+
for msg in messages:
|
|
491
|
+
role = msg.get("role", "user")
|
|
492
|
+
content = msg.get("content", "")
|
|
493
|
+
formatted.append(f"{role}: {content}")
|
|
494
|
+
|
|
495
|
+
formatted.append("assistant: ")
|
|
496
|
+
return "\n".join(formatted)
|
|
497
|
+
|
|
498
|
+
def supports_streaming(self) -> bool:
|
|
499
|
+
"""支持流式输出"""
|
|
500
|
+
return True
|
|
501
|
+
|
|
502
|
+
def supports_prefix_caching(self) -> bool:
|
|
503
|
+
"""支持前缀缓存"""
|
|
504
|
+
return self._vllm_config.get("enable_prefix_caching", True)
|
|
505
|
+
|
|
506
|
+
def supports_batch_inference(self) -> bool:
|
|
507
|
+
"""支持批量推理"""
|
|
508
|
+
return True
|
|
509
|
+
|
|
510
|
+
def unload_model(self) -> None:
|
|
511
|
+
"""卸载模型"""
|
|
512
|
+
if self.engine:
|
|
513
|
+
del self.engine
|
|
514
|
+
self.engine = None
|
|
515
|
+
|
|
516
|
+
if self.tokenizer:
|
|
517
|
+
del self.tokenizer
|
|
518
|
+
self.tokenizer = None
|
|
519
|
+
|
|
520
|
+
import torch
|
|
521
|
+
if torch.cuda.is_available():
|
|
522
|
+
torch.cuda.empty_cache()
|
|
523
|
+
|
|
524
|
+
self.loaded = False
|
|
525
|
+
logger.info("vLLM AsyncEngine unloaded")
|
|
526
|
+
|
|
527
|
+
def get_status(self) -> Dict[str, Any]:
|
|
528
|
+
"""获取引擎状态"""
|
|
529
|
+
status = super().get_status()
|
|
530
|
+
status["async_mode"] = True
|
|
531
|
+
status["features"] = [
|
|
532
|
+
"paged_attention",
|
|
533
|
+
"continuous_batching",
|
|
534
|
+
"tensor_parallelism",
|
|
535
|
+
"prefix_caching",
|
|
536
|
+
"async_inference",
|
|
537
|
+
"streaming",
|
|
538
|
+
]
|
|
539
|
+
return status
|