google-adk 1.2.1__py3-none-any.whl → 1.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (91) hide show
  1. google/adk/a2a/__init__.py +13 -0
  2. google/adk/a2a/converters/__init__.py +13 -0
  3. google/adk/a2a/converters/part_converter.py +177 -0
  4. google/adk/agents/invocation_context.py +2 -0
  5. google/adk/agents/llm_agent.py +1 -6
  6. google/adk/agents/run_config.py +11 -0
  7. google/adk/auth/auth_credential.py +4 -0
  8. google/adk/auth/auth_handler.py +22 -96
  9. google/adk/auth/auth_preprocessor.py +3 -3
  10. google/adk/auth/auth_tool.py +46 -0
  11. google/adk/auth/credential_manager.py +261 -0
  12. google/adk/auth/credential_service/__init__.py +13 -0
  13. google/adk/auth/credential_service/base_credential_service.py +75 -0
  14. google/adk/auth/credential_service/in_memory_credential_service.py +64 -0
  15. google/adk/auth/exchanger/__init__.py +21 -0
  16. google/adk/auth/exchanger/base_credential_exchanger.py +57 -0
  17. google/adk/auth/exchanger/credential_exchanger_registry.py +58 -0
  18. google/adk/auth/exchanger/oauth2_credential_exchanger.py +104 -0
  19. google/adk/auth/oauth2_credential_util.py +107 -0
  20. google/adk/auth/refresher/__init__.py +21 -0
  21. google/adk/auth/refresher/base_credential_refresher.py +74 -0
  22. google/adk/auth/refresher/credential_refresher_registry.py +59 -0
  23. google/adk/auth/refresher/oauth2_credential_refresher.py +126 -0
  24. google/adk/cli/agent_graph.py +34 -32
  25. google/adk/cli/browser/index.html +2 -2
  26. google/adk/cli/browser/main-JAAWEV7F.js +92 -0
  27. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  28. google/adk/cli/cli.py +10 -0
  29. google/adk/cli/cli_deploy.py +80 -21
  30. google/adk/cli/cli_tools_click.py +132 -61
  31. google/adk/cli/fast_api.py +46 -41
  32. google/adk/cli/utils/agent_loader.py +15 -2
  33. google/adk/code_executors/container_code_executor.py +10 -6
  34. google/adk/code_executors/vertex_ai_code_executor.py +8 -2
  35. google/adk/evaluation/_eval_set_results_manager_utils.py +44 -0
  36. google/adk/evaluation/_eval_sets_manager_utils.py +108 -0
  37. google/adk/evaluation/eval_metrics.py +0 -5
  38. google/adk/evaluation/eval_result.py +12 -7
  39. google/adk/evaluation/eval_set_results_manager.py +6 -1
  40. google/adk/evaluation/gcs_eval_set_results_manager.py +121 -0
  41. google/adk/evaluation/gcs_eval_sets_manager.py +196 -0
  42. google/adk/evaluation/local_eval_set_results_manager.py +6 -18
  43. google/adk/evaluation/local_eval_sets_manager.py +27 -78
  44. google/adk/flows/llm_flows/basic.py +9 -0
  45. google/adk/flows/llm_flows/functions.py +1 -2
  46. google/adk/models/anthropic_llm.py +1 -1
  47. google/adk/models/gemini_llm_connection.py +2 -0
  48. google/adk/models/google_llm.py +57 -16
  49. google/adk/models/lite_llm.py +2 -1
  50. google/adk/platform/__init__.py +13 -0
  51. google/adk/platform/internal/__init__.py +15 -0
  52. google/adk/platform/internal/thread.py +30 -0
  53. google/adk/platform/thread.py +31 -0
  54. google/adk/runners.py +8 -2
  55. google/adk/sessions/in_memory_session_service.py +12 -1
  56. google/adk/sessions/vertex_ai_session_service.py +71 -50
  57. google/adk/tools/__init__.py +2 -0
  58. google/adk/tools/_automatic_function_calling_util.py +1 -0
  59. google/adk/tools/_forwarding_artifact_service.py +96 -0
  60. google/adk/tools/_function_parameter_parse_util.py +1 -0
  61. google/adk/tools/agent_tool.py +5 -39
  62. google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -2
  63. google/adk/tools/authenticated_function_tool.py +107 -0
  64. google/adk/tools/base_authenticated_tool.py +107 -0
  65. google/adk/tools/bigquery/bigquery_credentials.py +6 -4
  66. google/adk/tools/bigquery/bigquery_tool.py +22 -9
  67. google/adk/tools/bigquery/bigquery_toolset.py +9 -3
  68. google/adk/tools/bigquery/client.py +7 -3
  69. google/adk/tools/bigquery/config.py +46 -0
  70. google/adk/tools/bigquery/metadata_tool.py +114 -91
  71. google/adk/tools/bigquery/query_tool.py +141 -23
  72. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +7 -4
  73. google/adk/tools/google_search_tool.py +0 -1
  74. google/adk/tools/mcp_tool/__init__.py +6 -0
  75. google/adk/tools/mcp_tool/mcp_session_manager.py +271 -149
  76. google/adk/tools/mcp_tool/mcp_tool.py +73 -22
  77. google/adk/tools/mcp_tool/mcp_toolset.py +32 -29
  78. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +3 -3
  79. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +55 -33
  80. google/adk/tools/retrieval/files_retrieval.py +7 -1
  81. google/adk/tools/url_context_tool.py +61 -0
  82. google/adk/tools/vertex_ai_search_tool.py +13 -2
  83. google/adk/utils/feature_decorator.py +175 -0
  84. google/adk/version.py +1 -1
  85. {google_adk-1.2.1.dist-info → google_adk-1.4.1.dist-info}/METADATA +10 -2
  86. {google_adk-1.2.1.dist-info → google_adk-1.4.1.dist-info}/RECORD +89 -59
  87. google/adk/cli/browser/main-CS5OLUMF.js +0 -91
  88. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
  89. {google_adk-1.2.1.dist-info → google_adk-1.4.1.dist-info}/WHEEL +0 -0
  90. {google_adk-1.2.1.dist-info → google_adk-1.4.1.dist-info}/entry_points.txt +0 -0
  91. {google_adk-1.2.1.dist-info → google_adk-1.4.1.dist-info}/licenses/LICENSE +0 -0
@@ -14,12 +14,16 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import asyncio
17
18
  from contextlib import AsyncExitStack
18
19
  from datetime import timedelta
19
20
  import functools
21
+ import hashlib
22
+ import json
20
23
  import logging
21
24
  import sys
22
25
  from typing import Any
26
+ from typing import Dict
23
27
  from typing import Optional
24
28
  from typing import TextIO
25
29
  from typing import Union
@@ -34,7 +38,6 @@ try:
34
38
  from mcp.client.stdio import stdio_client
35
39
  from mcp.client.streamable_http import streamablehttp_client
36
40
  except ImportError as e:
37
- import sys
38
41
 
39
42
  if sys.version_info < (3, 10):
40
43
  raise ImportError(
@@ -47,102 +50,106 @@ except ImportError as e:
47
50
  logger = logging.getLogger('google_adk.' + __name__)
48
51
 
49
52
 
50
- class SseServerParams(BaseModel):
53
+ class StdioConnectionParams(BaseModel):
54
+ """Parameters for the MCP Stdio connection.
55
+
56
+ Attributes:
57
+ server_params: Parameters for the MCP Stdio server.
58
+ timeout: Timeout in seconds for establishing the connection to the MCP
59
+ stdio server.
60
+ """
61
+
62
+ server_params: StdioServerParameters
63
+ timeout: float = 5.0
64
+
65
+
66
+ class SseConnectionParams(BaseModel):
51
67
  """Parameters for the MCP SSE connection.
52
68
 
53
69
  See MCP SSE Client documentation for more details.
54
70
  https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
71
+
72
+ Attributes:
73
+ url: URL for the MCP SSE server.
74
+ headers: Headers for the MCP SSE connection.
75
+ timeout: Timeout in seconds for establishing the connection to the MCP SSE
76
+ server.
77
+ sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
78
+ server.
55
79
  """
56
80
 
57
81
  url: str
58
82
  headers: dict[str, Any] | None = None
59
- timeout: float = 5
60
- sse_read_timeout: float = 60 * 5
83
+ timeout: float = 5.0
84
+ sse_read_timeout: float = 60 * 5.0
61
85
 
62
86
 
63
- class StreamableHTTPServerParams(BaseModel):
87
+ class StreamableHTTPConnectionParams(BaseModel):
64
88
  """Parameters for the MCP SSE connection.
65
89
 
66
90
  See MCP SSE Client documentation for more details.
67
91
  https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py
92
+
93
+ Attributes:
94
+ url: URL for the MCP Streamable HTTP server.
95
+ headers: Headers for the MCP Streamable HTTP connection.
96
+ timeout: Timeout in seconds for establishing the connection to the MCP
97
+ Streamable HTTP server.
98
+ sse_read_timeout: Timeout in seconds for reading data from the MCP
99
+ Streamable HTTP server.
100
+ terminate_on_close: Whether to terminate the MCP Streamable HTTP server
101
+ when the connection is closed.
68
102
  """
69
103
 
70
104
  url: str
71
105
  headers: dict[str, Any] | None = None
72
- timeout: float = 5
73
- sse_read_timeout: float = 60 * 5
106
+ timeout: float = 5.0
107
+ sse_read_timeout: float = 60 * 5.0
74
108
  terminate_on_close: bool = True
75
109
 
76
110
 
77
- def retry_on_closed_resource(async_reinit_func_name: str):
78
- """Decorator to automatically reinitialize session and retry action.
79
-
80
- When MCP session was closed, the decorator will automatically recreate the
81
- session and retry the action with the same parameters.
82
-
83
- Note:
84
- 1. async_reinit_func_name is the name of the class member function that
85
- reinitializes the MCP session.
86
- 2. Both the decorated function and the async_reinit_func_name must be async
87
- functions.
111
+ def retry_on_closed_resource(func):
112
+ """Decorator to automatically retry action when MCP session is closed.
88
113
 
89
- Usage:
90
- class MCPTool:
91
- ...
92
- async def create_session(self):
93
- self.session = ...
94
-
95
- @retry_on_closed_resource('create_session')
96
- async def use_session(self):
97
- await self.session.call_tool()
114
+ When MCP session was closed, the decorator will automatically retry the
115
+ action once. The create_session method will handle creating a new session
116
+ if the old one was disconnected.
98
117
 
99
118
  Args:
100
- async_reinit_func_name: The name of the async function to recreate session.
119
+ func: The function to decorate.
101
120
 
102
121
  Returns:
103
122
  The decorated function.
104
123
  """
105
124
 
106
- def decorator(func):
107
- @functools.wraps(func) # Preserves original function metadata
108
- async def wrapper(self, *args, **kwargs):
109
- try:
110
- return await func(self, *args, **kwargs)
111
- except anyio.ClosedResourceError as close_err:
112
- try:
113
- if hasattr(self, async_reinit_func_name) and callable(
114
- getattr(self, async_reinit_func_name)
115
- ):
116
- async_init_fn = getattr(self, async_reinit_func_name)
117
- await async_init_fn()
118
- else:
119
- raise ValueError(
120
- f'Function {async_reinit_func_name} does not exist in decorated'
121
- ' class. Please check the function name in'
122
- ' retry_on_closed_resource decorator.'
123
- ) from close_err
124
- except Exception as reinit_err:
125
- raise RuntimeError(
126
- f'Error reinitializing: {reinit_err}'
127
- ) from reinit_err
128
- return await func(self, *args, **kwargs)
129
-
130
- return wrapper
131
-
132
- return decorator
125
+ @functools.wraps(func) # Preserves original function metadata
126
+ async def wrapper(self, *args, **kwargs):
127
+ try:
128
+ return await func(self, *args, **kwargs)
129
+ except anyio.ClosedResourceError:
130
+ # Simply retry the function - create_session will handle
131
+ # detecting and replacing disconnected sessions
132
+ logger.info('Retrying %s due to closed resource', func.__name__)
133
+ return await func(self, *args, **kwargs)
134
+
135
+ return wrapper
133
136
 
134
137
 
135
138
  class MCPSessionManager:
136
139
  """Manages MCP client sessions.
137
140
 
138
141
  This class provides methods for creating and initializing MCP client sessions,
139
- handling different connection parameters (Stdio and SSE).
142
+ handling different connection parameters (Stdio and SSE) and supporting
143
+ session pooling based on authentication headers.
140
144
  """
141
145
 
142
146
  def __init__(
143
147
  self,
144
148
  connection_params: Union[
145
- StdioServerParameters, SseServerParams, StreamableHTTPServerParams
149
+ StdioServerParameters,
150
+ StdioConnectionParams,
151
+ SseConnectionParams,
152
+ StreamableHTTPConnectionParams,
146
153
  ],
147
154
  errlog: TextIO = sys.stderr,
148
155
  ):
@@ -155,105 +162,220 @@ class MCPSessionManager:
155
162
  errlog: (Optional) TextIO stream for error logging. Use only for
156
163
  initializing a local stdio MCP session.
157
164
  """
158
- self._connection_params = connection_params
165
+ if isinstance(connection_params, StdioServerParameters):
166
+ # So far timeout is not configurable. Given MCP is still evolving, we
167
+ # would expect stdio_client to evolve to accept timeout parameter like
168
+ # other client.
169
+ logger.warning(
170
+ 'StdioServerParameters is not recommended. Please use'
171
+ ' StdioConnectionParams.'
172
+ )
173
+ self._connection_params = StdioConnectionParams(
174
+ server_params=connection_params,
175
+ timeout=5,
176
+ )
177
+ else:
178
+ self._connection_params = connection_params
159
179
  self._errlog = errlog
160
- # Each session manager maintains its own exit stack for proper cleanup
161
- self._exit_stack: Optional[AsyncExitStack] = None
162
- self._session: Optional[ClientSession] = None
163
180
 
164
- async def create_session(self) -> ClientSession:
181
+ # Session pool: maps session keys to (session, exit_stack) tuples
182
+ self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
183
+
184
+ # Lock to prevent race conditions in session creation
185
+ self._session_lock = asyncio.Lock()
186
+
187
+ def _generate_session_key(
188
+ self, merged_headers: Optional[Dict[str, str]] = None
189
+ ) -> str:
190
+ """Generates a session key based on connection params and merged headers.
191
+
192
+ For StdioConnectionParams, returns a constant key since headers are not
193
+ supported. For SSE and StreamableHTTP connections, generates a key based
194
+ on the provided merged headers.
195
+
196
+ Args:
197
+ merged_headers: Already merged headers (base + additional).
198
+
199
+ Returns:
200
+ A unique session key string.
201
+ """
202
+ if isinstance(self._connection_params, StdioConnectionParams):
203
+ # For stdio connections, headers are not supported, so use constant key
204
+ return 'stdio_session'
205
+
206
+ # For SSE and StreamableHTTP connections, use merged headers
207
+ if merged_headers:
208
+ headers_json = json.dumps(merged_headers, sort_keys=True)
209
+ headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
210
+ return f'session_{headers_hash}'
211
+ else:
212
+ return 'session_no_headers'
213
+
214
+ def _merge_headers(
215
+ self, additional_headers: Optional[Dict[str, str]] = None
216
+ ) -> Optional[Dict[str, str]]:
217
+ """Merges base connection headers with additional headers.
218
+
219
+ Args:
220
+ additional_headers: Optional headers to merge with connection headers.
221
+
222
+ Returns:
223
+ Merged headers dictionary, or None if no headers are provided.
224
+ """
225
+ if isinstance(self._connection_params, StdioConnectionParams) or isinstance(
226
+ self._connection_params, StdioServerParameters
227
+ ):
228
+ # Stdio connections don't support headers
229
+ return None
230
+
231
+ base_headers = {}
232
+ if (
233
+ hasattr(self._connection_params, 'headers')
234
+ and self._connection_params.headers
235
+ ):
236
+ base_headers = self._connection_params.headers.copy()
237
+
238
+ if additional_headers:
239
+ base_headers.update(additional_headers)
240
+
241
+ return base_headers
242
+
243
+ def _is_session_disconnected(self, session: ClientSession) -> bool:
244
+ """Checks if a session is disconnected or closed.
245
+
246
+ Args:
247
+ session: The ClientSession to check.
248
+
249
+ Returns:
250
+ True if the session is disconnected, False otherwise.
251
+ """
252
+ return session._read_stream._closed or session._write_stream._closed
253
+
254
+ async def create_session(
255
+ self, headers: Optional[Dict[str, str]] = None
256
+ ) -> ClientSession:
165
257
  """Creates and initializes an MCP client session.
166
258
 
259
+ This method will check if an existing session for the given headers
260
+ is still connected. If it's disconnected, it will be cleaned up and
261
+ a new session will be created.
262
+
263
+ Args:
264
+ headers: Optional headers to include in the session. These will be
265
+ merged with any existing connection headers. Only applicable
266
+ for SSE and StreamableHTTP connections.
267
+
167
268
  Returns:
168
269
  ClientSession: The initialized MCP client session.
169
270
  """
170
- if self._session is not None:
171
- return self._session
172
-
173
- # Create a new exit stack for this session
174
- self._exit_stack = AsyncExitStack()
271
+ # Merge headers once at the beginning
272
+ merged_headers = self._merge_headers(headers)
273
+
274
+ # Generate session key using merged headers
275
+ session_key = self._generate_session_key(merged_headers)
276
+
277
+ # Use async lock to prevent race conditions
278
+ async with self._session_lock:
279
+ # Check if we have an existing session
280
+ if session_key in self._sessions:
281
+ session, exit_stack = self._sessions[session_key]
282
+
283
+ # Check if the existing session is still connected
284
+ if not self._is_session_disconnected(session):
285
+ # Session is still good, return it
286
+ return session
287
+ else:
288
+ # Session is disconnected, clean it up
289
+ logger.info('Cleaning up disconnected session: %s', session_key)
290
+ try:
291
+ await exit_stack.aclose()
292
+ except Exception as e:
293
+ logger.warning('Error during disconnected session cleanup: %s', e)
294
+ finally:
295
+ del self._sessions[session_key]
296
+
297
+ # Create a new session (either first time or replacing disconnected one)
298
+ exit_stack = AsyncExitStack()
175
299
 
176
- try:
177
- if isinstance(self._connection_params, StdioServerParameters):
178
- # So far timeout is not configurable. Given MCP is still evolving, we
179
- # would expect stdio_client to evolve to accept timeout parameter like
180
- # other client.
181
- client = stdio_client(
182
- server=self._connection_params, errlog=self._errlog
183
- )
184
- elif isinstance(self._connection_params, SseServerParams):
185
- client = sse_client(
186
- url=self._connection_params.url,
187
- headers=self._connection_params.headers,
188
- timeout=self._connection_params.timeout,
189
- sse_read_timeout=self._connection_params.sse_read_timeout,
190
- )
191
- elif isinstance(self._connection_params, StreamableHTTPServerParams):
192
- client = streamablehttp_client(
193
- url=self._connection_params.url,
194
- headers=self._connection_params.headers,
195
- timeout=timedelta(seconds=self._connection_params.timeout),
196
- sse_read_timeout=timedelta(
197
- seconds=self._connection_params.sse_read_timeout
198
- ),
199
- terminate_on_close=self._connection_params.terminate_on_close,
200
- )
201
- else:
202
- raise ValueError(
203
- 'Unable to initialize connection. Connection should be'
204
- ' StdioServerParameters or SseServerParams, but got'
205
- f' {self._connection_params}'
206
- )
207
-
208
- transports = await self._exit_stack.enter_async_context(client)
209
- # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
210
- # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
211
- # The StdioServerParameters does not provide a timeout parameter for the
212
- # session, so we need to set a default timeout for it. Other clients
213
- # (SseServerParams and StreamableHTTPServerParams) already provide a
214
- # timeout parameter in their configuration.
215
- if isinstance(self._connection_params, StdioServerParameters):
216
- # Default timeout for MCP session is 5 seconds, same as SseServerParams
217
- # and StreamableHTTPServerParams.
218
- # TODO :
219
- # 1. make timeout configurable
220
- # 2. Add StdioConnectionParams to include StdioServerParameters as a
221
- # field and rename other two params to XXXXConnetionParams. Ohter
222
- # two params are actually connection params, while stdio is
223
- # special, stdio_client takes the resposibility of starting the
224
- # server and working as a client.
225
- session = await self._exit_stack.enter_async_context(
226
- ClientSession(
227
- *transports[:2],
228
- read_timeout_seconds=timedelta(seconds=5),
229
- )
230
- )
231
- else:
232
- session = await self._exit_stack.enter_async_context(
233
- ClientSession(*transports[:2])
234
- )
235
- await session.initialize()
236
-
237
- self._session = session
238
- return session
239
-
240
- except Exception:
241
- # If session creation fails, clean up the exit stack
242
- if self._exit_stack:
243
- await self._exit_stack.aclose()
244
- self._exit_stack = None
245
- raise
300
+ try:
301
+ if isinstance(self._connection_params, StdioConnectionParams):
302
+ client = stdio_client(
303
+ server=self._connection_params.server_params,
304
+ errlog=self._errlog,
305
+ )
306
+ elif isinstance(self._connection_params, SseConnectionParams):
307
+ client = sse_client(
308
+ url=self._connection_params.url,
309
+ headers=merged_headers,
310
+ timeout=self._connection_params.timeout,
311
+ sse_read_timeout=self._connection_params.sse_read_timeout,
312
+ )
313
+ elif isinstance(
314
+ self._connection_params, StreamableHTTPConnectionParams
315
+ ):
316
+ client = streamablehttp_client(
317
+ url=self._connection_params.url,
318
+ headers=merged_headers,
319
+ timeout=timedelta(seconds=self._connection_params.timeout),
320
+ sse_read_timeout=timedelta(
321
+ seconds=self._connection_params.sse_read_timeout
322
+ ),
323
+ terminate_on_close=self._connection_params.terminate_on_close,
324
+ )
325
+ else:
326
+ raise ValueError(
327
+ 'Unable to initialize connection. Connection should be'
328
+ ' StdioServerParameters or SseServerParams, but got'
329
+ f' {self._connection_params}'
330
+ )
331
+
332
+ transports = await exit_stack.enter_async_context(client)
333
+ # The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
334
+ # needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
335
+ if isinstance(self._connection_params, StdioConnectionParams):
336
+ session = await exit_stack.enter_async_context(
337
+ ClientSession(
338
+ *transports[:2],
339
+ read_timeout_seconds=timedelta(
340
+ seconds=self._connection_params.timeout
341
+ ),
342
+ )
343
+ )
344
+ else:
345
+ session = await exit_stack.enter_async_context(
346
+ ClientSession(*transports[:2])
347
+ )
348
+ await session.initialize()
349
+
350
+ # Store session and exit stack in the pool
351
+ self._sessions[session_key] = (session, exit_stack)
352
+ logger.debug('Created new session: %s', session_key)
353
+ return session
354
+
355
+ except Exception:
356
+ # If session creation fails, clean up the exit stack
357
+ if exit_stack:
358
+ await exit_stack.aclose()
359
+ raise
246
360
 
247
361
  async def close(self):
248
- """Closes the session and cleans up resources."""
249
- if self._exit_stack:
250
- try:
251
- await self._exit_stack.aclose()
252
- except Exception as e:
253
- # Log the error but don't re-raise to avoid blocking shutdown
254
- print(
255
- f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
256
- )
257
- finally:
258
- self._exit_stack = None
259
- self._session = None
362
+ """Closes all sessions and cleans up resources."""
363
+ async with self._session_lock:
364
+ for session_key in list(self._sessions.keys()):
365
+ _, exit_stack = self._sessions[session_key]
366
+ try:
367
+ await exit_stack.aclose()
368
+ except Exception as e:
369
+ # Log the error but don't re-raise to avoid blocking shutdown
370
+ print(
371
+ 'Warning: Error during MCP session cleanup for'
372
+ f' {session_key}: {e}',
373
+ file=self._errlog,
374
+ )
375
+ finally:
376
+ del self._sessions[session_key]
377
+
378
+
379
+ SseServerParams = SseConnectionParams
380
+
381
+ StreamableHTTPServerParams = StreamableHTTPConnectionParams
@@ -14,10 +14,13 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import base64
18
+ import json
17
19
  import logging
18
20
  from typing import Optional
19
21
 
20
22
  from google.genai.types import FunctionDeclaration
23
+ from google.oauth2.credentials import Credentials
21
24
  from typing_extensions import override
22
25
 
23
26
  from .._gemini_schema_util import _to_gemini_schema
@@ -42,14 +45,16 @@ except ImportError as e:
42
45
 
43
46
  from ...auth.auth_credential import AuthCredential
44
47
  from ...auth.auth_schemes import AuthScheme
45
- from ..base_tool import BaseTool
48
+ from ...auth.auth_tool import AuthConfig
49
+ from ..base_authenticated_tool import BaseAuthenticatedTool
50
+ # import
46
51
  from ..tool_context import ToolContext
47
52
 
48
53
  logger = logging.getLogger("google_adk." + __name__)
49
54
 
50
55
 
51
- class MCPTool(BaseTool):
52
- """Turns a MCP Tool into a Vertex Agent Framework Tool.
56
+ class MCPTool(BaseAuthenticatedTool):
57
+ """Turns an MCP Tool into an ADK Tool.
53
58
 
54
59
  Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
55
60
  call the tool.
@@ -63,9 +68,9 @@ class MCPTool(BaseTool):
63
68
  auth_scheme: Optional[AuthScheme] = None,
64
69
  auth_credential: Optional[AuthCredential] = None,
65
70
  ):
66
- """Initializes a MCPTool.
71
+ """Initializes an MCPTool.
67
72
 
68
- This tool wraps a MCP Tool interface and uses a session manager to
73
+ This tool wraps an MCP Tool interface and uses a session manager to
69
74
  communicate with the MCP server.
70
75
 
71
76
  Args:
@@ -77,19 +82,17 @@ class MCPTool(BaseTool):
77
82
  Raises:
78
83
  ValueError: If mcp_tool or mcp_session_manager is None.
79
84
  """
80
- if mcp_tool is None:
81
- raise ValueError("mcp_tool cannot be None")
82
- if mcp_session_manager is None:
83
- raise ValueError("mcp_session_manager cannot be None")
84
85
  super().__init__(
85
86
  name=mcp_tool.name,
86
87
  description=mcp_tool.description if mcp_tool.description else "",
88
+ auth_config=AuthConfig(
89
+ auth_scheme=auth_scheme, raw_auth_credential=auth_credential
90
+ )
91
+ if auth_scheme
92
+ else None,
87
93
  )
88
94
  self._mcp_tool = mcp_tool
89
95
  self._mcp_session_manager = mcp_session_manager
90
- # TODO(cheliu): Support passing auth to MCP Server.
91
- self._auth_scheme = auth_scheme
92
- self._auth_credential = auth_credential
93
96
 
94
97
  @override
95
98
  def _get_declaration(self) -> FunctionDeclaration:
@@ -105,26 +108,74 @@ class MCPTool(BaseTool):
105
108
  )
106
109
  return function_decl
107
110
 
108
- @retry_on_closed_resource("_reinitialize_session")
109
- async def run_async(self, *, args, tool_context: ToolContext):
111
+ @retry_on_closed_resource
112
+ @override
113
+ async def _run_async_impl(
114
+ self, *, args, tool_context: ToolContext, credential: AuthCredential
115
+ ):
110
116
  """Runs the tool asynchronously.
111
117
 
112
118
  Args:
113
119
  args: The arguments as a dict to pass to the tool.
114
- tool_context: The tool context from upper level ADK agent.
120
+ tool_context: The tool context of the current invocation.
115
121
 
116
122
  Returns:
117
123
  Any: The response from the tool.
118
124
  """
125
+ # Extract headers from credential for session pooling
126
+ headers = await self._get_headers(tool_context, credential)
127
+
119
128
  # Get the session from the session manager
120
- session = await self._mcp_session_manager.create_session()
129
+ session = await self._mcp_session_manager.create_session(headers=headers)
121
130
 
122
- # TODO(cheliu): Support passing tool context to MCP Server.
123
131
  response = await session.call_tool(self.name, arguments=args)
124
132
  return response
125
133
 
126
- async def _reinitialize_session(self):
127
- """Reinitializes the session when connection is lost."""
128
- # Close the old session and create a new one
129
- await self._mcp_session_manager.close()
130
- await self._mcp_session_manager.create_session()
134
+ async def _get_headers(
135
+ self, tool_context: ToolContext, credential: AuthCredential
136
+ ) -> Optional[dict[str, str]]:
137
+ headers = None
138
+ if credential:
139
+ if credential.oauth2:
140
+ headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
141
+ elif credential.http:
142
+ # Handle HTTP authentication schemes
143
+ if (
144
+ credential.http.scheme.lower() == "bearer"
145
+ and credential.http.credentials.token
146
+ ):
147
+ headers = {
148
+ "Authorization": f"Bearer {credential.http.credentials.token}"
149
+ }
150
+ elif credential.http.scheme.lower() == "basic":
151
+ # Handle basic auth
152
+ if (
153
+ credential.http.credentials.username
154
+ and credential.http.credentials.password
155
+ ):
156
+
157
+ credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
158
+ encoded_credentials = base64.b64encode(
159
+ credentials.encode()
160
+ ).decode()
161
+ headers = {"Authorization": f"Basic {encoded_credentials}"}
162
+ elif credential.http.credentials.token:
163
+ # Handle other HTTP schemes with token
164
+ headers = {
165
+ "Authorization": (
166
+ f"{credential.http.scheme} {credential.http.credentials.token}"
167
+ )
168
+ }
169
+ elif credential.api_key:
170
+ # For API keys, we'll add them as headers since MCP typically uses header-based auth
171
+ # The specific header name would depend on the API, using a common default
172
+ # TODO Allow user to specify the header name for API keys.
173
+ headers = {"X-API-Key": credential.api_key}
174
+ elif credential.service_account:
175
+ # Service accounts should be exchanged for access tokens before reaching this point
176
+ logger.warning(
177
+ "Service account credentials should be exchanged before MCP"
178
+ " session creation"
179
+ )
180
+
181
+ return headers