smarta2a 0.2.1__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {smarta2a-0.2.1 → smarta2a-0.2.3}/PKG-INFO +13 -7
- {smarta2a-0.2.1 → smarta2a-0.2.3}/README.md +1 -1
- {smarta2a-0.2.1 → smarta2a-0.2.3}/pyproject.toml +12 -6
- {smarta2a-0.2.1 → smarta2a-0.2.3}/requirements.txt +2 -0
- smarta2a-0.2.3/smarta2a/__init__.py +10 -0
- smarta2a-0.2.3/smarta2a/agent/a2a_agent.py +38 -0
- smarta2a-0.2.3/smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a-0.2.3/smarta2a/archive/mcp_client.py +86 -0
- smarta2a-0.2.3/smarta2a/client/__init__.py +0 -0
- smarta2a-0.2.3/smarta2a/client/a2a_client.py +267 -0
- smarta2a-0.2.3/smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a-0.2.3/smarta2a/client/tools_manager.py +58 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a-0.2.3/smarta2a/model_providers/__init__.py +5 -0
- smarta2a-0.2.3/smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a-0.2.3/smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a-0.2.3/smarta2a/server/__init__.py +3 -0
- smarta2a-0.2.3/smarta2a/server/handler_registry.py +23 -0
- {smarta2a-0.2.1/smarta2a → smarta2a-0.2.3/smarta2a/server}/server.py +224 -254
- smarta2a-0.2.3/smarta2a/server/state_manager.py +34 -0
- smarta2a-0.2.3/smarta2a/server/subscription_service.py +109 -0
- smarta2a-0.2.3/smarta2a/server/task_service.py +155 -0
- smarta2a-0.2.3/smarta2a/state_stores/__init__.py +8 -0
- smarta2a-0.2.3/smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a-0.2.3/smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a-0.2.3/smarta2a/utils/__init__.py +32 -0
- smarta2a-0.2.3/smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a-0.2.3/smarta2a/utils/task_builder.py +153 -0
- smarta2a-0.2.3/smarta2a/utils/task_request_builder.py +114 -0
- {smarta2a-0.2.1/smarta2a → smarta2a-0.2.3/smarta2a/utils}/types.py +62 -2
- smarta2a-0.2.3/tests/__init__.py +3 -0
- {smarta2a-0.2.1 → smarta2a-0.2.3}/tests/test_server.py +4 -10
- smarta2a-0.2.3/tests/test_server_history.py +189 -0
- smarta2a-0.2.3/tests/test_task_request_builder.py +130 -0
- smarta2a-0.2.1/smarta2a/__init__.py +0 -10
- {smarta2a-0.2.1 → smarta2a-0.2.3}/.gitignore +0 -0
- {smarta2a-0.2.1 → smarta2a-0.2.3}/LICENSE +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: smarta2a
|
3
|
-
Version: 0.2.
|
4
|
-
Summary: A Python
|
3
|
+
Version: 0.2.3
|
4
|
+
Summary: A simple Python framework (built on top of FastAPI) for creating Agents following Google's Agent2Agent protocol
|
5
5
|
Project-URL: Homepage, https://github.com/siddharthsma/smarta2a
|
6
6
|
Project-URL: Bug Tracker, https://github.com/siddharthsma/smarta2a/issues
|
7
7
|
Author-email: Siddharth Ambegaonkar <siddharthsma@gmail.com>
|
@@ -10,10 +10,16 @@ Classifier: License :: OSI Approved :: MIT License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.8
|
13
|
-
Requires-Dist:
|
14
|
-
Requires-Dist:
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist:
|
13
|
+
Requires-Dist: anyio>=4.9.0
|
14
|
+
Requires-Dist: fastapi>=0.115.12
|
15
|
+
Requires-Dist: httpx>=0.28.1
|
16
|
+
Requires-Dist: mcp>=0.1.0
|
17
|
+
Requires-Dist: openai>=1.0.0
|
18
|
+
Requires-Dist: pydantic>=2.11.3
|
19
|
+
Requires-Dist: sse-starlette>=2.2.1
|
20
|
+
Requires-Dist: starlette>=0.46.2
|
21
|
+
Requires-Dist: typing-extensions>=4.13.2
|
22
|
+
Requires-Dist: uvicorn>=0.34.1
|
17
23
|
Description-Content-Type: text/markdown
|
18
24
|
|
19
25
|
# SmartA2A
|
@@ -45,7 +51,7 @@ pip install smarta2a
|
|
45
51
|
## Simple Echo Server Implementation
|
46
52
|
|
47
53
|
```python
|
48
|
-
from smarta2a import SmartA2A
|
54
|
+
from smarta2a.server import SmartA2A
|
49
55
|
|
50
56
|
app = SmartA2A("EchoServer")
|
51
57
|
|
@@ -4,11 +4,11 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "smarta2a"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.3"
|
8
8
|
authors = [
|
9
9
|
{ name = "Siddharth Ambegaonkar", email = "siddharthsma@gmail.com" },
|
10
10
|
]
|
11
|
-
description = "A Python
|
11
|
+
description = "A simple Python framework (built on top of FastAPI) for creating Agents following Google's Agent2Agent protocol"
|
12
12
|
readme = "README.md"
|
13
13
|
requires-python = ">=3.8"
|
14
14
|
classifiers = [
|
@@ -17,10 +17,16 @@ classifiers = [
|
|
17
17
|
"Operating System :: OS Independent",
|
18
18
|
]
|
19
19
|
dependencies = [
|
20
|
-
"fastapi",
|
21
|
-
"pydantic",
|
22
|
-
"uvicorn",
|
23
|
-
"sse-starlette",
|
20
|
+
"fastapi>=0.115.12",
|
21
|
+
"pydantic>=2.11.3",
|
22
|
+
"uvicorn>=0.34.1",
|
23
|
+
"sse-starlette>=2.2.1",
|
24
|
+
"openai>=1.0.0",
|
25
|
+
"mcp>=0.1.0",
|
26
|
+
"httpx>=0.28.1",
|
27
|
+
"starlette>=0.46.2",
|
28
|
+
"typing-extensions>=4.13.2",
|
29
|
+
"anyio>=4.9.0"
|
24
30
|
]
|
25
31
|
|
26
32
|
[project.urls]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Library imports
|
2
|
+
|
3
|
+
|
4
|
+
# Local imports
|
5
|
+
from smarta2a.server import SmartA2A
|
6
|
+
from smarta2a.model_providers.base_llm_provider import BaseLLMProvider
|
7
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
8
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
9
|
+
from smarta2a.utils.types import StateData, SendTaskRequest
|
10
|
+
|
11
|
+
class A2AAgent:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
name: str,
|
15
|
+
model_provider: BaseLLMProvider,
|
16
|
+
history_update_strategy: HistoryUpdateStrategy,
|
17
|
+
state_storage: BaseStateStore,
|
18
|
+
):
|
19
|
+
self.model_provider = model_provider
|
20
|
+
self.app = SmartA2A(
|
21
|
+
name=name,
|
22
|
+
history_update_strategy=history_update_strategy,
|
23
|
+
state_storage=state_storage
|
24
|
+
)
|
25
|
+
self.__register_handlers()
|
26
|
+
|
27
|
+
def __register_handlers(self):
|
28
|
+
@self.app.on_send_task()
|
29
|
+
async def on_send_task(request: SendTaskRequest, state: StateData):
|
30
|
+
response = self.model_provider.generate(state.history)
|
31
|
+
return response
|
32
|
+
|
33
|
+
def start(self, **kwargs):
|
34
|
+
self.app.configure(**kwargs)
|
35
|
+
self.app.run()
|
36
|
+
|
37
|
+
|
38
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Library imports
|
2
|
+
from mcp.server.fastmcp import FastMCP, Context
|
3
|
+
|
4
|
+
class A2AMCPServer:
|
5
|
+
def __init__(self):
|
6
|
+
self.mcp = FastMCP("A2AMCPServer")
|
7
|
+
self._register_handlers()
|
8
|
+
|
9
|
+
def _register_handlers(self):
|
10
|
+
|
11
|
+
@self.mcp.middleware
|
12
|
+
async def extract_session_id(request, call_next):
|
13
|
+
# Extract 'x-session-id' from headers
|
14
|
+
session_id = request.headers.get("x-session-id")
|
15
|
+
# Store it in the request state for later access
|
16
|
+
request.state.session_id = session_id
|
17
|
+
# Proceed with the request processing
|
18
|
+
response = await call_next(request)
|
19
|
+
return response
|
20
|
+
|
21
|
+
@self.mcp.tool()
|
22
|
+
def send_task(ctx: Context, url: str, message: str):
|
23
|
+
session_id = ctx.request.state.session_id
|
24
|
+
pass
|
25
|
+
|
26
|
+
@self.mcp.tool()
|
27
|
+
def get_task():
|
28
|
+
pass
|
29
|
+
|
30
|
+
@self.mcp.tool()
|
31
|
+
def cancel_task():
|
32
|
+
pass
|
33
|
+
|
34
|
+
|
35
|
+
|
36
|
+
def start(self):
|
37
|
+
pass
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Library imports
|
2
|
+
import re
|
3
|
+
from contextlib import AsyncExitStack
|
4
|
+
from mcp.client import ClientSession, sse_client, stdio_client, StdioServerParameters
|
5
|
+
|
6
|
+
|
7
|
+
class MCPClient:
|
8
|
+
def __init__(self):
|
9
|
+
self.session = None
|
10
|
+
self.exit_stack = AsyncExitStack()
|
11
|
+
self._connect_to_server()
|
12
|
+
|
13
|
+
async def _connect_to_sse_server(self, server_url: str):
|
14
|
+
"""Connect to an SSE MCP server."""
|
15
|
+
self._streams_context = sse_client(url=server_url)
|
16
|
+
streams = await self._streams_context.__aenter__()
|
17
|
+
|
18
|
+
self._session_context = ClientSession(*streams)
|
19
|
+
self.session = await self._session_context.__aenter__()
|
20
|
+
|
21
|
+
# Initialize
|
22
|
+
await self.session.initialize()
|
23
|
+
|
24
|
+
async def _connect_to_stdio_server(self, server_script_path: str):
|
25
|
+
"""Connect to a stdio MCP server."""
|
26
|
+
is_python = False
|
27
|
+
is_javascript = False
|
28
|
+
command = None
|
29
|
+
args = [server_script_path]
|
30
|
+
|
31
|
+
# Determine if the server is a file path or npm package
|
32
|
+
if server_script_path.startswith("@") or "/" not in server_script_path:
|
33
|
+
# Assume it's an npm package
|
34
|
+
is_javascript = True
|
35
|
+
command = "npx"
|
36
|
+
else:
|
37
|
+
# It's a file path
|
38
|
+
is_python = server_script_path.endswith(".py")
|
39
|
+
is_javascript = server_script_path.endswith(".js")
|
40
|
+
if not (is_python or is_javascript):
|
41
|
+
raise ValueError("Server script must be a .py, .js file or npm package.")
|
42
|
+
|
43
|
+
command = "python" if is_python else "node"
|
44
|
+
|
45
|
+
server_params = StdioServerParameters(
|
46
|
+
command=command,
|
47
|
+
args=args,
|
48
|
+
env=None
|
49
|
+
)
|
50
|
+
|
51
|
+
# Start the server
|
52
|
+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
53
|
+
self.stdio, self.writer = stdio_transport
|
54
|
+
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.writer))
|
55
|
+
|
56
|
+
await self.session.initialize()
|
57
|
+
|
58
|
+
async def _connect_to_server(self, server_path_or_url: str):
|
59
|
+
"""Connect to an MCP server (either stdio or SSE)."""
|
60
|
+
# Check if the input is a URL (for SSE server)
|
61
|
+
url_pattern = re.compile(r'^https?://')
|
62
|
+
|
63
|
+
if url_pattern.match(server_path_or_url):
|
64
|
+
# It's a URL, connect to SSE server
|
65
|
+
await self._connect_to_sse_server(server_path_or_url)
|
66
|
+
else:
|
67
|
+
# It's a script path, connect to stdio server
|
68
|
+
await self._connect_to_stdio_server(server_path_or_url)
|
69
|
+
|
70
|
+
async def list_tools(self):
|
71
|
+
"""List available tools."""
|
72
|
+
response = await self.session.list_tools()
|
73
|
+
return response.tools
|
74
|
+
|
75
|
+
async def call_tool(self, tool_name: str, **tool_args):
|
76
|
+
"""Call a tool."""
|
77
|
+
response = await self.session.call_tool(tool_name, **tool_args)
|
78
|
+
return response.content
|
79
|
+
|
80
|
+
async def cleanup(self):
|
81
|
+
"""Clean up resources."""
|
82
|
+
await self.exit_stack.aclose()
|
83
|
+
if hasattr(self, '_session_context') and self._session_context:
|
84
|
+
await self._session_context.__aexit__(None, None, None)
|
85
|
+
if hasattr(self, '_streams_context') and self._streams_context:
|
86
|
+
await self._streams_context.__aexit__(None, None, None)
|
File without changes
|
@@ -0,0 +1,267 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Any, Literal, AsyncIterable, get_origin, get_args
|
3
|
+
import httpx
|
4
|
+
import json
|
5
|
+
from httpx_sse import connect_sse
|
6
|
+
from inspect import signature, Parameter, iscoroutinefunction
|
7
|
+
from pydantic import create_model, Field, BaseModel
|
8
|
+
|
9
|
+
# Local imports
|
10
|
+
from smarta2a.utils.types import (
|
11
|
+
PushNotificationConfig,
|
12
|
+
SendTaskStreamingResponse,
|
13
|
+
SendTaskResponse,
|
14
|
+
SendTaskStreamingRequest,
|
15
|
+
SendTaskRequest,
|
16
|
+
JSONRPCRequest,
|
17
|
+
A2AClientJSONError,
|
18
|
+
A2AClientHTTPError,
|
19
|
+
AgentCard,
|
20
|
+
AuthenticationInfo,
|
21
|
+
GetTaskResponse,
|
22
|
+
CancelTaskResponse,
|
23
|
+
SetTaskPushNotificationResponse,
|
24
|
+
GetTaskPushNotificationResponse,
|
25
|
+
)
|
26
|
+
from smarta2a.utils.task_request_builder import TaskRequestBuilder
|
27
|
+
|
28
|
+
|
29
|
+
class A2AClient:
|
30
|
+
def __init__(self, agent_card: AgentCard = None, url: str = None):
|
31
|
+
if agent_card:
|
32
|
+
self.url = agent_card.url
|
33
|
+
elif url:
|
34
|
+
self.url = url
|
35
|
+
else:
|
36
|
+
raise ValueError("Must provide either agent_card or url")
|
37
|
+
|
38
|
+
async def send(
|
39
|
+
self,
|
40
|
+
*,
|
41
|
+
id: str,
|
42
|
+
role: Literal["user", "agent"] = "user",
|
43
|
+
text: str | None = None,
|
44
|
+
data: dict[str, Any] | None = None,
|
45
|
+
file_uri: str | None = None,
|
46
|
+
session_id: str | None = None,
|
47
|
+
accepted_output_modes: list[str] | None = None,
|
48
|
+
push_notification: PushNotificationConfig | None = None,
|
49
|
+
history_length: int | None = None,
|
50
|
+
metadata: dict[str, Any] | None = None,
|
51
|
+
):
|
52
|
+
"""Send a task to another Agent"""
|
53
|
+
params = TaskRequestBuilder.build_send_task_request(
|
54
|
+
id=id,
|
55
|
+
role=role,
|
56
|
+
text=text,
|
57
|
+
data=data,
|
58
|
+
file_uri=file_uri,
|
59
|
+
session_id=session_id,
|
60
|
+
accepted_output_modes=accepted_output_modes,
|
61
|
+
push_notification=push_notification,
|
62
|
+
history_length=history_length,
|
63
|
+
metadata=metadata,
|
64
|
+
)
|
65
|
+
request = SendTaskRequest(params=params)
|
66
|
+
return SendTaskResponse(**await self._send_request(request))
|
67
|
+
|
68
|
+
def subscribe(
|
69
|
+
self,
|
70
|
+
*,
|
71
|
+
id: str,
|
72
|
+
role: Literal["user", "agent"] = "user",
|
73
|
+
text: str | None = None,
|
74
|
+
data: dict[str, Any] | None = None,
|
75
|
+
file_uri: str | None = None,
|
76
|
+
session_id: str | None = None,
|
77
|
+
accepted_output_modes: list[str] | None = None,
|
78
|
+
push_notification: PushNotificationConfig | None = None,
|
79
|
+
history_length: int | None = None,
|
80
|
+
metadata: dict[str, Any] | None = None,
|
81
|
+
):
|
82
|
+
"""Send to another Agent and receive a stream of responses"""
|
83
|
+
params = TaskRequestBuilder.build_send_task_request(
|
84
|
+
id=id,
|
85
|
+
role=role,
|
86
|
+
text=text,
|
87
|
+
data=data,
|
88
|
+
file_uri=file_uri,
|
89
|
+
session_id=session_id,
|
90
|
+
accepted_output_modes=accepted_output_modes,
|
91
|
+
push_notification=push_notification,
|
92
|
+
history_length=history_length,
|
93
|
+
metadata=metadata,
|
94
|
+
)
|
95
|
+
request = SendTaskStreamingRequest(params=params)
|
96
|
+
with httpx.Client(timeout=None) as client:
|
97
|
+
with connect_sse(
|
98
|
+
client, "POST", self.url, json=request.model_dump()
|
99
|
+
) as event_source:
|
100
|
+
try:
|
101
|
+
for sse in event_source.iter_sse():
|
102
|
+
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
103
|
+
except json.JSONDecodeError as e:
|
104
|
+
raise A2AClientJSONError(str(e)) from e
|
105
|
+
except httpx.RequestError as e:
|
106
|
+
raise A2AClientHTTPError(400, str(e)) from e
|
107
|
+
|
108
|
+
async def get_task(
|
109
|
+
self,
|
110
|
+
*,
|
111
|
+
id: str,
|
112
|
+
history_length: int | None = None,
|
113
|
+
metadata: dict[str, Any] | None = None,
|
114
|
+
) -> GetTaskResponse:
|
115
|
+
"""Get a task from another Agent"""
|
116
|
+
req = TaskRequestBuilder.get_task(id, history_length, metadata)
|
117
|
+
raw = await self._send_request(req)
|
118
|
+
return GetTaskResponse(**raw)
|
119
|
+
|
120
|
+
async def cancel_task(
|
121
|
+
self,
|
122
|
+
*,
|
123
|
+
id: str,
|
124
|
+
metadata: dict[str, Any] | None = None,
|
125
|
+
) -> CancelTaskResponse:
|
126
|
+
"""Cancel a task from another Agent"""
|
127
|
+
req = TaskRequestBuilder.cancel_task(id, metadata)
|
128
|
+
raw = await self._send_request(req)
|
129
|
+
return CancelTaskResponse(**raw)
|
130
|
+
|
131
|
+
async def set_push_notification(
|
132
|
+
self,
|
133
|
+
*,
|
134
|
+
id: str,
|
135
|
+
url: str,
|
136
|
+
token: str | None = None,
|
137
|
+
authentication: AuthenticationInfo | dict[str, Any] | None = None,
|
138
|
+
) -> SetTaskPushNotificationResponse:
|
139
|
+
"""Set a push notification for a task"""
|
140
|
+
req = TaskRequestBuilder.set_push_notification(id, url, token, authentication)
|
141
|
+
raw = await self._send_request(req)
|
142
|
+
return SetTaskPushNotificationResponse(**raw)
|
143
|
+
|
144
|
+
async def get_push_notification(
|
145
|
+
self,
|
146
|
+
*,
|
147
|
+
id: str,
|
148
|
+
metadata: dict[str, Any] | None = None,
|
149
|
+
) -> GetTaskPushNotificationResponse:
|
150
|
+
"""Get a push notification for a task"""
|
151
|
+
req = TaskRequestBuilder.get_push_notification(id, metadata)
|
152
|
+
raw = await self._send_request(req)
|
153
|
+
return GetTaskPushNotificationResponse(**raw)
|
154
|
+
|
155
|
+
|
156
|
+
async def _send_request(self, request: JSONRPCRequest) -> dict[str, Any]:
|
157
|
+
async with httpx.AsyncClient() as client:
|
158
|
+
try:
|
159
|
+
# Image generation could take time, adding timeout
|
160
|
+
response = await client.post(
|
161
|
+
self.url, json=request.model_dump(), timeout=30
|
162
|
+
)
|
163
|
+
response.raise_for_status()
|
164
|
+
return response.json()
|
165
|
+
except httpx.HTTPStatusError as e:
|
166
|
+
raise A2AClientHTTPError(e.response.status_code, str(e)) from e
|
167
|
+
except json.JSONDecodeError as e:
|
168
|
+
raise A2AClientJSONError(str(e)) from e
|
169
|
+
|
170
|
+
async def _send_streaming_request(self, request: JSONRPCRequest) -> AsyncIterable[SendTaskStreamingResponse]:
|
171
|
+
with httpx.Client(timeout=None) as client:
|
172
|
+
with connect_sse(
|
173
|
+
client, "POST", self.url, json=request.model_dump()
|
174
|
+
) as event_source:
|
175
|
+
try:
|
176
|
+
for sse in event_source.iter_sse():
|
177
|
+
yield SendTaskStreamingResponse(**json.loads(sse.data))
|
178
|
+
except json.JSONDecodeError as e:
|
179
|
+
raise A2AClientJSONError(str(e)) from e
|
180
|
+
except httpx.RequestError as e:
|
181
|
+
raise A2AClientHTTPError(400, str(e)) from e
|
182
|
+
|
183
|
+
|
184
|
+
def list_tools(self) -> list[dict[str, Any]]:
|
185
|
+
"""Return metadata for all available tools."""
|
186
|
+
tools = []
|
187
|
+
tool_names = [
|
188
|
+
'send'
|
189
|
+
]
|
190
|
+
for name in tool_names:
|
191
|
+
method = getattr(self, name)
|
192
|
+
doc = method.__doc__ or ""
|
193
|
+
description = doc.strip().split('\n')[0] if doc else ""
|
194
|
+
|
195
|
+
# Generate input schema
|
196
|
+
sig = signature(method)
|
197
|
+
parameters = sig.parameters
|
198
|
+
|
199
|
+
fields = {}
|
200
|
+
required = []
|
201
|
+
for param_name, param in parameters.items():
|
202
|
+
if param_name == 'self':
|
203
|
+
continue
|
204
|
+
annotation = param.annotation
|
205
|
+
if annotation is Parameter.empty:
|
206
|
+
annotation = Any
|
207
|
+
# Handle Literal types
|
208
|
+
if get_origin(annotation) is Literal:
|
209
|
+
enum_values = get_args(annotation)
|
210
|
+
annotation = Literal.__getitem__(enum_values)
|
211
|
+
# Handle default
|
212
|
+
default = param.default
|
213
|
+
if default is Parameter.empty:
|
214
|
+
required.append(param_name)
|
215
|
+
field = Field(...)
|
216
|
+
else:
|
217
|
+
field = Field(default=default)
|
218
|
+
fields[param_name] = (annotation, field)
|
219
|
+
|
220
|
+
# Create dynamic Pydantic model
|
221
|
+
model = create_model(f"{name}_Input", **fields)
|
222
|
+
schema = model.schema()
|
223
|
+
|
224
|
+
tools.append({
|
225
|
+
'name': name,
|
226
|
+
'description': description,
|
227
|
+
'input_schema': schema
|
228
|
+
})
|
229
|
+
return tools
|
230
|
+
|
231
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
232
|
+
"""Call a tool by name with validated arguments."""
|
233
|
+
if not hasattr(self, tool_name):
|
234
|
+
raise ValueError(f"Tool {tool_name} not found")
|
235
|
+
method = getattr(self, tool_name)
|
236
|
+
|
237
|
+
# Validate arguments using the same schema as list_tools
|
238
|
+
sig = signature(method)
|
239
|
+
parameters = sig.parameters
|
240
|
+
|
241
|
+
fields = {}
|
242
|
+
for param_name, param in parameters.items():
|
243
|
+
if param_name == 'self':
|
244
|
+
continue
|
245
|
+
annotation = param.annotation
|
246
|
+
if annotation is Parameter.empty:
|
247
|
+
annotation = Any
|
248
|
+
# Handle Literal
|
249
|
+
if get_origin(annotation) is Literal:
|
250
|
+
enum_values = get_args(annotation)
|
251
|
+
annotation = Literal.__getitem__(enum_values)
|
252
|
+
default = param.default
|
253
|
+
if default is Parameter.empty:
|
254
|
+
fields[param_name] = (annotation, Field(...))
|
255
|
+
else:
|
256
|
+
fields[param_name] = (annotation, Field(default=default))
|
257
|
+
|
258
|
+
# Create validation model
|
259
|
+
model = create_model(f"{tool_name}_ValidationModel", **fields)
|
260
|
+
validated_args = model(**arguments).dict()
|
261
|
+
|
262
|
+
# Call the method
|
263
|
+
if iscoroutinefunction(method):
|
264
|
+
return await method(**validated_args)
|
265
|
+
else:
|
266
|
+
# Note: Synchronous methods (like subscribe) will block the event loop
|
267
|
+
return method(**validated_args)
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from mcp.client import Client
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
|
4
|
+
class SmartMCPClient:
|
5
|
+
def __init__(self, base_url: str):
|
6
|
+
"""
|
7
|
+
Initialize with the server URL. Headers are provided per request, not globally.
|
8
|
+
"""
|
9
|
+
self.base_url = base_url
|
10
|
+
|
11
|
+
async def list_tools(self, session_id: Optional[str] = None) -> Any:
|
12
|
+
"""
|
13
|
+
List tools with optional session_id header.
|
14
|
+
"""
|
15
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
16
|
+
return await client.list_tools()
|
17
|
+
|
18
|
+
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
|
19
|
+
"""
|
20
|
+
Call a tool with dynamic session_id.
|
21
|
+
"""
|
22
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
23
|
+
return await client.call_tool(tool_name, arguments)
|
24
|
+
|
25
|
+
async def list_resources(self, session_id: Optional[str] = None) -> Any:
|
26
|
+
"""
|
27
|
+
List resources with optional session_id.
|
28
|
+
"""
|
29
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
30
|
+
return await client.list_resources()
|
31
|
+
|
32
|
+
async def read_resource(self, resource_uri: str, session_id: Optional[str] = None) -> Any:
|
33
|
+
"""
|
34
|
+
Read a resource with dynamic session_id.
|
35
|
+
"""
|
36
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
37
|
+
return await client.read_resource(resource_uri)
|
38
|
+
|
39
|
+
async def get_prompt(self, prompt_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
|
40
|
+
"""
|
41
|
+
Fetch a prompt with optional session_id.
|
42
|
+
"""
|
43
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
44
|
+
return await client.get_prompt(prompt_name, arguments)
|
45
|
+
|
46
|
+
async def ping(self, session_id: Optional[str] = None) -> Any:
|
47
|
+
"""
|
48
|
+
Ping server with optional session_id.
|
49
|
+
"""
|
50
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
51
|
+
return await client.ping()
|
52
|
+
|
53
|
+
def _build_headers(self, session_id: Optional[str]) -> Dict[str, str]:
|
54
|
+
"""
|
55
|
+
Internal helper to build headers dynamically.
|
56
|
+
"""
|
57
|
+
headers = {}
|
58
|
+
if session_id:
|
59
|
+
headers["x-session-id"] = session_id
|
60
|
+
return headers
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Library imports
|
2
|
+
import json
|
3
|
+
from typing import List, Dict, Any, Union, Literal
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.client.smart_mcp_client import SmartMCPClient
|
7
|
+
from smarta2a.client.a2a_client import A2AClient
|
8
|
+
from smarta2a.utils.types import AgentCard
|
9
|
+
|
10
|
+
class ToolsManager:
|
11
|
+
"""
|
12
|
+
Manages loading, describing, and invoking tools from various providers.
|
13
|
+
Acts as a wrapper around the MCP and A2A clients.
|
14
|
+
"""
|
15
|
+
def __init__(self):
|
16
|
+
self.tools_list: List[Any] = []
|
17
|
+
self.clients: Dict[str, Union[SmartMCPClient, A2AClient]] = {}
|
18
|
+
|
19
|
+
def load_mcp_tools(self, urls_or_paths: List[str]) -> None:
|
20
|
+
for url in urls_or_paths:
|
21
|
+
mcp_client = SmartMCPClient(url)
|
22
|
+
for tool in mcp_client.list_tools():
|
23
|
+
self.tools_list.append(tool)
|
24
|
+
self.clients[tool.name] = mcp_client
|
25
|
+
|
26
|
+
def load_a2a_tools(self, agent_cards: List[AgentCard]) -> None:
|
27
|
+
for agent_card in agent_cards:
|
28
|
+
a2a_client = A2AClient(agent_card)
|
29
|
+
for tool in a2a_client.list_tools():
|
30
|
+
self.tools_list.append(tool)
|
31
|
+
self.clients[tool.name] = a2a_client
|
32
|
+
|
33
|
+
def get_tools(self) -> List[Any]:
|
34
|
+
return self.tools_list
|
35
|
+
|
36
|
+
|
37
|
+
def describe_tools(self, client_type: Literal["mcp", "a2a"]) -> str:
|
38
|
+
lines = []
|
39
|
+
for tool in self.tools_list:
|
40
|
+
if client_type == "mcp" and isinstance(tool, SmartMCPClient):
|
41
|
+
schema = json.dumps(tool.input_schema, indent=2)
|
42
|
+
lines.append(
|
43
|
+
f"- **{tool.name}**: {tool.description}\n Parameters schema:\n ```json\n{schema}\n```"
|
44
|
+
)
|
45
|
+
elif client_type == "a2a" and isinstance(tool, A2AClient):
|
46
|
+
lines.append(
|
47
|
+
f"- **{tool.name}**: {tool.description} Parameters schema:\n ```json\n{schema}\n```"
|
48
|
+
)
|
49
|
+
return "\n".join(lines)
|
50
|
+
|
51
|
+
def get_client(self, tool_name: str) -> Any:
|
52
|
+
return self.clients.get(tool_name)
|
53
|
+
|
54
|
+
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
55
|
+
client = self.get_client(tool_name)
|
56
|
+
if not client:
|
57
|
+
raise ValueError(f"Tool not found: {tool_name}")
|
58
|
+
return await client.call_tool(tool_name, args)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
# Local imports
|
5
|
+
from smarta2a.utils.types import Message
|
6
|
+
|
7
|
+
class AppendStrategy:
|
8
|
+
"""Default append behavior"""
|
9
|
+
def update_history(self, existing_history: List[Message], new_messages: List[Message]) -> List[Message]:
|
10
|
+
return existing_history + new_messages
|