service-forge 0.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of service-forge might be problematic. Click here for more details.
- service_forge/api/deprecated_websocket_api.py +86 -0
- service_forge/api/deprecated_websocket_manager.py +425 -0
- service_forge/api/http_api.py +148 -0
- service_forge/api/http_api_doc.py +455 -0
- service_forge/api/kafka_api.py +126 -0
- service_forge/api/routers/service/__init__.py +4 -0
- service_forge/api/routers/service/service_router.py +137 -0
- service_forge/api/routers/websocket/websocket_manager.py +83 -0
- service_forge/api/routers/websocket/websocket_router.py +78 -0
- service_forge/api/task_manager.py +141 -0
- service_forge/db/__init__.py +1 -0
- service_forge/db/database.py +240 -0
- service_forge/llm/__init__.py +62 -0
- service_forge/llm/llm.py +56 -0
- service_forge/model/__init__.py +0 -0
- service_forge/model/websocket.py +13 -0
- service_forge/proto/foo_input.py +5 -0
- service_forge/service.py +288 -0
- service_forge/service_config.py +158 -0
- service_forge/sft/cli.py +91 -0
- service_forge/sft/cmd/config_command.py +67 -0
- service_forge/sft/cmd/deploy_service.py +123 -0
- service_forge/sft/cmd/list_tars.py +41 -0
- service_forge/sft/cmd/service_command.py +149 -0
- service_forge/sft/cmd/upload_service.py +36 -0
- service_forge/sft/config/injector.py +119 -0
- service_forge/sft/config/injector_default_files.py +131 -0
- service_forge/sft/config/sf_metadata.py +30 -0
- service_forge/sft/config/sft_config.py +153 -0
- service_forge/sft/file/__init__.py +0 -0
- service_forge/sft/file/ignore_pattern.py +80 -0
- service_forge/sft/file/sft_file_manager.py +107 -0
- service_forge/sft/kubernetes/kubernetes_manager.py +257 -0
- service_forge/sft/util/assert_util.py +25 -0
- service_forge/sft/util/logger.py +16 -0
- service_forge/sft/util/name_util.py +8 -0
- service_forge/sft/util/yaml_utils.py +57 -0
- service_forge/utils/__init__.py +0 -0
- service_forge/utils/default_type_converter.py +12 -0
- service_forge/utils/register.py +39 -0
- service_forge/utils/type_converter.py +99 -0
- service_forge/utils/workflow_clone.py +124 -0
- service_forge/workflow/__init__.py +1 -0
- service_forge/workflow/context.py +14 -0
- service_forge/workflow/edge.py +24 -0
- service_forge/workflow/node.py +184 -0
- service_forge/workflow/nodes/__init__.py +8 -0
- service_forge/workflow/nodes/control/if_node.py +29 -0
- service_forge/workflow/nodes/control/switch_node.py +28 -0
- service_forge/workflow/nodes/input/console_input_node.py +26 -0
- service_forge/workflow/nodes/llm/query_llm_node.py +41 -0
- service_forge/workflow/nodes/nested/workflow_node.py +28 -0
- service_forge/workflow/nodes/output/kafka_output_node.py +27 -0
- service_forge/workflow/nodes/output/print_node.py +29 -0
- service_forge/workflow/nodes/test/if_console_input_node.py +33 -0
- service_forge/workflow/nodes/test/time_consuming_node.py +62 -0
- service_forge/workflow/port.py +89 -0
- service_forge/workflow/trigger.py +24 -0
- service_forge/workflow/triggers/__init__.py +6 -0
- service_forge/workflow/triggers/a2a_api_trigger.py +255 -0
- service_forge/workflow/triggers/fast_api_trigger.py +169 -0
- service_forge/workflow/triggers/kafka_api_trigger.py +44 -0
- service_forge/workflow/triggers/once_trigger.py +20 -0
- service_forge/workflow/triggers/period_trigger.py +26 -0
- service_forge/workflow/triggers/websocket_api_trigger.py +184 -0
- service_forge/workflow/workflow.py +210 -0
- service_forge/workflow/workflow_callback.py +141 -0
- service_forge/workflow/workflow_event.py +15 -0
- service_forge/workflow/workflow_factory.py +246 -0
- service_forge/workflow/workflow_group.py +27 -0
- service_forge/workflow/workflow_type.py +52 -0
- service_forge-0.1.11.dist-info/METADATA +98 -0
- service_forge-0.1.11.dist-info/RECORD +75 -0
- service_forge-0.1.11.dist-info/WHEEL +4 -0
- service_forge-0.1.11.dist-info/entry_points.txt +2 -0
|
@@ -0,0 +1,41 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import os
|
|
3
|
+
|
|
4
|
+
from service_forge.workflow.node import Node
|
|
5
|
+
from service_forge.workflow.port import Port
|
|
6
|
+
from service_forge.llm import chat_stream, Model
|
|
7
|
+
|
|
8
|
+
class QueryLLMNode(Node):
|
|
9
|
+
DEFAULT_INPUT_PORTS = [
|
|
10
|
+
Port("prompt", str),
|
|
11
|
+
Port("system_prompt", str),
|
|
12
|
+
Port("temperature", float),
|
|
13
|
+
Port("TRIGGER", bool),
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
17
|
+
Port("response", str),
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
AUTO_FILL_INPUT_PORTS = [('TRIGGER', True)]
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
name: str,
|
|
25
|
+
) -> None:
|
|
26
|
+
super().__init__(
|
|
27
|
+
name,
|
|
28
|
+
)
|
|
29
|
+
|
|
30
|
+
async def _run(self, prompt: str, system_prompt: str, temperature: float) -> None:
|
|
31
|
+
if os.path.exists(system_prompt):
|
|
32
|
+
with open(system_prompt, "r") as f:
|
|
33
|
+
system_prompt = f.read()
|
|
34
|
+
if os.path.exists(prompt):
|
|
35
|
+
with open(prompt, "r") as f:
|
|
36
|
+
prompt = f.read()
|
|
37
|
+
|
|
38
|
+
print(f"prompt: {prompt} temperature: {temperature}")
|
|
39
|
+
response = chat_stream(prompt, system_prompt, Model.DEEPSEEK_V3_250324, temperature)
|
|
40
|
+
for chunk in response:
|
|
41
|
+
yield chunk
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
from typing import TYPE_CHECKING
|
|
4
|
+
from service_forge.workflow.node import Node
|
|
5
|
+
from service_forge.workflow.port import Port
|
|
6
|
+
|
|
7
|
+
class WorkflowNode(Node):
|
|
8
|
+
from service_forge.workflow.workflow import Workflow
|
|
9
|
+
DEFAULT_INPUT_PORTS = [
|
|
10
|
+
Port("workflow", Workflow),
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
def __init__(
|
|
17
|
+
self,
|
|
18
|
+
name: str,
|
|
19
|
+
) -> None:
|
|
20
|
+
super().__init__(
|
|
21
|
+
name,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
async def _run(self, workflow: Workflow, **kwargs) -> None:
|
|
25
|
+
for input_port in self.input_ports:
|
|
26
|
+
if input_port.is_sub_workflow_input_port():
|
|
27
|
+
input_port.port.node.fill_input(input_port.port, input_port.value)
|
|
28
|
+
await workflow.run()
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from service_forge.workflow.node import Node
|
|
3
|
+
from service_forge.workflow.port import Port
|
|
4
|
+
from service_forge.api.kafka_api import KafkaApp
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
class KafkaOutputNode(Node):
|
|
8
|
+
DEFAULT_INPUT_PORTS = [
|
|
9
|
+
Port("app", KafkaApp),
|
|
10
|
+
Port("topic", str),
|
|
11
|
+
Port("data_type", type),
|
|
12
|
+
Port("data", Any),
|
|
13
|
+
]
|
|
14
|
+
|
|
15
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
16
|
+
]
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
name: str,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(
|
|
23
|
+
name,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
async def _run(self, app: KafkaApp, topic: str, data_type: type, data: Any) -> None:
|
|
27
|
+
await app.send_message(topic, data_type, data)
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
from typing import AsyncIterator
|
|
4
|
+
from service_forge.workflow.node import Node
|
|
5
|
+
from service_forge.workflow.port import Port
|
|
6
|
+
|
|
7
|
+
class PrintNode(Node):
|
|
8
|
+
DEFAULT_INPUT_PORTS = [
|
|
9
|
+
Port("TRIGGER", bool),
|
|
10
|
+
Port("message", str),
|
|
11
|
+
]
|
|
12
|
+
|
|
13
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
AUTO_FILL_INPUT_PORTS = [('TRIGGER', True)]
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
name: str,
|
|
21
|
+
) -> None:
|
|
22
|
+
super().__init__(
|
|
23
|
+
name,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
async def _run(self, message: str) -> AsyncIterator[str]:
|
|
27
|
+
for char in str(message):
|
|
28
|
+
await asyncio.sleep(0.1)
|
|
29
|
+
yield char
|
|
@@ -0,0 +1,33 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import AsyncIterator
|
|
3
|
+
from service_forge.workflow.node import Node
|
|
4
|
+
from service_forge.workflow.port import Port
|
|
5
|
+
|
|
6
|
+
class IfConsoleInputNode(Node):
|
|
7
|
+
DEFAULT_INPUT_PORTS = [
|
|
8
|
+
Port("TRIGGER", bool),
|
|
9
|
+
Port("condition", str),
|
|
10
|
+
]
|
|
11
|
+
|
|
12
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
13
|
+
Port("true", bool),
|
|
14
|
+
Port("false", bool),
|
|
15
|
+
]
|
|
16
|
+
|
|
17
|
+
def __init__(
|
|
18
|
+
self,
|
|
19
|
+
name: str,
|
|
20
|
+
) -> None:
|
|
21
|
+
super().__init__(
|
|
22
|
+
name,
|
|
23
|
+
)
|
|
24
|
+
|
|
25
|
+
async def _run(self, condition: str) -> None:
|
|
26
|
+
while True:
|
|
27
|
+
user_input = await self._query_user(condition)
|
|
28
|
+
if user_input.lower() in ['y', 'yes']:
|
|
29
|
+
self.activate_output_edges(self.get_output_port_by_name('true'), True)
|
|
30
|
+
break
|
|
31
|
+
elif user_input.lower() in ['n', 'no']:
|
|
32
|
+
self.activate_output_edges(self.get_output_port_by_name('false'), False)
|
|
33
|
+
break
|
|
@@ -0,0 +1,62 @@
|
|
|
1
|
+
|
|
2
|
+
from __future__ import annotations
|
|
3
|
+
import asyncio
|
|
4
|
+
import uuid
|
|
5
|
+
import json
|
|
6
|
+
from typing import Any
|
|
7
|
+
from ...node import Node
|
|
8
|
+
from ...port import Port
|
|
9
|
+
from ....api.routers.websocket.websocket_manager import websocket_manager
|
|
10
|
+
|
|
11
|
+
# It's deprecated, just for testing
|
|
12
|
+
class TimeConsumingNode(Node):
|
|
13
|
+
"""模拟耗时节点,定期发送进度更新"""
|
|
14
|
+
|
|
15
|
+
DEFAULT_INPUT_PORTS = [
|
|
16
|
+
Port("input", Any)
|
|
17
|
+
]
|
|
18
|
+
|
|
19
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
20
|
+
Port("output", Any)
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
def __init__(self, name: str, duration: float = 2.0):
|
|
24
|
+
super().__init__(name)
|
|
25
|
+
self.duration = duration # 总耗时(秒)
|
|
26
|
+
self.progress = 0.0
|
|
27
|
+
self.task_id = None
|
|
28
|
+
|
|
29
|
+
async def _run(self, input: Any = None, task_id: uuid.UUID = None) -> str:
|
|
30
|
+
"""执行耗时任务,定期更新进度"""
|
|
31
|
+
# 保存任务ID(如果有)
|
|
32
|
+
if task_id is not None:
|
|
33
|
+
self.task_id = task_id
|
|
34
|
+
|
|
35
|
+
total_steps = 10
|
|
36
|
+
result = f"Completed {self.name} after {self.duration} seconds"
|
|
37
|
+
|
|
38
|
+
# 分步骤执行,每步更新进度
|
|
39
|
+
for i in range(total_steps + 1):
|
|
40
|
+
# 更新进度
|
|
41
|
+
self.progress = i / total_steps
|
|
42
|
+
|
|
43
|
+
# 发送进度更新
|
|
44
|
+
if self.task_id:
|
|
45
|
+
await websocket_manager.send_progress(
|
|
46
|
+
self.task_id,
|
|
47
|
+
self.name,
|
|
48
|
+
self.progress
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
# 模拟耗时
|
|
52
|
+
if i < total_steps: # 最后一步不需要等待
|
|
53
|
+
await asyncio.sleep(self.duration / total_steps)
|
|
54
|
+
|
|
55
|
+
# 获取输出端口并设置值
|
|
56
|
+
output_port = self.get_output_port_by_name('output')
|
|
57
|
+
output_port.prepare(result)
|
|
58
|
+
|
|
59
|
+
# 激活输出端口
|
|
60
|
+
self.activate_output_edges(output_port, result)
|
|
61
|
+
|
|
62
|
+
return result
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from typing import TYPE_CHECKING, Any
|
|
3
|
+
from ..utils.workflow_clone import port_clone
|
|
4
|
+
|
|
5
|
+
if TYPE_CHECKING:
|
|
6
|
+
from .node import Node
|
|
7
|
+
|
|
8
|
+
class Port:
|
|
9
|
+
def __init__(
|
|
10
|
+
self,
|
|
11
|
+
name: str,
|
|
12
|
+
type: type,
|
|
13
|
+
node: Node = None,
|
|
14
|
+
port: Port = None,
|
|
15
|
+
value: Any = None,
|
|
16
|
+
default: Any = None,
|
|
17
|
+
is_extended: bool = False,
|
|
18
|
+
is_extended_generated: bool = False,
|
|
19
|
+
) -> None:
|
|
20
|
+
self.name = name
|
|
21
|
+
self.type = type
|
|
22
|
+
self.node = node
|
|
23
|
+
self.port = port
|
|
24
|
+
self.value = value
|
|
25
|
+
# not used yet
|
|
26
|
+
self.default = default
|
|
27
|
+
self.is_prepared = False
|
|
28
|
+
self.is_extended = is_extended
|
|
29
|
+
self.is_extended_generated = is_extended_generated
|
|
30
|
+
|
|
31
|
+
def is_sub_workflow_input_port(self) -> bool:
|
|
32
|
+
return self.port != None
|
|
33
|
+
|
|
34
|
+
def prepare(self, data: Any) -> None:
|
|
35
|
+
from ..utils.default_type_converter import type_converter
|
|
36
|
+
data = type_converter.convert(data, self.type, node=self.node)
|
|
37
|
+
self.value = data
|
|
38
|
+
self.is_prepared = True
|
|
39
|
+
|
|
40
|
+
def trigger(self) -> None:
|
|
41
|
+
if self.node is None:
|
|
42
|
+
return
|
|
43
|
+
if self in self.node.input_variables:
|
|
44
|
+
return
|
|
45
|
+
self.node.input_variables[self] = self.value
|
|
46
|
+
self.node.num_activated_input_edges += 1
|
|
47
|
+
if self.node.is_ready():
|
|
48
|
+
self.node.workflow.ready_nodes.append(self.node)
|
|
49
|
+
|
|
50
|
+
def activate(self, data: Any) -> None:
|
|
51
|
+
self.prepare(data)
|
|
52
|
+
self.trigger()
|
|
53
|
+
|
|
54
|
+
def get_extended_name(self) -> str:
|
|
55
|
+
if self.is_extended_generated:
|
|
56
|
+
return '_'.join(self.name.split('_')[:-1])
|
|
57
|
+
raise ValueError(f"Port {self.name} is not extended generated.")
|
|
58
|
+
|
|
59
|
+
def get_extended_index(self) -> int:
|
|
60
|
+
if self.is_extended_generated:
|
|
61
|
+
return int(self.name.split('_')[-1])
|
|
62
|
+
raise ValueError(f"Port {self.name} is not extended generated.")
|
|
63
|
+
|
|
64
|
+
def _clone(self, node_map: dict[Node, Node]) -> Port:
|
|
65
|
+
return port_clone(self, node_map)
|
|
66
|
+
|
|
67
|
+
# node port
|
|
68
|
+
def create_port(name: str, type: type, node: Node = None, value: Any = None, port: Port = None) -> Port:
|
|
69
|
+
return Port(name, type, node, port, value)
|
|
70
|
+
|
|
71
|
+
# workflow input port
|
|
72
|
+
def create_workflow_input_port(name: str, port: Port, value: Any = None) -> Port:
|
|
73
|
+
if value is None:
|
|
74
|
+
value = port.value
|
|
75
|
+
return Port(name, port.type, port.node, port, value)
|
|
76
|
+
|
|
77
|
+
# sub workflow input port
|
|
78
|
+
# node is the node that the sub workflow is running on
|
|
79
|
+
def create_sub_workflow_input_port(name: str, node: Node, port: Port, value: Any = None) -> Port:
|
|
80
|
+
if value is None:
|
|
81
|
+
value = port.value
|
|
82
|
+
return Port(name, port.type, node, port, value)
|
|
83
|
+
|
|
84
|
+
PORT_DELIMITER = '|'
|
|
85
|
+
|
|
86
|
+
def parse_port_name(port_name: str) -> tuple[str, str]:
|
|
87
|
+
if PORT_DELIMITER not in port_name or len(port_name.split(PORT_DELIMITER)) != 2:
|
|
88
|
+
raise ValueError(f"Invalid port name: {port_name}")
|
|
89
|
+
return port_name.split(PORT_DELIMITER)
|
|
@@ -0,0 +1,24 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
import asyncio
|
|
3
|
+
from typing import AsyncIterator
|
|
4
|
+
from abc import ABC, abstractmethod
|
|
5
|
+
import uuid
|
|
6
|
+
from .node import Node
|
|
7
|
+
from .workflow_event import WorkflowResult
|
|
8
|
+
|
|
9
|
+
class Trigger(Node, ABC):
|
|
10
|
+
def __init__(self, name: str):
|
|
11
|
+
super().__init__(name)
|
|
12
|
+
self.trigger_queue = asyncio.Queue()
|
|
13
|
+
# for workflow result
|
|
14
|
+
self.result_queues: dict[uuid.UUID, asyncio.Queue[WorkflowResult]] = {}
|
|
15
|
+
# for node stream output
|
|
16
|
+
self.stream_queues: dict[uuid.UUID, asyncio.Queue[WorkflowResult]] = {}
|
|
17
|
+
|
|
18
|
+
@abstractmethod
|
|
19
|
+
async def _run(self) -> AsyncIterator[bool]:
|
|
20
|
+
...
|
|
21
|
+
|
|
22
|
+
def trigger(self, task_id: uuid.UUID) -> bool:
|
|
23
|
+
self.prepare_output_edges(self.get_output_port_by_name('trigger'), True)
|
|
24
|
+
return task_id
|
|
@@ -0,0 +1,6 @@
|
|
|
1
|
+
from .once_trigger import OnceTrigger
|
|
2
|
+
from .period_trigger import PeriodTrigger
|
|
3
|
+
from .fast_api_trigger import FastAPITrigger
|
|
4
|
+
from .kafka_api_trigger import KafkaAPITrigger
|
|
5
|
+
from .websocket_api_trigger import WebSocketAPITrigger
|
|
6
|
+
from .a2a_api_trigger import A2AAPITrigger
|
|
@@ -0,0 +1,255 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from loguru import logger
|
|
3
|
+
from service_forge.workflow.trigger import Trigger
|
|
4
|
+
from typing import AsyncIterator, Any
|
|
5
|
+
from service_forge.workflow.port import Port
|
|
6
|
+
from google.protobuf.message import Message
|
|
7
|
+
from google.protobuf.json_format import MessageToJson
|
|
8
|
+
from fastapi import FastAPI
|
|
9
|
+
from a2a.types import (
|
|
10
|
+
AgentCapabilities,
|
|
11
|
+
AgentCard,
|
|
12
|
+
AgentSkill,
|
|
13
|
+
)
|
|
14
|
+
from a2a.server.apps import A2AStarletteApplication
|
|
15
|
+
from a2a.server.request_handlers import DefaultRequestHandler
|
|
16
|
+
from a2a.server.tasks import InMemoryTaskStore
|
|
17
|
+
from a2a.server.agent_execution import AgentExecutor, RequestContext
|
|
18
|
+
from a2a.server.events import EventQueue
|
|
19
|
+
from a2a.utils import new_agent_text_message
|
|
20
|
+
from a2a.utils.constants import DEFAULT_RPC_URL, EXTENDED_AGENT_CARD_PATH, AGENT_CARD_WELL_KNOWN_PATH
|
|
21
|
+
|
|
22
|
+
import json
|
|
23
|
+
import uuid
|
|
24
|
+
import asyncio
|
|
25
|
+
from service_forge.workflow.workflow_event import WorkflowResult
|
|
26
|
+
|
|
27
|
+
class A2AAgentExecutor(AgentExecutor):
|
|
28
|
+
def __init__(self, trigger: A2AAPITrigger):
|
|
29
|
+
self.trigger = trigger
|
|
30
|
+
|
|
31
|
+
@staticmethod
|
|
32
|
+
def serialize_result(result: Any) -> str:
|
|
33
|
+
if isinstance(result, Message):
|
|
34
|
+
return MessageToJson(
|
|
35
|
+
result,
|
|
36
|
+
preserving_proto_field_name=True
|
|
37
|
+
)
|
|
38
|
+
return json.dumps(result)
|
|
39
|
+
|
|
40
|
+
async def send_event(self, event_queue: EventQueue, item: WorkflowResult) -> None:
|
|
41
|
+
if item.is_error:
|
|
42
|
+
result = {
|
|
43
|
+
'event': 'error',
|
|
44
|
+
'detail': str(item.result)
|
|
45
|
+
}
|
|
46
|
+
await event_queue.enqueue_event(new_agent_text_message(json.dumps(result)))
|
|
47
|
+
|
|
48
|
+
if item.is_end:
|
|
49
|
+
result = {
|
|
50
|
+
'event': 'end',
|
|
51
|
+
'detail': self.serialize_result(item.result)
|
|
52
|
+
}
|
|
53
|
+
await event_queue.enqueue_event(new_agent_text_message(json.dumps(result)))
|
|
54
|
+
|
|
55
|
+
result = {
|
|
56
|
+
'event': 'data',
|
|
57
|
+
'data': self.serialize_result(item.result)
|
|
58
|
+
}
|
|
59
|
+
await event_queue.enqueue_event(new_agent_text_message(json.dumps(result)))
|
|
60
|
+
|
|
61
|
+
async def execute(
|
|
62
|
+
self,
|
|
63
|
+
context: RequestContext,
|
|
64
|
+
event_queue: EventQueue,
|
|
65
|
+
) -> None:
|
|
66
|
+
task_id = uuid.uuid4()
|
|
67
|
+
self.trigger.result_queues[task_id] = asyncio.Queue()
|
|
68
|
+
|
|
69
|
+
self.trigger.trigger_queue.put_nowait({
|
|
70
|
+
'id': task_id,
|
|
71
|
+
'context': context,
|
|
72
|
+
})
|
|
73
|
+
|
|
74
|
+
# TODO: support stream output
|
|
75
|
+
if False:
|
|
76
|
+
self.trigger.stream_queues[task_id] = asyncio.Queue()
|
|
77
|
+
while True:
|
|
78
|
+
item = await self.trigger.stream_queues[task_id].get()
|
|
79
|
+
await self.send_event(event_queue, item)
|
|
80
|
+
|
|
81
|
+
if item.is_error or item.is_end:
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
if task_id in self.trigger.stream_queues:
|
|
85
|
+
del self.trigger.stream_queues[task_id]
|
|
86
|
+
else:
|
|
87
|
+
result = await self.trigger.result_queues[task_id].get()
|
|
88
|
+
await self.send_event(event_queue, result)
|
|
89
|
+
|
|
90
|
+
if task_id in self.trigger.result_queues:
|
|
91
|
+
del self.trigger.result_queues[task_id]
|
|
92
|
+
|
|
93
|
+
async def cancel(
|
|
94
|
+
self, context: RequestContext, event_queue: EventQueue
|
|
95
|
+
) -> None:
|
|
96
|
+
raise Exception('cancel not supported')
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class A2AAPITrigger(Trigger):
|
|
100
|
+
DEFAULT_INPUT_PORTS = [
|
|
101
|
+
Port("app", FastAPI),
|
|
102
|
+
Port("path", str),
|
|
103
|
+
Port("skill_id", str, is_extended=True),
|
|
104
|
+
Port("skill_name", str, is_extended=True),
|
|
105
|
+
Port("skill_description", str, is_extended=True),
|
|
106
|
+
Port("skill_tags", list[str], is_extended=True),
|
|
107
|
+
Port("skill_examples", list[str], is_extended=True),
|
|
108
|
+
Port("agent_name", str),
|
|
109
|
+
Port("agent_url", str),
|
|
110
|
+
Port("agent_description", str),
|
|
111
|
+
Port("agent_version", str),
|
|
112
|
+
]
|
|
113
|
+
|
|
114
|
+
DEFAULT_OUTPUT_PORTS = [
|
|
115
|
+
Port("trigger", bool),
|
|
116
|
+
Port("context", RequestContext),
|
|
117
|
+
]
|
|
118
|
+
|
|
119
|
+
def __init__(self, name: str):
|
|
120
|
+
super().__init__(name)
|
|
121
|
+
self.events = {}
|
|
122
|
+
self.is_setup_handler = False
|
|
123
|
+
self.agent_card: AgentCard | None = None
|
|
124
|
+
|
|
125
|
+
@staticmethod
|
|
126
|
+
def serialize_result(result: Any):
|
|
127
|
+
if isinstance(result, Message):
|
|
128
|
+
return MessageToJson(
|
|
129
|
+
result,
|
|
130
|
+
preserving_proto_field_name=True
|
|
131
|
+
)
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
def _setup_handler(
|
|
135
|
+
self,
|
|
136
|
+
app: FastAPI,
|
|
137
|
+
path: str,
|
|
138
|
+
skill_id: list[tuple[int, str]],
|
|
139
|
+
skill_name: list[tuple[int, str]],
|
|
140
|
+
skill_description: list[tuple[int, str]],
|
|
141
|
+
skill_tags: list[tuple[int, list[str]]],
|
|
142
|
+
skill_examples: list[tuple[int, list[str]]],
|
|
143
|
+
agent_name: str,
|
|
144
|
+
agent_url: str,
|
|
145
|
+
agent_description: str,
|
|
146
|
+
agent_version: str,
|
|
147
|
+
) -> None:
|
|
148
|
+
|
|
149
|
+
skills_config = []
|
|
150
|
+
for i in range(len(skill_id)):
|
|
151
|
+
skills_config.append({
|
|
152
|
+
'id': '',
|
|
153
|
+
'name': '',
|
|
154
|
+
'description': '',
|
|
155
|
+
'tags': [],
|
|
156
|
+
'examples': [],
|
|
157
|
+
})
|
|
158
|
+
|
|
159
|
+
for i in range(len(skill_id)):
|
|
160
|
+
skills_config[skill_id[i][0]]['id'] = skill_id[i][1]
|
|
161
|
+
skills_config[skill_name[i][0]]['name'] = skill_name[i][1]
|
|
162
|
+
skills_config[skill_description[i][0]]['description'] = skill_description[i][1]
|
|
163
|
+
skills_config[skill_tags[i][0]]['tags'] = skill_tags[i][1]
|
|
164
|
+
skills_config[skill_examples[i][0]]['examples'] = skill_examples[i][1]
|
|
165
|
+
|
|
166
|
+
skills = []
|
|
167
|
+
for config in skills_config:
|
|
168
|
+
skills.append(AgentSkill(
|
|
169
|
+
id=config['id'],
|
|
170
|
+
name=config['name'],
|
|
171
|
+
description=config['description'],
|
|
172
|
+
tags=config['tags'],
|
|
173
|
+
examples=config['examples'],
|
|
174
|
+
))
|
|
175
|
+
|
|
176
|
+
agent_card = AgentCard(
|
|
177
|
+
name=agent_name,
|
|
178
|
+
description=agent_description,
|
|
179
|
+
url=agent_url,
|
|
180
|
+
version=agent_version,
|
|
181
|
+
default_input_modes=['text'],
|
|
182
|
+
default_output_modes=['text'],
|
|
183
|
+
capabilities=AgentCapabilities(streaming=True),
|
|
184
|
+
skills=skills,
|
|
185
|
+
supports_authenticated_extended_card=False,
|
|
186
|
+
)
|
|
187
|
+
|
|
188
|
+
self.agent_card = agent_card
|
|
189
|
+
|
|
190
|
+
request_handler = DefaultRequestHandler(
|
|
191
|
+
agent_executor=A2AAgentExecutor(self),
|
|
192
|
+
task_store=InMemoryTaskStore(),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
try:
|
|
196
|
+
server = A2AStarletteApplication(
|
|
197
|
+
agent_card=agent_card,
|
|
198
|
+
http_handler=request_handler,
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
server.add_routes_to_app(
|
|
202
|
+
app,
|
|
203
|
+
agent_card_url="/a2a" + path + AGENT_CARD_WELL_KNOWN_PATH,
|
|
204
|
+
rpc_url="/a2a" + path + DEFAULT_RPC_URL,
|
|
205
|
+
extended_agent_card_url="/a2a" + path + EXTENDED_AGENT_CARD_PATH,
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
except Exception as e:
|
|
209
|
+
logger.error(f"Error adding A2A routes: {e}")
|
|
210
|
+
raise
|
|
211
|
+
|
|
212
|
+
async def _run(
|
|
213
|
+
self,
|
|
214
|
+
app: FastAPI,
|
|
215
|
+
path: str,
|
|
216
|
+
skill_id: list[tuple[int, str]],
|
|
217
|
+
skill_name: list[tuple[int, str]],
|
|
218
|
+
skill_description: list[tuple[int, str]],
|
|
219
|
+
skill_tags: list[tuple[int, list[str]]],
|
|
220
|
+
skill_examples: list[tuple[int, list[str]]],
|
|
221
|
+
agent_name: str,
|
|
222
|
+
agent_url: str,
|
|
223
|
+
agent_description: str,
|
|
224
|
+
agent_version: str,
|
|
225
|
+
) -> AsyncIterator[bool]:
|
|
226
|
+
if len(skill_id) != len(skill_name) or len(skill_id) != len(skill_description) or len(skill_id) != len(skill_tags) or len(skill_id) != len(skill_examples):
|
|
227
|
+
raise ValueError("skill_id, skill_name, skill_description, skill_tags, skill_examples must have the same length")
|
|
228
|
+
|
|
229
|
+
if not self.is_setup_handler:
|
|
230
|
+
self._setup_handler(
|
|
231
|
+
app,
|
|
232
|
+
path,
|
|
233
|
+
skill_id,
|
|
234
|
+
skill_name,
|
|
235
|
+
skill_description,
|
|
236
|
+
skill_tags,
|
|
237
|
+
skill_examples,
|
|
238
|
+
agent_name,
|
|
239
|
+
agent_url,
|
|
240
|
+
agent_description,
|
|
241
|
+
agent_version,
|
|
242
|
+
)
|
|
243
|
+
self.is_setup_handler = True
|
|
244
|
+
|
|
245
|
+
logger.info(f"A2A Trigger {self.name} is running")
|
|
246
|
+
|
|
247
|
+
while True:
|
|
248
|
+
try:
|
|
249
|
+
trigger = await self.trigger_queue.get()
|
|
250
|
+
self.prepare_output_edges('context', trigger['context'])
|
|
251
|
+
yield self.trigger(trigger['id'])
|
|
252
|
+
except Exception as e:
|
|
253
|
+
logger.error(f"Error in A2AAPITrigger._run: {e}")
|
|
254
|
+
continue
|
|
255
|
+
|