gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,32 @@
1
+ # 2GB 显存配置
2
+ name: "Worker-2GB"
3
+ region: "asia-east"
4
+ country: "China"
5
+ city: "Shanghai"
6
+
7
+ server:
8
+ url: "http://服务器A的IP:8880" # 修改为服务器A的实际IP
9
+ timeout: 30
10
+ verify_ssl: false
11
+
12
+ gpu:
13
+ device_id: 0
14
+ enable_cpu_offload: true
15
+
16
+ # 2GB显存只能运行小模型
17
+ supported_types:
18
+ - whisper
19
+ # - embedding
20
+
21
+ engines:
22
+ whisper:
23
+ model_id: "openai/whisper-small" # 约2GB显存
24
+ # embedding:
25
+ # model_id: "BAAI/bge-small-zh-v1.5" # 约0.5GB
26
+
27
+ heartbeat_interval: 30
28
+ poll_interval: 2
29
+
30
+ load_control:
31
+ acceptance_rate: 1.0
32
+ max_concurrent_jobs: 1
@@ -0,0 +1,29 @@
1
+ # 8GB 显存配置
2
+ name: "Worker-8GB"
3
+ region: "asia-east"
4
+ country: "China"
5
+ city: "Shanghai"
6
+
7
+ server:
8
+ url: "http://127.0.0.1:8880"
9
+ timeout: 30
10
+ verify_ssl: false
11
+
12
+ gpu:
13
+ device_id: 0
14
+ enable_cpu_offload: true # 显存不足时使用CPU内存
15
+
16
+ # 8GB显存推荐只运行一种任务
17
+ supported_types:
18
+ - llm
19
+
20
+ engines:
21
+ llm:
22
+ model_id: "Qwen/Qwen2.5-1.5B-Instruct" # 约3GB显存
23
+
24
+ heartbeat_interval: 30
25
+ poll_interval: 2
26
+
27
+ load_control:
28
+ acceptance_rate: 1.0
29
+ max_concurrent_jobs: 1
@@ -0,0 +1,72 @@
1
+ # Worker配置示例 - 增强版
2
+
3
+ # Worker标识(首次运行后会自动填充)
4
+ worker_id: null
5
+ token: null
6
+ name: "My GPU Worker"
7
+
8
+ # 地理信息(重要:影响任务分配)
9
+ region: "asia-east" # 区域代码,见下方说明
10
+ country: "China"
11
+ city: "Shanghai"
12
+ timezone: "Asia/Shanghai"
13
+
14
+ # 区域代码说明:
15
+ # - asia-east: 东亚(中国、日本、韩国)
16
+ # - asia-south: 东南亚(新加坡、泰国)
17
+ # - europe-west: 西欧(德国、法国、英国)
18
+ # - europe-east: 东欧
19
+ # - america-north: 北美(美国、加拿大)
20
+ # - america-south: 南美
21
+ # - oceania: 大洋洲(澳大利亚)
22
+
23
+ # 服务器配置
24
+ server:
25
+ url: "http://localhost:8000" # 中央服务器地址
26
+ timeout: 30
27
+ verify_ssl: true
28
+
29
+ # GPU配置
30
+ gpu:
31
+ enable_cpu_offload: true # 启用CPU Offload节省显存
32
+ max_memory_gb: null # 限制最大显存使用(null=不限制)
33
+ device_id: 0 # 使用的GPU设备ID
34
+
35
+ # 直连配置(可选,用于P2P低延迟场景)
36
+ direct:
37
+ enabled: false # 是否启用直连服务
38
+ host: "0.0.0.0"
39
+ port: 8080
40
+ public_url: null # 公网可访问的URL,如 "http://your-ip:8080"
41
+
42
+ # 支持的任务类型
43
+ supported_types:
44
+ - "llm"
45
+ - "image_gen"
46
+ # - "whisper"
47
+ # - "embedding"
48
+
49
+ # 引擎配置
50
+ engines:
51
+ llm:
52
+ model_id: "Qwen/Qwen2.5-7B-Instruct"
53
+ # 可选配置:
54
+ # max_new_tokens: 2048
55
+ # temperature: 0.7
56
+
57
+ image_gen:
58
+ model_id: "black-forest-labs/FLUX.1-schnell"
59
+ # 可选配置:
60
+ # default_steps: 4
61
+ # default_width: 1024
62
+ # default_height: 1024
63
+
64
+ # whisper:
65
+ # model_id: "openai/whisper-large-v3"
66
+
67
+ # embedding:
68
+ # model_id: "BAAI/bge-large-zh-v1.5"
69
+
70
+ # 轮询配置
71
+ heartbeat_interval: 30 # 心跳间隔(秒)
72
+ poll_interval: 2 # 任务轮询间隔(秒)
package/config.py ADDED
@@ -0,0 +1,213 @@
1
+ """
2
+ Worker配置 - 支持环境变量和YAML配置
3
+ 优先级: 环境变量 > config.yaml > 默认值
4
+ """
5
+ from pydantic import BaseModel, Field
6
+ from typing import List, Optional, Dict, Any
7
+ import yaml
8
+ from pathlib import Path
9
+ import os
10
+
11
+
12
+ def get_env(key: str, default: Any = None, cast: type = str) -> Any:
13
+ """获取环境变量并转换类型"""
14
+ value = os.getenv(key)
15
+ if value is None:
16
+ return default
17
+
18
+ if cast == bool:
19
+ return value.lower() in ('true', '1', 'yes', 'on')
20
+ elif cast == list:
21
+ return [x.strip() for x in value.split(',') if x.strip()]
22
+
23
+ try:
24
+ return cast(value)
25
+ except (ValueError, TypeError):
26
+ return default
27
+
28
+
29
+ class ServerConfig(BaseModel):
30
+ """服务器配置"""
31
+ url: str = Field(default_factory=lambda: get_env('GPU_SERVER_URL', 'http://localhost:8000'))
32
+ timeout: int = Field(default_factory=lambda: get_env('GPU_SERVER_TIMEOUT', 30, int))
33
+ verify_ssl: bool = Field(default_factory=lambda: get_env('GPU_SERVER_VERIFY_SSL', True, bool))
34
+
35
+
36
+ class GPUConfig(BaseModel):
37
+ """GPU配置"""
38
+ enable_cpu_offload: bool = Field(default_factory=lambda: get_env('GPU_ENABLE_CPU_OFFLOAD', True, bool))
39
+ max_memory_gb: Optional[float] = Field(default_factory=lambda: get_env('GPU_MAX_MEMORY_GB', None, float))
40
+ device_id: int = Field(default_factory=lambda: get_env('GPU_DEVICE_ID', 0, int))
41
+
42
+
43
+ class DirectConfig(BaseModel):
44
+ """直连配置"""
45
+ enabled: bool = Field(default_factory=lambda: get_env('GPU_DIRECT_ENABLED', False, bool))
46
+ host: str = Field(default_factory=lambda: get_env('GPU_DIRECT_HOST', '0.0.0.0'))
47
+ port: int = Field(default_factory=lambda: get_env('GPU_DIRECT_PORT', 8080, int))
48
+ public_url: Optional[str] = Field(default_factory=lambda: get_env('GPU_DIRECT_PUBLIC_URL', None))
49
+
50
+
51
+ class LoadControlConfig(BaseModel):
52
+ """负载控制配置"""
53
+ acceptance_rate: float = Field(default_factory=lambda: get_env('GPU_ACCEPTANCE_RATE', 1.0, float))
54
+ max_concurrent_jobs: int = Field(default_factory=lambda: get_env('GPU_MAX_CONCURRENT_JOBS', 1, int))
55
+ max_jobs_per_hour: int = Field(default_factory=lambda: get_env('GPU_MAX_JOBS_PER_HOUR', 0, int))
56
+ working_hours_start: Optional[int] = Field(default_factory=lambda: get_env('GPU_WORKING_HOURS_START', None, int))
57
+ working_hours_end: Optional[int] = Field(default_factory=lambda: get_env('GPU_WORKING_HOURS_END', None, int))
58
+
59
+
60
+ class WorkerConfig(BaseModel):
61
+ """Worker配置"""
62
+
63
+ # Worker标识(首次运行后自动填充)
64
+ worker_id: Optional[str] = Field(default_factory=lambda: get_env('GPU_WORKER_ID', None))
65
+ token: Optional[str] = Field(default_factory=lambda: get_env('GPU_WORKER_TOKEN', None))
66
+ name: Optional[str] = Field(default_factory=lambda: get_env('GPU_WORKER_NAME', None))
67
+
68
+ # 地理信息
69
+ region: str = Field(default_factory=lambda: get_env('GPU_REGION', 'asia-east'))
70
+ country: Optional[str] = Field(default_factory=lambda: get_env('GPU_COUNTRY', None))
71
+ city: Optional[str] = Field(default_factory=lambda: get_env('GPU_CITY', None))
72
+ timezone: Optional[str] = Field(default_factory=lambda: get_env('GPU_TIMEZONE', None))
73
+
74
+ # 服务器配置
75
+ server: ServerConfig = Field(default_factory=ServerConfig)
76
+
77
+ # GPU配置
78
+ gpu: GPUConfig = Field(default_factory=GPUConfig)
79
+
80
+ # 直连配置
81
+ direct: DirectConfig = Field(default_factory=DirectConfig)
82
+
83
+ # 负载控制
84
+ load_control: LoadControlConfig = Field(default_factory=LoadControlConfig)
85
+
86
+ # 支持的任务类型
87
+ supported_types: List[str] = Field(
88
+ default_factory=lambda: get_env('GPU_SUPPORTED_TYPES', ['llm', 'image_gen'], list)
89
+ )
90
+
91
+ # 引擎配置
92
+ engines: Dict[str, Dict[str, Any]] = Field(default_factory=dict)
93
+
94
+ # 轮询配置
95
+ heartbeat_interval: int = Field(default_factory=lambda: get_env('GPU_HEARTBEAT_INTERVAL', 30, int))
96
+ poll_interval: int = Field(default_factory=lambda: get_env('GPU_POLL_INTERVAL', 2, int))
97
+
98
+ def save(self, path: str = "config.yaml"):
99
+ """保存配置到YAML文件"""
100
+ data = self.model_dump()
101
+ with open(path, "w", encoding="utf-8") as f:
102
+ yaml.dump(data, f, default_flow_style=False, allow_unicode=True)
103
+
104
+ @classmethod
105
+ def from_env(cls) -> 'WorkerConfig':
106
+ """从环境变量创建配置"""
107
+ return cls()
108
+
109
+
110
+ def load_dotenv(path: str = ".env"):
111
+ """加载.env文件到环境变量"""
112
+ env_path = Path(path)
113
+ if not env_path.exists():
114
+ return
115
+
116
+ with open(env_path, encoding='utf-8') as f:
117
+ for line in f:
118
+ line = line.strip()
119
+ # 跳过空行和注释
120
+ if not line or line.startswith('#'):
121
+ continue
122
+
123
+ # 解析 KEY=VALUE
124
+ if '=' in line:
125
+ key, _, value = line.partition('=')
126
+ key = key.strip()
127
+ value = value.strip()
128
+
129
+ # 移除引号
130
+ if value and value[0] in ('"', "'") and value[-1] == value[0]:
131
+ value = value[1:-1]
132
+
133
+ # 只设置未定义的环境变量
134
+ if key and key not in os.environ:
135
+ os.environ[key] = value
136
+
137
+
138
+ def load_config(path: str = "config.yaml") -> WorkerConfig:
139
+ """
140
+ 加载配置
141
+ 优先级: 环境变量 > config.yaml > 默认值
142
+ """
143
+ # 首先加载.env文件
144
+ load_dotenv()
145
+
146
+ config_path = Path(path)
147
+
148
+ if config_path.exists():
149
+ with open(config_path, encoding="utf-8") as f:
150
+ data = yaml.safe_load(f) or {}
151
+
152
+ # 处理嵌套配置
153
+ if "server" in data and isinstance(data["server"], dict):
154
+ data["server"] = ServerConfig(**data["server"])
155
+ if "gpu" in data and isinstance(data["gpu"], dict):
156
+ data["gpu"] = GPUConfig(**data["gpu"])
157
+ if "direct" in data and isinstance(data["direct"], dict):
158
+ data["direct"] = DirectConfig(**data["direct"])
159
+ if "load_control" in data and isinstance(data["load_control"], dict):
160
+ data["load_control"] = LoadControlConfig(**data["load_control"])
161
+
162
+ config = WorkerConfig(**data)
163
+ else:
164
+ # 仅从环境变量创建配置
165
+ config = WorkerConfig()
166
+
167
+ # 从环境变量加载引擎配置
168
+ _load_engine_configs_from_env(config)
169
+
170
+ return config
171
+
172
+
173
+ def _load_engine_configs_from_env(config: WorkerConfig):
174
+ """从环境变量加载引擎配置"""
175
+ env_models = {
176
+ 'llm': get_env('GPU_LLM_MODEL'),
177
+ 'image_gen': get_env('GPU_IMAGE_MODEL'),
178
+ 'vision': get_env('GPU_VISION_MODEL'),
179
+ 'whisper': get_env('GPU_WHISPER_MODEL'),
180
+ 'embedding': get_env('GPU_EMBEDDING_MODEL'),
181
+ }
182
+
183
+ for engine_type, model_id in env_models.items():
184
+ if model_id:
185
+ if engine_type not in config.engines:
186
+ config.engines[engine_type] = {}
187
+ config.engines[engine_type]['model_id'] = model_id
188
+
189
+
190
+ # 默认引擎配置
191
+ DEFAULT_ENGINE_CONFIGS = {
192
+ "llm": {
193
+ "model_id": "Qwen/Qwen2.5-7B-Instruct",
194
+ "max_new_tokens": 2048,
195
+ "temperature": 0.7,
196
+ },
197
+ "image_gen": {
198
+ "model_id": "Zhihu-ai/Z-Image-Turbo",
199
+ "default_steps": 4,
200
+ "default_width": 1024,
201
+ "default_height": 1024,
202
+ },
203
+ "vision": {
204
+ "model_id": "THUDM/glm-4v-9b",
205
+ "max_new_tokens": 1024,
206
+ },
207
+ "whisper": {
208
+ "model_id": "openai/whisper-large-v3",
209
+ },
210
+ "embedding": {
211
+ "model_id": "BAAI/bge-large-zh-v1.5",
212
+ }
213
+ }
@@ -0,0 +1,140 @@
1
+ """
2
+ 直连服务器 - 允许客户端直接与Worker通信
3
+ 跳过中央服务器,降低延迟
4
+ """
5
+ from fastapi import FastAPI, HTTPException
6
+ from pydantic import BaseModel
7
+ from typing import Optional, Dict, Any
8
+ import uvicorn
9
+ import logging
10
+ from threading import Thread
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class DirectInferenceRequest(BaseModel):
16
+ """直连推理请求"""
17
+ type: str
18
+ params: Dict[str, Any]
19
+ timeout_seconds: int = 300
20
+
21
+
22
+ class DirectInferenceResponse(BaseModel):
23
+ """直连推理响应"""
24
+ success: bool
25
+ result: Optional[Dict[str, Any]] = None
26
+ error: Optional[str] = None
27
+ processing_time_ms: int = 0
28
+
29
+
30
+ class DirectServer:
31
+ """
32
+ 直连服务器
33
+
34
+ 允许客户端绕过中央服务器,直接与Worker通信
35
+ 适用于低延迟场景
36
+ """
37
+
38
+ def __init__(self, worker, host: str = "0.0.0.0", port: int = 8080):
39
+ self.worker = worker
40
+ self.host = host
41
+ self.port = port
42
+ self.app = FastAPI(title="Worker Direct API")
43
+ self._setup_routes()
44
+ self.server = None
45
+
46
+ def _setup_routes(self):
47
+ """设置路由"""
48
+
49
+ @self.app.get("/health")
50
+ async def health():
51
+ return {
52
+ "status": "healthy",
53
+ "worker_id": self.worker.worker_id,
54
+ "worker_status": self.worker.status,
55
+ "supported_types": list(self.worker.engines.keys())
56
+ }
57
+
58
+ @self.app.get("/status")
59
+ async def status():
60
+ gpu_info = self.worker._get_gpu_info()
61
+ return {
62
+ "worker_id": self.worker.worker_id,
63
+ "status": self.worker.status,
64
+ "current_job": self.worker.current_job_id,
65
+ "supported_types": list(self.worker.engines.keys()),
66
+ "gpu_info": gpu_info,
67
+ "accepting_jobs": self.worker.accepting_jobs
68
+ }
69
+
70
+ @self.app.post("/inference", response_model=DirectInferenceResponse)
71
+ async def direct_inference(request: DirectInferenceRequest):
72
+ """
73
+ 直连推理接口
74
+
75
+ 客户端可以直接调用此接口进行推理,跳过中央服务器
76
+ """
77
+ import time
78
+
79
+ # 检查是否接受任务
80
+ if not self.worker.accepting_jobs:
81
+ raise HTTPException(503, "Worker is going offline")
82
+
83
+ # 检查是否空闲
84
+ if self.worker.status != "idle":
85
+ raise HTTPException(503, "Worker is busy")
86
+
87
+ # 检查引擎
88
+ engine = self.worker.engines.get(request.type)
89
+ if not engine:
90
+ raise HTTPException(
91
+ 400,
92
+ f"Unsupported type: {request.type}. "
93
+ f"Supported: {list(self.worker.engines.keys())}"
94
+ )
95
+
96
+ # 标记为忙碌
97
+ self.worker.status = "busy"
98
+
99
+ try:
100
+ start_time = time.time()
101
+ result = engine.inference(request.params)
102
+ processing_time_ms = int((time.time() - start_time) * 1000)
103
+
104
+ return DirectInferenceResponse(
105
+ success=True,
106
+ result=result,
107
+ processing_time_ms=processing_time_ms
108
+ )
109
+
110
+ except Exception as e:
111
+ logger.error(f"Direct inference error: {e}")
112
+ return DirectInferenceResponse(
113
+ success=False,
114
+ error=str(e)
115
+ )
116
+
117
+ finally:
118
+ self.worker.status = "idle"
119
+
120
+ def start(self):
121
+ """启动服务器(阻塞)"""
122
+ config = uvicorn.Config(
123
+ self.app,
124
+ host=self.host,
125
+ port=self.port,
126
+ log_level="warning"
127
+ )
128
+ self.server = uvicorn.Server(config)
129
+ self.server.run()
130
+
131
+ def start_background(self):
132
+ """后台启动服务器"""
133
+ thread = Thread(target=self.start, daemon=True)
134
+ thread.start()
135
+ return thread
136
+
137
+ def stop(self):
138
+ """停止服务器"""
139
+ if self.server:
140
+ self.server.should_exit = True
@@ -0,0 +1,35 @@
1
+ """分布式推理组件
2
+
3
+ 实现跨 Worker 的模型分片推理,参考 Petals 项目设计:
4
+ - DistributedInferenceSession: 分布式推理会话管理
5
+ - WorkerSession: Worker 级别会话
6
+ - ModelShard: 模型分片加载器
7
+ - GRPCServer: Worker 间 P2P 通信
8
+ """
9
+ from .session import (
10
+ DistributedInferenceSession,
11
+ WorkerSession,
12
+ SessionManager,
13
+ )
14
+ from .model_shard import (
15
+ ModelShard,
16
+ ShardedModelLoader,
17
+ get_layer_range_for_worker,
18
+ )
19
+ from .kv_cache import (
20
+ DistributedKVCacheManager,
21
+ PagedKVCache,
22
+ KVCachePool,
23
+ )
24
+
25
+ __all__ = [
26
+ "DistributedInferenceSession",
27
+ "WorkerSession",
28
+ "SessionManager",
29
+ "ModelShard",
30
+ "ShardedModelLoader",
31
+ "get_layer_range_for_worker",
32
+ "DistributedKVCacheManager",
33
+ "PagedKVCache",
34
+ "KVCachePool",
35
+ ]