service-forge 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +152 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/feedback/feedback_router.py +148 -0
- service_forge/api/routers/service/service_router.py +127 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/current_service.py +14 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +237 -0
- service_forge/db/migrations/feedback_migration.py +154 -0
- service_forge/db/models/__init__.py +0 -0
- service_forge/db/models/feedback.py +33 -0
- service_forge/llm/__init__.py +67 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/feedback.py +30 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +280 -0
- service_forge/service_config.py +44 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +129 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +200 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/storage/__init__.py +5 -0
- service_forge/storage/feedback_storage.py +245 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +28 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
- service_forge/workflow/triggers/fast_api_trigger.py +201 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
- service_forge/workflow/triggers/once_trigger.py +23 -0
- service_forge/workflow/triggers/period_trigger.py +29 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
- service_forge/workflow/workflow.py +227 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_config.py +66 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +51 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.18.dist-info/METADATA +98 -0
- service_forge-0.1.18.dist-info/RECORD +83 -0
- service_forge-0.1.18.dist-info/WHEEL +4 -0
- service_forge-0.1.18.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,86 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, Query
|
|
6
|
+
from fastapi.responses import JSONResponse
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from typing import Dict, Any, Optional
|
|
9
|
+
from .websocket_manager import websocket_manager
|
|
10
|
+
|
|
11
|
+
router = APIRouter(prefix="/ws", tags=["websocket"])
|
|
12
|
+
|
|
13
|
+
@router.websocket("/connect")
|
|
14
|
+
async def websocket_endpoint(websocket: WebSocket, client_id: Optional[str] = Query(None)):
|
|
15
|
+
"""WebSocket连接端点,支持指定客户端ID"""
|
|
16
|
+
client_id = await websocket_manager.connect(websocket, client_id)
|
|
17
|
+
try:
|
|
18
|
+
while True:
|
|
19
|
+
# 接收客户端消息
|
|
20
|
+
data = await websocket.receive_text()
|
|
21
|
+
try:
|
|
22
|
+
message = json.loads(data)
|
|
23
|
+
await handle_client_message(client_id, message)
|
|
24
|
+
except json.JSONDecodeError:
|
|
25
|
+
logger.error(f"从客户端 {client_id} 收到无效JSON消息: {data}")
|
|
26
|
+
await websocket_manager.send_personal_message(
|
|
27
|
+
json.dumps({"error": "Invalid JSON format"}),
|
|
28
|
+
client_id
|
|
29
|
+
)
|
|
30
|
+
except WebSocketDisconnect:
|
|
31
|
+
websocket_manager.disconnect(client_id)
|
|
32
|
+
except Exception as e:
|
|
33
|
+
logger.error(f"WebSocket连接处理异常: {e}")
|
|
34
|
+
websocket_manager.disconnect(client_id)
|
|
35
|
+
|
|
36
|
+
async def handle_client_message(client_id: str, message: Dict[str, Any]):
|
|
37
|
+
"""处理来自客户端的消息"""
|
|
38
|
+
message_type = message.get("type")
|
|
39
|
+
|
|
40
|
+
if message_type == "subscribe":
|
|
41
|
+
# 客户端订阅任务
|
|
42
|
+
task_id_str = message.get("task_id")
|
|
43
|
+
if not task_id_str:
|
|
44
|
+
await websocket_manager.send_personal_message(
|
|
45
|
+
json.dumps({"error": "Missing task_id in subscribe message"}),
|
|
46
|
+
client_id
|
|
47
|
+
)
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
try:
|
|
51
|
+
task_id = uuid.UUID(task_id_str)
|
|
52
|
+
success = await websocket_manager.subscribe_to_task(client_id, task_id)
|
|
53
|
+
response = {"success": success}
|
|
54
|
+
await websocket_manager.send_personal_message(json.dumps(response), client_id)
|
|
55
|
+
except ValueError:
|
|
56
|
+
await websocket_manager.send_personal_message(
|
|
57
|
+
json.dumps({"error": "Invalid task_id format"}),
|
|
58
|
+
client_id
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
elif message_type == "unsubscribe":
|
|
62
|
+
# 客户端取消订阅任务
|
|
63
|
+
task_id_str = message.get("task_id")
|
|
64
|
+
if not task_id_str:
|
|
65
|
+
await websocket_manager.send_personal_message(
|
|
66
|
+
json.dumps({"error": "Missing task_id in unsubscribe message"}),
|
|
67
|
+
client_id
|
|
68
|
+
)
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
try:
|
|
72
|
+
task_id = uuid.UUID(task_id_str)
|
|
73
|
+
success = await websocket_manager.unsubscribe_from_task(client_id, task_id)
|
|
74
|
+
response = {"success": success}
|
|
75
|
+
await websocket_manager.send_personal_message(json.dumps(response), client_id)
|
|
76
|
+
except ValueError:
|
|
77
|
+
await websocket_manager.send_personal_message(
|
|
78
|
+
json.dumps({"error": "Invalid task_id format"}),
|
|
79
|
+
client_id
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
else:
|
|
83
|
+
await websocket_manager.send_personal_message(
|
|
84
|
+
json.dumps({"error": f"Unknown message type: {message_type}"}),
|
|
85
|
+
client_id
|
|
86
|
+
)
|
|
@@ -0,0 +1,425 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
import uuid
|
|
4
|
+
import json
|
|
5
|
+
from typing import Dict, List, Set, Any
|
|
6
|
+
from fastapi import WebSocket, WebSocketDisconnect
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from .task_manager import TaskManager
|
|
9
|
+
|
|
10
|
+
class WebSocketManager:
|
|
11
|
+
def __init__(self):
|
|
12
|
+
# 存储活动连接: {client_id: websocket}
|
|
13
|
+
self.active_connections: Dict[str, WebSocket] = {}
|
|
14
|
+
# 存储任务与客户端的映射: {task_id: client_id}
|
|
15
|
+
self.task_client_mapping: Dict[uuid.UUID, str] = {}
|
|
16
|
+
# 存储客户端订阅的任务: {client_id: set(task_id)}
|
|
17
|
+
self.client_task_subscriptions: Dict[str, Set[uuid.UUID]] = {}
|
|
18
|
+
# 存储客户端历史记录,用于重连时恢复订阅: {client_id: last_active_time}
|
|
19
|
+
self.client_history: Dict[str, float] = {}
|
|
20
|
+
# 设置客户端记录过期时间(秒),默认0.5小时
|
|
21
|
+
self.client_history_expiry = 0.5 * 60 * 60
|
|
22
|
+
# 初始化任务管理器
|
|
23
|
+
self.task_manager = TaskManager()
|
|
24
|
+
# 启动定期清理任务
|
|
25
|
+
self._cleanup_task = None
|
|
26
|
+
self._start_cleanup_task()
|
|
27
|
+
|
|
28
|
+
async def connect(self, websocket: WebSocket, client_id: str = None) -> str:
|
|
29
|
+
"""接受WebSocket连接,可以使用指定客户端ID或生成新ID"""
|
|
30
|
+
await websocket.accept()
|
|
31
|
+
|
|
32
|
+
# 如果没有提供客户端ID,则生成一个新的
|
|
33
|
+
if client_id is None:
|
|
34
|
+
client_id = f"client_{uuid.uuid4().hex[:12]}"
|
|
35
|
+
|
|
36
|
+
# 如果客户端ID已存在,先断开旧连接
|
|
37
|
+
if client_id in self.active_connections:
|
|
38
|
+
logger.warning(f"客户端 {client_id} 已存在连接,断开旧连接")
|
|
39
|
+
await self.active_connections[client_id].close()
|
|
40
|
+
|
|
41
|
+
# 更新连接记录
|
|
42
|
+
self.active_connections[client_id] = websocket
|
|
43
|
+
self.client_history[client_id] = asyncio.get_event_loop().time()
|
|
44
|
+
logger.info(f"客户端 {client_id} 已连接到WebSocket")
|
|
45
|
+
|
|
46
|
+
# 确保清理任务已启动
|
|
47
|
+
self._start_cleanup_task()
|
|
48
|
+
|
|
49
|
+
# 发送连接确认消息,包含客户端ID和恢复的订阅信息
|
|
50
|
+
connection_message = {
|
|
51
|
+
"type": "connection established",
|
|
52
|
+
"client_id": client_id,
|
|
53
|
+
"timestamp": str(asyncio.get_event_loop().time()),
|
|
54
|
+
"restored_subscriptions": []
|
|
55
|
+
}
|
|
56
|
+
|
|
57
|
+
# 如果有历史订阅,恢复它们
|
|
58
|
+
if client_id in self.client_task_subscriptions and self.client_task_subscriptions[client_id]:
|
|
59
|
+
restored_tasks = []
|
|
60
|
+
for task_id in self.client_task_subscriptions[client_id]:
|
|
61
|
+
restored_tasks.append(str(task_id))
|
|
62
|
+
logger.info(f"恢复客户端 {client_id} 对任务 {task_id} 的订阅")
|
|
63
|
+
|
|
64
|
+
connection_message["restored_subscriptions"] = restored_tasks
|
|
65
|
+
|
|
66
|
+
await self.send_personal_message(json.dumps(connection_message), client_id)
|
|
67
|
+
return client_id
|
|
68
|
+
|
|
69
|
+
def disconnect(self, client_id: str):
|
|
70
|
+
"""断开WebSocket连接,但保留客户端的订阅信息"""
|
|
71
|
+
if client_id in self.active_connections:
|
|
72
|
+
# 删除连接记录,但保留订阅信息
|
|
73
|
+
del self.active_connections[client_id]
|
|
74
|
+
# 更新客户端的最后活动时间
|
|
75
|
+
self.client_history[client_id] = asyncio.get_event_loop().time()
|
|
76
|
+
logger.info(f"客户端 {client_id} 已断开WebSocket连接,保留订阅信息")
|
|
77
|
+
|
|
78
|
+
async def subscribe_to_task(self, client_id: str, task_id: uuid.UUID) -> bool:
|
|
79
|
+
"""客户端订阅任务"""
|
|
80
|
+
if client_id not in self.client_task_subscriptions:
|
|
81
|
+
self.client_task_subscriptions[client_id] = set()
|
|
82
|
+
|
|
83
|
+
# 添加任务到客户端的订阅列表
|
|
84
|
+
self.client_task_subscriptions[client_id].add(task_id)
|
|
85
|
+
logger.info(f"客户端 {client_id} 已订阅任务 {task_id}")
|
|
86
|
+
return True
|
|
87
|
+
|
|
88
|
+
async def unsubscribe_from_task(self, client_id: str, task_id: uuid.UUID) -> bool:
|
|
89
|
+
"""客户端取消订阅任务"""
|
|
90
|
+
if client_id not in self.client_task_subscriptions:
|
|
91
|
+
return False
|
|
92
|
+
|
|
93
|
+
# 从客户端的订阅列表中移除任务
|
|
94
|
+
if task_id in self.client_task_subscriptions[client_id]:
|
|
95
|
+
self.client_task_subscriptions[client_id].remove(task_id)
|
|
96
|
+
logger.info(f"客户端 {client_id} 已取消订阅任务 {task_id}")
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
return False
|
|
100
|
+
|
|
101
|
+
def create_task_with_client(self, task_id: uuid.UUID, client_id: str, workflow_name: str = "Unknown", steps: int = 1) -> bool:
|
|
102
|
+
"""创建任务与客户端的映射,并添加到任务管理器"""
|
|
103
|
+
# 建立任务与客户端的映射
|
|
104
|
+
self.task_client_mapping[task_id] = client_id
|
|
105
|
+
|
|
106
|
+
# 自动将任务添加到客户端的订阅列表
|
|
107
|
+
if client_id not in self.client_task_subscriptions:
|
|
108
|
+
self.client_task_subscriptions[client_id] = set()
|
|
109
|
+
self.client_task_subscriptions[client_id].add(task_id)
|
|
110
|
+
|
|
111
|
+
# 添加任务到任务管理器
|
|
112
|
+
self.task_manager.add_task(task_id, client_id, workflow_name, steps)
|
|
113
|
+
|
|
114
|
+
logger.info(f"已为任务 {task_id} 与客户端 {client_id} 建立映射")
|
|
115
|
+
return True
|
|
116
|
+
|
|
117
|
+
async def send_personal_message(self, message: str, client_id: str):
|
|
118
|
+
"""向特定客户端发送消息"""
|
|
119
|
+
if client_id in self.active_connections:
|
|
120
|
+
try:
|
|
121
|
+
await self.active_connections[client_id].send_text(message)
|
|
122
|
+
return True
|
|
123
|
+
except Exception as e:
|
|
124
|
+
logger.error(f"向客户端 {client_id} 发送消息失败: {e}")
|
|
125
|
+
self.disconnect(client_id)
|
|
126
|
+
return False
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
async def send_to_task_client(self, task_id: uuid.UUID, message: Dict[str, Any]):
|
|
130
|
+
"""向任务关联的客户端发送消息"""
|
|
131
|
+
if task_id not in self.task_client_mapping:
|
|
132
|
+
return # 没有关联的客户端
|
|
133
|
+
|
|
134
|
+
client_id = self.task_client_mapping[task_id]
|
|
135
|
+
|
|
136
|
+
# 确保task_id是字符串,避免JSON序列化问题
|
|
137
|
+
if "task_id" in message and isinstance(message["task_id"], uuid.UUID):
|
|
138
|
+
message["task_id"] = str(message["task_id"])
|
|
139
|
+
|
|
140
|
+
# 递归处理嵌套字典中的UUID
|
|
141
|
+
def convert_uuids(obj):
|
|
142
|
+
if isinstance(obj, dict):
|
|
143
|
+
return {k: convert_uuids(v) for k, v in obj.items()}
|
|
144
|
+
elif isinstance(obj, list):
|
|
145
|
+
return [convert_uuids(item) for item in obj]
|
|
146
|
+
elif isinstance(obj, uuid.UUID):
|
|
147
|
+
return str(obj)
|
|
148
|
+
else:
|
|
149
|
+
return obj
|
|
150
|
+
|
|
151
|
+
message = convert_uuids(message)
|
|
152
|
+
message_str = json.dumps(message)
|
|
153
|
+
await self.send_personal_message(message_str, client_id)
|
|
154
|
+
|
|
155
|
+
async def send_task_status(self, task_id: uuid.UUID, status: str, node: str = None, progress: float = None, error: str = None):
|
|
156
|
+
"""发送任务状态更新"""
|
|
157
|
+
message = {
|
|
158
|
+
"task_id": str(task_id),
|
|
159
|
+
"type": "status",
|
|
160
|
+
"status": status
|
|
161
|
+
}
|
|
162
|
+
|
|
163
|
+
if node is not None:
|
|
164
|
+
message["node"] = node
|
|
165
|
+
|
|
166
|
+
if progress is not None:
|
|
167
|
+
message["progress"] = progress
|
|
168
|
+
|
|
169
|
+
if error is not None:
|
|
170
|
+
message["error"] = error
|
|
171
|
+
|
|
172
|
+
await self.send_to_task_client(task_id, message)
|
|
173
|
+
|
|
174
|
+
async def send_execution_start(self, task_id: uuid.UUID):
|
|
175
|
+
"""发送任务开始执行消息"""
|
|
176
|
+
# 获取客户端ID
|
|
177
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
178
|
+
|
|
179
|
+
# 更新任务状态为运行中
|
|
180
|
+
self.task_manager.start_task(task_id)
|
|
181
|
+
|
|
182
|
+
# 获取客户端的任务队列信息
|
|
183
|
+
client_tasks = []
|
|
184
|
+
if client_id:
|
|
185
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
186
|
+
|
|
187
|
+
# 获取全局任务队列信息
|
|
188
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
189
|
+
|
|
190
|
+
# 获取当前任务在队列中的位置
|
|
191
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
192
|
+
|
|
193
|
+
message = {
|
|
194
|
+
"task_id": str(task_id),
|
|
195
|
+
"type": "execution start",
|
|
196
|
+
"client_tasks": {
|
|
197
|
+
"total": len(client_tasks),
|
|
198
|
+
"tasks": client_tasks
|
|
199
|
+
},
|
|
200
|
+
"global_queue": global_queue_info,
|
|
201
|
+
"queue_position": queue_position
|
|
202
|
+
}
|
|
203
|
+
await self.send_to_task_client(task_id, message)
|
|
204
|
+
|
|
205
|
+
async def send_executing(self, task_id: uuid.UUID, node: str):
|
|
206
|
+
"""发送节点正在执行消息"""
|
|
207
|
+
# 获取客户端ID
|
|
208
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
209
|
+
|
|
210
|
+
# 获取客户端的任务队列信息
|
|
211
|
+
client_tasks = []
|
|
212
|
+
if client_id:
|
|
213
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
214
|
+
|
|
215
|
+
# 获取全局任务队列信息
|
|
216
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
217
|
+
|
|
218
|
+
# 获取当前任务在队列中的位置
|
|
219
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
220
|
+
|
|
221
|
+
message = {
|
|
222
|
+
"task_id": str(task_id),
|
|
223
|
+
"type": "executing",
|
|
224
|
+
"node": node,
|
|
225
|
+
"client_tasks": {
|
|
226
|
+
"total": len(client_tasks),
|
|
227
|
+
"tasks": client_tasks
|
|
228
|
+
},
|
|
229
|
+
"global_queue": global_queue_info,
|
|
230
|
+
"queue_position": queue_position
|
|
231
|
+
}
|
|
232
|
+
await self.send_to_task_client(task_id, message)
|
|
233
|
+
|
|
234
|
+
async def send_progress(self, task_id: uuid.UUID, node: str, progress: float):
|
|
235
|
+
"""发送节点执行进度消息"""
|
|
236
|
+
# 获取客户端ID
|
|
237
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
238
|
+
|
|
239
|
+
# 获取客户端的任务队列信息
|
|
240
|
+
client_tasks = []
|
|
241
|
+
if client_id:
|
|
242
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
243
|
+
|
|
244
|
+
# 获取全局任务队列信息
|
|
245
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
246
|
+
|
|
247
|
+
# 获取当前任务在队列中的位置
|
|
248
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
249
|
+
|
|
250
|
+
message = {
|
|
251
|
+
"task_id": str(task_id),
|
|
252
|
+
"type": "progress",
|
|
253
|
+
"node": node,
|
|
254
|
+
"progress": progress,
|
|
255
|
+
"client_tasks": {
|
|
256
|
+
"total": len(client_tasks),
|
|
257
|
+
"tasks": client_tasks
|
|
258
|
+
},
|
|
259
|
+
"global_queue": global_queue_info,
|
|
260
|
+
"queue_position": queue_position
|
|
261
|
+
}
|
|
262
|
+
await self.send_to_task_client(task_id, message)
|
|
263
|
+
|
|
264
|
+
async def send_executed(self, task_id: uuid.UUID, node: str, result: Any = None):
|
|
265
|
+
"""发送节点执行完成消息"""
|
|
266
|
+
# 获取客户端ID
|
|
267
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
268
|
+
|
|
269
|
+
# 获取客户端的任务队列信息
|
|
270
|
+
client_tasks = []
|
|
271
|
+
if client_id:
|
|
272
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
273
|
+
|
|
274
|
+
# 获取全局任务队列信息
|
|
275
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
276
|
+
|
|
277
|
+
# 获取当前任务在队列中的位置
|
|
278
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
279
|
+
|
|
280
|
+
message = {
|
|
281
|
+
"task_id": str(task_id),
|
|
282
|
+
"type": "executed",
|
|
283
|
+
"node": node,
|
|
284
|
+
"client_tasks": {
|
|
285
|
+
"total": len(client_tasks),
|
|
286
|
+
"tasks": client_tasks
|
|
287
|
+
},
|
|
288
|
+
"global_queue": global_queue_info,
|
|
289
|
+
"queue_position": queue_position
|
|
290
|
+
}
|
|
291
|
+
|
|
292
|
+
if result is not None:
|
|
293
|
+
# 检查是否为协程对象
|
|
294
|
+
import asyncio
|
|
295
|
+
if asyncio.iscoroutine(result):
|
|
296
|
+
message["result"] = "<coroutine object>"
|
|
297
|
+
else:
|
|
298
|
+
# 尝试序列化结果,如果失败则转换为字符串
|
|
299
|
+
try:
|
|
300
|
+
message["result"] = result
|
|
301
|
+
except TypeError:
|
|
302
|
+
message["result"] = str(result)
|
|
303
|
+
|
|
304
|
+
await self.send_to_task_client(task_id, message)
|
|
305
|
+
|
|
306
|
+
async def send_execution_error(self, task_id: uuid.UUID, node: str, error: str):
|
|
307
|
+
"""发送执行错误消息"""
|
|
308
|
+
# 获取客户端ID
|
|
309
|
+
client_id = self.task_client_mapping.get(task_id)
|
|
310
|
+
|
|
311
|
+
# 更新任务状态为失败
|
|
312
|
+
self.task_manager.fail_task(task_id, error)
|
|
313
|
+
|
|
314
|
+
# 获取客户端的任务队列信息
|
|
315
|
+
client_tasks = []
|
|
316
|
+
if client_id:
|
|
317
|
+
client_tasks = self.task_manager.get_client_tasks(client_id)
|
|
318
|
+
|
|
319
|
+
# 获取全局任务队列信息
|
|
320
|
+
global_queue_info = self.task_manager.get_global_queue_info()
|
|
321
|
+
|
|
322
|
+
# 获取当前任务在队列中的位置
|
|
323
|
+
queue_position = self.task_manager.get_queue_position(task_id)
|
|
324
|
+
|
|
325
|
+
message = {
|
|
326
|
+
"task_id": str(task_id),
|
|
327
|
+
"type": "execution error",
|
|
328
|
+
"node": node,
|
|
329
|
+
"error": error,
|
|
330
|
+
"client_tasks": {
|
|
331
|
+
"total": len(client_tasks),
|
|
332
|
+
"tasks": client_tasks
|
|
333
|
+
},
|
|
334
|
+
"global_queue": global_queue_info,
|
|
335
|
+
"queue_position": queue_position
|
|
336
|
+
}
|
|
337
|
+
await self.send_to_task_client(task_id, message)
|
|
338
|
+
|
|
339
|
+
def _start_cleanup_task(self):
|
|
340
|
+
"""启动定期清理任务"""
|
|
341
|
+
# 检查是否有运行的事件循环
|
|
342
|
+
try:
|
|
343
|
+
loop = asyncio.get_running_loop()
|
|
344
|
+
except RuntimeError:
|
|
345
|
+
# 没有运行的事件循环,延迟启动清理任务
|
|
346
|
+
return
|
|
347
|
+
|
|
348
|
+
if self._cleanup_task is None:
|
|
349
|
+
self._cleanup_task = loop.create_task(self._cleanup_expired_clients())
|
|
350
|
+
|
|
351
|
+
async def _cleanup_expired_clients(self):
|
|
352
|
+
"""定期清理过期的客户端记录"""
|
|
353
|
+
while True:
|
|
354
|
+
try:
|
|
355
|
+
# 每小时执行一次清理
|
|
356
|
+
await asyncio.sleep(60 * 60)
|
|
357
|
+
current_time = asyncio.get_event_loop().time()
|
|
358
|
+
expired_clients = []
|
|
359
|
+
|
|
360
|
+
# 查找过期的客户端记录
|
|
361
|
+
for client_id, last_active_time in self.client_history.items():
|
|
362
|
+
# 如果客户端不在活动连接中且超过过期时间,则标记为过期
|
|
363
|
+
if (client_id not in self.active_connections and
|
|
364
|
+
current_time - last_active_time > self.client_history_expiry):
|
|
365
|
+
expired_clients.append(client_id)
|
|
366
|
+
|
|
367
|
+
# 清理过期客户端的订阅记录
|
|
368
|
+
for client_id in expired_clients:
|
|
369
|
+
# 移除客户端的订阅记录
|
|
370
|
+
if client_id in self.client_task_subscriptions:
|
|
371
|
+
del self.client_task_subscriptions[client_id]
|
|
372
|
+
|
|
373
|
+
# 移除客户端的历史记录
|
|
374
|
+
del self.client_history[client_id]
|
|
375
|
+
|
|
376
|
+
# 清理任务映射中的过期记录
|
|
377
|
+
tasks_to_remove = []
|
|
378
|
+
for task_id, mapped_client_id in self.task_client_mapping.items():
|
|
379
|
+
if mapped_client_id == client_id:
|
|
380
|
+
tasks_to_remove.append(task_id)
|
|
381
|
+
|
|
382
|
+
for task_id in tasks_to_remove:
|
|
383
|
+
del self.task_client_mapping[task_id]
|
|
384
|
+
|
|
385
|
+
logger.info(f"已清理过期客户端 {client_id} 的所有记录")
|
|
386
|
+
|
|
387
|
+
if expired_clients:
|
|
388
|
+
logger.info(f"清理了 {len(expired_clients)} 个过期客户端记录")
|
|
389
|
+
|
|
390
|
+
except Exception as e:
|
|
391
|
+
logger.error(f"清理过期客户端时出错: {e}")
|
|
392
|
+
# 出错后继续尝试清理,不中断任务
|
|
393
|
+
|
|
394
|
+
def set_client_history_expiry(self, seconds: int):
|
|
395
|
+
"""设置客户端记录过期时间"""
|
|
396
|
+
self.client_history_expiry = seconds
|
|
397
|
+
logger.info(f"客户端记录过期时间已设置为 {seconds} 秒")
|
|
398
|
+
|
|
399
|
+
async def send_node_output(self, task_id: uuid.UUID, node: str, port: str, value: Any):
|
|
400
|
+
"""发送节点输出结果消息"""
|
|
401
|
+
message = {
|
|
402
|
+
"task_id": str(task_id),
|
|
403
|
+
"type": "node output",
|
|
404
|
+
"node": node,
|
|
405
|
+
"port": port
|
|
406
|
+
}
|
|
407
|
+
|
|
408
|
+
# 尝试将值序列化为JSON
|
|
409
|
+
try:
|
|
410
|
+
import json
|
|
411
|
+
# 如果值是基本类型,直接使用
|
|
412
|
+
if isinstance(value, (str, int, float, bool)) or value is None:
|
|
413
|
+
message["value"] = value
|
|
414
|
+
else:
|
|
415
|
+
# 尝试JSON序列化复杂对象
|
|
416
|
+
json.dumps(value) # 测试是否可以序列化
|
|
417
|
+
message["value"] = value
|
|
418
|
+
except (TypeError, ValueError, json.JSONDecodeError):
|
|
419
|
+
# 如果序列化失败,转换为字符串
|
|
420
|
+
message["value"] = str(value)
|
|
421
|
+
|
|
422
|
+
await self.send_to_task_client(task_id, message)
|
|
423
|
+
|
|
424
|
+
# 全局WebSocket管理器实例
|
|
425
|
+
websocket_manager = WebSocketManager()
|
|
@@ -0,0 +1,152 @@
|
|
|
1
|
+
from fastapi import FastAPI
|
|
2
|
+
import uvicorn
|
|
3
|
+
from fastapi import APIRouter
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from urllib.parse import urlparse
|
|
6
|
+
from fastapi import HTTPException, Request
|
|
7
|
+
from fastapi.middleware.cors import CORSMiddleware
|
|
8
|
+
from fastapi.openapi.utils import get_openapi
|
|
9
|
+
from service_forge.api.routers.websocket.websocket_router import websocket_router
|
|
10
|
+
from service_forge.api.routers.service.service_router import service_router
|
|
11
|
+
from service_forge.api.routers.feedback.feedback_router import router as feedback_router
|
|
12
|
+
from service_forge.sft.config.sf_metadata import load_metadata
|
|
13
|
+
from service_forge.sft.util.name_util import get_service_url_name
|
|
14
|
+
|
|
15
|
+
def is_trusted_origin(origin_host: str, host: str, trusted_root: str = "ring.shiweinan.com") -> bool:
|
|
16
|
+
"""
|
|
17
|
+
Check if the origin host is trusted based on domain matching.
|
|
18
|
+
|
|
19
|
+
Args:
|
|
20
|
+
origin_host: The hostname from the origin header
|
|
21
|
+
host: The hostname from the host header
|
|
22
|
+
trusted_root: The trusted root domain (can be customized)
|
|
23
|
+
|
|
24
|
+
Returns:
|
|
25
|
+
bool: True if the origin is trusted, False otherwise
|
|
26
|
+
"""
|
|
27
|
+
# Convert to lowercase to avoid case sensitivity issues
|
|
28
|
+
origin_host = origin_host.lower()
|
|
29
|
+
host = host.lower()
|
|
30
|
+
|
|
31
|
+
# Allow same domain, or subdomains under the same trusted root
|
|
32
|
+
return (
|
|
33
|
+
origin_host == host or
|
|
34
|
+
origin_host.endswith("." + trusted_root) or
|
|
35
|
+
host.endswith("." + trusted_root)
|
|
36
|
+
)
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
def create_app(
|
|
40
|
+
app: FastAPI | None = None,
|
|
41
|
+
routers: list[APIRouter] | None = None,
|
|
42
|
+
cors_origins: list[str] | None = None,
|
|
43
|
+
enable_auth_middleware: bool = True,
|
|
44
|
+
trusted_domain: str = "ring.shiweinan.com",
|
|
45
|
+
root_path: str | None = None,
|
|
46
|
+
) -> FastAPI:
|
|
47
|
+
"""
|
|
48
|
+
Create or configure a FastAPI app with common middleware and configuration.
|
|
49
|
+
|
|
50
|
+
Args:
|
|
51
|
+
app: Optional existing FastAPI instance. If None, creates a new one.
|
|
52
|
+
routers: List of APIRouter instances to include
|
|
53
|
+
cors_origins: List of allowed CORS origins. Defaults to ["*"]
|
|
54
|
+
enable_auth_middleware: Whether to enable authentication middleware
|
|
55
|
+
trusted_domain: Trusted domain for origin validation
|
|
56
|
+
|
|
57
|
+
Returns:
|
|
58
|
+
FastAPI: Configured FastAPI application instance
|
|
59
|
+
"""
|
|
60
|
+
if app is None:
|
|
61
|
+
app = FastAPI(root_path=root_path)
|
|
62
|
+
|
|
63
|
+
# Configure CORS middleware
|
|
64
|
+
if cors_origins is None:
|
|
65
|
+
cors_origins = ["*"]
|
|
66
|
+
|
|
67
|
+
app.add_middleware(
|
|
68
|
+
CORSMiddleware,
|
|
69
|
+
allow_origins=cors_origins,
|
|
70
|
+
allow_credentials=True,
|
|
71
|
+
allow_methods=["*"],
|
|
72
|
+
allow_headers=["*"],
|
|
73
|
+
)
|
|
74
|
+
|
|
75
|
+
# Include routers if provided
|
|
76
|
+
if routers:
|
|
77
|
+
for router in routers:
|
|
78
|
+
app.include_router(router)
|
|
79
|
+
|
|
80
|
+
# Always include WebSocket router
|
|
81
|
+
app.include_router(websocket_router)
|
|
82
|
+
|
|
83
|
+
# Include Feedback router
|
|
84
|
+
app.include_router(feedback_router)
|
|
85
|
+
|
|
86
|
+
# Always include Service router
|
|
87
|
+
app.include_router(service_router)
|
|
88
|
+
|
|
89
|
+
# Add authentication middleware if enabled
|
|
90
|
+
if enable_auth_middleware:
|
|
91
|
+
@app.middleware("http")
|
|
92
|
+
async def auth_middleware(request: Request, call_next):
|
|
93
|
+
"""
|
|
94
|
+
Authentication middleware for API routes.
|
|
95
|
+
|
|
96
|
+
Validates user authentication for /api routes with origin-based
|
|
97
|
+
trust verification and X-User-ID header validation.
|
|
98
|
+
"""
|
|
99
|
+
if request.url.path.startswith("/api"):
|
|
100
|
+
origin = request.headers.get("origin") or request.headers.get("referer")
|
|
101
|
+
scheme = request.url.scheme
|
|
102
|
+
host = request.headers.get("host", "")
|
|
103
|
+
is_same_origin = False
|
|
104
|
+
|
|
105
|
+
logger.debug(f"origin {origin}, host:{host}")
|
|
106
|
+
|
|
107
|
+
if origin and host:
|
|
108
|
+
try:
|
|
109
|
+
parsed_origin = urlparse(origin)
|
|
110
|
+
parsed_host = urlparse(f"{scheme}://{host}")
|
|
111
|
+
is_same_origin = (
|
|
112
|
+
parsed_origin.hostname == parsed_host.hostname
|
|
113
|
+
and parsed_origin.port == parsed_host.port
|
|
114
|
+
and is_trusted_origin(parsed_origin.hostname, parsed_host.hostname, trusted_domain)
|
|
115
|
+
)
|
|
116
|
+
except Exception:
|
|
117
|
+
pass # If parsing fails, continue with default behavior
|
|
118
|
+
if not is_same_origin:
|
|
119
|
+
headers = request.headers
|
|
120
|
+
user_id = headers.get("X-User-ID")
|
|
121
|
+
if not user_id:
|
|
122
|
+
raise HTTPException(status_code=401, detail="Unauthorized")
|
|
123
|
+
|
|
124
|
+
request.state.user_id = user_id
|
|
125
|
+
else:
|
|
126
|
+
# Same-origin requests can skip auth, but still set default user_id
|
|
127
|
+
request.state.user_id = "0" # Can be None or default value as needed
|
|
128
|
+
|
|
129
|
+
return await call_next(request)
|
|
130
|
+
|
|
131
|
+
return app
|
|
132
|
+
|
|
133
|
+
async def start_fastapi_server(host: str, port: int):
|
|
134
|
+
try:
|
|
135
|
+
config = uvicorn.Config(
|
|
136
|
+
fastapi_app,
|
|
137
|
+
host=host,
|
|
138
|
+
port=int(port),
|
|
139
|
+
log_level="info",
|
|
140
|
+
access_log=True
|
|
141
|
+
)
|
|
142
|
+
server = uvicorn.Server(config)
|
|
143
|
+
await server.serve()
|
|
144
|
+
except Exception as e:
|
|
145
|
+
logger.error(f"Server error: {e}")
|
|
146
|
+
raise
|
|
147
|
+
|
|
148
|
+
try:
|
|
149
|
+
metadata = load_metadata("sf-meta.yaml")
|
|
150
|
+
fastapi_app = create_app(enable_auth_middleware=False, root_path=f"/api/v1/{get_service_url_name(metadata.name, metadata.version)}")
|
|
151
|
+
except Exception as e:
|
|
152
|
+
fastapi_app = create_app(enable_auth_middleware=False, root_path=None)
|