gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
|
@@ -0,0 +1,436 @@
|
|
|
1
|
+
"""
|
|
2
|
+
连续批处理器 (Continuous Batcher)
|
|
3
|
+
|
|
4
|
+
实现动态批处理,将多个请求合并为一个批次执行,
|
|
5
|
+
提升 GPU 利用率和整体吞吐量。
|
|
6
|
+
|
|
7
|
+
支持:
|
|
8
|
+
- 动态批处理大小调整
|
|
9
|
+
- 请求优先级队列
|
|
10
|
+
- 超时控制
|
|
11
|
+
- 前缀共享优化
|
|
12
|
+
"""
|
|
13
|
+
import asyncio
|
|
14
|
+
import time
|
|
15
|
+
import logging
|
|
16
|
+
from dataclasses import dataclass, field
|
|
17
|
+
from typing import Dict, Any, List, Optional, Callable, Awaitable
|
|
18
|
+
from enum import Enum
|
|
19
|
+
from collections import defaultdict
|
|
20
|
+
import heapq
|
|
21
|
+
|
|
22
|
+
logger = logging.getLogger(__name__)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class RequestPriority(Enum):
|
|
26
|
+
"""请求优先级"""
|
|
27
|
+
HIGH = 0
|
|
28
|
+
NORMAL = 1
|
|
29
|
+
LOW = 2
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass(order=True)
|
|
33
|
+
class PendingRequest:
|
|
34
|
+
"""待处理请求"""
|
|
35
|
+
priority: int
|
|
36
|
+
timestamp: float
|
|
37
|
+
job_id: str = field(compare=False)
|
|
38
|
+
params: Dict[str, Any] = field(compare=False)
|
|
39
|
+
future: asyncio.Future = field(compare=False)
|
|
40
|
+
prefix_hash: str = field(compare=False, default="")
|
|
41
|
+
|
|
42
|
+
@classmethod
|
|
43
|
+
def create(
|
|
44
|
+
cls,
|
|
45
|
+
job_id: str,
|
|
46
|
+
params: Dict[str, Any],
|
|
47
|
+
priority: RequestPriority = RequestPriority.NORMAL,
|
|
48
|
+
prefix_hash: str = ""
|
|
49
|
+
) -> "PendingRequest":
|
|
50
|
+
return cls(
|
|
51
|
+
priority=priority.value,
|
|
52
|
+
timestamp=time.time(),
|
|
53
|
+
job_id=job_id,
|
|
54
|
+
params=params,
|
|
55
|
+
future=asyncio.Future(),
|
|
56
|
+
prefix_hash=prefix_hash
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
class ContinuousBatcher:
|
|
61
|
+
"""
|
|
62
|
+
连续批处理器
|
|
63
|
+
|
|
64
|
+
将多个推理请求动态合并为批次执行,支持:
|
|
65
|
+
- 最大批处理大小限制
|
|
66
|
+
- 最大等待时间控制
|
|
67
|
+
- 优先级队列
|
|
68
|
+
- 前缀共享批处理
|
|
69
|
+
"""
|
|
70
|
+
|
|
71
|
+
def __init__(
|
|
72
|
+
self,
|
|
73
|
+
engine,
|
|
74
|
+
max_batch_size: int = 32,
|
|
75
|
+
max_wait_ms: float = 50,
|
|
76
|
+
enable_prefix_grouping: bool = True,
|
|
77
|
+
max_queue_size: int = 1000,
|
|
78
|
+
):
|
|
79
|
+
self.engine = engine
|
|
80
|
+
self.max_batch_size = max_batch_size
|
|
81
|
+
self.max_wait_ms = max_wait_ms
|
|
82
|
+
self.enable_prefix_grouping = enable_prefix_grouping
|
|
83
|
+
self.max_queue_size = max_queue_size
|
|
84
|
+
|
|
85
|
+
# 请求队列(优先级堆)
|
|
86
|
+
self._pending: List[PendingRequest] = []
|
|
87
|
+
self._pending_by_prefix: Dict[str, List[PendingRequest]] = defaultdict(list)
|
|
88
|
+
|
|
89
|
+
# 批处理任务
|
|
90
|
+
self._batch_task: Optional[asyncio.Task] = None
|
|
91
|
+
self._lock = asyncio.Lock()
|
|
92
|
+
|
|
93
|
+
# 统计信息
|
|
94
|
+
self._stats = {
|
|
95
|
+
"total_requests": 0,
|
|
96
|
+
"total_batches": 0,
|
|
97
|
+
"avg_batch_size": 0.0,
|
|
98
|
+
"avg_wait_time_ms": 0.0,
|
|
99
|
+
}
|
|
100
|
+
|
|
101
|
+
# 运行状态
|
|
102
|
+
self._running = False
|
|
103
|
+
|
|
104
|
+
async def start(self) -> None:
|
|
105
|
+
"""启动批处理器"""
|
|
106
|
+
self._running = True
|
|
107
|
+
logger.info("ContinuousBatcher started")
|
|
108
|
+
|
|
109
|
+
async def stop(self) -> None:
|
|
110
|
+
"""停止批处理器"""
|
|
111
|
+
self._running = False
|
|
112
|
+
|
|
113
|
+
# 取消所有待处理请求
|
|
114
|
+
async with self._lock:
|
|
115
|
+
for req in self._pending:
|
|
116
|
+
if not req.future.done():
|
|
117
|
+
req.future.cancel()
|
|
118
|
+
self._pending.clear()
|
|
119
|
+
self._pending_by_prefix.clear()
|
|
120
|
+
|
|
121
|
+
if self._batch_task:
|
|
122
|
+
self._batch_task.cancel()
|
|
123
|
+
try:
|
|
124
|
+
await self._batch_task
|
|
125
|
+
except asyncio.CancelledError:
|
|
126
|
+
pass
|
|
127
|
+
|
|
128
|
+
logger.info("ContinuousBatcher stopped")
|
|
129
|
+
|
|
130
|
+
async def submit(
|
|
131
|
+
self,
|
|
132
|
+
job_id: str,
|
|
133
|
+
params: Dict[str, Any],
|
|
134
|
+
priority: RequestPriority = RequestPriority.NORMAL,
|
|
135
|
+
timeout: float = 120.0,
|
|
136
|
+
) -> Dict[str, Any]:
|
|
137
|
+
"""
|
|
138
|
+
提交推理请求
|
|
139
|
+
|
|
140
|
+
Args:
|
|
141
|
+
job_id: 任务ID
|
|
142
|
+
params: 推理参数
|
|
143
|
+
priority: 请求优先级
|
|
144
|
+
timeout: 超时时间(秒)
|
|
145
|
+
|
|
146
|
+
Returns:
|
|
147
|
+
推理结果
|
|
148
|
+
"""
|
|
149
|
+
if not self._running:
|
|
150
|
+
raise RuntimeError("Batcher is not running")
|
|
151
|
+
|
|
152
|
+
if len(self._pending) >= self.max_queue_size:
|
|
153
|
+
raise RuntimeError(f"Queue full (max={self.max_queue_size})")
|
|
154
|
+
|
|
155
|
+
# 计算前缀哈希(用于分组)
|
|
156
|
+
prefix_hash = ""
|
|
157
|
+
if self.enable_prefix_grouping:
|
|
158
|
+
prefix_hash = self._compute_prefix_hash(params)
|
|
159
|
+
|
|
160
|
+
# 创建请求
|
|
161
|
+
request = PendingRequest.create(
|
|
162
|
+
job_id=job_id,
|
|
163
|
+
params=params,
|
|
164
|
+
priority=priority,
|
|
165
|
+
prefix_hash=prefix_hash
|
|
166
|
+
)
|
|
167
|
+
|
|
168
|
+
async with self._lock:
|
|
169
|
+
heapq.heappush(self._pending, request)
|
|
170
|
+
|
|
171
|
+
if self.enable_prefix_grouping and prefix_hash:
|
|
172
|
+
self._pending_by_prefix[prefix_hash].append(request)
|
|
173
|
+
|
|
174
|
+
self._stats["total_requests"] += 1
|
|
175
|
+
|
|
176
|
+
# 检查是否应该触发批处理
|
|
177
|
+
if len(self._pending) >= self.max_batch_size:
|
|
178
|
+
# 立即处理满批次
|
|
179
|
+
asyncio.create_task(self._process_batch())
|
|
180
|
+
elif self._batch_task is None or self._batch_task.done():
|
|
181
|
+
# 启动等待定时器
|
|
182
|
+
self._batch_task = asyncio.create_task(self._wait_and_process())
|
|
183
|
+
|
|
184
|
+
# 等待结果
|
|
185
|
+
try:
|
|
186
|
+
return await asyncio.wait_for(request.future, timeout=timeout)
|
|
187
|
+
except asyncio.TimeoutError:
|
|
188
|
+
# 超时,尝试从队列中移除
|
|
189
|
+
async with self._lock:
|
|
190
|
+
try:
|
|
191
|
+
self._pending.remove(request)
|
|
192
|
+
heapq.heapify(self._pending)
|
|
193
|
+
except ValueError:
|
|
194
|
+
pass # 可能已被处理
|
|
195
|
+
raise
|
|
196
|
+
|
|
197
|
+
async def _wait_and_process(self) -> None:
|
|
198
|
+
"""等待指定时间后处理批次"""
|
|
199
|
+
await asyncio.sleep(self.max_wait_ms / 1000)
|
|
200
|
+
await self._process_batch()
|
|
201
|
+
|
|
202
|
+
async def _process_batch(self) -> None:
|
|
203
|
+
"""处理一个批次"""
|
|
204
|
+
async with self._lock:
|
|
205
|
+
if not self._pending:
|
|
206
|
+
return
|
|
207
|
+
|
|
208
|
+
batch_start_time = time.time()
|
|
209
|
+
|
|
210
|
+
# 选择要处理的请求
|
|
211
|
+
if self.enable_prefix_grouping:
|
|
212
|
+
batch = self._select_batch_with_prefix_grouping()
|
|
213
|
+
else:
|
|
214
|
+
batch = self._select_batch_simple()
|
|
215
|
+
|
|
216
|
+
if not batch:
|
|
217
|
+
return
|
|
218
|
+
|
|
219
|
+
# 从队列中移除
|
|
220
|
+
for req in batch:
|
|
221
|
+
try:
|
|
222
|
+
self._pending.remove(req)
|
|
223
|
+
except ValueError:
|
|
224
|
+
pass
|
|
225
|
+
if req.prefix_hash:
|
|
226
|
+
try:
|
|
227
|
+
self._pending_by_prefix[req.prefix_hash].remove(req)
|
|
228
|
+
except ValueError:
|
|
229
|
+
pass
|
|
230
|
+
heapq.heapify(self._pending)
|
|
231
|
+
|
|
232
|
+
# 执行批量推理(在锁外执行)
|
|
233
|
+
try:
|
|
234
|
+
results = await self._execute_batch(batch)
|
|
235
|
+
|
|
236
|
+
# 设置结果
|
|
237
|
+
for req, result in zip(batch, results):
|
|
238
|
+
if not req.future.done():
|
|
239
|
+
if isinstance(result, Exception):
|
|
240
|
+
req.future.set_exception(result)
|
|
241
|
+
else:
|
|
242
|
+
req.future.set_result(result)
|
|
243
|
+
|
|
244
|
+
except Exception as e:
|
|
245
|
+
logger.error(f"Batch processing error: {e}")
|
|
246
|
+
# 设置所有请求失败
|
|
247
|
+
for req in batch:
|
|
248
|
+
if not req.future.done():
|
|
249
|
+
req.future.set_exception(e)
|
|
250
|
+
|
|
251
|
+
# 更新统计
|
|
252
|
+
batch_time = (time.time() - batch_start_time) * 1000
|
|
253
|
+
self._stats["total_batches"] += 1
|
|
254
|
+
self._stats["avg_batch_size"] = (
|
|
255
|
+
(self._stats["avg_batch_size"] * (self._stats["total_batches"] - 1) + len(batch))
|
|
256
|
+
/ self._stats["total_batches"]
|
|
257
|
+
)
|
|
258
|
+
self._stats["avg_wait_time_ms"] = (
|
|
259
|
+
(self._stats["avg_wait_time_ms"] * (self._stats["total_batches"] - 1) + batch_time)
|
|
260
|
+
/ self._stats["total_batches"]
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
def _select_batch_simple(self) -> List[PendingRequest]:
|
|
264
|
+
"""简单的批次选择(按优先级)"""
|
|
265
|
+
return [heapq.heappop(self._pending) for _ in range(min(len(self._pending), self.max_batch_size))]
|
|
266
|
+
|
|
267
|
+
def _select_batch_with_prefix_grouping(self) -> List[PendingRequest]:
|
|
268
|
+
"""带前缀分组的批次选择"""
|
|
269
|
+
batch = []
|
|
270
|
+
|
|
271
|
+
# 首先尝试找到最大的前缀组
|
|
272
|
+
if self._pending_by_prefix:
|
|
273
|
+
# 按组大小排序
|
|
274
|
+
sorted_groups = sorted(
|
|
275
|
+
self._pending_by_prefix.items(),
|
|
276
|
+
key=lambda x: len(x[1]),
|
|
277
|
+
reverse=True
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
for prefix_hash, group in sorted_groups:
|
|
281
|
+
if not group:
|
|
282
|
+
continue
|
|
283
|
+
|
|
284
|
+
# 取该组的请求
|
|
285
|
+
take_count = min(len(group), self.max_batch_size - len(batch))
|
|
286
|
+
batch.extend(group[:take_count])
|
|
287
|
+
|
|
288
|
+
if len(batch) >= self.max_batch_size:
|
|
289
|
+
break
|
|
290
|
+
|
|
291
|
+
# 如果还有空间,添加没有前缀的请求
|
|
292
|
+
remaining = self.max_batch_size - len(batch)
|
|
293
|
+
if remaining > 0:
|
|
294
|
+
no_prefix_requests = [
|
|
295
|
+
req for req in self._pending
|
|
296
|
+
if not req.prefix_hash and req not in batch
|
|
297
|
+
]
|
|
298
|
+
batch.extend(no_prefix_requests[:remaining])
|
|
299
|
+
|
|
300
|
+
return batch
|
|
301
|
+
|
|
302
|
+
async def _execute_batch(
|
|
303
|
+
self,
|
|
304
|
+
batch: List[PendingRequest]
|
|
305
|
+
) -> List[Any]:
|
|
306
|
+
"""执行批量推理"""
|
|
307
|
+
params_list = [req.params for req in batch]
|
|
308
|
+
|
|
309
|
+
if hasattr(self.engine, "batch_inference_async"):
|
|
310
|
+
return await self.engine.batch_inference_async(params_list)
|
|
311
|
+
elif hasattr(self.engine, "batch_inference"):
|
|
312
|
+
# 在线程池中执行同步方法
|
|
313
|
+
loop = asyncio.get_event_loop()
|
|
314
|
+
return await loop.run_in_executor(
|
|
315
|
+
None,
|
|
316
|
+
self.engine.batch_inference,
|
|
317
|
+
params_list
|
|
318
|
+
)
|
|
319
|
+
else:
|
|
320
|
+
# 回退到串行执行
|
|
321
|
+
results = []
|
|
322
|
+
for params in params_list:
|
|
323
|
+
try:
|
|
324
|
+
if hasattr(self.engine, "inference_async"):
|
|
325
|
+
result = await self.engine.inference_async(params)
|
|
326
|
+
else:
|
|
327
|
+
loop = asyncio.get_event_loop()
|
|
328
|
+
result = await loop.run_in_executor(
|
|
329
|
+
None,
|
|
330
|
+
self.engine.inference,
|
|
331
|
+
params
|
|
332
|
+
)
|
|
333
|
+
results.append(result)
|
|
334
|
+
except Exception as e:
|
|
335
|
+
results.append(e)
|
|
336
|
+
return results
|
|
337
|
+
|
|
338
|
+
def _compute_prefix_hash(self, params: Dict[str, Any]) -> str:
|
|
339
|
+
"""计算请求的前缀哈希"""
|
|
340
|
+
import hashlib
|
|
341
|
+
|
|
342
|
+
messages = params.get("messages", [])
|
|
343
|
+
if not messages:
|
|
344
|
+
return ""
|
|
345
|
+
|
|
346
|
+
# 使用系统消息作为前缀
|
|
347
|
+
system_messages = [
|
|
348
|
+
m.get("content", "")
|
|
349
|
+
for m in messages
|
|
350
|
+
if m.get("role") == "system"
|
|
351
|
+
]
|
|
352
|
+
|
|
353
|
+
if not system_messages:
|
|
354
|
+
return ""
|
|
355
|
+
|
|
356
|
+
prefix_str = "".join(system_messages)
|
|
357
|
+
return hashlib.sha256(prefix_str.encode()).hexdigest()[:16]
|
|
358
|
+
|
|
359
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
360
|
+
"""获取统计信息"""
|
|
361
|
+
return {
|
|
362
|
+
**self._stats,
|
|
363
|
+
"queue_size": len(self._pending),
|
|
364
|
+
"prefix_groups": len(self._pending_by_prefix),
|
|
365
|
+
}
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
class AdaptiveBatcher(ContinuousBatcher):
|
|
369
|
+
"""
|
|
370
|
+
自适应批处理器
|
|
371
|
+
|
|
372
|
+
根据负载和延迟要求动态调整批处理参数
|
|
373
|
+
"""
|
|
374
|
+
|
|
375
|
+
def __init__(
|
|
376
|
+
self,
|
|
377
|
+
engine,
|
|
378
|
+
min_batch_size: int = 1,
|
|
379
|
+
max_batch_size: int = 64,
|
|
380
|
+
target_latency_ms: float = 100,
|
|
381
|
+
**kwargs
|
|
382
|
+
):
|
|
383
|
+
super().__init__(engine, max_batch_size=max_batch_size, **kwargs)
|
|
384
|
+
self.min_batch_size = min_batch_size
|
|
385
|
+
self.target_latency_ms = target_latency_ms
|
|
386
|
+
|
|
387
|
+
# 自适应参数
|
|
388
|
+
self._current_batch_size = max_batch_size // 2
|
|
389
|
+
self._latency_history: List[float] = []
|
|
390
|
+
self._max_history = 100
|
|
391
|
+
|
|
392
|
+
async def _process_batch(self) -> None:
|
|
393
|
+
"""处理批次并自适应调整参数"""
|
|
394
|
+
start_time = time.time()
|
|
395
|
+
|
|
396
|
+
# 使用当前自适应的批次大小
|
|
397
|
+
original_max = self.max_batch_size
|
|
398
|
+
self.max_batch_size = self._current_batch_size
|
|
399
|
+
|
|
400
|
+
await super()._process_batch()
|
|
401
|
+
|
|
402
|
+
self.max_batch_size = original_max
|
|
403
|
+
|
|
404
|
+
# 记录延迟
|
|
405
|
+
latency_ms = (time.time() - start_time) * 1000
|
|
406
|
+
self._latency_history.append(latency_ms)
|
|
407
|
+
if len(self._latency_history) > self._max_history:
|
|
408
|
+
self._latency_history.pop(0)
|
|
409
|
+
|
|
410
|
+
# 自适应调整
|
|
411
|
+
self._adapt_batch_size()
|
|
412
|
+
|
|
413
|
+
def _adapt_batch_size(self) -> None:
|
|
414
|
+
"""根据延迟历史调整批处理大小"""
|
|
415
|
+
if len(self._latency_history) < 10:
|
|
416
|
+
return
|
|
417
|
+
|
|
418
|
+
avg_latency = sum(self._latency_history[-10:]) / 10
|
|
419
|
+
|
|
420
|
+
if avg_latency > self.target_latency_ms * 1.2:
|
|
421
|
+
# 延迟过高,减小批次大小
|
|
422
|
+
self._current_batch_size = max(
|
|
423
|
+
self.min_batch_size,
|
|
424
|
+
int(self._current_batch_size * 0.8)
|
|
425
|
+
)
|
|
426
|
+
elif avg_latency < self.target_latency_ms * 0.8:
|
|
427
|
+
# 延迟较低,增大批次大小
|
|
428
|
+
self._current_batch_size = min(
|
|
429
|
+
self.max_batch_size,
|
|
430
|
+
int(self._current_batch_size * 1.2)
|
|
431
|
+
)
|
|
432
|
+
|
|
433
|
+
logger.debug(
|
|
434
|
+
f"Adaptive batch size: {self._current_batch_size} "
|
|
435
|
+
f"(avg latency: {avg_latency:.1f}ms)"
|
|
436
|
+
)
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
#!/usr/bin/env node
|
|
2
|
+
/**
|
|
3
|
+
* GPU Worker CLI - Node.js 入口
|
|
4
|
+
* 包装 Python Worker,提供简单的 npm/npx 安装体验
|
|
5
|
+
*/
|
|
6
|
+
|
|
7
|
+
const { Command } = require('commander');
|
|
8
|
+
const chalk = require('chalk');
|
|
9
|
+
const ora = require('ora');
|
|
10
|
+
const inquirer = require('inquirer');
|
|
11
|
+
const { spawn, execSync } = require('child_process');
|
|
12
|
+
const path = require('path');
|
|
13
|
+
const fs = require('fs');
|
|
14
|
+
const which = require('which');
|
|
15
|
+
|
|
16
|
+
const PACKAGE_DIR = path.resolve(__dirname, '..');
|
|
17
|
+
const PYTHON_DIR = PACKAGE_DIR;
|
|
18
|
+
const CONFIG_FILE = path.join(process.cwd(), 'config.yaml');
|
|
19
|
+
|
|
20
|
+
// 检测 Python
|
|
21
|
+
function findPython() {
|
|
22
|
+
const pythonCommands = ['python3', 'python', 'py'];
|
|
23
|
+
|
|
24
|
+
for (const cmd of pythonCommands) {
|
|
25
|
+
try {
|
|
26
|
+
const pythonPath = which.sync(cmd);
|
|
27
|
+
// 验证版本
|
|
28
|
+
const version = execSync(`${cmd} --version`, { encoding: 'utf8' });
|
|
29
|
+
const match = version.match(/Python (\d+)\.(\d+)/);
|
|
30
|
+
if (match && parseInt(match[1]) >= 3 && parseInt(match[2]) >= 9) {
|
|
31
|
+
return cmd;
|
|
32
|
+
}
|
|
33
|
+
} catch (e) {
|
|
34
|
+
continue;
|
|
35
|
+
}
|
|
36
|
+
}
|
|
37
|
+
return null;
|
|
38
|
+
}
|
|
39
|
+
|
|
40
|
+
// 检查虚拟环境
|
|
41
|
+
function getVenvPython() {
|
|
42
|
+
const venvPath = path.join(PACKAGE_DIR, '.venv');
|
|
43
|
+
|
|
44
|
+
if (process.platform === 'win32') {
|
|
45
|
+
const pythonPath = path.join(venvPath, 'Scripts', 'python.exe');
|
|
46
|
+
if (fs.existsSync(pythonPath)) return pythonPath;
|
|
47
|
+
} else {
|
|
48
|
+
const pythonPath = path.join(venvPath, 'bin', 'python');
|
|
49
|
+
if (fs.existsSync(pythonPath)) return pythonPath;
|
|
50
|
+
}
|
|
51
|
+
return null;
|
|
52
|
+
}
|
|
53
|
+
|
|
54
|
+
// 创建虚拟环境
|
|
55
|
+
async function createVenv(pythonCmd) {
|
|
56
|
+
const spinner = ora('Creating Python virtual environment...').start();
|
|
57
|
+
const venvPath = path.join(PACKAGE_DIR, '.venv');
|
|
58
|
+
|
|
59
|
+
try {
|
|
60
|
+
execSync(`${pythonCmd} -m venv "${venvPath}"`, { stdio: 'pipe' });
|
|
61
|
+
spinner.succeed('Virtual environment created');
|
|
62
|
+
return true;
|
|
63
|
+
} catch (e) {
|
|
64
|
+
spinner.fail('Failed to create virtual environment');
|
|
65
|
+
console.error(chalk.red(e.message));
|
|
66
|
+
return false;
|
|
67
|
+
}
|
|
68
|
+
}
|
|
69
|
+
|
|
70
|
+
// 安装 Python 依赖
|
|
71
|
+
async function installDependencies() {
|
|
72
|
+
const venvPython = getVenvPython();
|
|
73
|
+
if (!venvPython) {
|
|
74
|
+
console.error(chalk.red('Virtual environment not found'));
|
|
75
|
+
return false;
|
|
76
|
+
}
|
|
77
|
+
|
|
78
|
+
const spinner = ora('Installing Python dependencies...').start();
|
|
79
|
+
const requirementsFile = path.join(PACKAGE_DIR, 'requirements.txt');
|
|
80
|
+
|
|
81
|
+
try {
|
|
82
|
+
execSync(`"${venvPython}" -m pip install -r "${requirementsFile}" -q`, {
|
|
83
|
+
stdio: 'pipe',
|
|
84
|
+
timeout: 600000 // 10分钟超时
|
|
85
|
+
});
|
|
86
|
+
spinner.succeed('Dependencies installed');
|
|
87
|
+
return true;
|
|
88
|
+
} catch (e) {
|
|
89
|
+
spinner.fail('Failed to install dependencies');
|
|
90
|
+
console.error(chalk.red(e.message));
|
|
91
|
+
return false;
|
|
92
|
+
}
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
// 运行 Python CLI
|
|
96
|
+
function runPythonCLI(args) {
|
|
97
|
+
let pythonCmd = getVenvPython() || findPython();
|
|
98
|
+
|
|
99
|
+
if (!pythonCmd) {
|
|
100
|
+
console.error(chalk.red('Python 3.9+ not found!'));
|
|
101
|
+
console.log(chalk.yellow('Please install Python 3.9 or higher:'));
|
|
102
|
+
console.log(' - Windows: https://www.python.org/downloads/');
|
|
103
|
+
console.log(' - macOS: brew install python@3.11');
|
|
104
|
+
console.log(' - Linux: sudo apt install python3.11');
|
|
105
|
+
process.exit(1);
|
|
106
|
+
}
|
|
107
|
+
|
|
108
|
+
const cliPath = path.join(PACKAGE_DIR, 'cli.py');
|
|
109
|
+
|
|
110
|
+
const proc = spawn(pythonCmd, [cliPath, ...args], {
|
|
111
|
+
stdio: 'inherit',
|
|
112
|
+
cwd: process.cwd()
|
|
113
|
+
});
|
|
114
|
+
|
|
115
|
+
proc.on('close', (code) => {
|
|
116
|
+
process.exit(code);
|
|
117
|
+
});
|
|
118
|
+
|
|
119
|
+
proc.on('error', (err) => {
|
|
120
|
+
console.error(chalk.red('Failed to start Python process:'), err.message);
|
|
121
|
+
process.exit(1);
|
|
122
|
+
});
|
|
123
|
+
}
|
|
124
|
+
|
|
125
|
+
// 初始化检查
|
|
126
|
+
async function ensureSetup() {
|
|
127
|
+
const venvPython = getVenvPython();
|
|
128
|
+
|
|
129
|
+
if (!venvPython) {
|
|
130
|
+
console.log(chalk.cyan('First time setup detected. Setting up environment...\n'));
|
|
131
|
+
|
|
132
|
+
const pythonCmd = findPython();
|
|
133
|
+
if (!pythonCmd) {
|
|
134
|
+
console.error(chalk.red('Python 3.9+ is required but not found!'));
|
|
135
|
+
console.log(chalk.yellow('\nPlease install Python:'));
|
|
136
|
+
console.log(' - Windows: https://www.python.org/downloads/');
|
|
137
|
+
console.log(' - macOS: brew install python@3.11');
|
|
138
|
+
console.log(' - Linux: sudo apt install python3.11');
|
|
139
|
+
process.exit(1);
|
|
140
|
+
}
|
|
141
|
+
|
|
142
|
+
console.log(chalk.green(`Found Python: ${pythonCmd}`));
|
|
143
|
+
|
|
144
|
+
if (!await createVenv(pythonCmd)) {
|
|
145
|
+
process.exit(1);
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
if (!await installDependencies()) {
|
|
149
|
+
process.exit(1);
|
|
150
|
+
}
|
|
151
|
+
|
|
152
|
+
console.log(chalk.green('\n✓ Setup complete!\n'));
|
|
153
|
+
}
|
|
154
|
+
}
|
|
155
|
+
|
|
156
|
+
// 主程序
|
|
157
|
+
const program = new Command();
|
|
158
|
+
|
|
159
|
+
program
|
|
160
|
+
.name('gpu-worker')
|
|
161
|
+
.description('分布式GPU推理 Worker 节点')
|
|
162
|
+
.version('1.0.0');
|
|
163
|
+
|
|
164
|
+
program
|
|
165
|
+
.command('install')
|
|
166
|
+
.description('安装/更新 Python 依赖')
|
|
167
|
+
.action(async () => {
|
|
168
|
+
await ensureSetup();
|
|
169
|
+
await installDependencies();
|
|
170
|
+
});
|
|
171
|
+
|
|
172
|
+
program
|
|
173
|
+
.command('configure')
|
|
174
|
+
.description('交互式配置向导')
|
|
175
|
+
.action(async () => {
|
|
176
|
+
await ensureSetup();
|
|
177
|
+
runPythonCLI(['configure']);
|
|
178
|
+
});
|
|
179
|
+
|
|
180
|
+
program
|
|
181
|
+
.command('start')
|
|
182
|
+
.description('启动 Worker')
|
|
183
|
+
.option('-c, --config <path>', '配置文件路径', 'config.yaml')
|
|
184
|
+
.action(async (options) => {
|
|
185
|
+
await ensureSetup();
|
|
186
|
+
|
|
187
|
+
// 检查配置文件
|
|
188
|
+
const configPath = path.resolve(options.config);
|
|
189
|
+
if (!fs.existsSync(configPath)) {
|
|
190
|
+
console.log(chalk.yellow('No config file found. Starting configuration wizard...\n'));
|
|
191
|
+
runPythonCLI(['configure']);
|
|
192
|
+
return;
|
|
193
|
+
}
|
|
194
|
+
|
|
195
|
+
runPythonCLI(['start', '-c', configPath]);
|
|
196
|
+
});
|
|
197
|
+
|
|
198
|
+
program
|
|
199
|
+
.command('status')
|
|
200
|
+
.description('查看状态')
|
|
201
|
+
.action(async () => {
|
|
202
|
+
await ensureSetup();
|
|
203
|
+
runPythonCLI(['status']);
|
|
204
|
+
});
|
|
205
|
+
|
|
206
|
+
program
|
|
207
|
+
.command('set <key> <value>')
|
|
208
|
+
.description('设置配置项')
|
|
209
|
+
.action(async (key, value) => {
|
|
210
|
+
await ensureSetup();
|
|
211
|
+
runPythonCLI(['set', key, value]);
|
|
212
|
+
});
|
|
213
|
+
|
|
214
|
+
program
|
|
215
|
+
.command('setup')
|
|
216
|
+
.description('初始化环境(创建虚拟环境并安装依赖)')
|
|
217
|
+
.action(async () => {
|
|
218
|
+
const pythonCmd = findPython();
|
|
219
|
+
if (!pythonCmd) {
|
|
220
|
+
console.error(chalk.red('Python 3.9+ not found!'));
|
|
221
|
+
process.exit(1);
|
|
222
|
+
}
|
|
223
|
+
|
|
224
|
+
console.log(chalk.cyan('Setting up GPU Worker environment...\n'));
|
|
225
|
+
console.log(chalk.green(`Python: ${pythonCmd}`));
|
|
226
|
+
|
|
227
|
+
await createVenv(pythonCmd);
|
|
228
|
+
await installDependencies();
|
|
229
|
+
|
|
230
|
+
console.log(chalk.green('\n✓ Setup complete!'));
|
|
231
|
+
console.log(chalk.cyan('\nNext steps:'));
|
|
232
|
+
console.log(' 1. Run: gpu-worker configure');
|
|
233
|
+
console.log(' 2. Run: gpu-worker start');
|
|
234
|
+
});
|
|
235
|
+
|
|
236
|
+
// 快速启动命令 (无参数时的默认行为)
|
|
237
|
+
program
|
|
238
|
+
.command('quick', { isDefault: true, hidden: true })
|
|
239
|
+
.action(async () => {
|
|
240
|
+
await ensureSetup();
|
|
241
|
+
|
|
242
|
+
console.log(chalk.cyan.bold('\n GPU Worker - 分布式GPU推理节点\n'));
|
|
243
|
+
|
|
244
|
+
const choices = [
|
|
245
|
+
{ name: '🚀 启动 Worker', value: 'start' },
|
|
246
|
+
{ name: '⚙️ 配置向导', value: 'configure' },
|
|
247
|
+
{ name: '📊 查看状态', value: 'status' },
|
|
248
|
+
{ name: '📦 安装依赖', value: 'install' },
|
|
249
|
+
{ name: '❌ 退出', value: 'exit' }
|
|
250
|
+
];
|
|
251
|
+
|
|
252
|
+
const { action } = await inquirer.prompt([{
|
|
253
|
+
type: 'list',
|
|
254
|
+
name: 'action',
|
|
255
|
+
message: '请选择操作:',
|
|
256
|
+
choices
|
|
257
|
+
}]);
|
|
258
|
+
|
|
259
|
+
if (action === 'exit') {
|
|
260
|
+
process.exit(0);
|
|
261
|
+
}
|
|
262
|
+
|
|
263
|
+
if (action === 'start') {
|
|
264
|
+
const configPath = path.join(process.cwd(), 'config.yaml');
|
|
265
|
+
if (!fs.existsSync(configPath)) {
|
|
266
|
+
console.log(chalk.yellow('\n未找到配置文件,先进行配置...\n'));
|
|
267
|
+
runPythonCLI(['configure']);
|
|
268
|
+
return;
|
|
269
|
+
}
|
|
270
|
+
}
|
|
271
|
+
|
|
272
|
+
runPythonCLI([action]);
|
|
273
|
+
});
|
|
274
|
+
|
|
275
|
+
program.parse();
|