service-forge 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +148 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/service/__init__.py +4 -0
- service_forge/api/routers/service/service_router.py +137 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +240 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +288 -0
- service_forge/service_config.py +158 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +119 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +153 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +24 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
- service_forge/workflow/triggers/fast_api_trigger.py +169 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
- service_forge/workflow/workflow.py +210 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +27 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.11.dist-info/METADATA +98 -0
- service_forge-0.1.11.dist-info/RECORD +75 -0
- service_forge-0.1.11.dist-info/WHEEL +4 -0
- service_forge-0.1.11.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from service_forge.workflow.trigger import Trigger
|
|
7
|
+
from typing import AsyncIterator, Any
|
|
8
|
+
from fastapi import FastAPI, Request
|
|
9
|
+
from fastapi.responses import StreamingResponse
|
|
10
|
+
from service_forge.workflow.port import Port
|
|
11
|
+
from service_forge.utils.default_type_converter import type_converter
|
|
12
|
+
from service_forge.api.routers.websocket.websocket_manager import websocket_manager
|
|
13
|
+
from fastapi import HTTPException
|
|
14
|
+
from google.protobuf.message import Message
|
|
15
|
+
from google.protobuf.json_format import MessageToJson
|
|
16
|
+
|
|
17
|
+
class FastAPITrigger(Trigger):
|
|
18
|
+
DEFAULT_INPUT_PORTS = [
|
|
19
|
+
Port("app", FastAPI),
|
|
20
|
+
Port("path", str),
|
|
21
|
+
Port("method", str),
|
|
22
|
+
Port("data_type", type),
|
|
23
|
+
Port("is_stream", bool),
|
|
24
|
+
]
|
|
25
|
+
|
|
26
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
27
|
+
Port("trigger", bool),
|
|
28
|
+
Port("user_id", int),
|
|
29
|
+
Port("data", Any),
|
|
30
|
+
]
|
|
31
|
+
|
|
32
|
+
def __init__(self, name: str):
|
|
33
|
+
super().__init__(name)
|
|
34
|
+
self.events = {}
|
|
35
|
+
self.is_setup_route = False
|
|
36
|
+
|
|
37
|
+
@staticmethod
|
|
38
|
+
def serialize_result(result: Any):
|
|
39
|
+
if isinstance(result, Message):
|
|
40
|
+
return MessageToJson(
|
|
41
|
+
result,
|
|
42
|
+
preserving_proto_field_name=True
|
|
43
|
+
)
|
|
44
|
+
return result
|
|
45
|
+
|
|
46
|
+
async def handle_request(
|
|
47
|
+
self,
|
|
48
|
+
request: Request,
|
|
49
|
+
data_type: type,
|
|
50
|
+
extract_data_fn: callable[[Request], dict],
|
|
51
|
+
is_stream: bool,
|
|
52
|
+
):
|
|
53
|
+
task_id = uuid.uuid4()
|
|
54
|
+
self.result_queues[task_id] = asyncio.Queue()
|
|
55
|
+
|
|
56
|
+
body_data = await extract_data_fn(request)
|
|
57
|
+
converted_data = data_type(**body_data)
|
|
58
|
+
|
|
59
|
+
client_id = (
|
|
60
|
+
body_data.get("client_id")
|
|
61
|
+
or request.query_params.get("client_id")
|
|
62
|
+
or request.headers.get("X-Client-ID")
|
|
63
|
+
)
|
|
64
|
+
if client_id:
|
|
65
|
+
workflow_name = getattr(self.workflow, "name", "Unknown")
|
|
66
|
+
steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
|
|
67
|
+
websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
|
|
68
|
+
|
|
69
|
+
self.trigger_queue.put_nowait({
|
|
70
|
+
"id": task_id,
|
|
71
|
+
"user_id": getattr(request.state, "user_id", None),
|
|
72
|
+
"data": converted_data,
|
|
73
|
+
})
|
|
74
|
+
|
|
75
|
+
if is_stream:
|
|
76
|
+
self.stream_queues[task_id] = asyncio.Queue()
|
|
77
|
+
|
|
78
|
+
async def generate_sse():
|
|
79
|
+
try:
|
|
80
|
+
while True:
|
|
81
|
+
item = await self.stream_queues[task_id].get()
|
|
82
|
+
|
|
83
|
+
if item.is_error:
|
|
84
|
+
yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
|
|
85
|
+
break
|
|
86
|
+
|
|
87
|
+
if item.is_end:
|
|
88
|
+
# TODO: send the result?
|
|
89
|
+
break
|
|
90
|
+
|
|
91
|
+
# TODO: modify
|
|
92
|
+
serialized = self.serialize_result(item.result)
|
|
93
|
+
if isinstance(serialized, str):
|
|
94
|
+
data = serialized
|
|
95
|
+
else:
|
|
96
|
+
data = json.dumps(serialized)
|
|
97
|
+
|
|
98
|
+
yield f"data: {data}\n\n"
|
|
99
|
+
|
|
100
|
+
except Exception as e:
|
|
101
|
+
yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
|
|
102
|
+
finally:
|
|
103
|
+
if task_id in self.stream_queues:
|
|
104
|
+
del self.stream_queues[task_id]
|
|
105
|
+
if task_id in self.result_queues:
|
|
106
|
+
del self.result_queues[task_id]
|
|
107
|
+
|
|
108
|
+
return StreamingResponse(
|
|
109
|
+
generate_sse(),
|
|
110
|
+
media_type="text/event-stream",
|
|
111
|
+
headers={
|
|
112
|
+
"Cache-Control": "no-cache",
|
|
113
|
+
"Connection": "keep-alive",
|
|
114
|
+
"X-Accel-Buffering": "no",
|
|
115
|
+
}
|
|
116
|
+
)
|
|
117
|
+
else:
|
|
118
|
+
result = await self.result_queues[task_id].get()
|
|
119
|
+
del self.result_queues[task_id]
|
|
120
|
+
|
|
121
|
+
if result.is_error:
|
|
122
|
+
if isinstance(result.result, HTTPException):
|
|
123
|
+
raise result.result
|
|
124
|
+
else:
|
|
125
|
+
raise HTTPException(status_code=500, detail=str(result.result))
|
|
126
|
+
|
|
127
|
+
return self.serialize_result(result.result)
|
|
128
|
+
|
|
129
|
+
def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
|
|
130
|
+
async def get_data(request: Request) -> dict:
|
|
131
|
+
return dict(request.query_params)
|
|
132
|
+
|
|
133
|
+
async def body_data(request: Request) -> dict:
|
|
134
|
+
raw = await request.body()
|
|
135
|
+
if not raw:
|
|
136
|
+
return {}
|
|
137
|
+
return json.loads(raw.decode("utf-8"))
|
|
138
|
+
|
|
139
|
+
extractor = get_data if method == "GET" else body_data
|
|
140
|
+
|
|
141
|
+
async def handler(request: Request):
|
|
142
|
+
return await self.handle_request(request, data_type, extractor, is_stream)
|
|
143
|
+
|
|
144
|
+
if method == "GET":
|
|
145
|
+
app.get(path)(handler)
|
|
146
|
+
elif method == "POST":
|
|
147
|
+
app.post(path)(handler)
|
|
148
|
+
elif method == "PUT":
|
|
149
|
+
app.put(path)(handler)
|
|
150
|
+
elif method == "DELETE":
|
|
151
|
+
app.delete(path)(handler)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(f"Invalid method {method}")
|
|
154
|
+
|
|
155
|
+
async def _run(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool = False) -> AsyncIterator[bool]:
|
|
156
|
+
if not self.is_setup_route:
|
|
157
|
+
self._setup_route(app, path, method, data_type, is_stream)
|
|
158
|
+
self.is_setup_route = True
|
|
159
|
+
|
|
160
|
+
while True:
|
|
161
|
+
try:
|
|
162
|
+
trigger = await self.trigger_queue.get()
|
|
163
|
+
self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
|
|
164
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
165
|
+
yield self.trigger(trigger['id'])
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"Error in FastAPITrigger._run: {e}")
|
|
168
|
+
continue
|
|
169
|
+
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import Any
|
|
4
|
+
from service_forge.workflow.trigger import Trigger
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
from service_forge.workflow.port import Port
|
|
7
|
+
from service_forge.api.kafka_api import KafkaApp
|
|
8
|
+
|
|
9
|
+
class KafkaAPITrigger(Trigger):
|
|
10
|
+
DEFAULT_INPUT_PORTS = [
|
|
11
|
+
Port("app", KafkaApp),
|
|
12
|
+
Port("topic", str),
|
|
13
|
+
Port("data_type", type),
|
|
14
|
+
Port("group_id", str),
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
18
|
+
Port("trigger", bool),
|
|
19
|
+
Port("data", Any),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
def __init__(self, name: str):
|
|
23
|
+
super().__init__(name)
|
|
24
|
+
self.events = {}
|
|
25
|
+
self.is_setup_kafka_input = False
|
|
26
|
+
|
|
27
|
+
def _setup_kafka_input(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> None:
|
|
28
|
+
@app.kafka_input(topic, data_type, group_id)
|
|
29
|
+
async def handle_message(data):
|
|
30
|
+
task_id = uuid.uuid4()
|
|
31
|
+
self.trigger_queue.put_nowait({
|
|
32
|
+
"id": task_id,
|
|
33
|
+
"data": data,
|
|
34
|
+
})
|
|
35
|
+
|
|
36
|
+
async def _run(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> AsyncIterator[bool]:
|
|
37
|
+
if not self.is_setup_kafka_input:
|
|
38
|
+
self._setup_kafka_input(app, topic, data_type, group_id)
|
|
39
|
+
self.is_setup_kafka_input = True
|
|
40
|
+
|
|
41
|
+
while True:
|
|
42
|
+
trigger = await self.trigger_queue.get()
|
|
43
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
44
|
+
yield self.trigger(trigger['id'])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from service_forge.workflow.node import Node
|
|
3
|
+
from service_forge.workflow.port import Port
|
|
4
|
+
from service_forge.workflow.trigger import Trigger
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
class OnceTrigger(Trigger):
|
|
9
|
+
DEFAULT_INPUT_PORTS = [
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
13
|
+
Port("trigger", bool),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(self, name: str):
|
|
17
|
+
super().__init__(name)
|
|
18
|
+
|
|
19
|
+
async def _run(self) -> AsyncIterator[bool]:
|
|
20
|
+
yield self.trigger(uuid.uuid4())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
from service_forge.workflow.node import Node
|
|
4
|
+
from service_forge.workflow.port import Port
|
|
5
|
+
from service_forge.workflow.trigger import Trigger
|
|
6
|
+
from typing import AsyncIterator
|
|
7
|
+
import uuid
|
|
8
|
+
|
|
9
|
+
class PeriodTrigger(Trigger):
|
|
10
|
+
DEFAULT_INPUT_PORTS = [
|
|
11
|
+
Port("TRIGGER", bool),
|
|
12
|
+
Port("period", float),
|
|
13
|
+
Port("times", int),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
17
|
+
Port("trigger", bool),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
def __init__(self, name: str):
|
|
21
|
+
super().__init__(name)
|
|
22
|
+
|
|
23
|
+
async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
|
|
24
|
+
for _ in range(times):
|
|
25
|
+
await asyncio.sleep(period)
|
|
26
|
+
yield uuid.uuid4()
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
import asyncio
|
|
4
|
+
import json
|
|
5
|
+
from loguru import logger
|
|
6
|
+
from service_forge.workflow.trigger import Trigger
|
|
7
|
+
from typing import AsyncIterator, Any
|
|
8
|
+
from fastapi import FastAPI, WebSocket, WebSocketDisconnect
|
|
9
|
+
from service_forge.workflow.port import Port
|
|
10
|
+
from google.protobuf.message import Message
|
|
11
|
+
from google.protobuf.json_format import MessageToJson
|
|
12
|
+
|
|
13
|
+
class WebSocketAPITrigger(Trigger):
|
|
14
|
+
DEFAULT_INPUT_PORTS = [
|
|
15
|
+
Port("app", FastAPI),
|
|
16
|
+
Port("path", str),
|
|
17
|
+
Port("data_type", type),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
21
|
+
Port("trigger", bool),
|
|
22
|
+
Port("data", Any),
|
|
23
|
+
]
|
|
24
|
+
|
|
25
|
+
def __init__(self, name: str):
|
|
26
|
+
super().__init__(name)
|
|
27
|
+
self.events = {}
|
|
28
|
+
self.is_setup_websocket = False
|
|
29
|
+
|
|
30
|
+
@staticmethod
|
|
31
|
+
def serialize_result(result: Any):
|
|
32
|
+
if isinstance(result, Message):
|
|
33
|
+
return MessageToJson(
|
|
34
|
+
result,
|
|
35
|
+
preserving_proto_field_name=True
|
|
36
|
+
)
|
|
37
|
+
return result
|
|
38
|
+
|
|
39
|
+
async def handle_stream_output(
|
|
40
|
+
self,
|
|
41
|
+
websocket: WebSocket,
|
|
42
|
+
task_id: uuid.UUID,
|
|
43
|
+
):
|
|
44
|
+
try:
|
|
45
|
+
while True:
|
|
46
|
+
item = await self.stream_queues[task_id].get()
|
|
47
|
+
|
|
48
|
+
if item.is_error:
|
|
49
|
+
error_response = {
|
|
50
|
+
"type": "stream_error",
|
|
51
|
+
"task_id": str(task_id),
|
|
52
|
+
"detail": str(item.result)
|
|
53
|
+
}
|
|
54
|
+
await websocket.send_text(json.dumps(error_response))
|
|
55
|
+
break
|
|
56
|
+
|
|
57
|
+
if item.is_end:
|
|
58
|
+
# Stream ended, send final result if available
|
|
59
|
+
if item.result is not None:
|
|
60
|
+
serialized = self.serialize_result(item.result)
|
|
61
|
+
if isinstance(serialized, str):
|
|
62
|
+
try:
|
|
63
|
+
data = json.loads(serialized)
|
|
64
|
+
except json.JSONDecodeError:
|
|
65
|
+
data = serialized
|
|
66
|
+
else:
|
|
67
|
+
data = serialized
|
|
68
|
+
|
|
69
|
+
end_response = {
|
|
70
|
+
"type": "stream_end",
|
|
71
|
+
"task_id": str(task_id),
|
|
72
|
+
"data": data
|
|
73
|
+
}
|
|
74
|
+
else:
|
|
75
|
+
end_response = {
|
|
76
|
+
"type": "stream_end",
|
|
77
|
+
"task_id": str(task_id)
|
|
78
|
+
}
|
|
79
|
+
await websocket.send_text(json.dumps(end_response))
|
|
80
|
+
break
|
|
81
|
+
|
|
82
|
+
# Send stream data
|
|
83
|
+
serialized = self.serialize_result(item.result)
|
|
84
|
+
if isinstance(serialized, str):
|
|
85
|
+
try:
|
|
86
|
+
data = json.loads(serialized)
|
|
87
|
+
except json.JSONDecodeError:
|
|
88
|
+
data = serialized
|
|
89
|
+
else:
|
|
90
|
+
data = serialized
|
|
91
|
+
|
|
92
|
+
stream_response = {
|
|
93
|
+
"type": "stream",
|
|
94
|
+
"task_id": str(task_id),
|
|
95
|
+
"data": data
|
|
96
|
+
}
|
|
97
|
+
await websocket.send_text(json.dumps(stream_response))
|
|
98
|
+
except Exception as e:
|
|
99
|
+
logger.error(f"Error handling stream output for task {task_id}: {e}")
|
|
100
|
+
error_response = {
|
|
101
|
+
"type": "stream_error",
|
|
102
|
+
"task_id": str(task_id),
|
|
103
|
+
"detail": str(e)
|
|
104
|
+
}
|
|
105
|
+
try:
|
|
106
|
+
await websocket.send_text(json.dumps(error_response))
|
|
107
|
+
except Exception:
|
|
108
|
+
pass
|
|
109
|
+
finally:
|
|
110
|
+
if task_id in self.stream_queues:
|
|
111
|
+
del self.stream_queues[task_id]
|
|
112
|
+
|
|
113
|
+
async def handle_websocket_message(
|
|
114
|
+
self,
|
|
115
|
+
websocket: WebSocket,
|
|
116
|
+
data_type: type,
|
|
117
|
+
message_data: dict,
|
|
118
|
+
):
|
|
119
|
+
task_id = uuid.uuid4()
|
|
120
|
+
self.result_queues[task_id] = asyncio.Queue()
|
|
121
|
+
self.stream_queues[task_id] = asyncio.Queue()
|
|
122
|
+
|
|
123
|
+
try:
|
|
124
|
+
converted_data = data_type(**message_data)
|
|
125
|
+
except Exception as e:
|
|
126
|
+
error_msg = {"error": f"Failed to convert data: {str(e)}"}
|
|
127
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
128
|
+
return
|
|
129
|
+
|
|
130
|
+
# Always start background task to handle stream output
|
|
131
|
+
asyncio.create_task(self.handle_stream_output(websocket, task_id))
|
|
132
|
+
|
|
133
|
+
self.trigger_queue.put_nowait({
|
|
134
|
+
"id": task_id,
|
|
135
|
+
"data": converted_data,
|
|
136
|
+
})
|
|
137
|
+
|
|
138
|
+
# The stream handler will send all messages including stream_end when workflow completes
|
|
139
|
+
|
|
140
|
+
def _setup_websocket(self, app: FastAPI, path: str, data_type: type) -> None:
|
|
141
|
+
async def websocket_handler(websocket: WebSocket):
|
|
142
|
+
await websocket.accept()
|
|
143
|
+
|
|
144
|
+
try:
|
|
145
|
+
while True:
|
|
146
|
+
# Receive message from client
|
|
147
|
+
data = await websocket.receive_text()
|
|
148
|
+
try:
|
|
149
|
+
message = json.loads(data)
|
|
150
|
+
|
|
151
|
+
# Handle the message and trigger workflow
|
|
152
|
+
await self.handle_websocket_message(
|
|
153
|
+
websocket,
|
|
154
|
+
data_type,
|
|
155
|
+
message
|
|
156
|
+
)
|
|
157
|
+
except json.JSONDecodeError:
|
|
158
|
+
error_msg = {"error": "Invalid JSON format"}
|
|
159
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
160
|
+
except Exception as e:
|
|
161
|
+
logger.error(f"Error handling websocket message: {e}")
|
|
162
|
+
error_msg = {"error": str(e)}
|
|
163
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
164
|
+
except WebSocketDisconnect:
|
|
165
|
+
logger.info("WebSocket client disconnected")
|
|
166
|
+
except Exception as e:
|
|
167
|
+
logger.error(f"WebSocket connection error: {e}")
|
|
168
|
+
|
|
169
|
+
app.websocket(path)(websocket_handler)
|
|
170
|
+
|
|
171
|
+
async def _run(self, app: FastAPI, path: str, data_type: type) -> AsyncIterator[bool]:
|
|
172
|
+
if not self.is_setup_websocket:
|
|
173
|
+
self._setup_websocket(app, path, data_type)
|
|
174
|
+
self.is_setup_websocket = True
|
|
175
|
+
|
|
176
|
+
while True:
|
|
177
|
+
try:
|
|
178
|
+
trigger = await self.trigger_queue.get()
|
|
179
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
180
|
+
yield self.trigger(trigger['id'])
|
|
181
|
+
except Exception as e:
|
|
182
|
+
logger.error(f"Error in WebSocketAPITrigger._run: {e}")
|
|
183
|
+
continue
|
|
184
|
+
|
|
@@ -0,0 +1,210 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import traceback
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import AsyncIterator, Awaitable, Callable, Any
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from .node import Node
|
|
9
|
+
from .port import Port
|
|
10
|
+
from .trigger import Trigger
|
|
11
|
+
from .edge import Edge
|
|
12
|
+
from ..db.database import DatabaseManager
|
|
13
|
+
from ..utils.workflow_clone import workflow_clone
|
|
14
|
+
from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
|
|
15
|
+
|
|
16
|
+
class Workflow:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
name: str,
|
|
20
|
+
description: str,
|
|
21
|
+
nodes: list[Node],
|
|
22
|
+
input_ports: list[Port],
|
|
23
|
+
output_ports: list[Port],
|
|
24
|
+
_handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None, # deprecated
|
|
25
|
+
_handle_query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
26
|
+
database_manager: DatabaseManager = None,
|
|
27
|
+
max_concurrent_runs: int = 10,
|
|
28
|
+
callbacks: list[WorkflowCallback] = [],
|
|
29
|
+
|
|
30
|
+
# for run
|
|
31
|
+
task_id: uuid.UUID = None,
|
|
32
|
+
real_trigger_node: Trigger = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
self.name = name
|
|
35
|
+
self.description = description
|
|
36
|
+
self.nodes = nodes
|
|
37
|
+
self.ready_nodes: list[Node] = []
|
|
38
|
+
self.input_ports = input_ports
|
|
39
|
+
self.output_ports = output_ports
|
|
40
|
+
self._handle_stream_output = _handle_stream_output
|
|
41
|
+
self._handle_query_user = _handle_query_user
|
|
42
|
+
self.after_trigger_workflow = None
|
|
43
|
+
self.result_port = Port("result", Any)
|
|
44
|
+
self.database_manager = database_manager
|
|
45
|
+
self.max_concurrent_runs = max_concurrent_runs
|
|
46
|
+
self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
|
|
47
|
+
self.callbacks = callbacks
|
|
48
|
+
self.task_id = task_id
|
|
49
|
+
self.real_trigger_node = real_trigger_node
|
|
50
|
+
self._validate()
|
|
51
|
+
|
|
52
|
+
def register_callback(self, callback: WorkflowCallback) -> None:
|
|
53
|
+
self.callbacks.append(callback)
|
|
54
|
+
|
|
55
|
+
def unregister_callback(self, callback: WorkflowCallback) -> None:
|
|
56
|
+
self.callbacks.remove(callback)
|
|
57
|
+
|
|
58
|
+
async def call_callbacks(self, callback_type: CallbackEvent, *args, **kwargs) -> None:
|
|
59
|
+
for callback in self.callbacks:
|
|
60
|
+
if callback_type == CallbackEvent.ON_WORKFLOW_START:
|
|
61
|
+
await callback.on_workflow_start(*args, **kwargs)
|
|
62
|
+
elif callback_type == CallbackEvent.ON_WORKFLOW_END:
|
|
63
|
+
await callback.on_workflow_end(*args, **kwargs)
|
|
64
|
+
elif callback_type == CallbackEvent.ON_NODE_START:
|
|
65
|
+
await callback.on_node_start(*args, **kwargs)
|
|
66
|
+
elif callback_type == CallbackEvent.ON_NODE_END:
|
|
67
|
+
await callback.on_node_end(*args, **kwargs)
|
|
68
|
+
elif callback_type == CallbackEvent.ON_NODE_STREAM_OUTPUT:
|
|
69
|
+
await callback.on_node_stream_output(*args, **kwargs)
|
|
70
|
+
|
|
71
|
+
def add_nodes(self, nodes: list[Node]) -> None:
|
|
72
|
+
for node in nodes:
|
|
73
|
+
node.workflow = self
|
|
74
|
+
self.nodes.extend(nodes)
|
|
75
|
+
|
|
76
|
+
def remove_nodes(self, nodes: list[Node]) -> None:
|
|
77
|
+
for node in nodes:
|
|
78
|
+
self.nodes.remove(node)
|
|
79
|
+
|
|
80
|
+
def load_config(self) -> None:
|
|
81
|
+
...
|
|
82
|
+
|
|
83
|
+
def _validate(self) -> None:
|
|
84
|
+
# DAG
|
|
85
|
+
...
|
|
86
|
+
|
|
87
|
+
def get_input_port_by_name(self, name: str) -> Port:
|
|
88
|
+
for input_port in self.input_ports:
|
|
89
|
+
if input_port.name == name:
|
|
90
|
+
return input_port
|
|
91
|
+
return None
|
|
92
|
+
|
|
93
|
+
def get_output_port_by_name(self, name: str) -> Port:
|
|
94
|
+
for output_port in self.output_ports:
|
|
95
|
+
if output_port.name == name:
|
|
96
|
+
return output_port
|
|
97
|
+
return None
|
|
98
|
+
|
|
99
|
+
def get_trigger_node(self) -> Trigger:
|
|
100
|
+
trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
|
|
101
|
+
if not trigger_nodes:
|
|
102
|
+
raise ValueError("No trigger nodes found in workflow.")
|
|
103
|
+
if len(trigger_nodes) > 1:
|
|
104
|
+
raise ValueError("Multiple trigger nodes found in workflow.")
|
|
105
|
+
return trigger_nodes[0]
|
|
106
|
+
|
|
107
|
+
async def _run_node_with_callbacks(self, node: Node) -> None:
|
|
108
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
|
|
109
|
+
|
|
110
|
+
try:
|
|
111
|
+
result = node.run()
|
|
112
|
+
if hasattr(result, '__anext__'):
|
|
113
|
+
await self.handle_node_stream_output(node, result)
|
|
114
|
+
elif asyncio.iscoroutine(result):
|
|
115
|
+
await result
|
|
116
|
+
finally:
|
|
117
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
|
|
118
|
+
|
|
119
|
+
async def run_after_trigger(self) -> Any:
|
|
120
|
+
logger.info(f"Running workflow: {self.name}")
|
|
121
|
+
|
|
122
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_START, workflow=self)
|
|
123
|
+
|
|
124
|
+
self.ready_nodes = []
|
|
125
|
+
for edge in self.get_trigger_node().output_edges:
|
|
126
|
+
edge.end_port.trigger()
|
|
127
|
+
|
|
128
|
+
try:
|
|
129
|
+
for input_port in self.input_ports:
|
|
130
|
+
if input_port.value is not None:
|
|
131
|
+
input_port.port.node.fill_input(input_port.port, input_port.value)
|
|
132
|
+
|
|
133
|
+
for node in self.nodes:
|
|
134
|
+
for key in node.AUTO_FILL_INPUT_PORTS:
|
|
135
|
+
if key[0] not in [edge.end_port.name for edge in node.input_edges]:
|
|
136
|
+
node.fill_input_by_name(key[0], key[1])
|
|
137
|
+
|
|
138
|
+
while self.ready_nodes:
|
|
139
|
+
nodes = self.ready_nodes.copy()
|
|
140
|
+
self.ready_nodes = []
|
|
141
|
+
|
|
142
|
+
tasks = []
|
|
143
|
+
for node in nodes:
|
|
144
|
+
tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
|
|
145
|
+
|
|
146
|
+
await asyncio.gather(*tasks)
|
|
147
|
+
|
|
148
|
+
except Exception as e:
|
|
149
|
+
error_msg = f"Error in run_after_trigger: {str(e)}"
|
|
150
|
+
logger.error(error_msg)
|
|
151
|
+
raise e
|
|
152
|
+
|
|
153
|
+
if len(self.output_ports) > 0:
|
|
154
|
+
if len(self.output_ports) == 1:
|
|
155
|
+
if self.output_ports[0].is_prepared:
|
|
156
|
+
result = self.output_ports[0].value
|
|
157
|
+
else:
|
|
158
|
+
result = None
|
|
159
|
+
else:
|
|
160
|
+
result = {}
|
|
161
|
+
for port in self.output_ports:
|
|
162
|
+
if port.is_prepared:
|
|
163
|
+
result[port.name] = port.value
|
|
164
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=result)
|
|
165
|
+
else:
|
|
166
|
+
await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=None)
|
|
167
|
+
|
|
168
|
+
async def _run(self, task_id: uuid.UUID, trigger_node: Trigger) -> None:
|
|
169
|
+
async with self.run_semaphore:
|
|
170
|
+
try:
|
|
171
|
+
new_workflow = self._clone(task_id, trigger_node)
|
|
172
|
+
await new_workflow.run_after_trigger()
|
|
173
|
+
# TODO: clear new_workflow
|
|
174
|
+
|
|
175
|
+
except Exception as e:
|
|
176
|
+
error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
|
|
177
|
+
logger.error(error_msg)
|
|
178
|
+
|
|
179
|
+
async def run(self):
|
|
180
|
+
tasks = []
|
|
181
|
+
trigger = self.get_trigger_node()
|
|
182
|
+
|
|
183
|
+
async for task_id in trigger.run():
|
|
184
|
+
tasks.append(asyncio.create_task(self._run(task_id, trigger)))
|
|
185
|
+
|
|
186
|
+
if tasks:
|
|
187
|
+
await asyncio.gather(*tasks)
|
|
188
|
+
|
|
189
|
+
def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
|
|
190
|
+
trigger = self.get_trigger_node()
|
|
191
|
+
task_id = uuid.uuid4()
|
|
192
|
+
for key, value in kwargs.items():
|
|
193
|
+
trigger.prepare_output_edges(key, value)
|
|
194
|
+
task = asyncio.create_task(self._run(task_id, trigger))
|
|
195
|
+
return task_id
|
|
196
|
+
|
|
197
|
+
async def handle_node_stream_output(
|
|
198
|
+
self,
|
|
199
|
+
node: Node,
|
|
200
|
+
stream: AsyncIterator[Any],
|
|
201
|
+
) -> None:
|
|
202
|
+
async for data in stream:
|
|
203
|
+
await self.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=node, output=data)
|
|
204
|
+
|
|
205
|
+
# TODO: refactor this
|
|
206
|
+
async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
|
|
207
|
+
return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
|
|
208
|
+
|
|
209
|
+
def _clone(self, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
|
|
210
|
+
return workflow_clone(self, task_id, trigger_node)
|