google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/active_streaming_tool.py +1 -0
- google/adk/agents/base_agent.py +91 -47
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +4 -9
- google/adk/agents/invocation_context.py +1 -0
- google/adk/agents/langgraph_agent.py +1 -0
- google/adk/agents/live_request_queue.py +1 -0
- google/adk/agents/llm_agent.py +172 -35
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +5 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +5 -2
- google/adk/artifacts/base_artifact_service.py +5 -10
- google/adk/artifacts/gcs_artifact_service.py +9 -9
- google/adk/artifacts/in_memory_artifact_service.py +6 -6
- google/adk/auth/auth_credential.py +9 -5
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +181 -106
- google/adk/cli/cli_tools_click.py +147 -62
- google/adk/cli/fast_api.py +340 -158
- google/adk/cli/fast_api.py.orig +822 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_constants.py +1 -0
- google/adk/evaluation/evaluation_generator.py +89 -114
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +107 -3
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +7 -1
- google/adk/events/event_actions.py +7 -1
- google/adk/examples/example.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/__init__.py +0 -1
- google/adk/flows/llm_flows/_code_execution.py +19 -11
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +86 -22
- google/adk/flows/llm_flows/basic.py +3 -0
- google/adk/flows/llm_flows/functions.py +10 -9
- google/adk/flows/llm_flows/instructions.py +28 -9
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +25 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +59 -27
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
- google/adk/models/anthropic_llm.py +36 -11
- google/adk/models/base_llm.py +45 -4
- google/adk/models/gemini_llm_connection.py +15 -2
- google/adk/models/google_llm.py +9 -44
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +94 -38
- google/adk/models/llm_request.py +1 -1
- google/adk/models/llm_response.py +15 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +68 -44
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +14 -0
- google/adk/sessions/base_session_service.py +8 -32
- google/adk/sessions/database_session_service.py +58 -61
- google/adk/sessions/in_memory_session_service.py +108 -26
- google/adk/sessions/session.py +4 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +16 -13
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_artifacts_tool.py +4 -4
- google/adk/tools/load_memory_tool.py +16 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +1 -1
- google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/common/common.py +2 -5
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/tool_context.py +4 -4
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.4.0.dist-info/RECORD +0 -179
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -15,7 +15,8 @@
|
|
15
15
|
__all__ = []
|
16
16
|
|
17
17
|
try:
|
18
|
-
from .conversion_utils import adk_to_mcp_tool_type
|
18
|
+
from .conversion_utils import adk_to_mcp_tool_type
|
19
|
+
from .conversion_utils import gemini_to_json_schema
|
19
20
|
from .mcp_tool import MCPTool
|
20
21
|
from .mcp_toolset import MCPToolset
|
21
22
|
|
@@ -30,7 +31,7 @@ except ImportError as e:
|
|
30
31
|
import logging
|
31
32
|
import sys
|
32
33
|
|
33
|
-
logger = logging.getLogger(__name__)
|
34
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
34
35
|
|
35
36
|
if sys.version_info < (3, 10):
|
36
37
|
logger.warning(
|
@@ -22,7 +22,7 @@ def adk_to_mcp_tool_type(tool: BaseTool) -> mcp_types.Tool:
|
|
22
22
|
"""Convert a Tool in ADK into MCP tool type.
|
23
23
|
|
24
24
|
This function transforms an ADK tool definition into its equivalent
|
25
|
-
representation in the MCP (Model
|
25
|
+
representation in the MCP (Model Context Protocol) system.
|
26
26
|
|
27
27
|
Args:
|
28
28
|
tool: The ADK tool to convert. It should be an instance of a class derived
|
@@ -1,12 +1,33 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import asyncio
|
16
|
+
from contextlib import asynccontextmanager
|
1
17
|
from contextlib import AsyncExitStack
|
2
18
|
import functools
|
19
|
+
import logging
|
3
20
|
import sys
|
4
|
-
from typing import Any
|
21
|
+
from typing import Any
|
22
|
+
from typing import Optional
|
23
|
+
from typing import TextIO
|
24
|
+
|
5
25
|
import anyio
|
6
26
|
from pydantic import BaseModel
|
7
27
|
|
8
28
|
try:
|
9
|
-
from mcp import ClientSession
|
29
|
+
from mcp import ClientSession
|
30
|
+
from mcp import StdioServerParameters
|
10
31
|
from mcp.client.sse import sse_client
|
11
32
|
from mcp.client.stdio import stdio_client
|
12
33
|
except ImportError as e:
|
@@ -20,6 +41,8 @@ except ImportError as e:
|
|
20
41
|
else:
|
21
42
|
raise e
|
22
43
|
|
44
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
45
|
+
|
23
46
|
|
24
47
|
class SseServerParams(BaseModel):
|
25
48
|
"""Parameters for the MCP SSE connection.
|
@@ -94,6 +117,45 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|
94
117
|
return decorator
|
95
118
|
|
96
119
|
|
120
|
+
@asynccontextmanager
|
121
|
+
async def tracked_stdio_client(server, errlog, process=None):
|
122
|
+
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
123
|
+
our_process = process
|
124
|
+
|
125
|
+
# If no process was provided, create one
|
126
|
+
if our_process is None:
|
127
|
+
our_process = await asyncio.create_subprocess_exec(
|
128
|
+
server.command,
|
129
|
+
*server.args,
|
130
|
+
stdin=asyncio.subprocess.PIPE,
|
131
|
+
stdout=asyncio.subprocess.PIPE,
|
132
|
+
stderr=errlog,
|
133
|
+
)
|
134
|
+
|
135
|
+
# Use the original stdio_client, but ensure process cleanup
|
136
|
+
try:
|
137
|
+
async with stdio_client(server=server, errlog=errlog) as client:
|
138
|
+
yield client, our_process
|
139
|
+
finally:
|
140
|
+
# Ensure the process is properly terminated if it still exists
|
141
|
+
if our_process and our_process.returncode is None:
|
142
|
+
try:
|
143
|
+
logger.info(
|
144
|
+
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
145
|
+
)
|
146
|
+
our_process.terminate()
|
147
|
+
try:
|
148
|
+
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
149
|
+
except asyncio.TimeoutError:
|
150
|
+
# Force kill if it doesn't terminate quickly
|
151
|
+
if our_process.returncode is None:
|
152
|
+
logger.warning(f'Forcing kill of process {our_process.pid}')
|
153
|
+
our_process.kill()
|
154
|
+
except ProcessLookupError:
|
155
|
+
# Process already gone, that's fine
|
156
|
+
logger.info(f'Process {our_process.pid} already terminated')
|
157
|
+
|
158
|
+
|
97
159
|
class MCPSessionManager:
|
98
160
|
"""Manages MCP client sessions.
|
99
161
|
|
@@ -106,7 +168,7 @@ class MCPSessionManager:
|
|
106
168
|
connection_params: StdioServerParameters | SseServerParams,
|
107
169
|
exit_stack: AsyncExitStack,
|
108
170
|
errlog: TextIO = sys.stderr,
|
109
|
-
)
|
171
|
+
):
|
110
172
|
"""Initializes the MCP session manager.
|
111
173
|
|
112
174
|
Example usage:
|
@@ -124,25 +186,39 @@ class MCPSessionManager:
|
|
124
186
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
125
187
|
initializing a local stdio MCP session.
|
126
188
|
"""
|
127
|
-
|
128
|
-
self.
|
129
|
-
self.
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
189
|
+
|
190
|
+
self._connection_params = connection_params
|
191
|
+
self._exit_stack = exit_stack
|
192
|
+
self._errlog = errlog
|
193
|
+
self._process = None # Track the subprocess
|
194
|
+
self._active_processes = set() # Track all processes created
|
195
|
+
self._active_file_handles = set() # Track file handles
|
196
|
+
|
197
|
+
async def create_session(
|
198
|
+
self,
|
199
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
200
|
+
"""Creates a new MCP session and tracks the associated process."""
|
201
|
+
session, process = await self._initialize_session(
|
202
|
+
connection_params=self._connection_params,
|
203
|
+
exit_stack=self._exit_stack,
|
204
|
+
errlog=self._errlog,
|
136
205
|
)
|
206
|
+
self._process = process # Store reference to process
|
207
|
+
|
208
|
+
# Track the process
|
209
|
+
if process:
|
210
|
+
self._active_processes.add(process)
|
211
|
+
|
212
|
+
return session, process
|
137
213
|
|
138
214
|
@classmethod
|
139
|
-
async def
|
215
|
+
async def _initialize_session(
|
140
216
|
cls,
|
141
217
|
*,
|
142
218
|
connection_params: StdioServerParameters | SseServerParams,
|
143
219
|
exit_stack: AsyncExitStack,
|
144
220
|
errlog: TextIO = sys.stderr,
|
145
|
-
) -> ClientSession:
|
221
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
146
222
|
"""Initializes an MCP client session.
|
147
223
|
|
148
224
|
Args:
|
@@ -154,9 +230,17 @@ class MCPSessionManager:
|
|
154
230
|
Returns:
|
155
231
|
ClientSession: The initialized MCP client session.
|
156
232
|
"""
|
233
|
+
process = None
|
234
|
+
|
157
235
|
if isinstance(connection_params, StdioServerParameters):
|
158
|
-
|
236
|
+
# For stdio connections, we need to track the subprocess
|
237
|
+
client, process = await cls._create_stdio_client(
|
238
|
+
server=connection_params,
|
239
|
+
errlog=errlog,
|
240
|
+
exit_stack=exit_stack,
|
241
|
+
)
|
159
242
|
elif isinstance(connection_params, SseServerParams):
|
243
|
+
# For SSE connections, create the client without a subprocess
|
160
244
|
client = sse_client(
|
161
245
|
url=connection_params.url,
|
162
246
|
headers=connection_params.headers,
|
@@ -170,7 +254,74 @@ class MCPSessionManager:
|
|
170
254
|
f' {connection_params}'
|
171
255
|
)
|
172
256
|
|
257
|
+
# Create the session with the client
|
173
258
|
transports = await exit_stack.enter_async_context(client)
|
174
259
|
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
175
260
|
await session.initialize()
|
176
|
-
|
261
|
+
|
262
|
+
return session, process
|
263
|
+
|
264
|
+
@staticmethod
|
265
|
+
async def _create_stdio_client(
|
266
|
+
server: StdioServerParameters,
|
267
|
+
errlog: TextIO,
|
268
|
+
exit_stack: AsyncExitStack,
|
269
|
+
) -> tuple[Any, asyncio.subprocess.Process]:
|
270
|
+
"""Create stdio client and return both the client and process.
|
271
|
+
|
272
|
+
This implementation adapts to how the MCP stdio_client is created.
|
273
|
+
The actual implementation may need to be adjusted based on the MCP library
|
274
|
+
structure.
|
275
|
+
"""
|
276
|
+
# Create the subprocess directly so we can track it
|
277
|
+
process = await asyncio.create_subprocess_exec(
|
278
|
+
server.command,
|
279
|
+
*server.args,
|
280
|
+
stdin=asyncio.subprocess.PIPE,
|
281
|
+
stdout=asyncio.subprocess.PIPE,
|
282
|
+
stderr=errlog,
|
283
|
+
)
|
284
|
+
|
285
|
+
# Create the stdio client using the MCP library
|
286
|
+
try:
|
287
|
+
# Method 1: Try using the existing process if stdio_client supports it
|
288
|
+
client = stdio_client(server=server, errlog=errlog, process=process)
|
289
|
+
except TypeError:
|
290
|
+
# Method 2: If the above doesn't work, let stdio_client create its own process
|
291
|
+
# and we'll need to terminate both processes later
|
292
|
+
logger.warning(
|
293
|
+
'Using stdio_client with its own process - may lead to duplicate'
|
294
|
+
' processes'
|
295
|
+
)
|
296
|
+
client = stdio_client(server=server, errlog=errlog)
|
297
|
+
|
298
|
+
return client, process
|
299
|
+
|
300
|
+
async def _emergency_cleanup(self):
|
301
|
+
"""Perform emergency cleanup of resources when normal cleanup fails."""
|
302
|
+
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
303
|
+
|
304
|
+
# Clean up any tracked processes
|
305
|
+
for proc in list(self._active_processes):
|
306
|
+
try:
|
307
|
+
if proc and proc.returncode is None:
|
308
|
+
logger.info(f'Emergency termination of process {proc.pid}')
|
309
|
+
proc.terminate()
|
310
|
+
try:
|
311
|
+
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
312
|
+
except asyncio.TimeoutError:
|
313
|
+
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
314
|
+
proc.kill()
|
315
|
+
self._active_processes.remove(proc)
|
316
|
+
except Exception as e:
|
317
|
+
logger.error(f'Error during process cleanup: {e}')
|
318
|
+
|
319
|
+
# Clean up any tracked file handles
|
320
|
+
for handle in list(self._active_file_handles):
|
321
|
+
try:
|
322
|
+
if not handle.closed:
|
323
|
+
logger.info('Closing file handle')
|
324
|
+
handle.close()
|
325
|
+
self._active_file_handles.remove(handle)
|
326
|
+
except Exception as e:
|
327
|
+
logger.error(f'Error closing file handle: {e}')
|
@@ -0,0 +1,322 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import asyncio
|
16
|
+
from contextlib import AsyncExitStack, asynccontextmanager
|
17
|
+
import functools
|
18
|
+
import logging
|
19
|
+
import sys
|
20
|
+
from typing import Any, Optional, TextIO
|
21
|
+
import anyio
|
22
|
+
from pydantic import BaseModel
|
23
|
+
|
24
|
+
try:
|
25
|
+
from mcp import ClientSession, StdioServerParameters
|
26
|
+
from mcp.client.sse import sse_client
|
27
|
+
from mcp.client.stdio import stdio_client
|
28
|
+
except ImportError as e:
|
29
|
+
import sys
|
30
|
+
|
31
|
+
if sys.version_info < (3, 10):
|
32
|
+
raise ImportError(
|
33
|
+
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
34
|
+
' version.'
|
35
|
+
) from e
|
36
|
+
else:
|
37
|
+
raise e
|
38
|
+
|
39
|
+
logger = logging.getLogger(__name__)
|
40
|
+
|
41
|
+
|
42
|
+
class SseServerParams(BaseModel):
|
43
|
+
"""Parameters for the MCP SSE connection.
|
44
|
+
|
45
|
+
See MCP SSE Client documentation for more details.
|
46
|
+
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
47
|
+
"""
|
48
|
+
|
49
|
+
url: str
|
50
|
+
headers: dict[str, Any] | None = None
|
51
|
+
timeout: float = 5
|
52
|
+
sse_read_timeout: float = 60 * 5
|
53
|
+
|
54
|
+
|
55
|
+
def retry_on_closed_resource(async_reinit_func_name: str):
|
56
|
+
"""Decorator to automatically reinitialize session and retry action.
|
57
|
+
|
58
|
+
When MCP session was closed, the decorator will automatically recreate the
|
59
|
+
session and retry the action with the same parameters.
|
60
|
+
|
61
|
+
Note:
|
62
|
+
1. async_reinit_func_name is the name of the class member function that
|
63
|
+
reinitializes the MCP session.
|
64
|
+
2. Both the decorated function and the async_reinit_func_name must be async
|
65
|
+
functions.
|
66
|
+
|
67
|
+
Usage:
|
68
|
+
class MCPTool:
|
69
|
+
...
|
70
|
+
async def create_session(self):
|
71
|
+
self.session = ...
|
72
|
+
|
73
|
+
@retry_on_closed_resource('create_session')
|
74
|
+
async def use_session(self):
|
75
|
+
await self.session.call_tool()
|
76
|
+
|
77
|
+
Args:
|
78
|
+
async_reinit_func_name: The name of the async function to recreate session.
|
79
|
+
|
80
|
+
Returns:
|
81
|
+
The decorated function.
|
82
|
+
"""
|
83
|
+
|
84
|
+
def decorator(func):
|
85
|
+
@functools.wraps(
|
86
|
+
func
|
87
|
+
) # Preserves original function metadata (name, docstring)
|
88
|
+
async def wrapper(self, *args, **kwargs):
|
89
|
+
try:
|
90
|
+
return await func(self, *args, **kwargs)
|
91
|
+
except anyio.ClosedResourceError:
|
92
|
+
try:
|
93
|
+
if hasattr(self, async_reinit_func_name) and callable(
|
94
|
+
getattr(self, async_reinit_func_name)
|
95
|
+
):
|
96
|
+
async_init_fn = getattr(self, async_reinit_func_name)
|
97
|
+
await async_init_fn()
|
98
|
+
else:
|
99
|
+
raise ValueError(
|
100
|
+
f'Function {async_reinit_func_name} does not exist in decorated'
|
101
|
+
' class. Please check the function name in'
|
102
|
+
' retry_on_closed_resource decorator.'
|
103
|
+
)
|
104
|
+
except Exception as reinit_err:
|
105
|
+
raise RuntimeError(
|
106
|
+
f'Error reinitializing: {reinit_err}'
|
107
|
+
) from reinit_err
|
108
|
+
return await func(self, *args, **kwargs)
|
109
|
+
|
110
|
+
return wrapper
|
111
|
+
|
112
|
+
return decorator
|
113
|
+
|
114
|
+
|
115
|
+
@asynccontextmanager
|
116
|
+
async def tracked_stdio_client(server, errlog, process=None):
|
117
|
+
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
118
|
+
our_process = process
|
119
|
+
|
120
|
+
# If no process was provided, create one
|
121
|
+
if our_process is None:
|
122
|
+
our_process = await asyncio.create_subprocess_exec(
|
123
|
+
server.command,
|
124
|
+
*server.args,
|
125
|
+
stdin=asyncio.subprocess.PIPE,
|
126
|
+
stdout=asyncio.subprocess.PIPE,
|
127
|
+
stderr=errlog,
|
128
|
+
)
|
129
|
+
|
130
|
+
# Use the original stdio_client, but ensure process cleanup
|
131
|
+
try:
|
132
|
+
async with stdio_client(server=server, errlog=errlog) as client:
|
133
|
+
yield client, our_process
|
134
|
+
finally:
|
135
|
+
# Ensure the process is properly terminated if it still exists
|
136
|
+
if our_process and our_process.returncode is None:
|
137
|
+
try:
|
138
|
+
logger.info(
|
139
|
+
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
140
|
+
)
|
141
|
+
our_process.terminate()
|
142
|
+
try:
|
143
|
+
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
144
|
+
except asyncio.TimeoutError:
|
145
|
+
# Force kill if it doesn't terminate quickly
|
146
|
+
if our_process.returncode is None:
|
147
|
+
logger.warning(f'Forcing kill of process {our_process.pid}')
|
148
|
+
our_process.kill()
|
149
|
+
except ProcessLookupError:
|
150
|
+
# Process already gone, that's fine
|
151
|
+
logger.info(f'Process {our_process.pid} already terminated')
|
152
|
+
|
153
|
+
|
154
|
+
class MCPSessionManager:
|
155
|
+
"""Manages MCP client sessions.
|
156
|
+
|
157
|
+
This class provides methods for creating and initializing MCP client sessions,
|
158
|
+
handling different connection parameters (Stdio and SSE).
|
159
|
+
"""
|
160
|
+
|
161
|
+
def __init__(
|
162
|
+
self,
|
163
|
+
connection_params: StdioServerParameters | SseServerParams,
|
164
|
+
exit_stack: AsyncExitStack,
|
165
|
+
errlog: TextIO = sys.stderr,
|
166
|
+
) -> ClientSession:
|
167
|
+
"""Initializes the MCP session manager.
|
168
|
+
|
169
|
+
Example usage:
|
170
|
+
```
|
171
|
+
mcp_session_manager = MCPSessionManager(
|
172
|
+
connection_params=connection_params,
|
173
|
+
exit_stack=exit_stack,
|
174
|
+
)
|
175
|
+
session = await mcp_session_manager.create_session()
|
176
|
+
```
|
177
|
+
|
178
|
+
Args:
|
179
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
180
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
181
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
182
|
+
initializing a local stdio MCP session.
|
183
|
+
"""
|
184
|
+
|
185
|
+
self._connection_params = connection_params
|
186
|
+
self._exit_stack = exit_stack
|
187
|
+
self._errlog = errlog
|
188
|
+
self._process = None # Track the subprocess
|
189
|
+
self._active_processes = set() # Track all processes created
|
190
|
+
self._active_file_handles = set() # Track file handles
|
191
|
+
|
192
|
+
async def create_session(
|
193
|
+
self,
|
194
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
195
|
+
"""Creates a new MCP session and tracks the associated process."""
|
196
|
+
session, process = await self._initialize_session(
|
197
|
+
connection_params=self._connection_params,
|
198
|
+
exit_stack=self._exit_stack,
|
199
|
+
errlog=self._errlog,
|
200
|
+
)
|
201
|
+
self._process = process # Store reference to process
|
202
|
+
|
203
|
+
# Track the process
|
204
|
+
if process:
|
205
|
+
self._active_processes.add(process)
|
206
|
+
|
207
|
+
return session, process
|
208
|
+
|
209
|
+
@classmethod
|
210
|
+
async def _initialize_session(
|
211
|
+
cls,
|
212
|
+
*,
|
213
|
+
connection_params: StdioServerParameters | SseServerParams,
|
214
|
+
exit_stack: AsyncExitStack,
|
215
|
+
errlog: TextIO = sys.stderr,
|
216
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
217
|
+
"""Initializes an MCP client session.
|
218
|
+
|
219
|
+
Args:
|
220
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
221
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
222
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
223
|
+
initializing a local stdio MCP session.
|
224
|
+
|
225
|
+
Returns:
|
226
|
+
ClientSession: The initialized MCP client session.
|
227
|
+
"""
|
228
|
+
process = None
|
229
|
+
|
230
|
+
if isinstance(connection_params, StdioServerParameters):
|
231
|
+
# For stdio connections, we need to track the subprocess
|
232
|
+
client, process = await cls._create_stdio_client(
|
233
|
+
server=connection_params,
|
234
|
+
errlog=errlog,
|
235
|
+
exit_stack=exit_stack,
|
236
|
+
)
|
237
|
+
elif isinstance(connection_params, SseServerParams):
|
238
|
+
# For SSE connections, create the client without a subprocess
|
239
|
+
client = sse_client(
|
240
|
+
url=connection_params.url,
|
241
|
+
headers=connection_params.headers,
|
242
|
+
timeout=connection_params.timeout,
|
243
|
+
sse_read_timeout=connection_params.sse_read_timeout,
|
244
|
+
)
|
245
|
+
else:
|
246
|
+
raise ValueError(
|
247
|
+
'Unable to initialize connection. Connection should be'
|
248
|
+
' StdioServerParameters or SseServerParams, but got'
|
249
|
+
f' {connection_params}'
|
250
|
+
)
|
251
|
+
|
252
|
+
# Create the session with the client
|
253
|
+
transports = await exit_stack.enter_async_context(client)
|
254
|
+
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
255
|
+
await session.initialize()
|
256
|
+
|
257
|
+
return session, process
|
258
|
+
|
259
|
+
@staticmethod
|
260
|
+
async def _create_stdio_client(
|
261
|
+
server: StdioServerParameters,
|
262
|
+
errlog: TextIO,
|
263
|
+
exit_stack: AsyncExitStack,
|
264
|
+
) -> tuple[Any, asyncio.subprocess.Process]:
|
265
|
+
"""Create stdio client and return both the client and process.
|
266
|
+
|
267
|
+
This implementation adapts to how the MCP stdio_client is created.
|
268
|
+
The actual implementation may need to be adjusted based on the MCP library
|
269
|
+
structure.
|
270
|
+
"""
|
271
|
+
# Create the subprocess directly so we can track it
|
272
|
+
process = await asyncio.create_subprocess_exec(
|
273
|
+
server.command,
|
274
|
+
*server.args,
|
275
|
+
stdin=asyncio.subprocess.PIPE,
|
276
|
+
stdout=asyncio.subprocess.PIPE,
|
277
|
+
stderr=errlog,
|
278
|
+
)
|
279
|
+
|
280
|
+
# Create the stdio client using the MCP library
|
281
|
+
try:
|
282
|
+
# Method 1: Try using the existing process if stdio_client supports it
|
283
|
+
client = stdio_client(server=server, errlog=errlog, process=process)
|
284
|
+
except TypeError:
|
285
|
+
# Method 2: If the above doesn't work, let stdio_client create its own process
|
286
|
+
# and we'll need to terminate both processes later
|
287
|
+
logger.warning(
|
288
|
+
'Using stdio_client with its own process - may lead to duplicate'
|
289
|
+
' processes'
|
290
|
+
)
|
291
|
+
client = stdio_client(server=server, errlog=errlog)
|
292
|
+
|
293
|
+
return client, process
|
294
|
+
|
295
|
+
async def _emergency_cleanup(self):
|
296
|
+
"""Perform emergency cleanup of resources when normal cleanup fails."""
|
297
|
+
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
298
|
+
|
299
|
+
# Clean up any tracked processes
|
300
|
+
for proc in list(self._active_processes):
|
301
|
+
try:
|
302
|
+
if proc and proc.returncode is None:
|
303
|
+
logger.info(f'Emergency termination of process {proc.pid}')
|
304
|
+
proc.terminate()
|
305
|
+
try:
|
306
|
+
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
307
|
+
except asyncio.TimeoutError:
|
308
|
+
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
309
|
+
proc.kill()
|
310
|
+
self._active_processes.remove(proc)
|
311
|
+
except Exception as e:
|
312
|
+
logger.error(f'Error during process cleanup: {e}')
|
313
|
+
|
314
|
+
# Clean up any tracked file handles
|
315
|
+
for handle in list(self._active_file_handles):
|
316
|
+
try:
|
317
|
+
if not handle.closed:
|
318
|
+
logger.info('Closing file handle')
|
319
|
+
handle.close()
|
320
|
+
self._active_file_handles.remove(handle)
|
321
|
+
except Exception as e:
|
322
|
+
logger.error(f'Error closing file handle: {e}')
|
@@ -17,7 +17,8 @@ from typing import Optional
|
|
17
17
|
from google.genai.types import FunctionDeclaration
|
18
18
|
from typing_extensions import override
|
19
19
|
|
20
|
-
from .mcp_session_manager import MCPSessionManager
|
20
|
+
from .mcp_session_manager import MCPSessionManager
|
21
|
+
from .mcp_session_manager import retry_on_closed_resource
|
21
22
|
|
22
23
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
23
24
|
# their Python version to 3.10 if it fails.
|
@@ -36,9 +37,9 @@ except ImportError as e:
|
|
36
37
|
raise e
|
37
38
|
|
38
39
|
|
39
|
-
from ..base_tool import BaseTool
|
40
40
|
from ...auth.auth_credential import AuthCredential
|
41
41
|
from ...auth.auth_schemes import AuthScheme
|
42
|
+
from ..base_tool import BaseTool
|
42
43
|
from ..openapi_tool.openapi_spec_parser.rest_api_tool import to_gemini_schema
|
43
44
|
from ..tool_context import ToolContext
|
44
45
|
|
@@ -79,17 +80,16 @@ class MCPTool(BaseTool):
|
|
79
80
|
raise ValueError("mcp_tool cannot be None")
|
80
81
|
if mcp_session is None:
|
81
82
|
raise ValueError("mcp_session cannot be None")
|
82
|
-
|
83
|
-
self.
|
84
|
-
self.
|
85
|
-
self.
|
86
|
-
self.mcp_session_manager = mcp_session_manager
|
83
|
+
super().__init__(name=mcp_tool.name, description=mcp_tool.description or "")
|
84
|
+
self._mcp_tool = mcp_tool
|
85
|
+
self._mcp_session = mcp_session
|
86
|
+
self._mcp_session_manager = mcp_session_manager
|
87
87
|
# TODO(cheliu): Support passing auth to MCP Server.
|
88
|
-
self.
|
89
|
-
self.
|
88
|
+
self._auth_scheme = auth_scheme
|
89
|
+
self._auth_credential = auth_credential
|
90
90
|
|
91
91
|
async def _reinitialize_session(self):
|
92
|
-
self.
|
92
|
+
self._mcp_session = await self._mcp_session_manager.create_session()
|
93
93
|
|
94
94
|
@override
|
95
95
|
def _get_declaration(self) -> FunctionDeclaration:
|
@@ -98,7 +98,7 @@ class MCPTool(BaseTool):
|
|
98
98
|
Returns:
|
99
99
|
FunctionDeclaration: The Gemini function declaration for the tool.
|
100
100
|
"""
|
101
|
-
schema_dict = self.
|
101
|
+
schema_dict = self._mcp_tool.inputSchema
|
102
102
|
parameters = to_gemini_schema(schema_dict)
|
103
103
|
function_decl = FunctionDeclaration(
|
104
104
|
name=self.name, description=self.description, parameters=parameters
|
@@ -119,7 +119,7 @@ class MCPTool(BaseTool):
|
|
119
119
|
"""
|
120
120
|
# TODO(cheliu): Support passing tool context to MCP Server.
|
121
121
|
try:
|
122
|
-
response = await self.
|
122
|
+
response = await self._mcp_session.call_tool(self.name, arguments=args)
|
123
123
|
return response
|
124
124
|
except Exception as e:
|
125
125
|
print(e)
|