celestialflow 3.0.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,160 @@
1
+ import json
2
+ import time
3
+ import redis
4
+
5
+ from .task_manage import TaskManager
6
+
7
+
8
+ class RemoteWorkerError(Exception):
9
+ pass
10
+
11
+
12
+ class TaskSplitter(TaskManager):
13
+ def __init__(self):
14
+ """
15
+ 初始化 TaskSplitter
16
+ """
17
+ super().__init__(
18
+ func=self._split_task,
19
+ execution_mode="serial",
20
+ max_retries=0,
21
+ )
22
+
23
+ def _split_task(self, *task):
24
+ """
25
+ 实际上这个函数不执行逻辑,仅用于符合 TaskManager 架构
26
+ """
27
+ return task
28
+
29
+ def get_args(self, task):
30
+ return task
31
+
32
+ def put_split_result(self, result: tuple):
33
+ split_count = 0
34
+ for item in result:
35
+ self.put_result_queues(item)
36
+ split_count += 1
37
+
38
+ self.extra_stats["split_output_count"].value += split_count
39
+ return split_count
40
+
41
+ def process_result(self, task, result):
42
+ """
43
+ 处理不可迭代的任务结果
44
+ """
45
+ if not hasattr(result, "__iter__") or isinstance(result, (str, bytes)):
46
+ result = (result,)
47
+ elif isinstance(result, list):
48
+ result = tuple(result)
49
+
50
+ return result
51
+
52
+ def process_task_success(self, task, result, start_time):
53
+ """
54
+ 统一处理成功任务
55
+
56
+ :param task: 完成的任务
57
+ :param result: 任务的结果
58
+ :param start_time: 任务开始时间
59
+ """
60
+ processed_result = self.process_result(task, result)
61
+
62
+ if self.enable_result_cache:
63
+ self.success_dict[task] = processed_result
64
+
65
+ # ✅ 清理 retry_time_dict
66
+ task_id = self.get_task_id(task)
67
+ self.retry_time_dict.pop(task_id, None)
68
+
69
+ split_count = self.put_split_result(result)
70
+ self.update_success_counter()
71
+
72
+ self.task_logger.splitter_success(
73
+ self.func.__name__,
74
+ self.get_task_info(task),
75
+ split_count,
76
+ time.time() - start_time,
77
+ )
78
+
79
+
80
+ class TaskRedisTransfer(TaskManager):
81
+ def __init__(
82
+ self,
83
+ worker_limit=50,
84
+ unpack_task_args=False,
85
+ host="localhost",
86
+ port=6379,
87
+ db=0,
88
+ fetch_timeout=10,
89
+ result_timeout=10,
90
+ ):
91
+ """
92
+ 初始化 TaskRedisTransfer
93
+ :param worker_limit: 并行工作线程数
94
+ :param unpack_task_args: 是否将任务参数解包
95
+ :param host: Redis 主机地址
96
+ :param port: Redis 端口
97
+ :param db: Redis 数据库
98
+ :param fetch_timeout: Redis 任务等待超时时间
99
+ :param result_timeout: Redis 结果等待超时时间
100
+ """
101
+ super().__init__(
102
+ func=self._trans_redis,
103
+ execution_mode="thread",
104
+ worker_limit=worker_limit,
105
+ unpack_task_args=unpack_task_args,
106
+ )
107
+
108
+ self.host = host
109
+ self.port = port
110
+ self.db = db
111
+ self.fetch_timeout = fetch_timeout
112
+ self.result_timeout = result_timeout
113
+
114
+ def init_redis(self):
115
+ """初始化 Redis 客户端"""
116
+ if not hasattr(self, "redis_client"):
117
+ self.redis_client = redis.Redis(
118
+ host=self.host, port=self.port, db=self.db, decode_responses=True
119
+ )
120
+
121
+ def _trans_redis(self, *task):
122
+ """
123
+ 将任务写入 Redis, 并等待结果
124
+ """
125
+ self.init_redis()
126
+ input_key = f"{self.get_stage_tag()}:input"
127
+ output_key = f"{self.get_stage_tag()}:output"
128
+
129
+ # 提交任务
130
+ task_id = self.get_task_id(task)
131
+ payload = json.dumps({"id": task_id, "task": task})
132
+ self.redis_client.rpush(input_key, payload)
133
+
134
+ # ✅ 等待任务被 BLPOP 拿走(不在 list 中)
135
+ wait_start = time.time()
136
+ while True:
137
+ if self.redis_client.lpos(input_key, payload) is None: # 已被取走
138
+ break
139
+ if time.time() - wait_start > self.fetch_timeout:
140
+ raise TimeoutError("Task not fetched from Redis in time")
141
+ time.sleep(0.1)
142
+
143
+ # ✅ 被取走后再进入结果等待阶段
144
+ start_time = time.time()
145
+ while True:
146
+ result = self.redis_client.hget(output_key, task_id)
147
+ if result:
148
+ self.redis_client.hdel(output_key, task_id)
149
+ result_obj = json.loads(result)
150
+ if result_obj.get("status") == "success":
151
+ return result_obj.get("result")
152
+ elif result_obj.get("status") == "error":
153
+ raise RemoteWorkerError(f"{result_obj.get('error')}")
154
+ else:
155
+ raise ValueError(f"Unknown result status: {result_obj}")
156
+ if time.time() - start_time > self.result_timeout:
157
+ raise TimeoutError(
158
+ "Redis result not returned in time after being fetched"
159
+ )
160
+ time.sleep(0.1)
@@ -0,0 +1,57 @@
1
+ from tqdm import tqdm
2
+ from tqdm.asyncio import tqdm as tqdm_asy
3
+
4
+
5
+ class NullProgress:
6
+ def update(self, n=1):
7
+ pass
8
+
9
+ def close(self):
10
+ pass
11
+
12
+ def refresh_total(self, total):
13
+ pass
14
+
15
+ def add_total(self, add_num):
16
+ pass
17
+
18
+
19
+ class ProgressManager:
20
+ def __init__(
21
+ self,
22
+ total_tasks: int,
23
+ desc: str = "Processing",
24
+ mode: str = "normal",
25
+ ):
26
+ """
27
+ 初始化进度条管理器
28
+
29
+ :param total_tasks: 任务总数,用于设置进度条的总长度
30
+ :param desc: 进度条的描述文字
31
+ :param mode: 任务模式,可选 "async", other
32
+ :param show_progress: 是否显示进度条
33
+ """
34
+ if mode == "async":
35
+ self.progress_bar = tqdm_asy(total=total_tasks, desc=desc)
36
+ else:
37
+ self.progress_bar = tqdm(total=total_tasks, desc=desc)
38
+
39
+ def update(self, n=1):
40
+ """更新进度条"""
41
+ self.progress_bar.update(n)
42
+
43
+ def close(self):
44
+ """关闭进度条"""
45
+ self.progress_bar.close()
46
+
47
+ def refresh_total(self, total):
48
+ """动态调整进度条的总任务数"""
49
+ self.progress_bar.total = total
50
+ self.progress_bar.refresh()
51
+
52
+ def add_total(self, add_num):
53
+ """动态增加进度条的总任务数"""
54
+ if not add_num:
55
+ return
56
+ total = self.progress_bar.total + add_num
57
+ self.refresh_total(total)
@@ -0,0 +1,162 @@
1
+ from threading import Event, Thread
2
+
3
+ import requests
4
+
5
+ from .task_logging import TaskLogger
6
+ from .task_types import TERMINATION_SIGNAL
7
+
8
+
9
+ class TaskReporter:
10
+ """
11
+ 周期性向远程服务推送任务运行状态的上报器。
12
+
13
+ - 定时从服务器拉取配置(如上报间隔、任务注入信息)
14
+ - 将任务图中的状态、错误、结构、拓扑等信息推送到后端接口
15
+ - 以后台线程方式运行,可随时 start()/stop()
16
+ - 主要用于可视化监控、任务远程控制与 Web UI 同步
17
+ """
18
+ def __init__(self, task_graph, logger_queue, host="127.0.0.1", port=5000):
19
+ from .task_graph import TaskGraph
20
+
21
+ self.task_graph: TaskGraph = task_graph
22
+ self.logger = TaskLogger(logger_queue)
23
+ self.base_url = f"http://{host}:{port}"
24
+ self._stop_flag = Event()
25
+ self._thread = None
26
+ self.interval = 5
27
+
28
+ def start(self):
29
+ if self._thread is None or not self._thread.is_alive():
30
+ self._stop_flag.clear()
31
+ self._thread = Thread(target=self._loop, daemon=True)
32
+ self._thread.start()
33
+
34
+ def stop(self):
35
+ if self._thread:
36
+ self.push_once() # 最后一次
37
+ self._stop_flag.set()
38
+ self._thread.join(timeout=2)
39
+ self._thread = None
40
+ self.logger._log("DEBUG", "[Reporter] Stopped.")
41
+
42
+ def _loop(self):
43
+ while not self._stop_flag.is_set():
44
+ try:
45
+ self.push_once()
46
+ except Exception as e:
47
+ self.logger._log(
48
+ "ERROR", f"[Reporter] Push error: {type(e).__name__}({e})."
49
+ )
50
+ self._stop_flag.wait(self.interval)
51
+
52
+ def push_once(self):
53
+ # 拉取逻辑
54
+ self._pull_interval()
55
+ self._pull_and_inject_tasks()
56
+
57
+ # 推送逻辑
58
+ self._push_errors()
59
+ self._push_status()
60
+ self._push_structure()
61
+ self._push_topology()
62
+
63
+ def _pull_interval(self):
64
+ try:
65
+ res = requests.get(f"{self.base_url}/api/get_interval", timeout=1)
66
+ if res.ok:
67
+ interval = res.json().get("interval", 5)
68
+ self.interval = max(1.0, min(interval, 60.0))
69
+ except Exception as e:
70
+ self.logger._log(
71
+ "WARNING", f"[Reporter] Interval fetch failed: {type(e).__name__}({e})."
72
+ )
73
+
74
+ def _pull_and_inject_tasks(self):
75
+ try:
76
+ res = requests.get(f"{self.base_url}/api/get_task_injection", timeout=2)
77
+ if res.ok:
78
+ tasks_list = res.json()
79
+ for task in tasks_list:
80
+ target_node = task.get("node")
81
+ task_datas = task.get("task_datas")
82
+
83
+ if target_node not in self.task_graph.stages_status_dict:
84
+ self.logger._log(
85
+ "WARNING",
86
+ f"[Reporter] Task injection target node {target_node} not found.",
87
+ )
88
+ continue
89
+
90
+ # 这里你可以按需注入到不同的节点
91
+ task_datas = [
92
+ task if task != "TERMINATION_SIGNAL" else TERMINATION_SIGNAL
93
+ for task in task_datas
94
+ ]
95
+ self.task_graph.put_stage_queue(
96
+ {target_node: task_datas}, put_termination_signal=False
97
+ )
98
+ self.logger._log(
99
+ "INFO", f"[Reporter] 注入任务到 {target_node}: {task_datas}"
100
+ )
101
+ except Exception as e:
102
+ self.logger._log(
103
+ "WARNING",
104
+ f"[Reporter] Task injection fetch failed: {type(e).__name__}({e}).",
105
+ )
106
+
107
+ def _push_errors(self):
108
+ try:
109
+ self.task_graph.handle_fail_queue()
110
+ error_data = []
111
+ for (
112
+ err,
113
+ tag,
114
+ ), task_list in self.task_graph.get_error_timeline_dict().items():
115
+ for task, ts in task_list:
116
+ error_data.append(
117
+ {
118
+ "error": err,
119
+ "node": tag,
120
+ "task_id": task if len(task) < 100 else task[:100] + "...",
121
+ "timestamp": ts,
122
+ }
123
+ )
124
+ payload = {"errors": error_data}
125
+ requests.post(f"{self.base_url}/api/push_errors", json=payload, timeout=1)
126
+ except Exception as e:
127
+ self.logger._log(
128
+ "WARNING", f"[Reporter] Error push failed: {type(e).__name__}({e})."
129
+ )
130
+
131
+ def _push_status(self):
132
+ try:
133
+ status_data = self.task_graph.get_status_dict()
134
+ payload = {"status": status_data}
135
+ requests.post(f"{self.base_url}/api/push_status", json=payload, timeout=1)
136
+ except Exception as e:
137
+ self.logger._log(
138
+ "WARNING", f"[Reporter] Status push failed: {type(e).__name__}({e})."
139
+ )
140
+
141
+ def _push_structure(self):
142
+ try:
143
+ structure = self.task_graph.get_structure_json()
144
+ payload = {"items": structure}
145
+ requests.post(
146
+ f"{self.base_url}/api/push_structure", json=payload, timeout=1
147
+ )
148
+ except Exception as e:
149
+ self.logger._log(
150
+ "WARNING", f"[Reporter] Structure push failed: {type(e).__name__}({e})"
151
+ )
152
+
153
+ def _push_topology(self):
154
+ try:
155
+ topology = self.task_graph.get_graph_topology()
156
+ payload = {"topology": topology}
157
+ requests.post(f"{self.base_url}/api/push_topology", json=payload, timeout=1)
158
+ except Exception as e:
159
+ self.logger._log(
160
+ "WARNING", f"[Reporter] Topology push failed: {type(e).__name__}({e})."
161
+ )
162
+
@@ -0,0 +1,151 @@
1
+ from typing import List
2
+
3
+ from .task_manage import TaskManager
4
+ from .task_graph import TaskGraph
5
+
6
+
7
+ # ========有向无环图(DAG)========
8
+ class TaskChain(TaskGraph):
9
+ def __init__(self, stages: List[TaskManager], chain_mode: str = "serial"):
10
+ """
11
+ 初始化 TaskChain
12
+ :param stages: TaskManager 列表
13
+ :param chain_mode: 链式模式,默认为 'serial'
14
+ """
15
+ for num, stage in enumerate(stages):
16
+ stage_name = f"Stage {num + 1}"
17
+ next_stages = [stages[num + 1]] if num < len(stages) - 1 else []
18
+ stage.set_graph_context(next_stages, chain_mode, stage_name)
19
+
20
+ root_stage = stages[0]
21
+ super().__init__([root_stage])
22
+
23
+ def start_chain(self, init_tasks_dict: dict, put_termination_signal: bool = True):
24
+ """
25
+ 启动任务链
26
+ :param init_tasks_dict: 任务列表
27
+ """
28
+ self.start_graph(init_tasks_dict, put_termination_signal)
29
+
30
+
31
+ class TaskCross(TaskGraph):
32
+ def __init__(self, layers: List[List[TaskManager]], layout_mode: str = "process"):
33
+ """
34
+ TaskCross: 多层任务交叉结构
35
+
36
+ 该结构将任务按“层”组织,每层可以包含多个并行执行的 TaskManager 节点,
37
+ 不同层之间通过依赖关系连接,形成跨层的数据流图。
38
+
39
+ :param layers: List[List[TaskManager]]
40
+ 按层划分的任务节点列表。每个子列表代表一层,列表中的 TaskManager 将并行执行。
41
+ 相邻层之间的所有节点将建立全连接依赖(即每个上一层节点都连接到下一层所有节点)。
42
+
43
+ :param layout_mode: str, default = 'process'
44
+ 控制任务图的调度布局模式:
45
+ - 'serial':逐层顺序执行,上一层全部完成后才启动下一层;
46
+ - 'process':所有层并行启动,执行顺序由依赖关系自动调度。
47
+ """
48
+ for i in range(len(layers)):
49
+ curr_layer = layers[i]
50
+ next_layer = layers[i + 1] if i < len(layers) - 1 else []
51
+ for index, stage in enumerate(curr_layer[:]):
52
+ # 非最后一层连接为并行
53
+ stage.set_graph_context(
54
+ next_stages=next_layer,
55
+ stage_mode="process",
56
+ stage_name=f"Layer{i+1}-{index+1}",
57
+ )
58
+ super().__init__(layers[0], layout_mode)
59
+
60
+ def start_cross(self, init_tasks_dict: dict, put_termination_signal: bool = True):
61
+ """
62
+ 启动多层交叉结构任务图
63
+ """
64
+ self.start_graph(init_tasks_dict, put_termination_signal)
65
+
66
+
67
+ class TaskGrid(TaskGraph):
68
+ def __init__(self, grid: List[List[TaskManager]], layout_mode: str = "process"):
69
+ rows, cols = len(grid), len(grid[0])
70
+ for i in range(rows):
71
+ for j in range(cols):
72
+ curr = grid[i][j]
73
+ nexts = []
74
+ if i + 1 < rows:
75
+ nexts.append(grid[i + 1][j]) # down
76
+ if j + 1 < cols:
77
+ nexts.append(grid[i][j + 1]) # right
78
+ curr.set_graph_context(nexts, "process", f"Grid-{i+1}-{j+1}")
79
+ super().__init__([grid[0][0]], layout_mode) # 起点为左上角
80
+
81
+ def start_grid(self, init_tasks_dict: dict, put_termination_signal: bool = True):
82
+ """
83
+ 启动任务网格结构
84
+ :param init_tasks_dict: 任务列表
85
+ """
86
+ self.start_graph(init_tasks_dict, put_termination_signal)
87
+
88
+
89
+ # ========有环图========
90
+ class TaskLoop(TaskGraph):
91
+ def __init__(self, stages: List[TaskManager]):
92
+ """
93
+ 初始化 TaskLoop, 由于环的结构特性, 强制使用 'process' 节点模式
94
+ :param stages: TaskManager 列表
95
+ """
96
+ for num, stage in enumerate(stages):
97
+ stage_name = f"Stage {num + 1}"
98
+ next_stages = [stages[num + 1]] if num < len(stages) - 1 else [stages[0]]
99
+ stage.set_graph_context(next_stages, "process", stage_name)
100
+
101
+ root_stage = stages[0]
102
+ super().__init__([root_stage])
103
+
104
+ def start_loop(self, init_tasks_dict: dict):
105
+ """
106
+ 启动任务环, 环是自锁结构, 能且仅能外部注入式停止
107
+ :param init_tasks_dict: 任务列表
108
+ """
109
+ self.start_graph(init_tasks_dict, False)
110
+
111
+
112
+ class TaskWheel(TaskGraph):
113
+ def __init__(self, center: TaskManager, ring: List[TaskManager]):
114
+ # 中心连向环
115
+ center.set_graph_context(ring, "process", "Center")
116
+ # 环相连(成闭环)
117
+ for i, node in enumerate(ring):
118
+ next_stage = ring[(i + 1) % len(ring)]
119
+ node.set_graph_context([next_stage], "process", f"Ring-{i+1}")
120
+ super().__init__([center])
121
+
122
+ def start_wheel(self, init_tasks_dict: dict):
123
+ """
124
+ 启动任务轮结构
125
+ :param init_tasks_dict: 任务列表
126
+ """
127
+ self.start_graph(init_tasks_dict, False)
128
+
129
+
130
+ class TaskComplete(TaskGraph):
131
+ def __init__(self, stages: List[TaskManager]):
132
+ """
133
+ TaskComplete: 完全图结构,每个节点都连向除自己以外的所有其他节点
134
+ :param stages: 所有 TaskManager 节点
135
+ """
136
+ for i, stage in enumerate(stages):
137
+ next_stages = [s for j, s in enumerate(stages) if i != j]
138
+ stage.set_graph_context(
139
+ next_stages=next_stages,
140
+ stage_mode="process",
141
+ stage_name=f"Node {i + 1}",
142
+ )
143
+
144
+ super().__init__(stages)
145
+
146
+ def start_complete(self, init_tasks_dict: dict):
147
+ """
148
+ 启动任务完全图
149
+ :param init_tasks_dict: 任务列表
150
+ """
151
+ self.start_graph(init_tasks_dict, False)