service-forge 0.1.18__py3-none-any.whl → 0.1.39__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (80) hide show
  1. service_forge/__init__.py +0 -0
  2. service_forge/api/deprecated_websocket_api.py +91 -33
  3. service_forge/api/deprecated_websocket_manager.py +70 -53
  4. service_forge/api/http_api.py +205 -55
  5. service_forge/api/kafka_api.py +113 -25
  6. service_forge/api/routers/meta_api/meta_api_router.py +57 -0
  7. service_forge/api/routers/service/service_router.py +42 -6
  8. service_forge/api/routers/trace/trace_router.py +326 -0
  9. service_forge/api/routers/websocket/websocket_router.py +69 -1
  10. service_forge/api/service_studio.py +9 -0
  11. service_forge/db/database.py +17 -0
  12. service_forge/execution_context.py +106 -0
  13. service_forge/frontend/static/assets/CreateNewNodeDialog-DkrEMxSH.js +1 -0
  14. service_forge/frontend/static/assets/CreateNewNodeDialog-DwFcBiGp.css +1 -0
  15. service_forge/frontend/static/assets/EditorSidePanel-BNVms9Fq.css +1 -0
  16. service_forge/frontend/static/assets/EditorSidePanel-DZbB3ILL.js +1 -0
  17. service_forge/frontend/static/assets/FeedbackPanel-CC8HX7Yo.js +1 -0
  18. service_forge/frontend/static/assets/FeedbackPanel-ClgniIVk.css +1 -0
  19. service_forge/frontend/static/assets/FormattedCodeViewer.vue_vue_type_script_setup_true_lang-BNuI1NCs.js +1 -0
  20. service_forge/frontend/static/assets/NodeDetailWrapper-BqFFM7-r.js +1 -0
  21. service_forge/frontend/static/assets/NodeDetailWrapper-pZBxv3J0.css +1 -0
  22. service_forge/frontend/static/assets/TestRunningDialog-D0GrCoYs.js +1 -0
  23. service_forge/frontend/static/assets/TestRunningDialog-dhXOsPgH.css +1 -0
  24. service_forge/frontend/static/assets/TracePanelWrapper-B9zvDSc_.js +1 -0
  25. service_forge/frontend/static/assets/TracePanelWrapper-BiednCrq.css +1 -0
  26. service_forge/frontend/static/assets/WorkflowEditor-CcaGGbko.js +3 -0
  27. service_forge/frontend/static/assets/WorkflowEditor-CmasOOYK.css +1 -0
  28. service_forge/frontend/static/assets/WorkflowList-Copuwi-a.css +1 -0
  29. service_forge/frontend/static/assets/WorkflowList-LrRJ7B7h.js +1 -0
  30. service_forge/frontend/static/assets/WorkflowStudio-CthjgII2.css +1 -0
  31. service_forge/frontend/static/assets/WorkflowStudio-FCyhGD4y.js +2 -0
  32. service_forge/frontend/static/assets/api-BDer3rj7.css +1 -0
  33. service_forge/frontend/static/assets/api-DyiqpKJK.js +1 -0
  34. service_forge/frontend/static/assets/code-editor-DBSql_sc.js +12 -0
  35. service_forge/frontend/static/assets/el-collapse-item-D4LG0FJ0.css +1 -0
  36. service_forge/frontend/static/assets/el-empty-D4ZqTl4F.css +1 -0
  37. service_forge/frontend/static/assets/el-form-item-BWkJzdQ_.css +1 -0
  38. service_forge/frontend/static/assets/el-input-D6B3r8CH.css +1 -0
  39. service_forge/frontend/static/assets/el-select-B0XIb2QK.css +1 -0
  40. service_forge/frontend/static/assets/el-tag-DljBBxJR.css +1 -0
  41. service_forge/frontend/static/assets/element-ui-D3x2y3TA.js +12 -0
  42. service_forge/frontend/static/assets/elkjs-Dm5QV7uy.js +24 -0
  43. service_forge/frontend/static/assets/highlightjs-D4ATuRwX.js +3 -0
  44. service_forge/frontend/static/assets/index-BMvodlwc.js +2 -0
  45. service_forge/frontend/static/assets/index-CjSe8i2q.css +1 -0
  46. service_forge/frontend/static/assets/js-yaml-yTPt38rv.js +32 -0
  47. service_forge/frontend/static/assets/time-DKCKV6Ug.js +1 -0
  48. service_forge/frontend/static/assets/ui-components-DQ7-U3pr.js +1 -0
  49. service_forge/frontend/static/assets/vue-core-DL-LgTX0.js +1 -0
  50. service_forge/frontend/static/assets/vue-flow-Dn7R8GPr.js +39 -0
  51. service_forge/frontend/static/index.html +16 -0
  52. service_forge/frontend/static/vite.svg +1 -0
  53. service_forge/model/meta_api/__init__.py +0 -0
  54. service_forge/model/meta_api/schema.py +29 -0
  55. service_forge/model/trace.py +82 -0
  56. service_forge/service.py +39 -11
  57. service_forge/service_config.py +14 -0
  58. service_forge/sft/cli.py +39 -0
  59. service_forge/sft/cmd/remote_deploy.py +160 -0
  60. service_forge/sft/cmd/remote_list_tars.py +111 -0
  61. service_forge/sft/config/injector.py +54 -7
  62. service_forge/sft/config/injector_default_files.py +13 -1
  63. service_forge/sft/config/sf_metadata.py +31 -27
  64. service_forge/sft/config/sft_config.py +18 -0
  65. service_forge/sft/util/assert_util.py +0 -1
  66. service_forge/telemetry.py +66 -0
  67. service_forge/utils/default_type_converter.py +1 -1
  68. service_forge/utils/type_converter.py +5 -0
  69. service_forge/utils/workflow_clone.py +1 -0
  70. service_forge/workflow/node.py +274 -27
  71. service_forge/workflow/triggers/fast_api_trigger.py +64 -28
  72. service_forge/workflow/triggers/websocket_api_trigger.py +66 -38
  73. service_forge/workflow/workflow.py +140 -37
  74. service_forge/workflow/workflow_callback.py +27 -4
  75. service_forge/workflow/workflow_factory.py +14 -0
  76. {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/METADATA +4 -1
  77. service_forge-0.1.39.dist-info/RECORD +134 -0
  78. service_forge-0.1.18.dist-info/RECORD +0 -83
  79. {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/WHEEL +0 -0
  80. {service_forge-0.1.18.dist-info → service_forge-0.1.39.dist-info}/entry_points.txt +0 -0
@@ -1,50 +1,78 @@
1
1
  from __future__ import annotations
2
2
 
3
- from typing import Any, AsyncIterator, Union, TYPE_CHECKING, Callable, Awaitable
4
- from abc import ABC, abstractmethod
3
+ import inspect
4
+ import json
5
5
  import uuid
6
+ from abc import ABC, abstractmethod
7
+ from dataclasses import asdict, is_dataclass
8
+ from typing import (
9
+ TYPE_CHECKING,
10
+ Any,
11
+ AsyncIterator,
12
+ Awaitable,
13
+ Callable,
14
+ Optional,
15
+ Union,
16
+ cast,
17
+ )
18
+
6
19
  from loguru import logger
7
- from .edge import Edge
8
- from .port import Port
9
- from .context import Context
10
- from ..utils.register import Register
20
+ from opentelemetry import context as otel_context_api
21
+ from opentelemetry import trace
22
+ from opentelemetry.trace import SpanKind
23
+ from pydantic import BaseModel
24
+
11
25
  from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
12
26
  from ..utils.workflow_clone import node_clone
27
+ from ..execution_context import (
28
+ ExecutionContext,
29
+ get_current_context,
30
+ reset_current_context,
31
+ set_current_context,
32
+ )
33
+ from ..utils.register import Register
34
+ from .context import Context
35
+ from .edge import Edge
36
+ from .port import Port
37
+ from .workflow_callback import CallbackEvent
13
38
 
14
39
  if TYPE_CHECKING:
15
40
  from .workflow import Workflow
16
41
 
42
+
17
43
  class Node(ABC):
18
44
  DEFAULT_INPUT_PORTS: list[Port] = []
19
45
  DEFAULT_OUTPUT_PORTS: list[Port] = []
20
46
 
21
- CLASS_NOT_REQUIRED_TO_REGISTER = ['Node']
47
+ CLASS_NOT_REQUIRED_TO_REGISTER = ["Node"]
22
48
  AUTO_FILL_INPUT_PORTS = []
23
49
 
24
50
  def __init__(
25
51
  self,
26
52
  name: str,
27
- context: Context = None,
28
- input_edges: list[Edge] = None,
29
- output_edges: list[Edge] = None,
53
+ context: Optional[Context] = None,
54
+ input_edges: Optional[list[Edge]] = None,
55
+ output_edges: Optional[list[Edge]] = None,
30
56
  input_ports: list[Port] = DEFAULT_INPUT_PORTS,
31
57
  output_ports: list[Port] = DEFAULT_OUTPUT_PORTS,
32
- query_user: Callable[[str, str], Awaitable[str]] = None,
58
+ query_user: Optional[Callable[[str, str], Awaitable[str]]] = None,
33
59
  ) -> None:
34
60
  from .workflow_group import WorkflowGroup
61
+
35
62
  self.name = name
36
63
  self.input_edges = [] if input_edges is None else input_edges
37
64
  self.output_edges = [] if output_edges is None else output_edges
38
65
  self.input_ports = input_ports
39
66
  self.output_ports = output_ports
40
- self.workflow: Workflow = None
67
+ self.workflow: Optional[Workflow] = None
41
68
  self.query_user = query_user
42
- self.sub_workflows: WorkflowGroup = None
69
+ self.sub_workflows: Optional[WorkflowGroup] = None
43
70
 
44
71
  # runtime variables
45
72
  self.context = context
46
73
  self.input_variables: dict[Port, Any] = {}
47
74
  self.num_activated_input_edges = 0
75
+ self._tracer = trace.get_tracer("service_forge.node")
48
76
 
49
77
  @property
50
78
  def default_postgres_database(self) -> PostgresDatabase | None:
@@ -60,28 +88,39 @@ class Node(ABC):
60
88
 
61
89
  @property
62
90
  def database_manager(self) -> DatabaseManager:
91
+ assert self.workflow
63
92
  return self.workflow.database_manager
64
93
 
94
+ @property
95
+ def global_context(self) -> Context:
96
+ return self.workflow.global_context
97
+
65
98
  def backup(self) -> None:
66
99
  # do NOT use deepcopy here
67
100
  # self.bak_context = deepcopy(self.context)
68
101
  # TODO: what if the value changes after backup?
69
- self.bak_input_variables = {port: value for port, value in self.input_variables.items()}
102
+ self.bak_input_variables = {
103
+ port: value for port, value in self.input_variables.items()
104
+ }
70
105
  self.bak_num_activated_input_edges = self.num_activated_input_edges
71
106
 
72
107
  def reset(self) -> None:
73
108
  # self.context = deepcopy(self.bak_context)
74
- self.input_variables = {port: value for port, value in self.bak_input_variables.items()}
109
+ self.input_variables = {
110
+ port: value for port, value in self.bak_input_variables.items()
111
+ }
75
112
  self.num_activated_input_edges = self.bak_num_activated_input_edges
76
113
 
77
114
  def __init_subclass__(cls) -> None:
78
115
  if cls.__name__ not in Node.CLASS_NOT_REQUIRED_TO_REGISTER:
116
+ # TODO: Register currently stores class objects; clarify Register typing vs instance usage.
79
117
  node_register.register(cls.__name__, cls)
80
118
  return super().__init_subclass__()
81
119
 
82
- def _query_user(self, prompt: str) -> Callable[[str, str], Awaitable[str]]:
120
+ def _query_user(self, prompt: str) -> Awaitable[str]:
121
+ assert self.query_user
83
122
  return self.query_user(self.name, prompt)
84
-
123
+
85
124
  def variables_to_params(self) -> dict[str, Any]:
86
125
  params = {port.name: self.input_variables[port] for port in self.input_variables.keys() if not port.is_extended_generated}
87
126
  for port in self.input_variables.keys():
@@ -94,6 +133,7 @@ class Node(ABC):
94
133
 
95
134
  def is_trigger(self) -> bool:
96
135
  from .trigger import Trigger
136
+
97
137
  return isinstance(self, Trigger)
98
138
 
99
139
  # TODO: maybe add a function before the run function?
@@ -101,23 +141,219 @@ class Node(ABC):
101
141
  @abstractmethod
102
142
  async def _run(self, **kwargs) -> Union[None, AsyncIterator]:
103
143
  ...
144
+
145
+ async def clear(self) -> None:
146
+ ...
104
147
 
105
148
  def run(self) -> Union[None, AsyncIterator]:
149
+ task_id: uuid.UUID | None = None
106
150
  for key in list(self.input_variables.keys()):
107
151
  if key and key.name[0].isupper():
108
152
  del self.input_variables[key]
109
153
  params = self.variables_to_params()
110
- return self._run(**params)
154
+ if task_id is not None and "task_id" in self._run.__code__.co_varnames:
155
+ params["task_id"] = task_id
156
+ base_context = get_current_context()
157
+ parent_context = (
158
+ base_context.trace_context
159
+ if base_context and base_context.trace_context
160
+ else otel_context_api.get_current()
161
+ )
162
+ span_name = f"Node {self.name}"
111
163
 
112
- def get_input_port_by_name(self, name: str) -> Port:
113
- # TODO: add warning if port is extended
164
+ if inspect.isasyncgenfunction(self._run):
165
+ return self._run_async_generator(
166
+ params, task_id, base_context, parent_context, span_name
167
+ )
168
+ if inspect.iscoroutinefunction(self._run):
169
+ return self._run_async(
170
+ params, task_id, base_context, parent_context, span_name
171
+ )
172
+
173
+ return self._run_sync(params, task_id, base_context, parent_context, span_name)
174
+
175
+ def _build_execution_context(
176
+ self, base_context: ExecutionContext | None, span: trace.Span
177
+ ) -> ExecutionContext:
178
+ return ExecutionContext(
179
+ trace_context=otel_context_api.get_current(),
180
+ span=span,
181
+ state=base_context.state if base_context else {},
182
+ metadata={
183
+ **(base_context.metadata if base_context else {}),
184
+ "node": self.name,
185
+ "workflow_name": getattr(self.workflow, "name", None),
186
+ },
187
+ )
188
+
189
+ @staticmethod
190
+ def _serialize_for_trace(value: Any, max_length: int = 4000) -> tuple[str, bool]:
191
+ def _normalize(val: Any) -> Any:
192
+ if isinstance(val, BaseModel):
193
+ return val.model_dump()
194
+ if hasattr(val, "model_dump"):
195
+ try:
196
+ return val.model_dump()
197
+ except Exception:
198
+ pass
199
+ if hasattr(val, "dict"):
200
+ try:
201
+ return val.dict()
202
+ except Exception:
203
+ pass
204
+ if is_dataclass is not None and is_dataclass(val):
205
+ return asdict(val) if asdict else val
206
+ if isinstance(val, dict):
207
+ return {k: _normalize(v) for k, v in val.items()}
208
+ if isinstance(val, (list, tuple)):
209
+ return [_normalize(v) for v in val]
210
+ return val
211
+
212
+ normalized_value = _normalize(value)
213
+ serialized = json.dumps(normalized_value, ensure_ascii=False, default=str)
214
+ if len(serialized) > max_length:
215
+ return serialized[:max_length], True
216
+ return serialized, False
217
+
218
+ def _set_span_attributes(
219
+ self, span: trace.Span, params: dict[str, Any], task_id: uuid.UUID | None
220
+ ) -> None:
221
+ span.set_attribute("node.name", self.name)
222
+ if self.workflow is not None:
223
+ span.set_attribute("workflow.name", self.workflow.name)
224
+ if task_id is not None:
225
+ span.set_attribute("workflow.task_id", str(task_id))
226
+ span.set_attribute("node.input_keys", ",".join(params.keys()))
227
+ serialized_inputs, inputs_truncated = self._serialize_for_trace(params)
228
+ span.set_attribute("node.inputs", serialized_inputs)
229
+ if inputs_truncated:
230
+ span.set_attribute("node.inputs_truncated", True)
231
+
232
+ def _record_output(self, span: trace.Span, output: Any) -> None:
233
+ span.set_attribute(
234
+ "node.output_type", type(output).__name__ if output is not None else "None"
235
+ )
236
+ serialized_output, output_truncated = self._serialize_for_trace(output)
237
+ span.set_attribute("node.output", serialized_output)
238
+ if output_truncated:
239
+ span.set_attribute("node.output_truncated", True)
240
+
241
+ async def _run_async(
242
+ self,
243
+ params: dict[str, Any],
244
+ task_id: uuid.UUID | None,
245
+ base_context: ExecutionContext | None,
246
+ parent_context: otel_context_api.Context,
247
+ span_name: str,
248
+ ) -> Any:
249
+ with self._tracer.start_as_current_span(
250
+ span_name,
251
+ context=parent_context,
252
+ kind=SpanKind.INTERNAL,
253
+ ) as span:
254
+ self._set_span_attributes(span, params, task_id)
255
+ exec_ctx = self._build_execution_context(base_context, span)
256
+ token = set_current_context(exec_ctx)
257
+ try:
258
+ result = self._run(**params)
259
+ if inspect.isawaitable(result):
260
+ result = await result
261
+ self._record_output(span, result)
262
+ return result
263
+ finally:
264
+ reset_current_context(token)
265
+
266
+ async def _run_async_generator(
267
+ self,
268
+ params: dict[str, Any],
269
+ task_id: uuid.UUID | None,
270
+ base_context: ExecutionContext | None,
271
+ parent_context: otel_context_api.Context,
272
+ span_name: str,
273
+ ) -> AsyncIterator[Any]:
274
+ # Trigger 节点是长期运行的 async generator,这里为每次触发单独生成/关闭 span,避免一个 span 挂载所有请求。
275
+ if self.is_trigger():
276
+ async for item in self._run(**params):
277
+ trigger_parent_context = parent_context
278
+ if hasattr(self, "task_contexts") and isinstance(item, uuid.UUID):
279
+ trigger_ctx = getattr(self, "task_contexts").get(item)
280
+ if trigger_ctx is not None:
281
+ trigger_parent_context = trigger_ctx
282
+ with self._tracer.start_as_current_span(
283
+ span_name,
284
+ context=trigger_parent_context,
285
+ kind=SpanKind.INTERNAL,
286
+ ) as span:
287
+ self._set_span_attributes(span, params, task_id)
288
+ span.set_attribute("node.output_type", "async_generator")
289
+ serialized_item, item_truncated = self._serialize_for_trace(item)
290
+ span.add_event(
291
+ "node.output_item",
292
+ {
293
+ "value": serialized_item,
294
+ "truncated": item_truncated,
295
+ },
296
+ )
297
+ exec_ctx = self._build_execution_context(base_context, span)
298
+ token = set_current_context(exec_ctx)
299
+ try:
300
+ yield item
301
+ finally:
302
+ reset_current_context(token)
303
+ else:
304
+ with self._tracer.start_as_current_span(
305
+ span_name,
306
+ context=parent_context,
307
+ kind=SpanKind.INTERNAL,
308
+ ) as span:
309
+ self._set_span_attributes(span, params, task_id)
310
+ span.set_attribute("node.output_type", "async_generator")
311
+ exec_ctx = self._build_execution_context(base_context, span)
312
+ token = set_current_context(exec_ctx)
313
+ try:
314
+ async for item in self._run(**params):
315
+ serialized_item, item_truncated = self._serialize_for_trace(item)
316
+ span.add_event(
317
+ "node.output_item",
318
+ {
319
+ "value": serialized_item,
320
+ "truncated": item_truncated,
321
+ },
322
+ )
323
+ yield item
324
+ finally:
325
+ reset_current_context(token)
326
+
327
+ def _run_sync(
328
+ self,
329
+ params: dict[str, Any],
330
+ task_id: uuid.UUID | None,
331
+ base_context: ExecutionContext | None,
332
+ parent_context: otel_context_api.Context,
333
+ span_name: str,
334
+ ) -> Any:
335
+ with self._tracer.start_as_current_span(
336
+ span_name,
337
+ context=parent_context,
338
+ kind=SpanKind.INTERNAL,
339
+ ) as span:
340
+ self._set_span_attributes(span, params, task_id)
341
+ exec_ctx = self._build_execution_context(base_context, span)
342
+ token = set_current_context(exec_ctx)
343
+ try:
344
+ result = self._run(**params)
345
+ self._record_output(span, result)
346
+ return result
347
+ finally:
348
+ reset_current_context(token)
349
+
350
+ def get_input_port_by_name(self, name: str) -> Optional[Port]:
114
351
  for port in self.input_ports:
115
352
  if port.name == name:
116
353
  return port
117
354
  return None
118
355
 
119
- def get_output_port_by_name(self, name: str) -> Port:
120
- # TODO: add warning if port is extended
356
+ def get_output_port_by_name(self, name: str) -> Optional[Port]:
121
357
  for port in self.output_ports:
122
358
  if port.name == name:
123
359
  return port
@@ -143,9 +379,9 @@ class Node(ABC):
143
379
  self.try_create_extended_input_port(port_name)
144
380
  port = self.get_input_port_by_name(port_name)
145
381
  if port is None:
146
- raise ValueError(f'{port_name} is not a valid input port.')
382
+ raise ValueError(f"{port_name} is not a valid input port.")
147
383
  self.fill_input(port, value)
148
-
384
+
149
385
  def fill_input(self, port: Port, value: Any) -> None:
150
386
  port.activate(value)
151
387
 
@@ -163,7 +399,7 @@ class Node(ABC):
163
399
  for output_edge in self.output_edges:
164
400
  if output_edge.start_port == port:
165
401
  output_edge.end_port.prepare(data)
166
-
402
+
167
403
  def trigger_output_edges(self, port: Port) -> None:
168
404
  if isinstance(port, str):
169
405
  port = self.get_output_port_by_name(port)
@@ -173,7 +409,14 @@ class Node(ABC):
173
409
 
174
410
  # TODO: the result is outputed to the trigger now, maybe we should add a new function to output the result to the workflow
175
411
  def output_to_workflow(self, data: Any) -> None:
176
- self.workflow._handle_workflow_output(self.name, data)
412
+ if self.workflow and hasattr(self.workflow, "_handle_workflow_output"):
413
+ handler = cast(
414
+ Callable[[str, Any], None],
415
+ getattr(self.workflow, "_handle_workflow_output"),
416
+ )
417
+ handler(self.name, data)
418
+ else:
419
+ logger.warning("Workflow output handler not set; skipping output dispatch.")
177
420
 
178
421
  def extended_output_name(self, name: str, index: int) -> str:
179
422
  return name + '_' + str(index)
@@ -181,4 +424,8 @@ class Node(ABC):
181
424
  def _clone(self, context: Context) -> Node:
182
425
  return node_clone(self, context)
183
426
 
184
- node_register = Register[Node]()
427
+ async def stream_output(self, data: Any) -> None:
428
+ await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
429
+
430
+
431
+ node_register = Register[Node]()
@@ -4,7 +4,7 @@ import asyncio
4
4
  import json
5
5
  from loguru import logger
6
6
  from service_forge.workflow.trigger import Trigger
7
- from typing import AsyncIterator, Any
7
+ from typing import AsyncIterator, Any, Optional
8
8
  from fastapi import FastAPI, Request
9
9
  from fastapi.responses import StreamingResponse
10
10
  from service_forge.workflow.port import Port
@@ -13,6 +13,7 @@ from service_forge.api.routers.websocket.websocket_manager import websocket_mana
13
13
  from fastapi import HTTPException
14
14
  from google.protobuf.message import Message
15
15
  from google.protobuf.json_format import MessageToJson
16
+ from opentelemetry import context as otel_context_api
16
17
 
17
18
  class FastAPITrigger(Trigger):
18
19
  DEFAULT_INPUT_PORTS = [
@@ -26,13 +27,16 @@ class FastAPITrigger(Trigger):
26
27
  DEFAULT_OUTPUT_PORTS = [
27
28
  Port("trigger", bool),
28
29
  Port("user_id", int),
30
+ Port("token", str),
29
31
  Port("data", Any),
32
+ Port("path_params", dict),
30
33
  ]
31
34
 
32
35
  def __init__(self, name: str):
33
36
  super().__init__(name)
34
37
  self.events = {}
35
38
  self.is_setup_route = False
39
+ self.task_contexts: dict[uuid.UUID, otel_context_api.Context] = {}
36
40
  self.app = None
37
41
  self.route_path = None
38
42
  self.route_method = None
@@ -40,23 +44,44 @@ class FastAPITrigger(Trigger):
40
44
  @staticmethod
41
45
  def serialize_result(result: Any):
42
46
  if isinstance(result, Message):
43
- return MessageToJson(
44
- result,
45
- preserving_proto_field_name=True
46
- )
47
+ return MessageToJson(result, preserving_proto_field_name=True)
47
48
  return result
48
49
 
50
+ def _normalize_result_or_raise(self, result: Any):
51
+ # TODO: 检查合并
52
+ if hasattr(result, "is_error") and hasattr(result, "result"):
53
+ if result.is_error:
54
+ if isinstance(result.result, HTTPException):
55
+ raise result.result
56
+ raise HTTPException(status_code=500, detail=str(result.result))
57
+ return self.serialize_result(result.result)
58
+
59
+ if isinstance(result, HTTPException):
60
+ raise result
61
+ if isinstance(result, Exception):
62
+ raise HTTPException(status_code=500, detail=str(result))
63
+
64
+ return self.serialize_result(result)
65
+
49
66
  async def handle_request(
50
67
  self,
51
68
  request: Request,
52
69
  data_type: type,
53
70
  extract_data_fn: callable[[Request], dict],
54
71
  is_stream: bool,
72
+ path_params: Optional[dict] = None,
55
73
  ):
56
74
  task_id = uuid.uuid4()
57
75
  self.result_queues[task_id] = asyncio.Queue()
58
76
 
77
+ # parse trace context
78
+ trace_ctx = otel_context_api.get_current()
79
+ self.task_contexts[task_id] = trace_ctx
80
+
59
81
  body_data = await extract_data_fn(request)
82
+ # Merge path parameters into body_data (path params take precedence)
83
+ if path_params:
84
+ body_data = {**body_data, **path_params}
60
85
  converted_data = data_type(**body_data)
61
86
 
62
87
  client_id = (
@@ -69,10 +94,16 @@ class FastAPITrigger(Trigger):
69
94
  steps = len(self.workflow.nodes) if hasattr(self.workflow, "nodes") else 1
70
95
  websocket_manager.create_task_with_client(task_id, client_id, workflow_name, steps)
71
96
 
97
+ # trigger_queue with trace_context, used in _run()
98
+ logger.info(f'user_id {getattr(request.state, "user_id", None)} token {getattr(request.state, "auth_token", None)}')
99
+
72
100
  self.trigger_queue.put_nowait({
73
101
  "id": task_id,
74
102
  "user_id": getattr(request.state, "user_id", None),
103
+ "token": getattr(request.state, "auth_token", None),
75
104
  "data": converted_data,
105
+ "trace_context": trace_ctx,
106
+ "path_params": path_params,
76
107
  })
77
108
 
78
109
  if is_stream:
@@ -83,26 +114,23 @@ class FastAPITrigger(Trigger):
83
114
  while True:
84
115
  item = await self.stream_queues[task_id].get()
85
116
 
86
- if item.is_error:
117
+ if getattr(item, "is_error", False):
87
118
  yield f"event: error\ndata: {json.dumps({'detail': str(item.result)})}\n\n"
88
119
  break
89
-
90
- if item.is_end:
91
- # TODO: send the result?
120
+
121
+ if getattr(item, "is_end", False):
92
122
  break
93
123
 
94
124
  # TODO: modify
95
125
  serialized = self.serialize_result(item.result)
96
- if isinstance(serialized, str):
97
- data = serialized
98
- else:
99
- data = json.dumps(serialized)
100
-
126
+ data = serialized if isinstance(serialized, str) else json.dumps(serialized)
101
127
  yield f"data: {data}\n\n"
102
-
128
+
103
129
  except Exception as e:
104
130
  yield f"event: error\ndata: {json.dumps({'detail': str(e)})}\n\n"
105
131
  finally:
132
+ self.stream_queues.pop(task_id, None)
133
+
106
134
  if task_id in self.stream_queues:
107
135
  del self.stream_queues[task_id]
108
136
  if task_id in self.result_queues:
@@ -115,19 +143,13 @@ class FastAPITrigger(Trigger):
115
143
  "Cache-Control": "no-cache",
116
144
  "Connection": "keep-alive",
117
145
  "X-Accel-Buffering": "no",
118
- }
146
+ },
119
147
  )
120
- else:
121
- result = await self.result_queues[task_id].get()
122
- del self.result_queues[task_id]
123
148
 
124
- if result.is_error:
125
- if isinstance(result.result, HTTPException):
126
- raise result.result
127
- else:
128
- raise HTTPException(status_code=500, detail=str(result.result))
129
-
130
- return self.serialize_result(result.result)
149
+ # 非流式:等待结果
150
+ result = await self.result_queues[task_id].get()
151
+ self.result_queues.pop(task_id, None)
152
+ return self._normalize_result_or_raise(result)
131
153
 
132
154
  def _setup_route(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool) -> None:
133
155
  async def get_data(request: Request) -> dict:
@@ -142,7 +164,10 @@ class FastAPITrigger(Trigger):
142
164
  extractor = get_data if method == "GET" else body_data
143
165
 
144
166
  async def handler(request: Request):
145
- return await self.handle_request(request, data_type, extractor, is_stream)
167
+ # Get path parameters from FastAPI request
168
+ # request.path_params is always available in FastAPI and contains path parameters
169
+ path_params = dict(request.path_params)
170
+ return await self.handle_request(request, data_type, extractor, is_stream, path_params)
146
171
 
147
172
  # Save route information for cleanup
148
173
  self.app = app
@@ -160,7 +185,14 @@ class FastAPITrigger(Trigger):
160
185
  else:
161
186
  raise ValueError(f"Invalid method {method}")
162
187
 
163
- async def _run(self, app: FastAPI, path: str, method: str, data_type: type, is_stream: bool = False) -> AsyncIterator[bool]:
188
+ async def _run(
189
+ self,
190
+ app: FastAPI,
191
+ path: str,
192
+ method: str,
193
+ data_type: type,
194
+ is_stream: bool = False,
195
+ ) -> AsyncIterator[bool]:
164
196
  if not self.is_setup_route:
165
197
  self._setup_route(app, path, method, data_type, is_stream)
166
198
  self.is_setup_route = True
@@ -168,8 +200,12 @@ class FastAPITrigger(Trigger):
168
200
  while True:
169
201
  try:
170
202
  trigger = await self.trigger_queue.get()
203
+ if trace_ctx := trigger.get("trace_context"):
204
+ self.task_contexts[trigger["id"]] = trace_ctx
171
205
  self.prepare_output_edges(self.get_output_port_by_name('user_id'), trigger['user_id'])
206
+ self.prepare_output_edges(self.get_output_port_by_name('token'), trigger['token'])
172
207
  self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
208
+ self.prepare_output_edges(self.get_output_port_by_name('path_params'), trigger['path_params'])
173
209
  yield self.trigger(trigger['id'])
174
210
  except Exception as e:
175
211
  logger.error(f"Error in FastAPITrigger._run: {e}")