service-forge 0.1.28__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (72) hide show
  1. service_forge/__init__.py +0 -0
  2. service_forge/api/deprecated_websocket_api.py +91 -33
  3. service_forge/api/deprecated_websocket_manager.py +70 -53
  4. service_forge/api/http_api.py +127 -53
  5. service_forge/api/kafka_api.py +113 -25
  6. service_forge/api/routers/meta_api/meta_api_router.py +57 -0
  7. service_forge/api/routers/service/service_router.py +42 -6
  8. service_forge/api/routers/trace/trace_router.py +326 -0
  9. service_forge/api/routers/websocket/websocket_router.py +56 -1
  10. service_forge/api/service_studio.py +9 -0
  11. service_forge/execution_context.py +106 -0
  12. service_forge/frontend/static/assets/CreateNewNodeDialog-DkrEMxSH.js +1 -0
  13. service_forge/frontend/static/assets/CreateNewNodeDialog-DwFcBiGp.css +1 -0
  14. service_forge/frontend/static/assets/EditorSidePanel-BNVms9Fq.css +1 -0
  15. service_forge/frontend/static/assets/EditorSidePanel-DZbB3ILL.js +1 -0
  16. service_forge/frontend/static/assets/FeedbackPanel-CC8HX7Yo.js +1 -0
  17. service_forge/frontend/static/assets/FeedbackPanel-ClgniIVk.css +1 -0
  18. service_forge/frontend/static/assets/FormattedCodeViewer.vue_vue_type_script_setup_true_lang-BNuI1NCs.js +1 -0
  19. service_forge/frontend/static/assets/NodeDetailWrapper-BqFFM7-r.js +1 -0
  20. service_forge/frontend/static/assets/NodeDetailWrapper-pZBxv3J0.css +1 -0
  21. service_forge/frontend/static/assets/TestRunningDialog-D0GrCoYs.js +1 -0
  22. service_forge/frontend/static/assets/TestRunningDialog-dhXOsPgH.css +1 -0
  23. service_forge/frontend/static/assets/TracePanelWrapper-B9zvDSc_.js +1 -0
  24. service_forge/frontend/static/assets/TracePanelWrapper-BiednCrq.css +1 -0
  25. service_forge/frontend/static/assets/WorkflowEditor-CcaGGbko.js +3 -0
  26. service_forge/frontend/static/assets/WorkflowEditor-CmasOOYK.css +1 -0
  27. service_forge/frontend/static/assets/WorkflowList-Copuwi-a.css +1 -0
  28. service_forge/frontend/static/assets/WorkflowList-LrRJ7B7h.js +1 -0
  29. service_forge/frontend/static/assets/WorkflowStudio-CthjgII2.css +1 -0
  30. service_forge/frontend/static/assets/WorkflowStudio-FCyhGD4y.js +2 -0
  31. service_forge/frontend/static/assets/api-BDer3rj7.css +1 -0
  32. service_forge/frontend/static/assets/api-DyiqpKJK.js +1 -0
  33. service_forge/frontend/static/assets/code-editor-DBSql_sc.js +12 -0
  34. service_forge/frontend/static/assets/el-collapse-item-D4LG0FJ0.css +1 -0
  35. service_forge/frontend/static/assets/el-empty-D4ZqTl4F.css +1 -0
  36. service_forge/frontend/static/assets/el-form-item-BWkJzdQ_.css +1 -0
  37. service_forge/frontend/static/assets/el-input-D6B3r8CH.css +1 -0
  38. service_forge/frontend/static/assets/el-select-B0XIb2QK.css +1 -0
  39. service_forge/frontend/static/assets/el-tag-DljBBxJR.css +1 -0
  40. service_forge/frontend/static/assets/element-ui-D3x2y3TA.js +12 -0
  41. service_forge/frontend/static/assets/elkjs-Dm5QV7uy.js +24 -0
  42. service_forge/frontend/static/assets/highlightjs-D4ATuRwX.js +3 -0
  43. service_forge/frontend/static/assets/index-BMvodlwc.js +2 -0
  44. service_forge/frontend/static/assets/index-CjSe8i2q.css +1 -0
  45. service_forge/frontend/static/assets/js-yaml-yTPt38rv.js +32 -0
  46. service_forge/frontend/static/assets/time-DKCKV6Ug.js +1 -0
  47. service_forge/frontend/static/assets/ui-components-DQ7-U3pr.js +1 -0
  48. service_forge/frontend/static/assets/vue-core-DL-LgTX0.js +1 -0
  49. service_forge/frontend/static/assets/vue-flow-Dn7R8GPr.js +39 -0
  50. service_forge/frontend/static/index.html +16 -0
  51. service_forge/frontend/static/vite.svg +1 -0
  52. service_forge/model/meta_api/__init__.py +0 -0
  53. service_forge/model/meta_api/schema.py +29 -0
  54. service_forge/model/trace.py +82 -0
  55. service_forge/service.py +32 -11
  56. service_forge/service_config.py +14 -0
  57. service_forge/sft/config/injector.py +32 -2
  58. service_forge/sft/config/injector_default_files.py +12 -0
  59. service_forge/sft/config/sf_metadata.py +5 -0
  60. service_forge/sft/config/sft_config.py +18 -0
  61. service_forge/telemetry.py +66 -0
  62. service_forge/workflow/node.py +266 -27
  63. service_forge/workflow/triggers/fast_api_trigger.py +61 -28
  64. service_forge/workflow/triggers/websocket_api_trigger.py +31 -10
  65. service_forge/workflow/workflow.py +87 -10
  66. service_forge/workflow/workflow_callback.py +24 -2
  67. service_forge/workflow/workflow_factory.py +13 -0
  68. {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/METADATA +4 -1
  69. service_forge-0.1.39.dist-info/RECORD +134 -0
  70. service_forge-0.1.28.dist-info/RECORD +0 -85
  71. {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/WHEEL +0 -0
  72. {service_forge-0.1.28.dist-info → service_forge-0.1.39.dist-info}/entry_points.txt +0 -0
@@ -1,51 +1,78 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, AsyncIterator, Union, TYPE_CHECKING, Callable, Awaitable
4
- from abc import ABC, abstractmethod
3
+ import inspect
4
+ import json
5
5
  import uuid
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import asdict, is_dataclass
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ AsyncIterator,
12
+ Awaitable,
13
+ Callable,
14
+ Optional,
15
+ Union,
16
+ cast,
17
+ )
18
+
6
19
  from loguru import logger
7
- from .edge import Edge
8
- from .port import Port
9
- from .context import Context
10
- from ..utils.register import Register
20
+ from opentelemetry import context as otel_context_api
21
+ from opentelemetry import trace
22
+ from opentelemetry.trace import SpanKind
23
+ from pydantic import BaseModel
24
+
11
25
  from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
12
26
  from ..utils.workflow_clone import node_clone
27
+ from ..execution_context import (
28
+ ExecutionContext,
29
+ get_current_context,
30
+ reset_current_context,
31
+ set_current_context,
32
+ )
33
+ from ..utils.register import Register
34
+ from .context import Context
35
+ from .edge import Edge
36
+ from .port import Port
13
37
  from .workflow_callback import CallbackEvent
14
38
 
15
39
  if TYPE_CHECKING:
16
40
  from .workflow import Workflow
17
41
 
42
+
18
43
  class Node(ABC):
19
44
  DEFAULT_INPUT_PORTS: list[Port] = []
20
45
  DEFAULT_OUTPUT_PORTS: list[Port] = []
21
46
 
22
- CLASS_NOT_REQUIRED_TO_REGISTER = ['Node']
47
+ CLASS_NOT_REQUIRED_TO_REGISTER = ["Node"]
23
48
  AUTO_FILL_INPUT_PORTS = []
24
49
 
25
50
  def __init__(
26
51
  self,
27
52
  name: str,
28
- context: Context = None,
29
- input_edges: list[Edge] = None,
30
- output_edges: list[Edge] = None,
53
+ context: Optional[Context] = None,
54
+ input_edges: Optional[list[Edge]] = None,
55
+ output_edges: Optional[list[Edge]] = None,
31
56
  input_ports: list[Port] = DEFAULT_INPUT_PORTS,
32
57
  output_ports: list[Port] = DEFAULT_OUTPUT_PORTS,
33
- query_user: Callable[[str, str], Awaitable[str]] = None,
58
+ query_user: Optional[Callable[[str, str], Awaitable[str]]] = None,
34
59
  ) -> None:
35
60
  from .workflow_group import WorkflowGroup
61
+
36
62
  self.name = name
37
63
  self.input_edges = [] if input_edges is None else input_edges
38
64
  self.output_edges = [] if output_edges is None else output_edges
39
65
  self.input_ports = input_ports
40
66
  self.output_ports = output_ports
41
- self.workflow: Workflow = None
67
+ self.workflow: Optional[Workflow] = None
42
68
  self.query_user = query_user
43
- self.sub_workflows: WorkflowGroup = None
69
+ self.sub_workflows: Optional[WorkflowGroup] = None
44
70
 
45
71
  # runtime variables
46
72
  self.context = context
47
73
  self.input_variables: dict[Port, Any] = {}
48
74
  self.num_activated_input_edges = 0
75
+ self._tracer = trace.get_tracer("service_forge.node")
49
76
 
50
77
  @property
51
78
  def default_postgres_database(self) -> PostgresDatabase | None:
@@ -61,6 +88,7 @@ class Node(ABC):
61
88
 
62
89
  @property
63
90
  def database_manager(self) -> DatabaseManager:
91
+ assert self.workflow
64
92
  return self.workflow.database_manager
65
93
 
66
94
  @property
@@ -71,22 +99,28 @@ class Node(ABC):
71
99
  # do NOT use deepcopy here
72
100
  # self.bak_context = deepcopy(self.context)
73
101
  # TODO: what if the value changes after backup?
74
- self.bak_input_variables = {port: value for port, value in self.input_variables.items()}
102
+ self.bak_input_variables = {
103
+ port: value for port, value in self.input_variables.items()
104
+ }
75
105
  self.bak_num_activated_input_edges = self.num_activated_input_edges
76
106
 
77
107
  def reset(self) -> None:
78
108
  # self.context = deepcopy(self.bak_context)
79
- self.input_variables = {port: value for port, value in self.bak_input_variables.items()}
109
+ self.input_variables = {
110
+ port: value for port, value in self.bak_input_variables.items()
111
+ }
80
112
  self.num_activated_input_edges = self.bak_num_activated_input_edges
81
113
 
82
114
  def __init_subclass__(cls) -> None:
83
115
  if cls.__name__ not in Node.CLASS_NOT_REQUIRED_TO_REGISTER:
116
+ # TODO: Register currently stores class objects; clarify Register typing vs instance usage.
84
117
  node_register.register(cls.__name__, cls)
85
118
  return super().__init_subclass__()
86
119
 
87
- def _query_user(self, prompt: str) -> Callable[[str, str], Awaitable[str]]:
120
+ def _query_user(self, prompt: str) -> Awaitable[str]:
121
+ assert self.query_user
88
122
  return self.query_user(self.name, prompt)
89
-
123
+
90
124
  def variables_to_params(self) -> dict[str, Any]:
91
125
  params = {port.name: self.input_variables[port] for port in self.input_variables.keys() if not port.is_extended_generated}
92
126
  for port in self.input_variables.keys():
@@ -99,6 +133,7 @@ class Node(ABC):
99
133
 
100
134
  def is_trigger(self) -> bool:
101
135
  from .trigger import Trigger
136
+
102
137
  return isinstance(self, Trigger)
103
138
 
104
139
  # TODO: maybe add a function before the run function?
@@ -106,23 +141,219 @@ class Node(ABC):
106
141
  @abstractmethod
107
142
  async def _run(self, **kwargs) -> Union[None, AsyncIterator]:
108
143
  ...
144
+
145
+ async def clear(self) -> None:
146
+ ...
109
147
 
110
148
  def run(self) -> Union[None, AsyncIterator]:
149
+ task_id: uuid.UUID | None = None
111
150
  for key in list(self.input_variables.keys()):
112
151
  if key and key.name[0].isupper():
113
152
  del self.input_variables[key]
114
153
  params = self.variables_to_params()
115
- return self._run(**params)
154
+ if task_id is not None and "task_id" in self._run.__code__.co_varnames:
155
+ params["task_id"] = task_id
156
+ base_context = get_current_context()
157
+ parent_context = (
158
+ base_context.trace_context
159
+ if base_context and base_context.trace_context
160
+ else otel_context_api.get_current()
161
+ )
162
+ span_name = f"Node {self.name}"
116
163
 
117
- def get_input_port_by_name(self, name: str) -> Port:
118
- # TODO: add warning if port is extended
164
+ if inspect.isasyncgenfunction(self._run):
165
+ return self._run_async_generator(
166
+ params, task_id, base_context, parent_context, span_name
167
+ )
168
+ if inspect.iscoroutinefunction(self._run):
169
+ return self._run_async(
170
+ params, task_id, base_context, parent_context, span_name
171
+ )
172
+
173
+ return self._run_sync(params, task_id, base_context, parent_context, span_name)
174
+
175
+ def _build_execution_context(
176
+ self, base_context: ExecutionContext | None, span: trace.Span
177
+ ) -> ExecutionContext:
178
+ return ExecutionContext(
179
+ trace_context=otel_context_api.get_current(),
180
+ span=span,
181
+ state=base_context.state if base_context else {},
182
+ metadata={
183
+ **(base_context.metadata if base_context else {}),
184
+ "node": self.name,
185
+ "workflow_name": getattr(self.workflow, "name", None),
186
+ },
187
+ )
188
+
189
+ @staticmethod
190
+ def _serialize_for_trace(value: Any, max_length: int = 4000) -> tuple[str, bool]:
191
+ def _normalize(val: Any) -> Any:
192
+ if isinstance(val, BaseModel):
193
+ return val.model_dump()
194
+ if hasattr(val, "model_dump"):
195
+ try:
196
+ return val.model_dump()
197
+ except Exception:
198
+ pass
199
+ if hasattr(val, "dict"):
200
+ try:
201
+ return val.dict()
202
+ except Exception:
203
+ pass
204
+ if is_dataclass is not None and is_dataclass(val):
205
+ return asdict(val) if asdict else val
206
+ if isinstance(val, dict):
207
+ return {k: _normalize(v) for k, v in val.items()}
208
+ if isinstance(val, (list, tuple)):
209
+ return [_normalize(v) for v in val]
210
+ return val
211
+
212
+ normalized_value = _normalize(value)
213
+ serialized = json.dumps(normalized_value, ensure_ascii=False, default=str)
214
+ if len(serialized) > max_length:
215
+ return serialized[:max_length], True
216
+ return serialized, False
217
+
218
+ def _set_span_attributes(
219
+ self, span: trace.Span, params: dict[str, Any], task_id: uuid.UUID | None
220
+ ) -> None:
221
+ span.set_attribute("node.name", self.name)
222
+ if self.workflow is not None:
223
+ span.set_attribute("workflow.name", self.workflow.name)
224
+ if task_id is not None:
225
+ span.set_attribute("workflow.task_id", str(task_id))
226
+ span.set_attribute("node.input_keys", ",".join(params.keys()))
227
+ serialized_inputs, inputs_truncated = self._serialize_for_trace(params)
228
+ span.set_attribute("node.inputs", serialized_inputs)
229
+ if inputs_truncated:
230
+ span.set_attribute("node.inputs_truncated", True)
231
+
232
+ def _record_output(self, span: trace.Span, output: Any) -> None:
233
+ span.set_attribute(
234
+ "node.output_type", type(output).__name__ if output is not None else "None"
235
+ )
236
+ serialized_output, output_truncated = self._serialize_for_trace(output)
237
+ span.set_attribute("node.output", serialized_output)
238
+ if output_truncated:
239
+ span.set_attribute("node.output_truncated", True)
240
+
241
+ async def _run_async(
242
+ self,
243
+ params: dict[str, Any],
244
+ task_id: uuid.UUID | None,
245
+ base_context: ExecutionContext | None,
246
+ parent_context: otel_context_api.Context,
247
+ span_name: str,
248
+ ) -> Any:
249
+ with self._tracer.start_as_current_span(
250
+ span_name,
251
+ context=parent_context,
252
+ kind=SpanKind.INTERNAL,
253
+ ) as span:
254
+ self._set_span_attributes(span, params, task_id)
255
+ exec_ctx = self._build_execution_context(base_context, span)
256
+ token = set_current_context(exec_ctx)
257
+ try:
258
+ result = self._run(**params)
259
+ if inspect.isawaitable(result):
260
+ result = await result
261
+ self._record_output(span, result)
262
+ return result
263
+ finally:
264
+ reset_current_context(token)
265
+
266
+ async def _run_async_generator(
267
+ self,
268
+ params: dict[str, Any],
269
+ task_id: uuid.UUID | None,
270
+ base_context: ExecutionContext | None,
271
+ parent_context: otel_context_api.Context,
272
+ span_name: str,
273
+ ) -> AsyncIterator[Any]:
274
+ # Trigger 节点是长期运行的 async generator,这里为每次触发单独生成/关闭 span,避免一个 span 挂载所有请求。
275
+ if self.is_trigger():
276
+ async for item in self._run(**params):
277
+ trigger_parent_context = parent_context
278
+ if hasattr(self, "task_contexts") and isinstance(item, uuid.UUID):
279
+ trigger_ctx = getattr(self, "task_contexts").get(item)
280
+ if trigger_ctx is not None:
281
+ trigger_parent_context = trigger_ctx
282
+ with self._tracer.start_as_current_span(
283
+ span_name,
284
+ context=trigger_parent_context,
285
+ kind=SpanKind.INTERNAL,
286
+ ) as span:
287
+ self._set_span_attributes(span, params, task_id)
288
+ span.set_attribute("node.output_type", "async_generator")
289
+ serialized_item, item_truncated = self._serialize_for_trace(item)
290
+ span.add_event(
291
+ "node.output_item",
292
+ {
293
+ "value": serialized_item,
294
+ "truncated": item_truncated,
295
+ },
296
+ )
297
+ exec_ctx = self._build_execution_context(base_context, span)
298
+ token = set_current_context(exec_ctx)
299
+ try:
300
+ yield item
301
+ finally:
302
+ reset_current_context(token)
303
+ else:
304
+ with self._tracer.start_as_current_span(
305
+ span_name,
306
+ context=parent_context,
307
+ kind=SpanKind.INTERNAL,
308
+ ) as span:
309
+ self._set_span_attributes(span, params, task_id)
310
+ span.set_attribute("node.output_type", "async_generator")
311
+ exec_ctx = self._build_execution_context(base_context, span)
312
+ token = set_current_context(exec_ctx)
313
+ try:
314
+ async for item in self._run(**params):
315
+ serialized_item, item_truncated = self._serialize_for_trace(item)
316
+ span.add_event(
317
+ "node.output_item",
318
+ {
319
+ "value": serialized_item,
320
+ "truncated": item_truncated,
321
+ },
322
+ )
323
+ yield item
324
+ finally:
325
+ reset_current_context(token)
326
+
327
+ def _run_sync(
328
+ self,
329
+ params: dict[str, Any],
330
+ task_id: uuid.UUID | None,
331
+ base_context: ExecutionContext | None,
332
+ parent_context: otel_context_api.Context,
333
+ span_name: str,
334
+ ) -> Any:
335
+ with self._tracer.start_as_current_span(
336
+ span_name,
337
+ context=parent_context,
338
+ kind=SpanKind.INTERNAL,
339
+ ) as span:
340
+ self._set_span_attributes(span, params, task_id)
341
+ exec_ctx = self._build_execution_context(base_context, span)
342
+ token = set_current_context(exec_ctx)
343
+ try:
344
+ result = self._run(**params)
345
+ self._record_output(span, result)
346
+ return result
347
+ finally:
348
+ reset_current_context(token)
349
+
350
+ def get_input_port_by_name(self, name: str) -> Optional[Port]:
119
351
  for port in self.input_ports:
120
352
  if port.name == name:
121
353
  return port
122
354
  return None
123
355
 
124
- def get_output_port_by_name(self, name: str) -> Port:
125
- # TODO: add warning if port is extended
356
+ def get_output_port_by_name(self, name: str) -> Optional[Port]:
126
357
  for port in self.output_ports:
127
358
  if port.name == name:
128
359
  return port
@@ -148,9 +379,9 @@ class Node(ABC):
148
379
  self.try_create_extended_input_port(port_name)
149
380
  port = self.get_input_port_by_name(port_name)
150
381
  if port is None:
151
- raise ValueError(f'{port_name} is not a valid input port.')
382
+ raise ValueError(f"{port_name} is not a valid input port.")
152
383
  self.fill_input(port, value)
153
-
384
+
154
385
  def fill_input(self, port: Port, value: Any) -> None:
155
386
  port.activate(value)
156
387
 
@@ -168,7 +399,7 @@ class Node(ABC):
168
399
  for output_edge in self.output_edges:
169
400
  if output_edge.start_port == port:
170
401
  output_edge.end_port.prepare(data)
171
-
402
+
172
403
  def trigger_output_edges(self, port: Port) -> None:
173
404
  if isinstance(port, str):
174
405
  port = self.get_output_port_by_name(port)
@@ -178,7 +409,14 @@ class Node(ABC):
178
409
 
179
410
  # TODO: the result is outputed to the trigger now, maybe we should add a new function to output the result to the workflow
180
411
  def output_to_workflow(self, data: Any) -> None:
181
- self.workflow._handle_workflow_output(self.name, data)
412
+ if self.workflow and hasattr(self.workflow, "_handle_workflow_output"):
413
+ handler = cast(
414
+ Callable[[str, Any], None],
415
+ getattr(self.workflow, "_handle_workflow_output"),
416
+ )
417
+ handler(self.name, data)
418
+ else:
419
+ logger.warning("Workflow output handler not set; skipping output dispatch.")
182
420
 
183
421
  def extended_output_name(self, name: str, index: int) -> str:
184
422
  return name + '_' + str(index)
@@ -189,4 +427,5 @@ class Node(ABC):
189
427
  async def stream_output(self, data: Any) -> None:
190
428
  await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
191
429
 
192
- node_register = Register[Node]()
430
+
431
+ node_register = Register[Node]()
@@ -4,7 +4,7 @@ import asyncio
4
4
  import json
5
5
  from loguru import logger
6
6
  from service_forge.workflow.trigger import Trigger
7
- from typing import AsyncIterator, Any
7
+ from typing import AsyncIterator, Any, Optional
8
8
  from fastapi import FastAPI, Request
9
9
  from fastapi.responses import StreamingResponse
10
10
  from service_forge.workflow.port import Port
@@ -13,6 +13,7 @@ from service_forge.api.routers.websocket.websocket_manager import websocket_mana
13
13
  from fastapi import HTTPException
14
14
  from google.protobuf.message import Message
15
15
  from google.protobuf.json_format import MessageToJson
16
+ from opentelemetry import context as otel_context_api
16
17
 
17
18
  class FastAPITrigger(Trigger):
18
19
  DEFAULT_INPUT_PORTS = [
@@ -28,12 +29,14 @@ class FastAPITrigger(Trigger):
28
29
  Port("user_id", int),
29
30
  Port("token", str),
30
31
  Port("data", Any),
32
+ Port("path_params", dict),
31
33
  ]
32
34
 
33
35
  def __init__(self, name: str):
34
36
  super().__init__(name)
35
37
  self.events = {}
36
38
  self.is_setup_route = False
39
+ self.task_contexts: dict[uuid.UUID, otel_context_api.Context] = {}
37
40
  self.app = None
38
41
  self.route_path = None
39
42
  self.route_method = None
@@ -41,23 +44,44 @@ class FastAPITrigger(Trigger):
41
44
  @staticmethod
42
45
  def serialize_result(result: Any):
43
46
  if isinstance(result, Message):
44
- return MessageToJson(
45
- result,
46
- preserving_proto_field_name=True
47
- )
47
+ return MessageToJson(result, preserving_proto_field_name=True)
48
48
  return result
49
49
 
50
+ def _normalize_result_or_raise(self, result: Any):
51
+ # TODO: 检查合并
52
+ if hasattr(result, "is_error") and hasattr(result, "result"):
53
+ if result.is_error:
54
+ if isinstance(result.result, HTTPException):
55
+ raise result.result
56
+ raise HTTPException(status_code=500, detail=str(result.result))
57
+ return self.serialize_result(result.result)
58
+
59
+ if isinstance(result, HTTPException):
60
+ raise result
61
+ if isinstance(result, Exception):
62
+ raise HTTPException(status_code=500, detail=str(result))
63
+
64
+ return self.serialize_result(result)
65
+
50
66
  async def handle_request(
51
67
  self,
52
68
  request: Request,
53
69
  data_type: type,
54
70
  extract_data_fn: callable[[Request], dict],
55
71
  is_stream: bool,
72
+ path_params: Optional[dict] = None,
56
73
  ):
57
74
  task_id = uuid.uuid4()
58
75
  self.result_queues[task_id] = asyncio.Queue()
59
76
 
77
+ # parse trace context
78
+ trace_ctx = otel_context_api.get_current()
79
+ self.task_contexts[task_id] = trace_ctx
80
+
60
81
  body_data = await extract_data_fn(request)
82
+ # Merge path parameters into body_data (path params take precedence)
83
+ if path_params:
84
+ body_data = {**body_data, **path_params}
61
85
  converted_data = data_type(**body_data)
62
86
 
63
87
  client_id = (
@@ -70,11 +94,16 @@ class FastAPITrigger(Trigger):
70
94
  steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
71
95
  websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
72
96
 
97
+ # trigger_queue with trace_context, used in _run()
98
+ logger.info(f'user_id {getattr(request.state, "user_id", None)} token {getattr(request.state, "auth_token", None)}')
99
+
73
100
  self.trigger_queue.put_nowait({
74
101
  "id": task_id,
75
102
  "user_id": getattr(request.state, "user_id", None),
76
103
  "token": getattr(request.state, "auth_token", None),
77
104
  "data": converted_data,
105
+ "trace_context": trace_ctx,
106
+ "path_params": path_params,
78
107
  })
79
108
 
80
109
  if is_stream:
@@ -85,26 +114,23 @@ class FastAPITrigger(Trigger):
85
114
  while True:
86
115
  item = await self.stream_queues[task_id].get()
87
116
 
88
- if item.is_error:
117
+ if getattr(item, "is_error", False):
89
118
  yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
90
119
  break
91
-
92
- if item.is_end:
93
- # TODO: send the result?
120
+
121
+ if getattr(item, "is_end", False):
94
122
  break
95
123
 
96
124
  # TODO: modify
97
125
  serialized = self.serialize_result(item.result)
98
- if isinstance(serialized, str):
99
- data = serialized
100
- else:
101
- data = json.dumps(serialized)
102
-
126
+ data = serialized if isinstance(serialized, str) else json.dumps(serialized)
103
127
  yield f"data: {data}\n\n"
104
-
128
+
105
129
  except Exception as e:
106
130
  yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
107
131
  finally:
132
+ self.stream_queues.pop(task_id, None)
133
+
108
134
  if task_id in self.stream_queues:
109
135
  del self.stream_queues[task_id]
110
136
  if task_id in self.result_queues:
@@ -117,19 +143,13 @@ class FastAPITrigger(Trigger):
117
143
  "Cache-Control": "no-cache",
118
144
  "Connection": "keep-alive",
119
145
  "X-Accel-Buffering": "no",
120
- }
146
+ },
121
147
  )
122
- else:
123
- result = await self.result_queues[task_id].get()
124
- del self.result_queues[task_id]
125
148
 
126
- if result.is_error:
127
- if isinstance(result.result, HTTPException):
128
- raise result.result
129
- else:
130
- raise HTTPException(status_code=500, detail=str(result.result))
131
-
132
- return self.serialize_result(result.result)
149
+ # 非流式:等待结果
150
+ result = await self.result_queues[task_id].get()
151
+ self.result_queues.pop(task_id, None)
152
+ return self._normalize_result_or_raise(result)
133
153
 
134
154
  def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
135
155
  async def get_data(request: Request) -> dict:
@@ -144,7 +164,10 @@ class FastAPITrigger(Trigger):
144
164
  extractor = get_data if method == "GET" else body_data
145
165
 
146
166
  async def handler(request: Request):
147
- return await self.handle_request(request, data_type, extractor, is_stream)
167
+ # Get path parameters from FastAPI request
168
+ # request.path_params is always available in FastAPI and contains path parameters
169
+ path_params = dict(request.path_params)
170
+ return await self.handle_request(request, data_type, extractor, is_stream, path_params)
148
171
 
149
172
  # Save route information for cleanup
150
173
  self.app = app
@@ -162,7 +185,14 @@ class FastAPITrigger(Trigger):
162
185
  else:
163
186
  raise ValueError(f"Invalid method {method}")
164
187
 
165
- async def _run(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool = False) -> AsyncIterator[bool]:
188
+ async def _run(
189
+ self,
190
+ app: FastAPI,
191
+ path: str,
192
+ method: str,
193
+ data_type: type,
194
+ is_stream: bool = False,
195
+ ) -> AsyncIterator[bool]:
166
196
  if not self.is_setup_route:
167
197
  self._setup_route(app, path, method, data_type, is_stream)
168
198
  self.is_setup_route = True
@@ -170,9 +200,12 @@ class FastAPITrigger(Trigger):
170
200
  while True:
171
201
  try:
172
202
  trigger = await self.trigger_queue.get()
203
+ if trace_ctx := trigger.get("trace_context"):
204
+ self.task_contexts[trigger["id"]] = trace_ctx
173
205
  self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
174
206
  self.prepare_output_edges(self.get_output_port_by_name('token'), trigger['token'])
175
207
  self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
208
+ self.prepare_output_edges(self.get_output_port_by_name('path_params'), trigger['path_params'])
176
209
  yield self.trigger(trigger['id'])
177
210
  except Exception as e:
178
211
  logger.error(f"Error in FastAPITrigger._run: {e}")