gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,561 @@
1
+ """
2
+ gRPC 服务实现
3
+
4
+ 用于 Worker 间的 P2P 通信,支持:
5
+ - 流式推理
6
+ - KV-Cache 传输
7
+ - 会话管理
8
+ """
9
+ import asyncio
10
+ import logging
11
+ import time
12
+ from typing import Dict, Any, Optional, AsyncIterator
13
+ from concurrent import futures
14
+
15
+ try:
16
+ import grpc
17
+ from grpc import aio as grpc_aio
18
+ HAS_GRPC = True
19
+ except ImportError:
20
+ HAS_GRPC = False
21
+
22
+ import torch
23
+
24
+ from .session import WorkerSession, SessionState
25
+ from .kv_cache import DistributedKVCacheManager
26
+
27
+ # 本地导入
28
+ import sys
29
+ import os
30
+ sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
31
+ from common.serialization import TensorSerializer
32
+
33
+ logger = logging.getLogger(__name__)
34
+
35
+
36
+ class InferenceServicer:
37
+ """
38
+ 分布式推理 gRPC 服务实现
39
+
40
+ 处理来自其他 Worker 或客户端的推理请求
41
+ """
42
+
43
+ def __init__(
44
+ self,
45
+ worker_id: str,
46
+ model_shard,
47
+ kv_cache_manager: DistributedKVCacheManager,
48
+ max_sessions: int = 100,
49
+ ):
50
+ self.worker_id = worker_id
51
+ self.model_shard = model_shard
52
+ self.kv_cache = kv_cache_manager
53
+ self.max_sessions = max_sessions
54
+
55
+ # 会话管理
56
+ self._sessions: Dict[str, Dict[str, Any]] = {}
57
+
58
+ # 统计
59
+ self._stats = {
60
+ "total_requests": 0,
61
+ "total_tokens": 0,
62
+ "total_latency_ms": 0,
63
+ "errors": 0,
64
+ }
65
+
66
+ async def StreamInference(
67
+ self,
68
+ request_iterator: AsyncIterator,
69
+ context,
70
+ ) -> AsyncIterator:
71
+ """
72
+ 流式推理
73
+
74
+ 支持连续的推理步骤,维护会话状态
75
+ """
76
+ session_id = None
77
+
78
+ async for request in request_iterator:
79
+ try:
80
+ start_time = time.time()
81
+ session_id = request.session_id
82
+
83
+ # 获取或创建会话
84
+ session = self._get_or_create_session(session_id, request)
85
+
86
+ # 反序列化输入
87
+ hidden_states = TensorSerializer.deserialize(
88
+ request.hidden_states,
89
+ tuple(request.shape),
90
+ request.dtype,
91
+ device=str(self.model_shard.device) if hasattr(self.model_shard, 'device') else "cuda"
92
+ )
93
+
94
+ # 执行前向传播
95
+ output, kv_keys = await self._forward(
96
+ session,
97
+ hidden_states,
98
+ request.position,
99
+ list(request.kv_cache_keys),
100
+ )
101
+
102
+ # 序列化输出
103
+ output_bytes, output_shape, output_dtype = TensorSerializer.serialize(output)
104
+
105
+ # 如果有下一跳,转发
106
+ if request.next_worker_address:
107
+ await self._forward_to_next(
108
+ request.next_worker_address,
109
+ request.next_session_id,
110
+ output,
111
+ request.position,
112
+ kv_keys,
113
+ )
114
+
115
+ latency_ms = (time.time() - start_time) * 1000
116
+ self._stats["total_requests"] += 1
117
+ self._stats["total_latency_ms"] += latency_ms
118
+
119
+ # 构造响应(模拟 protobuf 消息)
120
+ yield {
121
+ "session_id": session_id,
122
+ "step_id": request.step_id,
123
+ "hidden_states": output_bytes,
124
+ "shape": list(output_shape),
125
+ "dtype": output_dtype,
126
+ "updated_kv_keys": kv_keys,
127
+ "latency_ms": int(latency_ms),
128
+ "tokens_processed": hidden_states.shape[1] if hasattr(hidden_states, 'shape') else 1,
129
+ "success": True,
130
+ "error_message": "",
131
+ }
132
+
133
+ except Exception as e:
134
+ logger.error(f"StreamInference error: {e}")
135
+ self._stats["errors"] += 1
136
+ yield {
137
+ "session_id": session_id or "",
138
+ "success": False,
139
+ "error_message": str(e),
140
+ }
141
+
142
+ async def Forward(self, request, context) -> Dict[str, Any]:
143
+ """单次前向传播"""
144
+ try:
145
+ start_time = time.time()
146
+
147
+ # 获取会话
148
+ session = self._sessions.get(request.session_id)
149
+ if not session:
150
+ session = self._create_session(request.session_id)
151
+
152
+ # 反序列化输入
153
+ input_tensor = TensorSerializer.deserialize(
154
+ request.input,
155
+ tuple(request.shape),
156
+ request.dtype,
157
+ device=str(self.model_shard.device) if hasattr(self.model_shard, 'device') else "cuda"
158
+ )
159
+
160
+ # 执行前向传播
161
+ output, kv_keys = await self._forward(
162
+ session,
163
+ input_tensor,
164
+ request.position,
165
+ list(request.kv_cache_keys),
166
+ )
167
+
168
+ # 序列化输出
169
+ output_bytes, output_shape, output_dtype = TensorSerializer.serialize(output)
170
+
171
+ latency_ms = (time.time() - start_time) * 1000
172
+
173
+ return {
174
+ "output": output_bytes,
175
+ "shape": list(output_shape),
176
+ "dtype": output_dtype,
177
+ "updated_kv_keys": kv_keys,
178
+ "success": True,
179
+ "error_message": "",
180
+ "latency_ms": int(latency_ms),
181
+ }
182
+
183
+ except Exception as e:
184
+ logger.error(f"Forward error: {e}")
185
+ return {
186
+ "success": False,
187
+ "error_message": str(e),
188
+ }
189
+
190
+ async def TransferKVCache(self, request, context) -> Dict[str, Any]:
191
+ """接收 KV-Cache 传输"""
192
+ try:
193
+ start_time = time.time()
194
+ total_bytes = 0
195
+
196
+ for layer_data in request.layers:
197
+ # 反序列化 KV
198
+ keys = TensorSerializer.deserialize(
199
+ layer_data.keys,
200
+ tuple(layer_data.shape),
201
+ layer_data.dtype,
202
+ )
203
+ values = TensorSerializer.deserialize(
204
+ layer_data.values,
205
+ tuple(layer_data.shape),
206
+ layer_data.dtype,
207
+ )
208
+
209
+ # 存储到缓存
210
+ cache_key = f"{request.prefix_key}:{layer_data.layer_idx}"
211
+ await self.kv_cache._store_kv(
212
+ cache_key,
213
+ keys,
214
+ values,
215
+ layer_data.layer_idx,
216
+ request.prefix_key,
217
+ )
218
+
219
+ total_bytes += len(layer_data.keys) + len(layer_data.values)
220
+
221
+ latency_ms = (time.time() - start_time) * 1000
222
+
223
+ return {
224
+ "success": True,
225
+ "error_message": "",
226
+ "bytes_transferred": total_bytes,
227
+ "latency_ms": int(latency_ms),
228
+ }
229
+
230
+ except Exception as e:
231
+ logger.error(f"TransferKVCache error: {e}")
232
+ return {
233
+ "success": False,
234
+ "error_message": str(e),
235
+ }
236
+
237
+ async def CreateSession(self, request, context) -> Dict[str, Any]:
238
+ """创建会话"""
239
+ try:
240
+ if len(self._sessions) >= self.max_sessions:
241
+ return {
242
+ "success": False,
243
+ "error_message": f"Max sessions reached: {self.max_sessions}",
244
+ }
245
+
246
+ import uuid
247
+ session_id = str(uuid.uuid4())
248
+
249
+ self._sessions[session_id] = {
250
+ "session_id": session_id,
251
+ "model_name": request.model_name,
252
+ "max_length": request.max_length,
253
+ "start_layer": request.start_layer,
254
+ "end_layer": request.end_layer,
255
+ "position": 0,
256
+ "created_at": time.time(),
257
+ }
258
+
259
+ return {
260
+ "session_id": session_id,
261
+ "success": True,
262
+ "error_message": "",
263
+ "cache_tokens_available": self._get_cache_capacity(),
264
+ }
265
+
266
+ except Exception as e:
267
+ logger.error(f"CreateSession error: {e}")
268
+ return {
269
+ "success": False,
270
+ "error_message": str(e),
271
+ }
272
+
273
+ async def CloseSession(self, request, context) -> Dict[str, Any]:
274
+ """关闭会话"""
275
+ try:
276
+ session = self._sessions.pop(request.session_id, None)
277
+ if session:
278
+ # 清理会话相关的 KV-Cache
279
+ # TODO: 实现 KV-Cache 清理
280
+ pass
281
+
282
+ return {
283
+ "success": True,
284
+ "error_message": "",
285
+ }
286
+
287
+ except Exception as e:
288
+ return {
289
+ "success": False,
290
+ "error_message": str(e),
291
+ }
292
+
293
+ async def HealthCheck(self, request, context) -> Dict[str, Any]:
294
+ """健康检查"""
295
+ try:
296
+ response = {
297
+ "healthy": True,
298
+ "worker_id": self.worker_id,
299
+ "status": "online",
300
+ "active_sessions": len(self._sessions),
301
+ }
302
+
303
+ if request.include_stats:
304
+ # GPU 信息
305
+ if torch.cuda.is_available():
306
+ response["gpu_memory_used_gb"] = torch.cuda.memory_allocated() / (1024 ** 3)
307
+ response["gpu_memory_total_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
308
+
309
+ # KV-Cache 信息
310
+ cache_stats = self.kv_cache.get_stats()
311
+ response["cache_tokens_used"] = cache_stats.get("total_blocks", 0) * 16 # block_size
312
+ response["cache_tokens_available"] = self._get_cache_capacity()
313
+
314
+ # 性能指标
315
+ if self._stats["total_requests"] > 0:
316
+ response["avg_latency_ms"] = self._stats["total_latency_ms"] / self._stats["total_requests"]
317
+ response["throughput_tokens_per_sec"] = (
318
+ self._stats["total_tokens"] / (self._stats["total_latency_ms"] / 1000)
319
+ if self._stats["total_latency_ms"] > 0 else 0
320
+ )
321
+
322
+ return response
323
+
324
+ except Exception as e:
325
+ return {
326
+ "healthy": False,
327
+ "status": "error",
328
+ "error_message": str(e),
329
+ }
330
+
331
+ def _get_or_create_session(self, session_id: str, request) -> Dict[str, Any]:
332
+ """获取或创建会话"""
333
+ if session_id not in self._sessions:
334
+ self._sessions[session_id] = {
335
+ "session_id": session_id,
336
+ "position": 0,
337
+ "created_at": time.time(),
338
+ }
339
+ return self._sessions[session_id]
340
+
341
+ def _create_session(self, session_id: str) -> Dict[str, Any]:
342
+ """创建新会话"""
343
+ session = {
344
+ "session_id": session_id,
345
+ "position": 0,
346
+ "created_at": time.time(),
347
+ }
348
+ self._sessions[session_id] = session
349
+ return session
350
+
351
+ async def _forward(
352
+ self,
353
+ session: Dict[str, Any],
354
+ hidden_states: torch.Tensor,
355
+ position: int,
356
+ kv_cache_keys: list,
357
+ ) -> tuple:
358
+ """执行前向传播"""
359
+ # 这里应该调用 model_shard 的 forward 方法
360
+ # 简化实现,实际需要处理 KV-Cache
361
+ if self.model_shard is None:
362
+ # 模拟输出
363
+ return hidden_states, kv_cache_keys
364
+
365
+ output, new_kv = self.model_shard.forward(
366
+ hidden_states,
367
+ position_ids=torch.arange(position, position + hidden_states.shape[1], device=hidden_states.device).unsqueeze(0),
368
+ use_cache=True,
369
+ )
370
+
371
+ # 更新会话位置
372
+ session["position"] = position + hidden_states.shape[1]
373
+
374
+ return output, kv_cache_keys
375
+
376
+ async def _forward_to_next(
377
+ self,
378
+ next_address: str,
379
+ next_session_id: str,
380
+ hidden_states: torch.Tensor,
381
+ position: int,
382
+ kv_cache_keys: list,
383
+ ) -> None:
384
+ """转发到下一个 Worker"""
385
+ # TODO: 实现 server-to-server 转发
386
+ # 这需要建立到下一个 Worker 的 gRPC 连接
387
+ logger.debug(f"Forwarding to {next_address} (session: {next_session_id})")
388
+ pass
389
+
390
+ def _get_cache_capacity(self) -> int:
391
+ """获取缓存容量"""
392
+ stats = self.kv_cache.get_stats()
393
+ free_blocks = stats.get("free_blocks", 0)
394
+ return free_blocks * 16 # block_size
395
+
396
+
397
+ class GRPCServer:
398
+ """
399
+ gRPC 服务器
400
+
401
+ 启动和管理 gRPC 服务
402
+ """
403
+
404
+ def __init__(
405
+ self,
406
+ servicer: InferenceServicer,
407
+ host: str = "0.0.0.0",
408
+ port: int = 50051,
409
+ max_workers: int = 10,
410
+ ):
411
+ if not HAS_GRPC:
412
+ raise ImportError("grpcio not installed. Please install with: pip install grpcio")
413
+
414
+ self.servicer = servicer
415
+ self.host = host
416
+ self.port = port
417
+ self.max_workers = max_workers
418
+
419
+ self._server: Optional[grpc_aio.Server] = None
420
+
421
+ async def start(self) -> None:
422
+ """启动服务器"""
423
+ self._server = grpc_aio.server(
424
+ futures.ThreadPoolExecutor(max_workers=self.max_workers)
425
+ )
426
+
427
+ # 注册服务
428
+ # 注意:需要从 proto 生成的代码注册
429
+ # 这里使用简化的 HTTP 风格服务
430
+
431
+ listen_addr = f"{self.host}:{self.port}"
432
+ self._server.add_insecure_port(listen_addr)
433
+
434
+ await self._server.start()
435
+ logger.info(f"gRPC server started on {listen_addr}")
436
+
437
+ async def stop(self, grace: float = 5.0) -> None:
438
+ """停止服务器"""
439
+ if self._server:
440
+ await self._server.stop(grace)
441
+ logger.info("gRPC server stopped")
442
+
443
+ async def wait_for_termination(self) -> None:
444
+ """等待服务器终止"""
445
+ if self._server:
446
+ await self._server.wait_for_termination()
447
+
448
+
449
+ # HTTP 风格的 P2P API(作为 gRPC 的替代方案)
450
+ class HTTPInferenceServer:
451
+ """
452
+ HTTP 风格的推理服务器
453
+
454
+ 使用 FastAPI/aiohttp 实现,作为 gRPC 的简单替代
455
+ """
456
+
457
+ def __init__(
458
+ self,
459
+ servicer: InferenceServicer,
460
+ host: str = "0.0.0.0",
461
+ port: int = 8080,
462
+ ):
463
+ self.servicer = servicer
464
+ self.host = host
465
+ self.port = port
466
+ self._app = None
467
+ self._runner = None
468
+
469
+ async def start(self) -> None:
470
+ """启动 HTTP 服务器"""
471
+ try:
472
+ from aiohttp import web
473
+ except ImportError:
474
+ raise ImportError("aiohttp not installed. Please install with: pip install aiohttp")
475
+
476
+ self._app = web.Application()
477
+
478
+ # 注册路由
479
+ self._app.router.add_post("/inference/forward", self._handle_forward)
480
+ self._app.router.add_post("/inference/close", self._handle_close_session)
481
+ self._app.router.add_get("/health", self._handle_health)
482
+
483
+ self._runner = web.AppRunner(self._app)
484
+ await self._runner.setup()
485
+
486
+ site = web.TCPSite(self._runner, self.host, self.port)
487
+ await site.start()
488
+
489
+ logger.info(f"HTTP inference server started on {self.host}:{self.port}")
490
+
491
+ async def stop(self) -> None:
492
+ """停止服务器"""
493
+ if self._runner:
494
+ await self._runner.cleanup()
495
+ logger.info("HTTP inference server stopped")
496
+
497
+ async def _handle_forward(self, request) -> "web.Response":
498
+ """处理前向传播请求"""
499
+ from aiohttp import web
500
+ import json
501
+
502
+ try:
503
+ data = await request.json()
504
+
505
+ # 创建请求对象
506
+ class MockRequest:
507
+ def __init__(self, d):
508
+ self.session_id = d.get("session_id", "")
509
+ self.input = bytes.fromhex(d.get("input", {}).get("data", ""))
510
+ self.shape = tuple(d.get("input", {}).get("shape", []))
511
+ self.dtype = d.get("input", {}).get("dtype", "float16")
512
+ self.position = d.get("position", 0)
513
+ self.kv_cache_keys = d.get("kv_cache_keys", [])
514
+
515
+ result = await self.servicer.Forward(MockRequest(data), None)
516
+
517
+ return web.json_response(result)
518
+
519
+ except Exception as e:
520
+ logger.error(f"Forward error: {e}")
521
+ return web.json_response(
522
+ {"success": False, "error_message": str(e)},
523
+ status=500
524
+ )
525
+
526
+ async def _handle_close_session(self, request) -> "web.Response":
527
+ """处理关闭会话请求"""
528
+ from aiohttp import web
529
+
530
+ try:
531
+ data = await request.json()
532
+
533
+ class MockRequest:
534
+ def __init__(self, d):
535
+ self.session_id = d.get("session_id", "")
536
+
537
+ result = await self.servicer.CloseSession(MockRequest(data), None)
538
+ return web.json_response(result)
539
+
540
+ except Exception as e:
541
+ return web.json_response(
542
+ {"success": False, "error_message": str(e)},
543
+ status=500
544
+ )
545
+
546
+ async def _handle_health(self, request) -> "web.Response":
547
+ """处理健康检查请求"""
548
+ from aiohttp import web
549
+
550
+ try:
551
+ class MockRequest:
552
+ include_stats = True
553
+
554
+ result = await self.servicer.HealthCheck(MockRequest(), None)
555
+ return web.json_response(result)
556
+
557
+ except Exception as e:
558
+ return web.json_response(
559
+ {"healthy": False, "error_message": str(e)},
560
+ status=500
561
+ )