smarta2a 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
smarta2a/__init__.py CHANGED
@@ -3,7 +3,7 @@ py_a2a - A Python package for implementing an A2A server
3
3
  """
4
4
 
5
5
  from .server.server import SmartA2A
6
- from .common.types import *
6
+ from .utils.types import *
7
7
 
8
8
  __version__ = "0.1.0"
9
9
 
@@ -0,0 +1,38 @@
1
+ # Library imports
2
+
3
+
4
+ # Local imports
5
+ from smarta2a.server import SmartA2A
6
+ from smarta2a.model_providers.base_llm_provider import BaseLLMProvider
7
+ from smarta2a.history_update_strategies.history_update_strategy import HistoryUpdateStrategy
8
+ from smarta2a.state_stores.base_state_store import BaseStateStore
9
+ from smarta2a.utils.types import StateData, SendTaskRequest
10
+
11
+ class A2AAgent:
12
+ def __init__(
13
+ self,
14
+ name: str,
15
+ model_provider: BaseLLMProvider,
16
+ history_update_strategy: HistoryUpdateStrategy,
17
+ state_storage: BaseStateStore,
18
+ ):
19
+ self.model_provider = model_provider
20
+ self.app = SmartA2A(
21
+ name=name,
22
+ history_update_strategy=history_update_strategy,
23
+ state_storage=state_storage
24
+ )
25
+ self.__register_handlers()
26
+
27
+ def __register_handlers(self):
28
+ @self.app.on_send_task()
29
+ async def on_send_task(request: SendTaskRequest, state: StateData):
30
+ response = self.model_provider.generate(state.history)
31
+ return response
32
+
33
+ def start(self, **kwargs):
34
+ self.app.configure(**kwargs)
35
+ self.app.run()
36
+
37
+
38
+
@@ -0,0 +1,37 @@
1
+ # Library imports
2
+ from mcp.server.fastmcp import FastMCP, Context
3
+
4
+ class A2AMCPServer:
5
+ def __init__(self):
6
+ self.mcp = FastMCP("A2AMCPServer")
7
+ self._register_handlers()
8
+
9
+ def _register_handlers(self):
10
+
11
+ @self.mcp.middleware
12
+ async def extract_session_id(request, call_next):
13
+ # Extract 'x-session-id' from headers
14
+ session_id = request.headers.get("x-session-id")
15
+ # Store it in the request state for later access
16
+ request.state.session_id = session_id
17
+ # Proceed with the request processing
18
+ response = await call_next(request)
19
+ return response
20
+
21
+ @self.mcp.tool()
22
+ def send_task(ctx: Context, url: str, message: str):
23
+ session_id = ctx.request.state.session_id
24
+ pass
25
+
26
+ @self.mcp.tool()
27
+ def get_task():
28
+ pass
29
+
30
+ @self.mcp.tool()
31
+ def cancel_task():
32
+ pass
33
+
34
+
35
+
36
+ def start(self):
37
+ pass
@@ -0,0 +1,86 @@
1
+ # Library imports
2
+ import re
3
+ from contextlib import AsyncExitStack
4
+ from mcp.client import ClientSession, sse_client, stdio_client, StdioServerParameters
5
+
6
+
7
+ class MCPClient:
8
+ def __init__(self):
9
+ self.session = None
10
+ self.exit_stack = AsyncExitStack()
11
+ self._connect_to_server()
12
+
13
+ async def _connect_to_sse_server(self, server_url: str):
14
+ """Connect to an SSE MCP server."""
15
+ self._streams_context = sse_client(url=server_url)
16
+ streams = await self._streams_context.__aenter__()
17
+
18
+ self._session_context = ClientSession(*streams)
19
+ self.session = await self._session_context.__aenter__()
20
+
21
+ # Initialize
22
+ await self.session.initialize()
23
+
24
+ async def _connect_to_stdio_server(self, server_script_path: str):
25
+ """Connect to a stdio MCP server."""
26
+ is_python = False
27
+ is_javascript = False
28
+ command = None
29
+ args = [server_script_path]
30
+
31
+ # Determine if the server is a file path or npm package
32
+ if server_script_path.startswith("@") or "/" not in server_script_path:
33
+ # Assume it's an npm package
34
+ is_javascript = True
35
+ command = "npx"
36
+ else:
37
+ # It's a file path
38
+ is_python = server_script_path.endswith(".py")
39
+ is_javascript = server_script_path.endswith(".js")
40
+ if not (is_python or is_javascript):
41
+ raise ValueError("Server script must be a .py, .js file or npm package.")
42
+
43
+ command = "python" if is_python else "node"
44
+
45
+ server_params = StdioServerParameters(
46
+ command=command,
47
+ args=args,
48
+ env=None
49
+ )
50
+
51
+ # Start the server
52
+ stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
53
+ self.stdio, self.writer = stdio_transport
54
+ self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.writer))
55
+
56
+ await self.session.initialize()
57
+
58
+ async def _connect_to_server(self, server_path_or_url: str):
59
+ """Connect to an MCP server (either stdio or SSE)."""
60
+ # Check if the input is a URL (for SSE server)
61
+ url_pattern = re.compile(r'^https?://')
62
+
63
+ if url_pattern.match(server_path_or_url):
64
+ # It's a URL, connect to SSE server
65
+ await self._connect_to_sse_server(server_path_or_url)
66
+ else:
67
+ # It's a script path, connect to stdio server
68
+ await self._connect_to_stdio_server(server_path_or_url)
69
+
70
+ async def list_tools(self):
71
+ """List available tools."""
72
+ response = await self.session.list_tools()
73
+ return response.tools
74
+
75
+ async def call_tool(self, tool_name: str, **tool_args):
76
+ """Call a tool."""
77
+ response = await self.session.call_tool(tool_name, **tool_args)
78
+ return response.content
79
+
80
+ async def cleanup(self):
81
+ """Clean up resources."""
82
+ await self.exit_stack.aclose()
83
+ if hasattr(self, '_session_context') and self._session_context:
84
+ await self._session_context.__aexit__(None, None, None)
85
+ if hasattr(self, '_streams_context') and self._streams_context:
86
+ await self._streams_context.__aexit__(None, None, None)
@@ -1,11 +1,13 @@
1
1
  # Library imports
2
- from typing import Any, Literal, AsyncIterable
2
+ from typing import Any, Literal, AsyncIterable, get_origin, get_args
3
3
  import httpx
4
4
  import json
5
5
  from httpx_sse import connect_sse
6
+ from inspect import signature, Parameter, iscoroutinefunction
7
+ from pydantic import create_model, Field, BaseModel
6
8
 
7
9
  # Local imports
8
- from smarta2a.common.types import (
10
+ from smarta2a.utils.types import (
9
11
  PushNotificationConfig,
10
12
  SendTaskStreamingResponse,
11
13
  SendTaskResponse,
@@ -21,7 +23,7 @@ from smarta2a.common.types import (
21
23
  SetTaskPushNotificationResponse,
22
24
  GetTaskPushNotificationResponse,
23
25
  )
24
- from smarta2a.common.task_request_builder import TaskRequestBuilder
26
+ from smarta2a.utils.task_request_builder import TaskRequestBuilder
25
27
 
26
28
 
27
29
  class A2AClient:
@@ -47,6 +49,7 @@ class A2AClient:
47
49
  history_length: int | None = None,
48
50
  metadata: dict[str, Any] | None = None,
49
51
  ):
52
+ """Send a task to another Agent"""
50
53
  params = TaskRequestBuilder.build_send_task_request(
51
54
  id=id,
52
55
  role=role,
@@ -76,6 +79,7 @@ class A2AClient:
76
79
  history_length: int | None = None,
77
80
  metadata: dict[str, Any] | None = None,
78
81
  ):
82
+ """Send to another Agent and receive a stream of responses"""
79
83
  params = TaskRequestBuilder.build_send_task_request(
80
84
  id=id,
81
85
  role=role,
@@ -108,6 +112,7 @@ class A2AClient:
108
112
  history_length: int | None = None,
109
113
  metadata: dict[str, Any] | None = None,
110
114
  ) -> GetTaskResponse:
115
+ """Get a task from another Agent"""
111
116
  req = TaskRequestBuilder.get_task(id, history_length, metadata)
112
117
  raw = await self._send_request(req)
113
118
  return GetTaskResponse(**raw)
@@ -118,6 +123,7 @@ class A2AClient:
118
123
  id: str,
119
124
  metadata: dict[str, Any] | None = None,
120
125
  ) -> CancelTaskResponse:
126
+ """Cancel a task from another Agent"""
121
127
  req = TaskRequestBuilder.cancel_task(id, metadata)
122
128
  raw = await self._send_request(req)
123
129
  return CancelTaskResponse(**raw)
@@ -130,6 +136,7 @@ class A2AClient:
130
136
  token: str | None = None,
131
137
  authentication: AuthenticationInfo | dict[str, Any] | None = None,
132
138
  ) -> SetTaskPushNotificationResponse:
139
+ """Set a push notification for a task"""
133
140
  req = TaskRequestBuilder.set_push_notification(id, url, token, authentication)
134
141
  raw = await self._send_request(req)
135
142
  return SetTaskPushNotificationResponse(**raw)
@@ -140,6 +147,7 @@ class A2AClient:
140
147
  id: str,
141
148
  metadata: dict[str, Any] | None = None,
142
149
  ) -> GetTaskPushNotificationResponse:
150
+ """Get a push notification for a task"""
143
151
  req = TaskRequestBuilder.get_push_notification(id, metadata)
144
152
  raw = await self._send_request(req)
145
153
  return GetTaskPushNotificationResponse(**raw)
@@ -171,3 +179,89 @@ class A2AClient:
171
179
  raise A2AClientJSONError(str(e)) from e
172
180
  except httpx.RequestError as e:
173
181
  raise A2AClientHTTPError(400, str(e)) from e
182
+
183
+
184
+ def list_tools(self) -> list[dict[str, Any]]:
185
+ """Return metadata for all available tools."""
186
+ tools = []
187
+ tool_names = [
188
+ 'send'
189
+ ]
190
+ for name in tool_names:
191
+ method = getattr(self, name)
192
+ doc = method.__doc__ or ""
193
+ description = doc.strip().split('\n')[0] if doc else ""
194
+
195
+ # Generate input schema
196
+ sig = signature(method)
197
+ parameters = sig.parameters
198
+
199
+ fields = {}
200
+ required = []
201
+ for param_name, param in parameters.items():
202
+ if param_name == 'self':
203
+ continue
204
+ annotation = param.annotation
205
+ if annotation is Parameter.empty:
206
+ annotation = Any
207
+ # Handle Literal types
208
+ if get_origin(annotation) is Literal:
209
+ enum_values = get_args(annotation)
210
+ annotation = Literal.__getitem__(enum_values)
211
+ # Handle default
212
+ default = param.default
213
+ if default is Parameter.empty:
214
+ required.append(param_name)
215
+ field = Field(...)
216
+ else:
217
+ field = Field(default=default)
218
+ fields[param_name] = (annotation, field)
219
+
220
+ # Create dynamic Pydantic model
221
+ model = create_model(f"{name}_Input", **fields)
222
+ schema = model.schema()
223
+
224
+ tools.append({
225
+ 'name': name,
226
+ 'description': description,
227
+ 'input_schema': schema
228
+ })
229
+ return tools
230
+
231
+ async def call_tool(self, tool_name: str, arguments: dict[str, Any]) -> Any:
232
+ """Call a tool by name with validated arguments."""
233
+ if not hasattr(self, tool_name):
234
+ raise ValueError(f"Tool {tool_name} not found")
235
+ method = getattr(self, tool_name)
236
+
237
+ # Validate arguments using the same schema as list_tools
238
+ sig = signature(method)
239
+ parameters = sig.parameters
240
+
241
+ fields = {}
242
+ for param_name, param in parameters.items():
243
+ if param_name == 'self':
244
+ continue
245
+ annotation = param.annotation
246
+ if annotation is Parameter.empty:
247
+ annotation = Any
248
+ # Handle Literal
249
+ if get_origin(annotation) is Literal:
250
+ enum_values = get_args(annotation)
251
+ annotation = Literal.__getitem__(enum_values)
252
+ default = param.default
253
+ if default is Parameter.empty:
254
+ fields[param_name] = (annotation, Field(...))
255
+ else:
256
+ fields[param_name] = (annotation, Field(default=default))
257
+
258
+ # Create validation model
259
+ model = create_model(f"{tool_name}_ValidationModel", **fields)
260
+ validated_args = model(**arguments).dict()
261
+
262
+ # Call the method
263
+ if iscoroutinefunction(method):
264
+ return await method(**validated_args)
265
+ else:
266
+ # Note: Synchronous methods (like subscribe) will block the event loop
267
+ return method(**validated_args)
@@ -0,0 +1,60 @@
1
+ from mcp.client import Client
2
+ from typing import Optional, Dict, Any
3
+
4
+ class SmartMCPClient:
5
+ def __init__(self, base_url: str):
6
+ """
7
+ Initialize with the server URL. Headers are provided per request, not globally.
8
+ """
9
+ self.base_url = base_url
10
+
11
+ async def list_tools(self, session_id: Optional[str] = None) -> Any:
12
+ """
13
+ List tools with optional session_id header.
14
+ """
15
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
16
+ return await client.list_tools()
17
+
18
+ async def call_tool(self, tool_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
19
+ """
20
+ Call a tool with dynamic session_id.
21
+ """
22
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
23
+ return await client.call_tool(tool_name, arguments)
24
+
25
+ async def list_resources(self, session_id: Optional[str] = None) -> Any:
26
+ """
27
+ List resources with optional session_id.
28
+ """
29
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
30
+ return await client.list_resources()
31
+
32
+ async def read_resource(self, resource_uri: str, session_id: Optional[str] = None) -> Any:
33
+ """
34
+ Read a resource with dynamic session_id.
35
+ """
36
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
37
+ return await client.read_resource(resource_uri)
38
+
39
+ async def get_prompt(self, prompt_name: str, arguments: Dict[str, Any], session_id: Optional[str] = None) -> Any:
40
+ """
41
+ Fetch a prompt with optional session_id.
42
+ """
43
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
44
+ return await client.get_prompt(prompt_name, arguments)
45
+
46
+ async def ping(self, session_id: Optional[str] = None) -> Any:
47
+ """
48
+ Ping server with optional session_id.
49
+ """
50
+ async with Client(self.base_url, headers=self._build_headers(session_id)) as client:
51
+ return await client.ping()
52
+
53
+ def _build_headers(self, session_id: Optional[str]) -> Dict[str, str]:
54
+ """
55
+ Internal helper to build headers dynamically.
56
+ """
57
+ headers = {}
58
+ if session_id:
59
+ headers["x-session-id"] = session_id
60
+ return headers
@@ -0,0 +1,58 @@
1
+ # Library imports
2
+ import json
3
+ from typing import List, Dict, Any, Union, Literal
4
+
5
+ # Local imports
6
+ from smarta2a.client.smart_mcp_client import SmartMCPClient
7
+ from smarta2a.client.a2a_client import A2AClient
8
+ from smarta2a.utils.types import AgentCard
9
+
10
+ class ToolsManager:
11
+ """
12
+ Manages loading, describing, and invoking tools from various providers.
13
+ Acts as a wrapper around the MCP and A2A clients.
14
+ """
15
+ def __init__(self):
16
+ self.tools_list: List[Any] = []
17
+ self.clients: Dict[str, Union[SmartMCPClient, A2AClient]] = {}
18
+
19
+ def load_mcp_tools(self, urls_or_paths: List[str]) -> None:
20
+ for url in urls_or_paths:
21
+ mcp_client = SmartMCPClient(url)
22
+ for tool in mcp_client.list_tools():
23
+ self.tools_list.append(tool)
24
+ self.clients[tool.name] = mcp_client
25
+
26
+ def load_a2a_tools(self, agent_cards: List[AgentCard]) -> None:
27
+ for agent_card in agent_cards:
28
+ a2a_client = A2AClient(agent_card)
29
+ for tool in a2a_client.list_tools():
30
+ self.tools_list.append(tool)
31
+ self.clients[tool.name] = a2a_client
32
+
33
+ def get_tools(self) -> List[Any]:
34
+ return self.tools_list
35
+
36
+
37
+ def describe_tools(self, client_type: Literal["mcp", "a2a"]) -> str:
38
+ lines = []
39
+ for tool in self.tools_list:
40
+ if client_type == "mcp" and isinstance(tool, SmartMCPClient):
41
+ schema = json.dumps(tool.input_schema, indent=2)
42
+ lines.append(
43
+ f"- **{tool.name}**: {tool.description}\n Parameters schema:\n ```json\n{schema}\n```"
44
+ )
45
+ elif client_type == "a2a" and isinstance(tool, A2AClient):
46
+ lines.append(
47
+ f"- **{tool.name}**: {tool.description} Parameters schema:\n ```json\n{schema}\n```"
48
+ )
49
+ return "\n".join(lines)
50
+
51
+ def get_client(self, tool_name: str) -> Any:
52
+ return self.clients.get(tool_name)
53
+
54
+ async def call_tool(self, tool_name: str, args: Dict[str, Any]) -> Any:
55
+ client = self.get_client(tool_name)
56
+ if not client:
57
+ raise ValueError(f"Tool not found: {tool_name}")
58
+ return await client.call_tool(tool_name, args)
@@ -0,0 +1,8 @@
1
+ """
2
+ Strategies for updating conversation history.
3
+ """
4
+
5
+ from .history_update_strategy import HistoryUpdateStrategy
6
+ from .append_strategy import AppendStrategy
7
+
8
+ __all__ = ['HistoryUpdateStrategy', 'AppendStrategy']
@@ -0,0 +1,10 @@
1
+ # Library imports
2
+ from typing import List
3
+
4
+ # Local imports
5
+ from smarta2a.utils.types import Message
6
+
7
+ class AppendStrategy:
8
+ """Default append behavior"""
9
+ def update_history(self, existing_history: List[Message], new_messages: List[Message]) -> List[Message]:
10
+ return existing_history + new_messages
@@ -0,0 +1,15 @@
1
+ # Library imports
2
+ from typing import Protocol, List
3
+ from typing import List
4
+
5
+ # Local imports
6
+ from smarta2a.utils.types import Message
7
+
8
+ class HistoryUpdateStrategy(Protocol):
9
+ def update_history(
10
+ self,
11
+ existing_history: List[Message],
12
+ new_messages: List[Message]
13
+ ) -> List[Message]:
14
+ """Process history with new messages"""
15
+ pass
@@ -0,0 +1,5 @@
1
+ """
2
+ Model provider implementations for different AI models.
3
+ """
4
+
5
+ # This package is currently empty but will contain model provider implementations
@@ -0,0 +1,15 @@
1
+ # Library imports
2
+ from abc import ABC, abstractmethod
3
+ from typing import AsyncGenerator, List
4
+
5
+ # Local imports
6
+ from smarta2a.utils.types import Message
7
+
8
+ class BaseLLMProvider(ABC):
9
+ @abstractmethod
10
+ async def generate(self, messages: List[Message], **kwargs) -> str:
11
+ pass
12
+
13
+ @abstractmethod
14
+ async def generate_stream(self, messages: List[Message], **kwargs) -> AsyncGenerator[str, None]:
15
+ pass