celestialflow 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,501 @@
1
+ import json, ast
2
+ import hashlib
3
+ import pickle
4
+ import networkx as nx
5
+ from networkx import is_directed_acyclic_graph
6
+ from itertools import zip_longest
7
+ from collections import defaultdict
8
+ from datetime import datetime
9
+ from multiprocessing import Queue as MPQueue
10
+ from asyncio import Queue as AsyncQueue
11
+ from queue import Queue as ThreadQueue
12
+ from pathlib import Path
13
+ from queue import Empty
14
+ from asyncio import QueueEmpty as AsyncQueueEmpty
15
+ from typing import TYPE_CHECKING, Dict, Any, List, Set, Optional
16
+
17
+ if TYPE_CHECKING:
18
+ from .task_manage import TaskManager
19
+
20
+
21
+ # ========调用于task_graph.py========
22
+ def format_duration(seconds):
23
+ """将秒数格式化为 HH:MM:SS 或 MM:SS(自动省略前导零)"""
24
+ seconds = int(seconds)
25
+ hours, remainder = divmod(seconds, 3600)
26
+ minutes, seconds = divmod(remainder, 60)
27
+
28
+ if hours > 0:
29
+ return f"{hours:02d}:{minutes:02d}:{seconds:02d}"
30
+ else:
31
+ return f"{minutes:02d}:{seconds:02d}"
32
+
33
+
34
+ def format_timestamp(timestamp) -> str:
35
+ """将时间戳格式化为 YYYY-MM-DD HH:MM:SS"""
36
+ return datetime.fromtimestamp(timestamp).strftime("%Y-%m-%d %H:%M:%S")
37
+
38
+
39
+ def build_structure_graph(root_stages: List["TaskManager"]) -> List[Dict[str, Any]]:
40
+ """
41
+ 从多个根节点构建任务链的 JSON 图结构
42
+
43
+ :param root_stages: 根节点列表
44
+ :return: 多棵任务图的 JSON 列表
45
+ """
46
+ visited_stages: Set[str] = set()
47
+ graphs = []
48
+
49
+ for root_stage in root_stages:
50
+ graph = _build_structure_subgraph(root_stage, visited_stages)
51
+ graphs.append(graph)
52
+
53
+ return graphs
54
+
55
+
56
+ def _build_structure_subgraph(
57
+ task_manager: "TaskManager", visited_stages: Set[str]
58
+ ) -> Dict[str, Any]:
59
+ """
60
+ 构建单个子图结构
61
+ """
62
+ stage_tag = task_manager.get_stage_tag()
63
+ node = {
64
+ "stage_name": task_manager.stage_name,
65
+ "stage_mode": task_manager.stage_mode,
66
+ "func_name": task_manager.func.__name__,
67
+ "visited": False,
68
+ "next_stages": [],
69
+ }
70
+
71
+ if stage_tag in visited_stages:
72
+ node["visited"] = True
73
+ return node
74
+
75
+ visited_stages.add(stage_tag)
76
+
77
+ for next_stage in task_manager.next_stages:
78
+ child_node = _build_structure_subgraph(next_stage, visited_stages)
79
+ node["next_stages"].append(child_node)
80
+
81
+ return node
82
+
83
+
84
+ def format_structure_list_from_graph(
85
+ root_roots: List[Dict] = None, indent=0
86
+ ) -> List[str]:
87
+ """
88
+ 从多个 JSON 图结构生成格式化任务结构文本列表(带边框)
89
+
90
+ :param root_roots: JSON 格式任务图根节点列表
91
+ :param indent: 当前缩进级别
92
+ :return: 带边框的格式化字符串列表
93
+ """
94
+
95
+ def build_lines(node: Dict, current_indent: int) -> List[str]:
96
+ lines = []
97
+ visited_note = " [Visited]" if node.get("visited") else ""
98
+ line = f"{node['stage_name']} (stage_mode: {node['stage_mode']}, func: {node['func_name']}){visited_note}"
99
+ lines.append(line)
100
+
101
+ for child in node.get("next_stages", []):
102
+ sub_lines = build_lines(child, current_indent + 2)
103
+ arrow_prefix = " " * current_indent + "╘-->"
104
+ sub_lines[0] = f"{arrow_prefix}{sub_lines[0]}"
105
+ lines.extend(sub_lines)
106
+
107
+ return lines
108
+
109
+ all_lines = []
110
+ for root in root_roots or []:
111
+ if all_lines:
112
+ all_lines.append("") # 根之间留空行
113
+ all_lines.extend(build_lines(root, indent))
114
+
115
+ if not all_lines:
116
+ return ["+ No stages defined +"]
117
+
118
+ max_length = max(len(line) for line in all_lines)
119
+ content_lines = [f"| {line.ljust(max_length)} |" for line in all_lines]
120
+ border = "+" + "-" * (max_length + 2) + "+"
121
+ return [border] + content_lines + [border]
122
+
123
+
124
+ def append_jsonl_log(
125
+ log_data: dict, start_time: float, base_path: str, prefix: str, logger=None
126
+ ):
127
+ """
128
+ 将日志字典写入指定目录下的 JSONL 文件。
129
+
130
+ :param log_data: 要写入的日志项(字典)
131
+ :param start_time: 运行开始时间,用于构造路径
132
+ :param base_path: 基础路径,例如 './fallback'
133
+ :param prefix: 文件名前缀,例如 'realtime_errors'
134
+ :param logger: 可选的日志对象用于记录失败信息
135
+ """
136
+ try:
137
+ date_str = datetime.fromtimestamp(start_time).strftime("%Y-%m-%d")
138
+ time_str = datetime.fromtimestamp(start_time).strftime("%H-%M-%S-%f")[:-3]
139
+ file_path = Path(base_path) / date_str / f"{prefix}({time_str}).jsonl"
140
+ file_path.parent.mkdir(parents=True, exist_ok=True)
141
+
142
+ with open(file_path, "a", encoding="utf-8") as f:
143
+ f.write(json.dumps(log_data, ensure_ascii=False) + "\n")
144
+ except Exception as e:
145
+ if logger:
146
+ logger._log("WARNING", f"[Persist] 写入日志失败: {e}")
147
+
148
+
149
+ def cluster_by_value_sorted(input_dict: Dict[str, int]) -> Dict[int, List[str]]:
150
+ """
151
+ 按值聚类,并确保按 value(键)升序排序
152
+ """
153
+ from collections import defaultdict
154
+
155
+ clusters = defaultdict(list)
156
+ for key, val in input_dict.items():
157
+ clusters[val].append(key)
158
+
159
+ return dict(sorted(clusters.items())) # ✅ 按键排序
160
+
161
+
162
+ # ========(图论分析)========
163
+ def format_networkx_graph(structure_graph: List[Dict[str, Any]]) -> nx.DiGraph:
164
+ """
165
+ 将结构图(由 build_structure_graph 生成)转换为 networkx 有向图(DiGraph)
166
+
167
+ :param structure_graph: JSON 格式的任务结构图,List[Dict]
168
+ :return: 构建好的 networkx.DiGraph
169
+ """
170
+ G = nx.DiGraph()
171
+
172
+ def add_node_and_edges(node: Dict[str, Any]):
173
+ node_id = f'{node["stage_name"]}[{node["func_name"]}]'
174
+ G.add_node(node_id, **{"mode": node.get("stage_mode")})
175
+
176
+ for child in node.get("next_stages", []):
177
+ child_id = f'{child["stage_name"]}[{child["func_name"]}]'
178
+ G.add_edge(node_id, child_id)
179
+ # 递归添加子节点
180
+ add_node_and_edges(child)
181
+
182
+ for root in structure_graph:
183
+ add_node_and_edges(root)
184
+
185
+ return G
186
+
187
+
188
+ def compute_node_levels(G: nx.DiGraph) -> Dict[str, int]:
189
+ """
190
+ 计算 DAG 中每个节点的层级(最早执行阶段)
191
+ 前提:图必须是有向无环图(DAG)
192
+
193
+ 返回: dict[node] = level (int)
194
+ """
195
+ if not nx.is_directed_acyclic_graph(G):
196
+ raise ValueError("该图不是 DAG,无法进行层级划分")
197
+
198
+ level = {node: 0 for node in G.nodes} # 初始层级为 0
199
+
200
+ for node in nx.topological_sort(G): # 按拓扑顺序遍历
201
+ for succ in G.successors(node):
202
+ level[succ] = max(level[succ], level[node] + 1)
203
+
204
+ return level
205
+
206
+
207
+ # ========调用于task_manage.py========
208
+ def is_queue_empty(q: ThreadQueue) -> bool:
209
+ """
210
+ 判断队列是否为空
211
+ """
212
+ try:
213
+ item = q.get_nowait()
214
+ q.put(item) # optional: put it back
215
+ return False
216
+ except Empty:
217
+ return True
218
+
219
+
220
+ async def is_queue_empty_async(q: AsyncQueue) -> bool:
221
+ """
222
+ 判断队列是否为空
223
+ """
224
+ try:
225
+ item = q.get_nowait()
226
+ await q.put(item) # ✅ 修复点
227
+ return False
228
+ except AsyncQueueEmpty:
229
+ return True
230
+
231
+
232
+ def are_queues_empty(queues: List[ThreadQueue]) -> bool:
233
+ """
234
+ 判断多个同步队列是否都为空。
235
+ 所有队列都为空才返回 True。
236
+ """
237
+ for q in queues:
238
+ if not is_queue_empty(q):
239
+ return False
240
+ return True
241
+
242
+
243
+ async def are_queues_empty_async(queues: List[AsyncQueue]) -> bool:
244
+ """
245
+ 判断多个异步队列是否都为空。
246
+ 所有队列都为空才返回 True。
247
+ """
248
+ for q in queues:
249
+ if not await is_queue_empty_async(q):
250
+ return False
251
+ return True
252
+
253
+
254
+ def format_repr(obj: Any, max_length: int) -> str:
255
+ """
256
+ 将对象格式化为字符串,自动转义换行、截断超长文本。
257
+
258
+ :param obj: 任意对象
259
+ :param max_length: 显示的最大字符数(超出将被截断)
260
+ :return: 格式化字符串
261
+ """
262
+ obj_str = str(obj).replace("\\", "\\\\").replace("\n", "\\n")
263
+ if max_length <= 0 or len(obj_str) <= max_length:
264
+ return obj_str
265
+ # 截断逻辑(前 2/3 + ... + 后 1/3)
266
+ first_part = obj_str[: int(max_length * 2 / 3)]
267
+ last_part = obj_str[-int(max_length / 3) :]
268
+ return f"{first_part}...{last_part}"
269
+
270
+
271
+ def object_to_str_hash(obj) -> str:
272
+ """
273
+ 将任意对象转换为 MD5 字符串。
274
+ """
275
+ obj_bytes = pickle.dumps(obj) # 序列化对象
276
+ return hashlib.md5(obj_bytes).hexdigest()
277
+
278
+
279
+ # ========公共函数========
280
+ def make_hashable(obj):
281
+ """
282
+ 把 obj 转换成可哈希的形式。
283
+ """
284
+ if isinstance(obj, (tuple, list)):
285
+ return tuple(make_hashable(e) for e in obj)
286
+ elif isinstance(obj, dict):
287
+ # dict 转换成 (key, value) 对的元组,且按 key 排序以确保哈希结果一致
288
+ return tuple(
289
+ sorted((make_hashable(k), make_hashable(v)) for k, v in obj.items())
290
+ )
291
+ elif isinstance(obj, set):
292
+ # set 转换成排序后的 tuple
293
+ return tuple(sorted(make_hashable(e) for e in obj))
294
+ else:
295
+ # 基本类型直接返回
296
+ return obj
297
+
298
+
299
+ def cleanup_mpqueue(queue: MPQueue):
300
+ """
301
+ 清理队列
302
+ """
303
+ queue.close()
304
+ queue.join_thread() # 确保队列的后台线程正确终止
305
+
306
+
307
+ def format_table(
308
+ data: list,
309
+ row_names: list = None,
310
+ column_names: list = None,
311
+ index_header: str = "#",
312
+ fill_value: str = "N/A",
313
+ align: str = "left",
314
+ ) -> str:
315
+ """
316
+ 格式化表格数据为字符串(CelestialVault.TextTools中同名函数的简化版)。
317
+ """
318
+
319
+ def _generate_excel_column_names(n: int, start_index: int = 0) -> list[str]:
320
+ """
321
+ 生成 Excel 风格列名(A, B, ..., Z, AA, AB, ...)
322
+ 支持从指定起始索引开始生成。
323
+ """
324
+ names = []
325
+ for i in range(start_index, start_index + n):
326
+ name = ""
327
+ x = i
328
+ while True:
329
+ name = chr(ord("A") + (x % 26)) + name
330
+ x = x // 26 - 1
331
+ if x < 0:
332
+ break
333
+ names.append(name)
334
+ return names
335
+
336
+ if not data:
337
+ return "表格数据为空!"
338
+
339
+ # 计算列数
340
+ max_cols = max(map(len, data))
341
+
342
+ # 生成列名
343
+ if column_names is None:
344
+ column_names = _generate_excel_column_names(max_cols)
345
+ elif len(column_names) < max_cols:
346
+ start = len(column_names) # 从当前列名数量继续命名
347
+ column_names.extend(
348
+ _generate_excel_column_names(max_cols - len(column_names), start)
349
+ )
350
+
351
+ # 生成行名
352
+ if row_names is None:
353
+ row_names = range(len(data))
354
+ elif len(row_names) < len(data):
355
+ row_names.extend([i for i in range(len(row_names), len(data))])
356
+
357
+ # 添加行号列
358
+ column_names = [index_header] + column_names
359
+ num_columns = len(column_names)
360
+
361
+ # 处理行号
362
+ formatted_data = []
363
+ for i, row in enumerate(data):
364
+ row_label = row_names[i] if row_names else i
365
+ formatted_data.append([row_label] + list(row))
366
+
367
+ # 统一填充数据行,确保所有行长度一致
368
+ formatted_data = zip_longest(*formatted_data, fillvalue=fill_value)
369
+ formatted_data = list(zip(*formatted_data)) # 转置回来
370
+
371
+ # 计算每列的最大宽度
372
+ col_widths = [
373
+ max(len(str(item)) for item in col)
374
+ for col in zip(column_names, *formatted_data)
375
+ ]
376
+
377
+ # 选择对齐方式
378
+ align_funcs = {
379
+ "left": lambda text, width: f"{text:<{width}}",
380
+ "right": lambda text, width: f"{text:>{width}}",
381
+ "center": lambda text, width: f"{text:^{width}}",
382
+ }
383
+ align_func = align_funcs.get(align, align_funcs["left"]) # 默认左对齐
384
+
385
+ # 生成表格
386
+ separator = "+" + "+".join(["-" * (width + 2) for width in col_widths]) + "+"
387
+ header = (
388
+ "| "
389
+ + " | ".join(
390
+ [
391
+ f"{align_func(name, col_widths[i])}"
392
+ for i, name in enumerate(column_names)
393
+ ]
394
+ )
395
+ + " |"
396
+ )
397
+
398
+ # 生成行
399
+ rows_list = []
400
+ for row in formatted_data:
401
+ rows_list.append(
402
+ "| "
403
+ + " | ".join(
404
+ [
405
+ f"{align_func(str(row[i]), col_widths[i])}"
406
+ for i in range(num_columns)
407
+ ]
408
+ )
409
+ + " |"
410
+ )
411
+ rows = "\n".join(rows_list)
412
+
413
+ # 拼接表格
414
+ table = f"{separator}\n{header}\n{separator}\n{rows}\n{separator}"
415
+ return table
416
+
417
+
418
+ # ========外部调用========
419
+ def load_jsonl_grouped_by_keys(
420
+ jsonl_path: str,
421
+ group_keys: List[str],
422
+ extract_fields: Optional[List[str]] = None,
423
+ eval_fields: Optional[List[str]] = None,
424
+ skip_if_missing: bool = True,
425
+ ) -> Dict[str, List[Any]]:
426
+ """
427
+ 加载 JSONL 文件内容并按多个 key 分组。
428
+
429
+ :param jsonl_path: JSONL 文件路径
430
+ :param group_keys: 用于分组的字段名列表(如 ['error', 'stage'])
431
+ :param extract_fields: 要提取的字段名列表;为空时返回整个 item
432
+ :param eval_fields: 哪些字段需要用 ast.literal_eval 解析
433
+ :param skip_if_missing: 缺 key 是否跳过该条记录
434
+ :return: 一个 {"(k1, k2)": [items]} 的字典
435
+ """
436
+ result_dict = defaultdict(list)
437
+
438
+ with open(jsonl_path, "r", encoding="utf-8") as f:
439
+ for line in f:
440
+ try:
441
+ item = json.loads(line)
442
+ except Exception:
443
+ continue
444
+
445
+ # 确保 group_keys 都存在
446
+ if skip_if_missing and any(k not in item for k in group_keys):
447
+ continue
448
+
449
+ # 组合分组 key
450
+ group_values = tuple(item.get(k, "") for k in group_keys)
451
+ group_key = (
452
+ f"({', '.join(map(str, group_values))})"
453
+ if len(group_values) > 1
454
+ else group_values[0]
455
+ )
456
+
457
+ # 字段反序列化(仅 eval_fields)
458
+ if eval_fields:
459
+ for key in eval_fields:
460
+ if key in item:
461
+ try:
462
+ item[key] = ast.literal_eval(item[key])
463
+ except Exception:
464
+ pass # 解析失败不终止
465
+
466
+ # 提取内容
467
+ if extract_fields:
468
+ if skip_if_missing and any(k not in item for k in extract_fields):
469
+ continue
470
+
471
+ if len(extract_fields) == 1:
472
+ value = item[extract_fields[0]]
473
+ else:
474
+ value = {k: item[k] for k in extract_fields if k in item}
475
+ else:
476
+ value = item
477
+
478
+ result_dict[group_key].append(value)
479
+
480
+ return dict(result_dict)
481
+
482
+
483
+ def load_task_by_stage(jsonl_path):
484
+ """
485
+ 加载错误记录,按 stage 分类
486
+ """
487
+ return load_jsonl_grouped_by_keys(
488
+ jsonl_path, group_keys=["stage"], extract_fields=["task"], eval_fields=["task"]
489
+ )
490
+
491
+
492
+ def load_task_by_error(jsonl_path):
493
+ """
494
+ 加载错误记录,按 error 和 stage 分类
495
+ """
496
+ return load_jsonl_grouped_by_keys(
497
+ jsonl_path,
498
+ group_keys=["error", "stage"],
499
+ extract_fields=["task"],
500
+ eval_fields=["task"],
501
+ )
@@ -0,0 +1,61 @@
1
+ from enum import IntEnum
2
+ from typing import List
3
+ from multiprocessing import Value as MPValue
4
+
5
+
6
+ class TerminationSignal:
7
+ """用于标记任务队列终止的哨兵对象"""
8
+
9
+ pass
10
+
11
+
12
+ # 单例 termination signal
13
+ TERMINATION_SIGNAL = TerminationSignal()
14
+
15
+
16
+ class TaskError(Exception):
17
+ """用于标记任务执行错误的异常类"""
18
+
19
+ pass
20
+
21
+
22
+ class NoOpContext:
23
+ """空上下文管理器,可用于禁用 with 逻辑"""
24
+ def __enter__(self):
25
+ return self
26
+
27
+ def __exit__(self, exc_type, exc_val, exc_tb):
28
+ pass
29
+
30
+
31
+ class ValueWrapper:
32
+ """简单包装一个数值,用于进程间共享"""
33
+ def __init__(self, value=0):
34
+ self.value = value
35
+
36
+
37
+ class SumCounter:
38
+ """累加多个 ValueWrapper / MPValue"""
39
+ def __init__(self):
40
+ self.init_value = MPValue("i", 0)
41
+ self.counters: List[ValueWrapper] = []
42
+
43
+ def add_init_value(self, value):
44
+ self.init_value.value += value
45
+
46
+ def add_counter(self, counter):
47
+ self.counters.append(counter)
48
+
49
+ @property
50
+ def value(self):
51
+ return (
52
+ self.init_value.value + sum(c.value for c in self.counters)
53
+ if self.counters
54
+ else self.init_value.value
55
+ )
56
+
57
+
58
+ class StageStatus(IntEnum):
59
+ NOT_STARTED = 0
60
+ RUNNING = 1
61
+ STOPPED = 2