gpu-worker 1.0.0

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,436 @@
1
+ """
2
+ 连续批处理器 (Continuous Batcher)
3
+
4
+ 实现动态批处理,将多个请求合并为一个批次执行,
5
+ 提升 GPU 利用率和整体吞吐量。
6
+
7
+ 支持:
8
+ - 动态批处理大小调整
9
+ - 请求优先级队列
10
+ - 超时控制
11
+ - 前缀共享优化
12
+ """
13
+ import asyncio
14
+ import time
15
+ import logging
16
+ from dataclasses import dataclass, field
17
+ from typing import Dict, Any, List, Optional, Callable, Awaitable
18
+ from enum import Enum
19
+ from collections import defaultdict
20
+ import heapq
21
+
22
+ logger = logging.getLogger(__name__)
23
+
24
+
25
+ class RequestPriority(Enum):
26
+ """请求优先级"""
27
+ HIGH = 0
28
+ NORMAL = 1
29
+ LOW = 2
30
+
31
+
32
+ @dataclass(order=True)
33
+ class PendingRequest:
34
+ """待处理请求"""
35
+ priority: int
36
+ timestamp: float
37
+ job_id: str = field(compare=False)
38
+ params: Dict[str, Any] = field(compare=False)
39
+ future: asyncio.Future = field(compare=False)
40
+ prefix_hash: str = field(compare=False, default="")
41
+
42
+ @classmethod
43
+ def create(
44
+ cls,
45
+ job_id: str,
46
+ params: Dict[str, Any],
47
+ priority: RequestPriority = RequestPriority.NORMAL,
48
+ prefix_hash: str = ""
49
+ ) -> "PendingRequest":
50
+ return cls(
51
+ priority=priority.value,
52
+ timestamp=time.time(),
53
+ job_id=job_id,
54
+ params=params,
55
+ future=asyncio.Future(),
56
+ prefix_hash=prefix_hash
57
+ )
58
+
59
+
60
+ class ContinuousBatcher:
61
+ """
62
+ 连续批处理器
63
+
64
+ 将多个推理请求动态合并为批次执行,支持:
65
+ - 最大批处理大小限制
66
+ - 最大等待时间控制
67
+ - 优先级队列
68
+ - 前缀共享批处理
69
+ """
70
+
71
+ def __init__(
72
+ self,
73
+ engine,
74
+ max_batch_size: int = 32,
75
+ max_wait_ms: float = 50,
76
+ enable_prefix_grouping: bool = True,
77
+ max_queue_size: int = 1000,
78
+ ):
79
+ self.engine = engine
80
+ self.max_batch_size = max_batch_size
81
+ self.max_wait_ms = max_wait_ms
82
+ self.enable_prefix_grouping = enable_prefix_grouping
83
+ self.max_queue_size = max_queue_size
84
+
85
+ # 请求队列(优先级堆)
86
+ self._pending: List[PendingRequest] = []
87
+ self._pending_by_prefix: Dict[str, List[PendingRequest]] = defaultdict(list)
88
+
89
+ # 批处理任务
90
+ self._batch_task: Optional[asyncio.Task] = None
91
+ self._lock = asyncio.Lock()
92
+
93
+ # 统计信息
94
+ self._stats = {
95
+ "total_requests": 0,
96
+ "total_batches": 0,
97
+ "avg_batch_size": 0.0,
98
+ "avg_wait_time_ms": 0.0,
99
+ }
100
+
101
+ # 运行状态
102
+ self._running = False
103
+
104
+ async def start(self) -> None:
105
+ """启动批处理器"""
106
+ self._running = True
107
+ logger.info("ContinuousBatcher started")
108
+
109
+ async def stop(self) -> None:
110
+ """停止批处理器"""
111
+ self._running = False
112
+
113
+ # 取消所有待处理请求
114
+ async with self._lock:
115
+ for req in self._pending:
116
+ if not req.future.done():
117
+ req.future.cancel()
118
+ self._pending.clear()
119
+ self._pending_by_prefix.clear()
120
+
121
+ if self._batch_task:
122
+ self._batch_task.cancel()
123
+ try:
124
+ await self._batch_task
125
+ except asyncio.CancelledError:
126
+ pass
127
+
128
+ logger.info("ContinuousBatcher stopped")
129
+
130
+ async def submit(
131
+ self,
132
+ job_id: str,
133
+ params: Dict[str, Any],
134
+ priority: RequestPriority = RequestPriority.NORMAL,
135
+ timeout: float = 120.0,
136
+ ) -> Dict[str, Any]:
137
+ """
138
+ 提交推理请求
139
+
140
+ Args:
141
+ job_id: 任务ID
142
+ params: 推理参数
143
+ priority: 请求优先级
144
+ timeout: 超时时间(秒)
145
+
146
+ Returns:
147
+ 推理结果
148
+ """
149
+ if not self._running:
150
+ raise RuntimeError("Batcher is not running")
151
+
152
+ if len(self._pending) >= self.max_queue_size:
153
+ raise RuntimeError(f"Queue full (max={self.max_queue_size})")
154
+
155
+ # 计算前缀哈希(用于分组)
156
+ prefix_hash = ""
157
+ if self.enable_prefix_grouping:
158
+ prefix_hash = self._compute_prefix_hash(params)
159
+
160
+ # 创建请求
161
+ request = PendingRequest.create(
162
+ job_id=job_id,
163
+ params=params,
164
+ priority=priority,
165
+ prefix_hash=prefix_hash
166
+ )
167
+
168
+ async with self._lock:
169
+ heapq.heappush(self._pending, request)
170
+
171
+ if self.enable_prefix_grouping and prefix_hash:
172
+ self._pending_by_prefix[prefix_hash].append(request)
173
+
174
+ self._stats["total_requests"] += 1
175
+
176
+ # 检查是否应该触发批处理
177
+ if len(self._pending) >= self.max_batch_size:
178
+ # 立即处理满批次
179
+ asyncio.create_task(self._process_batch())
180
+ elif self._batch_task is None or self._batch_task.done():
181
+ # 启动等待定时器
182
+ self._batch_task = asyncio.create_task(self._wait_and_process())
183
+
184
+ # 等待结果
185
+ try:
186
+ return await asyncio.wait_for(request.future, timeout=timeout)
187
+ except asyncio.TimeoutError:
188
+ # 超时,尝试从队列中移除
189
+ async with self._lock:
190
+ try:
191
+ self._pending.remove(request)
192
+ heapq.heapify(self._pending)
193
+ except ValueError:
194
+ pass # 可能已被处理
195
+ raise
196
+
197
+ async def _wait_and_process(self) -> None:
198
+ """等待指定时间后处理批次"""
199
+ await asyncio.sleep(self.max_wait_ms / 1000)
200
+ await self._process_batch()
201
+
202
+ async def _process_batch(self) -> None:
203
+ """处理一个批次"""
204
+ async with self._lock:
205
+ if not self._pending:
206
+ return
207
+
208
+ batch_start_time = time.time()
209
+
210
+ # 选择要处理的请求
211
+ if self.enable_prefix_grouping:
212
+ batch = self._select_batch_with_prefix_grouping()
213
+ else:
214
+ batch = self._select_batch_simple()
215
+
216
+ if not batch:
217
+ return
218
+
219
+ # 从队列中移除
220
+ for req in batch:
221
+ try:
222
+ self._pending.remove(req)
223
+ except ValueError:
224
+ pass
225
+ if req.prefix_hash:
226
+ try:
227
+ self._pending_by_prefix[req.prefix_hash].remove(req)
228
+ except ValueError:
229
+ pass
230
+ heapq.heapify(self._pending)
231
+
232
+ # 执行批量推理(在锁外执行)
233
+ try:
234
+ results = await self._execute_batch(batch)
235
+
236
+ # 设置结果
237
+ for req, result in zip(batch, results):
238
+ if not req.future.done():
239
+ if isinstance(result, Exception):
240
+ req.future.set_exception(result)
241
+ else:
242
+ req.future.set_result(result)
243
+
244
+ except Exception as e:
245
+ logger.error(f"Batch processing error: {e}")
246
+ # 设置所有请求失败
247
+ for req in batch:
248
+ if not req.future.done():
249
+ req.future.set_exception(e)
250
+
251
+ # 更新统计
252
+ batch_time = (time.time() - batch_start_time) * 1000
253
+ self._stats["total_batches"] += 1
254
+ self._stats["avg_batch_size"] = (
255
+ (self._stats["avg_batch_size"] * (self._stats["total_batches"] - 1) + len(batch))
256
+ / self._stats["total_batches"]
257
+ )
258
+ self._stats["avg_wait_time_ms"] = (
259
+ (self._stats["avg_wait_time_ms"] * (self._stats["total_batches"] - 1) + batch_time)
260
+ / self._stats["total_batches"]
261
+ )
262
+
263
+ def _select_batch_simple(self) -> List[PendingRequest]:
264
+ """简单的批次选择(按优先级)"""
265
+ return [heapq.heappop(self._pending) for _ in range(min(len(self._pending), self.max_batch_size))]
266
+
267
+ def _select_batch_with_prefix_grouping(self) -> List[PendingRequest]:
268
+ """带前缀分组的批次选择"""
269
+ batch = []
270
+
271
+ # 首先尝试找到最大的前缀组
272
+ if self._pending_by_prefix:
273
+ # 按组大小排序
274
+ sorted_groups = sorted(
275
+ self._pending_by_prefix.items(),
276
+ key=lambda x: len(x[1]),
277
+ reverse=True
278
+ )
279
+
280
+ for prefix_hash, group in sorted_groups:
281
+ if not group:
282
+ continue
283
+
284
+ # 取该组的请求
285
+ take_count = min(len(group), self.max_batch_size - len(batch))
286
+ batch.extend(group[:take_count])
287
+
288
+ if len(batch) >= self.max_batch_size:
289
+ break
290
+
291
+ # 如果还有空间,添加没有前缀的请求
292
+ remaining = self.max_batch_size - len(batch)
293
+ if remaining > 0:
294
+ no_prefix_requests = [
295
+ req for req in self._pending
296
+ if not req.prefix_hash and req not in batch
297
+ ]
298
+ batch.extend(no_prefix_requests[:remaining])
299
+
300
+ return batch
301
+
302
+ async def _execute_batch(
303
+ self,
304
+ batch: List[PendingRequest]
305
+ ) -> List[Any]:
306
+ """执行批量推理"""
307
+ params_list = [req.params for req in batch]
308
+
309
+ if hasattr(self.engine, "batch_inference_async"):
310
+ return await self.engine.batch_inference_async(params_list)
311
+ elif hasattr(self.engine, "batch_inference"):
312
+ # 在线程池中执行同步方法
313
+ loop = asyncio.get_event_loop()
314
+ return await loop.run_in_executor(
315
+ None,
316
+ self.engine.batch_inference,
317
+ params_list
318
+ )
319
+ else:
320
+ # 回退到串行执行
321
+ results = []
322
+ for params in params_list:
323
+ try:
324
+ if hasattr(self.engine, "inference_async"):
325
+ result = await self.engine.inference_async(params)
326
+ else:
327
+ loop = asyncio.get_event_loop()
328
+ result = await loop.run_in_executor(
329
+ None,
330
+ self.engine.inference,
331
+ params
332
+ )
333
+ results.append(result)
334
+ except Exception as e:
335
+ results.append(e)
336
+ return results
337
+
338
+ def _compute_prefix_hash(self, params: Dict[str, Any]) -> str:
339
+ """计算请求的前缀哈希"""
340
+ import hashlib
341
+
342
+ messages = params.get("messages", [])
343
+ if not messages:
344
+ return ""
345
+
346
+ # 使用系统消息作为前缀
347
+ system_messages = [
348
+ m.get("content", "")
349
+ for m in messages
350
+ if m.get("role") == "system"
351
+ ]
352
+
353
+ if not system_messages:
354
+ return ""
355
+
356
+ prefix_str = "".join(system_messages)
357
+ return hashlib.sha256(prefix_str.encode()).hexdigest()[:16]
358
+
359
+ def get_stats(self) -> Dict[str, Any]:
360
+ """获取统计信息"""
361
+ return {
362
+ **self._stats,
363
+ "queue_size": len(self._pending),
364
+ "prefix_groups": len(self._pending_by_prefix),
365
+ }
366
+
367
+
368
+ class AdaptiveBatcher(ContinuousBatcher):
369
+ """
370
+ 自适应批处理器
371
+
372
+ 根据负载和延迟要求动态调整批处理参数
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ engine,
378
+ min_batch_size: int = 1,
379
+ max_batch_size: int = 64,
380
+ target_latency_ms: float = 100,
381
+ **kwargs
382
+ ):
383
+ super().__init__(engine, max_batch_size=max_batch_size, **kwargs)
384
+ self.min_batch_size = min_batch_size
385
+ self.target_latency_ms = target_latency_ms
386
+
387
+ # 自适应参数
388
+ self._current_batch_size = max_batch_size // 2
389
+ self._latency_history: List[float] = []
390
+ self._max_history = 100
391
+
392
+ async def _process_batch(self) -> None:
393
+ """处理批次并自适应调整参数"""
394
+ start_time = time.time()
395
+
396
+ # 使用当前自适应的批次大小
397
+ original_max = self.max_batch_size
398
+ self.max_batch_size = self._current_batch_size
399
+
400
+ await super()._process_batch()
401
+
402
+ self.max_batch_size = original_max
403
+
404
+ # 记录延迟
405
+ latency_ms = (time.time() - start_time) * 1000
406
+ self._latency_history.append(latency_ms)
407
+ if len(self._latency_history) > self._max_history:
408
+ self._latency_history.pop(0)
409
+
410
+ # 自适应调整
411
+ self._adapt_batch_size()
412
+
413
+ def _adapt_batch_size(self) -> None:
414
+ """根据延迟历史调整批处理大小"""
415
+ if len(self._latency_history) < 10:
416
+ return
417
+
418
+ avg_latency = sum(self._latency_history[-10:]) / 10
419
+
420
+ if avg_latency > self.target_latency_ms * 1.2:
421
+ # 延迟过高,减小批次大小
422
+ self._current_batch_size = max(
423
+ self.min_batch_size,
424
+ int(self._current_batch_size * 0.8)
425
+ )
426
+ elif avg_latency < self.target_latency_ms * 0.8:
427
+ # 延迟较低,增大批次大小
428
+ self._current_batch_size = min(
429
+ self.max_batch_size,
430
+ int(self._current_batch_size * 1.2)
431
+ )
432
+
433
+ logger.debug(
434
+ f"Adaptive batch size: {self._current_batch_size} "
435
+ f"(avg latency: {avg_latency:.1f}ms)"
436
+ )
@@ -0,0 +1,275 @@
1
+ #!/usr/bin/env node
2
+ /**
3
+ * GPU Worker CLI - Node.js 入口
4
+ * 包装 Python Worker,提供简单的 npm/npx 安装体验
5
+ */
6
+
7
+ const { Command } = require('commander');
8
+ const chalk = require('chalk');
9
+ const ora = require('ora');
10
+ const inquirer = require('inquirer');
11
+ const { spawn, execSync } = require('child_process');
12
+ const path = require('path');
13
+ const fs = require('fs');
14
+ const which = require('which');
15
+
16
+ const PACKAGE_DIR = path.resolve(__dirname, '..');
17
+ const PYTHON_DIR = PACKAGE_DIR;
18
+ const CONFIG_FILE = path.join(process.cwd(), 'config.yaml');
19
+
20
+ // 检测 Python
21
+ function findPython() {
22
+ const pythonCommands = ['python3', 'python', 'py'];
23
+
24
+ for (const cmd of pythonCommands) {
25
+ try {
26
+ const pythonPath = which.sync(cmd);
27
+ // 验证版本
28
+ const version = execSync(`${cmd} --version`, { encoding: 'utf8' });
29
+ const match = version.match(/Python (\d+)\.(\d+)/);
30
+ if (match && parseInt(match[1]) >= 3 && parseInt(match[2]) >= 9) {
31
+ return cmd;
32
+ }
33
+ } catch (e) {
34
+ continue;
35
+ }
36
+ }
37
+ return null;
38
+ }
39
+
40
+ // 检查虚拟环境
41
+ function getVenvPython() {
42
+ const venvPath = path.join(PACKAGE_DIR, '.venv');
43
+
44
+ if (process.platform === 'win32') {
45
+ const pythonPath = path.join(venvPath, 'Scripts', 'python.exe');
46
+ if (fs.existsSync(pythonPath)) return pythonPath;
47
+ } else {
48
+ const pythonPath = path.join(venvPath, 'bin', 'python');
49
+ if (fs.existsSync(pythonPath)) return pythonPath;
50
+ }
51
+ return null;
52
+ }
53
+
54
+ // 创建虚拟环境
55
+ async function createVenv(pythonCmd) {
56
+ const spinner = ora('Creating Python virtual environment...').start();
57
+ const venvPath = path.join(PACKAGE_DIR, '.venv');
58
+
59
+ try {
60
+ execSync(`${pythonCmd} -m venv "${venvPath}"`, { stdio: 'pipe' });
61
+ spinner.succeed('Virtual environment created');
62
+ return true;
63
+ } catch (e) {
64
+ spinner.fail('Failed to create virtual environment');
65
+ console.error(chalk.red(e.message));
66
+ return false;
67
+ }
68
+ }
69
+
70
+ // 安装 Python 依赖
71
+ async function installDependencies() {
72
+ const venvPython = getVenvPython();
73
+ if (!venvPython) {
74
+ console.error(chalk.red('Virtual environment not found'));
75
+ return false;
76
+ }
77
+
78
+ const spinner = ora('Installing Python dependencies...').start();
79
+ const requirementsFile = path.join(PACKAGE_DIR, 'requirements.txt');
80
+
81
+ try {
82
+ execSync(`"${venvPython}" -m pip install -r "${requirementsFile}" -q`, {
83
+ stdio: 'pipe',
84
+ timeout: 600000 // 10分钟超时
85
+ });
86
+ spinner.succeed('Dependencies installed');
87
+ return true;
88
+ } catch (e) {
89
+ spinner.fail('Failed to install dependencies');
90
+ console.error(chalk.red(e.message));
91
+ return false;
92
+ }
93
+ }
94
+
95
+ // 运行 Python CLI
96
+ function runPythonCLI(args) {
97
+ let pythonCmd = getVenvPython() || findPython();
98
+
99
+ if (!pythonCmd) {
100
+ console.error(chalk.red('Python 3.9+ not found!'));
101
+ console.log(chalk.yellow('Please install Python 3.9 or higher:'));
102
+ console.log(' - Windows: https://www.python.org/downloads/');
103
+ console.log(' - macOS: brew install python@3.11');
104
+ console.log(' - Linux: sudo apt install python3.11');
105
+ process.exit(1);
106
+ }
107
+
108
+ const cliPath = path.join(PACKAGE_DIR, 'cli.py');
109
+
110
+ const proc = spawn(pythonCmd, [cliPath, ...args], {
111
+ stdio: 'inherit',
112
+ cwd: process.cwd()
113
+ });
114
+
115
+ proc.on('close', (code) => {
116
+ process.exit(code);
117
+ });
118
+
119
+ proc.on('error', (err) => {
120
+ console.error(chalk.red('Failed to start Python process:'), err.message);
121
+ process.exit(1);
122
+ });
123
+ }
124
+
125
+ // 初始化检查
126
+ async function ensureSetup() {
127
+ const venvPython = getVenvPython();
128
+
129
+ if (!venvPython) {
130
+ console.log(chalk.cyan('First time setup detected. Setting up environment...\n'));
131
+
132
+ const pythonCmd = findPython();
133
+ if (!pythonCmd) {
134
+ console.error(chalk.red('Python 3.9+ is required but not found!'));
135
+ console.log(chalk.yellow('\nPlease install Python:'));
136
+ console.log(' - Windows: https://www.python.org/downloads/');
137
+ console.log(' - macOS: brew install python@3.11');
138
+ console.log(' - Linux: sudo apt install python3.11');
139
+ process.exit(1);
140
+ }
141
+
142
+ console.log(chalk.green(`Found Python: ${pythonCmd}`));
143
+
144
+ if (!await createVenv(pythonCmd)) {
145
+ process.exit(1);
146
+ }
147
+
148
+ if (!await installDependencies()) {
149
+ process.exit(1);
150
+ }
151
+
152
+ console.log(chalk.green('\n✓ Setup complete!\n'));
153
+ }
154
+ }
155
+
156
+ // 主程序
157
+ const program = new Command();
158
+
159
+ program
160
+ .name('gpu-worker')
161
+ .description('分布式GPU推理 Worker 节点')
162
+ .version('1.0.0');
163
+
164
+ program
165
+ .command('install')
166
+ .description('安装/更新 Python 依赖')
167
+ .action(async () => {
168
+ await ensureSetup();
169
+ await installDependencies();
170
+ });
171
+
172
+ program
173
+ .command('configure')
174
+ .description('交互式配置向导')
175
+ .action(async () => {
176
+ await ensureSetup();
177
+ runPythonCLI(['configure']);
178
+ });
179
+
180
+ program
181
+ .command('start')
182
+ .description('启动 Worker')
183
+ .option('-c, --config <path>', '配置文件路径', 'config.yaml')
184
+ .action(async (options) => {
185
+ await ensureSetup();
186
+
187
+ // 检查配置文件
188
+ const configPath = path.resolve(options.config);
189
+ if (!fs.existsSync(configPath)) {
190
+ console.log(chalk.yellow('No config file found. Starting configuration wizard...\n'));
191
+ runPythonCLI(['configure']);
192
+ return;
193
+ }
194
+
195
+ runPythonCLI(['start', '-c', configPath]);
196
+ });
197
+
198
+ program
199
+ .command('status')
200
+ .description('查看状态')
201
+ .action(async () => {
202
+ await ensureSetup();
203
+ runPythonCLI(['status']);
204
+ });
205
+
206
+ program
207
+ .command('set <key> <value>')
208
+ .description('设置配置项')
209
+ .action(async (key, value) => {
210
+ await ensureSetup();
211
+ runPythonCLI(['set', key, value]);
212
+ });
213
+
214
+ program
215
+ .command('setup')
216
+ .description('初始化环境(创建虚拟环境并安装依赖)')
217
+ .action(async () => {
218
+ const pythonCmd = findPython();
219
+ if (!pythonCmd) {
220
+ console.error(chalk.red('Python 3.9+ not found!'));
221
+ process.exit(1);
222
+ }
223
+
224
+ console.log(chalk.cyan('Setting up GPU Worker environment...\n'));
225
+ console.log(chalk.green(`Python: ${pythonCmd}`));
226
+
227
+ await createVenv(pythonCmd);
228
+ await installDependencies();
229
+
230
+ console.log(chalk.green('\n✓ Setup complete!'));
231
+ console.log(chalk.cyan('\nNext steps:'));
232
+ console.log(' 1. Run: gpu-worker configure');
233
+ console.log(' 2. Run: gpu-worker start');
234
+ });
235
+
236
+ // 快速启动命令 (无参数时的默认行为)
237
+ program
238
+ .command('quick', { isDefault: true, hidden: true })
239
+ .action(async () => {
240
+ await ensureSetup();
241
+
242
+ console.log(chalk.cyan.bold('\n GPU Worker - 分布式GPU推理节点\n'));
243
+
244
+ const choices = [
245
+ { name: '🚀 启动 Worker', value: 'start' },
246
+ { name: '⚙️ 配置向导', value: 'configure' },
247
+ { name: '📊 查看状态', value: 'status' },
248
+ { name: '📦 安装依赖', value: 'install' },
249
+ { name: '❌ 退出', value: 'exit' }
250
+ ];
251
+
252
+ const { action } = await inquirer.prompt([{
253
+ type: 'list',
254
+ name: 'action',
255
+ message: '请选择操作:',
256
+ choices
257
+ }]);
258
+
259
+ if (action === 'exit') {
260
+ process.exit(0);
261
+ }
262
+
263
+ if (action === 'start') {
264
+ const configPath = path.join(process.cwd(), 'config.yaml');
265
+ if (!fs.existsSync(configPath)) {
266
+ console.log(chalk.yellow('\n未找到配置文件,先进行配置...\n'));
267
+ runPythonCLI(['configure']);
268
+ return;
269
+ }
270
+ }
271
+
272
+ runPythonCLI([action]);
273
+ });
274
+
275
+ program.parse();