celestialflow 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,1070 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio, time
4
+ from asyncio import Queue as AsyncQueue, QueueEmpty
5
+ from collections import defaultdict
6
+ from collections.abc import Iterable
7
+ from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
8
+ from multiprocessing import Queue as MPQueue
9
+ from queue import Queue as ThreadQueue, Empty
10
+ from threading import Event, Lock
11
+ from typing import List
12
+
13
+ from httpx import (
14
+ ConnectError,
15
+ ConnectTimeout,
16
+ PoolTimeout,
17
+ ProtocolError,
18
+ ReadError,
19
+ ReadTimeout,
20
+ ProxyError,
21
+ RequestError,
22
+ )
23
+
24
+ from .task_progress import ProgressManager, NullProgress
25
+ from .task_logging import LogListener, TaskLogger
26
+ from .task_types import ValueWrapper, NoOpContext, TerminationSignal, TERMINATION_SIGNAL
27
+ from .task_tools import make_hashable, format_repr, object_to_str_hash
28
+
29
+
30
+ class TaskManager:
31
+ def __init__(
32
+ self,
33
+ func,
34
+ execution_mode="serial",
35
+ worker_limit=50,
36
+ max_retries=3,
37
+ max_info=50,
38
+ unpack_task_args=False,
39
+ enable_result_cache=False,
40
+ progress_desc="Processing",
41
+ show_progress=False,
42
+ ):
43
+ """
44
+ 初始化 TaskManager
45
+
46
+ :param func: 可调用对象
47
+ :param execution_mode: 执行模式,可选 'serial', 'thread', 'process', 'async'
48
+ :param worker_limit: 同时处理数量
49
+ :param max_retries: 任务的最大重试次数
50
+ :param max_info: 日志最大条数
51
+ :param unpack_task_args: 是否将任务参数解包
52
+ :param enable_result_cache: 是否启用结果缓存
53
+ :param progress_desc: 进度条显示名称
54
+ :param show_progress: 进度条显示与否
55
+ """
56
+ self.func = func
57
+ self.execution_mode = execution_mode
58
+ self.worker_limit = worker_limit
59
+ self.max_retries = max_retries
60
+ self.max_info = max_info
61
+ self.unpack_task_args = unpack_task_args
62
+ self.enable_result_cache = enable_result_cache
63
+
64
+ self.progress_desc = progress_desc
65
+ self.show_progress = show_progress
66
+
67
+ self.thread_pool = None
68
+ self.process_pool = None
69
+
70
+ self.current_index = 0 # 记录起始队列索引
71
+ self.terminated_queue_set = set()
72
+
73
+ self.prev_stages: List[TaskManager] = []
74
+ self.set_stage_name(None)
75
+
76
+ self.retry_exceptions = (
77
+ ConnectTimeout,
78
+ ProtocolError,
79
+ ReadError,
80
+ ConnectError,
81
+ PoolTimeout,
82
+ ReadTimeout,
83
+ ProxyError,
84
+ ) # 需要重试的异常类型
85
+
86
+ self.init_counter()
87
+
88
+ def init_counter(
89
+ self,
90
+ task_counter=None,
91
+ success_counter=None,
92
+ error_counter=None,
93
+ duplicate_counter=None,
94
+ counter_lock=None,
95
+ extra_stats=None,
96
+ ):
97
+ """
98
+ 初始化计数器
99
+ """
100
+ self.task_counter = task_counter if task_counter is not None else ValueWrapper()
101
+ self.success_counter = (
102
+ success_counter if success_counter is not None else ValueWrapper()
103
+ )
104
+ self.error_counter = (
105
+ error_counter if error_counter is not None else ValueWrapper()
106
+ )
107
+ self.duplicate_counter = (
108
+ duplicate_counter if duplicate_counter is not None else ValueWrapper()
109
+ )
110
+
111
+ self.counter_lock = counter_lock if counter_lock is not None else NoOpContext()
112
+
113
+ self.extra_stats = extra_stats if extra_stats is not None else {}
114
+
115
+ def init_env(
116
+ self, task_queues=None, result_queues=None, fail_queue=None, logger_queue=None
117
+ ):
118
+ """
119
+ 初始化环境
120
+ """
121
+ self.init_queue(task_queues, result_queues, fail_queue, logger_queue)
122
+ self.init_state()
123
+ self.init_pool()
124
+ self.init_logger()
125
+
126
+ def init_queue(
127
+ self, task_queues=None, result_queues=None, fail_queue=None, logger_queue=None
128
+ ):
129
+ """
130
+ 初始化队列
131
+ """
132
+ queue_map = {
133
+ "process": ThreadQueue, # MPqueue
134
+ "async": AsyncQueue,
135
+ "thread": ThreadQueue,
136
+ "serial": ThreadQueue,
137
+ }
138
+
139
+ # task_queues, result_queues与fail_queue只会在节点进程内运行, 因此如果不涉及多个进程的节点间通信, 可以全部使用ThreadQueue
140
+ self.task_queues: List[ThreadQueue | MPQueue | AsyncQueue] = task_queues or [
141
+ queue_map[self.execution_mode]()
142
+ ]
143
+ self.result_queues: List[ThreadQueue | MPQueue | AsyncQueue] = (
144
+ result_queues or [queue_map[self.execution_mode]()]
145
+ )
146
+ self.fail_queue: ThreadQueue | MPQueue | AsyncQueue = (
147
+ fail_queue or queue_map[self.execution_mode]()
148
+ )
149
+ self.logger_queue: ThreadQueue | MPQueue = logger_queue or ThreadQueue()
150
+
151
+ def init_state(self):
152
+ """
153
+ 初始化任务状态:
154
+ - success_dict / error_dict:缓存执行结果
155
+ - retry_time_dict:记录重试次数
156
+ - processed_set:用于重复检测
157
+ """
158
+ self.success_dict = {}
159
+ self.error_dict = {}
160
+ self.retry_time_dict = {} # task_id -> retry_time
161
+
162
+ self.processed_set = set()
163
+
164
+ def init_pool(self):
165
+ """
166
+ 初始化线程池或进程池
167
+ """
168
+ # 可以复用的线程池或进程池
169
+ if self.execution_mode == "thread" and self.thread_pool is None:
170
+ self.thread_pool = ThreadPoolExecutor(max_workers=self.worker_limit)
171
+ elif self.execution_mode == "process" and self.process_pool is None:
172
+ self.process_pool = ProcessPoolExecutor(max_workers=self.worker_limit)
173
+
174
+ def init_logger(self):
175
+ """
176
+ 初始化日志
177
+ """
178
+ self.task_logger = TaskLogger(self.logger_queue)
179
+
180
+ def init_listener(self):
181
+ """
182
+ 初始化监听器
183
+ """
184
+ self.log_listener = LogListener("INFO")
185
+ self.log_listener.start()
186
+
187
+ def init_progress(self):
188
+ """
189
+ 初始化进度条
190
+ """
191
+ if not self.show_progress:
192
+ self.progress_manager = NullProgress()
193
+ return
194
+
195
+ extra_desc = (
196
+ f"{self.execution_mode}-{self.worker_limit}"
197
+ if self.execution_mode != "serial"
198
+ else "serial"
199
+ )
200
+ progress_mode = "normal" if self.execution_mode != "async" else "async"
201
+
202
+ self.progress_manager = ProgressManager(
203
+ total_tasks=0,
204
+ desc=f"{self.progress_desc}({extra_desc})",
205
+ mode=progress_mode,
206
+ )
207
+
208
+ def set_execution_mode(self, execution_mode):
209
+ """
210
+ 设置执行模式
211
+ :param execution_mode: 执行模式,可以是 'thread'(线程), 'process'(进程), 'async'(异步), 'serial'(串行)
212
+ """
213
+ self.execution_mode = (
214
+ execution_mode
215
+ if execution_mode in ["thread", "process", "async", "serial"]
216
+ else "serial"
217
+ )
218
+
219
+ def set_graph_context(
220
+ self,
221
+ next_stages: List[TaskManager] = None,
222
+ stage_mode: str = None,
223
+ stage_name: str = None,
224
+ ):
225
+ """
226
+ 设置链式上下文(仅限组成graph时)
227
+ :param next_stages: 后续节点列表
228
+ :param stage_mode: 当前节点执行模式, 可以是 'serial'(串行)或 'process'(并行)
229
+ :param name: 当前节点名称
230
+ """
231
+ self.set_next_stages(next_stages)
232
+ self.set_stage_mode(stage_mode)
233
+ self.set_stage_name(stage_name)
234
+
235
+ def set_next_stages(self, next_stages: List[TaskManager]):
236
+ """
237
+ 设置后续节点列表, 并为后续节点添加本节点为前置节点
238
+ """
239
+ self.next_stages = next_stages or [] # 默认为空列表
240
+ for next_stage in self.next_stages:
241
+ next_stage.add_prev_stages(self)
242
+
243
+ def set_stage_mode(self, stage_mode: str):
244
+ """
245
+ 设置当前节点在graph中的执行模式, 可以是 'serial'(串行)或 'process'(并行)
246
+ """
247
+ self.stage_mode = stage_mode if stage_mode == "process" else "serial"
248
+
249
+ def set_stage_name(self, name: str):
250
+ """
251
+ 设置当前节点名称
252
+ """
253
+ self.stage_name = name or id(self)
254
+
255
+ def add_prev_stages(self, prev_stage: TaskManager):
256
+ """
257
+ 添加前置节点
258
+ """
259
+ if prev_stage in self.prev_stages:
260
+ return
261
+ self.prev_stages.append(prev_stage)
262
+
263
+ def get_stage_tag(self):
264
+ """
265
+ 获取当前节点在graph中的标签
266
+ """
267
+ return f"{self.stage_name}[{self.func.__name__}]"
268
+
269
+ def get_stage_summary(self) -> dict:
270
+ """
271
+ 获取当前节点的状态快照
272
+ """
273
+ return {
274
+ "stage_mode": self.stage_mode,
275
+ "execution_mode": (
276
+ self.execution_mode
277
+ if self.execution_mode == "serial"
278
+ else f"{self.execution_mode}-{self.worker_limit}"
279
+ ),
280
+ "func_name": self.func.__name__,
281
+ "class_name": self.__class__.__name__,
282
+ }
283
+
284
+ def add_retry_exceptions(self, *exceptions):
285
+ """
286
+ 添加需要重试的异常类型
287
+ """
288
+ self.retry_exceptions = self.retry_exceptions + tuple(exceptions)
289
+
290
+ def get_task_queues(self, poll_interval: float = 0.01):
291
+ """
292
+ 从多个队列中轮询获取任务。
293
+
294
+ :param poll_interval: 每轮遍历后的等待时间(秒)
295
+ :return: 获取到的任务,或 TerminationSignal 表示所有队列已终止
296
+ """
297
+ total_queues = len(self.task_queues)
298
+
299
+ if total_queues == 1:
300
+ # ✅ 只有一个队列时,使用阻塞式 get,提高效率
301
+ queue = self.task_queues[0]
302
+ item = queue.get() # 阻塞等待,无需 sleep
303
+ if isinstance(item, TerminationSignal):
304
+ self.terminated_queue_set.add(0)
305
+ self.task_logger._log("TRACE", f"get_task_queues: queue[0] terminated")
306
+ return TERMINATION_SIGNAL
307
+ return item
308
+
309
+ while True:
310
+ for i in range(total_queues):
311
+ idx = (self.current_index + i) % total_queues # 轮转访问
312
+ if idx in self.terminated_queue_set:
313
+ continue
314
+ queue = self.task_queues[idx]
315
+ try:
316
+ item = queue.get_nowait()
317
+ if isinstance(item, TerminationSignal):
318
+ self.terminated_queue_set.add(idx)
319
+ self.task_logger._log(
320
+ "TRACE", f"get_task_queues: queue[{idx}] terminated"
321
+ )
322
+ continue
323
+ self.current_index = (
324
+ idx + 1
325
+ ) % total_queues # 下一轮从下一个队列开始
326
+ return item
327
+ except Empty:
328
+ continue
329
+ except Exception as e:
330
+ self.task_logger._log(
331
+ "WARNING",
332
+ f"get_task_queues: Error from queue[{idx}]: {type(e).__name__}({e})",
333
+ )
334
+ continue
335
+
336
+ # 所有队列都终止了
337
+ if len(self.terminated_queue_set) == total_queues:
338
+ return TERMINATION_SIGNAL
339
+
340
+ # 所有队列都暂时无数据,避免 busy-wait
341
+ time.sleep(poll_interval)
342
+
343
+ async def get_task_queues_async(self, poll_interval=0.01):
344
+ """
345
+ 异步轮询多个 AsyncQueue,获取任务。
346
+
347
+ :param poll_interval: 全部为空时的 sleep 间隔(秒)
348
+ :return: task 或 TerminationSignal
349
+ """
350
+ total_queues = len(self.task_queues)
351
+
352
+ if total_queues == 1:
353
+ # ✅ 单队列直接 await 阻塞等待
354
+ queue = self.task_queues[0]
355
+ task = await queue.get()
356
+ if isinstance(task, TerminationSignal):
357
+ self.terminated_queue_set.add(0)
358
+ self.task_logger._log(
359
+ "TRACE", "get_task_queues_async: queue[0] terminated"
360
+ )
361
+ return TERMINATION_SIGNAL
362
+ return task
363
+
364
+ while True:
365
+ for i in range(total_queues):
366
+ idx = (self.current_index + i) % total_queues
367
+ if idx in self.terminated_queue_set:
368
+ continue
369
+ queue = self.task_queues[idx]
370
+ try:
371
+ task = queue.get_nowait()
372
+ if isinstance(task, TerminationSignal):
373
+ self.terminated_queue_set.add(idx)
374
+ self.task_logger._log(
375
+ "TRACE", f"get_task_queues_async: queue[{idx}] terminated"
376
+ )
377
+ continue
378
+ self.current_index = (idx + 1) % total_queues
379
+ return task
380
+ except QueueEmpty:
381
+ continue
382
+ except Exception as e:
383
+ self.task_logger._log(
384
+ "WARNING",
385
+ f"get_task_queues_async: queue[{idx}] error: {type(e).__name__}({e})",
386
+ )
387
+ continue
388
+
389
+ if len(self.terminated_queue_set) == total_queues:
390
+ return TERMINATION_SIGNAL
391
+
392
+ await asyncio.sleep(poll_interval)
393
+
394
+ def put_task_queues(self, task_source):
395
+ """
396
+ 将任务放入任务队列
397
+ """
398
+ progress_num = 0
399
+ for item in task_source:
400
+ self.task_queues[0].put(make_hashable(item))
401
+ self.update_task_counter()
402
+ if self.task_counter.value % 100 == 0:
403
+ self.progress_manager.add_total(100)
404
+ progress_num += 100
405
+ self.progress_manager.add_total(self.task_counter.value - progress_num)
406
+
407
+ async def put_task_queues_async(self, task_source):
408
+ """
409
+ 将任务放入任务队列(async模式)
410
+ """
411
+ progress_num = 0
412
+ for item in task_source:
413
+ await self.task_queues[0].put(make_hashable(item))
414
+ self.update_task_counter()
415
+ if self.task_counter.value % 100 == 0:
416
+ self.progress_manager.add_total(100)
417
+ progress_num += 100
418
+ self.progress_manager.add_total(self.task_counter.value - progress_num)
419
+
420
+ def terminate_task_queues(self):
421
+ """
422
+ 终止所有任务队列
423
+ """
424
+ for queue in self.task_queues:
425
+ queue.put(TERMINATION_SIGNAL) # 添加一个哨兵任务,用于结束任务队列
426
+
427
+ async def terminate_task_queues_async(self):
428
+ """
429
+ 终止所有任务队列(async模式)
430
+ """
431
+ for queue in self.task_queues:
432
+ await queue.put(TERMINATION_SIGNAL) # 添加一个哨兵任务,用于结束任务队列
433
+
434
+ def put_result_queues(self, result):
435
+ """
436
+ 将结果放入所有结果队列
437
+ """
438
+ for result_queue in self.result_queues:
439
+ result_queue.put(result)
440
+
441
+ async def put_result_queues_async(self, result):
442
+ """
443
+ 将结果放入所有结果队列(async模式)
444
+ """
445
+ for queue in self.result_queues:
446
+ await queue.put(result)
447
+
448
+ def put_fail_queue(self, task, error):
449
+ """
450
+ 将失败的任务放入失败队列
451
+ """
452
+ self.fail_queue.put(
453
+ {
454
+ "stage_tag": self.get_stage_tag(),
455
+ "task": str(task),
456
+ "error_info": f"{type(error).__name__}({error})",
457
+ "timestamp": time.time(),
458
+ }
459
+ )
460
+
461
+ async def put_fail_queue_async(self, task, error):
462
+ """
463
+ 将失败的任务放入失败队列(异步版本)
464
+ """
465
+ await self.fail_queue.put(
466
+ {
467
+ "stage_tag": self.get_stage_tag(),
468
+ "task": str(task),
469
+ "error_info": f"{type(error).__name__}({error})",
470
+ "timestamp": time.time(),
471
+ }
472
+ )
473
+
474
+ def update_task_counter(self):
475
+ # 加锁方式(保证正确)
476
+ with self.counter_lock:
477
+ self.task_counter.value += 1
478
+
479
+ def update_success_counter(self):
480
+ # 加锁方式(保证正确)
481
+ with self.counter_lock:
482
+ self.success_counter.value += 1
483
+
484
+ async def update_success_counter_async(self):
485
+ await asyncio.to_thread(self.update_success_counter)
486
+
487
+ def update_error_counter(self):
488
+ # 加锁方式(保证正确)
489
+ with self.counter_lock:
490
+ self.error_counter.value += 1
491
+
492
+ def update_duplicate_counter(self):
493
+ # 加锁方式(保证正确)
494
+ with self.counter_lock:
495
+ self.duplicate_counter.value += 1
496
+
497
+ def is_tasks_finished(self) -> bool:
498
+ """
499
+ 判断任务是否完成
500
+ """
501
+ processed = (
502
+ self.success_counter.value
503
+ + self.error_counter.value
504
+ + self.duplicate_counter.value
505
+ )
506
+ return self.task_counter.value == processed
507
+
508
+ def is_duplicate(self, task_id):
509
+ """
510
+ 判断任务是否重复
511
+ """
512
+ return task_id in self.processed_set
513
+
514
+ def deal_dupliacte(self, task):
515
+ """
516
+ 处理重复任务
517
+ """
518
+ self.update_duplicate_counter()
519
+ self.task_logger.task_duplicate(self.func.__name__, self.get_task_info(task))
520
+
521
+ def get_args(self, task):
522
+ """
523
+ 从 obj 中获取参数
524
+
525
+ 在这个示例中,我们假设 obj 是一个参数,并将其打包为元组返回
526
+ """
527
+ if self.unpack_task_args and isinstance(task, tuple):
528
+ return task
529
+ return (task,)
530
+
531
+ def process_result(self, task, result):
532
+ """
533
+ 从结果队列中获取结果,并进行处理
534
+
535
+ 在这个示例中,我们只是简单地返回结果
536
+ """
537
+ return result
538
+
539
+ def process_result_dict(self):
540
+ """
541
+ 处理结果字典
542
+
543
+ 在这个示例中,我们合并了字典并返回
544
+ """
545
+ success_dict = self.get_success_dict()
546
+ error_dict = self.get_error_dict()
547
+
548
+ return {**success_dict, **error_dict}
549
+
550
+ def handle_error_dict(self):
551
+ """
552
+ 处理错误字典
553
+
554
+ 在这个示例中,我们将列表合并为错误组
555
+ """
556
+ error_dict = self.get_error_dict()
557
+
558
+ error_groups = defaultdict(list)
559
+ for task, error in error_dict.items():
560
+ error_groups[error].append(task)
561
+
562
+ return dict(error_groups) # 转换回普通字典
563
+
564
+ def get_task_id(self, task):
565
+ """
566
+ 获取任务ID
567
+ """
568
+ return object_to_str_hash(task)
569
+
570
+ def get_task_info(self, task):
571
+ """
572
+ 获取任务参数信息的可读字符串表示。
573
+ """
574
+ args = self.get_args(task)
575
+
576
+ # 格式化每个参数
577
+ def format_args_list(args_list):
578
+ return [format_repr(arg, self.max_info) for arg in args_list]
579
+
580
+ if len(args) <= 3:
581
+ formatted_args = format_args_list(args)
582
+ else:
583
+ # 显示前两个 + ... + 最后一个
584
+ head = format_args_list(args[:2])
585
+ tail = format_args_list([args[-1]])
586
+ formatted_args = head + ["..."] + tail
587
+
588
+ return f"({', '.join(formatted_args)})"
589
+
590
+ def get_result_info(self, result):
591
+ """
592
+ 获取结果信息
593
+ """
594
+ return format_repr(result, self.max_info)
595
+
596
+ def process_task_success(self, task, result, start_time):
597
+ """
598
+ 统一处理成功任务
599
+
600
+ :param task: 完成的任务
601
+ :param result: 任务的结果
602
+ :param start_time: 任务开始时间
603
+ """
604
+ processed_result = self.process_result(task, result)
605
+
606
+ if self.enable_result_cache:
607
+ self.success_dict[task] = processed_result
608
+
609
+ # ✅ 清理 retry_time_dict
610
+ task_id = self.get_task_id(task)
611
+ self.retry_time_dict.pop(task_id, None)
612
+
613
+ self.update_success_counter()
614
+ self.put_result_queues(processed_result)
615
+ self.task_logger.task_success(
616
+ self.func.__name__,
617
+ self.get_task_info(task),
618
+ self.execution_mode,
619
+ self.get_result_info(result),
620
+ time.time() - start_time,
621
+ )
622
+
623
+ async def process_task_success_async(self, task, result, start_time):
624
+ """
625
+ 异步版本:统一处理成功任务
626
+
627
+ :param task: 完成的任务
628
+ :param result: 任务的结果
629
+ :param start_time: 任务开始时间
630
+ """
631
+ processed_result = self.process_result(task, result)
632
+
633
+ if self.enable_result_cache:
634
+ self.success_dict[task] = processed_result
635
+
636
+ # ✅ 清理 retry_time_dict
637
+ task_id = self.get_task_id(task)
638
+ self.retry_time_dict.pop(task_id, None)
639
+
640
+ await self.update_success_counter_async()
641
+ await self.put_result_queues_async(processed_result)
642
+ self.task_logger.task_success(
643
+ self.func.__name__,
644
+ self.get_task_info(task),
645
+ self.execution_mode,
646
+ self.get_result_info(result),
647
+ time.time() - start_time,
648
+ )
649
+
650
+ def handle_task_error(self, task, exception: Exception):
651
+ """
652
+ 统一处理异常任务
653
+
654
+ :param task: 发生异常的任务
655
+ :param exception: 捕获的异常
656
+ :return 是否需要重试
657
+ """
658
+ task_id = self.get_task_id(task)
659
+ retry_time = self.retry_time_dict.setdefault(task_id, 0)
660
+
661
+ # 基于异常类型决定重试策略
662
+ if (
663
+ isinstance(exception, self.retry_exceptions)
664
+ and retry_time < self.max_retries
665
+ ):
666
+ self.processed_set.remove(task_id)
667
+ self.task_queues[0].put(task) # 只在第一个队列存放retry task
668
+
669
+ self.progress_manager.add_total(1)
670
+ self.retry_time_dict[task_id] += 1
671
+ self.task_logger.task_retry(
672
+ self.func.__name__,
673
+ self.get_task_info(task),
674
+ self.retry_time_dict[task_id],
675
+ exception,
676
+ )
677
+ else:
678
+ # 如果不是可重试的异常,直接将任务标记为失败
679
+ if self.enable_result_cache:
680
+ self.error_dict[task] = exception
681
+
682
+ # ✅ 清理 retry_time_dict
683
+ self.retry_time_dict.pop(task_id, None)
684
+
685
+ self.update_error_counter()
686
+ self.put_fail_queue(task, exception)
687
+ self.task_logger.task_error(
688
+ self.func.__name__, self.get_task_info(task), exception
689
+ )
690
+
691
+ async def handle_task_error_async(self, task, exception: Exception):
692
+ """
693
+ 统一处理任务异常, 异步版本
694
+
695
+ :param task: 发生异常的任务
696
+ :param exception: 捕获的异常
697
+ :return 是否需要重试
698
+ """
699
+ task_id = self.get_task_id(task)
700
+ retry_time = self.retry_time_dict.setdefault(task_id, 0)
701
+
702
+ # 基于异常类型决定重试策略
703
+ if (
704
+ isinstance(exception, self.retry_exceptions)
705
+ and retry_time < self.max_retries
706
+ ):
707
+ self.processed_set.remove(task_id)
708
+ await self.task_queues[0].put(task) # 只在第一个队列存放retry task
709
+
710
+ self.progress_manager.add_total(1)
711
+ self.retry_time_dict[task_id] += 1
712
+ self.task_logger.task_retry(
713
+ self.func.__name__,
714
+ self.get_task_info(task),
715
+ self.retry_time_dict[task_id],
716
+ exception,
717
+ )
718
+ else:
719
+ # 如果不是可重试的异常,直接将任务标记为失败
720
+ if self.enable_result_cache:
721
+ self.error_dict[task] = exception
722
+
723
+ # ✅ 清理 retry_time_dict
724
+ self.retry_time_dict.pop(task_id, None)
725
+
726
+ self.update_error_counter()
727
+ await self.put_fail_queue_async(task, exception)
728
+ self.task_logger.task_error(
729
+ self.func.__name__, self.get_task_info(task), exception
730
+ )
731
+
732
+ def start(self, task_source: Iterable):
733
+ """
734
+ 根据 start_type 的值,选择串行、并行、异步或多进程执行任务
735
+
736
+ :param task_source: 任务迭代器或者生成器
737
+ """
738
+ start_time = time.time()
739
+ self.init_listener()
740
+ self.init_progress()
741
+ self.init_env(logger_queue=self.log_listener.get_queue())
742
+
743
+ self.put_task_queues(task_source)
744
+ self.terminate_task_queues()
745
+ self.task_logger.start_manager(
746
+ self.func.__name__,
747
+ self.task_counter.value,
748
+ self.execution_mode,
749
+ self.worker_limit,
750
+ )
751
+
752
+ # 根据模式运行对应的任务处理函数
753
+ if self.execution_mode == "thread":
754
+ self.run_with_executor(self.thread_pool)
755
+ elif self.execution_mode == "process":
756
+ self.run_with_executor(self.process_pool)
757
+ # cleanup_mpqueue(self.task_queues)
758
+ elif self.execution_mode == "async":
759
+ asyncio.run(self.run_in_async())
760
+ else:
761
+ self.set_execution_mode("serial")
762
+ self.run_in_serial()
763
+
764
+ self.progress_manager.close()
765
+ self.task_logger.end_manager(
766
+ self.func.__name__,
767
+ self.execution_mode,
768
+ time.time() - start_time,
769
+ self.success_counter.value,
770
+ self.error_counter.value,
771
+ self.duplicate_counter.value,
772
+ )
773
+ self.log_listener.stop()
774
+
775
+ async def start_async(self, task_source: Iterable):
776
+ """
777
+ 异步地执行任务
778
+
779
+ :param task_source: 任务迭代器或者生成器
780
+ """
781
+ start_time = time.time()
782
+ self.set_execution_mode("async")
783
+ self.init_listener()
784
+ self.init_progress()
785
+ self.init_env(logger_queue=self.log_listener.get_queue())
786
+
787
+ await self.put_task_queues_async(task_source)
788
+ await self.terminate_task_queues_async()
789
+ self.task_logger.start_manager(
790
+ self.func.__name__,
791
+ self.task_counter.value,
792
+ "async(await)",
793
+ self.worker_limit,
794
+ )
795
+
796
+ await self.run_in_async()
797
+
798
+ self.progress_manager.close()
799
+ self.task_logger.end_manager(
800
+ self.func.__name__,
801
+ self.execution_mode,
802
+ time.time() - start_time,
803
+ self.success_counter.value,
804
+ self.error_counter.value,
805
+ self.duplicate_counter.value,
806
+ )
807
+ self.log_listener.stop()
808
+
809
+ def start_stage(
810
+ self,
811
+ input_queues: List[MPQueue],
812
+ output_queues: List[MPQueue],
813
+ fail_queue: MPQueue,
814
+ logger_queue: MPQueue,
815
+ ):
816
+ """
817
+ 根据 start_type 的值,选择串行、并行执行任务
818
+
819
+ :param input_queues: 输入队列
820
+ :param output_queue: 输出队列
821
+ :param fail_queue: 失败队列
822
+ """
823
+ start_time = time.time()
824
+ self.active = True
825
+ self.init_progress()
826
+ self.init_env(input_queues, output_queues, fail_queue, logger_queue)
827
+ self.task_logger.start_stage(
828
+ self.stage_name, self.func.__name__, self.execution_mode, self.worker_limit
829
+ )
830
+
831
+ # 根据模式运行对应的任务处理函数
832
+ if self.execution_mode == "thread":
833
+ self.run_with_executor(self.thread_pool)
834
+ else:
835
+ self.run_in_serial()
836
+
837
+ # cleanup_mpqueue(input_queues) # 会影响之后finalize_nodes
838
+ self.release_pool()
839
+ self.put_result_queues(TERMINATION_SIGNAL)
840
+
841
+ self.progress_manager.close()
842
+ self.task_logger.end_stage(
843
+ self.stage_name,
844
+ self.func.__name__,
845
+ self.execution_mode,
846
+ time.time() - start_time,
847
+ self.success_counter.value,
848
+ self.error_counter.value,
849
+ self.duplicate_counter.value,
850
+ )
851
+
852
+ def run_in_serial(self):
853
+ """
854
+ 串行地执行任务
855
+ """
856
+ # 从队列中依次获取任务并执行
857
+ while True:
858
+ task = self.get_task_queues()
859
+ task_id = self.get_task_id(task)
860
+ self.task_logger._log(
861
+ "TRACE", f"Task {task} is submitted to {self.func.__name__}"
862
+ )
863
+ if isinstance(task, TerminationSignal):
864
+ # progress_manager.update(1)
865
+ break
866
+ elif self.is_duplicate(task_id):
867
+ self.deal_dupliacte(task)
868
+ self.progress_manager.update(1)
869
+ continue
870
+ self.processed_set.add(task_id)
871
+ try:
872
+ start_time = time.time()
873
+ result = self.func(*self.get_args(task))
874
+ self.process_task_success(task, result, start_time)
875
+ except Exception as error:
876
+ self.handle_task_error(task, error)
877
+ self.progress_manager.update(1)
878
+
879
+ self.terminated_queue_set = set()
880
+
881
+ if not self.is_tasks_finished():
882
+ self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
883
+ self.terminate_task_queues()
884
+ self.run_in_serial()
885
+
886
+ def run_with_executor(self, executor: ThreadPoolExecutor | ProcessPoolExecutor):
887
+ """
888
+ 使用指定的执行池(线程池或进程池)来并行执行任务。
889
+
890
+ :param executor: 线程池或进程池
891
+ """
892
+ task_start_dict = {} # 用于存储任务开始时间
893
+
894
+ # 用于追踪进行中任务数的计数器和事件
895
+ in_flight = 0
896
+ in_flight_lock = Lock()
897
+ all_done_event = Event()
898
+ all_done_event.set() # 初始为无任务状态,设为完成状态
899
+
900
+ def on_task_done(future, task, progress_manager: ProgressManager):
901
+ # 回调函数中处理任务结果
902
+ progress_manager.update(1)
903
+ try:
904
+ result = future.result()
905
+ start_time = task_start_dict[task]
906
+ self.process_task_success(task, result, start_time)
907
+ except Exception as error:
908
+ self.handle_task_error(task, error)
909
+ # 任务完成后减少in_flight计数
910
+ with in_flight_lock:
911
+ nonlocal in_flight
912
+ in_flight -= 1
913
+ if in_flight == 0:
914
+ all_done_event.set()
915
+
916
+ # 从任务队列中提交任务到执行池
917
+ while True:
918
+ task = self.get_task_queues()
919
+ task_id = self.get_task_id(task)
920
+ self.task_logger._log(
921
+ "TRACE", f"Task {task} is submitted to {self.func.__name__}"
922
+ )
923
+
924
+ if isinstance(task, TerminationSignal):
925
+ # 收到终止信号后不再提交新任务
926
+ # progress_manager.update(1)
927
+ break
928
+ elif self.is_duplicate(task_id):
929
+ self.deal_dupliacte(task)
930
+ self.progress_manager.update(1)
931
+ continue
932
+ self.processed_set.add(task_id)
933
+
934
+ # 提交新任务时增加in_flight计数,并清除完成事件
935
+ with in_flight_lock:
936
+ in_flight += 1
937
+ all_done_event.clear()
938
+
939
+ task_start_dict[task] = time.time()
940
+ future = executor.submit(self.func, *self.get_args(task))
941
+ future.add_done_callback(
942
+ lambda f, t=task: on_task_done(f, t, self.progress_manager)
943
+ )
944
+
945
+ # 等待所有已提交任务完成(包括回调)
946
+ all_done_event.wait()
947
+
948
+ # 所有任务和回调都完成了,现在可以安全关闭进度条
949
+ self.terminated_queue_set = set()
950
+
951
+ if not self.is_tasks_finished():
952
+ self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
953
+ self.terminate_task_queues()
954
+ self.run_with_executor(executor)
955
+
956
+ async def run_in_async(self):
957
+ """
958
+ 异步地执行任务,限制并发数量
959
+ """
960
+ semaphore = asyncio.Semaphore(self.worker_limit) # 限制并发数量
961
+
962
+ async def sem_task(task):
963
+ start_time = time.time() # 记录任务开始时间
964
+ async with semaphore: # 使用信号量限制并发
965
+ result = await self._run_single_task(task)
966
+ return task, result, start_time # 返回 task, result 和 start_time
967
+
968
+ # 创建异步任务列表
969
+ async_tasks = []
970
+
971
+ while True:
972
+ task = await self.get_task_queues_async()
973
+ task_id = self.get_task_id(task)
974
+ self.task_logger._log(
975
+ "TRACE", f"Task {task} is submitted to {self.func.__name__}"
976
+ )
977
+ if isinstance(task, TerminationSignal):
978
+ break
979
+ elif self.is_duplicate(task_id):
980
+ self.deal_dupliacte(task)
981
+ self.progress_manager.update(1)
982
+ continue
983
+ self.processed_set.add(task_id)
984
+ async_tasks.append(sem_task(task)) # 使用信号量包裹的任务
985
+
986
+ # 并发运行所有任务
987
+ for task, result, start_time in await asyncio.gather(
988
+ *async_tasks, return_exceptions=True
989
+ ):
990
+ if not isinstance(result, Exception):
991
+ await self.process_task_success_async(task, result, start_time)
992
+ else:
993
+ await self.handle_task_error_async(task, result)
994
+ self.progress_manager.update(1)
995
+
996
+ self.terminated_queue_set = set()
997
+
998
+ if not self.is_tasks_finished():
999
+ self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
1000
+ await self.terminate_task_queues_async()
1001
+ await self.run_in_async()
1002
+
1003
+ async def _run_single_task(self, task):
1004
+ """
1005
+ 运行单个任务并捕获异常
1006
+ """
1007
+ try:
1008
+ result = await self.func(*self.get_args(task))
1009
+ return result
1010
+ except Exception as error:
1011
+ return error
1012
+
1013
+ def get_success_dict(self) -> dict:
1014
+ """
1015
+ 获取成功任务的字典
1016
+ """
1017
+ return dict(self.success_dict)
1018
+
1019
+ def get_error_dict(self) -> dict:
1020
+ """
1021
+ 获取出错任务的字典
1022
+ """
1023
+ return dict(self.error_dict)
1024
+
1025
+ def release_queue(self):
1026
+ """
1027
+ 清理环境
1028
+ """
1029
+ self.task_queues = None
1030
+ self.result_queues = None
1031
+ self.fail_queue = None
1032
+
1033
+ def release_pool(self):
1034
+ """
1035
+ 关闭线程池和进程池,释放资源
1036
+ """
1037
+ for pool in [self.thread_pool, self.process_pool]:
1038
+ if pool:
1039
+ pool.shutdown(wait=True)
1040
+ self.thread_pool = None
1041
+ self.process_pool = None
1042
+
1043
+ def test_method(self, execution_mode: str, task_list: list) -> float:
1044
+ """
1045
+ 测试方法
1046
+ """
1047
+ start = time.time()
1048
+ self.set_execution_mode(execution_mode)
1049
+ self.init_counter()
1050
+ self.init_state()
1051
+ self.start(task_list)
1052
+ return time.time() - start
1053
+
1054
+ def test_methods(self, task_source: Iterable, execution_modes: list = None) -> list:
1055
+ """
1056
+ 测试多种方法
1057
+ """
1058
+ # 如果 task_source 是生成器或一次性可迭代对象,需要提前转化成列表
1059
+ # 确保对不同模式的测试使用同一批任务数据
1060
+ task_list = list(task_source)
1061
+ execution_modes = execution_modes or ["serial", "thread", "process"]
1062
+
1063
+ results = []
1064
+ for mode in execution_modes:
1065
+ result = self.test_method(mode, task_list)
1066
+ results.append([result])
1067
+ return results, execution_modes, ["Time"]
1068
+
1069
+ def __exit__(self, exc_type, exc_val, exc_tb):
1070
+ self.release_queue()