google-adk 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/callback_context.py +2 -1
- google/adk/agents/readonly_context.py +3 -1
- google/adk/auth/auth_credential.py +4 -1
- google/adk/cli/browser/index.html +4 -4
- google/adk/cli/browser/{main-QOEMUXM4.js → main-PKDNKWJE.js} +59 -59
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +3 -2
- google/adk/cli/cli_eval.py +6 -85
- google/adk/cli/cli_tools_click.py +39 -10
- google/adk/cli/fast_api.py +53 -184
- google/adk/cli/utils/agent_loader.py +137 -0
- google/adk/cli/utils/cleanup.py +40 -0
- google/adk/cli/utils/evals.py +2 -1
- google/adk/cli/utils/logs.py +2 -7
- google/adk/code_executors/code_execution_utils.py +2 -1
- google/adk/code_executors/container_code_executor.py +0 -1
- google/adk/code_executors/vertex_ai_code_executor.py +6 -8
- google/adk/evaluation/eval_case.py +3 -1
- google/adk/evaluation/eval_metrics.py +74 -0
- google/adk/evaluation/eval_result.py +86 -0
- google/adk/evaluation/eval_set.py +2 -0
- google/adk/evaluation/eval_set_results_manager.py +47 -0
- google/adk/evaluation/eval_sets_manager.py +2 -1
- google/adk/evaluation/evaluator.py +2 -0
- google/adk/evaluation/local_eval_set_results_manager.py +113 -0
- google/adk/evaluation/local_eval_sets_manager.py +4 -4
- google/adk/evaluation/response_evaluator.py +2 -1
- google/adk/evaluation/trajectory_evaluator.py +3 -2
- google/adk/examples/base_example_provider.py +1 -0
- google/adk/flows/llm_flows/base_llm_flow.py +4 -6
- google/adk/flows/llm_flows/contents.py +3 -1
- google/adk/flows/llm_flows/instructions.py +7 -77
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/models/base_llm.py +2 -1
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/google_llm.py +4 -1
- google/adk/models/lite_llm.py +3 -2
- google/adk/models/llm_response.py +2 -1
- google/adk/runners.py +36 -4
- google/adk/sessions/_session_util.py +2 -1
- google/adk/sessions/database_session_service.py +5 -8
- google/adk/sessions/vertex_ai_session_service.py +28 -13
- google/adk/telemetry.py +4 -2
- google/adk/tools/agent_tool.py +1 -1
- google/adk/tools/apihub_tool/apihub_toolset.py +1 -1
- google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
- google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +6 -2
- google/adk/tools/application_integration_tool/clients/connections_client.py +8 -1
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -1
- google/adk/tools/application_integration_tool/integration_connector_tool.py +1 -1
- google/adk/tools/base_toolset.py +40 -2
- google/adk/tools/bigquery/__init__.py +38 -0
- google/adk/tools/bigquery/bigquery_credentials.py +217 -0
- google/adk/tools/bigquery/bigquery_tool.py +116 -0
- google/adk/tools/bigquery/bigquery_toolset.py +86 -0
- google/adk/tools/bigquery/client.py +33 -0
- google/adk/tools/bigquery/metadata_tool.py +249 -0
- google/adk/tools/bigquery/query_tool.py +76 -0
- google/adk/tools/function_parameter_parse_util.py +7 -0
- google/adk/tools/function_tool.py +33 -3
- google/adk/tools/get_user_choice_tool.py +1 -0
- google/adk/tools/google_api_tool/__init__.py +17 -11
- google/adk/tools/google_api_tool/google_api_tool.py +1 -1
- google/adk/tools/google_api_tool/google_api_toolset.py +0 -14
- google/adk/tools/google_api_tool/google_api_toolsets.py +8 -2
- google/adk/tools/google_search_tool.py +2 -2
- google/adk/tools/mcp_tool/conversion_utils.py +6 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +62 -188
- google/adk/tools/mcp_tool/mcp_tool.py +27 -24
- google/adk/tools/mcp_tool/mcp_toolset.py +76 -131
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
- google/adk/tools/openapi_tool/common/common.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +2 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
- google/adk/tools/toolbox_toolset.py +31 -3
- google/adk/utils/__init__.py +13 -0
- google/adk/utils/instructions_utils.py +131 -0
- google/adk/version.py +1 -1
- {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/METADATA +12 -15
- {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/RECORD +87 -78
- google/adk/agents/base_agent.py.orig +0 -330
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
- google/adk/cli/fast_api.py.orig +0 -822
- google/adk/memory/base_memory_service.py.orig +0 -76
- google/adk/models/google_llm.py.orig +0 -305
- google/adk/tools/_built_in_code_execution_tool.py +0 -70
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +0 -322
- {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/WHEEL +0 -0
- {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/entry_points.txt +0 -0
- {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from google.oauth2.credentials import Credentials
|
16
|
+
|
17
|
+
from ...tools.bigquery import client
|
18
|
+
|
19
|
+
MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
|
20
|
+
|
21
|
+
|
22
|
+
def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
|
23
|
+
"""Run a BigQuery SQL query in the project and return the result.
|
24
|
+
|
25
|
+
Args:
|
26
|
+
project_id (str): The GCP project id in which the query should be
|
27
|
+
executed.
|
28
|
+
query (str): The BigQuery SQL query to be executed.
|
29
|
+
credentials (Credentials): The credentials to use for the request.
|
30
|
+
|
31
|
+
Returns:
|
32
|
+
dict: Dictionary representing the result of the query.
|
33
|
+
If the result contains the key "result_is_likely_truncated" with
|
34
|
+
value True, it means that there may be additional rows matching the
|
35
|
+
query not returned in the result.
|
36
|
+
|
37
|
+
Examples:
|
38
|
+
>>> execute_sql("bigframes-dev",
|
39
|
+
... "SELECT island, COUNT(*) AS population "
|
40
|
+
... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
|
41
|
+
{
|
42
|
+
"rows": [
|
43
|
+
{
|
44
|
+
"island": "Dream",
|
45
|
+
"population": 124
|
46
|
+
},
|
47
|
+
{
|
48
|
+
"island": "Biscoe",
|
49
|
+
"population": 168
|
50
|
+
},
|
51
|
+
{
|
52
|
+
"island": "Torgersen",
|
53
|
+
"population": 52
|
54
|
+
}
|
55
|
+
]
|
56
|
+
}
|
57
|
+
"""
|
58
|
+
|
59
|
+
try:
|
60
|
+
bq_client = client.get_bigquery_client(credentials=credentials)
|
61
|
+
row_iterator = bq_client.query_and_wait(
|
62
|
+
query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
63
|
+
)
|
64
|
+
rows = [{key: val for key, val in row.items()} for row in row_iterator]
|
65
|
+
result = {"rows": rows}
|
66
|
+
if (
|
67
|
+
MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
|
68
|
+
and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
|
69
|
+
):
|
70
|
+
result["result_is_likely_truncated"] = True
|
71
|
+
return result
|
72
|
+
except Exception as ex:
|
73
|
+
return {
|
74
|
+
"status": "ERROR",
|
75
|
+
"error_details": str(ex),
|
76
|
+
}
|
@@ -289,6 +289,13 @@ def _parse_schema_from_parameter(
|
|
289
289
|
)
|
290
290
|
_raise_if_schema_unsupported(variant, schema)
|
291
291
|
return schema
|
292
|
+
if param.annotation is None:
|
293
|
+
# https://swagger.io/docs/specification/v3_0/data-models/data-types/#null
|
294
|
+
# null is not a valid type in schema, use object instead.
|
295
|
+
schema.type = types.Type.OBJECT
|
296
|
+
schema.nullable = True
|
297
|
+
_raise_if_schema_unsupported(variant, schema)
|
298
|
+
return schema
|
292
299
|
raise ValueError(
|
293
300
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
294
301
|
' automatic function calling. Automatic function calling works best with'
|
@@ -33,8 +33,31 @@ class FunctionTool(BaseTool):
|
|
33
33
|
"""
|
34
34
|
|
35
35
|
def __init__(self, func: Callable[..., Any]):
|
36
|
-
|
36
|
+
"""Extract metadata from a callable object."""
|
37
|
+
name = ''
|
38
|
+
doc = ''
|
39
|
+
# Handle different types of callables
|
40
|
+
if hasattr(func, '__name__'):
|
41
|
+
# Regular functions, unbound methods, etc.
|
42
|
+
name = func.__name__
|
43
|
+
elif hasattr(func, '__class__'):
|
44
|
+
# Callable objects, bound methods, etc.
|
45
|
+
name = func.__class__.__name__
|
46
|
+
|
47
|
+
# Get documentation (prioritize direct __doc__ if available)
|
48
|
+
if hasattr(func, '__doc__') and func.__doc__:
|
49
|
+
doc = func.__doc__
|
50
|
+
elif (
|
51
|
+
hasattr(func, '__call__')
|
52
|
+
and hasattr(func.__call__, '__doc__')
|
53
|
+
and func.__call__.__doc__
|
54
|
+
):
|
55
|
+
# For callable objects, try to get docstring from __call__ method
|
56
|
+
doc = func.__call__.__doc__
|
57
|
+
|
58
|
+
super().__init__(name=name, description=doc)
|
37
59
|
self.func = func
|
60
|
+
self._ignore_params = ['tool_context', 'input_stream']
|
38
61
|
|
39
62
|
@override
|
40
63
|
def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
|
@@ -43,7 +66,7 @@ class FunctionTool(BaseTool):
|
|
43
66
|
func=self.func,
|
44
67
|
# The model doesn't understand the function context.
|
45
68
|
# input_stream is for streaming tool
|
46
|
-
ignore_params=
|
69
|
+
ignore_params=self._ignore_params,
|
47
70
|
variant=self._api_variant,
|
48
71
|
)
|
49
72
|
)
|
@@ -76,7 +99,14 @@ class FunctionTool(BaseTool):
|
|
76
99
|
You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
|
77
100
|
return {'error': error_str}
|
78
101
|
|
79
|
-
|
102
|
+
# Functions are callable objects, but not all callable objects are functions
|
103
|
+
# checking coroutine function is not enough. We also need to check whether
|
104
|
+
# Callable's __call__ function is a coroutine funciton
|
105
|
+
if (
|
106
|
+
inspect.iscoroutinefunction(self.func)
|
107
|
+
or hasattr(self.func, '__call__')
|
108
|
+
and inspect.iscoroutinefunction(self.func.__call__)
|
109
|
+
):
|
80
110
|
return await self.func(**args_to_call) or {}
|
81
111
|
else:
|
82
112
|
return self.func(**args_to_call) or {}
|
@@ -11,18 +11,12 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
__all__ = [
|
15
|
-
'BigQueryToolset',
|
16
|
-
'CalendarToolset',
|
17
|
-
'GmailToolset',
|
18
|
-
'YoutubeToolset',
|
19
|
-
'SlidesToolset',
|
20
|
-
'SheetsToolset',
|
21
|
-
'DocsToolset',
|
22
|
-
'GoogleApiToolset',
|
23
|
-
'GoogleApiTool',
|
24
|
-
]
|
25
14
|
|
15
|
+
"""Auto-generated tools and toolsets for Google APIs.
|
16
|
+
|
17
|
+
These tools and toolsets are auto-generated based on the API specifications
|
18
|
+
provided by the Google API Discovery API.
|
19
|
+
"""
|
26
20
|
|
27
21
|
from .google_api_tool import GoogleApiTool
|
28
22
|
from .google_api_toolset import GoogleApiToolset
|
@@ -33,3 +27,15 @@ from .google_api_toolsets import GmailToolset
|
|
33
27
|
from .google_api_toolsets import SheetsToolset
|
34
28
|
from .google_api_toolsets import SlidesToolset
|
35
29
|
from .google_api_toolsets import YoutubeToolset
|
30
|
+
|
31
|
+
__all__ = [
|
32
|
+
'BigQueryToolset',
|
33
|
+
'CalendarToolset',
|
34
|
+
'GmailToolset',
|
35
|
+
'YoutubeToolset',
|
36
|
+
'SlidesToolset',
|
37
|
+
'SheetsToolset',
|
38
|
+
'DocsToolset',
|
39
|
+
'GoogleApiToolset',
|
40
|
+
'GoogleApiTool',
|
41
|
+
]
|
@@ -19,10 +19,10 @@ from typing import Optional
|
|
19
19
|
from google.genai.types import FunctionDeclaration
|
20
20
|
from typing_extensions import override
|
21
21
|
|
22
|
+
from .. import BaseTool
|
22
23
|
from ...auth import AuthCredential
|
23
24
|
from ...auth import AuthCredentialTypes
|
24
25
|
from ...auth import OAuth2Auth
|
25
|
-
from .. import BaseTool
|
26
26
|
from ..openapi_tool import RestApiTool
|
27
27
|
from ..tool_context import ToolContext
|
28
28
|
|
@@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
|
|
56
56
|
self._openapi_toolset = self._load_toolset_with_oidc_auth()
|
57
57
|
self.tool_filter = tool_filter
|
58
58
|
|
59
|
-
def _is_tool_selected(
|
60
|
-
self, tool: GoogleApiTool, readonly_context: ReadonlyContext
|
61
|
-
) -> bool:
|
62
|
-
if not self.tool_filter:
|
63
|
-
return True
|
64
|
-
|
65
|
-
if isinstance(self.tool_filter, ToolPredicate):
|
66
|
-
return self.tool_filter(tool, readonly_context)
|
67
|
-
|
68
|
-
if isinstance(self.tool_filter, list):
|
69
|
-
return tool.name in self.tool_filter
|
70
|
-
|
71
|
-
return False
|
72
|
-
|
73
59
|
@override
|
74
60
|
async def get_tools(
|
75
61
|
self, readonly_context: Optional[ReadonlyContext] = None
|
@@ -18,14 +18,14 @@ from typing import List
|
|
18
18
|
from typing import Optional
|
19
19
|
from typing import Union
|
20
20
|
|
21
|
-
from
|
22
|
-
|
21
|
+
from ..base_toolset import ToolPredicate
|
23
22
|
from .google_api_toolset import GoogleApiToolset
|
24
23
|
|
25
24
|
logger = logging.getLogger("google_adk." + __name__)
|
26
25
|
|
27
26
|
|
28
27
|
class BigQueryToolset(GoogleApiToolset):
|
28
|
+
"""Auto-generated Bigquery toolset based on Google BigQuery API v2 spec exposed by Google API discovery API"""
|
29
29
|
|
30
30
|
def __init__(
|
31
31
|
self,
|
@@ -37,6 +37,7 @@ class BigQueryToolset(GoogleApiToolset):
|
|
37
37
|
|
38
38
|
|
39
39
|
class CalendarToolset(GoogleApiToolset):
|
40
|
+
"""Auto-generated Calendar toolset based on Google Calendar API v3 spec exposed by Google API discovery API"""
|
40
41
|
|
41
42
|
def __init__(
|
42
43
|
self,
|
@@ -48,6 +49,7 @@ class CalendarToolset(GoogleApiToolset):
|
|
48
49
|
|
49
50
|
|
50
51
|
class GmailToolset(GoogleApiToolset):
|
52
|
+
"""Auto-generated Gmail toolset based on Google Gmail API v1 spec exposed by Google API discovery API"""
|
51
53
|
|
52
54
|
def __init__(
|
53
55
|
self,
|
@@ -59,6 +61,7 @@ class GmailToolset(GoogleApiToolset):
|
|
59
61
|
|
60
62
|
|
61
63
|
class YoutubeToolset(GoogleApiToolset):
|
64
|
+
"""Auto-generated Youtube toolset based on Youtube API v3 spec exposed by Google API discovery API"""
|
62
65
|
|
63
66
|
def __init__(
|
64
67
|
self,
|
@@ -70,6 +73,7 @@ class YoutubeToolset(GoogleApiToolset):
|
|
70
73
|
|
71
74
|
|
72
75
|
class SlidesToolset(GoogleApiToolset):
|
76
|
+
"""Auto-generated Slides toolset based on Google Slides API v1 spec exposed by Google API discovery API"""
|
73
77
|
|
74
78
|
def __init__(
|
75
79
|
self,
|
@@ -81,6 +85,7 @@ class SlidesToolset(GoogleApiToolset):
|
|
81
85
|
|
82
86
|
|
83
87
|
class SheetsToolset(GoogleApiToolset):
|
88
|
+
"""Auto-generated Sheets toolset based on Google Sheets API v4 spec exposed by Google API discovery API"""
|
84
89
|
|
85
90
|
def __init__(
|
86
91
|
self,
|
@@ -92,6 +97,7 @@ class SheetsToolset(GoogleApiToolset):
|
|
92
97
|
|
93
98
|
|
94
99
|
class DocsToolset(GoogleApiToolset):
|
100
|
+
"""Auto-generated Docs toolset based on Google Docs API v1 spec exposed by Google API discovery API"""
|
95
101
|
|
96
102
|
def __init__(
|
97
103
|
self,
|
@@ -46,7 +46,7 @@ class GoogleSearchTool(BaseTool):
|
|
46
46
|
) -> None:
|
47
47
|
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
48
48
|
llm_request.config.tools = llm_request.config.tools or []
|
49
|
-
if llm_request.model and
|
49
|
+
if llm_request.model and 'gemini-1' in llm_request.model:
|
50
50
|
if llm_request.config.tools:
|
51
51
|
print(llm_request.config.tools)
|
52
52
|
raise ValueError(
|
@@ -55,7 +55,7 @@ class GoogleSearchTool(BaseTool):
|
|
55
55
|
llm_request.config.tools.append(
|
56
56
|
types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
|
57
57
|
)
|
58
|
-
elif llm_request.model and
|
58
|
+
elif llm_request.model and 'gemini-2' in llm_request.model:
|
59
59
|
llm_request.config.tools.append(
|
60
60
|
types.Tool(google_search=types.GoogleSearch())
|
61
61
|
)
|
@@ -12,9 +12,13 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any
|
16
|
-
from
|
15
|
+
from typing import Any
|
16
|
+
from typing import Dict
|
17
|
+
|
18
|
+
from google.genai.types import Schema
|
19
|
+
from google.genai.types import Type
|
17
20
|
import mcp.types as mcp_types
|
21
|
+
|
18
22
|
from ..base_tool import BaseTool
|
19
23
|
|
20
24
|
|
@@ -12,8 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
from contextlib import asynccontextmanager
|
15
|
+
|
17
16
|
from contextlib import AsyncExitStack
|
18
17
|
import functools
|
19
18
|
import logging
|
@@ -71,29 +70,27 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|
71
70
|
|
72
71
|
Usage:
|
73
72
|
class MCPTool:
|
74
|
-
|
75
|
-
|
76
|
-
|
73
|
+
...
|
74
|
+
async def create_session(self):
|
75
|
+
self.session = ...
|
77
76
|
|
78
|
-
|
79
|
-
|
80
|
-
|
77
|
+
@retry_on_closed_resource('create_session')
|
78
|
+
async def use_session(self):
|
79
|
+
await self.session.call_tool()
|
81
80
|
|
82
81
|
Args:
|
83
|
-
|
82
|
+
async_reinit_func_name: The name of the async function to recreate session.
|
84
83
|
|
85
84
|
Returns:
|
86
|
-
|
85
|
+
The decorated function.
|
87
86
|
"""
|
88
87
|
|
89
88
|
def decorator(func):
|
90
|
-
@functools.wraps(
|
91
|
-
func
|
92
|
-
) # Preserves original function metadata (name, docstring)
|
89
|
+
@functools.wraps(func) # Preserves original function metadata
|
93
90
|
async def wrapper(self, *args, **kwargs):
|
94
91
|
try:
|
95
92
|
return await func(self, *args, **kwargs)
|
96
|
-
except anyio.ClosedResourceError:
|
93
|
+
except anyio.ClosedResourceError as close_err:
|
97
94
|
try:
|
98
95
|
if hasattr(self, async_reinit_func_name) and callable(
|
99
96
|
getattr(self, async_reinit_func_name)
|
@@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|
105
102
|
f'Function {async_reinit_func_name} does not exist in decorated'
|
106
103
|
' class. Please check the function name in'
|
107
104
|
' retry_on_closed_resource decorator.'
|
108
|
-
)
|
105
|
+
) from close_err
|
109
106
|
except Exception as reinit_err:
|
110
107
|
raise RuntimeError(
|
111
108
|
f'Error reinitializing: {reinit_err}'
|
@@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|
117
114
|
return decorator
|
118
115
|
|
119
116
|
|
120
|
-
@asynccontextmanager
|
121
|
-
async def tracked_stdio_client(server, errlog, process=None):
|
122
|
-
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
123
|
-
our_process = process
|
124
|
-
|
125
|
-
# If no process was provided, create one
|
126
|
-
if our_process is None:
|
127
|
-
our_process = await asyncio.create_subprocess_exec(
|
128
|
-
server.command,
|
129
|
-
*server.args,
|
130
|
-
stdin=asyncio.subprocess.PIPE,
|
131
|
-
stdout=asyncio.subprocess.PIPE,
|
132
|
-
stderr=errlog,
|
133
|
-
)
|
134
|
-
|
135
|
-
# Use the original stdio_client, but ensure process cleanup
|
136
|
-
try:
|
137
|
-
async with stdio_client(server=server, errlog=errlog) as client:
|
138
|
-
yield client, our_process
|
139
|
-
finally:
|
140
|
-
# Ensure the process is properly terminated if it still exists
|
141
|
-
if our_process and our_process.returncode is None:
|
142
|
-
try:
|
143
|
-
logger.info(
|
144
|
-
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
145
|
-
)
|
146
|
-
our_process.terminate()
|
147
|
-
try:
|
148
|
-
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
149
|
-
except asyncio.TimeoutError:
|
150
|
-
# Force kill if it doesn't terminate quickly
|
151
|
-
if our_process.returncode is None:
|
152
|
-
logger.warning(f'Forcing kill of process {our_process.pid}')
|
153
|
-
our_process.kill()
|
154
|
-
except ProcessLookupError:
|
155
|
-
# Process already gone, that's fine
|
156
|
-
logger.info(f'Process {our_process.pid} already terminated')
|
157
|
-
|
158
|
-
|
159
117
|
class MCPSessionManager:
|
160
118
|
"""Manages MCP client sessions.
|
161
119
|
|
@@ -166,162 +124,78 @@ class MCPSessionManager:
|
|
166
124
|
def __init__(
|
167
125
|
self,
|
168
126
|
connection_params: StdioServerParameters | SseServerParams,
|
169
|
-
exit_stack: AsyncExitStack,
|
170
127
|
errlog: TextIO = sys.stderr,
|
171
128
|
):
|
172
129
|
"""Initializes the MCP session manager.
|
173
130
|
|
174
|
-
Example usage:
|
175
|
-
```
|
176
|
-
mcp_session_manager = MCPSessionManager(
|
177
|
-
connection_params=connection_params,
|
178
|
-
exit_stack=exit_stack,
|
179
|
-
)
|
180
|
-
session = await mcp_session_manager.create_session()
|
181
|
-
```
|
182
|
-
|
183
131
|
Args:
|
184
132
|
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
185
|
-
exit_stack: AsyncExitStack to manage the session lifecycle.
|
186
133
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
187
134
|
initializing a local stdio MCP session.
|
188
135
|
"""
|
189
|
-
|
190
136
|
self._connection_params = connection_params
|
191
|
-
self._exit_stack = exit_stack
|
192
137
|
self._errlog = errlog
|
193
|
-
|
194
|
-
self.
|
195
|
-
self.
|
138
|
+
# Each session manager maintains its own exit stack for proper cleanup
|
139
|
+
self._exit_stack: Optional[AsyncExitStack] = None
|
140
|
+
self._session: Optional[ClientSession] = None
|
196
141
|
|
197
|
-
async def create_session(
|
198
|
-
|
199
|
-
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
200
|
-
"""Creates a new MCP session and tracks the associated process."""
|
201
|
-
session, process = await self._initialize_session(
|
202
|
-
connection_params=self._connection_params,
|
203
|
-
exit_stack=self._exit_stack,
|
204
|
-
errlog=self._errlog,
|
205
|
-
)
|
206
|
-
self._process = process # Store reference to process
|
207
|
-
|
208
|
-
# Track the process
|
209
|
-
if process:
|
210
|
-
self._active_processes.add(process)
|
211
|
-
|
212
|
-
return session, process
|
213
|
-
|
214
|
-
@classmethod
|
215
|
-
async def _initialize_session(
|
216
|
-
cls,
|
217
|
-
*,
|
218
|
-
connection_params: StdioServerParameters | SseServerParams,
|
219
|
-
exit_stack: AsyncExitStack,
|
220
|
-
errlog: TextIO = sys.stderr,
|
221
|
-
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
222
|
-
"""Initializes an MCP client session.
|
223
|
-
|
224
|
-
Args:
|
225
|
-
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
226
|
-
exit_stack: AsyncExitStack to manage the session lifecycle.
|
227
|
-
errlog: (Optional) TextIO stream for error logging. Use only for
|
228
|
-
initializing a local stdio MCP session.
|
142
|
+
async def create_session(self) -> ClientSession:
|
143
|
+
"""Creates and initializes an MCP client session.
|
229
144
|
|
230
145
|
Returns:
|
231
146
|
ClientSession: The initialized MCP client session.
|
232
147
|
"""
|
233
|
-
|
234
|
-
|
235
|
-
if isinstance(connection_params, StdioServerParameters):
|
236
|
-
# For stdio connections, we need to track the subprocess
|
237
|
-
client, process = await cls._create_stdio_client(
|
238
|
-
server=connection_params,
|
239
|
-
errlog=errlog,
|
240
|
-
exit_stack=exit_stack,
|
241
|
-
)
|
242
|
-
elif isinstance(connection_params, SseServerParams):
|
243
|
-
# For SSE connections, create the client without a subprocess
|
244
|
-
client = sse_client(
|
245
|
-
url=connection_params.url,
|
246
|
-
headers=connection_params.headers,
|
247
|
-
timeout=connection_params.timeout,
|
248
|
-
sse_read_timeout=connection_params.sse_read_timeout,
|
249
|
-
)
|
250
|
-
else:
|
251
|
-
raise ValueError(
|
252
|
-
'Unable to initialize connection. Connection should be'
|
253
|
-
' StdioServerParameters or SseServerParams, but got'
|
254
|
-
f' {connection_params}'
|
255
|
-
)
|
148
|
+
if self._session is not None:
|
149
|
+
return self._session
|
256
150
|
|
257
|
-
# Create
|
258
|
-
|
259
|
-
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
260
|
-
await session.initialize()
|
151
|
+
# Create a new exit stack for this session
|
152
|
+
self._exit_stack = AsyncExitStack()
|
261
153
|
|
262
|
-
return session, process
|
263
|
-
|
264
|
-
@staticmethod
|
265
|
-
async def _create_stdio_client(
|
266
|
-
server: StdioServerParameters,
|
267
|
-
errlog: TextIO,
|
268
|
-
exit_stack: AsyncExitStack,
|
269
|
-
) -> tuple[Any, asyncio.subprocess.Process]:
|
270
|
-
"""Create stdio client and return both the client and process.
|
271
|
-
|
272
|
-
This implementation adapts to how the MCP stdio_client is created.
|
273
|
-
The actual implementation may need to be adjusted based on the MCP library
|
274
|
-
structure.
|
275
|
-
"""
|
276
|
-
# Create the subprocess directly so we can track it
|
277
|
-
process = await asyncio.create_subprocess_exec(
|
278
|
-
server.command,
|
279
|
-
*server.args,
|
280
|
-
stdin=asyncio.subprocess.PIPE,
|
281
|
-
stdout=asyncio.subprocess.PIPE,
|
282
|
-
stderr=errlog,
|
283
|
-
)
|
284
|
-
|
285
|
-
# Create the stdio client using the MCP library
|
286
154
|
try:
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
|
296
|
-
|
155
|
+
if isinstance(self._connection_params, StdioServerParameters):
|
156
|
+
client = stdio_client(
|
157
|
+
server=self._connection_params, errlog=self._errlog
|
158
|
+
)
|
159
|
+
elif isinstance(self._connection_params, SseServerParams):
|
160
|
+
client = sse_client(
|
161
|
+
url=self._connection_params.url,
|
162
|
+
headers=self._connection_params.headers,
|
163
|
+
timeout=self._connection_params.timeout,
|
164
|
+
sse_read_timeout=self._connection_params.sse_read_timeout,
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
raise ValueError(
|
168
|
+
'Unable to initialize connection. Connection should be'
|
169
|
+
' StdioServerParameters or SseServerParams, but got'
|
170
|
+
f' {self._connection_params}'
|
171
|
+
)
|
297
172
|
|
298
|
-
|
173
|
+
transports = await self._exit_stack.enter_async_context(client)
|
174
|
+
session = await self._exit_stack.enter_async_context(
|
175
|
+
ClientSession(*transports)
|
176
|
+
)
|
177
|
+
await session.initialize()
|
299
178
|
|
300
|
-
|
301
|
-
|
302
|
-
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
179
|
+
self._session = session
|
180
|
+
return session
|
303
181
|
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
try:
|
311
|
-
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
312
|
-
except asyncio.TimeoutError:
|
313
|
-
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
314
|
-
proc.kill()
|
315
|
-
self._active_processes.remove(proc)
|
316
|
-
except Exception as e:
|
317
|
-
logger.error(f'Error during process cleanup: {e}')
|
182
|
+
except Exception:
|
183
|
+
# If session creation fails, clean up the exit stack
|
184
|
+
if self._exit_stack:
|
185
|
+
await self._exit_stack.aclose()
|
186
|
+
self._exit_stack = None
|
187
|
+
raise
|
318
188
|
|
319
|
-
|
320
|
-
|
189
|
+
async def close(self):
|
190
|
+
"""Closes the session and cleans up resources."""
|
191
|
+
if self._exit_stack:
|
321
192
|
try:
|
322
|
-
|
323
|
-
logger.info('Closing file handle')
|
324
|
-
handle.close()
|
325
|
-
self._active_file_handles.remove(handle)
|
193
|
+
await self._exit_stack.aclose()
|
326
194
|
except Exception as e:
|
327
|
-
|
195
|
+
# Log the error but don't re-raise to avoid blocking shutdown
|
196
|
+
print(
|
197
|
+
f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
|
198
|
+
)
|
199
|
+
finally:
|
200
|
+
self._exit_stack = None
|
201
|
+
self._session = None
|