service-forge 0.1.11__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of service-forge might be problematic. Click here for more details.

Files changed (35) hide show
  1. service_forge/api/http_api.py +4 -0
  2. service_forge/api/routers/feedback/feedback_router.py +148 -0
  3. service_forge/api/routers/service/service_router.py +22 -32
  4. service_forge/current_service.py +14 -0
  5. service_forge/db/database.py +29 -32
  6. service_forge/db/migrations/feedback_migration.py +154 -0
  7. service_forge/db/models/__init__.py +0 -0
  8. service_forge/db/models/feedback.py +33 -0
  9. service_forge/llm/__init__.py +5 -0
  10. service_forge/model/feedback.py +30 -0
  11. service_forge/service.py +118 -126
  12. service_forge/service_config.py +42 -156
  13. service_forge/sft/config/injector.py +33 -23
  14. service_forge/sft/config/sft_config.py +55 -8
  15. service_forge/storage/__init__.py +5 -0
  16. service_forge/storage/feedback_storage.py +245 -0
  17. service_forge/utils/workflow_clone.py +3 -2
  18. service_forge/workflow/node.py +8 -0
  19. service_forge/workflow/nodes/llm/query_llm_node.py +1 -1
  20. service_forge/workflow/trigger.py +4 -0
  21. service_forge/workflow/triggers/a2a_api_trigger.py +2 -0
  22. service_forge/workflow/triggers/fast_api_trigger.py +32 -0
  23. service_forge/workflow/triggers/kafka_api_trigger.py +3 -0
  24. service_forge/workflow/triggers/once_trigger.py +4 -1
  25. service_forge/workflow/triggers/period_trigger.py +4 -1
  26. service_forge/workflow/triggers/websocket_api_trigger.py +15 -11
  27. service_forge/workflow/workflow.py +26 -4
  28. service_forge/workflow/workflow_config.py +66 -0
  29. service_forge/workflow/workflow_factory.py +86 -85
  30. service_forge/workflow/workflow_group.py +33 -9
  31. {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/METADATA +1 -1
  32. {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/RECORD +34 -26
  33. service_forge/api/routers/service/__init__.py +0 -4
  34. {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/WHEEL +0 -0
  35. {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/entry_points.txt +0 -0
@@ -21,6 +21,8 @@ class SftConfig:
21
21
  "inject_postgres_port": "Postgres port for services",
22
22
  "inject_postgres_user": "Postgres user for services",
23
23
  "inject_postgres_password": "Postgres password for services",
24
+ "inject_feedback_api_url": "Feedback API URL for services",
25
+ "inject_feedback_api_timeout": "Feedback API timeout for services",
24
26
  "deepseek_api_key": "DeepSeek API key",
25
27
  "deepseek_base_url": "DeepSeek base URL",
26
28
  }
@@ -52,6 +54,9 @@ class SftConfig:
52
54
  inject_redis_port: int = 6379,
53
55
  inject_redis_password: str = "rDdM2Y2gX9",
54
56
 
57
+ inject_feedback_api_url: str = "http://vps.shiweinan.com:37919/api/v1/feedback",
58
+ inject_feedback_api_timeout: int = 5,
59
+
55
60
  deepseek_api_key: str = "82c9df22-f6ed-411e-90d7-c5255376b7ca",
56
61
  deepseek_base_url: str = "https://ark.cn-beijing.volces.com/api/v3",
57
62
  ):
@@ -80,6 +85,9 @@ class SftConfig:
80
85
  self.inject_redis_port = inject_redis_port
81
86
  self.inject_redis_password = inject_redis_password
82
87
 
88
+ self.inject_feedback_api_url = inject_feedback_api_url
89
+ self.inject_feedback_api_timeout = inject_feedback_api_timeout
90
+
83
91
  self.deepseek_api_key = deepseek_api_key
84
92
  self.deepseek_base_url = deepseek_base_url
85
93
 
@@ -91,10 +99,23 @@ class SftConfig:
91
99
  def upload_timeout(self) -> int:
92
100
  return 300 # 5 minutes default timeout
93
101
 
94
- @classmethod
95
- def get_config_keys(cls) -> list[str]:
96
- sig = inspect.signature(cls.__init__)
97
- return [param for param in sig.parameters.keys() if param != 'self']
102
+ def get_config_keys(self) -> list[str]:
103
+ # Get initial configuration parameters from __init__ method
104
+ sig = inspect.signature(self.__class__.__init__)
105
+ init_keys = [param for param in sig.parameters.keys() if param != 'self']
106
+
107
+ # Get all instance attributes (including dynamically added configurations)
108
+ instance_keys = []
109
+ for attr_name in dir(self):
110
+ # Exclude special methods, private attributes, class attributes, and methods
111
+ if (not attr_name.startswith('_') and
112
+ not callable(getattr(self, attr_name)) and
113
+ attr_name not in ['CONFIG_ROOT', 'CONFIG_DESCRIPTIONS']):
114
+ instance_keys.append(attr_name)
115
+
116
+ # Merge and deduplicate, maintaining order (initial configs first, then dynamically added)
117
+ all_keys = list(dict.fromkeys(init_keys + instance_keys))
118
+ return all_keys
98
119
 
99
120
  @property
100
121
  def config_file_path(self) -> Path:
@@ -105,13 +126,30 @@ class SftConfig:
105
126
 
106
127
  def to_dict(self) -> dict:
107
128
  config_keys = self.get_config_keys()
108
- return {key: getattr(self, key) for key in config_keys}
129
+ result = {}
130
+ for key in config_keys:
131
+ value = getattr(self, key)
132
+ # Convert Path objects to strings for JSON serialization
133
+ if isinstance(value, Path):
134
+ value = str(value)
135
+ result[key] = value
136
+ return result
109
137
 
110
138
  def from_dict(self, data: dict) -> None:
111
- config_keys = self.get_config_keys()
112
- for key in config_keys:
139
+ # Get initial configuration parameters from __init__ method
140
+ sig = inspect.signature(self.__class__.__init__)
141
+ init_keys = [param for param in sig.parameters.keys() if param != 'self']
142
+
143
+ # First, set all initial configuration parameters
144
+ for key in init_keys:
113
145
  if key in data:
114
146
  setattr(self, key, data[key])
147
+
148
+ # Then, handle any additional keys that might be dynamically added configurations
149
+ for key, value in data.items():
150
+ if key not in init_keys and key not in ['CONFIG_ROOT', 'CONFIG_DESCRIPTIONS']:
151
+ # This might be a dynamically added configuration
152
+ setattr(self, key, value)
115
153
 
116
154
  def save(self) -> None:
117
155
  self.ensure_config_dir()
@@ -121,13 +159,22 @@ class SftConfig:
121
159
  def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
122
160
  return getattr(self, key, default)
123
161
 
124
- def set(self, key: str, value: str) -> None:
162
+ def set(self, key: str, value: str, description: Optional[str] = None) -> None:
125
163
  if key in ["config_root"]:
126
164
  raise ValueError(f"{key} is read-only")
127
165
  if hasattr(self, key):
128
166
  setattr(self, key, value)
167
+ if description:
168
+ self.CONFIG_DESCRIPTIONS[key] = description
129
169
  else:
130
170
  raise ValueError(f"Unknown config key: {key}")
171
+
172
+ def add(self, key: str, value: str, description: Optional[str] = None) -> None:
173
+ if hasattr(self, key):
174
+ raise ValueError(f"{key} already exists")
175
+ setattr(self, key, value)
176
+ if description:
177
+ self.CONFIG_DESCRIPTIONS[key] = description
131
178
 
132
179
  def update(self, updates: dict) -> None:
133
180
  for key, value in updates.items():
@@ -0,0 +1,5 @@
1
+ """Storage layer for business logic."""
2
+
3
+ from .feedback_storage import FeedbackStorage, feedback_storage
4
+
5
+ __all__ = ["FeedbackStorage", "feedback_storage"]
@@ -0,0 +1,245 @@
1
+ from typing import Optional, Any
2
+ from datetime import datetime
3
+ from uuid import uuid4
4
+ from loguru import logger
5
+ from sqlalchemy import select, and_
6
+ from sqlalchemy.exc import SQLAlchemyError
7
+
8
+ from ..db.models.feedback import FeedbackBase
9
+ from ..db.database import DatabaseManager
10
+
11
+
12
+ class FeedbackStorage:
13
+ """反馈存储管理器 - 使用 PostgreSQL 数据库存储"""
14
+
15
+ def __init__(self, database_manager: DatabaseManager = None):
16
+ self.database_manager = database_manager
17
+ self._db = None
18
+ # 内存存储后备(用于数据库未配置时)
19
+ self._storage: dict[str, dict[str, Any]] = {}
20
+ self._task_index: dict[str, list[str]] = {}
21
+ self._workflow_index: dict[str, list[str]] = {}
22
+
23
+ @property
24
+ def db(self):
25
+ """延迟获取数据库连接"""
26
+ if self._db is None and self.database_manager is not None:
27
+ self._db = self.database_manager.get_default_postgres_database()
28
+ return self._db
29
+
30
+ async def create_feedback(
31
+ self,
32
+ task_id: str,
33
+ workflow_name: str,
34
+ rating: Optional[int] = None,
35
+ comment: Optional[str] = None,
36
+ metadata: Optional[dict[str, Any]] = None,
37
+ ) -> dict[str, Any]:
38
+ """创建新的反馈记录"""
39
+ if self.db is None:
40
+ logger.warning("数据库未初始化,使用内存存储(数据将在重启后丢失)")
41
+ return self._create_feedback_in_memory(task_id, workflow_name, rating, comment, metadata)
42
+
43
+ try:
44
+ feedback_id = uuid4()
45
+ created_at = datetime.now()
46
+
47
+ feedback = FeedbackBase(
48
+ feedback_id=feedback_id,
49
+ task_id=task_id,
50
+ workflow_name=workflow_name,
51
+ rating=rating,
52
+ comment=comment,
53
+ extra_metadata=metadata or {},
54
+ created_at=created_at,
55
+ )
56
+
57
+ session_factory = await self.db.get_session_factory()
58
+ async with session_factory() as session:
59
+ session.add(feedback)
60
+ await session.commit()
61
+ await session.refresh(feedback)
62
+
63
+ logger.info(f"创建反馈: feedback_id={feedback_id}, task_id={task_id}, workflow={workflow_name}")
64
+ return feedback.to_dict()
65
+
66
+ except SQLAlchemyError as e:
67
+ logger.error(f"数据库创建反馈失败: {e}")
68
+ raise
69
+ except Exception as e:
70
+ logger.error(f"创建反馈失败: {e}")
71
+ raise
72
+
73
+ async def get_feedback(self, feedback_id: str) -> Optional[dict[str, Any]]:
74
+ """根据反馈ID获取反馈"""
75
+ if self.db is None:
76
+ return self._get_feedback_from_memory(feedback_id)
77
+
78
+ try:
79
+ session_factory = await self.db.get_session_factory()
80
+ async with session_factory() as session:
81
+ result = await session.execute(
82
+ select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
83
+ )
84
+ feedback = result.scalar_one_or_none()
85
+ return feedback.to_dict() if feedback else None
86
+
87
+ except SQLAlchemyError as e:
88
+ logger.error(f"查询反馈失败: {e}")
89
+ return None
90
+
91
+ async def get_feedbacks_by_task(self, task_id: str) -> list[dict[str, Any]]:
92
+ """根据任务ID获取所有反馈"""
93
+ if self.db is None:
94
+ return self._get_feedbacks_by_task_from_memory(task_id)
95
+
96
+ try:
97
+ session_factory = await self.db.get_session_factory()
98
+ async with session_factory() as session:
99
+ result = await session.execute(
100
+ select(FeedbackBase)
101
+ .where(FeedbackBase.task_id == task_id)
102
+ .order_by(FeedbackBase.created_at.desc())
103
+ )
104
+ feedbacks = result.scalars().all()
105
+ return [f.to_dict() for f in feedbacks]
106
+
107
+ except SQLAlchemyError as e:
108
+ logger.error(f"查询任务反馈失败: {e}")
109
+ return []
110
+
111
+ async def get_feedbacks_by_workflow(self, workflow_name: str) -> list[dict[str, Any]]:
112
+ """根据工作流名称获取所有反馈"""
113
+ if self.db is None:
114
+ return self._get_feedbacks_by_workflow_from_memory(workflow_name)
115
+
116
+ try:
117
+ session_factory = await self.db.get_session_factory()
118
+ async with session_factory() as session:
119
+ result = await session.execute(
120
+ select(FeedbackBase)
121
+ .where(FeedbackBase.workflow_name == workflow_name)
122
+ .order_by(FeedbackBase.created_at.desc())
123
+ )
124
+ feedbacks = result.scalars().all()
125
+ return [f.to_dict() for f in feedbacks]
126
+
127
+ except SQLAlchemyError as e:
128
+ logger.error(f"查询工作流反馈失败: {e}")
129
+ return []
130
+
131
+ async def get_all_feedbacks(self) -> list[dict[str, Any]]:
132
+ """获取所有反馈"""
133
+ if self.db is None:
134
+ return self._get_all_feedbacks_from_memory()
135
+
136
+ try:
137
+ session_factory = await self.db.get_session_factory()
138
+ async with session_factory() as session:
139
+ result = await session.execute(
140
+ select(FeedbackBase).order_by(FeedbackBase.created_at.desc())
141
+ )
142
+ feedbacks = result.scalars().all()
143
+ return [f.to_dict() for f in feedbacks]
144
+
145
+ except SQLAlchemyError as e:
146
+ logger.error(f"查询所有反馈失败: {e}")
147
+ return []
148
+
149
+ async def delete_feedback(self, feedback_id: str) -> bool:
150
+ """删除反馈"""
151
+ if self.db is None:
152
+ return self._delete_feedback_from_memory(feedback_id)
153
+
154
+ try:
155
+ session_factory = await self.db.get_session_factory()
156
+ async with session_factory() as session:
157
+ result = await session.execute(
158
+ select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
159
+ )
160
+ feedback = result.scalar_one_or_none()
161
+
162
+ if not feedback:
163
+ return False
164
+
165
+ await session.delete(feedback)
166
+ await session.commit()
167
+
168
+ logger.info(f"删除反馈: feedback_id={feedback_id}")
169
+ return True
170
+
171
+ except SQLAlchemyError as e:
172
+ logger.error(f"删除反馈失败: {e}")
173
+ return False
174
+
175
+ # ========== 内存存储后备方法(用于数据库未配置时) ==========
176
+
177
+ def _create_feedback_in_memory(
178
+ self,
179
+ task_id: str,
180
+ workflow_name: str,
181
+ rating: Optional[int],
182
+ comment: Optional[str],
183
+ metadata: Optional[dict[str, Any]],
184
+ ) -> dict[str, Any]:
185
+ """内存存储版本"""
186
+ feedback_id = str(uuid4())
187
+ created_at = datetime.now()
188
+
189
+ feedback_data = {
190
+ "feedback_id": feedback_id,
191
+ "task_id": task_id,
192
+ "workflow_name": workflow_name,
193
+ "rating": rating,
194
+ "comment": comment,
195
+ "metadata": metadata or {},
196
+ "created_at": created_at,
197
+ }
198
+
199
+ self._storage[feedback_id] = feedback_data
200
+
201
+ if task_id not in self._task_index:
202
+ self._task_index[task_id] = []
203
+ self._task_index[task_id].append(feedback_id)
204
+
205
+ if workflow_name not in self._workflow_index:
206
+ self._workflow_index[workflow_name] = []
207
+ self._workflow_index[workflow_name].append(feedback_id)
208
+
209
+ logger.info(f"[内存存储]创建反馈: feedback_id={feedback_id}")
210
+ return feedback_data
211
+
212
+ def _get_feedback_from_memory(self, feedback_id: str) -> Optional[dict[str, Any]]:
213
+ return self._storage.get(feedback_id)
214
+
215
+ def _get_feedbacks_by_task_from_memory(self, task_id: str) -> list[dict[str, Any]]:
216
+ feedback_ids = self._task_index.get(task_id, [])
217
+ return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
218
+
219
+ def _get_feedbacks_by_workflow_from_memory(self, workflow_name: str) -> list[dict[str, Any]]:
220
+ feedback_ids = self._workflow_index.get(workflow_name, [])
221
+ return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
222
+
223
+ def _get_all_feedbacks_from_memory(self) -> list[dict[str, Any]]:
224
+ return list(self._storage.values())
225
+
226
+ def _delete_feedback_from_memory(self, feedback_id: str) -> bool:
227
+ if feedback_id not in self._storage:
228
+ return False
229
+
230
+ feedback = self._storage[feedback_id]
231
+ task_id = feedback["task_id"]
232
+ workflow_name = feedback["workflow_name"]
233
+
234
+ if task_id in self._task_index:
235
+ self._task_index[task_id].remove(feedback_id)
236
+ if workflow_name in self._workflow_index:
237
+ self._workflow_index[workflow_name].remove(feedback_id)
238
+
239
+ del self._storage[feedback_id]
240
+ logger.info(f"[内存存储]删除反馈: feedback_id={feedback_id}")
241
+ return True
242
+
243
+
244
+ # 全局单例实例
245
+ feedback_storage = FeedbackStorage()
@@ -57,8 +57,8 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
57
57
  }
58
58
 
59
59
  workflow = Workflow(
60
- name=self.name,
61
- description=self.description,
60
+ id=self.id,
61
+ config=self.config,
62
62
  nodes=[node_map[node] for node in self.nodes],
63
63
  input_ports=[port_map[port] for port in self.input_ports],
64
64
  output_ports=[port_map[port] for port in self.output_ports],
@@ -69,6 +69,7 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
69
69
  callbacks=self.callbacks,
70
70
  task_id=task_id,
71
71
  real_trigger_node=trigger_node,
72
+ global_context=self.global_context,
72
73
  )
73
74
 
74
75
  for node in workflow.nodes:
@@ -10,6 +10,7 @@ from .context import Context
10
10
  from ..utils.register import Register
11
11
  from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
12
12
  from ..utils.workflow_clone import node_clone
13
+ from .workflow_callback import CallbackEvent
13
14
 
14
15
  if TYPE_CHECKING:
15
16
  from .workflow import Workflow
@@ -62,6 +63,10 @@ class Node(ABC):
62
63
  def database_manager(self) -> DatabaseManager:
63
64
  return self.workflow.database_manager
64
65
 
66
+ @property
67
+ def global_context(self) -> Context:
68
+ return self.workflow.global_context
69
+
65
70
  def backup(self) -> None:
66
71
  # do NOT use deepcopy here
67
72
  # self.bak_context = deepcopy(self.context)
@@ -181,4 +186,7 @@ class Node(ABC):
181
186
  def _clone(self, context: Context) -> Node:
182
187
  return node_clone(self, context)
183
188
 
189
+ async def stream_output(self, data: Any) -> None:
190
+ await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
191
+
184
192
  node_register = Register[Node]()
@@ -36,6 +36,6 @@ class QueryLLMNode(Node):
36
36
  prompt = f.read()
37
37
 
38
38
  print(f"prompt: {prompt} temperature: {temperature}")
39
- response = chat_stream(prompt, system_prompt, Model.DEEPSEEK_V3_250324, temperature)
39
+ response = chat_stream(prompt, system_prompt, Model.GEMINI, temperature)
40
40
  for chunk in response:
41
41
  yield chunk
@@ -19,6 +19,10 @@ class Trigger(Node, ABC):
19
19
  async def _run(self) -> AsyncIterator[bool]:
20
20
  ...
21
21
 
22
+ @abstractmethod
23
+ async def _stop(self) -> AsyncIterator[bool]:
24
+ ...
25
+
22
26
  def trigger(self, task_id: uuid.UUID) -> bool:
23
27
  self.prepare_output_edges(self.get_output_port_by_name('trigger'), True)
24
28
  return task_id
@@ -253,3 +253,5 @@ class A2AAPITrigger(Trigger):
253
253
  logger.error(f"Error in A2AAPITrigger._run: {e}")
254
254
  continue
255
255
 
256
+ async def _stop(self) -> AsyncIterator[bool]:
257
+ pass
@@ -33,6 +33,9 @@ class FastAPITrigger(Trigger):
33
33
  super().__init__(name)
34
34
  self.events = {}
35
35
  self.is_setup_route = False
36
+ self.app = None
37
+ self.route_path = None
38
+ self.route_method = None
36
39
 
37
40
  @staticmethod
38
41
  def serialize_result(result: Any):
@@ -141,6 +144,11 @@ class FastAPITrigger(Trigger):
141
144
  async def handler(request: Request):
142
145
  return await self.handle_request(request, data_type, extractor, is_stream)
143
146
 
147
+ # Save route information for cleanup
148
+ self.app = app
149
+ self.route_path = path
150
+ self.route_method = method.upper()
151
+
144
152
  if method == "GET":
145
153
  app.get(path)(handler)
146
154
  elif method == "POST":
@@ -167,3 +175,27 @@ class FastAPITrigger(Trigger):
167
175
  logger.error(f"Error in FastAPITrigger._run: {e}")
168
176
  continue
169
177
 
178
+ async def _stop(self) -> AsyncIterator[bool]:
179
+ if self.is_setup_route:
180
+ # Remove the route from the app
181
+ if self.app and self.route_path and self.route_method:
182
+ # Find and remove matching route
183
+ routes_to_remove = []
184
+ for route in self.app.routes:
185
+ if hasattr(route, "path") and hasattr(route, "methods"):
186
+ if route.path == self.route_path and self.route_method in route.methods:
187
+ routes_to_remove.append(route)
188
+
189
+ # Remove found routes
190
+ for route in routes_to_remove:
191
+ try:
192
+ self.app.routes.remove(route)
193
+ logger.info(f"Removed route {self.route_method} {self.route_path} from FastAPI app")
194
+ except ValueError:
195
+ logger.warning(f"Route {self.route_method} {self.route_path} not found in app.routes")
196
+
197
+ # Reset route information
198
+ self.app = None
199
+ self.route_path = None
200
+ self.route_method = None
201
+ self.is_setup_route = False
@@ -42,3 +42,6 @@ class KafkaAPITrigger(Trigger):
42
42
  trigger = await self.trigger_queue.get()
43
43
  self.prepare_output_edges(self.get_output_port_by_name('data'), trigger['data'])
44
44
  yield self.trigger(trigger['id'])
45
+
46
+ async def _stop(self) -> AsyncIterator[bool]:
47
+ pass
@@ -17,4 +17,7 @@ class OnceTrigger(Trigger):
17
17
  super().__init__(name)
18
18
 
19
19
  async def _run(self) -> AsyncIterator[bool]:
20
- yield self.trigger(uuid.uuid4())
20
+ yield self.trigger(uuid.uuid4())
21
+
22
+ async def _stop(self) -> AsyncIterator[bool]:
23
+ pass
@@ -23,4 +23,7 @@ class PeriodTrigger(Trigger):
23
23
  async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
24
24
  for _ in range(times):
25
25
  await asyncio.sleep(period)
26
- yield uuid.uuid4()
26
+ yield uuid.uuid4()
27
+
28
+ async def _stop(self) -> AsyncIterator[bool]:
29
+ pass
@@ -120,12 +120,16 @@ class WebSocketAPITrigger(Trigger):
120
120
  self.result_queues[task_id] = asyncio.Queue()
121
121
  self.stream_queues[task_id] = asyncio.Queue()
122
122
 
123
- try:
124
- converted_data = data_type(**message_data)
125
- except Exception as e:
126
- error_msg = {"error": f"Failed to convert data: {str(e)}"}
127
- await websocket.send_text(json.dumps(error_msg))
128
- return
123
+ if data_type is Any:
124
+ converted_data = message_data
125
+ else:
126
+ try:
127
+ # TODO: message_data is Message, need to convert to dict
128
+ converted_data = data_type(**message_data)
129
+ except Exception as e:
130
+ error_msg = {"error": f"Failed to convert data: {str(e)}"}
131
+ await websocket.send_text(json.dumps(error_msg))
132
+ return
129
133
 
130
134
  # Always start background task to handle stream output
131
135
  asyncio.create_task(self.handle_stream_output(websocket, task_id))
@@ -143,16 +147,14 @@ class WebSocketAPITrigger(Trigger):
143
147
 
144
148
  try:
145
149
  while True:
146
- # Receive message from client
147
- data = await websocket.receive_text()
150
+ data = await websocket.receive()
148
151
  try:
149
- message = json.loads(data)
150
-
152
+ # message = json.loads(data)
151
153
  # Handle the message and trigger workflow
152
154
  await self.handle_websocket_message(
153
155
  websocket,
154
156
  data_type,
155
- message
157
+ data
156
158
  )
157
159
  except json.JSONDecodeError:
158
160
  error_msg = {"error": "Invalid JSON format"}
@@ -182,3 +184,5 @@ class WebSocketAPITrigger(Trigger):
182
184
  logger.error(f"Error in WebSocketAPITrigger._run: {e}")
183
185
  continue
184
186
 
187
+ async def _stop(self) -> AsyncIterator[bool]:
188
+ pass
@@ -12,12 +12,14 @@ from .edge import Edge
12
12
  from ..db.database import DatabaseManager
13
13
  from ..utils.workflow_clone import workflow_clone
14
14
  from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
15
+ from .workflow_config import WorkflowConfig
16
+ from .context import Context
15
17
 
16
18
  class Workflow:
17
19
  def __init__(
18
20
  self,
19
- name: str,
20
- description: str,
21
+ id: uuid.UUID,
22
+ config: WorkflowConfig,
21
23
  nodes: list[Node],
22
24
  input_ports: list[Port],
23
25
  output_ports: list[Port],
@@ -30,9 +32,12 @@ class Workflow:
30
32
  # for run
31
33
  task_id: uuid.UUID = None,
32
34
  real_trigger_node: Trigger = None,
35
+
36
+ # global variables
37
+ global_context: Context = None,
33
38
  ) -> None:
34
- self.name = name
35
- self.description = description
39
+ self.id = id
40
+ self.config = config
36
41
  self.nodes = nodes
37
42
  self.ready_nodes: list[Node] = []
38
43
  self.input_ports = input_ports
@@ -47,7 +52,20 @@ class Workflow:
47
52
  self.callbacks = callbacks
48
53
  self.task_id = task_id
49
54
  self.real_trigger_node = real_trigger_node
55
+ self.global_context = global_context
50
56
  self._validate()
57
+
58
+ @property
59
+ def name(self) -> str:
60
+ return self.config.name
61
+
62
+ @property
63
+ def version(self) -> str:
64
+ return self.config.version
65
+
66
+ @property
67
+ def description(self) -> str:
68
+ return self.config.description
51
69
 
52
70
  def register_callback(self, callback: WorkflowCallback) -> None:
53
71
  self.callbacks.append(callback)
@@ -186,6 +204,10 @@ class Workflow:
186
204
  if tasks:
187
205
  await asyncio.gather(*tasks)
188
206
 
207
+ async def stop(self):
208
+ trigger = self.get_trigger_node()
209
+ await trigger._stop()
210
+
189
211
  def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
190
212
  trigger = self.get_trigger_node()
191
213
  task_id = uuid.uuid4()