gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,539 @@
1
+ """
2
+ vLLM 高性能推理引擎
3
+
4
+ 特点:
5
+ - PagedAttention 高效内存管理(PyTorch Foundation 项目)
6
+ - 连续批处理 (Continuous Batching)
7
+ - 张量并行 (Tensor Parallelism)
8
+ - 前缀缓存 (Prefix Caching)
9
+ - 分块预填充 (Chunked Prefill)
10
+ - 多种量化支持 (AWQ, GPTQ, INT8, FP8)
11
+
12
+ 参考:https://github.com/vllm-project/vllm
13
+ """
14
+ from typing import Dict, Any, List, Optional, AsyncIterator
15
+ import logging
16
+ import asyncio
17
+ import time
18
+
19
+ from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
20
+
21
+ logger = logging.getLogger(__name__)
22
+
23
+
24
+ class VLLMEngine(LLMBaseEngine):
25
+ """
26
+ 基于 vLLM 的高性能 LLM 推理引擎
27
+
28
+ vLLM 核心优势:
29
+ 1. PagedAttention - 高效显存管理,提升利用率到 85%+
30
+ 2. 成熟的生态系统和社区支持
31
+ 3. 支持多种量化方案 (AWQ, GPTQ, FP8, INT8)
32
+ 4. 原生支持张量并行,适合多卡部署
33
+ """
34
+
35
+ def __init__(self, config: Dict[str, Any]):
36
+ super().__init__(config)
37
+ self.backend_type = LLMBackend.VLLM
38
+ self.llm = None
39
+ self._vllm_config = config.get("vllm", {})
40
+ self._default_sampling_params = None
41
+
42
+ def load_model(self) -> None:
43
+ """加载模型到 vLLM"""
44
+ try:
45
+ from vllm import LLM, SamplingParams
46
+ except ImportError:
47
+ raise ImportError(
48
+ "vLLM not installed. Please install with:\n"
49
+ " pip install vllm"
50
+ )
51
+
52
+ model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
53
+ logger.info(f"Loading model with vLLM: {model_id}")
54
+
55
+ # 合并配置
56
+ tp_size = self._vllm_config.get("tensor_parallel_size", 1)
57
+ gpu_util = self._vllm_config.get("gpu_memory_utilization", 0.85)
58
+ max_model_len = self._vllm_config.get("max_model_len", 8192)
59
+ max_num_seqs = self._vllm_config.get("max_num_seqs", 256)
60
+ enable_prefix_caching = self._vllm_config.get("enable_prefix_caching", True)
61
+ enable_chunked_prefill = self._vllm_config.get("enable_chunked_prefill", True)
62
+
63
+ # vLLM 配置
64
+ llm_config = {
65
+ "model": model_id,
66
+ "tensor_parallel_size": tp_size,
67
+ "gpu_memory_utilization": gpu_util,
68
+ "max_model_len": max_model_len,
69
+ "max_num_seqs": max_num_seqs,
70
+ "trust_remote_code": True,
71
+ "enforce_eager": self._vllm_config.get("enforce_eager", False),
72
+ }
73
+
74
+ # 前缀缓存
75
+ if enable_prefix_caching:
76
+ llm_config["enable_prefix_caching"] = True
77
+
78
+ # 分块预填充
79
+ if enable_chunked_prefill:
80
+ llm_config["enable_chunked_prefill"] = True
81
+
82
+ # 量化配置
83
+ quantization = self.config.get("quantization")
84
+ if quantization:
85
+ if quantization in ["awq", "gptq", "squeezellm", "fp8", "int8"]:
86
+ llm_config["quantization"] = quantization
87
+ logger.info(f"Using quantization: {quantization}")
88
+
89
+ # dtype 配置
90
+ dtype = self._vllm_config.get("dtype", "auto")
91
+ llm_config["dtype"] = dtype
92
+
93
+ # 创建 LLM 实例
94
+ logger.info(f"Creating vLLM instance with config: {llm_config}")
95
+ self.llm = LLM(**llm_config)
96
+
97
+ # 获取 tokenizer
98
+ self.tokenizer = self.llm.get_tokenizer()
99
+
100
+ # 默认采样参数
101
+ self._default_sampling_params = SamplingParams(
102
+ temperature=0.7,
103
+ top_p=0.9,
104
+ max_tokens=2048,
105
+ )
106
+
107
+ self.loaded = True
108
+ logger.info("vLLM engine loaded successfully")
109
+ logger.info(f" - Tensor Parallel: {tp_size}")
110
+ logger.info(f" - GPU Memory Utilization: {gpu_util}")
111
+ logger.info(f" - Prefix Caching: {enable_prefix_caching}")
112
+ logger.info(f" - Max Sequences: {max_num_seqs}")
113
+
114
+ async def generate_async(
115
+ self,
116
+ messages: List[Dict[str, str]],
117
+ config: Optional[GenerationConfig] = None
118
+ ) -> GenerationResult:
119
+ """异步生成"""
120
+ if config is None:
121
+ config = GenerationConfig()
122
+
123
+ # vLLM 的 generate 是同步的,在线程池中执行
124
+ loop = asyncio.get_event_loop()
125
+ return await loop.run_in_executor(
126
+ None,
127
+ lambda: self._generate_sync(messages, config)
128
+ )
129
+
130
+ def _generate_sync(
131
+ self,
132
+ messages: List[Dict[str, str]],
133
+ config: GenerationConfig
134
+ ) -> GenerationResult:
135
+ """同步生成"""
136
+ from vllm import SamplingParams
137
+
138
+ start_time = time.time()
139
+
140
+ # 格式化输入
141
+ prompt = self._format_messages(messages)
142
+
143
+ # 采样参数
144
+ sampling_params = SamplingParams(
145
+ temperature=config.temperature if config.temperature > 0 else 1.0,
146
+ top_p=config.top_p,
147
+ top_k=config.top_k if config.top_k > 0 else -1,
148
+ max_tokens=config.max_tokens,
149
+ use_beam_search=config.temperature == 0,
150
+ stop=config.stop_sequences,
151
+ )
152
+
153
+ # 执行生成
154
+ outputs = self.llm.generate([prompt], sampling_params)
155
+
156
+ # 解析输出
157
+ output = outputs[0]
158
+ response_text = output.outputs[0].text
159
+ prompt_tokens = len(output.prompt_token_ids)
160
+ completion_tokens = len(output.outputs[0].token_ids)
161
+ finish_reason = output.outputs[0].finish_reason or "stop"
162
+
163
+ latency_ms = (time.time() - start_time) * 1000
164
+ logger.debug(f"Generation completed in {latency_ms:.2f}ms")
165
+
166
+ return GenerationResult(
167
+ text=response_text,
168
+ prompt_tokens=prompt_tokens,
169
+ completion_tokens=completion_tokens,
170
+ total_tokens=prompt_tokens + completion_tokens,
171
+ finish_reason=finish_reason
172
+ )
173
+
174
+ async def batch_generate(
175
+ self,
176
+ batch_messages: List[List[Dict[str, str]]],
177
+ config: Optional[GenerationConfig] = None
178
+ ) -> List[GenerationResult]:
179
+ """批量生成 - 利用 vLLM 的连续批处理"""
180
+ if config is None:
181
+ config = GenerationConfig()
182
+
183
+ # 在线程池中执行批量生成
184
+ loop = asyncio.get_event_loop()
185
+ return await loop.run_in_executor(
186
+ None,
187
+ lambda: self._batch_generate_sync(batch_messages, config)
188
+ )
189
+
190
+ def _batch_generate_sync(
191
+ self,
192
+ batch_messages: List[List[Dict[str, str]]],
193
+ config: GenerationConfig
194
+ ) -> List[GenerationResult]:
195
+ """同步批量生成"""
196
+ from vllm import SamplingParams
197
+
198
+ # 格式化所有输入
199
+ prompts = [self._format_messages(msgs) for msgs in batch_messages]
200
+
201
+ # 采样参数
202
+ sampling_params = SamplingParams(
203
+ temperature=config.temperature if config.temperature > 0 else 1.0,
204
+ top_p=config.top_p,
205
+ top_k=config.top_k if config.top_k > 0 else -1,
206
+ max_tokens=config.max_tokens,
207
+ stop=config.stop_sequences,
208
+ )
209
+
210
+ # 批量生成(vLLM 自动优化批处理)
211
+ outputs = self.llm.generate(prompts, sampling_params)
212
+
213
+ results = []
214
+ for output in outputs:
215
+ response_text = output.outputs[0].text
216
+ prompt_tokens = len(output.prompt_token_ids)
217
+ completion_tokens = len(output.outputs[0].token_ids)
218
+ finish_reason = output.outputs[0].finish_reason or "stop"
219
+
220
+ results.append(GenerationResult(
221
+ text=response_text,
222
+ prompt_tokens=prompt_tokens,
223
+ completion_tokens=completion_tokens,
224
+ total_tokens=prompt_tokens + completion_tokens,
225
+ finish_reason=finish_reason
226
+ ))
227
+
228
+ return results
229
+
230
+ def _format_messages(self, messages: List[Dict[str, str]]) -> str:
231
+ """格式化消息为 prompt"""
232
+ if self.tokenizer and hasattr(self.tokenizer, "apply_chat_template"):
233
+ return self.tokenizer.apply_chat_template(
234
+ messages,
235
+ tokenize=False,
236
+ add_generation_prompt=True
237
+ )
238
+
239
+ # Fallback: 简单拼接
240
+ formatted = []
241
+ for msg in messages:
242
+ role = msg.get("role", "user")
243
+ content = msg.get("content", "")
244
+ formatted.append(f"{role}: {content}")
245
+
246
+ formatted.append("assistant: ")
247
+ return "\n".join(formatted)
248
+
249
+ def supports_streaming(self) -> bool:
250
+ """vLLM 同步模式不支持流式"""
251
+ return False
252
+
253
+ def supports_prefix_caching(self) -> bool:
254
+ """支持前缀缓存"""
255
+ return self._vllm_config.get("enable_prefix_caching", True)
256
+
257
+ def supports_batch_inference(self) -> bool:
258
+ """支持批量推理"""
259
+ return True
260
+
261
+ def unload_model(self) -> None:
262
+ """卸载模型"""
263
+ if self.llm:
264
+ del self.llm
265
+ self.llm = None
266
+
267
+ if self.tokenizer:
268
+ del self.tokenizer
269
+ self.tokenizer = None
270
+
271
+ self._default_sampling_params = None
272
+
273
+ import torch
274
+ if torch.cuda.is_available():
275
+ torch.cuda.empty_cache()
276
+
277
+ self.loaded = False
278
+ logger.info("vLLM engine unloaded")
279
+
280
+ def get_status(self) -> Dict[str, Any]:
281
+ """获取引擎状态"""
282
+ status = super().get_status()
283
+ status["features"] = [
284
+ "paged_attention",
285
+ "continuous_batching",
286
+ "tensor_parallelism",
287
+ "prefix_caching",
288
+ "chunked_prefill",
289
+ ]
290
+ return status
291
+
292
+
293
+ class VLLMAsyncEngine(LLMBaseEngine):
294
+ """
295
+ 基于 vLLM AsyncLLMEngine 的异步推理引擎
296
+
297
+ 适用于需要流式输出和高并发的场景
298
+ """
299
+
300
+ def __init__(self, config: Dict[str, Any]):
301
+ super().__init__(config)
302
+ self.backend_type = LLMBackend.VLLM
303
+ self.engine = None
304
+ self._vllm_config = config.get("vllm", {})
305
+
306
+ def load_model(self) -> None:
307
+ """加载模型到 AsyncLLMEngine"""
308
+ try:
309
+ from vllm import AsyncLLMEngine, AsyncEngineArgs
310
+ except ImportError:
311
+ raise ImportError(
312
+ "vLLM not installed. Please install with:\n"
313
+ " pip install vllm"
314
+ )
315
+
316
+ model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
317
+ logger.info(f"Loading model with vLLM AsyncEngine: {model_id}")
318
+
319
+ # 合并配置
320
+ tp_size = self._vllm_config.get("tensor_parallel_size", 1)
321
+ gpu_util = self._vllm_config.get("gpu_memory_utilization", 0.85)
322
+ max_model_len = self._vllm_config.get("max_model_len", 8192)
323
+ enable_prefix_caching = self._vllm_config.get("enable_prefix_caching", True)
324
+ enable_chunked_prefill = self._vllm_config.get("enable_chunked_prefill", True)
325
+
326
+ # 引擎参数
327
+ engine_args = AsyncEngineArgs(
328
+ model=model_id,
329
+ tensor_parallel_size=tp_size,
330
+ gpu_memory_utilization=gpu_util,
331
+ max_model_len=max_model_len,
332
+ trust_remote_code=True,
333
+ enable_prefix_caching=enable_prefix_caching,
334
+ enable_chunked_prefill=enable_chunked_prefill,
335
+ )
336
+
337
+ # 量化配置
338
+ quantization = self.config.get("quantization")
339
+ if quantization:
340
+ engine_args.quantization = quantization
341
+
342
+ # 创建异步引擎
343
+ self.engine = AsyncLLMEngine.from_engine_args(engine_args)
344
+
345
+ # 获取 tokenizer
346
+ from transformers import AutoTokenizer
347
+ self.tokenizer = AutoTokenizer.from_pretrained(
348
+ model_id,
349
+ trust_remote_code=True
350
+ )
351
+
352
+ self.loaded = True
353
+ logger.info("vLLM AsyncEngine loaded successfully")
354
+
355
+ async def generate_async(
356
+ self,
357
+ messages: List[Dict[str, str]],
358
+ config: Optional[GenerationConfig] = None
359
+ ) -> GenerationResult:
360
+ """异步生成"""
361
+ from vllm import SamplingParams
362
+ import uuid
363
+
364
+ if config is None:
365
+ config = GenerationConfig()
366
+
367
+ start_time = time.time()
368
+
369
+ # 格式化输入
370
+ prompt = self._format_messages(messages)
371
+ request_id = str(uuid.uuid4())
372
+
373
+ # 采样参数
374
+ sampling_params = SamplingParams(
375
+ temperature=config.temperature if config.temperature > 0 else 1.0,
376
+ top_p=config.top_p,
377
+ top_k=config.top_k if config.top_k > 0 else -1,
378
+ max_tokens=config.max_tokens,
379
+ stop=config.stop_sequences,
380
+ )
381
+
382
+ # 异步生成
383
+ results_generator = self.engine.generate(prompt, sampling_params, request_id)
384
+
385
+ final_output = None
386
+ async for request_output in results_generator:
387
+ final_output = request_output
388
+
389
+ if final_output is None:
390
+ raise RuntimeError("No output generated")
391
+
392
+ output = final_output.outputs[0]
393
+ response_text = output.text
394
+ prompt_tokens = len(final_output.prompt_token_ids)
395
+ completion_tokens = len(output.token_ids)
396
+ finish_reason = output.finish_reason or "stop"
397
+
398
+ latency_ms = (time.time() - start_time) * 1000
399
+ logger.debug(f"Generation completed in {latency_ms:.2f}ms")
400
+
401
+ return GenerationResult(
402
+ text=response_text,
403
+ prompt_tokens=prompt_tokens,
404
+ completion_tokens=completion_tokens,
405
+ total_tokens=prompt_tokens + completion_tokens,
406
+ finish_reason=finish_reason
407
+ )
408
+
409
+ async def batch_generate(
410
+ self,
411
+ batch_messages: List[List[Dict[str, str]]],
412
+ config: Optional[GenerationConfig] = None
413
+ ) -> List[GenerationResult]:
414
+ """批量生成"""
415
+ if config is None:
416
+ config = GenerationConfig()
417
+
418
+ # 并发执行所有请求
419
+ tasks = [
420
+ self.generate_async(messages, config)
421
+ for messages in batch_messages
422
+ ]
423
+
424
+ results = await asyncio.gather(*tasks, return_exceptions=True)
425
+
426
+ outputs = []
427
+ for result in results:
428
+ if isinstance(result, Exception):
429
+ logger.error(f"Batch generation error: {result}")
430
+ outputs.append(GenerationResult(
431
+ text="",
432
+ prompt_tokens=0,
433
+ completion_tokens=0,
434
+ total_tokens=0,
435
+ finish_reason="error"
436
+ ))
437
+ else:
438
+ outputs.append(result)
439
+
440
+ return outputs
441
+
442
+ async def stream_generate(
443
+ self,
444
+ messages: List[Dict[str, str]],
445
+ config: Optional[GenerationConfig] = None
446
+ ) -> AsyncIterator[str]:
447
+ """流式生成"""
448
+ from vllm import SamplingParams
449
+ import uuid
450
+
451
+ if config is None:
452
+ config = GenerationConfig()
453
+
454
+ # 格式化输入
455
+ prompt = self._format_messages(messages)
456
+ request_id = str(uuid.uuid4())
457
+
458
+ # 采样参数
459
+ sampling_params = SamplingParams(
460
+ temperature=config.temperature if config.temperature > 0 else 1.0,
461
+ top_p=config.top_p,
462
+ top_k=config.top_k if config.top_k > 0 else -1,
463
+ max_tokens=config.max_tokens,
464
+ stop=config.stop_sequences,
465
+ )
466
+
467
+ # 流式生成
468
+ results_generator = self.engine.generate(prompt, sampling_params, request_id)
469
+
470
+ prev_text = ""
471
+ async for request_output in results_generator:
472
+ output = request_output.outputs[0]
473
+ new_text = output.text[len(prev_text):]
474
+ prev_text = output.text
475
+
476
+ if new_text:
477
+ yield new_text
478
+
479
+ def _format_messages(self, messages: List[Dict[str, str]]) -> str:
480
+ """格式化消息为 prompt"""
481
+ if self.tokenizer and hasattr(self.tokenizer, "apply_chat_template"):
482
+ return self.tokenizer.apply_chat_template(
483
+ messages,
484
+ tokenize=False,
485
+ add_generation_prompt=True
486
+ )
487
+
488
+ # Fallback
489
+ formatted = []
490
+ for msg in messages:
491
+ role = msg.get("role", "user")
492
+ content = msg.get("content", "")
493
+ formatted.append(f"{role}: {content}")
494
+
495
+ formatted.append("assistant: ")
496
+ return "\n".join(formatted)
497
+
498
+ def supports_streaming(self) -> bool:
499
+ """支持流式输出"""
500
+ return True
501
+
502
+ def supports_prefix_caching(self) -> bool:
503
+ """支持前缀缓存"""
504
+ return self._vllm_config.get("enable_prefix_caching", True)
505
+
506
+ def supports_batch_inference(self) -> bool:
507
+ """支持批量推理"""
508
+ return True
509
+
510
+ def unload_model(self) -> None:
511
+ """卸载模型"""
512
+ if self.engine:
513
+ del self.engine
514
+ self.engine = None
515
+
516
+ if self.tokenizer:
517
+ del self.tokenizer
518
+ self.tokenizer = None
519
+
520
+ import torch
521
+ if torch.cuda.is_available():
522
+ torch.cuda.empty_cache()
523
+
524
+ self.loaded = False
525
+ logger.info("vLLM AsyncEngine unloaded")
526
+
527
+ def get_status(self) -> Dict[str, Any]:
528
+ """获取引擎状态"""
529
+ status = super().get_status()
530
+ status["async_mode"] = True
531
+ status["features"] = [
532
+ "paged_attention",
533
+ "continuous_batching",
534
+ "tensor_parallelism",
535
+ "prefix_caching",
536
+ "async_inference",
537
+ "streaming",
538
+ ]
539
+ return status