service-forge 0.1.11__py3-none-any.whl → 0.1.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/http_api.py +4 -0
- service_forge/api/routers/feedback/feedback_router.py +148 -0
- service_forge/api/routers/service/service_router.py +22 -32
- service_forge/current_service.py +14 -0
- service_forge/db/database.py +46 -32
- service_forge/db/migrations/feedback_migration.py +154 -0
- service_forge/db/models/__init__.py +0 -0
- service_forge/db/models/feedback.py +33 -0
- service_forge/llm/__init__.py +5 -0
- service_forge/model/feedback.py +30 -0
- service_forge/service.py +118 -126
- service_forge/service_config.py +42 -156
- service_forge/sft/cli.py +39 -0
- service_forge/sft/cmd/remote_deploy.py +160 -0
- service_forge/sft/cmd/remote_list_tars.py +111 -0
- service_forge/sft/config/injector.py +46 -24
- service_forge/sft/config/injector_default_files.py +1 -1
- service_forge/sft/config/sft_config.py +55 -8
- service_forge/storage/__init__.py +5 -0
- service_forge/storage/feedback_storage.py +245 -0
- service_forge/utils/default_type_converter.py +1 -1
- service_forge/utils/type_converter.py +5 -0
- service_forge/utils/workflow_clone.py +3 -2
- service_forge/workflow/node.py +8 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +1 -1
- service_forge/workflow/trigger.py +4 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +2 -0
- service_forge/workflow/triggers/fast_api_trigger.py +32 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +3 -0
- service_forge/workflow/triggers/once_trigger.py +4 -1
- service_forge/workflow/triggers/period_trigger.py +4 -1
- service_forge/workflow/triggers/websocket_api_trigger.py +15 -11
- service_forge/workflow/workflow.py +74 -31
- service_forge/workflow/workflow_callback.py +3 -2
- service_forge/workflow/workflow_config.py +66 -0
- service_forge/workflow/workflow_factory.py +86 -85
- service_forge/workflow/workflow_group.py +33 -9
- {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/METADATA +1 -1
- {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/RECORD +41 -31
- service_forge/api/routers/service/__init__.py +0 -4
- {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/WHEEL +0 -0
- {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from sqlalchemy import select, and_
|
|
6
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
7
|
+
|
|
8
|
+
from ..db.models.feedback import FeedbackBase
|
|
9
|
+
from ..db.database import DatabaseManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeedbackStorage:
|
|
13
|
+
"""反馈存储管理器 - 使用 PostgreSQL 数据库存储"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, database_manager: DatabaseManager = None):
|
|
16
|
+
self.database_manager = database_manager
|
|
17
|
+
self._db = None
|
|
18
|
+
# 内存存储后备(用于数据库未配置时)
|
|
19
|
+
self._storage: dict[str, dict[str, Any]] = {}
|
|
20
|
+
self._task_index: dict[str, list[str]] = {}
|
|
21
|
+
self._workflow_index: dict[str, list[str]] = {}
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def db(self):
|
|
25
|
+
"""延迟获取数据库连接"""
|
|
26
|
+
if self._db is None and self.database_manager is not None:
|
|
27
|
+
self._db = self.database_manager.get_default_postgres_database()
|
|
28
|
+
return self._db
|
|
29
|
+
|
|
30
|
+
async def create_feedback(
|
|
31
|
+
self,
|
|
32
|
+
task_id: str,
|
|
33
|
+
workflow_name: str,
|
|
34
|
+
rating: Optional[int] = None,
|
|
35
|
+
comment: Optional[str] = None,
|
|
36
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
37
|
+
) -> dict[str, Any]:
|
|
38
|
+
"""创建新的反馈记录"""
|
|
39
|
+
if self.db is None:
|
|
40
|
+
logger.warning("数据库未初始化,使用内存存储(数据将在重启后丢失)")
|
|
41
|
+
return self._create_feedback_in_memory(task_id, workflow_name, rating, comment, metadata)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
feedback_id = uuid4()
|
|
45
|
+
created_at = datetime.now()
|
|
46
|
+
|
|
47
|
+
feedback = FeedbackBase(
|
|
48
|
+
feedback_id=feedback_id,
|
|
49
|
+
task_id=task_id,
|
|
50
|
+
workflow_name=workflow_name,
|
|
51
|
+
rating=rating,
|
|
52
|
+
comment=comment,
|
|
53
|
+
extra_metadata=metadata or {},
|
|
54
|
+
created_at=created_at,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
session_factory = await self.db.get_session_factory()
|
|
58
|
+
async with session_factory() as session:
|
|
59
|
+
session.add(feedback)
|
|
60
|
+
await session.commit()
|
|
61
|
+
await session.refresh(feedback)
|
|
62
|
+
|
|
63
|
+
logger.info(f"创建反馈: feedback_id={feedback_id}, task_id={task_id}, workflow={workflow_name}")
|
|
64
|
+
return feedback.to_dict()
|
|
65
|
+
|
|
66
|
+
except SQLAlchemyError as e:
|
|
67
|
+
logger.error(f"数据库创建反馈失败: {e}")
|
|
68
|
+
raise
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"创建反馈失败: {e}")
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
async def get_feedback(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
74
|
+
"""根据反馈ID获取反馈"""
|
|
75
|
+
if self.db is None:
|
|
76
|
+
return self._get_feedback_from_memory(feedback_id)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
session_factory = await self.db.get_session_factory()
|
|
80
|
+
async with session_factory() as session:
|
|
81
|
+
result = await session.execute(
|
|
82
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
83
|
+
)
|
|
84
|
+
feedback = result.scalar_one_or_none()
|
|
85
|
+
return feedback.to_dict() if feedback else None
|
|
86
|
+
|
|
87
|
+
except SQLAlchemyError as e:
|
|
88
|
+
logger.error(f"查询反馈失败: {e}")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
async def get_feedbacks_by_task(self, task_id: str) -> list[dict[str, Any]]:
|
|
92
|
+
"""根据任务ID获取所有反馈"""
|
|
93
|
+
if self.db is None:
|
|
94
|
+
return self._get_feedbacks_by_task_from_memory(task_id)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
session_factory = await self.db.get_session_factory()
|
|
98
|
+
async with session_factory() as session:
|
|
99
|
+
result = await session.execute(
|
|
100
|
+
select(FeedbackBase)
|
|
101
|
+
.where(FeedbackBase.task_id == task_id)
|
|
102
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
103
|
+
)
|
|
104
|
+
feedbacks = result.scalars().all()
|
|
105
|
+
return [f.to_dict() for f in feedbacks]
|
|
106
|
+
|
|
107
|
+
except SQLAlchemyError as e:
|
|
108
|
+
logger.error(f"查询任务反馈失败: {e}")
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
async def get_feedbacks_by_workflow(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
112
|
+
"""根据工作流名称获取所有反馈"""
|
|
113
|
+
if self.db is None:
|
|
114
|
+
return self._get_feedbacks_by_workflow_from_memory(workflow_name)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
session_factory = await self.db.get_session_factory()
|
|
118
|
+
async with session_factory() as session:
|
|
119
|
+
result = await session.execute(
|
|
120
|
+
select(FeedbackBase)
|
|
121
|
+
.where(FeedbackBase.workflow_name == workflow_name)
|
|
122
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
123
|
+
)
|
|
124
|
+
feedbacks = result.scalars().all()
|
|
125
|
+
return [f.to_dict() for f in feedbacks]
|
|
126
|
+
|
|
127
|
+
except SQLAlchemyError as e:
|
|
128
|
+
logger.error(f"查询工作流反馈失败: {e}")
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
async def get_all_feedbacks(self) -> list[dict[str, Any]]:
|
|
132
|
+
"""获取所有反馈"""
|
|
133
|
+
if self.db is None:
|
|
134
|
+
return self._get_all_feedbacks_from_memory()
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
session_factory = await self.db.get_session_factory()
|
|
138
|
+
async with session_factory() as session:
|
|
139
|
+
result = await session.execute(
|
|
140
|
+
select(FeedbackBase).order_by(FeedbackBase.created_at.desc())
|
|
141
|
+
)
|
|
142
|
+
feedbacks = result.scalars().all()
|
|
143
|
+
return [f.to_dict() for f in feedbacks]
|
|
144
|
+
|
|
145
|
+
except SQLAlchemyError as e:
|
|
146
|
+
logger.error(f"查询所有反馈失败: {e}")
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
150
|
+
"""删除反馈"""
|
|
151
|
+
if self.db is None:
|
|
152
|
+
return self._delete_feedback_from_memory(feedback_id)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
session_factory = await self.db.get_session_factory()
|
|
156
|
+
async with session_factory() as session:
|
|
157
|
+
result = await session.execute(
|
|
158
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
159
|
+
)
|
|
160
|
+
feedback = result.scalar_one_or_none()
|
|
161
|
+
|
|
162
|
+
if not feedback:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
await session.delete(feedback)
|
|
166
|
+
await session.commit()
|
|
167
|
+
|
|
168
|
+
logger.info(f"删除反馈: feedback_id={feedback_id}")
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
except SQLAlchemyError as e:
|
|
172
|
+
logger.error(f"删除反馈失败: {e}")
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
# ========== 内存存储后备方法(用于数据库未配置时) ==========
|
|
176
|
+
|
|
177
|
+
def _create_feedback_in_memory(
|
|
178
|
+
self,
|
|
179
|
+
task_id: str,
|
|
180
|
+
workflow_name: str,
|
|
181
|
+
rating: Optional[int],
|
|
182
|
+
comment: Optional[str],
|
|
183
|
+
metadata: Optional[dict[str, Any]],
|
|
184
|
+
) -> dict[str, Any]:
|
|
185
|
+
"""内存存储版本"""
|
|
186
|
+
feedback_id = str(uuid4())
|
|
187
|
+
created_at = datetime.now()
|
|
188
|
+
|
|
189
|
+
feedback_data = {
|
|
190
|
+
"feedback_id": feedback_id,
|
|
191
|
+
"task_id": task_id,
|
|
192
|
+
"workflow_name": workflow_name,
|
|
193
|
+
"rating": rating,
|
|
194
|
+
"comment": comment,
|
|
195
|
+
"metadata": metadata or {},
|
|
196
|
+
"created_at": created_at,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
self._storage[feedback_id] = feedback_data
|
|
200
|
+
|
|
201
|
+
if task_id not in self._task_index:
|
|
202
|
+
self._task_index[task_id] = []
|
|
203
|
+
self._task_index[task_id].append(feedback_id)
|
|
204
|
+
|
|
205
|
+
if workflow_name not in self._workflow_index:
|
|
206
|
+
self._workflow_index[workflow_name] = []
|
|
207
|
+
self._workflow_index[workflow_name].append(feedback_id)
|
|
208
|
+
|
|
209
|
+
logger.info(f"[内存存储]创建反馈: feedback_id={feedback_id}")
|
|
210
|
+
return feedback_data
|
|
211
|
+
|
|
212
|
+
def _get_feedback_from_memory(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
213
|
+
return self._storage.get(feedback_id)
|
|
214
|
+
|
|
215
|
+
def _get_feedbacks_by_task_from_memory(self, task_id: str) -> list[dict[str, Any]]:
|
|
216
|
+
feedback_ids = self._task_index.get(task_id, [])
|
|
217
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
218
|
+
|
|
219
|
+
def _get_feedbacks_by_workflow_from_memory(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
220
|
+
feedback_ids = self._workflow_index.get(workflow_name, [])
|
|
221
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
222
|
+
|
|
223
|
+
def _get_all_feedbacks_from_memory(self) -> list[dict[str, Any]]:
|
|
224
|
+
return list(self._storage.values())
|
|
225
|
+
|
|
226
|
+
def _delete_feedback_from_memory(self, feedback_id: str) -> bool:
|
|
227
|
+
if feedback_id not in self._storage:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
feedback = self._storage[feedback_id]
|
|
231
|
+
task_id = feedback["task_id"]
|
|
232
|
+
workflow_name = feedback["workflow_name"]
|
|
233
|
+
|
|
234
|
+
if task_id in self._task_index:
|
|
235
|
+
self._task_index[task_id].remove(feedback_id)
|
|
236
|
+
if workflow_name in self._workflow_index:
|
|
237
|
+
self._workflow_index[workflow_name].remove(feedback_id)
|
|
238
|
+
|
|
239
|
+
del self._storage[feedback_id]
|
|
240
|
+
logger.info(f"[内存存储]删除反馈: feedback_id={feedback_id}")
|
|
241
|
+
return True
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# 全局单例实例
|
|
245
|
+
feedback_storage = FeedbackStorage()
|
|
@@ -2,8 +2,8 @@ from ..utils.type_converter import TypeConverter
|
|
|
2
2
|
from ..workflow.workflow import Workflow
|
|
3
3
|
from ..api.http_api import fastapi_app
|
|
4
4
|
from ..api.kafka_api import KafkaApp, kafka_app
|
|
5
|
-
from fastapi import FastAPI
|
|
6
5
|
from ..workflow.workflow_type import WorkflowType, workflow_type_register
|
|
6
|
+
from fastapi import FastAPI
|
|
7
7
|
|
|
8
8
|
type_converter = TypeConverter()
|
|
9
9
|
type_converter.register(str, Workflow, lambda s, node: node.sub_workflows.get_workflow(s))
|
|
@@ -1,6 +1,8 @@
|
|
|
1
1
|
from typing import Any, Callable, Type, Dict, Tuple, Set, List
|
|
2
2
|
from collections import deque
|
|
3
3
|
import inspect
|
|
4
|
+
import traceback
|
|
5
|
+
from pydantic import BaseModel
|
|
4
6
|
from typing_extensions import get_origin, get_args
|
|
5
7
|
|
|
6
8
|
def is_type(value, dst_type):
|
|
@@ -57,6 +59,9 @@ class TypeConverter:
|
|
|
57
59
|
except Exception:
|
|
58
60
|
pass
|
|
59
61
|
|
|
62
|
+
if issubclass(dst_type, BaseModel) and isinstance(value, dict):
|
|
63
|
+
return dst_type(**value)
|
|
64
|
+
|
|
60
65
|
path = self._find_path(src_type, dst_type)
|
|
61
66
|
if not path:
|
|
62
67
|
raise TypeError(f"No conversion path found from {src_type.__name__} to {dst_type.__name__}.")
|
|
@@ -57,8 +57,8 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
|
|
|
57
57
|
}
|
|
58
58
|
|
|
59
59
|
workflow = Workflow(
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
id=self.id,
|
|
61
|
+
config=self.config,
|
|
62
62
|
nodes=[node_map[node] for node in self.nodes],
|
|
63
63
|
input_ports=[port_map[port] for port in self.input_ports],
|
|
64
64
|
output_ports=[port_map[port] for port in self.output_ports],
|
|
@@ -69,6 +69,7 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
|
|
|
69
69
|
callbacks=self.callbacks,
|
|
70
70
|
task_id=task_id,
|
|
71
71
|
real_trigger_node=trigger_node,
|
|
72
|
+
global_context=self.global_context,
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
for node in workflow.nodes:
|
service_forge/workflow/node.py
CHANGED
|
@@ -10,6 +10,7 @@ from .context import Context
|
|
|
10
10
|
from ..utils.register import Register
|
|
11
11
|
from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
12
12
|
from ..utils.workflow_clone import node_clone
|
|
13
|
+
from .workflow_callback import CallbackEvent
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from .workflow import Workflow
|
|
@@ -62,6 +63,10 @@ class Node(ABC):
|
|
|
62
63
|
def database_manager(self) -> DatabaseManager:
|
|
63
64
|
return self.workflow.database_manager
|
|
64
65
|
|
|
66
|
+
@property
|
|
67
|
+
def global_context(self) -> Context:
|
|
68
|
+
return self.workflow.global_context
|
|
69
|
+
|
|
65
70
|
def backup(self) -> None:
|
|
66
71
|
# do NOT use deepcopy here
|
|
67
72
|
# self.bak_context = deepcopy(self.context)
|
|
@@ -181,4 +186,7 @@ class Node(ABC):
|
|
|
181
186
|
def _clone(self, context: Context) -> Node:
|
|
182
187
|
return node_clone(self, context)
|
|
183
188
|
|
|
189
|
+
async def stream_output(self, data: Any) -> None:
|
|
190
|
+
await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
|
|
191
|
+
|
|
184
192
|
node_register = Register[Node]()
|
|
@@ -36,6 +36,6 @@ class QueryLLMNode(Node):
|
|
|
36
36
|
prompt = f.read()
|
|
37
37
|
|
|
38
38
|
print(f"prompt: {prompt} temperature: {temperature}")
|
|
39
|
-
response = chat_stream(prompt, system_prompt, Model.
|
|
39
|
+
response = chat_stream(prompt, system_prompt, Model.GEMINI, temperature)
|
|
40
40
|
for chunk in response:
|
|
41
41
|
yield chunk
|
|
@@ -19,6 +19,10 @@ class Trigger(Node, ABC):
|
|
|
19
19
|
async def _run(self) -> AsyncIterator[bool]:
|
|
20
20
|
...
|
|
21
21
|
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
24
|
+
...
|
|
25
|
+
|
|
22
26
|
def trigger(self, task_id: uuid.UUID) -> bool:
|
|
23
27
|
self.prepare_output_edges(self.get_output_port_by_name('trigger'), True)
|
|
24
28
|
return task_id
|
|
@@ -33,6 +33,9 @@ class FastAPITrigger(Trigger):
|
|
|
33
33
|
super().__init__(name)
|
|
34
34
|
self.events = {}
|
|
35
35
|
self.is_setup_route = False
|
|
36
|
+
self.app = None
|
|
37
|
+
self.route_path = None
|
|
38
|
+
self.route_method = None
|
|
36
39
|
|
|
37
40
|
@staticmethod
|
|
38
41
|
def serialize_result(result: Any):
|
|
@@ -141,6 +144,11 @@ class FastAPITrigger(Trigger):
|
|
|
141
144
|
async def handler(request: Request):
|
|
142
145
|
return await self.handle_request(request, data_type, extractor, is_stream)
|
|
143
146
|
|
|
147
|
+
# Save route information for cleanup
|
|
148
|
+
self.app = app
|
|
149
|
+
self.route_path = path
|
|
150
|
+
self.route_method = method.upper()
|
|
151
|
+
|
|
144
152
|
if method == "GET":
|
|
145
153
|
app.get(path)(handler)
|
|
146
154
|
elif method == "POST":
|
|
@@ -167,3 +175,27 @@ class FastAPITrigger(Trigger):
|
|
|
167
175
|
logger.error(f"Error in FastAPITrigger._run: {e}")
|
|
168
176
|
continue
|
|
169
177
|
|
|
178
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
179
|
+
if self.is_setup_route:
|
|
180
|
+
# Remove the route from the app
|
|
181
|
+
if self.app and self.route_path and self.route_method:
|
|
182
|
+
# Find and remove matching route
|
|
183
|
+
routes_to_remove = []
|
|
184
|
+
for route in self.app.routes:
|
|
185
|
+
if hasattr(route, "path") and hasattr(route, "methods"):
|
|
186
|
+
if route.path == self.route_path and self.route_method in route.methods:
|
|
187
|
+
routes_to_remove.append(route)
|
|
188
|
+
|
|
189
|
+
# Remove found routes
|
|
190
|
+
for route in routes_to_remove:
|
|
191
|
+
try:
|
|
192
|
+
self.app.routes.remove(route)
|
|
193
|
+
logger.info(f"Removed route {self.route_method} {self.route_path} from FastAPI app")
|
|
194
|
+
except ValueError:
|
|
195
|
+
logger.warning(f"Route {self.route_method} {self.route_path} not found in app.routes")
|
|
196
|
+
|
|
197
|
+
# Reset route information
|
|
198
|
+
self.app = None
|
|
199
|
+
self.route_path = None
|
|
200
|
+
self.route_method = None
|
|
201
|
+
self.is_setup_route = False
|
|
@@ -23,4 +23,7 @@ class PeriodTrigger(Trigger):
|
|
|
23
23
|
async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
|
|
24
24
|
for _ in range(times):
|
|
25
25
|
await asyncio.sleep(period)
|
|
26
|
-
yield uuid.uuid4()
|
|
26
|
+
yield uuid.uuid4()
|
|
27
|
+
|
|
28
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
29
|
+
pass
|
|
@@ -120,12 +120,16 @@ class WebSocketAPITrigger(Trigger):
|
|
|
120
120
|
self.result_queues[task_id] = asyncio.Queue()
|
|
121
121
|
self.stream_queues[task_id] = asyncio.Queue()
|
|
122
122
|
|
|
123
|
-
|
|
124
|
-
converted_data =
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
123
|
+
if data_type is Any:
|
|
124
|
+
converted_data = message_data
|
|
125
|
+
else:
|
|
126
|
+
try:
|
|
127
|
+
# TODO: message_data is Message, need to convert to dict
|
|
128
|
+
converted_data = data_type(**message_data)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
error_msg = {"error": f"Failed to convert data: {str(e)}"}
|
|
131
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
132
|
+
return
|
|
129
133
|
|
|
130
134
|
# Always start background task to handle stream output
|
|
131
135
|
asyncio.create_task(self.handle_stream_output(websocket, task_id))
|
|
@@ -143,16 +147,14 @@ class WebSocketAPITrigger(Trigger):
|
|
|
143
147
|
|
|
144
148
|
try:
|
|
145
149
|
while True:
|
|
146
|
-
|
|
147
|
-
data = await websocket.receive_text()
|
|
150
|
+
data = await websocket.receive()
|
|
148
151
|
try:
|
|
149
|
-
message = json.loads(data)
|
|
150
|
-
|
|
152
|
+
# message = json.loads(data)
|
|
151
153
|
# Handle the message and trigger workflow
|
|
152
154
|
await self.handle_websocket_message(
|
|
153
155
|
websocket,
|
|
154
156
|
data_type,
|
|
155
|
-
|
|
157
|
+
data
|
|
156
158
|
)
|
|
157
159
|
except json.JSONDecodeError:
|
|
158
160
|
error_msg = {"error": "Invalid JSON format"}
|
|
@@ -182,3 +184,5 @@ class WebSocketAPITrigger(Trigger):
|
|
|
182
184
|
logger.error(f"Error in WebSocketAPITrigger._run: {e}")
|
|
183
185
|
continue
|
|
184
186
|
|
|
187
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
188
|
+
pass
|
|
@@ -12,12 +12,14 @@ from .edge import Edge
|
|
|
12
12
|
from ..db.database import DatabaseManager
|
|
13
13
|
from ..utils.workflow_clone import workflow_clone
|
|
14
14
|
from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
|
|
15
|
+
from .workflow_config import WorkflowConfig
|
|
16
|
+
from .context import Context
|
|
15
17
|
|
|
16
18
|
class Workflow:
|
|
17
19
|
def __init__(
|
|
18
20
|
self,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
+
id: uuid.UUID,
|
|
22
|
+
config: WorkflowConfig,
|
|
21
23
|
nodes: list[Node],
|
|
22
24
|
input_ports: list[Port],
|
|
23
25
|
output_ports: list[Port],
|
|
@@ -30,9 +32,12 @@ class Workflow:
|
|
|
30
32
|
# for run
|
|
31
33
|
task_id: uuid.UUID = None,
|
|
32
34
|
real_trigger_node: Trigger = None,
|
|
35
|
+
|
|
36
|
+
# global variables
|
|
37
|
+
global_context: Context = None,
|
|
33
38
|
) -> None:
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
39
|
+
self.id = id
|
|
40
|
+
self.config = config
|
|
36
41
|
self.nodes = nodes
|
|
37
42
|
self.ready_nodes: list[Node] = []
|
|
38
43
|
self.input_ports = input_ports
|
|
@@ -47,7 +52,20 @@ class Workflow:
|
|
|
47
52
|
self.callbacks = callbacks
|
|
48
53
|
self.task_id = task_id
|
|
49
54
|
self.real_trigger_node = real_trigger_node
|
|
55
|
+
self.global_context = global_context
|
|
50
56
|
self._validate()
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def name(self) -> str:
|
|
60
|
+
return self.config.name
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def version(self) -> str:
|
|
64
|
+
return self.config.version
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def description(self) -> str:
|
|
68
|
+
return self.config.description
|
|
51
69
|
|
|
52
70
|
def register_callback(self, callback: WorkflowCallback) -> None:
|
|
53
71
|
self.callbacks.append(callback)
|
|
@@ -61,6 +79,8 @@ class Workflow:
|
|
|
61
79
|
await callback.on_workflow_start(*args, **kwargs)
|
|
62
80
|
elif callback_type == CallbackEvent.ON_WORKFLOW_END:
|
|
63
81
|
await callback.on_workflow_end(*args, **kwargs)
|
|
82
|
+
elif callback_type == CallbackEvent.ON_WORKFLOW_ERROR:
|
|
83
|
+
await callback.on_workflow_error(*args, **kwargs)
|
|
64
84
|
elif callback_type == CallbackEvent.ON_NODE_START:
|
|
65
85
|
await callback.on_node_start(*args, **kwargs)
|
|
66
86
|
elif callback_type == CallbackEvent.ON_NODE_END:
|
|
@@ -104,7 +124,7 @@ class Workflow:
|
|
|
104
124
|
raise ValueError("Multiple trigger nodes found in workflow.")
|
|
105
125
|
return trigger_nodes[0]
|
|
106
126
|
|
|
107
|
-
async def _run_node_with_callbacks(self, node: Node) ->
|
|
127
|
+
async def _run_node_with_callbacks(self, node: Node) -> bool:
|
|
108
128
|
await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
|
|
109
129
|
|
|
110
130
|
try:
|
|
@@ -113,8 +133,13 @@ class Workflow:
|
|
|
113
133
|
await self.handle_node_stream_output(node, result)
|
|
114
134
|
elif asyncio.iscoroutine(result):
|
|
115
135
|
await result
|
|
136
|
+
except Exception as e:
|
|
137
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_ERROR, workflow=self, node=node, error=e)
|
|
138
|
+
logger.error(f"Error when running node {node.name}: {str(e)}, task_id: {self.task_id}")
|
|
139
|
+
return False
|
|
116
140
|
finally:
|
|
117
141
|
await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
|
|
142
|
+
return True
|
|
118
143
|
|
|
119
144
|
async def run_after_trigger(self) -> Any:
|
|
120
145
|
logger.info(f"Running workflow: {self.name}")
|
|
@@ -125,30 +150,41 @@ class Workflow:
|
|
|
125
150
|
for edge in self.get_trigger_node().output_edges:
|
|
126
151
|
edge.end_port.trigger()
|
|
127
152
|
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
for
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
153
|
+
for input_port in self.input_ports:
|
|
154
|
+
if input_port.value is not None:
|
|
155
|
+
input_port.port.node.fill_input(input_port.port, input_port.value)
|
|
156
|
+
|
|
157
|
+
for node in self.nodes:
|
|
158
|
+
for key in node.AUTO_FILL_INPUT_PORTS:
|
|
159
|
+
if key[0] not in [edge.end_port.name for edge in node.input_edges]:
|
|
160
|
+
node.fill_input_by_name(key[0], key[1])
|
|
161
|
+
|
|
162
|
+
while self.ready_nodes:
|
|
163
|
+
nodes = self.ready_nodes.copy()
|
|
164
|
+
self.ready_nodes = []
|
|
165
|
+
|
|
166
|
+
tasks = []
|
|
167
|
+
for node in nodes:
|
|
168
|
+
tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
|
|
169
|
+
|
|
170
|
+
results = await asyncio.gather(*tasks, return_exceptions=True)
|
|
171
|
+
|
|
172
|
+
for i, result in enumerate(results):
|
|
173
|
+
if isinstance(result, Exception):
|
|
174
|
+
for task in tasks:
|
|
175
|
+
if not task.done():
|
|
176
|
+
task.cancel()
|
|
177
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
178
|
+
return
|
|
179
|
+
# raise result
|
|
180
|
+
elif result is False:
|
|
181
|
+
logger.error(f"Node execution failed, stopping workflow: {nodes[i].name}")
|
|
182
|
+
for task in tasks:
|
|
183
|
+
if not task.done():
|
|
184
|
+
task.cancel()
|
|
185
|
+
await asyncio.gather(*tasks, return_exceptions=True)
|
|
186
|
+
return
|
|
187
|
+
# raise RuntimeError(f"Workflow stopped due to node execution failure: {nodes[i].name}")
|
|
152
188
|
|
|
153
189
|
if len(self.output_ports) > 0:
|
|
154
190
|
if len(self.output_ports) == 1:
|
|
@@ -173,8 +209,11 @@ class Workflow:
|
|
|
173
209
|
# TODO: clear new_workflow
|
|
174
210
|
|
|
175
211
|
except Exception as e:
|
|
176
|
-
|
|
177
|
-
|
|
212
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_ERROR, workflow=self, node=None, error=e)
|
|
213
|
+
# error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
|
|
214
|
+
# logger.error(error_msg)
|
|
215
|
+
# await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, node=None, error=e)
|
|
216
|
+
return
|
|
178
217
|
|
|
179
218
|
async def run(self):
|
|
180
219
|
tasks = []
|
|
@@ -186,6 +225,10 @@ class Workflow:
|
|
|
186
225
|
if tasks:
|
|
187
226
|
await asyncio.gather(*tasks)
|
|
188
227
|
|
|
228
|
+
async def stop(self):
|
|
229
|
+
trigger = self.get_trigger_node()
|
|
230
|
+
await trigger._stop()
|
|
231
|
+
|
|
189
232
|
def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
|
|
190
233
|
trigger = self.get_trigger_node()
|
|
191
234
|
task_id = uuid.uuid4()
|
|
@@ -31,7 +31,7 @@ class WorkflowCallback:
|
|
|
31
31
|
pass
|
|
32
32
|
|
|
33
33
|
@abstractmethod
|
|
34
|
-
async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
|
|
34
|
+
async def on_workflow_error(self, workflow: Workflow, node: Node, error: Any) -> None:
|
|
35
35
|
pass
|
|
36
36
|
|
|
37
37
|
@abstractmethod
|
|
@@ -90,7 +90,7 @@ class BuiltinWorkflowCallback(WorkflowCallback):
|
|
|
90
90
|
logger.error(f"发送 workflow_end 消息到 websocket 失败: {e}")
|
|
91
91
|
|
|
92
92
|
@override
|
|
93
|
-
async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
|
|
93
|
+
async def on_workflow_error(self, workflow: Workflow, node: Node | None, error: Any) -> None:
|
|
94
94
|
workflow_result = WorkflowResult(result=error, is_end=False, is_error=True)
|
|
95
95
|
|
|
96
96
|
if workflow.task_id in workflow.real_trigger_node.result_queues:
|
|
@@ -103,6 +103,7 @@ class BuiltinWorkflowCallback(WorkflowCallback):
|
|
|
103
103
|
message = {
|
|
104
104
|
"type": "workflow_error",
|
|
105
105
|
"task_id": str(workflow.task_id),
|
|
106
|
+
"node": node.name if node else None,
|
|
106
107
|
"error": self._serialize_result(error),
|
|
107
108
|
"is_end": False,
|
|
108
109
|
"is_error": True
|