service-forge 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- service_forge/api/http_api.py +138 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/api/websocket_api.py +86 -0
- service_forge/api/websocket_manager.py +425 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +119 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/main.py +121 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +111 -0
- service_forge/service_config.py +115 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +124 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +87 -0
- service_forge/sft/config/injector_default_files.py +97 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +125 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +2 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +74 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +13 -0
- service_forge/workflow/edge.py +31 -0
- service_forge/workflow/node.py +179 -0
- service_forge/workflow/nodes/__init__.py +7 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +61 -0
- service_forge/workflow/port.py +86 -0
- service_forge/workflow/trigger.py +20 -0
- service_forge/workflow/triggers/__init__.py +4 -0
- service_forge/workflow/triggers/fast_api_trigger.py +125 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/workflow.py +251 -0
- service_forge/workflow/workflow_factory.py +227 -0
- service_forge/workflow/workflow_group.py +23 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.0.dist-info/METADATA +93 -0
- service_forge-0.1.0.dist-info/RECORD +64 -0
- service_forge-0.1.0.dist-info/WHEEL +4 -0
- service_forge-0.1.0.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,125 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
from service_forge.workflow.trigger import Trigger
|
|
4
|
+
from typing import AsyncIterator, Any
|
|
5
|
+
from fastapi import FastAPI, Request
|
|
6
|
+
import uuid
|
|
7
|
+
import asyncio
|
|
8
|
+
import json
|
|
9
|
+
from service_forge.workflow.port import Port
|
|
10
|
+
from service_forge.utils.default_type_converter import type_converter
|
|
11
|
+
from service_forge.api.websocket_manager import websocket_manager
|
|
12
|
+
from fastapi import HTTPException
|
|
13
|
+
from google.protobuf.message import Message
|
|
14
|
+
from google.protobuf.json_format import MessageToJson
|
|
15
|
+
|
|
16
|
+
class FastAPITrigger(Trigger):
|
|
17
|
+
DEFAULT_INPUT_PORTS = [
|
|
18
|
+
Port("app", FastAPI),
|
|
19
|
+
Port("path", str),
|
|
20
|
+
Port("method", str),
|
|
21
|
+
Port("data_type", type),
|
|
22
|
+
]
|
|
23
|
+
|
|
24
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
25
|
+
Port("trigger", bool),
|
|
26
|
+
Port("user_id", int),
|
|
27
|
+
Port("data", Any),
|
|
28
|
+
]
|
|
29
|
+
|
|
30
|
+
def __init__(self, name: str):
|
|
31
|
+
super().__init__(name)
|
|
32
|
+
self.events = {}
|
|
33
|
+
self.is_setup_route = False
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def serialize_result(result: Any):
|
|
37
|
+
if isinstance(result, Message):
|
|
38
|
+
return MessageToJson(
|
|
39
|
+
result,
|
|
40
|
+
preserving_proto_field_name=True
|
|
41
|
+
)
|
|
42
|
+
try:
|
|
43
|
+
return json.dumps(result)
|
|
44
|
+
except Exception:
|
|
45
|
+
return result
|
|
46
|
+
|
|
47
|
+
async def handle_request(
|
|
48
|
+
self,
|
|
49
|
+
request: Request,
|
|
50
|
+
data_type: type,
|
|
51
|
+
extract_data_fn: callable[[Request], dict]
|
|
52
|
+
):
|
|
53
|
+
task_id = uuid.uuid4()
|
|
54
|
+
self.result_queues[task_id] = asyncio.Queue()
|
|
55
|
+
|
|
56
|
+
body_data = await extract_data_fn(request)
|
|
57
|
+
converted_data = data_type(**body_data)
|
|
58
|
+
|
|
59
|
+
client_id = (
|
|
60
|
+
body_data.get("client_id")
|
|
61
|
+
or request.query_params.get("client_id")
|
|
62
|
+
or request.headers.get("X-Client-ID")
|
|
63
|
+
)
|
|
64
|
+
if client_id:
|
|
65
|
+
workflow_name = getattr(self.workflow, "name", "Unknown")
|
|
66
|
+
steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
|
|
67
|
+
websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
|
|
68
|
+
|
|
69
|
+
self.trigger_queue.put_nowait({
|
|
70
|
+
"id": task_id,
|
|
71
|
+
"user_id": getattr(request.state, "user_id", None),
|
|
72
|
+
"data": converted_data,
|
|
73
|
+
})
|
|
74
|
+
|
|
75
|
+
result = await self.result_queues[task_id].get()
|
|
76
|
+
del self.result_queues[task_id]
|
|
77
|
+
|
|
78
|
+
if isinstance(result, HTTPException):
|
|
79
|
+
raise result
|
|
80
|
+
if isinstance(result, Exception):
|
|
81
|
+
raise HTTPException(status_code=500, detail=str(result))
|
|
82
|
+
|
|
83
|
+
return self.serialize_result(result)
|
|
84
|
+
|
|
85
|
+
def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type) -> None:
|
|
86
|
+
async def get_data(request: Request) -> dict:
|
|
87
|
+
return dict(request.query_params)
|
|
88
|
+
|
|
89
|
+
async def body_data(request: Request) -> dict:
|
|
90
|
+
raw = await request.body()
|
|
91
|
+
if not raw:
|
|
92
|
+
return {}
|
|
93
|
+
return json.loads(raw.decode("utf-8"))
|
|
94
|
+
|
|
95
|
+
extractor = get_data if method == "GET" else body_data
|
|
96
|
+
|
|
97
|
+
async def handler(request: Request):
|
|
98
|
+
return await self.handle_request(request, data_type, extractor)
|
|
99
|
+
|
|
100
|
+
if method == "GET":
|
|
101
|
+
app.get(path)(handler)
|
|
102
|
+
elif method == "POST":
|
|
103
|
+
app.post(path)(handler)
|
|
104
|
+
elif method == "PUT":
|
|
105
|
+
app.put(path)(handler)
|
|
106
|
+
elif method == "DELETE":
|
|
107
|
+
app.delete(path)(handler)
|
|
108
|
+
else:
|
|
109
|
+
raise ValueError(f"Invalid method {method}")
|
|
110
|
+
|
|
111
|
+
async def _run(self, app: FastAPI, path: str, method: str, data_type: type) -> AsyncIterator[bool]:
|
|
112
|
+
if not self.is_setup_route:
|
|
113
|
+
self._setup_route(app, path, method, data_type)
|
|
114
|
+
self.is_setup_route = True
|
|
115
|
+
|
|
116
|
+
while True:
|
|
117
|
+
try:
|
|
118
|
+
trigger = await self.trigger_queue.get()
|
|
119
|
+
self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
|
|
120
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
121
|
+
yield self.trigger(trigger['id'])
|
|
122
|
+
except Exception as e:
|
|
123
|
+
from loguru import logger
|
|
124
|
+
logger.error(f"Error in FastAPITrigger._run: {e}")
|
|
125
|
+
continue
|
|
@@ -0,0 +1,44 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import uuid
|
|
3
|
+
from typing import Any
|
|
4
|
+
from service_forge.workflow.trigger import Trigger
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
from service_forge.workflow.port import Port
|
|
7
|
+
from service_forge.api.kafka_api import KafkaApp
|
|
8
|
+
|
|
9
|
+
class KafkaAPITrigger(Trigger):
|
|
10
|
+
DEFAULT_INPUT_PORTS = [
|
|
11
|
+
Port("app", KafkaApp),
|
|
12
|
+
Port("topic", str),
|
|
13
|
+
Port("data_type", type),
|
|
14
|
+
Port("group_id", str),
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
18
|
+
Port("trigger", bool),
|
|
19
|
+
Port("data", Any),
|
|
20
|
+
]
|
|
21
|
+
|
|
22
|
+
def __init__(self, name: str):
|
|
23
|
+
super().__init__(name)
|
|
24
|
+
self.events = {}
|
|
25
|
+
self.is_setup_kafka_input = False
|
|
26
|
+
|
|
27
|
+
def _setup_kafka_input(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> None:
|
|
28
|
+
@app.kafka_input(topic, data_type, group_id)
|
|
29
|
+
async def handle_message(data):
|
|
30
|
+
task_id = uuid.uuid4()
|
|
31
|
+
self.trigger_queue.put_nowait({
|
|
32
|
+
"id": task_id,
|
|
33
|
+
"data": data,
|
|
34
|
+
})
|
|
35
|
+
|
|
36
|
+
async def _run(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> AsyncIterator[bool]:
|
|
37
|
+
if not self.is_setup_kafka_input:
|
|
38
|
+
self._setup_kafka_input(app, topic, data_type, group_id)
|
|
39
|
+
self.is_setup_kafka_input = True
|
|
40
|
+
|
|
41
|
+
while True:
|
|
42
|
+
trigger = await self.trigger_queue.get()
|
|
43
|
+
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
44
|
+
yield self.trigger(trigger['id'])
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from service_forge.workflow.node import Node
|
|
3
|
+
from service_forge.workflow.port import Port
|
|
4
|
+
from service_forge.workflow.trigger import Trigger
|
|
5
|
+
from typing import AsyncIterator
|
|
6
|
+
import uuid
|
|
7
|
+
|
|
8
|
+
class OnceTrigger(Trigger):
|
|
9
|
+
DEFAULT_INPUT_PORTS = [
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
13
|
+
Port("trigger", bool),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(self, name: str):
|
|
17
|
+
super().__init__(name)
|
|
18
|
+
|
|
19
|
+
async def _run(self) -> AsyncIterator[bool]:
|
|
20
|
+
yield self.trigger(uuid.uuid4())
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
from service_forge.workflow.node import Node
|
|
4
|
+
from service_forge.workflow.port import Port
|
|
5
|
+
from service_forge.workflow.trigger import Trigger
|
|
6
|
+
from typing import AsyncIterator
|
|
7
|
+
import uuid
|
|
8
|
+
|
|
9
|
+
class PeriodTrigger(Trigger):
|
|
10
|
+
DEFAULT_INPUT_PORTS = [
|
|
11
|
+
Port("TRIGGER", bool),
|
|
12
|
+
Port("period", float),
|
|
13
|
+
Port("times", int),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
17
|
+
Port("trigger", bool),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
def __init__(self, name: str):
|
|
21
|
+
super().__init__(name)
|
|
22
|
+
|
|
23
|
+
async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
|
|
24
|
+
for _ in range(times):
|
|
25
|
+
await asyncio.sleep(period)
|
|
26
|
+
yield uuid.uuid4()
|
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
from typing import AsyncIterator, Awaitable, Callable, Any
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from copy import deepcopy
|
|
8
|
+
from .node import Node
|
|
9
|
+
from .port import Port
|
|
10
|
+
from .trigger import Trigger
|
|
11
|
+
from .edge import Edge
|
|
12
|
+
from ..db.database import DatabaseManager
|
|
13
|
+
from ..api.websocket_manager import websocket_manager
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class Workflow:
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
name: str,
|
|
20
|
+
nodes: list[Node],
|
|
21
|
+
input_ports: list[Port],
|
|
22
|
+
output_ports: list[Port],
|
|
23
|
+
_handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None,
|
|
24
|
+
_handle_query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
25
|
+
database_manager: DatabaseManager = None,
|
|
26
|
+
max_concurrent_runs: int = 10,
|
|
27
|
+
) -> None:
|
|
28
|
+
self.name = name
|
|
29
|
+
self.nodes = nodes
|
|
30
|
+
self.ready_nodes = []
|
|
31
|
+
self.input_ports = input_ports
|
|
32
|
+
self.output_ports = output_ports
|
|
33
|
+
self._handle_stream_output = _handle_stream_output
|
|
34
|
+
self._handle_query_user = _handle_query_user
|
|
35
|
+
self.after_trigger_workflow = None
|
|
36
|
+
self.result_port = Port("result", Any)
|
|
37
|
+
self.database_manager = database_manager
|
|
38
|
+
self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
|
|
39
|
+
self._validate()
|
|
40
|
+
|
|
41
|
+
def add_nodes(self, nodes: list[Node]) -> None:
|
|
42
|
+
for node in nodes:
|
|
43
|
+
node.workflow = self
|
|
44
|
+
self.nodes.extend(nodes)
|
|
45
|
+
|
|
46
|
+
def remove_nodes(self, nodes: list[Node]) -> None:
|
|
47
|
+
for node in nodes:
|
|
48
|
+
self.nodes.remove(node)
|
|
49
|
+
|
|
50
|
+
def load_config(self) -> None:
|
|
51
|
+
...
|
|
52
|
+
|
|
53
|
+
def _validate(self) -> None:
|
|
54
|
+
# DAG
|
|
55
|
+
...
|
|
56
|
+
|
|
57
|
+
def get_input_port_by_name(self, name: str) -> Port:
|
|
58
|
+
for input_port in self.input_ports:
|
|
59
|
+
if input_port.name == name:
|
|
60
|
+
return input_port
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
def get_output_port_by_name(self, name: str) -> Port:
|
|
64
|
+
for output_port in self.output_ports:
|
|
65
|
+
if output_port.name == name:
|
|
66
|
+
return output_port
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
async def run_after_trigger(self, task_id: uuid.UUID = None) -> Any:
|
|
70
|
+
trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
|
|
71
|
+
if not trigger_nodes:
|
|
72
|
+
raise ValueError("No trigger nodes found in workflow.")
|
|
73
|
+
self.ready_nodes = []
|
|
74
|
+
for edge in trigger_nodes[0].output_edges:
|
|
75
|
+
edge.end_port.trigger()
|
|
76
|
+
|
|
77
|
+
logger.info(f"Running workflow: {self.name}")
|
|
78
|
+
result = None
|
|
79
|
+
try:
|
|
80
|
+
for input_port in self.input_ports:
|
|
81
|
+
if input_port.value is not None:
|
|
82
|
+
input_port.port.node.fill_input(input_port.port, input_port.value)
|
|
83
|
+
|
|
84
|
+
for node in self.nodes:
|
|
85
|
+
for key in node.AUTO_FILL_INPUT_PORTS:
|
|
86
|
+
if key[0] not in [edge.end_port.name for edge in node.input_edges]:
|
|
87
|
+
node.fill_input_by_name(key[0], key[1])
|
|
88
|
+
|
|
89
|
+
while self.ready_nodes:
|
|
90
|
+
tasks = []
|
|
91
|
+
nodes = self.ready_nodes.copy()
|
|
92
|
+
self.ready_nodes = []
|
|
93
|
+
for i, node in enumerate(nodes):
|
|
94
|
+
# 发送节点开始执行的WebSocket通知
|
|
95
|
+
if task_id:
|
|
96
|
+
# 更新当前步骤
|
|
97
|
+
websocket_manager.task_manager.update_current_step(task_id, i + 1)
|
|
98
|
+
await websocket_manager.send_executing(task_id, node.name)
|
|
99
|
+
|
|
100
|
+
result = node.run(task_id)
|
|
101
|
+
|
|
102
|
+
if hasattr(result, '__anext__'):
|
|
103
|
+
# 处理流输出
|
|
104
|
+
if self._handle_stream_output is None:
|
|
105
|
+
tasks.append(self.handle_stream_output(node.name, result))
|
|
106
|
+
else:
|
|
107
|
+
tasks.append(self._handle_stream_output(node.name, result))
|
|
108
|
+
elif asyncio.iscoroutine(result):
|
|
109
|
+
tasks.append(result)
|
|
110
|
+
|
|
111
|
+
# 发送每个输出端口的结果
|
|
112
|
+
if task_id:
|
|
113
|
+
for output_port in node.output_ports:
|
|
114
|
+
if output_port.is_prepared:
|
|
115
|
+
await websocket_manager.send_node_output(task_id, node.name, output_port.name, output_port.value)
|
|
116
|
+
|
|
117
|
+
if tasks:
|
|
118
|
+
results = await asyncio.gather(*tasks)
|
|
119
|
+
# 更新result为最后一个非None的结果
|
|
120
|
+
for res in results:
|
|
121
|
+
if res is not None:
|
|
122
|
+
result = res
|
|
123
|
+
|
|
124
|
+
# 在所有任务完成后,再次发送每个输出端口的结果
|
|
125
|
+
if task_id:
|
|
126
|
+
for node in nodes:
|
|
127
|
+
for output_port in node.output_ports:
|
|
128
|
+
if output_port.is_prepared:
|
|
129
|
+
await websocket_manager.send_node_output(task_id, node.name, output_port.name, output_port.value)
|
|
130
|
+
|
|
131
|
+
except Exception as e:
|
|
132
|
+
import traceback
|
|
133
|
+
error_msg = f"Error in run_after_trigger: {str(e)}"
|
|
134
|
+
logger.error(error_msg)
|
|
135
|
+
logger.error(traceback.format_exc())
|
|
136
|
+
# 发送执行错误的WebSocket通知
|
|
137
|
+
if task_id:
|
|
138
|
+
# 更新任务状态为失败
|
|
139
|
+
websocket_manager.task_manager.fail_task(task_id, error_msg)
|
|
140
|
+
await websocket_manager.send_execution_error(task_id, "workflow", error_msg)
|
|
141
|
+
raise e
|
|
142
|
+
|
|
143
|
+
# 发送工作流执行完成的WebSocket通知
|
|
144
|
+
if task_id:
|
|
145
|
+
# 更新任务状态为已完成
|
|
146
|
+
websocket_manager.task_manager.complete_task(task_id)
|
|
147
|
+
await websocket_manager.send_executed(task_id, "workflow", result)
|
|
148
|
+
|
|
149
|
+
if len(self.output_ports) > 0 and self.output_ports[0].is_prepared:
|
|
150
|
+
return self.output_ports[0].value
|
|
151
|
+
|
|
152
|
+
async def _run(self, uuid: uuid.UUID, trigger_node: Trigger) -> None:
|
|
153
|
+
async with self.run_semaphore:
|
|
154
|
+
try:
|
|
155
|
+
# 发送任务开始执行的WebSocket通知
|
|
156
|
+
await websocket_manager.send_execution_start(uuid)
|
|
157
|
+
|
|
158
|
+
new_workflow = self._clone()
|
|
159
|
+
result = await new_workflow.run_after_trigger(uuid)
|
|
160
|
+
if uuid in trigger_node.result_queues:
|
|
161
|
+
trigger_node.result_queues[uuid].put_nowait(result)
|
|
162
|
+
|
|
163
|
+
except Exception as e:
|
|
164
|
+
import traceback
|
|
165
|
+
error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
|
|
166
|
+
logger.error(error_msg)
|
|
167
|
+
# 发送执行错误的WebSocket通知
|
|
168
|
+
# 更新任务状态为失败
|
|
169
|
+
websocket_manager.task_manager.fail_task(uuid, error_msg)
|
|
170
|
+
await websocket_manager.send_execution_error(uuid, "workflow", error_msg)
|
|
171
|
+
if uuid in trigger_node.result_queues:
|
|
172
|
+
trigger_node.result_queues[uuid].put_nowait(e)
|
|
173
|
+
|
|
174
|
+
async def run(self):
|
|
175
|
+
trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
|
|
176
|
+
|
|
177
|
+
if len(trigger_nodes) == 0:
|
|
178
|
+
raise ValueError("No trigger nodes found in workflow.")
|
|
179
|
+
if len(trigger_nodes) > 1:
|
|
180
|
+
raise ValueError("Multiple trigger nodes found in workflow.")
|
|
181
|
+
|
|
182
|
+
trigger_node = trigger_nodes[0]
|
|
183
|
+
async for uuid in trigger_node.run():
|
|
184
|
+
asyncio.create_task(self._run(uuid, trigger_node))
|
|
185
|
+
|
|
186
|
+
async def handle_stream_output(self, node_name: str, stream: AsyncIterator[str]) -> None:
|
|
187
|
+
logger.info(f"[{node_name}] Starting stream output:")
|
|
188
|
+
buffer = []
|
|
189
|
+
async for char in stream:
|
|
190
|
+
buffer.append(char)
|
|
191
|
+
logger.info(f"[{node_name}] Received char: '{char}'")
|
|
192
|
+
|
|
193
|
+
complete_message = ''.join(buffer)
|
|
194
|
+
logger.info(f"[{node_name}] Complete message: '{complete_message}'")
|
|
195
|
+
|
|
196
|
+
async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
|
|
197
|
+
return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
|
|
198
|
+
|
|
199
|
+
def _clone(self) -> Workflow:
|
|
200
|
+
node_map: dict[Node, Node] = {node: node._simple_clone() for node in self.nodes}
|
|
201
|
+
|
|
202
|
+
port_map: dict[Port, Port] = {}
|
|
203
|
+
port_map.update({port: port._simple_clone(node_map) for port in self.input_ports})
|
|
204
|
+
port_map.update({port: port._simple_clone(node_map) for port in self.output_ports})
|
|
205
|
+
for node in self.nodes:
|
|
206
|
+
for port in node.input_ports:
|
|
207
|
+
if port not in port_map:
|
|
208
|
+
port_map[port] = port._simple_clone(node_map)
|
|
209
|
+
for port in node.output_ports:
|
|
210
|
+
if port not in port_map:
|
|
211
|
+
port_map[port] = port._simple_clone(node_map)
|
|
212
|
+
|
|
213
|
+
edge_map: dict[Edge, Edge] = {}
|
|
214
|
+
for node in self.nodes:
|
|
215
|
+
for edge in node.input_edges:
|
|
216
|
+
if edge not in edge_map:
|
|
217
|
+
edge_map[edge] = edge._simple_clone(node_map, port_map)
|
|
218
|
+
for edge in node.output_edges:
|
|
219
|
+
if edge not in edge_map:
|
|
220
|
+
edge_map[edge] = edge._simple_clone(node_map, port_map)
|
|
221
|
+
|
|
222
|
+
# fill port.port
|
|
223
|
+
for old_port, new_port in port_map.items():
|
|
224
|
+
if old_port.port is not None:
|
|
225
|
+
new_port.port = port_map[old_port.port]
|
|
226
|
+
|
|
227
|
+
|
|
228
|
+
# fill ports and edges in nodes
|
|
229
|
+
for old_node, new_node in node_map.items():
|
|
230
|
+
new_node.input_edges = [edge_map[edge] for edge in old_node.input_edges]
|
|
231
|
+
new_node.output_edges = [edge_map[edge] for edge in old_node.output_edges]
|
|
232
|
+
new_node.input_ports = [port_map[port] for port in old_node.input_ports]
|
|
233
|
+
new_node.output_ports = [port_map[port] for port in old_node.output_ports]
|
|
234
|
+
new_node.input_variables = {
|
|
235
|
+
port_map[port]: value for port, value in old_node.input_variables.items()
|
|
236
|
+
}
|
|
237
|
+
|
|
238
|
+
workflow = Workflow(
|
|
239
|
+
name=self.name,
|
|
240
|
+
nodes=[node_map[node] for node in self.nodes],
|
|
241
|
+
input_ports=[port_map[port] for port in self.input_ports],
|
|
242
|
+
output_ports=[port_map[port] for port in self.output_ports],
|
|
243
|
+
_handle_stream_output=self._handle_stream_output,
|
|
244
|
+
_handle_query_user=self._handle_query_user,
|
|
245
|
+
database_manager=self.database_manager,
|
|
246
|
+
)
|
|
247
|
+
|
|
248
|
+
for node in workflow.nodes:
|
|
249
|
+
node.workflow = workflow
|
|
250
|
+
|
|
251
|
+
return workflow
|
|
@@ -0,0 +1,227 @@
|
|
|
1
|
+
from omegaconf import OmegaConf
|
|
2
|
+
from typing import Callable, Awaitable, AsyncIterator, Any
|
|
3
|
+
from copy import deepcopy
|
|
4
|
+
from .workflow import Workflow
|
|
5
|
+
from .workflow_group import WorkflowGroup
|
|
6
|
+
from .node import Node
|
|
7
|
+
from .edge import Edge
|
|
8
|
+
from .port import Port, parse_port_name, create_workflow_input_port, create_sub_workflow_input_port, create_port
|
|
9
|
+
from .node import node_register
|
|
10
|
+
from .nodes import *
|
|
11
|
+
from .triggers import *
|
|
12
|
+
from .context import Context
|
|
13
|
+
from ..db.database import DatabaseManager
|
|
14
|
+
|
|
15
|
+
WORKFLOW_KEY_NAME = 'name'
|
|
16
|
+
WORKFLOW_KEY_NODES = 'nodes'
|
|
17
|
+
WORKFLOW_KEY_INPUTS = 'inputs'
|
|
18
|
+
WORKFLOW_KEY_OUTPUTS = 'outputs'
|
|
19
|
+
|
|
20
|
+
NODE_KEY_NAME = 'name'
|
|
21
|
+
NODE_KEY_TYPE = 'type'
|
|
22
|
+
NODE_KEY_ARGS = 'args'
|
|
23
|
+
NODE_KEY_OUTPUTS = 'outputs'
|
|
24
|
+
NODE_KEY_INPUT_PORTS = 'input_ports'
|
|
25
|
+
NODE_KEY_OUTPUT_PORTS = 'output_ports'
|
|
26
|
+
NODE_KEY_SUB_WORKFLOWS = 'sub_workflows'
|
|
27
|
+
NODE_KEY_SUB_WORKFLOWS_INPUT_PORTS = 'sub_workflows_input_ports'
|
|
28
|
+
|
|
29
|
+
PORT_KEY_NAME = 'name'
|
|
30
|
+
PORT_KEY_PORT = 'port'
|
|
31
|
+
PORT_KEY_VALUE = 'value'
|
|
32
|
+
|
|
33
|
+
def parse_argument(arg: Any, service_env: dict[str, Any] = None) -> Any:
|
|
34
|
+
if type(arg) == str and arg.startswith(f'<{{') and arg.endswith(f'}}>'):
|
|
35
|
+
key = arg[2:-2]
|
|
36
|
+
if key not in service_env:
|
|
37
|
+
raise ValueError(f"Key {key} not found in service env.")
|
|
38
|
+
return service_env[key]
|
|
39
|
+
return arg
|
|
40
|
+
|
|
41
|
+
def create_workflow(
|
|
42
|
+
config_path: str = None,
|
|
43
|
+
service_env: dict[str, Any] = None,
|
|
44
|
+
config: dict = None,
|
|
45
|
+
workflows: WorkflowGroup = None,
|
|
46
|
+
_handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] | None = None,
|
|
47
|
+
_handle_query_user: Callable[[str, str], Awaitable[str]] | None = None,
|
|
48
|
+
database_manager: DatabaseManager = None,
|
|
49
|
+
) -> Workflow:
|
|
50
|
+
if config is None:
|
|
51
|
+
config = OmegaConf.to_object(OmegaConf.load(config_path))
|
|
52
|
+
|
|
53
|
+
if WORKFLOW_KEY_NAME not in config:
|
|
54
|
+
if config_path is None:
|
|
55
|
+
raise ValueError(f"{WORKFLOW_KEY_NAME} is required in workflow config in {config}.")
|
|
56
|
+
else:
|
|
57
|
+
raise ValueError(f"{WORKFLOW_KEY_NAME} is required in workflow config at {config_path}.")
|
|
58
|
+
|
|
59
|
+
workflow = Workflow(
|
|
60
|
+
name = config[WORKFLOW_KEY_NAME],
|
|
61
|
+
nodes = [],
|
|
62
|
+
input_ports = [],
|
|
63
|
+
output_ports = [],
|
|
64
|
+
_handle_stream_output = _handle_stream_output,
|
|
65
|
+
_handle_query_user = _handle_query_user,
|
|
66
|
+
database_manager = database_manager,
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
nodes: dict[str, Node] = {}
|
|
70
|
+
|
|
71
|
+
# Nodes
|
|
72
|
+
for node_config in config[WORKFLOW_KEY_NODES]:
|
|
73
|
+
params = {
|
|
74
|
+
"name": node_config[NODE_KEY_NAME],
|
|
75
|
+
}
|
|
76
|
+
|
|
77
|
+
node: Node = node_register.instance(node_config[NODE_KEY_TYPE], ignore_keys=['type'], kwargs=params)
|
|
78
|
+
|
|
79
|
+
# Context
|
|
80
|
+
node.context = Context(variables = {})
|
|
81
|
+
|
|
82
|
+
# Input ports
|
|
83
|
+
if node_key_input_ports := node_config.get(NODE_KEY_INPUT_PORTS, None):
|
|
84
|
+
node.input_ports = [Port(**port_params) for port_params in node_key_input_ports]
|
|
85
|
+
else:
|
|
86
|
+
node.input_ports = deepcopy(node.DEFAULT_INPUT_PORTS)
|
|
87
|
+
|
|
88
|
+
for input_port in node.input_ports:
|
|
89
|
+
input_port.node = node
|
|
90
|
+
|
|
91
|
+
# Output ports
|
|
92
|
+
if node_key_output_ports := node_config.get(NODE_KEY_OUTPUT_PORTS, None):
|
|
93
|
+
node.output_ports = [Port(**port_params) for port_params in node_key_output_ports]
|
|
94
|
+
else:
|
|
95
|
+
node.output_ports = deepcopy(node.DEFAULT_OUTPUT_PORTS)
|
|
96
|
+
|
|
97
|
+
for output_port in node.output_ports:
|
|
98
|
+
output_port.node = node
|
|
99
|
+
|
|
100
|
+
# Sub workflows
|
|
101
|
+
if node_key_sub_workflows := node_config.get(NODE_KEY_SUB_WORKFLOWS, None):
|
|
102
|
+
sub_workflows: WorkflowGroup = WorkflowGroup(workflows=[])
|
|
103
|
+
for sub_workflow_config in node_key_sub_workflows:
|
|
104
|
+
sub_workflow = workflows.get_workflow(sub_workflow_config['name'])
|
|
105
|
+
sub_workflows.add_workflow(deepcopy(sub_workflow))
|
|
106
|
+
node.sub_workflows = sub_workflows
|
|
107
|
+
|
|
108
|
+
# Sub workflows input ports
|
|
109
|
+
if node_key_sub_network_input_ports := node_config.get(NODE_KEY_SUB_WORKFLOWS_INPUT_PORTS, None):
|
|
110
|
+
for sub_workflow_input_port_config in node_key_sub_network_input_ports:
|
|
111
|
+
name = sub_workflow_input_port_config[PORT_KEY_NAME]
|
|
112
|
+
sub_workflow_name, sub_workflow_port_name = parse_port_name(sub_workflow_input_port_config[PORT_KEY_PORT])
|
|
113
|
+
sub_workflow = node.sub_workflows.get_workflow(sub_workflow_name)
|
|
114
|
+
if sub_workflow is None:
|
|
115
|
+
raise ValueError(f"{sub_workflow_name} is not a valid sub workflow.")
|
|
116
|
+
sub_workflow_port = sub_workflow.get_input_port_by_name(sub_workflow_port_name)
|
|
117
|
+
if sub_workflow_port is None:
|
|
118
|
+
raise ValueError(f"{sub_workflow_port_name} is not a valid input port.")
|
|
119
|
+
value = sub_workflow_input_port_config.get(PORT_KEY_VALUE, None)
|
|
120
|
+
node.input_ports.append(create_sub_workflow_input_port(name=name, node=node, port=sub_workflow_port, value=value))
|
|
121
|
+
|
|
122
|
+
# Sub workflows output ports
|
|
123
|
+
...
|
|
124
|
+
|
|
125
|
+
# Hooks
|
|
126
|
+
if _handle_query_user is None:
|
|
127
|
+
node.query_user = workflow.handle_query_user
|
|
128
|
+
else:
|
|
129
|
+
node.query_user = _handle_query_user
|
|
130
|
+
|
|
131
|
+
nodes[node_config[NODE_KEY_NAME]] = node
|
|
132
|
+
|
|
133
|
+
# Edges
|
|
134
|
+
for node_config in config[WORKFLOW_KEY_NODES]:
|
|
135
|
+
start_node = nodes[node_config[NODE_KEY_NAME]]
|
|
136
|
+
if NODE_KEY_OUTPUTS in node_config and node_config[NODE_KEY_OUTPUTS]:
|
|
137
|
+
for key, value in node_config[NODE_KEY_OUTPUTS].items():
|
|
138
|
+
if value is None:
|
|
139
|
+
continue
|
|
140
|
+
|
|
141
|
+
if type(value) is str:
|
|
142
|
+
value = [value]
|
|
143
|
+
|
|
144
|
+
for edge_value in value:
|
|
145
|
+
end_node_name, end_port_name = parse_port_name(edge_value)
|
|
146
|
+
end_node = nodes[end_node_name]
|
|
147
|
+
|
|
148
|
+
start_port = start_node.get_output_port_by_name(key)
|
|
149
|
+
end_port = end_node.get_input_port_by_name(end_port_name)
|
|
150
|
+
|
|
151
|
+
if start_port is None:
|
|
152
|
+
raise ValueError(f"{key} is not a valid output port.")
|
|
153
|
+
if end_port is None:
|
|
154
|
+
raise ValueError(f"{end_port_name} is not a valid input port.")
|
|
155
|
+
|
|
156
|
+
edge = Edge(start_node, end_node, start_port, end_port)
|
|
157
|
+
|
|
158
|
+
start_node.output_edges.append(edge)
|
|
159
|
+
end_node.input_edges.append(edge)
|
|
160
|
+
|
|
161
|
+
workflow.add_nodes(list(nodes.values()))
|
|
162
|
+
|
|
163
|
+
# Inputs
|
|
164
|
+
if workflow_key_inputs := config.get(WORKFLOW_KEY_INPUTS, None):
|
|
165
|
+
for port_config in workflow_key_inputs:
|
|
166
|
+
name = port_config[PORT_KEY_NAME]
|
|
167
|
+
node_name, node_port_name = parse_port_name(port_config[PORT_KEY_PORT])
|
|
168
|
+
if node_name not in nodes:
|
|
169
|
+
raise ValueError(f"{node_name} is not a valid node.")
|
|
170
|
+
node = nodes[node_name]
|
|
171
|
+
port = node.get_input_port_by_name(node_port_name)
|
|
172
|
+
if port is None:
|
|
173
|
+
raise ValueError(f"{node_port_name} is not a valid input port.")
|
|
174
|
+
value = port_config.get(PORT_KEY_VALUE, None)
|
|
175
|
+
workflow.input_ports.append(create_workflow_input_port(name=name, port=port, value=value))
|
|
176
|
+
|
|
177
|
+
# Outputs
|
|
178
|
+
if workflow_key_outputs := config.get(WORKFLOW_KEY_OUTPUTS, None):
|
|
179
|
+
for port_config in workflow_key_outputs:
|
|
180
|
+
name = port_config[PORT_KEY_NAME]
|
|
181
|
+
node_name, node_port_name = parse_port_name(port_config[PORT_KEY_PORT])
|
|
182
|
+
if node_name not in nodes:
|
|
183
|
+
raise ValueError(f"{node_name} is not a valid node.")
|
|
184
|
+
node = nodes[node_name]
|
|
185
|
+
port = node.get_output_port_by_name(node_port_name)
|
|
186
|
+
if port is None:
|
|
187
|
+
raise ValueError(f"{node_port_name} is not a valid output port.")
|
|
188
|
+
output_port = create_port(name=name, type=Any)
|
|
189
|
+
workflow.output_ports.append(output_port)
|
|
190
|
+
edge = Edge(node, None, port, output_port)
|
|
191
|
+
node.output_edges.append(edge)
|
|
192
|
+
|
|
193
|
+
for node_config in config[WORKFLOW_KEY_NODES]:
|
|
194
|
+
node = nodes[node_config[NODE_KEY_NAME]]
|
|
195
|
+
# Arguments
|
|
196
|
+
if node_key_args := node_config.get(NODE_KEY_ARGS, None):
|
|
197
|
+
for key, value in node_key_args.items():
|
|
198
|
+
node.fill_input_by_name(key, parse_argument(value, service_env=service_env))
|
|
199
|
+
|
|
200
|
+
return workflow
|
|
201
|
+
|
|
202
|
+
def create_workflows(
|
|
203
|
+
config_path: str,
|
|
204
|
+
service_env: dict[str, Any] = None,
|
|
205
|
+
_handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None,
|
|
206
|
+
_handle_query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
207
|
+
database_manager: DatabaseManager = None,
|
|
208
|
+
) -> WorkflowGroup:
|
|
209
|
+
WORKFLOW_KEY_WORKFLOWS = 'workflows'
|
|
210
|
+
WORKFLOW_KEY_MAIN_WORKFLOW_NAME = 'main'
|
|
211
|
+
|
|
212
|
+
config = OmegaConf.to_object(OmegaConf.load(config_path))
|
|
213
|
+
|
|
214
|
+
if WORKFLOW_KEY_WORKFLOWS not in config:
|
|
215
|
+
workflow = create_workflow(
|
|
216
|
+
config_path,
|
|
217
|
+
service_env=service_env,
|
|
218
|
+
_handle_stream_output=_handle_stream_output,
|
|
219
|
+
_handle_query_user=_handle_query_user,
|
|
220
|
+
database_manager=database_manager,
|
|
221
|
+
)
|
|
222
|
+
return WorkflowGroup(workflows=[workflow], main_workflow_name=workflow.name)
|
|
223
|
+
|
|
224
|
+
workflows = WorkflowGroup(workflows=[], main_workflow_name=config.get(WORKFLOW_KEY_MAIN_WORKFLOW_NAME, None))
|
|
225
|
+
for workflow_config in config[WORKFLOW_KEY_WORKFLOWS]:
|
|
226
|
+
workflows.add_workflow(create_workflow(config = workflow_config, workflows=workflows, _handle_stream_output=_handle_stream_output, _handle_query_user=_handle_query_user))
|
|
227
|
+
return workflows
|