gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,455 @@
|
|
|
1
|
+
"""
|
|
2
|
+
分布式推理会话
|
|
3
|
+
|
|
4
|
+
参考 Petals InferenceSession 设计,实现:
|
|
5
|
+
- 跨 Worker 的推理会话管理
|
|
6
|
+
- 故障检测与自动恢复
|
|
7
|
+
- Server-to-Server 直连传输
|
|
8
|
+
"""
|
|
9
|
+
import asyncio
|
|
10
|
+
import threading
|
|
11
|
+
from concurrent.futures import Future
|
|
12
|
+
import uuid
|
|
13
|
+
import time
|
|
14
|
+
import logging
|
|
15
|
+
from typing import Dict, Any, List, Optional, Tuple
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from enum import Enum
|
|
18
|
+
|
|
19
|
+
import aiohttp
|
|
20
|
+
|
|
21
|
+
# 本地导入
|
|
22
|
+
import sys
|
|
23
|
+
import os
|
|
24
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
25
|
+
from common.data_structures import (
|
|
26
|
+
BlockRange,
|
|
27
|
+
WorkerInfo,
|
|
28
|
+
InferenceState,
|
|
29
|
+
SessionConfig,
|
|
30
|
+
WorkerState,
|
|
31
|
+
)
|
|
32
|
+
from common.serialization import serialize_tensor, deserialize_tensor
|
|
33
|
+
|
|
34
|
+
logger = logging.getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
def _run_coroutine_in_new_thread(coro):
|
|
37
|
+
future: Future = Future()
|
|
38
|
+
|
|
39
|
+
def runner() -> None:
|
|
40
|
+
try:
|
|
41
|
+
future.set_result(asyncio.run(coro))
|
|
42
|
+
except BaseException as exc:
|
|
43
|
+
future.set_exception(exc)
|
|
44
|
+
|
|
45
|
+
threading.Thread(target=runner, daemon=True).start()
|
|
46
|
+
return future.result()
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class SessionState(Enum):
|
|
50
|
+
"""会话状态"""
|
|
51
|
+
INITIALIZING = "initializing"
|
|
52
|
+
READY = "ready"
|
|
53
|
+
ACTIVE = "active"
|
|
54
|
+
ERROR = "error"
|
|
55
|
+
CLOSED = "closed"
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
@dataclass
|
|
59
|
+
class WorkerSession:
|
|
60
|
+
"""
|
|
61
|
+
单个 Worker 的推理会话
|
|
62
|
+
|
|
63
|
+
管理与特定 Worker 的连接和状态
|
|
64
|
+
"""
|
|
65
|
+
worker_info: WorkerInfo
|
|
66
|
+
session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
|
67
|
+
state: SessionState = SessionState.INITIALIZING
|
|
68
|
+
|
|
69
|
+
# 会话状态
|
|
70
|
+
position: int = 0
|
|
71
|
+
history: Optional[Any] = None # 用于故障恢复的输入历史
|
|
72
|
+
|
|
73
|
+
# 下一跳会话(用于 server-to-server)
|
|
74
|
+
next_session: Optional["WorkerSession"] = None
|
|
75
|
+
|
|
76
|
+
# 连接
|
|
77
|
+
_http_session: Optional[aiohttp.ClientSession] = None
|
|
78
|
+
|
|
79
|
+
async def connect(self, timeout: float = 30.0) -> None:
|
|
80
|
+
"""建立与 Worker 的连接"""
|
|
81
|
+
if self._http_session is None:
|
|
82
|
+
self._http_session = aiohttp.ClientSession(
|
|
83
|
+
timeout=aiohttp.ClientTimeout(total=timeout)
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# 验证 Worker 可用性
|
|
87
|
+
try:
|
|
88
|
+
async with self._http_session.get(
|
|
89
|
+
f"{self.worker_info.api_endpoint}/health"
|
|
90
|
+
) as response:
|
|
91
|
+
if response.status != 200:
|
|
92
|
+
raise ConnectionError(
|
|
93
|
+
f"Worker health check failed: {response.status}"
|
|
94
|
+
)
|
|
95
|
+
except Exception as e:
|
|
96
|
+
self.state = SessionState.ERROR
|
|
97
|
+
raise ConnectionError(f"Failed to connect to worker: {e}")
|
|
98
|
+
|
|
99
|
+
self.state = SessionState.READY
|
|
100
|
+
logger.info(f"Connected to worker {self.worker_info.worker_id}")
|
|
101
|
+
|
|
102
|
+
async def forward(
|
|
103
|
+
self,
|
|
104
|
+
hidden_states: Any,
|
|
105
|
+
position: int,
|
|
106
|
+
kv_cache_keys: List[str] = None,
|
|
107
|
+
) -> Tuple[Any, List[str]]:
|
|
108
|
+
"""
|
|
109
|
+
执行前向传播
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
hidden_states: 输入隐藏状态 (tensor)
|
|
113
|
+
position: 当前位置
|
|
114
|
+
kv_cache_keys: KV-Cache 键列表
|
|
115
|
+
|
|
116
|
+
Returns:
|
|
117
|
+
(output_hidden_states, updated_kv_keys)
|
|
118
|
+
"""
|
|
119
|
+
if self.state not in (SessionState.READY, SessionState.ACTIVE):
|
|
120
|
+
raise RuntimeError(f"Session not ready: {self.state}")
|
|
121
|
+
|
|
122
|
+
self.state = SessionState.ACTIVE
|
|
123
|
+
|
|
124
|
+
# 序列化输入
|
|
125
|
+
serialized_input = serialize_tensor(hidden_states)
|
|
126
|
+
|
|
127
|
+
# 构建请求
|
|
128
|
+
payload = {
|
|
129
|
+
"session_id": self.session_id,
|
|
130
|
+
"input": serialized_input,
|
|
131
|
+
"position": position,
|
|
132
|
+
"kv_cache_keys": kv_cache_keys or [],
|
|
133
|
+
"blocks": self.worker_info.blocks.to_dict() if self.worker_info.blocks else None,
|
|
134
|
+
}
|
|
135
|
+
|
|
136
|
+
# 如果有下一跳,添加路由信息
|
|
137
|
+
if self.next_session:
|
|
138
|
+
payload["next_worker"] = {
|
|
139
|
+
"address": self.next_session.worker_info.api_endpoint,
|
|
140
|
+
"session_id": self.next_session.session_id,
|
|
141
|
+
}
|
|
142
|
+
|
|
143
|
+
# 发送请求
|
|
144
|
+
try:
|
|
145
|
+
async with self._http_session.post(
|
|
146
|
+
f"{self.worker_info.api_endpoint}/inference/forward",
|
|
147
|
+
json=payload
|
|
148
|
+
) as response:
|
|
149
|
+
if response.status != 200:
|
|
150
|
+
error = await response.text()
|
|
151
|
+
raise RuntimeError(f"Forward failed: {error}")
|
|
152
|
+
|
|
153
|
+
result = await response.json()
|
|
154
|
+
|
|
155
|
+
except Exception as e:
|
|
156
|
+
self.state = SessionState.ERROR
|
|
157
|
+
raise RuntimeError(f"Forward error: {e}")
|
|
158
|
+
|
|
159
|
+
# 反序列化输出
|
|
160
|
+
output_hidden_states = deserialize_tensor(result["output"])
|
|
161
|
+
updated_kv_keys = result.get("kv_cache_keys", [])
|
|
162
|
+
|
|
163
|
+
# 更新位置
|
|
164
|
+
self.position = position + hidden_states.shape[1] if hasattr(hidden_states, 'shape') else position + 1
|
|
165
|
+
|
|
166
|
+
return output_hidden_states, updated_kv_keys
|
|
167
|
+
|
|
168
|
+
async def close(self) -> None:
|
|
169
|
+
"""关闭会话"""
|
|
170
|
+
if self._http_session:
|
|
171
|
+
# 通知 Worker 关闭会话
|
|
172
|
+
try:
|
|
173
|
+
async with self._http_session.post(
|
|
174
|
+
f"{self.worker_info.api_endpoint}/inference/close",
|
|
175
|
+
json={"session_id": self.session_id}
|
|
176
|
+
) as response:
|
|
177
|
+
pass
|
|
178
|
+
except Exception as e:
|
|
179
|
+
logger.warning(f"Error closing worker session: {e}")
|
|
180
|
+
|
|
181
|
+
await self._http_session.close()
|
|
182
|
+
self._http_session = None
|
|
183
|
+
|
|
184
|
+
self.state = SessionState.CLOSED
|
|
185
|
+
|
|
186
|
+
def __enter__(self):
|
|
187
|
+
return self
|
|
188
|
+
|
|
189
|
+
def __exit__(self, *exc):
|
|
190
|
+
try:
|
|
191
|
+
asyncio.get_running_loop()
|
|
192
|
+
except RuntimeError:
|
|
193
|
+
asyncio.run(self.close())
|
|
194
|
+
else:
|
|
195
|
+
_run_coroutine_in_new_thread(self.close())
|
|
196
|
+
|
|
197
|
+
|
|
198
|
+
class DistributedInferenceSession:
|
|
199
|
+
"""
|
|
200
|
+
分布式推理会话
|
|
201
|
+
|
|
202
|
+
管理跨多个 Worker 的推理会话,参考 Petals InferenceSession
|
|
203
|
+
"""
|
|
204
|
+
|
|
205
|
+
def __init__(
|
|
206
|
+
self,
|
|
207
|
+
config: SessionConfig,
|
|
208
|
+
route: List[WorkerInfo],
|
|
209
|
+
):
|
|
210
|
+
"""
|
|
211
|
+
Args:
|
|
212
|
+
config: 会话配置
|
|
213
|
+
route: 推理路由(按顺序的 Worker 列表)
|
|
214
|
+
"""
|
|
215
|
+
self.config = config
|
|
216
|
+
self.route = route
|
|
217
|
+
|
|
218
|
+
self.session_id = str(uuid.uuid4())
|
|
219
|
+
self.state = SessionState.INITIALIZING
|
|
220
|
+
|
|
221
|
+
# Worker 会话
|
|
222
|
+
self._worker_sessions: List[WorkerSession] = []
|
|
223
|
+
|
|
224
|
+
# 推理状态
|
|
225
|
+
self._position = 0
|
|
226
|
+
self._max_length = config.max_length
|
|
227
|
+
|
|
228
|
+
# 统计信息
|
|
229
|
+
self._stats = {
|
|
230
|
+
"total_tokens": 0,
|
|
231
|
+
"total_steps": 0,
|
|
232
|
+
"total_latency_ms": 0,
|
|
233
|
+
"retries": 0,
|
|
234
|
+
}
|
|
235
|
+
|
|
236
|
+
@property
|
|
237
|
+
def position(self) -> int:
|
|
238
|
+
return self._position
|
|
239
|
+
|
|
240
|
+
@position.setter
|
|
241
|
+
def position(self, value: int) -> None:
|
|
242
|
+
self._position = value
|
|
243
|
+
for session in self._worker_sessions:
|
|
244
|
+
session.position = value
|
|
245
|
+
|
|
246
|
+
async def setup(self) -> None:
|
|
247
|
+
"""建立与所有 Worker 的连接"""
|
|
248
|
+
logger.info(f"Setting up distributed session with {len(self.route)} workers")
|
|
249
|
+
|
|
250
|
+
try:
|
|
251
|
+
for worker_info in self.route:
|
|
252
|
+
session = WorkerSession(worker_info=worker_info)
|
|
253
|
+
await session.connect(timeout=self.config.connect_timeout)
|
|
254
|
+
self._worker_sessions.append(session)
|
|
255
|
+
|
|
256
|
+
# 链接会话(用于 server-to-server)
|
|
257
|
+
for i in range(len(self._worker_sessions) - 1):
|
|
258
|
+
self._worker_sessions[i].next_session = self._worker_sessions[i + 1]
|
|
259
|
+
|
|
260
|
+
self.state = SessionState.READY
|
|
261
|
+
logger.info("Distributed session setup complete")
|
|
262
|
+
|
|
263
|
+
except Exception as e:
|
|
264
|
+
self.state = SessionState.ERROR
|
|
265
|
+
# 清理已创建的会话
|
|
266
|
+
for session in self._worker_sessions:
|
|
267
|
+
await session.close()
|
|
268
|
+
self._worker_sessions.clear()
|
|
269
|
+
raise
|
|
270
|
+
|
|
271
|
+
async def step(
|
|
272
|
+
self,
|
|
273
|
+
inputs: Any,
|
|
274
|
+
kv_cache_keys: List[str] = None,
|
|
275
|
+
) -> Any:
|
|
276
|
+
"""
|
|
277
|
+
执行一步推理
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
inputs: 输入 tensor
|
|
281
|
+
kv_cache_keys: KV-Cache 键列表
|
|
282
|
+
|
|
283
|
+
Returns:
|
|
284
|
+
输出 tensor
|
|
285
|
+
"""
|
|
286
|
+
if self.state not in (SessionState.READY, SessionState.ACTIVE):
|
|
287
|
+
raise RuntimeError(f"Session not ready: {self.state}")
|
|
288
|
+
|
|
289
|
+
self.state = SessionState.ACTIVE
|
|
290
|
+
step_start = time.time()
|
|
291
|
+
|
|
292
|
+
# 检查长度限制
|
|
293
|
+
n_input_tokens = inputs.shape[1] if hasattr(inputs, 'shape') else 1
|
|
294
|
+
if self._position + n_input_tokens > self._max_length:
|
|
295
|
+
raise ValueError(
|
|
296
|
+
f"Maximum length exceeded: {self._position} + {n_input_tokens} > {self._max_length}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
hidden_states = inputs
|
|
300
|
+
current_kv_keys = kv_cache_keys or []
|
|
301
|
+
|
|
302
|
+
# 依次通过每个 Worker
|
|
303
|
+
for i, session in enumerate(self._worker_sessions):
|
|
304
|
+
for attempt in range(self.config.max_retries):
|
|
305
|
+
try:
|
|
306
|
+
hidden_states, current_kv_keys = await session.forward(
|
|
307
|
+
hidden_states,
|
|
308
|
+
position=self._position,
|
|
309
|
+
kv_cache_keys=current_kv_keys,
|
|
310
|
+
)
|
|
311
|
+
break
|
|
312
|
+
|
|
313
|
+
except Exception as e:
|
|
314
|
+
logger.warning(
|
|
315
|
+
f"Worker {session.worker_info.worker_id} failed "
|
|
316
|
+
f"(attempt {attempt + 1}/{self.config.max_retries}): {e}"
|
|
317
|
+
)
|
|
318
|
+
self._stats["retries"] += 1
|
|
319
|
+
|
|
320
|
+
if attempt + 1 == self.config.max_retries:
|
|
321
|
+
# 尝试故障恢复
|
|
322
|
+
await self._handle_failure(i, e)
|
|
323
|
+
hidden_states, current_kv_keys = await session.forward(
|
|
324
|
+
hidden_states,
|
|
325
|
+
position=self._position,
|
|
326
|
+
kv_cache_keys=current_kv_keys,
|
|
327
|
+
)
|
|
328
|
+
else:
|
|
329
|
+
await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
|
|
330
|
+
|
|
331
|
+
# 更新状态
|
|
332
|
+
self._position += n_input_tokens
|
|
333
|
+
self._stats["total_tokens"] += n_input_tokens
|
|
334
|
+
self._stats["total_steps"] += 1
|
|
335
|
+
self._stats["total_latency_ms"] += (time.time() - step_start) * 1000
|
|
336
|
+
|
|
337
|
+
return hidden_states
|
|
338
|
+
|
|
339
|
+
async def _handle_failure(
|
|
340
|
+
self,
|
|
341
|
+
failed_idx: int,
|
|
342
|
+
error: Exception
|
|
343
|
+
) -> None:
|
|
344
|
+
"""
|
|
345
|
+
处理 Worker 故障
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
failed_idx: 故障 Worker 的索引
|
|
349
|
+
error: 错误信息
|
|
350
|
+
"""
|
|
351
|
+
failed_session = self._worker_sessions[failed_idx]
|
|
352
|
+
logger.error(
|
|
353
|
+
f"Worker {failed_session.worker_info.worker_id} failed: {error}. "
|
|
354
|
+
f"Attempting recovery..."
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
# 关闭故障会话
|
|
358
|
+
await failed_session.close()
|
|
359
|
+
|
|
360
|
+
# TODO: 从调度器获取替代 Worker
|
|
361
|
+
# 这里需要集成调度器服务
|
|
362
|
+
raise RuntimeError(
|
|
363
|
+
f"Worker failure recovery not implemented. "
|
|
364
|
+
f"Failed worker: {failed_session.worker_info.worker_id}"
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
async def close(self) -> None:
|
|
368
|
+
"""关闭会话"""
|
|
369
|
+
for session in self._worker_sessions:
|
|
370
|
+
try:
|
|
371
|
+
await session.close()
|
|
372
|
+
except Exception as e:
|
|
373
|
+
logger.warning(f"Error closing session: {e}")
|
|
374
|
+
|
|
375
|
+
self._worker_sessions.clear()
|
|
376
|
+
self.state = SessionState.CLOSED
|
|
377
|
+
logger.info(f"Distributed session closed. Stats: {self._stats}")
|
|
378
|
+
|
|
379
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
380
|
+
"""获取统计信息"""
|
|
381
|
+
stats = self._stats.copy()
|
|
382
|
+
if stats["total_steps"] > 0:
|
|
383
|
+
stats["avg_latency_ms"] = stats["total_latency_ms"] / stats["total_steps"]
|
|
384
|
+
stats["tokens_per_second"] = (
|
|
385
|
+
stats["total_tokens"] / (stats["total_latency_ms"] / 1000)
|
|
386
|
+
if stats["total_latency_ms"] > 0 else 0
|
|
387
|
+
)
|
|
388
|
+
return stats
|
|
389
|
+
|
|
390
|
+
async def __aenter__(self):
|
|
391
|
+
await self.setup()
|
|
392
|
+
return self
|
|
393
|
+
|
|
394
|
+
async def __aexit__(self, *exc):
|
|
395
|
+
await self.close()
|
|
396
|
+
|
|
397
|
+
|
|
398
|
+
class SessionManager:
|
|
399
|
+
"""
|
|
400
|
+
会话管理器
|
|
401
|
+
|
|
402
|
+
管理多个分布式推理会话的生命周期
|
|
403
|
+
"""
|
|
404
|
+
|
|
405
|
+
def __init__(self, max_sessions: int = 100):
|
|
406
|
+
self.max_sessions = max_sessions
|
|
407
|
+
self._sessions: Dict[str, DistributedInferenceSession] = {}
|
|
408
|
+
self._lock = asyncio.Lock()
|
|
409
|
+
|
|
410
|
+
async def create_session(
|
|
411
|
+
self,
|
|
412
|
+
config: SessionConfig,
|
|
413
|
+
route: List[WorkerInfo],
|
|
414
|
+
) -> DistributedInferenceSession:
|
|
415
|
+
"""创建新会话"""
|
|
416
|
+
async with self._lock:
|
|
417
|
+
if len(self._sessions) >= self.max_sessions:
|
|
418
|
+
# 清理过期会话
|
|
419
|
+
await self._cleanup_expired_sessions()
|
|
420
|
+
|
|
421
|
+
if len(self._sessions) >= self.max_sessions:
|
|
422
|
+
raise RuntimeError(f"Maximum sessions reached: {self.max_sessions}")
|
|
423
|
+
|
|
424
|
+
session = DistributedInferenceSession(config, route)
|
|
425
|
+
await session.setup()
|
|
426
|
+
self._sessions[session.session_id] = session
|
|
427
|
+
|
|
428
|
+
return session
|
|
429
|
+
|
|
430
|
+
async def get_session(self, session_id: str) -> Optional[DistributedInferenceSession]:
|
|
431
|
+
"""获取会话"""
|
|
432
|
+
return self._sessions.get(session_id)
|
|
433
|
+
|
|
434
|
+
async def close_session(self, session_id: str) -> None:
|
|
435
|
+
"""关闭会话"""
|
|
436
|
+
async with self._lock:
|
|
437
|
+
session = self._sessions.pop(session_id, None)
|
|
438
|
+
if session:
|
|
439
|
+
await session.close()
|
|
440
|
+
|
|
441
|
+
async def _cleanup_expired_sessions(self) -> None:
|
|
442
|
+
"""清理过期会话"""
|
|
443
|
+
expired = [
|
|
444
|
+
sid for sid, session in self._sessions.items()
|
|
445
|
+
if session.state in (SessionState.CLOSED, SessionState.ERROR)
|
|
446
|
+
]
|
|
447
|
+
for sid in expired:
|
|
448
|
+
await self.close_session(sid)
|
|
449
|
+
|
|
450
|
+
async def close_all(self) -> None:
|
|
451
|
+
"""关闭所有会话"""
|
|
452
|
+
async with self._lock:
|
|
453
|
+
for session in list(self._sessions.values()):
|
|
454
|
+
await session.close()
|
|
455
|
+
self._sessions.clear()
|
|
@@ -0,0 +1,215 @@
|
|
|
1
|
+
"""引擎模块
|
|
2
|
+
|
|
3
|
+
支持多种推理后端:
|
|
4
|
+
- llm: 原生 Transformers 后端(兼容性好)
|
|
5
|
+
- llm_sglang: SGLang 高性能后端(推荐,RadixAttention)
|
|
6
|
+
- llm_vllm: vLLM 高性能后端(PagedAttention)
|
|
7
|
+
- llm_vllm_async: vLLM 异步引擎(支持流式)
|
|
8
|
+
- image_gen: 图像生成引擎
|
|
9
|
+
- vision: 视觉模型引擎
|
|
10
|
+
|
|
11
|
+
使用示例:
|
|
12
|
+
# 方式1: 直接使用引擎类
|
|
13
|
+
from engines import LLMEngine
|
|
14
|
+
engine = LLMEngine(config)
|
|
15
|
+
|
|
16
|
+
# 方式2: 通过配置选择后端
|
|
17
|
+
from engines import create_llm_engine
|
|
18
|
+
engine = create_llm_engine({"backend": "sglang", "model_id": "..."})
|
|
19
|
+
|
|
20
|
+
# 方式3: 通过类型名获取
|
|
21
|
+
from engines import get_engine
|
|
22
|
+
EngineClass = get_engine("llm_sglang")
|
|
23
|
+
engine = EngineClass(config)
|
|
24
|
+
"""
|
|
25
|
+
from typing import Dict, Any, Optional
|
|
26
|
+
|
|
27
|
+
from .base import BaseEngine
|
|
28
|
+
from .llm import LLMEngine
|
|
29
|
+
from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
|
|
30
|
+
from .image_gen import ImageGenEngine
|
|
31
|
+
from .vision import VisionEngine
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
# 延迟导入高性能引擎(可能需要额外依赖)
|
|
35
|
+
def _get_sglang_engine():
|
|
36
|
+
from .llm_sglang import SGLangEngine
|
|
37
|
+
return SGLangEngine
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def _get_vllm_engine():
|
|
41
|
+
from .llm_vllm import VLLMEngine
|
|
42
|
+
return VLLMEngine
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _get_vllm_async_engine():
|
|
46
|
+
from .llm_vllm import VLLMAsyncEngine
|
|
47
|
+
return VLLMAsyncEngine
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
# 引擎注册表
|
|
51
|
+
ENGINE_REGISTRY = {
|
|
52
|
+
# 原生后端
|
|
53
|
+
"llm": LLMEngine,
|
|
54
|
+
"image_gen": ImageGenEngine,
|
|
55
|
+
"vision": VisionEngine,
|
|
56
|
+
}
|
|
57
|
+
|
|
58
|
+
# 高性能后端(延迟注册)
|
|
59
|
+
_LAZY_ENGINES = {
|
|
60
|
+
"llm_sglang": _get_sglang_engine,
|
|
61
|
+
"llm_vllm": _get_vllm_engine,
|
|
62
|
+
"llm_vllm_async": _get_vllm_async_engine,
|
|
63
|
+
}
|
|
64
|
+
|
|
65
|
+
# 后端别名映射
|
|
66
|
+
_BACKEND_ALIASES = {
|
|
67
|
+
"native": "llm",
|
|
68
|
+
"transformers": "llm",
|
|
69
|
+
"sglang": "llm_sglang",
|
|
70
|
+
"vllm": "llm_vllm",
|
|
71
|
+
"vllm_async": "llm_vllm_async",
|
|
72
|
+
}
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def get_engine(engine_type: str) -> type:
|
|
76
|
+
"""
|
|
77
|
+
获取引擎类
|
|
78
|
+
|
|
79
|
+
Args:
|
|
80
|
+
engine_type: 引擎类型名称
|
|
81
|
+
|
|
82
|
+
Returns:
|
|
83
|
+
引擎类
|
|
84
|
+
|
|
85
|
+
Raises:
|
|
86
|
+
ValueError: 未知的引擎类型
|
|
87
|
+
ImportError: 引擎依赖未安装
|
|
88
|
+
"""
|
|
89
|
+
# 处理别名
|
|
90
|
+
engine_type = _BACKEND_ALIASES.get(engine_type, engine_type)
|
|
91
|
+
|
|
92
|
+
if engine_type in ENGINE_REGISTRY:
|
|
93
|
+
return ENGINE_REGISTRY[engine_type]
|
|
94
|
+
|
|
95
|
+
if engine_type in _LAZY_ENGINES:
|
|
96
|
+
try:
|
|
97
|
+
engine_class = _LAZY_ENGINES[engine_type]()
|
|
98
|
+
ENGINE_REGISTRY[engine_type] = engine_class # 缓存
|
|
99
|
+
return engine_class
|
|
100
|
+
except ImportError as e:
|
|
101
|
+
raise ImportError(
|
|
102
|
+
f"Engine '{engine_type}' requires additional dependencies: {e}"
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
raise ValueError(f"Unknown engine type: {engine_type}")
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def create_llm_engine(config: Dict[str, Any]) -> LLMBaseEngine:
|
|
109
|
+
"""
|
|
110
|
+
根据配置创建 LLM 引擎
|
|
111
|
+
|
|
112
|
+
这是创建 LLM 引擎的推荐方式,会根据配置中的 backend 字段
|
|
113
|
+
自动选择合适的引擎实现。
|
|
114
|
+
|
|
115
|
+
Args:
|
|
116
|
+
config: 引擎配置,应包含:
|
|
117
|
+
- backend: 后端类型 ("native", "sglang", "vllm", "vllm_async")
|
|
118
|
+
- model_id: 模型 ID
|
|
119
|
+
- 其他后端特定配置
|
|
120
|
+
|
|
121
|
+
Returns:
|
|
122
|
+
LLM 引擎实例
|
|
123
|
+
|
|
124
|
+
示例:
|
|
125
|
+
config = {
|
|
126
|
+
"backend": "sglang",
|
|
127
|
+
"model_id": "Qwen/Qwen2.5-7B-Instruct",
|
|
128
|
+
"sglang": {
|
|
129
|
+
"tp_size": 1,
|
|
130
|
+
"mem_fraction_static": 0.85,
|
|
131
|
+
"enable_prefix_caching": True,
|
|
132
|
+
}
|
|
133
|
+
}
|
|
134
|
+
engine = create_llm_engine(config)
|
|
135
|
+
"""
|
|
136
|
+
backend = config.get("backend", "native").lower()
|
|
137
|
+
|
|
138
|
+
# 获取引擎类型
|
|
139
|
+
engine_type = _BACKEND_ALIASES.get(backend, backend)
|
|
140
|
+
|
|
141
|
+
# 验证是 LLM 引擎
|
|
142
|
+
if not engine_type.startswith("llm"):
|
|
143
|
+
raise ValueError(f"'{backend}' is not a valid LLM backend")
|
|
144
|
+
|
|
145
|
+
# 获取引擎类
|
|
146
|
+
engine_class = get_engine(engine_type)
|
|
147
|
+
|
|
148
|
+
# 创建实例
|
|
149
|
+
return engine_class(config)
|
|
150
|
+
|
|
151
|
+
|
|
152
|
+
def list_engines() -> dict:
|
|
153
|
+
"""列出所有可用引擎及其状态"""
|
|
154
|
+
engines = {}
|
|
155
|
+
|
|
156
|
+
# 已注册引擎
|
|
157
|
+
for name in ENGINE_REGISTRY:
|
|
158
|
+
engines[name] = {"available": True, "loaded": True}
|
|
159
|
+
|
|
160
|
+
# 延迟加载引擎
|
|
161
|
+
for name, loader in _LAZY_ENGINES.items():
|
|
162
|
+
if name not in engines:
|
|
163
|
+
try:
|
|
164
|
+
loader()
|
|
165
|
+
engines[name] = {"available": True, "loaded": False}
|
|
166
|
+
except ImportError as e:
|
|
167
|
+
engines[name] = {"available": False, "error": str(e)}
|
|
168
|
+
|
|
169
|
+
return engines
|
|
170
|
+
|
|
171
|
+
|
|
172
|
+
def get_recommended_backend() -> str:
|
|
173
|
+
"""
|
|
174
|
+
获取推荐的 LLM 后端
|
|
175
|
+
|
|
176
|
+
按优先级尝试:SGLang > vLLM > Native
|
|
177
|
+
"""
|
|
178
|
+
# 优先尝试 SGLang
|
|
179
|
+
try:
|
|
180
|
+
_get_sglang_engine()
|
|
181
|
+
return "sglang"
|
|
182
|
+
except ImportError:
|
|
183
|
+
pass
|
|
184
|
+
|
|
185
|
+
# 其次尝试 vLLM
|
|
186
|
+
try:
|
|
187
|
+
_get_vllm_engine()
|
|
188
|
+
return "vllm"
|
|
189
|
+
except ImportError:
|
|
190
|
+
pass
|
|
191
|
+
|
|
192
|
+
# 回退到原生
|
|
193
|
+
return "native"
|
|
194
|
+
|
|
195
|
+
|
|
196
|
+
__all__ = [
|
|
197
|
+
# 基类
|
|
198
|
+
"BaseEngine",
|
|
199
|
+
"LLMBaseEngine",
|
|
200
|
+
"LLMBackend",
|
|
201
|
+
"GenerationConfig",
|
|
202
|
+
"GenerationResult",
|
|
203
|
+
|
|
204
|
+
# 具体引擎
|
|
205
|
+
"LLMEngine",
|
|
206
|
+
"ImageGenEngine",
|
|
207
|
+
"VisionEngine",
|
|
208
|
+
|
|
209
|
+
# 工厂和注册
|
|
210
|
+
"ENGINE_REGISTRY",
|
|
211
|
+
"get_engine",
|
|
212
|
+
"create_llm_engine",
|
|
213
|
+
"list_engines",
|
|
214
|
+
"get_recommended_backend",
|
|
215
|
+
]
|