gpu-worker 1.0.0
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- package/README.md +115 -0
- package/api_client.py +288 -0
- package/batch_processor.py +436 -0
- package/bin/gpu-worker.js +275 -0
- package/cli.py +729 -0
- package/config.2gb.yaml +32 -0
- package/config.8gb.yaml +29 -0
- package/config.example.yaml +72 -0
- package/config.py +213 -0
- package/direct_server.py +140 -0
- package/distributed/__init__.py +35 -0
- package/distributed/grpc_server.py +561 -0
- package/distributed/kv_cache.py +555 -0
- package/distributed/model_shard.py +465 -0
- package/distributed/session.py +455 -0
- package/engines/__init__.py +215 -0
- package/engines/base.py +57 -0
- package/engines/image_gen.py +83 -0
- package/engines/llm.py +97 -0
- package/engines/llm_base.py +216 -0
- package/engines/llm_sglang.py +489 -0
- package/engines/llm_vllm.py +539 -0
- package/engines/speculative.py +513 -0
- package/engines/vision.py +139 -0
- package/machine_id.py +200 -0
- package/main.py +521 -0
- package/package.json +64 -0
- package/requirements-sglang.txt +12 -0
- package/requirements-vllm.txt +15 -0
- package/requirements.txt +35 -0
- package/scripts/postinstall.js +60 -0
- package/setup.py +43 -0
package/main.py
ADDED
|
@@ -0,0 +1,521 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Worker主程序 - 轻量版
|
|
3
|
+
核心逻辑在服务端,Worker仅负责:
|
|
4
|
+
- 注册和心跳
|
|
5
|
+
- 执行推理任务
|
|
6
|
+
- 汇报状态
|
|
7
|
+
"""
|
|
8
|
+
import time
|
|
9
|
+
import signal
|
|
10
|
+
import logging
|
|
11
|
+
from threading import Thread, Event
|
|
12
|
+
from typing import Optional, Dict, Any, List
|
|
13
|
+
from datetime import datetime
|
|
14
|
+
import torch
|
|
15
|
+
import sys
|
|
16
|
+
|
|
17
|
+
from config import WorkerConfig, load_config
|
|
18
|
+
from api_client import APIClient
|
|
19
|
+
from engines import ENGINE_REGISTRY, BaseEngine
|
|
20
|
+
|
|
21
|
+
logging.basicConfig(
|
|
22
|
+
level=logging.INFO,
|
|
23
|
+
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
|
|
24
|
+
)
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class Worker:
|
|
29
|
+
"""
|
|
30
|
+
分布式GPU推理Worker - 轻量版
|
|
31
|
+
|
|
32
|
+
核心原则:
|
|
33
|
+
- Worker是"哑终端",服务端做决策
|
|
34
|
+
- 配置从服务端获取
|
|
35
|
+
- 负载控制由服务端判断
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
def __init__(self, config: WorkerConfig = None):
|
|
39
|
+
self.config = config or load_config()
|
|
40
|
+
self.api_client: Optional[APIClient] = None
|
|
41
|
+
self.engines: Dict[str, BaseEngine] = {}
|
|
42
|
+
|
|
43
|
+
# 认证信息
|
|
44
|
+
self.worker_id: Optional[str] = self.config.worker_id
|
|
45
|
+
self.token: Optional[str] = self.config.token
|
|
46
|
+
self.refresh_token: Optional[str] = None
|
|
47
|
+
self.signing_secret: Optional[str] = None
|
|
48
|
+
self.token_expires_at: Optional[datetime] = None
|
|
49
|
+
|
|
50
|
+
# 状态
|
|
51
|
+
self.status = "initializing"
|
|
52
|
+
self.current_job_id: Optional[str] = None
|
|
53
|
+
self.running = False
|
|
54
|
+
self.accepting_jobs = True
|
|
55
|
+
|
|
56
|
+
# 服务端配置(从服务器获取)
|
|
57
|
+
self.remote_config: Dict[str, Any] = {}
|
|
58
|
+
self.config_version = 0
|
|
59
|
+
|
|
60
|
+
# 用于优雅关闭
|
|
61
|
+
self.shutdown_event = Event()
|
|
62
|
+
|
|
63
|
+
# 直连服务器(可选)
|
|
64
|
+
self.direct_server = None
|
|
65
|
+
|
|
66
|
+
def _get_gpu_info(self) -> Dict[str, Any]:
|
|
67
|
+
"""获取GPU信息"""
|
|
68
|
+
if torch.cuda.is_available():
|
|
69
|
+
props = torch.cuda.get_device_properties(0)
|
|
70
|
+
return {
|
|
71
|
+
"model": torch.cuda.get_device_name(0),
|
|
72
|
+
"memory_total_gb": round(props.total_memory / 1024**3, 2),
|
|
73
|
+
"memory_used_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
|
|
74
|
+
"gpu_count": torch.cuda.device_count()
|
|
75
|
+
}
|
|
76
|
+
return {
|
|
77
|
+
"model": "CPU Only",
|
|
78
|
+
"memory_total_gb": 0,
|
|
79
|
+
"memory_used_gb": 0,
|
|
80
|
+
"gpu_count": 0
|
|
81
|
+
}
|
|
82
|
+
|
|
83
|
+
def _register(self):
|
|
84
|
+
"""注册Worker到服务器"""
|
|
85
|
+
if self.worker_id and self.token:
|
|
86
|
+
logger.info(f"Using existing worker ID: {self.worker_id}")
|
|
87
|
+
self.api_client.set_credentials(self.token, self.signing_secret)
|
|
88
|
+
|
|
89
|
+
# 验证凭据是否有效
|
|
90
|
+
if not self._verify_credentials():
|
|
91
|
+
logger.warning("Existing credentials invalid, re-registering...")
|
|
92
|
+
self.worker_id = None
|
|
93
|
+
self.token = None
|
|
94
|
+
|
|
95
|
+
if not self.worker_id:
|
|
96
|
+
self._do_register()
|
|
97
|
+
|
|
98
|
+
# 获取远程配置
|
|
99
|
+
self._fetch_remote_config()
|
|
100
|
+
|
|
101
|
+
def _do_register(self):
|
|
102
|
+
"""执行注册"""
|
|
103
|
+
gpu_info = self._get_gpu_info()
|
|
104
|
+
|
|
105
|
+
# 构建直连URL
|
|
106
|
+
direct_url = None
|
|
107
|
+
if self.config.direct.enabled and self.config.direct.public_url:
|
|
108
|
+
direct_url = self.config.direct.public_url
|
|
109
|
+
|
|
110
|
+
result = self.api_client.register(
|
|
111
|
+
name=self.config.name or f"Worker-{gpu_info.get('model', 'Unknown')[:20]}",
|
|
112
|
+
region=self.config.region,
|
|
113
|
+
country=self.config.country,
|
|
114
|
+
city=self.config.city,
|
|
115
|
+
timezone=self.config.timezone,
|
|
116
|
+
gpu_model=gpu_info.get("model"),
|
|
117
|
+
gpu_memory_gb=gpu_info.get("memory_total_gb"),
|
|
118
|
+
gpu_count=gpu_info.get("gpu_count", 1),
|
|
119
|
+
supported_types=self.config.supported_types,
|
|
120
|
+
direct_url=direct_url,
|
|
121
|
+
supports_direct=self.config.direct.enabled
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
# 保存认证信息
|
|
125
|
+
self.worker_id = result["worker_id"]
|
|
126
|
+
self.token = result["token"]
|
|
127
|
+
self.refresh_token = result.get("refresh_token")
|
|
128
|
+
self.signing_secret = result.get("signing_secret")
|
|
129
|
+
|
|
130
|
+
if result.get("token_expires_at"):
|
|
131
|
+
self.token_expires_at = datetime.fromisoformat(result["token_expires_at"])
|
|
132
|
+
|
|
133
|
+
# 保存到配置文件
|
|
134
|
+
self.config.worker_id = self.worker_id
|
|
135
|
+
self.config.token = self.token
|
|
136
|
+
self.config.save()
|
|
137
|
+
|
|
138
|
+
# 更新API客户端凭据
|
|
139
|
+
self.api_client.set_credentials(self.token, self.signing_secret)
|
|
140
|
+
|
|
141
|
+
logger.info(f"Registered as worker: {self.worker_id}")
|
|
142
|
+
|
|
143
|
+
def _verify_credentials(self) -> bool:
|
|
144
|
+
"""验证当前凭据是否有效"""
|
|
145
|
+
try:
|
|
146
|
+
return self.api_client.verify_credentials(self.worker_id, self.token)
|
|
147
|
+
except Exception as e:
|
|
148
|
+
logger.error(f"Credential verification failed: {e}")
|
|
149
|
+
return False
|
|
150
|
+
|
|
151
|
+
def _fetch_remote_config(self):
|
|
152
|
+
"""从服务端获取配置"""
|
|
153
|
+
try:
|
|
154
|
+
config = self.api_client.get_config(self.worker_id)
|
|
155
|
+
|
|
156
|
+
if config and config.get("version", 0) > self.config_version:
|
|
157
|
+
self.remote_config = config
|
|
158
|
+
self.config_version = config.get("version", 0)
|
|
159
|
+
logger.info(f"Remote config updated to version {self.config_version}")
|
|
160
|
+
|
|
161
|
+
# 应用负载控制配置
|
|
162
|
+
self._apply_load_control(config.get("load_control", {}))
|
|
163
|
+
|
|
164
|
+
except Exception as e:
|
|
165
|
+
logger.warning(f"Failed to fetch remote config: {e}")
|
|
166
|
+
|
|
167
|
+
def _apply_load_control(self, load_control: Dict[str, Any]):
|
|
168
|
+
"""应用负载控制配置"""
|
|
169
|
+
# 检查工作时间
|
|
170
|
+
if "working_hours_start" in load_control and "working_hours_end" in load_control:
|
|
171
|
+
current_hour = datetime.now().hour
|
|
172
|
+
start = load_control["working_hours_start"]
|
|
173
|
+
end = load_control["working_hours_end"]
|
|
174
|
+
|
|
175
|
+
if start <= end:
|
|
176
|
+
in_working_hours = start <= current_hour < end
|
|
177
|
+
else:
|
|
178
|
+
in_working_hours = current_hour >= start or current_hour < end
|
|
179
|
+
|
|
180
|
+
if not in_working_hours:
|
|
181
|
+
self.accepting_jobs = False
|
|
182
|
+
logger.info("Outside working hours, not accepting jobs")
|
|
183
|
+
|
|
184
|
+
def _should_accept_job(self) -> bool:
|
|
185
|
+
"""检查是否应该接受任务(服务端决策为主)"""
|
|
186
|
+
if not self.accepting_jobs:
|
|
187
|
+
return False
|
|
188
|
+
|
|
189
|
+
# 服务端的负载控制配置
|
|
190
|
+
load_control = self.remote_config.get("load_control", {})
|
|
191
|
+
|
|
192
|
+
# 检查工作时间
|
|
193
|
+
if "working_hours_start" in load_control:
|
|
194
|
+
current_hour = datetime.now().hour
|
|
195
|
+
start = load_control.get("working_hours_start", 0)
|
|
196
|
+
end = load_control.get("working_hours_end", 24)
|
|
197
|
+
|
|
198
|
+
if start <= end:
|
|
199
|
+
if not (start <= current_hour < end):
|
|
200
|
+
return False
|
|
201
|
+
else:
|
|
202
|
+
if not (current_hour >= start or current_hour < end):
|
|
203
|
+
return False
|
|
204
|
+
|
|
205
|
+
return True
|
|
206
|
+
|
|
207
|
+
def _refresh_token_if_needed(self):
|
|
208
|
+
"""检查并刷新Token"""
|
|
209
|
+
if not self.token_expires_at or not self.refresh_token:
|
|
210
|
+
return
|
|
211
|
+
|
|
212
|
+
# 提前4小时刷新
|
|
213
|
+
refresh_threshold = datetime.utcnow()
|
|
214
|
+
hours_until_expiry = (self.token_expires_at - refresh_threshold).total_seconds() / 3600
|
|
215
|
+
|
|
216
|
+
if hours_until_expiry < 4:
|
|
217
|
+
try:
|
|
218
|
+
result = self.api_client.refresh_token(
|
|
219
|
+
self.worker_id, self.refresh_token
|
|
220
|
+
)
|
|
221
|
+
|
|
222
|
+
if result:
|
|
223
|
+
self.token = result["token"]
|
|
224
|
+
self.refresh_token = result.get("refresh_token")
|
|
225
|
+
if result.get("token_expires_at"):
|
|
226
|
+
self.token_expires_at = datetime.fromisoformat(result["token_expires_at"])
|
|
227
|
+
|
|
228
|
+
self.api_client.set_credentials(self.token, self.signing_secret)
|
|
229
|
+
logger.info("Token refreshed successfully")
|
|
230
|
+
|
|
231
|
+
except Exception as e:
|
|
232
|
+
logger.error(f"Token refresh failed: {e}")
|
|
233
|
+
|
|
234
|
+
def _load_engines(self):
|
|
235
|
+
"""加载推理引擎"""
|
|
236
|
+
# 优先使用服务端配置的模型
|
|
237
|
+
model_configs = self.remote_config.get("model_configs", {})
|
|
238
|
+
|
|
239
|
+
for engine_type in self.config.supported_types:
|
|
240
|
+
if engine_type not in ENGINE_REGISTRY:
|
|
241
|
+
logger.warning(f"Unknown engine type: {engine_type}, skipping")
|
|
242
|
+
continue
|
|
243
|
+
|
|
244
|
+
# 合并本地和远程配置
|
|
245
|
+
engine_config = self.config.engines.get(engine_type, {})
|
|
246
|
+
if engine_type in model_configs:
|
|
247
|
+
engine_config.update(model_configs[engine_type])
|
|
248
|
+
|
|
249
|
+
engine_config["enable_cpu_offload"] = self.config.gpu.enable_cpu_offload
|
|
250
|
+
|
|
251
|
+
logger.info(f"Loading engine: {engine_type}")
|
|
252
|
+
|
|
253
|
+
try:
|
|
254
|
+
engine = ENGINE_REGISTRY[engine_type](engine_config)
|
|
255
|
+
engine.load_model()
|
|
256
|
+
self.engines[engine_type] = engine
|
|
257
|
+
logger.info(f"Engine {engine_type} loaded successfully")
|
|
258
|
+
except Exception as e:
|
|
259
|
+
logger.error(f"Failed to load engine {engine_type}: {e}")
|
|
260
|
+
if engine_type in self.config.supported_types:
|
|
261
|
+
self.config.supported_types.remove(engine_type)
|
|
262
|
+
|
|
263
|
+
def _heartbeat_loop(self):
|
|
264
|
+
"""心跳循环"""
|
|
265
|
+
heartbeat_count = 0
|
|
266
|
+
|
|
267
|
+
while self.running and not self.shutdown_event.is_set():
|
|
268
|
+
try:
|
|
269
|
+
gpu_info = self._get_gpu_info()
|
|
270
|
+
|
|
271
|
+
# 确定状态
|
|
272
|
+
if not self.accepting_jobs:
|
|
273
|
+
status = "going_offline"
|
|
274
|
+
elif self.current_job_id:
|
|
275
|
+
status = "busy"
|
|
276
|
+
else:
|
|
277
|
+
status = "online"
|
|
278
|
+
|
|
279
|
+
response = self.api_client.heartbeat(
|
|
280
|
+
worker_id=self.worker_id,
|
|
281
|
+
status=status,
|
|
282
|
+
current_job_id=self.current_job_id,
|
|
283
|
+
gpu_memory_used_gb=gpu_info.get("memory_used_gb"),
|
|
284
|
+
supported_types=list(self.engines.keys()),
|
|
285
|
+
loaded_models=self._get_loaded_models(),
|
|
286
|
+
config_version=self.config_version
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
# 处理服务器响应
|
|
290
|
+
if response:
|
|
291
|
+
# 配置更新
|
|
292
|
+
if response.get("config_changed"):
|
|
293
|
+
self._fetch_remote_config()
|
|
294
|
+
|
|
295
|
+
# 服务器指令
|
|
296
|
+
action = response.get("action")
|
|
297
|
+
if action == "shutdown":
|
|
298
|
+
logger.info("Received shutdown command from server")
|
|
299
|
+
self.request_shutdown()
|
|
300
|
+
elif action == "reload_config":
|
|
301
|
+
self._fetch_remote_config()
|
|
302
|
+
|
|
303
|
+
# 定期刷新Token
|
|
304
|
+
heartbeat_count += 1
|
|
305
|
+
if heartbeat_count % 10 == 0:
|
|
306
|
+
self._refresh_token_if_needed()
|
|
307
|
+
|
|
308
|
+
except Exception as e:
|
|
309
|
+
logger.error(f"Heartbeat error: {e}")
|
|
310
|
+
|
|
311
|
+
self.shutdown_event.wait(timeout=self.config.heartbeat_interval)
|
|
312
|
+
|
|
313
|
+
def _main_loop(self):
|
|
314
|
+
"""任务处理主循环"""
|
|
315
|
+
while self.running:
|
|
316
|
+
if not self.accepting_jobs and not self.current_job_id:
|
|
317
|
+
logger.info("No more jobs to process, shutting down")
|
|
318
|
+
break
|
|
319
|
+
|
|
320
|
+
# 检查是否应该接受任务
|
|
321
|
+
if self._should_accept_job() and self.status == "idle":
|
|
322
|
+
try:
|
|
323
|
+
job = self.api_client.fetch_next_job(self.worker_id)
|
|
324
|
+
|
|
325
|
+
if job:
|
|
326
|
+
self._process_job(job)
|
|
327
|
+
|
|
328
|
+
except Exception as e:
|
|
329
|
+
logger.error(f"Error fetching job: {e}")
|
|
330
|
+
|
|
331
|
+
if self.shutdown_event.wait(timeout=self.config.poll_interval):
|
|
332
|
+
if not self.current_job_id:
|
|
333
|
+
break
|
|
334
|
+
|
|
335
|
+
def _process_job(self, job: Dict[str, Any]):
|
|
336
|
+
"""处理单个任务"""
|
|
337
|
+
job_id = job["job_id"]
|
|
338
|
+
job_type = job["type"]
|
|
339
|
+
params = job["params"]
|
|
340
|
+
|
|
341
|
+
logger.info(f"Processing job {job_id} (type: {job_type})")
|
|
342
|
+
|
|
343
|
+
self.status = "busy"
|
|
344
|
+
self.current_job_id = job_id
|
|
345
|
+
|
|
346
|
+
try:
|
|
347
|
+
engine = self.engines.get(job_type)
|
|
348
|
+
if not engine:
|
|
349
|
+
raise ValueError(f"No engine for type: {job_type}")
|
|
350
|
+
|
|
351
|
+
start_time = time.time()
|
|
352
|
+
result = engine.inference(params)
|
|
353
|
+
processing_time_ms = int((time.time() - start_time) * 1000)
|
|
354
|
+
|
|
355
|
+
self.api_client.complete_job(
|
|
356
|
+
worker_id=self.worker_id,
|
|
357
|
+
job_id=job_id,
|
|
358
|
+
success=True,
|
|
359
|
+
result=result,
|
|
360
|
+
processing_time_ms=processing_time_ms
|
|
361
|
+
)
|
|
362
|
+
|
|
363
|
+
logger.info(f"Job {job_id} completed in {processing_time_ms}ms")
|
|
364
|
+
|
|
365
|
+
except Exception as e:
|
|
366
|
+
logger.error(f"Error processing job {job_id}: {e}")
|
|
367
|
+
self.api_client.complete_job(
|
|
368
|
+
worker_id=self.worker_id,
|
|
369
|
+
job_id=job_id,
|
|
370
|
+
success=False,
|
|
371
|
+
error=str(e)
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
finally:
|
|
375
|
+
self.status = "idle"
|
|
376
|
+
self.current_job_id = None
|
|
377
|
+
|
|
378
|
+
def _get_loaded_models(self) -> List[str]:
|
|
379
|
+
"""获取已加载的模型列表"""
|
|
380
|
+
models = []
|
|
381
|
+
for engine_type, engine in self.engines.items():
|
|
382
|
+
if hasattr(engine, 'config') and 'model_id' in engine.config:
|
|
383
|
+
models.append(engine.config['model_id'])
|
|
384
|
+
return models
|
|
385
|
+
|
|
386
|
+
def _start_direct_server(self):
|
|
387
|
+
"""启动直连服务器(可选)"""
|
|
388
|
+
if not self.config.direct.enabled:
|
|
389
|
+
return
|
|
390
|
+
|
|
391
|
+
from direct_server import DirectServer
|
|
392
|
+
|
|
393
|
+
self.direct_server = DirectServer(
|
|
394
|
+
worker=self,
|
|
395
|
+
host=self.config.direct.host,
|
|
396
|
+
port=self.config.direct.port
|
|
397
|
+
)
|
|
398
|
+
|
|
399
|
+
direct_thread = Thread(target=self.direct_server.start, daemon=True)
|
|
400
|
+
direct_thread.start()
|
|
401
|
+
|
|
402
|
+
logger.info(f"Direct server started on {self.config.direct.host}:{self.config.direct.port}")
|
|
403
|
+
|
|
404
|
+
def start(self):
|
|
405
|
+
"""启动Worker"""
|
|
406
|
+
logger.info("=" * 50)
|
|
407
|
+
logger.info("Starting GPU Worker (Lightweight)")
|
|
408
|
+
logger.info("=" * 50)
|
|
409
|
+
|
|
410
|
+
signal.signal(signal.SIGINT, self._signal_handler)
|
|
411
|
+
signal.signal(signal.SIGTERM, self._signal_handler)
|
|
412
|
+
|
|
413
|
+
self.api_client = APIClient(
|
|
414
|
+
base_url=self.config.server.url,
|
|
415
|
+
token=self.token,
|
|
416
|
+
timeout=self.config.server.timeout
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
self._register()
|
|
420
|
+
self._load_engines()
|
|
421
|
+
|
|
422
|
+
if not self.engines:
|
|
423
|
+
logger.error("No engines loaded, exiting")
|
|
424
|
+
return
|
|
425
|
+
|
|
426
|
+
self.running = True
|
|
427
|
+
self.status = "idle"
|
|
428
|
+
|
|
429
|
+
self._start_direct_server()
|
|
430
|
+
|
|
431
|
+
heartbeat_thread = Thread(target=self._heartbeat_loop, daemon=True)
|
|
432
|
+
heartbeat_thread.start()
|
|
433
|
+
|
|
434
|
+
logger.info(f"Worker {self.worker_id} started")
|
|
435
|
+
logger.info(f"Supported types: {list(self.engines.keys())}")
|
|
436
|
+
|
|
437
|
+
try:
|
|
438
|
+
self._main_loop()
|
|
439
|
+
except Exception as e:
|
|
440
|
+
logger.error(f"Main loop error: {e}")
|
|
441
|
+
finally:
|
|
442
|
+
self.shutdown()
|
|
443
|
+
|
|
444
|
+
def request_shutdown(self, graceful: bool = True):
|
|
445
|
+
"""请求关闭Worker"""
|
|
446
|
+
logger.info(f"Shutdown requested (graceful={graceful})")
|
|
447
|
+
|
|
448
|
+
if graceful:
|
|
449
|
+
self.accepting_jobs = False
|
|
450
|
+
|
|
451
|
+
try:
|
|
452
|
+
self.api_client.notify_going_offline(
|
|
453
|
+
self.worker_id,
|
|
454
|
+
finish_current=True
|
|
455
|
+
)
|
|
456
|
+
except Exception as e:
|
|
457
|
+
logger.error(f"Failed to notify server: {e}")
|
|
458
|
+
|
|
459
|
+
if not self.current_job_id:
|
|
460
|
+
self.shutdown_event.set()
|
|
461
|
+
else:
|
|
462
|
+
self.shutdown_event.set()
|
|
463
|
+
self.running = False
|
|
464
|
+
|
|
465
|
+
def shutdown(self):
|
|
466
|
+
"""关闭Worker"""
|
|
467
|
+
logger.info("Shutting down worker...")
|
|
468
|
+
self.running = False
|
|
469
|
+
self.shutdown_event.set()
|
|
470
|
+
|
|
471
|
+
try:
|
|
472
|
+
if self.api_client and self.worker_id:
|
|
473
|
+
self.api_client.notify_offline(self.worker_id)
|
|
474
|
+
except Exception as e:
|
|
475
|
+
logger.error(f"Failed to notify offline: {e}")
|
|
476
|
+
|
|
477
|
+
if self.direct_server:
|
|
478
|
+
self.direct_server.stop()
|
|
479
|
+
|
|
480
|
+
for name, engine in self.engines.items():
|
|
481
|
+
try:
|
|
482
|
+
engine.unload_model()
|
|
483
|
+
logger.info(f"Engine {name} unloaded")
|
|
484
|
+
except Exception as e:
|
|
485
|
+
logger.error(f"Error unloading engine {name}: {e}")
|
|
486
|
+
|
|
487
|
+
if self.api_client:
|
|
488
|
+
self.api_client.close()
|
|
489
|
+
|
|
490
|
+
logger.info("Worker shutdown complete")
|
|
491
|
+
|
|
492
|
+
def _signal_handler(self, signum, frame):
|
|
493
|
+
"""信号处理器"""
|
|
494
|
+
logger.info(f"Received signal {signum}")
|
|
495
|
+
self.request_shutdown(graceful=True)
|
|
496
|
+
|
|
497
|
+
|
|
498
|
+
def main():
|
|
499
|
+
"""主入口 - 由CLI调用"""
|
|
500
|
+
import argparse
|
|
501
|
+
|
|
502
|
+
parser = argparse.ArgumentParser(description="GPU Worker")
|
|
503
|
+
parser.add_argument("--config", "-c", default="config.yaml", help="Config file path")
|
|
504
|
+
parser.add_argument("--region", "-r", help="Override region")
|
|
505
|
+
parser.add_argument("--server", "-s", help="Override server URL")
|
|
506
|
+
|
|
507
|
+
args = parser.parse_args()
|
|
508
|
+
|
|
509
|
+
config = load_config(args.config)
|
|
510
|
+
|
|
511
|
+
if args.region:
|
|
512
|
+
config.region = args.region
|
|
513
|
+
if args.server:
|
|
514
|
+
config.server.url = args.server
|
|
515
|
+
|
|
516
|
+
worker = Worker(config)
|
|
517
|
+
worker.start()
|
|
518
|
+
|
|
519
|
+
|
|
520
|
+
if __name__ == "__main__":
|
|
521
|
+
main()
|
package/package.json
ADDED
|
@@ -0,0 +1,64 @@
|
|
|
1
|
+
{
|
|
2
|
+
"name": "gpu-worker",
|
|
3
|
+
"version": "1.0.0",
|
|
4
|
+
"description": "Distributed GPU Inference Worker - Share idle GPU computing power for LLM and image generation",
|
|
5
|
+
"keywords": [
|
|
6
|
+
"gpu",
|
|
7
|
+
"inference",
|
|
8
|
+
"distributed",
|
|
9
|
+
"machine-learning",
|
|
10
|
+
"llm",
|
|
11
|
+
"stable-diffusion",
|
|
12
|
+
"cuda",
|
|
13
|
+
"pytorch",
|
|
14
|
+
"transformers",
|
|
15
|
+
"vllm",
|
|
16
|
+
"sglang"
|
|
17
|
+
],
|
|
18
|
+
"author": "Baozhi888 <kj331704@gmail.com>",
|
|
19
|
+
"license": "MIT",
|
|
20
|
+
"repository": {
|
|
21
|
+
"type": "git",
|
|
22
|
+
"url": "git+https://github.com/Baozhi888/distributed-gpu-inference.git",
|
|
23
|
+
"directory": "worker"
|
|
24
|
+
},
|
|
25
|
+
"homepage": "https://github.com/Baozhi888/distributed-gpu-inference#readme",
|
|
26
|
+
"bugs": {
|
|
27
|
+
"url": "https://github.com/Baozhi888/distributed-gpu-inference/issues"
|
|
28
|
+
},
|
|
29
|
+
"bin": {
|
|
30
|
+
"gpu-worker": "./bin/gpu-worker.js"
|
|
31
|
+
},
|
|
32
|
+
"scripts": {
|
|
33
|
+
"postinstall": "node scripts/postinstall.js",
|
|
34
|
+
"start": "node bin/gpu-worker.js start",
|
|
35
|
+
"configure": "node bin/gpu-worker.js configure",
|
|
36
|
+
"status": "node bin/gpu-worker.js status"
|
|
37
|
+
},
|
|
38
|
+
"engines": {
|
|
39
|
+
"node": ">=16.0.0"
|
|
40
|
+
},
|
|
41
|
+
"os": [
|
|
42
|
+
"win32",
|
|
43
|
+
"linux",
|
|
44
|
+
"darwin"
|
|
45
|
+
],
|
|
46
|
+
"files": [
|
|
47
|
+
"bin/",
|
|
48
|
+
"scripts/",
|
|
49
|
+
"engines/",
|
|
50
|
+
"distributed/",
|
|
51
|
+
"*.py",
|
|
52
|
+
"*.yaml",
|
|
53
|
+
"*.txt",
|
|
54
|
+
"!.env*",
|
|
55
|
+
"README.md"
|
|
56
|
+
],
|
|
57
|
+
"dependencies": {
|
|
58
|
+
"chalk": "^4.1.2",
|
|
59
|
+
"commander": "^11.1.0",
|
|
60
|
+
"inquirer": "^8.2.6",
|
|
61
|
+
"ora": "^5.4.1",
|
|
62
|
+
"which": "^3.0.1"
|
|
63
|
+
}
|
|
64
|
+
}
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# SGLang 高性能推理后端依赖
|
|
2
|
+
# 安装: pip install -r requirements-sglang.txt
|
|
3
|
+
|
|
4
|
+
# 核心依赖
|
|
5
|
+
sglang[all]>=0.4.0
|
|
6
|
+
|
|
7
|
+
# FlashInfer (可选,用于加速)
|
|
8
|
+
# 请使用以下命令安装对应 CUDA 版本:
|
|
9
|
+
# pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
|
|
10
|
+
# pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4/
|
|
11
|
+
|
|
12
|
+
# 注意: SGLang 需要 CUDA 12.1+ 和 Python 3.9+
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
# vLLM 高性能推理后端依赖
|
|
2
|
+
# 安装: pip install -r requirements-vllm.txt
|
|
3
|
+
|
|
4
|
+
# 核心依赖
|
|
5
|
+
vllm>=0.6.0
|
|
6
|
+
|
|
7
|
+
# 注意: vLLM 需要:
|
|
8
|
+
# - CUDA 11.8+ 或 12.x
|
|
9
|
+
# - Python 3.8+
|
|
10
|
+
# - Linux (Windows 需要 WSL2)
|
|
11
|
+
|
|
12
|
+
# 可选量化支持
|
|
13
|
+
# bitsandbytes>=0.42.0 # INT8/NF4 量化
|
|
14
|
+
# auto-gptq>=0.7.0 # GPTQ 量化
|
|
15
|
+
# autoawq>=0.2.0 # AWQ 量化
|
package/requirements.txt
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
# Worker核心依赖
|
|
2
|
+
torch>=2.0.0
|
|
3
|
+
transformers>=4.35.0
|
|
4
|
+
diffusers>=0.24.0
|
|
5
|
+
accelerate>=0.24.0
|
|
6
|
+
peft>=0.6.0
|
|
7
|
+
httpx>=0.25.0
|
|
8
|
+
pyyaml>=6.0
|
|
9
|
+
pydantic>=2.0.0
|
|
10
|
+
sentencepiece
|
|
11
|
+
protobuf
|
|
12
|
+
|
|
13
|
+
# CLI工具
|
|
14
|
+
rich>=13.0.0
|
|
15
|
+
|
|
16
|
+
# 直连服务器
|
|
17
|
+
fastapi>=0.100.0
|
|
18
|
+
uvicorn>=0.23.0
|
|
19
|
+
aiohttp>=3.9.0
|
|
20
|
+
|
|
21
|
+
# 高性能推理后端(可选,按需安装)
|
|
22
|
+
# sglang>=0.1.0 # SGLang 推理后端
|
|
23
|
+
# vllm>=0.4.0 # vLLM 推理后端
|
|
24
|
+
|
|
25
|
+
# 分布式通信
|
|
26
|
+
grpcio>=1.60.0
|
|
27
|
+
grpcio-tools>=1.60.0
|
|
28
|
+
|
|
29
|
+
# 缓存和序列化
|
|
30
|
+
redis>=5.0.0
|
|
31
|
+
lz4>=4.3.0 # 快速压缩
|
|
32
|
+
|
|
33
|
+
# 可选: CUDA量化支持 (需要NVIDIA GPU)
|
|
34
|
+
# bitsandbytes>=0.41.0
|
|
35
|
+
|