gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,455 @@
1
+ """
2
+ 分布式推理会话
3
+
4
+ 参考 Petals InferenceSession 设计,实现:
5
+ - 跨 Worker 的推理会话管理
6
+ - 故障检测与自动恢复
7
+ - Server-to-Server 直连传输
8
+ """
9
+ import asyncio
10
+ import threading
11
+ from concurrent.futures import Future
12
+ import uuid
13
+ import time
14
+ import logging
15
+ from typing import Dict, Any, List, Optional, Tuple
16
+ from dataclasses import dataclass, field
17
+ from enum import Enum
18
+
19
+ import aiohttp
20
+
21
+ # 本地导入
22
+ import sys
23
+ import os
24
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
25
+ from common.data_structures import (
26
+ BlockRange,
27
+ WorkerInfo,
28
+ InferenceState,
29
+ SessionConfig,
30
+ WorkerState,
31
+ )
32
+ from common.serialization import serialize_tensor, deserialize_tensor
33
+
34
+ logger = logging.getLogger(__name__)
35
+
36
+ def _run_coroutine_in_new_thread(coro):
37
+ future: Future = Future()
38
+
39
+ def runner() -> None:
40
+ try:
41
+ future.set_result(asyncio.run(coro))
42
+ except BaseException as exc:
43
+ future.set_exception(exc)
44
+
45
+ threading.Thread(target=runner, daemon=True).start()
46
+ return future.result()
47
+
48
+
49
+ class SessionState(Enum):
50
+ """会话状态"""
51
+ INITIALIZING = "initializing"
52
+ READY = "ready"
53
+ ACTIVE = "active"
54
+ ERROR = "error"
55
+ CLOSED = "closed"
56
+
57
+
58
+ @dataclass
59
+ class WorkerSession:
60
+ """
61
+ 单个 Worker 的推理会话
62
+
63
+ 管理与特定 Worker 的连接和状态
64
+ """
65
+ worker_info: WorkerInfo
66
+ session_id: str = field(default_factory=lambda: str(uuid.uuid4()))
67
+ state: SessionState = SessionState.INITIALIZING
68
+
69
+ # 会话状态
70
+ position: int = 0
71
+ history: Optional[Any] = None # 用于故障恢复的输入历史
72
+
73
+ # 下一跳会话(用于 server-to-server)
74
+ next_session: Optional["WorkerSession"] = None
75
+
76
+ # 连接
77
+ _http_session: Optional[aiohttp.ClientSession] = None
78
+
79
+ async def connect(self, timeout: float = 30.0) -> None:
80
+ """建立与 Worker 的连接"""
81
+ if self._http_session is None:
82
+ self._http_session = aiohttp.ClientSession(
83
+ timeout=aiohttp.ClientTimeout(total=timeout)
84
+ )
85
+
86
+ # 验证 Worker 可用性
87
+ try:
88
+ async with self._http_session.get(
89
+ f"{self.worker_info.api_endpoint}/health"
90
+ ) as response:
91
+ if response.status != 200:
92
+ raise ConnectionError(
93
+ f"Worker health check failed: {response.status}"
94
+ )
95
+ except Exception as e:
96
+ self.state = SessionState.ERROR
97
+ raise ConnectionError(f"Failed to connect to worker: {e}")
98
+
99
+ self.state = SessionState.READY
100
+ logger.info(f"Connected to worker {self.worker_info.worker_id}")
101
+
102
+ async def forward(
103
+ self,
104
+ hidden_states: Any,
105
+ position: int,
106
+ kv_cache_keys: List[str] = None,
107
+ ) -> Tuple[Any, List[str]]:
108
+ """
109
+ 执行前向传播
110
+
111
+ Args:
112
+ hidden_states: 输入隐藏状态 (tensor)
113
+ position: 当前位置
114
+ kv_cache_keys: KV-Cache 键列表
115
+
116
+ Returns:
117
+ (output_hidden_states, updated_kv_keys)
118
+ """
119
+ if self.state not in (SessionState.READY, SessionState.ACTIVE):
120
+ raise RuntimeError(f"Session not ready: {self.state}")
121
+
122
+ self.state = SessionState.ACTIVE
123
+
124
+ # 序列化输入
125
+ serialized_input = serialize_tensor(hidden_states)
126
+
127
+ # 构建请求
128
+ payload = {
129
+ "session_id": self.session_id,
130
+ "input": serialized_input,
131
+ "position": position,
132
+ "kv_cache_keys": kv_cache_keys or [],
133
+ "blocks": self.worker_info.blocks.to_dict() if self.worker_info.blocks else None,
134
+ }
135
+
136
+ # 如果有下一跳,添加路由信息
137
+ if self.next_session:
138
+ payload["next_worker"] = {
139
+ "address": self.next_session.worker_info.api_endpoint,
140
+ "session_id": self.next_session.session_id,
141
+ }
142
+
143
+ # 发送请求
144
+ try:
145
+ async with self._http_session.post(
146
+ f"{self.worker_info.api_endpoint}/inference/forward",
147
+ json=payload
148
+ ) as response:
149
+ if response.status != 200:
150
+ error = await response.text()
151
+ raise RuntimeError(f"Forward failed: {error}")
152
+
153
+ result = await response.json()
154
+
155
+ except Exception as e:
156
+ self.state = SessionState.ERROR
157
+ raise RuntimeError(f"Forward error: {e}")
158
+
159
+ # 反序列化输出
160
+ output_hidden_states = deserialize_tensor(result["output"])
161
+ updated_kv_keys = result.get("kv_cache_keys", [])
162
+
163
+ # 更新位置
164
+ self.position = position + hidden_states.shape[1] if hasattr(hidden_states, 'shape') else position + 1
165
+
166
+ return output_hidden_states, updated_kv_keys
167
+
168
+ async def close(self) -> None:
169
+ """关闭会话"""
170
+ if self._http_session:
171
+ # 通知 Worker 关闭会话
172
+ try:
173
+ async with self._http_session.post(
174
+ f"{self.worker_info.api_endpoint}/inference/close",
175
+ json={"session_id": self.session_id}
176
+ ) as response:
177
+ pass
178
+ except Exception as e:
179
+ logger.warning(f"Error closing worker session: {e}")
180
+
181
+ await self._http_session.close()
182
+ self._http_session = None
183
+
184
+ self.state = SessionState.CLOSED
185
+
186
+ def __enter__(self):
187
+ return self
188
+
189
+ def __exit__(self, *exc):
190
+ try:
191
+ asyncio.get_running_loop()
192
+ except RuntimeError:
193
+ asyncio.run(self.close())
194
+ else:
195
+ _run_coroutine_in_new_thread(self.close())
196
+
197
+
198
+ class DistributedInferenceSession:
199
+ """
200
+ 分布式推理会话
201
+
202
+ 管理跨多个 Worker 的推理会话,参考 Petals InferenceSession
203
+ """
204
+
205
+ def __init__(
206
+ self,
207
+ config: SessionConfig,
208
+ route: List[WorkerInfo],
209
+ ):
210
+ """
211
+ Args:
212
+ config: 会话配置
213
+ route: 推理路由(按顺序的 Worker 列表)
214
+ """
215
+ self.config = config
216
+ self.route = route
217
+
218
+ self.session_id = str(uuid.uuid4())
219
+ self.state = SessionState.INITIALIZING
220
+
221
+ # Worker 会话
222
+ self._worker_sessions: List[WorkerSession] = []
223
+
224
+ # 推理状态
225
+ self._position = 0
226
+ self._max_length = config.max_length
227
+
228
+ # 统计信息
229
+ self._stats = {
230
+ "total_tokens": 0,
231
+ "total_steps": 0,
232
+ "total_latency_ms": 0,
233
+ "retries": 0,
234
+ }
235
+
236
+ @property
237
+ def position(self) -> int:
238
+ return self._position
239
+
240
+ @position.setter
241
+ def position(self, value: int) -> None:
242
+ self._position = value
243
+ for session in self._worker_sessions:
244
+ session.position = value
245
+
246
+ async def setup(self) -> None:
247
+ """建立与所有 Worker 的连接"""
248
+ logger.info(f"Setting up distributed session with {len(self.route)} workers")
249
+
250
+ try:
251
+ for worker_info in self.route:
252
+ session = WorkerSession(worker_info=worker_info)
253
+ await session.connect(timeout=self.config.connect_timeout)
254
+ self._worker_sessions.append(session)
255
+
256
+ # 链接会话(用于 server-to-server)
257
+ for i in range(len(self._worker_sessions) - 1):
258
+ self._worker_sessions[i].next_session = self._worker_sessions[i + 1]
259
+
260
+ self.state = SessionState.READY
261
+ logger.info("Distributed session setup complete")
262
+
263
+ except Exception as e:
264
+ self.state = SessionState.ERROR
265
+ # 清理已创建的会话
266
+ for session in self._worker_sessions:
267
+ await session.close()
268
+ self._worker_sessions.clear()
269
+ raise
270
+
271
+ async def step(
272
+ self,
273
+ inputs: Any,
274
+ kv_cache_keys: List[str] = None,
275
+ ) -> Any:
276
+ """
277
+ 执行一步推理
278
+
279
+ Args:
280
+ inputs: 输入 tensor
281
+ kv_cache_keys: KV-Cache 键列表
282
+
283
+ Returns:
284
+ 输出 tensor
285
+ """
286
+ if self.state not in (SessionState.READY, SessionState.ACTIVE):
287
+ raise RuntimeError(f"Session not ready: {self.state}")
288
+
289
+ self.state = SessionState.ACTIVE
290
+ step_start = time.time()
291
+
292
+ # 检查长度限制
293
+ n_input_tokens = inputs.shape[1] if hasattr(inputs, 'shape') else 1
294
+ if self._position + n_input_tokens > self._max_length:
295
+ raise ValueError(
296
+ f"Maximum length exceeded: {self._position} + {n_input_tokens} > {self._max_length}"
297
+ )
298
+
299
+ hidden_states = inputs
300
+ current_kv_keys = kv_cache_keys or []
301
+
302
+ # 依次通过每个 Worker
303
+ for i, session in enumerate(self._worker_sessions):
304
+ for attempt in range(self.config.max_retries):
305
+ try:
306
+ hidden_states, current_kv_keys = await session.forward(
307
+ hidden_states,
308
+ position=self._position,
309
+ kv_cache_keys=current_kv_keys,
310
+ )
311
+ break
312
+
313
+ except Exception as e:
314
+ logger.warning(
315
+ f"Worker {session.worker_info.worker_id} failed "
316
+ f"(attempt {attempt + 1}/{self.config.max_retries}): {e}"
317
+ )
318
+ self._stats["retries"] += 1
319
+
320
+ if attempt + 1 == self.config.max_retries:
321
+ # 尝试故障恢复
322
+ await self._handle_failure(i, e)
323
+ hidden_states, current_kv_keys = await session.forward(
324
+ hidden_states,
325
+ position=self._position,
326
+ kv_cache_keys=current_kv_keys,
327
+ )
328
+ else:
329
+ await asyncio.sleep(0.5 * (attempt + 1)) # 指数退避
330
+
331
+ # 更新状态
332
+ self._position += n_input_tokens
333
+ self._stats["total_tokens"] += n_input_tokens
334
+ self._stats["total_steps"] += 1
335
+ self._stats["total_latency_ms"] += (time.time() - step_start) * 1000
336
+
337
+ return hidden_states
338
+
339
+ async def _handle_failure(
340
+ self,
341
+ failed_idx: int,
342
+ error: Exception
343
+ ) -> None:
344
+ """
345
+ 处理 Worker 故障
346
+
347
+ Args:
348
+ failed_idx: 故障 Worker 的索引
349
+ error: 错误信息
350
+ """
351
+ failed_session = self._worker_sessions[failed_idx]
352
+ logger.error(
353
+ f"Worker {failed_session.worker_info.worker_id} failed: {error}. "
354
+ f"Attempting recovery..."
355
+ )
356
+
357
+ # 关闭故障会话
358
+ await failed_session.close()
359
+
360
+ # TODO: 从调度器获取替代 Worker
361
+ # 这里需要集成调度器服务
362
+ raise RuntimeError(
363
+ f"Worker failure recovery not implemented. "
364
+ f"Failed worker: {failed_session.worker_info.worker_id}"
365
+ )
366
+
367
+ async def close(self) -> None:
368
+ """关闭会话"""
369
+ for session in self._worker_sessions:
370
+ try:
371
+ await session.close()
372
+ except Exception as e:
373
+ logger.warning(f"Error closing session: {e}")
374
+
375
+ self._worker_sessions.clear()
376
+ self.state = SessionState.CLOSED
377
+ logger.info(f"Distributed session closed. Stats: {self._stats}")
378
+
379
+ def get_stats(self) -> Dict[str, Any]:
380
+ """获取统计信息"""
381
+ stats = self._stats.copy()
382
+ if stats["total_steps"] > 0:
383
+ stats["avg_latency_ms"] = stats["total_latency_ms"] / stats["total_steps"]
384
+ stats["tokens_per_second"] = (
385
+ stats["total_tokens"] / (stats["total_latency_ms"] / 1000)
386
+ if stats["total_latency_ms"] > 0 else 0
387
+ )
388
+ return stats
389
+
390
+ async def __aenter__(self):
391
+ await self.setup()
392
+ return self
393
+
394
+ async def __aexit__(self, *exc):
395
+ await self.close()
396
+
397
+
398
+ class SessionManager:
399
+ """
400
+ 会话管理器
401
+
402
+ 管理多个分布式推理会话的生命周期
403
+ """
404
+
405
+ def __init__(self, max_sessions: int = 100):
406
+ self.max_sessions = max_sessions
407
+ self._sessions: Dict[str, DistributedInferenceSession] = {}
408
+ self._lock = asyncio.Lock()
409
+
410
+ async def create_session(
411
+ self,
412
+ config: SessionConfig,
413
+ route: List[WorkerInfo],
414
+ ) -> DistributedInferenceSession:
415
+ """创建新会话"""
416
+ async with self._lock:
417
+ if len(self._sessions) >= self.max_sessions:
418
+ # 清理过期会话
419
+ await self._cleanup_expired_sessions()
420
+
421
+ if len(self._sessions) >= self.max_sessions:
422
+ raise RuntimeError(f"Maximum sessions reached: {self.max_sessions}")
423
+
424
+ session = DistributedInferenceSession(config, route)
425
+ await session.setup()
426
+ self._sessions[session.session_id] = session
427
+
428
+ return session
429
+
430
+ async def get_session(self, session_id: str) -> Optional[DistributedInferenceSession]:
431
+ """获取会话"""
432
+ return self._sessions.get(session_id)
433
+
434
+ async def close_session(self, session_id: str) -> None:
435
+ """关闭会话"""
436
+ async with self._lock:
437
+ session = self._sessions.pop(session_id, None)
438
+ if session:
439
+ await session.close()
440
+
441
+ async def _cleanup_expired_sessions(self) -> None:
442
+ """清理过期会话"""
443
+ expired = [
444
+ sid for sid, session in self._sessions.items()
445
+ if session.state in (SessionState.CLOSED, SessionState.ERROR)
446
+ ]
447
+ for sid in expired:
448
+ await self.close_session(sid)
449
+
450
+ async def close_all(self) -> None:
451
+ """关闭所有会话"""
452
+ async with self._lock:
453
+ for session in list(self._sessions.values()):
454
+ await session.close()
455
+ self._sessions.clear()
@@ -0,0 +1,215 @@
1
+ """引擎模块
2
+
3
+ 支持多种推理后端:
4
+ - llm: 原生 Transformers 后端(兼容性好)
5
+ - llm_sglang: SGLang 高性能后端(推荐,RadixAttention)
6
+ - llm_vllm: vLLM 高性能后端(PagedAttention)
7
+ - llm_vllm_async: vLLM 异步引擎(支持流式)
8
+ - image_gen: 图像生成引擎
9
+ - vision: 视觉模型引擎
10
+
11
+ 使用示例:
12
+ # 方式1: 直接使用引擎类
13
+ from engines import LLMEngine
14
+ engine = LLMEngine(config)
15
+
16
+ # 方式2: 通过配置选择后端
17
+ from engines import create_llm_engine
18
+ engine = create_llm_engine({"backend": "sglang", "model_id": "..."})
19
+
20
+ # 方式3: 通过类型名获取
21
+ from engines import get_engine
22
+ EngineClass = get_engine("llm_sglang")
23
+ engine = EngineClass(config)
24
+ """
25
+ from typing import Dict, Any, Optional
26
+
27
+ from .base import BaseEngine
28
+ from .llm import LLMEngine
29
+ from .llm_base import LLMBaseEngine, LLMBackend, GenerationConfig, GenerationResult
30
+ from .image_gen import ImageGenEngine
31
+ from .vision import VisionEngine
32
+
33
+
34
+ # 延迟导入高性能引擎(可能需要额外依赖)
35
+ def _get_sglang_engine():
36
+ from .llm_sglang import SGLangEngine
37
+ return SGLangEngine
38
+
39
+
40
+ def _get_vllm_engine():
41
+ from .llm_vllm import VLLMEngine
42
+ return VLLMEngine
43
+
44
+
45
+ def _get_vllm_async_engine():
46
+ from .llm_vllm import VLLMAsyncEngine
47
+ return VLLMAsyncEngine
48
+
49
+
50
+ # 引擎注册表
51
+ ENGINE_REGISTRY = {
52
+ # 原生后端
53
+ "llm": LLMEngine,
54
+ "image_gen": ImageGenEngine,
55
+ "vision": VisionEngine,
56
+ }
57
+
58
+ # 高性能后端(延迟注册)
59
+ _LAZY_ENGINES = {
60
+ "llm_sglang": _get_sglang_engine,
61
+ "llm_vllm": _get_vllm_engine,
62
+ "llm_vllm_async": _get_vllm_async_engine,
63
+ }
64
+
65
+ # 后端别名映射
66
+ _BACKEND_ALIASES = {
67
+ "native": "llm",
68
+ "transformers": "llm",
69
+ "sglang": "llm_sglang",
70
+ "vllm": "llm_vllm",
71
+ "vllm_async": "llm_vllm_async",
72
+ }
73
+
74
+
75
+ def get_engine(engine_type: str) -> type:
76
+ """
77
+ 获取引擎类
78
+
79
+ Args:
80
+ engine_type: 引擎类型名称
81
+
82
+ Returns:
83
+ 引擎类
84
+
85
+ Raises:
86
+ ValueError: 未知的引擎类型
87
+ ImportError: 引擎依赖未安装
88
+ """
89
+ # 处理别名
90
+ engine_type = _BACKEND_ALIASES.get(engine_type, engine_type)
91
+
92
+ if engine_type in ENGINE_REGISTRY:
93
+ return ENGINE_REGISTRY[engine_type]
94
+
95
+ if engine_type in _LAZY_ENGINES:
96
+ try:
97
+ engine_class = _LAZY_ENGINES[engine_type]()
98
+ ENGINE_REGISTRY[engine_type] = engine_class # 缓存
99
+ return engine_class
100
+ except ImportError as e:
101
+ raise ImportError(
102
+ f"Engine '{engine_type}' requires additional dependencies: {e}"
103
+ )
104
+
105
+ raise ValueError(f"Unknown engine type: {engine_type}")
106
+
107
+
108
+ def create_llm_engine(config: Dict[str, Any]) -> LLMBaseEngine:
109
+ """
110
+ 根据配置创建 LLM 引擎
111
+
112
+ 这是创建 LLM 引擎的推荐方式,会根据配置中的 backend 字段
113
+ 自动选择合适的引擎实现。
114
+
115
+ Args:
116
+ config: 引擎配置,应包含:
117
+ - backend: 后端类型 ("native", "sglang", "vllm", "vllm_async")
118
+ - model_id: 模型 ID
119
+ - 其他后端特定配置
120
+
121
+ Returns:
122
+ LLM 引擎实例
123
+
124
+ 示例:
125
+ config = {
126
+ "backend": "sglang",
127
+ "model_id": "Qwen/Qwen2.5-7B-Instruct",
128
+ "sglang": {
129
+ "tp_size": 1,
130
+ "mem_fraction_static": 0.85,
131
+ "enable_prefix_caching": True,
132
+ }
133
+ }
134
+ engine = create_llm_engine(config)
135
+ """
136
+ backend = config.get("backend", "native").lower()
137
+
138
+ # 获取引擎类型
139
+ engine_type = _BACKEND_ALIASES.get(backend, backend)
140
+
141
+ # 验证是 LLM 引擎
142
+ if not engine_type.startswith("llm"):
143
+ raise ValueError(f"'{backend}' is not a valid LLM backend")
144
+
145
+ # 获取引擎类
146
+ engine_class = get_engine(engine_type)
147
+
148
+ # 创建实例
149
+ return engine_class(config)
150
+
151
+
152
+ def list_engines() -> dict:
153
+ """列出所有可用引擎及其状态"""
154
+ engines = {}
155
+
156
+ # 已注册引擎
157
+ for name in ENGINE_REGISTRY:
158
+ engines[name] = {"available": True, "loaded": True}
159
+
160
+ # 延迟加载引擎
161
+ for name, loader in _LAZY_ENGINES.items():
162
+ if name not in engines:
163
+ try:
164
+ loader()
165
+ engines[name] = {"available": True, "loaded": False}
166
+ except ImportError as e:
167
+ engines[name] = {"available": False, "error": str(e)}
168
+
169
+ return engines
170
+
171
+
172
+ def get_recommended_backend() -> str:
173
+ """
174
+ 获取推荐的 LLM 后端
175
+
176
+ 按优先级尝试:SGLang > vLLM > Native
177
+ """
178
+ # 优先尝试 SGLang
179
+ try:
180
+ _get_sglang_engine()
181
+ return "sglang"
182
+ except ImportError:
183
+ pass
184
+
185
+ # 其次尝试 vLLM
186
+ try:
187
+ _get_vllm_engine()
188
+ return "vllm"
189
+ except ImportError:
190
+ pass
191
+
192
+ # 回退到原生
193
+ return "native"
194
+
195
+
196
+ __all__ = [
197
+ # 基类
198
+ "BaseEngine",
199
+ "LLMBaseEngine",
200
+ "LLMBackend",
201
+ "GenerationConfig",
202
+ "GenerationResult",
203
+
204
+ # 具体引擎
205
+ "LLMEngine",
206
+ "ImageGenEngine",
207
+ "VisionEngine",
208
+
209
+ # 工厂和注册
210
+ "ENGINE_REGISTRY",
211
+ "get_engine",
212
+ "create_llm_engine",
213
+ "list_engines",
214
+ "get_recommended_backend",
215
+ ]