service-forge 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (64) hide show
  1. service_forge/api/http_api.py +138 -0
  2. service_forge/api/kafka_api.py +126 -0
  3. service_forge/api/task_manager.py +141 -0
  4. service_forge/api/websocket_api.py +86 -0
  5. service_forge/api/websocket_manager.py +425 -0
  6. service_forge/db/__init__.py +1 -0
  7. service_forge/db/database.py +119 -0
  8. service_forge/llm/__init__.py +62 -0
  9. service_forge/llm/llm.py +56 -0
  10. service_forge/main.py +121 -0
  11. service_forge/model/__init__.py +0 -0
  12. service_forge/model/websocket.py +13 -0
  13. service_forge/proto/foo_input.py +5 -0
  14. service_forge/service.py +111 -0
  15. service_forge/service_config.py +115 -0
  16. service_forge/sft/cli.py +91 -0
  17. service_forge/sft/cmd/config_command.py +67 -0
  18. service_forge/sft/cmd/deploy_service.py +124 -0
  19. service_forge/sft/cmd/list_tars.py +41 -0
  20. service_forge/sft/cmd/service_command.py +149 -0
  21. service_forge/sft/cmd/upload_service.py +36 -0
  22. service_forge/sft/config/injector.py +87 -0
  23. service_forge/sft/config/injector_default_files.py +97 -0
  24. service_forge/sft/config/sf_metadata.py +30 -0
  25. service_forge/sft/config/sft_config.py +125 -0
  26. service_forge/sft/file/__init__.py +0 -0
  27. service_forge/sft/file/ignore_pattern.py +80 -0
  28. service_forge/sft/file/sft_file_manager.py +107 -0
  29. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  30. service_forge/sft/util/assert_util.py +25 -0
  31. service_forge/sft/util/logger.py +16 -0
  32. service_forge/sft/util/name_util.py +2 -0
  33. service_forge/utils/__init__.py +0 -0
  34. service_forge/utils/default_type_converter.py +12 -0
  35. service_forge/utils/register.py +39 -0
  36. service_forge/utils/type_converter.py +74 -0
  37. service_forge/workflow/__init__.py +1 -0
  38. service_forge/workflow/context.py +13 -0
  39. service_forge/workflow/edge.py +31 -0
  40. service_forge/workflow/node.py +179 -0
  41. service_forge/workflow/nodes/__init__.py +7 -0
  42. service_forge/workflow/nodes/control/if_node.py +29 -0
  43. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  44. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  45. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  46. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  47. service_forge/workflow/nodes/output/print_node.py +29 -0
  48. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  49. service_forge/workflow/nodes/test/time_consuming_node.py +61 -0
  50. service_forge/workflow/port.py +86 -0
  51. service_forge/workflow/trigger.py +20 -0
  52. service_forge/workflow/triggers/__init__.py +4 -0
  53. service_forge/workflow/triggers/fast_api_trigger.py +125 -0
  54. service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
  55. service_forge/workflow/triggers/once_trigger.py +20 -0
  56. service_forge/workflow/triggers/period_trigger.py +26 -0
  57. service_forge/workflow/workflow.py +251 -0
  58. service_forge/workflow/workflow_factory.py +227 -0
  59. service_forge/workflow/workflow_group.py +23 -0
  60. service_forge/workflow/workflow_type.py +52 -0
  61. service_forge-0.1.0.dist-info/METADATA +93 -0
  62. service_forge-0.1.0.dist-info/RECORD +64 -0
  63. service_forge-0.1.0.dist-info/WHEEL +4 -0
  64. service_forge-0.1.0.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,125 @@
1
+
2
+ from __future__ import annotations
3
+ from service_forge.workflow.trigger import Trigger
4
+ from typing import AsyncIterator, Any
5
+ from fastapi import FastAPI, Request
6
+ import uuid
7
+ import asyncio
8
+ import json
9
+ from service_forge.workflow.port import Port
10
+ from service_forge.utils.default_type_converter import type_converter
11
+ from service_forge.api.websocket_manager import websocket_manager
12
+ from fastapi import HTTPException
13
+ from google.protobuf.message import Message
14
+ from google.protobuf.json_format import MessageToJson
15
+
16
+ class FastAPITrigger(Trigger):
17
+ DEFAULT_INPUT_PORTS = [
18
+ Port("app", FastAPI),
19
+ Port("path", str),
20
+ Port("method", str),
21
+ Port("data_type", type),
22
+ ]
23
+
24
+ DEFAULT_OUTPUT_PORTS = [
25
+ Port("trigger", bool),
26
+ Port("user_id", int),
27
+ Port("data", Any),
28
+ ]
29
+
30
+ def __init__(self, name: str):
31
+ super().__init__(name)
32
+ self.events = {}
33
+ self.is_setup_route = False
34
+
35
+ @staticmethod
36
+ def serialize_result(result: Any):
37
+ if isinstance(result, Message):
38
+ return MessageToJson(
39
+ result,
40
+ preserving_proto_field_name=True
41
+ )
42
+ try:
43
+ return json.dumps(result)
44
+ except Exception:
45
+ return result
46
+
47
+ async def handle_request(
48
+ self,
49
+ request: Request,
50
+ data_type: type,
51
+ extract_data_fn: callable[[Request], dict]
52
+ ):
53
+ task_id = uuid.uuid4()
54
+ self.result_queues[task_id] = asyncio.Queue()
55
+
56
+ body_data = await extract_data_fn(request)
57
+ converted_data = data_type(**body_data)
58
+
59
+ client_id = (
60
+ body_data.get("client_id")
61
+ or request.query_params.get("client_id")
62
+ or request.headers.get("X-Client-ID")
63
+ )
64
+ if client_id:
65
+ workflow_name = getattr(self.workflow, "name", "Unknown")
66
+ steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
67
+ websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
68
+
69
+ self.trigger_queue.put_nowait({
70
+ "id": task_id,
71
+ "user_id": getattr(request.state, "user_id", None),
72
+ "data": converted_data,
73
+ })
74
+
75
+ result = await self.result_queues[task_id].get()
76
+ del self.result_queues[task_id]
77
+
78
+ if isinstance(result, HTTPException):
79
+ raise result
80
+ if isinstance(result, Exception):
81
+ raise HTTPException(status_code=500, detail=str(result))
82
+
83
+ return self.serialize_result(result)
84
+
85
+ def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type) -> None:
86
+ async def get_data(request: Request) -> dict:
87
+ return dict(request.query_params)
88
+
89
+ async def body_data(request: Request) -> dict:
90
+ raw = await request.body()
91
+ if not raw:
92
+ return {}
93
+ return json.loads(raw.decode("utf-8"))
94
+
95
+ extractor = get_data if method == "GET" else body_data
96
+
97
+ async def handler(request: Request):
98
+ return await self.handle_request(request, data_type, extractor)
99
+
100
+ if method == "GET":
101
+ app.get(path)(handler)
102
+ elif method == "POST":
103
+ app.post(path)(handler)
104
+ elif method == "PUT":
105
+ app.put(path)(handler)
106
+ elif method == "DELETE":
107
+ app.delete(path)(handler)
108
+ else:
109
+ raise ValueError(f"Invalid method {method}")
110
+
111
+ async def _run(self, app: FastAPI, path: str, method: str, data_type: type) -> AsyncIterator[bool]:
112
+ if not self.is_setup_route:
113
+ self._setup_route(app, path, method, data_type)
114
+ self.is_setup_route = True
115
+
116
+ while True:
117
+ try:
118
+ trigger = await self.trigger_queue.get()
119
+ self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
120
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
121
+ yield self.trigger(trigger['id'])
122
+ except Exception as e:
123
+ from loguru import logger
124
+ logger.error(f"Error in FastAPITrigger._run: {e}")
125
+ continue
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ from typing import Any
4
+ from service_forge.workflow.trigger import Trigger
5
+ from typing import AsyncIterator
6
+ from service_forge.workflow.port import Port
7
+ from service_forge.api.kafka_api import KafkaApp
8
+
9
+ class KafkaAPITrigger(Trigger):
10
+ DEFAULT_INPUT_PORTS = [
11
+ Port("app", KafkaApp),
12
+ Port("topic", str),
13
+ Port("data_type", type),
14
+ Port("group_id", str),
15
+ ]
16
+
17
+ DEFAULT_OUTPUT_PORTS = [
18
+ Port("trigger", bool),
19
+ Port("data", Any),
20
+ ]
21
+
22
+ def __init__(self, name: str):
23
+ super().__init__(name)
24
+ self.events = {}
25
+ self.is_setup_kafka_input = False
26
+
27
+ def _setup_kafka_input(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> None:
28
+ @app.kafka_input(topic, data_type, group_id)
29
+ async def handle_message(data):
30
+ task_id = uuid.uuid4()
31
+ self.trigger_queue.put_nowait({
32
+ "id": task_id,
33
+ "data": data,
34
+ })
35
+
36
+ async def _run(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> AsyncIterator[bool]:
37
+ if not self.is_setup_kafka_input:
38
+ self._setup_kafka_input(app, topic, data_type, group_id)
39
+ self.is_setup_kafka_input = True
40
+
41
+ while True:
42
+ trigger = await self.trigger_queue.get()
43
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
44
+ yield self.trigger(trigger['id'])
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+ from service_forge.workflow.node import Node
3
+ from service_forge.workflow.port import Port
4
+ from service_forge.workflow.trigger import Trigger
5
+ from typing import AsyncIterator
6
+ import uuid
7
+
8
+ class OnceTrigger(Trigger):
9
+ DEFAULT_INPUT_PORTS = [
10
+ ]
11
+
12
+ DEFAULT_OUTPUT_PORTS = [
13
+ Port("trigger", bool),
14
+ ]
15
+
16
+ def __init__(self, name: str):
17
+ super().__init__(name)
18
+
19
+ async def _run(self) -> AsyncIterator[bool]:
20
+ yield self.trigger(uuid.uuid4())
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+ import asyncio
3
+ from service_forge.workflow.node import Node
4
+ from service_forge.workflow.port import Port
5
+ from service_forge.workflow.trigger import Trigger
6
+ from typing import AsyncIterator
7
+ import uuid
8
+
9
+ class PeriodTrigger(Trigger):
10
+ DEFAULT_INPUT_PORTS = [
11
+ Port("TRIGGER", bool),
12
+ Port("period", float),
13
+ Port("times", int),
14
+ ]
15
+
16
+ DEFAULT_OUTPUT_PORTS = [
17
+ Port("trigger", bool),
18
+ ]
19
+
20
+ def __init__(self, name: str):
21
+ super().__init__(name)
22
+
23
+ async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
24
+ for _ in range(times):
25
+ await asyncio.sleep(period)
26
+ yield uuid.uuid4()
@@ -0,0 +1,251 @@
1
+
2
+ from __future__ import annotations
3
+ import asyncio
4
+ import uuid
5
+ from typing import AsyncIterator, Awaitable, Callable, Any
6
+ from loguru import logger
7
+ from copy import deepcopy
8
+ from .node import Node
9
+ from .port import Port
10
+ from .trigger import Trigger
11
+ from .edge import Edge
12
+ from ..db.database import DatabaseManager
13
+ from ..api.websocket_manager import websocket_manager
14
+
15
+
16
+ class Workflow:
17
+ def __init__(
18
+ self,
19
+ name: str,
20
+ nodes: list[Node],
21
+ input_ports: list[Port],
22
+ output_ports: list[Port],
23
+ _handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None,
24
+ _handle_query_user: Callable[[str, str], Awaitable[str]] = None,
25
+ database_manager: DatabaseManager = None,
26
+ max_concurrent_runs: int = 10,
27
+ ) -> None:
28
+ self.name = name
29
+ self.nodes = nodes
30
+ self.ready_nodes = []
31
+ self.input_ports = input_ports
32
+ self.output_ports = output_ports
33
+ self._handle_stream_output = _handle_stream_output
34
+ self._handle_query_user = _handle_query_user
35
+ self.after_trigger_workflow = None
36
+ self.result_port = Port("result", Any)
37
+ self.database_manager = database_manager
38
+ self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
39
+ self._validate()
40
+
41
+ def add_nodes(self, nodes: list[Node]) -> None:
42
+ for node in nodes:
43
+ node.workflow = self
44
+ self.nodes.extend(nodes)
45
+
46
+ def remove_nodes(self, nodes: list[Node]) -> None:
47
+ for node in nodes:
48
+ self.nodes.remove(node)
49
+
50
+ def load_config(self) -> None:
51
+ ...
52
+
53
+ def _validate(self) -> None:
54
+ # DAG
55
+ ...
56
+
57
+ def get_input_port_by_name(self, name: str) -> Port:
58
+ for input_port in self.input_ports:
59
+ if input_port.name == name:
60
+ return input_port
61
+ return None
62
+
63
+ def get_output_port_by_name(self, name: str) -> Port:
64
+ for output_port in self.output_ports:
65
+ if output_port.name == name:
66
+ return output_port
67
+ return None
68
+
69
+ async def run_after_trigger(self, task_id: uuid.UUID = None) -> Any:
70
+ trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
71
+ if not trigger_nodes:
72
+ raise ValueError("No trigger nodes found in workflow.")
73
+ self.ready_nodes = []
74
+ for edge in trigger_nodes[0].output_edges:
75
+ edge.end_port.trigger()
76
+
77
+ logger.info(f"Running workflow: {self.name}")
78
+ result = None
79
+ try:
80
+ for input_port in self.input_ports:
81
+ if input_port.value is not None:
82
+ input_port.port.node.fill_input(input_port.port, input_port.value)
83
+
84
+ for node in self.nodes:
85
+ for key in node.AUTO_FILL_INPUT_PORTS:
86
+ if key[0] not in [edge.end_port.name for edge in node.input_edges]:
87
+ node.fill_input_by_name(key[0], key[1])
88
+
89
+ while self.ready_nodes:
90
+ tasks = []
91
+ nodes = self.ready_nodes.copy()
92
+ self.ready_nodes = []
93
+ for i, node in enumerate(nodes):
94
+ # 发送节点开始执行的WebSocket通知
95
+ if task_id:
96
+ # 更新当前步骤
97
+ websocket_manager.task_manager.update_current_step(task_id, i + 1)
98
+ await websocket_manager.send_executing(task_id, node.name)
99
+
100
+ result = node.run(task_id)
101
+
102
+ if hasattr(result, '__anext__'):
103
+ # 处理流输出
104
+ if self._handle_stream_output is None:
105
+ tasks.append(self.handle_stream_output(node.name, result))
106
+ else:
107
+ tasks.append(self._handle_stream_output(node.name, result))
108
+ elif asyncio.iscoroutine(result):
109
+ tasks.append(result)
110
+
111
+ # 发送每个输出端口的结果
112
+ if task_id:
113
+ for output_port in node.output_ports:
114
+ if output_port.is_prepared:
115
+ await websocket_manager.send_node_output(task_id, node.name, output_port.name, output_port.value)
116
+
117
+ if tasks:
118
+ results = await asyncio.gather(*tasks)
119
+ # 更新result为最后一个非None的结果
120
+ for res in results:
121
+ if res is not None:
122
+ result = res
123
+
124
+ # 在所有任务完成后,再次发送每个输出端口的结果
125
+ if task_id:
126
+ for node in nodes:
127
+ for output_port in node.output_ports:
128
+ if output_port.is_prepared:
129
+ await websocket_manager.send_node_output(task_id, node.name, output_port.name, output_port.value)
130
+
131
+ except Exception as e:
132
+ import traceback
133
+ error_msg = f"Error in run_after_trigger: {str(e)}"
134
+ logger.error(error_msg)
135
+ logger.error(traceback.format_exc())
136
+ # 发送执行错误的WebSocket通知
137
+ if task_id:
138
+ # 更新任务状态为失败
139
+ websocket_manager.task_manager.fail_task(task_id, error_msg)
140
+ await websocket_manager.send_execution_error(task_id, "workflow", error_msg)
141
+ raise e
142
+
143
+ # 发送工作流执行完成的WebSocket通知
144
+ if task_id:
145
+ # 更新任务状态为已完成
146
+ websocket_manager.task_manager.complete_task(task_id)
147
+ await websocket_manager.send_executed(task_id, "workflow", result)
148
+
149
+ if len(self.output_ports) > 0 and self.output_ports[0].is_prepared:
150
+ return self.output_ports[0].value
151
+
152
+ async def _run(self, uuid: uuid.UUID, trigger_node: Trigger) -> None:
153
+ async with self.run_semaphore:
154
+ try:
155
+ # 发送任务开始执行的WebSocket通知
156
+ await websocket_manager.send_execution_start(uuid)
157
+
158
+ new_workflow = self._clone()
159
+ result = await new_workflow.run_after_trigger(uuid)
160
+ if uuid in trigger_node.result_queues:
161
+ trigger_node.result_queues[uuid].put_nowait(result)
162
+
163
+ except Exception as e:
164
+ import traceback
165
+ error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
166
+ logger.error(error_msg)
167
+ # 发送执行错误的WebSocket通知
168
+ # 更新任务状态为失败
169
+ websocket_manager.task_manager.fail_task(uuid, error_msg)
170
+ await websocket_manager.send_execution_error(uuid, "workflow", error_msg)
171
+ if uuid in trigger_node.result_queues:
172
+ trigger_node.result_queues[uuid].put_nowait(e)
173
+
174
+ async def run(self):
175
+ trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
176
+
177
+ if len(trigger_nodes) == 0:
178
+ raise ValueError("No trigger nodes found in workflow.")
179
+ if len(trigger_nodes) > 1:
180
+ raise ValueError("Multiple trigger nodes found in workflow.")
181
+
182
+ trigger_node = trigger_nodes[0]
183
+ async for uuid in trigger_node.run():
184
+ asyncio.create_task(self._run(uuid, trigger_node))
185
+
186
+ async def handle_stream_output(self, node_name: str, stream: AsyncIterator[str]) -> None:
187
+ logger.info(f"[{node_name}] Starting stream output:")
188
+ buffer = []
189
+ async for char in stream:
190
+ buffer.append(char)
191
+ logger.info(f"[{node_name}] Received char: '{char}'")
192
+
193
+ complete_message = ''.join(buffer)
194
+ logger.info(f"[{node_name}] Complete message: '{complete_message}'")
195
+
196
+ async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
197
+ return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
198
+
199
+ def _clone(self) -> Workflow:
200
+ node_map: dict[Node, Node] = {node: node._simple_clone() for node in self.nodes}
201
+
202
+ port_map: dict[Port, Port] = {}
203
+ port_map.update({port: port._simple_clone(node_map) for port in self.input_ports})
204
+ port_map.update({port: port._simple_clone(node_map) for port in self.output_ports})
205
+ for node in self.nodes:
206
+ for port in node.input_ports:
207
+ if port not in port_map:
208
+ port_map[port] = port._simple_clone(node_map)
209
+ for port in node.output_ports:
210
+ if port not in port_map:
211
+ port_map[port] = port._simple_clone(node_map)
212
+
213
+ edge_map: dict[Edge, Edge] = {}
214
+ for node in self.nodes:
215
+ for edge in node.input_edges:
216
+ if edge not in edge_map:
217
+ edge_map[edge] = edge._simple_clone(node_map, port_map)
218
+ for edge in node.output_edges:
219
+ if edge not in edge_map:
220
+ edge_map[edge] = edge._simple_clone(node_map, port_map)
221
+
222
+ # fill port.port
223
+ for old_port, new_port in port_map.items():
224
+ if old_port.port is not None:
225
+ new_port.port = port_map[old_port.port]
226
+
227
+
228
+ # fill ports and edges in nodes
229
+ for old_node, new_node in node_map.items():
230
+ new_node.input_edges = [edge_map[edge] for edge in old_node.input_edges]
231
+ new_node.output_edges = [edge_map[edge] for edge in old_node.output_edges]
232
+ new_node.input_ports = [port_map[port] for port in old_node.input_ports]
233
+ new_node.output_ports = [port_map[port] for port in old_node.output_ports]
234
+ new_node.input_variables = {
235
+ port_map[port]: value for port, value in old_node.input_variables.items()
236
+ }
237
+
238
+ workflow = Workflow(
239
+ name=self.name,
240
+ nodes=[node_map[node] for node in self.nodes],
241
+ input_ports=[port_map[port] for port in self.input_ports],
242
+ output_ports=[port_map[port] for port in self.output_ports],
243
+ _handle_stream_output=self._handle_stream_output,
244
+ _handle_query_user=self._handle_query_user,
245
+ database_manager=self.database_manager,
246
+ )
247
+
248
+ for node in workflow.nodes:
249
+ node.workflow = workflow
250
+
251
+ return workflow
@@ -0,0 +1,227 @@
1
+ from omegaconf import OmegaConf
2
+ from typing import Callable, Awaitable, AsyncIterator, Any
3
+ from copy import deepcopy
4
+ from .workflow import Workflow
5
+ from .workflow_group import WorkflowGroup
6
+ from .node import Node
7
+ from .edge import Edge
8
+ from .port import Port, parse_port_name, create_workflow_input_port, create_sub_workflow_input_port, create_port
9
+ from .node import node_register
10
+ from .nodes import *
11
+ from .triggers import *
12
+ from .context import Context
13
+ from ..db.database import DatabaseManager
14
+
15
+ WORKFLOW_KEY_NAME = 'name'
16
+ WORKFLOW_KEY_NODES = 'nodes'
17
+ WORKFLOW_KEY_INPUTS = 'inputs'
18
+ WORKFLOW_KEY_OUTPUTS = 'outputs'
19
+
20
+ NODE_KEY_NAME = 'name'
21
+ NODE_KEY_TYPE = 'type'
22
+ NODE_KEY_ARGS = 'args'
23
+ NODE_KEY_OUTPUTS = 'outputs'
24
+ NODE_KEY_INPUT_PORTS = 'input_ports'
25
+ NODE_KEY_OUTPUT_PORTS = 'output_ports'
26
+ NODE_KEY_SUB_WORKFLOWS = 'sub_workflows'
27
+ NODE_KEY_SUB_WORKFLOWS_INPUT_PORTS = 'sub_workflows_input_ports'
28
+
29
+ PORT_KEY_NAME = 'name'
30
+ PORT_KEY_PORT = 'port'
31
+ PORT_KEY_VALUE = 'value'
32
+
33
+ def parse_argument(arg: Any, service_env: dict[str, Any] = None) -> Any:
34
+ if type(arg) == str and arg.startswith(f'<{{') and arg.endswith(f'}}>'):
35
+ key = arg[2:-2]
36
+ if key not in service_env:
37
+ raise ValueError(f"Key {key} not found in service env.")
38
+ return service_env[key]
39
+ return arg
40
+
41
+ def create_workflow(
42
+ config_path: str = None,
43
+ service_env: dict[str, Any] = None,
44
+ config: dict = None,
45
+ workflows: WorkflowGroup = None,
46
+ _handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] | None = None,
47
+ _handle_query_user: Callable[[str, str], Awaitable[str]] | None = None,
48
+ database_manager: DatabaseManager = None,
49
+ ) -> Workflow:
50
+ if config is None:
51
+ config = OmegaConf.to_object(OmegaConf.load(config_path))
52
+
53
+ if WORKFLOW_KEY_NAME not in config:
54
+ if config_path is None:
55
+ raise ValueError(f"{WORKFLOW_KEY_NAME} is required in workflow config in {config}.")
56
+ else:
57
+ raise ValueError(f"{WORKFLOW_KEY_NAME} is required in workflow config at {config_path}.")
58
+
59
+ workflow = Workflow(
60
+ name = config[WORKFLOW_KEY_NAME],
61
+ nodes = [],
62
+ input_ports = [],
63
+ output_ports = [],
64
+ _handle_stream_output = _handle_stream_output,
65
+ _handle_query_user = _handle_query_user,
66
+ database_manager = database_manager,
67
+ )
68
+
69
+ nodes: dict[str, Node] = {}
70
+
71
+ # Nodes
72
+ for node_config in config[WORKFLOW_KEY_NODES]:
73
+ params = {
74
+ "name": node_config[NODE_KEY_NAME],
75
+ }
76
+
77
+ node: Node = node_register.instance(node_config[NODE_KEY_TYPE], ignore_keys=['type'], kwargs=params)
78
+
79
+ # Context
80
+ node.context = Context(variables = {})
81
+
82
+ # Input ports
83
+ if node_key_input_ports := node_config.get(NODE_KEY_INPUT_PORTS, None):
84
+ node.input_ports = [Port(**port_params) for port_params in node_key_input_ports]
85
+ else:
86
+ node.input_ports = deepcopy(node.DEFAULT_INPUT_PORTS)
87
+
88
+ for input_port in node.input_ports:
89
+ input_port.node = node
90
+
91
+ # Output ports
92
+ if node_key_output_ports := node_config.get(NODE_KEY_OUTPUT_PORTS, None):
93
+ node.output_ports = [Port(**port_params) for port_params in node_key_output_ports]
94
+ else:
95
+ node.output_ports = deepcopy(node.DEFAULT_OUTPUT_PORTS)
96
+
97
+ for output_port in node.output_ports:
98
+ output_port.node = node
99
+
100
+ # Sub workflows
101
+ if node_key_sub_workflows := node_config.get(NODE_KEY_SUB_WORKFLOWS, None):
102
+ sub_workflows: WorkflowGroup = WorkflowGroup(workflows=[])
103
+ for sub_workflow_config in node_key_sub_workflows:
104
+ sub_workflow = workflows.get_workflow(sub_workflow_config['name'])
105
+ sub_workflows.add_workflow(deepcopy(sub_workflow))
106
+ node.sub_workflows = sub_workflows
107
+
108
+ # Sub workflows input ports
109
+ if node_key_sub_network_input_ports := node_config.get(NODE_KEY_SUB_WORKFLOWS_INPUT_PORTS, None):
110
+ for sub_workflow_input_port_config in node_key_sub_network_input_ports:
111
+ name = sub_workflow_input_port_config[PORT_KEY_NAME]
112
+ sub_workflow_name, sub_workflow_port_name = parse_port_name(sub_workflow_input_port_config[PORT_KEY_PORT])
113
+ sub_workflow = node.sub_workflows.get_workflow(sub_workflow_name)
114
+ if sub_workflow is None:
115
+ raise ValueError(f"{sub_workflow_name} is not a valid sub workflow.")
116
+ sub_workflow_port = sub_workflow.get_input_port_by_name(sub_workflow_port_name)
117
+ if sub_workflow_port is None:
118
+ raise ValueError(f"{sub_workflow_port_name} is not a valid input port.")
119
+ value = sub_workflow_input_port_config.get(PORT_KEY_VALUE, None)
120
+ node.input_ports.append(create_sub_workflow_input_port(name=name, node=node, port=sub_workflow_port, value=value))
121
+
122
+ # Sub workflows output ports
123
+ ...
124
+
125
+ # Hooks
126
+ if _handle_query_user is None:
127
+ node.query_user = workflow.handle_query_user
128
+ else:
129
+ node.query_user = _handle_query_user
130
+
131
+ nodes[node_config[NODE_KEY_NAME]] = node
132
+
133
+ # Edges
134
+ for node_config in config[WORKFLOW_KEY_NODES]:
135
+ start_node = nodes[node_config[NODE_KEY_NAME]]
136
+ if NODE_KEY_OUTPUTS in node_config and node_config[NODE_KEY_OUTPUTS]:
137
+ for key, value in node_config[NODE_KEY_OUTPUTS].items():
138
+ if value is None:
139
+ continue
140
+
141
+ if type(value) is str:
142
+ value = [value]
143
+
144
+ for edge_value in value:
145
+ end_node_name, end_port_name = parse_port_name(edge_value)
146
+ end_node = nodes[end_node_name]
147
+
148
+ start_port = start_node.get_output_port_by_name(key)
149
+ end_port = end_node.get_input_port_by_name(end_port_name)
150
+
151
+ if start_port is None:
152
+ raise ValueError(f"{key} is not a valid output port.")
153
+ if end_port is None:
154
+ raise ValueError(f"{end_port_name} is not a valid input port.")
155
+
156
+ edge = Edge(start_node, end_node, start_port, end_port)
157
+
158
+ start_node.output_edges.append(edge)
159
+ end_node.input_edges.append(edge)
160
+
161
+ workflow.add_nodes(list(nodes.values()))
162
+
163
+ # Inputs
164
+ if workflow_key_inputs := config.get(WORKFLOW_KEY_INPUTS, None):
165
+ for port_config in workflow_key_inputs:
166
+ name = port_config[PORT_KEY_NAME]
167
+ node_name, node_port_name = parse_port_name(port_config[PORT_KEY_PORT])
168
+ if node_name not in nodes:
169
+ raise ValueError(f"{node_name} is not a valid node.")
170
+ node = nodes[node_name]
171
+ port = node.get_input_port_by_name(node_port_name)
172
+ if port is None:
173
+ raise ValueError(f"{node_port_name} is not a valid input port.")
174
+ value = port_config.get(PORT_KEY_VALUE, None)
175
+ workflow.input_ports.append(create_workflow_input_port(name=name, port=port, value=value))
176
+
177
+ # Outputs
178
+ if workflow_key_outputs := config.get(WORKFLOW_KEY_OUTPUTS, None):
179
+ for port_config in workflow_key_outputs:
180
+ name = port_config[PORT_KEY_NAME]
181
+ node_name, node_port_name = parse_port_name(port_config[PORT_KEY_PORT])
182
+ if node_name not in nodes:
183
+ raise ValueError(f"{node_name} is not a valid node.")
184
+ node = nodes[node_name]
185
+ port = node.get_output_port_by_name(node_port_name)
186
+ if port is None:
187
+ raise ValueError(f"{node_port_name} is not a valid output port.")
188
+ output_port = create_port(name=name, type=Any)
189
+ workflow.output_ports.append(output_port)
190
+ edge = Edge(node, None, port, output_port)
191
+ node.output_edges.append(edge)
192
+
193
+ for node_config in config[WORKFLOW_KEY_NODES]:
194
+ node = nodes[node_config[NODE_KEY_NAME]]
195
+ # Arguments
196
+ if node_key_args := node_config.get(NODE_KEY_ARGS, None):
197
+ for key, value in node_key_args.items():
198
+ node.fill_input_by_name(key, parse_argument(value, service_env=service_env))
199
+
200
+ return workflow
201
+
202
+ def create_workflows(
203
+ config_path: str,
204
+ service_env: dict[str, Any] = None,
205
+ _handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None,
206
+ _handle_query_user: Callable[[str, str], Awaitable[str]] = None,
207
+ database_manager: DatabaseManager = None,
208
+ ) -> WorkflowGroup:
209
+ WORKFLOW_KEY_WORKFLOWS = 'workflows'
210
+ WORKFLOW_KEY_MAIN_WORKFLOW_NAME = 'main'
211
+
212
+ config = OmegaConf.to_object(OmegaConf.load(config_path))
213
+
214
+ if WORKFLOW_KEY_WORKFLOWS not in config:
215
+ workflow = create_workflow(
216
+ config_path,
217
+ service_env=service_env,
218
+ _handle_stream_output=_handle_stream_output,
219
+ _handle_query_user=_handle_query_user,
220
+ database_manager=database_manager,
221
+ )
222
+ return WorkflowGroup(workflows=[workflow], main_workflow_name=workflow.name)
223
+
224
+ workflows = WorkflowGroup(workflows=[], main_workflow_name=config.get(WORKFLOW_KEY_MAIN_WORKFLOW_NAME, None))
225
+ for workflow_config in config[WORKFLOW_KEY_WORKFLOWS]:
226
+ workflows.add_workflow(create_workflow(config = workflow_config, workflows=workflows, _handle_stream_output=_handle_stream_output, _handle_query_user=_handle_query_user))
227
+ return workflows