flexllm 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (39) hide show
  1. flexllm/__init__.py +224 -0
  2. flexllm/__main__.py +1096 -0
  3. flexllm/async_api/__init__.py +9 -0
  4. flexllm/async_api/concurrent_call.py +100 -0
  5. flexllm/async_api/concurrent_executor.py +1036 -0
  6. flexllm/async_api/core.py +373 -0
  7. flexllm/async_api/interface.py +12 -0
  8. flexllm/async_api/progress.py +277 -0
  9. flexllm/base_client.py +988 -0
  10. flexllm/batch_tools/__init__.py +16 -0
  11. flexllm/batch_tools/folder_processor.py +317 -0
  12. flexllm/batch_tools/table_processor.py +363 -0
  13. flexllm/cache/__init__.py +10 -0
  14. flexllm/cache/response_cache.py +293 -0
  15. flexllm/chain_of_thought_client.py +1120 -0
  16. flexllm/claudeclient.py +402 -0
  17. flexllm/client_pool.py +698 -0
  18. flexllm/geminiclient.py +563 -0
  19. flexllm/llm_client.py +523 -0
  20. flexllm/llm_parser.py +60 -0
  21. flexllm/mllm_client.py +559 -0
  22. flexllm/msg_processors/__init__.py +174 -0
  23. flexllm/msg_processors/image_processor.py +729 -0
  24. flexllm/msg_processors/image_processor_helper.py +485 -0
  25. flexllm/msg_processors/messages_processor.py +341 -0
  26. flexllm/msg_processors/unified_processor.py +1404 -0
  27. flexllm/openaiclient.py +256 -0
  28. flexllm/pricing/__init__.py +104 -0
  29. flexllm/pricing/data.json +1201 -0
  30. flexllm/pricing/updater.py +223 -0
  31. flexllm/provider_router.py +213 -0
  32. flexllm/token_counter.py +270 -0
  33. flexllm/utils/__init__.py +1 -0
  34. flexllm/utils/core.py +41 -0
  35. flexllm-0.3.3.dist-info/METADATA +573 -0
  36. flexllm-0.3.3.dist-info/RECORD +39 -0
  37. flexllm-0.3.3.dist-info/WHEEL +4 -0
  38. flexllm-0.3.3.dist-info/entry_points.txt +3 -0
  39. flexllm-0.3.3.dist-info/licenses/LICENSE +201 -0
@@ -0,0 +1,1036 @@
1
+ import asyncio
2
+ import time
3
+ import itertools
4
+ import inspect
5
+ import os
6
+ import uuid
7
+ from asyncio import Queue
8
+ from typing import (
9
+ Any,
10
+ Dict,
11
+ Iterable,
12
+ List,
13
+ Optional,
14
+ Callable,
15
+ AsyncIterator,
16
+ AsyncGenerator,
17
+ Tuple,
18
+ Union,
19
+ Awaitable,
20
+ Protocol,
21
+ )
22
+ from dataclasses import dataclass, field
23
+ import heapq
24
+
25
+ from ..utils.core import async_retry
26
+ from .interface import RequestResult
27
+ from .progress import ProgressTracker, ProgressBarConfig
28
+
29
+ # FlaxKV2 用于检查点存储
30
+ try:
31
+ from flaxkv2 import FlaxKV
32
+ FLAXKV_AVAILABLE = True
33
+ except ImportError:
34
+ FLAXKV_AVAILABLE = False
35
+
36
+
37
+ # 添加函数协议定义
38
+ class TaskFunction(Protocol):
39
+ """任务函数协议"""
40
+
41
+ async def __call__(self, *args, **kwargs) -> Any: ...
42
+
43
+
44
+ @dataclass
45
+ class TaskContext:
46
+ """任务执行上下文,包含所有可能需要的信息"""
47
+
48
+ task_id: int
49
+ data: Any
50
+ meta: Optional[dict] = None
51
+ retry_count: int = 0
52
+ executor_kwargs: Optional[dict] = None
53
+
54
+
55
+ @dataclass
56
+ class TaskItem:
57
+ """任务项,支持优先级"""
58
+
59
+ priority: int
60
+ task_id: int
61
+ data: Any
62
+ meta: Optional[dict] = field(default_factory=dict)
63
+
64
+ def __lt__(self, other):
65
+ return self.priority < other.priority
66
+
67
+
68
+ @dataclass
69
+ class ExecutionResult:
70
+ """执行结果"""
71
+
72
+ task_id: int
73
+ data: Any
74
+ status: str # 'success' or 'error'
75
+ meta: Optional[dict] = None
76
+ latency: float = 0.0
77
+ error: Optional[Exception] = None
78
+ retry_count: int = 0 # 重试次数
79
+
80
+
81
+ @dataclass
82
+ class StreamingExecutionResult:
83
+ """流式执行结果"""
84
+
85
+ completed_tasks: List[ExecutionResult]
86
+ progress: Optional[ProgressTracker]
87
+ is_final: bool
88
+
89
+
90
+ class RateLimiter:
91
+ """速率限制器"""
92
+
93
+ def __init__(self, max_qps: Optional[float] = None):
94
+ self.max_qps = max_qps
95
+ self.min_interval = 1 / max_qps if max_qps else 0
96
+ self.last_request_time = 0
97
+
98
+ async def acquire(self):
99
+ if not self.max_qps:
100
+ return
101
+
102
+ current_time = time.time()
103
+ elapsed = current_time - self.last_request_time
104
+ if elapsed < self.min_interval:
105
+ time.sleep(self.min_interval - elapsed)
106
+ self.last_request_time = time.time()
107
+
108
+
109
+ class ConcurrentExecutor:
110
+ """
111
+ 通用并发执行器
112
+
113
+ 可以对任意的异步函数进行并发调度,支持:
114
+ - 并发数量控制
115
+ - QPS限制
116
+ - 进度跟踪
117
+ - 重试机制
118
+ - 流式结果返回
119
+ - 优先级调度
120
+ - 自定义错误处理
121
+ - 智能函数调用模式
122
+
123
+ 支持的执行方法:
124
+ - execute_batch: 智能批量执行,自动检测函数签名
125
+ - execute_batch_with_adapter: 使用适配器的批量执行
126
+ - execute_batch_with_context: 上下文模式的批量执行
127
+ - execute_batch_with_factory: 函数工厂模式的批量执行
128
+ - execute_priority_batch: 按优先级的批量执行
129
+ - aiter_execute_batch: 流式批量执行
130
+ - execute_batch_sync: 同步版本的批量执行
131
+
132
+ Example
133
+ -------
134
+
135
+ # 方式1: 简单函数 - 仅接收数据
136
+ async def simple_task(data):
137
+ await asyncio.sleep(0.1)
138
+ return f"processed: {data}"
139
+
140
+ # 方式2: 上下文函数 - 接收完整上下文
141
+ async def context_task(context: TaskContext):
142
+ return f"task_{context.task_id}: {context.data} (retry: {context.retry_count})"
143
+
144
+ # 方式3: 通用函数 - 接受任意参数
145
+ async def flexible_task(item, **options):
146
+ return f"processed {item} with {options}"
147
+
148
+ # 创建执行器
149
+ executor = ConcurrentExecutor(
150
+ concurrency_limit=5,
151
+ max_qps=10,
152
+ retry_times=3
153
+ )
154
+
155
+ # 智能批量执行 - 自动检测函数签名适配调用方式
156
+ results, _ = await executor.execute_batch(
157
+ async_func=simple_task, # 简单函数
158
+ tasks_data=["task1", "task2", "task3"]
159
+ )
160
+
161
+ results, _ = await executor.execute_batch(
162
+ async_func=flexible_task, # 通用函数
163
+ tasks_data=["item1", "item2"],
164
+ executor_kwargs={"user_id": 123, "mode": "fast"}
165
+ )
166
+
167
+ # 使用任务适配器(适配复杂函数签名)
168
+ results, _ = await executor.execute_batch_with_adapter(
169
+ async_func=complex_function,
170
+ tasks_data=complex_data,
171
+ task_adapter=my_adapter_function
172
+ )
173
+
174
+ # 使用上下文模式(获取完整执行信息)
175
+ results, _ = await executor.execute_batch_with_context(
176
+ async_func=context_task,
177
+ tasks_data=["data1", "data2"]
178
+ )
179
+ """
180
+
181
+ def __init__(
182
+ self,
183
+ concurrency_limit: int,
184
+ max_qps: Optional[float] = None,
185
+ retry_times: int = 3,
186
+ retry_delay: float = 0.3,
187
+ error_handler: Optional[Callable[[Exception, Any, int], bool]] = None,
188
+ ):
189
+ self._concurrency_limit = concurrency_limit
190
+ self._rate_limiter = RateLimiter(max_qps)
191
+ self._semaphore = asyncio.Semaphore(concurrency_limit)
192
+ self.retry_times = retry_times
193
+ self.retry_delay = retry_delay
194
+ self.error_handler = error_handler # 自定义错误处理函数
195
+
196
+ def _inspect_function_signature(self, func: Callable) -> dict:
197
+ """检查函数签名,返回参数信息"""
198
+ sig = inspect.signature(func)
199
+ params = sig.parameters
200
+
201
+ # 更强大的TaskContext检测
202
+ has_context_param = False
203
+ for param in params.values():
204
+ # 检查类型注解
205
+ if param.annotation == TaskContext:
206
+ has_context_param = True
207
+ break
208
+ # 检查字符串形式的注解
209
+ if isinstance(param.annotation, str) and "TaskContext" in param.annotation:
210
+ has_context_param = True
211
+ break
212
+ # 检查参数名称(作为备选检测方式)
213
+ if param.name == "context" and len(params) == 1:
214
+ has_context_param = True
215
+ break
216
+
217
+ return {
218
+ "param_names": list(params.keys()),
219
+ "has_context_param": has_context_param,
220
+ "accepts_var_kwargs": any(p.kind == p.VAR_KEYWORD for p in params.values()),
221
+ "param_count": len(
222
+ [
223
+ p
224
+ for p in params.values()
225
+ if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD)
226
+ ]
227
+ ),
228
+ }
229
+
230
+ async def _call_function_intelligently(
231
+ self,
232
+ func: Callable,
233
+ task_context: TaskContext,
234
+ executor_kwargs: Optional[dict] = None,
235
+ ) -> Any:
236
+ """智能调用函数,根据函数签名自动适配参数"""
237
+ sig_info = self._inspect_function_signature(func)
238
+ executor_kwargs = executor_kwargs or {}
239
+
240
+ # 方式1: 如果函数接受TaskContext类型参数
241
+ if sig_info["has_context_param"]:
242
+ # 上下文函数不接受额外的关键字参数,所有信息都在TaskContext中
243
+ return await func(task_context)
244
+
245
+ # 方式2: 只接受data的简单函数
246
+ elif sig_info["param_count"] == 1 and not sig_info["accepts_var_kwargs"]:
247
+ return await func(task_context.data)
248
+
249
+ # 方式3: 接受任意参数的通用函数
250
+ else:
251
+ return await func(task_context.data, **executor_kwargs)
252
+
253
+ async def _execute_single_task(
254
+ self,
255
+ async_func: Callable[..., Awaitable[Any]],
256
+ task_data: Any,
257
+ task_id: int,
258
+ meta: Optional[dict] = None,
259
+ executor_kwargs: Optional[dict] = None,
260
+ task_adapter: Optional[Callable] = None,
261
+ **kwargs,
262
+ ) -> ExecutionResult:
263
+ """执行单个异步任务"""
264
+ retry_count = 0
265
+ last_error = None
266
+ executor_kwargs = {**(executor_kwargs or {}), **kwargs}
267
+
268
+ async with self._semaphore:
269
+ while retry_count <= self.retry_times:
270
+ try:
271
+ await self._rate_limiter.acquire()
272
+
273
+ start_time = time.time()
274
+
275
+ # 创建任务上下文
276
+ task_context = TaskContext(
277
+ task_id=task_id,
278
+ data=task_data,
279
+ meta=meta,
280
+ retry_count=retry_count,
281
+ executor_kwargs=executor_kwargs,
282
+ )
283
+
284
+ # 使用适配器或智能调用
285
+ if task_adapter:
286
+ args, kwargs_from_adapter = task_adapter(
287
+ task_data, task_context
288
+ )
289
+ if isinstance(args, (list, tuple)):
290
+ result = await async_func(
291
+ *args, **{**executor_kwargs, **kwargs_from_adapter}
292
+ )
293
+ else:
294
+ result = await async_func(
295
+ args, **{**executor_kwargs, **kwargs_from_adapter}
296
+ )
297
+ else:
298
+ result = await self._call_function_intelligently(
299
+ async_func, task_context, executor_kwargs
300
+ )
301
+
302
+ latency = time.time() - start_time
303
+
304
+ return ExecutionResult(
305
+ task_id=task_id,
306
+ data=result,
307
+ status="success",
308
+ meta=meta,
309
+ latency=latency,
310
+ retry_count=retry_count,
311
+ )
312
+
313
+ except Exception as e:
314
+ last_error = e
315
+ retry_count += 1
316
+
317
+ # 调用自定义错误处理函数
318
+ if self.error_handler:
319
+ should_retry = self.error_handler(e, task_data, retry_count)
320
+ if not should_retry:
321
+ break
322
+
323
+ if retry_count <= self.retry_times:
324
+ await asyncio.sleep(self.retry_delay)
325
+
326
+ # 所有重试都失败了
327
+ return ExecutionResult(
328
+ task_id=task_id,
329
+ data=None,
330
+ status="error",
331
+ meta=meta,
332
+ latency=time.time() - start_time if "start_time" in locals() else 0,
333
+ error=last_error,
334
+ retry_count=retry_count - 1,
335
+ )
336
+
337
+ async def _process_with_concurrency_window(
338
+ self,
339
+ async_func: Callable[..., Awaitable[Any]],
340
+ tasks_data: Iterable[Any],
341
+ progress: Optional[ProgressTracker] = None,
342
+ batch_size: int = 1,
343
+ **kwargs,
344
+ ) -> AsyncGenerator[StreamingExecutionResult, Any]:
345
+ """
346
+ 使用滑动窗口方式处理并发任务,支持流式返回结果
347
+ """
348
+
349
+ async def handle_completed_tasks(done_tasks, batch, is_final=False):
350
+ """处理已完成的任务"""
351
+ for task in done_tasks:
352
+ result = await task
353
+ if progress:
354
+ # 将ExecutionResult转换为RequestResult以兼容ProgressTracker
355
+ # 对于错误情况,需要构造包含错误信息的data字典
356
+ progress_data = result.data
357
+ if result.status == "error" and result.error:
358
+ progress_data = {
359
+ "error": result.error.__class__.__name__,
360
+ "detail": str(result.error),
361
+ }
362
+
363
+ request_result = RequestResult(
364
+ request_id=result.task_id,
365
+ data=progress_data,
366
+ status=result.status,
367
+ meta=result.meta,
368
+ latency=result.latency,
369
+ )
370
+ progress.update(request_result)
371
+ batch.append(result)
372
+
373
+ if len(batch) >= batch_size or (is_final and batch):
374
+ if is_final and progress:
375
+ progress.summary()
376
+ yield StreamingExecutionResult(
377
+ completed_tasks=sorted(batch, key=lambda x: x.task_id),
378
+ progress=progress,
379
+ is_final=is_final,
380
+ )
381
+ batch.clear()
382
+
383
+ task_id = 0
384
+ active_tasks = set()
385
+ completed_batch = []
386
+
387
+ # 处理任务数据
388
+ for data in tasks_data:
389
+ if len(active_tasks) >= self._concurrency_limit:
390
+ done, active_tasks = await asyncio.wait(
391
+ active_tasks, return_when=asyncio.FIRST_COMPLETED
392
+ )
393
+ async for result in handle_completed_tasks(done, completed_batch):
394
+ yield result
395
+
396
+ active_tasks.add(
397
+ asyncio.create_task(
398
+ self._execute_single_task(async_func, data, task_id, **kwargs)
399
+ )
400
+ )
401
+ task_id += 1
402
+
403
+ # 处理剩余任务
404
+ if active_tasks:
405
+ done, _ = await asyncio.wait(active_tasks)
406
+ async for result in handle_completed_tasks(
407
+ done, completed_batch, is_final=True
408
+ ):
409
+ yield result
410
+
411
+ async def execute_batch(
412
+ self,
413
+ async_func: Callable[..., Awaitable[Any]],
414
+ tasks_data: Iterable[Any],
415
+ total_tasks: Optional[int] = None,
416
+ show_progress: bool = True,
417
+ **kwargs,
418
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
419
+ """
420
+ 批量执行异步任务
421
+
422
+ Args:
423
+ async_func: 要执行的异步函数,函数签名应为 async def func(data, meta=None, **kwargs)
424
+ tasks_data: 任务数据列表
425
+ total_tasks: 总任务数量,如果不提供会自动计算
426
+ show_progress: 是否显示进度
427
+ **kwargs: 传递给异步函数的额外参数
428
+
429
+ Returns:
430
+ (结果列表, 进度跟踪器)
431
+ """
432
+ progress = None
433
+
434
+ if total_tasks is None and show_progress:
435
+ tasks_data, data_for_counting = itertools.tee(tasks_data)
436
+ total_tasks = sum(1 for _ in data_for_counting)
437
+
438
+ if show_progress and total_tasks is not None:
439
+ progress = ProgressTracker(
440
+ total_tasks,
441
+ concurrency=self._concurrency_limit,
442
+ config=ProgressBarConfig(),
443
+ )
444
+
445
+ results = []
446
+ async for result in self._process_with_concurrency_window(
447
+ async_func=async_func, tasks_data=tasks_data, progress=progress, **kwargs
448
+ ):
449
+ results.extend(result.completed_tasks)
450
+
451
+ # 按任务ID排序
452
+ results = sorted(results, key=lambda x: x.task_id)
453
+ return results, progress
454
+
455
+ async def _stream_execute(
456
+ self,
457
+ queue: Queue,
458
+ async_func: Callable[..., Awaitable[Any]],
459
+ tasks_data: Iterable[Any],
460
+ total_tasks: Optional[int] = None,
461
+ show_progress: bool = True,
462
+ batch_size: Optional[int] = None,
463
+ **kwargs,
464
+ ):
465
+ """流式执行任务并将结果放入队列"""
466
+ progress = None
467
+ if batch_size is None:
468
+ batch_size = self._concurrency_limit
469
+
470
+ if total_tasks is None and show_progress:
471
+ tasks_data, data_for_counting = itertools.tee(tasks_data)
472
+ total_tasks = sum(1 for _ in data_for_counting)
473
+
474
+ if show_progress and total_tasks is not None:
475
+ progress = ProgressTracker(
476
+ total_tasks,
477
+ concurrency=self._concurrency_limit,
478
+ config=ProgressBarConfig(),
479
+ )
480
+
481
+ async for result in self._process_with_concurrency_window(
482
+ async_func=async_func,
483
+ tasks_data=tasks_data,
484
+ progress=progress,
485
+ batch_size=batch_size,
486
+ **kwargs,
487
+ ):
488
+ await queue.put(result)
489
+
490
+ await queue.put(None)
491
+
492
+ async def aiter_execute_batch(
493
+ self,
494
+ async_func: Callable[..., Awaitable[Any]],
495
+ tasks_data: Iterable[Any],
496
+ total_tasks: Optional[int] = None,
497
+ show_progress: bool = True,
498
+ batch_size: Optional[int] = None,
499
+ **kwargs,
500
+ ) -> AsyncIterator[StreamingExecutionResult]:
501
+ """
502
+ 流式批量执行异步任务
503
+
504
+ Args:
505
+ async_func: 要执行的异步函数
506
+ tasks_data: 任务数据列表
507
+ total_tasks: 总任务数量
508
+ show_progress: 是否显示进度
509
+ batch_size: 每次返回的批次大小
510
+ **kwargs: 传递给异步函数的额外参数
511
+
512
+ Yields:
513
+ StreamingExecutionResult: 包含已完成任务的结果
514
+ """
515
+ queue = Queue()
516
+ task = asyncio.create_task(
517
+ self._stream_execute(
518
+ queue=queue,
519
+ async_func=async_func,
520
+ tasks_data=tasks_data,
521
+ total_tasks=total_tasks,
522
+ show_progress=show_progress,
523
+ batch_size=batch_size,
524
+ **kwargs,
525
+ )
526
+ )
527
+
528
+ try:
529
+ while True:
530
+ result = await queue.get()
531
+ if result is None:
532
+ break
533
+ yield result
534
+ finally:
535
+ if not task.done():
536
+ task.cancel()
537
+
538
+ def execute_batch_sync(
539
+ self,
540
+ async_func: Callable[..., Awaitable[Any]],
541
+ tasks_data: Iterable[Any],
542
+ total_tasks: Optional[int] = None,
543
+ show_progress: bool = True,
544
+ **kwargs,
545
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
546
+ """同步版本的批量执行"""
547
+ try:
548
+ # 检查是否已经在事件循环中
549
+ loop = asyncio.get_running_loop()
550
+ # 如果已经在事件循环中,使用新的线程执行
551
+ import concurrent.futures
552
+ import threading
553
+
554
+ def run_in_thread():
555
+ return asyncio.run(
556
+ self.execute_batch(
557
+ async_func=async_func,
558
+ tasks_data=tasks_data,
559
+ total_tasks=total_tasks,
560
+ show_progress=show_progress,
561
+ **kwargs,
562
+ )
563
+ )
564
+
565
+ with concurrent.futures.ThreadPoolExecutor() as executor:
566
+ future = executor.submit(run_in_thread)
567
+ return future.result()
568
+
569
+ except RuntimeError:
570
+ # 没有运行的事件循环,可以直接使用 asyncio.run
571
+ return asyncio.run(
572
+ self.execute_batch(
573
+ async_func=async_func,
574
+ tasks_data=tasks_data,
575
+ total_tasks=total_tasks,
576
+ show_progress=show_progress,
577
+ **kwargs,
578
+ )
579
+ )
580
+
581
+ async def execute_priority_batch(
582
+ self,
583
+ async_func: Callable[..., Awaitable[Any]],
584
+ priority_tasks: List[TaskItem],
585
+ show_progress: bool = True,
586
+ **kwargs,
587
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
588
+ """
589
+ 按优先级批量执行任务
590
+
591
+ Args:
592
+ async_func: 要执行的异步函数
593
+ priority_tasks: 带优先级的任务列表 (优先级数字越小越优先)
594
+ show_progress: 是否显示进度
595
+ **kwargs: 传递给异步函数的额外参数
596
+
597
+ Returns:
598
+ (结果列表, 进度跟踪器)
599
+ """
600
+ # 创建优先级队列
601
+ task_queue = []
602
+ for task in priority_tasks:
603
+ heapq.heappush(task_queue, task)
604
+
605
+ progress = None
606
+ if show_progress:
607
+ progress = ProgressTracker(
608
+ len(priority_tasks),
609
+ concurrency=self._concurrency_limit,
610
+ config=ProgressBarConfig(),
611
+ )
612
+
613
+ results = []
614
+ active_tasks = set()
615
+
616
+ while task_queue or active_tasks:
617
+ # 启动新任务直到达到并发限制
618
+ while len(active_tasks) < self._concurrency_limit and task_queue:
619
+ task_item = heapq.heappop(task_queue)
620
+ coroutine = self._execute_single_task(
621
+ async_func=async_func,
622
+ task_data=task_item.data,
623
+ task_id=task_item.task_id,
624
+ meta=task_item.meta,
625
+ **kwargs,
626
+ )
627
+ active_tasks.add(asyncio.create_task(coroutine))
628
+
629
+ # 等待至少一个任务完成
630
+ if active_tasks:
631
+ done, active_tasks = await asyncio.wait(
632
+ active_tasks, return_when=asyncio.FIRST_COMPLETED
633
+ )
634
+
635
+ for task in done:
636
+ result = await task
637
+ results.append(result)
638
+
639
+ if progress:
640
+ # 转换为RequestResult以兼容ProgressTracker
641
+ request_result = RequestResult(
642
+ request_id=result.task_id,
643
+ data=result.data,
644
+ status=result.status,
645
+ meta=result.meta,
646
+ latency=result.latency,
647
+ )
648
+ progress.update(request_result)
649
+
650
+ if progress:
651
+ progress.summary()
652
+
653
+ # 按任务ID排序
654
+ results = sorted(results, key=lambda x: x.task_id)
655
+ return results, progress
656
+
657
+ def add_custom_error_handler(self, handler: Callable[[Exception, Any, int], bool]):
658
+ """
659
+ 添加自定义错误处理函数
660
+
661
+ Args:
662
+ handler: 错误处理函数,签名为 (error, task_data, retry_count) -> should_retry
663
+ """
664
+ self.error_handler = handler
665
+
666
+ # === 新增的更灵活的执行方法 ===
667
+
668
+ async def execute_batch_with_adapter(
669
+ self,
670
+ async_func: Callable[..., Awaitable[Any]],
671
+ tasks_data: Iterable[Any],
672
+ task_adapter: Callable[[Any, TaskContext], Tuple[Any, Dict]],
673
+ executor_kwargs: Optional[dict] = None,
674
+ total_tasks: Optional[int] = None,
675
+ show_progress: bool = True,
676
+ **kwargs,
677
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
678
+ """
679
+ 使用任务适配器的批量执行
680
+
681
+ Args:
682
+ async_func: 要执行的异步函数
683
+ tasks_data: 任务数据列表
684
+ task_adapter: 任务适配器函数,签名为 (data, context) -> (args, kwargs)
685
+ 返回值应为 (位置参数, 关键字参数) 的元组
686
+ executor_kwargs: 传递给所有任务的公共参数
687
+ total_tasks: 总任务数量
688
+ show_progress: 是否显示进度
689
+ **kwargs: 其他参数
690
+
691
+ Returns:
692
+ (结果列表, 进度跟踪器)
693
+
694
+ Example:
695
+ # 定义适配器函数
696
+ def my_adapter(data, context):
697
+ # 返回位置参数和关键字参数
698
+ return (data['item'],), {'user_id': data['user_id'], 'batch_id': context.task_id}
699
+
700
+ # 执行
701
+ results, _ = await executor.execute_batch_with_adapter(
702
+ async_func=my_custom_function,
703
+ tasks_data=[{'item': 'a', 'user_id': 1}, {'item': 'b', 'user_id': 2}],
704
+ task_adapter=my_adapter,
705
+ executor_kwargs={'mode': 'fast'}
706
+ )
707
+ """
708
+ progress = None
709
+
710
+ if total_tasks is None and show_progress:
711
+ tasks_data, data_for_counting = itertools.tee(tasks_data)
712
+ total_tasks = sum(1 for _ in data_for_counting)
713
+
714
+ if show_progress and total_tasks is not None:
715
+ progress = ProgressTracker(
716
+ total_tasks,
717
+ concurrency=self._concurrency_limit,
718
+ config=ProgressBarConfig(),
719
+ )
720
+
721
+ results = []
722
+ async for result in self._process_with_concurrency_window(
723
+ async_func=async_func,
724
+ tasks_data=tasks_data,
725
+ progress=progress,
726
+ executor_kwargs=executor_kwargs,
727
+ task_adapter=task_adapter,
728
+ **kwargs,
729
+ ):
730
+ results.extend(result.completed_tasks)
731
+
732
+ # 按任务ID排序
733
+ results = sorted(results, key=lambda x: x.task_id)
734
+ return results, progress
735
+
736
+ async def execute_batch_with_context(
737
+ self,
738
+ async_func: Callable[[TaskContext], Awaitable[Any]],
739
+ tasks_data: Iterable[Any],
740
+ executor_kwargs: Optional[dict] = None,
741
+ total_tasks: Optional[int] = None,
742
+ show_progress: bool = True,
743
+ **kwargs,
744
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
745
+ """
746
+ 使用上下文模式的批量执行,函数直接接收TaskContext对象
747
+
748
+ Args:
749
+ async_func: 要执行的异步函数,签名应为 async def func(context: TaskContext) -> Any
750
+ tasks_data: 任务数据列表
751
+ executor_kwargs: 传递给TaskContext的额外参数
752
+ total_tasks: 总任务数量
753
+ show_progress: 是否显示进度
754
+ **kwargs: 其他参数
755
+
756
+ Returns:
757
+ (结果列表, 进度跟踪器)
758
+
759
+ Example:
760
+ async def context_task(context: TaskContext):
761
+ print(f"处理任务 {context.task_id}: {context.data}")
762
+ print(f"重试次数: {context.retry_count}")
763
+ print(f"额外参数: {context.executor_kwargs}")
764
+ return f"结果: {context.data}"
765
+
766
+ results, _ = await executor.execute_batch_with_context(
767
+ async_func=context_task,
768
+ tasks_data=["data1", "data2", "data3"],
769
+ executor_kwargs={'user_id': 123}
770
+ )
771
+ """
772
+ return await self.execute_batch(
773
+ async_func=async_func,
774
+ tasks_data=tasks_data,
775
+ total_tasks=total_tasks,
776
+ show_progress=show_progress,
777
+ executor_kwargs=executor_kwargs,
778
+ **kwargs,
779
+ )
780
+
781
+ async def execute_batch_with_factory(
782
+ self,
783
+ func_factory: Callable[[TaskContext], Callable[..., Awaitable[Any]]],
784
+ tasks_data: Iterable[Any],
785
+ total_tasks: Optional[int] = None,
786
+ show_progress: bool = True,
787
+ **kwargs,
788
+ ) -> Tuple[List[ExecutionResult], Optional[ProgressTracker]]:
789
+ """
790
+ 使用函数工厂的批量执行,可以为每个任务动态创建不同的执行函数
791
+
792
+ Args:
793
+ func_factory: 函数工厂,根据上下文返回要执行的函数
794
+ tasks_data: 任务数据列表
795
+ total_tasks: 总任务数量
796
+ show_progress: 是否显示进度
797
+ **kwargs: 其他参数
798
+
799
+ Returns:
800
+ (结果列表, 进度跟踪器)
801
+
802
+ Example:
803
+ def task_factory(context: TaskContext):
804
+ if context.task_id % 2 == 0:
805
+ return slow_processor
806
+ else:
807
+ return fast_processor
808
+
809
+ results, _ = await executor.execute_batch_with_factory(
810
+ func_factory=task_factory,
811
+ tasks_data=["data1", "data2", "data3"]
812
+ )
813
+ """
814
+
815
+ async def factory_wrapper(context: TaskContext):
816
+ actual_func = func_factory(context)
817
+ return await actual_func(context.data)
818
+
819
+ return await self.execute_batch_with_context(
820
+ async_func=factory_wrapper,
821
+ tasks_data=tasks_data,
822
+ total_tasks=total_tasks,
823
+ show_progress=show_progress,
824
+ **kwargs,
825
+ )
826
+
827
+
828
+ # ============== 检查点/断点续传 ==============
829
+
830
+ DEFAULT_CHECKPOINT_DIR = os.path.expanduser("~/.cache/maque/checkpoints")
831
+
832
+
833
+ @dataclass
834
+ class CheckpointConfig:
835
+ """
836
+ 检查点配置
837
+
838
+ Attributes:
839
+ enabled: 是否启用检查点
840
+ checkpoint_dir: 检查点存储目录
841
+ checkpoint_interval: 每完成 N 个任务保存一次检查点
842
+ """
843
+ enabled: bool = False
844
+ checkpoint_dir: str = DEFAULT_CHECKPOINT_DIR
845
+ checkpoint_interval: int = 100
846
+
847
+ @classmethod
848
+ def disabled(cls) -> "CheckpointConfig":
849
+ return cls(enabled=False)
850
+
851
+ @classmethod
852
+ def default(cls) -> "CheckpointConfig":
853
+ return cls(enabled=True)
854
+
855
+
856
+ class CheckpointManager:
857
+ """
858
+ 检查点管理器
859
+
860
+ 用于在大规模批量任务中保存和恢复进度。
861
+ 使用 FlaxKV2 作为存储后端。
862
+
863
+ Example:
864
+ # 创建带检查点的执行
865
+ checkpoint = CheckpointManager(CheckpointConfig.default())
866
+ checkpoint_id = checkpoint.create("my_batch_task")
867
+
868
+ # 执行时保存进度
869
+ for i, result in enumerate(results):
870
+ checkpoint.save_result(checkpoint_id, i, result)
871
+ if i % 100 == 0:
872
+ checkpoint.flush(checkpoint_id)
873
+
874
+ # 中断后恢复
875
+ completed, pending_indices = checkpoint.load(checkpoint_id, total_tasks=1000)
876
+ """
877
+
878
+ def __init__(self, config: Optional[CheckpointConfig] = None):
879
+ if not FLAXKV_AVAILABLE:
880
+ raise ImportError("检查点功能需要安装 flaxkv2: pip install flaxkv2")
881
+
882
+ self.config = config or CheckpointConfig.disabled()
883
+ self._db: Optional[FlaxKV] = None
884
+
885
+ if self.config.enabled:
886
+ self._db = FlaxKV(
887
+ "checkpoints",
888
+ self.config.checkpoint_dir,
889
+ write_buffer_size=50,
890
+ async_flush=True,
891
+ )
892
+
893
+ def create(self, name: str = "") -> str:
894
+ """
895
+ 创建新的检查点
896
+
897
+ Args:
898
+ name: 检查点名称 (可选)
899
+
900
+ Returns:
901
+ 检查点 ID
902
+ """
903
+ checkpoint_id = f"{name}_{uuid.uuid4().hex[:8]}" if name else uuid.uuid4().hex[:8]
904
+
905
+ if self._db is not None:
906
+ self._db[f"{checkpoint_id}:meta"] = {
907
+ "name": name,
908
+ "created_at": time.time(),
909
+ "completed_count": 0,
910
+ }
911
+
912
+ return checkpoint_id
913
+
914
+ def save_result(
915
+ self,
916
+ checkpoint_id: str,
917
+ task_id: int,
918
+ result: ExecutionResult,
919
+ ) -> None:
920
+ """保存单个任务结果"""
921
+ if self._db is None:
922
+ return
923
+
924
+ # 保存结果(只保存可序列化的部分)
925
+ self._db[f"{checkpoint_id}:result:{task_id}"] = {
926
+ "task_id": result.task_id,
927
+ "data": result.data,
928
+ "status": result.status,
929
+ "meta": result.meta,
930
+ "latency": result.latency,
931
+ "retry_count": result.retry_count,
932
+ }
933
+
934
+ # 更新计数
935
+ meta = self._db.get(f"{checkpoint_id}:meta", {})
936
+ meta["completed_count"] = meta.get("completed_count", 0) + 1
937
+ meta["last_updated"] = time.time()
938
+ self._db[f"{checkpoint_id}:meta"] = meta
939
+
940
+ def save_pending(
941
+ self,
942
+ checkpoint_id: str,
943
+ pending_data: List[Tuple[int, Any]],
944
+ ) -> None:
945
+ """
946
+ 保存待处理的任务数据
947
+
948
+ Args:
949
+ checkpoint_id: 检查点 ID
950
+ pending_data: (task_id, data) 元组列表
951
+ """
952
+ if self._db is None:
953
+ return
954
+ self._db[f"{checkpoint_id}:pending"] = pending_data
955
+
956
+ def load(
957
+ self,
958
+ checkpoint_id: str,
959
+ total_tasks: int,
960
+ ) -> Tuple[List[ExecutionResult], List[int]]:
961
+ """
962
+ 加载检查点
963
+
964
+ Args:
965
+ checkpoint_id: 检查点 ID
966
+ total_tasks: 总任务数
967
+
968
+ Returns:
969
+ (已完成的结果列表, 待处理的任务索引列表)
970
+ """
971
+ if self._db is None:
972
+ return [], list(range(total_tasks))
973
+
974
+ completed = []
975
+ completed_ids = set()
976
+
977
+ # 读取所有已完成的结果
978
+ for key in self._db.keys():
979
+ if key.startswith(f"{checkpoint_id}:result:"):
980
+ data = self._db[key]
981
+ result = ExecutionResult(
982
+ task_id=data["task_id"],
983
+ data=data["data"],
984
+ status=data["status"],
985
+ meta=data.get("meta"),
986
+ latency=data.get("latency", 0),
987
+ retry_count=data.get("retry_count", 0),
988
+ )
989
+ completed.append(result)
990
+ completed_ids.add(data["task_id"])
991
+
992
+ # 计算待处理的任务
993
+ pending_indices = [i for i in range(total_tasks) if i not in completed_ids]
994
+
995
+ return sorted(completed, key=lambda x: x.task_id), pending_indices
996
+
997
+ def delete(self, checkpoint_id: str) -> None:
998
+ """删除检查点"""
999
+ if self._db is None:
1000
+ return
1001
+
1002
+ keys_to_delete = [
1003
+ key for key in self._db.keys()
1004
+ if key.startswith(f"{checkpoint_id}:")
1005
+ ]
1006
+ for key in keys_to_delete:
1007
+ del self._db[key]
1008
+
1009
+ def list_checkpoints(self) -> List[Dict[str, Any]]:
1010
+ """列出所有检查点"""
1011
+ if self._db is None:
1012
+ return []
1013
+
1014
+ checkpoints = []
1015
+ for key in self._db.keys():
1016
+ if key.endswith(":meta"):
1017
+ checkpoint_id = key.replace(":meta", "")
1018
+ meta = self._db[key]
1019
+ checkpoints.append({
1020
+ "id": checkpoint_id,
1021
+ **meta,
1022
+ })
1023
+
1024
+ return sorted(checkpoints, key=lambda x: x.get("created_at", 0), reverse=True)
1025
+
1026
+ def close(self):
1027
+ """关闭检查点管理器"""
1028
+ if self._db is not None:
1029
+ self._db.close()
1030
+ self._db = None
1031
+
1032
+ def __enter__(self):
1033
+ return self
1034
+
1035
+ def __exit__(self, *args):
1036
+ self.close()