service-forge 0.1.28__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/__init__.py +0 -0
- service_forge/api/deprecated_websocket_api.py +91 -33
- service_forge/api/deprecated_websocket_manager.py +70 -53
- service_forge/api/http_api.py +127 -53
- service_forge/api/kafka_api.py +113 -25
- service_forge/api/routers/meta_api/meta_api_router.py +57 -0
- service_forge/api/routers/service/service_router.py +42 -6
- service_forge/api/routers/trace/trace_router.py +326 -0
- service_forge/api/routers/websocket/websocket_router.py +56 -1
- service_forge/api/service_studio.py +9 -0
- service_forge/execution_context.py +106 -0
- service_forge/frontend/static/assets/CreateNewNodeDialog-DkrEMxSH.js +1 -0
- service_forge/frontend/static/assets/CreateNewNodeDialog-DwFcBiGp.css +1 -0
- service_forge/frontend/static/assets/EditorSidePanel-BNVms9Fq.css +1 -0
- service_forge/frontend/static/assets/EditorSidePanel-DZbB3ILL.js +1 -0
- service_forge/frontend/static/assets/FeedbackPanel-CC8HX7Yo.js +1 -0
- service_forge/frontend/static/assets/FeedbackPanel-ClgniIVk.css +1 -0
- service_forge/frontend/static/assets/FormattedCodeViewer.vue_vue_type_script_setup_true_lang-BNuI1NCs.js +1 -0
- service_forge/frontend/static/assets/NodeDetailWrapper-BqFFM7-r.js +1 -0
- service_forge/frontend/static/assets/NodeDetailWrapper-pZBxv3J0.css +1 -0
- service_forge/frontend/static/assets/TestRunningDialog-D0GrCoYs.js +1 -0
- service_forge/frontend/static/assets/TestRunningDialog-dhXOsPgH.css +1 -0
- service_forge/frontend/static/assets/TracePanelWrapper-B9zvDSc_.js +1 -0
- service_forge/frontend/static/assets/TracePanelWrapper-BiednCrq.css +1 -0
- service_forge/frontend/static/assets/WorkflowEditor-CcaGGbko.js +3 -0
- service_forge/frontend/static/assets/WorkflowEditor-CmasOOYK.css +1 -0
- service_forge/frontend/static/assets/WorkflowList-Copuwi-a.css +1 -0
- service_forge/frontend/static/assets/WorkflowList-LrRJ7B7h.js +1 -0
- service_forge/frontend/static/assets/WorkflowStudio-CthjgII2.css +1 -0
- service_forge/frontend/static/assets/WorkflowStudio-FCyhGD4y.js +2 -0
- service_forge/frontend/static/assets/api-BDer3rj7.css +1 -0
- service_forge/frontend/static/assets/api-DyiqpKJK.js +1 -0
- service_forge/frontend/static/assets/code-editor-DBSql_sc.js +12 -0
- service_forge/frontend/static/assets/el-collapse-item-D4LG0FJ0.css +1 -0
- service_forge/frontend/static/assets/el-empty-D4ZqTl4F.css +1 -0
- service_forge/frontend/static/assets/el-form-item-BWkJzdQ_.css +1 -0
- service_forge/frontend/static/assets/el-input-D6B3r8CH.css +1 -0
- service_forge/frontend/static/assets/el-select-B0XIb2QK.css +1 -0
- service_forge/frontend/static/assets/el-tag-DljBBxJR.css +1 -0
- service_forge/frontend/static/assets/element-ui-D3x2y3TA.js +12 -0
- service_forge/frontend/static/assets/elkjs-Dm5QV7uy.js +24 -0
- service_forge/frontend/static/assets/highlightjs-D4ATuRwX.js +3 -0
- service_forge/frontend/static/assets/index-BMvodlwc.js +2 -0
- service_forge/frontend/static/assets/index-CjSe8i2q.css +1 -0
- service_forge/frontend/static/assets/js-yaml-yTPt38rv.js +32 -0
- service_forge/frontend/static/assets/time-DKCKV6Ug.js +1 -0
- service_forge/frontend/static/assets/ui-components-DQ7-U3pr.js +1 -0
- service_forge/frontend/static/assets/vue-core-DL-LgTX0.js +1 -0
- service_forge/frontend/static/assets/vue-flow-Dn7R8GPr.js +39 -0
- service_forge/frontend/static/index.html +16 -0
- service_forge/frontend/static/vite.svg +1 -0
- service_forge/model/meta_api/__init__.py +0 -0
- service_forge/model/meta_api/schema.py +29 -0
- service_forge/model/trace.py +82 -0
- service_forge/service.py +32 -11
- service_forge/service_config.py +14 -0
- service_forge/sft/config/injector.py +32 -2
- service_forge/sft/config/injector_default_files.py +12 -0
- service_forge/sft/config/sf_metadata.py +5 -0
- service_forge/sft/config/sft_config.py +18 -0
- service_forge/telemetry.py +66 -0
- service_forge/workflow/node.py +266 -27
- service_forge/workflow/triggers/fast_api_trigger.py +61 -28
- service_forge/workflow/triggers/websocket_api_trigger.py +31 -10
- service_forge/workflow/workflow.py +87 -10
- service_forge/workflow/workflow_callback.py +24 -2
- service_forge/workflow/workflow_factory.py +13 -0
- {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/METADATA +4 -1
- service_forge-0.1.39.dist-info/RECORD +134 -0
- service_forge-0.1.28.dist-info/RECORD +0 -85
- {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/WHEEL +0 -0
- {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/entry_points.txt +0 -0
service_forge/workflow/node.py
CHANGED
|
@@ -1,51 +1,78 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
5
|
import uuid
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import asdict, is_dataclass
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
AsyncIterator,
|
|
12
|
+
Awaitable,
|
|
13
|
+
Callable,
|
|
14
|
+
Optional,
|
|
15
|
+
Union,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
6
19
|
from loguru import logger
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from .
|
|
10
|
-
from
|
|
20
|
+
from opentelemetry import context as otel_context_api
|
|
21
|
+
from opentelemetry import trace
|
|
22
|
+
from opentelemetry.trace import SpanKind
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
11
25
|
from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
12
26
|
from ..utils.workflow_clone import node_clone
|
|
27
|
+
from ..execution_context import (
|
|
28
|
+
ExecutionContext,
|
|
29
|
+
get_current_context,
|
|
30
|
+
reset_current_context,
|
|
31
|
+
set_current_context,
|
|
32
|
+
)
|
|
33
|
+
from ..utils.register import Register
|
|
34
|
+
from .context import Context
|
|
35
|
+
from .edge import Edge
|
|
36
|
+
from .port import Port
|
|
13
37
|
from .workflow_callback import CallbackEvent
|
|
14
38
|
|
|
15
39
|
if TYPE_CHECKING:
|
|
16
40
|
from .workflow import Workflow
|
|
17
41
|
|
|
42
|
+
|
|
18
43
|
class Node(ABC):
|
|
19
44
|
DEFAULT_INPUT_PORTS: list[Port] = []
|
|
20
45
|
DEFAULT_OUTPUT_PORTS: list[Port] = []
|
|
21
46
|
|
|
22
|
-
CLASS_NOT_REQUIRED_TO_REGISTER = [
|
|
47
|
+
CLASS_NOT_REQUIRED_TO_REGISTER = ["Node"]
|
|
23
48
|
AUTO_FILL_INPUT_PORTS = []
|
|
24
49
|
|
|
25
50
|
def __init__(
|
|
26
51
|
self,
|
|
27
52
|
name: str,
|
|
28
|
-
context: Context = None,
|
|
29
|
-
input_edges: list[Edge] = None,
|
|
30
|
-
output_edges: list[Edge] = None,
|
|
53
|
+
context: Optional[Context] = None,
|
|
54
|
+
input_edges: Optional[list[Edge]] = None,
|
|
55
|
+
output_edges: Optional[list[Edge]] = None,
|
|
31
56
|
input_ports: list[Port] = DEFAULT_INPUT_PORTS,
|
|
32
57
|
output_ports: list[Port] = DEFAULT_OUTPUT_PORTS,
|
|
33
|
-
query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
58
|
+
query_user: Optional[Callable[[str, str], Awaitable[str]]] = None,
|
|
34
59
|
) -> None:
|
|
35
60
|
from .workflow_group import WorkflowGroup
|
|
61
|
+
|
|
36
62
|
self.name = name
|
|
37
63
|
self.input_edges = [] if input_edges is None else input_edges
|
|
38
64
|
self.output_edges = [] if output_edges is None else output_edges
|
|
39
65
|
self.input_ports = input_ports
|
|
40
66
|
self.output_ports = output_ports
|
|
41
|
-
self.workflow: Workflow = None
|
|
67
|
+
self.workflow: Optional[Workflow] = None
|
|
42
68
|
self.query_user = query_user
|
|
43
|
-
self.sub_workflows: WorkflowGroup = None
|
|
69
|
+
self.sub_workflows: Optional[WorkflowGroup] = None
|
|
44
70
|
|
|
45
71
|
# runtime variables
|
|
46
72
|
self.context = context
|
|
47
73
|
self.input_variables: dict[Port, Any] = {}
|
|
48
74
|
self.num_activated_input_edges = 0
|
|
75
|
+
self._tracer = trace.get_tracer("service_forge.node")
|
|
49
76
|
|
|
50
77
|
@property
|
|
51
78
|
def default_postgres_database(self) -> PostgresDatabase | None:
|
|
@@ -61,6 +88,7 @@ class Node(ABC):
|
|
|
61
88
|
|
|
62
89
|
@property
|
|
63
90
|
def database_manager(self) -> DatabaseManager:
|
|
91
|
+
assert self.workflow
|
|
64
92
|
return self.workflow.database_manager
|
|
65
93
|
|
|
66
94
|
@property
|
|
@@ -71,22 +99,28 @@ class Node(ABC):
|
|
|
71
99
|
# do NOT use deepcopy here
|
|
72
100
|
# self.bak_context = deepcopy(self.context)
|
|
73
101
|
# TODO: what if the value changes after backup?
|
|
74
|
-
self.bak_input_variables = {
|
|
102
|
+
self.bak_input_variables = {
|
|
103
|
+
port: value for port, value in self.input_variables.items()
|
|
104
|
+
}
|
|
75
105
|
self.bak_num_activated_input_edges = self.num_activated_input_edges
|
|
76
106
|
|
|
77
107
|
def reset(self) -> None:
|
|
78
108
|
# self.context = deepcopy(self.bak_context)
|
|
79
|
-
self.input_variables = {
|
|
109
|
+
self.input_variables = {
|
|
110
|
+
port: value for port, value in self.bak_input_variables.items()
|
|
111
|
+
}
|
|
80
112
|
self.num_activated_input_edges = self.bak_num_activated_input_edges
|
|
81
113
|
|
|
82
114
|
def __init_subclass__(cls) -> None:
|
|
83
115
|
if cls.__name__ not in Node.CLASS_NOT_REQUIRED_TO_REGISTER:
|
|
116
|
+
# TODO: Register currently stores class objects; clarify Register typing vs instance usage.
|
|
84
117
|
node_register.register(cls.__name__, cls)
|
|
85
118
|
return super().__init_subclass__()
|
|
86
119
|
|
|
87
|
-
def _query_user(self, prompt: str) ->
|
|
120
|
+
def _query_user(self, prompt: str) -> Awaitable[str]:
|
|
121
|
+
assert self.query_user
|
|
88
122
|
return self.query_user(self.name, prompt)
|
|
89
|
-
|
|
123
|
+
|
|
90
124
|
def variables_to_params(self) -> dict[str, Any]:
|
|
91
125
|
params = {port.name: self.input_variables[port] for port in self.input_variables.keys() if not port.is_extended_generated}
|
|
92
126
|
for port in self.input_variables.keys():
|
|
@@ -99,6 +133,7 @@ class Node(ABC):
|
|
|
99
133
|
|
|
100
134
|
def is_trigger(self) -> bool:
|
|
101
135
|
from .trigger import Trigger
|
|
136
|
+
|
|
102
137
|
return isinstance(self, Trigger)
|
|
103
138
|
|
|
104
139
|
# TODO: maybe add a function before the run function?
|
|
@@ -106,23 +141,219 @@ class Node(ABC):
|
|
|
106
141
|
@abstractmethod
|
|
107
142
|
async def _run(self, **kwargs) -> Union[None, AsyncIterator]:
|
|
108
143
|
...
|
|
144
|
+
|
|
145
|
+
async def clear(self) -> None:
|
|
146
|
+
...
|
|
109
147
|
|
|
110
148
|
def run(self) -> Union[None, AsyncIterator]:
|
|
149
|
+
task_id: uuid.UUID | None = None
|
|
111
150
|
for key in list(self.input_variables.keys()):
|
|
112
151
|
if key and key.name[0].isupper():
|
|
113
152
|
del self.input_variables[key]
|
|
114
153
|
params = self.variables_to_params()
|
|
115
|
-
|
|
154
|
+
if task_id is not None and "task_id" in self._run.__code__.co_varnames:
|
|
155
|
+
params["task_id"] = task_id
|
|
156
|
+
base_context = get_current_context()
|
|
157
|
+
parent_context = (
|
|
158
|
+
base_context.trace_context
|
|
159
|
+
if base_context and base_context.trace_context
|
|
160
|
+
else otel_context_api.get_current()
|
|
161
|
+
)
|
|
162
|
+
span_name = f"Node {self.name}"
|
|
116
163
|
|
|
117
|
-
|
|
118
|
-
|
|
164
|
+
if inspect.isasyncgenfunction(self._run):
|
|
165
|
+
return self._run_async_generator(
|
|
166
|
+
params, task_id, base_context, parent_context, span_name
|
|
167
|
+
)
|
|
168
|
+
if inspect.iscoroutinefunction(self._run):
|
|
169
|
+
return self._run_async(
|
|
170
|
+
params, task_id, base_context, parent_context, span_name
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return self._run_sync(params, task_id, base_context, parent_context, span_name)
|
|
174
|
+
|
|
175
|
+
def _build_execution_context(
|
|
176
|
+
self, base_context: ExecutionContext | None, span: trace.Span
|
|
177
|
+
) -> ExecutionContext:
|
|
178
|
+
return ExecutionContext(
|
|
179
|
+
trace_context=otel_context_api.get_current(),
|
|
180
|
+
span=span,
|
|
181
|
+
state=base_context.state if base_context else {},
|
|
182
|
+
metadata={
|
|
183
|
+
**(base_context.metadata if base_context else {}),
|
|
184
|
+
"node": self.name,
|
|
185
|
+
"workflow_name": getattr(self.workflow, "name", None),
|
|
186
|
+
},
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _serialize_for_trace(value: Any, max_length: int = 4000) -> tuple[str, bool]:
|
|
191
|
+
def _normalize(val: Any) -> Any:
|
|
192
|
+
if isinstance(val, BaseModel):
|
|
193
|
+
return val.model_dump()
|
|
194
|
+
if hasattr(val, "model_dump"):
|
|
195
|
+
try:
|
|
196
|
+
return val.model_dump()
|
|
197
|
+
except Exception:
|
|
198
|
+
pass
|
|
199
|
+
if hasattr(val, "dict"):
|
|
200
|
+
try:
|
|
201
|
+
return val.dict()
|
|
202
|
+
except Exception:
|
|
203
|
+
pass
|
|
204
|
+
if is_dataclass is not None and is_dataclass(val):
|
|
205
|
+
return asdict(val) if asdict else val
|
|
206
|
+
if isinstance(val, dict):
|
|
207
|
+
return {k: _normalize(v) for k, v in val.items()}
|
|
208
|
+
if isinstance(val, (list, tuple)):
|
|
209
|
+
return [_normalize(v) for v in val]
|
|
210
|
+
return val
|
|
211
|
+
|
|
212
|
+
normalized_value = _normalize(value)
|
|
213
|
+
serialized = json.dumps(normalized_value, ensure_ascii=False, default=str)
|
|
214
|
+
if len(serialized) > max_length:
|
|
215
|
+
return serialized[:max_length], True
|
|
216
|
+
return serialized, False
|
|
217
|
+
|
|
218
|
+
def _set_span_attributes(
|
|
219
|
+
self, span: trace.Span, params: dict[str, Any], task_id: uuid.UUID | None
|
|
220
|
+
) -> None:
|
|
221
|
+
span.set_attribute("node.name", self.name)
|
|
222
|
+
if self.workflow is not None:
|
|
223
|
+
span.set_attribute("workflow.name", self.workflow.name)
|
|
224
|
+
if task_id is not None:
|
|
225
|
+
span.set_attribute("workflow.task_id", str(task_id))
|
|
226
|
+
span.set_attribute("node.input_keys", ",".join(params.keys()))
|
|
227
|
+
serialized_inputs, inputs_truncated = self._serialize_for_trace(params)
|
|
228
|
+
span.set_attribute("node.inputs", serialized_inputs)
|
|
229
|
+
if inputs_truncated:
|
|
230
|
+
span.set_attribute("node.inputs_truncated", True)
|
|
231
|
+
|
|
232
|
+
def _record_output(self, span: trace.Span, output: Any) -> None:
|
|
233
|
+
span.set_attribute(
|
|
234
|
+
"node.output_type", type(output).__name__ if output is not None else "None"
|
|
235
|
+
)
|
|
236
|
+
serialized_output, output_truncated = self._serialize_for_trace(output)
|
|
237
|
+
span.set_attribute("node.output", serialized_output)
|
|
238
|
+
if output_truncated:
|
|
239
|
+
span.set_attribute("node.output_truncated", True)
|
|
240
|
+
|
|
241
|
+
async def _run_async(
|
|
242
|
+
self,
|
|
243
|
+
params: dict[str, Any],
|
|
244
|
+
task_id: uuid.UUID | None,
|
|
245
|
+
base_context: ExecutionContext | None,
|
|
246
|
+
parent_context: otel_context_api.Context,
|
|
247
|
+
span_name: str,
|
|
248
|
+
) -> Any:
|
|
249
|
+
with self._tracer.start_as_current_span(
|
|
250
|
+
span_name,
|
|
251
|
+
context=parent_context,
|
|
252
|
+
kind=SpanKind.INTERNAL,
|
|
253
|
+
) as span:
|
|
254
|
+
self._set_span_attributes(span, params, task_id)
|
|
255
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
256
|
+
token = set_current_context(exec_ctx)
|
|
257
|
+
try:
|
|
258
|
+
result = self._run(**params)
|
|
259
|
+
if inspect.isawaitable(result):
|
|
260
|
+
result = await result
|
|
261
|
+
self._record_output(span, result)
|
|
262
|
+
return result
|
|
263
|
+
finally:
|
|
264
|
+
reset_current_context(token)
|
|
265
|
+
|
|
266
|
+
async def _run_async_generator(
|
|
267
|
+
self,
|
|
268
|
+
params: dict[str, Any],
|
|
269
|
+
task_id: uuid.UUID | None,
|
|
270
|
+
base_context: ExecutionContext | None,
|
|
271
|
+
parent_context: otel_context_api.Context,
|
|
272
|
+
span_name: str,
|
|
273
|
+
) -> AsyncIterator[Any]:
|
|
274
|
+
# Trigger 节点是长期运行的 async generator,这里为每次触发单独生成/关闭 span,避免一个 span 挂载所有请求。
|
|
275
|
+
if self.is_trigger():
|
|
276
|
+
async for item in self._run(**params):
|
|
277
|
+
trigger_parent_context = parent_context
|
|
278
|
+
if hasattr(self, "task_contexts") and isinstance(item, uuid.UUID):
|
|
279
|
+
trigger_ctx = getattr(self, "task_contexts").get(item)
|
|
280
|
+
if trigger_ctx is not None:
|
|
281
|
+
trigger_parent_context = trigger_ctx
|
|
282
|
+
with self._tracer.start_as_current_span(
|
|
283
|
+
span_name,
|
|
284
|
+
context=trigger_parent_context,
|
|
285
|
+
kind=SpanKind.INTERNAL,
|
|
286
|
+
) as span:
|
|
287
|
+
self._set_span_attributes(span, params, task_id)
|
|
288
|
+
span.set_attribute("node.output_type", "async_generator")
|
|
289
|
+
serialized_item, item_truncated = self._serialize_for_trace(item)
|
|
290
|
+
span.add_event(
|
|
291
|
+
"node.output_item",
|
|
292
|
+
{
|
|
293
|
+
"value": serialized_item,
|
|
294
|
+
"truncated": item_truncated,
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
298
|
+
token = set_current_context(exec_ctx)
|
|
299
|
+
try:
|
|
300
|
+
yield item
|
|
301
|
+
finally:
|
|
302
|
+
reset_current_context(token)
|
|
303
|
+
else:
|
|
304
|
+
with self._tracer.start_as_current_span(
|
|
305
|
+
span_name,
|
|
306
|
+
context=parent_context,
|
|
307
|
+
kind=SpanKind.INTERNAL,
|
|
308
|
+
) as span:
|
|
309
|
+
self._set_span_attributes(span, params, task_id)
|
|
310
|
+
span.set_attribute("node.output_type", "async_generator")
|
|
311
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
312
|
+
token = set_current_context(exec_ctx)
|
|
313
|
+
try:
|
|
314
|
+
async for item in self._run(**params):
|
|
315
|
+
serialized_item, item_truncated = self._serialize_for_trace(item)
|
|
316
|
+
span.add_event(
|
|
317
|
+
"node.output_item",
|
|
318
|
+
{
|
|
319
|
+
"value": serialized_item,
|
|
320
|
+
"truncated": item_truncated,
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
yield item
|
|
324
|
+
finally:
|
|
325
|
+
reset_current_context(token)
|
|
326
|
+
|
|
327
|
+
def _run_sync(
|
|
328
|
+
self,
|
|
329
|
+
params: dict[str, Any],
|
|
330
|
+
task_id: uuid.UUID | None,
|
|
331
|
+
base_context: ExecutionContext | None,
|
|
332
|
+
parent_context: otel_context_api.Context,
|
|
333
|
+
span_name: str,
|
|
334
|
+
) -> Any:
|
|
335
|
+
with self._tracer.start_as_current_span(
|
|
336
|
+
span_name,
|
|
337
|
+
context=parent_context,
|
|
338
|
+
kind=SpanKind.INTERNAL,
|
|
339
|
+
) as span:
|
|
340
|
+
self._set_span_attributes(span, params, task_id)
|
|
341
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
342
|
+
token = set_current_context(exec_ctx)
|
|
343
|
+
try:
|
|
344
|
+
result = self._run(**params)
|
|
345
|
+
self._record_output(span, result)
|
|
346
|
+
return result
|
|
347
|
+
finally:
|
|
348
|
+
reset_current_context(token)
|
|
349
|
+
|
|
350
|
+
def get_input_port_by_name(self, name: str) -> Optional[Port]:
|
|
119
351
|
for port in self.input_ports:
|
|
120
352
|
if port.name == name:
|
|
121
353
|
return port
|
|
122
354
|
return None
|
|
123
355
|
|
|
124
|
-
def get_output_port_by_name(self, name: str) -> Port:
|
|
125
|
-
# TODO: add warning if port is extended
|
|
356
|
+
def get_output_port_by_name(self, name: str) -> Optional[Port]:
|
|
126
357
|
for port in self.output_ports:
|
|
127
358
|
if port.name == name:
|
|
128
359
|
return port
|
|
@@ -148,9 +379,9 @@ class Node(ABC):
|
|
|
148
379
|
self.try_create_extended_input_port(port_name)
|
|
149
380
|
port = self.get_input_port_by_name(port_name)
|
|
150
381
|
if port is None:
|
|
151
|
-
raise ValueError(f
|
|
382
|
+
raise ValueError(f"{port_name} is not a valid input port.")
|
|
152
383
|
self.fill_input(port, value)
|
|
153
|
-
|
|
384
|
+
|
|
154
385
|
def fill_input(self, port: Port, value: Any) -> None:
|
|
155
386
|
port.activate(value)
|
|
156
387
|
|
|
@@ -168,7 +399,7 @@ class Node(ABC):
|
|
|
168
399
|
for output_edge in self.output_edges:
|
|
169
400
|
if output_edge.start_port == port:
|
|
170
401
|
output_edge.end_port.prepare(data)
|
|
171
|
-
|
|
402
|
+
|
|
172
403
|
def trigger_output_edges(self, port: Port) -> None:
|
|
173
404
|
if isinstance(port, str):
|
|
174
405
|
port = self.get_output_port_by_name(port)
|
|
@@ -178,7 +409,14 @@ class Node(ABC):
|
|
|
178
409
|
|
|
179
410
|
# TODO: the result is outputed to the trigger now, maybe we should add a new function to output the result to the workflow
|
|
180
411
|
def output_to_workflow(self, data: Any) -> None:
|
|
181
|
-
self.workflow
|
|
412
|
+
if self.workflow and hasattr(self.workflow, "_handle_workflow_output"):
|
|
413
|
+
handler = cast(
|
|
414
|
+
Callable[[str, Any], None],
|
|
415
|
+
getattr(self.workflow, "_handle_workflow_output"),
|
|
416
|
+
)
|
|
417
|
+
handler(self.name, data)
|
|
418
|
+
else:
|
|
419
|
+
logger.warning("Workflow output handler not set; skipping output dispatch.")
|
|
182
420
|
|
|
183
421
|
def extended_output_name(self, name: str, index: int) -> str:
|
|
184
422
|
return name + '_' + str(index)
|
|
@@ -189,4 +427,5 @@ class Node(ABC):
|
|
|
189
427
|
async def stream_output(self, data: Any) -> None:
|
|
190
428
|
await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
|
|
191
429
|
|
|
192
|
-
|
|
430
|
+
|
|
431
|
+
node_register = Register[Node]()
|
|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
4
4
|
import json
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from service_forge.workflow.trigger import Trigger
|
|
7
|
-
from typing import AsyncIterator, Any
|
|
7
|
+
from typing import AsyncIterator, Any, Optional
|
|
8
8
|
from fastapi import FastAPI, Request
|
|
9
9
|
from fastapi.responses import StreamingResponse
|
|
10
10
|
from service_forge.workflow.port import Port
|
|
@@ -13,6 +13,7 @@ from service_forge.api.routers.websocket.websocket_manager import websocket_mana
|
|
|
13
13
|
from fastapi import HTTPException
|
|
14
14
|
from google.protobuf.message import Message
|
|
15
15
|
from google.protobuf.json_format import MessageToJson
|
|
16
|
+
from opentelemetry import context as otel_context_api
|
|
16
17
|
|
|
17
18
|
class FastAPITrigger(Trigger):
|
|
18
19
|
DEFAULT_INPUT_PORTS = [
|
|
@@ -28,12 +29,14 @@ class FastAPITrigger(Trigger):
|
|
|
28
29
|
Port("user_id", int),
|
|
29
30
|
Port("token", str),
|
|
30
31
|
Port("data", Any),
|
|
32
|
+
Port("path_params", dict),
|
|
31
33
|
]
|
|
32
34
|
|
|
33
35
|
def __init__(self, name: str):
|
|
34
36
|
super().__init__(name)
|
|
35
37
|
self.events = {}
|
|
36
38
|
self.is_setup_route = False
|
|
39
|
+
self.task_contexts: dict[uuid.UUID, otel_context_api.Context] = {}
|
|
37
40
|
self.app = None
|
|
38
41
|
self.route_path = None
|
|
39
42
|
self.route_method = None
|
|
@@ -41,23 +44,44 @@ class FastAPITrigger(Trigger):
|
|
|
41
44
|
@staticmethod
|
|
42
45
|
def serialize_result(result: Any):
|
|
43
46
|
if isinstance(result, Message):
|
|
44
|
-
return MessageToJson(
|
|
45
|
-
result,
|
|
46
|
-
preserving_proto_field_name=True
|
|
47
|
-
)
|
|
47
|
+
return MessageToJson(result, preserving_proto_field_name=True)
|
|
48
48
|
return result
|
|
49
49
|
|
|
50
|
+
def _normalize_result_or_raise(self, result: Any):
|
|
51
|
+
# TODO: 检查合并
|
|
52
|
+
if hasattr(result, "is_error") and hasattr(result, "result"):
|
|
53
|
+
if result.is_error:
|
|
54
|
+
if isinstance(result.result, HTTPException):
|
|
55
|
+
raise result.result
|
|
56
|
+
raise HTTPException(status_code=500, detail=str(result.result))
|
|
57
|
+
return self.serialize_result(result.result)
|
|
58
|
+
|
|
59
|
+
if isinstance(result, HTTPException):
|
|
60
|
+
raise result
|
|
61
|
+
if isinstance(result, Exception):
|
|
62
|
+
raise HTTPException(status_code=500, detail=str(result))
|
|
63
|
+
|
|
64
|
+
return self.serialize_result(result)
|
|
65
|
+
|
|
50
66
|
async def handle_request(
|
|
51
67
|
self,
|
|
52
68
|
request: Request,
|
|
53
69
|
data_type: type,
|
|
54
70
|
extract_data_fn: callable[[Request], dict],
|
|
55
71
|
is_stream: bool,
|
|
72
|
+
path_params: Optional[dict] = None,
|
|
56
73
|
):
|
|
57
74
|
task_id = uuid.uuid4()
|
|
58
75
|
self.result_queues[task_id] = asyncio.Queue()
|
|
59
76
|
|
|
77
|
+
# parse trace context
|
|
78
|
+
trace_ctx = otel_context_api.get_current()
|
|
79
|
+
self.task_contexts[task_id] = trace_ctx
|
|
80
|
+
|
|
60
81
|
body_data = await extract_data_fn(request)
|
|
82
|
+
# Merge path parameters into body_data (path params take precedence)
|
|
83
|
+
if path_params:
|
|
84
|
+
body_data = {**body_data, **path_params}
|
|
61
85
|
converted_data = data_type(**body_data)
|
|
62
86
|
|
|
63
87
|
client_id = (
|
|
@@ -70,11 +94,16 @@ class FastAPITrigger(Trigger):
|
|
|
70
94
|
steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
|
|
71
95
|
websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
|
|
72
96
|
|
|
97
|
+
# trigger_queue with trace_context, used in _run()
|
|
98
|
+
logger.info(f'user_id {getattr(request.state, "user_id", None)} token {getattr(request.state, "auth_token", None)}')
|
|
99
|
+
|
|
73
100
|
self.trigger_queue.put_nowait({
|
|
74
101
|
"id": task_id,
|
|
75
102
|
"user_id": getattr(request.state, "user_id", None),
|
|
76
103
|
"token": getattr(request.state, "auth_token", None),
|
|
77
104
|
"data": converted_data,
|
|
105
|
+
"trace_context": trace_ctx,
|
|
106
|
+
"path_params": path_params,
|
|
78
107
|
})
|
|
79
108
|
|
|
80
109
|
if is_stream:
|
|
@@ -85,26 +114,23 @@ class FastAPITrigger(Trigger):
|
|
|
85
114
|
while True:
|
|
86
115
|
item = await self.stream_queues[task_id].get()
|
|
87
116
|
|
|
88
|
-
if item
|
|
117
|
+
if getattr(item, "is_error", False):
|
|
89
118
|
yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
|
|
90
119
|
break
|
|
91
|
-
|
|
92
|
-
if item
|
|
93
|
-
# TODO: send the result?
|
|
120
|
+
|
|
121
|
+
if getattr(item, "is_end", False):
|
|
94
122
|
break
|
|
95
123
|
|
|
96
124
|
# TODO: modify
|
|
97
125
|
serialized = self.serialize_result(item.result)
|
|
98
|
-
if isinstance(serialized, str)
|
|
99
|
-
data = serialized
|
|
100
|
-
else:
|
|
101
|
-
data = json.dumps(serialized)
|
|
102
|
-
|
|
126
|
+
data = serialized if isinstance(serialized, str) else json.dumps(serialized)
|
|
103
127
|
yield f"data: {data}\n\n"
|
|
104
|
-
|
|
128
|
+
|
|
105
129
|
except Exception as e:
|
|
106
130
|
yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
|
|
107
131
|
finally:
|
|
132
|
+
self.stream_queues.pop(task_id, None)
|
|
133
|
+
|
|
108
134
|
if task_id in self.stream_queues:
|
|
109
135
|
del self.stream_queues[task_id]
|
|
110
136
|
if task_id in self.result_queues:
|
|
@@ -117,19 +143,13 @@ class FastAPITrigger(Trigger):
|
|
|
117
143
|
"Cache-Control": "no-cache",
|
|
118
144
|
"Connection": "keep-alive",
|
|
119
145
|
"X-Accel-Buffering": "no",
|
|
120
|
-
}
|
|
146
|
+
},
|
|
121
147
|
)
|
|
122
|
-
else:
|
|
123
|
-
result = await self.result_queues[task_id].get()
|
|
124
|
-
del self.result_queues[task_id]
|
|
125
148
|
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
129
|
-
|
|
130
|
-
raise HTTPException(status_code=500, detail=str(result.result))
|
|
131
|
-
|
|
132
|
-
return self.serialize_result(result.result)
|
|
149
|
+
# 非流式:等待结果
|
|
150
|
+
result = await self.result_queues[task_id].get()
|
|
151
|
+
self.result_queues.pop(task_id, None)
|
|
152
|
+
return self._normalize_result_or_raise(result)
|
|
133
153
|
|
|
134
154
|
def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
|
|
135
155
|
async def get_data(request: Request) -> dict:
|
|
@@ -144,7 +164,10 @@ class FastAPITrigger(Trigger):
|
|
|
144
164
|
extractor = get_data if method == "GET" else body_data
|
|
145
165
|
|
|
146
166
|
async def handler(request: Request):
|
|
147
|
-
|
|
167
|
+
# Get path parameters from FastAPI request
|
|
168
|
+
# request.path_params is always available in FastAPI and contains path parameters
|
|
169
|
+
path_params = dict(request.path_params)
|
|
170
|
+
return await self.handle_request(request, data_type, extractor, is_stream, path_params)
|
|
148
171
|
|
|
149
172
|
# Save route information for cleanup
|
|
150
173
|
self.app = app
|
|
@@ -162,7 +185,14 @@ class FastAPITrigger(Trigger):
|
|
|
162
185
|
else:
|
|
163
186
|
raise ValueError(f"Invalid method {method}")
|
|
164
187
|
|
|
165
|
-
async def _run(
|
|
188
|
+
async def _run(
|
|
189
|
+
self,
|
|
190
|
+
app: FastAPI,
|
|
191
|
+
path: str,
|
|
192
|
+
method: str,
|
|
193
|
+
data_type: type,
|
|
194
|
+
is_stream: bool = False,
|
|
195
|
+
) -> AsyncIterator[bool]:
|
|
166
196
|
if not self.is_setup_route:
|
|
167
197
|
self._setup_route(app, path, method, data_type, is_stream)
|
|
168
198
|
self.is_setup_route = True
|
|
@@ -170,9 +200,12 @@ class FastAPITrigger(Trigger):
|
|
|
170
200
|
while True:
|
|
171
201
|
try:
|
|
172
202
|
trigger = await self.trigger_queue.get()
|
|
203
|
+
if trace_ctx := trigger.get("trace_context"):
|
|
204
|
+
self.task_contexts[trigger["id"]] = trace_ctx
|
|
173
205
|
self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
|
|
174
206
|
self.prepare_output_edges(self.get_output_port_by_name('token'), trigger['token'])
|
|
175
207
|
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
208
|
+
self.prepare_output_edges(self.get_output_port_by_name('path_params'), trigger['path_params'])
|
|
176
209
|
yield self.trigger(trigger['id'])
|
|
177
210
|
except Exception as e:
|
|
178
211
|
logger.error(f"Error in FastAPITrigger._run: {e}")
|