jettask 0.2.20__py3-none-any.whl → 0.2.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (110) hide show
  1. jettask/__init__.py +4 -0
  2. jettask/cli.py +12 -8
  3. jettask/config/lua_scripts.py +37 -0
  4. jettask/config/nacos_config.py +1 -1
  5. jettask/core/app.py +313 -340
  6. jettask/core/container.py +4 -4
  7. jettask/{persistence → core}/namespace.py +93 -27
  8. jettask/core/task.py +16 -9
  9. jettask/core/unified_manager_base.py +136 -26
  10. jettask/db/__init__.py +67 -0
  11. jettask/db/base.py +137 -0
  12. jettask/{utils/db_connector.py → db/connector.py} +130 -26
  13. jettask/db/models/__init__.py +16 -0
  14. jettask/db/models/scheduled_task.py +196 -0
  15. jettask/db/models/task.py +77 -0
  16. jettask/db/models/task_run.py +85 -0
  17. jettask/executor/__init__.py +0 -15
  18. jettask/executor/core.py +76 -31
  19. jettask/executor/process_entry.py +29 -114
  20. jettask/executor/task_executor.py +4 -0
  21. jettask/messaging/event_pool.py +928 -685
  22. jettask/messaging/scanner.py +30 -0
  23. jettask/persistence/__init__.py +28 -103
  24. jettask/persistence/buffer.py +170 -0
  25. jettask/persistence/consumer.py +330 -249
  26. jettask/persistence/manager.py +304 -0
  27. jettask/persistence/persistence.py +391 -0
  28. jettask/scheduler/__init__.py +15 -3
  29. jettask/scheduler/{task_crud.py → database.py} +61 -57
  30. jettask/scheduler/loader.py +2 -2
  31. jettask/scheduler/{scheduler_coordinator.py → manager.py} +23 -6
  32. jettask/scheduler/models.py +14 -10
  33. jettask/scheduler/schedule.py +166 -0
  34. jettask/scheduler/scheduler.py +12 -11
  35. jettask/schemas/__init__.py +50 -1
  36. jettask/schemas/backlog.py +43 -6
  37. jettask/schemas/namespace.py +70 -19
  38. jettask/schemas/queue.py +19 -3
  39. jettask/schemas/responses.py +493 -0
  40. jettask/task/__init__.py +0 -2
  41. jettask/task/router.py +3 -0
  42. jettask/test_connection_monitor.py +1 -1
  43. jettask/utils/__init__.py +7 -5
  44. jettask/utils/db_init.py +8 -4
  45. jettask/utils/namespace_dep.py +167 -0
  46. jettask/utils/queue_matcher.py +186 -0
  47. jettask/utils/rate_limit/concurrency_limiter.py +7 -1
  48. jettask/utils/stream_backlog.py +1 -1
  49. jettask/webui/__init__.py +0 -1
  50. jettask/webui/api/__init__.py +4 -4
  51. jettask/webui/api/alerts.py +806 -71
  52. jettask/webui/api/example_refactored.py +400 -0
  53. jettask/webui/api/namespaces.py +390 -45
  54. jettask/webui/api/overview.py +300 -54
  55. jettask/webui/api/queues.py +971 -267
  56. jettask/webui/api/scheduled.py +1249 -56
  57. jettask/webui/api/settings.py +129 -7
  58. jettask/webui/api/workers.py +442 -0
  59. jettask/webui/app.py +46 -2329
  60. jettask/webui/middleware/__init__.py +6 -0
  61. jettask/webui/middleware/namespace_middleware.py +135 -0
  62. jettask/webui/services/__init__.py +146 -0
  63. jettask/webui/services/heartbeat_service.py +251 -0
  64. jettask/webui/services/overview_service.py +60 -51
  65. jettask/webui/services/queue_monitor_service.py +426 -0
  66. jettask/webui/services/redis_monitor_service.py +87 -0
  67. jettask/webui/services/settings_service.py +174 -111
  68. jettask/webui/services/task_monitor_service.py +222 -0
  69. jettask/webui/services/timeline_pg_service.py +452 -0
  70. jettask/webui/services/timeline_service.py +189 -0
  71. jettask/webui/services/worker_monitor_service.py +467 -0
  72. jettask/webui/utils/__init__.py +11 -0
  73. jettask/webui/utils/time_utils.py +122 -0
  74. jettask/worker/lifecycle.py +8 -2
  75. {jettask-0.2.20.dist-info → jettask-0.2.24.dist-info}/METADATA +1 -1
  76. jettask-0.2.24.dist-info/RECORD +142 -0
  77. jettask/executor/executor.py +0 -338
  78. jettask/persistence/backlog_monitor.py +0 -567
  79. jettask/persistence/base.py +0 -2334
  80. jettask/persistence/db_manager.py +0 -516
  81. jettask/persistence/maintenance.py +0 -81
  82. jettask/persistence/message_consumer.py +0 -259
  83. jettask/persistence/models.py +0 -49
  84. jettask/persistence/offline_recovery.py +0 -196
  85. jettask/persistence/queue_discovery.py +0 -215
  86. jettask/persistence/task_persistence.py +0 -218
  87. jettask/persistence/task_updater.py +0 -583
  88. jettask/scheduler/add_execution_count.sql +0 -11
  89. jettask/scheduler/add_priority_field.sql +0 -26
  90. jettask/scheduler/add_scheduler_id.sql +0 -25
  91. jettask/scheduler/add_scheduler_id_index.sql +0 -10
  92. jettask/scheduler/make_scheduler_id_required.sql +0 -28
  93. jettask/scheduler/migrate_interval_seconds.sql +0 -9
  94. jettask/scheduler/performance_optimization.sql +0 -45
  95. jettask/scheduler/run_scheduler.py +0 -186
  96. jettask/scheduler/schema.sql +0 -84
  97. jettask/task/task_executor.py +0 -318
  98. jettask/webui/api/analytics.py +0 -323
  99. jettask/webui/config.py +0 -90
  100. jettask/webui/models/__init__.py +0 -3
  101. jettask/webui/models/namespace.py +0 -63
  102. jettask/webui/namespace_manager/__init__.py +0 -10
  103. jettask/webui/namespace_manager/multi.py +0 -593
  104. jettask/webui/namespace_manager/unified.py +0 -193
  105. jettask/webui/run.py +0 -46
  106. jettask-0.2.20.dist-info/RECORD +0 -145
  107. {jettask-0.2.20.dist-info → jettask-0.2.24.dist-info}/WHEEL +0 -0
  108. {jettask-0.2.20.dist-info → jettask-0.2.24.dist-info}/entry_points.txt +0 -0
  109. {jettask-0.2.20.dist-info → jettask-0.2.24.dist-info}/licenses/LICENSE +0 -0
  110. {jettask-0.2.20.dist-info → jettask-0.2.24.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,304 @@
1
+ """
2
+ 统一的 PostgreSQL 消费者管理器
3
+ 自动识别单命名空间和多命名空间模式
4
+ """
5
+ import asyncio
6
+ import logging
7
+ import multiprocessing
8
+ from typing import Dict, Optional, Set
9
+ from jettask.core.unified_manager_base import UnifiedManagerBase
10
+ from jettask.core.namespace import NamespaceDataAccessManager
11
+ from jettask.persistence.consumer import PostgreSQLConsumer
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class UnifiedConsumerManager(UnifiedManagerBase):
17
+ """
18
+ 统一的消费者管理器
19
+ 继承自 UnifiedManagerBase,实现消费者特定的逻辑
20
+ """
21
+
22
+ def __init__(self,
23
+ task_center_url: str,
24
+ check_interval: int = 30,
25
+ backlog_monitor_interval: int = 30,
26
+ concurrency: int = 4,
27
+ debug: bool = False):
28
+ """
29
+ 初始化消费者管理器
30
+
31
+ Args:
32
+ task_center_url: 任务中心URL
33
+ check_interval: 命名空间检测间隔(秒)
34
+ backlog_monitor_interval: 积压监控间隔(秒)
35
+ concurrency: 并发数(每个命名空间的 worker 进程数)
36
+ debug: 是否启用调试模式
37
+ """
38
+ super().__init__(task_center_url, check_interval, debug)
39
+
40
+ self.backlog_monitor_interval = backlog_monitor_interval
41
+ self.concurrency = concurrency
42
+
43
+ # 消费者管理
44
+ self.consumer_instance: Optional[PostgreSQLConsumer] = None # 单命名空间模式
45
+ self.consumer_processes: Dict[str, multiprocessing.Process] = {} # 多命名空间模式
46
+ self.known_namespaces: Set[str] = set()
47
+
48
+ # 命名空间数据访问管理器
49
+ self.namespace_manager: Optional[NamespaceDataAccessManager] = None
50
+
51
+ async def run_single_namespace(self, namespace_name: str):
52
+ """
53
+ 运行单命名空间模式
54
+
55
+ Args:
56
+ namespace_name: 命名空间名称
57
+ """
58
+ logger.info(f"启动单命名空间消费者: {namespace_name}")
59
+ logger.info(f"积压监控间隔: {self.backlog_monitor_interval}秒")
60
+
61
+ try:
62
+ # 创建命名空间数据访问管理器
63
+ base_url = self.get_base_url()
64
+ self.namespace_manager = NamespaceDataAccessManager(task_center_base_url=base_url)
65
+
66
+ # 获取命名空间连接
67
+ conn = await self.namespace_manager.get_connection(namespace_name)
68
+
69
+ # 检查是否配置了 PostgreSQL
70
+ if not conn.pg_config:
71
+ logger.error(f"命名空间 {namespace_name} 未配置 PostgreSQL,无法启动消费者")
72
+ return
73
+
74
+ logger.info(f"命名空间 {namespace_name} 配置:")
75
+ logger.info(f" - Redis: {'已配置' if conn.redis_config else '未配置'}")
76
+ logger.info(f" - PostgreSQL: 已配置")
77
+ logger.info(f" - Redis Prefix: {conn.redis_prefix}")
78
+
79
+ # 创建并启动消费者
80
+ self.consumer_instance = PostgreSQLConsumer(
81
+ pg_config=conn.pg_config,
82
+ redis_config=conn.redis_config,
83
+ prefix=conn.redis_prefix,
84
+ namespace_name=namespace_name
85
+ )
86
+
87
+ logger.info(f"✓ 消费者已启动: {namespace_name}")
88
+
89
+ # 运行消费者
90
+ await self.consumer_instance.start(concurrency=self.concurrency)
91
+
92
+ except Exception as e:
93
+ logger.error(f"单命名空间消费者运行失败: {e}", exc_info=self.debug)
94
+ raise
95
+ finally:
96
+ # 清理
97
+ if self.consumer_instance:
98
+ await self.consumer_instance.stop()
99
+ logger.info(f"消费者已停止: {namespace_name}")
100
+
101
+ if self.namespace_manager:
102
+ await self.namespace_manager.close_all()
103
+
104
+ async def run_multi_namespace(self, namespace_names: Optional[Set[str]]):
105
+ """
106
+ 运行多命名空间模式
107
+
108
+ Args:
109
+ namespace_names: 目标命名空间集合,None表示所有命名空间
110
+ """
111
+ logger.info("启动多命名空间消费者管理")
112
+ logger.info(f"命名空间检测间隔: {self.check_interval}秒")
113
+ logger.info(f"积压监控间隔: {self.backlog_monitor_interval}秒")
114
+
115
+ # 创建命名空间数据访问管理器
116
+ base_url = self.get_base_url()
117
+ self.namespace_manager = NamespaceDataAccessManager(task_center_base_url=base_url)
118
+
119
+ # 获取初始命名空间
120
+ namespaces = await self.fetch_namespaces_info(namespace_names)
121
+
122
+ # 启动每个命名空间的消费者进程
123
+ for ns_info in namespaces:
124
+ self._start_consumer_process(ns_info['name'])
125
+ self.known_namespaces.add(ns_info['name'])
126
+
127
+ # 创建并发任务
128
+ try:
129
+ health_check_task = asyncio.create_task(self._health_check_loop())
130
+ namespace_check_task = asyncio.create_task(self._namespace_check_loop())
131
+
132
+ # 等待任一任务完成或出错
133
+ _, pending = await asyncio.wait(
134
+ [health_check_task, namespace_check_task],
135
+ return_when=asyncio.FIRST_EXCEPTION
136
+ )
137
+
138
+ # 取消所有未完成的任务
139
+ for task in pending:
140
+ task.cancel()
141
+
142
+ except asyncio.CancelledError:
143
+ logger.info("收到取消信号")
144
+ finally:
145
+ # 清理
146
+ if self.namespace_manager:
147
+ await self.namespace_manager.close_all()
148
+
149
+ def _start_consumer_process(self, namespace_name: str):
150
+ """启动单个命名空间的消费者进程"""
151
+
152
+ # 如果进程已存在且存活,跳过
153
+ if namespace_name in self.consumer_processes:
154
+ process = self.consumer_processes[namespace_name]
155
+ if process.is_alive():
156
+ logger.debug(f"命名空间 {namespace_name} 的消费者进程已在运行")
157
+ return
158
+ else:
159
+ # 清理已停止的进程
160
+ process.terminate()
161
+ process.join(timeout=5)
162
+
163
+ # 创建新进程
164
+ logger.info(f"启动命名空间 {namespace_name} 的消费者进程")
165
+
166
+ process = multiprocessing.Process(
167
+ target=_run_consumer_in_process,
168
+ args=(self.task_center_url, namespace_name, self.backlog_monitor_interval, self.concurrency, self.debug),
169
+ name=f"Consumer-{namespace_name}"
170
+ )
171
+ process.start()
172
+ self.consumer_processes[namespace_name] = process
173
+
174
+ logger.info(f"✓ 消费者进程已启动: {namespace_name} (PID: {process.pid})")
175
+
176
+ async def _health_check_loop(self):
177
+ """健康检查循环 - 检查消费者进程状态"""
178
+ logger.info("健康检查循环已启动")
179
+
180
+ while True:
181
+ try:
182
+ # 检查所有消费者进程
183
+ dead_processes = []
184
+ for ns_name, process in self.consumer_processes.items():
185
+ if not process.is_alive():
186
+ logger.warning(f"消费者进程 {ns_name} 已停止 (退出码: {process.exitcode})")
187
+ dead_processes.append(ns_name)
188
+
189
+ # 重启已停止的进程
190
+ for ns_name in dead_processes:
191
+ logger.info(f"重启消费者进程: {ns_name}")
192
+ self._start_consumer_process(ns_name)
193
+
194
+ # 等待下一次检查
195
+ await asyncio.sleep(self.check_interval)
196
+
197
+ except Exception as e:
198
+ logger.error(f"健康检查循环异常: {e}", exc_info=self.debug)
199
+ await asyncio.sleep(10)
200
+
201
+ async def _namespace_check_loop(self):
202
+ """命名空间检查循环 - 检测新的命名空间"""
203
+ logger.info("命名空间检查循环已启动")
204
+
205
+ while True:
206
+ try:
207
+ # 获取当前所有命名空间
208
+ namespaces = await self.fetch_namespaces_info(None)
209
+ current_namespaces = {ns['name'] for ns in namespaces}
210
+
211
+ # 发现新命名空间
212
+ new_namespaces = current_namespaces - self.known_namespaces
213
+ if new_namespaces:
214
+ logger.info(f"发现新命名空间: {new_namespaces}")
215
+ for ns_name in new_namespaces:
216
+ self._start_consumer_process(ns_name)
217
+ self.known_namespaces.add(ns_name)
218
+
219
+ # 停止已删除的命名空间消费者
220
+ removed_namespaces = self.known_namespaces - current_namespaces
221
+ if removed_namespaces:
222
+ logger.info(f"命名空间已删除: {removed_namespaces}")
223
+ for ns_name in removed_namespaces:
224
+ if ns_name in self.consumer_processes:
225
+ process = self.consumer_processes[ns_name]
226
+ logger.info(f"停止消费者进程: {ns_name}")
227
+ process.terminate()
228
+ process.join(timeout=10)
229
+ del self.consumer_processes[ns_name]
230
+ self.known_namespaces.remove(ns_name)
231
+
232
+ # 等待下一次检查
233
+ await asyncio.sleep(self.check_interval)
234
+
235
+ except Exception as e:
236
+ logger.error(f"命名空间检查循环异常: {e}", exc_info=self.debug)
237
+ await asyncio.sleep(10)
238
+
239
+ async def run(self):
240
+ """
241
+ 运行管理器(自动判断单/多命名空间模式)
242
+ """
243
+ try:
244
+ self.running = True
245
+
246
+ if self.is_single_namespace:
247
+ # 单命名空间模式
248
+ await self.run_single_namespace(self.namespace_name)
249
+ else:
250
+ # 多命名空间模式
251
+ target_namespaces = self.get_target_namespaces()
252
+ await self.run_multi_namespace(target_namespaces)
253
+
254
+ except KeyboardInterrupt:
255
+ logger.info("收到中断信号,停止所有消费者...")
256
+ finally:
257
+ self.running = False
258
+
259
+ # 停止所有消费者进程
260
+ for ns_name, process in list(self.consumer_processes.items()):
261
+ logger.info(f"停止消费者进程: {ns_name}")
262
+ process.terminate()
263
+ process.join(timeout=10)
264
+
265
+ logger.info("所有消费者已停止")
266
+
267
+
268
+ def _run_consumer_in_process(task_center_url: str, namespace_name: str,
269
+ backlog_monitor_interval: int, concurrency: int, debug: bool):
270
+ """
271
+ 在独立进程中运行消费者(复用 run_single_namespace 逻辑)
272
+
273
+ Args:
274
+ task_center_url: 任务中心URL
275
+ namespace_name: 命名空间名称
276
+ backlog_monitor_interval: 积压监控间隔
277
+ concurrency: 并发数
278
+ debug: 是否启用调试模式
279
+ """
280
+ import logging
281
+
282
+ # 配置日志
283
+ log_level = logging.DEBUG if debug else logging.INFO
284
+ logging.basicConfig(
285
+ level=log_level,
286
+ format=f'%(asctime)s - [{namespace_name}] - %(name)s - %(levelname)s - %(message)s'
287
+ )
288
+
289
+ # 创建临时管理器实例并运行单命名空间
290
+ manager = UnifiedConsumerManager(
291
+ task_center_url=task_center_url,
292
+ backlog_monitor_interval=backlog_monitor_interval,
293
+ concurrency=concurrency,
294
+ debug=debug
295
+ )
296
+
297
+ # 运行异步任务
298
+ try:
299
+ asyncio.run(manager.run_single_namespace(namespace_name))
300
+ except KeyboardInterrupt:
301
+ logging.getLogger(__name__).info("进程收到中断信号")
302
+
303
+
304
+ __all__ = ['UnifiedConsumerManager']
@@ -0,0 +1,391 @@
1
+ """任务持久化模块
2
+
3
+ 负责解析Redis Stream消息,并将任务数据批量插入PostgreSQL数据库。
4
+ """
5
+
6
+ import json
7
+ import logging
8
+ import traceback
9
+ from typing import Dict, List, Optional, Any
10
+ from datetime import datetime, timezone
11
+
12
+ from sqlalchemy.orm import sessionmaker
13
+ from sqlalchemy.dialects.postgresql import insert
14
+
15
+ from jettask.db.models.task import Task
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class TaskPersistence:
21
+ """任务持久化处理器
22
+
23
+ 职责:
24
+ - 解析Stream消息为任务信息
25
+ - 批量插入任务到PostgreSQL的tasks表
26
+ - 处理插入失败的降级策略
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ async_session_local: sessionmaker,
32
+ namespace_id: str,
33
+ namespace_name: str
34
+ ):
35
+ """初始化任务持久化处理器
36
+
37
+ Args:
38
+ async_session_local: SQLAlchemy会话工厂
39
+ namespace_id: 命名空间ID
40
+ namespace_name: 命名空间名称
41
+ """
42
+ self.AsyncSessionLocal = async_session_local
43
+ self.namespace_id = namespace_id
44
+ self.namespace_name = namespace_name
45
+
46
+ def parse_stream_message(self, task_id: str, data: dict) -> Optional[dict]:
47
+ """解析Stream消息为任务信息(返回完整的字段)
48
+
49
+ Args:
50
+ task_id: 任务ID(Redis Stream ID)
51
+ data: 消息数据
52
+
53
+ Returns:
54
+ 解析后的任务信息字典,失败返回None
55
+ """
56
+ try:
57
+ from jettask.utils.serializer import loads_str
58
+
59
+ if b'data' in data:
60
+ task_data = loads_str(data[b'data'])
61
+ else:
62
+ task_data = {}
63
+ for k, v in data.items():
64
+ key = k.decode('utf-8') if isinstance(k, bytes) else k
65
+ if isinstance(v, bytes):
66
+ try:
67
+ value = loads_str(v)
68
+ except:
69
+ value = str(v)
70
+ else:
71
+ value = v
72
+ task_data[key] = value
73
+
74
+ # 如果配置了命名空间,检查消息是否属于该命名空间
75
+ # if self.namespace_id:
76
+ # msg_namespace_id = task_data.get('__namespace_id')
77
+ # # 如果消息没有namespace_id且当前不是默认命名空间,跳过
78
+ # if msg_namespace_id != self.namespace_id:
79
+ # if not (msg_namespace_id is None and self.namespace_id == 'default'):
80
+ # logger.debug(f"Skipping message from different namespace: {msg_namespace_id} != {self.namespace_id}")
81
+ # return None
82
+
83
+ queue_name = task_data['queue']
84
+ task_name = task_data.get('name', task_data.get('task', 'unknown'))
85
+
86
+ created_at = None
87
+ if 'trigger_time' in task_data:
88
+ try:
89
+ timestamp = float(task_data['trigger_time'])
90
+ created_at = datetime.fromtimestamp(timestamp, tz=timezone.utc)
91
+ except:
92
+ pass
93
+
94
+ # 返回完整的字段,包括所有可能为None的字段
95
+ return {
96
+ 'id': task_id,
97
+ 'queue_name': queue_name,
98
+ 'task_name': task_name,
99
+ 'task_data': json.dumps(task_data),
100
+ 'priority': int(task_data.get('priority', 0)),
101
+ 'retry_count': int(task_data.get('retry', 0)),
102
+ 'max_retry': int(task_data.get('max_retry', 3)),
103
+ 'status': 'pending',
104
+ 'result': None, # 新任务没有结果
105
+ 'error_message': None, # 新任务没有错误信息
106
+ 'created_at': created_at,
107
+ 'started_at': None, # 新任务还未开始
108
+ 'completed_at': None, # 新任务还未完成
109
+ 'scheduled_task_id': task_data.get('scheduled_task_id'), # 调度任务ID
110
+ 'metadata': json.dumps(task_data.get('metadata', {})),
111
+ 'worker_id': None, # 新任务还未分配worker
112
+ 'execution_time': None, # 新任务还没有执行时间
113
+ 'duration': None, # 新任务还没有持续时间
114
+ 'namespace_id': self.namespace_id # 添加命名空间ID
115
+ }
116
+
117
+ except Exception as e:
118
+ logger.error(f"Error parsing stream message for task {task_id}: {e}")
119
+ logger.error(traceback.format_exc())
120
+ return None
121
+
122
+ async def insert_tasks(self, tasks: List[Dict[str, Any]]) -> int:
123
+ """批量插入任务到PostgreSQL(使用ORM)
124
+
125
+ Args:
126
+ tasks: 任务信息列表
127
+
128
+ Returns:
129
+ 实际插入的记录数
130
+ """
131
+ if not tasks:
132
+ return 0
133
+
134
+ logger.info(f"Attempting to insert {len(tasks)} tasks to tasks table")
135
+
136
+ try:
137
+ async with self.AsyncSessionLocal() as session:
138
+ # 准备tasks表的数据
139
+ tasks_data = []
140
+ for task in tasks:
141
+ task_data = json.loads(task['task_data'])
142
+
143
+ # 从task_data中获取scheduled_task_id
144
+ scheduled_task_id = task_data.get('scheduled_task_id') or task.get('scheduled_task_id')
145
+
146
+ # 根据是否有scheduled_task_id来判断任务来源
147
+ if scheduled_task_id:
148
+ source = 'scheduler' # 定时任务
149
+ else:
150
+ source = 'redis_stream' # 普通任务
151
+
152
+ tasks_data.append({
153
+ 'stream_id': task['id'], # Redis Stream ID作为stream_id
154
+ 'queue': task['queue_name'],
155
+ 'namespace': self.namespace_name,
156
+ 'scheduled_task_id': str(scheduled_task_id) if scheduled_task_id else None,
157
+ 'payload': json.loads(task['task_data']), # 解析为dict
158
+ 'priority': task['priority'],
159
+ 'created_at': task['created_at'],
160
+ 'source': source,
161
+ 'task_metadata': json.loads(task.get('metadata', '{}')) # 对应模型的 task_metadata 字段
162
+ })
163
+
164
+ # 批量插入 - 使用 ORM 的 INSERT ON CONFLICT DO NOTHING
165
+ logger.debug(f"Executing batch insert with {len(tasks_data)} tasks")
166
+
167
+ try:
168
+ # 使用 PostgreSQL 的 insert().on_conflict_do_nothing()
169
+ stmt = insert(Task).values(tasks_data).on_conflict_do_nothing(
170
+ constraint='tasks_pkey' # 主键冲突则跳过
171
+ )
172
+
173
+ await session.execute(stmt)
174
+ await session.commit()
175
+
176
+ # ORM 的 on_conflict_do_nothing 不返回 rowcount,我们假设全部成功
177
+ inserted_count = len(tasks_data)
178
+ logger.debug(f"Tasks table batch insert transaction completed: {inserted_count} tasks")
179
+ return inserted_count
180
+
181
+ except Exception as e:
182
+ logger.error(f"Error in batch insert, trying fallback: {e}")
183
+ await session.rollback()
184
+
185
+ # 降级为逐条插入(更稳妥)
186
+ total_inserted = 0
187
+
188
+ for task_dict in tasks_data:
189
+ try:
190
+ stmt = insert(Task).values(**task_dict).on_conflict_do_nothing(
191
+ constraint='tasks_pkey'
192
+ )
193
+ await session.execute(stmt)
194
+ await session.commit()
195
+ total_inserted += 1
196
+ except Exception as single_error:
197
+ logger.error(f"Failed to insert task {task_dict.get('stream_id')}: {single_error}")
198
+ await session.rollback()
199
+
200
+ if total_inserted > 0:
201
+ logger.info(f"Fallback insert completed: {total_inserted} tasks inserted")
202
+ else:
203
+ logger.info(f"No new tasks inserted in fallback mode")
204
+
205
+ return total_inserted
206
+
207
+ except Exception as e:
208
+ logger.error(f"Error inserting tasks to PostgreSQL: {e}")
209
+ logger.error(traceback.format_exc())
210
+ return 0
211
+
212
+ async def batch_insert_tasks(self, tasks: List[Dict[str, Any]]) -> int:
213
+ """批量插入任务(兼容 buffer.py 调用接口)
214
+
215
+ Args:
216
+ tasks: 任务记录列表
217
+
218
+ Returns:
219
+ 实际插入的记录数
220
+ """
221
+ if not tasks:
222
+ return 0
223
+
224
+ logger.info(f"[BATCH INSERT] 批量插入 {len(tasks)} 条任务...")
225
+
226
+ try:
227
+ async with self.AsyncSessionLocal() as session:
228
+ # 准备 ORM 数据
229
+ insert_data = []
230
+ for record in tasks:
231
+ # record 是从 consumer.py 传入的格式
232
+ insert_data.append({
233
+ 'stream_id': record['stream_id'],
234
+ 'queue': record['queue'],
235
+ 'namespace': record['namespace'],
236
+ 'scheduled_task_id': record.get('scheduled_task_id'),
237
+ 'payload': record.get('payload', {}),
238
+ 'priority': record.get('priority', 0),
239
+ 'created_at': record.get('created_at'),
240
+ 'source': record.get('source', 'redis_stream'),
241
+ 'task_metadata': record.get('metadata', {})
242
+ })
243
+
244
+ # 批量插入 - 使用 PostgreSQL 的 INSERT ON CONFLICT DO NOTHING
245
+ # 使用约束名称而不是列名
246
+ stmt = insert(Task).values(insert_data).on_conflict_do_nothing(
247
+ constraint='tasks_pkey'
248
+ )
249
+
250
+ await session.execute(stmt)
251
+ await session.commit()
252
+
253
+ logger.info(f"[BATCH INSERT] ✓ 成功插入 {len(insert_data)} 条任务")
254
+ return len(insert_data)
255
+
256
+ except Exception as e:
257
+ logger.error(f"[BATCH INSERT] ✗ 批量插入失败: {e}", exc_info=True)
258
+ return 0
259
+
260
+ async def batch_update_tasks(self, updates: List[Dict[str, Any]]) -> int:
261
+ """批量更新任务执行状态到 task_runs 表
262
+
263
+ 使用 PostgreSQL 的 INSERT ... ON CONFLICT DO UPDATE 实现 UPSERT 操作,
264
+ 如果记录存在则更新,不存在则插入。
265
+
266
+ Args:
267
+ updates: 更新记录列表,每条记录包含:
268
+ - stream_id: Redis Stream ID(主键)
269
+ - status: 任务状态
270
+ - result: 执行结果
271
+ - error: 错误信息
272
+ - started_at: 开始时间
273
+ - completed_at: 完成时间
274
+ - retries: 重试次数
275
+
276
+ Returns:
277
+ 实际更新的记录数
278
+ """
279
+ if not updates:
280
+ return 0
281
+
282
+ logger.info(f"[BATCH UPDATE] 批量更新 {len(updates)} 条任务状态...")
283
+ logger.info(f"[BATCH UPDATE] 更新记录示例: {updates[0] if updates else 'N/A'}")
284
+
285
+ try:
286
+ from sqlalchemy.dialects.postgresql import insert
287
+ from ..db.models import TaskRun
288
+ from ..utils.serializer import loads_str
289
+ from datetime import datetime, timezone
290
+
291
+ # 对相同 stream_id 的记录进行去重,保留最新的
292
+ # 使用字典,key 是 stream_id,value 是记录(后面的会覆盖前面的)
293
+ deduplicated = {}
294
+ for record in updates:
295
+ stream_id = record['stream_id']
296
+ deduplicated[stream_id] = record
297
+
298
+ # 转换回列表
299
+ unique_updates = list(deduplicated.values())
300
+
301
+ if len(unique_updates) < len(updates):
302
+ logger.info(
303
+ f"[BATCH UPDATE] 去重: {len(updates)} 条 → {len(unique_updates)} 条 "
304
+ f"(合并了 {len(updates) - len(unique_updates)} 条重复记录)"
305
+ )
306
+
307
+ async with self.AsyncSessionLocal() as session:
308
+ # 准备 UPSERT 数据
309
+ upsert_data = []
310
+ for record in unique_updates:
311
+ logger.debug(f"处理记录: {record}")
312
+ # 解析 result 字段(如果是序列化的字符串)
313
+ result = record.get('result')
314
+ if result and isinstance(result, bytes):
315
+ try:
316
+ result = loads_str(result)
317
+ except Exception:
318
+ result = result.decode('utf-8') if isinstance(result, bytes) else result
319
+
320
+ # 解析 error 字段
321
+ error = record.get('error')
322
+ if error and isinstance(error, bytes):
323
+ error = error.decode('utf-8')
324
+
325
+ # 计算执行时长
326
+ duration = None
327
+ started_at = record.get('started_at')
328
+ completed_at = record.get('completed_at')
329
+ if started_at and completed_at:
330
+ duration = completed_at - started_at
331
+
332
+ # 解析 status 字段
333
+ status = record.get('status')
334
+ if status and isinstance(status, bytes):
335
+ status = status.decode('utf-8')
336
+
337
+ # 解析 consumer 字段
338
+ consumer = record.get('consumer')
339
+ if consumer and isinstance(consumer, bytes):
340
+ consumer = consumer.decode('utf-8')
341
+
342
+ upsert_record = {
343
+ 'stream_id': record['stream_id'],
344
+ 'status': status,
345
+ 'result': result,
346
+ 'error': error,
347
+ 'started_at': started_at,
348
+ 'completed_at': completed_at,
349
+ 'retries': record.get('retries', 0),
350
+ 'duration': duration,
351
+ 'consumer': consumer,
352
+ 'updated_at': datetime.now(timezone.utc),
353
+ }
354
+ logger.debug(f"upsert_record: {upsert_record}")
355
+ upsert_data.append(upsert_record)
356
+
357
+ logger.info(f"[BATCH UPDATE] 准备写入 {len(upsert_data)} 条记录")
358
+
359
+ # 批量 UPSERT - 如果存在则更新,不存在则插入
360
+ stmt = insert(TaskRun).values(upsert_data)
361
+
362
+ # 定义冲突时的更新策略
363
+ # 使用 COALESCE 避免用 NULL 覆盖已有数据
364
+ from sqlalchemy import func
365
+ stmt = stmt.on_conflict_do_update(
366
+ constraint='task_runs_pkey',
367
+ set_={
368
+ # status 总是更新(状态变化)
369
+ 'status': stmt.excluded.status,
370
+ # 其他字段:如果新值不是 NULL,则更新;否则保留旧值
371
+ 'result': func.coalesce(stmt.excluded.result, TaskRun.result),
372
+ 'error': func.coalesce(stmt.excluded.error, TaskRun.error),
373
+ 'started_at': func.coalesce(stmt.excluded.started_at, TaskRun.started_at),
374
+ 'completed_at': func.coalesce(stmt.excluded.completed_at, TaskRun.completed_at),
375
+ 'retries': func.coalesce(stmt.excluded.retries, TaskRun.retries),
376
+ 'duration': func.coalesce(stmt.excluded.duration, TaskRun.duration),
377
+ 'consumer': func.coalesce(stmt.excluded.consumer, TaskRun.consumer),
378
+ # updated_at 总是更新
379
+ 'updated_at': stmt.excluded.updated_at,
380
+ }
381
+ )
382
+
383
+ await session.execute(stmt)
384
+ await session.commit()
385
+
386
+ logger.info(f"[BATCH UPDATE] ✓ 成功更新 {len(upsert_data)} 条任务状态")
387
+ return len(upsert_data)
388
+
389
+ except Exception as e:
390
+ logger.error(f"[BATCH UPDATE] ✗ 批量更新失败: {e}", exc_info=True)
391
+ return 0