service-forge 0.1.11__py3-none-any.whl → 0.1.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/http_api.py +4 -0
- service_forge/api/routers/feedback/feedback_router.py +148 -0
- service_forge/api/routers/service/service_router.py +22 -32
- service_forge/current_service.py +14 -0
- service_forge/db/database.py +29 -32
- service_forge/db/migrations/feedback_migration.py +154 -0
- service_forge/db/models/__init__.py +0 -0
- service_forge/db/models/feedback.py +33 -0
- service_forge/llm/__init__.py +5 -0
- service_forge/model/feedback.py +30 -0
- service_forge/service.py +118 -126
- service_forge/service_config.py +42 -156
- service_forge/sft/config/injector.py +33 -23
- service_forge/sft/config/sft_config.py +55 -8
- service_forge/storage/__init__.py +5 -0
- service_forge/storage/feedback_storage.py +245 -0
- service_forge/utils/workflow_clone.py +3 -2
- service_forge/workflow/node.py +8 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +1 -1
- service_forge/workflow/trigger.py +4 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +2 -0
- service_forge/workflow/triggers/fast_api_trigger.py +32 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +3 -0
- service_forge/workflow/triggers/once_trigger.py +4 -1
- service_forge/workflow/triggers/period_trigger.py +4 -1
- service_forge/workflow/triggers/websocket_api_trigger.py +15 -11
- service_forge/workflow/workflow.py +26 -4
- service_forge/workflow/workflow_config.py +66 -0
- service_forge/workflow/workflow_factory.py +86 -85
- service_forge/workflow/workflow_group.py +33 -9
- {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/METADATA +1 -1
- {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/RECORD +34 -26
- service_forge/api/routers/service/__init__.py +0 -4
- {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/WHEEL +0 -0
- {service_forge-0.1.11.dist-info → service_forge-0.1.21.dist-info}/entry_points.txt +0 -0
|
@@ -21,6 +21,8 @@ class SftConfig:
|
|
|
21
21
|
"inject_postgres_port": "Postgres port for services",
|
|
22
22
|
"inject_postgres_user": "Postgres user for services",
|
|
23
23
|
"inject_postgres_password": "Postgres password for services",
|
|
24
|
+
"inject_feedback_api_url": "Feedback API URL for services",
|
|
25
|
+
"inject_feedback_api_timeout": "Feedback API timeout for services",
|
|
24
26
|
"deepseek_api_key": "DeepSeek API key",
|
|
25
27
|
"deepseek_base_url": "DeepSeek base URL",
|
|
26
28
|
}
|
|
@@ -52,6 +54,9 @@ class SftConfig:
|
|
|
52
54
|
inject_redis_port: int = 6379,
|
|
53
55
|
inject_redis_password: str = "rDdM2Y2gX9",
|
|
54
56
|
|
|
57
|
+
inject_feedback_api_url: str = "http://vps.shiweinan.com:37919/api/v1/feedback",
|
|
58
|
+
inject_feedback_api_timeout: int = 5,
|
|
59
|
+
|
|
55
60
|
deepseek_api_key: str = "82c9df22-f6ed-411e-90d7-c5255376b7ca",
|
|
56
61
|
deepseek_base_url: str = "https://ark.cn-beijing.volces.com/api/v3",
|
|
57
62
|
):
|
|
@@ -80,6 +85,9 @@ class SftConfig:
|
|
|
80
85
|
self.inject_redis_port = inject_redis_port
|
|
81
86
|
self.inject_redis_password = inject_redis_password
|
|
82
87
|
|
|
88
|
+
self.inject_feedback_api_url = inject_feedback_api_url
|
|
89
|
+
self.inject_feedback_api_timeout = inject_feedback_api_timeout
|
|
90
|
+
|
|
83
91
|
self.deepseek_api_key = deepseek_api_key
|
|
84
92
|
self.deepseek_base_url = deepseek_base_url
|
|
85
93
|
|
|
@@ -91,10 +99,23 @@ class SftConfig:
|
|
|
91
99
|
def upload_timeout(self) -> int:
|
|
92
100
|
return 300 # 5 minutes default timeout
|
|
93
101
|
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
sig = inspect.signature(
|
|
97
|
-
|
|
102
|
+
def get_config_keys(self) -> list[str]:
|
|
103
|
+
# Get initial configuration parameters from __init__ method
|
|
104
|
+
sig = inspect.signature(self.__class__.__init__)
|
|
105
|
+
init_keys = [param for param in sig.parameters.keys() if param != 'self']
|
|
106
|
+
|
|
107
|
+
# Get all instance attributes (including dynamically added configurations)
|
|
108
|
+
instance_keys = []
|
|
109
|
+
for attr_name in dir(self):
|
|
110
|
+
# Exclude special methods, private attributes, class attributes, and methods
|
|
111
|
+
if (not attr_name.startswith('_') and
|
|
112
|
+
not callable(getattr(self, attr_name)) and
|
|
113
|
+
attr_name not in ['CONFIG_ROOT', 'CONFIG_DESCRIPTIONS']):
|
|
114
|
+
instance_keys.append(attr_name)
|
|
115
|
+
|
|
116
|
+
# Merge and deduplicate, maintaining order (initial configs first, then dynamically added)
|
|
117
|
+
all_keys = list(dict.fromkeys(init_keys + instance_keys))
|
|
118
|
+
return all_keys
|
|
98
119
|
|
|
99
120
|
@property
|
|
100
121
|
def config_file_path(self) -> Path:
|
|
@@ -105,13 +126,30 @@ class SftConfig:
|
|
|
105
126
|
|
|
106
127
|
def to_dict(self) -> dict:
|
|
107
128
|
config_keys = self.get_config_keys()
|
|
108
|
-
|
|
129
|
+
result = {}
|
|
130
|
+
for key in config_keys:
|
|
131
|
+
value = getattr(self, key)
|
|
132
|
+
# Convert Path objects to strings for JSON serialization
|
|
133
|
+
if isinstance(value, Path):
|
|
134
|
+
value = str(value)
|
|
135
|
+
result[key] = value
|
|
136
|
+
return result
|
|
109
137
|
|
|
110
138
|
def from_dict(self, data: dict) -> None:
|
|
111
|
-
|
|
112
|
-
|
|
139
|
+
# Get initial configuration parameters from __init__ method
|
|
140
|
+
sig = inspect.signature(self.__class__.__init__)
|
|
141
|
+
init_keys = [param for param in sig.parameters.keys() if param != 'self']
|
|
142
|
+
|
|
143
|
+
# First, set all initial configuration parameters
|
|
144
|
+
for key in init_keys:
|
|
113
145
|
if key in data:
|
|
114
146
|
setattr(self, key, data[key])
|
|
147
|
+
|
|
148
|
+
# Then, handle any additional keys that might be dynamically added configurations
|
|
149
|
+
for key, value in data.items():
|
|
150
|
+
if key not in init_keys and key not in ['CONFIG_ROOT', 'CONFIG_DESCRIPTIONS']:
|
|
151
|
+
# This might be a dynamically added configuration
|
|
152
|
+
setattr(self, key, value)
|
|
115
153
|
|
|
116
154
|
def save(self) -> None:
|
|
117
155
|
self.ensure_config_dir()
|
|
@@ -121,13 +159,22 @@ class SftConfig:
|
|
|
121
159
|
def get(self, key: str, default: Optional[str] = None) -> Optional[str]:
|
|
122
160
|
return getattr(self, key, default)
|
|
123
161
|
|
|
124
|
-
def set(self, key: str, value: str) -> None:
|
|
162
|
+
def set(self, key: str, value: str, description: Optional[str] = None) -> None:
|
|
125
163
|
if key in ["config_root"]:
|
|
126
164
|
raise ValueError(f"{key} is read-only")
|
|
127
165
|
if hasattr(self, key):
|
|
128
166
|
setattr(self, key, value)
|
|
167
|
+
if description:
|
|
168
|
+
self.CONFIG_DESCRIPTIONS[key] = description
|
|
129
169
|
else:
|
|
130
170
|
raise ValueError(f"Unknown config key: {key}")
|
|
171
|
+
|
|
172
|
+
def add(self, key: str, value: str, description: Optional[str] = None) -> None:
|
|
173
|
+
if hasattr(self, key):
|
|
174
|
+
raise ValueError(f"{key} already exists")
|
|
175
|
+
setattr(self, key, value)
|
|
176
|
+
if description:
|
|
177
|
+
self.CONFIG_DESCRIPTIONS[key] = description
|
|
131
178
|
|
|
132
179
|
def update(self, updates: dict) -> None:
|
|
133
180
|
for key, value in updates.items():
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from sqlalchemy import select, and_
|
|
6
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
7
|
+
|
|
8
|
+
from ..db.models.feedback import FeedbackBase
|
|
9
|
+
from ..db.database import DatabaseManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeedbackStorage:
|
|
13
|
+
"""反馈存储管理器 - 使用 PostgreSQL 数据库存储"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, database_manager: DatabaseManager = None):
|
|
16
|
+
self.database_manager = database_manager
|
|
17
|
+
self._db = None
|
|
18
|
+
# 内存存储后备(用于数据库未配置时)
|
|
19
|
+
self._storage: dict[str, dict[str, Any]] = {}
|
|
20
|
+
self._task_index: dict[str, list[str]] = {}
|
|
21
|
+
self._workflow_index: dict[str, list[str]] = {}
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def db(self):
|
|
25
|
+
"""延迟获取数据库连接"""
|
|
26
|
+
if self._db is None and self.database_manager is not None:
|
|
27
|
+
self._db = self.database_manager.get_default_postgres_database()
|
|
28
|
+
return self._db
|
|
29
|
+
|
|
30
|
+
async def create_feedback(
|
|
31
|
+
self,
|
|
32
|
+
task_id: str,
|
|
33
|
+
workflow_name: str,
|
|
34
|
+
rating: Optional[int] = None,
|
|
35
|
+
comment: Optional[str] = None,
|
|
36
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
37
|
+
) -> dict[str, Any]:
|
|
38
|
+
"""创建新的反馈记录"""
|
|
39
|
+
if self.db is None:
|
|
40
|
+
logger.warning("数据库未初始化,使用内存存储(数据将在重启后丢失)")
|
|
41
|
+
return self._create_feedback_in_memory(task_id, workflow_name, rating, comment, metadata)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
feedback_id = uuid4()
|
|
45
|
+
created_at = datetime.now()
|
|
46
|
+
|
|
47
|
+
feedback = FeedbackBase(
|
|
48
|
+
feedback_id=feedback_id,
|
|
49
|
+
task_id=task_id,
|
|
50
|
+
workflow_name=workflow_name,
|
|
51
|
+
rating=rating,
|
|
52
|
+
comment=comment,
|
|
53
|
+
extra_metadata=metadata or {},
|
|
54
|
+
created_at=created_at,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
session_factory = await self.db.get_session_factory()
|
|
58
|
+
async with session_factory() as session:
|
|
59
|
+
session.add(feedback)
|
|
60
|
+
await session.commit()
|
|
61
|
+
await session.refresh(feedback)
|
|
62
|
+
|
|
63
|
+
logger.info(f"创建反馈: feedback_id={feedback_id}, task_id={task_id}, workflow={workflow_name}")
|
|
64
|
+
return feedback.to_dict()
|
|
65
|
+
|
|
66
|
+
except SQLAlchemyError as e:
|
|
67
|
+
logger.error(f"数据库创建反馈失败: {e}")
|
|
68
|
+
raise
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"创建反馈失败: {e}")
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
async def get_feedback(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
74
|
+
"""根据反馈ID获取反馈"""
|
|
75
|
+
if self.db is None:
|
|
76
|
+
return self._get_feedback_from_memory(feedback_id)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
session_factory = await self.db.get_session_factory()
|
|
80
|
+
async with session_factory() as session:
|
|
81
|
+
result = await session.execute(
|
|
82
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
83
|
+
)
|
|
84
|
+
feedback = result.scalar_one_or_none()
|
|
85
|
+
return feedback.to_dict() if feedback else None
|
|
86
|
+
|
|
87
|
+
except SQLAlchemyError as e:
|
|
88
|
+
logger.error(f"查询反馈失败: {e}")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
async def get_feedbacks_by_task(self, task_id: str) -> list[dict[str, Any]]:
|
|
92
|
+
"""根据任务ID获取所有反馈"""
|
|
93
|
+
if self.db is None:
|
|
94
|
+
return self._get_feedbacks_by_task_from_memory(task_id)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
session_factory = await self.db.get_session_factory()
|
|
98
|
+
async with session_factory() as session:
|
|
99
|
+
result = await session.execute(
|
|
100
|
+
select(FeedbackBase)
|
|
101
|
+
.where(FeedbackBase.task_id == task_id)
|
|
102
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
103
|
+
)
|
|
104
|
+
feedbacks = result.scalars().all()
|
|
105
|
+
return [f.to_dict() for f in feedbacks]
|
|
106
|
+
|
|
107
|
+
except SQLAlchemyError as e:
|
|
108
|
+
logger.error(f"查询任务反馈失败: {e}")
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
async def get_feedbacks_by_workflow(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
112
|
+
"""根据工作流名称获取所有反馈"""
|
|
113
|
+
if self.db is None:
|
|
114
|
+
return self._get_feedbacks_by_workflow_from_memory(workflow_name)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
session_factory = await self.db.get_session_factory()
|
|
118
|
+
async with session_factory() as session:
|
|
119
|
+
result = await session.execute(
|
|
120
|
+
select(FeedbackBase)
|
|
121
|
+
.where(FeedbackBase.workflow_name == workflow_name)
|
|
122
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
123
|
+
)
|
|
124
|
+
feedbacks = result.scalars().all()
|
|
125
|
+
return [f.to_dict() for f in feedbacks]
|
|
126
|
+
|
|
127
|
+
except SQLAlchemyError as e:
|
|
128
|
+
logger.error(f"查询工作流反馈失败: {e}")
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
async def get_all_feedbacks(self) -> list[dict[str, Any]]:
|
|
132
|
+
"""获取所有反馈"""
|
|
133
|
+
if self.db is None:
|
|
134
|
+
return self._get_all_feedbacks_from_memory()
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
session_factory = await self.db.get_session_factory()
|
|
138
|
+
async with session_factory() as session:
|
|
139
|
+
result = await session.execute(
|
|
140
|
+
select(FeedbackBase).order_by(FeedbackBase.created_at.desc())
|
|
141
|
+
)
|
|
142
|
+
feedbacks = result.scalars().all()
|
|
143
|
+
return [f.to_dict() for f in feedbacks]
|
|
144
|
+
|
|
145
|
+
except SQLAlchemyError as e:
|
|
146
|
+
logger.error(f"查询所有反馈失败: {e}")
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
150
|
+
"""删除反馈"""
|
|
151
|
+
if self.db is None:
|
|
152
|
+
return self._delete_feedback_from_memory(feedback_id)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
session_factory = await self.db.get_session_factory()
|
|
156
|
+
async with session_factory() as session:
|
|
157
|
+
result = await session.execute(
|
|
158
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
159
|
+
)
|
|
160
|
+
feedback = result.scalar_one_or_none()
|
|
161
|
+
|
|
162
|
+
if not feedback:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
await session.delete(feedback)
|
|
166
|
+
await session.commit()
|
|
167
|
+
|
|
168
|
+
logger.info(f"删除反馈: feedback_id={feedback_id}")
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
except SQLAlchemyError as e:
|
|
172
|
+
logger.error(f"删除反馈失败: {e}")
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
# ========== 内存存储后备方法(用于数据库未配置时) ==========
|
|
176
|
+
|
|
177
|
+
def _create_feedback_in_memory(
|
|
178
|
+
self,
|
|
179
|
+
task_id: str,
|
|
180
|
+
workflow_name: str,
|
|
181
|
+
rating: Optional[int],
|
|
182
|
+
comment: Optional[str],
|
|
183
|
+
metadata: Optional[dict[str, Any]],
|
|
184
|
+
) -> dict[str, Any]:
|
|
185
|
+
"""内存存储版本"""
|
|
186
|
+
feedback_id = str(uuid4())
|
|
187
|
+
created_at = datetime.now()
|
|
188
|
+
|
|
189
|
+
feedback_data = {
|
|
190
|
+
"feedback_id": feedback_id,
|
|
191
|
+
"task_id": task_id,
|
|
192
|
+
"workflow_name": workflow_name,
|
|
193
|
+
"rating": rating,
|
|
194
|
+
"comment": comment,
|
|
195
|
+
"metadata": metadata or {},
|
|
196
|
+
"created_at": created_at,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
self._storage[feedback_id] = feedback_data
|
|
200
|
+
|
|
201
|
+
if task_id not in self._task_index:
|
|
202
|
+
self._task_index[task_id] = []
|
|
203
|
+
self._task_index[task_id].append(feedback_id)
|
|
204
|
+
|
|
205
|
+
if workflow_name not in self._workflow_index:
|
|
206
|
+
self._workflow_index[workflow_name] = []
|
|
207
|
+
self._workflow_index[workflow_name].append(feedback_id)
|
|
208
|
+
|
|
209
|
+
logger.info(f"[内存存储]创建反馈: feedback_id={feedback_id}")
|
|
210
|
+
return feedback_data
|
|
211
|
+
|
|
212
|
+
def _get_feedback_from_memory(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
213
|
+
return self._storage.get(feedback_id)
|
|
214
|
+
|
|
215
|
+
def _get_feedbacks_by_task_from_memory(self, task_id: str) -> list[dict[str, Any]]:
|
|
216
|
+
feedback_ids = self._task_index.get(task_id, [])
|
|
217
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
218
|
+
|
|
219
|
+
def _get_feedbacks_by_workflow_from_memory(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
220
|
+
feedback_ids = self._workflow_index.get(workflow_name, [])
|
|
221
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
222
|
+
|
|
223
|
+
def _get_all_feedbacks_from_memory(self) -> list[dict[str, Any]]:
|
|
224
|
+
return list(self._storage.values())
|
|
225
|
+
|
|
226
|
+
def _delete_feedback_from_memory(self, feedback_id: str) -> bool:
|
|
227
|
+
if feedback_id not in self._storage:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
feedback = self._storage[feedback_id]
|
|
231
|
+
task_id = feedback["task_id"]
|
|
232
|
+
workflow_name = feedback["workflow_name"]
|
|
233
|
+
|
|
234
|
+
if task_id in self._task_index:
|
|
235
|
+
self._task_index[task_id].remove(feedback_id)
|
|
236
|
+
if workflow_name in self._workflow_index:
|
|
237
|
+
self._workflow_index[workflow_name].remove(feedback_id)
|
|
238
|
+
|
|
239
|
+
del self._storage[feedback_id]
|
|
240
|
+
logger.info(f"[内存存储]删除反馈: feedback_id={feedback_id}")
|
|
241
|
+
return True
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# 全局单例实例
|
|
245
|
+
feedback_storage = FeedbackStorage()
|
|
@@ -57,8 +57,8 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
|
|
|
57
57
|
}
|
|
58
58
|
|
|
59
59
|
workflow = Workflow(
|
|
60
|
-
|
|
61
|
-
|
|
60
|
+
id=self.id,
|
|
61
|
+
config=self.config,
|
|
62
62
|
nodes=[node_map[node] for node in self.nodes],
|
|
63
63
|
input_ports=[port_map[port] for port in self.input_ports],
|
|
64
64
|
output_ports=[port_map[port] for port in self.output_ports],
|
|
@@ -69,6 +69,7 @@ def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) ->
|
|
|
69
69
|
callbacks=self.callbacks,
|
|
70
70
|
task_id=task_id,
|
|
71
71
|
real_trigger_node=trigger_node,
|
|
72
|
+
global_context=self.global_context,
|
|
72
73
|
)
|
|
73
74
|
|
|
74
75
|
for node in workflow.nodes:
|
service_forge/workflow/node.py
CHANGED
|
@@ -10,6 +10,7 @@ from .context import Context
|
|
|
10
10
|
from ..utils.register import Register
|
|
11
11
|
from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
12
12
|
from ..utils.workflow_clone import node_clone
|
|
13
|
+
from .workflow_callback import CallbackEvent
|
|
13
14
|
|
|
14
15
|
if TYPE_CHECKING:
|
|
15
16
|
from .workflow import Workflow
|
|
@@ -62,6 +63,10 @@ class Node(ABC):
|
|
|
62
63
|
def database_manager(self) -> DatabaseManager:
|
|
63
64
|
return self.workflow.database_manager
|
|
64
65
|
|
|
66
|
+
@property
|
|
67
|
+
def global_context(self) -> Context:
|
|
68
|
+
return self.workflow.global_context
|
|
69
|
+
|
|
65
70
|
def backup(self) -> None:
|
|
66
71
|
# do NOT use deepcopy here
|
|
67
72
|
# self.bak_context = deepcopy(self.context)
|
|
@@ -181,4 +186,7 @@ class Node(ABC):
|
|
|
181
186
|
def _clone(self, context: Context) -> Node:
|
|
182
187
|
return node_clone(self, context)
|
|
183
188
|
|
|
189
|
+
async def stream_output(self, data: Any) -> None:
|
|
190
|
+
await self.workflow.call_callbacks(CallbackEvent.ON_NODE_STREAM_OUTPUT, node=self, output=data)
|
|
191
|
+
|
|
184
192
|
node_register = Register[Node]()
|
|
@@ -36,6 +36,6 @@ class QueryLLMNode(Node):
|
|
|
36
36
|
prompt = f.read()
|
|
37
37
|
|
|
38
38
|
print(f"prompt: {prompt} temperature: {temperature}")
|
|
39
|
-
response = chat_stream(prompt, system_prompt, Model.
|
|
39
|
+
response = chat_stream(prompt, system_prompt, Model.GEMINI, temperature)
|
|
40
40
|
for chunk in response:
|
|
41
41
|
yield chunk
|
|
@@ -19,6 +19,10 @@ class Trigger(Node, ABC):
|
|
|
19
19
|
async def _run(self) -> AsyncIterator[bool]:
|
|
20
20
|
...
|
|
21
21
|
|
|
22
|
+
@abstractmethod
|
|
23
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
24
|
+
...
|
|
25
|
+
|
|
22
26
|
def trigger(self, task_id: uuid.UUID) -> bool:
|
|
23
27
|
self.prepare_output_edges(self.get_output_port_by_name('trigger'), True)
|
|
24
28
|
return task_id
|
|
@@ -33,6 +33,9 @@ class FastAPITrigger(Trigger):
|
|
|
33
33
|
super().__init__(name)
|
|
34
34
|
self.events = {}
|
|
35
35
|
self.is_setup_route = False
|
|
36
|
+
self.app = None
|
|
37
|
+
self.route_path = None
|
|
38
|
+
self.route_method = None
|
|
36
39
|
|
|
37
40
|
@staticmethod
|
|
38
41
|
def serialize_result(result: Any):
|
|
@@ -141,6 +144,11 @@ class FastAPITrigger(Trigger):
|
|
|
141
144
|
async def handler(request: Request):
|
|
142
145
|
return await self.handle_request(request, data_type, extractor, is_stream)
|
|
143
146
|
|
|
147
|
+
# Save route information for cleanup
|
|
148
|
+
self.app = app
|
|
149
|
+
self.route_path = path
|
|
150
|
+
self.route_method = method.upper()
|
|
151
|
+
|
|
144
152
|
if method == "GET":
|
|
145
153
|
app.get(path)(handler)
|
|
146
154
|
elif method == "POST":
|
|
@@ -167,3 +175,27 @@ class FastAPITrigger(Trigger):
|
|
|
167
175
|
logger.error(f"Error in FastAPITrigger._run: {e}")
|
|
168
176
|
continue
|
|
169
177
|
|
|
178
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
179
|
+
if self.is_setup_route:
|
|
180
|
+
# Remove the route from the app
|
|
181
|
+
if self.app and self.route_path and self.route_method:
|
|
182
|
+
# Find and remove matching route
|
|
183
|
+
routes_to_remove = []
|
|
184
|
+
for route in self.app.routes:
|
|
185
|
+
if hasattr(route, "path") and hasattr(route, "methods"):
|
|
186
|
+
if route.path == self.route_path and self.route_method in route.methods:
|
|
187
|
+
routes_to_remove.append(route)
|
|
188
|
+
|
|
189
|
+
# Remove found routes
|
|
190
|
+
for route in routes_to_remove:
|
|
191
|
+
try:
|
|
192
|
+
self.app.routes.remove(route)
|
|
193
|
+
logger.info(f"Removed route {self.route_method} {self.route_path} from FastAPI app")
|
|
194
|
+
except ValueError:
|
|
195
|
+
logger.warning(f"Route {self.route_method} {self.route_path} not found in app.routes")
|
|
196
|
+
|
|
197
|
+
# Reset route information
|
|
198
|
+
self.app = None
|
|
199
|
+
self.route_path = None
|
|
200
|
+
self.route_method = None
|
|
201
|
+
self.is_setup_route = False
|
|
@@ -23,4 +23,7 @@ class PeriodTrigger(Trigger):
|
|
|
23
23
|
async def _run(self, times: int, period: float) -> AsyncIterator[bool]:
|
|
24
24
|
for _ in range(times):
|
|
25
25
|
await asyncio.sleep(period)
|
|
26
|
-
yield uuid.uuid4()
|
|
26
|
+
yield uuid.uuid4()
|
|
27
|
+
|
|
28
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
29
|
+
pass
|
|
@@ -120,12 +120,16 @@ class WebSocketAPITrigger(Trigger):
|
|
|
120
120
|
self.result_queues[task_id] = asyncio.Queue()
|
|
121
121
|
self.stream_queues[task_id] = asyncio.Queue()
|
|
122
122
|
|
|
123
|
-
|
|
124
|
-
converted_data =
|
|
125
|
-
|
|
126
|
-
|
|
127
|
-
|
|
128
|
-
|
|
123
|
+
if data_type is Any:
|
|
124
|
+
converted_data = message_data
|
|
125
|
+
else:
|
|
126
|
+
try:
|
|
127
|
+
# TODO: message_data is Message, need to convert to dict
|
|
128
|
+
converted_data = data_type(**message_data)
|
|
129
|
+
except Exception as e:
|
|
130
|
+
error_msg = {"error": f"Failed to convert data: {str(e)}"}
|
|
131
|
+
await websocket.send_text(json.dumps(error_msg))
|
|
132
|
+
return
|
|
129
133
|
|
|
130
134
|
# Always start background task to handle stream output
|
|
131
135
|
asyncio.create_task(self.handle_stream_output(websocket, task_id))
|
|
@@ -143,16 +147,14 @@ class WebSocketAPITrigger(Trigger):
|
|
|
143
147
|
|
|
144
148
|
try:
|
|
145
149
|
while True:
|
|
146
|
-
|
|
147
|
-
data = await websocket.receive_text()
|
|
150
|
+
data = await websocket.receive()
|
|
148
151
|
try:
|
|
149
|
-
message = json.loads(data)
|
|
150
|
-
|
|
152
|
+
# message = json.loads(data)
|
|
151
153
|
# Handle the message and trigger workflow
|
|
152
154
|
await self.handle_websocket_message(
|
|
153
155
|
websocket,
|
|
154
156
|
data_type,
|
|
155
|
-
|
|
157
|
+
data
|
|
156
158
|
)
|
|
157
159
|
except json.JSONDecodeError:
|
|
158
160
|
error_msg = {"error": "Invalid JSON format"}
|
|
@@ -182,3 +184,5 @@ class WebSocketAPITrigger(Trigger):
|
|
|
182
184
|
logger.error(f"Error in WebSocketAPITrigger._run: {e}")
|
|
183
185
|
continue
|
|
184
186
|
|
|
187
|
+
async def _stop(self) -> AsyncIterator[bool]:
|
|
188
|
+
pass
|
|
@@ -12,12 +12,14 @@ from .edge import Edge
|
|
|
12
12
|
from ..db.database import DatabaseManager
|
|
13
13
|
from ..utils.workflow_clone import workflow_clone
|
|
14
14
|
from .workflow_callback import WorkflowCallback, BuiltinWorkflowCallback, CallbackEvent
|
|
15
|
+
from .workflow_config import WorkflowConfig
|
|
16
|
+
from .context import Context
|
|
15
17
|
|
|
16
18
|
class Workflow:
|
|
17
19
|
def __init__(
|
|
18
20
|
self,
|
|
19
|
-
|
|
20
|
-
|
|
21
|
+
id: uuid.UUID,
|
|
22
|
+
config: WorkflowConfig,
|
|
21
23
|
nodes: list[Node],
|
|
22
24
|
input_ports: list[Port],
|
|
23
25
|
output_ports: list[Port],
|
|
@@ -30,9 +32,12 @@ class Workflow:
|
|
|
30
32
|
# for run
|
|
31
33
|
task_id: uuid.UUID = None,
|
|
32
34
|
real_trigger_node: Trigger = None,
|
|
35
|
+
|
|
36
|
+
# global variables
|
|
37
|
+
global_context: Context = None,
|
|
33
38
|
) -> None:
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
39
|
+
self.id = id
|
|
40
|
+
self.config = config
|
|
36
41
|
self.nodes = nodes
|
|
37
42
|
self.ready_nodes: list[Node] = []
|
|
38
43
|
self.input_ports = input_ports
|
|
@@ -47,7 +52,20 @@ class Workflow:
|
|
|
47
52
|
self.callbacks = callbacks
|
|
48
53
|
self.task_id = task_id
|
|
49
54
|
self.real_trigger_node = real_trigger_node
|
|
55
|
+
self.global_context = global_context
|
|
50
56
|
self._validate()
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def name(self) -> str:
|
|
60
|
+
return self.config.name
|
|
61
|
+
|
|
62
|
+
@property
|
|
63
|
+
def version(self) -> str:
|
|
64
|
+
return self.config.version
|
|
65
|
+
|
|
66
|
+
@property
|
|
67
|
+
def description(self) -> str:
|
|
68
|
+
return self.config.description
|
|
51
69
|
|
|
52
70
|
def register_callback(self, callback: WorkflowCallback) -> None:
|
|
53
71
|
self.callbacks.append(callback)
|
|
@@ -186,6 +204,10 @@ class Workflow:
|
|
|
186
204
|
if tasks:
|
|
187
205
|
await asyncio.gather(*tasks)
|
|
188
206
|
|
|
207
|
+
async def stop(self):
|
|
208
|
+
trigger = self.get_trigger_node()
|
|
209
|
+
await trigger._stop()
|
|
210
|
+
|
|
189
211
|
def trigger(self, trigger_name: str, **kwargs) -> uuid.UUID:
|
|
190
212
|
trigger = self.get_trigger_node()
|
|
191
213
|
task_id = uuid.uuid4()
|