google-adk 1.0.0__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (94) hide show
  1. google/adk/agents/callback_context.py +2 -1
  2. google/adk/agents/readonly_context.py +3 -1
  3. google/adk/auth/auth_credential.py +4 -1
  4. google/adk/cli/browser/index.html +4 -4
  5. google/adk/cli/browser/{main-QOEMUXM4.js → main-PKDNKWJE.js} +59 -59
  6. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  7. google/adk/cli/cli.py +3 -2
  8. google/adk/cli/cli_eval.py +6 -85
  9. google/adk/cli/cli_tools_click.py +39 -10
  10. google/adk/cli/fast_api.py +53 -184
  11. google/adk/cli/utils/agent_loader.py +137 -0
  12. google/adk/cli/utils/cleanup.py +40 -0
  13. google/adk/cli/utils/evals.py +2 -1
  14. google/adk/cli/utils/logs.py +2 -7
  15. google/adk/code_executors/code_execution_utils.py +2 -1
  16. google/adk/code_executors/container_code_executor.py +0 -1
  17. google/adk/code_executors/vertex_ai_code_executor.py +6 -8
  18. google/adk/evaluation/eval_case.py +3 -1
  19. google/adk/evaluation/eval_metrics.py +74 -0
  20. google/adk/evaluation/eval_result.py +86 -0
  21. google/adk/evaluation/eval_set.py +2 -0
  22. google/adk/evaluation/eval_set_results_manager.py +47 -0
  23. google/adk/evaluation/eval_sets_manager.py +2 -1
  24. google/adk/evaluation/evaluator.py +2 -0
  25. google/adk/evaluation/local_eval_set_results_manager.py +113 -0
  26. google/adk/evaluation/local_eval_sets_manager.py +4 -4
  27. google/adk/evaluation/response_evaluator.py +2 -1
  28. google/adk/evaluation/trajectory_evaluator.py +3 -2
  29. google/adk/examples/base_example_provider.py +1 -0
  30. google/adk/flows/llm_flows/base_llm_flow.py +4 -6
  31. google/adk/flows/llm_flows/contents.py +3 -1
  32. google/adk/flows/llm_flows/instructions.py +7 -77
  33. google/adk/flows/llm_flows/single_flow.py +1 -1
  34. google/adk/models/base_llm.py +2 -1
  35. google/adk/models/base_llm_connection.py +2 -0
  36. google/adk/models/google_llm.py +4 -1
  37. google/adk/models/lite_llm.py +3 -2
  38. google/adk/models/llm_response.py +2 -1
  39. google/adk/runners.py +36 -4
  40. google/adk/sessions/_session_util.py +2 -1
  41. google/adk/sessions/database_session_service.py +5 -8
  42. google/adk/sessions/vertex_ai_session_service.py +28 -13
  43. google/adk/telemetry.py +4 -2
  44. google/adk/tools/agent_tool.py +1 -1
  45. google/adk/tools/apihub_tool/apihub_toolset.py +1 -1
  46. google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
  47. google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
  48. google/adk/tools/application_integration_tool/application_integration_toolset.py +6 -2
  49. google/adk/tools/application_integration_tool/clients/connections_client.py +8 -1
  50. google/adk/tools/application_integration_tool/clients/integration_client.py +3 -1
  51. google/adk/tools/application_integration_tool/integration_connector_tool.py +1 -1
  52. google/adk/tools/base_toolset.py +40 -2
  53. google/adk/tools/bigquery/__init__.py +38 -0
  54. google/adk/tools/bigquery/bigquery_credentials.py +217 -0
  55. google/adk/tools/bigquery/bigquery_tool.py +116 -0
  56. google/adk/tools/bigquery/bigquery_toolset.py +86 -0
  57. google/adk/tools/bigquery/client.py +33 -0
  58. google/adk/tools/bigquery/metadata_tool.py +249 -0
  59. google/adk/tools/bigquery/query_tool.py +76 -0
  60. google/adk/tools/function_parameter_parse_util.py +7 -0
  61. google/adk/tools/function_tool.py +33 -3
  62. google/adk/tools/get_user_choice_tool.py +1 -0
  63. google/adk/tools/google_api_tool/__init__.py +17 -11
  64. google/adk/tools/google_api_tool/google_api_tool.py +1 -1
  65. google/adk/tools/google_api_tool/google_api_toolset.py +0 -14
  66. google/adk/tools/google_api_tool/google_api_toolsets.py +8 -2
  67. google/adk/tools/google_search_tool.py +2 -2
  68. google/adk/tools/mcp_tool/conversion_utils.py +6 -2
  69. google/adk/tools/mcp_tool/mcp_session_manager.py +62 -188
  70. google/adk/tools/mcp_tool/mcp_tool.py +27 -24
  71. google/adk/tools/mcp_tool/mcp_toolset.py +76 -131
  72. google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
  73. google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
  74. google/adk/tools/openapi_tool/common/common.py +5 -1
  75. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
  76. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +2 -7
  77. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -1
  78. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
  79. google/adk/tools/toolbox_toolset.py +31 -3
  80. google/adk/utils/__init__.py +13 -0
  81. google/adk/utils/instructions_utils.py +131 -0
  82. google/adk/version.py +1 -1
  83. {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/METADATA +12 -15
  84. {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/RECORD +87 -78
  85. google/adk/agents/base_agent.py.orig +0 -330
  86. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
  87. google/adk/cli/fast_api.py.orig +0 -822
  88. google/adk/memory/base_memory_service.py.orig +0 -76
  89. google/adk/models/google_llm.py.orig +0 -305
  90. google/adk/tools/_built_in_code_execution_tool.py +0 -70
  91. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +0 -322
  92. {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/WHEEL +0 -0
  93. {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/entry_points.txt +0 -0
  94. {google_adk-1.0.0.dist-info → google_adk-1.1.1.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from google.oauth2.credentials import Credentials
16
+
17
+ from ...tools.bigquery import client
18
+
19
+ MAX_DOWNLOADED_QUERY_RESULT_ROWS = 50
20
+
21
+
22
+ def execute_sql(project_id: str, query: str, credentials: Credentials) -> dict:
23
+ """Run a BigQuery SQL query in the project and return the result.
24
+
25
+ Args:
26
+ project_id (str): The GCP project id in which the query should be
27
+ executed.
28
+ query (str): The BigQuery SQL query to be executed.
29
+ credentials (Credentials): The credentials to use for the request.
30
+
31
+ Returns:
32
+ dict: Dictionary representing the result of the query.
33
+ If the result contains the key "result_is_likely_truncated" with
34
+ value True, it means that there may be additional rows matching the
35
+ query not returned in the result.
36
+
37
+ Examples:
38
+ >>> execute_sql("bigframes-dev",
39
+ ... "SELECT island, COUNT(*) AS population "
40
+ ... "FROM bigquery-public-data.ml_datasets.penguins GROUP BY island")
41
+ {
42
+ "rows": [
43
+ {
44
+ "island": "Dream",
45
+ "population": 124
46
+ },
47
+ {
48
+ "island": "Biscoe",
49
+ "population": 168
50
+ },
51
+ {
52
+ "island": "Torgersen",
53
+ "population": 52
54
+ }
55
+ ]
56
+ }
57
+ """
58
+
59
+ try:
60
+ bq_client = client.get_bigquery_client(credentials=credentials)
61
+ row_iterator = bq_client.query_and_wait(
62
+ query, project=project_id, max_results=MAX_DOWNLOADED_QUERY_RESULT_ROWS
63
+ )
64
+ rows = [{key: val for key, val in row.items()} for row in row_iterator]
65
+ result = {"rows": rows}
66
+ if (
67
+ MAX_DOWNLOADED_QUERY_RESULT_ROWS is not None
68
+ and len(rows) == MAX_DOWNLOADED_QUERY_RESULT_ROWS
69
+ ):
70
+ result["result_is_likely_truncated"] = True
71
+ return result
72
+ except Exception as ex:
73
+ return {
74
+ "status": "ERROR",
75
+ "error_details": str(ex),
76
+ }
@@ -289,6 +289,13 @@ def _parse_schema_from_parameter(
289
289
  )
290
290
  _raise_if_schema_unsupported(variant, schema)
291
291
  return schema
292
+ if param.annotation is None:
293
+ # https://swagger.io/docs/specification/v3_0/data-models/data-types/#null
294
+ # null is not a valid type in schema, use object instead.
295
+ schema.type = types.Type.OBJECT
296
+ schema.nullable = True
297
+ _raise_if_schema_unsupported(variant, schema)
298
+ return schema
292
299
  raise ValueError(
293
300
  f'Failed to parse the parameter {param} of function {func_name} for'
294
301
  ' automatic function calling. Automatic function calling works best with'
@@ -33,8 +33,31 @@ class FunctionTool(BaseTool):
33
33
  """
34
34
 
35
35
  def __init__(self, func: Callable[..., Any]):
36
- super().__init__(name=func.__name__, description=func.__doc__)
36
+ """Extract metadata from a callable object."""
37
+ name = ''
38
+ doc = ''
39
+ # Handle different types of callables
40
+ if hasattr(func, '__name__'):
41
+ # Regular functions, unbound methods, etc.
42
+ name = func.__name__
43
+ elif hasattr(func, '__class__'):
44
+ # Callable objects, bound methods, etc.
45
+ name = func.__class__.__name__
46
+
47
+ # Get documentation (prioritize direct __doc__ if available)
48
+ if hasattr(func, '__doc__') and func.__doc__:
49
+ doc = func.__doc__
50
+ elif (
51
+ hasattr(func, '__call__')
52
+ and hasattr(func.__call__, '__doc__')
53
+ and func.__call__.__doc__
54
+ ):
55
+ # For callable objects, try to get docstring from __call__ method
56
+ doc = func.__call__.__doc__
57
+
58
+ super().__init__(name=name, description=doc)
37
59
  self.func = func
60
+ self._ignore_params = ['tool_context', 'input_stream']
38
61
 
39
62
  @override
40
63
  def _get_declaration(self) -> Optional[types.FunctionDeclaration]:
@@ -43,7 +66,7 @@ class FunctionTool(BaseTool):
43
66
  func=self.func,
44
67
  # The model doesn't understand the function context.
45
68
  # input_stream is for streaming tool
46
- ignore_params=['tool_context', 'input_stream'],
69
+ ignore_params=self._ignore_params,
47
70
  variant=self._api_variant,
48
71
  )
49
72
  )
@@ -76,7 +99,14 @@ class FunctionTool(BaseTool):
76
99
  You could retry calling this tool, but it is IMPORTANT for you to provide all the mandatory parameters."""
77
100
  return {'error': error_str}
78
101
 
79
- if inspect.iscoroutinefunction(self.func):
102
+ # Functions are callable objects, but not all callable objects are functions
103
+ # checking coroutine function is not enough. We also need to check whether
104
+ # Callable's __call__ function is a coroutine funciton
105
+ if (
106
+ inspect.iscoroutinefunction(self.func)
107
+ or hasattr(self.func, '__call__')
108
+ and inspect.iscoroutinefunction(self.func.__call__)
109
+ ):
80
110
  return await self.func(**args_to_call) or {}
81
111
  else:
82
112
  return self.func(**args_to_call) or {}
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  from typing import Optional
16
+
16
17
  from .long_running_tool import LongRunningFunctionTool
17
18
  from .tool_context import ToolContext
18
19
 
@@ -11,18 +11,12 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
- __all__ = [
15
- 'BigQueryToolset',
16
- 'CalendarToolset',
17
- 'GmailToolset',
18
- 'YoutubeToolset',
19
- 'SlidesToolset',
20
- 'SheetsToolset',
21
- 'DocsToolset',
22
- 'GoogleApiToolset',
23
- 'GoogleApiTool',
24
- ]
25
14
 
15
+ """Auto-generated tools and toolsets for Google APIs.
16
+
17
+ These tools and toolsets are auto-generated based on the API specifications
18
+ provided by the Google API Discovery API.
19
+ """
26
20
 
27
21
  from .google_api_tool import GoogleApiTool
28
22
  from .google_api_toolset import GoogleApiToolset
@@ -33,3 +27,15 @@ from .google_api_toolsets import GmailToolset
33
27
  from .google_api_toolsets import SheetsToolset
34
28
  from .google_api_toolsets import SlidesToolset
35
29
  from .google_api_toolsets import YoutubeToolset
30
+
31
+ __all__ = [
32
+ 'BigQueryToolset',
33
+ 'CalendarToolset',
34
+ 'GmailToolset',
35
+ 'YoutubeToolset',
36
+ 'SlidesToolset',
37
+ 'SheetsToolset',
38
+ 'DocsToolset',
39
+ 'GoogleApiToolset',
40
+ 'GoogleApiTool',
41
+ ]
@@ -19,10 +19,10 @@ from typing import Optional
19
19
  from google.genai.types import FunctionDeclaration
20
20
  from typing_extensions import override
21
21
 
22
+ from .. import BaseTool
22
23
  from ...auth import AuthCredential
23
24
  from ...auth import AuthCredentialTypes
24
25
  from ...auth import OAuth2Auth
25
- from .. import BaseTool
26
26
  from ..openapi_tool import RestApiTool
27
27
  from ..tool_context import ToolContext
28
28
 
@@ -56,20 +56,6 @@ class GoogleApiToolset(BaseToolset):
56
56
  self._openapi_toolset = self._load_toolset_with_oidc_auth()
57
57
  self.tool_filter = tool_filter
58
58
 
59
- def _is_tool_selected(
60
- self, tool: GoogleApiTool, readonly_context: ReadonlyContext
61
- ) -> bool:
62
- if not self.tool_filter:
63
- return True
64
-
65
- if isinstance(self.tool_filter, ToolPredicate):
66
- return self.tool_filter(tool, readonly_context)
67
-
68
- if isinstance(self.tool_filter, list):
69
- return tool.name in self.tool_filter
70
-
71
- return False
72
-
73
59
  @override
74
60
  async def get_tools(
75
61
  self, readonly_context: Optional[ReadonlyContext] = None
@@ -18,14 +18,14 @@ from typing import List
18
18
  from typing import Optional
19
19
  from typing import Union
20
20
 
21
- from google.adk.tools.base_toolset import ToolPredicate
22
-
21
+ from ..base_toolset import ToolPredicate
23
22
  from .google_api_toolset import GoogleApiToolset
24
23
 
25
24
  logger = logging.getLogger("google_adk." + __name__)
26
25
 
27
26
 
28
27
  class BigQueryToolset(GoogleApiToolset):
28
+ """Auto-generated Bigquery toolset based on Google BigQuery API v2 spec exposed by Google API discovery API"""
29
29
 
30
30
  def __init__(
31
31
  self,
@@ -37,6 +37,7 @@ class BigQueryToolset(GoogleApiToolset):
37
37
 
38
38
 
39
39
  class CalendarToolset(GoogleApiToolset):
40
+ """Auto-generated Calendar toolset based on Google Calendar API v3 spec exposed by Google API discovery API"""
40
41
 
41
42
  def __init__(
42
43
  self,
@@ -48,6 +49,7 @@ class CalendarToolset(GoogleApiToolset):
48
49
 
49
50
 
50
51
  class GmailToolset(GoogleApiToolset):
52
+ """Auto-generated Gmail toolset based on Google Gmail API v1 spec exposed by Google API discovery API"""
51
53
 
52
54
  def __init__(
53
55
  self,
@@ -59,6 +61,7 @@ class GmailToolset(GoogleApiToolset):
59
61
 
60
62
 
61
63
  class YoutubeToolset(GoogleApiToolset):
64
+ """Auto-generated Youtube toolset based on Youtube API v3 spec exposed by Google API discovery API"""
62
65
 
63
66
  def __init__(
64
67
  self,
@@ -70,6 +73,7 @@ class YoutubeToolset(GoogleApiToolset):
70
73
 
71
74
 
72
75
  class SlidesToolset(GoogleApiToolset):
76
+ """Auto-generated Slides toolset based on Google Slides API v1 spec exposed by Google API discovery API"""
73
77
 
74
78
  def __init__(
75
79
  self,
@@ -81,6 +85,7 @@ class SlidesToolset(GoogleApiToolset):
81
85
 
82
86
 
83
87
  class SheetsToolset(GoogleApiToolset):
88
+ """Auto-generated Sheets toolset based on Google Sheets API v4 spec exposed by Google API discovery API"""
84
89
 
85
90
  def __init__(
86
91
  self,
@@ -92,6 +97,7 @@ class SheetsToolset(GoogleApiToolset):
92
97
 
93
98
 
94
99
  class DocsToolset(GoogleApiToolset):
100
+ """Auto-generated Docs toolset based on Google Docs API v1 spec exposed by Google API discovery API"""
95
101
 
96
102
  def __init__(
97
103
  self,
@@ -46,7 +46,7 @@ class GoogleSearchTool(BaseTool):
46
46
  ) -> None:
47
47
  llm_request.config = llm_request.config or types.GenerateContentConfig()
48
48
  llm_request.config.tools = llm_request.config.tools or []
49
- if llm_request.model and llm_request.model.startswith('gemini-1'):
49
+ if llm_request.model and 'gemini-1' in llm_request.model:
50
50
  if llm_request.config.tools:
51
51
  print(llm_request.config.tools)
52
52
  raise ValueError(
@@ -55,7 +55,7 @@ class GoogleSearchTool(BaseTool):
55
55
  llm_request.config.tools.append(
56
56
  types.Tool(google_search_retrieval=types.GoogleSearchRetrieval())
57
57
  )
58
- elif llm_request.model and llm_request.model.startswith('gemini-2'):
58
+ elif llm_request.model and 'gemini-2' in llm_request.model:
59
59
  llm_request.config.tools.append(
60
60
  types.Tool(google_search=types.GoogleSearch())
61
61
  )
@@ -12,9 +12,13 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any, Dict
16
- from google.genai.types import Schema, Type
15
+ from typing import Any
16
+ from typing import Dict
17
+
18
+ from google.genai.types import Schema
19
+ from google.genai.types import Type
17
20
  import mcp.types as mcp_types
21
+
18
22
  from ..base_tool import BaseTool
19
23
 
20
24
 
@@ -12,8 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import asyncio
16
- from contextlib import asynccontextmanager
15
+
17
16
  from contextlib import AsyncExitStack
18
17
  import functools
19
18
  import logging
@@ -71,29 +70,27 @@ def retry_on_closed_resource(async_reinit_func_name: str):
71
70
 
72
71
  Usage:
73
72
  class MCPTool:
74
- ...
75
- async def create_session(self):
76
- self.session = ...
73
+ ...
74
+ async def create_session(self):
75
+ self.session = ...
77
76
 
78
- @retry_on_closed_resource('create_session')
79
- async def use_session(self):
80
- await self.session.call_tool()
77
+ @retry_on_closed_resource('create_session')
78
+ async def use_session(self):
79
+ await self.session.call_tool()
81
80
 
82
81
  Args:
83
- async_reinit_func_name: The name of the async function to recreate session.
82
+ async_reinit_func_name: The name of the async function to recreate session.
84
83
 
85
84
  Returns:
86
- The decorated function.
85
+ The decorated function.
87
86
  """
88
87
 
89
88
  def decorator(func):
90
- @functools.wraps(
91
- func
92
- ) # Preserves original function metadata (name, docstring)
89
+ @functools.wraps(func) # Preserves original function metadata
93
90
  async def wrapper(self, *args, **kwargs):
94
91
  try:
95
92
  return await func(self, *args, **kwargs)
96
- except anyio.ClosedResourceError:
93
+ except anyio.ClosedResourceError as close_err:
97
94
  try:
98
95
  if hasattr(self, async_reinit_func_name) and callable(
99
96
  getattr(self, async_reinit_func_name)
@@ -105,7 +102,7 @@ def retry_on_closed_resource(async_reinit_func_name: str):
105
102
  f'Function {async_reinit_func_name} does not exist in decorated'
106
103
  ' class. Please check the function name in'
107
104
  ' retry_on_closed_resource decorator.'
108
- )
105
+ ) from close_err
109
106
  except Exception as reinit_err:
110
107
  raise RuntimeError(
111
108
  f'Error reinitializing: {reinit_err}'
@@ -117,45 +114,6 @@ def retry_on_closed_resource(async_reinit_func_name: str):
117
114
  return decorator
118
115
 
119
116
 
120
- @asynccontextmanager
121
- async def tracked_stdio_client(server, errlog, process=None):
122
- """A wrapper around stdio_client that ensures proper process tracking and cleanup."""
123
- our_process = process
124
-
125
- # If no process was provided, create one
126
- if our_process is None:
127
- our_process = await asyncio.create_subprocess_exec(
128
- server.command,
129
- *server.args,
130
- stdin=asyncio.subprocess.PIPE,
131
- stdout=asyncio.subprocess.PIPE,
132
- stderr=errlog,
133
- )
134
-
135
- # Use the original stdio_client, but ensure process cleanup
136
- try:
137
- async with stdio_client(server=server, errlog=errlog) as client:
138
- yield client, our_process
139
- finally:
140
- # Ensure the process is properly terminated if it still exists
141
- if our_process and our_process.returncode is None:
142
- try:
143
- logger.info(
144
- f'Terminating process {our_process.pid} from tracked_stdio_client'
145
- )
146
- our_process.terminate()
147
- try:
148
- await asyncio.wait_for(our_process.wait(), timeout=3.0)
149
- except asyncio.TimeoutError:
150
- # Force kill if it doesn't terminate quickly
151
- if our_process.returncode is None:
152
- logger.warning(f'Forcing kill of process {our_process.pid}')
153
- our_process.kill()
154
- except ProcessLookupError:
155
- # Process already gone, that's fine
156
- logger.info(f'Process {our_process.pid} already terminated')
157
-
158
-
159
117
  class MCPSessionManager:
160
118
  """Manages MCP client sessions.
161
119
 
@@ -166,162 +124,78 @@ class MCPSessionManager:
166
124
  def __init__(
167
125
  self,
168
126
  connection_params: StdioServerParameters | SseServerParams,
169
- exit_stack: AsyncExitStack,
170
127
  errlog: TextIO = sys.stderr,
171
128
  ):
172
129
  """Initializes the MCP session manager.
173
130
 
174
- Example usage:
175
- ```
176
- mcp_session_manager = MCPSessionManager(
177
- connection_params=connection_params,
178
- exit_stack=exit_stack,
179
- )
180
- session = await mcp_session_manager.create_session()
181
- ```
182
-
183
131
  Args:
184
132
  connection_params: Parameters for the MCP connection (Stdio or SSE).
185
- exit_stack: AsyncExitStack to manage the session lifecycle.
186
133
  errlog: (Optional) TextIO stream for error logging. Use only for
187
134
  initializing a local stdio MCP session.
188
135
  """
189
-
190
136
  self._connection_params = connection_params
191
- self._exit_stack = exit_stack
192
137
  self._errlog = errlog
193
- self._process = None # Track the subprocess
194
- self._active_processes = set() # Track all processes created
195
- self._active_file_handles = set() # Track file handles
138
+ # Each session manager maintains its own exit stack for proper cleanup
139
+ self._exit_stack: Optional[AsyncExitStack] = None
140
+ self._session: Optional[ClientSession] = None
196
141
 
197
- async def create_session(
198
- self,
199
- ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
200
- """Creates a new MCP session and tracks the associated process."""
201
- session, process = await self._initialize_session(
202
- connection_params=self._connection_params,
203
- exit_stack=self._exit_stack,
204
- errlog=self._errlog,
205
- )
206
- self._process = process # Store reference to process
207
-
208
- # Track the process
209
- if process:
210
- self._active_processes.add(process)
211
-
212
- return session, process
213
-
214
- @classmethod
215
- async def _initialize_session(
216
- cls,
217
- *,
218
- connection_params: StdioServerParameters | SseServerParams,
219
- exit_stack: AsyncExitStack,
220
- errlog: TextIO = sys.stderr,
221
- ) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
222
- """Initializes an MCP client session.
223
-
224
- Args:
225
- connection_params: Parameters for the MCP connection (Stdio or SSE).
226
- exit_stack: AsyncExitStack to manage the session lifecycle.
227
- errlog: (Optional) TextIO stream for error logging. Use only for
228
- initializing a local stdio MCP session.
142
+ async def create_session(self) -> ClientSession:
143
+ """Creates and initializes an MCP client session.
229
144
 
230
145
  Returns:
231
146
  ClientSession: The initialized MCP client session.
232
147
  """
233
- process = None
234
-
235
- if isinstance(connection_params, StdioServerParameters):
236
- # For stdio connections, we need to track the subprocess
237
- client, process = await cls._create_stdio_client(
238
- server=connection_params,
239
- errlog=errlog,
240
- exit_stack=exit_stack,
241
- )
242
- elif isinstance(connection_params, SseServerParams):
243
- # For SSE connections, create the client without a subprocess
244
- client = sse_client(
245
- url=connection_params.url,
246
- headers=connection_params.headers,
247
- timeout=connection_params.timeout,
248
- sse_read_timeout=connection_params.sse_read_timeout,
249
- )
250
- else:
251
- raise ValueError(
252
- 'Unable to initialize connection. Connection should be'
253
- ' StdioServerParameters or SseServerParams, but got'
254
- f' {connection_params}'
255
- )
148
+ if self._session is not None:
149
+ return self._session
256
150
 
257
- # Create the session with the client
258
- transports = await exit_stack.enter_async_context(client)
259
- session = await exit_stack.enter_async_context(ClientSession(*transports))
260
- await session.initialize()
151
+ # Create a new exit stack for this session
152
+ self._exit_stack = AsyncExitStack()
261
153
 
262
- return session, process
263
-
264
- @staticmethod
265
- async def _create_stdio_client(
266
- server: StdioServerParameters,
267
- errlog: TextIO,
268
- exit_stack: AsyncExitStack,
269
- ) -> tuple[Any, asyncio.subprocess.Process]:
270
- """Create stdio client and return both the client and process.
271
-
272
- This implementation adapts to how the MCP stdio_client is created.
273
- The actual implementation may need to be adjusted based on the MCP library
274
- structure.
275
- """
276
- # Create the subprocess directly so we can track it
277
- process = await asyncio.create_subprocess_exec(
278
- server.command,
279
- *server.args,
280
- stdin=asyncio.subprocess.PIPE,
281
- stdout=asyncio.subprocess.PIPE,
282
- stderr=errlog,
283
- )
284
-
285
- # Create the stdio client using the MCP library
286
154
  try:
287
- # Method 1: Try using the existing process if stdio_client supports it
288
- client = stdio_client(server=server, errlog=errlog, process=process)
289
- except TypeError:
290
- # Method 2: If the above doesn't work, let stdio_client create its own process
291
- # and we'll need to terminate both processes later
292
- logger.warning(
293
- 'Using stdio_client with its own process - may lead to duplicate'
294
- ' processes'
295
- )
296
- client = stdio_client(server=server, errlog=errlog)
155
+ if isinstance(self._connection_params, StdioServerParameters):
156
+ client = stdio_client(
157
+ server=self._connection_params, errlog=self._errlog
158
+ )
159
+ elif isinstance(self._connection_params, SseServerParams):
160
+ client = sse_client(
161
+ url=self._connection_params.url,
162
+ headers=self._connection_params.headers,
163
+ timeout=self._connection_params.timeout,
164
+ sse_read_timeout=self._connection_params.sse_read_timeout,
165
+ )
166
+ else:
167
+ raise ValueError(
168
+ 'Unable to initialize connection. Connection should be'
169
+ ' StdioServerParameters or SseServerParams, but got'
170
+ f' {self._connection_params}'
171
+ )
297
172
 
298
- return client, process
173
+ transports = await self._exit_stack.enter_async_context(client)
174
+ session = await self._exit_stack.enter_async_context(
175
+ ClientSession(*transports)
176
+ )
177
+ await session.initialize()
299
178
 
300
- async def _emergency_cleanup(self):
301
- """Perform emergency cleanup of resources when normal cleanup fails."""
302
- logger.info('Performing emergency cleanup of MCPSessionManager resources')
179
+ self._session = session
180
+ return session
303
181
 
304
- # Clean up any tracked processes
305
- for proc in list(self._active_processes):
306
- try:
307
- if proc and proc.returncode is None:
308
- logger.info(f'Emergency termination of process {proc.pid}')
309
- proc.terminate()
310
- try:
311
- await asyncio.wait_for(proc.wait(), timeout=1.0)
312
- except asyncio.TimeoutError:
313
- logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
314
- proc.kill()
315
- self._active_processes.remove(proc)
316
- except Exception as e:
317
- logger.error(f'Error during process cleanup: {e}')
182
+ except Exception:
183
+ # If session creation fails, clean up the exit stack
184
+ if self._exit_stack:
185
+ await self._exit_stack.aclose()
186
+ self._exit_stack = None
187
+ raise
318
188
 
319
- # Clean up any tracked file handles
320
- for handle in list(self._active_file_handles):
189
+ async def close(self):
190
+ """Closes the session and cleans up resources."""
191
+ if self._exit_stack:
321
192
  try:
322
- if not handle.closed:
323
- logger.info('Closing file handle')
324
- handle.close()
325
- self._active_file_handles.remove(handle)
193
+ await self._exit_stack.aclose()
326
194
  except Exception as e:
327
- logger.error(f'Error closing file handle: {e}')
195
+ # Log the error but don't re-raise to avoid blocking shutdown
196
+ print(
197
+ f'Warning: Error during MCP session cleanup: {e}', file=self._errlog
198
+ )
199
+ finally:
200
+ self._exit_stack = None
201
+ self._session = None