codegnipy 0.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
codegnipy/scheduler.py ADDED
@@ -0,0 +1,498 @@
1
+ """
2
+ Codegnipy 异步调度器模块
3
+
4
+ 提供高性能的异步 LLM 调用调度,包括:
5
+ - 异步执行
6
+ - 超时控制
7
+ - 重试机制
8
+ - 并发控制
9
+ - 优先级队列
10
+ """
11
+
12
+ import asyncio
13
+ import time
14
+ from dataclasses import dataclass, field
15
+ from enum import Enum
16
+ from typing import (
17
+ Any, Optional, Callable, Coroutine, List, Dict,
18
+ TypeVar, Generic, Union
19
+ )
20
+
21
+ from .runtime import LLMConfig, CognitiveContext
22
+
23
+
24
+ T = TypeVar('T')
25
+
26
+
27
+ class TaskStatus(Enum):
28
+ """任务状态"""
29
+ PENDING = "pending"
30
+ RUNNING = "running"
31
+ COMPLETED = "completed"
32
+ FAILED = "failed"
33
+ TIMEOUT = "timeout"
34
+ CANCELLED = "cancelled"
35
+
36
+
37
+ class Priority(Enum):
38
+ """任务优先级"""
39
+ LOW = 1
40
+ NORMAL = 5
41
+ HIGH = 10
42
+ CRITICAL = 100
43
+
44
+
45
+ @dataclass(order=True)
46
+ class ScheduledTask(Generic[T]):
47
+ """调度任务"""
48
+ priority: int
49
+ task_id: str = field(compare=False)
50
+ coro_factory: Callable[[], Coroutine] = field(compare=False, repr=False) # 协程工厂
51
+ created_at: float = field(default_factory=time.time, compare=False)
52
+ started_at: Optional[float] = field(default=None, compare=False)
53
+ completed_at: Optional[float] = field(default=None, compare=False)
54
+ status: TaskStatus = field(default=TaskStatus.PENDING, compare=False)
55
+ result: Optional[T] = field(default=None, compare=False)
56
+ error: Optional[Exception] = field(default=None, compare=False)
57
+ retries: int = field(default=0, compare=False)
58
+ max_retries: int = field(default=3, compare=False)
59
+ timeout: Optional[float] = field(default=None, compare=False)
60
+ callback: Optional[Callable[[T], None]] = field(default=None, compare=False, repr=False)
61
+
62
+ def __post_init__(self):
63
+ self._async_task: Optional[asyncio.Task] = None
64
+
65
+ def create_coro(self) -> Coroutine:
66
+ """创建新的协程实例"""
67
+ return self.coro_factory()
68
+
69
+
70
+ @dataclass
71
+ class RetryPolicy:
72
+ """重试策略"""
73
+ max_retries: int = 3
74
+ base_delay: float = 1.0
75
+ max_delay: float = 60.0
76
+ exponential_base: float = 2.0
77
+
78
+ def get_delay(self, attempt: int) -> float:
79
+ """计算重试延迟(指数退避)"""
80
+ delay = self.base_delay * (self.exponential_base ** attempt)
81
+ return min(delay, self.max_delay)
82
+
83
+
84
+ @dataclass
85
+ class SchedulerConfig:
86
+ """调度器配置"""
87
+ max_concurrent: int = 10
88
+ default_timeout: float = 60.0
89
+ default_priority: Priority = Priority.NORMAL
90
+ retry_policy: RetryPolicy = field(default_factory=RetryPolicy)
91
+
92
+
93
+ class CognitiveScheduler:
94
+ """
95
+ 认知任务调度器
96
+
97
+ 管理异步 LLM 调用的执行,支持:
98
+ - 并发控制
99
+ - 优先级队列
100
+ - 超时和重试
101
+ - 任务回调
102
+
103
+ 示例:
104
+ scheduler = CognitiveScheduler(max_concurrent=5)
105
+
106
+ async def main():
107
+ # 提交任务
108
+ task_id = await scheduler.submit(
109
+ async_cognitive_call("Hello"),
110
+ priority=Priority.HIGH
111
+ )
112
+
113
+ # 等待结果
114
+ result = await scheduler.get_result(task_id)
115
+ print(result)
116
+
117
+ asyncio.run(main())
118
+ """
119
+
120
+ def __init__(self, config: Optional[SchedulerConfig] = None):
121
+ self.config = config or SchedulerConfig()
122
+ self._tasks: Dict[str, ScheduledTask[Any]] = {}
123
+ self._pending_queue: Optional[asyncio.PriorityQueue[ScheduledTask[Any]]] = None
124
+ self._running_count: int = 0
125
+ self._counter: int = 0
126
+ self._lock: Optional[asyncio.Lock] = None
127
+ self._started: bool = False
128
+
129
+ async def _ensure_initialized(self):
130
+ """确保异步资源初始化"""
131
+ if self._pending_queue is None:
132
+ self._pending_queue = asyncio.PriorityQueue()
133
+ if self._lock is None:
134
+ self._lock = asyncio.Lock()
135
+
136
+ def _generate_task_id(self) -> str:
137
+ """生成任务 ID"""
138
+ self._counter += 1
139
+ return f"task_{self._counter}_{int(time.time() * 1000)}"
140
+
141
+ async def submit(
142
+ self,
143
+ coro_or_factory: Union[Coroutine[Any, Any, T], Callable[[], Coroutine[Any, Any, T]]],
144
+ priority: Optional[Priority] = None,
145
+ timeout: Optional[float] = None,
146
+ max_retries: Optional[int] = None,
147
+ callback: Optional[Callable[[T], None]] = None
148
+ ) -> str:
149
+ """
150
+ 提交异步任务
151
+
152
+ 参数:
153
+ coro_or_factory: 协程对象或协程工厂函数(支持重试)
154
+ priority: 优先级
155
+ timeout: 超时时间(秒)
156
+ max_retries: 最大重试次数
157
+ callback: 完成回调
158
+
159
+ 返回:
160
+ 任务 ID
161
+
162
+ 注意:
163
+ 如果传入协程对象,重试将无法创建新实例。
164
+ 建议传入协程工厂函数以支持重试。
165
+ """
166
+ await self._ensure_initialized()
167
+
168
+ # 确定协程工厂
169
+ if asyncio.iscoroutine(coro_or_factory):
170
+ # 直接传入协程,无法重试
171
+ coro = coro_or_factory
172
+ def coro_factory():
173
+ return coro
174
+ effective_max_retries = 0 # 无法重试
175
+ else:
176
+ # 传入工厂函数
177
+ coro_factory = coro_or_factory
178
+ coro = coro_factory()
179
+ effective_max_retries = max_retries or self.config.retry_policy.max_retries
180
+
181
+ task_id = self._generate_task_id()
182
+ task = ScheduledTask(
183
+ priority=(priority or self.config.default_priority).value * -1, # 负数实现高优先级先出
184
+ task_id=task_id,
185
+ coro_factory=coro_factory,
186
+ timeout=timeout or self.config.default_timeout,
187
+ max_retries=effective_max_retries,
188
+ callback=callback
189
+ )
190
+
191
+ # 存储初始协程
192
+ task._current_coro = coro # type: ignore[attr-defined]
193
+
194
+ self._tasks[task_id] = task
195
+ assert self._pending_queue is not None
196
+ await self._pending_queue.put(task)
197
+
198
+ # 尝试处理队列
199
+ asyncio.create_task(self._process_queue())
200
+
201
+ return task_id
202
+
203
+ async def _process_queue(self):
204
+ """处理任务队列"""
205
+ await self._ensure_initialized()
206
+
207
+ assert self._lock is not None
208
+ async with self._lock:
209
+ while self._running_count < self.config.max_concurrent:
210
+ assert self._pending_queue is not None
211
+ if self._pending_queue.empty():
212
+ break
213
+
214
+ task = await self._pending_queue.get()
215
+ if task.status == TaskStatus.CANCELLED:
216
+ continue
217
+
218
+ asyncio.create_task(self._execute_task(task))
219
+ self._running_count += 1
220
+
221
+ async def _execute_task(self, task: ScheduledTask):
222
+ """执行单个任务"""
223
+ task.status = TaskStatus.RUNNING
224
+ task.started_at = time.time()
225
+
226
+ try:
227
+ # 获取当前协程
228
+ coro = getattr(task, '_current_coro', None) or task.create_coro()
229
+
230
+ # 执行带超时
231
+ result = await asyncio.wait_for(
232
+ coro,
233
+ timeout=task.timeout
234
+ )
235
+
236
+ task.result = result
237
+ task.status = TaskStatus.COMPLETED
238
+ task.completed_at = time.time()
239
+
240
+ # 调用回调
241
+ if task.callback:
242
+ try:
243
+ task.callback(result)
244
+ except Exception:
245
+ # 回调错误不影响任务状态
246
+ pass
247
+
248
+ except asyncio.TimeoutError:
249
+ task.status = TaskStatus.TIMEOUT
250
+ task.error = TimeoutError(f"Task {task.task_id} timed out after {task.timeout}s")
251
+
252
+ # 重试
253
+ if task.retries < task.max_retries:
254
+ await self._retry_task(task)
255
+
256
+ except asyncio.CancelledError:
257
+ task.status = TaskStatus.CANCELLED
258
+
259
+ except Exception as e:
260
+ task.status = TaskStatus.FAILED
261
+ task.error = e
262
+
263
+ # 重试
264
+ if task.retries < task.max_retries:
265
+ await self._retry_task(task)
266
+
267
+ finally:
268
+ assert self._lock is not None
269
+ async with self._lock:
270
+ self._running_count -= 1
271
+
272
+ # 继续处理队列
273
+ await self._process_queue()
274
+
275
+ async def _retry_task(self, task: ScheduledTask):
276
+ """重试任务"""
277
+ task.retries += 1
278
+ task.status = TaskStatus.PENDING
279
+
280
+ # 创建新的协程实例
281
+ task._current_coro = task.create_coro() # type: ignore[attr-defined]
282
+
283
+ # 等待延迟
284
+ delay = self.config.retry_policy.get_delay(task.retries - 1)
285
+ await asyncio.sleep(delay)
286
+
287
+ # 重新加入队列
288
+ assert self._pending_queue is not None
289
+ await self._pending_queue.put(task)
290
+
291
+ async def get_result(self, task_id: str, timeout: Optional[float] = None) -> Any:
292
+ """
293
+ 获取任务结果
294
+
295
+ 参数:
296
+ task_id: 任务 ID
297
+ timeout: 等待超时
298
+
299
+ 返回:
300
+ 任务结果
301
+
302
+ 抛出:
303
+ KeyError: 任务不存在
304
+ Exception: 任务失败
305
+ """
306
+ if task_id not in self._tasks:
307
+ raise KeyError(f"Task {task_id} not found")
308
+
309
+ task = self._tasks[task_id]
310
+
311
+ # 等待完成
312
+ start_time = time.time()
313
+ while task.status in (TaskStatus.PENDING, TaskStatus.RUNNING):
314
+ if timeout and (time.time() - start_time) > timeout:
315
+ raise TimeoutError(f"Timeout waiting for task {task_id}")
316
+ await asyncio.sleep(0.1)
317
+
318
+ # 返回结果或抛出错误
319
+ if task.status == TaskStatus.COMPLETED:
320
+ return task.result
321
+ elif task.error:
322
+ raise task.error
323
+ else:
324
+ raise Exception(f"Task {task_id} failed with status {task.status}")
325
+
326
+ async def cancel(self, task_id: str) -> bool:
327
+ """取消任务"""
328
+ if task_id not in self._tasks:
329
+ return False
330
+
331
+ task = self._tasks[task_id]
332
+ task.status = TaskStatus.CANCELLED
333
+
334
+ if task._async_task:
335
+ task._async_task.cancel()
336
+
337
+ return True
338
+
339
+ def get_status(self, task_id: str) -> Optional[TaskStatus]:
340
+ """获取任务状态"""
341
+ task = self._tasks.get(task_id)
342
+ return task.status if task else None
343
+
344
+ async def wait_all(self, timeout: Optional[float] = None) -> Dict[str, Any]:
345
+ """等待所有任务完成"""
346
+ start_time = time.time()
347
+
348
+ while True:
349
+ all_done = all(
350
+ t.status in (TaskStatus.COMPLETED, TaskStatus.FAILED, TaskStatus.CANCELLED, TaskStatus.TIMEOUT)
351
+ for t in self._tasks.values()
352
+ )
353
+
354
+ if all_done:
355
+ break
356
+
357
+ if timeout and (time.time() - start_time) > timeout:
358
+ raise TimeoutError("Timeout waiting for all tasks")
359
+
360
+ await asyncio.sleep(0.1)
361
+
362
+ return {tid: t.result for tid, t in self._tasks.items() if t.status == TaskStatus.COMPLETED}
363
+
364
+ def stats(self) -> Dict[str, Any]:
365
+ """获取调度器统计信息"""
366
+ statuses: Dict[str, int] = {}
367
+ for task in self._tasks.values():
368
+ status = task.status.value
369
+ statuses[status] = statuses.get(status, 0) + 1
370
+
371
+ return {
372
+ "total_tasks": len(self._tasks),
373
+ "running": self._running_count,
374
+ "max_concurrent": self.config.max_concurrent,
375
+ "by_status": statuses
376
+ }
377
+
378
+
379
+ # ============ 异步 LLM 调用 ============
380
+
381
+ async def async_cognitive_call(
382
+ prompt: str,
383
+ context: Optional[CognitiveContext] = None,
384
+ *,
385
+ model: Optional[str] = None,
386
+ temperature: Optional[float] = None,
387
+ config: Optional[LLMConfig] = None
388
+ ) -> str:
389
+ """
390
+ 异步认知调用
391
+
392
+ 参数:
393
+ prompt: 提示文本
394
+ context: 认知上下文
395
+ model: 模型名称
396
+ temperature: 温度参数
397
+ config: LLM 配置
398
+
399
+ 返回:
400
+ LLM 响应文本
401
+ """
402
+ # 获取配置
403
+ if config is None:
404
+ ctx = context or CognitiveContext.get_current()
405
+ if ctx:
406
+ config = ctx.get_config()
407
+ if model:
408
+ config.model = model
409
+ if temperature is not None:
410
+ config.temperature = temperature
411
+ else:
412
+ config = LLMConfig()
413
+
414
+ # 异步调用 OpenAI
415
+ return await _call_openai_async(config, prompt)
416
+
417
+
418
+ async def _call_openai_async(config: LLMConfig, prompt: str) -> str:
419
+ """异步调用 OpenAI API"""
420
+ try:
421
+ import openai
422
+ except ImportError:
423
+ raise ImportError("需要安装 openai 包。运行: pip install openai")
424
+
425
+ client = openai.AsyncOpenAI(
426
+ api_key=config.api_key,
427
+ base_url=config.base_url
428
+ )
429
+
430
+ response = await client.chat.completions.create(
431
+ model=config.model,
432
+ messages=[{"role": "user", "content": prompt}],
433
+ temperature=config.temperature,
434
+ max_tokens=config.max_tokens
435
+ )
436
+
437
+ return response.choices[0].message.content or ""
438
+
439
+
440
+ # ============ 便捷函数 ============
441
+
442
+ def run_async(coro: Coroutine) -> Any:
443
+ """运行协程的便捷函数"""
444
+ try:
445
+ loop = asyncio.get_running_loop()
446
+ except RuntimeError:
447
+ loop = None
448
+
449
+ if loop and loop.is_running():
450
+ # 已经在事件循环中,创建任务
451
+ return asyncio.create_task(coro)
452
+ else:
453
+ # 新建事件循环
454
+ return asyncio.run(coro)
455
+
456
+
457
+ async def batch_call(
458
+ prompts: List[str],
459
+ context: Optional[CognitiveContext] = None,
460
+ max_concurrent: int = 5,
461
+ timeout: float = 60.0
462
+ ) -> List[Optional[str]]:
463
+ """
464
+ 批量异步调用
465
+
466
+ 参数:
467
+ prompts: 提示列表
468
+ context: 认知上下文
469
+ max_concurrent: 最大并发数
470
+ timeout: 单个调用超时
471
+
472
+ 返回:
473
+ 响应列表(与输入顺序对应)
474
+ """
475
+ scheduler = CognitiveScheduler(SchedulerConfig(max_concurrent=max_concurrent))
476
+
477
+ # 提交所有任务
478
+ task_ids = []
479
+ for prompt in prompts:
480
+ task_id = await scheduler.submit(
481
+ async_cognitive_call(prompt, context),
482
+ timeout=timeout
483
+ )
484
+ task_ids.append(task_id)
485
+
486
+ # 等待所有完成
487
+ await scheduler.wait_all()
488
+
489
+ # 收集结果
490
+ results: List[Optional[str]] = []
491
+ for task_id in task_ids:
492
+ task = scheduler._tasks[task_id]
493
+ if task.status == TaskStatus.COMPLETED:
494
+ results.append(task.result)
495
+ else:
496
+ results.append(None) # 或抛出异常
497
+
498
+ return results