service-forge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- service_forge/api/http_api.py +138 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/api/websocket_api.py +86 -0
- service_forge/api/websocket_manager.py +425 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +119 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/main.py +121 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +111 -0
- service_forge/service_config.py +115 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +124 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +87 -0
- service_forge/sft/config/injector_default_files.py +97 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +125 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +2 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +74 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +13 -0
- service_forge/workflow/edge.py +31 -0
- service_forge/workflow/node.py +179 -0
- service_forge/workflow/nodes/__init__.py +7 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +61 -0
- service_forge/workflow/port.py +86 -0
- service_forge/workflow/trigger.py +20 -0
- service_forge/workflow/triggers/__init__.py +4 -0
- service_forge/workflow/triggers/fast_api_trigger.py +125 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/workflow.py +251 -0
- service_forge/workflow/workflow_factory.py +227 -0
- service_forge/workflow/workflow_group.py +23 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.0.dist-info/METADATA +93 -0
- service_forge-0.1.0.dist-info/RECORD +64 -0
- service_forge-0.1.0.dist-info/WHEEL +4 -0
- service_forge-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
import uuid
|
|
4
|
+
import json
|
|
5
|
+
from typing import Dict, List, Set, Any
|
|
6
|
+
from fastapi import WebSocket, WebSocketDisconnect
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from .task_manager import TaskManager
|
|
9
|
+
|
|
10
|
+
class WebSocketManager:
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# 存储活动连接: {client_id: websocket}
|
|
13
|
+
self.active_connections: Dict[str, WebSocket] = {}
|
|
14
|
+
# 存储任务与客户端的映射: {task_id: client_id}
|
|
15
|
+
self.task_client_mapping: Dict[uuid.UUID, str] = {}
|
|
16
|
+
# 存储客户端订阅的任务: {client_id: set(task_id)}
|
|
17
|
+
self.client_task_subscriptions: Dict[str, Set[uuid.UUID]] = {}
|
|
18
|
+
# 存储客户端历史记录,用于重连时恢复订阅: {client_id: last_active_time}
|
|
19
|
+
self.client_history: Dict[str, float] = {}
|
|
20
|
+
# 设置客户端记录过期时间(秒),默认0.5小时
|
|
21
|
+
self.client_history_expiry = 0.5 * 60 * 60
|
|
22
|
+
# 初始化任务管理器
|
|
23
|
+
self.task_manager = TaskManager()
|
|
24
|
+
# 启动定期清理任务
|
|
25
|
+
self._cleanup_task = None
|
|
26
|
+
self._start_cleanup_task()
|
|
27
|
+
|
|
28
|
+
async def connect(self, websocket: WebSocket, client_id: str = None) -> str:
|
|
29
|
+
"""接受WebSocket连接,可以使用指定客户端ID或生成新ID"""
|
|
30
|
+
await websocket.accept()
|
|
31
|
+
|
|
32
|
+
# 如果没有提供客户端ID,则生成一个新的
|
|
33
|
+
if client_id is None:
|
|
34
|
+
client_id = f"client_{uuid.uuid4().hex[:12]}"
|
|
35
|
+
|
|
36
|
+
# 如果客户端ID已存在,先断开旧连接
|
|
37
|
+
if client_id in self.active_connections:
|
|
38
|
+
logger.warning(f"客户端 {client_id} 已存在连接,断开旧连接")
|
|
39
|
+
await self.active_connections[client_id].close()
|
|
40
|
+
|
|
41
|
+
# 更新连接记录
|
|
42
|
+
self.active_connections[client_id] = websocket
|
|
43
|
+
self.client_history[client_id] = asyncio.get_event_loop().time()
|
|
44
|
+
logger.info(f"客户端 {client_id} 已连接到WebSocket")
|
|
45
|
+
|
|
46
|
+
# 确保清理任务已启动
|
|
47
|
+
self._start_cleanup_task()
|
|
48
|
+
|
|
49
|
+
# 发送连接确认消息,包含客户端ID和恢复的订阅信息
|
|
50
|
+
connection_message = {
|
|
51
|
+
"type": "connection established",
|
|
52
|
+
"client_id": client_id,
|
|
53
|
+
"timestamp": str(asyncio.get_event_loop().time()),
|
|
54
|
+
"restored_subscriptions": []
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# 如果有历史订阅,恢复它们
|
|
58
|
+
if client_id in self.client_task_subscriptions and self.client_task_subscriptions[client_id]:
|
|
59
|
+
restored_tasks = []
|
|
60
|
+
for task_id in self.client_task_subscriptions[client_id]:
|
|
61
|
+
restored_tasks.append(str(task_id))
|
|
62
|
+
logger.info(f"恢复客户端 {client_id} 对任务 {task_id} 的订阅")
|
|
63
|
+
|
|
64
|
+
connection_message["restored_subscriptions"] = restored_tasks
|
|
65
|
+
|
|
66
|
+
await self.send_personal_message(json.dumps(connection_message), client_id)
|
|
67
|
+
return client_id
|
|
68
|
+
|
|
69
|
+
def disconnect(self, client_id: str):
|
|
70
|
+
"""断开WebSocket连接,但保留客户端的订阅信息"""
|
|
71
|
+
if client_id in self.active_connections:
|
|
72
|
+
# 删除连接记录,但保留订阅信息
|
|
73
|
+
del self.active_connections[client_id]
|
|
74
|
+
# 更新客户端的最后活动时间
|
|
75
|
+
self.client_history[client_id] = asyncio.get_event_loop().time()
|
|
76
|
+
logger.info(f"客户端 {client_id} 已断开WebSocket连接,保留订阅信息")
|
|
77
|
+
|
|
78
|
+
async def subscribe_to_task(self, client_id: str, task_id: uuid.UUID) -> bool:
|
|
79
|
+
"""客户端订阅任务"""
|
|
80
|
+
if client_id not in self.client_task_subscriptions:
|
|
81
|
+
self.client_task_subscriptions[client_id] = set()
|
|
82
|
+
|
|
83
|
+
# 添加任务到客户端的订阅列表
|
|
84
|
+
self.client_task_subscriptions[client_id].add(task_id)
|
|
85
|
+
logger.info(f"客户端 {client_id} 已订阅任务 {task_id}")
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
async def unsubscribe_from_task(self, client_id: str, task_id: uuid.UUID) -> bool:
|
|
89
|
+
"""客户端取消订阅任务"""
|
|
90
|
+
if client_id not in self.client_task_subscriptions:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# 从客户端的订阅列表中移除任务
|
|
94
|
+
if task_id in self.client_task_subscriptions[client_id]:
|
|
95
|
+
self.client_task_subscriptions[client_id].remove(task_id)
|
|
96
|
+
logger.info(f"客户端 {client_id} 已取消订阅任务 {task_id}")
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
def create_task_with_client(self, task_id: uuid.UUID, client_id: str, workflow_name: str = "Unknown", steps: int = 1) -> bool:
|
|
102
|
+
"""创建任务与客户端的映射,并添加到任务管理器"""
|
|
103
|
+
# 建立任务与客户端的映射
|
|
104
|
+
self.task_client_mapping[task_id] = client_id
|
|
105
|
+
|
|
106
|
+
# 自动将任务添加到客户端的订阅列表
|
|
107
|
+
if client_id not in self.client_task_subscriptions:
|
|
108
|
+
self.client_task_subscriptions[client_id] = set()
|
|
109
|
+
self.client_task_subscriptions[client_id].add(task_id)
|
|
110
|
+
|
|
111
|
+
# 添加任务到任务管理器
|
|
112
|
+
self.task_manager.add_task(task_id, client_id, workflow_name, steps)
|
|
113
|
+
|
|
114
|
+
logger.info(f"已为任务 {task_id} 与客户端 {client_id} 建立映射")
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
async def send_personal_message(self, message: str, client_id: str):
|
|
118
|
+
"""向特定客户端发送消息"""
|
|
119
|
+
if client_id in self.active_connections:
|
|
120
|
+
try:
|
|
121
|
+
await self.active_connections[client_id].send_text(message)
|
|
122
|
+
return True
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"向客户端 {client_id} 发送消息失败: {e}")
|
|
125
|
+
self.disconnect(client_id)
|
|
126
|
+
return False
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
async def send_to_task_client(self, task_id: uuid.UUID, message: Dict[str, Any]):
|
|
130
|
+
"""向任务关联的客户端发送消息"""
|
|
131
|
+
if task_id not in self.task_client_mapping:
|
|
132
|
+
return # 没有关联的客户端
|
|
133
|
+
|
|
134
|
+
client_id = self.task_client_mapping[task_id]
|
|
135
|
+
|
|
136
|
+
# 确保task_id是字符串,避免JSON序列化问题
|
|
137
|
+
if "task_id" in message and isinstance(message["task_id"], uuid.UUID):
|
|
138
|
+
message["task_id"] = str(message["task_id"])
|
|
139
|
+
|
|
140
|
+
# 递归处理嵌套字典中的UUID
|
|
141
|
+
def convert_uuids(obj):
|
|
142
|
+
if isinstance(obj, dict):
|
|
143
|
+
return {k: convert_uuids(v) for k, v in obj.items()}
|
|
144
|
+
elif isinstance(obj, list):
|
|
145
|
+
return [convert_uuids(item) for item in obj]
|
|
146
|
+
elif isinstance(obj, uuid.UUID):
|
|
147
|
+
return str(obj)
|
|
148
|
+
else:
|
|
149
|
+
return obj
|
|
150
|
+
|
|
151
|
+
message = convert_uuids(message)
|
|
152
|
+
message_str = json.dumps(message)
|
|
153
|
+
await self.send_personal_message(message_str, client_id)
|
|
154
|
+
|
|
155
|
+
async def send_task_status(self, task_id: uuid.UUID, status: str, node: str = None, progress: float = None, error: str = None):
|
|
156
|
+
"""发送任务状态更新"""
|
|
157
|
+
message = {
|
|
158
|
+
"task_id": str(task_id),
|
|
159
|
+
"type": "status",
|
|
160
|
+
"status": status
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if node is not None:
|
|
164
|
+
message["node"] = node
|
|
165
|
+
|
|
166
|
+
if progress is not None:
|
|
167
|
+
message["progress"] = progress
|
|
168
|
+
|
|
169
|
+
if error is not None:
|
|
170
|
+
message["error"] = error
|
|
171
|
+
|
|
172
|
+
await self.send_to_task_client(task_id, message)
|
|
173
|
+
|
|
174
|
+
async def send_execution_start(self, task_id: uuid.UUID):
|
|
175
|
+
"""发送任务开始执行消息"""
|
|
176
|
+
# 获取客户端ID
|
|
177
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
178
|
+
|
|
179
|
+
# 更新任务状态为运行中
|
|
180
|
+
self.task_manager.start_task(task_id)
|
|
181
|
+
|
|
182
|
+
# 获取客户端的任务队列信息
|
|
183
|
+
client_tasks = []
|
|
184
|
+
if client_id:
|
|
185
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
186
|
+
|
|
187
|
+
# 获取全局任务队列信息
|
|
188
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
189
|
+
|
|
190
|
+
# 获取当前任务在队列中的位置
|
|
191
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
192
|
+
|
|
193
|
+
message = {
|
|
194
|
+
"task_id": str(task_id),
|
|
195
|
+
"type": "execution start",
|
|
196
|
+
"client_tasks": {
|
|
197
|
+
"total": len(client_tasks),
|
|
198
|
+
"tasks": client_tasks
|
|
199
|
+
},
|
|
200
|
+
"global_queue": global_queue_info,
|
|
201
|
+
"queue_position": queue_position
|
|
202
|
+
}
|
|
203
|
+
await self.send_to_task_client(task_id, message)
|
|
204
|
+
|
|
205
|
+
async def send_executing(self, task_id: uuid.UUID, node: str):
|
|
206
|
+
"""发送节点正在执行消息"""
|
|
207
|
+
# 获取客户端ID
|
|
208
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
209
|
+
|
|
210
|
+
# 获取客户端的任务队列信息
|
|
211
|
+
client_tasks = []
|
|
212
|
+
if client_id:
|
|
213
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
214
|
+
|
|
215
|
+
# 获取全局任务队列信息
|
|
216
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
217
|
+
|
|
218
|
+
# 获取当前任务在队列中的位置
|
|
219
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
220
|
+
|
|
221
|
+
message = {
|
|
222
|
+
"task_id": str(task_id),
|
|
223
|
+
"type": "executing",
|
|
224
|
+
"node": node,
|
|
225
|
+
"client_tasks": {
|
|
226
|
+
"total": len(client_tasks),
|
|
227
|
+
"tasks": client_tasks
|
|
228
|
+
},
|
|
229
|
+
"global_queue": global_queue_info,
|
|
230
|
+
"queue_position": queue_position
|
|
231
|
+
}
|
|
232
|
+
await self.send_to_task_client(task_id, message)
|
|
233
|
+
|
|
234
|
+
async def send_progress(self, task_id: uuid.UUID, node: str, progress: float):
|
|
235
|
+
"""发送节点执行进度消息"""
|
|
236
|
+
# 获取客户端ID
|
|
237
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
238
|
+
|
|
239
|
+
# 获取客户端的任务队列信息
|
|
240
|
+
client_tasks = []
|
|
241
|
+
if client_id:
|
|
242
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
243
|
+
|
|
244
|
+
# 获取全局任务队列信息
|
|
245
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
246
|
+
|
|
247
|
+
# 获取当前任务在队列中的位置
|
|
248
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
249
|
+
|
|
250
|
+
message = {
|
|
251
|
+
"task_id": str(task_id),
|
|
252
|
+
"type": "progress",
|
|
253
|
+
"node": node,
|
|
254
|
+
"progress": progress,
|
|
255
|
+
"client_tasks": {
|
|
256
|
+
"total": len(client_tasks),
|
|
257
|
+
"tasks": client_tasks
|
|
258
|
+
},
|
|
259
|
+
"global_queue": global_queue_info,
|
|
260
|
+
"queue_position": queue_position
|
|
261
|
+
}
|
|
262
|
+
await self.send_to_task_client(task_id, message)
|
|
263
|
+
|
|
264
|
+
async def send_executed(self, task_id: uuid.UUID, node: str, result: Any = None):
|
|
265
|
+
"""发送节点执行完成消息"""
|
|
266
|
+
# 获取客户端ID
|
|
267
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
268
|
+
|
|
269
|
+
# 获取客户端的任务队列信息
|
|
270
|
+
client_tasks = []
|
|
271
|
+
if client_id:
|
|
272
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
273
|
+
|
|
274
|
+
# 获取全局任务队列信息
|
|
275
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
276
|
+
|
|
277
|
+
# 获取当前任务在队列中的位置
|
|
278
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
279
|
+
|
|
280
|
+
message = {
|
|
281
|
+
"task_id": str(task_id),
|
|
282
|
+
"type": "executed",
|
|
283
|
+
"node": node,
|
|
284
|
+
"client_tasks": {
|
|
285
|
+
"total": len(client_tasks),
|
|
286
|
+
"tasks": client_tasks
|
|
287
|
+
},
|
|
288
|
+
"global_queue": global_queue_info,
|
|
289
|
+
"queue_position": queue_position
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
if result is not None:
|
|
293
|
+
# 检查是否为协程对象
|
|
294
|
+
import asyncio
|
|
295
|
+
if asyncio.iscoroutine(result):
|
|
296
|
+
message["result"] = "<coroutine object>"
|
|
297
|
+
else:
|
|
298
|
+
# 尝试序列化结果,如果失败则转换为字符串
|
|
299
|
+
try:
|
|
300
|
+
message["result"] = result
|
|
301
|
+
except TypeError:
|
|
302
|
+
message["result"] = str(result)
|
|
303
|
+
|
|
304
|
+
await self.send_to_task_client(task_id, message)
|
|
305
|
+
|
|
306
|
+
async def send_execution_error(self, task_id: uuid.UUID, node: str, error: str):
|
|
307
|
+
"""发送执行错误消息"""
|
|
308
|
+
# 获取客户端ID
|
|
309
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
310
|
+
|
|
311
|
+
# 更新任务状态为失败
|
|
312
|
+
self.task_manager.fail_task(task_id, error)
|
|
313
|
+
|
|
314
|
+
# 获取客户端的任务队列信息
|
|
315
|
+
client_tasks = []
|
|
316
|
+
if client_id:
|
|
317
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
318
|
+
|
|
319
|
+
# 获取全局任务队列信息
|
|
320
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
321
|
+
|
|
322
|
+
# 获取当前任务在队列中的位置
|
|
323
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
324
|
+
|
|
325
|
+
message = {
|
|
326
|
+
"task_id": str(task_id),
|
|
327
|
+
"type": "execution error",
|
|
328
|
+
"node": node,
|
|
329
|
+
"error": error,
|
|
330
|
+
"client_tasks": {
|
|
331
|
+
"total": len(client_tasks),
|
|
332
|
+
"tasks": client_tasks
|
|
333
|
+
},
|
|
334
|
+
"global_queue": global_queue_info,
|
|
335
|
+
"queue_position": queue_position
|
|
336
|
+
}
|
|
337
|
+
await self.send_to_task_client(task_id, message)
|
|
338
|
+
|
|
339
|
+
def _start_cleanup_task(self):
|
|
340
|
+
"""启动定期清理任务"""
|
|
341
|
+
# 检查是否有运行的事件循环
|
|
342
|
+
try:
|
|
343
|
+
loop = asyncio.get_running_loop()
|
|
344
|
+
except RuntimeError:
|
|
345
|
+
# 没有运行的事件循环,延迟启动清理任务
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
if self._cleanup_task is None:
|
|
349
|
+
self._cleanup_task = loop.create_task(self._cleanup_expired_clients())
|
|
350
|
+
|
|
351
|
+
async def _cleanup_expired_clients(self):
|
|
352
|
+
"""定期清理过期的客户端记录"""
|
|
353
|
+
while True:
|
|
354
|
+
try:
|
|
355
|
+
# 每小时执行一次清理
|
|
356
|
+
await asyncio.sleep(60 * 60)
|
|
357
|
+
current_time = asyncio.get_event_loop().time()
|
|
358
|
+
expired_clients = []
|
|
359
|
+
|
|
360
|
+
# 查找过期的客户端记录
|
|
361
|
+
for client_id, last_active_time in self.client_history.items():
|
|
362
|
+
# 如果客户端不在活动连接中且超过过期时间,则标记为过期
|
|
363
|
+
if (client_id not in self.active_connections and
|
|
364
|
+
current_time - last_active_time > self.client_history_expiry):
|
|
365
|
+
expired_clients.append(client_id)
|
|
366
|
+
|
|
367
|
+
# 清理过期客户端的订阅记录
|
|
368
|
+
for client_id in expired_clients:
|
|
369
|
+
# 移除客户端的订阅记录
|
|
370
|
+
if client_id in self.client_task_subscriptions:
|
|
371
|
+
del self.client_task_subscriptions[client_id]
|
|
372
|
+
|
|
373
|
+
# 移除客户端的历史记录
|
|
374
|
+
del self.client_history[client_id]
|
|
375
|
+
|
|
376
|
+
# 清理任务映射中的过期记录
|
|
377
|
+
tasks_to_remove = []
|
|
378
|
+
for task_id, mapped_client_id in self.task_client_mapping.items():
|
|
379
|
+
if mapped_client_id == client_id:
|
|
380
|
+
tasks_to_remove.append(task_id)
|
|
381
|
+
|
|
382
|
+
for task_id in tasks_to_remove:
|
|
383
|
+
del self.task_client_mapping[task_id]
|
|
384
|
+
|
|
385
|
+
logger.info(f"已清理过期客户端 {client_id} 的所有记录")
|
|
386
|
+
|
|
387
|
+
if expired_clients:
|
|
388
|
+
logger.info(f"清理了 {len(expired_clients)} 个过期客户端记录")
|
|
389
|
+
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.error(f"清理过期客户端时出错: {e}")
|
|
392
|
+
# 出错后继续尝试清理,不中断任务
|
|
393
|
+
|
|
394
|
+
def set_client_history_expiry(self, seconds: int):
|
|
395
|
+
"""设置客户端记录过期时间"""
|
|
396
|
+
self.client_history_expiry = seconds
|
|
397
|
+
logger.info(f"客户端记录过期时间已设置为 {seconds} 秒")
|
|
398
|
+
|
|
399
|
+
async def send_node_output(self, task_id: uuid.UUID, node: str, port: str, value: Any):
|
|
400
|
+
"""发送节点输出结果消息"""
|
|
401
|
+
message = {
|
|
402
|
+
"task_id": str(task_id),
|
|
403
|
+
"type": "node output",
|
|
404
|
+
"node": node,
|
|
405
|
+
"port": port
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
# 尝试将值序列化为JSON
|
|
409
|
+
try:
|
|
410
|
+
import json
|
|
411
|
+
# 如果值是基本类型,直接使用
|
|
412
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
413
|
+
message["value"] = value
|
|
414
|
+
else:
|
|
415
|
+
# 尝试JSON序列化复杂对象
|
|
416
|
+
json.dumps(value) # 测试是否可以序列化
|
|
417
|
+
message["value"] = value
|
|
418
|
+
except (TypeError, ValueError, json.JSONDecodeError):
|
|
419
|
+
# 如果序列化失败,转换为字符串
|
|
420
|
+
message["value"] = str(value)
|
|
421
|
+
|
|
422
|
+
await self.send_to_task_client(task_id, message)
|
|
423
|
+
|
|
424
|
+
# 全局WebSocket管理器实例
|
|
425
|
+
websocket_manager = WebSocketManager()
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .database import DatabaseManager, Database
|
|
@@ -0,0 +1,119 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import AsyncGenerator
|
|
4
|
+
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from omegaconf import OmegaConf
|
|
7
|
+
from sqlalchemy.ext.asyncio import AsyncEngine, AsyncSession, async_sessionmaker, create_async_engine
|
|
8
|
+
|
|
9
|
+
class Database:
|
|
10
|
+
def __init__(
|
|
11
|
+
self,
|
|
12
|
+
name: str,
|
|
13
|
+
postgres_user: str,
|
|
14
|
+
postgres_password: str,
|
|
15
|
+
postgres_host: str,
|
|
16
|
+
postgres_port: int,
|
|
17
|
+
postgres_db: str,
|
|
18
|
+
) -> None:
|
|
19
|
+
self.name = name
|
|
20
|
+
self.postgres_user = postgres_user
|
|
21
|
+
self.postgres_password = postgres_password
|
|
22
|
+
self.postgres_host = postgres_host
|
|
23
|
+
self.postgres_port = postgres_port
|
|
24
|
+
self.postgres_db = postgres_db
|
|
25
|
+
self.engine = None
|
|
26
|
+
self.session_factory = None
|
|
27
|
+
|
|
28
|
+
@property
|
|
29
|
+
def database_url(self) -> str:
|
|
30
|
+
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/{self.postgres_db}"
|
|
31
|
+
|
|
32
|
+
@property
|
|
33
|
+
def database_base_url(self) -> str:
|
|
34
|
+
return f"postgresql+asyncpg://{self.postgres_user}:{self.postgres_password}@{self.postgres_host}:{self.postgres_port}/postgres"
|
|
35
|
+
|
|
36
|
+
async def init(self) -> None:
|
|
37
|
+
if self.engine is None:
|
|
38
|
+
self.engine = await self.create_engine()
|
|
39
|
+
self.session_factory = async_sessionmaker(bind=self.engine, class_=AsyncSession, expire_on_commit=False)
|
|
40
|
+
|
|
41
|
+
async def close(self) -> None:
|
|
42
|
+
if self.engine:
|
|
43
|
+
await self.engine.dispose()
|
|
44
|
+
self.engine = None
|
|
45
|
+
self.session_factory = None
|
|
46
|
+
logger.info("Database connection closed")
|
|
47
|
+
|
|
48
|
+
async def create_engine(self) -> AsyncEngine:
|
|
49
|
+
if not all([self.postgres_user, self.postgres_host, self.postgres_port, self.postgres_db]):
|
|
50
|
+
raise ValueError("Missing required database configuration. Please check your .env file or configuration.")
|
|
51
|
+
logger.info(f"Creating database engine: {self.database_url}")
|
|
52
|
+
return create_async_engine(self.database_url)
|
|
53
|
+
|
|
54
|
+
async def get_async_session(self) -> AsyncGenerator[AsyncSession, None]:
|
|
55
|
+
if self.session_factory is None:
|
|
56
|
+
await self.init()
|
|
57
|
+
|
|
58
|
+
if self.session_factory is None:
|
|
59
|
+
raise RuntimeError("Session factory is not initialized")
|
|
60
|
+
|
|
61
|
+
async with self.session_factory() as session:
|
|
62
|
+
try:
|
|
63
|
+
yield session
|
|
64
|
+
except Exception:
|
|
65
|
+
await session.rollback()
|
|
66
|
+
raise
|
|
67
|
+
finally:
|
|
68
|
+
await session.close()
|
|
69
|
+
yield session
|
|
70
|
+
|
|
71
|
+
async def get_session_factory(self) -> async_sessionmaker[AsyncSession]:
|
|
72
|
+
if self.engine is None:
|
|
73
|
+
await self.init()
|
|
74
|
+
|
|
75
|
+
if self.session_factory is None:
|
|
76
|
+
raise RuntimeError("Session factory is not initialized")
|
|
77
|
+
|
|
78
|
+
return self.session_factory
|
|
79
|
+
|
|
80
|
+
class DatabaseManager:
|
|
81
|
+
def __init__(
|
|
82
|
+
self,
|
|
83
|
+
databases: list[Database],
|
|
84
|
+
) -> None:
|
|
85
|
+
self.databases = databases
|
|
86
|
+
|
|
87
|
+
def get_database(self, name: str) -> Database | None:
|
|
88
|
+
for database in self.databases:
|
|
89
|
+
if database.name == name:
|
|
90
|
+
return database
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
def get_default_database(self) -> Database | None:
|
|
94
|
+
if len(self.databases) > 0:
|
|
95
|
+
return self.databases[0]
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
@staticmethod
|
|
99
|
+
def from_config(config_path: str = None, config = None) -> DatabaseManager:
|
|
100
|
+
if config is None:
|
|
101
|
+
config = OmegaConf.to_object(OmegaConf.load(config_path))
|
|
102
|
+
|
|
103
|
+
databases_config = config.get('databases', None)
|
|
104
|
+
databases = []
|
|
105
|
+
if databases_config is not None:
|
|
106
|
+
for database_config in databases_config:
|
|
107
|
+
if 'postgres_db' in database_config and database_config['postgres_db'] is not None:
|
|
108
|
+
databases.append(Database(
|
|
109
|
+
name=database_config['name'],
|
|
110
|
+
postgres_user=database_config['postgres_user'],
|
|
111
|
+
postgres_password=database_config['postgres_password'],
|
|
112
|
+
postgres_host=database_config['postgres_host'],
|
|
113
|
+
postgres_port=database_config['postgres_port'],
|
|
114
|
+
postgres_db=database_config['postgres_db'],
|
|
115
|
+
))
|
|
116
|
+
return DatabaseManager(databases=databases)
|
|
117
|
+
|
|
118
|
+
def create_database_manager(config_path: str = None, config = None) -> DatabaseManager:
|
|
119
|
+
return DatabaseManager.from_config(config_path, config)
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from .llm import LLM
|
|
3
|
+
from enum import Enum
|
|
4
|
+
from typing import Iterator
|
|
5
|
+
|
|
6
|
+
_llm_dicts = {}
|
|
7
|
+
|
|
8
|
+
class Model(Enum):
|
|
9
|
+
GPT_4_1_NANO = "gpt-4.1-nano"
|
|
10
|
+
QWEN_TURBO_LATEST = "qwen-turbo-latest"
|
|
11
|
+
QWEN_PLUS_LATEST = "qwen-plus-latest"
|
|
12
|
+
QWEN_MAX_LATEST = "qwen-max-latest"
|
|
13
|
+
DOUBO_SEED_1_6_250615 = "doubao-seed-1-6-250615"
|
|
14
|
+
DOUBO_SEED_1_6_THINKING_250615 = "doubao-seed-1-6-thinking-250615"
|
|
15
|
+
DOUBO_SEED_1_6_FLASH_250615 = "doubao-seed-1-6-flash-250615"
|
|
16
|
+
DEEPSEEK_V3_250324 = "deepseek-v3-250324"
|
|
17
|
+
AZURE_GPT_4O_MINI = "azure-gpt-4o-mini"
|
|
18
|
+
|
|
19
|
+
def provider(self) -> str:
|
|
20
|
+
if self.value.startswith("gpt"):
|
|
21
|
+
return "openai"
|
|
22
|
+
elif self.value.startswith("qwen"):
|
|
23
|
+
return "dashscope"
|
|
24
|
+
elif self.value.startswith("doubao"):
|
|
25
|
+
return "doubao"
|
|
26
|
+
elif self.value.startswith("deepseek"):
|
|
27
|
+
return "deepseek"
|
|
28
|
+
elif self.value.startswith("azure"):
|
|
29
|
+
return "azure"
|
|
30
|
+
raise ValueError(f"Invalid model: {self.value}")
|
|
31
|
+
|
|
32
|
+
def get_model(model: str) -> Model:
|
|
33
|
+
if model in Model.__members__:
|
|
34
|
+
return Model[model]
|
|
35
|
+
|
|
36
|
+
model = model.upper().replace("-", "_")
|
|
37
|
+
if model in Model.__members__:
|
|
38
|
+
return Model[model]
|
|
39
|
+
|
|
40
|
+
raise ValueError(f"Invalid model: {model}")
|
|
41
|
+
|
|
42
|
+
def get_llm(provider: str) -> LLM:
|
|
43
|
+
if provider not in _llm_dicts:
|
|
44
|
+
if provider == "openai":
|
|
45
|
+
_llm_dicts[provider] = LLM(os.environ.get("OPENAI_API_KEY", ""), os.environ.get("OPENAI_BASE_URL", ""), int(os.environ.get("OPENAI_TIMEOUT", 2000)))
|
|
46
|
+
elif provider == "doubao":
|
|
47
|
+
_llm_dicts[provider] = LLM(os.environ.get("DOUBAO_API_KEY", ""), os.environ.get("DOUBAO_BASE_URL", ""), int(os.environ.get("DOUBAO_TIMEOUT", 2000)))
|
|
48
|
+
elif provider == "dashscope":
|
|
49
|
+
_llm_dicts[provider] = LLM(os.environ.get("DASHSCOPE_API_KEY", ""), os.environ.get("DASHSCOPE_BASE_URL", ""), int(os.environ.get("DASHSCOPE_TIMEOUT", 2000)))
|
|
50
|
+
elif provider == "deepseek":
|
|
51
|
+
_llm_dicts[provider] = LLM(os.environ.get("DEEPSEEK_API_KEY", ""), os.environ.get("DEEPSEEK_BASE_URL", ""), int(os.environ.get("DEEPSEEK_TIMEOUT", 2000)))
|
|
52
|
+
elif provider == "azure":
|
|
53
|
+
_llm_dicts[provider] = LLM(os.environ.get("AZURE_API_KEY", ""), os.environ.get("AZURE_BASE_URL", ""), int(os.environ.get("AZURE_TIMEOUT", 2000)), os.environ.get("AZURE_API_VERSION", ""))
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Invalid provider: {provider}")
|
|
56
|
+
return _llm_dicts[provider]
|
|
57
|
+
|
|
58
|
+
def chat(input: str, system_prompt: str, model: Model, temperature: float) -> str:
|
|
59
|
+
return get_llm(model.provider()).chat(input, system_prompt, model.value, temperature)
|
|
60
|
+
|
|
61
|
+
def chat_stream(input: str, system_prompt: str, model: Model, temperature: float) -> Iterator[str]:
|
|
62
|
+
return get_llm(model.provider()).chat_stream(input, system_prompt, model.value, temperature)
|
service_forge/llm/llm.py
ADDED
|
@@ -0,0 +1,56 @@
|
|
|
1
|
+
import random
|
|
2
|
+
from openai import OpenAI
|
|
3
|
+
from openai import AzureOpenAI
|
|
4
|
+
from typing import Iterator
|
|
5
|
+
|
|
6
|
+
class LLM():
|
|
7
|
+
def __init__(self, api_key: str, base_url: str, timeout: int, api_version: str | None = None):
|
|
8
|
+
if api_version is not None:
|
|
9
|
+
self.client = AzureOpenAI(
|
|
10
|
+
api_key=api_key,
|
|
11
|
+
azure_endpoint=base_url,
|
|
12
|
+
timeout=timeout,
|
|
13
|
+
api_version=api_version,
|
|
14
|
+
)
|
|
15
|
+
else:
|
|
16
|
+
self.client = OpenAI(
|
|
17
|
+
api_key=api_key,
|
|
18
|
+
base_url=base_url,
|
|
19
|
+
timeout=timeout,
|
|
20
|
+
)
|
|
21
|
+
|
|
22
|
+
def chat(self, input: str, system_prompt: str, model: str, temperature: float) -> str:
|
|
23
|
+
if model.startswith("azure"):
|
|
24
|
+
model = model.replace("azure-", "")
|
|
25
|
+
|
|
26
|
+
response = self.client.chat.completions.create(
|
|
27
|
+
model=model,
|
|
28
|
+
messages=[
|
|
29
|
+
{"role": "system", "content": system_prompt},
|
|
30
|
+
{"role": "user", "content": input},
|
|
31
|
+
],
|
|
32
|
+
temperature=temperature,
|
|
33
|
+
)
|
|
34
|
+
|
|
35
|
+
if response.choices[0].message.content is None:
|
|
36
|
+
return "Error"
|
|
37
|
+
else:
|
|
38
|
+
return response.choices[0].message.content
|
|
39
|
+
|
|
40
|
+
def chat_stream(self, input: str, system_prompt: str, model: str, temperature: float) -> Iterator[str]:
|
|
41
|
+
if model.startswith("azure"):
|
|
42
|
+
model = model.replace("azure-", "")
|
|
43
|
+
|
|
44
|
+
stream = self.client.chat.completions.create(
|
|
45
|
+
model=model,
|
|
46
|
+
messages=[
|
|
47
|
+
{"role": "system", "content": system_prompt},
|
|
48
|
+
{"role": "user", "content": input},
|
|
49
|
+
],
|
|
50
|
+
temperature=temperature,
|
|
51
|
+
stream=True,
|
|
52
|
+
)
|
|
53
|
+
|
|
54
|
+
for chunk in stream:
|
|
55
|
+
if chunk.choices[0].delta.content is not None:
|
|
56
|
+
yield chunk.choices[0].delta.content
|