celestialflow 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celestialflow/__init__.py +39 -0
- celestialflow/task_graph.py +665 -0
- celestialflow/task_logging.py +154 -0
- celestialflow/task_manage.py +1070 -0
- celestialflow/task_nodes.py +160 -0
- celestialflow/task_progress.py +57 -0
- celestialflow/task_report.py +162 -0
- celestialflow/task_structure.py +151 -0
- celestialflow/task_tools.py +501 -0
- celestialflow/task_types.py +61 -0
- celestialflow/task_web.py +170 -0
- celestialflow-3.0.1.dist-info/METADATA +301 -0
- celestialflow-3.0.1.dist-info/RECORD +15 -0
- celestialflow-3.0.1.dist-info/WHEEL +5 -0
- celestialflow-3.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1070 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import asyncio, time
|
|
4
|
+
from asyncio import Queue as AsyncQueue, QueueEmpty
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import Iterable
|
|
7
|
+
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor
|
|
8
|
+
from multiprocessing import Queue as MPQueue
|
|
9
|
+
from queue import Queue as ThreadQueue, Empty
|
|
10
|
+
from threading import Event, Lock
|
|
11
|
+
from typing import List
|
|
12
|
+
|
|
13
|
+
from httpx import (
|
|
14
|
+
ConnectError,
|
|
15
|
+
ConnectTimeout,
|
|
16
|
+
PoolTimeout,
|
|
17
|
+
ProtocolError,
|
|
18
|
+
ReadError,
|
|
19
|
+
ReadTimeout,
|
|
20
|
+
ProxyError,
|
|
21
|
+
RequestError,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
from .task_progress import ProgressManager, NullProgress
|
|
25
|
+
from .task_logging import LogListener, TaskLogger
|
|
26
|
+
from .task_types import ValueWrapper, NoOpContext, TerminationSignal, TERMINATION_SIGNAL
|
|
27
|
+
from .task_tools import make_hashable, format_repr, object_to_str_hash
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class TaskManager:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
func,
|
|
34
|
+
execution_mode="serial",
|
|
35
|
+
worker_limit=50,
|
|
36
|
+
max_retries=3,
|
|
37
|
+
max_info=50,
|
|
38
|
+
unpack_task_args=False,
|
|
39
|
+
enable_result_cache=False,
|
|
40
|
+
progress_desc="Processing",
|
|
41
|
+
show_progress=False,
|
|
42
|
+
):
|
|
43
|
+
"""
|
|
44
|
+
初始化 TaskManager
|
|
45
|
+
|
|
46
|
+
:param func: 可调用对象
|
|
47
|
+
:param execution_mode: 执行模式,可选 'serial', 'thread', 'process', 'async'
|
|
48
|
+
:param worker_limit: 同时处理数量
|
|
49
|
+
:param max_retries: 任务的最大重试次数
|
|
50
|
+
:param max_info: 日志最大条数
|
|
51
|
+
:param unpack_task_args: 是否将任务参数解包
|
|
52
|
+
:param enable_result_cache: 是否启用结果缓存
|
|
53
|
+
:param progress_desc: 进度条显示名称
|
|
54
|
+
:param show_progress: 进度条显示与否
|
|
55
|
+
"""
|
|
56
|
+
self.func = func
|
|
57
|
+
self.execution_mode = execution_mode
|
|
58
|
+
self.worker_limit = worker_limit
|
|
59
|
+
self.max_retries = max_retries
|
|
60
|
+
self.max_info = max_info
|
|
61
|
+
self.unpack_task_args = unpack_task_args
|
|
62
|
+
self.enable_result_cache = enable_result_cache
|
|
63
|
+
|
|
64
|
+
self.progress_desc = progress_desc
|
|
65
|
+
self.show_progress = show_progress
|
|
66
|
+
|
|
67
|
+
self.thread_pool = None
|
|
68
|
+
self.process_pool = None
|
|
69
|
+
|
|
70
|
+
self.current_index = 0 # 记录起始队列索引
|
|
71
|
+
self.terminated_queue_set = set()
|
|
72
|
+
|
|
73
|
+
self.prev_stages: List[TaskManager] = []
|
|
74
|
+
self.set_stage_name(None)
|
|
75
|
+
|
|
76
|
+
self.retry_exceptions = (
|
|
77
|
+
ConnectTimeout,
|
|
78
|
+
ProtocolError,
|
|
79
|
+
ReadError,
|
|
80
|
+
ConnectError,
|
|
81
|
+
PoolTimeout,
|
|
82
|
+
ReadTimeout,
|
|
83
|
+
ProxyError,
|
|
84
|
+
) # 需要重试的异常类型
|
|
85
|
+
|
|
86
|
+
self.init_counter()
|
|
87
|
+
|
|
88
|
+
def init_counter(
|
|
89
|
+
self,
|
|
90
|
+
task_counter=None,
|
|
91
|
+
success_counter=None,
|
|
92
|
+
error_counter=None,
|
|
93
|
+
duplicate_counter=None,
|
|
94
|
+
counter_lock=None,
|
|
95
|
+
extra_stats=None,
|
|
96
|
+
):
|
|
97
|
+
"""
|
|
98
|
+
初始化计数器
|
|
99
|
+
"""
|
|
100
|
+
self.task_counter = task_counter if task_counter is not None else ValueWrapper()
|
|
101
|
+
self.success_counter = (
|
|
102
|
+
success_counter if success_counter is not None else ValueWrapper()
|
|
103
|
+
)
|
|
104
|
+
self.error_counter = (
|
|
105
|
+
error_counter if error_counter is not None else ValueWrapper()
|
|
106
|
+
)
|
|
107
|
+
self.duplicate_counter = (
|
|
108
|
+
duplicate_counter if duplicate_counter is not None else ValueWrapper()
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
self.counter_lock = counter_lock if counter_lock is not None else NoOpContext()
|
|
112
|
+
|
|
113
|
+
self.extra_stats = extra_stats if extra_stats is not None else {}
|
|
114
|
+
|
|
115
|
+
def init_env(
|
|
116
|
+
self, task_queues=None, result_queues=None, fail_queue=None, logger_queue=None
|
|
117
|
+
):
|
|
118
|
+
"""
|
|
119
|
+
初始化环境
|
|
120
|
+
"""
|
|
121
|
+
self.init_queue(task_queues, result_queues, fail_queue, logger_queue)
|
|
122
|
+
self.init_state()
|
|
123
|
+
self.init_pool()
|
|
124
|
+
self.init_logger()
|
|
125
|
+
|
|
126
|
+
def init_queue(
|
|
127
|
+
self, task_queues=None, result_queues=None, fail_queue=None, logger_queue=None
|
|
128
|
+
):
|
|
129
|
+
"""
|
|
130
|
+
初始化队列
|
|
131
|
+
"""
|
|
132
|
+
queue_map = {
|
|
133
|
+
"process": ThreadQueue, # MPqueue
|
|
134
|
+
"async": AsyncQueue,
|
|
135
|
+
"thread": ThreadQueue,
|
|
136
|
+
"serial": ThreadQueue,
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
# task_queues, result_queues与fail_queue只会在节点进程内运行, 因此如果不涉及多个进程的节点间通信, 可以全部使用ThreadQueue
|
|
140
|
+
self.task_queues: List[ThreadQueue | MPQueue | AsyncQueue] = task_queues or [
|
|
141
|
+
queue_map[self.execution_mode]()
|
|
142
|
+
]
|
|
143
|
+
self.result_queues: List[ThreadQueue | MPQueue | AsyncQueue] = (
|
|
144
|
+
result_queues or [queue_map[self.execution_mode]()]
|
|
145
|
+
)
|
|
146
|
+
self.fail_queue: ThreadQueue | MPQueue | AsyncQueue = (
|
|
147
|
+
fail_queue or queue_map[self.execution_mode]()
|
|
148
|
+
)
|
|
149
|
+
self.logger_queue: ThreadQueue | MPQueue = logger_queue or ThreadQueue()
|
|
150
|
+
|
|
151
|
+
def init_state(self):
|
|
152
|
+
"""
|
|
153
|
+
初始化任务状态:
|
|
154
|
+
- success_dict / error_dict:缓存执行结果
|
|
155
|
+
- retry_time_dict:记录重试次数
|
|
156
|
+
- processed_set:用于重复检测
|
|
157
|
+
"""
|
|
158
|
+
self.success_dict = {}
|
|
159
|
+
self.error_dict = {}
|
|
160
|
+
self.retry_time_dict = {} # task_id -> retry_time
|
|
161
|
+
|
|
162
|
+
self.processed_set = set()
|
|
163
|
+
|
|
164
|
+
def init_pool(self):
|
|
165
|
+
"""
|
|
166
|
+
初始化线程池或进程池
|
|
167
|
+
"""
|
|
168
|
+
# 可以复用的线程池或进程池
|
|
169
|
+
if self.execution_mode == "thread" and self.thread_pool is None:
|
|
170
|
+
self.thread_pool = ThreadPoolExecutor(max_workers=self.worker_limit)
|
|
171
|
+
elif self.execution_mode == "process" and self.process_pool is None:
|
|
172
|
+
self.process_pool = ProcessPoolExecutor(max_workers=self.worker_limit)
|
|
173
|
+
|
|
174
|
+
def init_logger(self):
|
|
175
|
+
"""
|
|
176
|
+
初始化日志
|
|
177
|
+
"""
|
|
178
|
+
self.task_logger = TaskLogger(self.logger_queue)
|
|
179
|
+
|
|
180
|
+
def init_listener(self):
|
|
181
|
+
"""
|
|
182
|
+
初始化监听器
|
|
183
|
+
"""
|
|
184
|
+
self.log_listener = LogListener("INFO")
|
|
185
|
+
self.log_listener.start()
|
|
186
|
+
|
|
187
|
+
def init_progress(self):
|
|
188
|
+
"""
|
|
189
|
+
初始化进度条
|
|
190
|
+
"""
|
|
191
|
+
if not self.show_progress:
|
|
192
|
+
self.progress_manager = NullProgress()
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
extra_desc = (
|
|
196
|
+
f"{self.execution_mode}-{self.worker_limit}"
|
|
197
|
+
if self.execution_mode != "serial"
|
|
198
|
+
else "serial"
|
|
199
|
+
)
|
|
200
|
+
progress_mode = "normal" if self.execution_mode != "async" else "async"
|
|
201
|
+
|
|
202
|
+
self.progress_manager = ProgressManager(
|
|
203
|
+
total_tasks=0,
|
|
204
|
+
desc=f"{self.progress_desc}({extra_desc})",
|
|
205
|
+
mode=progress_mode,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
def set_execution_mode(self, execution_mode):
|
|
209
|
+
"""
|
|
210
|
+
设置执行模式
|
|
211
|
+
:param execution_mode: 执行模式,可以是 'thread'(线程), 'process'(进程), 'async'(异步), 'serial'(串行)
|
|
212
|
+
"""
|
|
213
|
+
self.execution_mode = (
|
|
214
|
+
execution_mode
|
|
215
|
+
if execution_mode in ["thread", "process", "async", "serial"]
|
|
216
|
+
else "serial"
|
|
217
|
+
)
|
|
218
|
+
|
|
219
|
+
def set_graph_context(
|
|
220
|
+
self,
|
|
221
|
+
next_stages: List[TaskManager] = None,
|
|
222
|
+
stage_mode: str = None,
|
|
223
|
+
stage_name: str = None,
|
|
224
|
+
):
|
|
225
|
+
"""
|
|
226
|
+
设置链式上下文(仅限组成graph时)
|
|
227
|
+
:param next_stages: 后续节点列表
|
|
228
|
+
:param stage_mode: 当前节点执行模式, 可以是 'serial'(串行)或 'process'(并行)
|
|
229
|
+
:param name: 当前节点名称
|
|
230
|
+
"""
|
|
231
|
+
self.set_next_stages(next_stages)
|
|
232
|
+
self.set_stage_mode(stage_mode)
|
|
233
|
+
self.set_stage_name(stage_name)
|
|
234
|
+
|
|
235
|
+
def set_next_stages(self, next_stages: List[TaskManager]):
|
|
236
|
+
"""
|
|
237
|
+
设置后续节点列表, 并为后续节点添加本节点为前置节点
|
|
238
|
+
"""
|
|
239
|
+
self.next_stages = next_stages or [] # 默认为空列表
|
|
240
|
+
for next_stage in self.next_stages:
|
|
241
|
+
next_stage.add_prev_stages(self)
|
|
242
|
+
|
|
243
|
+
def set_stage_mode(self, stage_mode: str):
|
|
244
|
+
"""
|
|
245
|
+
设置当前节点在graph中的执行模式, 可以是 'serial'(串行)或 'process'(并行)
|
|
246
|
+
"""
|
|
247
|
+
self.stage_mode = stage_mode if stage_mode == "process" else "serial"
|
|
248
|
+
|
|
249
|
+
def set_stage_name(self, name: str):
|
|
250
|
+
"""
|
|
251
|
+
设置当前节点名称
|
|
252
|
+
"""
|
|
253
|
+
self.stage_name = name or id(self)
|
|
254
|
+
|
|
255
|
+
def add_prev_stages(self, prev_stage: TaskManager):
|
|
256
|
+
"""
|
|
257
|
+
添加前置节点
|
|
258
|
+
"""
|
|
259
|
+
if prev_stage in self.prev_stages:
|
|
260
|
+
return
|
|
261
|
+
self.prev_stages.append(prev_stage)
|
|
262
|
+
|
|
263
|
+
def get_stage_tag(self):
|
|
264
|
+
"""
|
|
265
|
+
获取当前节点在graph中的标签
|
|
266
|
+
"""
|
|
267
|
+
return f"{self.stage_name}[{self.func.__name__}]"
|
|
268
|
+
|
|
269
|
+
def get_stage_summary(self) -> dict:
|
|
270
|
+
"""
|
|
271
|
+
获取当前节点的状态快照
|
|
272
|
+
"""
|
|
273
|
+
return {
|
|
274
|
+
"stage_mode": self.stage_mode,
|
|
275
|
+
"execution_mode": (
|
|
276
|
+
self.execution_mode
|
|
277
|
+
if self.execution_mode == "serial"
|
|
278
|
+
else f"{self.execution_mode}-{self.worker_limit}"
|
|
279
|
+
),
|
|
280
|
+
"func_name": self.func.__name__,
|
|
281
|
+
"class_name": self.__class__.__name__,
|
|
282
|
+
}
|
|
283
|
+
|
|
284
|
+
def add_retry_exceptions(self, *exceptions):
|
|
285
|
+
"""
|
|
286
|
+
添加需要重试的异常类型
|
|
287
|
+
"""
|
|
288
|
+
self.retry_exceptions = self.retry_exceptions + tuple(exceptions)
|
|
289
|
+
|
|
290
|
+
def get_task_queues(self, poll_interval: float = 0.01):
|
|
291
|
+
"""
|
|
292
|
+
从多个队列中轮询获取任务。
|
|
293
|
+
|
|
294
|
+
:param poll_interval: 每轮遍历后的等待时间(秒)
|
|
295
|
+
:return: 获取到的任务,或 TerminationSignal 表示所有队列已终止
|
|
296
|
+
"""
|
|
297
|
+
total_queues = len(self.task_queues)
|
|
298
|
+
|
|
299
|
+
if total_queues == 1:
|
|
300
|
+
# ✅ 只有一个队列时,使用阻塞式 get,提高效率
|
|
301
|
+
queue = self.task_queues[0]
|
|
302
|
+
item = queue.get() # 阻塞等待,无需 sleep
|
|
303
|
+
if isinstance(item, TerminationSignal):
|
|
304
|
+
self.terminated_queue_set.add(0)
|
|
305
|
+
self.task_logger._log("TRACE", f"get_task_queues: queue[0] terminated")
|
|
306
|
+
return TERMINATION_SIGNAL
|
|
307
|
+
return item
|
|
308
|
+
|
|
309
|
+
while True:
|
|
310
|
+
for i in range(total_queues):
|
|
311
|
+
idx = (self.current_index + i) % total_queues # 轮转访问
|
|
312
|
+
if idx in self.terminated_queue_set:
|
|
313
|
+
continue
|
|
314
|
+
queue = self.task_queues[idx]
|
|
315
|
+
try:
|
|
316
|
+
item = queue.get_nowait()
|
|
317
|
+
if isinstance(item, TerminationSignal):
|
|
318
|
+
self.terminated_queue_set.add(idx)
|
|
319
|
+
self.task_logger._log(
|
|
320
|
+
"TRACE", f"get_task_queues: queue[{idx}] terminated"
|
|
321
|
+
)
|
|
322
|
+
continue
|
|
323
|
+
self.current_index = (
|
|
324
|
+
idx + 1
|
|
325
|
+
) % total_queues # 下一轮从下一个队列开始
|
|
326
|
+
return item
|
|
327
|
+
except Empty:
|
|
328
|
+
continue
|
|
329
|
+
except Exception as e:
|
|
330
|
+
self.task_logger._log(
|
|
331
|
+
"WARNING",
|
|
332
|
+
f"get_task_queues: Error from queue[{idx}]: {type(e).__name__}({e})",
|
|
333
|
+
)
|
|
334
|
+
continue
|
|
335
|
+
|
|
336
|
+
# 所有队列都终止了
|
|
337
|
+
if len(self.terminated_queue_set) == total_queues:
|
|
338
|
+
return TERMINATION_SIGNAL
|
|
339
|
+
|
|
340
|
+
# 所有队列都暂时无数据,避免 busy-wait
|
|
341
|
+
time.sleep(poll_interval)
|
|
342
|
+
|
|
343
|
+
async def get_task_queues_async(self, poll_interval=0.01):
|
|
344
|
+
"""
|
|
345
|
+
异步轮询多个 AsyncQueue,获取任务。
|
|
346
|
+
|
|
347
|
+
:param poll_interval: 全部为空时的 sleep 间隔(秒)
|
|
348
|
+
:return: task 或 TerminationSignal
|
|
349
|
+
"""
|
|
350
|
+
total_queues = len(self.task_queues)
|
|
351
|
+
|
|
352
|
+
if total_queues == 1:
|
|
353
|
+
# ✅ 单队列直接 await 阻塞等待
|
|
354
|
+
queue = self.task_queues[0]
|
|
355
|
+
task = await queue.get()
|
|
356
|
+
if isinstance(task, TerminationSignal):
|
|
357
|
+
self.terminated_queue_set.add(0)
|
|
358
|
+
self.task_logger._log(
|
|
359
|
+
"TRACE", "get_task_queues_async: queue[0] terminated"
|
|
360
|
+
)
|
|
361
|
+
return TERMINATION_SIGNAL
|
|
362
|
+
return task
|
|
363
|
+
|
|
364
|
+
while True:
|
|
365
|
+
for i in range(total_queues):
|
|
366
|
+
idx = (self.current_index + i) % total_queues
|
|
367
|
+
if idx in self.terminated_queue_set:
|
|
368
|
+
continue
|
|
369
|
+
queue = self.task_queues[idx]
|
|
370
|
+
try:
|
|
371
|
+
task = queue.get_nowait()
|
|
372
|
+
if isinstance(task, TerminationSignal):
|
|
373
|
+
self.terminated_queue_set.add(idx)
|
|
374
|
+
self.task_logger._log(
|
|
375
|
+
"TRACE", f"get_task_queues_async: queue[{idx}] terminated"
|
|
376
|
+
)
|
|
377
|
+
continue
|
|
378
|
+
self.current_index = (idx + 1) % total_queues
|
|
379
|
+
return task
|
|
380
|
+
except QueueEmpty:
|
|
381
|
+
continue
|
|
382
|
+
except Exception as e:
|
|
383
|
+
self.task_logger._log(
|
|
384
|
+
"WARNING",
|
|
385
|
+
f"get_task_queues_async: queue[{idx}] error: {type(e).__name__}({e})",
|
|
386
|
+
)
|
|
387
|
+
continue
|
|
388
|
+
|
|
389
|
+
if len(self.terminated_queue_set) == total_queues:
|
|
390
|
+
return TERMINATION_SIGNAL
|
|
391
|
+
|
|
392
|
+
await asyncio.sleep(poll_interval)
|
|
393
|
+
|
|
394
|
+
def put_task_queues(self, task_source):
|
|
395
|
+
"""
|
|
396
|
+
将任务放入任务队列
|
|
397
|
+
"""
|
|
398
|
+
progress_num = 0
|
|
399
|
+
for item in task_source:
|
|
400
|
+
self.task_queues[0].put(make_hashable(item))
|
|
401
|
+
self.update_task_counter()
|
|
402
|
+
if self.task_counter.value % 100 == 0:
|
|
403
|
+
self.progress_manager.add_total(100)
|
|
404
|
+
progress_num += 100
|
|
405
|
+
self.progress_manager.add_total(self.task_counter.value - progress_num)
|
|
406
|
+
|
|
407
|
+
async def put_task_queues_async(self, task_source):
|
|
408
|
+
"""
|
|
409
|
+
将任务放入任务队列(async模式)
|
|
410
|
+
"""
|
|
411
|
+
progress_num = 0
|
|
412
|
+
for item in task_source:
|
|
413
|
+
await self.task_queues[0].put(make_hashable(item))
|
|
414
|
+
self.update_task_counter()
|
|
415
|
+
if self.task_counter.value % 100 == 0:
|
|
416
|
+
self.progress_manager.add_total(100)
|
|
417
|
+
progress_num += 100
|
|
418
|
+
self.progress_manager.add_total(self.task_counter.value - progress_num)
|
|
419
|
+
|
|
420
|
+
def terminate_task_queues(self):
|
|
421
|
+
"""
|
|
422
|
+
终止所有任务队列
|
|
423
|
+
"""
|
|
424
|
+
for queue in self.task_queues:
|
|
425
|
+
queue.put(TERMINATION_SIGNAL) # 添加一个哨兵任务,用于结束任务队列
|
|
426
|
+
|
|
427
|
+
async def terminate_task_queues_async(self):
|
|
428
|
+
"""
|
|
429
|
+
终止所有任务队列(async模式)
|
|
430
|
+
"""
|
|
431
|
+
for queue in self.task_queues:
|
|
432
|
+
await queue.put(TERMINATION_SIGNAL) # 添加一个哨兵任务,用于结束任务队列
|
|
433
|
+
|
|
434
|
+
def put_result_queues(self, result):
|
|
435
|
+
"""
|
|
436
|
+
将结果放入所有结果队列
|
|
437
|
+
"""
|
|
438
|
+
for result_queue in self.result_queues:
|
|
439
|
+
result_queue.put(result)
|
|
440
|
+
|
|
441
|
+
async def put_result_queues_async(self, result):
|
|
442
|
+
"""
|
|
443
|
+
将结果放入所有结果队列(async模式)
|
|
444
|
+
"""
|
|
445
|
+
for queue in self.result_queues:
|
|
446
|
+
await queue.put(result)
|
|
447
|
+
|
|
448
|
+
def put_fail_queue(self, task, error):
|
|
449
|
+
"""
|
|
450
|
+
将失败的任务放入失败队列
|
|
451
|
+
"""
|
|
452
|
+
self.fail_queue.put(
|
|
453
|
+
{
|
|
454
|
+
"stage_tag": self.get_stage_tag(),
|
|
455
|
+
"task": str(task),
|
|
456
|
+
"error_info": f"{type(error).__name__}({error})",
|
|
457
|
+
"timestamp": time.time(),
|
|
458
|
+
}
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
async def put_fail_queue_async(self, task, error):
|
|
462
|
+
"""
|
|
463
|
+
将失败的任务放入失败队列(异步版本)
|
|
464
|
+
"""
|
|
465
|
+
await self.fail_queue.put(
|
|
466
|
+
{
|
|
467
|
+
"stage_tag": self.get_stage_tag(),
|
|
468
|
+
"task": str(task),
|
|
469
|
+
"error_info": f"{type(error).__name__}({error})",
|
|
470
|
+
"timestamp": time.time(),
|
|
471
|
+
}
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
def update_task_counter(self):
|
|
475
|
+
# 加锁方式(保证正确)
|
|
476
|
+
with self.counter_lock:
|
|
477
|
+
self.task_counter.value += 1
|
|
478
|
+
|
|
479
|
+
def update_success_counter(self):
|
|
480
|
+
# 加锁方式(保证正确)
|
|
481
|
+
with self.counter_lock:
|
|
482
|
+
self.success_counter.value += 1
|
|
483
|
+
|
|
484
|
+
async def update_success_counter_async(self):
|
|
485
|
+
await asyncio.to_thread(self.update_success_counter)
|
|
486
|
+
|
|
487
|
+
def update_error_counter(self):
|
|
488
|
+
# 加锁方式(保证正确)
|
|
489
|
+
with self.counter_lock:
|
|
490
|
+
self.error_counter.value += 1
|
|
491
|
+
|
|
492
|
+
def update_duplicate_counter(self):
|
|
493
|
+
# 加锁方式(保证正确)
|
|
494
|
+
with self.counter_lock:
|
|
495
|
+
self.duplicate_counter.value += 1
|
|
496
|
+
|
|
497
|
+
def is_tasks_finished(self) -> bool:
|
|
498
|
+
"""
|
|
499
|
+
判断任务是否完成
|
|
500
|
+
"""
|
|
501
|
+
processed = (
|
|
502
|
+
self.success_counter.value
|
|
503
|
+
+ self.error_counter.value
|
|
504
|
+
+ self.duplicate_counter.value
|
|
505
|
+
)
|
|
506
|
+
return self.task_counter.value == processed
|
|
507
|
+
|
|
508
|
+
def is_duplicate(self, task_id):
|
|
509
|
+
"""
|
|
510
|
+
判断任务是否重复
|
|
511
|
+
"""
|
|
512
|
+
return task_id in self.processed_set
|
|
513
|
+
|
|
514
|
+
def deal_dupliacte(self, task):
|
|
515
|
+
"""
|
|
516
|
+
处理重复任务
|
|
517
|
+
"""
|
|
518
|
+
self.update_duplicate_counter()
|
|
519
|
+
self.task_logger.task_duplicate(self.func.__name__, self.get_task_info(task))
|
|
520
|
+
|
|
521
|
+
def get_args(self, task):
|
|
522
|
+
"""
|
|
523
|
+
从 obj 中获取参数
|
|
524
|
+
|
|
525
|
+
在这个示例中,我们假设 obj 是一个参数,并将其打包为元组返回
|
|
526
|
+
"""
|
|
527
|
+
if self.unpack_task_args and isinstance(task, tuple):
|
|
528
|
+
return task
|
|
529
|
+
return (task,)
|
|
530
|
+
|
|
531
|
+
def process_result(self, task, result):
|
|
532
|
+
"""
|
|
533
|
+
从结果队列中获取结果,并进行处理
|
|
534
|
+
|
|
535
|
+
在这个示例中,我们只是简单地返回结果
|
|
536
|
+
"""
|
|
537
|
+
return result
|
|
538
|
+
|
|
539
|
+
def process_result_dict(self):
|
|
540
|
+
"""
|
|
541
|
+
处理结果字典
|
|
542
|
+
|
|
543
|
+
在这个示例中,我们合并了字典并返回
|
|
544
|
+
"""
|
|
545
|
+
success_dict = self.get_success_dict()
|
|
546
|
+
error_dict = self.get_error_dict()
|
|
547
|
+
|
|
548
|
+
return {**success_dict, **error_dict}
|
|
549
|
+
|
|
550
|
+
def handle_error_dict(self):
|
|
551
|
+
"""
|
|
552
|
+
处理错误字典
|
|
553
|
+
|
|
554
|
+
在这个示例中,我们将列表合并为错误组
|
|
555
|
+
"""
|
|
556
|
+
error_dict = self.get_error_dict()
|
|
557
|
+
|
|
558
|
+
error_groups = defaultdict(list)
|
|
559
|
+
for task, error in error_dict.items():
|
|
560
|
+
error_groups[error].append(task)
|
|
561
|
+
|
|
562
|
+
return dict(error_groups) # 转换回普通字典
|
|
563
|
+
|
|
564
|
+
def get_task_id(self, task):
|
|
565
|
+
"""
|
|
566
|
+
获取任务ID
|
|
567
|
+
"""
|
|
568
|
+
return object_to_str_hash(task)
|
|
569
|
+
|
|
570
|
+
def get_task_info(self, task):
|
|
571
|
+
"""
|
|
572
|
+
获取任务参数信息的可读字符串表示。
|
|
573
|
+
"""
|
|
574
|
+
args = self.get_args(task)
|
|
575
|
+
|
|
576
|
+
# 格式化每个参数
|
|
577
|
+
def format_args_list(args_list):
|
|
578
|
+
return [format_repr(arg, self.max_info) for arg in args_list]
|
|
579
|
+
|
|
580
|
+
if len(args) <= 3:
|
|
581
|
+
formatted_args = format_args_list(args)
|
|
582
|
+
else:
|
|
583
|
+
# 显示前两个 + ... + 最后一个
|
|
584
|
+
head = format_args_list(args[:2])
|
|
585
|
+
tail = format_args_list([args[-1]])
|
|
586
|
+
formatted_args = head + ["..."] + tail
|
|
587
|
+
|
|
588
|
+
return f"({', '.join(formatted_args)})"
|
|
589
|
+
|
|
590
|
+
def get_result_info(self, result):
|
|
591
|
+
"""
|
|
592
|
+
获取结果信息
|
|
593
|
+
"""
|
|
594
|
+
return format_repr(result, self.max_info)
|
|
595
|
+
|
|
596
|
+
def process_task_success(self, task, result, start_time):
|
|
597
|
+
"""
|
|
598
|
+
统一处理成功任务
|
|
599
|
+
|
|
600
|
+
:param task: 完成的任务
|
|
601
|
+
:param result: 任务的结果
|
|
602
|
+
:param start_time: 任务开始时间
|
|
603
|
+
"""
|
|
604
|
+
processed_result = self.process_result(task, result)
|
|
605
|
+
|
|
606
|
+
if self.enable_result_cache:
|
|
607
|
+
self.success_dict[task] = processed_result
|
|
608
|
+
|
|
609
|
+
# ✅ 清理 retry_time_dict
|
|
610
|
+
task_id = self.get_task_id(task)
|
|
611
|
+
self.retry_time_dict.pop(task_id, None)
|
|
612
|
+
|
|
613
|
+
self.update_success_counter()
|
|
614
|
+
self.put_result_queues(processed_result)
|
|
615
|
+
self.task_logger.task_success(
|
|
616
|
+
self.func.__name__,
|
|
617
|
+
self.get_task_info(task),
|
|
618
|
+
self.execution_mode,
|
|
619
|
+
self.get_result_info(result),
|
|
620
|
+
time.time() - start_time,
|
|
621
|
+
)
|
|
622
|
+
|
|
623
|
+
async def process_task_success_async(self, task, result, start_time):
|
|
624
|
+
"""
|
|
625
|
+
异步版本:统一处理成功任务
|
|
626
|
+
|
|
627
|
+
:param task: 完成的任务
|
|
628
|
+
:param result: 任务的结果
|
|
629
|
+
:param start_time: 任务开始时间
|
|
630
|
+
"""
|
|
631
|
+
processed_result = self.process_result(task, result)
|
|
632
|
+
|
|
633
|
+
if self.enable_result_cache:
|
|
634
|
+
self.success_dict[task] = processed_result
|
|
635
|
+
|
|
636
|
+
# ✅ 清理 retry_time_dict
|
|
637
|
+
task_id = self.get_task_id(task)
|
|
638
|
+
self.retry_time_dict.pop(task_id, None)
|
|
639
|
+
|
|
640
|
+
await self.update_success_counter_async()
|
|
641
|
+
await self.put_result_queues_async(processed_result)
|
|
642
|
+
self.task_logger.task_success(
|
|
643
|
+
self.func.__name__,
|
|
644
|
+
self.get_task_info(task),
|
|
645
|
+
self.execution_mode,
|
|
646
|
+
self.get_result_info(result),
|
|
647
|
+
time.time() - start_time,
|
|
648
|
+
)
|
|
649
|
+
|
|
650
|
+
def handle_task_error(self, task, exception: Exception):
|
|
651
|
+
"""
|
|
652
|
+
统一处理异常任务
|
|
653
|
+
|
|
654
|
+
:param task: 发生异常的任务
|
|
655
|
+
:param exception: 捕获的异常
|
|
656
|
+
:return 是否需要重试
|
|
657
|
+
"""
|
|
658
|
+
task_id = self.get_task_id(task)
|
|
659
|
+
retry_time = self.retry_time_dict.setdefault(task_id, 0)
|
|
660
|
+
|
|
661
|
+
# 基于异常类型决定重试策略
|
|
662
|
+
if (
|
|
663
|
+
isinstance(exception, self.retry_exceptions)
|
|
664
|
+
and retry_time < self.max_retries
|
|
665
|
+
):
|
|
666
|
+
self.processed_set.remove(task_id)
|
|
667
|
+
self.task_queues[0].put(task) # 只在第一个队列存放retry task
|
|
668
|
+
|
|
669
|
+
self.progress_manager.add_total(1)
|
|
670
|
+
self.retry_time_dict[task_id] += 1
|
|
671
|
+
self.task_logger.task_retry(
|
|
672
|
+
self.func.__name__,
|
|
673
|
+
self.get_task_info(task),
|
|
674
|
+
self.retry_time_dict[task_id],
|
|
675
|
+
exception,
|
|
676
|
+
)
|
|
677
|
+
else:
|
|
678
|
+
# 如果不是可重试的异常,直接将任务标记为失败
|
|
679
|
+
if self.enable_result_cache:
|
|
680
|
+
self.error_dict[task] = exception
|
|
681
|
+
|
|
682
|
+
# ✅ 清理 retry_time_dict
|
|
683
|
+
self.retry_time_dict.pop(task_id, None)
|
|
684
|
+
|
|
685
|
+
self.update_error_counter()
|
|
686
|
+
self.put_fail_queue(task, exception)
|
|
687
|
+
self.task_logger.task_error(
|
|
688
|
+
self.func.__name__, self.get_task_info(task), exception
|
|
689
|
+
)
|
|
690
|
+
|
|
691
|
+
async def handle_task_error_async(self, task, exception: Exception):
|
|
692
|
+
"""
|
|
693
|
+
统一处理任务异常, 异步版本
|
|
694
|
+
|
|
695
|
+
:param task: 发生异常的任务
|
|
696
|
+
:param exception: 捕获的异常
|
|
697
|
+
:return 是否需要重试
|
|
698
|
+
"""
|
|
699
|
+
task_id = self.get_task_id(task)
|
|
700
|
+
retry_time = self.retry_time_dict.setdefault(task_id, 0)
|
|
701
|
+
|
|
702
|
+
# 基于异常类型决定重试策略
|
|
703
|
+
if (
|
|
704
|
+
isinstance(exception, self.retry_exceptions)
|
|
705
|
+
and retry_time < self.max_retries
|
|
706
|
+
):
|
|
707
|
+
self.processed_set.remove(task_id)
|
|
708
|
+
await self.task_queues[0].put(task) # 只在第一个队列存放retry task
|
|
709
|
+
|
|
710
|
+
self.progress_manager.add_total(1)
|
|
711
|
+
self.retry_time_dict[task_id] += 1
|
|
712
|
+
self.task_logger.task_retry(
|
|
713
|
+
self.func.__name__,
|
|
714
|
+
self.get_task_info(task),
|
|
715
|
+
self.retry_time_dict[task_id],
|
|
716
|
+
exception,
|
|
717
|
+
)
|
|
718
|
+
else:
|
|
719
|
+
# 如果不是可重试的异常,直接将任务标记为失败
|
|
720
|
+
if self.enable_result_cache:
|
|
721
|
+
self.error_dict[task] = exception
|
|
722
|
+
|
|
723
|
+
# ✅ 清理 retry_time_dict
|
|
724
|
+
self.retry_time_dict.pop(task_id, None)
|
|
725
|
+
|
|
726
|
+
self.update_error_counter()
|
|
727
|
+
await self.put_fail_queue_async(task, exception)
|
|
728
|
+
self.task_logger.task_error(
|
|
729
|
+
self.func.__name__, self.get_task_info(task), exception
|
|
730
|
+
)
|
|
731
|
+
|
|
732
|
+
def start(self, task_source: Iterable):
|
|
733
|
+
"""
|
|
734
|
+
根据 start_type 的值,选择串行、并行、异步或多进程执行任务
|
|
735
|
+
|
|
736
|
+
:param task_source: 任务迭代器或者生成器
|
|
737
|
+
"""
|
|
738
|
+
start_time = time.time()
|
|
739
|
+
self.init_listener()
|
|
740
|
+
self.init_progress()
|
|
741
|
+
self.init_env(logger_queue=self.log_listener.get_queue())
|
|
742
|
+
|
|
743
|
+
self.put_task_queues(task_source)
|
|
744
|
+
self.terminate_task_queues()
|
|
745
|
+
self.task_logger.start_manager(
|
|
746
|
+
self.func.__name__,
|
|
747
|
+
self.task_counter.value,
|
|
748
|
+
self.execution_mode,
|
|
749
|
+
self.worker_limit,
|
|
750
|
+
)
|
|
751
|
+
|
|
752
|
+
# 根据模式运行对应的任务处理函数
|
|
753
|
+
if self.execution_mode == "thread":
|
|
754
|
+
self.run_with_executor(self.thread_pool)
|
|
755
|
+
elif self.execution_mode == "process":
|
|
756
|
+
self.run_with_executor(self.process_pool)
|
|
757
|
+
# cleanup_mpqueue(self.task_queues)
|
|
758
|
+
elif self.execution_mode == "async":
|
|
759
|
+
asyncio.run(self.run_in_async())
|
|
760
|
+
else:
|
|
761
|
+
self.set_execution_mode("serial")
|
|
762
|
+
self.run_in_serial()
|
|
763
|
+
|
|
764
|
+
self.progress_manager.close()
|
|
765
|
+
self.task_logger.end_manager(
|
|
766
|
+
self.func.__name__,
|
|
767
|
+
self.execution_mode,
|
|
768
|
+
time.time() - start_time,
|
|
769
|
+
self.success_counter.value,
|
|
770
|
+
self.error_counter.value,
|
|
771
|
+
self.duplicate_counter.value,
|
|
772
|
+
)
|
|
773
|
+
self.log_listener.stop()
|
|
774
|
+
|
|
775
|
+
async def start_async(self, task_source: Iterable):
|
|
776
|
+
"""
|
|
777
|
+
异步地执行任务
|
|
778
|
+
|
|
779
|
+
:param task_source: 任务迭代器或者生成器
|
|
780
|
+
"""
|
|
781
|
+
start_time = time.time()
|
|
782
|
+
self.set_execution_mode("async")
|
|
783
|
+
self.init_listener()
|
|
784
|
+
self.init_progress()
|
|
785
|
+
self.init_env(logger_queue=self.log_listener.get_queue())
|
|
786
|
+
|
|
787
|
+
await self.put_task_queues_async(task_source)
|
|
788
|
+
await self.terminate_task_queues_async()
|
|
789
|
+
self.task_logger.start_manager(
|
|
790
|
+
self.func.__name__,
|
|
791
|
+
self.task_counter.value,
|
|
792
|
+
"async(await)",
|
|
793
|
+
self.worker_limit,
|
|
794
|
+
)
|
|
795
|
+
|
|
796
|
+
await self.run_in_async()
|
|
797
|
+
|
|
798
|
+
self.progress_manager.close()
|
|
799
|
+
self.task_logger.end_manager(
|
|
800
|
+
self.func.__name__,
|
|
801
|
+
self.execution_mode,
|
|
802
|
+
time.time() - start_time,
|
|
803
|
+
self.success_counter.value,
|
|
804
|
+
self.error_counter.value,
|
|
805
|
+
self.duplicate_counter.value,
|
|
806
|
+
)
|
|
807
|
+
self.log_listener.stop()
|
|
808
|
+
|
|
809
|
+
def start_stage(
|
|
810
|
+
self,
|
|
811
|
+
input_queues: List[MPQueue],
|
|
812
|
+
output_queues: List[MPQueue],
|
|
813
|
+
fail_queue: MPQueue,
|
|
814
|
+
logger_queue: MPQueue,
|
|
815
|
+
):
|
|
816
|
+
"""
|
|
817
|
+
根据 start_type 的值,选择串行、并行执行任务
|
|
818
|
+
|
|
819
|
+
:param input_queues: 输入队列
|
|
820
|
+
:param output_queue: 输出队列
|
|
821
|
+
:param fail_queue: 失败队列
|
|
822
|
+
"""
|
|
823
|
+
start_time = time.time()
|
|
824
|
+
self.active = True
|
|
825
|
+
self.init_progress()
|
|
826
|
+
self.init_env(input_queues, output_queues, fail_queue, logger_queue)
|
|
827
|
+
self.task_logger.start_stage(
|
|
828
|
+
self.stage_name, self.func.__name__, self.execution_mode, self.worker_limit
|
|
829
|
+
)
|
|
830
|
+
|
|
831
|
+
# 根据模式运行对应的任务处理函数
|
|
832
|
+
if self.execution_mode == "thread":
|
|
833
|
+
self.run_with_executor(self.thread_pool)
|
|
834
|
+
else:
|
|
835
|
+
self.run_in_serial()
|
|
836
|
+
|
|
837
|
+
# cleanup_mpqueue(input_queues) # 会影响之后finalize_nodes
|
|
838
|
+
self.release_pool()
|
|
839
|
+
self.put_result_queues(TERMINATION_SIGNAL)
|
|
840
|
+
|
|
841
|
+
self.progress_manager.close()
|
|
842
|
+
self.task_logger.end_stage(
|
|
843
|
+
self.stage_name,
|
|
844
|
+
self.func.__name__,
|
|
845
|
+
self.execution_mode,
|
|
846
|
+
time.time() - start_time,
|
|
847
|
+
self.success_counter.value,
|
|
848
|
+
self.error_counter.value,
|
|
849
|
+
self.duplicate_counter.value,
|
|
850
|
+
)
|
|
851
|
+
|
|
852
|
+
def run_in_serial(self):
|
|
853
|
+
"""
|
|
854
|
+
串行地执行任务
|
|
855
|
+
"""
|
|
856
|
+
# 从队列中依次获取任务并执行
|
|
857
|
+
while True:
|
|
858
|
+
task = self.get_task_queues()
|
|
859
|
+
task_id = self.get_task_id(task)
|
|
860
|
+
self.task_logger._log(
|
|
861
|
+
"TRACE", f"Task {task} is submitted to {self.func.__name__}"
|
|
862
|
+
)
|
|
863
|
+
if isinstance(task, TerminationSignal):
|
|
864
|
+
# progress_manager.update(1)
|
|
865
|
+
break
|
|
866
|
+
elif self.is_duplicate(task_id):
|
|
867
|
+
self.deal_dupliacte(task)
|
|
868
|
+
self.progress_manager.update(1)
|
|
869
|
+
continue
|
|
870
|
+
self.processed_set.add(task_id)
|
|
871
|
+
try:
|
|
872
|
+
start_time = time.time()
|
|
873
|
+
result = self.func(*self.get_args(task))
|
|
874
|
+
self.process_task_success(task, result, start_time)
|
|
875
|
+
except Exception as error:
|
|
876
|
+
self.handle_task_error(task, error)
|
|
877
|
+
self.progress_manager.update(1)
|
|
878
|
+
|
|
879
|
+
self.terminated_queue_set = set()
|
|
880
|
+
|
|
881
|
+
if not self.is_tasks_finished():
|
|
882
|
+
self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
|
|
883
|
+
self.terminate_task_queues()
|
|
884
|
+
self.run_in_serial()
|
|
885
|
+
|
|
886
|
+
def run_with_executor(self, executor: ThreadPoolExecutor | ProcessPoolExecutor):
|
|
887
|
+
"""
|
|
888
|
+
使用指定的执行池(线程池或进程池)来并行执行任务。
|
|
889
|
+
|
|
890
|
+
:param executor: 线程池或进程池
|
|
891
|
+
"""
|
|
892
|
+
task_start_dict = {} # 用于存储任务开始时间
|
|
893
|
+
|
|
894
|
+
# 用于追踪进行中任务数的计数器和事件
|
|
895
|
+
in_flight = 0
|
|
896
|
+
in_flight_lock = Lock()
|
|
897
|
+
all_done_event = Event()
|
|
898
|
+
all_done_event.set() # 初始为无任务状态,设为完成状态
|
|
899
|
+
|
|
900
|
+
def on_task_done(future, task, progress_manager: ProgressManager):
|
|
901
|
+
# 回调函数中处理任务结果
|
|
902
|
+
progress_manager.update(1)
|
|
903
|
+
try:
|
|
904
|
+
result = future.result()
|
|
905
|
+
start_time = task_start_dict[task]
|
|
906
|
+
self.process_task_success(task, result, start_time)
|
|
907
|
+
except Exception as error:
|
|
908
|
+
self.handle_task_error(task, error)
|
|
909
|
+
# 任务完成后减少in_flight计数
|
|
910
|
+
with in_flight_lock:
|
|
911
|
+
nonlocal in_flight
|
|
912
|
+
in_flight -= 1
|
|
913
|
+
if in_flight == 0:
|
|
914
|
+
all_done_event.set()
|
|
915
|
+
|
|
916
|
+
# 从任务队列中提交任务到执行池
|
|
917
|
+
while True:
|
|
918
|
+
task = self.get_task_queues()
|
|
919
|
+
task_id = self.get_task_id(task)
|
|
920
|
+
self.task_logger._log(
|
|
921
|
+
"TRACE", f"Task {task} is submitted to {self.func.__name__}"
|
|
922
|
+
)
|
|
923
|
+
|
|
924
|
+
if isinstance(task, TerminationSignal):
|
|
925
|
+
# 收到终止信号后不再提交新任务
|
|
926
|
+
# progress_manager.update(1)
|
|
927
|
+
break
|
|
928
|
+
elif self.is_duplicate(task_id):
|
|
929
|
+
self.deal_dupliacte(task)
|
|
930
|
+
self.progress_manager.update(1)
|
|
931
|
+
continue
|
|
932
|
+
self.processed_set.add(task_id)
|
|
933
|
+
|
|
934
|
+
# 提交新任务时增加in_flight计数,并清除完成事件
|
|
935
|
+
with in_flight_lock:
|
|
936
|
+
in_flight += 1
|
|
937
|
+
all_done_event.clear()
|
|
938
|
+
|
|
939
|
+
task_start_dict[task] = time.time()
|
|
940
|
+
future = executor.submit(self.func, *self.get_args(task))
|
|
941
|
+
future.add_done_callback(
|
|
942
|
+
lambda f, t=task: on_task_done(f, t, self.progress_manager)
|
|
943
|
+
)
|
|
944
|
+
|
|
945
|
+
# 等待所有已提交任务完成(包括回调)
|
|
946
|
+
all_done_event.wait()
|
|
947
|
+
|
|
948
|
+
# 所有任务和回调都完成了,现在可以安全关闭进度条
|
|
949
|
+
self.terminated_queue_set = set()
|
|
950
|
+
|
|
951
|
+
if not self.is_tasks_finished():
|
|
952
|
+
self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
|
|
953
|
+
self.terminate_task_queues()
|
|
954
|
+
self.run_with_executor(executor)
|
|
955
|
+
|
|
956
|
+
async def run_in_async(self):
|
|
957
|
+
"""
|
|
958
|
+
异步地执行任务,限制并发数量
|
|
959
|
+
"""
|
|
960
|
+
semaphore = asyncio.Semaphore(self.worker_limit) # 限制并发数量
|
|
961
|
+
|
|
962
|
+
async def sem_task(task):
|
|
963
|
+
start_time = time.time() # 记录任务开始时间
|
|
964
|
+
async with semaphore: # 使用信号量限制并发
|
|
965
|
+
result = await self._run_single_task(task)
|
|
966
|
+
return task, result, start_time # 返回 task, result 和 start_time
|
|
967
|
+
|
|
968
|
+
# 创建异步任务列表
|
|
969
|
+
async_tasks = []
|
|
970
|
+
|
|
971
|
+
while True:
|
|
972
|
+
task = await self.get_task_queues_async()
|
|
973
|
+
task_id = self.get_task_id(task)
|
|
974
|
+
self.task_logger._log(
|
|
975
|
+
"TRACE", f"Task {task} is submitted to {self.func.__name__}"
|
|
976
|
+
)
|
|
977
|
+
if isinstance(task, TerminationSignal):
|
|
978
|
+
break
|
|
979
|
+
elif self.is_duplicate(task_id):
|
|
980
|
+
self.deal_dupliacte(task)
|
|
981
|
+
self.progress_manager.update(1)
|
|
982
|
+
continue
|
|
983
|
+
self.processed_set.add(task_id)
|
|
984
|
+
async_tasks.append(sem_task(task)) # 使用信号量包裹的任务
|
|
985
|
+
|
|
986
|
+
# 并发运行所有任务
|
|
987
|
+
for task, result, start_time in await asyncio.gather(
|
|
988
|
+
*async_tasks, return_exceptions=True
|
|
989
|
+
):
|
|
990
|
+
if not isinstance(result, Exception):
|
|
991
|
+
await self.process_task_success_async(task, result, start_time)
|
|
992
|
+
else:
|
|
993
|
+
await self.handle_task_error_async(task, result)
|
|
994
|
+
self.progress_manager.update(1)
|
|
995
|
+
|
|
996
|
+
self.terminated_queue_set = set()
|
|
997
|
+
|
|
998
|
+
if not self.is_tasks_finished():
|
|
999
|
+
self.task_logger._log("DEBUG", f"Retrying tasks for '{self.func.__name__}'")
|
|
1000
|
+
await self.terminate_task_queues_async()
|
|
1001
|
+
await self.run_in_async()
|
|
1002
|
+
|
|
1003
|
+
async def _run_single_task(self, task):
|
|
1004
|
+
"""
|
|
1005
|
+
运行单个任务并捕获异常
|
|
1006
|
+
"""
|
|
1007
|
+
try:
|
|
1008
|
+
result = await self.func(*self.get_args(task))
|
|
1009
|
+
return result
|
|
1010
|
+
except Exception as error:
|
|
1011
|
+
return error
|
|
1012
|
+
|
|
1013
|
+
def get_success_dict(self) -> dict:
|
|
1014
|
+
"""
|
|
1015
|
+
获取成功任务的字典
|
|
1016
|
+
"""
|
|
1017
|
+
return dict(self.success_dict)
|
|
1018
|
+
|
|
1019
|
+
def get_error_dict(self) -> dict:
|
|
1020
|
+
"""
|
|
1021
|
+
获取出错任务的字典
|
|
1022
|
+
"""
|
|
1023
|
+
return dict(self.error_dict)
|
|
1024
|
+
|
|
1025
|
+
def release_queue(self):
|
|
1026
|
+
"""
|
|
1027
|
+
清理环境
|
|
1028
|
+
"""
|
|
1029
|
+
self.task_queues = None
|
|
1030
|
+
self.result_queues = None
|
|
1031
|
+
self.fail_queue = None
|
|
1032
|
+
|
|
1033
|
+
def release_pool(self):
|
|
1034
|
+
"""
|
|
1035
|
+
关闭线程池和进程池,释放资源
|
|
1036
|
+
"""
|
|
1037
|
+
for pool in [self.thread_pool, self.process_pool]:
|
|
1038
|
+
if pool:
|
|
1039
|
+
pool.shutdown(wait=True)
|
|
1040
|
+
self.thread_pool = None
|
|
1041
|
+
self.process_pool = None
|
|
1042
|
+
|
|
1043
|
+
def test_method(self, execution_mode: str, task_list: list) -> float:
|
|
1044
|
+
"""
|
|
1045
|
+
测试方法
|
|
1046
|
+
"""
|
|
1047
|
+
start = time.time()
|
|
1048
|
+
self.set_execution_mode(execution_mode)
|
|
1049
|
+
self.init_counter()
|
|
1050
|
+
self.init_state()
|
|
1051
|
+
self.start(task_list)
|
|
1052
|
+
return time.time() - start
|
|
1053
|
+
|
|
1054
|
+
def test_methods(self, task_source: Iterable, execution_modes: list = None) -> list:
|
|
1055
|
+
"""
|
|
1056
|
+
测试多种方法
|
|
1057
|
+
"""
|
|
1058
|
+
# 如果 task_source 是生成器或一次性可迭代对象,需要提前转化成列表
|
|
1059
|
+
# 确保对不同模式的测试使用同一批任务数据
|
|
1060
|
+
task_list = list(task_source)
|
|
1061
|
+
execution_modes = execution_modes or ["serial", "thread", "process"]
|
|
1062
|
+
|
|
1063
|
+
results = []
|
|
1064
|
+
for mode in execution_modes:
|
|
1065
|
+
result = self.test_method(mode, task_list)
|
|
1066
|
+
results.append([result])
|
|
1067
|
+
return results, execution_modes, ["Time"]
|
|
1068
|
+
|
|
1069
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
1070
|
+
self.release_queue()
|