agno 2.3.21__py3-none-any.whl → 2.3.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agno/agent/agent.py +48 -2
- agno/agent/remote.py +234 -73
- agno/client/a2a/__init__.py +10 -0
- agno/client/a2a/client.py +554 -0
- agno/client/a2a/schemas.py +112 -0
- agno/client/a2a/utils.py +369 -0
- agno/db/migrations/utils.py +19 -0
- agno/db/migrations/v1_to_v2.py +54 -16
- agno/db/migrations/versions/v2_3_0.py +92 -53
- agno/db/mysql/async_mysql.py +5 -7
- agno/db/mysql/mysql.py +5 -7
- agno/db/mysql/schemas.py +39 -21
- agno/db/postgres/async_postgres.py +172 -42
- agno/db/postgres/postgres.py +186 -38
- agno/db/postgres/schemas.py +39 -21
- agno/db/postgres/utils.py +6 -2
- agno/db/singlestore/schemas.py +41 -21
- agno/db/singlestore/singlestore.py +14 -3
- agno/db/sqlite/async_sqlite.py +7 -2
- agno/db/sqlite/schemas.py +36 -21
- agno/db/sqlite/sqlite.py +3 -7
- agno/knowledge/chunking/document.py +3 -2
- agno/knowledge/chunking/markdown.py +8 -3
- agno/knowledge/chunking/recursive.py +2 -2
- agno/models/base.py +4 -0
- agno/models/google/gemini.py +27 -4
- agno/models/openai/chat.py +1 -1
- agno/models/openai/responses.py +14 -7
- agno/os/middleware/jwt.py +66 -27
- agno/os/routers/agents/router.py +3 -3
- agno/os/routers/evals/evals.py +2 -2
- agno/os/routers/knowledge/knowledge.py +5 -5
- agno/os/routers/knowledge/schemas.py +1 -1
- agno/os/routers/memory/memory.py +4 -4
- agno/os/routers/session/session.py +2 -2
- agno/os/routers/teams/router.py +4 -4
- agno/os/routers/traces/traces.py +3 -3
- agno/os/routers/workflows/router.py +3 -3
- agno/os/schema.py +1 -1
- agno/reasoning/deepseek.py +11 -1
- agno/reasoning/gemini.py +6 -2
- agno/reasoning/groq.py +8 -3
- agno/reasoning/openai.py +2 -0
- agno/remote/base.py +106 -9
- agno/skills/__init__.py +17 -0
- agno/skills/agent_skills.py +370 -0
- agno/skills/errors.py +32 -0
- agno/skills/loaders/__init__.py +4 -0
- agno/skills/loaders/base.py +27 -0
- agno/skills/loaders/local.py +216 -0
- agno/skills/skill.py +65 -0
- agno/skills/utils.py +107 -0
- agno/skills/validator.py +277 -0
- agno/team/remote.py +220 -60
- agno/team/team.py +41 -3
- agno/tools/brandfetch.py +27 -18
- agno/tools/browserbase.py +150 -13
- agno/tools/function.py +6 -1
- agno/tools/mcp/mcp.py +300 -17
- agno/tools/mcp/multi_mcp.py +269 -14
- agno/tools/toolkit.py +89 -21
- agno/utils/mcp.py +49 -8
- agno/utils/string.py +43 -1
- agno/workflow/condition.py +4 -2
- agno/workflow/loop.py +20 -1
- agno/workflow/remote.py +173 -33
- agno/workflow/router.py +4 -1
- agno/workflow/steps.py +4 -0
- agno/workflow/workflow.py +14 -0
- {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/METADATA +13 -14
- {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/RECORD +74 -60
- {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/WHEEL +0 -0
- {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/licenses/LICENSE +0 -0
- {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/top_level.txt +0 -0
agno/tools/mcp/multi_mcp.py
CHANGED
|
@@ -1,9 +1,12 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import time
|
|
3
|
+
import warnings
|
|
1
4
|
import weakref
|
|
2
5
|
from contextlib import AsyncExitStack
|
|
3
6
|
from dataclasses import asdict
|
|
4
7
|
from datetime import timedelta
|
|
5
8
|
from types import TracebackType
|
|
6
|
-
from typing import List, Literal, Optional, Union
|
|
9
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
|
|
7
10
|
|
|
8
11
|
from agno.tools import Toolkit
|
|
9
12
|
from agno.tools.function import Function
|
|
@@ -11,6 +14,11 @@ from agno.tools.mcp.params import SSEClientParams, StreamableHTTPClientParams
|
|
|
11
14
|
from agno.utils.log import log_debug, log_error, log_info, log_warning
|
|
12
15
|
from agno.utils.mcp import get_entrypoint_for_tool, prepare_command
|
|
13
16
|
|
|
17
|
+
if TYPE_CHECKING:
|
|
18
|
+
from agno.agent import Agent
|
|
19
|
+
from agno.run import RunContext
|
|
20
|
+
from agno.team.team import Team
|
|
21
|
+
|
|
14
22
|
try:
|
|
15
23
|
from mcp import ClientSession, StdioServerParameters
|
|
16
24
|
from mcp.client.sse import sse_client
|
|
@@ -47,6 +55,7 @@ class MultiMCPTools(Toolkit):
|
|
|
47
55
|
exclude_tools: Optional[list[str]] = None,
|
|
48
56
|
refresh_connection: bool = False,
|
|
49
57
|
allow_partial_failure: bool = False,
|
|
58
|
+
header_provider: Optional[Callable[..., dict[str, Any]]] = None,
|
|
50
59
|
**kwargs,
|
|
51
60
|
):
|
|
52
61
|
"""
|
|
@@ -64,7 +73,14 @@ class MultiMCPTools(Toolkit):
|
|
|
64
73
|
exclude_tools: Optional list of tool names to exclude (if None, excludes none).
|
|
65
74
|
allow_partial_failure: If True, allows toolkit to initialize even if some MCP servers fail to connect. If False, any failure will raise an exception.
|
|
66
75
|
refresh_connection: If True, the connection and tools will be refreshed on each run
|
|
76
|
+
header_provider: Header provider function for all servers. Takes RunContext and returns dict of HTTP headers.
|
|
67
77
|
"""
|
|
78
|
+
warnings.warn(
|
|
79
|
+
"The MultiMCPTools class is deprecated and will be removed in a future version. Please use multiple MCPTools instances instead.",
|
|
80
|
+
DeprecationWarning,
|
|
81
|
+
stacklevel=2,
|
|
82
|
+
)
|
|
83
|
+
|
|
68
84
|
super().__init__(name="MultiMCPTools", **kwargs)
|
|
69
85
|
|
|
70
86
|
if urls_transports is not None:
|
|
@@ -86,6 +102,16 @@ class MultiMCPTools(Toolkit):
|
|
|
86
102
|
self.exclude_tools = exclude_tools
|
|
87
103
|
self.refresh_connection = refresh_connection
|
|
88
104
|
|
|
105
|
+
self.header_provider = header_provider
|
|
106
|
+
|
|
107
|
+
# Validate header_provider signature
|
|
108
|
+
if header_provider:
|
|
109
|
+
try:
|
|
110
|
+
# Just verify we can inspect the signature - no parameter requirements
|
|
111
|
+
inspect.signature(header_provider)
|
|
112
|
+
except Exception as e:
|
|
113
|
+
log_warning(f"Could not validate header_provider signature: {e}")
|
|
114
|
+
|
|
89
115
|
if server_params_list is None and commands is None and urls is None:
|
|
90
116
|
raise ValueError("Either server_params_list or commands or urls must be provided")
|
|
91
117
|
|
|
@@ -130,6 +156,14 @@ class MultiMCPTools(Toolkit):
|
|
|
130
156
|
self._connection_task = None
|
|
131
157
|
self._successful_connections = 0
|
|
132
158
|
self._sessions: list[ClientSession] = []
|
|
159
|
+
self._session_to_server_idx: Dict[int, int] = {} # Maps session list index to server params index
|
|
160
|
+
|
|
161
|
+
# Session management for per-agent-run sessions with dynamic headers
|
|
162
|
+
# For MultiMCP, we track sessions per (run_id, server_idx) since we have multiple servers
|
|
163
|
+
# Maps (run_id, server_idx) to (session, timestamp) for TTL-based cleanup
|
|
164
|
+
self._run_sessions: Dict[Tuple[str, int], Tuple[ClientSession, float]] = {}
|
|
165
|
+
self._run_session_contexts: Dict[Tuple[str, int], Any] = {} # Maps (run_id, server_idx) to context managers
|
|
166
|
+
self._session_ttl_seconds: float = 300.0 # 5 minutes default TTL
|
|
133
167
|
|
|
134
168
|
self.allow_partial_failure = allow_partial_failure
|
|
135
169
|
|
|
@@ -153,6 +187,205 @@ class MultiMCPTools(Toolkit):
|
|
|
153
187
|
except (RuntimeError, BaseException):
|
|
154
188
|
return False
|
|
155
189
|
|
|
190
|
+
def _call_header_provider(
|
|
191
|
+
self,
|
|
192
|
+
run_context: Optional["RunContext"] = None,
|
|
193
|
+
agent: Optional["Agent"] = None,
|
|
194
|
+
team: Optional["Team"] = None,
|
|
195
|
+
) -> dict[str, Any]:
|
|
196
|
+
"""Call the header_provider with run_context, agent, and/or team based on its signature.
|
|
197
|
+
|
|
198
|
+
Args:
|
|
199
|
+
run_context: The RunContext for the current agent run
|
|
200
|
+
agent: The Agent instance (if running within an agent)
|
|
201
|
+
team: The Team instance (if running within a team)
|
|
202
|
+
|
|
203
|
+
Returns:
|
|
204
|
+
dict[str, Any]: The headers returned by the header_provider
|
|
205
|
+
"""
|
|
206
|
+
header_provider = getattr(self, "header_provider", None)
|
|
207
|
+
if header_provider is None:
|
|
208
|
+
return {}
|
|
209
|
+
|
|
210
|
+
try:
|
|
211
|
+
sig = inspect.signature(header_provider)
|
|
212
|
+
param_names = set(sig.parameters.keys())
|
|
213
|
+
|
|
214
|
+
# Build kwargs based on what the function accepts
|
|
215
|
+
call_kwargs: dict[str, Any] = {}
|
|
216
|
+
|
|
217
|
+
if "run_context" in param_names:
|
|
218
|
+
call_kwargs["run_context"] = run_context
|
|
219
|
+
if "agent" in param_names:
|
|
220
|
+
call_kwargs["agent"] = agent
|
|
221
|
+
if "team" in param_names:
|
|
222
|
+
call_kwargs["team"] = team
|
|
223
|
+
|
|
224
|
+
# Check if function accepts **kwargs (VAR_KEYWORD)
|
|
225
|
+
has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
|
|
226
|
+
|
|
227
|
+
if has_var_keyword:
|
|
228
|
+
# Pass all available context to **kwargs
|
|
229
|
+
call_kwargs = {"run_context": run_context, "agent": agent, "team": team}
|
|
230
|
+
return header_provider(**call_kwargs)
|
|
231
|
+
elif call_kwargs:
|
|
232
|
+
return header_provider(**call_kwargs)
|
|
233
|
+
else:
|
|
234
|
+
# Function takes no recognized parameters - check for positional
|
|
235
|
+
positional_params = [
|
|
236
|
+
p
|
|
237
|
+
for p in sig.parameters.values()
|
|
238
|
+
if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
|
|
239
|
+
]
|
|
240
|
+
if positional_params:
|
|
241
|
+
# Legacy support: pass run_context as first positional arg
|
|
242
|
+
return header_provider(run_context)
|
|
243
|
+
else:
|
|
244
|
+
# Function takes no parameters
|
|
245
|
+
return header_provider()
|
|
246
|
+
except Exception as e:
|
|
247
|
+
log_warning(f"Error calling header_provider: {e}")
|
|
248
|
+
return {}
|
|
249
|
+
|
|
250
|
+
async def _cleanup_stale_sessions(self) -> None:
|
|
251
|
+
"""Clean up sessions older than TTL to prevent memory leaks."""
|
|
252
|
+
if not self._run_sessions:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
now = time.time()
|
|
256
|
+
stale_keys = [
|
|
257
|
+
cache_key
|
|
258
|
+
for cache_key, (_, created_at) in self._run_sessions.items()
|
|
259
|
+
if now - created_at > self._session_ttl_seconds
|
|
260
|
+
]
|
|
261
|
+
|
|
262
|
+
for run_id, server_idx in stale_keys:
|
|
263
|
+
log_debug(f"Cleaning up stale session for run_id={run_id}, server_idx={server_idx}")
|
|
264
|
+
await self.cleanup_run_session(run_id, server_idx)
|
|
265
|
+
|
|
266
|
+
async def get_session_for_run(
|
|
267
|
+
self,
|
|
268
|
+
run_context: Optional["RunContext"] = None,
|
|
269
|
+
server_idx: int = 0,
|
|
270
|
+
agent: Optional["Agent"] = None,
|
|
271
|
+
team: Optional["Team"] = None,
|
|
272
|
+
) -> ClientSession:
|
|
273
|
+
"""
|
|
274
|
+
Get or create a session for the given run_context and server index.
|
|
275
|
+
|
|
276
|
+
If header_provider is configured and run_context is provided, this creates
|
|
277
|
+
a new session with dynamic headers for this specific agent run and server.
|
|
278
|
+
|
|
279
|
+
Args:
|
|
280
|
+
run_context: The RunContext containing user_id, metadata, etc.
|
|
281
|
+
server_idx: Index of the server in self._sessions list
|
|
282
|
+
agent: The Agent instance (if running within an agent)
|
|
283
|
+
team: The Team instance (if running within a team)
|
|
284
|
+
|
|
285
|
+
Returns:
|
|
286
|
+
ClientSession: Either the default session or a per-run session with dynamic headers
|
|
287
|
+
"""
|
|
288
|
+
# If no header_provider or no run_context, use the default session
|
|
289
|
+
if not self.header_provider or not run_context:
|
|
290
|
+
# Return the default session for this server
|
|
291
|
+
if server_idx < len(self._sessions):
|
|
292
|
+
return self._sessions[server_idx]
|
|
293
|
+
raise ValueError(f"Server index {server_idx} out of range")
|
|
294
|
+
|
|
295
|
+
# Lazy cleanup of stale sessions
|
|
296
|
+
await self._cleanup_stale_sessions()
|
|
297
|
+
|
|
298
|
+
# Check if we already have a session for this (run_id, server_idx)
|
|
299
|
+
run_id = run_context.run_id
|
|
300
|
+
cache_key = (run_id, server_idx)
|
|
301
|
+
if cache_key in self._run_sessions:
|
|
302
|
+
session, _ = self._run_sessions[cache_key]
|
|
303
|
+
return session
|
|
304
|
+
|
|
305
|
+
# Create a new session with dynamic headers for this run and server
|
|
306
|
+
log_debug(f"Creating new session for run_id={run_id}, server_idx={server_idx} with dynamic headers")
|
|
307
|
+
|
|
308
|
+
# Generate dynamic headers from the provider
|
|
309
|
+
dynamic_headers = self._call_header_provider(run_context=run_context, agent=agent, team=team)
|
|
310
|
+
|
|
311
|
+
# Get the server params for this server index
|
|
312
|
+
if server_idx >= len(self.server_params_list):
|
|
313
|
+
raise ValueError(f"Server index {server_idx} out of range")
|
|
314
|
+
|
|
315
|
+
server_params = self.server_params_list[server_idx]
|
|
316
|
+
|
|
317
|
+
# Create new session with merged headers based on transport type
|
|
318
|
+
if isinstance(server_params, SSEClientParams):
|
|
319
|
+
params_dict = asdict(server_params)
|
|
320
|
+
existing_headers = params_dict.get("headers") or {}
|
|
321
|
+
params_dict["headers"] = {**existing_headers, **dynamic_headers}
|
|
322
|
+
|
|
323
|
+
context = sse_client(**params_dict) # type: ignore
|
|
324
|
+
client_timeout = min(self.timeout_seconds, params_dict.get("timeout", self.timeout_seconds))
|
|
325
|
+
|
|
326
|
+
elif isinstance(server_params, StreamableHTTPClientParams):
|
|
327
|
+
params_dict = asdict(server_params)
|
|
328
|
+
existing_headers = params_dict.get("headers") or {}
|
|
329
|
+
params_dict["headers"] = {**existing_headers, **dynamic_headers}
|
|
330
|
+
|
|
331
|
+
context = streamablehttp_client(**params_dict) # type: ignore
|
|
332
|
+
params_timeout = params_dict.get("timeout", self.timeout_seconds)
|
|
333
|
+
if isinstance(params_timeout, timedelta):
|
|
334
|
+
params_timeout = int(params_timeout.total_seconds())
|
|
335
|
+
client_timeout = min(self.timeout_seconds, params_timeout)
|
|
336
|
+
else:
|
|
337
|
+
# stdio doesn't support headers, fall back to default session
|
|
338
|
+
log_warning(
|
|
339
|
+
f"Cannot use dynamic headers with stdio transport for server {server_idx}, using default session"
|
|
340
|
+
)
|
|
341
|
+
if server_idx < len(self._sessions):
|
|
342
|
+
return self._sessions[server_idx]
|
|
343
|
+
raise ValueError(f"Server index {server_idx} out of range")
|
|
344
|
+
|
|
345
|
+
# Enter the context and create session
|
|
346
|
+
session_params = await context.__aenter__() # type: ignore
|
|
347
|
+
read, write = session_params[0:2]
|
|
348
|
+
|
|
349
|
+
session_context = ClientSession(read, write, read_timeout_seconds=timedelta(seconds=client_timeout)) # type: ignore
|
|
350
|
+
session = await session_context.__aenter__() # type: ignore
|
|
351
|
+
|
|
352
|
+
# Initialize the session
|
|
353
|
+
await session.initialize()
|
|
354
|
+
|
|
355
|
+
# Store the session with timestamp and context for cleanup
|
|
356
|
+
self._run_sessions[cache_key] = (session, time.time())
|
|
357
|
+
self._run_session_contexts[cache_key] = (context, session_context)
|
|
358
|
+
|
|
359
|
+
return session
|
|
360
|
+
|
|
361
|
+
async def cleanup_run_session(self, run_id: str, server_idx: int) -> None:
|
|
362
|
+
"""Clean up a per-run session."""
|
|
363
|
+
cache_key = (run_id, server_idx)
|
|
364
|
+
if cache_key not in self._run_sessions:
|
|
365
|
+
return
|
|
366
|
+
|
|
367
|
+
try:
|
|
368
|
+
context, session_context = self._run_session_contexts[cache_key]
|
|
369
|
+
|
|
370
|
+
# Exit session context - silently ignore errors
|
|
371
|
+
try:
|
|
372
|
+
await session_context.__aexit__(None, None, None)
|
|
373
|
+
except (RuntimeError, Exception):
|
|
374
|
+
pass # Silently ignore
|
|
375
|
+
|
|
376
|
+
# Exit transport context - silently ignore errors
|
|
377
|
+
try:
|
|
378
|
+
await context.__aexit__(None, None, None)
|
|
379
|
+
except (RuntimeError, Exception):
|
|
380
|
+
pass # Silently ignore
|
|
381
|
+
|
|
382
|
+
except Exception:
|
|
383
|
+
pass # Silently ignore all cleanup errors
|
|
384
|
+
finally:
|
|
385
|
+
# Remove from cache
|
|
386
|
+
self._run_sessions.pop(cache_key, None)
|
|
387
|
+
self._run_session_contexts.pop(cache_key, None)
|
|
388
|
+
|
|
156
389
|
async def connect(self, force: bool = False):
|
|
157
390
|
"""Initialize a MultiMCPTools instance and connect to the MCP servers"""
|
|
158
391
|
|
|
@@ -214,7 +447,7 @@ class MultiMCPTools(Toolkit):
|
|
|
214
447
|
|
|
215
448
|
server_connection_errors = []
|
|
216
449
|
|
|
217
|
-
for server_params in self.server_params_list:
|
|
450
|
+
for server_idx, server_params in enumerate(self.server_params_list):
|
|
218
451
|
try:
|
|
219
452
|
# Handle stdio connections
|
|
220
453
|
if isinstance(server_params, StdioServerParameters):
|
|
@@ -223,7 +456,7 @@ class MultiMCPTools(Toolkit):
|
|
|
223
456
|
session = await self._async_exit_stack.enter_async_context(
|
|
224
457
|
ClientSession(read, write, read_timeout_seconds=timedelta(seconds=self.timeout_seconds))
|
|
225
458
|
)
|
|
226
|
-
await self.initialize(session)
|
|
459
|
+
await self.initialize(session, server_idx)
|
|
227
460
|
self._successful_connections += 1
|
|
228
461
|
|
|
229
462
|
# Handle SSE connections
|
|
@@ -233,7 +466,7 @@ class MultiMCPTools(Toolkit):
|
|
|
233
466
|
)
|
|
234
467
|
read, write = client_connection
|
|
235
468
|
session = await self._async_exit_stack.enter_async_context(ClientSession(read, write))
|
|
236
|
-
await self.initialize(session)
|
|
469
|
+
await self.initialize(session, server_idx)
|
|
237
470
|
self._successful_connections += 1
|
|
238
471
|
|
|
239
472
|
# Handle Streamable HTTP connections
|
|
@@ -243,7 +476,7 @@ class MultiMCPTools(Toolkit):
|
|
|
243
476
|
)
|
|
244
477
|
read, write = client_connection[0:2]
|
|
245
478
|
session = await self._async_exit_stack.enter_async_context(ClientSession(read, write))
|
|
246
|
-
await self.initialize(session)
|
|
479
|
+
await self.initialize(session, server_idx)
|
|
247
480
|
self._successful_connections += 1
|
|
248
481
|
|
|
249
482
|
except Exception as e:
|
|
@@ -268,13 +501,26 @@ class MultiMCPTools(Toolkit):
|
|
|
268
501
|
if not self._initialized:
|
|
269
502
|
return
|
|
270
503
|
|
|
271
|
-
|
|
272
|
-
await self._async_exit_stack.aclose()
|
|
273
|
-
self._sessions = []
|
|
274
|
-
self._successful_connections = 0
|
|
504
|
+
import warnings
|
|
275
505
|
|
|
276
|
-
|
|
277
|
-
|
|
506
|
+
# Suppress async generator cleanup warnings
|
|
507
|
+
with warnings.catch_warnings():
|
|
508
|
+
warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*async_generator.*")
|
|
509
|
+
warnings.filterwarnings("ignore", message=".*cancel scope.*")
|
|
510
|
+
|
|
511
|
+
try:
|
|
512
|
+
# Clean up all per-run sessions first
|
|
513
|
+
cache_keys = list(self._run_sessions.keys())
|
|
514
|
+
for run_id, server_idx in cache_keys:
|
|
515
|
+
await self.cleanup_run_session(run_id, server_idx)
|
|
516
|
+
|
|
517
|
+
# Clean up main sessions
|
|
518
|
+
await self._async_exit_stack.aclose()
|
|
519
|
+
self._sessions = []
|
|
520
|
+
self._successful_connections = 0
|
|
521
|
+
|
|
522
|
+
except (RuntimeError, BaseException):
|
|
523
|
+
pass # Silently ignore all cleanup errors
|
|
278
524
|
|
|
279
525
|
self._initialized = False
|
|
280
526
|
|
|
@@ -298,7 +544,7 @@ class MultiMCPTools(Toolkit):
|
|
|
298
544
|
self._successful_connections = 0
|
|
299
545
|
|
|
300
546
|
async def build_tools(self) -> None:
|
|
301
|
-
for session in self._sessions:
|
|
547
|
+
for session_list_idx, session in enumerate(self._sessions):
|
|
302
548
|
# Get the list of tools from the MCP server
|
|
303
549
|
available_tools = await session.list_tools()
|
|
304
550
|
|
|
@@ -314,7 +560,12 @@ class MultiMCPTools(Toolkit):
|
|
|
314
560
|
for tool in filtered_tools:
|
|
315
561
|
try:
|
|
316
562
|
# Get an entrypoint for the tool
|
|
317
|
-
entrypoint = get_entrypoint_for_tool(
|
|
563
|
+
entrypoint = get_entrypoint_for_tool(
|
|
564
|
+
tool=tool,
|
|
565
|
+
session=session,
|
|
566
|
+
mcp_tools_instance=self, # Pass self to enable dynamic headers
|
|
567
|
+
server_idx=session_list_idx, # Pass session list index for session lookup
|
|
568
|
+
)
|
|
318
569
|
|
|
319
570
|
# Create a Function for the tool
|
|
320
571
|
f = Function(
|
|
@@ -333,14 +584,18 @@ class MultiMCPTools(Toolkit):
|
|
|
333
584
|
log_error(f"Failed to register tool {tool.name}: {e}")
|
|
334
585
|
raise
|
|
335
586
|
|
|
336
|
-
async def initialize(self, session: ClientSession) -> None:
|
|
587
|
+
async def initialize(self, session: ClientSession, server_idx: int = 0) -> None:
|
|
337
588
|
"""Initialize the MCP toolkit by getting available tools from the MCP server"""
|
|
338
589
|
|
|
339
590
|
try:
|
|
340
591
|
# Initialize the session if not already initialized
|
|
341
592
|
await session.initialize()
|
|
342
593
|
|
|
594
|
+
# Track which server index this session belongs to
|
|
595
|
+
session_list_idx = len(self._sessions)
|
|
343
596
|
self._sessions.append(session)
|
|
597
|
+
self._session_to_server_idx[session_list_idx] = server_idx
|
|
598
|
+
|
|
344
599
|
self._initialized = True
|
|
345
600
|
except Exception as e:
|
|
346
601
|
log_error(f"Failed to get MCP tools: {e}")
|
agno/tools/toolkit.py
CHANGED
|
@@ -1,4 +1,5 @@
|
|
|
1
1
|
from collections import OrderedDict
|
|
2
|
+
from inspect import iscoroutinefunction
|
|
2
3
|
from typing import Any, Callable, Dict, List, Optional, Sequence, Union
|
|
3
4
|
|
|
4
5
|
from agno.tools.function import Function
|
|
@@ -14,6 +15,7 @@ class Toolkit:
|
|
|
14
15
|
self,
|
|
15
16
|
name: str = "toolkit",
|
|
16
17
|
tools: Sequence[Union[Callable[..., Any], Function]] = [],
|
|
18
|
+
async_tools: Optional[Sequence[tuple[Callable[..., Any], str]]] = None,
|
|
17
19
|
instructions: Optional[str] = None,
|
|
18
20
|
add_instructions: bool = False,
|
|
19
21
|
include_tools: Optional[list[str]] = None,
|
|
@@ -32,6 +34,9 @@ class Toolkit:
|
|
|
32
34
|
Args:
|
|
33
35
|
name: A descriptive name for the toolkit
|
|
34
36
|
tools: List of tools to include in the toolkit (can be callables or Function objects from @tool decorator)
|
|
37
|
+
async_tools: List of (async_callable, tool_name) tuples for async variants.
|
|
38
|
+
Used when async methods have different names than sync methods.
|
|
39
|
+
Example: [(self.anavigate_to, "navigate_to"), (self.ascreenshot, "screenshot")]
|
|
35
40
|
instructions: Instructions for the toolkit
|
|
36
41
|
add_instructions: Whether to add instructions to the toolkit
|
|
37
42
|
include_tools: List of tool names to include in the toolkit
|
|
@@ -47,7 +52,11 @@ class Toolkit:
|
|
|
47
52
|
"""
|
|
48
53
|
self.name: str = name
|
|
49
54
|
self.tools: Sequence[Union[Callable[..., Any], Function]] = tools
|
|
55
|
+
self._async_tools: Sequence[tuple[Callable[..., Any], str]] = async_tools or []
|
|
56
|
+
# Functions dict - used by agent.run() and agent.print_response()
|
|
50
57
|
self.functions: Dict[str, Function] = OrderedDict()
|
|
58
|
+
# Async functions dict - used by agent.arun() and agent.aprint_response()
|
|
59
|
+
self.async_functions: Dict[str, Function] = OrderedDict()
|
|
51
60
|
self.instructions: Optional[str] = instructions
|
|
52
61
|
self.add_instructions: bool = add_instructions
|
|
53
62
|
|
|
@@ -71,8 +80,11 @@ class Toolkit:
|
|
|
71
80
|
self.cache_dir: Optional[str] = cache_dir
|
|
72
81
|
|
|
73
82
|
# Automatically register all methods if auto_register is True
|
|
74
|
-
if auto_register
|
|
75
|
-
self.
|
|
83
|
+
if auto_register:
|
|
84
|
+
if self.tools:
|
|
85
|
+
self._register_tools()
|
|
86
|
+
if self._async_tools:
|
|
87
|
+
self._register_async_tools()
|
|
76
88
|
|
|
77
89
|
def _get_tool_name(self, tool: Union[Callable[..., Any], Function]) -> str:
|
|
78
90
|
"""Get the name of a tool, whether it's a Function or callable."""
|
|
@@ -125,14 +137,25 @@ class Toolkit:
|
|
|
125
137
|
log_warning(f"Show result tool(s) not present in the toolkit: {', '.join(missing_show_result)}")
|
|
126
138
|
|
|
127
139
|
def _register_tools(self) -> None:
|
|
128
|
-
"""Register all tools."""
|
|
140
|
+
"""Register all sync tools."""
|
|
129
141
|
for tool in self.tools:
|
|
130
142
|
self.register(tool)
|
|
131
143
|
|
|
144
|
+
def _register_async_tools(self) -> None:
|
|
145
|
+
"""Register all async tools with their mapped names.
|
|
146
|
+
|
|
147
|
+
Async detection is automatic via iscoroutinefunction.
|
|
148
|
+
"""
|
|
149
|
+
for async_func, tool_name in self._async_tools:
|
|
150
|
+
self.register(async_func, name=tool_name)
|
|
151
|
+
|
|
132
152
|
def register(self, function: Union[Callable[..., Any], Function], name: Optional[str] = None) -> None:
|
|
133
153
|
"""Register a function with the toolkit.
|
|
134
154
|
|
|
135
155
|
This method supports both regular callables and Function objects (from @tool decorator).
|
|
156
|
+
Automatically detects if the function is async (using iscoroutinefunction) and registers
|
|
157
|
+
it to the appropriate dict (functions for sync, async_functions for async).
|
|
158
|
+
|
|
136
159
|
When a Function object is passed (e.g., from a @tool decorated method), it will:
|
|
137
160
|
1. Extract the configuration from the Function object
|
|
138
161
|
2. Look for a bound method with the same name on `self`
|
|
@@ -140,17 +163,18 @@ class Toolkit:
|
|
|
140
163
|
|
|
141
164
|
Args:
|
|
142
165
|
function: The callable or Function object to register
|
|
143
|
-
name: Optional custom name for the function
|
|
144
|
-
|
|
145
|
-
Returns:
|
|
146
|
-
The registered function
|
|
166
|
+
name: Optional custom name for the function (useful for aliasing)
|
|
147
167
|
"""
|
|
148
168
|
try:
|
|
149
169
|
# Handle Function objects (from @tool decorator)
|
|
150
170
|
if isinstance(function, Function):
|
|
151
|
-
|
|
171
|
+
# Auto-detect if this is an async function
|
|
172
|
+
is_async = function.entrypoint is not None and iscoroutinefunction(function.entrypoint)
|
|
173
|
+
return self._register_decorated_tool(function, name, is_async=is_async)
|
|
174
|
+
|
|
175
|
+
# Handle regular callables - auto-detect async
|
|
176
|
+
is_async = iscoroutinefunction(function)
|
|
152
177
|
|
|
153
|
-
# Handle regular callables
|
|
154
178
|
tool_name = name or function.__name__
|
|
155
179
|
if self.include_tools is not None and tool_name not in self.include_tools:
|
|
156
180
|
return
|
|
@@ -168,14 +192,19 @@ class Toolkit:
|
|
|
168
192
|
stop_after_tool_call=tool_name in self.stop_after_tool_call_tools,
|
|
169
193
|
show_result=tool_name in self.show_result_tools or tool_name in self.stop_after_tool_call_tools,
|
|
170
194
|
)
|
|
171
|
-
|
|
172
|
-
|
|
195
|
+
|
|
196
|
+
if is_async:
|
|
197
|
+
self.async_functions[f.name] = f
|
|
198
|
+
log_debug(f"Async function: {f.name} registered with {self.name}")
|
|
199
|
+
else:
|
|
200
|
+
self.functions[f.name] = f
|
|
201
|
+
log_debug(f"Function: {f.name} registered with {self.name}")
|
|
173
202
|
except Exception as e:
|
|
174
203
|
func_name = self._get_tool_name(function)
|
|
175
204
|
logger.warning(f"Failed to create Function for: {func_name}")
|
|
176
205
|
raise e
|
|
177
206
|
|
|
178
|
-
def _register_decorated_tool(self, function: Function, name: Optional[str] = None) -> None:
|
|
207
|
+
def _register_decorated_tool(self, function: Function, name: Optional[str] = None, is_async: bool = False) -> None:
|
|
179
208
|
"""Register a Function object from @tool decorator, binding it to self.
|
|
180
209
|
|
|
181
210
|
When @tool decorator is used on a class method, it creates a Function with an unbound
|
|
@@ -185,6 +214,7 @@ class Toolkit:
|
|
|
185
214
|
Args:
|
|
186
215
|
function: The Function object from @tool decorator
|
|
187
216
|
name: Optional custom name override
|
|
217
|
+
is_async: If True, register to async_functions dict instead of functions
|
|
188
218
|
"""
|
|
189
219
|
import inspect
|
|
190
220
|
|
|
@@ -207,14 +237,24 @@ class Toolkit:
|
|
|
207
237
|
|
|
208
238
|
if params and params[0] == "self":
|
|
209
239
|
# Create a bound method by wrapping the function to include self
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
240
|
+
if is_async:
|
|
241
|
+
|
|
242
|
+
def make_bound_method(func, instance):
|
|
243
|
+
async def bound(*args, **kwargs):
|
|
244
|
+
return await func(instance, *args, **kwargs)
|
|
245
|
+
|
|
246
|
+
bound.__name__ = getattr(func, "__name__", tool_name)
|
|
247
|
+
bound.__doc__ = getattr(func, "__doc__", None)
|
|
248
|
+
return bound
|
|
249
|
+
else:
|
|
250
|
+
|
|
251
|
+
def make_bound_method(func, instance):
|
|
252
|
+
def bound(*args, **kwargs):
|
|
253
|
+
return func(instance, *args, **kwargs)
|
|
213
254
|
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
return bound
|
|
255
|
+
bound.__name__ = getattr(func, "__name__", tool_name)
|
|
256
|
+
bound.__doc__ = getattr(func, "__doc__", None)
|
|
257
|
+
return bound
|
|
218
258
|
|
|
219
259
|
bound_method = make_bound_method(original_func, self)
|
|
220
260
|
else:
|
|
@@ -251,8 +291,36 @@ class Toolkit:
|
|
|
251
291
|
cache_dir=function.cache_dir if function.cache_dir else self.cache_dir,
|
|
252
292
|
cache_ttl=function.cache_ttl if function.cache_ttl != 3600 else self.cache_ttl,
|
|
253
293
|
)
|
|
254
|
-
|
|
255
|
-
|
|
294
|
+
|
|
295
|
+
if is_async:
|
|
296
|
+
self.async_functions[f.name] = f
|
|
297
|
+
log_debug(f"Async function: {f.name} registered with {self.name} (from @tool decorator)")
|
|
298
|
+
else:
|
|
299
|
+
self.functions[f.name] = f
|
|
300
|
+
log_debug(f"Function: {f.name} registered with {self.name} (from @tool decorator)")
|
|
301
|
+
|
|
302
|
+
def get_functions(self) -> Dict[str, Function]:
|
|
303
|
+
"""Get sync functions dict.
|
|
304
|
+
|
|
305
|
+
Returns:
|
|
306
|
+
Dict of function name to Function for sync execution
|
|
307
|
+
"""
|
|
308
|
+
return self.functions
|
|
309
|
+
|
|
310
|
+
def get_async_functions(self) -> Dict[str, Function]:
|
|
311
|
+
"""Get functions dict optimized for async execution.
|
|
312
|
+
|
|
313
|
+
Returns a merged dict where async_functions take precedence over functions.
|
|
314
|
+
This allows async-optimized implementations to be automatically used in async contexts,
|
|
315
|
+
while falling back to sync implementations for tools without async variants.
|
|
316
|
+
|
|
317
|
+
Returns:
|
|
318
|
+
Dict of function name to Function, with async variants preferred
|
|
319
|
+
"""
|
|
320
|
+
# Merge: start with sync functions, override with async variants
|
|
321
|
+
merged = OrderedDict(self.functions)
|
|
322
|
+
merged.update(self.async_functions)
|
|
323
|
+
return merged
|
|
256
324
|
|
|
257
325
|
@property
|
|
258
326
|
def requires_connect(self) -> bool:
|
agno/utils/mcp.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import json
|
|
2
2
|
from functools import partial
|
|
3
|
+
from typing import TYPE_CHECKING, Optional, Union
|
|
3
4
|
from uuid import uuid4
|
|
4
5
|
|
|
5
6
|
from agno.utils.log import log_debug, log_exception
|
|
@@ -15,28 +16,68 @@ except (ImportError, ModuleNotFoundError):
|
|
|
15
16
|
from agno.media import Image
|
|
16
17
|
from agno.tools.function import ToolResult
|
|
17
18
|
|
|
19
|
+
if TYPE_CHECKING:
|
|
20
|
+
from agno.agent import Agent
|
|
21
|
+
from agno.run import RunContext
|
|
22
|
+
from agno.team.team import Team
|
|
23
|
+
from agno.tools.mcp.mcp import MCPTools
|
|
24
|
+
from agno.tools.mcp.multi_mcp import MultiMCPTools
|
|
18
25
|
|
|
19
|
-
|
|
26
|
+
|
|
27
|
+
def get_entrypoint_for_tool(
|
|
28
|
+
tool: MCPTool,
|
|
29
|
+
session: ClientSession,
|
|
30
|
+
mcp_tools_instance: Optional[Union["MCPTools", "MultiMCPTools"]] = None,
|
|
31
|
+
server_idx: int = 0,
|
|
32
|
+
):
|
|
20
33
|
"""
|
|
21
34
|
Return an entrypoint for an MCP tool.
|
|
22
35
|
|
|
23
36
|
Args:
|
|
24
37
|
tool: The MCP tool to create an entrypoint for
|
|
25
|
-
session: The
|
|
38
|
+
session: The MCP ClientSession to use
|
|
39
|
+
mcp_tools_instance: Optional MCPTools or MultiMCPTools instance
|
|
40
|
+
server_idx: Index of the server (for MultiMCPTools)
|
|
26
41
|
|
|
27
42
|
Returns:
|
|
28
43
|
Callable: The entrypoint function for the tool
|
|
29
44
|
"""
|
|
30
45
|
|
|
31
|
-
async def call_tool(
|
|
46
|
+
async def call_tool(
|
|
47
|
+
tool_name: str,
|
|
48
|
+
run_context: Optional["RunContext"] = None,
|
|
49
|
+
agent: Optional["Agent"] = None,
|
|
50
|
+
team: Optional["Team"] = None,
|
|
51
|
+
**kwargs,
|
|
52
|
+
) -> ToolResult:
|
|
53
|
+
# Execute the MCP tool call
|
|
32
54
|
try:
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
55
|
+
# Get the appropriate session for this run
|
|
56
|
+
# If mcp_tools_instance has header_provider and run_context is provided,
|
|
57
|
+
# this will create/reuse a session with dynamic headers
|
|
58
|
+
if mcp_tools_instance and hasattr(mcp_tools_instance, "get_session_for_run"):
|
|
59
|
+
# Import here to avoid circular imports
|
|
60
|
+
from agno.tools.mcp.multi_mcp import MultiMCPTools
|
|
61
|
+
|
|
62
|
+
# For MultiMCPTools, pass server_idx; for MCPTools, only pass run_context
|
|
63
|
+
if isinstance(mcp_tools_instance, MultiMCPTools):
|
|
64
|
+
active_session = await mcp_tools_instance.get_session_for_run(
|
|
65
|
+
run_context=run_context, server_idx=server_idx, agent=agent, team=team
|
|
66
|
+
)
|
|
67
|
+
else:
|
|
68
|
+
active_session = await mcp_tools_instance.get_session_for_run(
|
|
69
|
+
run_context=run_context, agent=agent, team=team
|
|
70
|
+
)
|
|
71
|
+
else:
|
|
72
|
+
active_session = session
|
|
73
|
+
|
|
74
|
+
try:
|
|
75
|
+
await active_session.send_ping()
|
|
76
|
+
except Exception as e:
|
|
77
|
+
log_exception(e)
|
|
36
78
|
|
|
37
|
-
try:
|
|
38
79
|
log_debug(f"Calling MCP Tool '{tool_name}' with args: {kwargs}")
|
|
39
|
-
result: CallToolResult = await
|
|
80
|
+
result: CallToolResult = await active_session.call_tool(tool_name, kwargs) # type: ignore
|
|
40
81
|
|
|
41
82
|
# Return an error if the tool call failed
|
|
42
83
|
if result.isError:
|