gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/main.py ADDED
@@ -0,0 +1,521 @@
1
+ """
2
+ Worker主程序 - 轻量版
3
+ 核心逻辑在服务端,Worker仅负责:
4
+ - 注册和心跳
5
+ - 执行推理任务
6
+ - 汇报状态
7
+ """
8
+ import time
9
+ import signal
10
+ import logging
11
+ from threading import Thread, Event
12
+ from typing import Optional, Dict, Any, List
13
+ from datetime import datetime
14
+ import torch
15
+ import sys
16
+
17
+ from config import WorkerConfig, load_config
18
+ from api_client import APIClient
19
+ from engines import ENGINE_REGISTRY, BaseEngine
20
+
21
+ logging.basicConfig(
22
+ level=logging.INFO,
23
+ format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
24
+ )
25
+ logger = logging.getLogger(__name__)
26
+
27
+
28
+ class Worker:
29
+ """
30
+ 分布式GPU推理Worker - 轻量版
31
+
32
+ 核心原则:
33
+ - Worker是"哑终端",服务端做决策
34
+ - 配置从服务端获取
35
+ - 负载控制由服务端判断
36
+ """
37
+
38
+ def __init__(self, config: WorkerConfig = None):
39
+ self.config = config or load_config()
40
+ self.api_client: Optional[APIClient] = None
41
+ self.engines: Dict[str, BaseEngine] = {}
42
+
43
+ # 认证信息
44
+ self.worker_id: Optional[str] = self.config.worker_id
45
+ self.token: Optional[str] = self.config.token
46
+ self.refresh_token: Optional[str] = None
47
+ self.signing_secret: Optional[str] = None
48
+ self.token_expires_at: Optional[datetime] = None
49
+
50
+ # 状态
51
+ self.status = "initializing"
52
+ self.current_job_id: Optional[str] = None
53
+ self.running = False
54
+ self.accepting_jobs = True
55
+
56
+ # 服务端配置(从服务器获取)
57
+ self.remote_config: Dict[str, Any] = {}
58
+ self.config_version = 0
59
+
60
+ # 用于优雅关闭
61
+ self.shutdown_event = Event()
62
+
63
+ # 直连服务器(可选)
64
+ self.direct_server = None
65
+
66
+ def _get_gpu_info(self) -> Dict[str, Any]:
67
+ """获取GPU信息"""
68
+ if torch.cuda.is_available():
69
+ props = torch.cuda.get_device_properties(0)
70
+ return {
71
+ "model": torch.cuda.get_device_name(0),
72
+ "memory_total_gb": round(props.total_memory / 1024**3, 2),
73
+ "memory_used_gb": round(torch.cuda.memory_allocated() / 1024**3, 2),
74
+ "gpu_count": torch.cuda.device_count()
75
+ }
76
+ return {
77
+ "model": "CPU Only",
78
+ "memory_total_gb": 0,
79
+ "memory_used_gb": 0,
80
+ "gpu_count": 0
81
+ }
82
+
83
+ def _register(self):
84
+ """注册Worker到服务器"""
85
+ if self.worker_id and self.token:
86
+ logger.info(f"Using existing worker ID: {self.worker_id}")
87
+ self.api_client.set_credentials(self.token, self.signing_secret)
88
+
89
+ # 验证凭据是否有效
90
+ if not self._verify_credentials():
91
+ logger.warning("Existing credentials invalid, re-registering...")
92
+ self.worker_id = None
93
+ self.token = None
94
+
95
+ if not self.worker_id:
96
+ self._do_register()
97
+
98
+ # 获取远程配置
99
+ self._fetch_remote_config()
100
+
101
+ def _do_register(self):
102
+ """执行注册"""
103
+ gpu_info = self._get_gpu_info()
104
+
105
+ # 构建直连URL
106
+ direct_url = None
107
+ if self.config.direct.enabled and self.config.direct.public_url:
108
+ direct_url = self.config.direct.public_url
109
+
110
+ result = self.api_client.register(
111
+ name=self.config.name or f"Worker-{gpu_info.get('model', 'Unknown')[:20]}",
112
+ region=self.config.region,
113
+ country=self.config.country,
114
+ city=self.config.city,
115
+ timezone=self.config.timezone,
116
+ gpu_model=gpu_info.get("model"),
117
+ gpu_memory_gb=gpu_info.get("memory_total_gb"),
118
+ gpu_count=gpu_info.get("gpu_count", 1),
119
+ supported_types=self.config.supported_types,
120
+ direct_url=direct_url,
121
+ supports_direct=self.config.direct.enabled
122
+ )
123
+
124
+ # 保存认证信息
125
+ self.worker_id = result["worker_id"]
126
+ self.token = result["token"]
127
+ self.refresh_token = result.get("refresh_token")
128
+ self.signing_secret = result.get("signing_secret")
129
+
130
+ if result.get("token_expires_at"):
131
+ self.token_expires_at = datetime.fromisoformat(result["token_expires_at"])
132
+
133
+ # 保存到配置文件
134
+ self.config.worker_id = self.worker_id
135
+ self.config.token = self.token
136
+ self.config.save()
137
+
138
+ # 更新API客户端凭据
139
+ self.api_client.set_credentials(self.token, self.signing_secret)
140
+
141
+ logger.info(f"Registered as worker: {self.worker_id}")
142
+
143
+ def _verify_credentials(self) -> bool:
144
+ """验证当前凭据是否有效"""
145
+ try:
146
+ return self.api_client.verify_credentials(self.worker_id, self.token)
147
+ except Exception as e:
148
+ logger.error(f"Credential verification failed: {e}")
149
+ return False
150
+
151
+ def _fetch_remote_config(self):
152
+ """从服务端获取配置"""
153
+ try:
154
+ config = self.api_client.get_config(self.worker_id)
155
+
156
+ if config and config.get("version", 0) > self.config_version:
157
+ self.remote_config = config
158
+ self.config_version = config.get("version", 0)
159
+ logger.info(f"Remote config updated to version {self.config_version}")
160
+
161
+ # 应用负载控制配置
162
+ self._apply_load_control(config.get("load_control", {}))
163
+
164
+ except Exception as e:
165
+ logger.warning(f"Failed to fetch remote config: {e}")
166
+
167
+ def _apply_load_control(self, load_control: Dict[str, Any]):
168
+ """应用负载控制配置"""
169
+ # 检查工作时间
170
+ if "working_hours_start" in load_control and "working_hours_end" in load_control:
171
+ current_hour = datetime.now().hour
172
+ start = load_control["working_hours_start"]
173
+ end = load_control["working_hours_end"]
174
+
175
+ if start <= end:
176
+ in_working_hours = start <= current_hour < end
177
+ else:
178
+ in_working_hours = current_hour >= start or current_hour < end
179
+
180
+ if not in_working_hours:
181
+ self.accepting_jobs = False
182
+ logger.info("Outside working hours, not accepting jobs")
183
+
184
+ def _should_accept_job(self) -> bool:
185
+ """检查是否应该接受任务(服务端决策为主)"""
186
+ if not self.accepting_jobs:
187
+ return False
188
+
189
+ # 服务端的负载控制配置
190
+ load_control = self.remote_config.get("load_control", {})
191
+
192
+ # 检查工作时间
193
+ if "working_hours_start" in load_control:
194
+ current_hour = datetime.now().hour
195
+ start = load_control.get("working_hours_start", 0)
196
+ end = load_control.get("working_hours_end", 24)
197
+
198
+ if start <= end:
199
+ if not (start <= current_hour < end):
200
+ return False
201
+ else:
202
+ if not (current_hour >= start or current_hour < end):
203
+ return False
204
+
205
+ return True
206
+
207
+ def _refresh_token_if_needed(self):
208
+ """检查并刷新Token"""
209
+ if not self.token_expires_at or not self.refresh_token:
210
+ return
211
+
212
+ # 提前4小时刷新
213
+ refresh_threshold = datetime.utcnow()
214
+ hours_until_expiry = (self.token_expires_at - refresh_threshold).total_seconds() / 3600
215
+
216
+ if hours_until_expiry < 4:
217
+ try:
218
+ result = self.api_client.refresh_token(
219
+ self.worker_id, self.refresh_token
220
+ )
221
+
222
+ if result:
223
+ self.token = result["token"]
224
+ self.refresh_token = result.get("refresh_token")
225
+ if result.get("token_expires_at"):
226
+ self.token_expires_at = datetime.fromisoformat(result["token_expires_at"])
227
+
228
+ self.api_client.set_credentials(self.token, self.signing_secret)
229
+ logger.info("Token refreshed successfully")
230
+
231
+ except Exception as e:
232
+ logger.error(f"Token refresh failed: {e}")
233
+
234
+ def _load_engines(self):
235
+ """加载推理引擎"""
236
+ # 优先使用服务端配置的模型
237
+ model_configs = self.remote_config.get("model_configs", {})
238
+
239
+ for engine_type in self.config.supported_types:
240
+ if engine_type not in ENGINE_REGISTRY:
241
+ logger.warning(f"Unknown engine type: {engine_type}, skipping")
242
+ continue
243
+
244
+ # 合并本地和远程配置
245
+ engine_config = self.config.engines.get(engine_type, {})
246
+ if engine_type in model_configs:
247
+ engine_config.update(model_configs[engine_type])
248
+
249
+ engine_config["enable_cpu_offload"] = self.config.gpu.enable_cpu_offload
250
+
251
+ logger.info(f"Loading engine: {engine_type}")
252
+
253
+ try:
254
+ engine = ENGINE_REGISTRY[engine_type](engine_config)
255
+ engine.load_model()
256
+ self.engines[engine_type] = engine
257
+ logger.info(f"Engine {engine_type} loaded successfully")
258
+ except Exception as e:
259
+ logger.error(f"Failed to load engine {engine_type}: {e}")
260
+ if engine_type in self.config.supported_types:
261
+ self.config.supported_types.remove(engine_type)
262
+
263
+ def _heartbeat_loop(self):
264
+ """心跳循环"""
265
+ heartbeat_count = 0
266
+
267
+ while self.running and not self.shutdown_event.is_set():
268
+ try:
269
+ gpu_info = self._get_gpu_info()
270
+
271
+ # 确定状态
272
+ if not self.accepting_jobs:
273
+ status = "going_offline"
274
+ elif self.current_job_id:
275
+ status = "busy"
276
+ else:
277
+ status = "online"
278
+
279
+ response = self.api_client.heartbeat(
280
+ worker_id=self.worker_id,
281
+ status=status,
282
+ current_job_id=self.current_job_id,
283
+ gpu_memory_used_gb=gpu_info.get("memory_used_gb"),
284
+ supported_types=list(self.engines.keys()),
285
+ loaded_models=self._get_loaded_models(),
286
+ config_version=self.config_version
287
+ )
288
+
289
+ # 处理服务器响应
290
+ if response:
291
+ # 配置更新
292
+ if response.get("config_changed"):
293
+ self._fetch_remote_config()
294
+
295
+ # 服务器指令
296
+ action = response.get("action")
297
+ if action == "shutdown":
298
+ logger.info("Received shutdown command from server")
299
+ self.request_shutdown()
300
+ elif action == "reload_config":
301
+ self._fetch_remote_config()
302
+
303
+ # 定期刷新Token
304
+ heartbeat_count += 1
305
+ if heartbeat_count % 10 == 0:
306
+ self._refresh_token_if_needed()
307
+
308
+ except Exception as e:
309
+ logger.error(f"Heartbeat error: {e}")
310
+
311
+ self.shutdown_event.wait(timeout=self.config.heartbeat_interval)
312
+
313
+ def _main_loop(self):
314
+ """任务处理主循环"""
315
+ while self.running:
316
+ if not self.accepting_jobs and not self.current_job_id:
317
+ logger.info("No more jobs to process, shutting down")
318
+ break
319
+
320
+ # 检查是否应该接受任务
321
+ if self._should_accept_job() and self.status == "idle":
322
+ try:
323
+ job = self.api_client.fetch_next_job(self.worker_id)
324
+
325
+ if job:
326
+ self._process_job(job)
327
+
328
+ except Exception as e:
329
+ logger.error(f"Error fetching job: {e}")
330
+
331
+ if self.shutdown_event.wait(timeout=self.config.poll_interval):
332
+ if not self.current_job_id:
333
+ break
334
+
335
+ def _process_job(self, job: Dict[str, Any]):
336
+ """处理单个任务"""
337
+ job_id = job["job_id"]
338
+ job_type = job["type"]
339
+ params = job["params"]
340
+
341
+ logger.info(f"Processing job {job_id} (type: {job_type})")
342
+
343
+ self.status = "busy"
344
+ self.current_job_id = job_id
345
+
346
+ try:
347
+ engine = self.engines.get(job_type)
348
+ if not engine:
349
+ raise ValueError(f"No engine for type: {job_type}")
350
+
351
+ start_time = time.time()
352
+ result = engine.inference(params)
353
+ processing_time_ms = int((time.time() - start_time) * 1000)
354
+
355
+ self.api_client.complete_job(
356
+ worker_id=self.worker_id,
357
+ job_id=job_id,
358
+ success=True,
359
+ result=result,
360
+ processing_time_ms=processing_time_ms
361
+ )
362
+
363
+ logger.info(f"Job {job_id} completed in {processing_time_ms}ms")
364
+
365
+ except Exception as e:
366
+ logger.error(f"Error processing job {job_id}: {e}")
367
+ self.api_client.complete_job(
368
+ worker_id=self.worker_id,
369
+ job_id=job_id,
370
+ success=False,
371
+ error=str(e)
372
+ )
373
+
374
+ finally:
375
+ self.status = "idle"
376
+ self.current_job_id = None
377
+
378
+ def _get_loaded_models(self) -> List[str]:
379
+ """获取已加载的模型列表"""
380
+ models = []
381
+ for engine_type, engine in self.engines.items():
382
+ if hasattr(engine, 'config') and 'model_id' in engine.config:
383
+ models.append(engine.config['model_id'])
384
+ return models
385
+
386
+ def _start_direct_server(self):
387
+ """启动直连服务器(可选)"""
388
+ if not self.config.direct.enabled:
389
+ return
390
+
391
+ from direct_server import DirectServer
392
+
393
+ self.direct_server = DirectServer(
394
+ worker=self,
395
+ host=self.config.direct.host,
396
+ port=self.config.direct.port
397
+ )
398
+
399
+ direct_thread = Thread(target=self.direct_server.start, daemon=True)
400
+ direct_thread.start()
401
+
402
+ logger.info(f"Direct server started on {self.config.direct.host}:{self.config.direct.port}")
403
+
404
+ def start(self):
405
+ """启动Worker"""
406
+ logger.info("=" * 50)
407
+ logger.info("Starting GPU Worker (Lightweight)")
408
+ logger.info("=" * 50)
409
+
410
+ signal.signal(signal.SIGINT, self._signal_handler)
411
+ signal.signal(signal.SIGTERM, self._signal_handler)
412
+
413
+ self.api_client = APIClient(
414
+ base_url=self.config.server.url,
415
+ token=self.token,
416
+ timeout=self.config.server.timeout
417
+ )
418
+
419
+ self._register()
420
+ self._load_engines()
421
+
422
+ if not self.engines:
423
+ logger.error("No engines loaded, exiting")
424
+ return
425
+
426
+ self.running = True
427
+ self.status = "idle"
428
+
429
+ self._start_direct_server()
430
+
431
+ heartbeat_thread = Thread(target=self._heartbeat_loop, daemon=True)
432
+ heartbeat_thread.start()
433
+
434
+ logger.info(f"Worker {self.worker_id} started")
435
+ logger.info(f"Supported types: {list(self.engines.keys())}")
436
+
437
+ try:
438
+ self._main_loop()
439
+ except Exception as e:
440
+ logger.error(f"Main loop error: {e}")
441
+ finally:
442
+ self.shutdown()
443
+
444
+ def request_shutdown(self, graceful: bool = True):
445
+ """请求关闭Worker"""
446
+ logger.info(f"Shutdown requested (graceful={graceful})")
447
+
448
+ if graceful:
449
+ self.accepting_jobs = False
450
+
451
+ try:
452
+ self.api_client.notify_going_offline(
453
+ self.worker_id,
454
+ finish_current=True
455
+ )
456
+ except Exception as e:
457
+ logger.error(f"Failed to notify server: {e}")
458
+
459
+ if not self.current_job_id:
460
+ self.shutdown_event.set()
461
+ else:
462
+ self.shutdown_event.set()
463
+ self.running = False
464
+
465
+ def shutdown(self):
466
+ """关闭Worker"""
467
+ logger.info("Shutting down worker...")
468
+ self.running = False
469
+ self.shutdown_event.set()
470
+
471
+ try:
472
+ if self.api_client and self.worker_id:
473
+ self.api_client.notify_offline(self.worker_id)
474
+ except Exception as e:
475
+ logger.error(f"Failed to notify offline: {e}")
476
+
477
+ if self.direct_server:
478
+ self.direct_server.stop()
479
+
480
+ for name, engine in self.engines.items():
481
+ try:
482
+ engine.unload_model()
483
+ logger.info(f"Engine {name} unloaded")
484
+ except Exception as e:
485
+ logger.error(f"Error unloading engine {name}: {e}")
486
+
487
+ if self.api_client:
488
+ self.api_client.close()
489
+
490
+ logger.info("Worker shutdown complete")
491
+
492
+ def _signal_handler(self, signum, frame):
493
+ """信号处理器"""
494
+ logger.info(f"Received signal {signum}")
495
+ self.request_shutdown(graceful=True)
496
+
497
+
498
+ def main():
499
+ """主入口 - 由CLI调用"""
500
+ import argparse
501
+
502
+ parser = argparse.ArgumentParser(description="GPU Worker")
503
+ parser.add_argument("--config", "-c", default="config.yaml", help="Config file path")
504
+ parser.add_argument("--region", "-r", help="Override region")
505
+ parser.add_argument("--server", "-s", help="Override server URL")
506
+
507
+ args = parser.parse_args()
508
+
509
+ config = load_config(args.config)
510
+
511
+ if args.region:
512
+ config.region = args.region
513
+ if args.server:
514
+ config.server.url = args.server
515
+
516
+ worker = Worker(config)
517
+ worker.start()
518
+
519
+
520
+ if __name__ == "__main__":
521
+ main()
package/package.json ADDED
@@ -0,0 +1,64 @@
1
+ {
2
+ "name": "gpu-worker",
3
+ "version": "1.0.0",
4
+ "description": "Distributed GPU Inference Worker - Share idle GPU computing power for LLM and image generation",
5
+ "keywords": [
6
+ "gpu",
7
+ "inference",
8
+ "distributed",
9
+ "machine-learning",
10
+ "llm",
11
+ "stable-diffusion",
12
+ "cuda",
13
+ "pytorch",
14
+ "transformers",
15
+ "vllm",
16
+ "sglang"
17
+ ],
18
+ "author": "Baozhi888 <kj331704@gmail.com>",
19
+ "license": "MIT",
20
+ "repository": {
21
+ "type": "git",
22
+ "url": "git+https://github.com/Baozhi888/distributed-gpu-inference.git",
23
+ "directory": "worker"
24
+ },
25
+ "homepage": "https://github.com/Baozhi888/distributed-gpu-inference#readme",
26
+ "bugs": {
27
+ "url": "https://github.com/Baozhi888/distributed-gpu-inference/issues"
28
+ },
29
+ "bin": {
30
+ "gpu-worker": "./bin/gpu-worker.js"
31
+ },
32
+ "scripts": {
33
+ "postinstall": "node scripts/postinstall.js",
34
+ "start": "node bin/gpu-worker.js start",
35
+ "configure": "node bin/gpu-worker.js configure",
36
+ "status": "node bin/gpu-worker.js status"
37
+ },
38
+ "engines": {
39
+ "node": ">=16.0.0"
40
+ },
41
+ "os": [
42
+ "win32",
43
+ "linux",
44
+ "darwin"
45
+ ],
46
+ "files": [
47
+ "bin/",
48
+ "scripts/",
49
+ "engines/",
50
+ "distributed/",
51
+ "*.py",
52
+ "*.yaml",
53
+ "*.txt",
54
+ "!.env*",
55
+ "README.md"
56
+ ],
57
+ "dependencies": {
58
+ "chalk": "^4.1.2",
59
+ "commander": "^11.1.0",
60
+ "inquirer": "^8.2.6",
61
+ "ora": "^5.4.1",
62
+ "which": "^3.0.1"
63
+ }
64
+ }
@@ -0,0 +1,12 @@
1
+ # SGLang 高性能推理后端依赖
2
+ # 安装: pip install -r requirements-sglang.txt
3
+
4
+ # 核心依赖
5
+ sglang[all]>=0.4.0
6
+
7
+ # FlashInfer (可选,用于加速)
8
+ # 请使用以下命令安装对应 CUDA 版本:
9
+ # pip install flashinfer -i https://flashinfer.ai/whl/cu121/torch2.4/
10
+ # pip install flashinfer -i https://flashinfer.ai/whl/cu124/torch2.4/
11
+
12
+ # 注意: SGLang 需要 CUDA 12.1+ 和 Python 3.9+
@@ -0,0 +1,15 @@
1
+ # vLLM 高性能推理后端依赖
2
+ # 安装: pip install -r requirements-vllm.txt
3
+
4
+ # 核心依赖
5
+ vllm>=0.6.0
6
+
7
+ # 注意: vLLM 需要:
8
+ # - CUDA 11.8+ 或 12.x
9
+ # - Python 3.8+
10
+ # - Linux (Windows 需要 WSL2)
11
+
12
+ # 可选量化支持
13
+ # bitsandbytes>=0.42.0 # INT8/NF4 量化
14
+ # auto-gptq>=0.7.0 # GPTQ 量化
15
+ # autoawq>=0.2.0 # AWQ 量化
@@ -0,0 +1,35 @@
1
+ # Worker核心依赖
2
+ torch>=2.0.0
3
+ transformers>=4.35.0
4
+ diffusers>=0.24.0
5
+ accelerate>=0.24.0
6
+ peft>=0.6.0
7
+ httpx>=0.25.0
8
+ pyyaml>=6.0
9
+ pydantic>=2.0.0
10
+ sentencepiece
11
+ protobuf
12
+
13
+ # CLI工具
14
+ rich>=13.0.0
15
+
16
+ # 直连服务器
17
+ fastapi>=0.100.0
18
+ uvicorn>=0.23.0
19
+ aiohttp>=3.9.0
20
+
21
+ # 高性能推理后端(可选,按需安装)
22
+ # sglang>=0.1.0 # SGLang 推理后端
23
+ # vllm>=0.4.0 # vLLM 推理后端
24
+
25
+ # 分布式通信
26
+ grpcio>=1.60.0
27
+ grpcio-tools>=1.60.0
28
+
29
+ # 缓存和序列化
30
+ redis>=5.0.0
31
+ lz4>=4.3.0 # 快速压缩
32
+
33
+ # 可选: CUDA量化支持 (需要NVIDIA GPU)
34
+ # bitsandbytes>=0.41.0
35
+