service-forge 0.1.18__py3-none-any.whl → 0.1.39__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/__init__.py +0 -0
- service_forge/api/deprecated_websocket_api.py +91 -33
- service_forge/api/deprecated_websocket_manager.py +70 -53
- service_forge/api/http_api.py +205 -55
- service_forge/api/kafka_api.py +113 -25
- service_forge/api/routers/meta_api/meta_api_router.py +57 -0
- service_forge/api/routers/service/service_router.py +42 -6
- service_forge/api/routers/trace/trace_router.py +326 -0
- service_forge/api/routers/websocket/websocket_router.py +69 -1
- service_forge/api/service_studio.py +9 -0
- service_forge/db/database.py +17 -0
- service_forge/execution_context.py +106 -0
- service_forge/frontend/static/assets/CreateNewNodeDialog-DkrEMxSH.js +1 -0
- service_forge/frontend/static/assets/CreateNewNodeDialog-DwFcBiGp.css +1 -0
- service_forge/frontend/static/assets/EditorSidePanel-BNVms9Fq.css +1 -0
- service_forge/frontend/static/assets/EditorSidePanel-DZbB3ILL.js +1 -0
- service_forge/frontend/static/assets/FeedbackPanel-CC8HX7Yo.js +1 -0
- service_forge/frontend/static/assets/FeedbackPanel-ClgniIVk.css +1 -0
- service_forge/frontend/static/assets/FormattedCodeViewer.vue_vue_type_script_setup_true_lang-BNuI1NCs.js +1 -0
- service_forge/frontend/static/assets/NodeDetailWrapper-BqFFM7-r.js +1 -0
- service_forge/frontend/static/assets/NodeDetailWrapper-pZBxv3J0.css +1 -0
- service_forge/frontend/static/assets/TestRunningDialog-D0GrCoYs.js +1 -0
- service_forge/frontend/static/assets/TestRunningDialog-dhXOsPgH.css +1 -0
- service_forge/frontend/static/assets/TracePanelWrapper-B9zvDSc_.js +1 -0
- service_forge/frontend/static/assets/TracePanelWrapper-BiednCrq.css +1 -0
- service_forge/frontend/static/assets/WorkflowEditor-CcaGGbko.js +3 -0
- service_forge/frontend/static/assets/WorkflowEditor-CmasOOYK.css +1 -0
- service_forge/frontend/static/assets/WorkflowList-Copuwi-a.css +1 -0
- service_forge/frontend/static/assets/WorkflowList-LrRJ7B7h.js +1 -0
- service_forge/frontend/static/assets/WorkflowStudio-CthjgII2.css +1 -0
- service_forge/frontend/static/assets/WorkflowStudio-FCyhGD4y.js +2 -0
- service_forge/frontend/static/assets/api-BDer3rj7.css +1 -0
- service_forge/frontend/static/assets/api-DyiqpKJK.js +1 -0
- service_forge/frontend/static/assets/code-editor-DBSql_sc.js +12 -0
- service_forge/frontend/static/assets/el-collapse-item-D4LG0FJ0.css +1 -0
- service_forge/frontend/static/assets/el-empty-D4ZqTl4F.css +1 -0
- service_forge/frontend/static/assets/el-form-item-BWkJzdQ_.css +1 -0
- service_forge/frontend/static/assets/el-input-D6B3r8CH.css +1 -0
- service_forge/frontend/static/assets/el-select-B0XIb2QK.css +1 -0
- service_forge/frontend/static/assets/el-tag-DljBBxJR.css +1 -0
- service_forge/frontend/static/assets/element-ui-D3x2y3TA.js +12 -0
- service_forge/frontend/static/assets/elkjs-Dm5QV7uy.js +24 -0
- service_forge/frontend/static/assets/highlightjs-D4ATuRwX.js +3 -0
- service_forge/frontend/static/assets/index-BMvodlwc.js +2 -0
- service_forge/frontend/static/assets/index-CjSe8i2q.css +1 -0
- service_forge/frontend/static/assets/js-yaml-yTPt38rv.js +32 -0
- service_forge/frontend/static/assets/time-DKCKV6Ug.js +1 -0
- service_forge/frontend/static/assets/ui-components-DQ7-U3pr.js +1 -0
- service_forge/frontend/static/assets/vue-core-DL-LgTX0.js +1 -0
- service_forge/frontend/static/assets/vue-flow-Dn7R8GPr.js +39 -0
- service_forge/frontend/static/index.html +16 -0
- service_forge/frontend/static/vite.svg +1 -0
- service_forge/model/meta_api/__init__.py +0 -0
- service_forge/model/meta_api/schema.py +29 -0
- service_forge/model/trace.py +82 -0
- service_forge/service.py +39 -11
- service_forge/service_config.py +14 -0
- service_forge/sft/cli.py +39 -0
- service_forge/sft/cmd/remote_deploy.py +160 -0
- service_forge/sft/cmd/remote_list_tars.py +111 -0
- service_forge/sft/config/injector.py +54 -7
- service_forge/sft/config/injector_default_files.py +13 -1
- service_forge/sft/config/sf_metadata.py +31 -27
- service_forge/sft/config/sft_config.py +18 -0
- service_forge/sft/util/assert_util.py +0 -1
- service_forge/telemetry.py +66 -0
- service_forge/utils/default_type_converter.py +1 -1
- service_forge/utils/type_converter.py +5 -0
- service_forge/utils/workflow_clone.py +1 -0
- service_forge/workflow/node.py +274 -27
- service_forge/workflow/triggers/fast_api_trigger.py +64 -28
- service_forge/workflow/triggers/websocket_api_trigger.py +66 -38
- service_forge/workflow/workflow.py +140 -37
- service_forge/workflow/workflow_callback.py +27 -4
- service_forge/workflow/workflow_factory.py +14 -0
- {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/METADATA +4 -1
- service_forge-0.1.39.dist-info/RECORD +134 -0
- service_forge-0.1.18.dist-info/RECORD +0 -83
- {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/WHEEL +0 -0
- {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/entry_points.txt +0 -0
service_forge/workflow/node.py
CHANGED
|
@@ -1,50 +1,78 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
|
-
|
|
4
|
-
|
|
3
|
+
import inspect
|
|
4
|
+
import json
|
|
5
5
|
import uuid
|
|
6
|
+
from abc import ABC, abstractmethod
|
|
7
|
+
from dataclasses import asdict, is_dataclass
|
|
8
|
+
from typing import (
|
|
9
|
+
TYPE_CHECKING,
|
|
10
|
+
Any,
|
|
11
|
+
AsyncIterator,
|
|
12
|
+
Awaitable,
|
|
13
|
+
Callable,
|
|
14
|
+
Optional,
|
|
15
|
+
Union,
|
|
16
|
+
cast,
|
|
17
|
+
)
|
|
18
|
+
|
|
6
19
|
from loguru import logger
|
|
7
|
-
from
|
|
8
|
-
from
|
|
9
|
-
from .
|
|
10
|
-
from
|
|
20
|
+
from opentelemetry import context as otel_context_api
|
|
21
|
+
from opentelemetry import trace
|
|
22
|
+
from opentelemetry.trace import SpanKind
|
|
23
|
+
from pydantic import BaseModel
|
|
24
|
+
|
|
11
25
|
from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
12
26
|
from ..utils.workflow_clone import node_clone
|
|
27
|
+
from ..execution_context import (
|
|
28
|
+
ExecutionContext,
|
|
29
|
+
get_current_context,
|
|
30
|
+
reset_current_context,
|
|
31
|
+
set_current_context,
|
|
32
|
+
)
|
|
33
|
+
from ..utils.register import Register
|
|
34
|
+
from .context import Context
|
|
35
|
+
from .edge import Edge
|
|
36
|
+
from .port import Port
|
|
37
|
+
from .workflow_callback import CallbackEvent
|
|
13
38
|
|
|
14
39
|
if TYPE_CHECKING:
|
|
15
40
|
from .workflow import Workflow
|
|
16
41
|
|
|
42
|
+
|
|
17
43
|
class Node(ABC):
|
|
18
44
|
DEFAULT_INPUT_PORTS: list[Port] = []
|
|
19
45
|
DEFAULT_OUTPUT_PORTS: list[Port] = []
|
|
20
46
|
|
|
21
|
-
CLASS_NOT_REQUIRED_TO_REGISTER = [
|
|
47
|
+
CLASS_NOT_REQUIRED_TO_REGISTER = ["Node"]
|
|
22
48
|
AUTO_FILL_INPUT_PORTS = []
|
|
23
49
|
|
|
24
50
|
def __init__(
|
|
25
51
|
self,
|
|
26
52
|
name: str,
|
|
27
|
-
context: Context = None,
|
|
28
|
-
input_edges: list[Edge] = None,
|
|
29
|
-
output_edges: list[Edge] = None,
|
|
53
|
+
context: Optional[Context] = None,
|
|
54
|
+
input_edges: Optional[list[Edge]] = None,
|
|
55
|
+
output_edges: Optional[list[Edge]] = None,
|
|
30
56
|
input_ports: list[Port] = DEFAULT_INPUT_PORTS,
|
|
31
57
|
output_ports: list[Port] = DEFAULT_OUTPUT_PORTS,
|
|
32
|
-
query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
58
|
+
query_user: Optional[Callable[[str, str], Awaitable[str]]] = None,
|
|
33
59
|
) -> None:
|
|
34
60
|
from .workflow_group import WorkflowGroup
|
|
61
|
+
|
|
35
62
|
self.name = name
|
|
36
63
|
self.input_edges = [] if input_edges is None else input_edges
|
|
37
64
|
self.output_edges = [] if output_edges is None else output_edges
|
|
38
65
|
self.input_ports = input_ports
|
|
39
66
|
self.output_ports = output_ports
|
|
40
|
-
self.workflow: Workflow = None
|
|
67
|
+
self.workflow: Optional[Workflow] = None
|
|
41
68
|
self.query_user = query_user
|
|
42
|
-
self.sub_workflows: WorkflowGroup = None
|
|
69
|
+
self.sub_workflows: Optional[WorkflowGroup] = None
|
|
43
70
|
|
|
44
71
|
# runtime variables
|
|
45
72
|
self.context = context
|
|
46
73
|
self.input_variables: dict[Port, Any] = {}
|
|
47
74
|
self.num_activated_input_edges = 0
|
|
75
|
+
self._tracer = trace.get_tracer("service_forge.node")
|
|
48
76
|
|
|
49
77
|
@property
|
|
50
78
|
def default_postgres_database(self) -> PostgresDatabase | None:
|
|
@@ -60,28 +88,39 @@ class Node(ABC):
|
|
|
60
88
|
|
|
61
89
|
@property
|
|
62
90
|
def database_manager(self) -> DatabaseManager:
|
|
91
|
+
assert self.workflow
|
|
63
92
|
return self.workflow.database_manager
|
|
64
93
|
|
|
94
|
+
@property
|
|
95
|
+
def global_context(self) -> Context:
|
|
96
|
+
return self.workflow.global_context
|
|
97
|
+
|
|
65
98
|
def backup(self) -> None:
|
|
66
99
|
# do NOT use deepcopy here
|
|
67
100
|
# self.bak_context = deepcopy(self.context)
|
|
68
101
|
# TODO: what if the value changes after backup?
|
|
69
|
-
self.bak_input_variables = {
|
|
102
|
+
self.bak_input_variables = {
|
|
103
|
+
port: value for port, value in self.input_variables.items()
|
|
104
|
+
}
|
|
70
105
|
self.bak_num_activated_input_edges = self.num_activated_input_edges
|
|
71
106
|
|
|
72
107
|
def reset(self) -> None:
|
|
73
108
|
# self.context = deepcopy(self.bak_context)
|
|
74
|
-
self.input_variables = {
|
|
109
|
+
self.input_variables = {
|
|
110
|
+
port: value for port, value in self.bak_input_variables.items()
|
|
111
|
+
}
|
|
75
112
|
self.num_activated_input_edges = self.bak_num_activated_input_edges
|
|
76
113
|
|
|
77
114
|
def __init_subclass__(cls) -> None:
|
|
78
115
|
if cls.__name__ not in Node.CLASS_NOT_REQUIRED_TO_REGISTER:
|
|
116
|
+
# TODO: Register currently stores class objects; clarify Register typing vs instance usage.
|
|
79
117
|
node_register.register(cls.__name__, cls)
|
|
80
118
|
return super().__init_subclass__()
|
|
81
119
|
|
|
82
|
-
def _query_user(self, prompt: str) ->
|
|
120
|
+
def _query_user(self, prompt: str) -> Awaitable[str]:
|
|
121
|
+
assert self.query_user
|
|
83
122
|
return self.query_user(self.name, prompt)
|
|
84
|
-
|
|
123
|
+
|
|
85
124
|
def variables_to_params(self) -> dict[str, Any]:
|
|
86
125
|
params = {port.name: self.input_variables[port] for port in self.input_variables.keys() if not port.is_extended_generated}
|
|
87
126
|
for port in self.input_variables.keys():
|
|
@@ -94,6 +133,7 @@ class Node(ABC):
|
|
|
94
133
|
|
|
95
134
|
def is_trigger(self) -> bool:
|
|
96
135
|
from .trigger import Trigger
|
|
136
|
+
|
|
97
137
|
return isinstance(self, Trigger)
|
|
98
138
|
|
|
99
139
|
# TODO: maybe add a function before the run function?
|
|
@@ -101,23 +141,219 @@ class Node(ABC):
|
|
|
101
141
|
@abstractmethod
|
|
102
142
|
async def _run(self, **kwargs) -> Union[None, AsyncIterator]:
|
|
103
143
|
...
|
|
144
|
+
|
|
145
|
+
async def clear(self) -> None:
|
|
146
|
+
...
|
|
104
147
|
|
|
105
148
|
def run(self) -> Union[None, AsyncIterator]:
|
|
149
|
+
task_id: uuid.UUID | None = None
|
|
106
150
|
for key in list(self.input_variables.keys()):
|
|
107
151
|
if key and key.name[0].isupper():
|
|
108
152
|
del self.input_variables[key]
|
|
109
153
|
params = self.variables_to_params()
|
|
110
|
-
|
|
154
|
+
if task_id is not None and "task_id" in self._run.__code__.co_varnames:
|
|
155
|
+
params["task_id"] = task_id
|
|
156
|
+
base_context = get_current_context()
|
|
157
|
+
parent_context = (
|
|
158
|
+
base_context.trace_context
|
|
159
|
+
if base_context and base_context.trace_context
|
|
160
|
+
else otel_context_api.get_current()
|
|
161
|
+
)
|
|
162
|
+
span_name = f"Node {self.name}"
|
|
111
163
|
|
|
112
|
-
|
|
113
|
-
|
|
164
|
+
if inspect.isasyncgenfunction(self._run):
|
|
165
|
+
return self._run_async_generator(
|
|
166
|
+
params, task_id, base_context, parent_context, span_name
|
|
167
|
+
)
|
|
168
|
+
if inspect.iscoroutinefunction(self._run):
|
|
169
|
+
return self._run_async(
|
|
170
|
+
params, task_id, base_context, parent_context, span_name
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return self._run_sync(params, task_id, base_context, parent_context, span_name)
|
|
174
|
+
|
|
175
|
+
def _build_execution_context(
|
|
176
|
+
self, base_context: ExecutionContext | None, span: trace.Span
|
|
177
|
+
) -> ExecutionContext:
|
|
178
|
+
return ExecutionContext(
|
|
179
|
+
trace_context=otel_context_api.get_current(),
|
|
180
|
+
span=span,
|
|
181
|
+
state=base_context.state if base_context else {},
|
|
182
|
+
metadata={
|
|
183
|
+
**(base_context.metadata if base_context else {}),
|
|
184
|
+
"node": self.name,
|
|
185
|
+
"workflow_name": getattr(self.workflow, "name", None),
|
|
186
|
+
},
|
|
187
|
+
)
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _serialize_for_trace(value: Any, max_length: int = 4000) -> tuple[str, bool]:
|
|
191
|
+
def _normalize(val: Any) -> Any:
|
|
192
|
+
if isinstance(val, BaseModel):
|
|
193
|
+
return val.model_dump()
|
|
194
|
+
if hasattr(val, "model_dump"):
|
|
195
|
+
try:
|
|
196
|
+
return val.model_dump()
|
|
197
|
+
except Exception:
|
|
198
|
+
pass
|
|
199
|
+
if hasattr(val, "dict"):
|
|
200
|
+
try:
|
|
201
|
+
return val.dict()
|
|
202
|
+
except Exception:
|
|
203
|
+
pass
|
|
204
|
+
if is_dataclass is not None and is_dataclass(val):
|
|
205
|
+
return asdict(val) if asdict else val
|
|
206
|
+
if isinstance(val, dict):
|
|
207
|
+
return {k: _normalize(v) for k, v in val.items()}
|
|
208
|
+
if isinstance(val, (list, tuple)):
|
|
209
|
+
return [_normalize(v) for v in val]
|
|
210
|
+
return val
|
|
211
|
+
|
|
212
|
+
normalized_value = _normalize(value)
|
|
213
|
+
serialized = json.dumps(normalized_value, ensure_ascii=False, default=str)
|
|
214
|
+
if len(serialized) > max_length:
|
|
215
|
+
return serialized[:max_length], True
|
|
216
|
+
return serialized, False
|
|
217
|
+
|
|
218
|
+
def _set_span_attributes(
|
|
219
|
+
self, span: trace.Span, params: dict[str, Any], task_id: uuid.UUID | None
|
|
220
|
+
) -> None:
|
|
221
|
+
span.set_attribute("node.name", self.name)
|
|
222
|
+
if self.workflow is not None:
|
|
223
|
+
span.set_attribute("workflow.name", self.workflow.name)
|
|
224
|
+
if task_id is not None:
|
|
225
|
+
span.set_attribute("workflow.task_id", str(task_id))
|
|
226
|
+
span.set_attribute("node.input_keys", ",".join(params.keys()))
|
|
227
|
+
serialized_inputs, inputs_truncated = self._serialize_for_trace(params)
|
|
228
|
+
span.set_attribute("node.inputs", serialized_inputs)
|
|
229
|
+
if inputs_truncated:
|
|
230
|
+
span.set_attribute("node.inputs_truncated", True)
|
|
231
|
+
|
|
232
|
+
def _record_output(self, span: trace.Span, output: Any) -> None:
|
|
233
|
+
span.set_attribute(
|
|
234
|
+
"node.output_type", type(output).__name__ if output is not None else "None"
|
|
235
|
+
)
|
|
236
|
+
serialized_output, output_truncated = self._serialize_for_trace(output)
|
|
237
|
+
span.set_attribute("node.output", serialized_output)
|
|
238
|
+
if output_truncated:
|
|
239
|
+
span.set_attribute("node.output_truncated", True)
|
|
240
|
+
|
|
241
|
+
async def _run_async(
|
|
242
|
+
self,
|
|
243
|
+
params: dict[str, Any],
|
|
244
|
+
task_id: uuid.UUID | None,
|
|
245
|
+
base_context: ExecutionContext | None,
|
|
246
|
+
parent_context: otel_context_api.Context,
|
|
247
|
+
span_name: str,
|
|
248
|
+
) -> Any:
|
|
249
|
+
with self._tracer.start_as_current_span(
|
|
250
|
+
span_name,
|
|
251
|
+
context=parent_context,
|
|
252
|
+
kind=SpanKind.INTERNAL,
|
|
253
|
+
) as span:
|
|
254
|
+
self._set_span_attributes(span, params, task_id)
|
|
255
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
256
|
+
token = set_current_context(exec_ctx)
|
|
257
|
+
try:
|
|
258
|
+
result = self._run(**params)
|
|
259
|
+
if inspect.isawaitable(result):
|
|
260
|
+
result = await result
|
|
261
|
+
self._record_output(span, result)
|
|
262
|
+
return result
|
|
263
|
+
finally:
|
|
264
|
+
reset_current_context(token)
|
|
265
|
+
|
|
266
|
+
async def _run_async_generator(
|
|
267
|
+
self,
|
|
268
|
+
params: dict[str, Any],
|
|
269
|
+
task_id: uuid.UUID | None,
|
|
270
|
+
base_context: ExecutionContext | None,
|
|
271
|
+
parent_context: otel_context_api.Context,
|
|
272
|
+
span_name: str,
|
|
273
|
+
) -> AsyncIterator[Any]:
|
|
274
|
+
# Trigger 节点是长期运行的 async generator,这里为每次触发单独生成/关闭 span,避免一个 span 挂载所有请求。
|
|
275
|
+
if self.is_trigger():
|
|
276
|
+
async for item in self._run(**params):
|
|
277
|
+
trigger_parent_context = parent_context
|
|
278
|
+
if hasattr(self, "task_contexts") and isinstance(item, uuid.UUID):
|
|
279
|
+
trigger_ctx = getattr(self, "task_contexts").get(item)
|
|
280
|
+
if trigger_ctx is not None:
|
|
281
|
+
trigger_parent_context = trigger_ctx
|
|
282
|
+
with self._tracer.start_as_current_span(
|
|
283
|
+
span_name,
|
|
284
|
+
context=trigger_parent_context,
|
|
285
|
+
kind=SpanKind.INTERNAL,
|
|
286
|
+
) as span:
|
|
287
|
+
self._set_span_attributes(span, params, task_id)
|
|
288
|
+
span.set_attribute("node.output_type", "async_generator")
|
|
289
|
+
serialized_item, item_truncated = self._serialize_for_trace(item)
|
|
290
|
+
span.add_event(
|
|
291
|
+
"node.output_item",
|
|
292
|
+
{
|
|
293
|
+
"value": serialized_item,
|
|
294
|
+
"truncated": item_truncated,
|
|
295
|
+
},
|
|
296
|
+
)
|
|
297
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
298
|
+
token = set_current_context(exec_ctx)
|
|
299
|
+
try:
|
|
300
|
+
yield item
|
|
301
|
+
finally:
|
|
302
|
+
reset_current_context(token)
|
|
303
|
+
else:
|
|
304
|
+
with self._tracer.start_as_current_span(
|
|
305
|
+
span_name,
|
|
306
|
+
context=parent_context,
|
|
307
|
+
kind=SpanKind.INTERNAL,
|
|
308
|
+
) as span:
|
|
309
|
+
self._set_span_attributes(span, params, task_id)
|
|
310
|
+
span.set_attribute("node.output_type", "async_generator")
|
|
311
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
312
|
+
token = set_current_context(exec_ctx)
|
|
313
|
+
try:
|
|
314
|
+
async for item in self._run(**params):
|
|
315
|
+
serialized_item, item_truncated = self._serialize_for_trace(item)
|
|
316
|
+
span.add_event(
|
|
317
|
+
"node.output_item",
|
|
318
|
+
{
|
|
319
|
+
"value": serialized_item,
|
|
320
|
+
"truncated": item_truncated,
|
|
321
|
+
},
|
|
322
|
+
)
|
|
323
|
+
yield item
|
|
324
|
+
finally:
|
|
325
|
+
reset_current_context(token)
|
|
326
|
+
|
|
327
|
+
def _run_sync(
|
|
328
|
+
self,
|
|
329
|
+
params: dict[str, Any],
|
|
330
|
+
task_id: uuid.UUID | None,
|
|
331
|
+
base_context: ExecutionContext | None,
|
|
332
|
+
parent_context: otel_context_api.Context,
|
|
333
|
+
span_name: str,
|
|
334
|
+
) -> Any:
|
|
335
|
+
with self._tracer.start_as_current_span(
|
|
336
|
+
span_name,
|
|
337
|
+
context=parent_context,
|
|
338
|
+
kind=SpanKind.INTERNAL,
|
|
339
|
+
) as span:
|
|
340
|
+
self._set_span_attributes(span, params, task_id)
|
|
341
|
+
exec_ctx = self._build_execution_context(base_context, span)
|
|
342
|
+
token = set_current_context(exec_ctx)
|
|
343
|
+
try:
|
|
344
|
+
result = self._run(**params)
|
|
345
|
+
self._record_output(span, result)
|
|
346
|
+
return result
|
|
347
|
+
finally:
|
|
348
|
+
reset_current_context(token)
|
|
349
|
+
|
|
350
|
+
def get_input_port_by_name(self, name: str) -> Optional[Port]:
|
|
114
351
|
for port in self.input_ports:
|
|
115
352
|
if port.name == name:
|
|
116
353
|
return port
|
|
117
354
|
return None
|
|
118
355
|
|
|
119
|
-
def get_output_port_by_name(self, name: str) -> Port:
|
|
120
|
-
# TODO: add warning if port is extended
|
|
356
|
+
def get_output_port_by_name(self, name: str) -> Optional[Port]:
|
|
121
357
|
for port in self.output_ports:
|
|
122
358
|
if port.name == name:
|
|
123
359
|
return port
|
|
@@ -143,9 +379,9 @@ class Node(ABC):
|
|
|
143
379
|
self.try_create_extended_input_port(port_name)
|
|
144
380
|
port = self.get_input_port_by_name(port_name)
|
|
145
381
|
if port is None:
|
|
146
|
-
raise ValueError(f
|
|
382
|
+
raise ValueError(f"{port_name} is not a valid input port.")
|
|
147
383
|
self.fill_input(port, value)
|
|
148
|
-
|
|
384
|
+
|
|
149
385
|
def fill_input(self, port: Port, value: Any) -> None:
|
|
150
386
|
port.activate(value)
|
|
151
387
|
|
|
@@ -163,7 +399,7 @@ class Node(ABC):
|
|
|
163
399
|
for output_edge in self.output_edges:
|
|
164
400
|
if output_edge.start_port == port:
|
|
165
401
|
output_edge.end_port.prepare(data)
|
|
166
|
-
|
|
402
|
+
|
|
167
403
|
def trigger_output_edges(self, port: Port) -> None:
|
|
168
404
|
if isinstance(port, str):
|
|
169
405
|
port = self.get_output_port_by_name(port)
|
|
@@ -173,7 +409,14 @@ class Node(ABC):
|
|
|
173
409
|
|
|
174
410
|
# TODO: the result is outputed to the trigger now, maybe we should add a new function to output the result to the workflow
|
|
175
411
|
def output_to_workflow(self, data: Any) -> None:
|
|
176
|
-
self.workflow
|
|
412
|
+
if self.workflow and hasattr(self.workflow, "_handle_workflow_output"):
|
|
413
|
+
handler = cast(
|
|
414
|
+
Callable[[str, Any], None],
|
|
415
|
+
getattr(self.workflow, "_handle_workflow_output"),
|
|
416
|
+
)
|
|
417
|
+
handler(self.name, data)
|
|
418
|
+
else:
|
|
419
|
+
logger.warning("Workflow output handler not set; skipping output dispatch.")
|
|
177
420
|
|
|
178
421
|
def extended_output_name(self, name: str, index: int) -> str:
|
|
179
422
|
return name + '_' + str(index)
|
|
@@ -181,4 +424,8 @@ class Node(ABC):
|
|
|
181
424
|
def _clone(self, context: Context) -> Node:
|
|
182
425
|
return node_clone(self, context)
|
|
183
426
|
|
|
184
|
-
|
|
427
|
+
async def stream_output(self, data: Any) -> None:
|
|
428
|
+
await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
|
|
429
|
+
|
|
430
|
+
|
|
431
|
+
node_register = Register[Node]()
|
|
@@ -4,7 +4,7 @@ import asyncio
|
|
|
4
4
|
import json
|
|
5
5
|
from loguru import logger
|
|
6
6
|
from service_forge.workflow.trigger import Trigger
|
|
7
|
-
from typing import AsyncIterator, Any
|
|
7
|
+
from typing import AsyncIterator, Any, Optional
|
|
8
8
|
from fastapi import FastAPI, Request
|
|
9
9
|
from fastapi.responses import StreamingResponse
|
|
10
10
|
from service_forge.workflow.port import Port
|
|
@@ -13,6 +13,7 @@ from service_forge.api.routers.websocket.websocket_manager import websocket_mana
|
|
|
13
13
|
from fastapi import HTTPException
|
|
14
14
|
from google.protobuf.message import Message
|
|
15
15
|
from google.protobuf.json_format import MessageToJson
|
|
16
|
+
from opentelemetry import context as otel_context_api
|
|
16
17
|
|
|
17
18
|
class FastAPITrigger(Trigger):
|
|
18
19
|
DEFAULT_INPUT_PORTS = [
|
|
@@ -26,13 +27,16 @@ class FastAPITrigger(Trigger):
|
|
|
26
27
|
DEFAULT_OUTPUT_PORTS = [
|
|
27
28
|
Port("trigger", bool),
|
|
28
29
|
Port("user_id", int),
|
|
30
|
+
Port("token", str),
|
|
29
31
|
Port("data", Any),
|
|
32
|
+
Port("path_params", dict),
|
|
30
33
|
]
|
|
31
34
|
|
|
32
35
|
def __init__(self, name: str):
|
|
33
36
|
super().__init__(name)
|
|
34
37
|
self.events = {}
|
|
35
38
|
self.is_setup_route = False
|
|
39
|
+
self.task_contexts: dict[uuid.UUID, otel_context_api.Context] = {}
|
|
36
40
|
self.app = None
|
|
37
41
|
self.route_path = None
|
|
38
42
|
self.route_method = None
|
|
@@ -40,23 +44,44 @@ class FastAPITrigger(Trigger):
|
|
|
40
44
|
@staticmethod
|
|
41
45
|
def serialize_result(result: Any):
|
|
42
46
|
if isinstance(result, Message):
|
|
43
|
-
return MessageToJson(
|
|
44
|
-
result,
|
|
45
|
-
preserving_proto_field_name=True
|
|
46
|
-
)
|
|
47
|
+
return MessageToJson(result, preserving_proto_field_name=True)
|
|
47
48
|
return result
|
|
48
49
|
|
|
50
|
+
def _normalize_result_or_raise(self, result: Any):
|
|
51
|
+
# TODO: 检查合并
|
|
52
|
+
if hasattr(result, "is_error") and hasattr(result, "result"):
|
|
53
|
+
if result.is_error:
|
|
54
|
+
if isinstance(result.result, HTTPException):
|
|
55
|
+
raise result.result
|
|
56
|
+
raise HTTPException(status_code=500, detail=str(result.result))
|
|
57
|
+
return self.serialize_result(result.result)
|
|
58
|
+
|
|
59
|
+
if isinstance(result, HTTPException):
|
|
60
|
+
raise result
|
|
61
|
+
if isinstance(result, Exception):
|
|
62
|
+
raise HTTPException(status_code=500, detail=str(result))
|
|
63
|
+
|
|
64
|
+
return self.serialize_result(result)
|
|
65
|
+
|
|
49
66
|
async def handle_request(
|
|
50
67
|
self,
|
|
51
68
|
request: Request,
|
|
52
69
|
data_type: type,
|
|
53
70
|
extract_data_fn: callable[[Request], dict],
|
|
54
71
|
is_stream: bool,
|
|
72
|
+
path_params: Optional[dict] = None,
|
|
55
73
|
):
|
|
56
74
|
task_id = uuid.uuid4()
|
|
57
75
|
self.result_queues[task_id] = asyncio.Queue()
|
|
58
76
|
|
|
77
|
+
# parse trace context
|
|
78
|
+
trace_ctx = otel_context_api.get_current()
|
|
79
|
+
self.task_contexts[task_id] = trace_ctx
|
|
80
|
+
|
|
59
81
|
body_data = await extract_data_fn(request)
|
|
82
|
+
# Merge path parameters into body_data (path params take precedence)
|
|
83
|
+
if path_params:
|
|
84
|
+
body_data = {**body_data, **path_params}
|
|
60
85
|
converted_data = data_type(**body_data)
|
|
61
86
|
|
|
62
87
|
client_id = (
|
|
@@ -69,10 +94,16 @@ class FastAPITrigger(Trigger):
|
|
|
69
94
|
steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
|
|
70
95
|
websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
|
|
71
96
|
|
|
97
|
+
# trigger_queue with trace_context, used in _run()
|
|
98
|
+
logger.info(f'user_id {getattr(request.state, "user_id", None)} token {getattr(request.state, "auth_token", None)}')
|
|
99
|
+
|
|
72
100
|
self.trigger_queue.put_nowait({
|
|
73
101
|
"id": task_id,
|
|
74
102
|
"user_id": getattr(request.state, "user_id", None),
|
|
103
|
+
"token": getattr(request.state, "auth_token", None),
|
|
75
104
|
"data": converted_data,
|
|
105
|
+
"trace_context": trace_ctx,
|
|
106
|
+
"path_params": path_params,
|
|
76
107
|
})
|
|
77
108
|
|
|
78
109
|
if is_stream:
|
|
@@ -83,26 +114,23 @@ class FastAPITrigger(Trigger):
|
|
|
83
114
|
while True:
|
|
84
115
|
item = await self.stream_queues[task_id].get()
|
|
85
116
|
|
|
86
|
-
if item
|
|
117
|
+
if getattr(item, "is_error", False):
|
|
87
118
|
yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
|
|
88
119
|
break
|
|
89
|
-
|
|
90
|
-
if item
|
|
91
|
-
# TODO: send the result?
|
|
120
|
+
|
|
121
|
+
if getattr(item, "is_end", False):
|
|
92
122
|
break
|
|
93
123
|
|
|
94
124
|
# TODO: modify
|
|
95
125
|
serialized = self.serialize_result(item.result)
|
|
96
|
-
if isinstance(serialized, str)
|
|
97
|
-
data = serialized
|
|
98
|
-
else:
|
|
99
|
-
data = json.dumps(serialized)
|
|
100
|
-
|
|
126
|
+
data = serialized if isinstance(serialized, str) else json.dumps(serialized)
|
|
101
127
|
yield f"data: {data}\n\n"
|
|
102
|
-
|
|
128
|
+
|
|
103
129
|
except Exception as e:
|
|
104
130
|
yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
|
|
105
131
|
finally:
|
|
132
|
+
self.stream_queues.pop(task_id, None)
|
|
133
|
+
|
|
106
134
|
if task_id in self.stream_queues:
|
|
107
135
|
del self.stream_queues[task_id]
|
|
108
136
|
if task_id in self.result_queues:
|
|
@@ -115,19 +143,13 @@ class FastAPITrigger(Trigger):
|
|
|
115
143
|
"Cache-Control": "no-cache",
|
|
116
144
|
"Connection": "keep-alive",
|
|
117
145
|
"X-Accel-Buffering": "no",
|
|
118
|
-
}
|
|
146
|
+
},
|
|
119
147
|
)
|
|
120
|
-
else:
|
|
121
|
-
result = await self.result_queues[task_id].get()
|
|
122
|
-
del self.result_queues[task_id]
|
|
123
148
|
|
|
124
|
-
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
raise HTTPException(status_code=500, detail=str(result.result))
|
|
129
|
-
|
|
130
|
-
return self.serialize_result(result.result)
|
|
149
|
+
# 非流式:等待结果
|
|
150
|
+
result = await self.result_queues[task_id].get()
|
|
151
|
+
self.result_queues.pop(task_id, None)
|
|
152
|
+
return self._normalize_result_or_raise(result)
|
|
131
153
|
|
|
132
154
|
def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
|
|
133
155
|
async def get_data(request: Request) -> dict:
|
|
@@ -142,7 +164,10 @@ class FastAPITrigger(Trigger):
|
|
|
142
164
|
extractor = get_data if method == "GET" else body_data
|
|
143
165
|
|
|
144
166
|
async def handler(request: Request):
|
|
145
|
-
|
|
167
|
+
# Get path parameters from FastAPI request
|
|
168
|
+
# request.path_params is always available in FastAPI and contains path parameters
|
|
169
|
+
path_params = dict(request.path_params)
|
|
170
|
+
return await self.handle_request(request, data_type, extractor, is_stream, path_params)
|
|
146
171
|
|
|
147
172
|
# Save route information for cleanup
|
|
148
173
|
self.app = app
|
|
@@ -160,7 +185,14 @@ class FastAPITrigger(Trigger):
|
|
|
160
185
|
else:
|
|
161
186
|
raise ValueError(f"Invalid method {method}")
|
|
162
187
|
|
|
163
|
-
async def _run(
|
|
188
|
+
async def _run(
|
|
189
|
+
self,
|
|
190
|
+
app: FastAPI,
|
|
191
|
+
path: str,
|
|
192
|
+
method: str,
|
|
193
|
+
data_type: type,
|
|
194
|
+
is_stream: bool = False,
|
|
195
|
+
) -> AsyncIterator[bool]:
|
|
164
196
|
if not self.is_setup_route:
|
|
165
197
|
self._setup_route(app, path, method, data_type, is_stream)
|
|
166
198
|
self.is_setup_route = True
|
|
@@ -168,8 +200,12 @@ class FastAPITrigger(Trigger):
|
|
|
168
200
|
while True:
|
|
169
201
|
try:
|
|
170
202
|
trigger = await self.trigger_queue.get()
|
|
203
|
+
if trace_ctx := trigger.get("trace_context"):
|
|
204
|
+
self.task_contexts[trigger["id"]] = trace_ctx
|
|
171
205
|
self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
|
|
206
|
+
self.prepare_output_edges(self.get_output_port_by_name('token'), trigger['token'])
|
|
172
207
|
self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
|
|
208
|
+
self.prepare_output_edges(self.get_output_port_by_name('path_params'), trigger['path_params'])
|
|
173
209
|
yield self.trigger(trigger['id'])
|
|
174
210
|
except Exception as e:
|
|
175
211
|
logger.error(f"Error in FastAPITrigger._run: {e}")
|