service-forge 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +148 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/service/__init__.py +4 -0
- service_forge/api/routers/service/service_router.py +137 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +240 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +288 -0
- service_forge/service_config.py +158 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +119 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +153 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +24 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
- service_forge/workflow/triggers/fast_api_trigger.py +169 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
- service_forge/workflow/workflow.py +210 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +27 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.11.dist-info/METADATA +98 -0
- service_forge-0.1.11.dist-info/RECORD +75 -0
- service_forge-0.1.11.dist-info/WHEEL +4 -0
- service_forge-0.1.11.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Dict
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
def extract_yaml_content_without_comments(yaml_file: Path) -> str:
|
|
7
|
+
"""
|
|
8
|
+
读取 YAML 文件,去除注释,返回格式化的 YAML 字符串
|
|
9
|
+
|
|
10
|
+
Args:
|
|
11
|
+
yaml_file: YAML 文件路径
|
|
12
|
+
|
|
13
|
+
Returns:
|
|
14
|
+
去除注释后的 YAML 字符串
|
|
15
|
+
"""
|
|
16
|
+
with open(yaml_file, 'r', encoding='utf-8') as f:
|
|
17
|
+
content = f.read()
|
|
18
|
+
|
|
19
|
+
# 移除注释行(以 # 开头的行,但不包含字符串中的 #)
|
|
20
|
+
lines = content.split('\n')
|
|
21
|
+
filtered_lines = []
|
|
22
|
+
|
|
23
|
+
for line in lines:
|
|
24
|
+
stripped = line.strip()
|
|
25
|
+
# 跳过空行和纯注释行
|
|
26
|
+
if not stripped or stripped.startswith('#'):
|
|
27
|
+
continue
|
|
28
|
+
filtered_lines.append(line)
|
|
29
|
+
|
|
30
|
+
# 重新组合并解析为 Python 对象,然后重新序列化
|
|
31
|
+
yaml_content = '\n'.join(filtered_lines)
|
|
32
|
+
|
|
33
|
+
try:
|
|
34
|
+
# 使用 yaml.safe_load 解析,然后重新序列化确保格式正确
|
|
35
|
+
data = yaml.safe_load(yaml_content)
|
|
36
|
+
if data is None:
|
|
37
|
+
return ""
|
|
38
|
+
return yaml.dump(data, default_flow_style=False, allow_unicode=True, sort_keys=False)
|
|
39
|
+
except yaml.YAMLError:
|
|
40
|
+
# 如果解析失败,返回原始过滤后的内容
|
|
41
|
+
return yaml_content
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def load_sf_metadata_as_string(metadata_file: Path) -> str:
|
|
45
|
+
"""
|
|
46
|
+
加载 sf-meta.yaml 文件,去除注释后返回字符串格式
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
metadata_file: sf-meta.yaml 文件路径
|
|
50
|
+
|
|
51
|
+
Returns:
|
|
52
|
+
去除注释后的 YAML 内容字符串
|
|
53
|
+
"""
|
|
54
|
+
if not metadata_file.exists():
|
|
55
|
+
raise FileNotFoundError(f"Metadata file not found: {metadata_file}")
|
|
56
|
+
|
|
57
|
+
return extract_yaml_content_without_comments(metadata_file)
|
|
File without changes
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
from ..utils.type_converter import TypeConverter
|
|
2
|
+
from ..workflow.workflow import Workflow
|
|
3
|
+
from ..api.http_api import fastapi_app
|
|
4
|
+
from ..api.kafka_api import KafkaApp, kafka_app
|
|
5
|
+
from fastapi import FastAPI
|
|
6
|
+
from ..workflow.workflow_type import WorkflowType, workflow_type_register
|
|
7
|
+
|
|
8
|
+
type_converter = TypeConverter()
|
|
9
|
+
type_converter.register(str, Workflow, lambda s, node: node.sub_workflows.get_workflow(s))
|
|
10
|
+
type_converter.register(str, FastAPI, lambda s, node: fastapi_app)
|
|
11
|
+
type_converter.register(str, KafkaApp, lambda s, node: kafka_app)
|
|
12
|
+
type_converter.register(str, type, lambda s, node: workflow_type_register.items[s].type)
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
from typing import TypeVar, Generic, Any
|
|
2
|
+
from loguru import logger
|
|
3
|
+
|
|
4
|
+
T = TypeVar('T')
|
|
5
|
+
|
|
6
|
+
class Register(Generic[T]):
|
|
7
|
+
def __init__(self) -> None:
|
|
8
|
+
self.items:dict[str, T] = {}
|
|
9
|
+
|
|
10
|
+
def register(
|
|
11
|
+
self,
|
|
12
|
+
name: str,
|
|
13
|
+
item: T,
|
|
14
|
+
show_info_log: bool = False,
|
|
15
|
+
) -> None:
|
|
16
|
+
if name not in self.items:
|
|
17
|
+
self.items[name] = item
|
|
18
|
+
if show_info_log:
|
|
19
|
+
logger.info(f'Register {name}.')
|
|
20
|
+
else:
|
|
21
|
+
logger.warning(f'{name} has been registered.')
|
|
22
|
+
|
|
23
|
+
def instance(
|
|
24
|
+
self,
|
|
25
|
+
name: str,
|
|
26
|
+
kwargs: dict[str, Any] = {},
|
|
27
|
+
ignore_keys: list[str] = []
|
|
28
|
+
) -> None:
|
|
29
|
+
for key in ignore_keys:
|
|
30
|
+
try:
|
|
31
|
+
kwargs.pop(key)
|
|
32
|
+
except:
|
|
33
|
+
pass
|
|
34
|
+
if name not in self.items:
|
|
35
|
+
logger.error(f'{name} has not been registered.')
|
|
36
|
+
return self.items[name](**kwargs)
|
|
37
|
+
|
|
38
|
+
def __len__(self) -> int:
|
|
39
|
+
return len(self.items)
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
from typing import Any, Callable, Type, Dict, Tuple, Set, List
|
|
2
|
+
from collections import deque
|
|
3
|
+
import inspect
|
|
4
|
+
from typing_extensions import get_origin, get_args
|
|
5
|
+
|
|
6
|
+
def is_type(value, dst_type):
|
|
7
|
+
origin = get_origin(dst_type)
|
|
8
|
+
if origin is None:
|
|
9
|
+
return isinstance(value, dst_type)
|
|
10
|
+
|
|
11
|
+
if not isinstance(value, origin):
|
|
12
|
+
return False
|
|
13
|
+
|
|
14
|
+
args = get_args(dst_type)
|
|
15
|
+
if not args:
|
|
16
|
+
return True
|
|
17
|
+
|
|
18
|
+
if origin is list:
|
|
19
|
+
elem_type = args[0]
|
|
20
|
+
return all(is_type(item, elem_type) for item in value)
|
|
21
|
+
elif origin is dict:
|
|
22
|
+
key_type, value_type = args
|
|
23
|
+
return all(
|
|
24
|
+
is_type(k, key_type) and is_type(v, value_type)
|
|
25
|
+
for k, v in value.items()
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
return True
|
|
29
|
+
|
|
30
|
+
class TypeConverter:
|
|
31
|
+
def __init__(self):
|
|
32
|
+
self._registry: Dict[Tuple[Type, Type], Callable[..., Any]] = {}
|
|
33
|
+
|
|
34
|
+
def register(self, src_type: Type, dst_type: Type, func: Callable[..., Any]):
|
|
35
|
+
self._registry[(src_type, dst_type)] = func
|
|
36
|
+
|
|
37
|
+
def can_convert(self, src_type: Type, dst_type: Type) -> bool:
|
|
38
|
+
return self._find_path(src_type, dst_type) is not None
|
|
39
|
+
|
|
40
|
+
def convert(self, value: Any, dst_type: Type, **kwargs) -> Any:
|
|
41
|
+
if value is None:
|
|
42
|
+
return None
|
|
43
|
+
|
|
44
|
+
if dst_type == Any:
|
|
45
|
+
return value
|
|
46
|
+
|
|
47
|
+
src_type = type(value)
|
|
48
|
+
|
|
49
|
+
if is_type(value, dst_type):
|
|
50
|
+
return value
|
|
51
|
+
|
|
52
|
+
if (src_type, dst_type) in self._registry:
|
|
53
|
+
return self._call_func(self._registry[(src_type, dst_type)], value, **kwargs)
|
|
54
|
+
|
|
55
|
+
try:
|
|
56
|
+
return dst_type(value)
|
|
57
|
+
except Exception:
|
|
58
|
+
pass
|
|
59
|
+
|
|
60
|
+
path = self._find_path(src_type, dst_type)
|
|
61
|
+
if not path:
|
|
62
|
+
raise TypeError(f"No conversion path found from {src_type.__name__} to {dst_type.__name__}.")
|
|
63
|
+
|
|
64
|
+
result = value
|
|
65
|
+
for i in range(len(path) - 1):
|
|
66
|
+
func = self._registry[(path[i], path[i + 1])]
|
|
67
|
+
result = self._call_func(func, result, **kwargs)
|
|
68
|
+
return result
|
|
69
|
+
|
|
70
|
+
def _call_func(self, func: Callable[..., Any], value: Any, **kwargs) -> Any:
|
|
71
|
+
sig = inspect.signature(func)
|
|
72
|
+
if len(sig.parameters) == 1:
|
|
73
|
+
return func(value)
|
|
74
|
+
else:
|
|
75
|
+
return func(value, **kwargs)
|
|
76
|
+
|
|
77
|
+
def _find_path(self, src_type: Type, dst_type: Type) -> List[Type] | None:
|
|
78
|
+
if src_type == dst_type:
|
|
79
|
+
return [src_type]
|
|
80
|
+
|
|
81
|
+
graph: Dict[Type, Set[Type]] = {}
|
|
82
|
+
for (s, d) in self._registry.keys():
|
|
83
|
+
graph.setdefault(s, set()).add(d)
|
|
84
|
+
|
|
85
|
+
queue = deque([[src_type]])
|
|
86
|
+
visited = {src_type}
|
|
87
|
+
|
|
88
|
+
while queue:
|
|
89
|
+
path = queue.popleft()
|
|
90
|
+
current = path[-1]
|
|
91
|
+
for neighbor in graph.get(current, []):
|
|
92
|
+
if neighbor in visited:
|
|
93
|
+
continue
|
|
94
|
+
new_path = path + [neighbor]
|
|
95
|
+
if neighbor == dst_type:
|
|
96
|
+
return new_path
|
|
97
|
+
queue.append(new_path)
|
|
98
|
+
visited.add(neighbor)
|
|
99
|
+
return None
|
|
@@ -0,0 +1,124 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import uuid
|
|
4
|
+
from typing import TYPE_CHECKING
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from service_forge.workflow.node import Node
|
|
7
|
+
from service_forge.workflow.port import Port
|
|
8
|
+
from service_forge.workflow.edge import Edge
|
|
9
|
+
from service_forge.workflow.workflow import Workflow
|
|
10
|
+
from service_forge.workflow.trigger import Trigger
|
|
11
|
+
from service_forge.workflow.context import Context
|
|
12
|
+
|
|
13
|
+
def workflow_clone(self: Workflow, task_id: uuid.UUID, trigger_node: Trigger) -> Workflow:
|
|
14
|
+
from service_forge.workflow.workflow import Workflow
|
|
15
|
+
|
|
16
|
+
if self.nodes is not None and len(self.nodes) > 0:
|
|
17
|
+
context = self.nodes[0].context._clone()
|
|
18
|
+
else:
|
|
19
|
+
context = Context(variables={})
|
|
20
|
+
|
|
21
|
+
node_map: dict[Node, Node] = {node: node._clone(context) for node in self.nodes}
|
|
22
|
+
|
|
23
|
+
port_map: dict[Port, Port] = {}
|
|
24
|
+
port_map.update({port: port._clone(node_map) for port in self.input_ports})
|
|
25
|
+
port_map.update({port: port._clone(node_map) for port in self.output_ports})
|
|
26
|
+
for node in self.nodes:
|
|
27
|
+
for port in node.input_ports:
|
|
28
|
+
if port not in port_map:
|
|
29
|
+
port_map[port] = port._clone(node_map)
|
|
30
|
+
for port in node.output_ports:
|
|
31
|
+
if port not in port_map:
|
|
32
|
+
port_map[port] = port._clone(node_map)
|
|
33
|
+
|
|
34
|
+
edge_map: dict[Edge, Edge] = {}
|
|
35
|
+
for node in self.nodes:
|
|
36
|
+
for edge in node.input_edges:
|
|
37
|
+
if edge not in edge_map:
|
|
38
|
+
edge_map[edge] = edge._clone(node_map, port_map)
|
|
39
|
+
for edge in node.output_edges:
|
|
40
|
+
if edge not in edge_map:
|
|
41
|
+
edge_map[edge] = edge._clone(node_map, port_map)
|
|
42
|
+
|
|
43
|
+
# fill port.port
|
|
44
|
+
for old_port, new_port in port_map.items():
|
|
45
|
+
if old_port.port is not None:
|
|
46
|
+
new_port.port = port_map[old_port.port]
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
# fill ports and edges in nodes
|
|
50
|
+
for old_node, new_node in node_map.items():
|
|
51
|
+
new_node.input_edges = [edge_map[edge] for edge in old_node.input_edges]
|
|
52
|
+
new_node.output_edges = [edge_map[edge] for edge in old_node.output_edges]
|
|
53
|
+
new_node.input_ports = [port_map[port] for port in old_node.input_ports]
|
|
54
|
+
new_node.output_ports = [port_map[port] for port in old_node.output_ports]
|
|
55
|
+
new_node.input_variables = {
|
|
56
|
+
port_map[port]: value for port, value in old_node.input_variables.items()
|
|
57
|
+
}
|
|
58
|
+
|
|
59
|
+
workflow = Workflow(
|
|
60
|
+
name=self.name,
|
|
61
|
+
description=self.description,
|
|
62
|
+
nodes=[node_map[node] for node in self.nodes],
|
|
63
|
+
input_ports=[port_map[port] for port in self.input_ports],
|
|
64
|
+
output_ports=[port_map[port] for port in self.output_ports],
|
|
65
|
+
_handle_stream_output=self._handle_stream_output,
|
|
66
|
+
_handle_query_user=self._handle_query_user,
|
|
67
|
+
database_manager=self.database_manager,
|
|
68
|
+
max_concurrent_runs=self.max_concurrent_runs,
|
|
69
|
+
callbacks=self.callbacks,
|
|
70
|
+
task_id=task_id,
|
|
71
|
+
real_trigger_node=trigger_node,
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
for node in workflow.nodes:
|
|
75
|
+
node.workflow = workflow
|
|
76
|
+
|
|
77
|
+
return workflow
|
|
78
|
+
|
|
79
|
+
def port_clone(self: Port, node_map: dict[Node, Node]) -> Port:
|
|
80
|
+
from service_forge.workflow.port import Port
|
|
81
|
+
node = node_map[self.node] if self.node is not None else None
|
|
82
|
+
port = Port(
|
|
83
|
+
name=self.name,
|
|
84
|
+
type=self.type,
|
|
85
|
+
node=node,
|
|
86
|
+
port=None,
|
|
87
|
+
value=self.value,
|
|
88
|
+
default=self.default,
|
|
89
|
+
is_extended=self.is_extended,
|
|
90
|
+
is_extended_generated=self.is_extended_generated,
|
|
91
|
+
)
|
|
92
|
+
port.is_prepared = self.is_prepared
|
|
93
|
+
return port
|
|
94
|
+
|
|
95
|
+
def node_clone(self: Node, context: Context) -> Node:
|
|
96
|
+
node = self.__class__(
|
|
97
|
+
name=self.name
|
|
98
|
+
)
|
|
99
|
+
node.context = context
|
|
100
|
+
node.input_edges = []
|
|
101
|
+
node.output_edges = []
|
|
102
|
+
node.input_ports = []
|
|
103
|
+
node.output_ports = []
|
|
104
|
+
node.query_user = self.query_user
|
|
105
|
+
node.workflow = None
|
|
106
|
+
|
|
107
|
+
if self.sub_workflows is not None:
|
|
108
|
+
raise ValueError("Sub workflows are not supported in node clone.")
|
|
109
|
+
node.sub_workflows = None
|
|
110
|
+
node.input_variables = {}
|
|
111
|
+
node.num_activated_input_edges = self.num_activated_input_edges
|
|
112
|
+
|
|
113
|
+
return node
|
|
114
|
+
|
|
115
|
+
def edge_clone(self: Edge, node_map: dict[Node, Node], port_map: dict[Port, Port]) -> Edge:
|
|
116
|
+
from service_forge.workflow.edge import Edge
|
|
117
|
+
start_node = node_map[self.start_node] if self.start_node is not None else None
|
|
118
|
+
end_node = node_map[self.end_node] if self.end_node is not None else None
|
|
119
|
+
return Edge(
|
|
120
|
+
start_node=start_node,
|
|
121
|
+
end_node=end_node,
|
|
122
|
+
start_port=port_map[self.start_port],
|
|
123
|
+
end_port=port_map[self.end_port],
|
|
124
|
+
)
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from .workflow_type import workflow_type_register
|
|
@@ -0,0 +1,14 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
|
|
4
|
+
class Context():
|
|
5
|
+
def __init__(
|
|
6
|
+
self,
|
|
7
|
+
variables: dict[Any, Any] = dict(),
|
|
8
|
+
) -> None:
|
|
9
|
+
self.variables = variables
|
|
10
|
+
|
|
11
|
+
def _clone(self) -> Context:
|
|
12
|
+
return Context(
|
|
13
|
+
variables={key: value for key, value in self.variables.items()},
|
|
14
|
+
)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from ..utils.workflow_clone import edge_clone
|
|
5
|
+
|
|
6
|
+
if TYPE_CHECKING:
|
|
7
|
+
from .node import Node
|
|
8
|
+
from .port import Port
|
|
9
|
+
|
|
10
|
+
class Edge:
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
start_node: Node,
|
|
14
|
+
end_node: Node,
|
|
15
|
+
start_port: Port,
|
|
16
|
+
end_port: Port,
|
|
17
|
+
) -> None:
|
|
18
|
+
self.start_node = start_node
|
|
19
|
+
self.end_node = end_node
|
|
20
|
+
self.start_port = start_port
|
|
21
|
+
self.end_port = end_port
|
|
22
|
+
|
|
23
|
+
def _clone(self, node_map: dict[Node, Node], port_map: dict[Port, Port]) -> Edge:
|
|
24
|
+
return edge_clone(self, node_map, port_map)
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import Any, AsyncIterator, Union, TYPE_CHECKING, Callable, Awaitable
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
import uuid
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from .edge import Edge
|
|
8
|
+
from .port import Port
|
|
9
|
+
from .context import Context
|
|
10
|
+
from ..utils.register import Register
|
|
11
|
+
from ..db.database import DatabaseManager, PostgresDatabase, MongoDatabase, RedisDatabase
|
|
12
|
+
from ..utils.workflow_clone import node_clone
|
|
13
|
+
|
|
14
|
+
if TYPE_CHECKING:
|
|
15
|
+
from .workflow import Workflow
|
|
16
|
+
|
|
17
|
+
class Node(ABC):
|
|
18
|
+
DEFAULT_INPUT_PORTS: list[Port] = []
|
|
19
|
+
DEFAULT_OUTPUT_PORTS: list[Port] = []
|
|
20
|
+
|
|
21
|
+
CLASS_NOT_REQUIRED_TO_REGISTER = ['Node']
|
|
22
|
+
AUTO_FILL_INPUT_PORTS = []
|
|
23
|
+
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
name: str,
|
|
27
|
+
context: Context = None,
|
|
28
|
+
input_edges: list[Edge] = None,
|
|
29
|
+
output_edges: list[Edge] = None,
|
|
30
|
+
input_ports: list[Port] = DEFAULT_INPUT_PORTS,
|
|
31
|
+
output_ports: list[Port] = DEFAULT_OUTPUT_PORTS,
|
|
32
|
+
query_user: Callable[[str, str], Awaitable[str]] = None,
|
|
33
|
+
) -> None:
|
|
34
|
+
from .workflow_group import WorkflowGroup
|
|
35
|
+
self.name = name
|
|
36
|
+
self.input_edges = [] if input_edges is None else input_edges
|
|
37
|
+
self.output_edges = [] if output_edges is None else output_edges
|
|
38
|
+
self.input_ports = input_ports
|
|
39
|
+
self.output_ports = output_ports
|
|
40
|
+
self.workflow: Workflow = None
|
|
41
|
+
self.query_user = query_user
|
|
42
|
+
self.sub_workflows: WorkflowGroup = None
|
|
43
|
+
|
|
44
|
+
# runtime variables
|
|
45
|
+
self.context = context
|
|
46
|
+
self.input_variables: dict[Port, Any] = {}
|
|
47
|
+
self.num_activated_input_edges = 0
|
|
48
|
+
|
|
49
|
+
@property
|
|
50
|
+
def default_postgres_database(self) -> PostgresDatabase | None:
|
|
51
|
+
return self.database_manager.get_default_postgres_database()
|
|
52
|
+
|
|
53
|
+
@property
|
|
54
|
+
def default_mongo_database(self) -> MongoDatabase | None:
|
|
55
|
+
return self.database_manager.get_default_mongo_database()
|
|
56
|
+
|
|
57
|
+
@property
|
|
58
|
+
def default_redis_database(self) -> RedisDatabase | None:
|
|
59
|
+
return self.database_manager.get_default_redis_database()
|
|
60
|
+
|
|
61
|
+
@property
|
|
62
|
+
def database_manager(self) -> DatabaseManager:
|
|
63
|
+
return self.workflow.database_manager
|
|
64
|
+
|
|
65
|
+
def backup(self) -> None:
|
|
66
|
+
# do NOT use deepcopy here
|
|
67
|
+
# self.bak_context = deepcopy(self.context)
|
|
68
|
+
# TODO: what if the value changes after backup?
|
|
69
|
+
self.bak_input_variables = {port: value for port, value in self.input_variables.items()}
|
|
70
|
+
self.bak_num_activated_input_edges = self.num_activated_input_edges
|
|
71
|
+
|
|
72
|
+
def reset(self) -> None:
|
|
73
|
+
# self.context = deepcopy(self.bak_context)
|
|
74
|
+
self.input_variables = {port: value for port, value in self.bak_input_variables.items()}
|
|
75
|
+
self.num_activated_input_edges = self.bak_num_activated_input_edges
|
|
76
|
+
|
|
77
|
+
def __init_subclass__(cls) -> None:
|
|
78
|
+
if cls.__name__ not in Node.CLASS_NOT_REQUIRED_TO_REGISTER:
|
|
79
|
+
node_register.register(cls.__name__, cls)
|
|
80
|
+
return super().__init_subclass__()
|
|
81
|
+
|
|
82
|
+
def _query_user(self, prompt: str) -> Callable[[str, str], Awaitable[str]]:
|
|
83
|
+
return self.query_user(self.name, prompt)
|
|
84
|
+
|
|
85
|
+
def variables_to_params(self) -> dict[str, Any]:
|
|
86
|
+
params = {port.name: self.input_variables[port] for port in self.input_variables.keys() if not port.is_extended_generated}
|
|
87
|
+
for port in self.input_variables.keys():
|
|
88
|
+
if port.is_extended_generated:
|
|
89
|
+
if port.get_extended_name() not in params:
|
|
90
|
+
params[port.get_extended_name()] = []
|
|
91
|
+
params[port.get_extended_name()].append((port.get_extended_index(), self.input_variables[port]))
|
|
92
|
+
params[port.get_extended_name()].sort()
|
|
93
|
+
return params
|
|
94
|
+
|
|
95
|
+
def is_trigger(self) -> bool:
|
|
96
|
+
from .trigger import Trigger
|
|
97
|
+
return isinstance(self, Trigger)
|
|
98
|
+
|
|
99
|
+
# TODO: maybe add a function before the run function?
|
|
100
|
+
|
|
101
|
+
@abstractmethod
|
|
102
|
+
async def _run(self, **kwargs) -> Union[None, AsyncIterator]:
|
|
103
|
+
...
|
|
104
|
+
|
|
105
|
+
def run(self) -> Union[None, AsyncIterator]:
|
|
106
|
+
for key in list(self.input_variables.keys()):
|
|
107
|
+
if key and key.name[0].isupper():
|
|
108
|
+
del self.input_variables[key]
|
|
109
|
+
params = self.variables_to_params()
|
|
110
|
+
return self._run(**params)
|
|
111
|
+
|
|
112
|
+
def get_input_port_by_name(self, name: str) -> Port:
|
|
113
|
+
# TODO: add warning if port is extended
|
|
114
|
+
for port in self.input_ports:
|
|
115
|
+
if port.name == name:
|
|
116
|
+
return port
|
|
117
|
+
return None
|
|
118
|
+
|
|
119
|
+
def get_output_port_by_name(self, name: str) -> Port:
|
|
120
|
+
# TODO: add warning if port is extended
|
|
121
|
+
for port in self.output_ports:
|
|
122
|
+
if port.name == name:
|
|
123
|
+
return port
|
|
124
|
+
return None
|
|
125
|
+
|
|
126
|
+
def try_create_extended_input_port(self, name: str) -> None:
|
|
127
|
+
for port in self.input_ports:
|
|
128
|
+
if port.is_extended and name.startswith(port.name + '_') and name[len(port.name + '_'):].isdigit():
|
|
129
|
+
self.input_ports.append(Port(name=name, type=port.type, node=port.node, port=port.port, value=port.value, default=port.default, is_extended=False, is_extended_generated=True))
|
|
130
|
+
|
|
131
|
+
def try_create_extended_output_port(self, name: str) -> None:
|
|
132
|
+
for port in self.output_ports:
|
|
133
|
+
if port.is_extended and name.startswith(port.name + '_') and name[len(port.name + '_'):].isdigit():
|
|
134
|
+
self.output_ports.append(Port(name=name, type=port.type, node=port.node, port=port.port, value=port.value, default=port.default, is_extended=False, is_extended_generated=True))
|
|
135
|
+
|
|
136
|
+
def num_input_ports(self) -> int:
|
|
137
|
+
return sum(1 for port in self.input_ports if not port.is_extended)
|
|
138
|
+
|
|
139
|
+
def is_ready(self) -> bool:
|
|
140
|
+
return self.num_activated_input_edges == self.num_input_ports()
|
|
141
|
+
|
|
142
|
+
def fill_input_by_name(self, port_name: str, value: Any) -> None:
|
|
143
|
+
self.try_create_extended_input_port(port_name)
|
|
144
|
+
port = self.get_input_port_by_name(port_name)
|
|
145
|
+
if port is None:
|
|
146
|
+
raise ValueError(f'{port_name} is not a valid input port.')
|
|
147
|
+
self.fill_input(port, value)
|
|
148
|
+
|
|
149
|
+
def fill_input(self, port: Port, value: Any) -> None:
|
|
150
|
+
port.activate(value)
|
|
151
|
+
|
|
152
|
+
def activate_output_edges(self, port: str | Port, data: Any) -> None:
|
|
153
|
+
if isinstance(port, str):
|
|
154
|
+
port = self.get_output_port_by_name(port)
|
|
155
|
+
for output_edge in self.output_edges:
|
|
156
|
+
if output_edge.start_port == port:
|
|
157
|
+
output_edge.end_port.activate(data)
|
|
158
|
+
|
|
159
|
+
# for trigger nodes
|
|
160
|
+
def prepare_output_edges(self, port: Port, data: Any) -> None:
|
|
161
|
+
if isinstance(port, str):
|
|
162
|
+
port = self.get_output_port_by_name(port)
|
|
163
|
+
for output_edge in self.output_edges:
|
|
164
|
+
if output_edge.start_port == port:
|
|
165
|
+
output_edge.end_port.prepare(data)
|
|
166
|
+
|
|
167
|
+
def trigger_output_edges(self, port: Port) -> None:
|
|
168
|
+
if isinstance(port, str):
|
|
169
|
+
port = self.get_output_port_by_name(port)
|
|
170
|
+
for output_edge in self.output_edges:
|
|
171
|
+
if output_edge.start_port == port:
|
|
172
|
+
output_edge.end_port.trigger()
|
|
173
|
+
|
|
174
|
+
# TODO: the result is outputed to the trigger now, maybe we should add a new function to output the result to the workflow
|
|
175
|
+
def output_to_workflow(self, data: Any) -> None:
|
|
176
|
+
self.workflow._handle_workflow_output(self.name, data)
|
|
177
|
+
|
|
178
|
+
def extended_output_name(self, name: str, index: int) -> str:
|
|
179
|
+
return name + '_' + str(index)
|
|
180
|
+
|
|
181
|
+
def _clone(self, context: Context) -> Node:
|
|
182
|
+
return node_clone(self, context)
|
|
183
|
+
|
|
184
|
+
node_register = Register[Node]()
|
|
@@ -0,0 +1,8 @@
|
|
|
1
|
+
from .output.print_node import PrintNode
|
|
2
|
+
from .output.kafka_output_node import KafkaOutputNode
|
|
3
|
+
from .control.if_node import IfNode
|
|
4
|
+
from .control.switch_node import SwitchNode
|
|
5
|
+
from .llm.query_llm_node import QueryLLMNode
|
|
6
|
+
from .test.if_console_input_node import IfConsoleInputNode
|
|
7
|
+
from .nested.workflow_node import WorkflowNode
|
|
8
|
+
from .test.time_consuming_node import TimeConsumingNode
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from ...node import Node
|
|
3
|
+
from ...port import Port
|
|
4
|
+
|
|
5
|
+
class IfNode(Node):
|
|
6
|
+
DEFAULT_INPUT_PORTS = [
|
|
7
|
+
Port("TRIGGER", bool),
|
|
8
|
+
Port("condition", str),
|
|
9
|
+
]
|
|
10
|
+
|
|
11
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
12
|
+
Port("true", bool),
|
|
13
|
+
Port("false", bool),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
name: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
super().__init__(
|
|
21
|
+
name,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def _run(self, condition: str) -> None:
|
|
25
|
+
result = eval(condition)
|
|
26
|
+
if result:
|
|
27
|
+
self.activate_output_edges('true', True)
|
|
28
|
+
else:
|
|
29
|
+
self.activate_output_edges('false', False)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import Any
|
|
3
|
+
from ...node import Node
|
|
4
|
+
from ...port import Port
|
|
5
|
+
|
|
6
|
+
class SwitchNode(Node):
|
|
7
|
+
DEFAULT_INPUT_PORTS = [
|
|
8
|
+
Port("TRIGGER", bool),
|
|
9
|
+
Port("condition", str, is_extended=True),
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
13
|
+
Port("result", Any, is_extended=True),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
name: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
super().__init__(
|
|
21
|
+
name,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def _run(self, condition: list[tuple[int, str]]) -> None:
|
|
25
|
+
for index, cond in condition:
|
|
26
|
+
if eval(cond):
|
|
27
|
+
self.activate_output_edges(self.extended_output_name('result', index), str(index))
|
|
28
|
+
break
|
|
@@ -0,0 +1,26 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from service_forge.workflow.node import Node
|
|
3
|
+
from service_forge.workflow.port import Port
|
|
4
|
+
from service_forge.llm import chat_stream, Model
|
|
5
|
+
|
|
6
|
+
class ConsoleInputNode(Node):
|
|
7
|
+
DEFAULT_INPUT_PORTS = [
|
|
8
|
+
Port("TRIGGER", bool),
|
|
9
|
+
Port("prompt", str),
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
13
|
+
Port("user_input", str),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
AUTO_FILL_INPUT_PORTS = []
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
name: str,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(name)
|
|
23
|
+
|
|
24
|
+
async def _run(self, prompt: str) -> None:
|
|
25
|
+
user_input = self._query_user(prompt)
|
|
26
|
+
self.activate_output_edges(self.get_output_port_by_name('user_input'), user_input)
|