google-adk 1.2.1__py3-none-any.whl → 1.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/__init__.py +13 -0
- google/adk/a2a/converters/__init__.py +13 -0
- google/adk/a2a/converters/part_converter.py +166 -0
- google/adk/agents/invocation_context.py +2 -0
- google/adk/agents/llm_agent.py +1 -6
- google/adk/agents/run_config.py +11 -0
- google/adk/auth/auth_credential.py +5 -0
- google/adk/auth/auth_handler.py +22 -96
- google/adk/auth/auth_preprocessor.py +3 -3
- google/adk/auth/auth_tool.py +46 -0
- google/adk/auth/credential_manager.py +265 -0
- google/adk/auth/credential_service/__init__.py +13 -0
- google/adk/auth/credential_service/base_credential_service.py +75 -0
- google/adk/auth/credential_service/in_memory_credential_service.py +64 -0
- google/adk/auth/exchanger/__init__.py +23 -0
- google/adk/auth/exchanger/base_credential_exchanger.py +57 -0
- google/adk/auth/exchanger/credential_exchanger_registry.py +58 -0
- google/adk/auth/exchanger/oauth2_credential_exchanger.py +104 -0
- google/adk/auth/exchanger/service_account_credential_exchanger.py +104 -0
- google/adk/auth/oauth2_credential_util.py +107 -0
- google/adk/auth/refresher/__init__.py +21 -0
- google/adk/auth/refresher/base_credential_refresher.py +74 -0
- google/adk/auth/refresher/credential_refresher_registry.py +59 -0
- google/adk/auth/refresher/oauth2_credential_refresher.py +154 -0
- google/adk/cli/agent_graph.py +34 -32
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/main-JAAWEV7F.js +92 -0
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +10 -0
- google/adk/cli/cli_deploy.py +80 -21
- google/adk/cli/cli_tools_click.py +132 -61
- google/adk/cli/fast_api.py +46 -41
- google/adk/cli/utils/agent_loader.py +15 -2
- google/adk/code_executors/container_code_executor.py +10 -6
- google/adk/code_executors/vertex_ai_code_executor.py +8 -2
- google/adk/evaluation/_eval_set_results_manager_utils.py +44 -0
- google/adk/evaluation/_eval_sets_manager_utils.py +108 -0
- google/adk/evaluation/eval_metrics.py +0 -5
- google/adk/evaluation/eval_result.py +12 -7
- google/adk/evaluation/eval_set_results_manager.py +6 -1
- google/adk/evaluation/gcs_eval_set_results_manager.py +121 -0
- google/adk/evaluation/gcs_eval_sets_manager.py +196 -0
- google/adk/evaluation/local_eval_set_results_manager.py +6 -18
- google/adk/evaluation/local_eval_sets_manager.py +27 -78
- google/adk/flows/llm_flows/basic.py +9 -0
- google/adk/models/anthropic_llm.py +1 -1
- google/adk/models/gemini_llm_connection.py +2 -0
- google/adk/models/google_llm.py +57 -16
- google/adk/models/lite_llm.py +2 -1
- google/adk/platform/__init__.py +13 -0
- google/adk/platform/internal/__init__.py +15 -0
- google/adk/platform/internal/thread.py +30 -0
- google/adk/platform/thread.py +31 -0
- google/adk/runners.py +8 -2
- google/adk/sessions/in_memory_session_service.py +12 -1
- google/adk/sessions/vertex_ai_session_service.py +71 -50
- google/adk/tools/__init__.py +2 -0
- google/adk/tools/_automatic_function_calling_util.py +1 -0
- google/adk/tools/_forwarding_artifact_service.py +96 -0
- google/adk/tools/_function_parameter_parse_util.py +1 -0
- google/adk/tools/agent_tool.py +5 -39
- google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -2
- google/adk/tools/authenticated_function_tool.py +107 -0
- google/adk/tools/base_authenticated_tool.py +107 -0
- google/adk/tools/bigquery/bigquery_credentials.py +6 -4
- google/adk/tools/bigquery/bigquery_tool.py +22 -9
- google/adk/tools/bigquery/bigquery_toolset.py +9 -3
- google/adk/tools/bigquery/client.py +7 -3
- google/adk/tools/bigquery/config.py +46 -0
- google/adk/tools/bigquery/metadata_tool.py +114 -91
- google/adk/tools/bigquery/query_tool.py +141 -23
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +7 -4
- google/adk/tools/google_search_tool.py +0 -1
- google/adk/tools/mcp_tool/__init__.py +6 -0
- google/adk/tools/mcp_tool/mcp_session_manager.py +271 -149
- google/adk/tools/mcp_tool/mcp_tool.py +79 -22
- google/adk/tools/mcp_tool/mcp_toolset.py +32 -29
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +3 -3
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +56 -33
- google/adk/tools/retrieval/files_retrieval.py +7 -1
- google/adk/tools/url_context_tool.py +61 -0
- google/adk/tools/vertex_ai_search_tool.py +13 -2
- google/adk/utils/feature_decorator.py +175 -0
- google/adk/version.py +1 -1
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/METADATA +10 -2
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/RECORD +89 -58
- google/adk/cli/browser/main-CS5OLUMF.js +0 -91
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/WHEEL +0 -0
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.2.1.dist-info → google_adk-1.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,12 +14,16 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
import asyncio
|
17
18
|
from contextlib import AsyncExitStack
|
18
19
|
from datetime import timedelta
|
19
20
|
import functools
|
21
|
+
import hashlib
|
22
|
+
import json
|
20
23
|
import logging
|
21
24
|
import sys
|
22
25
|
from typing import Any
|
26
|
+
from typing import Dict
|
23
27
|
from typing import Optional
|
24
28
|
from typing import TextIO
|
25
29
|
from typing import Union
|
@@ -34,7 +38,6 @@ try:
|
|
34
38
|
from mcp.client.stdio import stdio_client
|
35
39
|
from mcp.client.streamable_http import streamablehttp_client
|
36
40
|
except ImportError as e:
|
37
|
-
import sys
|
38
41
|
|
39
42
|
if sys.version_info < (3, 10):
|
40
43
|
raise ImportError(
|
@@ -47,102 +50,106 @@ except ImportError as e:
|
|
47
50
|
logger = logging.getLogger('google_adk.' + __name__)
|
48
51
|
|
49
52
|
|
50
|
-
class
|
53
|
+
class StdioConnectionParams(BaseModel):
|
54
|
+
"""Parameters for the MCP Stdio connection.
|
55
|
+
|
56
|
+
Attributes:
|
57
|
+
server_params: Parameters for the MCP Stdio server.
|
58
|
+
timeout: Timeout in seconds for establishing the connection to the MCP
|
59
|
+
stdio server.
|
60
|
+
"""
|
61
|
+
|
62
|
+
server_params: StdioServerParameters
|
63
|
+
timeout: float = 5.0
|
64
|
+
|
65
|
+
|
66
|
+
class SseConnectionParams(BaseModel):
|
51
67
|
"""Parameters for the MCP SSE connection.
|
52
68
|
|
53
69
|
See MCP SSE Client documentation for more details.
|
54
70
|
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
71
|
+
|
72
|
+
Attributes:
|
73
|
+
url: URL for the MCP SSE server.
|
74
|
+
headers: Headers for the MCP SSE connection.
|
75
|
+
timeout: Timeout in seconds for establishing the connection to the MCP SSE
|
76
|
+
server.
|
77
|
+
sse_read_timeout: Timeout in seconds for reading data from the MCP SSE
|
78
|
+
server.
|
55
79
|
"""
|
56
80
|
|
57
81
|
url: str
|
58
82
|
headers: dict[str, Any] | None = None
|
59
|
-
timeout: float = 5
|
60
|
-
sse_read_timeout: float = 60 * 5
|
83
|
+
timeout: float = 5.0
|
84
|
+
sse_read_timeout: float = 60 * 5.0
|
61
85
|
|
62
86
|
|
63
|
-
class
|
87
|
+
class StreamableHTTPConnectionParams(BaseModel):
|
64
88
|
"""Parameters for the MCP SSE connection.
|
65
89
|
|
66
90
|
See MCP SSE Client documentation for more details.
|
67
91
|
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/streamable_http.py
|
92
|
+
|
93
|
+
Attributes:
|
94
|
+
url: URL for the MCP Streamable HTTP server.
|
95
|
+
headers: Headers for the MCP Streamable HTTP connection.
|
96
|
+
timeout: Timeout in seconds for establishing the connection to the MCP
|
97
|
+
Streamable HTTP server.
|
98
|
+
sse_read_timeout: Timeout in seconds for reading data from the MCP
|
99
|
+
Streamable HTTP server.
|
100
|
+
terminate_on_close: Whether to terminate the MCP Streamable HTTP server
|
101
|
+
when the connection is closed.
|
68
102
|
"""
|
69
103
|
|
70
104
|
url: str
|
71
105
|
headers: dict[str, Any] | None = None
|
72
|
-
timeout: float = 5
|
73
|
-
sse_read_timeout: float = 60 * 5
|
106
|
+
timeout: float = 5.0
|
107
|
+
sse_read_timeout: float = 60 * 5.0
|
74
108
|
terminate_on_close: bool = True
|
75
109
|
|
76
110
|
|
77
|
-
def retry_on_closed_resource(
|
78
|
-
"""Decorator to automatically
|
79
|
-
|
80
|
-
When MCP session was closed, the decorator will automatically recreate the
|
81
|
-
session and retry the action with the same parameters.
|
82
|
-
|
83
|
-
Note:
|
84
|
-
1. async_reinit_func_name is the name of the class member function that
|
85
|
-
reinitializes the MCP session.
|
86
|
-
2. Both the decorated function and the async_reinit_func_name must be async
|
87
|
-
functions.
|
111
|
+
def retry_on_closed_resource(func):
|
112
|
+
"""Decorator to automatically retry action when MCP session is closed.
|
88
113
|
|
89
|
-
|
90
|
-
|
91
|
-
|
92
|
-
async def create_session(self):
|
93
|
-
self.session = ...
|
94
|
-
|
95
|
-
@retry_on_closed_resource('create_session')
|
96
|
-
async def use_session(self):
|
97
|
-
await self.session.call_tool()
|
114
|
+
When MCP session was closed, the decorator will automatically retry the
|
115
|
+
action once. The create_session method will handle creating a new session
|
116
|
+
if the old one was disconnected.
|
98
117
|
|
99
118
|
Args:
|
100
|
-
|
119
|
+
func: The function to decorate.
|
101
120
|
|
102
121
|
Returns:
|
103
122
|
The decorated function.
|
104
123
|
"""
|
105
124
|
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
await async_init_fn()
|
118
|
-
else:
|
119
|
-
raise ValueError(
|
120
|
-
f'Function {async_reinit_func_name} does not exist in decorated'
|
121
|
-
' class. Please check the function name in'
|
122
|
-
' retry_on_closed_resource decorator.'
|
123
|
-
) from close_err
|
124
|
-
except Exception as reinit_err:
|
125
|
-
raise RuntimeError(
|
126
|
-
f'Error reinitializing: {reinit_err}'
|
127
|
-
) from reinit_err
|
128
|
-
return await func(self, *args, **kwargs)
|
129
|
-
|
130
|
-
return wrapper
|
131
|
-
|
132
|
-
return decorator
|
125
|
+
@functools.wraps(func) # Preserves original function metadata
|
126
|
+
async def wrapper(self, *args, **kwargs):
|
127
|
+
try:
|
128
|
+
return await func(self, *args, **kwargs)
|
129
|
+
except anyio.ClosedResourceError:
|
130
|
+
# Simply retry the function - create_session will handle
|
131
|
+
# detecting and replacing disconnected sessions
|
132
|
+
logger.info('Retrying %s due to closed resource', func.__name__)
|
133
|
+
return await func(self, *args, **kwargs)
|
134
|
+
|
135
|
+
return wrapper
|
133
136
|
|
134
137
|
|
135
138
|
class MCPSessionManager:
|
136
139
|
"""Manages MCP client sessions.
|
137
140
|
|
138
141
|
This class provides methods for creating and initializing MCP client sessions,
|
139
|
-
handling different connection parameters (Stdio and SSE)
|
142
|
+
handling different connection parameters (Stdio and SSE) and supporting
|
143
|
+
session pooling based on authentication headers.
|
140
144
|
"""
|
141
145
|
|
142
146
|
def __init__(
|
143
147
|
self,
|
144
148
|
connection_params: Union[
|
145
|
-
StdioServerParameters,
|
149
|
+
StdioServerParameters,
|
150
|
+
StdioConnectionParams,
|
151
|
+
SseConnectionParams,
|
152
|
+
StreamableHTTPConnectionParams,
|
146
153
|
],
|
147
154
|
errlog: TextIO = sys.stderr,
|
148
155
|
):
|
@@ -155,105 +162,220 @@ class MCPSessionManager:
|
|
155
162
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
156
163
|
initializing a local stdio MCP session.
|
157
164
|
"""
|
158
|
-
|
165
|
+
if isinstance(connection_params, StdioServerParameters):
|
166
|
+
# So far timeout is not configurable. Given MCP is still evolving, we
|
167
|
+
# would expect stdio_client to evolve to accept timeout parameter like
|
168
|
+
# other client.
|
169
|
+
logger.warning(
|
170
|
+
'StdioServerParameters is not recommended. Please use'
|
171
|
+
' StdioConnectionParams.'
|
172
|
+
)
|
173
|
+
self._connection_params = StdioConnectionParams(
|
174
|
+
server_params=connection_params,
|
175
|
+
timeout=5,
|
176
|
+
)
|
177
|
+
else:
|
178
|
+
self._connection_params = connection_params
|
159
179
|
self._errlog = errlog
|
160
|
-
# Each session manager maintains its own exit stack for proper cleanup
|
161
|
-
self._exit_stack: Optional[AsyncExitStack] = None
|
162
|
-
self._session: Optional[ClientSession] = None
|
163
180
|
|
164
|
-
|
181
|
+
# Session pool: maps session keys to (session, exit_stack) tuples
|
182
|
+
self._sessions: Dict[str, tuple[ClientSession, AsyncExitStack]] = {}
|
183
|
+
|
184
|
+
# Lock to prevent race conditions in session creation
|
185
|
+
self._session_lock = asyncio.Lock()
|
186
|
+
|
187
|
+
def _generate_session_key(
|
188
|
+
self, merged_headers: Optional[Dict[str, str]] = None
|
189
|
+
) -> str:
|
190
|
+
"""Generates a session key based on connection params and merged headers.
|
191
|
+
|
192
|
+
For StdioConnectionParams, returns a constant key since headers are not
|
193
|
+
supported. For SSE and StreamableHTTP connections, generates a key based
|
194
|
+
on the provided merged headers.
|
195
|
+
|
196
|
+
Args:
|
197
|
+
merged_headers: Already merged headers (base + additional).
|
198
|
+
|
199
|
+
Returns:
|
200
|
+
A unique session key string.
|
201
|
+
"""
|
202
|
+
if isinstance(self._connection_params, StdioConnectionParams):
|
203
|
+
# For stdio connections, headers are not supported, so use constant key
|
204
|
+
return 'stdio_session'
|
205
|
+
|
206
|
+
# For SSE and StreamableHTTP connections, use merged headers
|
207
|
+
if merged_headers:
|
208
|
+
headers_json = json.dumps(merged_headers, sort_keys=True)
|
209
|
+
headers_hash = hashlib.md5(headers_json.encode()).hexdigest()
|
210
|
+
return f'session_{headers_hash}'
|
211
|
+
else:
|
212
|
+
return 'session_no_headers'
|
213
|
+
|
214
|
+
def _merge_headers(
|
215
|
+
self, additional_headers: Optional[Dict[str, str]] = None
|
216
|
+
) -> Optional[Dict[str, str]]:
|
217
|
+
"""Merges base connection headers with additional headers.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
additional_headers: Optional headers to merge with connection headers.
|
221
|
+
|
222
|
+
Returns:
|
223
|
+
Merged headers dictionary, or None if no headers are provided.
|
224
|
+
"""
|
225
|
+
if isinstance(self._connection_params, StdioConnectionParams) or isinstance(
|
226
|
+
self._connection_params, StdioServerParameters
|
227
|
+
):
|
228
|
+
# Stdio connections don't support headers
|
229
|
+
return None
|
230
|
+
|
231
|
+
base_headers = {}
|
232
|
+
if (
|
233
|
+
hasattr(self._connection_params, 'headers')
|
234
|
+
and self._connection_params.headers
|
235
|
+
):
|
236
|
+
base_headers = self._connection_params.headers.copy()
|
237
|
+
|
238
|
+
if additional_headers:
|
239
|
+
base_headers.update(additional_headers)
|
240
|
+
|
241
|
+
return base_headers
|
242
|
+
|
243
|
+
def _is_session_disconnected(self, session: ClientSession) -> bool:
|
244
|
+
"""Checks if a session is disconnected or closed.
|
245
|
+
|
246
|
+
Args:
|
247
|
+
session: The ClientSession to check.
|
248
|
+
|
249
|
+
Returns:
|
250
|
+
True if the session is disconnected, False otherwise.
|
251
|
+
"""
|
252
|
+
return session._read_stream._closed or session._write_stream._closed
|
253
|
+
|
254
|
+
async def create_session(
|
255
|
+
self, headers: Optional[Dict[str, str]] = None
|
256
|
+
) -> ClientSession:
|
165
257
|
"""Creates and initializes an MCP client session.
|
166
258
|
|
259
|
+
This method will check if an existing session for the given headers
|
260
|
+
is still connected. If it's disconnected, it will be cleaned up and
|
261
|
+
a new session will be created.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
headers: Optional headers to include in the session. These will be
|
265
|
+
merged with any existing connection headers. Only applicable
|
266
|
+
for SSE and StreamableHTTP connections.
|
267
|
+
|
167
268
|
Returns:
|
168
269
|
ClientSession: The initialized MCP client session.
|
169
270
|
"""
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
#
|
174
|
-
|
271
|
+
# Merge headers once at the beginning
|
272
|
+
merged_headers = self._merge_headers(headers)
|
273
|
+
|
274
|
+
# Generate session key using merged headers
|
275
|
+
session_key = self._generate_session_key(merged_headers)
|
276
|
+
|
277
|
+
# Use async lock to prevent race conditions
|
278
|
+
async with self._session_lock:
|
279
|
+
# Check if we have an existing session
|
280
|
+
if session_key in self._sessions:
|
281
|
+
session, exit_stack = self._sessions[session_key]
|
282
|
+
|
283
|
+
# Check if the existing session is still connected
|
284
|
+
if not self._is_session_disconnected(session):
|
285
|
+
# Session is still good, return it
|
286
|
+
return session
|
287
|
+
else:
|
288
|
+
# Session is disconnected, clean it up
|
289
|
+
logger.info('Cleaning up disconnected session: %s', session_key)
|
290
|
+
try:
|
291
|
+
await exit_stack.aclose()
|
292
|
+
except Exception as e:
|
293
|
+
logger.warning('Error during disconnected session cleanup: %s', e)
|
294
|
+
finally:
|
295
|
+
del self._sessions[session_key]
|
296
|
+
|
297
|
+
# Create a new session (either first time or replacing disconnected one)
|
298
|
+
exit_stack = AsyncExitStack()
|
175
299
|
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
199
|
-
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
|
207
|
-
|
208
|
-
|
209
|
-
|
210
|
-
|
211
|
-
|
212
|
-
|
213
|
-
|
214
|
-
|
215
|
-
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
session
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
self._session = session
|
238
|
-
return session
|
239
|
-
|
240
|
-
except Exception:
|
241
|
-
# If session creation fails, clean up the exit stack
|
242
|
-
if self._exit_stack:
|
243
|
-
await self._exit_stack.aclose()
|
244
|
-
self._exit_stack = None
|
245
|
-
raise
|
300
|
+
try:
|
301
|
+
if isinstance(self._connection_params, StdioConnectionParams):
|
302
|
+
client = stdio_client(
|
303
|
+
server=self._connection_params.server_params,
|
304
|
+
errlog=self._errlog,
|
305
|
+
)
|
306
|
+
elif isinstance(self._connection_params, SseConnectionParams):
|
307
|
+
client = sse_client(
|
308
|
+
url=self._connection_params.url,
|
309
|
+
headers=merged_headers,
|
310
|
+
timeout=self._connection_params.timeout,
|
311
|
+
sse_read_timeout=self._connection_params.sse_read_timeout,
|
312
|
+
)
|
313
|
+
elif isinstance(
|
314
|
+
self._connection_params, StreamableHTTPConnectionParams
|
315
|
+
):
|
316
|
+
client = streamablehttp_client(
|
317
|
+
url=self._connection_params.url,
|
318
|
+
headers=merged_headers,
|
319
|
+
timeout=timedelta(seconds=self._connection_params.timeout),
|
320
|
+
sse_read_timeout=timedelta(
|
321
|
+
seconds=self._connection_params.sse_read_timeout
|
322
|
+
),
|
323
|
+
terminate_on_close=self._connection_params.terminate_on_close,
|
324
|
+
)
|
325
|
+
else:
|
326
|
+
raise ValueError(
|
327
|
+
'Unable to initialize connection. Connection should be'
|
328
|
+
' StdioServerParameters or SseServerParams, but got'
|
329
|
+
f' {self._connection_params}'
|
330
|
+
)
|
331
|
+
|
332
|
+
transports = await exit_stack.enter_async_context(client)
|
333
|
+
# The streamable http client returns a GetSessionCallback in addition to the read/write MemoryObjectStreams
|
334
|
+
# needed to build the ClientSession, we limit then to the two first values to be compatible with all clients.
|
335
|
+
if isinstance(self._connection_params, StdioConnectionParams):
|
336
|
+
session = await exit_stack.enter_async_context(
|
337
|
+
ClientSession(
|
338
|
+
*transports[:2],
|
339
|
+
read_timeout_seconds=timedelta(
|
340
|
+
seconds=self._connection_params.timeout
|
341
|
+
),
|
342
|
+
)
|
343
|
+
)
|
344
|
+
else:
|
345
|
+
session = await exit_stack.enter_async_context(
|
346
|
+
ClientSession(*transports[:2])
|
347
|
+
)
|
348
|
+
await session.initialize()
|
349
|
+
|
350
|
+
# Store session and exit stack in the pool
|
351
|
+
self._sessions[session_key] = (session, exit_stack)
|
352
|
+
logger.debug('Created new session: %s', session_key)
|
353
|
+
return session
|
354
|
+
|
355
|
+
except Exception:
|
356
|
+
# If session creation fails, clean up the exit stack
|
357
|
+
if exit_stack:
|
358
|
+
await exit_stack.aclose()
|
359
|
+
raise
|
246
360
|
|
247
361
|
async def close(self):
|
248
|
-
"""Closes
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
362
|
+
"""Closes all sessions and cleans up resources."""
|
363
|
+
async with self._session_lock:
|
364
|
+
for session_key in list(self._sessions.keys()):
|
365
|
+
_, exit_stack = self._sessions[session_key]
|
366
|
+
try:
|
367
|
+
await exit_stack.aclose()
|
368
|
+
except Exception as e:
|
369
|
+
# Log the error but don't re-raise to avoid blocking shutdown
|
370
|
+
print(
|
371
|
+
'Warning: Error during MCP session cleanup for'
|
372
|
+
f' {session_key}: {e}',
|
373
|
+
file=self._errlog,
|
374
|
+
)
|
375
|
+
finally:
|
376
|
+
del self._sessions[session_key]
|
377
|
+
|
378
|
+
|
379
|
+
SseServerParams = SseConnectionParams
|
380
|
+
|
381
|
+
StreamableHTTPServerParams = StreamableHTTPConnectionParams
|
@@ -14,10 +14,13 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
import base64
|
18
|
+
import json
|
17
19
|
import logging
|
18
20
|
from typing import Optional
|
19
21
|
|
20
22
|
from google.genai.types import FunctionDeclaration
|
23
|
+
from google.oauth2.credentials import Credentials
|
21
24
|
from typing_extensions import override
|
22
25
|
|
23
26
|
from .._gemini_schema_util import _to_gemini_schema
|
@@ -42,14 +45,16 @@ except ImportError as e:
|
|
42
45
|
|
43
46
|
from ...auth.auth_credential import AuthCredential
|
44
47
|
from ...auth.auth_schemes import AuthScheme
|
45
|
-
from
|
48
|
+
from ...auth.auth_tool import AuthConfig
|
49
|
+
from ..base_authenticated_tool import BaseAuthenticatedTool
|
50
|
+
# import
|
46
51
|
from ..tool_context import ToolContext
|
47
52
|
|
48
53
|
logger = logging.getLogger("google_adk." + __name__)
|
49
54
|
|
50
55
|
|
51
|
-
class MCPTool(
|
52
|
-
"""Turns
|
56
|
+
class MCPTool(BaseAuthenticatedTool):
|
57
|
+
"""Turns an MCP Tool into an ADK Tool.
|
53
58
|
|
54
59
|
Internally, the tool initializes from a MCP Tool, and uses the MCP Session to
|
55
60
|
call the tool.
|
@@ -63,9 +68,9 @@ class MCPTool(BaseTool):
|
|
63
68
|
auth_scheme: Optional[AuthScheme] = None,
|
64
69
|
auth_credential: Optional[AuthCredential] = None,
|
65
70
|
):
|
66
|
-
"""Initializes
|
71
|
+
"""Initializes an MCPTool.
|
67
72
|
|
68
|
-
This tool wraps
|
73
|
+
This tool wraps an MCP Tool interface and uses a session manager to
|
69
74
|
communicate with the MCP server.
|
70
75
|
|
71
76
|
Args:
|
@@ -77,19 +82,17 @@ class MCPTool(BaseTool):
|
|
77
82
|
Raises:
|
78
83
|
ValueError: If mcp_tool or mcp_session_manager is None.
|
79
84
|
"""
|
80
|
-
if mcp_tool is None:
|
81
|
-
raise ValueError("mcp_tool cannot be None")
|
82
|
-
if mcp_session_manager is None:
|
83
|
-
raise ValueError("mcp_session_manager cannot be None")
|
84
85
|
super().__init__(
|
85
86
|
name=mcp_tool.name,
|
86
87
|
description=mcp_tool.description if mcp_tool.description else "",
|
88
|
+
auth_config=AuthConfig(
|
89
|
+
auth_scheme=auth_scheme, raw_auth_credential=auth_credential
|
90
|
+
)
|
91
|
+
if auth_scheme
|
92
|
+
else None,
|
87
93
|
)
|
88
94
|
self._mcp_tool = mcp_tool
|
89
95
|
self._mcp_session_manager = mcp_session_manager
|
90
|
-
# TODO(cheliu): Support passing auth to MCP Server.
|
91
|
-
self._auth_scheme = auth_scheme
|
92
|
-
self._auth_credential = auth_credential
|
93
96
|
|
94
97
|
@override
|
95
98
|
def _get_declaration(self) -> FunctionDeclaration:
|
@@ -105,26 +108,80 @@ class MCPTool(BaseTool):
|
|
105
108
|
)
|
106
109
|
return function_decl
|
107
110
|
|
108
|
-
@retry_on_closed_resource
|
109
|
-
|
111
|
+
@retry_on_closed_resource
|
112
|
+
@override
|
113
|
+
async def _run_async_impl(
|
114
|
+
self, *, args, tool_context: ToolContext, credential: AuthCredential
|
115
|
+
):
|
110
116
|
"""Runs the tool asynchronously.
|
111
117
|
|
112
118
|
Args:
|
113
119
|
args: The arguments as a dict to pass to the tool.
|
114
|
-
tool_context: The tool context
|
120
|
+
tool_context: The tool context of the current invocation.
|
115
121
|
|
116
122
|
Returns:
|
117
123
|
Any: The response from the tool.
|
118
124
|
"""
|
125
|
+
# Extract headers from credential for session pooling
|
126
|
+
headers = await self._get_headers(tool_context, credential)
|
127
|
+
|
119
128
|
# Get the session from the session manager
|
120
|
-
session = await self._mcp_session_manager.create_session()
|
129
|
+
session = await self._mcp_session_manager.create_session(headers=headers)
|
121
130
|
|
122
|
-
# TODO(cheliu): Support passing tool context to MCP Server.
|
123
131
|
response = await session.call_tool(self.name, arguments=args)
|
124
132
|
return response
|
125
133
|
|
126
|
-
async def
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
134
|
+
async def _get_headers(
|
135
|
+
self, tool_context: ToolContext, credential: AuthCredential
|
136
|
+
) -> Optional[dict[str, str]]:
|
137
|
+
headers = None
|
138
|
+
if credential:
|
139
|
+
if credential.oauth2:
|
140
|
+
headers = {"Authorization": f"Bearer {credential.oauth2.access_token}"}
|
141
|
+
elif credential.google_oauth2_json:
|
142
|
+
google_credential = Credentials.from_authorized_user_info(
|
143
|
+
json.loads(credential.google_oauth2_json)
|
144
|
+
)
|
145
|
+
headers = {"Authorization": f"Bearer {google_credential.token}"}
|
146
|
+
elif credential.http:
|
147
|
+
# Handle HTTP authentication schemes
|
148
|
+
if (
|
149
|
+
credential.http.scheme.lower() == "bearer"
|
150
|
+
and credential.http.credentials.token
|
151
|
+
):
|
152
|
+
headers = {
|
153
|
+
"Authorization": f"Bearer {credential.http.credentials.token}"
|
154
|
+
}
|
155
|
+
elif credential.http.scheme.lower() == "basic":
|
156
|
+
# Handle basic auth
|
157
|
+
if (
|
158
|
+
credential.http.credentials.username
|
159
|
+
and credential.http.credentials.password
|
160
|
+
):
|
161
|
+
|
162
|
+
credentials = f"{credential.http.credentials.username}:{credential.http.credentials.password}"
|
163
|
+
encoded_credentials = base64.b64encode(
|
164
|
+
credentials.encode()
|
165
|
+
).decode()
|
166
|
+
headers = {"Authorization": f"Basic {encoded_credentials}"}
|
167
|
+
elif credential.http.credentials.token:
|
168
|
+
# Handle other HTTP schemes with token
|
169
|
+
headers = {
|
170
|
+
"Authorization": (
|
171
|
+
f"{credential.http.scheme} {credential.http.credentials.token}"
|
172
|
+
)
|
173
|
+
}
|
174
|
+
elif credential.api_key:
|
175
|
+
# For API keys, we'll add them as headers since MCP typically uses header-based auth
|
176
|
+
# The specific header name would depend on the API, using a common default
|
177
|
+
# TODO Allow user to specify the header name for API keys.
|
178
|
+
headers = {"X-API-Key": credential.api_key}
|
179
|
+
elif credential.service_account:
|
180
|
+
# Service accounts should be exchanged for access tokens before reaching this point
|
181
|
+
# If we reach here, we can try to use google_oauth2_json or log a warning
|
182
|
+
logger.warning(
|
183
|
+
"Service account credentials should be exchanged for access"
|
184
|
+
" tokens before MCP session creation"
|
185
|
+
)
|
186
|
+
|
187
|
+
return headers
|