celestialflow 3.0.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- celestialflow/__init__.py +39 -0
- celestialflow/task_graph.py +665 -0
- celestialflow/task_logging.py +154 -0
- celestialflow/task_manage.py +1070 -0
- celestialflow/task_nodes.py +160 -0
- celestialflow/task_progress.py +57 -0
- celestialflow/task_report.py +162 -0
- celestialflow/task_structure.py +151 -0
- celestialflow/task_tools.py +501 -0
- celestialflow/task_types.py +61 -0
- celestialflow/task_web.py +170 -0
- celestialflow-3.0.1.dist-info/METADATA +301 -0
- celestialflow-3.0.1.dist-info/RECORD +15 -0
- celestialflow-3.0.1.dist-info/WHEEL +5 -0
- celestialflow-3.0.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
import json, ast
|
|
2
|
+
import hashlib
|
|
3
|
+
import pickle
|
|
4
|
+
import networkx as nx
|
|
5
|
+
from networkx import is_directed_acyclic_graph
|
|
6
|
+
from itertools import zip_longest
|
|
7
|
+
from collections import defaultdict
|
|
8
|
+
from datetime import datetime
|
|
9
|
+
from multiprocessing import Queue as MPQueue
|
|
10
|
+
from asyncio import Queue as AsyncQueue
|
|
11
|
+
from queue import Queue as ThreadQueue
|
|
12
|
+
from pathlib import Path
|
|
13
|
+
from queue import Empty
|
|
14
|
+
from asyncio import QueueEmpty as AsyncQueueEmpty
|
|
15
|
+
from typing import TYPE_CHECKING, Dict, Any, List, Set, Optional
|
|
16
|
+
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from .task_manage import TaskManager
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
# ========调用于task_graph.py========
|
|
22
|
+
def format_duration(seconds):
|
|
23
|
+
"""将秒数格式化为 HH:MM:SS 或 MM:SS(自动省略前导零)"""
|
|
24
|
+
seconds = int(seconds)
|
|
25
|
+
hours, remainder = divmod(seconds, 3600)
|
|
26
|
+
minutes, seconds = divmod(remainder, 60)
|
|
27
|
+
|
|
28
|
+
if hours > 0:
|
|
29
|
+
return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
|
|
30
|
+
else:
|
|
31
|
+
return f"{minutes:02d}:{seconds:02d}"
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def format_timestamp(timestamp) -> str:
|
|
35
|
+
"""将时间戳格式化为 YYYY-MM-DD HH:MM:SS"""
|
|
36
|
+
return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def build_structure_graph(root_stages: List["TaskManager"]) -> List[Dict[str, Any]]:
|
|
40
|
+
"""
|
|
41
|
+
从多个根节点构建任务链的 JSON 图结构
|
|
42
|
+
|
|
43
|
+
:param root_stages: 根节点列表
|
|
44
|
+
:return: 多棵任务图的 JSON 列表
|
|
45
|
+
"""
|
|
46
|
+
visited_stages: Set[str] = set()
|
|
47
|
+
graphs = []
|
|
48
|
+
|
|
49
|
+
for root_stage in root_stages:
|
|
50
|
+
graph = _build_structure_subgraph(root_stage, visited_stages)
|
|
51
|
+
graphs.append(graph)
|
|
52
|
+
|
|
53
|
+
return graphs
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _build_structure_subgraph(
|
|
57
|
+
task_manager: "TaskManager", visited_stages: Set[str]
|
|
58
|
+
) -> Dict[str, Any]:
|
|
59
|
+
"""
|
|
60
|
+
构建单个子图结构
|
|
61
|
+
"""
|
|
62
|
+
stage_tag = task_manager.get_stage_tag()
|
|
63
|
+
node = {
|
|
64
|
+
"stage_name": task_manager.stage_name,
|
|
65
|
+
"stage_mode": task_manager.stage_mode,
|
|
66
|
+
"func_name": task_manager.func.__name__,
|
|
67
|
+
"visited": False,
|
|
68
|
+
"next_stages": [],
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if stage_tag in visited_stages:
|
|
72
|
+
node["visited"] = True
|
|
73
|
+
return node
|
|
74
|
+
|
|
75
|
+
visited_stages.add(stage_tag)
|
|
76
|
+
|
|
77
|
+
for next_stage in task_manager.next_stages:
|
|
78
|
+
child_node = _build_structure_subgraph(next_stage, visited_stages)
|
|
79
|
+
node["next_stages"].append(child_node)
|
|
80
|
+
|
|
81
|
+
return node
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def format_structure_list_from_graph(
|
|
85
|
+
root_roots: List[Dict] = None, indent=0
|
|
86
|
+
) -> List[str]:
|
|
87
|
+
"""
|
|
88
|
+
从多个 JSON 图结构生成格式化任务结构文本列表(带边框)
|
|
89
|
+
|
|
90
|
+
:param root_roots: JSON 格式任务图根节点列表
|
|
91
|
+
:param indent: 当前缩进级别
|
|
92
|
+
:return: 带边框的格式化字符串列表
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def build_lines(node: Dict, current_indent: int) -> List[str]:
|
|
96
|
+
lines = []
|
|
97
|
+
visited_note = " [Visited]" if node.get("visited") else ""
|
|
98
|
+
line = f"{node['stage_name']} (stage_mode: {node['stage_mode']}, func: {node['func_name']}){visited_note}"
|
|
99
|
+
lines.append(line)
|
|
100
|
+
|
|
101
|
+
for child in node.get("next_stages", []):
|
|
102
|
+
sub_lines = build_lines(child, current_indent + 2)
|
|
103
|
+
arrow_prefix = " " * current_indent + "╘-->"
|
|
104
|
+
sub_lines[0] = f"{arrow_prefix}{sub_lines[0]}"
|
|
105
|
+
lines.extend(sub_lines)
|
|
106
|
+
|
|
107
|
+
return lines
|
|
108
|
+
|
|
109
|
+
all_lines = []
|
|
110
|
+
for root in root_roots or []:
|
|
111
|
+
if all_lines:
|
|
112
|
+
all_lines.append("") # 根之间留空行
|
|
113
|
+
all_lines.extend(build_lines(root, indent))
|
|
114
|
+
|
|
115
|
+
if not all_lines:
|
|
116
|
+
return ["+ No stages defined +"]
|
|
117
|
+
|
|
118
|
+
max_length = max(len(line) for line in all_lines)
|
|
119
|
+
content_lines = [f"| {line.ljust(max_length)} |" for line in all_lines]
|
|
120
|
+
border = "+" + "-" * (max_length + 2) + "+"
|
|
121
|
+
return [border] + content_lines + [border]
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
def append_jsonl_log(
|
|
125
|
+
log_data: dict, start_time: float, base_path: str, prefix: str, logger=None
|
|
126
|
+
):
|
|
127
|
+
"""
|
|
128
|
+
将日志字典写入指定目录下的 JSONL 文件。
|
|
129
|
+
|
|
130
|
+
:param log_data: 要写入的日志项(字典)
|
|
131
|
+
:param start_time: 运行开始时间,用于构造路径
|
|
132
|
+
:param base_path: 基础路径,例如 './fallback'
|
|
133
|
+
:param prefix: 文件名前缀,例如 'realtime_errors'
|
|
134
|
+
:param logger: 可选的日志对象用于记录失败信息
|
|
135
|
+
"""
|
|
136
|
+
try:
|
|
137
|
+
date_str = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d")
|
|
138
|
+
time_str = datetime.fromtimestamp(start_time).strftime("%H-%M-%S-%f")[:-3]
|
|
139
|
+
file_path = Path(base_path) / date_str / f"{prefix}({time_str}).jsonl"
|
|
140
|
+
file_path.parent.mkdir(parents=True, exist_ok=True)
|
|
141
|
+
|
|
142
|
+
with open(file_path, "a", encoding="utf-8") as f:
|
|
143
|
+
f.write(json.dumps(log_data, ensure_ascii=False) + "\n")
|
|
144
|
+
except Exception as e:
|
|
145
|
+
if logger:
|
|
146
|
+
logger._log("WARNING", f"[Persist] 写入日志失败: {e}")
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def cluster_by_value_sorted(input_dict: Dict[str, int]) -> Dict[int, List[str]]:
|
|
150
|
+
"""
|
|
151
|
+
按值聚类,并确保按 value(键)升序排序
|
|
152
|
+
"""
|
|
153
|
+
from collections import defaultdict
|
|
154
|
+
|
|
155
|
+
clusters = defaultdict(list)
|
|
156
|
+
for key, val in input_dict.items():
|
|
157
|
+
clusters[val].append(key)
|
|
158
|
+
|
|
159
|
+
return dict(sorted(clusters.items())) # ✅ 按键排序
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
# ========(图论分析)========
|
|
163
|
+
def format_networkx_graph(structure_graph: List[Dict[str, Any]]) -> nx.DiGraph:
|
|
164
|
+
"""
|
|
165
|
+
将结构图(由 build_structure_graph 生成)转换为 networkx 有向图(DiGraph)
|
|
166
|
+
|
|
167
|
+
:param structure_graph: JSON 格式的任务结构图,List[Dict]
|
|
168
|
+
:return: 构建好的 networkx.DiGraph
|
|
169
|
+
"""
|
|
170
|
+
G = nx.DiGraph()
|
|
171
|
+
|
|
172
|
+
def add_node_and_edges(node: Dict[str, Any]):
|
|
173
|
+
node_id = f'{node["stage_name"]}[{node["func_name"]}]'
|
|
174
|
+
G.add_node(node_id, **{"mode": node.get("stage_mode")})
|
|
175
|
+
|
|
176
|
+
for child in node.get("next_stages", []):
|
|
177
|
+
child_id = f'{child["stage_name"]}[{child["func_name"]}]'
|
|
178
|
+
G.add_edge(node_id, child_id)
|
|
179
|
+
# 递归添加子节点
|
|
180
|
+
add_node_and_edges(child)
|
|
181
|
+
|
|
182
|
+
for root in structure_graph:
|
|
183
|
+
add_node_and_edges(root)
|
|
184
|
+
|
|
185
|
+
return G
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
def compute_node_levels(G: nx.DiGraph) -> Dict[str, int]:
|
|
189
|
+
"""
|
|
190
|
+
计算 DAG 中每个节点的层级(最早执行阶段)
|
|
191
|
+
前提:图必须是有向无环图(DAG)
|
|
192
|
+
|
|
193
|
+
返回: dict[node] = level (int)
|
|
194
|
+
"""
|
|
195
|
+
if not nx.is_directed_acyclic_graph(G):
|
|
196
|
+
raise ValueError("该图不是 DAG,无法进行层级划分")
|
|
197
|
+
|
|
198
|
+
level = {node: 0 for node in G.nodes} # 初始层级为 0
|
|
199
|
+
|
|
200
|
+
for node in nx.topological_sort(G): # 按拓扑顺序遍历
|
|
201
|
+
for succ in G.successors(node):
|
|
202
|
+
level[succ] = max(level[succ], level[node] + 1)
|
|
203
|
+
|
|
204
|
+
return level
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
# ========调用于task_manage.py========
|
|
208
|
+
def is_queue_empty(q: ThreadQueue) -> bool:
|
|
209
|
+
"""
|
|
210
|
+
判断队列是否为空
|
|
211
|
+
"""
|
|
212
|
+
try:
|
|
213
|
+
item = q.get_nowait()
|
|
214
|
+
q.put(item) # optional: put it back
|
|
215
|
+
return False
|
|
216
|
+
except Empty:
|
|
217
|
+
return True
|
|
218
|
+
|
|
219
|
+
|
|
220
|
+
async def is_queue_empty_async(q: AsyncQueue) -> bool:
|
|
221
|
+
"""
|
|
222
|
+
判断队列是否为空
|
|
223
|
+
"""
|
|
224
|
+
try:
|
|
225
|
+
item = q.get_nowait()
|
|
226
|
+
await q.put(item) # ✅ 修复点
|
|
227
|
+
return False
|
|
228
|
+
except AsyncQueueEmpty:
|
|
229
|
+
return True
|
|
230
|
+
|
|
231
|
+
|
|
232
|
+
def are_queues_empty(queues: List[ThreadQueue]) -> bool:
|
|
233
|
+
"""
|
|
234
|
+
判断多个同步队列是否都为空。
|
|
235
|
+
所有队列都为空才返回 True。
|
|
236
|
+
"""
|
|
237
|
+
for q in queues:
|
|
238
|
+
if not is_queue_empty(q):
|
|
239
|
+
return False
|
|
240
|
+
return True
|
|
241
|
+
|
|
242
|
+
|
|
243
|
+
async def are_queues_empty_async(queues: List[AsyncQueue]) -> bool:
|
|
244
|
+
"""
|
|
245
|
+
判断多个异步队列是否都为空。
|
|
246
|
+
所有队列都为空才返回 True。
|
|
247
|
+
"""
|
|
248
|
+
for q in queues:
|
|
249
|
+
if not await is_queue_empty_async(q):
|
|
250
|
+
return False
|
|
251
|
+
return True
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
def format_repr(obj: Any, max_length: int) -> str:
|
|
255
|
+
"""
|
|
256
|
+
将对象格式化为字符串,自动转义换行、截断超长文本。
|
|
257
|
+
|
|
258
|
+
:param obj: 任意对象
|
|
259
|
+
:param max_length: 显示的最大字符数(超出将被截断)
|
|
260
|
+
:return: 格式化字符串
|
|
261
|
+
"""
|
|
262
|
+
obj_str = str(obj).replace("\\", "\\\\").replace("\n", "\\n")
|
|
263
|
+
if max_length <= 0 or len(obj_str) <= max_length:
|
|
264
|
+
return obj_str
|
|
265
|
+
# 截断逻辑(前 2/3 + ... + 后 1/3)
|
|
266
|
+
first_part = obj_str[: int(max_length * 2 / 3)]
|
|
267
|
+
last_part = obj_str[-int(max_length / 3) :]
|
|
268
|
+
return f"{first_part}...{last_part}"
|
|
269
|
+
|
|
270
|
+
|
|
271
|
+
def object_to_str_hash(obj) -> str:
|
|
272
|
+
"""
|
|
273
|
+
将任意对象转换为 MD5 字符串。
|
|
274
|
+
"""
|
|
275
|
+
obj_bytes = pickle.dumps(obj) # 序列化对象
|
|
276
|
+
return hashlib.md5(obj_bytes).hexdigest()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
# ========公共函数========
|
|
280
|
+
def make_hashable(obj):
|
|
281
|
+
"""
|
|
282
|
+
把 obj 转换成可哈希的形式。
|
|
283
|
+
"""
|
|
284
|
+
if isinstance(obj, (tuple, list)):
|
|
285
|
+
return tuple(make_hashable(e) for e in obj)
|
|
286
|
+
elif isinstance(obj, dict):
|
|
287
|
+
# dict 转换成 (key, value) 对的元组,且按 key 排序以确保哈希结果一致
|
|
288
|
+
return tuple(
|
|
289
|
+
sorted((make_hashable(k), make_hashable(v)) for k, v in obj.items())
|
|
290
|
+
)
|
|
291
|
+
elif isinstance(obj, set):
|
|
292
|
+
# set 转换成排序后的 tuple
|
|
293
|
+
return tuple(sorted(make_hashable(e) for e in obj))
|
|
294
|
+
else:
|
|
295
|
+
# 基本类型直接返回
|
|
296
|
+
return obj
|
|
297
|
+
|
|
298
|
+
|
|
299
|
+
def cleanup_mpqueue(queue: MPQueue):
|
|
300
|
+
"""
|
|
301
|
+
清理队列
|
|
302
|
+
"""
|
|
303
|
+
queue.close()
|
|
304
|
+
queue.join_thread() # 确保队列的后台线程正确终止
|
|
305
|
+
|
|
306
|
+
|
|
307
|
+
def format_table(
|
|
308
|
+
data: list,
|
|
309
|
+
row_names: list = None,
|
|
310
|
+
column_names: list = None,
|
|
311
|
+
index_header: str = "#",
|
|
312
|
+
fill_value: str = "N/A",
|
|
313
|
+
align: str = "left",
|
|
314
|
+
) -> str:
|
|
315
|
+
"""
|
|
316
|
+
格式化表格数据为字符串(CelestialVault.TextTools中同名函数的简化版)。
|
|
317
|
+
"""
|
|
318
|
+
|
|
319
|
+
def _generate_excel_column_names(n: int, start_index: int = 0) -> list[str]:
|
|
320
|
+
"""
|
|
321
|
+
生成 Excel 风格列名(A, B, ..., Z, AA, AB, ...)
|
|
322
|
+
支持从指定起始索引开始生成。
|
|
323
|
+
"""
|
|
324
|
+
names = []
|
|
325
|
+
for i in range(start_index, start_index + n):
|
|
326
|
+
name = ""
|
|
327
|
+
x = i
|
|
328
|
+
while True:
|
|
329
|
+
name = chr(ord("A") + (x % 26)) + name
|
|
330
|
+
x = x // 26 - 1
|
|
331
|
+
if x < 0:
|
|
332
|
+
break
|
|
333
|
+
names.append(name)
|
|
334
|
+
return names
|
|
335
|
+
|
|
336
|
+
if not data:
|
|
337
|
+
return "表格数据为空!"
|
|
338
|
+
|
|
339
|
+
# 计算列数
|
|
340
|
+
max_cols = max(map(len, data))
|
|
341
|
+
|
|
342
|
+
# 生成列名
|
|
343
|
+
if column_names is None:
|
|
344
|
+
column_names = _generate_excel_column_names(max_cols)
|
|
345
|
+
elif len(column_names) < max_cols:
|
|
346
|
+
start = len(column_names) # 从当前列名数量继续命名
|
|
347
|
+
column_names.extend(
|
|
348
|
+
_generate_excel_column_names(max_cols - len(column_names), start)
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
# 生成行名
|
|
352
|
+
if row_names is None:
|
|
353
|
+
row_names = range(len(data))
|
|
354
|
+
elif len(row_names) < len(data):
|
|
355
|
+
row_names.extend([i for i in range(len(row_names), len(data))])
|
|
356
|
+
|
|
357
|
+
# 添加行号列
|
|
358
|
+
column_names = [index_header] + column_names
|
|
359
|
+
num_columns = len(column_names)
|
|
360
|
+
|
|
361
|
+
# 处理行号
|
|
362
|
+
formatted_data = []
|
|
363
|
+
for i, row in enumerate(data):
|
|
364
|
+
row_label = row_names[i] if row_names else i
|
|
365
|
+
formatted_data.append([row_label] + list(row))
|
|
366
|
+
|
|
367
|
+
# 统一填充数据行,确保所有行长度一致
|
|
368
|
+
formatted_data = zip_longest(*formatted_data, fillvalue=fill_value)
|
|
369
|
+
formatted_data = list(zip(*formatted_data)) # 转置回来
|
|
370
|
+
|
|
371
|
+
# 计算每列的最大宽度
|
|
372
|
+
col_widths = [
|
|
373
|
+
max(len(str(item)) for item in col)
|
|
374
|
+
for col in zip(column_names, *formatted_data)
|
|
375
|
+
]
|
|
376
|
+
|
|
377
|
+
# 选择对齐方式
|
|
378
|
+
align_funcs = {
|
|
379
|
+
"left": lambda text, width: f"{text:<{width}}",
|
|
380
|
+
"right": lambda text, width: f"{text:>{width}}",
|
|
381
|
+
"center": lambda text, width: f"{text:^{width}}",
|
|
382
|
+
}
|
|
383
|
+
align_func = align_funcs.get(align, align_funcs["left"]) # 默认左对齐
|
|
384
|
+
|
|
385
|
+
# 生成表格
|
|
386
|
+
separator = "+" + "+".join(["-" * (width + 2) for width in col_widths]) + "+"
|
|
387
|
+
header = (
|
|
388
|
+
"| "
|
|
389
|
+
+ " | ".join(
|
|
390
|
+
[
|
|
391
|
+
f"{align_func(name, col_widths[i])}"
|
|
392
|
+
for i, name in enumerate(column_names)
|
|
393
|
+
]
|
|
394
|
+
)
|
|
395
|
+
+ " |"
|
|
396
|
+
)
|
|
397
|
+
|
|
398
|
+
# 生成行
|
|
399
|
+
rows_list = []
|
|
400
|
+
for row in formatted_data:
|
|
401
|
+
rows_list.append(
|
|
402
|
+
"| "
|
|
403
|
+
+ " | ".join(
|
|
404
|
+
[
|
|
405
|
+
f"{align_func(str(row[i]), col_widths[i])}"
|
|
406
|
+
for i in range(num_columns)
|
|
407
|
+
]
|
|
408
|
+
)
|
|
409
|
+
+ " |"
|
|
410
|
+
)
|
|
411
|
+
rows = "\n".join(rows_list)
|
|
412
|
+
|
|
413
|
+
# 拼接表格
|
|
414
|
+
table = f"{separator}\n{header}\n{separator}\n{rows}\n{separator}"
|
|
415
|
+
return table
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
# ========外部调用========
|
|
419
|
+
def load_jsonl_grouped_by_keys(
|
|
420
|
+
jsonl_path: str,
|
|
421
|
+
group_keys: List[str],
|
|
422
|
+
extract_fields: Optional[List[str]] = None,
|
|
423
|
+
eval_fields: Optional[List[str]] = None,
|
|
424
|
+
skip_if_missing: bool = True,
|
|
425
|
+
) -> Dict[str, List[Any]]:
|
|
426
|
+
"""
|
|
427
|
+
加载 JSONL 文件内容并按多个 key 分组。
|
|
428
|
+
|
|
429
|
+
:param jsonl_path: JSONL 文件路径
|
|
430
|
+
:param group_keys: 用于分组的字段名列表(如 ['error', 'stage'])
|
|
431
|
+
:param extract_fields: 要提取的字段名列表;为空时返回整个 item
|
|
432
|
+
:param eval_fields: 哪些字段需要用 ast.literal_eval 解析
|
|
433
|
+
:param skip_if_missing: 缺 key 是否跳过该条记录
|
|
434
|
+
:return: 一个 {"(k1, k2)": [items]} 的字典
|
|
435
|
+
"""
|
|
436
|
+
result_dict = defaultdict(list)
|
|
437
|
+
|
|
438
|
+
with open(jsonl_path, "r", encoding="utf-8") as f:
|
|
439
|
+
for line in f:
|
|
440
|
+
try:
|
|
441
|
+
item = json.loads(line)
|
|
442
|
+
except Exception:
|
|
443
|
+
continue
|
|
444
|
+
|
|
445
|
+
# 确保 group_keys 都存在
|
|
446
|
+
if skip_if_missing and any(k not in item for k in group_keys):
|
|
447
|
+
continue
|
|
448
|
+
|
|
449
|
+
# 组合分组 key
|
|
450
|
+
group_values = tuple(item.get(k, "") for k in group_keys)
|
|
451
|
+
group_key = (
|
|
452
|
+
f"({', '.join(map(str, group_values))})"
|
|
453
|
+
if len(group_values) > 1
|
|
454
|
+
else group_values[0]
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# 字段反序列化(仅 eval_fields)
|
|
458
|
+
if eval_fields:
|
|
459
|
+
for key in eval_fields:
|
|
460
|
+
if key in item:
|
|
461
|
+
try:
|
|
462
|
+
item[key] = ast.literal_eval(item[key])
|
|
463
|
+
except Exception:
|
|
464
|
+
pass # 解析失败不终止
|
|
465
|
+
|
|
466
|
+
# 提取内容
|
|
467
|
+
if extract_fields:
|
|
468
|
+
if skip_if_missing and any(k not in item for k in extract_fields):
|
|
469
|
+
continue
|
|
470
|
+
|
|
471
|
+
if len(extract_fields) == 1:
|
|
472
|
+
value = item[extract_fields[0]]
|
|
473
|
+
else:
|
|
474
|
+
value = {k: item[k] for k in extract_fields if k in item}
|
|
475
|
+
else:
|
|
476
|
+
value = item
|
|
477
|
+
|
|
478
|
+
result_dict[group_key].append(value)
|
|
479
|
+
|
|
480
|
+
return dict(result_dict)
|
|
481
|
+
|
|
482
|
+
|
|
483
|
+
def load_task_by_stage(jsonl_path):
|
|
484
|
+
"""
|
|
485
|
+
加载错误记录,按 stage 分类
|
|
486
|
+
"""
|
|
487
|
+
return load_jsonl_grouped_by_keys(
|
|
488
|
+
jsonl_path, group_keys=["stage"], extract_fields=["task"], eval_fields=["task"]
|
|
489
|
+
)
|
|
490
|
+
|
|
491
|
+
|
|
492
|
+
def load_task_by_error(jsonl_path):
|
|
493
|
+
"""
|
|
494
|
+
加载错误记录,按 error 和 stage 分类
|
|
495
|
+
"""
|
|
496
|
+
return load_jsonl_grouped_by_keys(
|
|
497
|
+
jsonl_path,
|
|
498
|
+
group_keys=["error", "stage"],
|
|
499
|
+
extract_fields=["task"],
|
|
500
|
+
eval_fields=["task"],
|
|
501
|
+
)
|
|
@@ -0,0 +1,61 @@
|
|
|
1
|
+
from enum import IntEnum
|
|
2
|
+
from typing import List
|
|
3
|
+
from multiprocessing import Value as MPValue
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class TerminationSignal:
|
|
7
|
+
"""用于标记任务队列终止的哨兵对象"""
|
|
8
|
+
|
|
9
|
+
pass
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
# 单例 termination signal
|
|
13
|
+
TERMINATION_SIGNAL = TerminationSignal()
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class TaskError(Exception):
|
|
17
|
+
"""用于标记任务执行错误的异常类"""
|
|
18
|
+
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class NoOpContext:
|
|
23
|
+
"""空上下文管理器,可用于禁用 with 逻辑"""
|
|
24
|
+
def __enter__(self):
|
|
25
|
+
return self
|
|
26
|
+
|
|
27
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
28
|
+
pass
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class ValueWrapper:
|
|
32
|
+
"""简单包装一个数值,用于进程间共享"""
|
|
33
|
+
def __init__(self, value=0):
|
|
34
|
+
self.value = value
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class SumCounter:
|
|
38
|
+
"""累加多个 ValueWrapper / MPValue"""
|
|
39
|
+
def __init__(self):
|
|
40
|
+
self.init_value = MPValue("i", 0)
|
|
41
|
+
self.counters: List[ValueWrapper] = []
|
|
42
|
+
|
|
43
|
+
def add_init_value(self, value):
|
|
44
|
+
self.init_value.value += value
|
|
45
|
+
|
|
46
|
+
def add_counter(self, counter):
|
|
47
|
+
self.counters.append(counter)
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def value(self):
|
|
51
|
+
return (
|
|
52
|
+
self.init_value.value + sum(c.value for c in self.counters)
|
|
53
|
+
if self.counters
|
|
54
|
+
else self.init_value.value
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class StageStatus(IntEnum):
|
|
59
|
+
NOT_STARTED = 0
|
|
60
|
+
RUNNING = 1
|
|
61
|
+
STOPPED = 2
|