service-forge 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (83) hide show
  1. service_forge/api/deprecated_websocket_api.py +86 -0
  2. service_forge/api/deprecated_websocket_manager.py +425 -0
  3. service_forge/api/http_api.py +152 -0
  4. service_forge/api/http_api_doc.py +455 -0
  5. service_forge/api/kafka_api.py +126 -0
  6. service_forge/api/routers/feedback/feedback_router.py +148 -0
  7. service_forge/api/routers/service/service_router.py +127 -0
  8. service_forge/api/routers/websocket/websocket_manager.py +83 -0
  9. service_forge/api/routers/websocket/websocket_router.py +78 -0
  10. service_forge/api/task_manager.py +141 -0
  11. service_forge/current_service.py +14 -0
  12. service_forge/db/__init__.py +1 -0
  13. service_forge/db/database.py +237 -0
  14. service_forge/db/migrations/feedback_migration.py +154 -0
  15. service_forge/db/models/__init__.py +0 -0
  16. service_forge/db/models/feedback.py +33 -0
  17. service_forge/llm/__init__.py +67 -0
  18. service_forge/llm/llm.py +56 -0
  19. service_forge/model/__init__.py +0 -0
  20. service_forge/model/feedback.py +30 -0
  21. service_forge/model/websocket.py +13 -0
  22. service_forge/proto/foo_input.py +5 -0
  23. service_forge/service.py +280 -0
  24. service_forge/service_config.py +44 -0
  25. service_forge/sft/cli.py +91 -0
  26. service_forge/sft/cmd/config_command.py +67 -0
  27. service_forge/sft/cmd/deploy_service.py +123 -0
  28. service_forge/sft/cmd/list_tars.py +41 -0
  29. service_forge/sft/cmd/service_command.py +149 -0
  30. service_forge/sft/cmd/upload_service.py +36 -0
  31. service_forge/sft/config/injector.py +129 -0
  32. service_forge/sft/config/injector_default_files.py +131 -0
  33. service_forge/sft/config/sf_metadata.py +30 -0
  34. service_forge/sft/config/sft_config.py +200 -0
  35. service_forge/sft/file/__init__.py +0 -0
  36. service_forge/sft/file/ignore_pattern.py +80 -0
  37. service_forge/sft/file/sft_file_manager.py +107 -0
  38. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  39. service_forge/sft/util/assert_util.py +25 -0
  40. service_forge/sft/util/logger.py +16 -0
  41. service_forge/sft/util/name_util.py +8 -0
  42. service_forge/sft/util/yaml_utils.py +57 -0
  43. service_forge/storage/__init__.py +5 -0
  44. service_forge/storage/feedback_storage.py +245 -0
  45. service_forge/utils/__init__.py +0 -0
  46. service_forge/utils/default_type_converter.py +12 -0
  47. service_forge/utils/register.py +39 -0
  48. service_forge/utils/type_converter.py +99 -0
  49. service_forge/utils/workflow_clone.py +124 -0
  50. service_forge/workflow/__init__.py +1 -0
  51. service_forge/workflow/context.py +14 -0
  52. service_forge/workflow/edge.py +24 -0
  53. service_forge/workflow/node.py +184 -0
  54. service_forge/workflow/nodes/__init__.py +8 -0
  55. service_forge/workflow/nodes/control/if_node.py +29 -0
  56. service_forge/workflow/nodes/control/switch_node.py +28 -0
  57. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  58. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  59. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  60. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  61. service_forge/workflow/nodes/output/print_node.py +29 -0
  62. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  63. service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
  64. service_forge/workflow/port.py +89 -0
  65. service_forge/workflow/trigger.py +28 -0
  66. service_forge/workflow/triggers/__init__.py +6 -0
  67. service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
  68. service_forge/workflow/triggers/fast_api_trigger.py +201 -0
  69. service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
  70. service_forge/workflow/triggers/once_trigger.py +23 -0
  71. service_forge/workflow/triggers/period_trigger.py +29 -0
  72. service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
  73. service_forge/workflow/workflow.py +227 -0
  74. service_forge/workflow/workflow_callback.py +141 -0
  75. service_forge/workflow/workflow_config.py +66 -0
  76. service_forge/workflow/workflow_event.py +15 -0
  77. service_forge/workflow/workflow_factory.py +246 -0
  78. service_forge/workflow/workflow_group.py +51 -0
  79. service_forge/workflow/workflow_type.py +52 -0
  80. service_forge-0.1.18.dist-info/METADATA +98 -0
  81. service_forge-0.1.18.dist-info/RECORD +83 -0
  82. service_forge-0.1.18.dist-info/WHEEL +4 -0
  83. service_forge-0.1.18.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,189 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ import asyncio
4
+ import json
5
+ from loguru import logger
6
+ from service_forge.workflow.trigger import Trigger
7
+ from typing import AsyncIterator, Any
8
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
+ from service_forge.workflow.port import Port
10
+ from google.protobuf.message import Message
11
+ from google.protobuf.json_format import MessageToJson
12
+
13
+ class WebSocketAPITrigger(Trigger):
14
+ DEFAULT_INPUT_PORTS = [
15
+ Port("app", FastAPI),
16
+ Port("path", str),
17
+ Port("data_type", type),
18
+ ]
19
+
20
+ DEFAULT_OUTPUT_PORTS = [
21
+ Port("trigger", bool),
22
+ Port("data", Any),
23
+ ]
24
+
25
+ def __init__(self, name: str):
26
+ super().__init__(name)
27
+ self.events = {}
28
+ self.is_setup_websocket = False
29
+
30
+ @staticmethod
31
+ def serialize_result(result: Any):
32
+ if isinstance(result, Message):
33
+ return MessageToJson(
34
+ result,
35
+ preserving_proto_field_name=True
36
+ )
37
+ return result
38
+
39
+ async def handle_stream_output(
40
+ self,
41
+ websocket: WebSocket,
42
+ task_id: uuid.UUID,
43
+ ):
44
+ try:
45
+ while True:
46
+ item = await self.stream_queues[task_id].get()
47
+
48
+ if item.is_error:
49
+ error_response = {
50
+ "type": "stream_error",
51
+ "task_id": str(task_id),
52
+ "detail": str(item.result)
53
+ }
54
+ await websocket.send_text(json.dumps(error_response))
55
+ break
56
+
57
+ if item.is_end:
58
+ # Stream ended, send final result if available
59
+ if item.result is not None:
60
+ serialized = self.serialize_result(item.result)
61
+ if isinstance(serialized, str):
62
+ try:
63
+ data = json.loads(serialized)
64
+ except json.JSONDecodeError:
65
+ data = serialized
66
+ else:
67
+ data = serialized
68
+
69
+ end_response = {
70
+ "type": "stream_end",
71
+ "task_id": str(task_id),
72
+ "data": data
73
+ }
74
+ else:
75
+ end_response = {
76
+ "type": "stream_end",
77
+ "task_id": str(task_id)
78
+ }
79
+ await websocket.send_text(json.dumps(end_response))
80
+ break
81
+
82
+ # Send stream data
83
+ serialized = self.serialize_result(item.result)
84
+ if isinstance(serialized, str):
85
+ try:
86
+ data = json.loads(serialized)
87
+ except json.JSONDecodeError:
88
+ data = serialized
89
+ else:
90
+ data = serialized
91
+
92
+ stream_response = {
93
+ "type": "stream",
94
+ "task_id": str(task_id),
95
+ "data": data
96
+ }
97
+ await websocket.send_text(json.dumps(stream_response))
98
+ except Exception as e:
99
+ logger.error(f"Error handling stream output for task {task_id}: {e}")
100
+ error_response = {
101
+ "type": "stream_error",
102
+ "task_id": str(task_id),
103
+ "detail": str(e)
104
+ }
105
+ try:
106
+ await websocket.send_text(json.dumps(error_response))
107
+ except Exception:
108
+ pass
109
+ finally:
110
+ if task_id in self.stream_queues:
111
+ del self.stream_queues[task_id]
112
+
113
+ async def handle_websocket_message(
114
+ self,
115
+ websocket: WebSocket,
116
+ data_type: type,
117
+ message_data: dict,
118
+ ):
119
+ task_id = uuid.uuid4()
120
+ self.result_queues[task_id] = asyncio.Queue()
121
+ self.stream_queues[task_id] = asyncio.Queue()
122
+
123
+ if data_type is Any:
124
+ converted_data = message_data
125
+ else:
126
+ try:
127
+ converted_data = data_type(**message_data)
128
+ except Exception as e:
129
+ error_msg = {"error": f"Failed to convert data: {str(e)}"}
130
+ await websocket.send_text(json.dumps(error_msg))
131
+ return
132
+
133
+ # Always start background task to handle stream output
134
+ asyncio.create_task(self.handle_stream_output(websocket, task_id))
135
+
136
+ self.trigger_queue.put_nowait({
137
+ "id": task_id,
138
+ "data": converted_data,
139
+ })
140
+
141
+ # The stream handler will send all messages including stream_end when workflow completes
142
+
143
+ def _setup_websocket(self, app: FastAPI, path: str, data_type: type) -> None:
144
+ async def websocket_handler(websocket: WebSocket):
145
+ await websocket.accept()
146
+
147
+ try:
148
+ while True:
149
+ # Receive message from client
150
+ data = await websocket.receive_text()
151
+ try:
152
+ message = json.loads(data)
153
+
154
+ # Handle the message and trigger workflow
155
+ await self.handle_websocket_message(
156
+ websocket,
157
+ data_type,
158
+ message
159
+ )
160
+ except json.JSONDecodeError:
161
+ error_msg = {"error": "Invalid JSON format"}
162
+ await websocket.send_text(json.dumps(error_msg))
163
+ except Exception as e:
164
+ logger.error(f"Error handling websocket message: {e}")
165
+ error_msg = {"error": str(e)}
166
+ await websocket.send_text(json.dumps(error_msg))
167
+ except WebSocketDisconnect:
168
+ logger.info("WebSocket client disconnected")
169
+ except Exception as e:
170
+ logger.error(f"WebSocket connection error: {e}")
171
+
172
+ app.websocket(path)(websocket_handler)
173
+
174
+ async def _run(self, app: FastAPI, path: str, data_type: type) -> AsyncIterator[bool]:
175
+ if not self.is_setup_websocket:
176
+ self._setup_websocket(app, path, data_type)
177
+ self.is_setup_websocket = True
178
+
179
+ while True:
180
+ try:
181
+ trigger = await self.trigger_queue.get()
182
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
183
+ yield self.trigger(trigger['id'])
184
+ except Exception as e:
185
+ logger.error(f"Error in WebSocketAPITrigger._run: {e}")
186
+ continue
187
+
188
+ async def _stop(self) -> AsyncIterator[bool]:
189
+ pass
@@ -0,0 +1,227 @@
1
+ from __future__ import annotations
2
+ import traceback
3
+ import asyncio
4
+ import uuid
5
+ from typing import AsyncIterator, Awaitable, Callable, Any
6
+ from loguru import logger
7
+ from copy import deepcopy
8
+ from .node import Node
9
+ from .port import Port
10
+ from .trigger import Trigger
11
+ from .edge import Edge
12
+ from ..db.database import DatabaseManager
13
+ from ..utils.workflow_clone import workflow_clone
14
+ from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
15
+ from .workflow_config import WorkflowConfig
16
+
17
+ class Workflow:
18
+ def __init__(
19
+ self,
20
+ id: uuid.UUID,
21
+ config: WorkflowConfig,
22
+ nodes: list[Node],
23
+ input_ports: list[Port],
24
+ output_ports: list[Port],
25
+ _handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None, # deprecated
26
+ _handle_query_user: Callable[[str, str], Awaitable[str]] = None,
27
+ database_manager: DatabaseManager = None,
28
+ max_concurrent_runs: int = 10,
29
+ callbacks: list[WorkflowCallback] = [],
30
+
31
+ # for run
32
+ task_id: uuid.UUID = None,
33
+ real_trigger_node: Trigger = None,
34
+ ) -> None:
35
+ self.id = id
36
+ self.config = config
37
+ self.nodes = nodes
38
+ self.ready_nodes: list[Node] = []
39
+ self.input_ports = input_ports
40
+ self.output_ports = output_ports
41
+ self._handle_stream_output = _handle_stream_output
42
+ self._handle_query_user = _handle_query_user
43
+ self.after_trigger_workflow = None
44
+ self.result_port = Port("result", Any)
45
+ self.database_manager = database_manager
46
+ self.max_concurrent_runs = max_concurrent_runs
47
+ self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
48
+ self.callbacks = callbacks
49
+ self.task_id = task_id
50
+ self.real_trigger_node = real_trigger_node
51
+ self._validate()
52
+
53
+ @property
54
+ def name(self) -> str:
55
+ return self.config.name
56
+
57
+ @property
58
+ def version(self) -> str:
59
+ return self.config.version
60
+
61
+ @property
62
+ def description(self) -> str:
63
+ return self.config.description
64
+
65
+ def register_callback(self, callback: WorkflowCallback) -> None:
66
+ self.callbacks.append(callback)
67
+
68
+ def unregister_callback(self, callback: WorkflowCallback) -> None:
69
+ self.callbacks.remove(callback)
70
+
71
+ async def call_callbacks(self, callback_type: CallbackEvent, *args, **kwargs) -> None:
72
+ for callback in self.callbacks:
73
+ if callback_type == CallbackEvent.ON_WORKFLOW_START:
74
+ await callback.on_workflow_start(*args, **kwargs)
75
+ elif callback_type == CallbackEvent.ON_WORKFLOW_END:
76
+ await callback.on_workflow_end(*args, **kwargs)
77
+ elif callback_type == CallbackEvent.ON_NODE_START:
78
+ await callback.on_node_start(*args, **kwargs)
79
+ elif callback_type == CallbackEvent.ON_NODE_END:
80
+ await callback.on_node_end(*args, **kwargs)
81
+ elif callback_type == CallbackEvent.ON_NODE_STREAM_OUTPUT:
82
+ await callback.on_node_stream_output(*args, **kwargs)
83
+
84
+ def add_nodes(self, nodes: list[Node]) -> None:
85
+ for node in nodes:
86
+ node.workflow = self
87
+ self.nodes.extend(nodes)
88
+
89
+ def remove_nodes(self, nodes: list[Node]) -> None:
90
+ for node in nodes:
91
+ self.nodes.remove(node)
92
+
93
+ def load_config(self) -> None:
94
+ ...
95
+
96
+ def _validate(self) -> None:
97
+ # DAG
98
+ ...
99
+
100
+ def get_input_port_by_name(self, name: str) -> Port:
101
+ for input_port in self.input_ports:
102
+ if input_port.name == name:
103
+ return input_port
104
+ return None
105
+
106
+ def get_output_port_by_name(self, name: str) -> Port:
107
+ for output_port in self.output_ports:
108
+ if output_port.name == name:
109
+ return output_port
110
+ return None
111
+
112
+ def get_trigger_node(self) -> Trigger:
113
+ trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
114
+ if not trigger_nodes:
115
+ raise ValueError("No trigger nodes found in workflow.")
116
+ if len(trigger_nodes) > 1:
117
+ raise ValueError("Multiple trigger nodes found in workflow.")
118
+ return trigger_nodes[0]
119
+
120
+ async def _run_node_with_callbacks(self, node: Node) -> None:
121
+ await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
122
+
123
+ try:
124
+ result = node.run()
125
+ if hasattr(result, '__anext__'):
126
+ await self.handle_node_stream_output(node, result)
127
+ elif asyncio.iscoroutine(result):
128
+ await result
129
+ finally:
130
+ await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
131
+
132
+ async def run_after_trigger(self) -> Any:
133
+ logger.info(f"Running workflow: {self.name}")
134
+
135
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_START, workflow=self)
136
+
137
+ self.ready_nodes = []
138
+ for edge in self.get_trigger_node().output_edges:
139
+ edge.end_port.trigger()
140
+
141
+ try:
142
+ for input_port in self.input_ports:
143
+ if input_port.value is not None:
144
+ input_port.port.node.fill_input(input_port.port, input_port.value)
145
+
146
+ for node in self.nodes:
147
+ for key in node.AUTO_FILL_INPUT_PORTS:
148
+ if key[0] not in [edge.end_port.name for edge in node.input_edges]:
149
+ node.fill_input_by_name(key[0], key[1])
150
+
151
+ while self.ready_nodes:
152
+ nodes = self.ready_nodes.copy()
153
+ self.ready_nodes = []
154
+
155
+ tasks = []
156
+ for node in nodes:
157
+ tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
158
+
159
+ await asyncio.gather(*tasks)
160
+
161
+ except Exception as e:
162
+ error_msg = f"Error in run_after_trigger: {str(e)}"
163
+ logger.error(error_msg)
164
+ raise e
165
+
166
+ if len(self.output_ports) > 0:
167
+ if len(self.output_ports) == 1:
168
+ if self.output_ports[0].is_prepared:
169
+ result = self.output_ports[0].value
170
+ else:
171
+ result = None
172
+ else:
173
+ result = {}
174
+ for port in self.output_ports:
175
+ if port.is_prepared:
176
+ result[port.name] = port.value
177
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=result)
178
+ else:
179
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=None)
180
+
181
+ async def _run(self, task_id: uuid.UUID, trigger_node: Trigger) -> None:
182
+ async with self.run_semaphore:
183
+ try:
184
+ new_workflow = self._clone(task_id, trigger_node)
185
+ await new_workflow.run_after_trigger()
186
+ # TODO: clear new_workflow
187
+
188
+ except Exception as e:
189
+ error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
190
+ logger.error(error_msg)
191
+
192
+ async def run(self):
193
+ tasks = []
194
+ trigger = self.get_trigger_node()
195
+
196
+ async for task_id in trigger.run():
197
+ tasks.append(asyncio.create_task(self._run(task_id, trigger)))
198
+
199
+ if tasks:
200
+ await asyncio.gather(*tasks)
201
+
202
+ async def stop(self):
203
+ trigger = self.get_trigger_node()
204
+ await trigger._stop()
205
+
206
+ def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
207
+ trigger = self.get_trigger_node()
208
+ task_id = uuid.uuid4()
209
+ for key, value in kwargs.items():
210
+ trigger.prepare_output_edges(key, value)
211
+ task = asyncio.create_task(self._run(task_id, trigger))
212
+ return task_id
213
+
214
+ async def handle_node_stream_output(
215
+ self,
216
+ node: Node,
217
+ stream: AsyncIterator[Any],
218
+ ) -> None:
219
+ async for data in stream:
220
+ await self.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=node, output=data)
221
+
222
+ # TODO: refactor this
223
+ async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
224
+ return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
225
+
226
+ def _clone(self, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
227
+ return workflow_clone(self, task_id, trigger_node)
@@ -0,0 +1,141 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from abc import abstractmethod
5
+ from typing import TYPE_CHECKING, Any
6
+ from typing_extensions import override
7
+ from enum import Enum
8
+ from .workflow_event import WorkflowResult
9
+ from loguru import logger
10
+
11
+ if TYPE_CHECKING:
12
+ from .node import Node
13
+ from .workflow import Workflow
14
+
15
+ class CallbackEvent(Enum):
16
+ ON_WORKFLOW_START = "on_workflow_start"
17
+ ON_WORKFLOW_END = "on_workflow_end"
18
+ ON_WORKFLOW_ERROR = "on_workflow_error"
19
+ ON_NODE_START = "on_node_start"
20
+ ON_NODE_END = "on_node_end"
21
+ ON_NODE_OUTPUT = "on_node_output"
22
+ ON_NODE_STREAM_OUTPUT = "on_node_stream_output"
23
+
24
+ class WorkflowCallback:
25
+ @abstractmethod
26
+ async def on_workflow_start(self, workflow: Workflow) -> None:
27
+ ...
28
+
29
+ @abstractmethod
30
+ async def on_workflow_end(self, workflow: Workflow, output: Any) -> None:
31
+ pass
32
+
33
+ @abstractmethod
34
+ async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
35
+ pass
36
+
37
+ @abstractmethod
38
+ async def on_node_start(self, node: Node) -> None:
39
+ ...
40
+
41
+ @abstractmethod
42
+ async def on_node_end(self, node: Node) -> None:
43
+ ...
44
+
45
+ @abstractmethod
46
+ async def on_node_stream_output(self, node: Node, output: Any) -> None:
47
+ ...
48
+
49
+ class BuiltinWorkflowCallback(WorkflowCallback):
50
+ def __init__(self):
51
+ self._websocket_manager = None
52
+
53
+ def _get_websocket_manager(self):
54
+ if self._websocket_manager is None:
55
+ from service_forge.api.routers.websocket.websocket_manager import websocket_manager
56
+ self._websocket_manager = websocket_manager
57
+ return self._websocket_manager
58
+
59
+ def _serialize_result(self, result: Any) -> Any:
60
+ try:
61
+ json.dumps(result)
62
+ return result
63
+ except (TypeError, ValueError):
64
+ return str(result)
65
+
66
+ @override
67
+ async def on_workflow_start(self, workflow: Workflow) -> None:
68
+ ...
69
+
70
+ @override
71
+ async def on_workflow_end(self, workflow: Workflow, output: Any) -> None:
72
+ workflow_result = WorkflowResult(result=output, is_end=True, is_error=False)
73
+
74
+ if workflow.task_id in workflow.real_trigger_node.result_queues:
75
+ workflow.real_trigger_node.result_queues[workflow.task_id].put_nowait(workflow_result)
76
+ if workflow.task_id in workflow.real_trigger_node.stream_queues:
77
+ workflow.real_trigger_node.stream_queues[workflow.task_id].put_nowait(workflow_result)
78
+
79
+ try:
80
+ manager = self._get_websocket_manager()
81
+ message = {
82
+ "type": "workflow_end",
83
+ "task_id": str(workflow.task_id),
84
+ "result": self._serialize_result(output),
85
+ "is_end": True,
86
+ "is_error": False
87
+ }
88
+ await manager.send_to_task(workflow.task_id, message)
89
+ except Exception as e:
90
+ logger.error(f"发送 workflow_end 消息到 websocket 失败: {e}")
91
+
92
+ @override
93
+ async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
94
+ workflow_result = WorkflowResult(result=error, is_end=False, is_error=True)
95
+
96
+ if workflow.task_id in workflow.real_trigger_node.result_queues:
97
+ workflow.real_trigger_node.result_queues[workflow.task_id].put_nowait(workflow_result)
98
+ if workflow.task_id in workflow.real_trigger_node.stream_queues:
99
+ workflow.real_trigger_node.stream_queues[workflow.task_id].put_nowait(workflow_result)
100
+
101
+ try:
102
+ manager = self._get_websocket_manager()
103
+ message = {
104
+ "type": "workflow_error",
105
+ "task_id": str(workflow.task_id),
106
+ "error": self._serialize_result(error),
107
+ "is_end": False,
108
+ "is_error": True
109
+ }
110
+ await manager.send_to_task(workflow.task_id, message)
111
+ except Exception as e:
112
+ logger.error(f"发送 workflow_error 消息到 websocket 失败: {e}")
113
+
114
+ @override
115
+ async def on_node_start(self, node: Node) -> None:
116
+ ...
117
+
118
+ @override
119
+ async def on_node_end(self, node: Node) -> None:
120
+ ...
121
+
122
+ @override
123
+ async def on_node_stream_output(self, node: Node, output: Any) -> None:
124
+ workflow_result = WorkflowResult(result=output, is_end=False, is_error=False)
125
+
126
+ if node.workflow.task_id in node.workflow.real_trigger_node.stream_queues:
127
+ node.workflow.real_trigger_node.stream_queues[node.workflow.task_id].put_nowait(workflow_result)
128
+
129
+ try:
130
+ manager = self._get_websocket_manager()
131
+ message = {
132
+ "type": "node_stream_output",
133
+ "task_id": str(node.workflow.task_id),
134
+ "node": node.name,
135
+ "output": self._serialize_result(output),
136
+ "is_end": False,
137
+ "is_error": False
138
+ }
139
+ await manager.send_to_task(node.workflow.task_id, message)
140
+ except Exception as e:
141
+ logger.error(f"发送 node_stream_output 消息到 websocket 失败: {e}")
@@ -0,0 +1,66 @@
1
+ from __future__ import annotations
2
+
3
+ import yaml
4
+ from typing import Any
5
+ from pydantic import BaseModel
6
+
7
+ DEFAULT_WORKFLOW_VERSION = "0"
8
+
9
+ class WorkflowNodeOutputConfig(BaseModel):
10
+ name: str
11
+ port: str
12
+
13
+ class WorkflowNodeArgConfig(BaseModel):
14
+ name: str
15
+ value: Any | None = None
16
+
17
+ class WorkflowInputPortConfig(BaseModel):
18
+ name: str
19
+ port: str
20
+ value: Any | None = None
21
+
22
+ class WorkflowOutputPortConfig(BaseModel):
23
+ name: str
24
+ port: str
25
+
26
+ class WorkflowNodeSubWorkflowConfig(BaseModel):
27
+ name: str
28
+ version: str = DEFAULT_WORKFLOW_VERSION
29
+
30
+ class WorkflowNodeSubWorkflowInputPortConfig(BaseModel):
31
+ name: str
32
+ port: str
33
+ value: Any | None = None
34
+
35
+ class WorkflowNodeConfig(BaseModel):
36
+ name: str
37
+ type: str
38
+ args: dict | None = None
39
+ outputs: dict | None = None
40
+ sub_workflows: list[WorkflowNodeSubWorkflowConfig] | None = None
41
+ sub_workflows_input_ports: list[WorkflowNodeSubWorkflowInputPortConfig] | None = None
42
+
43
+ class WorkflowConfig(BaseModel):
44
+ name: str
45
+ version: str = DEFAULT_WORKFLOW_VERSION
46
+ description: str
47
+ nodes: list[WorkflowNodeConfig]
48
+ inputs: list[WorkflowInputPortConfig] | None = None
49
+ outputs: list[WorkflowOutputPortConfig] | None = None
50
+
51
+ @classmethod
52
+ def from_yaml_file(cls, filepath: str) -> WorkflowConfig:
53
+ with open(filepath, 'r', encoding='utf-8') as f:
54
+ data = yaml.safe_load(f)
55
+ return cls(**data)
56
+
57
+ class WorkflowGroupConfig(BaseModel):
58
+ workflows: list[WorkflowConfig] | None
59
+ main_workflow_name: str | None
60
+ main_workflow_version: str | None = DEFAULT_WORKFLOW_VERSION
61
+
62
+ @classmethod
63
+ def from_yaml_file(cls, filepath: str) -> WorkflowGroupConfig:
64
+ with open(filepath, 'r', encoding='utf-8') as f:
65
+ data = yaml.safe_load(f)
66
+ return cls(**data)
@@ -0,0 +1,15 @@
1
+ from __future__ import annotations
2
+
3
+ from typing import Any
4
+
5
+ class WorkflowResult:
6
+ def __init__(
7
+ self,
8
+ result: Any | None,
9
+ is_end: bool,
10
+ is_error: bool,
11
+ ) -> None:
12
+ # when is_end is True, result is from the output port of the workflow
13
+ self.result = result
14
+ self.is_end = is_end
15
+ self.is_error = is_error