service-forge 0.1.18__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +152 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/feedback/feedback_router.py +148 -0
- service_forge/api/routers/service/service_router.py +127 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/current_service.py +14 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +237 -0
- service_forge/db/migrations/feedback_migration.py +154 -0
- service_forge/db/models/__init__.py +0 -0
- service_forge/db/models/feedback.py +33 -0
- service_forge/llm/__init__.py +67 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/feedback.py +30 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +280 -0
- service_forge/service_config.py +44 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +129 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +200 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/storage/__init__.py +5 -0
- service_forge/storage/feedback_storage.py +245 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +28 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +257 -0
- service_forge/workflow/triggers/fast_api_trigger.py +201 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +47 -0
- service_forge/workflow/triggers/once_trigger.py +23 -0
- service_forge/workflow/triggers/period_trigger.py +29 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +189 -0
- service_forge/workflow/workflow.py +227 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_config.py +66 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +51 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.18.dist-info/METADATA +98 -0
- service_forge-0.1.18.dist-info/RECORD +83 -0
- service_forge-0.1.18.dist-info/WHEEL +4 -0
- service_forge-0.1.18.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
from typing import Optional, Any
|
|
2
|
+
from datetime import datetime
|
|
3
|
+
from uuid import uuid4
|
|
4
|
+
from loguru import logger
|
|
5
|
+
from sqlalchemy import select, and_
|
|
6
|
+
from sqlalchemy.exc import SQLAlchemyError
|
|
7
|
+
|
|
8
|
+
from ..db.models.feedback import FeedbackBase
|
|
9
|
+
from ..db.database import DatabaseManager
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class FeedbackStorage:
|
|
13
|
+
"""反馈存储管理器 - 使用 PostgreSQL 数据库存储"""
|
|
14
|
+
|
|
15
|
+
def __init__(self, database_manager: DatabaseManager = None):
|
|
16
|
+
self.database_manager = database_manager
|
|
17
|
+
self._db = None
|
|
18
|
+
# 内存存储后备(用于数据库未配置时)
|
|
19
|
+
self._storage: dict[str, dict[str, Any]] = {}
|
|
20
|
+
self._task_index: dict[str, list[str]] = {}
|
|
21
|
+
self._workflow_index: dict[str, list[str]] = {}
|
|
22
|
+
|
|
23
|
+
@property
|
|
24
|
+
def db(self):
|
|
25
|
+
"""延迟获取数据库连接"""
|
|
26
|
+
if self._db is None and self.database_manager is not None:
|
|
27
|
+
self._db = self.database_manager.get_default_postgres_database()
|
|
28
|
+
return self._db
|
|
29
|
+
|
|
30
|
+
async def create_feedback(
|
|
31
|
+
self,
|
|
32
|
+
task_id: str,
|
|
33
|
+
workflow_name: str,
|
|
34
|
+
rating: Optional[int] = None,
|
|
35
|
+
comment: Optional[str] = None,
|
|
36
|
+
metadata: Optional[dict[str, Any]] = None,
|
|
37
|
+
) -> dict[str, Any]:
|
|
38
|
+
"""创建新的反馈记录"""
|
|
39
|
+
if self.db is None:
|
|
40
|
+
logger.warning("数据库未初始化,使用内存存储(数据将在重启后丢失)")
|
|
41
|
+
return self._create_feedback_in_memory(task_id, workflow_name, rating, comment, metadata)
|
|
42
|
+
|
|
43
|
+
try:
|
|
44
|
+
feedback_id = uuid4()
|
|
45
|
+
created_at = datetime.now()
|
|
46
|
+
|
|
47
|
+
feedback = FeedbackBase(
|
|
48
|
+
feedback_id=feedback_id,
|
|
49
|
+
task_id=task_id,
|
|
50
|
+
workflow_name=workflow_name,
|
|
51
|
+
rating=rating,
|
|
52
|
+
comment=comment,
|
|
53
|
+
extra_metadata=metadata or {},
|
|
54
|
+
created_at=created_at,
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
session_factory = await self.db.get_session_factory()
|
|
58
|
+
async with session_factory() as session:
|
|
59
|
+
session.add(feedback)
|
|
60
|
+
await session.commit()
|
|
61
|
+
await session.refresh(feedback)
|
|
62
|
+
|
|
63
|
+
logger.info(f"创建反馈: feedback_id={feedback_id}, task_id={task_id}, workflow={workflow_name}")
|
|
64
|
+
return feedback.to_dict()
|
|
65
|
+
|
|
66
|
+
except SQLAlchemyError as e:
|
|
67
|
+
logger.error(f"数据库创建反馈失败: {e}")
|
|
68
|
+
raise
|
|
69
|
+
except Exception as e:
|
|
70
|
+
logger.error(f"创建反馈失败: {e}")
|
|
71
|
+
raise
|
|
72
|
+
|
|
73
|
+
async def get_feedback(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
74
|
+
"""根据反馈ID获取反馈"""
|
|
75
|
+
if self.db is None:
|
|
76
|
+
return self._get_feedback_from_memory(feedback_id)
|
|
77
|
+
|
|
78
|
+
try:
|
|
79
|
+
session_factory = await self.db.get_session_factory()
|
|
80
|
+
async with session_factory() as session:
|
|
81
|
+
result = await session.execute(
|
|
82
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
83
|
+
)
|
|
84
|
+
feedback = result.scalar_one_or_none()
|
|
85
|
+
return feedback.to_dict() if feedback else None
|
|
86
|
+
|
|
87
|
+
except SQLAlchemyError as e:
|
|
88
|
+
logger.error(f"查询反馈失败: {e}")
|
|
89
|
+
return None
|
|
90
|
+
|
|
91
|
+
async def get_feedbacks_by_task(self, task_id: str) -> list[dict[str, Any]]:
|
|
92
|
+
"""根据任务ID获取所有反馈"""
|
|
93
|
+
if self.db is None:
|
|
94
|
+
return self._get_feedbacks_by_task_from_memory(task_id)
|
|
95
|
+
|
|
96
|
+
try:
|
|
97
|
+
session_factory = await self.db.get_session_factory()
|
|
98
|
+
async with session_factory() as session:
|
|
99
|
+
result = await session.execute(
|
|
100
|
+
select(FeedbackBase)
|
|
101
|
+
.where(FeedbackBase.task_id == task_id)
|
|
102
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
103
|
+
)
|
|
104
|
+
feedbacks = result.scalars().all()
|
|
105
|
+
return [f.to_dict() for f in feedbacks]
|
|
106
|
+
|
|
107
|
+
except SQLAlchemyError as e:
|
|
108
|
+
logger.error(f"查询任务反馈失败: {e}")
|
|
109
|
+
return []
|
|
110
|
+
|
|
111
|
+
async def get_feedbacks_by_workflow(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
112
|
+
"""根据工作流名称获取所有反馈"""
|
|
113
|
+
if self.db is None:
|
|
114
|
+
return self._get_feedbacks_by_workflow_from_memory(workflow_name)
|
|
115
|
+
|
|
116
|
+
try:
|
|
117
|
+
session_factory = await self.db.get_session_factory()
|
|
118
|
+
async with session_factory() as session:
|
|
119
|
+
result = await session.execute(
|
|
120
|
+
select(FeedbackBase)
|
|
121
|
+
.where(FeedbackBase.workflow_name == workflow_name)
|
|
122
|
+
.order_by(FeedbackBase.created_at.desc())
|
|
123
|
+
)
|
|
124
|
+
feedbacks = result.scalars().all()
|
|
125
|
+
return [f.to_dict() for f in feedbacks]
|
|
126
|
+
|
|
127
|
+
except SQLAlchemyError as e:
|
|
128
|
+
logger.error(f"查询工作流反馈失败: {e}")
|
|
129
|
+
return []
|
|
130
|
+
|
|
131
|
+
async def get_all_feedbacks(self) -> list[dict[str, Any]]:
|
|
132
|
+
"""获取所有反馈"""
|
|
133
|
+
if self.db is None:
|
|
134
|
+
return self._get_all_feedbacks_from_memory()
|
|
135
|
+
|
|
136
|
+
try:
|
|
137
|
+
session_factory = await self.db.get_session_factory()
|
|
138
|
+
async with session_factory() as session:
|
|
139
|
+
result = await session.execute(
|
|
140
|
+
select(FeedbackBase).order_by(FeedbackBase.created_at.desc())
|
|
141
|
+
)
|
|
142
|
+
feedbacks = result.scalars().all()
|
|
143
|
+
return [f.to_dict() for f in feedbacks]
|
|
144
|
+
|
|
145
|
+
except SQLAlchemyError as e:
|
|
146
|
+
logger.error(f"查询所有反馈失败: {e}")
|
|
147
|
+
return []
|
|
148
|
+
|
|
149
|
+
async def delete_feedback(self, feedback_id: str) -> bool:
|
|
150
|
+
"""删除反馈"""
|
|
151
|
+
if self.db is None:
|
|
152
|
+
return self._delete_feedback_from_memory(feedback_id)
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
session_factory = await self.db.get_session_factory()
|
|
156
|
+
async with session_factory() as session:
|
|
157
|
+
result = await session.execute(
|
|
158
|
+
select(FeedbackBase).where(FeedbackBase.feedback_id == feedback_id)
|
|
159
|
+
)
|
|
160
|
+
feedback = result.scalar_one_or_none()
|
|
161
|
+
|
|
162
|
+
if not feedback:
|
|
163
|
+
return False
|
|
164
|
+
|
|
165
|
+
await session.delete(feedback)
|
|
166
|
+
await session.commit()
|
|
167
|
+
|
|
168
|
+
logger.info(f"删除反馈: feedback_id={feedback_id}")
|
|
169
|
+
return True
|
|
170
|
+
|
|
171
|
+
except SQLAlchemyError as e:
|
|
172
|
+
logger.error(f"删除反馈失败: {e}")
|
|
173
|
+
return False
|
|
174
|
+
|
|
175
|
+
# ========== 内存存储后备方法(用于数据库未配置时) ==========
|
|
176
|
+
|
|
177
|
+
def _create_feedback_in_memory(
|
|
178
|
+
self,
|
|
179
|
+
task_id: str,
|
|
180
|
+
workflow_name: str,
|
|
181
|
+
rating: Optional[int],
|
|
182
|
+
comment: Optional[str],
|
|
183
|
+
metadata: Optional[dict[str, Any]],
|
|
184
|
+
) -> dict[str, Any]:
|
|
185
|
+
"""内存存储版本"""
|
|
186
|
+
feedback_id = str(uuid4())
|
|
187
|
+
created_at = datetime.now()
|
|
188
|
+
|
|
189
|
+
feedback_data = {
|
|
190
|
+
"feedback_id": feedback_id,
|
|
191
|
+
"task_id": task_id,
|
|
192
|
+
"workflow_name": workflow_name,
|
|
193
|
+
"rating": rating,
|
|
194
|
+
"comment": comment,
|
|
195
|
+
"metadata": metadata or {},
|
|
196
|
+
"created_at": created_at,
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
self._storage[feedback_id] = feedback_data
|
|
200
|
+
|
|
201
|
+
if task_id not in self._task_index:
|
|
202
|
+
self._task_index[task_id] = []
|
|
203
|
+
self._task_index[task_id].append(feedback_id)
|
|
204
|
+
|
|
205
|
+
if workflow_name not in self._workflow_index:
|
|
206
|
+
self._workflow_index[workflow_name] = []
|
|
207
|
+
self._workflow_index[workflow_name].append(feedback_id)
|
|
208
|
+
|
|
209
|
+
logger.info(f"[内存存储]创建反馈: feedback_id={feedback_id}")
|
|
210
|
+
return feedback_data
|
|
211
|
+
|
|
212
|
+
def _get_feedback_from_memory(self, feedback_id: str) -> Optional[dict[str, Any]]:
|
|
213
|
+
return self._storage.get(feedback_id)
|
|
214
|
+
|
|
215
|
+
def _get_feedbacks_by_task_from_memory(self, task_id: str) -> list[dict[str, Any]]:
|
|
216
|
+
feedback_ids = self._task_index.get(task_id, [])
|
|
217
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
218
|
+
|
|
219
|
+
def _get_feedbacks_by_workflow_from_memory(self, workflow_name: str) -> list[dict[str, Any]]:
|
|
220
|
+
feedback_ids = self._workflow_index.get(workflow_name, [])
|
|
221
|
+
return [self._storage[fid] for fid in feedback_ids if fid in self._storage]
|
|
222
|
+
|
|
223
|
+
def _get_all_feedbacks_from_memory(self) -> list[dict[str, Any]]:
|
|
224
|
+
return list(self._storage.values())
|
|
225
|
+
|
|
226
|
+
def _delete_feedback_from_memory(self, feedback_id: str) -> bool:
|
|
227
|
+
if feedback_id not in self._storage:
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
feedback = self._storage[feedback_id]
|
|
231
|
+
task_id = feedback["task_id"]
|
|
232
|
+
workflow_name = feedback["workflow_name"]
|
|
233
|
+
|
|
234
|
+
if task_id in self._task_index:
|
|
235
|
+
self._task_index[task_id].remove(feedback_id)
|
|
236
|
+
if workflow_name in self._workflow_index:
|
|
237
|
+
self._workflow_index[workflow_name].remove(feedback_id)
|
|
238
|
+
|
|
239
|
+
del self._storage[feedback_id]
|
|
240
|
+
logger.info(f"[内存存储]删除反馈: feedback_id={feedback_id}")
|
|
241
|
+
return True
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
# 全局单例实例
|
|
245
|
+
feedback_storage = FeedbackStorage()
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from ..utils.type_converter import TypeConverter
|
|
2
|
+
from ..workflow.workflow import Workflow
|
|
3
|
+
from ..api.http_api import fastapi_app
|
|
4
|
+
from ..api.kafka_api import KafkaApp, kafka_app
|
|
5
|
+
from fastapi import FastAPI
|
|
6
|
+
from ..workflow.workflow_type import WorkflowType, workflow_type_register
|
|
7
|
+
|
|
8
|
+
type_converter = TypeConverter()
|
|
9
|
+
type_converter.register(str, Workflow, lambda s, node: node.sub_workflows.get_workflow(s))
|
|
10
|
+
type_converter.register(str, FastAPI, lambda s, node: fastapi_app)
|
|
11
|
+
type_converter.register(str, KafkaApp, lambda s, node: kafka_app)
|
|
12
|
+
type_converter.register(str, type, lambda s, node: workflow_type_register.items[s].type)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import TypeVar, Generic, Any
|
|
2
|
+
from loguru import logger
|
|
3
|
+
|
|
4
|
+
T = TypeVar('T')
|
|
5
|
+
|
|
6
|
+
class Register(Generic[T]):
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self.items:dict[str, T] = {}
|
|
9
|
+
|
|
10
|
+
def register(
|
|
11
|
+
self,
|
|
12
|
+
name: str,
|
|
13
|
+
item: T,
|
|
14
|
+
show_info_log: bool = False,
|
|
15
|
+
) -> None:
|
|
16
|
+
if name not in self.items:
|
|
17
|
+
self.items[name] = item
|
|
18
|
+
if show_info_log:
|
|
19
|
+
logger.info(f'Register {name}.')
|
|
20
|
+
else:
|
|
21
|
+
logger.warning(f'{name} has been registered.')
|
|
22
|
+
|
|
23
|
+
def instance(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
kwargs: dict[str, Any] = {},
|
|
27
|
+
ignore_keys: list[str] = []
|
|
28
|
+
) -> None:
|
|
29
|
+
for key in ignore_keys:
|
|
30
|
+
try:
|
|
31
|
+
kwargs.pop(key)
|
|
32
|
+
except:
|
|
33
|
+
pass
|
|
34
|
+
if name not in self.items:
|
|
35
|
+
logger.error(f'{name} has not been registered.')
|
|
36
|
+
return self.items[name](**kwargs)
|
|
37
|
+
|
|
38
|
+
def __len__(self) -> int:
|
|
39
|
+
return len(self.items)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Any, Callable, Type, Dict, Tuple, Set, List
|
|
2
|
+
from collections import deque
|
|
3
|
+
import inspect
|
|
4
|
+
from typing_extensions import get_origin, get_args
|
|
5
|
+
|
|
6
|
+
def is_type(value, dst_type):
|
|
7
|
+
origin = get_origin(dst_type)
|
|
8
|
+
if origin is None:
|
|
9
|
+
return isinstance(value, dst_type)
|
|
10
|
+
|
|
11
|
+
if not isinstance(value, origin):
|
|
12
|
+
return False
|
|
13
|
+
|
|
14
|
+
args = get_args(dst_type)
|
|
15
|
+
if not args:
|
|
16
|
+
return True
|
|
17
|
+
|
|
18
|
+
if origin is list:
|
|
19
|
+
elem_type = args[0]
|
|
20
|
+
return all(is_type(item, elem_type) for item in value)
|
|
21
|
+
elif origin is dict:
|
|
22
|
+
key_type, value_type = args
|
|
23
|
+
return all(
|
|
24
|
+
is_type(k, key_type) and is_type(v, value_type)
|
|
25
|
+
for k, v in value.items()
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
class TypeConverter:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self._registry: Dict[Tuple[Type, Type], Callable[..., Any]] = {}
|
|
33
|
+
|
|
34
|
+
def register(self, src_type: Type, dst_type: Type, func: Callable[..., Any]):
|
|
35
|
+
self._registry[(src_type, dst_type)] = func
|
|
36
|
+
|
|
37
|
+
def can_convert(self, src_type: Type, dst_type: Type) -> bool:
|
|
38
|
+
return self._find_path(src_type, dst_type) is not None
|
|
39
|
+
|
|
40
|
+
def convert(self, value: Any, dst_type: Type, **kwargs) -> Any:
|
|
41
|
+
if value is None:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
if dst_type == Any:
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
src_type = type(value)
|
|
48
|
+
|
|
49
|
+
if is_type(value, dst_type):
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
if (src_type, dst_type) in self._registry:
|
|
53
|
+
return self._call_func(self._registry[(src_type, dst_type)], value, **kwargs)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
return dst_type(value)
|
|
57
|
+
except Exception:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
path = self._find_path(src_type, dst_type)
|
|
61
|
+
if not path:
|
|
62
|
+
raise TypeError(f"No conversion path found from {src_type.__name__} to {dst_type.__name__}.")
|
|
63
|
+
|
|
64
|
+
result = value
|
|
65
|
+
for i in range(len(path) - 1):
|
|
66
|
+
func = self._registry[(path[i], path[i + 1])]
|
|
67
|
+
result = self._call_func(func, result, **kwargs)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
def _call_func(self, func: Callable[..., Any], value: Any, **kwargs) -> Any:
|
|
71
|
+
sig = inspect.signature(func)
|
|
72
|
+
if len(sig.parameters) == 1:
|
|
73
|
+
return func(value)
|
|
74
|
+
else:
|
|
75
|
+
return func(value, **kwargs)
|
|
76
|
+
|
|
77
|
+
def _find_path(self, src_type: Type, dst_type: Type) -> List[Type] | None:
|
|
78
|
+
if src_type == dst_type:
|
|
79
|
+
return [src_type]
|
|
80
|
+
|
|
81
|
+
graph: Dict[Type, Set[Type]] = {}
|
|
82
|
+
for (s, d) in self._registry.keys():
|
|
83
|
+
graph.setdefault(s, set()).add(d)
|
|
84
|
+
|
|
85
|
+
queue = deque([[src_type]])
|
|
86
|
+
visited = {src_type}
|
|
87
|
+
|
|
88
|
+
while queue:
|
|
89
|
+
path = queue.popleft()
|
|
90
|
+
current = path[-1]
|
|
91
|
+
for neighbor in graph.get(current, []):
|
|
92
|
+
if neighbor in visited:
|
|
93
|
+
continue
|
|
94
|
+
new_path = path + [neighbor]
|
|
95
|
+
if neighbor == dst_type:
|
|
96
|
+
return new_path
|
|
97
|
+
queue.append(new_path)
|
|
98
|
+
visited.add(neighbor)
|
|
99
|
+
return None
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from service_forge.workflow.node import Node
|
|
7
|
+
from service_forge.workflow.port import Port
|
|
8
|
+
from service_forge.workflow.edge import Edge
|
|
9
|
+
from service_forge.workflow.workflow import Workflow
|
|
10
|
+
from service_forge.workflow.trigger import Trigger
|
|
11
|
+
from service_forge.workflow.context import Context
|
|
12
|
+
|
|
13
|
+
def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
|
|
14
|
+
from service_forge.workflow.workflow import Workflow
|
|
15
|
+
|
|
16
|
+
if self.nodes is not None and len(self.nodes) > 0:
|
|
17
|
+
context = self.nodes[0].context._clone()
|
|
18
|
+
else:
|
|
19
|
+
context = Context(variables={})
|
|
20
|
+
|
|
21
|
+
node_map: dict[Node, Node] = {node: node._clone(context) for node in self.nodes}
|
|
22
|
+
|
|
23
|
+
port_map: dict[Port, Port] = {}
|
|
24
|
+
port_map.update({port: port._clone(node_map) for port in self.input_ports})
|
|
25
|
+
port_map.update({port: port._clone(node_map) for port in self.output_ports})
|
|
26
|
+
for node in self.nodes:
|
|
27
|
+
for port in node.input_ports:
|
|
28
|
+
if port not in port_map:
|
|
29
|
+
port_map[port] = port._clone(node_map)
|
|
30
|
+
for port in node.output_ports:
|
|
31
|
+
if port not in port_map:
|
|
32
|
+
port_map[port] = port._clone(node_map)
|
|
33
|
+
|
|
34
|
+
edge_map: dict[Edge, Edge] = {}
|
|
35
|
+
for node in self.nodes:
|
|
36
|
+
for edge in node.input_edges:
|
|
37
|
+
if edge not in edge_map:
|
|
38
|
+
edge_map[edge] = edge._clone(node_map, port_map)
|
|
39
|
+
for edge in node.output_edges:
|
|
40
|
+
if edge not in edge_map:
|
|
41
|
+
edge_map[edge] = edge._clone(node_map, port_map)
|
|
42
|
+
|
|
43
|
+
# fill port.port
|
|
44
|
+
for old_port, new_port in port_map.items():
|
|
45
|
+
if old_port.port is not None:
|
|
46
|
+
new_port.port = port_map[old_port.port]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# fill ports and edges in nodes
|
|
50
|
+
for old_node, new_node in node_map.items():
|
|
51
|
+
new_node.input_edges = [edge_map[edge] for edge in old_node.input_edges]
|
|
52
|
+
new_node.output_edges = [edge_map[edge] for edge in old_node.output_edges]
|
|
53
|
+
new_node.input_ports = [port_map[port] for port in old_node.input_ports]
|
|
54
|
+
new_node.output_ports = [port_map[port] for port in old_node.output_ports]
|
|
55
|
+
new_node.input_variables = {
|
|
56
|
+
port_map[port]: value for port, value in old_node.input_variables.items()
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
workflow = Workflow(
|
|
60
|
+
id=self.id,
|
|
61
|
+
config=self.config,
|
|
62
|
+
nodes=[node_map[node] for node in self.nodes],
|
|
63
|
+
input_ports=[port_map[port] for port in self.input_ports],
|
|
64
|
+
output_ports=[port_map[port] for port in self.output_ports],
|
|
65
|
+
_handle_stream_output=self._handle_stream_output,
|
|
66
|
+
_handle_query_user=self._handle_query_user,
|
|
67
|
+
database_manager=self.database_manager,
|
|
68
|
+
max_concurrent_runs=self.max_concurrent_runs,
|
|
69
|
+
callbacks=self.callbacks,
|
|
70
|
+
task_id=task_id,
|
|
71
|
+
real_trigger_node=trigger_node,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
for node in workflow.nodes:
|
|
75
|
+
node.workflow = workflow
|
|
76
|
+
|
|
77
|
+
return workflow
|
|
78
|
+
|
|
79
|
+
def port_clone(self: Port, node_map: dict[Node, Node]) -> Port:
|
|
80
|
+
from service_forge.workflow.port import Port
|
|
81
|
+
node = node_map[self.node] if self.node is not None else None
|
|
82
|
+
port = Port(
|
|
83
|
+
name=self.name,
|
|
84
|
+
type=self.type,
|
|
85
|
+
node=node,
|
|
86
|
+
port=None,
|
|
87
|
+
value=self.value,
|
|
88
|
+
default=self.default,
|
|
89
|
+
is_extended=self.is_extended,
|
|
90
|
+
is_extended_generated=self.is_extended_generated,
|
|
91
|
+
)
|
|
92
|
+
port.is_prepared = self.is_prepared
|
|
93
|
+
return port
|
|
94
|
+
|
|
95
|
+
def node_clone(self: Node, context: Context) -> Node:
|
|
96
|
+
node = self.__class__(
|
|
97
|
+
name=self.name
|
|
98
|
+
)
|
|
99
|
+
node.context = context
|
|
100
|
+
node.input_edges = []
|
|
101
|
+
node.output_edges = []
|
|
102
|
+
node.input_ports = []
|
|
103
|
+
node.output_ports = []
|
|
104
|
+
node.query_user = self.query_user
|
|
105
|
+
node.workflow = None
|
|
106
|
+
|
|
107
|
+
if self.sub_workflows is not None:
|
|
108
|
+
raise ValueError("Sub workflows are not supported in node clone.")
|
|
109
|
+
node.sub_workflows = None
|
|
110
|
+
node.input_variables = {}
|
|
111
|
+
node.num_activated_input_edges = self.num_activated_input_edges
|
|
112
|
+
|
|
113
|
+
return node
|
|
114
|
+
|
|
115
|
+
def edge_clone(self: Edge, node_map: dict[Node, Node], port_map: dict[Port, Port]) -> Edge:
|
|
116
|
+
from service_forge.workflow.edge import Edge
|
|
117
|
+
start_node = node_map[self.start_node] if self.start_node is not None else None
|
|
118
|
+
end_node = node_map[self.end_node] if self.end_node is not None else None
|
|
119
|
+
return Edge(
|
|
120
|
+
start_node=start_node,
|
|
121
|
+
end_node=end_node,
|
|
122
|
+
start_port=port_map[self.start_port],
|
|
123
|
+
end_port=port_map[self.end_port],
|
|
124
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .workflow_type import workflow_type_register
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
class Context():
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
variables: dict[Any, Any] = dict(),
|
|
8
|
+
) -> None:
|
|
9
|
+
self.variables = variables
|
|
10
|
+
|
|
11
|
+
def _clone(self) -> Context:
|
|
12
|
+
return Context(
|
|
13
|
+
variables={key: value for key, value in self.variables.items()},
|
|
14
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from ..utils.workflow_clone import edge_clone
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .node import Node
|
|
8
|
+
from .port import Port
|
|
9
|
+
|
|
10
|
+
class Edge:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
start_node: Node,
|
|
14
|
+
end_node: Node,
|
|
15
|
+
start_port: Port,
|
|
16
|
+
end_port: Port,
|
|
17
|
+
) -> None:
|
|
18
|
+
self.start_node = start_node
|
|
19
|
+
self.end_node = end_node
|
|
20
|
+
self.start_port = start_port
|
|
21
|
+
self.end_port = end_port
|
|
22
|
+
|
|
23
|
+
def _clone(self, node_map: dict[Node, Node], port_map: dict[Port, Port]) -> Edge:
|
|
24
|
+
return edge_clone(self, node_map, port_map)
|