agno 2.3.21__py3-none-any.whl → 2.3.23__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. agno/agent/agent.py +48 -2
  2. agno/agent/remote.py +234 -73
  3. agno/client/a2a/__init__.py +10 -0
  4. agno/client/a2a/client.py +554 -0
  5. agno/client/a2a/schemas.py +112 -0
  6. agno/client/a2a/utils.py +369 -0
  7. agno/db/migrations/utils.py +19 -0
  8. agno/db/migrations/v1_to_v2.py +54 -16
  9. agno/db/migrations/versions/v2_3_0.py +92 -53
  10. agno/db/mysql/async_mysql.py +5 -7
  11. agno/db/mysql/mysql.py +5 -7
  12. agno/db/mysql/schemas.py +39 -21
  13. agno/db/postgres/async_postgres.py +172 -42
  14. agno/db/postgres/postgres.py +186 -38
  15. agno/db/postgres/schemas.py +39 -21
  16. agno/db/postgres/utils.py +6 -2
  17. agno/db/singlestore/schemas.py +41 -21
  18. agno/db/singlestore/singlestore.py +14 -3
  19. agno/db/sqlite/async_sqlite.py +7 -2
  20. agno/db/sqlite/schemas.py +36 -21
  21. agno/db/sqlite/sqlite.py +3 -7
  22. agno/knowledge/chunking/document.py +3 -2
  23. agno/knowledge/chunking/markdown.py +8 -3
  24. agno/knowledge/chunking/recursive.py +2 -2
  25. agno/models/base.py +4 -0
  26. agno/models/google/gemini.py +27 -4
  27. agno/models/openai/chat.py +1 -1
  28. agno/models/openai/responses.py +14 -7
  29. agno/os/middleware/jwt.py +66 -27
  30. agno/os/routers/agents/router.py +3 -3
  31. agno/os/routers/evals/evals.py +2 -2
  32. agno/os/routers/knowledge/knowledge.py +5 -5
  33. agno/os/routers/knowledge/schemas.py +1 -1
  34. agno/os/routers/memory/memory.py +4 -4
  35. agno/os/routers/session/session.py +2 -2
  36. agno/os/routers/teams/router.py +4 -4
  37. agno/os/routers/traces/traces.py +3 -3
  38. agno/os/routers/workflows/router.py +3 -3
  39. agno/os/schema.py +1 -1
  40. agno/reasoning/deepseek.py +11 -1
  41. agno/reasoning/gemini.py +6 -2
  42. agno/reasoning/groq.py +8 -3
  43. agno/reasoning/openai.py +2 -0
  44. agno/remote/base.py +106 -9
  45. agno/skills/__init__.py +17 -0
  46. agno/skills/agent_skills.py +370 -0
  47. agno/skills/errors.py +32 -0
  48. agno/skills/loaders/__init__.py +4 -0
  49. agno/skills/loaders/base.py +27 -0
  50. agno/skills/loaders/local.py +216 -0
  51. agno/skills/skill.py +65 -0
  52. agno/skills/utils.py +107 -0
  53. agno/skills/validator.py +277 -0
  54. agno/team/remote.py +220 -60
  55. agno/team/team.py +41 -3
  56. agno/tools/brandfetch.py +27 -18
  57. agno/tools/browserbase.py +150 -13
  58. agno/tools/function.py +6 -1
  59. agno/tools/mcp/mcp.py +300 -17
  60. agno/tools/mcp/multi_mcp.py +269 -14
  61. agno/tools/toolkit.py +89 -21
  62. agno/utils/mcp.py +49 -8
  63. agno/utils/string.py +43 -1
  64. agno/workflow/condition.py +4 -2
  65. agno/workflow/loop.py +20 -1
  66. agno/workflow/remote.py +173 -33
  67. agno/workflow/router.py +4 -1
  68. agno/workflow/steps.py +4 -0
  69. agno/workflow/workflow.py +14 -0
  70. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/METADATA +13 -14
  71. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/RECORD +74 -60
  72. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/WHEEL +0 -0
  73. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/licenses/LICENSE +0 -0
  74. {agno-2.3.21.dist-info → agno-2.3.23.dist-info}/top_level.txt +0 -0
@@ -1,9 +1,12 @@
1
+ import inspect
2
+ import time
3
+ import warnings
1
4
  import weakref
2
5
  from contextlib import AsyncExitStack
3
6
  from dataclasses import asdict
4
7
  from datetime import timedelta
5
8
  from types import TracebackType
6
- from typing import List, Literal, Optional, Union
9
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Literal, Optional, Tuple, Union
7
10
 
8
11
  from agno.tools import Toolkit
9
12
  from agno.tools.function import Function
@@ -11,6 +14,11 @@ from agno.tools.mcp.params import SSEClientParams, StreamableHTTPClientParams
11
14
  from agno.utils.log import log_debug, log_error, log_info, log_warning
12
15
  from agno.utils.mcp import get_entrypoint_for_tool, prepare_command
13
16
 
17
+ if TYPE_CHECKING:
18
+ from agno.agent import Agent
19
+ from agno.run import RunContext
20
+ from agno.team.team import Team
21
+
14
22
  try:
15
23
  from mcp import ClientSession, StdioServerParameters
16
24
  from mcp.client.sse import sse_client
@@ -47,6 +55,7 @@ class MultiMCPTools(Toolkit):
47
55
  exclude_tools: Optional[list[str]] = None,
48
56
  refresh_connection: bool = False,
49
57
  allow_partial_failure: bool = False,
58
+ header_provider: Optional[Callable[..., dict[str, Any]]] = None,
50
59
  **kwargs,
51
60
  ):
52
61
  """
@@ -64,7 +73,14 @@ class MultiMCPTools(Toolkit):
64
73
  exclude_tools: Optional list of tool names to exclude (if None, excludes none).
65
74
  allow_partial_failure: If True, allows toolkit to initialize even if some MCP servers fail to connect. If False, any failure will raise an exception.
66
75
  refresh_connection: If True, the connection and tools will be refreshed on each run
76
+ header_provider: Header provider function for all servers. Takes RunContext and returns dict of HTTP headers.
67
77
  """
78
+ warnings.warn(
79
+ "The MultiMCPTools class is deprecated and will be removed in a future version. Please use multiple MCPTools instances instead.",
80
+ DeprecationWarning,
81
+ stacklevel=2,
82
+ )
83
+
68
84
  super().__init__(name="MultiMCPTools", **kwargs)
69
85
 
70
86
  if urls_transports is not None:
@@ -86,6 +102,16 @@ class MultiMCPTools(Toolkit):
86
102
  self.exclude_tools = exclude_tools
87
103
  self.refresh_connection = refresh_connection
88
104
 
105
+ self.header_provider = header_provider
106
+
107
+ # Validate header_provider signature
108
+ if header_provider:
109
+ try:
110
+ # Just verify we can inspect the signature - no parameter requirements
111
+ inspect.signature(header_provider)
112
+ except Exception as e:
113
+ log_warning(f"Could not validate header_provider signature: {e}")
114
+
89
115
  if server_params_list is None and commands is None and urls is None:
90
116
  raise ValueError("Either server_params_list or commands or urls must be provided")
91
117
 
@@ -130,6 +156,14 @@ class MultiMCPTools(Toolkit):
130
156
  self._connection_task = None
131
157
  self._successful_connections = 0
132
158
  self._sessions: list[ClientSession] = []
159
+ self._session_to_server_idx: Dict[int, int] = {} # Maps session list index to server params index
160
+
161
+ # Session management for per-agent-run sessions with dynamic headers
162
+ # For MultiMCP, we track sessions per (run_id, server_idx) since we have multiple servers
163
+ # Maps (run_id, server_idx) to (session, timestamp) for TTL-based cleanup
164
+ self._run_sessions: Dict[Tuple[str, int], Tuple[ClientSession, float]] = {}
165
+ self._run_session_contexts: Dict[Tuple[str, int], Any] = {} # Maps (run_id, server_idx) to context managers
166
+ self._session_ttl_seconds: float = 300.0 # 5 minutes default TTL
133
167
 
134
168
  self.allow_partial_failure = allow_partial_failure
135
169
 
@@ -153,6 +187,205 @@ class MultiMCPTools(Toolkit):
153
187
  except (RuntimeError, BaseException):
154
188
  return False
155
189
 
190
+ def _call_header_provider(
191
+ self,
192
+ run_context: Optional["RunContext"] = None,
193
+ agent: Optional["Agent"] = None,
194
+ team: Optional["Team"] = None,
195
+ ) -> dict[str, Any]:
196
+ """Call the header_provider with run_context, agent, and/or team based on its signature.
197
+
198
+ Args:
199
+ run_context: The RunContext for the current agent run
200
+ agent: The Agent instance (if running within an agent)
201
+ team: The Team instance (if running within a team)
202
+
203
+ Returns:
204
+ dict[str, Any]: The headers returned by the header_provider
205
+ """
206
+ header_provider = getattr(self, "header_provider", None)
207
+ if header_provider is None:
208
+ return {}
209
+
210
+ try:
211
+ sig = inspect.signature(header_provider)
212
+ param_names = set(sig.parameters.keys())
213
+
214
+ # Build kwargs based on what the function accepts
215
+ call_kwargs: dict[str, Any] = {}
216
+
217
+ if "run_context" in param_names:
218
+ call_kwargs["run_context"] = run_context
219
+ if "agent" in param_names:
220
+ call_kwargs["agent"] = agent
221
+ if "team" in param_names:
222
+ call_kwargs["team"] = team
223
+
224
+ # Check if function accepts **kwargs (VAR_KEYWORD)
225
+ has_var_keyword = any(p.kind == inspect.Parameter.VAR_KEYWORD for p in sig.parameters.values())
226
+
227
+ if has_var_keyword:
228
+ # Pass all available context to **kwargs
229
+ call_kwargs = {"run_context": run_context, "agent": agent, "team": team}
230
+ return header_provider(**call_kwargs)
231
+ elif call_kwargs:
232
+ return header_provider(**call_kwargs)
233
+ else:
234
+ # Function takes no recognized parameters - check for positional
235
+ positional_params = [
236
+ p
237
+ for p in sig.parameters.values()
238
+ if p.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
239
+ ]
240
+ if positional_params:
241
+ # Legacy support: pass run_context as first positional arg
242
+ return header_provider(run_context)
243
+ else:
244
+ # Function takes no parameters
245
+ return header_provider()
246
+ except Exception as e:
247
+ log_warning(f"Error calling header_provider: {e}")
248
+ return {}
249
+
250
+ async def _cleanup_stale_sessions(self) -> None:
251
+ """Clean up sessions older than TTL to prevent memory leaks."""
252
+ if not self._run_sessions:
253
+ return
254
+
255
+ now = time.time()
256
+ stale_keys = [
257
+ cache_key
258
+ for cache_key, (_, created_at) in self._run_sessions.items()
259
+ if now - created_at > self._session_ttl_seconds
260
+ ]
261
+
262
+ for run_id, server_idx in stale_keys:
263
+ log_debug(f"Cleaning up stale session for run_id={run_id}, server_idx={server_idx}")
264
+ await self.cleanup_run_session(run_id, server_idx)
265
+
266
+ async def get_session_for_run(
267
+ self,
268
+ run_context: Optional["RunContext"] = None,
269
+ server_idx: int = 0,
270
+ agent: Optional["Agent"] = None,
271
+ team: Optional["Team"] = None,
272
+ ) -> ClientSession:
273
+ """
274
+ Get or create a session for the given run_context and server index.
275
+
276
+ If header_provider is configured and run_context is provided, this creates
277
+ a new session with dynamic headers for this specific agent run and server.
278
+
279
+ Args:
280
+ run_context: The RunContext containing user_id, metadata, etc.
281
+ server_idx: Index of the server in self._sessions list
282
+ agent: The Agent instance (if running within an agent)
283
+ team: The Team instance (if running within a team)
284
+
285
+ Returns:
286
+ ClientSession: Either the default session or a per-run session with dynamic headers
287
+ """
288
+ # If no header_provider or no run_context, use the default session
289
+ if not self.header_provider or not run_context:
290
+ # Return the default session for this server
291
+ if server_idx < len(self._sessions):
292
+ return self._sessions[server_idx]
293
+ raise ValueError(f"Server index {server_idx} out of range")
294
+
295
+ # Lazy cleanup of stale sessions
296
+ await self._cleanup_stale_sessions()
297
+
298
+ # Check if we already have a session for this (run_id, server_idx)
299
+ run_id = run_context.run_id
300
+ cache_key = (run_id, server_idx)
301
+ if cache_key in self._run_sessions:
302
+ session, _ = self._run_sessions[cache_key]
303
+ return session
304
+
305
+ # Create a new session with dynamic headers for this run and server
306
+ log_debug(f"Creating new session for run_id={run_id}, server_idx={server_idx} with dynamic headers")
307
+
308
+ # Generate dynamic headers from the provider
309
+ dynamic_headers = self._call_header_provider(run_context=run_context, agent=agent, team=team)
310
+
311
+ # Get the server params for this server index
312
+ if server_idx >= len(self.server_params_list):
313
+ raise ValueError(f"Server index {server_idx} out of range")
314
+
315
+ server_params = self.server_params_list[server_idx]
316
+
317
+ # Create new session with merged headers based on transport type
318
+ if isinstance(server_params, SSEClientParams):
319
+ params_dict = asdict(server_params)
320
+ existing_headers = params_dict.get("headers") or {}
321
+ params_dict["headers"] = {**existing_headers, **dynamic_headers}
322
+
323
+ context = sse_client(**params_dict) # type: ignore
324
+ client_timeout = min(self.timeout_seconds, params_dict.get("timeout", self.timeout_seconds))
325
+
326
+ elif isinstance(server_params, StreamableHTTPClientParams):
327
+ params_dict = asdict(server_params)
328
+ existing_headers = params_dict.get("headers") or {}
329
+ params_dict["headers"] = {**existing_headers, **dynamic_headers}
330
+
331
+ context = streamablehttp_client(**params_dict) # type: ignore
332
+ params_timeout = params_dict.get("timeout", self.timeout_seconds)
333
+ if isinstance(params_timeout, timedelta):
334
+ params_timeout = int(params_timeout.total_seconds())
335
+ client_timeout = min(self.timeout_seconds, params_timeout)
336
+ else:
337
+ # stdio doesn't support headers, fall back to default session
338
+ log_warning(
339
+ f"Cannot use dynamic headers with stdio transport for server {server_idx}, using default session"
340
+ )
341
+ if server_idx < len(self._sessions):
342
+ return self._sessions[server_idx]
343
+ raise ValueError(f"Server index {server_idx} out of range")
344
+
345
+ # Enter the context and create session
346
+ session_params = await context.__aenter__() # type: ignore
347
+ read, write = session_params[0:2]
348
+
349
+ session_context = ClientSession(read, write, read_timeout_seconds=timedelta(seconds=client_timeout)) # type: ignore
350
+ session = await session_context.__aenter__() # type: ignore
351
+
352
+ # Initialize the session
353
+ await session.initialize()
354
+
355
+ # Store the session with timestamp and context for cleanup
356
+ self._run_sessions[cache_key] = (session, time.time())
357
+ self._run_session_contexts[cache_key] = (context, session_context)
358
+
359
+ return session
360
+
361
+ async def cleanup_run_session(self, run_id: str, server_idx: int) -> None:
362
+ """Clean up a per-run session."""
363
+ cache_key = (run_id, server_idx)
364
+ if cache_key not in self._run_sessions:
365
+ return
366
+
367
+ try:
368
+ context, session_context = self._run_session_contexts[cache_key]
369
+
370
+ # Exit session context - silently ignore errors
371
+ try:
372
+ await session_context.__aexit__(None, None, None)
373
+ except (RuntimeError, Exception):
374
+ pass # Silently ignore
375
+
376
+ # Exit transport context - silently ignore errors
377
+ try:
378
+ await context.__aexit__(None, None, None)
379
+ except (RuntimeError, Exception):
380
+ pass # Silently ignore
381
+
382
+ except Exception:
383
+ pass # Silently ignore all cleanup errors
384
+ finally:
385
+ # Remove from cache
386
+ self._run_sessions.pop(cache_key, None)
387
+ self._run_session_contexts.pop(cache_key, None)
388
+
156
389
  async def connect(self, force: bool = False):
157
390
  """Initialize a MultiMCPTools instance and connect to the MCP servers"""
158
391
 
@@ -214,7 +447,7 @@ class MultiMCPTools(Toolkit):
214
447
 
215
448
  server_connection_errors = []
216
449
 
217
- for server_params in self.server_params_list:
450
+ for server_idx, server_params in enumerate(self.server_params_list):
218
451
  try:
219
452
  # Handle stdio connections
220
453
  if isinstance(server_params, StdioServerParameters):
@@ -223,7 +456,7 @@ class MultiMCPTools(Toolkit):
223
456
  session = await self._async_exit_stack.enter_async_context(
224
457
  ClientSession(read, write, read_timeout_seconds=timedelta(seconds=self.timeout_seconds))
225
458
  )
226
- await self.initialize(session)
459
+ await self.initialize(session, server_idx)
227
460
  self._successful_connections += 1
228
461
 
229
462
  # Handle SSE connections
@@ -233,7 +466,7 @@ class MultiMCPTools(Toolkit):
233
466
  )
234
467
  read, write = client_connection
235
468
  session = await self._async_exit_stack.enter_async_context(ClientSession(read, write))
236
- await self.initialize(session)
469
+ await self.initialize(session, server_idx)
237
470
  self._successful_connections += 1
238
471
 
239
472
  # Handle Streamable HTTP connections
@@ -243,7 +476,7 @@ class MultiMCPTools(Toolkit):
243
476
  )
244
477
  read, write = client_connection[0:2]
245
478
  session = await self._async_exit_stack.enter_async_context(ClientSession(read, write))
246
- await self.initialize(session)
479
+ await self.initialize(session, server_idx)
247
480
  self._successful_connections += 1
248
481
 
249
482
  except Exception as e:
@@ -268,13 +501,26 @@ class MultiMCPTools(Toolkit):
268
501
  if not self._initialized:
269
502
  return
270
503
 
271
- try:
272
- await self._async_exit_stack.aclose()
273
- self._sessions = []
274
- self._successful_connections = 0
504
+ import warnings
275
505
 
276
- except (RuntimeError, BaseException) as e:
277
- log_error(f"Failed to close MCP connections: {e}")
506
+ # Suppress async generator cleanup warnings
507
+ with warnings.catch_warnings():
508
+ warnings.filterwarnings("ignore", category=RuntimeWarning, message=".*async_generator.*")
509
+ warnings.filterwarnings("ignore", message=".*cancel scope.*")
510
+
511
+ try:
512
+ # Clean up all per-run sessions first
513
+ cache_keys = list(self._run_sessions.keys())
514
+ for run_id, server_idx in cache_keys:
515
+ await self.cleanup_run_session(run_id, server_idx)
516
+
517
+ # Clean up main sessions
518
+ await self._async_exit_stack.aclose()
519
+ self._sessions = []
520
+ self._successful_connections = 0
521
+
522
+ except (RuntimeError, BaseException):
523
+ pass # Silently ignore all cleanup errors
278
524
 
279
525
  self._initialized = False
280
526
 
@@ -298,7 +544,7 @@ class MultiMCPTools(Toolkit):
298
544
  self._successful_connections = 0
299
545
 
300
546
  async def build_tools(self) -> None:
301
- for session in self._sessions:
547
+ for session_list_idx, session in enumerate(self._sessions):
302
548
  # Get the list of tools from the MCP server
303
549
  available_tools = await session.list_tools()
304
550
 
@@ -314,7 +560,12 @@ class MultiMCPTools(Toolkit):
314
560
  for tool in filtered_tools:
315
561
  try:
316
562
  # Get an entrypoint for the tool
317
- entrypoint = get_entrypoint_for_tool(tool, session)
563
+ entrypoint = get_entrypoint_for_tool(
564
+ tool=tool,
565
+ session=session,
566
+ mcp_tools_instance=self, # Pass self to enable dynamic headers
567
+ server_idx=session_list_idx, # Pass session list index for session lookup
568
+ )
318
569
 
319
570
  # Create a Function for the tool
320
571
  f = Function(
@@ -333,14 +584,18 @@ class MultiMCPTools(Toolkit):
333
584
  log_error(f"Failed to register tool {tool.name}: {e}")
334
585
  raise
335
586
 
336
- async def initialize(self, session: ClientSession) -> None:
587
+ async def initialize(self, session: ClientSession, server_idx: int = 0) -> None:
337
588
  """Initialize the MCP toolkit by getting available tools from the MCP server"""
338
589
 
339
590
  try:
340
591
  # Initialize the session if not already initialized
341
592
  await session.initialize()
342
593
 
594
+ # Track which server index this session belongs to
595
+ session_list_idx = len(self._sessions)
343
596
  self._sessions.append(session)
597
+ self._session_to_server_idx[session_list_idx] = server_idx
598
+
344
599
  self._initialized = True
345
600
  except Exception as e:
346
601
  log_error(f"Failed to get MCP tools: {e}")
agno/tools/toolkit.py CHANGED
@@ -1,4 +1,5 @@
1
1
  from collections import OrderedDict
2
+ from inspect import iscoroutinefunction
2
3
  from typing import Any, Callable, Dict, List, Optional, Sequence, Union
3
4
 
4
5
  from agno.tools.function import Function
@@ -14,6 +15,7 @@ class Toolkit:
14
15
  self,
15
16
  name: str = "toolkit",
16
17
  tools: Sequence[Union[Callable[..., Any], Function]] = [],
18
+ async_tools: Optional[Sequence[tuple[Callable[..., Any], str]]] = None,
17
19
  instructions: Optional[str] = None,
18
20
  add_instructions: bool = False,
19
21
  include_tools: Optional[list[str]] = None,
@@ -32,6 +34,9 @@ class Toolkit:
32
34
  Args:
33
35
  name: A descriptive name for the toolkit
34
36
  tools: List of tools to include in the toolkit (can be callables or Function objects from @tool decorator)
37
+ async_tools: List of (async_callable, tool_name) tuples for async variants.
38
+ Used when async methods have different names than sync methods.
39
+ Example: [(self.anavigate_to, "navigate_to"), (self.ascreenshot, "screenshot")]
35
40
  instructions: Instructions for the toolkit
36
41
  add_instructions: Whether to add instructions to the toolkit
37
42
  include_tools: List of tool names to include in the toolkit
@@ -47,7 +52,11 @@ class Toolkit:
47
52
  """
48
53
  self.name: str = name
49
54
  self.tools: Sequence[Union[Callable[..., Any], Function]] = tools
55
+ self._async_tools: Sequence[tuple[Callable[..., Any], str]] = async_tools or []
56
+ # Functions dict - used by agent.run() and agent.print_response()
50
57
  self.functions: Dict[str, Function] = OrderedDict()
58
+ # Async functions dict - used by agent.arun() and agent.aprint_response()
59
+ self.async_functions: Dict[str, Function] = OrderedDict()
51
60
  self.instructions: Optional[str] = instructions
52
61
  self.add_instructions: bool = add_instructions
53
62
 
@@ -71,8 +80,11 @@ class Toolkit:
71
80
  self.cache_dir: Optional[str] = cache_dir
72
81
 
73
82
  # Automatically register all methods if auto_register is True
74
- if auto_register and self.tools:
75
- self._register_tools()
83
+ if auto_register:
84
+ if self.tools:
85
+ self._register_tools()
86
+ if self._async_tools:
87
+ self._register_async_tools()
76
88
 
77
89
  def _get_tool_name(self, tool: Union[Callable[..., Any], Function]) -> str:
78
90
  """Get the name of a tool, whether it's a Function or callable."""
@@ -125,14 +137,25 @@ class Toolkit:
125
137
  log_warning(f"Show result tool(s) not present in the toolkit: {', '.join(missing_show_result)}")
126
138
 
127
139
  def _register_tools(self) -> None:
128
- """Register all tools."""
140
+ """Register all sync tools."""
129
141
  for tool in self.tools:
130
142
  self.register(tool)
131
143
 
144
+ def _register_async_tools(self) -> None:
145
+ """Register all async tools with their mapped names.
146
+
147
+ Async detection is automatic via iscoroutinefunction.
148
+ """
149
+ for async_func, tool_name in self._async_tools:
150
+ self.register(async_func, name=tool_name)
151
+
132
152
  def register(self, function: Union[Callable[..., Any], Function], name: Optional[str] = None) -> None:
133
153
  """Register a function with the toolkit.
134
154
 
135
155
  This method supports both regular callables and Function objects (from @tool decorator).
156
+ Automatically detects if the function is async (using iscoroutinefunction) and registers
157
+ it to the appropriate dict (functions for sync, async_functions for async).
158
+
136
159
  When a Function object is passed (e.g., from a @tool decorated method), it will:
137
160
  1. Extract the configuration from the Function object
138
161
  2. Look for a bound method with the same name on `self`
@@ -140,17 +163,18 @@ class Toolkit:
140
163
 
141
164
  Args:
142
165
  function: The callable or Function object to register
143
- name: Optional custom name for the function
144
-
145
- Returns:
146
- The registered function
166
+ name: Optional custom name for the function (useful for aliasing)
147
167
  """
148
168
  try:
149
169
  # Handle Function objects (from @tool decorator)
150
170
  if isinstance(function, Function):
151
- return self._register_decorated_tool(function, name)
171
+ # Auto-detect if this is an async function
172
+ is_async = function.entrypoint is not None and iscoroutinefunction(function.entrypoint)
173
+ return self._register_decorated_tool(function, name, is_async=is_async)
174
+
175
+ # Handle regular callables - auto-detect async
176
+ is_async = iscoroutinefunction(function)
152
177
 
153
- # Handle regular callables
154
178
  tool_name = name or function.__name__
155
179
  if self.include_tools is not None and tool_name not in self.include_tools:
156
180
  return
@@ -168,14 +192,19 @@ class Toolkit:
168
192
  stop_after_tool_call=tool_name in self.stop_after_tool_call_tools,
169
193
  show_result=tool_name in self.show_result_tools or tool_name in self.stop_after_tool_call_tools,
170
194
  )
171
- self.functions[f.name] = f
172
- log_debug(f"Function: {f.name} registered with {self.name}")
195
+
196
+ if is_async:
197
+ self.async_functions[f.name] = f
198
+ log_debug(f"Async function: {f.name} registered with {self.name}")
199
+ else:
200
+ self.functions[f.name] = f
201
+ log_debug(f"Function: {f.name} registered with {self.name}")
173
202
  except Exception as e:
174
203
  func_name = self._get_tool_name(function)
175
204
  logger.warning(f"Failed to create Function for: {func_name}")
176
205
  raise e
177
206
 
178
- def _register_decorated_tool(self, function: Function, name: Optional[str] = None) -> None:
207
+ def _register_decorated_tool(self, function: Function, name: Optional[str] = None, is_async: bool = False) -> None:
179
208
  """Register a Function object from @tool decorator, binding it to self.
180
209
 
181
210
  When @tool decorator is used on a class method, it creates a Function with an unbound
@@ -185,6 +214,7 @@ class Toolkit:
185
214
  Args:
186
215
  function: The Function object from @tool decorator
187
216
  name: Optional custom name override
217
+ is_async: If True, register to async_functions dict instead of functions
188
218
  """
189
219
  import inspect
190
220
 
@@ -207,14 +237,24 @@ class Toolkit:
207
237
 
208
238
  if params and params[0] == "self":
209
239
  # Create a bound method by wrapping the function to include self
210
- def make_bound_method(func, instance):
211
- def bound(*args, **kwargs):
212
- return func(instance, *args, **kwargs)
240
+ if is_async:
241
+
242
+ def make_bound_method(func, instance):
243
+ async def bound(*args, **kwargs):
244
+ return await func(instance, *args, **kwargs)
245
+
246
+ bound.__name__ = getattr(func, "__name__", tool_name)
247
+ bound.__doc__ = getattr(func, "__doc__", None)
248
+ return bound
249
+ else:
250
+
251
+ def make_bound_method(func, instance):
252
+ def bound(*args, **kwargs):
253
+ return func(instance, *args, **kwargs)
213
254
 
214
- # Preserve function metadata for debugging
215
- bound.__name__ = getattr(func, "__name__", tool_name)
216
- bound.__doc__ = getattr(func, "__doc__", None)
217
- return bound
255
+ bound.__name__ = getattr(func, "__name__", tool_name)
256
+ bound.__doc__ = getattr(func, "__doc__", None)
257
+ return bound
218
258
 
219
259
  bound_method = make_bound_method(original_func, self)
220
260
  else:
@@ -251,8 +291,36 @@ class Toolkit:
251
291
  cache_dir=function.cache_dir if function.cache_dir else self.cache_dir,
252
292
  cache_ttl=function.cache_ttl if function.cache_ttl != 3600 else self.cache_ttl,
253
293
  )
254
- self.functions[f.name] = f
255
- log_debug(f"Function: {f.name} registered with {self.name} (from @tool decorator)")
294
+
295
+ if is_async:
296
+ self.async_functions[f.name] = f
297
+ log_debug(f"Async function: {f.name} registered with {self.name} (from @tool decorator)")
298
+ else:
299
+ self.functions[f.name] = f
300
+ log_debug(f"Function: {f.name} registered with {self.name} (from @tool decorator)")
301
+
302
+ def get_functions(self) -> Dict[str, Function]:
303
+ """Get sync functions dict.
304
+
305
+ Returns:
306
+ Dict of function name to Function for sync execution
307
+ """
308
+ return self.functions
309
+
310
+ def get_async_functions(self) -> Dict[str, Function]:
311
+ """Get functions dict optimized for async execution.
312
+
313
+ Returns a merged dict where async_functions take precedence over functions.
314
+ This allows async-optimized implementations to be automatically used in async contexts,
315
+ while falling back to sync implementations for tools without async variants.
316
+
317
+ Returns:
318
+ Dict of function name to Function, with async variants preferred
319
+ """
320
+ # Merge: start with sync functions, override with async variants
321
+ merged = OrderedDict(self.functions)
322
+ merged.update(self.async_functions)
323
+ return merged
256
324
 
257
325
  @property
258
326
  def requires_connect(self) -> bool:
agno/utils/mcp.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import json
2
2
  from functools import partial
3
+ from typing import TYPE_CHECKING, Optional, Union
3
4
  from uuid import uuid4
4
5
 
5
6
  from agno.utils.log import log_debug, log_exception
@@ -15,28 +16,68 @@ except (ImportError, ModuleNotFoundError):
15
16
  from agno.media import Image
16
17
  from agno.tools.function import ToolResult
17
18
 
19
+ if TYPE_CHECKING:
20
+ from agno.agent import Agent
21
+ from agno.run import RunContext
22
+ from agno.team.team import Team
23
+ from agno.tools.mcp.mcp import MCPTools
24
+ from agno.tools.mcp.multi_mcp import MultiMCPTools
18
25
 
19
- def get_entrypoint_for_tool(tool: MCPTool, session: ClientSession):
26
+
27
+ def get_entrypoint_for_tool(
28
+ tool: MCPTool,
29
+ session: ClientSession,
30
+ mcp_tools_instance: Optional[Union["MCPTools", "MultiMCPTools"]] = None,
31
+ server_idx: int = 0,
32
+ ):
20
33
  """
21
34
  Return an entrypoint for an MCP tool.
22
35
 
23
36
  Args:
24
37
  tool: The MCP tool to create an entrypoint for
25
- session: The session to use
38
+ session: The MCP ClientSession to use
39
+ mcp_tools_instance: Optional MCPTools or MultiMCPTools instance
40
+ server_idx: Index of the server (for MultiMCPTools)
26
41
 
27
42
  Returns:
28
43
  Callable: The entrypoint function for the tool
29
44
  """
30
45
 
31
- async def call_tool(tool_name: str, **kwargs) -> ToolResult:
46
+ async def call_tool(
47
+ tool_name: str,
48
+ run_context: Optional["RunContext"] = None,
49
+ agent: Optional["Agent"] = None,
50
+ team: Optional["Team"] = None,
51
+ **kwargs,
52
+ ) -> ToolResult:
53
+ # Execute the MCP tool call
32
54
  try:
33
- await session.send_ping()
34
- except Exception as e:
35
- log_exception(e)
55
+ # Get the appropriate session for this run
56
+ # If mcp_tools_instance has header_provider and run_context is provided,
57
+ # this will create/reuse a session with dynamic headers
58
+ if mcp_tools_instance and hasattr(mcp_tools_instance, "get_session_for_run"):
59
+ # Import here to avoid circular imports
60
+ from agno.tools.mcp.multi_mcp import MultiMCPTools
61
+
62
+ # For MultiMCPTools, pass server_idx; for MCPTools, only pass run_context
63
+ if isinstance(mcp_tools_instance, MultiMCPTools):
64
+ active_session = await mcp_tools_instance.get_session_for_run(
65
+ run_context=run_context, server_idx=server_idx, agent=agent, team=team
66
+ )
67
+ else:
68
+ active_session = await mcp_tools_instance.get_session_for_run(
69
+ run_context=run_context, agent=agent, team=team
70
+ )
71
+ else:
72
+ active_session = session
73
+
74
+ try:
75
+ await active_session.send_ping()
76
+ except Exception as e:
77
+ log_exception(e)
36
78
 
37
- try:
38
79
  log_debug(f"Calling MCP Tool '{tool_name}' with args: {kwargs}")
39
- result: CallToolResult = await session.call_tool(tool_name, kwargs) # type: ignore
80
+ result: CallToolResult = await active_session.call_tool(tool_name, kwargs) # type: ignore
40
81
 
41
82
  # Return an error if the tool call failed
42
83
  if result.isError: