service-forge 0.1.11__py3-none-any.whl → 0.1.24__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (42) hide show
  1. service_forge/api/http_api.py +4 -0
  2. service_forge/api/routers/feedback/feedback_router.py +148 -0
  3. service_forge/api/routers/service/service_router.py +22 -32
  4. service_forge/current_service.py +14 -0
  5. service_forge/db/database.py +46 -32
  6. service_forge/db/migrations/feedback_migration.py +154 -0
  7. service_forge/db/models/__init__.py +0 -0
  8. service_forge/db/models/feedback.py +33 -0
  9. service_forge/llm/__init__.py +5 -0
  10. service_forge/model/feedback.py +30 -0
  11. service_forge/service.py +118 -126
  12. service_forge/service_config.py +42 -156
  13. service_forge/sft/cli.py +39 -0
  14. service_forge/sft/cmd/remote_deploy.py +160 -0
  15. service_forge/sft/cmd/remote_list_tars.py +111 -0
  16. service_forge/sft/config/injector.py +46 -24
  17. service_forge/sft/config/injector_default_files.py +1 -1
  18. service_forge/sft/config/sft_config.py +55 -8
  19. service_forge/storage/__init__.py +5 -0
  20. service_forge/storage/feedback_storage.py +245 -0
  21. service_forge/utils/default_type_converter.py +1 -1
  22. service_forge/utils/type_converter.py +5 -0
  23. service_forge/utils/workflow_clone.py +3 -2
  24. service_forge/workflow/node.py +8 -0
  25. service_forge/workflow/nodes/llm/query_llm_node.py +1 -1
  26. service_forge/workflow/trigger.py +4 -0
  27. service_forge/workflow/triggers/a2a_api_trigger.py +2 -0
  28. service_forge/workflow/triggers/fast_api_trigger.py +32 -0
  29. service_forge/workflow/triggers/kafka_api_trigger.py +3 -0
  30. service_forge/workflow/triggers/once_trigger.py +4 -1
  31. service_forge/workflow/triggers/period_trigger.py +4 -1
  32. service_forge/workflow/triggers/websocket_api_trigger.py +15 -11
  33. service_forge/workflow/workflow.py +74 -31
  34. service_forge/workflow/workflow_callback.py +3 -2
  35. service_forge/workflow/workflow_config.py +66 -0
  36. service_forge/workflow/workflow_factory.py +86 -85
  37. service_forge/workflow/workflow_group.py +33 -9
  38. {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/METADATA +1 -1
  39. {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/RECORD +41 -31
  40. service_forge/api/routers/service/__init__.py +0 -4
  41. {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/WHEEL +0 -0
  42. {service_forge-0.1.11.dist-info → service_forge-0.1.24.dist-info}/entry_points.txt +0 -0
@@ -0,0 +1,245 @@
1
+ from typing import Optional, Any
2
+ from datetime import datetime
3
+ from uuid import uuid4
4
+ from loguru import logger
5
+ from sqlalchemy import select, and_
6
+ from sqlalchemy.exc import SQLAlchemyError
7
+
8
+ from ..db.models.feedback import FeedbackBase
9
+ from ..db.database import DatabaseManager
10
+
11
+
12
+ class FeedbackStorage:
13
+ """反馈存储管理器 - 使用 PostgreSQL 数据库存储"""
14
+
15
+ def __init__(self, database_manager: DatabaseManager = None):
16
+ self.database_manager = database_manager
17
+ self._db = None
18
+ # 内存存储后备(用于数据库未配置时)
19
+ self._storage: dict[str, dict[str, Any]] = {}
20
+ self._task_index: dict[str, list[str]] = {}
21
+ self._workflow_index: dict[str, list[str]] = {}
22
+
23
+ @property
24
+ def db(self):
25
+ """延迟获取数据库连接"""
26
+ if self._db is None and self.database_manager is not None:
27
+ self._db = self.database_manager.get_default_postgres_database()
28
+ return self._db
29
+
30
+ async def create_feedback(
31
+ self,
32
+ task_id: str,
33
+ workflow_name: str,
34
+ rating: Optional[int] = None,
35
+ comment: Optional[str] = None,
36
+ metadata: Optional[dict[str, Any]] = None,
37
+ ) -> dict[str, Any]:
38
+ """创建新的反馈记录"""
39
+ if self.db is None:
40
+ logger.warning("数据库未初始化,使用内存存储(数据将在重启后丢失)")
41
+ return self._create_feedback_in_memory(task_id, workflow_name, rating, comment, metadata)
42
+
43
+ try:
44
+ feedback_id = uuid4()
45
+ created_at = datetime.now()
46
+
47
+ feedback = FeedbackBase(
48
+ feedback_id=feedback_id,
49
+ task_id=task_id,
50
+ workflow_name=workflow_name,
51
+ rating=rating,
52
+ comment=comment,
53
+ extra_metadata=metadata or {},
54
+ created_at=created_at,
55
+ )
56
+
57
+ session_factory = await self.db.get_session_factory()
58
+ async with session_factory() as session:
59
+ session.add(feedback)
60
+ await session.commit()
61
+ await session.refresh(feedback)
62
+
63
+ logger.info(f"创建反馈: feedback_id={feedback_id}, task_id={task_id}, workflow={workflow_name}")
64
+ return feedback.to_dict()
65
+
66
+ except SQLAlchemyError as e:
67
+ logger.error(f"数据库创建反馈失败: {e}")
68
+ raise
69
+ except Exception as e:
70
+ logger.error(f"创建反馈失败: {e}")
71
+ raise
72
+
73
+ async def get_feedback(self, feedback_id: str) -> Optional[dict[str, Any]]:
74
+ """根据反馈ID获取反馈"""
75
+ if self.db is None:
76
+ return self._get_feedback_from_memory(feedback_id)
77
+
78
+ try:
79
+ session_factory = await self.db.get_session_factory()
80
+ async with session_factory() as session:
81
+ result = await session.execute(
82
+ select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
83
+ )
84
+ feedback = result.scalar_one_or_none()
85
+ return feedback.to_dict() if feedback else None
86
+
87
+ except SQLAlchemyError as e:
88
+ logger.error(f"查询反馈失败: {e}")
89
+ return None
90
+
91
+ async def get_feedbacks_by_task(self, task_id: str) -> list[dict[str, Any]]:
92
+ """根据任务ID获取所有反馈"""
93
+ if self.db is None:
94
+ return self._get_feedbacks_by_task_from_memory(task_id)
95
+
96
+ try:
97
+ session_factory = await self.db.get_session_factory()
98
+ async with session_factory() as session:
99
+ result = await session.execute(
100
+ select(FeedbackBase)
101
+ .where(FeedbackBase.task_id == task_id)
102
+ .order_by(FeedbackBase.created_at.desc())
103
+ )
104
+ feedbacks = result.scalars().all()
105
+ return [f.to_dict() for f in feedbacks]
106
+
107
+ except SQLAlchemyError as e:
108
+ logger.error(f"查询任务反馈失败: {e}")
109
+ return []
110
+
111
+ async def get_feedbacks_by_workflow(self, workflow_name: str) -> list[dict[str, Any]]:
112
+ """根据工作流名称获取所有反馈"""
113
+ if self.db is None:
114
+ return self._get_feedbacks_by_workflow_from_memory(workflow_name)
115
+
116
+ try:
117
+ session_factory = await self.db.get_session_factory()
118
+ async with session_factory() as session:
119
+ result = await session.execute(
120
+ select(FeedbackBase)
121
+ .where(FeedbackBase.workflow_name == workflow_name)
122
+ .order_by(FeedbackBase.created_at.desc())
123
+ )
124
+ feedbacks = result.scalars().all()
125
+ return [f.to_dict() for f in feedbacks]
126
+
127
+ except SQLAlchemyError as e:
128
+ logger.error(f"查询工作流反馈失败: {e}")
129
+ return []
130
+
131
+ async def get_all_feedbacks(self) -> list[dict[str, Any]]:
132
+ """获取所有反馈"""
133
+ if self.db is None:
134
+ return self._get_all_feedbacks_from_memory()
135
+
136
+ try:
137
+ session_factory = await self.db.get_session_factory()
138
+ async with session_factory() as session:
139
+ result = await session.execute(
140
+ select(FeedbackBase).order_by(FeedbackBase.created_at.desc())
141
+ )
142
+ feedbacks = result.scalars().all()
143
+ return [f.to_dict() for f in feedbacks]
144
+
145
+ except SQLAlchemyError as e:
146
+ logger.error(f"查询所有反馈失败: {e}")
147
+ return []
148
+
149
+ async def delete_feedback(self, feedback_id: str) -> bool:
150
+ """删除反馈"""
151
+ if self.db is None:
152
+ return self._delete_feedback_from_memory(feedback_id)
153
+
154
+ try:
155
+ session_factory = await self.db.get_session_factory()
156
+ async with session_factory() as session:
157
+ result = await session.execute(
158
+ select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
159
+ )
160
+ feedback = result.scalar_one_or_none()
161
+
162
+ if not feedback:
163
+ return False
164
+
165
+ await session.delete(feedback)
166
+ await session.commit()
167
+
168
+ logger.info(f"删除反馈: feedback_id={feedback_id}")
169
+ return True
170
+
171
+ except SQLAlchemyError as e:
172
+ logger.error(f"删除反馈失败: {e}")
173
+ return False
174
+
175
+ # ========== 内存存储后备方法(用于数据库未配置时) ==========
176
+
177
+ def _create_feedback_in_memory(
178
+ self,
179
+ task_id: str,
180
+ workflow_name: str,
181
+ rating: Optional[int],
182
+ comment: Optional[str],
183
+ metadata: Optional[dict[str, Any]],
184
+ ) -> dict[str, Any]:
185
+ """内存存储版本"""
186
+ feedback_id = str(uuid4())
187
+ created_at = datetime.now()
188
+
189
+ feedback_data = {
190
+ "feedback_id": feedback_id,
191
+ "task_id": task_id,
192
+ "workflow_name": workflow_name,
193
+ "rating": rating,
194
+ "comment": comment,
195
+ "metadata": metadata or {},
196
+ "created_at": created_at,
197
+ }
198
+
199
+ self._storage[feedback_id] = feedback_data
200
+
201
+ if task_id not in self._task_index:
202
+ self._task_index[task_id] = []
203
+ self._task_index[task_id].append(feedback_id)
204
+
205
+ if workflow_name not in self._workflow_index:
206
+ self._workflow_index[workflow_name] = []
207
+ self._workflow_index[workflow_name].append(feedback_id)
208
+
209
+ logger.info(f"[内存存储]创建反馈: feedback_id={feedback_id}")
210
+ return feedback_data
211
+
212
+ def _get_feedback_from_memory(self, feedback_id: str) -> Optional[dict[str, Any]]:
213
+ return self._storage.get(feedback_id)
214
+
215
+ def _get_feedbacks_by_task_from_memory(self, task_id: str) -> list[dict[str, Any]]:
216
+ feedback_ids = self._task_index.get(task_id, [])
217
+ return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
218
+
219
+ def _get_feedbacks_by_workflow_from_memory(self, workflow_name: str) -> list[dict[str, Any]]:
220
+ feedback_ids = self._workflow_index.get(workflow_name, [])
221
+ return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
222
+
223
+ def _get_all_feedbacks_from_memory(self) -> list[dict[str, Any]]:
224
+ return list(self._storage.values())
225
+
226
+ def _delete_feedback_from_memory(self, feedback_id: str) -> bool:
227
+ if feedback_id not in self._storage:
228
+ return False
229
+
230
+ feedback = self._storage[feedback_id]
231
+ task_id = feedback["task_id"]
232
+ workflow_name = feedback["workflow_name"]
233
+
234
+ if task_id in self._task_index:
235
+ self._task_index[task_id].remove(feedback_id)
236
+ if workflow_name in self._workflow_index:
237
+ self._workflow_index[workflow_name].remove(feedback_id)
238
+
239
+ del self._storage[feedback_id]
240
+ logger.info(f"[内存存储]删除反馈: feedback_id={feedback_id}")
241
+ return True
242
+
243
+
244
+ # 全局单例实例
245
+ feedback_storage = FeedbackStorage()
@@ -2,8 +2,8 @@ from ..utils.type_converter import TypeConverter
2
2
  from ..workflow.workflow import Workflow
3
3
  from ..api.http_api import fastapi_app
4
4
  from ..api.kafka_api import KafkaApp, kafka_app
5
- from fastapi import FastAPI
6
5
  from ..workflow.workflow_type import WorkflowType, workflow_type_register
6
+ from fastapi import FastAPI
7
7
 
8
8
  type_converter = TypeConverter()
9
9
  type_converter.register(str, Workflow, lambda s, node: node.sub_workflows.get_workflow(s))
@@ -1,6 +1,8 @@
1
1
  from typing import Any, Callable, Type, Dict, Tuple, Set, List
2
2
  from collections import deque
3
3
  import inspect
4
+ import traceback
5
+ from pydantic import BaseModel
4
6
  from typing_extensions import get_origin, get_args
5
7
 
6
8
  def is_type(value, dst_type):
@@ -57,6 +59,9 @@ class TypeConverter:
57
59
  except Exception:
58
60
  pass
59
61
 
62
+ if issubclass(dst_type, BaseModel) and isinstance(value, dict):
63
+ return dst_type(**value)
64
+
60
65
  path = self._find_path(src_type, dst_type)
61
66
  if not path:
62
67
  raise TypeError(f"No conversion path found from {src_type.__name__} to {dst_type.__name__}.")
@@ -57,8 +57,8 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
57
57
  }
58
58
 
59
59
  workflow = Workflow(
60
- name=self.name,
61
- description=self.description,
60
+ id=self.id,
61
+ config=self.config,
62
62
  nodes=[node_map[node] for node in self.nodes],
63
63
  input_ports=[port_map[port] for port in self.input_ports],
64
64
  output_ports=[port_map[port] for port in self.output_ports],
@@ -69,6 +69,7 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
69
69
  callbacks=self.callbacks,
70
70
  task_id=task_id,
71
71
  real_trigger_node=trigger_node,
72
+ global_context=self.global_context,
72
73
  )
73
74
 
74
75
  for node in workflow.nodes:
@@ -10,6 +10,7 @@ from .context import Context
10
10
  from ..utils.register import Register
11
11
  from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
12
12
  from ..utils.workflow_clone import node_clone
13
+ from .workflow_callback import CallbackEvent
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from .workflow import Workflow
@@ -62,6 +63,10 @@ class Node(ABC):
62
63
  def database_manager(self) -> DatabaseManager:
63
64
  return self.workflow.database_manager
64
65
 
66
+ @property
67
+ def global_context(self) -> Context:
68
+ return self.workflow.global_context
69
+
65
70
  def backup(self) -> None:
66
71
  # do NOT use deepcopy here
67
72
  # self.bak_context = deepcopy(self.context)
@@ -181,4 +186,7 @@ class Node(ABC):
181
186
  def _clone(self, context: Context) -> Node:
182
187
  return node_clone(self, context)
183
188
 
189
+ async def stream_output(self, data: Any) -> None:
190
+ await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
191
+
184
192
  node_register = Register[Node]()
@@ -36,6 +36,6 @@ class QueryLLMNode(Node):
36
36
  prompt = f.read()
37
37
 
38
38
  print(f"prompt: {prompt} temperature: {temperature}")
39
- response = chat_stream(prompt, system_prompt, Model.DEEPSEEK_V3_250324, temperature)
39
+ response = chat_stream(prompt, system_prompt, Model.GEMINI, temperature)
40
40
  for chunk in response:
41
41
  yield chunk
@@ -19,6 +19,10 @@ class Trigger(Node, ABC):
19
19
  async def _run(self) -> AsyncIterator[bool]:
20
20
  ...
21
21
 
22
+ @abstractmethod
23
+ async def _stop(self) -> AsyncIterator[bool]:
24
+ ...
25
+
22
26
  def trigger(self, task_id: uuid.UUID) -> bool:
23
27
  self.prepare_output_edges(self.get_output_port_by_name('trigger'), True)
24
28
  return task_id
@@ -253,3 +253,5 @@ class A2AAPITrigger(Trigger):
253
253
  logger.error(f"Error in A2AAPITrigger._run: {e}")
254
254
  continue
255
255
 
256
+ async def _stop(self) -> AsyncIterator[bool]:
257
+ pass
@@ -33,6 +33,9 @@ class FastAPITrigger(Trigger):
33
33
  super().__init__(name)
34
34
  self.events = {}
35
35
  self.is_setup_route = False
36
+ self.app = None
37
+ self.route_path = None
38
+ self.route_method = None
36
39
 
37
40
  @staticmethod
38
41
  def serialize_result(result: Any):
@@ -141,6 +144,11 @@ class FastAPITrigger(Trigger):
141
144
  async def handler(request: Request):
142
145
  return await self.handle_request(request, data_type, extractor, is_stream)
143
146
 
147
+ # Save route information for cleanup
148
+ self.app = app
149
+ self.route_path = path
150
+ self.route_method = method.upper()
151
+
144
152
  if method == "GET":
145
153
  app.get(path)(handler)
146
154
  elif method == "POST":
@@ -167,3 +175,27 @@ class FastAPITrigger(Trigger):
167
175
  logger.error(f"Error in FastAPITrigger._run: {e}")
168
176
  continue
169
177
 
178
+ async def _stop(self) -> AsyncIterator[bool]:
179
+ if self.is_setup_route:
180
+ # Remove the route from the app
181
+ if self.app and self.route_path and self.route_method:
182
+ # Find and remove matching route
183
+ routes_to_remove = []
184
+ for route in self.app.routes:
185
+ if hasattr(route, "path") and hasattr(route, "methods"):
186
+ if route.path == self.route_path and self.route_method in route.methods:
187
+ routes_to_remove.append(route)
188
+
189
+ # Remove found routes
190
+ for route in routes_to_remove:
191
+ try:
192
+ self.app.routes.remove(route)
193
+ logger.info(f"Removed route {self.route_method} {self.route_path} from FastAPI app")
194
+ except ValueError:
195
+ logger.warning(f"Route {self.route_method} {self.route_path} not found in app.routes")
196
+
197
+ # Reset route information
198
+ self.app = None
199
+ self.route_path = None
200
+ self.route_method = None
201
+ self.is_setup_route = False
@@ -42,3 +42,6 @@ class KafkaAPITrigger(Trigger):
42
42
  trigger = await self.trigger_queue.get()
43
43
  self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
44
44
  yield self.trigger(trigger['id'])
45
+
46
+ async def _stop(self) -> AsyncIterator[bool]:
47
+ pass
@@ -17,4 +17,7 @@ class OnceTrigger(Trigger):
17
17
  super().__init__(name)
18
18
 
19
19
  async def _run(self) -> AsyncIterator[bool]:
20
- yield self.trigger(uuid.uuid4())
20
+ yield self.trigger(uuid.uuid4())
21
+
22
+ async def _stop(self) -> AsyncIterator[bool]:
23
+ pass
@@ -23,4 +23,7 @@ class PeriodTrigger(Trigger):
23
23
  async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
24
24
  for _ in range(times):
25
25
  await asyncio.sleep(period)
26
- yield uuid.uuid4()
26
+ yield uuid.uuid4()
27
+
28
+ async def _stop(self) -> AsyncIterator[bool]:
29
+ pass
@@ -120,12 +120,16 @@ class WebSocketAPITrigger(Trigger):
120
120
  self.result_queues[task_id] = asyncio.Queue()
121
121
  self.stream_queues[task_id] = asyncio.Queue()
122
122
 
123
- try:
124
- converted_data = data_type(**message_data)
125
- except Exception as e:
126
- error_msg = {"error": f"Failed to convert data: {str(e)}"}
127
- await websocket.send_text(json.dumps(error_msg))
128
- return
123
+ if data_type is Any:
124
+ converted_data = message_data
125
+ else:
126
+ try:
127
+ # TODO: message_data is Message, need to convert to dict
128
+ converted_data = data_type(**message_data)
129
+ except Exception as e:
130
+ error_msg = {"error": f"Failed to convert data: {str(e)}"}
131
+ await websocket.send_text(json.dumps(error_msg))
132
+ return
129
133
 
130
134
  # Always start background task to handle stream output
131
135
  asyncio.create_task(self.handle_stream_output(websocket, task_id))
@@ -143,16 +147,14 @@ class WebSocketAPITrigger(Trigger):
143
147
 
144
148
  try:
145
149
  while True:
146
- # Receive message from client
147
- data = await websocket.receive_text()
150
+ data = await websocket.receive()
148
151
  try:
149
- message = json.loads(data)
150
-
152
+ # message = json.loads(data)
151
153
  # Handle the message and trigger workflow
152
154
  await self.handle_websocket_message(
153
155
  websocket,
154
156
  data_type,
155
- message
157
+ data
156
158
  )
157
159
  except json.JSONDecodeError:
158
160
  error_msg = {"error": "Invalid JSON format"}
@@ -182,3 +184,5 @@ class WebSocketAPITrigger(Trigger):
182
184
  logger.error(f"Error in WebSocketAPITrigger._run: {e}")
183
185
  continue
184
186
 
187
+ async def _stop(self) -> AsyncIterator[bool]:
188
+ pass
@@ -12,12 +12,14 @@ from .edge import Edge
12
12
  from ..db.database import DatabaseManager
13
13
  from ..utils.workflow_clone import workflow_clone
14
14
  from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
15
+ from .workflow_config import WorkflowConfig
16
+ from .context import Context
15
17
 
16
18
  class Workflow:
17
19
  def __init__(
18
20
  self,
19
- name: str,
20
- description: str,
21
+ id: uuid.UUID,
22
+ config: WorkflowConfig,
21
23
  nodes: list[Node],
22
24
  input_ports: list[Port],
23
25
  output_ports: list[Port],
@@ -30,9 +32,12 @@ class Workflow:
30
32
  # for run
31
33
  task_id: uuid.UUID = None,
32
34
  real_trigger_node: Trigger = None,
35
+
36
+ # global variables
37
+ global_context: Context = None,
33
38
  ) -> None:
34
- self.name = name
35
- self.description = description
39
+ self.id = id
40
+ self.config = config
36
41
  self.nodes = nodes
37
42
  self.ready_nodes: list[Node] = []
38
43
  self.input_ports = input_ports
@@ -47,7 +52,20 @@ class Workflow:
47
52
  self.callbacks = callbacks
48
53
  self.task_id = task_id
49
54
  self.real_trigger_node = real_trigger_node
55
+ self.global_context = global_context
50
56
  self._validate()
57
+
58
+ @property
59
+ def name(self) -> str:
60
+ return self.config.name
61
+
62
+ @property
63
+ def version(self) -> str:
64
+ return self.config.version
65
+
66
+ @property
67
+ def description(self) -> str:
68
+ return self.config.description
51
69
 
52
70
  def register_callback(self, callback: WorkflowCallback) -> None:
53
71
  self.callbacks.append(callback)
@@ -61,6 +79,8 @@ class Workflow:
61
79
  await callback.on_workflow_start(*args, **kwargs)
62
80
  elif callback_type == CallbackEvent.ON_WORKFLOW_END:
63
81
  await callback.on_workflow_end(*args, **kwargs)
82
+ elif callback_type == CallbackEvent.ON_WORKFLOW_ERROR:
83
+ await callback.on_workflow_error(*args, **kwargs)
64
84
  elif callback_type == CallbackEvent.ON_NODE_START:
65
85
  await callback.on_node_start(*args, **kwargs)
66
86
  elif callback_type == CallbackEvent.ON_NODE_END:
@@ -104,7 +124,7 @@ class Workflow:
104
124
  raise ValueError("Multiple trigger nodes found in workflow.")
105
125
  return trigger_nodes[0]
106
126
 
107
- async def _run_node_with_callbacks(self, node: Node) -> None:
127
+ async def _run_node_with_callbacks(self, node: Node) -> bool:
108
128
  await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
109
129
 
110
130
  try:
@@ -113,8 +133,13 @@ class Workflow:
113
133
  await self.handle_node_stream_output(node, result)
114
134
  elif asyncio.iscoroutine(result):
115
135
  await result
136
+ except Exception as e:
137
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_ERROR, workflow=self, node=node, error=e)
138
+ logger.error(f"Error when running node {node.name}: {str(e)}, task_id: {self.task_id}")
139
+ return False
116
140
  finally:
117
141
  await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
142
+ return True
118
143
 
119
144
  async def run_after_trigger(self) -> Any:
120
145
  logger.info(f"Running workflow: {self.name}")
@@ -125,30 +150,41 @@ class Workflow:
125
150
  for edge in self.get_trigger_node().output_edges:
126
151
  edge.end_port.trigger()
127
152
 
128
- try:
129
- for input_port in self.input_ports:
130
- if input_port.value is not None:
131
- input_port.port.node.fill_input(input_port.port, input_port.value)
132
-
133
- for node in self.nodes:
134
- for key in node.AUTO_FILL_INPUT_PORTS:
135
- if key[0] not in [edge.end_port.name for edge in node.input_edges]:
136
- node.fill_input_by_name(key[0], key[1])
137
-
138
- while self.ready_nodes:
139
- nodes = self.ready_nodes.copy()
140
- self.ready_nodes = []
141
-
142
- tasks = []
143
- for node in nodes:
144
- tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
145
-
146
- await asyncio.gather(*tasks)
147
-
148
- except Exception as e:
149
- error_msg = f"Error in run_after_trigger: {str(e)}"
150
- logger.error(error_msg)
151
- raise e
153
+ for input_port in self.input_ports:
154
+ if input_port.value is not None:
155
+ input_port.port.node.fill_input(input_port.port, input_port.value)
156
+
157
+ for node in self.nodes:
158
+ for key in node.AUTO_FILL_INPUT_PORTS:
159
+ if key[0] not in [edge.end_port.name for edge in node.input_edges]:
160
+ node.fill_input_by_name(key[0], key[1])
161
+
162
+ while self.ready_nodes:
163
+ nodes = self.ready_nodes.copy()
164
+ self.ready_nodes = []
165
+
166
+ tasks = []
167
+ for node in nodes:
168
+ tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
169
+
170
+ results = await asyncio.gather(*tasks, return_exceptions=True)
171
+
172
+ for i, result in enumerate(results):
173
+ if isinstance(result, Exception):
174
+ for task in tasks:
175
+ if not task.done():
176
+ task.cancel()
177
+ await asyncio.gather(*tasks, return_exceptions=True)
178
+ return
179
+ # raise result
180
+ elif result is False:
181
+ logger.error(f"Node execution failed, stopping workflow: {nodes[i].name}")
182
+ for task in tasks:
183
+ if not task.done():
184
+ task.cancel()
185
+ await asyncio.gather(*tasks, return_exceptions=True)
186
+ return
187
+ # raise RuntimeError(f"Workflow stopped due to node execution failure: {nodes[i].name}")
152
188
 
153
189
  if len(self.output_ports) > 0:
154
190
  if len(self.output_ports) == 1:
@@ -173,8 +209,11 @@ class Workflow:
173
209
  # TODO: clear new_workflow
174
210
 
175
211
  except Exception as e:
176
- error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
177
- logger.error(error_msg)
212
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_ERROR, workflow=self, node=None, error=e)
213
+ # error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
214
+ # logger.error(error_msg)
215
+ # await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, node=None, error=e)
216
+ return
178
217
 
179
218
  async def run(self):
180
219
  tasks = []
@@ -186,6 +225,10 @@ class Workflow:
186
225
  if tasks:
187
226
  await asyncio.gather(*tasks)
188
227
 
228
+ async def stop(self):
229
+ trigger = self.get_trigger_node()
230
+ await trigger._stop()
231
+
189
232
  def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
190
233
  trigger = self.get_trigger_node()
191
234
  task_id = uuid.uuid4()
@@ -31,7 +31,7 @@ class WorkflowCallback:
31
31
  pass
32
32
 
33
33
  @abstractmethod
34
- async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
34
+ async def on_workflow_error(self, workflow: Workflow, node: Node, error: Any) -> None:
35
35
  pass
36
36
 
37
37
  @abstractmethod
@@ -90,7 +90,7 @@ class BuiltinWorkflowCallback(WorkflowCallback):
90
90
  logger.error(f"发送 workflow_end 消息到 websocket 失败: {e}")
91
91
 
92
92
  @override
93
- async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
93
+ async def on_workflow_error(self, workflow: Workflow, node: Node | None, error: Any) -> None:
94
94
  workflow_result = WorkflowResult(result=error, is_end=False, is_error=True)
95
95
 
96
96
  if workflow.task_id in workflow.real_trigger_node.result_queues:
@@ -103,6 +103,7 @@ class BuiltinWorkflowCallback(WorkflowCallback):
103
103
  message = {
104
104
  "type": "workflow_error",
105
105
  "task_id": str(workflow.task_id),
106
+ "node": node.name if node else None,
106
107
  "error": self._serialize_result(error),
107
108
  "is_end": False,
108
109
  "is_error": True