service-forge 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +152 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/feedback/feedback_router.py +148 -0
- service_forge/api/routers/service/service_router.py +127 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/current_service.py +14 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +237 -0
- service_forge/db/migrations/feedback_migration.py +154 -0
- service_forge/db/models/__init__.py +0 -0
- service_forge/db/models/feedback.py +33 -0
- service_forge/llm/__init__.py +67 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/feedback.py +30 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +280 -0
- service_forge/service_config.py +44 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +129 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +200 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/storage/__init__.py +5 -0
- service_forge/storage/feedback_storage.py +245 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +28 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
- service_forge/workflow/triggers/fast_api_trigger.py +201 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
- service_forge/workflow/triggers/once_trigger.py +23 -0
- service_forge/workflow/triggers/period_trigger.py +29 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
- service_forge/workflow/workflow.py +227 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_config.py +66 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +51 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.18.dist-info/METADATA +98 -0
- service_forge-0.1.18.dist-info/RECORD +83 -0
- service_forge-0.1.18.dist-info/WHEEL +4 -0
- service_forge-0.1.18.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from service_forge.workflow.trigger import Trigger
|
|
7
|
+
from typing import AsyncIterator, Any
|
|
8
|
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
9
|
+
from service_forge.workflow.port import Port
|
|
10
|
+
from google.protobuf.message import Message
|
|
11
|
+
from google.protobuf.json_format import MessageToJson
|
|
12
|
+
|
|
13
|
+
class WebSocketAPITrigger(Trigger):
|
|
14
|
+
DEFAULT_INPUT_PORTS = [
|
|
15
|
+
Port("app", FastAPI),
|
|
16
|
+
Port("path", str),
|
|
17
|
+
Port("data_type", type),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
21
|
+
Port("trigger", bool),
|
|
22
|
+
Port("data", Any),
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
def __init__(self, name: str):
|
|
26
|
+
super().__init__(name)
|
|
27
|
+
self.events = {}
|
|
28
|
+
self.is_setup_websocket = False
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def serialize_result(result: Any):
|
|
32
|
+
if isinstance(result, Message):
|
|
33
|
+
return MessageToJson(
|
|
34
|
+
result,
|
|
35
|
+
preserving_proto_field_name=True
|
|
36
|
+
)
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
async def handle_stream_output(
|
|
40
|
+
self,
|
|
41
|
+
websocket: WebSocket,
|
|
42
|
+
task_id: uuid.UUID,
|
|
43
|
+
):
|
|
44
|
+
try:
|
|
45
|
+
while True:
|
|
46
|
+
item = await self.stream_queues[task_id].get()
|
|
47
|
+
|
|
48
|
+
if item.is_error:
|
|
49
|
+
error_response = {
|
|
50
|
+
"type": "stream_error",
|
|
51
|
+
"task_id": str(task_id),
|
|
52
|
+
"detail": str(item.result)
|
|
53
|
+
}
|
|
54
|
+
await websocket.send_text(json.dumps(error_response))
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if item.is_end:
|
|
58
|
+
# Stream ended, send final result if available
|
|
59
|
+
if item.result is not None:
|
|
60
|
+
serialized = self.serialize_result(item.result)
|
|
61
|
+
if isinstance(serialized, str):
|
|
62
|
+
try:
|
|
63
|
+
data = json.loads(serialized)
|
|
64
|
+
except json.JSONDecodeError:
|
|
65
|
+
data = serialized
|
|
66
|
+
else:
|
|
67
|
+
data = serialized
|
|
68
|
+
|
|
69
|
+
end_response = {
|
|
70
|
+
"type": "stream_end",
|
|
71
|
+
"task_id": str(task_id),
|
|
72
|
+
"data": data
|
|
73
|
+
}
|
|
74
|
+
else:
|
|
75
|
+
end_response = {
|
|
76
|
+
"type": "stream_end",
|
|
77
|
+
"task_id": str(task_id)
|
|
78
|
+
}
|
|
79
|
+
await websocket.send_text(json.dumps(end_response))
|
|
80
|
+
break
|
|
81
|
+
|
|
82
|
+
# Send stream data
|
|
83
|
+
serialized = self.serialize_result(item.result)
|
|
84
|
+
if isinstance(serialized, str):
|
|
85
|
+
try:
|
|
86
|
+
data = json.loads(serialized)
|
|
87
|
+
except json.JSONDecodeError:
|
|
88
|
+
data = serialized
|
|
89
|
+
else:
|
|
90
|
+
data = serialized
|
|
91
|
+
|
|
92
|
+
stream_response = {
|
|
93
|
+
"type": "stream",
|
|
94
|
+
"task_id": str(task_id),
|
|
95
|
+
"data": data
|
|
96
|
+
}
|
|
97
|
+
await websocket.send_text(json.dumps(stream_response))
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error(f"Error handling stream output for task {task_id}: {e}")
|
|
100
|
+
error_response = {
|
|
101
|
+
"type": "stream_error",
|
|
102
|
+
"task_id": str(task_id),
|
|
103
|
+
"detail": str(e)
|
|
104
|
+
}
|
|
105
|
+
try:
|
|
106
|
+
await websocket.send_text(json.dumps(error_response))
|
|
107
|
+
except Exception:
|
|
108
|
+
pass
|
|
109
|
+
finally:
|
|
110
|
+
if task_id in self.stream_queues:
|
|
111
|
+
del self.stream_queues[task_id]
|
|
112
|
+
|
|
113
|
+
async def handle_websocket_message(
|
|
114
|
+
self,
|
|
115
|
+
websocket: WebSocket,
|
|
116
|
+
data_type: type,
|
|
117
|
+
message_data: dict,
|
|
118
|
+
):
|
|
119
|
+
task_id = uuid.uuid4()
|
|
120
|
+
self.result_queues[task_id] = asyncio.Queue()
|
|
121
|
+
self.stream_queues[task_id] = asyncio.Queue()
|
|
122
|
+
|
|
123
|
+
if data_type is Any:
|
|
124
|
+
converted_data = message_data
|
|
125
|
+
else:
|
|
126
|
+
try:
|
|
127
|
+
converted_data = data_type(**message_data)
|
|
128
|
+
except Exception as e:
|
|
129
|
+
error_msg = {"error": f"Failed to convert data: {str(e)}"}
|
|
130
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
131
|
+
return
|
|
132
|
+
|
|
133
|
+
# Always start background task to handle stream output
|
|
134
|
+
asyncio.create_task(self.handle_stream_output(websocket, task_id))
|
|
135
|
+
|
|
136
|
+
self.trigger_queue.put_nowait({
|
|
137
|
+
"id": task_id,
|
|
138
|
+
"data": converted_data,
|
|
139
|
+
})
|
|
140
|
+
|
|
141
|
+
# The stream handler will send all messages including stream_end when workflow completes
|
|
142
|
+
|
|
143
|
+
def _setup_websocket(self, app: FastAPI, path: str, data_type: type) -> None:
|
|
144
|
+
async def websocket_handler(websocket: WebSocket):
|
|
145
|
+
await websocket.accept()
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
while True:
|
|
149
|
+
# Receive message from client
|
|
150
|
+
data = await websocket.receive_text()
|
|
151
|
+
try:
|
|
152
|
+
message = json.loads(data)
|
|
153
|
+
|
|
154
|
+
# Handle the message and trigger workflow
|
|
155
|
+
await self.handle_websocket_message(
|
|
156
|
+
websocket,
|
|
157
|
+
data_type,
|
|
158
|
+
message
|
|
159
|
+
)
|
|
160
|
+
except json.JSONDecodeError:
|
|
161
|
+
error_msg = {"error": "Invalid JSON format"}
|
|
162
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
163
|
+
except Exception as e:
|
|
164
|
+
logger.error(f"Error handling websocket message: {e}")
|
|
165
|
+
error_msg = {"error": str(e)}
|
|
166
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
167
|
+
except WebSocketDisconnect:
|
|
168
|
+
logger.info("WebSocket client disconnected")
|
|
169
|
+
except Exception as e:
|
|
170
|
+
logger.error(f"WebSocket connection error: {e}")
|
|
171
|
+
|
|
172
|
+
app.websocket(path)(websocket_handler)
|
|
173
|
+
|
|
174
|
+
async def _run(self, app: FastAPI, path: str, data_type: type) -> AsyncIterator[bool]:
|
|
175
|
+
if not self.is_setup_websocket:
|
|
176
|
+
self._setup_websocket(app, path, data_type)
|
|
177
|
+
self.is_setup_websocket = True
|
|
178
|
+
|
|
179
|
+
while True:
|
|
180
|
+
try:
|
|
181
|
+
trigger = await self.trigger_queue.get()
|
|
182
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
183
|
+
yield self.trigger(trigger['id'])
|
|
184
|
+
except Exception as e:
|
|
185
|
+
logger.error(f"Error in WebSocketAPITrigger._run: {e}")
|
|
186
|
+
continue
|
|
187
|
+
|
|
188
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
189
|
+
pass
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import traceback
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import AsyncIterator, Awaitable, Callable, Any
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from .node import Node
|
|
9
|
+
from .port import Port
|
|
10
|
+
from .trigger import Trigger
|
|
11
|
+
from .edge import Edge
|
|
12
|
+
from ..db.database import DatabaseManager
|
|
13
|
+
from ..utils.workflow_clone import workflow_clone
|
|
14
|
+
from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
|
|
15
|
+
from .workflow_config import WorkflowConfig
|
|
16
|
+
|
|
17
|
+
class Workflow:
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
id: uuid.UUID,
|
|
21
|
+
config: WorkflowConfig,
|
|
22
|
+
nodes: list[Node],
|
|
23
|
+
input_ports: list[Port],
|
|
24
|
+
output_ports: list[Port],
|
|
25
|
+
_handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None, # deprecated
|
|
26
|
+
_handle_query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
27
|
+
database_manager: DatabaseManager = None,
|
|
28
|
+
max_concurrent_runs: int = 10,
|
|
29
|
+
callbacks: list[WorkflowCallback] = [],
|
|
30
|
+
|
|
31
|
+
# for run
|
|
32
|
+
task_id: uuid.UUID = None,
|
|
33
|
+
real_trigger_node: Trigger = None,
|
|
34
|
+
) -> None:
|
|
35
|
+
self.id = id
|
|
36
|
+
self.config = config
|
|
37
|
+
self.nodes = nodes
|
|
38
|
+
self.ready_nodes: list[Node] = []
|
|
39
|
+
self.input_ports = input_ports
|
|
40
|
+
self.output_ports = output_ports
|
|
41
|
+
self._handle_stream_output = _handle_stream_output
|
|
42
|
+
self._handle_query_user = _handle_query_user
|
|
43
|
+
self.after_trigger_workflow = None
|
|
44
|
+
self.result_port = Port("result", Any)
|
|
45
|
+
self.database_manager = database_manager
|
|
46
|
+
self.max_concurrent_runs = max_concurrent_runs
|
|
47
|
+
self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
|
|
48
|
+
self.callbacks = callbacks
|
|
49
|
+
self.task_id = task_id
|
|
50
|
+
self.real_trigger_node = real_trigger_node
|
|
51
|
+
self._validate()
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def name(self) -> str:
|
|
55
|
+
return self.config.name
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def version(self) -> str:
|
|
59
|
+
return self.config.version
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def description(self) -> str:
|
|
63
|
+
return self.config.description
|
|
64
|
+
|
|
65
|
+
def register_callback(self, callback: WorkflowCallback) -> None:
|
|
66
|
+
self.callbacks.append(callback)
|
|
67
|
+
|
|
68
|
+
def unregister_callback(self, callback: WorkflowCallback) -> None:
|
|
69
|
+
self.callbacks.remove(callback)
|
|
70
|
+
|
|
71
|
+
async def call_callbacks(self, callback_type: CallbackEvent, *args, **kwargs) -> None:
|
|
72
|
+
for callback in self.callbacks:
|
|
73
|
+
if callback_type == CallbackEvent.ON_WORKFLOW_START:
|
|
74
|
+
await callback.on_workflow_start(*args, **kwargs)
|
|
75
|
+
elif callback_type == CallbackEvent.ON_WORKFLOW_END:
|
|
76
|
+
await callback.on_workflow_end(*args, **kwargs)
|
|
77
|
+
elif callback_type == CallbackEvent.ON_NODE_START:
|
|
78
|
+
await callback.on_node_start(*args, **kwargs)
|
|
79
|
+
elif callback_type == CallbackEvent.ON_NODE_END:
|
|
80
|
+
await callback.on_node_end(*args, **kwargs)
|
|
81
|
+
elif callback_type == CallbackEvent.ON_NODE_STREAM_OUTPUT:
|
|
82
|
+
await callback.on_node_stream_output(*args, **kwargs)
|
|
83
|
+
|
|
84
|
+
def add_nodes(self, nodes: list[Node]) -> None:
|
|
85
|
+
for node in nodes:
|
|
86
|
+
node.workflow = self
|
|
87
|
+
self.nodes.extend(nodes)
|
|
88
|
+
|
|
89
|
+
def remove_nodes(self, nodes: list[Node]) -> None:
|
|
90
|
+
for node in nodes:
|
|
91
|
+
self.nodes.remove(node)
|
|
92
|
+
|
|
93
|
+
def load_config(self) -> None:
|
|
94
|
+
...
|
|
95
|
+
|
|
96
|
+
def _validate(self) -> None:
|
|
97
|
+
# DAG
|
|
98
|
+
...
|
|
99
|
+
|
|
100
|
+
def get_input_port_by_name(self, name: str) -> Port:
|
|
101
|
+
for input_port in self.input_ports:
|
|
102
|
+
if input_port.name == name:
|
|
103
|
+
return input_port
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
def get_output_port_by_name(self, name: str) -> Port:
|
|
107
|
+
for output_port in self.output_ports:
|
|
108
|
+
if output_port.name == name:
|
|
109
|
+
return output_port
|
|
110
|
+
return None
|
|
111
|
+
|
|
112
|
+
def get_trigger_node(self) -> Trigger:
|
|
113
|
+
trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
|
|
114
|
+
if not trigger_nodes:
|
|
115
|
+
raise ValueError("No trigger nodes found in workflow.")
|
|
116
|
+
if len(trigger_nodes) > 1:
|
|
117
|
+
raise ValueError("Multiple trigger nodes found in workflow.")
|
|
118
|
+
return trigger_nodes[0]
|
|
119
|
+
|
|
120
|
+
async def _run_node_with_callbacks(self, node: Node) -> None:
|
|
121
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
result = node.run()
|
|
125
|
+
if hasattr(result, '__anext__'):
|
|
126
|
+
await self.handle_node_stream_output(node, result)
|
|
127
|
+
elif asyncio.iscoroutine(result):
|
|
128
|
+
await result
|
|
129
|
+
finally:
|
|
130
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
|
|
131
|
+
|
|
132
|
+
async def run_after_trigger(self) -> Any:
|
|
133
|
+
logger.info(f"Running workflow: {self.name}")
|
|
134
|
+
|
|
135
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_START, workflow=self)
|
|
136
|
+
|
|
137
|
+
self.ready_nodes = []
|
|
138
|
+
for edge in self.get_trigger_node().output_edges:
|
|
139
|
+
edge.end_port.trigger()
|
|
140
|
+
|
|
141
|
+
try:
|
|
142
|
+
for input_port in self.input_ports:
|
|
143
|
+
if input_port.value is not None:
|
|
144
|
+
input_port.port.node.fill_input(input_port.port, input_port.value)
|
|
145
|
+
|
|
146
|
+
for node in self.nodes:
|
|
147
|
+
for key in node.AUTO_FILL_INPUT_PORTS:
|
|
148
|
+
if key[0] not in [edge.end_port.name for edge in node.input_edges]:
|
|
149
|
+
node.fill_input_by_name(key[0], key[1])
|
|
150
|
+
|
|
151
|
+
while self.ready_nodes:
|
|
152
|
+
nodes = self.ready_nodes.copy()
|
|
153
|
+
self.ready_nodes = []
|
|
154
|
+
|
|
155
|
+
tasks = []
|
|
156
|
+
for node in nodes:
|
|
157
|
+
tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
|
|
158
|
+
|
|
159
|
+
await asyncio.gather(*tasks)
|
|
160
|
+
|
|
161
|
+
except Exception as e:
|
|
162
|
+
error_msg = f"Error in run_after_trigger: {str(e)}"
|
|
163
|
+
logger.error(error_msg)
|
|
164
|
+
raise e
|
|
165
|
+
|
|
166
|
+
if len(self.output_ports) > 0:
|
|
167
|
+
if len(self.output_ports) == 1:
|
|
168
|
+
if self.output_ports[0].is_prepared:
|
|
169
|
+
result = self.output_ports[0].value
|
|
170
|
+
else:
|
|
171
|
+
result = None
|
|
172
|
+
else:
|
|
173
|
+
result = {}
|
|
174
|
+
for port in self.output_ports:
|
|
175
|
+
if port.is_prepared:
|
|
176
|
+
result[port.name] = port.value
|
|
177
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=result)
|
|
178
|
+
else:
|
|
179
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=None)
|
|
180
|
+
|
|
181
|
+
async def _run(self, task_id: uuid.UUID, trigger_node: Trigger) -> None:
|
|
182
|
+
async with self.run_semaphore:
|
|
183
|
+
try:
|
|
184
|
+
new_workflow = self._clone(task_id, trigger_node)
|
|
185
|
+
await new_workflow.run_after_trigger()
|
|
186
|
+
# TODO: clear new_workflow
|
|
187
|
+
|
|
188
|
+
except Exception as e:
|
|
189
|
+
error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
|
|
190
|
+
logger.error(error_msg)
|
|
191
|
+
|
|
192
|
+
async def run(self):
|
|
193
|
+
tasks = []
|
|
194
|
+
trigger = self.get_trigger_node()
|
|
195
|
+
|
|
196
|
+
async for task_id in trigger.run():
|
|
197
|
+
tasks.append(asyncio.create_task(self._run(task_id, trigger)))
|
|
198
|
+
|
|
199
|
+
if tasks:
|
|
200
|
+
await asyncio.gather(*tasks)
|
|
201
|
+
|
|
202
|
+
async def stop(self):
|
|
203
|
+
trigger = self.get_trigger_node()
|
|
204
|
+
await trigger._stop()
|
|
205
|
+
|
|
206
|
+
def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
|
|
207
|
+
trigger = self.get_trigger_node()
|
|
208
|
+
task_id = uuid.uuid4()
|
|
209
|
+
for key, value in kwargs.items():
|
|
210
|
+
trigger.prepare_output_edges(key, value)
|
|
211
|
+
task = asyncio.create_task(self._run(task_id, trigger))
|
|
212
|
+
return task_id
|
|
213
|
+
|
|
214
|
+
async def handle_node_stream_output(
|
|
215
|
+
self,
|
|
216
|
+
node: Node,
|
|
217
|
+
stream: AsyncIterator[Any],
|
|
218
|
+
) -> None:
|
|
219
|
+
async for data in stream:
|
|
220
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=node, output=data)
|
|
221
|
+
|
|
222
|
+
# TODO: refactor this
|
|
223
|
+
async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
|
|
224
|
+
return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
|
|
225
|
+
|
|
226
|
+
def _clone(self, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
|
|
227
|
+
return workflow_clone(self, task_id, trigger_node)
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
from abc import abstractmethod
|
|
5
|
+
from typing import TYPE_CHECKING, Any
|
|
6
|
+
from typing_extensions import override
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from .workflow_event import WorkflowResult
|
|
9
|
+
from loguru import logger
|
|
10
|
+
|
|
11
|
+
if TYPE_CHECKING:
|
|
12
|
+
from .node import Node
|
|
13
|
+
from .workflow import Workflow
|
|
14
|
+
|
|
15
|
+
class CallbackEvent(Enum):
|
|
16
|
+
ON_WORKFLOW_START = "on_workflow_start"
|
|
17
|
+
ON_WORKFLOW_END = "on_workflow_end"
|
|
18
|
+
ON_WORKFLOW_ERROR = "on_workflow_error"
|
|
19
|
+
ON_NODE_START = "on_node_start"
|
|
20
|
+
ON_NODE_END = "on_node_end"
|
|
21
|
+
ON_NODE_OUTPUT = "on_node_output"
|
|
22
|
+
ON_NODE_STREAM_OUTPUT = "on_node_stream_output"
|
|
23
|
+
|
|
24
|
+
class WorkflowCallback:
|
|
25
|
+
@abstractmethod
|
|
26
|
+
async def on_workflow_start(self, workflow: Workflow) -> None:
|
|
27
|
+
...
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
async def on_workflow_end(self, workflow: Workflow, output: Any) -> None:
|
|
31
|
+
pass
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
|
|
35
|
+
pass
|
|
36
|
+
|
|
37
|
+
@abstractmethod
|
|
38
|
+
async def on_node_start(self, node: Node) -> None:
|
|
39
|
+
...
|
|
40
|
+
|
|
41
|
+
@abstractmethod
|
|
42
|
+
async def on_node_end(self, node: Node) -> None:
|
|
43
|
+
...
|
|
44
|
+
|
|
45
|
+
@abstractmethod
|
|
46
|
+
async def on_node_stream_output(self, node: Node, output: Any) -> None:
|
|
47
|
+
...
|
|
48
|
+
|
|
49
|
+
class BuiltinWorkflowCallback(WorkflowCallback):
|
|
50
|
+
def __init__(self):
|
|
51
|
+
self._websocket_manager = None
|
|
52
|
+
|
|
53
|
+
def _get_websocket_manager(self):
|
|
54
|
+
if self._websocket_manager is None:
|
|
55
|
+
from service_forge.api.routers.websocket.websocket_manager import websocket_manager
|
|
56
|
+
self._websocket_manager = websocket_manager
|
|
57
|
+
return self._websocket_manager
|
|
58
|
+
|
|
59
|
+
def _serialize_result(self, result: Any) -> Any:
|
|
60
|
+
try:
|
|
61
|
+
json.dumps(result)
|
|
62
|
+
return result
|
|
63
|
+
except (TypeError, ValueError):
|
|
64
|
+
return str(result)
|
|
65
|
+
|
|
66
|
+
@override
|
|
67
|
+
async def on_workflow_start(self, workflow: Workflow) -> None:
|
|
68
|
+
...
|
|
69
|
+
|
|
70
|
+
@override
|
|
71
|
+
async def on_workflow_end(self, workflow: Workflow, output: Any) -> None:
|
|
72
|
+
workflow_result = WorkflowResult(result=output, is_end=True, is_error=False)
|
|
73
|
+
|
|
74
|
+
if workflow.task_id in workflow.real_trigger_node.result_queues:
|
|
75
|
+
workflow.real_trigger_node.result_queues[workflow.task_id].put_nowait(workflow_result)
|
|
76
|
+
if workflow.task_id in workflow.real_trigger_node.stream_queues:
|
|
77
|
+
workflow.real_trigger_node.stream_queues[workflow.task_id].put_nowait(workflow_result)
|
|
78
|
+
|
|
79
|
+
try:
|
|
80
|
+
manager = self._get_websocket_manager()
|
|
81
|
+
message = {
|
|
82
|
+
"type": "workflow_end",
|
|
83
|
+
"task_id": str(workflow.task_id),
|
|
84
|
+
"result": self._serialize_result(output),
|
|
85
|
+
"is_end": True,
|
|
86
|
+
"is_error": False
|
|
87
|
+
}
|
|
88
|
+
await manager.send_to_task(workflow.task_id, message)
|
|
89
|
+
except Exception as e:
|
|
90
|
+
logger.error(f"发送 workflow_end 消息到 websocket 失败: {e}")
|
|
91
|
+
|
|
92
|
+
@override
|
|
93
|
+
async def on_workflow_error(self, workflow: Workflow, error: Any) -> None:
|
|
94
|
+
workflow_result = WorkflowResult(result=error, is_end=False, is_error=True)
|
|
95
|
+
|
|
96
|
+
if workflow.task_id in workflow.real_trigger_node.result_queues:
|
|
97
|
+
workflow.real_trigger_node.result_queues[workflow.task_id].put_nowait(workflow_result)
|
|
98
|
+
if workflow.task_id in workflow.real_trigger_node.stream_queues:
|
|
99
|
+
workflow.real_trigger_node.stream_queues[workflow.task_id].put_nowait(workflow_result)
|
|
100
|
+
|
|
101
|
+
try:
|
|
102
|
+
manager = self._get_websocket_manager()
|
|
103
|
+
message = {
|
|
104
|
+
"type": "workflow_error",
|
|
105
|
+
"task_id": str(workflow.task_id),
|
|
106
|
+
"error": self._serialize_result(error),
|
|
107
|
+
"is_end": False,
|
|
108
|
+
"is_error": True
|
|
109
|
+
}
|
|
110
|
+
await manager.send_to_task(workflow.task_id, message)
|
|
111
|
+
except Exception as e:
|
|
112
|
+
logger.error(f"发送 workflow_error 消息到 websocket 失败: {e}")
|
|
113
|
+
|
|
114
|
+
@override
|
|
115
|
+
async def on_node_start(self, node: Node) -> None:
|
|
116
|
+
...
|
|
117
|
+
|
|
118
|
+
@override
|
|
119
|
+
async def on_node_end(self, node: Node) -> None:
|
|
120
|
+
...
|
|
121
|
+
|
|
122
|
+
@override
|
|
123
|
+
async def on_node_stream_output(self, node: Node, output: Any) -> None:
|
|
124
|
+
workflow_result = WorkflowResult(result=output, is_end=False, is_error=False)
|
|
125
|
+
|
|
126
|
+
if node.workflow.task_id in node.workflow.real_trigger_node.stream_queues:
|
|
127
|
+
node.workflow.real_trigger_node.stream_queues[node.workflow.task_id].put_nowait(workflow_result)
|
|
128
|
+
|
|
129
|
+
try:
|
|
130
|
+
manager = self._get_websocket_manager()
|
|
131
|
+
message = {
|
|
132
|
+
"type": "node_stream_output",
|
|
133
|
+
"task_id": str(node.workflow.task_id),
|
|
134
|
+
"node": node.name,
|
|
135
|
+
"output": self._serialize_result(output),
|
|
136
|
+
"is_end": False,
|
|
137
|
+
"is_error": False
|
|
138
|
+
}
|
|
139
|
+
await manager.send_to_task(node.workflow.task_id, message)
|
|
140
|
+
except Exception as e:
|
|
141
|
+
logger.error(f"发送 node_stream_output 消息到 websocket 失败: {e}")
|
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import yaml
|
|
4
|
+
from typing import Any
|
|
5
|
+
from pydantic import BaseModel
|
|
6
|
+
|
|
7
|
+
DEFAULT_WORKFLOW_VERSION = "0"
|
|
8
|
+
|
|
9
|
+
class WorkflowNodeOutputConfig(BaseModel):
|
|
10
|
+
name: str
|
|
11
|
+
port: str
|
|
12
|
+
|
|
13
|
+
class WorkflowNodeArgConfig(BaseModel):
|
|
14
|
+
name: str
|
|
15
|
+
value: Any | None = None
|
|
16
|
+
|
|
17
|
+
class WorkflowInputPortConfig(BaseModel):
|
|
18
|
+
name: str
|
|
19
|
+
port: str
|
|
20
|
+
value: Any | None = None
|
|
21
|
+
|
|
22
|
+
class WorkflowOutputPortConfig(BaseModel):
|
|
23
|
+
name: str
|
|
24
|
+
port: str
|
|
25
|
+
|
|
26
|
+
class WorkflowNodeSubWorkflowConfig(BaseModel):
|
|
27
|
+
name: str
|
|
28
|
+
version: str = DEFAULT_WORKFLOW_VERSION
|
|
29
|
+
|
|
30
|
+
class WorkflowNodeSubWorkflowInputPortConfig(BaseModel):
|
|
31
|
+
name: str
|
|
32
|
+
port: str
|
|
33
|
+
value: Any | None = None
|
|
34
|
+
|
|
35
|
+
class WorkflowNodeConfig(BaseModel):
|
|
36
|
+
name: str
|
|
37
|
+
type: str
|
|
38
|
+
args: dict | None = None
|
|
39
|
+
outputs: dict | None = None
|
|
40
|
+
sub_workflows: list[WorkflowNodeSubWorkflowConfig] | None = None
|
|
41
|
+
sub_workflows_input_ports: list[WorkflowNodeSubWorkflowInputPortConfig] | None = None
|
|
42
|
+
|
|
43
|
+
class WorkflowConfig(BaseModel):
|
|
44
|
+
name: str
|
|
45
|
+
version: str = DEFAULT_WORKFLOW_VERSION
|
|
46
|
+
description: str
|
|
47
|
+
nodes: list[WorkflowNodeConfig]
|
|
48
|
+
inputs: list[WorkflowInputPortConfig] | None = None
|
|
49
|
+
outputs: list[WorkflowOutputPortConfig] | None = None
|
|
50
|
+
|
|
51
|
+
@classmethod
|
|
52
|
+
def from_yaml_file(cls, filepath: str) -> WorkflowConfig:
|
|
53
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
54
|
+
data = yaml.safe_load(f)
|
|
55
|
+
return cls(**data)
|
|
56
|
+
|
|
57
|
+
class WorkflowGroupConfig(BaseModel):
|
|
58
|
+
workflows: list[WorkflowConfig] | None
|
|
59
|
+
main_workflow_name: str | None
|
|
60
|
+
main_workflow_version: str | None = DEFAULT_WORKFLOW_VERSION
|
|
61
|
+
|
|
62
|
+
@classmethod
|
|
63
|
+
def from_yaml_file(cls, filepath: str) -> WorkflowGroupConfig:
|
|
64
|
+
with open(filepath, 'r', encoding='utf-8') as f:
|
|
65
|
+
data = yaml.safe_load(f)
|
|
66
|
+
return cls(**data)
|
|
@@ -0,0 +1,15 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
class WorkflowResult:
|
|
6
|
+
def __init__(
|
|
7
|
+
self,
|
|
8
|
+
result: Any | None,
|
|
9
|
+
is_end: bool,
|
|
10
|
+
is_error: bool,
|
|
11
|
+
) -> None:
|
|
12
|
+
# when is_end is True, result is from the output port of the workflow
|
|
13
|
+
self.result = result
|
|
14
|
+
self.is_end = is_end
|
|
15
|
+
self.is_error = is_error
|