gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,489 @@
1
+ """
2
+ SGLang 高性能推理引擎
3
+
4
+ 特点:
5
+ - RadixAttention 前缀缓存(自动复用相同前缀的 KV-Cache,适合 RAG/Agent)
6
+ - PagedAttention 内存管理(高效显存利用率 80%+)
7
+ - 连续批处理 (Continuous Batching)
8
+ - 分块预填充 (Chunked Prefill)
9
+ - 3.1x 吞吐量提升(相比 vLLM 2024 基准)
10
+
11
+ 参考:https://github.com/sgl-project/sglang
12
+ """
13
+ from typing import Dict, Any, List, Optional, AsyncIterator
14
+ import logging
15
+ import asyncio
16
+ import time
17
+
18
+ from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ class SGLangEngine(LLMBaseEngine):
24
+ """
25
+ 基于 SGLang 的高性能 LLM 推理引擎
26
+
27
+ SGLang 核心优势:
28
+ 1. RadixAttention - 自动前缀缓存,共享系统提示/对话历史的 KV-Cache
29
+ 2. 纯 Python 实现 (<4K 核心代码),易于调试和定制
30
+ 3. 原生支持结构化输出和约束解码
31
+ 4. 支持主流模型架构(Llama, Qwen, Mistral, Gemma 等)
32
+ """
33
+
34
+ def __init__(self, config: Dict[str, Any]):
35
+ super().__init__(config)
36
+ self.backend_type = LLMBackend.SGLANG
37
+ self.runtime = None
38
+ self._server_process = None
39
+ self._sglang_config = config.get("sglang", {})
40
+
41
+ # 缓存统计
42
+ self._cache_hits = 0
43
+ self._cache_misses = 0
44
+
45
+ def load_model(self) -> None:
46
+ """加载模型到 SGLang Runtime"""
47
+ try:
48
+ import sglang as sgl
49
+ except ImportError:
50
+ raise ImportError(
51
+ "SGLang not installed. Please install with:\n"
52
+ " pip install sglang[all]\n"
53
+ "Or for minimal install:\n"
54
+ " pip install sglang"
55
+ )
56
+
57
+ model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
58
+ logger.info(f"Loading model with SGLang: {model_id}")
59
+
60
+ # 合并配置
61
+ tp_size = self._sglang_config.get("tp_size", 1)
62
+ mem_fraction = self._sglang_config.get("mem_fraction_static", 0.85)
63
+ chunked_prefill = self._sglang_config.get("chunked_prefill_size", 8192)
64
+ enable_prefix_caching = self._sglang_config.get("enable_prefix_caching", True)
65
+ max_batch_size = self._sglang_config.get("max_running_requests", 32)
66
+ context_length = self._sglang_config.get("context_length", 8192)
67
+
68
+ # SGLang Runtime 配置
69
+ runtime_config = {
70
+ "model_path": model_id,
71
+ "tp_size": tp_size,
72
+ "mem_fraction_static": mem_fraction,
73
+ "chunked_prefill_size": chunked_prefill,
74
+ "max_running_requests": max_batch_size,
75
+ "context_length": context_length,
76
+ "disable_radix_cache": not enable_prefix_caching,
77
+ }
78
+
79
+ # 量化配置
80
+ quantization = self.config.get("quantization")
81
+ if quantization:
82
+ if quantization in ["int8", "fp8", "awq", "gptq"]:
83
+ runtime_config["quantization"] = quantization
84
+ logger.info(f"Using quantization: {quantization}")
85
+
86
+ # 启动模式选择
87
+ if self._sglang_config.get("use_server_mode", False):
88
+ # 服务器模式 - 连接到已运行的 SGLang 服务器
89
+ self._connect_to_server()
90
+ else:
91
+ # 嵌入式模式 - 直接启动 Runtime
92
+ self._start_embedded_runtime(model_id, runtime_config)
93
+
94
+ # 加载 tokenizer(用于消息格式化)
95
+ self._load_tokenizer(model_id)
96
+
97
+ self.loaded = True
98
+ logger.info("SGLang engine loaded successfully")
99
+ logger.info(f" - Tensor Parallel: {tp_size}")
100
+ logger.info(f" - Memory Fraction: {mem_fraction}")
101
+ logger.info(f" - Prefix Caching: {enable_prefix_caching}")
102
+ logger.info(f" - Max Batch Size: {max_batch_size}")
103
+
104
+ def _connect_to_server(self) -> None:
105
+ """连接到已运行的 SGLang 服务器"""
106
+ from sglang import RuntimeEndpoint
107
+
108
+ server_url = self._sglang_config.get("server_url", "http://localhost:30000")
109
+ logger.info(f"Connecting to SGLang server at {server_url}")
110
+
111
+ self.runtime = RuntimeEndpoint(server_url)
112
+
113
+ def _start_embedded_runtime(self, model_id: str, config: Dict[str, Any]) -> None:
114
+ """启动嵌入式 Runtime"""
115
+ try:
116
+ import sglang as sgl
117
+
118
+ self.runtime = sgl.Runtime(**config)
119
+ sgl.set_default_backend(self.runtime)
120
+ logger.info("Started embedded SGLang runtime")
121
+
122
+ except Exception as e:
123
+ logger.warning(f"Failed to start embedded runtime: {e}")
124
+ logger.info("Falling back to server mode...")
125
+ self._start_sglang_server(model_id, config)
126
+
127
+ def _start_sglang_server(self, model_id: str, config: Dict[str, Any]) -> None:
128
+ """启动 SGLang 服务器进程"""
129
+ import subprocess
130
+
131
+ port = self._sglang_config.get("port", 30000)
132
+
133
+ cmd = [
134
+ "python", "-m", "sglang.launch_server",
135
+ "--model-path", model_id,
136
+ "--port", str(port),
137
+ "--mem-fraction-static", str(config.get("mem_fraction_static", 0.85)),
138
+ ]
139
+
140
+ if config.get("tp_size", 1) > 1:
141
+ cmd.extend(["--tp", str(config["tp_size"])])
142
+
143
+ if config.get("context_length"):
144
+ cmd.extend(["--context-length", str(config["context_length"])])
145
+
146
+ if not config.get("disable_radix_cache", False):
147
+ cmd.append("--enable-radix-cache")
148
+
149
+ logger.info(f"Starting SGLang server: {' '.join(cmd)}")
150
+ self._server_process = subprocess.Popen(
151
+ cmd,
152
+ stdout=subprocess.PIPE,
153
+ stderr=subprocess.PIPE
154
+ )
155
+
156
+ # 等待服务器启动
157
+ self._wait_for_server(port)
158
+
159
+ from sglang import RuntimeEndpoint
160
+ self.runtime = RuntimeEndpoint(f"http://localhost:{port}")
161
+
162
+ def _wait_for_server(self, port: int, timeout: int = 120) -> None:
163
+ """等待服务器就绪"""
164
+ import urllib.request
165
+ import urllib.error
166
+
167
+ url = f"http://localhost:{port}/health"
168
+ start_time = time.time()
169
+
170
+ while time.time() - start_time < timeout:
171
+ try:
172
+ urllib.request.urlopen(url, timeout=1)
173
+ logger.info("SGLang server is ready")
174
+ return
175
+ except (urllib.error.URLError, urllib.error.HTTPError):
176
+ time.sleep(2)
177
+
178
+ raise RuntimeError(f"SGLang server failed to start within {timeout}s")
179
+
180
+ def _load_tokenizer(self, model_id: str) -> None:
181
+ """加载 tokenizer"""
182
+ try:
183
+ from transformers import AutoTokenizer
184
+ self.tokenizer = AutoTokenizer.from_pretrained(
185
+ model_id,
186
+ trust_remote_code=True
187
+ )
188
+ except Exception as e:
189
+ logger.warning(f"Failed to load tokenizer: {e}")
190
+ self.tokenizer = None
191
+
192
+ async def generate_async(
193
+ self,
194
+ messages: List[Dict[str, str]],
195
+ config: Optional[GenerationConfig] = None
196
+ ) -> GenerationResult:
197
+ """异步生成"""
198
+ if config is None:
199
+ config = GenerationConfig()
200
+
201
+ start_time = time.time()
202
+
203
+ # 尝试使用 SGLang function API
204
+ try:
205
+ result = await self._generate_with_sglang_api(messages, config)
206
+ except Exception as e:
207
+ logger.warning(f"SGLang function API failed: {e}, falling back to HTTP API")
208
+ result = await self._generate_with_http_api(messages, config)
209
+
210
+ latency_ms = (time.time() - start_time) * 1000
211
+ logger.debug(f"Generation completed in {latency_ms:.2f}ms")
212
+
213
+ return result
214
+
215
+ async def _generate_with_sglang_api(
216
+ self,
217
+ messages: List[Dict[str, str]],
218
+ config: GenerationConfig
219
+ ) -> GenerationResult:
220
+ """使用 SGLang native API 生成"""
221
+ import sglang as sgl
222
+
223
+ @sgl.function
224
+ def chat_completion(s, messages_list, max_tokens, temperature, top_p, top_k):
225
+ for msg in messages_list:
226
+ role = msg.get("role", "user")
227
+ content = msg.get("content", "")
228
+ if role == "system":
229
+ s += sgl.system(content)
230
+ elif role == "user":
231
+ s += sgl.user(content)
232
+ elif role == "assistant":
233
+ s += sgl.assistant(content)
234
+
235
+ s += sgl.assistant(sgl.gen(
236
+ "response",
237
+ max_tokens=max_tokens,
238
+ temperature=temperature,
239
+ top_p=top_p,
240
+ top_k=top_k,
241
+ ))
242
+
243
+ # 在线程池中运行
244
+ loop = asyncio.get_event_loop()
245
+ state = await loop.run_in_executor(
246
+ None,
247
+ lambda: chat_completion.run(
248
+ messages_list=messages,
249
+ max_tokens=config.max_tokens,
250
+ temperature=config.temperature,
251
+ top_p=config.top_p,
252
+ top_k=config.top_k,
253
+ )
254
+ )
255
+
256
+ response_text = state["response"]
257
+
258
+ # 获取 meta 信息
259
+ meta = {}
260
+ if hasattr(state, 'get_meta_info'):
261
+ meta = {
262
+ "prompt_tokens": state.get_meta_info("prompt_tokens", 0),
263
+ "completion_tokens": state.get_meta_info("completion_tokens", 0),
264
+ "cached_tokens": state.get_meta_info("cached_tokens", 0),
265
+ }
266
+
267
+ return GenerationResult(
268
+ text=response_text,
269
+ prompt_tokens=meta.get("prompt_tokens", 0),
270
+ completion_tokens=meta.get("completion_tokens", 0),
271
+ total_tokens=meta.get("prompt_tokens", 0) + meta.get("completion_tokens", 0),
272
+ cached_tokens=meta.get("cached_tokens", 0),
273
+ finish_reason="stop"
274
+ )
275
+
276
+ async def _generate_with_http_api(
277
+ self,
278
+ messages: List[Dict[str, str]],
279
+ config: GenerationConfig
280
+ ) -> GenerationResult:
281
+ """使用 HTTP API 生成(OpenAI 兼容接口)"""
282
+ import aiohttp
283
+
284
+ server_url = self._sglang_config.get("server_url", "http://localhost:30000")
285
+
286
+ payload = {
287
+ "model": self.config.get("model_id"),
288
+ "messages": messages,
289
+ "max_tokens": config.max_tokens,
290
+ "temperature": config.temperature,
291
+ "top_p": config.top_p,
292
+ }
293
+
294
+ if config.stop_sequences:
295
+ payload["stop"] = config.stop_sequences
296
+
297
+ async with aiohttp.ClientSession() as session:
298
+ async with session.post(
299
+ f"{server_url}/v1/chat/completions",
300
+ json=payload,
301
+ timeout=aiohttp.ClientTimeout(total=300)
302
+ ) as response:
303
+ if response.status != 200:
304
+ error_text = await response.text()
305
+ raise RuntimeError(f"SGLang API error: {error_text}")
306
+
307
+ result = await response.json()
308
+
309
+ if "error" in result:
310
+ raise RuntimeError(f"SGLang API error: {result['error']}")
311
+
312
+ choice = result.get("choices", [{}])[0]
313
+ message = choice.get("message", {})
314
+ usage = result.get("usage", {})
315
+
316
+ return GenerationResult(
317
+ text=message.get("content", ""),
318
+ prompt_tokens=usage.get("prompt_tokens", 0),
319
+ completion_tokens=usage.get("completion_tokens", 0),
320
+ total_tokens=usage.get("total_tokens", 0),
321
+ finish_reason=choice.get("finish_reason", "stop"),
322
+ cached_tokens=usage.get("cached_tokens", 0)
323
+ )
324
+
325
+ async def batch_generate(
326
+ self,
327
+ batch_messages: List[List[Dict[str, str]]],
328
+ config: Optional[GenerationConfig] = None
329
+ ) -> List[GenerationResult]:
330
+ """批量生成 - 利用 SGLang 的连续批处理能力"""
331
+ if config is None:
332
+ config = GenerationConfig()
333
+
334
+ # 并发执行所有请求,SGLang 会自动进行批处理优化
335
+ tasks = [
336
+ self.generate_async(messages, config)
337
+ for messages in batch_messages
338
+ ]
339
+
340
+ results = await asyncio.gather(*tasks, return_exceptions=True)
341
+
342
+ outputs = []
343
+ for result in results:
344
+ if isinstance(result, Exception):
345
+ logger.error(f"Batch generation error: {result}")
346
+ outputs.append(GenerationResult(
347
+ text="",
348
+ prompt_tokens=0,
349
+ completion_tokens=0,
350
+ total_tokens=0,
351
+ finish_reason="error"
352
+ ))
353
+ else:
354
+ outputs.append(result)
355
+
356
+ return outputs
357
+
358
+ async def stream_generate(
359
+ self,
360
+ messages: List[Dict[str, str]],
361
+ config: Optional[GenerationConfig] = None
362
+ ) -> AsyncIterator[str]:
363
+ """流式生成"""
364
+ if config is None:
365
+ config = GenerationConfig()
366
+
367
+ import aiohttp
368
+
369
+ server_url = self._sglang_config.get("server_url", "http://localhost:30000")
370
+
371
+ payload = {
372
+ "model": self.config.get("model_id"),
373
+ "messages": messages,
374
+ "max_tokens": config.max_tokens,
375
+ "temperature": config.temperature,
376
+ "top_p": config.top_p,
377
+ "stream": True,
378
+ }
379
+
380
+ if config.stop_sequences:
381
+ payload["stop"] = config.stop_sequences
382
+
383
+ try:
384
+ async with aiohttp.ClientSession() as session:
385
+ async with session.post(
386
+ f"{server_url}/v1/chat/completions",
387
+ json=payload,
388
+ timeout=aiohttp.ClientTimeout(total=300)
389
+ ) as response:
390
+ if response.status != 200:
391
+ error_text = await response.text()
392
+ raise RuntimeError(f"SGLang stream API error: {error_text}")
393
+
394
+ async for line in response.content:
395
+ line = line.decode("utf-8").strip()
396
+ if not line or not line.startswith("data: "):
397
+ continue
398
+
399
+ data = line[6:]
400
+ if data == "[DONE]":
401
+ break
402
+
403
+ import json
404
+ try:
405
+ chunk = json.loads(data)
406
+ delta = chunk.get("choices", [{}])[0].get("delta", {})
407
+ if "content" in delta:
408
+ yield delta["content"]
409
+ except json.JSONDecodeError:
410
+ continue
411
+
412
+ except Exception as e:
413
+ logger.error(f"Stream generation failed: {e}")
414
+ # Fallback to non-streaming
415
+ result = await self.generate_async(messages, config)
416
+ yield result.text
417
+
418
+ def supports_streaming(self) -> bool:
419
+ """支持流式输出"""
420
+ return True
421
+
422
+ def supports_prefix_caching(self) -> bool:
423
+ """支持前缀缓存(RadixAttention)"""
424
+ return self._sglang_config.get("enable_prefix_caching", True)
425
+
426
+ def supports_batch_inference(self) -> bool:
427
+ """支持批量推理"""
428
+ return True
429
+
430
+ def unload_model(self) -> None:
431
+ """卸载模型"""
432
+ if self.runtime:
433
+ try:
434
+ if hasattr(self.runtime, 'shutdown'):
435
+ self.runtime.shutdown()
436
+ except Exception as e:
437
+ logger.warning(f"Error shutting down runtime: {e}")
438
+ self.runtime = None
439
+
440
+ if self._server_process:
441
+ self._server_process.terminate()
442
+ try:
443
+ self._server_process.wait(timeout=10)
444
+ except subprocess.TimeoutExpired:
445
+ self._server_process.kill()
446
+ self._server_process = None
447
+
448
+ if self.tokenizer:
449
+ del self.tokenizer
450
+ self.tokenizer = None
451
+
452
+ import torch
453
+ if torch.cuda.is_available():
454
+ torch.cuda.empty_cache()
455
+
456
+ self.loaded = False
457
+ logger.info("SGLang engine unloaded")
458
+
459
+ def get_cache_stats(self) -> Dict[str, Any]:
460
+ """获取前缀缓存统计"""
461
+ stats = {
462
+ "hits": self._cache_hits,
463
+ "misses": self._cache_misses,
464
+ "hit_rate": self._cache_hits / max(1, self._cache_hits + self._cache_misses)
465
+ }
466
+
467
+ # 尝试从 runtime 获取更详细的统计
468
+ if self.runtime and hasattr(self.runtime, 'get_server_info'):
469
+ try:
470
+ server_info = self.runtime.get_server_info()
471
+ if "cache_stats" in server_info:
472
+ stats.update(server_info["cache_stats"])
473
+ except Exception:
474
+ pass
475
+
476
+ return stats
477
+
478
+ def get_status(self) -> Dict[str, Any]:
479
+ """获取引擎状态"""
480
+ status = super().get_status()
481
+ status["cache_stats"] = self.get_cache_stats()
482
+ status["features"] = [
483
+ "paged_attention",
484
+ "radix_attention",
485
+ "continuous_batching",
486
+ "chunked_prefill",
487
+ "streaming",
488
+ ]
489
+ return status