gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,561 @@
|
|
|
1
|
+
"""
|
|
2
|
+
gRPC 服务实现
|
|
3
|
+
|
|
4
|
+
用于 Worker 间的 P2P 通信,支持:
|
|
5
|
+
- 流式推理
|
|
6
|
+
- KV-Cache 传输
|
|
7
|
+
- 会话管理
|
|
8
|
+
"""
|
|
9
|
+
import asyncio
|
|
10
|
+
import logging
|
|
11
|
+
import time
|
|
12
|
+
from typing import Dict, Any, Optional, AsyncIterator
|
|
13
|
+
from concurrent import futures
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
import grpc
|
|
17
|
+
from grpc import aio as grpc_aio
|
|
18
|
+
HAS_GRPC = True
|
|
19
|
+
except ImportError:
|
|
20
|
+
HAS_GRPC = False
|
|
21
|
+
|
|
22
|
+
import torch
|
|
23
|
+
|
|
24
|
+
from .session import WorkerSession, SessionState
|
|
25
|
+
from .kv_cache import DistributedKVCacheManager
|
|
26
|
+
|
|
27
|
+
# 本地导入
|
|
28
|
+
import sys
|
|
29
|
+
import os
|
|
30
|
+
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.dirname(__file__))))
|
|
31
|
+
from common.serialization import TensorSerializer
|
|
32
|
+
|
|
33
|
+
logger = logging.getLogger(__name__)
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
class InferenceServicer:
|
|
37
|
+
"""
|
|
38
|
+
分布式推理 gRPC 服务实现
|
|
39
|
+
|
|
40
|
+
处理来自其他 Worker 或客户端的推理请求
|
|
41
|
+
"""
|
|
42
|
+
|
|
43
|
+
def __init__(
|
|
44
|
+
self,
|
|
45
|
+
worker_id: str,
|
|
46
|
+
model_shard,
|
|
47
|
+
kv_cache_manager: DistributedKVCacheManager,
|
|
48
|
+
max_sessions: int = 100,
|
|
49
|
+
):
|
|
50
|
+
self.worker_id = worker_id
|
|
51
|
+
self.model_shard = model_shard
|
|
52
|
+
self.kv_cache = kv_cache_manager
|
|
53
|
+
self.max_sessions = max_sessions
|
|
54
|
+
|
|
55
|
+
# 会话管理
|
|
56
|
+
self._sessions: Dict[str, Dict[str, Any]] = {}
|
|
57
|
+
|
|
58
|
+
# 统计
|
|
59
|
+
self._stats = {
|
|
60
|
+
"total_requests": 0,
|
|
61
|
+
"total_tokens": 0,
|
|
62
|
+
"total_latency_ms": 0,
|
|
63
|
+
"errors": 0,
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
async def StreamInference(
|
|
67
|
+
self,
|
|
68
|
+
request_iterator: AsyncIterator,
|
|
69
|
+
context,
|
|
70
|
+
) -> AsyncIterator:
|
|
71
|
+
"""
|
|
72
|
+
流式推理
|
|
73
|
+
|
|
74
|
+
支持连续的推理步骤,维护会话状态
|
|
75
|
+
"""
|
|
76
|
+
session_id = None
|
|
77
|
+
|
|
78
|
+
async for request in request_iterator:
|
|
79
|
+
try:
|
|
80
|
+
start_time = time.time()
|
|
81
|
+
session_id = request.session_id
|
|
82
|
+
|
|
83
|
+
# 获取或创建会话
|
|
84
|
+
session = self._get_or_create_session(session_id, request)
|
|
85
|
+
|
|
86
|
+
# 反序列化输入
|
|
87
|
+
hidden_states = TensorSerializer.deserialize(
|
|
88
|
+
request.hidden_states,
|
|
89
|
+
tuple(request.shape),
|
|
90
|
+
request.dtype,
|
|
91
|
+
device=str(self.model_shard.device) if hasattr(self.model_shard, 'device') else "cuda"
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# 执行前向传播
|
|
95
|
+
output, kv_keys = await self._forward(
|
|
96
|
+
session,
|
|
97
|
+
hidden_states,
|
|
98
|
+
request.position,
|
|
99
|
+
list(request.kv_cache_keys),
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
# 序列化输出
|
|
103
|
+
output_bytes, output_shape, output_dtype = TensorSerializer.serialize(output)
|
|
104
|
+
|
|
105
|
+
# 如果有下一跳,转发
|
|
106
|
+
if request.next_worker_address:
|
|
107
|
+
await self._forward_to_next(
|
|
108
|
+
request.next_worker_address,
|
|
109
|
+
request.next_session_id,
|
|
110
|
+
output,
|
|
111
|
+
request.position,
|
|
112
|
+
kv_keys,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
116
|
+
self._stats["total_requests"] += 1
|
|
117
|
+
self._stats["total_latency_ms"] += latency_ms
|
|
118
|
+
|
|
119
|
+
# 构造响应(模拟 protobuf 消息)
|
|
120
|
+
yield {
|
|
121
|
+
"session_id": session_id,
|
|
122
|
+
"step_id": request.step_id,
|
|
123
|
+
"hidden_states": output_bytes,
|
|
124
|
+
"shape": list(output_shape),
|
|
125
|
+
"dtype": output_dtype,
|
|
126
|
+
"updated_kv_keys": kv_keys,
|
|
127
|
+
"latency_ms": int(latency_ms),
|
|
128
|
+
"tokens_processed": hidden_states.shape[1] if hasattr(hidden_states, 'shape') else 1,
|
|
129
|
+
"success": True,
|
|
130
|
+
"error_message": "",
|
|
131
|
+
}
|
|
132
|
+
|
|
133
|
+
except Exception as e:
|
|
134
|
+
logger.error(f"StreamInference error: {e}")
|
|
135
|
+
self._stats["errors"] += 1
|
|
136
|
+
yield {
|
|
137
|
+
"session_id": session_id or "",
|
|
138
|
+
"success": False,
|
|
139
|
+
"error_message": str(e),
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
async def Forward(self, request, context) -> Dict[str, Any]:
|
|
143
|
+
"""单次前向传播"""
|
|
144
|
+
try:
|
|
145
|
+
start_time = time.time()
|
|
146
|
+
|
|
147
|
+
# 获取会话
|
|
148
|
+
session = self._sessions.get(request.session_id)
|
|
149
|
+
if not session:
|
|
150
|
+
session = self._create_session(request.session_id)
|
|
151
|
+
|
|
152
|
+
# 反序列化输入
|
|
153
|
+
input_tensor = TensorSerializer.deserialize(
|
|
154
|
+
request.input,
|
|
155
|
+
tuple(request.shape),
|
|
156
|
+
request.dtype,
|
|
157
|
+
device=str(self.model_shard.device) if hasattr(self.model_shard, 'device') else "cuda"
|
|
158
|
+
)
|
|
159
|
+
|
|
160
|
+
# 执行前向传播
|
|
161
|
+
output, kv_keys = await self._forward(
|
|
162
|
+
session,
|
|
163
|
+
input_tensor,
|
|
164
|
+
request.position,
|
|
165
|
+
list(request.kv_cache_keys),
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
# 序列化输出
|
|
169
|
+
output_bytes, output_shape, output_dtype = TensorSerializer.serialize(output)
|
|
170
|
+
|
|
171
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
172
|
+
|
|
173
|
+
return {
|
|
174
|
+
"output": output_bytes,
|
|
175
|
+
"shape": list(output_shape),
|
|
176
|
+
"dtype": output_dtype,
|
|
177
|
+
"updated_kv_keys": kv_keys,
|
|
178
|
+
"success": True,
|
|
179
|
+
"error_message": "",
|
|
180
|
+
"latency_ms": int(latency_ms),
|
|
181
|
+
}
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(f"Forward error: {e}")
|
|
185
|
+
return {
|
|
186
|
+
"success": False,
|
|
187
|
+
"error_message": str(e),
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
async def TransferKVCache(self, request, context) -> Dict[str, Any]:
|
|
191
|
+
"""接收 KV-Cache 传输"""
|
|
192
|
+
try:
|
|
193
|
+
start_time = time.time()
|
|
194
|
+
total_bytes = 0
|
|
195
|
+
|
|
196
|
+
for layer_data in request.layers:
|
|
197
|
+
# 反序列化 KV
|
|
198
|
+
keys = TensorSerializer.deserialize(
|
|
199
|
+
layer_data.keys,
|
|
200
|
+
tuple(layer_data.shape),
|
|
201
|
+
layer_data.dtype,
|
|
202
|
+
)
|
|
203
|
+
values = TensorSerializer.deserialize(
|
|
204
|
+
layer_data.values,
|
|
205
|
+
tuple(layer_data.shape),
|
|
206
|
+
layer_data.dtype,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
# 存储到缓存
|
|
210
|
+
cache_key = f"{request.prefix_key}:{layer_data.layer_idx}"
|
|
211
|
+
await self.kv_cache._store_kv(
|
|
212
|
+
cache_key,
|
|
213
|
+
keys,
|
|
214
|
+
values,
|
|
215
|
+
layer_data.layer_idx,
|
|
216
|
+
request.prefix_key,
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
total_bytes += len(layer_data.keys) + len(layer_data.values)
|
|
220
|
+
|
|
221
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
222
|
+
|
|
223
|
+
return {
|
|
224
|
+
"success": True,
|
|
225
|
+
"error_message": "",
|
|
226
|
+
"bytes_transferred": total_bytes,
|
|
227
|
+
"latency_ms": int(latency_ms),
|
|
228
|
+
}
|
|
229
|
+
|
|
230
|
+
except Exception as e:
|
|
231
|
+
logger.error(f"TransferKVCache error: {e}")
|
|
232
|
+
return {
|
|
233
|
+
"success": False,
|
|
234
|
+
"error_message": str(e),
|
|
235
|
+
}
|
|
236
|
+
|
|
237
|
+
async def CreateSession(self, request, context) -> Dict[str, Any]:
|
|
238
|
+
"""创建会话"""
|
|
239
|
+
try:
|
|
240
|
+
if len(self._sessions) >= self.max_sessions:
|
|
241
|
+
return {
|
|
242
|
+
"success": False,
|
|
243
|
+
"error_message": f"Max sessions reached: {self.max_sessions}",
|
|
244
|
+
}
|
|
245
|
+
|
|
246
|
+
import uuid
|
|
247
|
+
session_id = str(uuid.uuid4())
|
|
248
|
+
|
|
249
|
+
self._sessions[session_id] = {
|
|
250
|
+
"session_id": session_id,
|
|
251
|
+
"model_name": request.model_name,
|
|
252
|
+
"max_length": request.max_length,
|
|
253
|
+
"start_layer": request.start_layer,
|
|
254
|
+
"end_layer": request.end_layer,
|
|
255
|
+
"position": 0,
|
|
256
|
+
"created_at": time.time(),
|
|
257
|
+
}
|
|
258
|
+
|
|
259
|
+
return {
|
|
260
|
+
"session_id": session_id,
|
|
261
|
+
"success": True,
|
|
262
|
+
"error_message": "",
|
|
263
|
+
"cache_tokens_available": self._get_cache_capacity(),
|
|
264
|
+
}
|
|
265
|
+
|
|
266
|
+
except Exception as e:
|
|
267
|
+
logger.error(f"CreateSession error: {e}")
|
|
268
|
+
return {
|
|
269
|
+
"success": False,
|
|
270
|
+
"error_message": str(e),
|
|
271
|
+
}
|
|
272
|
+
|
|
273
|
+
async def CloseSession(self, request, context) -> Dict[str, Any]:
|
|
274
|
+
"""关闭会话"""
|
|
275
|
+
try:
|
|
276
|
+
session = self._sessions.pop(request.session_id, None)
|
|
277
|
+
if session:
|
|
278
|
+
# 清理会话相关的 KV-Cache
|
|
279
|
+
# TODO: 实现 KV-Cache 清理
|
|
280
|
+
pass
|
|
281
|
+
|
|
282
|
+
return {
|
|
283
|
+
"success": True,
|
|
284
|
+
"error_message": "",
|
|
285
|
+
}
|
|
286
|
+
|
|
287
|
+
except Exception as e:
|
|
288
|
+
return {
|
|
289
|
+
"success": False,
|
|
290
|
+
"error_message": str(e),
|
|
291
|
+
}
|
|
292
|
+
|
|
293
|
+
async def HealthCheck(self, request, context) -> Dict[str, Any]:
|
|
294
|
+
"""健康检查"""
|
|
295
|
+
try:
|
|
296
|
+
response = {
|
|
297
|
+
"healthy": True,
|
|
298
|
+
"worker_id": self.worker_id,
|
|
299
|
+
"status": "online",
|
|
300
|
+
"active_sessions": len(self._sessions),
|
|
301
|
+
}
|
|
302
|
+
|
|
303
|
+
if request.include_stats:
|
|
304
|
+
# GPU 信息
|
|
305
|
+
if torch.cuda.is_available():
|
|
306
|
+
response["gpu_memory_used_gb"] = torch.cuda.memory_allocated() / (1024 ** 3)
|
|
307
|
+
response["gpu_memory_total_gb"] = torch.cuda.get_device_properties(0).total_memory / (1024 ** 3)
|
|
308
|
+
|
|
309
|
+
# KV-Cache 信息
|
|
310
|
+
cache_stats = self.kv_cache.get_stats()
|
|
311
|
+
response["cache_tokens_used"] = cache_stats.get("total_blocks", 0) * 16 # block_size
|
|
312
|
+
response["cache_tokens_available"] = self._get_cache_capacity()
|
|
313
|
+
|
|
314
|
+
# 性能指标
|
|
315
|
+
if self._stats["total_requests"] > 0:
|
|
316
|
+
response["avg_latency_ms"] = self._stats["total_latency_ms"] / self._stats["total_requests"]
|
|
317
|
+
response["throughput_tokens_per_sec"] = (
|
|
318
|
+
self._stats["total_tokens"] / (self._stats["total_latency_ms"] / 1000)
|
|
319
|
+
if self._stats["total_latency_ms"] > 0 else 0
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
return response
|
|
323
|
+
|
|
324
|
+
except Exception as e:
|
|
325
|
+
return {
|
|
326
|
+
"healthy": False,
|
|
327
|
+
"status": "error",
|
|
328
|
+
"error_message": str(e),
|
|
329
|
+
}
|
|
330
|
+
|
|
331
|
+
def _get_or_create_session(self, session_id: str, request) -> Dict[str, Any]:
|
|
332
|
+
"""获取或创建会话"""
|
|
333
|
+
if session_id not in self._sessions:
|
|
334
|
+
self._sessions[session_id] = {
|
|
335
|
+
"session_id": session_id,
|
|
336
|
+
"position": 0,
|
|
337
|
+
"created_at": time.time(),
|
|
338
|
+
}
|
|
339
|
+
return self._sessions[session_id]
|
|
340
|
+
|
|
341
|
+
def _create_session(self, session_id: str) -> Dict[str, Any]:
|
|
342
|
+
"""创建新会话"""
|
|
343
|
+
session = {
|
|
344
|
+
"session_id": session_id,
|
|
345
|
+
"position": 0,
|
|
346
|
+
"created_at": time.time(),
|
|
347
|
+
}
|
|
348
|
+
self._sessions[session_id] = session
|
|
349
|
+
return session
|
|
350
|
+
|
|
351
|
+
async def _forward(
|
|
352
|
+
self,
|
|
353
|
+
session: Dict[str, Any],
|
|
354
|
+
hidden_states: torch.Tensor,
|
|
355
|
+
position: int,
|
|
356
|
+
kv_cache_keys: list,
|
|
357
|
+
) -> tuple:
|
|
358
|
+
"""执行前向传播"""
|
|
359
|
+
# 这里应该调用 model_shard 的 forward 方法
|
|
360
|
+
# 简化实现,实际需要处理 KV-Cache
|
|
361
|
+
if self.model_shard is None:
|
|
362
|
+
# 模拟输出
|
|
363
|
+
return hidden_states, kv_cache_keys
|
|
364
|
+
|
|
365
|
+
output, new_kv = self.model_shard.forward(
|
|
366
|
+
hidden_states,
|
|
367
|
+
position_ids=torch.arange(position, position + hidden_states.shape[1], device=hidden_states.device).unsqueeze(0),
|
|
368
|
+
use_cache=True,
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
# 更新会话位置
|
|
372
|
+
session["position"] = position + hidden_states.shape[1]
|
|
373
|
+
|
|
374
|
+
return output, kv_cache_keys
|
|
375
|
+
|
|
376
|
+
async def _forward_to_next(
|
|
377
|
+
self,
|
|
378
|
+
next_address: str,
|
|
379
|
+
next_session_id: str,
|
|
380
|
+
hidden_states: torch.Tensor,
|
|
381
|
+
position: int,
|
|
382
|
+
kv_cache_keys: list,
|
|
383
|
+
) -> None:
|
|
384
|
+
"""转发到下一个 Worker"""
|
|
385
|
+
# TODO: 实现 server-to-server 转发
|
|
386
|
+
# 这需要建立到下一个 Worker 的 gRPC 连接
|
|
387
|
+
logger.debug(f"Forwarding to {next_address} (session: {next_session_id})")
|
|
388
|
+
pass
|
|
389
|
+
|
|
390
|
+
def _get_cache_capacity(self) -> int:
|
|
391
|
+
"""获取缓存容量"""
|
|
392
|
+
stats = self.kv_cache.get_stats()
|
|
393
|
+
free_blocks = stats.get("free_blocks", 0)
|
|
394
|
+
return free_blocks * 16 # block_size
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
class GRPCServer:
|
|
398
|
+
"""
|
|
399
|
+
gRPC 服务器
|
|
400
|
+
|
|
401
|
+
启动和管理 gRPC 服务
|
|
402
|
+
"""
|
|
403
|
+
|
|
404
|
+
def __init__(
|
|
405
|
+
self,
|
|
406
|
+
servicer: InferenceServicer,
|
|
407
|
+
host: str = "0.0.0.0",
|
|
408
|
+
port: int = 50051,
|
|
409
|
+
max_workers: int = 10,
|
|
410
|
+
):
|
|
411
|
+
if not HAS_GRPC:
|
|
412
|
+
raise ImportError("grpcio not installed. Please install with: pip install grpcio")
|
|
413
|
+
|
|
414
|
+
self.servicer = servicer
|
|
415
|
+
self.host = host
|
|
416
|
+
self.port = port
|
|
417
|
+
self.max_workers = max_workers
|
|
418
|
+
|
|
419
|
+
self._server: Optional[grpc_aio.Server] = None
|
|
420
|
+
|
|
421
|
+
async def start(self) -> None:
|
|
422
|
+
"""启动服务器"""
|
|
423
|
+
self._server = grpc_aio.server(
|
|
424
|
+
futures.ThreadPoolExecutor(max_workers=self.max_workers)
|
|
425
|
+
)
|
|
426
|
+
|
|
427
|
+
# 注册服务
|
|
428
|
+
# 注意:需要从 proto 生成的代码注册
|
|
429
|
+
# 这里使用简化的 HTTP 风格服务
|
|
430
|
+
|
|
431
|
+
listen_addr = f"{self.host}:{self.port}"
|
|
432
|
+
self._server.add_insecure_port(listen_addr)
|
|
433
|
+
|
|
434
|
+
await self._server.start()
|
|
435
|
+
logger.info(f"gRPC server started on {listen_addr}")
|
|
436
|
+
|
|
437
|
+
async def stop(self, grace: float = 5.0) -> None:
|
|
438
|
+
"""停止服务器"""
|
|
439
|
+
if self._server:
|
|
440
|
+
await self._server.stop(grace)
|
|
441
|
+
logger.info("gRPC server stopped")
|
|
442
|
+
|
|
443
|
+
async def wait_for_termination(self) -> None:
|
|
444
|
+
"""等待服务器终止"""
|
|
445
|
+
if self._server:
|
|
446
|
+
await self._server.wait_for_termination()
|
|
447
|
+
|
|
448
|
+
|
|
449
|
+
# HTTP 风格的 P2P API(作为 gRPC 的替代方案)
|
|
450
|
+
class HTTPInferenceServer:
|
|
451
|
+
"""
|
|
452
|
+
HTTP 风格的推理服务器
|
|
453
|
+
|
|
454
|
+
使用 FastAPI/aiohttp 实现,作为 gRPC 的简单替代
|
|
455
|
+
"""
|
|
456
|
+
|
|
457
|
+
def __init__(
|
|
458
|
+
self,
|
|
459
|
+
servicer: InferenceServicer,
|
|
460
|
+
host: str = "0.0.0.0",
|
|
461
|
+
port: int = 8080,
|
|
462
|
+
):
|
|
463
|
+
self.servicer = servicer
|
|
464
|
+
self.host = host
|
|
465
|
+
self.port = port
|
|
466
|
+
self._app = None
|
|
467
|
+
self._runner = None
|
|
468
|
+
|
|
469
|
+
async def start(self) -> None:
|
|
470
|
+
"""启动 HTTP 服务器"""
|
|
471
|
+
try:
|
|
472
|
+
from aiohttp import web
|
|
473
|
+
except ImportError:
|
|
474
|
+
raise ImportError("aiohttp not installed. Please install with: pip install aiohttp")
|
|
475
|
+
|
|
476
|
+
self._app = web.Application()
|
|
477
|
+
|
|
478
|
+
# 注册路由
|
|
479
|
+
self._app.router.add_post("/inference/forward", self._handle_forward)
|
|
480
|
+
self._app.router.add_post("/inference/close", self._handle_close_session)
|
|
481
|
+
self._app.router.add_get("/health", self._handle_health)
|
|
482
|
+
|
|
483
|
+
self._runner = web.AppRunner(self._app)
|
|
484
|
+
await self._runner.setup()
|
|
485
|
+
|
|
486
|
+
site = web.TCPSite(self._runner, self.host, self.port)
|
|
487
|
+
await site.start()
|
|
488
|
+
|
|
489
|
+
logger.info(f"HTTP inference server started on {self.host}:{self.port}")
|
|
490
|
+
|
|
491
|
+
async def stop(self) -> None:
|
|
492
|
+
"""停止服务器"""
|
|
493
|
+
if self._runner:
|
|
494
|
+
await self._runner.cleanup()
|
|
495
|
+
logger.info("HTTP inference server stopped")
|
|
496
|
+
|
|
497
|
+
async def _handle_forward(self, request) -> "web.Response":
|
|
498
|
+
"""处理前向传播请求"""
|
|
499
|
+
from aiohttp import web
|
|
500
|
+
import json
|
|
501
|
+
|
|
502
|
+
try:
|
|
503
|
+
data = await request.json()
|
|
504
|
+
|
|
505
|
+
# 创建请求对象
|
|
506
|
+
class MockRequest:
|
|
507
|
+
def __init__(self, d):
|
|
508
|
+
self.session_id = d.get("session_id", "")
|
|
509
|
+
self.input = bytes.fromhex(d.get("input", {}).get("data", ""))
|
|
510
|
+
self.shape = tuple(d.get("input", {}).get("shape", []))
|
|
511
|
+
self.dtype = d.get("input", {}).get("dtype", "float16")
|
|
512
|
+
self.position = d.get("position", 0)
|
|
513
|
+
self.kv_cache_keys = d.get("kv_cache_keys", [])
|
|
514
|
+
|
|
515
|
+
result = await self.servicer.Forward(MockRequest(data), None)
|
|
516
|
+
|
|
517
|
+
return web.json_response(result)
|
|
518
|
+
|
|
519
|
+
except Exception as e:
|
|
520
|
+
logger.error(f"Forward error: {e}")
|
|
521
|
+
return web.json_response(
|
|
522
|
+
{"success": False, "error_message": str(e)},
|
|
523
|
+
status=500
|
|
524
|
+
)
|
|
525
|
+
|
|
526
|
+
async def _handle_close_session(self, request) -> "web.Response":
|
|
527
|
+
"""处理关闭会话请求"""
|
|
528
|
+
from aiohttp import web
|
|
529
|
+
|
|
530
|
+
try:
|
|
531
|
+
data = await request.json()
|
|
532
|
+
|
|
533
|
+
class MockRequest:
|
|
534
|
+
def __init__(self, d):
|
|
535
|
+
self.session_id = d.get("session_id", "")
|
|
536
|
+
|
|
537
|
+
result = await self.servicer.CloseSession(MockRequest(data), None)
|
|
538
|
+
return web.json_response(result)
|
|
539
|
+
|
|
540
|
+
except Exception as e:
|
|
541
|
+
return web.json_response(
|
|
542
|
+
{"success": False, "error_message": str(e)},
|
|
543
|
+
status=500
|
|
544
|
+
)
|
|
545
|
+
|
|
546
|
+
async def _handle_health(self, request) -> "web.Response":
|
|
547
|
+
"""处理健康检查请求"""
|
|
548
|
+
from aiohttp import web
|
|
549
|
+
|
|
550
|
+
try:
|
|
551
|
+
class MockRequest:
|
|
552
|
+
include_stats = True
|
|
553
|
+
|
|
554
|
+
result = await self.servicer.HealthCheck(MockRequest(), None)
|
|
555
|
+
return web.json_response(result)
|
|
556
|
+
|
|
557
|
+
except Exception as e:
|
|
558
|
+
return web.json_response(
|
|
559
|
+
{"healthy": False, "error_message": str(e)},
|
|
560
|
+
status=500
|
|
561
|
+
)
|