gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
package/cli.py ADDED
@@ -0,0 +1,729 @@
1
+ #!/usr/bin/env python3
2
+ """
3
+ GPU Worker CLI 安装器和配置向导
4
+ 提供交互式的安装和配置体验
5
+ """
6
+ import os
7
+ import sys
8
+ import argparse
9
+ import subprocess
10
+ import platform
11
+ import shutil
12
+ from pathlib import Path
13
+ from typing import Optional, Dict, Any, List
14
+ import json
15
+ import time
16
+
17
+ # 尝试导入rich库(用于漂亮的终端输出)
18
+ try:
19
+ from rich.console import Console
20
+ from rich.panel import Panel
21
+ from rich.prompt import Prompt, Confirm, IntPrompt, FloatPrompt
22
+ from rich.table import Table
23
+ from rich.progress import Progress, SpinnerColumn, TextColumn
24
+ from rich import print as rprint
25
+ RICH_AVAILABLE = True
26
+ except ImportError:
27
+ RICH_AVAILABLE = False
28
+
29
+ # 简单的控制台输出(无rich时的降级方案)
30
+ class SimpleConsole:
31
+ def print(self, *args, **kwargs):
32
+ # 移除style参数
33
+ kwargs.pop('style', None)
34
+ print(*args, **kwargs)
35
+
36
+ def rule(self, title=""):
37
+ print(f"\n{'='*50}")
38
+ if title:
39
+ print(f" {title}")
40
+ print('='*50)
41
+
42
+ console = Console() if RICH_AVAILABLE else SimpleConsole()
43
+
44
+
45
+ # ==================== 常量定义 ====================
46
+
47
+ REGIONS = {
48
+ "asia-east": "东亚(中国、日本、韩国)",
49
+ "asia-south": "东南亚(新加坡、泰国)",
50
+ "europe-west": "西欧(德国、法国、英国)",
51
+ "europe-east": "东欧",
52
+ "america-north": "北美(美国、加拿大)",
53
+ "america-south": "南美",
54
+ "oceania": "大洋洲(澳大利亚)"
55
+ }
56
+
57
+ TASK_TYPES = {
58
+ "llm": "大语言模型推理 (LLM)",
59
+ "image_gen": "图像生成 (Stable Diffusion/FLUX)",
60
+ "whisper": "语音识别 (Whisper)",
61
+ "embedding": "文本嵌入 (Embedding)"
62
+ }
63
+
64
+ DEFAULT_SERVER_URL = "https://gpu-inference.example.com"
65
+
66
+ CONFIG_FILE = "config.yaml"
67
+
68
+
69
+ # ==================== 工具函数 ====================
70
+
71
+ def clear_screen():
72
+ """清屏"""
73
+ os.system('cls' if platform.system() == 'Windows' else 'clear')
74
+
75
+
76
+ def check_gpu():
77
+ """检测GPU信息"""
78
+ gpu_info = {
79
+ "available": False,
80
+ "count": 0,
81
+ "model": "Unknown",
82
+ "memory_gb": 0
83
+ }
84
+
85
+ try:
86
+ import torch
87
+ if torch.cuda.is_available():
88
+ gpu_info["available"] = True
89
+ gpu_info["count"] = torch.cuda.device_count()
90
+ gpu_info["model"] = torch.cuda.get_device_name(0)
91
+ props = torch.cuda.get_device_properties(0)
92
+ gpu_info["memory_gb"] = round(props.total_memory / 1024**3, 1)
93
+ except ImportError:
94
+ pass
95
+
96
+ return gpu_info
97
+
98
+
99
+ def check_dependencies() -> Dict[str, bool]:
100
+ """检查依赖"""
101
+ deps = {
102
+ "python": sys.version_info >= (3, 9),
103
+ "torch": False,
104
+ "transformers": False,
105
+ "cuda": False
106
+ }
107
+
108
+ try:
109
+ import torch
110
+ deps["torch"] = True
111
+ deps["cuda"] = torch.cuda.is_available()
112
+ except ImportError:
113
+ pass
114
+
115
+ try:
116
+ import transformers
117
+ deps["transformers"] = True
118
+ except ImportError:
119
+ pass
120
+
121
+ return deps
122
+
123
+
124
+ def install_dependencies(progress_callback=None):
125
+ """安装依赖"""
126
+ requirements = [
127
+ "torch>=2.0.0",
128
+ "transformers>=4.35.0",
129
+ "diffusers>=0.24.0",
130
+ "accelerate>=0.24.0",
131
+ "peft>=0.6.0",
132
+ "bitsandbytes>=0.41.0",
133
+ "httpx>=0.25.0",
134
+ "pyyaml>=6.0",
135
+ "pydantic>=2.0.0",
136
+ "fastapi>=0.100.0",
137
+ "uvicorn>=0.23.0"
138
+ ]
139
+
140
+ for req in requirements:
141
+ if progress_callback:
142
+ progress_callback(f"Installing {req}...")
143
+
144
+ result = subprocess.run(
145
+ [sys.executable, "-m", "pip", "install", req],
146
+ capture_output=True,
147
+ text=True
148
+ )
149
+
150
+ if result.returncode != 0:
151
+ raise Exception(f"Failed to install {req}: {result.stderr}")
152
+
153
+
154
+ def save_config(config: Dict[str, Any], path: str = CONFIG_FILE):
155
+ """保存配置"""
156
+ import yaml
157
+ with open(path, 'w', encoding='utf-8') as f:
158
+ yaml.dump(config, f, default_flow_style=False, allow_unicode=True)
159
+
160
+
161
+ def load_config(path: str = CONFIG_FILE) -> Optional[Dict[str, Any]]:
162
+ """加载配置"""
163
+ if not os.path.exists(path):
164
+ return None
165
+
166
+ import yaml
167
+ with open(path, encoding='utf-8') as f:
168
+ return yaml.safe_load(f)
169
+
170
+
171
+ # ==================== 交互式配置向导 ====================
172
+
173
+ class ConfigWizard:
174
+ """交互式配置向导"""
175
+
176
+ def __init__(self):
177
+ self.config = {}
178
+
179
+ def run(self) -> Dict[str, Any]:
180
+ """运行配置向导"""
181
+ clear_screen()
182
+ self._show_welcome()
183
+
184
+ # 步骤1: 服务器配置
185
+ self._configure_server()
186
+
187
+ # 步骤2: 区域选择
188
+ self._configure_region()
189
+
190
+ # 步骤3: GPU检测
191
+ self._configure_gpu()
192
+
193
+ # 步骤4: 任务类型
194
+ self._configure_task_types()
195
+
196
+ # 步骤5: 负载控制
197
+ self._configure_load_control()
198
+
199
+ # 步骤6: 直连配置
200
+ self._configure_direct_connection()
201
+
202
+ # 步骤7: 确认配置
203
+ self._confirm_config()
204
+
205
+ return self.config
206
+
207
+ def _show_welcome(self):
208
+ """显示欢迎信息"""
209
+ if RICH_AVAILABLE:
210
+ console.print(Panel.fit(
211
+ "[bold cyan]分布式GPU推理 Worker 配置向导[/bold cyan]\n\n"
212
+ "本向导将帮助您配置 GPU Worker 节点\n"
213
+ "您可以随时按 Ctrl+C 退出",
214
+ title="欢迎",
215
+ border_style="cyan"
216
+ ))
217
+ else:
218
+ console.rule("分布式GPU推理 Worker 配置向导")
219
+ print("\n本向导将帮助您配置 GPU Worker 节点")
220
+ print("您可以随时按 Ctrl+C 退出\n")
221
+
222
+ input("按 Enter 键继续...")
223
+
224
+ def _configure_server(self):
225
+ """配置服务器"""
226
+ console.rule("步骤 1/6: 服务器配置")
227
+
228
+ if RICH_AVAILABLE:
229
+ server_url = Prompt.ask(
230
+ "请输入服务器地址",
231
+ default=DEFAULT_SERVER_URL
232
+ )
233
+
234
+ use_https = Confirm.ask(
235
+ "是否使用 HTTPS(推荐)",
236
+ default=True
237
+ )
238
+ else:
239
+ server_url = input(f"请输入服务器地址 [{DEFAULT_SERVER_URL}]: ").strip()
240
+ if not server_url:
241
+ server_url = DEFAULT_SERVER_URL
242
+
243
+ use_https_input = input("是否使用 HTTPS(推荐)[Y/n]: ").strip().lower()
244
+ use_https = use_https_input != 'n'
245
+
246
+ # 确保URL格式正确
247
+ if not server_url.startswith(('http://', 'https://')):
248
+ server_url = ('https://' if use_https else 'http://') + server_url
249
+
250
+ self.config['server'] = {
251
+ 'url': server_url,
252
+ 'timeout': 30,
253
+ 'verify_ssl': use_https
254
+ }
255
+
256
+ print(f"\n服务器地址: {server_url}")
257
+
258
+ def _configure_region(self):
259
+ """配置区域"""
260
+ console.rule("步骤 2/6: 区域选择")
261
+
262
+ print("\n可用区域:")
263
+ if RICH_AVAILABLE:
264
+ table = Table(show_header=True)
265
+ table.add_column("序号", style="cyan")
266
+ table.add_column("区域代码", style="green")
267
+ table.add_column("描述")
268
+
269
+ for i, (code, desc) in enumerate(REGIONS.items(), 1):
270
+ table.add_row(str(i), code, desc)
271
+
272
+ console.print(table)
273
+
274
+ choice = IntPrompt.ask(
275
+ "请选择您的区域",
276
+ default=1,
277
+ show_default=True
278
+ )
279
+ else:
280
+ for i, (code, desc) in enumerate(REGIONS.items(), 1):
281
+ print(f" {i}. {code} - {desc}")
282
+
283
+ choice = int(input("\n请选择您的区域 [1]: ") or "1")
284
+
285
+ region_codes = list(REGIONS.keys())
286
+ if 1 <= choice <= len(region_codes):
287
+ selected_region = region_codes[choice - 1]
288
+ else:
289
+ selected_region = "asia-east"
290
+
291
+ self.config['region'] = selected_region
292
+ print(f"\n已选择区域: {selected_region} ({REGIONS[selected_region]})")
293
+
294
+ def _configure_gpu(self):
295
+ """检测和配置GPU"""
296
+ console.rule("步骤 3/6: GPU 检测")
297
+
298
+ print("\n正在检测 GPU...")
299
+ gpu_info = check_gpu()
300
+
301
+ if gpu_info["available"]:
302
+ if RICH_AVAILABLE:
303
+ console.print(f"[green]检测到 GPU:[/green]")
304
+ console.print(f" 型号: {gpu_info['model']}")
305
+ console.print(f" 显存: {gpu_info['memory_gb']} GB")
306
+ console.print(f" 数量: {gpu_info['count']}")
307
+ else:
308
+ print(f"检测到 GPU:")
309
+ print(f" 型号: {gpu_info['model']}")
310
+ print(f" 显存: {gpu_info['memory_gb']} GB")
311
+ print(f" 数量: {gpu_info['count']}")
312
+
313
+ self.config['gpu'] = {
314
+ 'model': gpu_info['model'],
315
+ 'memory_gb': gpu_info['memory_gb'],
316
+ 'count': gpu_info['count'],
317
+ 'enable_cpu_offload': gpu_info['memory_gb'] < 16
318
+ }
319
+ else:
320
+ if RICH_AVAILABLE:
321
+ console.print("[yellow]未检测到 GPU,将使用 CPU 模式(性能有限)[/yellow]")
322
+ else:
323
+ print("未检测到 GPU,将使用 CPU 模式(性能有限)")
324
+
325
+ self.config['gpu'] = {
326
+ 'model': 'CPU',
327
+ 'memory_gb': 0,
328
+ 'count': 0,
329
+ 'enable_cpu_offload': True
330
+ }
331
+
332
+ def _configure_task_types(self):
333
+ """配置支持的任务类型"""
334
+ console.rule("步骤 4/6: 任务类型")
335
+
336
+ print("\n可用任务类型:")
337
+ if RICH_AVAILABLE:
338
+ table = Table(show_header=True)
339
+ table.add_column("序号", style="cyan")
340
+ table.add_column("类型", style="green")
341
+ table.add_column("描述")
342
+ table.add_column("推荐显存")
343
+
344
+ requirements = {
345
+ "llm": "8GB+",
346
+ "image_gen": "12GB+",
347
+ "whisper": "4GB+",
348
+ "embedding": "4GB+"
349
+ }
350
+
351
+ for i, (code, desc) in enumerate(TASK_TYPES.items(), 1):
352
+ table.add_row(str(i), code, desc, requirements.get(code, "N/A"))
353
+
354
+ console.print(table)
355
+ else:
356
+ requirements = {
357
+ "llm": "8GB+",
358
+ "image_gen": "12GB+",
359
+ "whisper": "4GB+",
360
+ "embedding": "4GB+"
361
+ }
362
+ for i, (code, desc) in enumerate(TASK_TYPES.items(), 1):
363
+ print(f" {i}. {code} - {desc} (推荐: {requirements.get(code, 'N/A')})")
364
+
365
+ print("\n请输入要支持的任务类型序号(用逗号分隔,如: 1,2)")
366
+
367
+ if RICH_AVAILABLE:
368
+ choices_str = Prompt.ask("选择", default="1")
369
+ else:
370
+ choices_str = input("选择 [1]: ") or "1"
371
+
372
+ type_codes = list(TASK_TYPES.keys())
373
+ selected_types = []
374
+
375
+ for choice in choices_str.split(','):
376
+ try:
377
+ idx = int(choice.strip()) - 1
378
+ if 0 <= idx < len(type_codes):
379
+ selected_types.append(type_codes[idx])
380
+ except ValueError:
381
+ continue
382
+
383
+ if not selected_types:
384
+ selected_types = ["llm"]
385
+
386
+ self.config['supported_types'] = selected_types
387
+ print(f"\n已选择任务类型: {', '.join(selected_types)}")
388
+
389
+ def _configure_load_control(self):
390
+ """配置负载控制"""
391
+ console.rule("步骤 5/6: 负载控制")
392
+
393
+ print("\n配置 Worker 的负载控制参数")
394
+ print("这些设置决定了您的 GPU 如何参与任务处理\n")
395
+
396
+ if RICH_AVAILABLE:
397
+ acceptance_rate = FloatPrompt.ask(
398
+ "任务接受率 (0.1-1.0, 1.0=接受全部任务)",
399
+ default=1.0
400
+ )
401
+
402
+ max_jobs_per_hour = IntPrompt.ask(
403
+ "每小时最大任务数 (0=不限制)",
404
+ default=0
405
+ )
406
+
407
+ # 工作时间配置
408
+ set_working_hours = Confirm.ask(
409
+ "是否设置工作时间段(只在特定时间接受任务)",
410
+ default=False
411
+ )
412
+ else:
413
+ acceptance_rate = float(input("任务接受率 (0.1-1.0) [1.0]: ") or "1.0")
414
+ max_jobs_per_hour = int(input("每小时最大任务数 (0=不限制) [0]: ") or "0")
415
+ set_working_hours = input("是否设置工作时间段 [y/N]: ").lower() == 'y'
416
+
417
+ self.config['load_control'] = {
418
+ 'acceptance_rate': min(1.0, max(0.1, acceptance_rate)),
419
+ 'max_jobs_per_hour': max(0, max_jobs_per_hour),
420
+ 'max_concurrent_jobs': 1,
421
+ 'cooldown_seconds': 0
422
+ }
423
+
424
+ if set_working_hours:
425
+ if RICH_AVAILABLE:
426
+ start_hour = IntPrompt.ask("开始时间(24小时制,如 9 表示 9:00)", default=9)
427
+ end_hour = IntPrompt.ask("结束时间(24小时制,如 22 表示 22:00)", default=22)
428
+ else:
429
+ start_hour = int(input("开始时间(24小时制)[9]: ") or "9")
430
+ end_hour = int(input("结束时间(24小时制)[22]: ") or "22")
431
+
432
+ self.config['load_control']['working_hours_start'] = start_hour
433
+ self.config['load_control']['working_hours_end'] = end_hour
434
+
435
+ def _configure_direct_connection(self):
436
+ """配置直连"""
437
+ console.rule("步骤 6/6: 直连配置")
438
+
439
+ print("\n直连模式允许客户端直接与您的 Worker 通信")
440
+ print("这可以降低延迟,但需要您的机器有公网可访问的地址\n")
441
+
442
+ if RICH_AVAILABLE:
443
+ enable_direct = Confirm.ask(
444
+ "是否启用直连模式",
445
+ default=False
446
+ )
447
+ else:
448
+ enable_direct = input("是否启用直连模式 [y/N]: ").lower() == 'y'
449
+
450
+ self.config['direct'] = {
451
+ 'enabled': enable_direct,
452
+ 'host': '0.0.0.0',
453
+ 'port': 8080,
454
+ 'public_url': None
455
+ }
456
+
457
+ if enable_direct:
458
+ if RICH_AVAILABLE:
459
+ port = IntPrompt.ask("直连端口", default=8080)
460
+ public_url = Prompt.ask(
461
+ "公网访问地址(如 http://your-ip:8080,留空自动检测)",
462
+ default=""
463
+ )
464
+ else:
465
+ port = int(input("直连端口 [8080]: ") or "8080")
466
+ public_url = input("公网访问地址(留空自动检测): ").strip()
467
+
468
+ self.config['direct']['port'] = port
469
+ if public_url:
470
+ self.config['direct']['public_url'] = public_url
471
+
472
+ def _confirm_config(self):
473
+ """确认配置"""
474
+ console.rule("配置确认")
475
+
476
+ print("\n您的配置如下:\n")
477
+
478
+ if RICH_AVAILABLE:
479
+ table = Table(show_header=True)
480
+ table.add_column("配置项", style="cyan")
481
+ table.add_column("值", style="green")
482
+
483
+ table.add_row("服务器", self.config['server']['url'])
484
+ table.add_row("区域", f"{self.config['region']} ({REGIONS[self.config['region']]})")
485
+ table.add_row("GPU", self.config['gpu']['model'])
486
+ table.add_row("任务类型", ", ".join(self.config['supported_types']))
487
+ table.add_row("任务接受率", f"{self.config['load_control']['acceptance_rate']*100:.0f}%")
488
+ table.add_row("直连模式", "启用" if self.config['direct']['enabled'] else "禁用")
489
+
490
+ console.print(table)
491
+
492
+ confirm = Confirm.ask("\n确认保存配置", default=True)
493
+ else:
494
+ print(f" 服务器: {self.config['server']['url']}")
495
+ print(f" 区域: {self.config['region']}")
496
+ print(f" GPU: {self.config['gpu']['model']}")
497
+ print(f" 任务类型: {', '.join(self.config['supported_types'])}")
498
+ print(f" 任务接受率: {self.config['load_control']['acceptance_rate']*100:.0f}%")
499
+ print(f" 直连模式: {'启用' if self.config['direct']['enabled'] else '禁用'}")
500
+
501
+ confirm = input("\n确认保存配置 [Y/n]: ").lower() != 'n'
502
+
503
+ if confirm:
504
+ save_config(self.config)
505
+ print(f"\n配置已保存到 {CONFIG_FILE}")
506
+ else:
507
+ print("\n配置已取消")
508
+ sys.exit(0)
509
+
510
+
511
+ # ==================== CLI 命令 ====================
512
+
513
+ def cmd_install(args):
514
+ """安装依赖"""
515
+ console.rule("安装依赖")
516
+
517
+ print("\n正在检查依赖...")
518
+ deps = check_dependencies()
519
+
520
+ missing = [k for k, v in deps.items() if not v and k != 'cuda']
521
+
522
+ if not missing:
523
+ print("所有依赖已安装!")
524
+ return
525
+
526
+ print(f"缺少依赖: {', '.join(missing)}")
527
+
528
+ if RICH_AVAILABLE:
529
+ if not Confirm.ask("是否安装缺少的依赖"):
530
+ return
531
+ else:
532
+ if input("是否安装缺少的依赖 [Y/n]: ").lower() == 'n':
533
+ return
534
+
535
+ print("\n开始安装...")
536
+
537
+ try:
538
+ if RICH_AVAILABLE:
539
+ with Progress(
540
+ SpinnerColumn(),
541
+ TextColumn("[progress.description]{task.description}"),
542
+ console=console
543
+ ) as progress:
544
+ task = progress.add_task("Installing...", total=None)
545
+
546
+ def update_progress(msg):
547
+ progress.update(task, description=msg)
548
+
549
+ install_dependencies(update_progress)
550
+ else:
551
+ install_dependencies(lambda msg: print(f" {msg}"))
552
+
553
+ print("\n依赖安装完成!")
554
+
555
+ except Exception as e:
556
+ print(f"\n安装失败: {e}")
557
+ sys.exit(1)
558
+
559
+
560
+ def cmd_configure(args):
561
+ """交互式配置"""
562
+ wizard = ConfigWizard()
563
+ wizard.run()
564
+
565
+
566
+ def cmd_start(args):
567
+ """启动 Worker"""
568
+ config = load_config()
569
+
570
+ if not config:
571
+ print("未找到配置文件,请先运行 'gpu-worker configure'")
572
+ sys.exit(1)
573
+
574
+ console.rule("启动 Worker")
575
+
576
+ print(f"\n服务器: {config['server']['url']}")
577
+ print(f"区域: {config['region']}")
578
+ print(f"任务类型: {', '.join(config['supported_types'])}")
579
+ print()
580
+
581
+ # 导入并启动 Worker
582
+ try:
583
+ from main import Worker
584
+ from config import WorkerConfig
585
+
586
+ # 转换配置格式
587
+ worker_config = WorkerConfig(**config)
588
+ worker = Worker(worker_config)
589
+ worker.start()
590
+
591
+ except ImportError as e:
592
+ print(f"导入失败: {e}")
593
+ print("请确保已安装所有依赖: gpu-worker install")
594
+ sys.exit(1)
595
+ except KeyboardInterrupt:
596
+ print("\nWorker 已停止")
597
+
598
+
599
+ def cmd_status(args):
600
+ """查看状态"""
601
+ config = load_config()
602
+
603
+ if not config:
604
+ print("未找到配置文件")
605
+ return
606
+
607
+ console.rule("Worker 状态")
608
+
609
+ # 检查依赖
610
+ deps = check_dependencies()
611
+
612
+ if RICH_AVAILABLE:
613
+ table = Table(title="系统状态")
614
+ table.add_column("项目", style="cyan")
615
+ table.add_column("状态", style="green")
616
+
617
+ table.add_row("Python", "OK" if deps["python"] else "需要 3.9+")
618
+ table.add_row("PyTorch", "OK" if deps["torch"] else "未安装")
619
+ table.add_row("CUDA", "可用" if deps["cuda"] else "不可用")
620
+ table.add_row("Transformers", "OK" if deps["transformers"] else "未安装")
621
+
622
+ console.print(table)
623
+ else:
624
+ print(f"\nPython: {'OK' if deps['python'] else '需要 3.9+'}")
625
+ print(f"PyTorch: {'OK' if deps['torch'] else '未安装'}")
626
+ print(f"CUDA: {'可用' if deps['cuda'] else '不可用'}")
627
+ print(f"Transformers: {'OK' if deps['transformers'] else '未安装'}")
628
+
629
+ # GPU信息
630
+ gpu_info = check_gpu()
631
+ print(f"\nGPU: {gpu_info['model'] if gpu_info['available'] else '未检测到'}")
632
+ if gpu_info['available']:
633
+ print(f"显存: {gpu_info['memory_gb']} GB")
634
+
635
+ # 配置信息
636
+ print(f"\n配置:")
637
+ print(f" 服务器: {config.get('server', {}).get('url', '未配置')}")
638
+ print(f" 区域: {config.get('region', '未配置')}")
639
+ print(f" 任务类型: {', '.join(config.get('supported_types', []))}")
640
+
641
+
642
+ def cmd_config_set(args):
643
+ """设置单个配置项"""
644
+ config = load_config() or {}
645
+
646
+ key = args.key
647
+ value = args.value
648
+
649
+ # 解析键路径 (如 load_control.acceptance_rate)
650
+ keys = key.split('.')
651
+ current = config
652
+
653
+ for k in keys[:-1]:
654
+ if k not in current:
655
+ current[k] = {}
656
+ current = current[k]
657
+
658
+ # 尝试解析值类型
659
+ try:
660
+ if value.lower() == 'true':
661
+ value = True
662
+ elif value.lower() == 'false':
663
+ value = False
664
+ elif '.' in value:
665
+ value = float(value)
666
+ else:
667
+ value = int(value)
668
+ except ValueError:
669
+ pass # 保持字符串
670
+
671
+ current[keys[-1]] = value
672
+ save_config(config)
673
+
674
+ print(f"已设置 {key} = {value}")
675
+
676
+
677
+ # ==================== 主入口 ====================
678
+
679
+ def main():
680
+ parser = argparse.ArgumentParser(
681
+ description="分布式GPU推理 Worker CLI",
682
+ formatter_class=argparse.RawDescriptionHelpFormatter,
683
+ epilog="""
684
+ 示例:
685
+ gpu-worker install 安装依赖
686
+ gpu-worker configure 交互式配置
687
+ gpu-worker start 启动 Worker
688
+ gpu-worker status 查看状态
689
+ gpu-worker set load_control.acceptance_rate 0.5
690
+ """
691
+ )
692
+
693
+ subparsers = parser.add_subparsers(dest='command', help='可用命令')
694
+
695
+ # install 命令
696
+ subparsers.add_parser('install', help='安装依赖')
697
+
698
+ # configure 命令
699
+ subparsers.add_parser('configure', help='交互式配置向导')
700
+
701
+ # start 命令
702
+ subparsers.add_parser('start', help='启动 Worker')
703
+
704
+ # status 命令
705
+ subparsers.add_parser('status', help='查看状态')
706
+
707
+ # set 命令
708
+ set_parser = subparsers.add_parser('set', help='设置配置项')
709
+ set_parser.add_argument('key', help='配置键 (如 load_control.acceptance_rate)')
710
+ set_parser.add_argument('value', help='配置值')
711
+
712
+ args = parser.parse_args()
713
+
714
+ if args.command == 'install':
715
+ cmd_install(args)
716
+ elif args.command == 'configure':
717
+ cmd_configure(args)
718
+ elif args.command == 'start':
719
+ cmd_start(args)
720
+ elif args.command == 'status':
721
+ cmd_status(args)
722
+ elif args.command == 'set':
723
+ cmd_config_set(args)
724
+ else:
725
+ parser.print_help()
726
+
727
+
728
+ if __name__ == '__main__':
729
+ main()