smarta2a 0.2.2__tar.gz → 0.2.3__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {smarta2a-0.2.2 → smarta2a-0.2.3}/PKG-INFO +12 -6
- {smarta2a-0.2.2 → smarta2a-0.2.3}/pyproject.toml +12 -6
- {smarta2a-0.2.2 → smarta2a-0.2.3}/requirements.txt +2 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/smarta2a/__init__.py +1 -1
- smarta2a-0.2.3/smarta2a/agent/a2a_agent.py +38 -0
- smarta2a-0.2.3/smarta2a/agent/a2a_mcp_server.py +37 -0
- smarta2a-0.2.3/smarta2a/archive/mcp_client.py +86 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/smarta2a/client/a2a_client.py +97 -3
- smarta2a-0.2.3/smarta2a/client/smart_mcp_client.py +60 -0
- smarta2a-0.2.3/smarta2a/client/tools_manager.py +58 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/__init__.py +8 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/append_strategy.py +10 -0
- smarta2a-0.2.3/smarta2a/history_update_strategies/history_update_strategy.py +15 -0
- smarta2a-0.2.3/smarta2a/model_providers/__init__.py +5 -0
- smarta2a-0.2.3/smarta2a/model_providers/base_llm_provider.py +15 -0
- smarta2a-0.2.3/smarta2a/model_providers/openai_provider.py +281 -0
- smarta2a-0.2.3/smarta2a/server/handler_registry.py +23 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/smarta2a/server/server.py +225 -252
- smarta2a-0.2.3/smarta2a/server/state_manager.py +34 -0
- smarta2a-0.2.3/smarta2a/server/subscription_service.py +109 -0
- smarta2a-0.2.3/smarta2a/server/task_service.py +155 -0
- smarta2a-0.2.3/smarta2a/state_stores/__init__.py +8 -0
- smarta2a-0.2.3/smarta2a/state_stores/base_state_store.py +20 -0
- smarta2a-0.2.3/smarta2a/state_stores/inmemory_state_store.py +21 -0
- smarta2a-0.2.3/smarta2a/utils/prompt_helpers.py +38 -0
- smarta2a-0.2.3/smarta2a/utils/task_builder.py +153 -0
- {smarta2a-0.2.2/smarta2a/common → smarta2a-0.2.3/smarta2a/utils}/task_request_builder.py +1 -1
- {smarta2a-0.2.2/smarta2a/common → smarta2a-0.2.3/smarta2a/utils}/types.py +62 -2
- smarta2a-0.2.3/tests/__init__.py +3 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/tests/test_server.py +3 -7
- smarta2a-0.2.3/tests/test_server_history.py +189 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/tests/test_task_request_builder.py +2 -2
- {smarta2a-0.2.2 → smarta2a-0.2.3}/.gitignore +0 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/LICENSE +0 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/README.md +0 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/smarta2a/client/__init__.py +0 -0
- {smarta2a-0.2.2 → smarta2a-0.2.3}/smarta2a/server/__init__.py +0 -0
- {smarta2a-0.2.2/smarta2a/common → smarta2a-0.2.3/smarta2a/utils}/__init__.py +0 -0
@@ -1,7 +1,7 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: smarta2a
|
3
|
-
Version: 0.2.
|
4
|
-
Summary: A Python
|
3
|
+
Version: 0.2.3
|
4
|
+
Summary: A simple Python framework (built on top of FastAPI) for creating Agents following Google's Agent2Agent protocol
|
5
5
|
Project-URL: Homepage, https://github.com/siddharthsma/smarta2a
|
6
6
|
Project-URL: Bug Tracker, https://github.com/siddharthsma/smarta2a/issues
|
7
7
|
Author-email: Siddharth Ambegaonkar <siddharthsma@gmail.com>
|
@@ -10,10 +10,16 @@ Classifier: License :: OSI Approved :: MIT License
|
|
10
10
|
Classifier: Operating System :: OS Independent
|
11
11
|
Classifier: Programming Language :: Python :: 3
|
12
12
|
Requires-Python: >=3.8
|
13
|
-
Requires-Dist:
|
14
|
-
Requires-Dist:
|
15
|
-
Requires-Dist:
|
16
|
-
Requires-Dist:
|
13
|
+
Requires-Dist: anyio>=4.9.0
|
14
|
+
Requires-Dist: fastapi>=0.115.12
|
15
|
+
Requires-Dist: httpx>=0.28.1
|
16
|
+
Requires-Dist: mcp>=0.1.0
|
17
|
+
Requires-Dist: openai>=1.0.0
|
18
|
+
Requires-Dist: pydantic>=2.11.3
|
19
|
+
Requires-Dist: sse-starlette>=2.2.1
|
20
|
+
Requires-Dist: starlette>=0.46.2
|
21
|
+
Requires-Dist: typing-extensions>=4.13.2
|
22
|
+
Requires-Dist: uvicorn>=0.34.1
|
17
23
|
Description-Content-Type: text/markdown
|
18
24
|
|
19
25
|
# SmartA2A
|
@@ -4,11 +4,11 @@ build-backend = "hatchling.build"
|
|
4
4
|
|
5
5
|
[project]
|
6
6
|
name = "smarta2a"
|
7
|
-
version = "0.2.
|
7
|
+
version = "0.2.3"
|
8
8
|
authors = [
|
9
9
|
{ name = "Siddharth Ambegaonkar", email = "siddharthsma@gmail.com" },
|
10
10
|
]
|
11
|
-
description = "A Python
|
11
|
+
description = "A simple Python framework (built on top of FastAPI) for creating Agents following Google's Agent2Agent protocol"
|
12
12
|
readme = "README.md"
|
13
13
|
requires-python = ">=3.8"
|
14
14
|
classifiers = [
|
@@ -17,10 +17,16 @@ classifiers = [
|
|
17
17
|
"Operating System :: OS Independent",
|
18
18
|
]
|
19
19
|
dependencies = [
|
20
|
-
"fastapi",
|
21
|
-
"pydantic",
|
22
|
-
"uvicorn",
|
23
|
-
"sse-starlette",
|
20
|
+
"fastapi>=0.115.12",
|
21
|
+
"pydantic>=2.11.3",
|
22
|
+
"uvicorn>=0.34.1",
|
23
|
+
"sse-starlette>=2.2.1",
|
24
|
+
"openai>=1.0.0",
|
25
|
+
"mcp>=0.1.0",
|
26
|
+
"httpx>=0.28.1",
|
27
|
+
"starlette>=0.46.2",
|
28
|
+
"typing-extensions>=4.13.2",
|
29
|
+
"anyio>=4.9.0"
|
24
30
|
]
|
25
31
|
|
26
32
|
[project.urls]
|
@@ -0,0 +1,38 @@
|
|
1
|
+
# Library imports
|
2
|
+
|
3
|
+
|
4
|
+
# Local imports
|
5
|
+
from smarta2a.server import SmartA2A
|
6
|
+
from smarta2a.model_providers.base_llm_provider import BaseLLMProvider
|
7
|
+
from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
|
8
|
+
from smarta2a.state_stores.base_state_store import BaseStateStore
|
9
|
+
from smarta2a.utils.types import StateData, SendTaskRequest
|
10
|
+
|
11
|
+
class A2AAgent:
|
12
|
+
def __init__(
|
13
|
+
self,
|
14
|
+
name: str,
|
15
|
+
model_provider: BaseLLMProvider,
|
16
|
+
history_update_strategy: HistoryUpdateStrategy,
|
17
|
+
state_storage: BaseStateStore,
|
18
|
+
):
|
19
|
+
self.model_provider = model_provider
|
20
|
+
self.app = SmartA2A(
|
21
|
+
name=name,
|
22
|
+
history_update_strategy=history_update_strategy,
|
23
|
+
state_storage=state_storage
|
24
|
+
)
|
25
|
+
self.__register_handlers()
|
26
|
+
|
27
|
+
def __register_handlers(self):
|
28
|
+
@self.app.on_send_task()
|
29
|
+
async def on_send_task(request: SendTaskRequest, state: StateData):
|
30
|
+
response = self.model_provider.generate(state.history)
|
31
|
+
return response
|
32
|
+
|
33
|
+
def start(self, **kwargs):
|
34
|
+
self.app.configure(**kwargs)
|
35
|
+
self.app.run()
|
36
|
+
|
37
|
+
|
38
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Library imports
|
2
|
+
from mcp.server.fastmcp import FastMCP, Context
|
3
|
+
|
4
|
+
class A2AMCPServer:
|
5
|
+
def __init__(self):
|
6
|
+
self.mcp = FastMCP("A2AMCPServer")
|
7
|
+
self._register_handlers()
|
8
|
+
|
9
|
+
def _register_handlers(self):
|
10
|
+
|
11
|
+
@self.mcp.middleware
|
12
|
+
async def extract_session_id(request, call_next):
|
13
|
+
# Extract 'x-session-id' from headers
|
14
|
+
session_id = request.headers.get("x-session-id")
|
15
|
+
# Store it in the request state for later access
|
16
|
+
request.state.session_id = session_id
|
17
|
+
# Proceed with the request processing
|
18
|
+
response = await call_next(request)
|
19
|
+
return response
|
20
|
+
|
21
|
+
@self.mcp.tool()
|
22
|
+
def send_task(ctx: Context, url: str, message: str):
|
23
|
+
session_id = ctx.request.state.session_id
|
24
|
+
pass
|
25
|
+
|
26
|
+
@self.mcp.tool()
|
27
|
+
def get_task():
|
28
|
+
pass
|
29
|
+
|
30
|
+
@self.mcp.tool()
|
31
|
+
def cancel_task():
|
32
|
+
pass
|
33
|
+
|
34
|
+
|
35
|
+
|
36
|
+
def start(self):
|
37
|
+
pass
|
@@ -0,0 +1,86 @@
|
|
1
|
+
# Library imports
|
2
|
+
import re
|
3
|
+
from contextlib import AsyncExitStack
|
4
|
+
from mcp.client import ClientSession, sse_client, stdio_client, StdioServerParameters
|
5
|
+
|
6
|
+
|
7
|
+
class MCPClient:
|
8
|
+
def __init__(self):
|
9
|
+
self.session = None
|
10
|
+
self.exit_stack = AsyncExitStack()
|
11
|
+
self._connect_to_server()
|
12
|
+
|
13
|
+
async def _connect_to_sse_server(self, server_url: str):
|
14
|
+
"""Connect to an SSE MCP server."""
|
15
|
+
self._streams_context = sse_client(url=server_url)
|
16
|
+
streams = await self._streams_context.__aenter__()
|
17
|
+
|
18
|
+
self._session_context = ClientSession(*streams)
|
19
|
+
self.session = await self._session_context.__aenter__()
|
20
|
+
|
21
|
+
# Initialize
|
22
|
+
await self.session.initialize()
|
23
|
+
|
24
|
+
async def _connect_to_stdio_server(self, server_script_path: str):
|
25
|
+
"""Connect to a stdio MCP server."""
|
26
|
+
is_python = False
|
27
|
+
is_javascript = False
|
28
|
+
command = None
|
29
|
+
args = [server_script_path]
|
30
|
+
|
31
|
+
# Determine if the server is a file path or npm package
|
32
|
+
if server_script_path.startswith("@") or "/" not in server_script_path:
|
33
|
+
# Assume it's an npm package
|
34
|
+
is_javascript = True
|
35
|
+
command = "npx"
|
36
|
+
else:
|
37
|
+
# It's a file path
|
38
|
+
is_python = server_script_path.endswith(".py")
|
39
|
+
is_javascript = server_script_path.endswith(".js")
|
40
|
+
if not (is_python or is_javascript):
|
41
|
+
raise ValueError("Server script must be a .py, .js file or npm package.")
|
42
|
+
|
43
|
+
command = "python" if is_python else "node"
|
44
|
+
|
45
|
+
server_params = StdioServerParameters(
|
46
|
+
command=command,
|
47
|
+
args=args,
|
48
|
+
env=None
|
49
|
+
)
|
50
|
+
|
51
|
+
# Start the server
|
52
|
+
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
53
|
+
self.stdio, self.writer = stdio_transport
|
54
|
+
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.writer))
|
55
|
+
|
56
|
+
await self.session.initialize()
|
57
|
+
|
58
|
+
async def _connect_to_server(self, server_path_or_url: str):
|
59
|
+
"""Connect to an MCP server (either stdio or SSE)."""
|
60
|
+
# Check if the input is a URL (for SSE server)
|
61
|
+
url_pattern = re.compile(r'^https?://')
|
62
|
+
|
63
|
+
if url_pattern.match(server_path_or_url):
|
64
|
+
# It's a URL, connect to SSE server
|
65
|
+
await self._connect_to_sse_server(server_path_or_url)
|
66
|
+
else:
|
67
|
+
# It's a script path, connect to stdio server
|
68
|
+
await self._connect_to_stdio_server(server_path_or_url)
|
69
|
+
|
70
|
+
async def list_tools(self):
|
71
|
+
"""List available tools."""
|
72
|
+
response = await self.session.list_tools()
|
73
|
+
return response.tools
|
74
|
+
|
75
|
+
async def call_tool(self, tool_name: str, **tool_args):
|
76
|
+
"""Call a tool."""
|
77
|
+
response = await self.session.call_tool(tool_name, **tool_args)
|
78
|
+
return response.content
|
79
|
+
|
80
|
+
async def cleanup(self):
|
81
|
+
"""Clean up resources."""
|
82
|
+
await self.exit_stack.aclose()
|
83
|
+
if hasattr(self, '_session_context') and self._session_context:
|
84
|
+
await self._session_context.__aexit__(None, None, None)
|
85
|
+
if hasattr(self, '_streams_context') and self._streams_context:
|
86
|
+
await self._streams_context.__aexit__(None, None, None)
|
@@ -1,11 +1,13 @@
|
|
1
1
|
# Library imports
|
2
|
-
from typing import Any, Literal, AsyncIterable
|
2
|
+
from typing import Any, Literal, AsyncIterable, get_origin, get_args
|
3
3
|
import httpx
|
4
4
|
import json
|
5
5
|
from httpx_sse import connect_sse
|
6
|
+
from inspect import signature, Parameter, iscoroutinefunction
|
7
|
+
from pydantic import create_model, Field, BaseModel
|
6
8
|
|
7
9
|
# Local imports
|
8
|
-
from smarta2a.
|
10
|
+
from smarta2a.utils.types import (
|
9
11
|
PushNotificationConfig,
|
10
12
|
SendTaskStreamingResponse,
|
11
13
|
SendTaskResponse,
|
@@ -21,7 +23,7 @@ from smarta2a.common.types import (
|
|
21
23
|
SetTaskPushNotificationResponse,
|
22
24
|
GetTaskPushNotificationResponse,
|
23
25
|
)
|
24
|
-
from smarta2a.
|
26
|
+
from smarta2a.utils.task_request_builder import TaskRequestBuilder
|
25
27
|
|
26
28
|
|
27
29
|
class A2AClient:
|
@@ -47,6 +49,7 @@ class A2AClient:
|
|
47
49
|
history_length: int | None = None,
|
48
50
|
metadata: dict[str, Any] | None = None,
|
49
51
|
):
|
52
|
+
"""Send a task to another Agent"""
|
50
53
|
params = TaskRequestBuilder.build_send_task_request(
|
51
54
|
id=id,
|
52
55
|
role=role,
|
@@ -76,6 +79,7 @@ class A2AClient:
|
|
76
79
|
history_length: int | None = None,
|
77
80
|
metadata: dict[str, Any] | None = None,
|
78
81
|
):
|
82
|
+
"""Send to another Agent and receive a stream of responses"""
|
79
83
|
params = TaskRequestBuilder.build_send_task_request(
|
80
84
|
id=id,
|
81
85
|
role=role,
|
@@ -108,6 +112,7 @@ class A2AClient:
|
|
108
112
|
history_length: int | None = None,
|
109
113
|
metadata: dict[str, Any] | None = None,
|
110
114
|
) -> GetTaskResponse:
|
115
|
+
"""Get a task from another Agent"""
|
111
116
|
req = TaskRequestBuilder.get_task(id, history_length, metadata)
|
112
117
|
raw = await self._send_request(req)
|
113
118
|
return GetTaskResponse(**raw)
|
@@ -118,6 +123,7 @@ class A2AClient:
|
|
118
123
|
id: str,
|
119
124
|
metadata: dict[str, Any] | None = None,
|
120
125
|
) -> CancelTaskResponse:
|
126
|
+
"""Cancel a task from another Agent"""
|
121
127
|
req = TaskRequestBuilder.cancel_task(id, metadata)
|
122
128
|
raw = await self._send_request(req)
|
123
129
|
return CancelTaskResponse(**raw)
|
@@ -130,6 +136,7 @@ class A2AClient:
|
|
130
136
|
token: str | None = None,
|
131
137
|
authentication: AuthenticationInfo | dict[str, Any] | None = None,
|
132
138
|
) -> SetTaskPushNotificationResponse:
|
139
|
+
"""Set a push notification for a task"""
|
133
140
|
req = TaskRequestBuilder.set_push_notification(id, url, token, authentication)
|
134
141
|
raw = await self._send_request(req)
|
135
142
|
return SetTaskPushNotificationResponse(**raw)
|
@@ -140,6 +147,7 @@ class A2AClient:
|
|
140
147
|
id: str,
|
141
148
|
metadata: dict[str, Any] | None = None,
|
142
149
|
) -> GetTaskPushNotificationResponse:
|
150
|
+
"""Get a push notification for a task"""
|
143
151
|
req = TaskRequestBuilder.get_push_notification(id, metadata)
|
144
152
|
raw = await self._send_request(req)
|
145
153
|
return GetTaskPushNotificationResponse(**raw)
|
@@ -171,3 +179,89 @@ class A2AClient:
|
|
171
179
|
raise A2AClientJSONError(str(e)) from e
|
172
180
|
except httpx.RequestError as e:
|
173
181
|
raise A2AClientHTTPError(400, str(e)) from e
|
182
|
+
|
183
|
+
|
184
|
+
def list_tools(self) -> list[dict[str, Any]]:
|
185
|
+
"""Return metadata for all available tools."""
|
186
|
+
tools = []
|
187
|
+
tool_names = [
|
188
|
+
'send'
|
189
|
+
]
|
190
|
+
for name in tool_names:
|
191
|
+
method = getattr(self, name)
|
192
|
+
doc = method.__doc__ or ""
|
193
|
+
description = doc.strip().split('\n')[0] if doc else ""
|
194
|
+
|
195
|
+
# Generate input schema
|
196
|
+
sig = signature(method)
|
197
|
+
parameters = sig.parameters
|
198
|
+
|
199
|
+
fields = {}
|
200
|
+
required = []
|
201
|
+
for param_name, param in parameters.items():
|
202
|
+
if param_name == 'self':
|
203
|
+
continue
|
204
|
+
annotation = param.annotation
|
205
|
+
if annotation is Parameter.empty:
|
206
|
+
annotation = Any
|
207
|
+
# Handle Literal types
|
208
|
+
if get_origin(annotation) is Literal:
|
209
|
+
enum_values = get_args(annotation)
|
210
|
+
annotation = Literal.__getitem__(enum_values)
|
211
|
+
# Handle default
|
212
|
+
default = param.default
|
213
|
+
if default is Parameter.empty:
|
214
|
+
required.append(param_name)
|
215
|
+
field = Field(...)
|
216
|
+
else:
|
217
|
+
field = Field(default=default)
|
218
|
+
fields[param_name] = (annotation, field)
|
219
|
+
|
220
|
+
# Create dynamic Pydantic model
|
221
|
+
model = create_model(f"{name}_Input", **fields)
|
222
|
+
schema = model.schema()
|
223
|
+
|
224
|
+
tools.append({
|
225
|
+
'name': name,
|
226
|
+
'description': description,
|
227
|
+
'input_schema': schema
|
228
|
+
})
|
229
|
+
return tools
|
230
|
+
|
231
|
+
async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
|
232
|
+
"""Call a tool by name with validated arguments."""
|
233
|
+
if not hasattr(self, tool_name):
|
234
|
+
raise ValueError(f"Tool {tool_name} not found")
|
235
|
+
method = getattr(self, tool_name)
|
236
|
+
|
237
|
+
# Validate arguments using the same schema as list_tools
|
238
|
+
sig = signature(method)
|
239
|
+
parameters = sig.parameters
|
240
|
+
|
241
|
+
fields = {}
|
242
|
+
for param_name, param in parameters.items():
|
243
|
+
if param_name == 'self':
|
244
|
+
continue
|
245
|
+
annotation = param.annotation
|
246
|
+
if annotation is Parameter.empty:
|
247
|
+
annotation = Any
|
248
|
+
# Handle Literal
|
249
|
+
if get_origin(annotation) is Literal:
|
250
|
+
enum_values = get_args(annotation)
|
251
|
+
annotation = Literal.__getitem__(enum_values)
|
252
|
+
default = param.default
|
253
|
+
if default is Parameter.empty:
|
254
|
+
fields[param_name] = (annotation, Field(...))
|
255
|
+
else:
|
256
|
+
fields[param_name] = (annotation, Field(default=default))
|
257
|
+
|
258
|
+
# Create validation model
|
259
|
+
model = create_model(f"{tool_name}_ValidationModel", **fields)
|
260
|
+
validated_args = model(**arguments).dict()
|
261
|
+
|
262
|
+
# Call the method
|
263
|
+
if iscoroutinefunction(method):
|
264
|
+
return await method(**validated_args)
|
265
|
+
else:
|
266
|
+
# Note: Synchronous methods (like subscribe) will block the event loop
|
267
|
+
return method(**validated_args)
|
@@ -0,0 +1,60 @@
|
|
1
|
+
from mcp.client import Client
|
2
|
+
from typing import Optional, Dict, Any
|
3
|
+
|
4
|
+
class SmartMCPClient:
|
5
|
+
def __init__(self, base_url: str):
|
6
|
+
"""
|
7
|
+
Initialize with the server URL. Headers are provided per request, not globally.
|
8
|
+
"""
|
9
|
+
self.base_url = base_url
|
10
|
+
|
11
|
+
async def list_tools(self, session_id: Optional[str] = None) -> Any:
|
12
|
+
"""
|
13
|
+
List tools with optional session_id header.
|
14
|
+
"""
|
15
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
16
|
+
return await client.list_tools()
|
17
|
+
|
18
|
+
async def call_tool(self, tool_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
|
19
|
+
"""
|
20
|
+
Call a tool with dynamic session_id.
|
21
|
+
"""
|
22
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
23
|
+
return await client.call_tool(tool_name, arguments)
|
24
|
+
|
25
|
+
async def list_resources(self, session_id: Optional[str] = None) -> Any:
|
26
|
+
"""
|
27
|
+
List resources with optional session_id.
|
28
|
+
"""
|
29
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
30
|
+
return await client.list_resources()
|
31
|
+
|
32
|
+
async def read_resource(self, resource_uri: str, session_id: Optional[str] = None) -> Any:
|
33
|
+
"""
|
34
|
+
Read a resource with dynamic session_id.
|
35
|
+
"""
|
36
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
37
|
+
return await client.read_resource(resource_uri)
|
38
|
+
|
39
|
+
async def get_prompt(self, prompt_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
|
40
|
+
"""
|
41
|
+
Fetch a prompt with optional session_id.
|
42
|
+
"""
|
43
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
44
|
+
return await client.get_prompt(prompt_name, arguments)
|
45
|
+
|
46
|
+
async def ping(self, session_id: Optional[str] = None) -> Any:
|
47
|
+
"""
|
48
|
+
Ping server with optional session_id.
|
49
|
+
"""
|
50
|
+
async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
|
51
|
+
return await client.ping()
|
52
|
+
|
53
|
+
def _build_headers(self, session_id: Optional[str]) -> Dict[str, str]:
|
54
|
+
"""
|
55
|
+
Internal helper to build headers dynamically.
|
56
|
+
"""
|
57
|
+
headers = {}
|
58
|
+
if session_id:
|
59
|
+
headers["x-session-id"] = session_id
|
60
|
+
return headers
|
@@ -0,0 +1,58 @@
|
|
1
|
+
# Library imports
|
2
|
+
import json
|
3
|
+
from typing import List, Dict, Any, Union, Literal
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.client.smart_mcp_client import SmartMCPClient
|
7
|
+
from smarta2a.client.a2a_client import A2AClient
|
8
|
+
from smarta2a.utils.types import AgentCard
|
9
|
+
|
10
|
+
class ToolsManager:
|
11
|
+
"""
|
12
|
+
Manages loading, describing, and invoking tools from various providers.
|
13
|
+
Acts as a wrapper around the MCP and A2A clients.
|
14
|
+
"""
|
15
|
+
def __init__(self):
|
16
|
+
self.tools_list: List[Any] = []
|
17
|
+
self.clients: Dict[str, Union[SmartMCPClient, A2AClient]] = {}
|
18
|
+
|
19
|
+
def load_mcp_tools(self, urls_or_paths: List[str]) -> None:
|
20
|
+
for url in urls_or_paths:
|
21
|
+
mcp_client = SmartMCPClient(url)
|
22
|
+
for tool in mcp_client.list_tools():
|
23
|
+
self.tools_list.append(tool)
|
24
|
+
self.clients[tool.name] = mcp_client
|
25
|
+
|
26
|
+
def load_a2a_tools(self, agent_cards: List[AgentCard]) -> None:
|
27
|
+
for agent_card in agent_cards:
|
28
|
+
a2a_client = A2AClient(agent_card)
|
29
|
+
for tool in a2a_client.list_tools():
|
30
|
+
self.tools_list.append(tool)
|
31
|
+
self.clients[tool.name] = a2a_client
|
32
|
+
|
33
|
+
def get_tools(self) -> List[Any]:
|
34
|
+
return self.tools_list
|
35
|
+
|
36
|
+
|
37
|
+
def describe_tools(self, client_type: Literal["mcp", "a2a"]) -> str:
|
38
|
+
lines = []
|
39
|
+
for tool in self.tools_list:
|
40
|
+
if client_type == "mcp" and isinstance(tool, SmartMCPClient):
|
41
|
+
schema = json.dumps(tool.input_schema, indent=2)
|
42
|
+
lines.append(
|
43
|
+
f"- **{tool.name}**: {tool.description}\n Parameters schema:\n ```json\n{schema}\n```"
|
44
|
+
)
|
45
|
+
elif client_type == "a2a" and isinstance(tool, A2AClient):
|
46
|
+
lines.append(
|
47
|
+
f"- **{tool.name}**: {tool.description} Parameters schema:\n ```json\n{schema}\n```"
|
48
|
+
)
|
49
|
+
return "\n".join(lines)
|
50
|
+
|
51
|
+
def get_client(self, tool_name: str) -> Any:
|
52
|
+
return self.clients.get(tool_name)
|
53
|
+
|
54
|
+
async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
|
55
|
+
client = self.get_client(tool_name)
|
56
|
+
if not client:
|
57
|
+
raise ValueError(f"Tool not found: {tool_name}")
|
58
|
+
return await client.call_tool(tool_name, args)
|
@@ -0,0 +1,10 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
# Local imports
|
5
|
+
from smarta2a.utils.types import Message
|
6
|
+
|
7
|
+
class AppendStrategy:
|
8
|
+
"""Default append behavior"""
|
9
|
+
def update_history(self, existing_history: List[Message], new_messages: List[Message]) -> List[Message]:
|
10
|
+
return existing_history + new_messages
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Library imports
|
2
|
+
from typing import Protocol, List
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.utils.types import Message
|
7
|
+
|
8
|
+
class HistoryUpdateStrategy(Protocol):
|
9
|
+
def update_history(
|
10
|
+
self,
|
11
|
+
existing_history: List[Message],
|
12
|
+
new_messages: List[Message]
|
13
|
+
) -> List[Message]:
|
14
|
+
"""Process history with new messages"""
|
15
|
+
pass
|
@@ -0,0 +1,15 @@
|
|
1
|
+
# Library imports
|
2
|
+
from abc import ABC, abstractmethod
|
3
|
+
from typing import AsyncGenerator, List
|
4
|
+
|
5
|
+
# Local imports
|
6
|
+
from smarta2a.utils.types import Message
|
7
|
+
|
8
|
+
class BaseLLMProvider(ABC):
|
9
|
+
@abstractmethod
|
10
|
+
async def generate(self, messages: List[Message], **kwargs) -> str:
|
11
|
+
pass
|
12
|
+
|
13
|
+
@abstractmethod
|
14
|
+
async def generate_stream(self, messages: List[Message], **kwargs) -> AsyncGenerator[str, None]:
|
15
|
+
pass
|