service-forge 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (83) hide show
  1. service_forge/api/deprecated_websocket_api.py +86 -0
  2. service_forge/api/deprecated_websocket_manager.py +425 -0
  3. service_forge/api/http_api.py +152 -0
  4. service_forge/api/http_api_doc.py +455 -0
  5. service_forge/api/kafka_api.py +126 -0
  6. service_forge/api/routers/feedback/feedback_router.py +148 -0
  7. service_forge/api/routers/service/service_router.py +127 -0
  8. service_forge/api/routers/websocket/websocket_manager.py +83 -0
  9. service_forge/api/routers/websocket/websocket_router.py +78 -0
  10. service_forge/api/task_manager.py +141 -0
  11. service_forge/current_service.py +14 -0
  12. service_forge/db/__init__.py +1 -0
  13. service_forge/db/database.py +237 -0
  14. service_forge/db/migrations/feedback_migration.py +154 -0
  15. service_forge/db/models/__init__.py +0 -0
  16. service_forge/db/models/feedback.py +33 -0
  17. service_forge/llm/__init__.py +67 -0
  18. service_forge/llm/llm.py +56 -0
  19. service_forge/model/__init__.py +0 -0
  20. service_forge/model/feedback.py +30 -0
  21. service_forge/model/websocket.py +13 -0
  22. service_forge/proto/foo_input.py +5 -0
  23. service_forge/service.py +280 -0
  24. service_forge/service_config.py +44 -0
  25. service_forge/sft/cli.py +91 -0
  26. service_forge/sft/cmd/config_command.py +67 -0
  27. service_forge/sft/cmd/deploy_service.py +123 -0
  28. service_forge/sft/cmd/list_tars.py +41 -0
  29. service_forge/sft/cmd/service_command.py +149 -0
  30. service_forge/sft/cmd/upload_service.py +36 -0
  31. service_forge/sft/config/injector.py +129 -0
  32. service_forge/sft/config/injector_default_files.py +131 -0
  33. service_forge/sft/config/sf_metadata.py +30 -0
  34. service_forge/sft/config/sft_config.py +200 -0
  35. service_forge/sft/file/__init__.py +0 -0
  36. service_forge/sft/file/ignore_pattern.py +80 -0
  37. service_forge/sft/file/sft_file_manager.py +107 -0
  38. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  39. service_forge/sft/util/assert_util.py +25 -0
  40. service_forge/sft/util/logger.py +16 -0
  41. service_forge/sft/util/name_util.py +8 -0
  42. service_forge/sft/util/yaml_utils.py +57 -0
  43. service_forge/storage/__init__.py +5 -0
  44. service_forge/storage/feedback_storage.py +245 -0
  45. service_forge/utils/__init__.py +0 -0
  46. service_forge/utils/default_type_converter.py +12 -0
  47. service_forge/utils/register.py +39 -0
  48. service_forge/utils/type_converter.py +99 -0
  49. service_forge/utils/workflow_clone.py +124 -0
  50. service_forge/workflow/__init__.py +1 -0
  51. service_forge/workflow/context.py +14 -0
  52. service_forge/workflow/edge.py +24 -0
  53. service_forge/workflow/node.py +184 -0
  54. service_forge/workflow/nodes/__init__.py +8 -0
  55. service_forge/workflow/nodes/control/if_node.py +29 -0
  56. service_forge/workflow/nodes/control/switch_node.py +28 -0
  57. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  58. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  59. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  60. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  61. service_forge/workflow/nodes/output/print_node.py +29 -0
  62. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  63. service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
  64. service_forge/workflow/port.py +89 -0
  65. service_forge/workflow/trigger.py +28 -0
  66. service_forge/workflow/triggers/__init__.py +6 -0
  67. service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
  68. service_forge/workflow/triggers/fast_api_trigger.py +201 -0
  69. service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
  70. service_forge/workflow/triggers/once_trigger.py +23 -0
  71. service_forge/workflow/triggers/period_trigger.py +29 -0
  72. service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
  73. service_forge/workflow/workflow.py +227 -0
  74. service_forge/workflow/workflow_callback.py +141 -0
  75. service_forge/workflow/workflow_config.py +66 -0
  76. service_forge/workflow/workflow_event.py +15 -0
  77. service_forge/workflow/workflow_factory.py +246 -0
  78. service_forge/workflow/workflow_group.py +51 -0
  79. service_forge/workflow/workflow_type.py +52 -0
  80. service_forge-0.1.18.dist-info/METADATA +98 -0
  81. service_forge-0.1.18.dist-info/RECORD +83 -0
  82. service_forge-0.1.18.dist-info/WHEEL +4 -0
  83. service_forge-0.1.18.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,86 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ import asyncio
4
+ import json
5
+ from fastapi import APIRouter, WebSocket, WebSocketDisconnect, Depends, Request, Query
6
+ from fastapi.responses import JSONResponse
7
+ from loguru import logger
8
+ from typing import Dict, Any, Optional
9
+ from .websocket_manager import websocket_manager
10
+
11
+ router = APIRouter(prefix="/ws", tags=["websocket"])
12
+
13
+ @router.websocket("/connect")
14
+ async def websocket_endpoint(websocket: WebSocket, client_id: Optional[str] = Query(None)):
15
+ """WebSocket连接端点,支持指定客户端ID"""
16
+ client_id = await websocket_manager.connect(websocket, client_id)
17
+ try:
18
+ while True:
19
+ # 接收客户端消息
20
+ data = await websocket.receive_text()
21
+ try:
22
+ message = json.loads(data)
23
+ await handle_client_message(client_id, message)
24
+ except json.JSONDecodeError:
25
+ logger.error(f"从客户端 {client_id} 收到无效JSON消息: {data}")
26
+ await websocket_manager.send_personal_message(
27
+ json.dumps({"error": "Invalid JSON format"}),
28
+ client_id
29
+ )
30
+ except WebSocketDisconnect:
31
+ websocket_manager.disconnect(client_id)
32
+ except Exception as e:
33
+ logger.error(f"WebSocket连接处理异常: {e}")
34
+ websocket_manager.disconnect(client_id)
35
+
36
+ async def handle_client_message(client_id: str, message: Dict[str, Any]):
37
+ """处理来自客户端的消息"""
38
+ message_type = message.get("type")
39
+
40
+ if message_type == "subscribe":
41
+ # 客户端订阅任务
42
+ task_id_str = message.get("task_id")
43
+ if not task_id_str:
44
+ await websocket_manager.send_personal_message(
45
+ json.dumps({"error": "Missing task_id in subscribe message"}),
46
+ client_id
47
+ )
48
+ return
49
+
50
+ try:
51
+ task_id = uuid.UUID(task_id_str)
52
+ success = await websocket_manager.subscribe_to_task(client_id, task_id)
53
+ response = {"success": success}
54
+ await websocket_manager.send_personal_message(json.dumps(response), client_id)
55
+ except ValueError:
56
+ await websocket_manager.send_personal_message(
57
+ json.dumps({"error": "Invalid task_id format"}),
58
+ client_id
59
+ )
60
+
61
+ elif message_type == "unsubscribe":
62
+ # 客户端取消订阅任务
63
+ task_id_str = message.get("task_id")
64
+ if not task_id_str:
65
+ await websocket_manager.send_personal_message(
66
+ json.dumps({"error": "Missing task_id in unsubscribe message"}),
67
+ client_id
68
+ )
69
+ return
70
+
71
+ try:
72
+ task_id = uuid.UUID(task_id_str)
73
+ success = await websocket_manager.unsubscribe_from_task(client_id, task_id)
74
+ response = {"success": success}
75
+ await websocket_manager.send_personal_message(json.dumps(response), client_id)
76
+ except ValueError:
77
+ await websocket_manager.send_personal_message(
78
+ json.dumps({"error": "Invalid task_id format"}),
79
+ client_id
80
+ )
81
+
82
+ else:
83
+ await websocket_manager.send_personal_message(
84
+ json.dumps({"error": f"Unknown message type: {message_type}"}),
85
+ client_id
86
+ )
@@ -0,0 +1,425 @@
1
+ from __future__ import annotations
2
+ import asyncio
3
+ import uuid
4
+ import json
5
+ from typing import Dict, List, Set, Any
6
+ from fastapi import WebSocket, WebSocketDisconnect
7
+ from loguru import logger
8
+ from .task_manager import TaskManager
9
+
10
+ class WebSocketManager:
11
+ def __init__(self):
12
+ # 存储活动连接: {client_id: websocket}
13
+ self.active_connections: Dict[str, WebSocket] = {}
14
+ # 存储任务与客户端的映射: {task_id: client_id}
15
+ self.task_client_mapping: Dict[uuid.UUID, str] = {}
16
+ # 存储客户端订阅的任务: {client_id: set(task_id)}
17
+ self.client_task_subscriptions: Dict[str, Set[uuid.UUID]] = {}
18
+ # 存储客户端历史记录,用于重连时恢复订阅: {client_id: last_active_time}
19
+ self.client_history: Dict[str, float] = {}
20
+ # 设置客户端记录过期时间(秒),默认0.5小时
21
+ self.client_history_expiry = 0.5 * 60 * 60
22
+ # 初始化任务管理器
23
+ self.task_manager = TaskManager()
24
+ # 启动定期清理任务
25
+ self._cleanup_task = None
26
+ self._start_cleanup_task()
27
+
28
+ async def connect(self, websocket: WebSocket, client_id: str = None) -> str:
29
+ """接受WebSocket连接,可以使用指定客户端ID或生成新ID"""
30
+ await websocket.accept()
31
+
32
+ # 如果没有提供客户端ID,则生成一个新的
33
+ if client_id is None:
34
+ client_id = f"client_{uuid.uuid4().hex[:12]}"
35
+
36
+ # 如果客户端ID已存在,先断开旧连接
37
+ if client_id in self.active_connections:
38
+ logger.warning(f"客户端 {client_id} 已存在连接,断开旧连接")
39
+ await self.active_connections[client_id].close()
40
+
41
+ # 更新连接记录
42
+ self.active_connections[client_id] = websocket
43
+ self.client_history[client_id] = asyncio.get_event_loop().time()
44
+ logger.info(f"客户端 {client_id} 已连接到WebSocket")
45
+
46
+ # 确保清理任务已启动
47
+ self._start_cleanup_task()
48
+
49
+ # 发送连接确认消息,包含客户端ID和恢复的订阅信息
50
+ connection_message = {
51
+ "type": "connection established",
52
+ "client_id": client_id,
53
+ "timestamp": str(asyncio.get_event_loop().time()),
54
+ "restored_subscriptions": []
55
+ }
56
+
57
+ # 如果有历史订阅,恢复它们
58
+ if client_id in self.client_task_subscriptions and self.client_task_subscriptions[client_id]:
59
+ restored_tasks = []
60
+ for task_id in self.client_task_subscriptions[client_id]:
61
+ restored_tasks.append(str(task_id))
62
+ logger.info(f"恢复客户端 {client_id} 对任务 {task_id} 的订阅")
63
+
64
+ connection_message["restored_subscriptions"] = restored_tasks
65
+
66
+ await self.send_personal_message(json.dumps(connection_message), client_id)
67
+ return client_id
68
+
69
+ def disconnect(self, client_id: str):
70
+ """断开WebSocket连接,但保留客户端的订阅信息"""
71
+ if client_id in self.active_connections:
72
+ # 删除连接记录,但保留订阅信息
73
+ del self.active_connections[client_id]
74
+ # 更新客户端的最后活动时间
75
+ self.client_history[client_id] = asyncio.get_event_loop().time()
76
+ logger.info(f"客户端 {client_id} 已断开WebSocket连接,保留订阅信息")
77
+
78
+ async def subscribe_to_task(self, client_id: str, task_id: uuid.UUID) -> bool:
79
+ """客户端订阅任务"""
80
+ if client_id not in self.client_task_subscriptions:
81
+ self.client_task_subscriptions[client_id] = set()
82
+
83
+ # 添加任务到客户端的订阅列表
84
+ self.client_task_subscriptions[client_id].add(task_id)
85
+ logger.info(f"客户端 {client_id} 已订阅任务 {task_id}")
86
+ return True
87
+
88
+ async def unsubscribe_from_task(self, client_id: str, task_id: uuid.UUID) -> bool:
89
+ """客户端取消订阅任务"""
90
+ if client_id not in self.client_task_subscriptions:
91
+ return False
92
+
93
+ # 从客户端的订阅列表中移除任务
94
+ if task_id in self.client_task_subscriptions[client_id]:
95
+ self.client_task_subscriptions[client_id].remove(task_id)
96
+ logger.info(f"客户端 {client_id} 已取消订阅任务 {task_id}")
97
+ return True
98
+
99
+ return False
100
+
101
+ def create_task_with_client(self, task_id: uuid.UUID, client_id: str, workflow_name: str = "Unknown", steps: int = 1) -> bool:
102
+ """创建任务与客户端的映射,并添加到任务管理器"""
103
+ # 建立任务与客户端的映射
104
+ self.task_client_mapping[task_id] = client_id
105
+
106
+ # 自动将任务添加到客户端的订阅列表
107
+ if client_id not in self.client_task_subscriptions:
108
+ self.client_task_subscriptions[client_id] = set()
109
+ self.client_task_subscriptions[client_id].add(task_id)
110
+
111
+ # 添加任务到任务管理器
112
+ self.task_manager.add_task(task_id, client_id, workflow_name, steps)
113
+
114
+ logger.info(f"已为任务 {task_id} 与客户端 {client_id} 建立映射")
115
+ return True
116
+
117
+ async def send_personal_message(self, message: str, client_id: str):
118
+ """向特定客户端发送消息"""
119
+ if client_id in self.active_connections:
120
+ try:
121
+ await self.active_connections[client_id].send_text(message)
122
+ return True
123
+ except Exception as e:
124
+ logger.error(f"向客户端 {client_id} 发送消息失败: {e}")
125
+ self.disconnect(client_id)
126
+ return False
127
+ return False
128
+
129
+ async def send_to_task_client(self, task_id: uuid.UUID, message: Dict[str, Any]):
130
+ """向任务关联的客户端发送消息"""
131
+ if task_id not in self.task_client_mapping:
132
+ return # 没有关联的客户端
133
+
134
+ client_id = self.task_client_mapping[task_id]
135
+
136
+ # 确保task_id是字符串,避免JSON序列化问题
137
+ if "task_id" in message and isinstance(message["task_id"], uuid.UUID):
138
+ message["task_id"] = str(message["task_id"])
139
+
140
+ # 递归处理嵌套字典中的UUID
141
+ def convert_uuids(obj):
142
+ if isinstance(obj, dict):
143
+ return {k: convert_uuids(v) for k, v in obj.items()}
144
+ elif isinstance(obj, list):
145
+ return [convert_uuids(item) for item in obj]
146
+ elif isinstance(obj, uuid.UUID):
147
+ return str(obj)
148
+ else:
149
+ return obj
150
+
151
+ message = convert_uuids(message)
152
+ message_str = json.dumps(message)
153
+ await self.send_personal_message(message_str, client_id)
154
+
155
+ async def send_task_status(self, task_id: uuid.UUID, status: str, node: str = None, progress: float = None, error: str = None):
156
+ """发送任务状态更新"""
157
+ message = {
158
+ "task_id": str(task_id),
159
+ "type": "status",
160
+ "status": status
161
+ }
162
+
163
+ if node is not None:
164
+ message["node"] = node
165
+
166
+ if progress is not None:
167
+ message["progress"] = progress
168
+
169
+ if error is not None:
170
+ message["error"] = error
171
+
172
+ await self.send_to_task_client(task_id, message)
173
+
174
+ async def send_execution_start(self, task_id: uuid.UUID):
175
+ """发送任务开始执行消息"""
176
+ # 获取客户端ID
177
+ client_id = self.task_client_mapping.get(task_id)
178
+
179
+ # 更新任务状态为运行中
180
+ self.task_manager.start_task(task_id)
181
+
182
+ # 获取客户端的任务队列信息
183
+ client_tasks = []
184
+ if client_id:
185
+ client_tasks = self.task_manager.get_client_tasks(client_id)
186
+
187
+ # 获取全局任务队列信息
188
+ global_queue_info = self.task_manager.get_global_queue_info()
189
+
190
+ # 获取当前任务在队列中的位置
191
+ queue_position = self.task_manager.get_queue_position(task_id)
192
+
193
+ message = {
194
+ "task_id": str(task_id),
195
+ "type": "execution start",
196
+ "client_tasks": {
197
+ "total": len(client_tasks),
198
+ "tasks": client_tasks
199
+ },
200
+ "global_queue": global_queue_info,
201
+ "queue_position": queue_position
202
+ }
203
+ await self.send_to_task_client(task_id, message)
204
+
205
+ async def send_executing(self, task_id: uuid.UUID, node: str):
206
+ """发送节点正在执行消息"""
207
+ # 获取客户端ID
208
+ client_id = self.task_client_mapping.get(task_id)
209
+
210
+ # 获取客户端的任务队列信息
211
+ client_tasks = []
212
+ if client_id:
213
+ client_tasks = self.task_manager.get_client_tasks(client_id)
214
+
215
+ # 获取全局任务队列信息
216
+ global_queue_info = self.task_manager.get_global_queue_info()
217
+
218
+ # 获取当前任务在队列中的位置
219
+ queue_position = self.task_manager.get_queue_position(task_id)
220
+
221
+ message = {
222
+ "task_id": str(task_id),
223
+ "type": "executing",
224
+ "node": node,
225
+ "client_tasks": {
226
+ "total": len(client_tasks),
227
+ "tasks": client_tasks
228
+ },
229
+ "global_queue": global_queue_info,
230
+ "queue_position": queue_position
231
+ }
232
+ await self.send_to_task_client(task_id, message)
233
+
234
+ async def send_progress(self, task_id: uuid.UUID, node: str, progress: float):
235
+ """发送节点执行进度消息"""
236
+ # 获取客户端ID
237
+ client_id = self.task_client_mapping.get(task_id)
238
+
239
+ # 获取客户端的任务队列信息
240
+ client_tasks = []
241
+ if client_id:
242
+ client_tasks = self.task_manager.get_client_tasks(client_id)
243
+
244
+ # 获取全局任务队列信息
245
+ global_queue_info = self.task_manager.get_global_queue_info()
246
+
247
+ # 获取当前任务在队列中的位置
248
+ queue_position = self.task_manager.get_queue_position(task_id)
249
+
250
+ message = {
251
+ "task_id": str(task_id),
252
+ "type": "progress",
253
+ "node": node,
254
+ "progress": progress,
255
+ "client_tasks": {
256
+ "total": len(client_tasks),
257
+ "tasks": client_tasks
258
+ },
259
+ "global_queue": global_queue_info,
260
+ "queue_position": queue_position
261
+ }
262
+ await self.send_to_task_client(task_id, message)
263
+
264
+ async def send_executed(self, task_id: uuid.UUID, node: str, result: Any = None):
265
+ """发送节点执行完成消息"""
266
+ # 获取客户端ID
267
+ client_id = self.task_client_mapping.get(task_id)
268
+
269
+ # 获取客户端的任务队列信息
270
+ client_tasks = []
271
+ if client_id:
272
+ client_tasks = self.task_manager.get_client_tasks(client_id)
273
+
274
+ # 获取全局任务队列信息
275
+ global_queue_info = self.task_manager.get_global_queue_info()
276
+
277
+ # 获取当前任务在队列中的位置
278
+ queue_position = self.task_manager.get_queue_position(task_id)
279
+
280
+ message = {
281
+ "task_id": str(task_id),
282
+ "type": "executed",
283
+ "node": node,
284
+ "client_tasks": {
285
+ "total": len(client_tasks),
286
+ "tasks": client_tasks
287
+ },
288
+ "global_queue": global_queue_info,
289
+ "queue_position": queue_position
290
+ }
291
+
292
+ if result is not None:
293
+ # 检查是否为协程对象
294
+ import asyncio
295
+ if asyncio.iscoroutine(result):
296
+ message["result"] = "<coroutine object>"
297
+ else:
298
+ # 尝试序列化结果,如果失败则转换为字符串
299
+ try:
300
+ message["result"] = result
301
+ except TypeError:
302
+ message["result"] = str(result)
303
+
304
+ await self.send_to_task_client(task_id, message)
305
+
306
+ async def send_execution_error(self, task_id: uuid.UUID, node: str, error: str):
307
+ """发送执行错误消息"""
308
+ # 获取客户端ID
309
+ client_id = self.task_client_mapping.get(task_id)
310
+
311
+ # 更新任务状态为失败
312
+ self.task_manager.fail_task(task_id, error)
313
+
314
+ # 获取客户端的任务队列信息
315
+ client_tasks = []
316
+ if client_id:
317
+ client_tasks = self.task_manager.get_client_tasks(client_id)
318
+
319
+ # 获取全局任务队列信息
320
+ global_queue_info = self.task_manager.get_global_queue_info()
321
+
322
+ # 获取当前任务在队列中的位置
323
+ queue_position = self.task_manager.get_queue_position(task_id)
324
+
325
+ message = {
326
+ "task_id": str(task_id),
327
+ "type": "execution error",
328
+ "node": node,
329
+ "error": error,
330
+ "client_tasks": {
331
+ "total": len(client_tasks),
332
+ "tasks": client_tasks
333
+ },
334
+ "global_queue": global_queue_info,
335
+ "queue_position": queue_position
336
+ }
337
+ await self.send_to_task_client(task_id, message)
338
+
339
+ def _start_cleanup_task(self):
340
+ """启动定期清理任务"""
341
+ # 检查是否有运行的事件循环
342
+ try:
343
+ loop = asyncio.get_running_loop()
344
+ except RuntimeError:
345
+ # 没有运行的事件循环,延迟启动清理任务
346
+ return
347
+
348
+ if self._cleanup_task is None:
349
+ self._cleanup_task = loop.create_task(self._cleanup_expired_clients())
350
+
351
+ async def _cleanup_expired_clients(self):
352
+ """定期清理过期的客户端记录"""
353
+ while True:
354
+ try:
355
+ # 每小时执行一次清理
356
+ await asyncio.sleep(60 * 60)
357
+ current_time = asyncio.get_event_loop().time()
358
+ expired_clients = []
359
+
360
+ # 查找过期的客户端记录
361
+ for client_id, last_active_time in self.client_history.items():
362
+ # 如果客户端不在活动连接中且超过过期时间,则标记为过期
363
+ if (client_id not in self.active_connections and
364
+ current_time - last_active_time > self.client_history_expiry):
365
+ expired_clients.append(client_id)
366
+
367
+ # 清理过期客户端的订阅记录
368
+ for client_id in expired_clients:
369
+ # 移除客户端的订阅记录
370
+ if client_id in self.client_task_subscriptions:
371
+ del self.client_task_subscriptions[client_id]
372
+
373
+ # 移除客户端的历史记录
374
+ del self.client_history[client_id]
375
+
376
+ # 清理任务映射中的过期记录
377
+ tasks_to_remove = []
378
+ for task_id, mapped_client_id in self.task_client_mapping.items():
379
+ if mapped_client_id == client_id:
380
+ tasks_to_remove.append(task_id)
381
+
382
+ for task_id in tasks_to_remove:
383
+ del self.task_client_mapping[task_id]
384
+
385
+ logger.info(f"已清理过期客户端 {client_id} 的所有记录")
386
+
387
+ if expired_clients:
388
+ logger.info(f"清理了 {len(expired_clients)} 个过期客户端记录")
389
+
390
+ except Exception as e:
391
+ logger.error(f"清理过期客户端时出错: {e}")
392
+ # 出错后继续尝试清理,不中断任务
393
+
394
+ def set_client_history_expiry(self, seconds: int):
395
+ """设置客户端记录过期时间"""
396
+ self.client_history_expiry = seconds
397
+ logger.info(f"客户端记录过期时间已设置为 {seconds} 秒")
398
+
399
+ async def send_node_output(self, task_id: uuid.UUID, node: str, port: str, value: Any):
400
+ """发送节点输出结果消息"""
401
+ message = {
402
+ "task_id": str(task_id),
403
+ "type": "node output",
404
+ "node": node,
405
+ "port": port
406
+ }
407
+
408
+ # 尝试将值序列化为JSON
409
+ try:
410
+ import json
411
+ # 如果值是基本类型,直接使用
412
+ if isinstance(value, (str, int, float, bool)) or value is None:
413
+ message["value"] = value
414
+ else:
415
+ # 尝试JSON序列化复杂对象
416
+ json.dumps(value) # 测试是否可以序列化
417
+ message["value"] = value
418
+ except (TypeError, ValueError, json.JSONDecodeError):
419
+ # 如果序列化失败,转换为字符串
420
+ message["value"] = str(value)
421
+
422
+ await self.send_to_task_client(task_id, message)
423
+
424
+ # 全局WebSocket管理器实例
425
+ websocket_manager = WebSocketManager()
@@ -0,0 +1,152 @@
1
+ from fastapi import FastAPI
2
+ import uvicorn
3
+ from fastapi import APIRouter
4
+ from loguru import logger
5
+ from urllib.parse import urlparse
6
+ from fastapi import HTTPException, Request
7
+ from fastapi.middleware.cors import CORSMiddleware
8
+ from fastapi.openapi.utils import get_openapi
9
+ from service_forge.api.routers.websocket.websocket_router import websocket_router
10
+ from service_forge.api.routers.service.service_router import service_router
11
+ from service_forge.api.routers.feedback.feedback_router import router as feedback_router
12
+ from service_forge.sft.config.sf_metadata import load_metadata
13
+ from service_forge.sft.util.name_util import get_service_url_name
14
+
15
+ def is_trusted_origin(origin_host: str, host: str, trusted_root: str = "ring.shiweinan.com") -> bool:
16
+ """
17
+ Check if the origin host is trusted based on domain matching.
18
+
19
+ Args:
20
+ origin_host: The hostname from the origin header
21
+ host: The hostname from the host header
22
+ trusted_root: The trusted root domain (can be customized)
23
+
24
+ Returns:
25
+ bool: True if the origin is trusted, False otherwise
26
+ """
27
+ # Convert to lowercase to avoid case sensitivity issues
28
+ origin_host = origin_host.lower()
29
+ host = host.lower()
30
+
31
+ # Allow same domain, or subdomains under the same trusted root
32
+ return (
33
+ origin_host == host or
34
+ origin_host.endswith("." + trusted_root) or
35
+ host.endswith("." + trusted_root)
36
+ )
37
+
38
+
39
+ def create_app(
40
+ app: FastAPI | None = None,
41
+ routers: list[APIRouter] | None = None,
42
+ cors_origins: list[str] | None = None,
43
+ enable_auth_middleware: bool = True,
44
+ trusted_domain: str = "ring.shiweinan.com",
45
+ root_path: str | None = None,
46
+ ) -> FastAPI:
47
+ """
48
+ Create or configure a FastAPI app with common middleware and configuration.
49
+
50
+ Args:
51
+ app: Optional existing FastAPI instance. If None, creates a new one.
52
+ routers: List of APIRouter instances to include
53
+ cors_origins: List of allowed CORS origins. Defaults to ["*"]
54
+ enable_auth_middleware: Whether to enable authentication middleware
55
+ trusted_domain: Trusted domain for origin validation
56
+
57
+ Returns:
58
+ FastAPI: Configured FastAPI application instance
59
+ """
60
+ if app is None:
61
+ app = FastAPI(root_path=root_path)
62
+
63
+ # Configure CORS middleware
64
+ if cors_origins is None:
65
+ cors_origins = ["*"]
66
+
67
+ app.add_middleware(
68
+ CORSMiddleware,
69
+ allow_origins=cors_origins,
70
+ allow_credentials=True,
71
+ allow_methods=["*"],
72
+ allow_headers=["*"],
73
+ )
74
+
75
+ # Include routers if provided
76
+ if routers:
77
+ for router in routers:
78
+ app.include_router(router)
79
+
80
+ # Always include WebSocket router
81
+ app.include_router(websocket_router)
82
+
83
+ # Include Feedback router
84
+ app.include_router(feedback_router)
85
+
86
+ # Always include Service router
87
+ app.include_router(service_router)
88
+
89
+ # Add authentication middleware if enabled
90
+ if enable_auth_middleware:
91
+ @app.middleware("http")
92
+ async def auth_middleware(request: Request, call_next):
93
+ """
94
+ Authentication middleware for API routes.
95
+
96
+ Validates user authentication for /api routes with origin-based
97
+ trust verification and X-User-ID header validation.
98
+ """
99
+ if request.url.path.startswith("/api"):
100
+ origin = request.headers.get("origin") or request.headers.get("referer")
101
+ scheme = request.url.scheme
102
+ host = request.headers.get("host", "")
103
+ is_same_origin = False
104
+
105
+ logger.debug(f"origin {origin}, host:{host}")
106
+
107
+ if origin and host:
108
+ try:
109
+ parsed_origin = urlparse(origin)
110
+ parsed_host = urlparse(f"{scheme}://{host}")
111
+ is_same_origin = (
112
+ parsed_origin.hostname == parsed_host.hostname
113
+ and parsed_origin.port == parsed_host.port
114
+ and is_trusted_origin(parsed_origin.hostname, parsed_host.hostname, trusted_domain)
115
+ )
116
+ except Exception:
117
+ pass # If parsing fails, continue with default behavior
118
+ if not is_same_origin:
119
+ headers = request.headers
120
+ user_id = headers.get("X-User-ID")
121
+ if not user_id:
122
+ raise HTTPException(status_code=401, detail="Unauthorized")
123
+
124
+ request.state.user_id = user_id
125
+ else:
126
+ # Same-origin requests can skip auth, but still set default user_id
127
+ request.state.user_id = "0" # Can be None or default value as needed
128
+
129
+ return await call_next(request)
130
+
131
+ return app
132
+
133
+ async def start_fastapi_server(host: str, port: int):
134
+ try:
135
+ config = uvicorn.Config(
136
+ fastapi_app,
137
+ host=host,
138
+ port=int(port),
139
+ log_level="info",
140
+ access_log=True
141
+ )
142
+ server = uvicorn.Server(config)
143
+ await server.serve()
144
+ except Exception as e:
145
+ logger.error(f"Server error: {e}")
146
+ raise
147
+
148
+ try:
149
+ metadata = load_metadata("sf-meta.yaml")
150
+ fastapi_app = create_app(enable_auth_middleware=False, root_path=f"/api/v1/{get_service_url_name(metadata.name, metadata.version)}")
151
+ except Exception as e:
152
+ fastapi_app = create_app(enable_auth_middleware=False, root_path=None)