service-forge 0.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (75) hide show
  1. service_forge/api/deprecated_websocket_api.py +86 -0
  2. service_forge/api/deprecated_websocket_manager.py +425 -0
  3. service_forge/api/http_api.py +148 -0
  4. service_forge/api/http_api_doc.py +455 -0
  5. service_forge/api/kafka_api.py +126 -0
  6. service_forge/api/routers/service/__init__.py +4 -0
  7. service_forge/api/routers/service/service_router.py +137 -0
  8. service_forge/api/routers/websocket/websocket_manager.py +83 -0
  9. service_forge/api/routers/websocket/websocket_router.py +78 -0
  10. service_forge/api/task_manager.py +141 -0
  11. service_forge/db/__init__.py +1 -0
  12. service_forge/db/database.py +240 -0
  13. service_forge/llm/__init__.py +62 -0
  14. service_forge/llm/llm.py +56 -0
  15. service_forge/model/__init__.py +0 -0
  16. service_forge/model/websocket.py +13 -0
  17. service_forge/proto/foo_input.py +5 -0
  18. service_forge/service.py +288 -0
  19. service_forge/service_config.py +158 -0
  20. service_forge/sft/cli.py +91 -0
  21. service_forge/sft/cmd/config_command.py +67 -0
  22. service_forge/sft/cmd/deploy_service.py +123 -0
  23. service_forge/sft/cmd/list_tars.py +41 -0
  24. service_forge/sft/cmd/service_command.py +149 -0
  25. service_forge/sft/cmd/upload_service.py +36 -0
  26. service_forge/sft/config/injector.py +119 -0
  27. service_forge/sft/config/injector_default_files.py +131 -0
  28. service_forge/sft/config/sf_metadata.py +30 -0
  29. service_forge/sft/config/sft_config.py +153 -0
  30. service_forge/sft/file/__init__.py +0 -0
  31. service_forge/sft/file/ignore_pattern.py +80 -0
  32. service_forge/sft/file/sft_file_manager.py +107 -0
  33. service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
  34. service_forge/sft/util/assert_util.py +25 -0
  35. service_forge/sft/util/logger.py +16 -0
  36. service_forge/sft/util/name_util.py +8 -0
  37. service_forge/sft/util/yaml_utils.py +57 -0
  38. service_forge/utils/__init__.py +0 -0
  39. service_forge/utils/default_type_converter.py +12 -0
  40. service_forge/utils/register.py +39 -0
  41. service_forge/utils/type_converter.py +99 -0
  42. service_forge/utils/workflow_clone.py +124 -0
  43. service_forge/workflow/__init__.py +1 -0
  44. service_forge/workflow/context.py +14 -0
  45. service_forge/workflow/edge.py +24 -0
  46. service_forge/workflow/node.py +184 -0
  47. service_forge/workflow/nodes/__init__.py +8 -0
  48. service_forge/workflow/nodes/control/if_node.py +29 -0
  49. service_forge/workflow/nodes/control/switch_node.py +28 -0
  50. service_forge/workflow/nodes/input/console_input_node.py +26 -0
  51. service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
  52. service_forge/workflow/nodes/nested/workflow_node.py +28 -0
  53. service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
  54. service_forge/workflow/nodes/output/print_node.py +29 -0
  55. service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
  56. service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
  57. service_forge/workflow/port.py +89 -0
  58. service_forge/workflow/trigger.py +24 -0
  59. service_forge/workflow/triggers/__init__.py +6 -0
  60. service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
  61. service_forge/workflow/triggers/fast_api_trigger.py +169 -0
  62. service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
  63. service_forge/workflow/triggers/once_trigger.py +20 -0
  64. service_forge/workflow/triggers/period_trigger.py +26 -0
  65. service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
  66. service_forge/workflow/workflow.py +210 -0
  67. service_forge/workflow/workflow_callback.py +141 -0
  68. service_forge/workflow/workflow_event.py +15 -0
  69. service_forge/workflow/workflow_factory.py +246 -0
  70. service_forge/workflow/workflow_group.py +27 -0
  71. service_forge/workflow/workflow_type.py +52 -0
  72. service_forge-0.1.11.dist-info/METADATA +98 -0
  73. service_forge-0.1.11.dist-info/RECORD +75 -0
  74. service_forge-0.1.11.dist-info/WHEEL +4 -0
  75. service_forge-0.1.11.dist-info/entry_points.txt +2 -0
@@ -0,0 +1,169 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ import asyncio
4
+ import json
5
+ from loguru import logger
6
+ from service_forge.workflow.trigger import Trigger
7
+ from typing import AsyncIterator, Any
8
+ from fastapi import FastAPI, Request
9
+ from fastapi.responses import StreamingResponse
10
+ from service_forge.workflow.port import Port
11
+ from service_forge.utils.default_type_converter import type_converter
12
+ from service_forge.api.routers.websocket.websocket_manager import websocket_manager
13
+ from fastapi import HTTPException
14
+ from google.protobuf.message import Message
15
+ from google.protobuf.json_format import MessageToJson
16
+
17
+ class FastAPITrigger(Trigger):
18
+ DEFAULT_INPUT_PORTS = [
19
+ Port("app", FastAPI),
20
+ Port("path", str),
21
+ Port("method", str),
22
+ Port("data_type", type),
23
+ Port("is_stream", bool),
24
+ ]
25
+
26
+ DEFAULT_OUTPUT_PORTS = [
27
+ Port("trigger", bool),
28
+ Port("user_id", int),
29
+ Port("data", Any),
30
+ ]
31
+
32
+ def __init__(self, name: str):
33
+ super().__init__(name)
34
+ self.events = {}
35
+ self.is_setup_route = False
36
+
37
+ @staticmethod
38
+ def serialize_result(result: Any):
39
+ if isinstance(result, Message):
40
+ return MessageToJson(
41
+ result,
42
+ preserving_proto_field_name=True
43
+ )
44
+ return result
45
+
46
+ async def handle_request(
47
+ self,
48
+ request: Request,
49
+ data_type: type,
50
+ extract_data_fn: callable[[Request], dict],
51
+ is_stream: bool,
52
+ ):
53
+ task_id = uuid.uuid4()
54
+ self.result_queues[task_id] = asyncio.Queue()
55
+
56
+ body_data = await extract_data_fn(request)
57
+ converted_data = data_type(**body_data)
58
+
59
+ client_id = (
60
+ body_data.get("client_id")
61
+ or request.query_params.get("client_id")
62
+ or request.headers.get("X-Client-ID")
63
+ )
64
+ if client_id:
65
+ workflow_name = getattr(self.workflow, "name", "Unknown")
66
+ steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
67
+ websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
68
+
69
+ self.trigger_queue.put_nowait({
70
+ "id": task_id,
71
+ "user_id": getattr(request.state, "user_id", None),
72
+ "data": converted_data,
73
+ })
74
+
75
+ if is_stream:
76
+ self.stream_queues[task_id] = asyncio.Queue()
77
+
78
+ async def generate_sse():
79
+ try:
80
+ while True:
81
+ item = await self.stream_queues[task_id].get()
82
+
83
+ if item.is_error:
84
+ yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
85
+ break
86
+
87
+ if item.is_end:
88
+ # TODO: send the result?
89
+ break
90
+
91
+ # TODO: modify
92
+ serialized = self.serialize_result(item.result)
93
+ if isinstance(serialized, str):
94
+ data = serialized
95
+ else:
96
+ data = json.dumps(serialized)
97
+
98
+ yield f"data: {data}\n\n"
99
+
100
+ except Exception as e:
101
+ yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
102
+ finally:
103
+ if task_id in self.stream_queues:
104
+ del self.stream_queues[task_id]
105
+ if task_id in self.result_queues:
106
+ del self.result_queues[task_id]
107
+
108
+ return StreamingResponse(
109
+ generate_sse(),
110
+ media_type="text/event-stream",
111
+ headers={
112
+ "Cache-Control": "no-cache",
113
+ "Connection": "keep-alive",
114
+ "X-Accel-Buffering": "no",
115
+ }
116
+ )
117
+ else:
118
+ result = await self.result_queues[task_id].get()
119
+ del self.result_queues[task_id]
120
+
121
+ if result.is_error:
122
+ if isinstance(result.result, HTTPException):
123
+ raise result.result
124
+ else:
125
+ raise HTTPException(status_code=500, detail=str(result.result))
126
+
127
+ return self.serialize_result(result.result)
128
+
129
+ def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
130
+ async def get_data(request: Request) -> dict:
131
+ return dict(request.query_params)
132
+
133
+ async def body_data(request: Request) -> dict:
134
+ raw = await request.body()
135
+ if not raw:
136
+ return {}
137
+ return json.loads(raw.decode("utf-8"))
138
+
139
+ extractor = get_data if method == "GET" else body_data
140
+
141
+ async def handler(request: Request):
142
+ return await self.handle_request(request, data_type, extractor, is_stream)
143
+
144
+ if method == "GET":
145
+ app.get(path)(handler)
146
+ elif method == "POST":
147
+ app.post(path)(handler)
148
+ elif method == "PUT":
149
+ app.put(path)(handler)
150
+ elif method == "DELETE":
151
+ app.delete(path)(handler)
152
+ else:
153
+ raise ValueError(f"Invalid method {method}")
154
+
155
+ async def _run(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool = False) -> AsyncIterator[bool]:
156
+ if not self.is_setup_route:
157
+ self._setup_route(app, path, method, data_type, is_stream)
158
+ self.is_setup_route = True
159
+
160
+ while True:
161
+ try:
162
+ trigger = await self.trigger_queue.get()
163
+ self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
164
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
165
+ yield self.trigger(trigger['id'])
166
+ except Exception as e:
167
+ logger.error(f"Error in FastAPITrigger._run: {e}")
168
+ continue
169
+
@@ -0,0 +1,44 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ from typing import Any
4
+ from service_forge.workflow.trigger import Trigger
5
+ from typing import AsyncIterator
6
+ from service_forge.workflow.port import Port
7
+ from service_forge.api.kafka_api import KafkaApp
8
+
9
+ class KafkaAPITrigger(Trigger):
10
+ DEFAULT_INPUT_PORTS = [
11
+ Port("app", KafkaApp),
12
+ Port("topic", str),
13
+ Port("data_type", type),
14
+ Port("group_id", str),
15
+ ]
16
+
17
+ DEFAULT_OUTPUT_PORTS = [
18
+ Port("trigger", bool),
19
+ Port("data", Any),
20
+ ]
21
+
22
+ def __init__(self, name: str):
23
+ super().__init__(name)
24
+ self.events = {}
25
+ self.is_setup_kafka_input = False
26
+
27
+ def _setup_kafka_input(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> None:
28
+ @app.kafka_input(topic, data_type, group_id)
29
+ async def handle_message(data):
30
+ task_id = uuid.uuid4()
31
+ self.trigger_queue.put_nowait({
32
+ "id": task_id,
33
+ "data": data,
34
+ })
35
+
36
+ async def _run(self, app: KafkaApp, topic: str, data_type: type, group_id: str) -> AsyncIterator[bool]:
37
+ if not self.is_setup_kafka_input:
38
+ self._setup_kafka_input(app, topic, data_type, group_id)
39
+ self.is_setup_kafka_input = True
40
+
41
+ while True:
42
+ trigger = await self.trigger_queue.get()
43
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
44
+ yield self.trigger(trigger['id'])
@@ -0,0 +1,20 @@
1
+ from __future__ import annotations
2
+ from service_forge.workflow.node import Node
3
+ from service_forge.workflow.port import Port
4
+ from service_forge.workflow.trigger import Trigger
5
+ from typing import AsyncIterator
6
+ import uuid
7
+
8
+ class OnceTrigger(Trigger):
9
+ DEFAULT_INPUT_PORTS = [
10
+ ]
11
+
12
+ DEFAULT_OUTPUT_PORTS = [
13
+ Port("trigger", bool),
14
+ ]
15
+
16
+ def __init__(self, name: str):
17
+ super().__init__(name)
18
+
19
+ async def _run(self) -> AsyncIterator[bool]:
20
+ yield self.trigger(uuid.uuid4())
@@ -0,0 +1,26 @@
1
+ from __future__ import annotations
2
+ import asyncio
3
+ from service_forge.workflow.node import Node
4
+ from service_forge.workflow.port import Port
5
+ from service_forge.workflow.trigger import Trigger
6
+ from typing import AsyncIterator
7
+ import uuid
8
+
9
+ class PeriodTrigger(Trigger):
10
+ DEFAULT_INPUT_PORTS = [
11
+ Port("TRIGGER", bool),
12
+ Port("period", float),
13
+ Port("times", int),
14
+ ]
15
+
16
+ DEFAULT_OUTPUT_PORTS = [
17
+ Port("trigger", bool),
18
+ ]
19
+
20
+ def __init__(self, name: str):
21
+ super().__init__(name)
22
+
23
+ async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
24
+ for _ in range(times):
25
+ await asyncio.sleep(period)
26
+ yield uuid.uuid4()
@@ -0,0 +1,184 @@
1
+ from __future__ import annotations
2
+ import uuid
3
+ import asyncio
4
+ import json
5
+ from loguru import logger
6
+ from service_forge.workflow.trigger import Trigger
7
+ from typing import AsyncIterator, Any
8
+ from fastapi import FastAPI, WebSocket, WebSocketDisconnect
9
+ from service_forge.workflow.port import Port
10
+ from google.protobuf.message import Message
11
+ from google.protobuf.json_format import MessageToJson
12
+
13
+ class WebSocketAPITrigger(Trigger):
14
+ DEFAULT_INPUT_PORTS = [
15
+ Port("app", FastAPI),
16
+ Port("path", str),
17
+ Port("data_type", type),
18
+ ]
19
+
20
+ DEFAULT_OUTPUT_PORTS = [
21
+ Port("trigger", bool),
22
+ Port("data", Any),
23
+ ]
24
+
25
+ def __init__(self, name: str):
26
+ super().__init__(name)
27
+ self.events = {}
28
+ self.is_setup_websocket = False
29
+
30
+ @staticmethod
31
+ def serialize_result(result: Any):
32
+ if isinstance(result, Message):
33
+ return MessageToJson(
34
+ result,
35
+ preserving_proto_field_name=True
36
+ )
37
+ return result
38
+
39
+ async def handle_stream_output(
40
+ self,
41
+ websocket: WebSocket,
42
+ task_id: uuid.UUID,
43
+ ):
44
+ try:
45
+ while True:
46
+ item = await self.stream_queues[task_id].get()
47
+
48
+ if item.is_error:
49
+ error_response = {
50
+ "type": "stream_error",
51
+ "task_id": str(task_id),
52
+ "detail": str(item.result)
53
+ }
54
+ await websocket.send_text(json.dumps(error_response))
55
+ break
56
+
57
+ if item.is_end:
58
+ # Stream ended, send final result if available
59
+ if item.result is not None:
60
+ serialized = self.serialize_result(item.result)
61
+ if isinstance(serialized, str):
62
+ try:
63
+ data = json.loads(serialized)
64
+ except json.JSONDecodeError:
65
+ data = serialized
66
+ else:
67
+ data = serialized
68
+
69
+ end_response = {
70
+ "type": "stream_end",
71
+ "task_id": str(task_id),
72
+ "data": data
73
+ }
74
+ else:
75
+ end_response = {
76
+ "type": "stream_end",
77
+ "task_id": str(task_id)
78
+ }
79
+ await websocket.send_text(json.dumps(end_response))
80
+ break
81
+
82
+ # Send stream data
83
+ serialized = self.serialize_result(item.result)
84
+ if isinstance(serialized, str):
85
+ try:
86
+ data = json.loads(serialized)
87
+ except json.JSONDecodeError:
88
+ data = serialized
89
+ else:
90
+ data = serialized
91
+
92
+ stream_response = {
93
+ "type": "stream",
94
+ "task_id": str(task_id),
95
+ "data": data
96
+ }
97
+ await websocket.send_text(json.dumps(stream_response))
98
+ except Exception as e:
99
+ logger.error(f"Error handling stream output for task {task_id}: {e}")
100
+ error_response = {
101
+ "type": "stream_error",
102
+ "task_id": str(task_id),
103
+ "detail": str(e)
104
+ }
105
+ try:
106
+ await websocket.send_text(json.dumps(error_response))
107
+ except Exception:
108
+ pass
109
+ finally:
110
+ if task_id in self.stream_queues:
111
+ del self.stream_queues[task_id]
112
+
113
+ async def handle_websocket_message(
114
+ self,
115
+ websocket: WebSocket,
116
+ data_type: type,
117
+ message_data: dict,
118
+ ):
119
+ task_id = uuid.uuid4()
120
+ self.result_queues[task_id] = asyncio.Queue()
121
+ self.stream_queues[task_id] = asyncio.Queue()
122
+
123
+ try:
124
+ converted_data = data_type(**message_data)
125
+ except Exception as e:
126
+ error_msg = {"error": f"Failed to convert data: {str(e)}"}
127
+ await websocket.send_text(json.dumps(error_msg))
128
+ return
129
+
130
+ # Always start background task to handle stream output
131
+ asyncio.create_task(self.handle_stream_output(websocket, task_id))
132
+
133
+ self.trigger_queue.put_nowait({
134
+ "id": task_id,
135
+ "data": converted_data,
136
+ })
137
+
138
+ # The stream handler will send all messages including stream_end when workflow completes
139
+
140
+ def _setup_websocket(self, app: FastAPI, path: str, data_type: type) -> None:
141
+ async def websocket_handler(websocket: WebSocket):
142
+ await websocket.accept()
143
+
144
+ try:
145
+ while True:
146
+ # Receive message from client
147
+ data = await websocket.receive_text()
148
+ try:
149
+ message = json.loads(data)
150
+
151
+ # Handle the message and trigger workflow
152
+ await self.handle_websocket_message(
153
+ websocket,
154
+ data_type,
155
+ message
156
+ )
157
+ except json.JSONDecodeError:
158
+ error_msg = {"error": "Invalid JSON format"}
159
+ await websocket.send_text(json.dumps(error_msg))
160
+ except Exception as e:
161
+ logger.error(f"Error handling websocket message: {e}")
162
+ error_msg = {"error": str(e)}
163
+ await websocket.send_text(json.dumps(error_msg))
164
+ except WebSocketDisconnect:
165
+ logger.info("WebSocket client disconnected")
166
+ except Exception as e:
167
+ logger.error(f"WebSocket connection error: {e}")
168
+
169
+ app.websocket(path)(websocket_handler)
170
+
171
+ async def _run(self, app: FastAPI, path: str, data_type: type) -> AsyncIterator[bool]:
172
+ if not self.is_setup_websocket:
173
+ self._setup_websocket(app, path, data_type)
174
+ self.is_setup_websocket = True
175
+
176
+ while True:
177
+ try:
178
+ trigger = await self.trigger_queue.get()
179
+ self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
180
+ yield self.trigger(trigger['id'])
181
+ except Exception as e:
182
+ logger.error(f"Error in WebSocketAPITrigger._run: {e}")
183
+ continue
184
+
@@ -0,0 +1,210 @@
1
+ from __future__ import annotations
2
+ import traceback
3
+ import asyncio
4
+ import uuid
5
+ from typing import AsyncIterator, Awaitable, Callable, Any
6
+ from loguru import logger
7
+ from copy import deepcopy
8
+ from .node import Node
9
+ from .port import Port
10
+ from .trigger import Trigger
11
+ from .edge import Edge
12
+ from ..db.database import DatabaseManager
13
+ from ..utils.workflow_clone import workflow_clone
14
+ from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
15
+
16
+ class Workflow:
17
+ def __init__(
18
+ self,
19
+ name: str,
20
+ description: str,
21
+ nodes: list[Node],
22
+ input_ports: list[Port],
23
+ output_ports: list[Port],
24
+ _handle_stream_output: Callable[[str, AsyncIterator[str]], Awaitable[None]] = None, # deprecated
25
+ _handle_query_user: Callable[[str, str], Awaitable[str]] = None,
26
+ database_manager: DatabaseManager = None,
27
+ max_concurrent_runs: int = 10,
28
+ callbacks: list[WorkflowCallback] = [],
29
+
30
+ # for run
31
+ task_id: uuid.UUID = None,
32
+ real_trigger_node: Trigger = None,
33
+ ) -> None:
34
+ self.name = name
35
+ self.description = description
36
+ self.nodes = nodes
37
+ self.ready_nodes: list[Node] = []
38
+ self.input_ports = input_ports
39
+ self.output_ports = output_ports
40
+ self._handle_stream_output = _handle_stream_output
41
+ self._handle_query_user = _handle_query_user
42
+ self.after_trigger_workflow = None
43
+ self.result_port = Port("result", Any)
44
+ self.database_manager = database_manager
45
+ self.max_concurrent_runs = max_concurrent_runs
46
+ self.run_semaphore = asyncio.Semaphore(max_concurrent_runs)
47
+ self.callbacks = callbacks
48
+ self.task_id = task_id
49
+ self.real_trigger_node = real_trigger_node
50
+ self._validate()
51
+
52
+ def register_callback(self, callback: WorkflowCallback) -> None:
53
+ self.callbacks.append(callback)
54
+
55
+ def unregister_callback(self, callback: WorkflowCallback) -> None:
56
+ self.callbacks.remove(callback)
57
+
58
+ async def call_callbacks(self, callback_type: CallbackEvent, *args, **kwargs) -> None:
59
+ for callback in self.callbacks:
60
+ if callback_type == CallbackEvent.ON_WORKFLOW_START:
61
+ await callback.on_workflow_start(*args, **kwargs)
62
+ elif callback_type == CallbackEvent.ON_WORKFLOW_END:
63
+ await callback.on_workflow_end(*args, **kwargs)
64
+ elif callback_type == CallbackEvent.ON_NODE_START:
65
+ await callback.on_node_start(*args, **kwargs)
66
+ elif callback_type == CallbackEvent.ON_NODE_END:
67
+ await callback.on_node_end(*args, **kwargs)
68
+ elif callback_type == CallbackEvent.ON_NODE_STREAM_OUTPUT:
69
+ await callback.on_node_stream_output(*args, **kwargs)
70
+
71
+ def add_nodes(self, nodes: list[Node]) -> None:
72
+ for node in nodes:
73
+ node.workflow = self
74
+ self.nodes.extend(nodes)
75
+
76
+ def remove_nodes(self, nodes: list[Node]) -> None:
77
+ for node in nodes:
78
+ self.nodes.remove(node)
79
+
80
+ def load_config(self) -> None:
81
+ ...
82
+
83
+ def _validate(self) -> None:
84
+ # DAG
85
+ ...
86
+
87
+ def get_input_port_by_name(self, name: str) -> Port:
88
+ for input_port in self.input_ports:
89
+ if input_port.name == name:
90
+ return input_port
91
+ return None
92
+
93
+ def get_output_port_by_name(self, name: str) -> Port:
94
+ for output_port in self.output_ports:
95
+ if output_port.name == name:
96
+ return output_port
97
+ return None
98
+
99
+ def get_trigger_node(self) -> Trigger:
100
+ trigger_nodes = [node for node in self.nodes if isinstance(node, Trigger)]
101
+ if not trigger_nodes:
102
+ raise ValueError("No trigger nodes found in workflow.")
103
+ if len(trigger_nodes) > 1:
104
+ raise ValueError("Multiple trigger nodes found in workflow.")
105
+ return trigger_nodes[0]
106
+
107
+ async def _run_node_with_callbacks(self, node: Node) -> None:
108
+ await self.call_callbacks(CallbackEvent.ON_NODE_START, node=node)
109
+
110
+ try:
111
+ result = node.run()
112
+ if hasattr(result, '__anext__'):
113
+ await self.handle_node_stream_output(node, result)
114
+ elif asyncio.iscoroutine(result):
115
+ await result
116
+ finally:
117
+ await self.call_callbacks(CallbackEvent.ON_NODE_END, node=node)
118
+
119
+ async def run_after_trigger(self) -> Any:
120
+ logger.info(f"Running workflow: {self.name}")
121
+
122
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_START, workflow=self)
123
+
124
+ self.ready_nodes = []
125
+ for edge in self.get_trigger_node().output_edges:
126
+ edge.end_port.trigger()
127
+
128
+ try:
129
+ for input_port in self.input_ports:
130
+ if input_port.value is not None:
131
+ input_port.port.node.fill_input(input_port.port, input_port.value)
132
+
133
+ for node in self.nodes:
134
+ for key in node.AUTO_FILL_INPUT_PORTS:
135
+ if key[0] not in [edge.end_port.name for edge in node.input_edges]:
136
+ node.fill_input_by_name(key[0], key[1])
137
+
138
+ while self.ready_nodes:
139
+ nodes = self.ready_nodes.copy()
140
+ self.ready_nodes = []
141
+
142
+ tasks = []
143
+ for node in nodes:
144
+ tasks.append(asyncio.create_task(self._run_node_with_callbacks(node)))
145
+
146
+ await asyncio.gather(*tasks)
147
+
148
+ except Exception as e:
149
+ error_msg = f"Error in run_after_trigger: {str(e)}"
150
+ logger.error(error_msg)
151
+ raise e
152
+
153
+ if len(self.output_ports) > 0:
154
+ if len(self.output_ports) == 1:
155
+ if self.output_ports[0].is_prepared:
156
+ result = self.output_ports[0].value
157
+ else:
158
+ result = None
159
+ else:
160
+ result = {}
161
+ for port in self.output_ports:
162
+ if port.is_prepared:
163
+ result[port.name] = port.value
164
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=result)
165
+ else:
166
+ await self.call_callbacks(CallbackEvent.ON_WORKFLOW_END, workflow=self, output=None)
167
+
168
+ async def _run(self, task_id: uuid.UUID, trigger_node: Trigger) -> None:
169
+ async with self.run_semaphore:
170
+ try:
171
+ new_workflow = self._clone(task_id, trigger_node)
172
+ await new_workflow.run_after_trigger()
173
+ # TODO: clear new_workflow
174
+
175
+ except Exception as e:
176
+ error_msg = f"Error running workflow: {str(e)}, {traceback.format_exc()}"
177
+ logger.error(error_msg)
178
+
179
+ async def run(self):
180
+ tasks = []
181
+ trigger = self.get_trigger_node()
182
+
183
+ async for task_id in trigger.run():
184
+ tasks.append(asyncio.create_task(self._run(task_id, trigger)))
185
+
186
+ if tasks:
187
+ await asyncio.gather(*tasks)
188
+
189
+ def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
190
+ trigger = self.get_trigger_node()
191
+ task_id = uuid.uuid4()
192
+ for key, value in kwargs.items():
193
+ trigger.prepare_output_edges(key, value)
194
+ task = asyncio.create_task(self._run(task_id, trigger))
195
+ return task_id
196
+
197
+ async def handle_node_stream_output(
198
+ self,
199
+ node: Node,
200
+ stream: AsyncIterator[Any],
201
+ ) -> None:
202
+ async for data in stream:
203
+ await self.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=node, output=data)
204
+
205
+ # TODO: refactor this
206
+ async def handle_query_user(self, node_name: str, prompt: str) -> Awaitable[str]:
207
+ return await asyncio.to_thread(input, f"[{node_name}] {prompt}: ")
208
+
209
+ def _clone(self, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
210
+ return workflow_clone(self, task_id, trigger_node)