gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
"""
|
|
2
|
+
SGLang 高性能推理引擎
|
|
3
|
+
|
|
4
|
+
特点:
|
|
5
|
+
- RadixAttention 前缀缓存(自动复用相同前缀的 KV-Cache,适合 RAG/Agent)
|
|
6
|
+
- PagedAttention 内存管理(高效显存利用率 80%+)
|
|
7
|
+
- 连续批处理 (Continuous Batching)
|
|
8
|
+
- 分块预填充 (Chunked Prefill)
|
|
9
|
+
- 3.1x 吞吐量提升(相比 vLLM 2024 基准)
|
|
10
|
+
|
|
11
|
+
参考:https://github.com/sgl-project/sglang
|
|
12
|
+
"""
|
|
13
|
+
from typing import Dict, Any, List, Optional, AsyncIterator
|
|
14
|
+
import logging
|
|
15
|
+
import asyncio
|
|
16
|
+
import time
|
|
17
|
+
|
|
18
|
+
from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class SGLangEngine(LLMBaseEngine):
|
|
24
|
+
"""
|
|
25
|
+
基于 SGLang 的高性能 LLM 推理引擎
|
|
26
|
+
|
|
27
|
+
SGLang 核心优势:
|
|
28
|
+
1. RadixAttention - 自动前缀缓存,共享系统提示/对话历史的 KV-Cache
|
|
29
|
+
2. 纯 Python 实现 (<4K 核心代码),易于调试和定制
|
|
30
|
+
3. 原生支持结构化输出和约束解码
|
|
31
|
+
4. 支持主流模型架构(Llama, Qwen, Mistral, Gemma 等)
|
|
32
|
+
"""
|
|
33
|
+
|
|
34
|
+
def __init__(self, config: Dict[str, Any]):
|
|
35
|
+
super().__init__(config)
|
|
36
|
+
self.backend_type = LLMBackend.SGLANG
|
|
37
|
+
self.runtime = None
|
|
38
|
+
self._server_process = None
|
|
39
|
+
self._sglang_config = config.get("sglang", {})
|
|
40
|
+
|
|
41
|
+
# 缓存统计
|
|
42
|
+
self._cache_hits = 0
|
|
43
|
+
self._cache_misses = 0
|
|
44
|
+
|
|
45
|
+
def load_model(self) -> None:
|
|
46
|
+
"""加载模型到 SGLang Runtime"""
|
|
47
|
+
try:
|
|
48
|
+
import sglang as sgl
|
|
49
|
+
except ImportError:
|
|
50
|
+
raise ImportError(
|
|
51
|
+
"SGLang not installed. Please install with:\n"
|
|
52
|
+
" pip install sglang[all]\n"
|
|
53
|
+
"Or for minimal install:\n"
|
|
54
|
+
" pip install sglang"
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
model_id = self.config.get("model_id", "Qwen/Qwen2.5-7B-Instruct")
|
|
58
|
+
logger.info(f"Loading model with SGLang: {model_id}")
|
|
59
|
+
|
|
60
|
+
# 合并配置
|
|
61
|
+
tp_size = self._sglang_config.get("tp_size", 1)
|
|
62
|
+
mem_fraction = self._sglang_config.get("mem_fraction_static", 0.85)
|
|
63
|
+
chunked_prefill = self._sglang_config.get("chunked_prefill_size", 8192)
|
|
64
|
+
enable_prefix_caching = self._sglang_config.get("enable_prefix_caching", True)
|
|
65
|
+
max_batch_size = self._sglang_config.get("max_running_requests", 32)
|
|
66
|
+
context_length = self._sglang_config.get("context_length", 8192)
|
|
67
|
+
|
|
68
|
+
# SGLang Runtime 配置
|
|
69
|
+
runtime_config = {
|
|
70
|
+
"model_path": model_id,
|
|
71
|
+
"tp_size": tp_size,
|
|
72
|
+
"mem_fraction_static": mem_fraction,
|
|
73
|
+
"chunked_prefill_size": chunked_prefill,
|
|
74
|
+
"max_running_requests": max_batch_size,
|
|
75
|
+
"context_length": context_length,
|
|
76
|
+
"disable_radix_cache": not enable_prefix_caching,
|
|
77
|
+
}
|
|
78
|
+
|
|
79
|
+
# 量化配置
|
|
80
|
+
quantization = self.config.get("quantization")
|
|
81
|
+
if quantization:
|
|
82
|
+
if quantization in ["int8", "fp8", "awq", "gptq"]:
|
|
83
|
+
runtime_config["quantization"] = quantization
|
|
84
|
+
logger.info(f"Using quantization: {quantization}")
|
|
85
|
+
|
|
86
|
+
# 启动模式选择
|
|
87
|
+
if self._sglang_config.get("use_server_mode", False):
|
|
88
|
+
# 服务器模式 - 连接到已运行的 SGLang 服务器
|
|
89
|
+
self._connect_to_server()
|
|
90
|
+
else:
|
|
91
|
+
# 嵌入式模式 - 直接启动 Runtime
|
|
92
|
+
self._start_embedded_runtime(model_id, runtime_config)
|
|
93
|
+
|
|
94
|
+
# 加载 tokenizer(用于消息格式化)
|
|
95
|
+
self._load_tokenizer(model_id)
|
|
96
|
+
|
|
97
|
+
self.loaded = True
|
|
98
|
+
logger.info("SGLang engine loaded successfully")
|
|
99
|
+
logger.info(f" - Tensor Parallel: {tp_size}")
|
|
100
|
+
logger.info(f" - Memory Fraction: {mem_fraction}")
|
|
101
|
+
logger.info(f" - Prefix Caching: {enable_prefix_caching}")
|
|
102
|
+
logger.info(f" - Max Batch Size: {max_batch_size}")
|
|
103
|
+
|
|
104
|
+
def _connect_to_server(self) -> None:
|
|
105
|
+
"""连接到已运行的 SGLang 服务器"""
|
|
106
|
+
from sglang import RuntimeEndpoint
|
|
107
|
+
|
|
108
|
+
server_url = self._sglang_config.get("server_url", "http://localhost:30000")
|
|
109
|
+
logger.info(f"Connecting to SGLang server at {server_url}")
|
|
110
|
+
|
|
111
|
+
self.runtime = RuntimeEndpoint(server_url)
|
|
112
|
+
|
|
113
|
+
def _start_embedded_runtime(self, model_id: str, config: Dict[str, Any]) -> None:
|
|
114
|
+
"""启动嵌入式 Runtime"""
|
|
115
|
+
try:
|
|
116
|
+
import sglang as sgl
|
|
117
|
+
|
|
118
|
+
self.runtime = sgl.Runtime(**config)
|
|
119
|
+
sgl.set_default_backend(self.runtime)
|
|
120
|
+
logger.info("Started embedded SGLang runtime")
|
|
121
|
+
|
|
122
|
+
except Exception as e:
|
|
123
|
+
logger.warning(f"Failed to start embedded runtime: {e}")
|
|
124
|
+
logger.info("Falling back to server mode...")
|
|
125
|
+
self._start_sglang_server(model_id, config)
|
|
126
|
+
|
|
127
|
+
def _start_sglang_server(self, model_id: str, config: Dict[str, Any]) -> None:
|
|
128
|
+
"""启动 SGLang 服务器进程"""
|
|
129
|
+
import subprocess
|
|
130
|
+
|
|
131
|
+
port = self._sglang_config.get("port", 30000)
|
|
132
|
+
|
|
133
|
+
cmd = [
|
|
134
|
+
"python", "-m", "sglang.launch_server",
|
|
135
|
+
"--model-path", model_id,
|
|
136
|
+
"--port", str(port),
|
|
137
|
+
"--mem-fraction-static", str(config.get("mem_fraction_static", 0.85)),
|
|
138
|
+
]
|
|
139
|
+
|
|
140
|
+
if config.get("tp_size", 1) > 1:
|
|
141
|
+
cmd.extend(["--tp", str(config["tp_size"])])
|
|
142
|
+
|
|
143
|
+
if config.get("context_length"):
|
|
144
|
+
cmd.extend(["--context-length", str(config["context_length"])])
|
|
145
|
+
|
|
146
|
+
if not config.get("disable_radix_cache", False):
|
|
147
|
+
cmd.append("--enable-radix-cache")
|
|
148
|
+
|
|
149
|
+
logger.info(f"Starting SGLang server: {' '.join(cmd)}")
|
|
150
|
+
self._server_process = subprocess.Popen(
|
|
151
|
+
cmd,
|
|
152
|
+
stdout=subprocess.PIPE,
|
|
153
|
+
stderr=subprocess.PIPE
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
# 等待服务器启动
|
|
157
|
+
self._wait_for_server(port)
|
|
158
|
+
|
|
159
|
+
from sglang import RuntimeEndpoint
|
|
160
|
+
self.runtime = RuntimeEndpoint(f"http://localhost:{port}")
|
|
161
|
+
|
|
162
|
+
def _wait_for_server(self, port: int, timeout: int = 120) -> None:
|
|
163
|
+
"""等待服务器就绪"""
|
|
164
|
+
import urllib.request
|
|
165
|
+
import urllib.error
|
|
166
|
+
|
|
167
|
+
url = f"http://localhost:{port}/health"
|
|
168
|
+
start_time = time.time()
|
|
169
|
+
|
|
170
|
+
while time.time() - start_time < timeout:
|
|
171
|
+
try:
|
|
172
|
+
urllib.request.urlopen(url, timeout=1)
|
|
173
|
+
logger.info("SGLang server is ready")
|
|
174
|
+
return
|
|
175
|
+
except (urllib.error.URLError, urllib.error.HTTPError):
|
|
176
|
+
time.sleep(2)
|
|
177
|
+
|
|
178
|
+
raise RuntimeError(f"SGLang server failed to start within {timeout}s")
|
|
179
|
+
|
|
180
|
+
def _load_tokenizer(self, model_id: str) -> None:
|
|
181
|
+
"""加载 tokenizer"""
|
|
182
|
+
try:
|
|
183
|
+
from transformers import AutoTokenizer
|
|
184
|
+
self.tokenizer = AutoTokenizer.from_pretrained(
|
|
185
|
+
model_id,
|
|
186
|
+
trust_remote_code=True
|
|
187
|
+
)
|
|
188
|
+
except Exception as e:
|
|
189
|
+
logger.warning(f"Failed to load tokenizer: {e}")
|
|
190
|
+
self.tokenizer = None
|
|
191
|
+
|
|
192
|
+
async def generate_async(
|
|
193
|
+
self,
|
|
194
|
+
messages: List[Dict[str, str]],
|
|
195
|
+
config: Optional[GenerationConfig] = None
|
|
196
|
+
) -> GenerationResult:
|
|
197
|
+
"""异步生成"""
|
|
198
|
+
if config is None:
|
|
199
|
+
config = GenerationConfig()
|
|
200
|
+
|
|
201
|
+
start_time = time.time()
|
|
202
|
+
|
|
203
|
+
# 尝试使用 SGLang function API
|
|
204
|
+
try:
|
|
205
|
+
result = await self._generate_with_sglang_api(messages, config)
|
|
206
|
+
except Exception as e:
|
|
207
|
+
logger.warning(f"SGLang function API failed: {e}, falling back to HTTP API")
|
|
208
|
+
result = await self._generate_with_http_api(messages, config)
|
|
209
|
+
|
|
210
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
211
|
+
logger.debug(f"Generation completed in {latency_ms:.2f}ms")
|
|
212
|
+
|
|
213
|
+
return result
|
|
214
|
+
|
|
215
|
+
async def _generate_with_sglang_api(
|
|
216
|
+
self,
|
|
217
|
+
messages: List[Dict[str, str]],
|
|
218
|
+
config: GenerationConfig
|
|
219
|
+
) -> GenerationResult:
|
|
220
|
+
"""使用 SGLang native API 生成"""
|
|
221
|
+
import sglang as sgl
|
|
222
|
+
|
|
223
|
+
@sgl.function
|
|
224
|
+
def chat_completion(s, messages_list, max_tokens, temperature, top_p, top_k):
|
|
225
|
+
for msg in messages_list:
|
|
226
|
+
role = msg.get("role", "user")
|
|
227
|
+
content = msg.get("content", "")
|
|
228
|
+
if role == "system":
|
|
229
|
+
s += sgl.system(content)
|
|
230
|
+
elif role == "user":
|
|
231
|
+
s += sgl.user(content)
|
|
232
|
+
elif role == "assistant":
|
|
233
|
+
s += sgl.assistant(content)
|
|
234
|
+
|
|
235
|
+
s += sgl.assistant(sgl.gen(
|
|
236
|
+
"response",
|
|
237
|
+
max_tokens=max_tokens,
|
|
238
|
+
temperature=temperature,
|
|
239
|
+
top_p=top_p,
|
|
240
|
+
top_k=top_k,
|
|
241
|
+
))
|
|
242
|
+
|
|
243
|
+
# 在线程池中运行
|
|
244
|
+
loop = asyncio.get_event_loop()
|
|
245
|
+
state = await loop.run_in_executor(
|
|
246
|
+
None,
|
|
247
|
+
lambda: chat_completion.run(
|
|
248
|
+
messages_list=messages,
|
|
249
|
+
max_tokens=config.max_tokens,
|
|
250
|
+
temperature=config.temperature,
|
|
251
|
+
top_p=config.top_p,
|
|
252
|
+
top_k=config.top_k,
|
|
253
|
+
)
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
response_text = state["response"]
|
|
257
|
+
|
|
258
|
+
# 获取 meta 信息
|
|
259
|
+
meta = {}
|
|
260
|
+
if hasattr(state, 'get_meta_info'):
|
|
261
|
+
meta = {
|
|
262
|
+
"prompt_tokens": state.get_meta_info("prompt_tokens", 0),
|
|
263
|
+
"completion_tokens": state.get_meta_info("completion_tokens", 0),
|
|
264
|
+
"cached_tokens": state.get_meta_info("cached_tokens", 0),
|
|
265
|
+
}
|
|
266
|
+
|
|
267
|
+
return GenerationResult(
|
|
268
|
+
text=response_text,
|
|
269
|
+
prompt_tokens=meta.get("prompt_tokens", 0),
|
|
270
|
+
completion_tokens=meta.get("completion_tokens", 0),
|
|
271
|
+
total_tokens=meta.get("prompt_tokens", 0) + meta.get("completion_tokens", 0),
|
|
272
|
+
cached_tokens=meta.get("cached_tokens", 0),
|
|
273
|
+
finish_reason="stop"
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
async def _generate_with_http_api(
|
|
277
|
+
self,
|
|
278
|
+
messages: List[Dict[str, str]],
|
|
279
|
+
config: GenerationConfig
|
|
280
|
+
) -> GenerationResult:
|
|
281
|
+
"""使用 HTTP API 生成(OpenAI 兼容接口)"""
|
|
282
|
+
import aiohttp
|
|
283
|
+
|
|
284
|
+
server_url = self._sglang_config.get("server_url", "http://localhost:30000")
|
|
285
|
+
|
|
286
|
+
payload = {
|
|
287
|
+
"model": self.config.get("model_id"),
|
|
288
|
+
"messages": messages,
|
|
289
|
+
"max_tokens": config.max_tokens,
|
|
290
|
+
"temperature": config.temperature,
|
|
291
|
+
"top_p": config.top_p,
|
|
292
|
+
}
|
|
293
|
+
|
|
294
|
+
if config.stop_sequences:
|
|
295
|
+
payload["stop"] = config.stop_sequences
|
|
296
|
+
|
|
297
|
+
async with aiohttp.ClientSession() as session:
|
|
298
|
+
async with session.post(
|
|
299
|
+
f"{server_url}/v1/chat/completions",
|
|
300
|
+
json=payload,
|
|
301
|
+
timeout=aiohttp.ClientTimeout(total=300)
|
|
302
|
+
) as response:
|
|
303
|
+
if response.status != 200:
|
|
304
|
+
error_text = await response.text()
|
|
305
|
+
raise RuntimeError(f"SGLang API error: {error_text}")
|
|
306
|
+
|
|
307
|
+
result = await response.json()
|
|
308
|
+
|
|
309
|
+
if "error" in result:
|
|
310
|
+
raise RuntimeError(f"SGLang API error: {result['error']}")
|
|
311
|
+
|
|
312
|
+
choice = result.get("choices", [{}])[0]
|
|
313
|
+
message = choice.get("message", {})
|
|
314
|
+
usage = result.get("usage", {})
|
|
315
|
+
|
|
316
|
+
return GenerationResult(
|
|
317
|
+
text=message.get("content", ""),
|
|
318
|
+
prompt_tokens=usage.get("prompt_tokens", 0),
|
|
319
|
+
completion_tokens=usage.get("completion_tokens", 0),
|
|
320
|
+
total_tokens=usage.get("total_tokens", 0),
|
|
321
|
+
finish_reason=choice.get("finish_reason", "stop"),
|
|
322
|
+
cached_tokens=usage.get("cached_tokens", 0)
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
async def batch_generate(
|
|
326
|
+
self,
|
|
327
|
+
batch_messages: List[List[Dict[str, str]]],
|
|
328
|
+
config: Optional[GenerationConfig] = None
|
|
329
|
+
) -> List[GenerationResult]:
|
|
330
|
+
"""批量生成 - 利用 SGLang 的连续批处理能力"""
|
|
331
|
+
if config is None:
|
|
332
|
+
config = GenerationConfig()
|
|
333
|
+
|
|
334
|
+
# 并发执行所有请求,SGLang 会自动进行批处理优化
|
|
335
|
+
tasks = [
|
|
336
|
+
self.generate_async(messages, config)
|
|
337
|
+
for messages in batch_messages
|
|
338
|
+
]
|
|
339
|
+
|
|
340
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
341
|
+
|
|
342
|
+
outputs = []
|
|
343
|
+
for result in results:
|
|
344
|
+
if isinstance(result, Exception):
|
|
345
|
+
logger.error(f"Batch generation error: {result}")
|
|
346
|
+
outputs.append(GenerationResult(
|
|
347
|
+
text="",
|
|
348
|
+
prompt_tokens=0,
|
|
349
|
+
completion_tokens=0,
|
|
350
|
+
total_tokens=0,
|
|
351
|
+
finish_reason="error"
|
|
352
|
+
))
|
|
353
|
+
else:
|
|
354
|
+
outputs.append(result)
|
|
355
|
+
|
|
356
|
+
return outputs
|
|
357
|
+
|
|
358
|
+
async def stream_generate(
|
|
359
|
+
self,
|
|
360
|
+
messages: List[Dict[str, str]],
|
|
361
|
+
config: Optional[GenerationConfig] = None
|
|
362
|
+
) -> AsyncIterator[str]:
|
|
363
|
+
"""流式生成"""
|
|
364
|
+
if config is None:
|
|
365
|
+
config = GenerationConfig()
|
|
366
|
+
|
|
367
|
+
import aiohttp
|
|
368
|
+
|
|
369
|
+
server_url = self._sglang_config.get("server_url", "http://localhost:30000")
|
|
370
|
+
|
|
371
|
+
payload = {
|
|
372
|
+
"model": self.config.get("model_id"),
|
|
373
|
+
"messages": messages,
|
|
374
|
+
"max_tokens": config.max_tokens,
|
|
375
|
+
"temperature": config.temperature,
|
|
376
|
+
"top_p": config.top_p,
|
|
377
|
+
"stream": True,
|
|
378
|
+
}
|
|
379
|
+
|
|
380
|
+
if config.stop_sequences:
|
|
381
|
+
payload["stop"] = config.stop_sequences
|
|
382
|
+
|
|
383
|
+
try:
|
|
384
|
+
async with aiohttp.ClientSession() as session:
|
|
385
|
+
async with session.post(
|
|
386
|
+
f"{server_url}/v1/chat/completions",
|
|
387
|
+
json=payload,
|
|
388
|
+
timeout=aiohttp.ClientTimeout(total=300)
|
|
389
|
+
) as response:
|
|
390
|
+
if response.status != 200:
|
|
391
|
+
error_text = await response.text()
|
|
392
|
+
raise RuntimeError(f"SGLang stream API error: {error_text}")
|
|
393
|
+
|
|
394
|
+
async for line in response.content:
|
|
395
|
+
line = line.decode("utf-8").strip()
|
|
396
|
+
if not line or not line.startswith("data: "):
|
|
397
|
+
continue
|
|
398
|
+
|
|
399
|
+
data = line[6:]
|
|
400
|
+
if data == "[DONE]":
|
|
401
|
+
break
|
|
402
|
+
|
|
403
|
+
import json
|
|
404
|
+
try:
|
|
405
|
+
chunk = json.loads(data)
|
|
406
|
+
delta = chunk.get("choices", [{}])[0].get("delta", {})
|
|
407
|
+
if "content" in delta:
|
|
408
|
+
yield delta["content"]
|
|
409
|
+
except json.JSONDecodeError:
|
|
410
|
+
continue
|
|
411
|
+
|
|
412
|
+
except Exception as e:
|
|
413
|
+
logger.error(f"Stream generation failed: {e}")
|
|
414
|
+
# Fallback to non-streaming
|
|
415
|
+
result = await self.generate_async(messages, config)
|
|
416
|
+
yield result.text
|
|
417
|
+
|
|
418
|
+
def supports_streaming(self) -> bool:
|
|
419
|
+
"""支持流式输出"""
|
|
420
|
+
return True
|
|
421
|
+
|
|
422
|
+
def supports_prefix_caching(self) -> bool:
|
|
423
|
+
"""支持前缀缓存(RadixAttention)"""
|
|
424
|
+
return self._sglang_config.get("enable_prefix_caching", True)
|
|
425
|
+
|
|
426
|
+
def supports_batch_inference(self) -> bool:
|
|
427
|
+
"""支持批量推理"""
|
|
428
|
+
return True
|
|
429
|
+
|
|
430
|
+
def unload_model(self) -> None:
|
|
431
|
+
"""卸载模型"""
|
|
432
|
+
if self.runtime:
|
|
433
|
+
try:
|
|
434
|
+
if hasattr(self.runtime, 'shutdown'):
|
|
435
|
+
self.runtime.shutdown()
|
|
436
|
+
except Exception as e:
|
|
437
|
+
logger.warning(f"Error shutting down runtime: {e}")
|
|
438
|
+
self.runtime = None
|
|
439
|
+
|
|
440
|
+
if self._server_process:
|
|
441
|
+
self._server_process.terminate()
|
|
442
|
+
try:
|
|
443
|
+
self._server_process.wait(timeout=10)
|
|
444
|
+
except subprocess.TimeoutExpired:
|
|
445
|
+
self._server_process.kill()
|
|
446
|
+
self._server_process = None
|
|
447
|
+
|
|
448
|
+
if self.tokenizer:
|
|
449
|
+
del self.tokenizer
|
|
450
|
+
self.tokenizer = None
|
|
451
|
+
|
|
452
|
+
import torch
|
|
453
|
+
if torch.cuda.is_available():
|
|
454
|
+
torch.cuda.empty_cache()
|
|
455
|
+
|
|
456
|
+
self.loaded = False
|
|
457
|
+
logger.info("SGLang engine unloaded")
|
|
458
|
+
|
|
459
|
+
def get_cache_stats(self) -> Dict[str, Any]:
|
|
460
|
+
"""获取前缀缓存统计"""
|
|
461
|
+
stats = {
|
|
462
|
+
"hits": self._cache_hits,
|
|
463
|
+
"misses": self._cache_misses,
|
|
464
|
+
"hit_rate": self._cache_hits / max(1, self._cache_hits + self._cache_misses)
|
|
465
|
+
}
|
|
466
|
+
|
|
467
|
+
# 尝试从 runtime 获取更详细的统计
|
|
468
|
+
if self.runtime and hasattr(self.runtime, 'get_server_info'):
|
|
469
|
+
try:
|
|
470
|
+
server_info = self.runtime.get_server_info()
|
|
471
|
+
if "cache_stats" in server_info:
|
|
472
|
+
stats.update(server_info["cache_stats"])
|
|
473
|
+
except Exception:
|
|
474
|
+
pass
|
|
475
|
+
|
|
476
|
+
return stats
|
|
477
|
+
|
|
478
|
+
def get_status(self) -> Dict[str, Any]:
|
|
479
|
+
"""获取引擎状态"""
|
|
480
|
+
status = super().get_status()
|
|
481
|
+
status["cache_stats"] = self.get_cache_stats()
|
|
482
|
+
status["features"] = [
|
|
483
|
+
"paged_attention",
|
|
484
|
+
"radix_attention",
|
|
485
|
+
"continuous_batching",
|
|
486
|
+
"chunked_prefill",
|
|
487
|
+
"streaming",
|
|
488
|
+
]
|
|
489
|
+
return status
|