google-adk 0.1.0__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +4 -4
- google/adk/agents/invocation_context.py +1 -1
- google/adk/agents/remote_agent.py +1 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/auth/auth_preprocessor.py +2 -2
- google/adk/auth/auth_tool.py +1 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-SLIAU2JL.js → main-ZBO76GRM.js} +65 -81
- google/adk/cli/cli_create.py +279 -0
- google/adk/cli/cli_deploy.py +4 -0
- google/adk/cli/cli_eval.py +2 -2
- google/adk/cli/cli_tools_click.py +67 -7
- google/adk/cli/fast_api.py +51 -16
- google/adk/cli/utils/envs.py +0 -3
- google/adk/cli/utils/evals.py +2 -2
- google/adk/evaluation/evaluation_generator.py +4 -4
- google/adk/evaluation/response_evaluator.py +15 -3
- google/adk/events/event.py +3 -3
- google/adk/flows/llm_flows/_nl_planning.py +10 -4
- google/adk/flows/llm_flows/contents.py +1 -1
- google/adk/models/lite_llm.py +51 -34
- google/adk/planners/plan_re_act_planner.py +2 -2
- google/adk/runners.py +1 -1
- google/adk/sessions/database_session_service.py +84 -23
- google/adk/sessions/state.py +1 -1
- google/adk/telemetry.py +2 -2
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -2
- google/adk/tools/base_tool.py +1 -1
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +74 -1
- google/adk/tools/google_api_tool/google_api_tool_sets.py +91 -34
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +3 -1
- google/adk/tools/load_memory_tool.py +25 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +176 -0
- google/adk/tools/mcp_tool/mcp_tool.py +15 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +31 -37
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +1 -1
- google/adk/tools/toolbox_tool.py +1 -1
- google/adk/version.py +1 -1
- google_adk-0.2.0.dist-info/METADATA +212 -0
- {google_adk-0.1.0.dist-info → google_adk-0.2.0.dist-info}/RECORD +44 -42
- google_adk-0.1.0.dist-info/METADATA +0 -160
- {google_adk-0.1.0.dist-info → google_adk-0.2.0.dist-info}/WHEEL +0 -0
- {google_adk-0.1.0.dist-info → google_adk-0.2.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.1.0.dist-info → google_adk-0.2.0.dist-info}/licenses/LICENSE +0 -0
@@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
|
|
19
19
|
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
)
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
22
|
+
_bigquery_tool_set = None
|
23
|
+
_calendar_tool_set = None
|
24
|
+
_gmail_tool_set = None
|
25
|
+
_youtube_tool_set = None
|
26
|
+
_slides_tool_set = None
|
27
|
+
_sheets_tool_set = None
|
28
|
+
_docs_tool_set = None
|
29
|
+
|
30
|
+
|
31
|
+
def __getattr__(name):
|
32
|
+
"""This method dynamically loads and returns GoogleApiToolSet instances for
|
33
|
+
|
34
|
+
various Google APIs. It uses a lazy loading approach, initializing each
|
35
|
+
tool set only when it is first requested. This avoids unnecessary loading
|
36
|
+
of tool sets that are not used in a given session.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
name (str): The name of the tool set to retrieve (e.g.,
|
40
|
+
"bigquery_tool_set").
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
GoogleApiToolSet: The requested tool set instance.
|
44
|
+
|
45
|
+
Raises:
|
46
|
+
AttributeError: If the requested tool set name is not recognized.
|
47
|
+
"""
|
48
|
+
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
49
|
+
|
50
|
+
match name:
|
51
|
+
case "bigquery_tool_set":
|
52
|
+
if _bigquery_tool_set is None:
|
53
|
+
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
54
|
+
api_name="bigquery",
|
55
|
+
api_version="v2",
|
56
|
+
)
|
57
|
+
|
58
|
+
return _bigquery_tool_set
|
59
|
+
|
60
|
+
case "calendar_tool_set":
|
61
|
+
if _calendar_tool_set is None:
|
62
|
+
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
63
|
+
api_name="calendar",
|
64
|
+
api_version="v3",
|
65
|
+
)
|
66
|
+
|
67
|
+
return _calendar_tool_set
|
68
|
+
|
69
|
+
case "gmail_tool_set":
|
70
|
+
if _gmail_tool_set is None:
|
71
|
+
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
72
|
+
api_name="gmail",
|
73
|
+
api_version="v1",
|
74
|
+
)
|
75
|
+
|
76
|
+
return _gmail_tool_set
|
77
|
+
|
78
|
+
case "youtube_tool_set":
|
79
|
+
if _youtube_tool_set is None:
|
80
|
+
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
81
|
+
api_name="youtube",
|
82
|
+
api_version="v3",
|
83
|
+
)
|
84
|
+
|
85
|
+
return _youtube_tool_set
|
86
|
+
|
87
|
+
case "slides_tool_set":
|
88
|
+
if _slides_tool_set is None:
|
89
|
+
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
90
|
+
api_name="slides",
|
91
|
+
api_version="v1",
|
92
|
+
)
|
93
|
+
|
94
|
+
return _slides_tool_set
|
95
|
+
|
96
|
+
case "sheets_tool_set":
|
97
|
+
if _sheets_tool_set is None:
|
98
|
+
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
99
|
+
api_name="sheets",
|
100
|
+
api_version="v4",
|
101
|
+
)
|
102
|
+
|
103
|
+
return _sheets_tool_set
|
104
|
+
|
105
|
+
case "docs_tool_set":
|
106
|
+
if _docs_tool_set is None:
|
107
|
+
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
108
|
+
api_name="docs",
|
109
|
+
api_version="v1",
|
110
|
+
)
|
111
|
+
|
112
|
+
return _docs_tool_set
|
@@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
|
|
311
311
|
|
312
312
|
# Determine the actual endpoint path
|
313
313
|
# Google often has the format something like 'users.messages.list'
|
314
|
-
|
314
|
+
# flatPath is preferred as it provides the actual path, while path
|
315
|
+
# might contain variables like {+projectId}
|
316
|
+
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
|
315
317
|
if not rest_path.startswith("/"):
|
316
318
|
rest_path = "/" + rest_path
|
317
319
|
|
@@ -16,18 +16,26 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
from typing import TYPE_CHECKING
|
18
18
|
|
19
|
+
from google.genai import types
|
19
20
|
from typing_extensions import override
|
20
21
|
|
21
22
|
from .function_tool import FunctionTool
|
22
23
|
from .tool_context import ToolContext
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
25
|
-
from ..models import LlmRequest
|
26
26
|
from ..memory.base_memory_service import MemoryResult
|
27
|
+
from ..models import LlmRequest
|
27
28
|
|
28
29
|
|
29
30
|
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
30
|
-
"""Loads the memory for the current user.
|
31
|
+
"""Loads the memory for the current user.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
query: The query to load the memory for.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
A list of memory results.
|
38
|
+
"""
|
31
39
|
response = tool_context.search_memory(query)
|
32
40
|
return response.memories
|
33
41
|
|
@@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
|
|
38
46
|
def __init__(self):
|
39
47
|
super().__init__(load_memory)
|
40
48
|
|
49
|
+
@override
|
50
|
+
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
51
|
+
return types.FunctionDeclaration(
|
52
|
+
name=self.name,
|
53
|
+
description=self.description,
|
54
|
+
parameters=types.Schema(
|
55
|
+
type=types.Type.OBJECT,
|
56
|
+
properties={
|
57
|
+
'query': types.Schema(
|
58
|
+
type=types.Type.STRING,
|
59
|
+
)
|
60
|
+
},
|
61
|
+
),
|
62
|
+
)
|
63
|
+
|
41
64
|
@override
|
42
65
|
async def process_llm_request(
|
43
66
|
self,
|
@@ -0,0 +1,176 @@
|
|
1
|
+
from contextlib import AsyncExitStack
|
2
|
+
import functools
|
3
|
+
import sys
|
4
|
+
from typing import Any, TextIO
|
5
|
+
import anyio
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
try:
|
9
|
+
from mcp import ClientSession, StdioServerParameters
|
10
|
+
from mcp.client.sse import sse_client
|
11
|
+
from mcp.client.stdio import stdio_client
|
12
|
+
except ImportError as e:
|
13
|
+
import sys
|
14
|
+
|
15
|
+
if sys.version_info < (3, 10):
|
16
|
+
raise ImportError(
|
17
|
+
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
18
|
+
' version.'
|
19
|
+
) from e
|
20
|
+
else:
|
21
|
+
raise e
|
22
|
+
|
23
|
+
|
24
|
+
class SseServerParams(BaseModel):
|
25
|
+
"""Parameters for the MCP SSE connection.
|
26
|
+
|
27
|
+
See MCP SSE Client documentation for more details.
|
28
|
+
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
29
|
+
"""
|
30
|
+
|
31
|
+
url: str
|
32
|
+
headers: dict[str, Any] | None = None
|
33
|
+
timeout: float = 5
|
34
|
+
sse_read_timeout: float = 60 * 5
|
35
|
+
|
36
|
+
|
37
|
+
def retry_on_closed_resource(async_reinit_func_name: str):
|
38
|
+
"""Decorator to automatically reinitialize session and retry action.
|
39
|
+
|
40
|
+
When MCP session was closed, the decorator will automatically recreate the
|
41
|
+
session and retry the action with the same parameters.
|
42
|
+
|
43
|
+
Note:
|
44
|
+
1. async_reinit_func_name is the name of the class member function that
|
45
|
+
reinitializes the MCP session.
|
46
|
+
2. Both the decorated function and the async_reinit_func_name must be async
|
47
|
+
functions.
|
48
|
+
|
49
|
+
Usage:
|
50
|
+
class MCPTool:
|
51
|
+
...
|
52
|
+
async def create_session(self):
|
53
|
+
self.session = ...
|
54
|
+
|
55
|
+
@retry_on_closed_resource('create_session')
|
56
|
+
async def use_session(self):
|
57
|
+
await self.session.call_tool()
|
58
|
+
|
59
|
+
Args:
|
60
|
+
async_reinit_func_name: The name of the async function to recreate session.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The decorated function.
|
64
|
+
"""
|
65
|
+
|
66
|
+
def decorator(func):
|
67
|
+
@functools.wraps(
|
68
|
+
func
|
69
|
+
) # Preserves original function metadata (name, docstring)
|
70
|
+
async def wrapper(self, *args, **kwargs):
|
71
|
+
try:
|
72
|
+
return await func(self, *args, **kwargs)
|
73
|
+
except anyio.ClosedResourceError:
|
74
|
+
try:
|
75
|
+
if hasattr(self, async_reinit_func_name) and callable(
|
76
|
+
getattr(self, async_reinit_func_name)
|
77
|
+
):
|
78
|
+
async_init_fn = getattr(self, async_reinit_func_name)
|
79
|
+
await async_init_fn()
|
80
|
+
else:
|
81
|
+
raise ValueError(
|
82
|
+
f'Function {async_reinit_func_name} does not exist in decorated'
|
83
|
+
' class. Please check the function name in'
|
84
|
+
' retry_on_closed_resource decorator.'
|
85
|
+
)
|
86
|
+
except Exception as reinit_err:
|
87
|
+
raise RuntimeError(
|
88
|
+
f'Error reinitializing: {reinit_err}'
|
89
|
+
) from reinit_err
|
90
|
+
return await func(self, *args, **kwargs)
|
91
|
+
|
92
|
+
return wrapper
|
93
|
+
|
94
|
+
return decorator
|
95
|
+
|
96
|
+
|
97
|
+
class MCPSessionManager:
|
98
|
+
"""Manages MCP client sessions.
|
99
|
+
|
100
|
+
This class provides methods for creating and initializing MCP client sessions,
|
101
|
+
handling different connection parameters (Stdio and SSE).
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
connection_params: StdioServerParameters | SseServerParams,
|
107
|
+
exit_stack: AsyncExitStack,
|
108
|
+
errlog: TextIO = sys.stderr,
|
109
|
+
) -> ClientSession:
|
110
|
+
"""Initializes the MCP session manager.
|
111
|
+
|
112
|
+
Example usage:
|
113
|
+
```
|
114
|
+
mcp_session_manager = MCPSessionManager(
|
115
|
+
connection_params=connection_params,
|
116
|
+
exit_stack=exit_stack,
|
117
|
+
)
|
118
|
+
session = await mcp_session_manager.create_session()
|
119
|
+
```
|
120
|
+
|
121
|
+
Args:
|
122
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
123
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
124
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
125
|
+
initializing a local stdio MCP session.
|
126
|
+
"""
|
127
|
+
self.connection_params = connection_params
|
128
|
+
self.exit_stack = exit_stack
|
129
|
+
self.errlog = errlog
|
130
|
+
|
131
|
+
async def create_session(self) -> ClientSession:
|
132
|
+
return await MCPSessionManager.initialize_session(
|
133
|
+
connection_params=self.connection_params,
|
134
|
+
exit_stack=self.exit_stack,
|
135
|
+
errlog=self.errlog,
|
136
|
+
)
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
async def initialize_session(
|
140
|
+
cls,
|
141
|
+
*,
|
142
|
+
connection_params: StdioServerParameters | SseServerParams,
|
143
|
+
exit_stack: AsyncExitStack,
|
144
|
+
errlog: TextIO = sys.stderr,
|
145
|
+
) -> ClientSession:
|
146
|
+
"""Initializes an MCP client session.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
150
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
151
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
152
|
+
initializing a local stdio MCP session.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
ClientSession: The initialized MCP client session.
|
156
|
+
"""
|
157
|
+
if isinstance(connection_params, StdioServerParameters):
|
158
|
+
client = stdio_client(server=connection_params, errlog=errlog)
|
159
|
+
elif isinstance(connection_params, SseServerParams):
|
160
|
+
client = sse_client(
|
161
|
+
url=connection_params.url,
|
162
|
+
headers=connection_params.headers,
|
163
|
+
timeout=connection_params.timeout,
|
164
|
+
sse_read_timeout=connection_params.sse_read_timeout,
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
raise ValueError(
|
168
|
+
'Unable to initialize connection. Connection should be'
|
169
|
+
' StdioServerParameters or SseServerParams, but got'
|
170
|
+
f' {connection_params}'
|
171
|
+
)
|
172
|
+
|
173
|
+
transports = await exit_stack.enter_async_context(client)
|
174
|
+
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
175
|
+
await session.initialize()
|
176
|
+
return session
|
@@ -17,6 +17,8 @@ from typing import Optional
|
|
17
17
|
from google.genai.types import FunctionDeclaration
|
18
18
|
from typing_extensions import override
|
19
19
|
|
20
|
+
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
21
|
+
|
20
22
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
21
23
|
# their Python version to 3.10 if it fails.
|
22
24
|
try:
|
@@ -33,6 +35,7 @@ except ImportError as e:
|
|
33
35
|
else:
|
34
36
|
raise e
|
35
37
|
|
38
|
+
|
36
39
|
from ..base_tool import BaseTool
|
37
40
|
from ...auth.auth_credential import AuthCredential
|
38
41
|
from ...auth.auth_schemes import AuthScheme
|
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
|
|
51
54
|
self,
|
52
55
|
mcp_tool: McpBaseTool,
|
53
56
|
mcp_session: ClientSession,
|
57
|
+
mcp_session_manager: MCPSessionManager,
|
54
58
|
auth_scheme: Optional[AuthScheme] = None,
|
55
59
|
auth_credential: Optional[AuthCredential] | None = None,
|
56
60
|
):
|
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
|
|
79
83
|
self.description = mcp_tool.description if mcp_tool.description else ""
|
80
84
|
self.mcp_tool = mcp_tool
|
81
85
|
self.mcp_session = mcp_session
|
86
|
+
self.mcp_session_manager = mcp_session_manager
|
82
87
|
# TODO(cheliu): Support passing auth to MCP Server.
|
83
88
|
self.auth_scheme = auth_scheme
|
84
89
|
self.auth_credential = auth_credential
|
85
90
|
|
91
|
+
async def _reinitialize_session(self):
|
92
|
+
self.mcp_session = await self.mcp_session_manager.create_session()
|
93
|
+
|
86
94
|
@override
|
87
95
|
def _get_declaration(self) -> FunctionDeclaration:
|
88
96
|
"""Gets the function declaration for the tool.
|
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
|
|
98
106
|
return function_decl
|
99
107
|
|
100
108
|
@override
|
109
|
+
@retry_on_closed_resource("_reinitialize_session")
|
101
110
|
async def run_async(self, *, args, tool_context: ToolContext):
|
102
111
|
"""Runs the tool asynchronously.
|
103
112
|
|
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
|
|
109
118
|
Any: The response from the tool.
|
110
119
|
"""
|
111
120
|
# TODO(cheliu): Support passing tool context to MCP Server.
|
112
|
-
|
113
|
-
|
121
|
+
try:
|
122
|
+
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
123
|
+
return response
|
124
|
+
except Exception as e:
|
125
|
+
print(e)
|
126
|
+
raise e
|
@@ -13,15 +13,16 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from contextlib import AsyncExitStack
|
16
|
+
import sys
|
16
17
|
from types import TracebackType
|
17
|
-
from typing import
|
18
|
+
from typing import List, Optional, TextIO, Tuple, Type
|
19
|
+
|
20
|
+
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
18
21
|
|
19
22
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
20
23
|
# their Python version to 3.10 if it fails.
|
21
24
|
try:
|
22
25
|
from mcp import ClientSession, StdioServerParameters
|
23
|
-
from mcp.client.sse import sse_client
|
24
|
-
from mcp.client.stdio import stdio_client
|
25
26
|
from mcp.types import ListToolsResult
|
26
27
|
except ImportError as e:
|
27
28
|
import sys
|
@@ -34,18 +35,9 @@ except ImportError as e:
|
|
34
35
|
else:
|
35
36
|
raise e
|
36
37
|
|
37
|
-
from pydantic import BaseModel
|
38
|
-
|
39
38
|
from .mcp_tool import MCPTool
|
40
39
|
|
41
40
|
|
42
|
-
class SseServerParams(BaseModel):
|
43
|
-
url: str
|
44
|
-
headers: dict[str, Any] | None = None
|
45
|
-
timeout: float = 5
|
46
|
-
sse_read_timeout: float = 60 * 5
|
47
|
-
|
48
|
-
|
49
41
|
class MCPToolset:
|
50
42
|
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
51
43
|
|
@@ -110,7 +102,11 @@ class MCPToolset:
|
|
110
102
|
"""
|
111
103
|
|
112
104
|
def __init__(
|
113
|
-
self,
|
105
|
+
self,
|
106
|
+
*,
|
107
|
+
connection_params: StdioServerParameters | SseServerParams,
|
108
|
+
errlog: TextIO = sys.stderr,
|
109
|
+
exit_stack=AsyncExitStack(),
|
114
110
|
):
|
115
111
|
"""Initializes the MCPToolset.
|
116
112
|
|
@@ -175,7 +171,14 @@ class MCPToolset:
|
|
175
171
|
if not connection_params:
|
176
172
|
raise ValueError('Missing connection params in MCPToolset.')
|
177
173
|
self.connection_params = connection_params
|
178
|
-
self.
|
174
|
+
self.errlog = errlog
|
175
|
+
self.exit_stack = exit_stack
|
176
|
+
|
177
|
+
self.session_manager = MCPSessionManager(
|
178
|
+
connection_params=self.connection_params,
|
179
|
+
exit_stack=self.exit_stack,
|
180
|
+
errlog=self.errlog,
|
181
|
+
)
|
179
182
|
|
180
183
|
@classmethod
|
181
184
|
async def from_server(
|
@@ -183,6 +186,7 @@ class MCPToolset:
|
|
183
186
|
*,
|
184
187
|
connection_params: StdioServerParameters | SseServerParams,
|
185
188
|
async_exit_stack: Optional[AsyncExitStack] = None,
|
189
|
+
errlog: TextIO = sys.stderr,
|
186
190
|
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
187
191
|
"""Retrieve all tools from the MCP connection.
|
188
192
|
|
@@ -209,41 +213,27 @@ class MCPToolset:
|
|
209
213
|
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
210
214
|
connection when server shuts down.
|
211
215
|
"""
|
212
|
-
toolset = cls(connection_params=connection_params)
|
213
216
|
async_exit_stack = async_exit_stack or AsyncExitStack()
|
217
|
+
toolset = cls(
|
218
|
+
connection_params=connection_params,
|
219
|
+
exit_stack=async_exit_stack,
|
220
|
+
errlog=errlog,
|
221
|
+
)
|
222
|
+
|
214
223
|
await async_exit_stack.enter_async_context(toolset)
|
215
224
|
tools = await toolset.load_tools()
|
216
225
|
return (tools, async_exit_stack)
|
217
226
|
|
218
227
|
async def _initialize(self) -> ClientSession:
|
219
228
|
"""Connects to the MCP Server and initializes the ClientSession."""
|
220
|
-
|
221
|
-
client = stdio_client(self.connection_params)
|
222
|
-
elif isinstance(self.connection_params, SseServerParams):
|
223
|
-
client = sse_client(
|
224
|
-
url=self.connection_params.url,
|
225
|
-
headers=self.connection_params.headers,
|
226
|
-
timeout=self.connection_params.timeout,
|
227
|
-
sse_read_timeout=self.connection_params.sse_read_timeout,
|
228
|
-
)
|
229
|
-
else:
|
230
|
-
raise ValueError(
|
231
|
-
'Unable to initialize connection. Connection should be'
|
232
|
-
' StdioServerParameters or SseServerParams, but got'
|
233
|
-
f' {self.connection_params}'
|
234
|
-
)
|
235
|
-
|
236
|
-
transports = await self.exit_stack.enter_async_context(client)
|
237
|
-
self.session = await self.exit_stack.enter_async_context(
|
238
|
-
ClientSession(*transports)
|
239
|
-
)
|
240
|
-
await self.session.initialize()
|
229
|
+
self.session = await self.session_manager.create_session()
|
241
230
|
return self.session
|
242
231
|
|
243
232
|
async def _exit(self):
|
244
233
|
"""Closes the connection to MCP Server."""
|
245
234
|
await self.exit_stack.aclose()
|
246
235
|
|
236
|
+
@retry_on_closed_resource('_initialize')
|
247
237
|
async def load_tools(self) -> List[MCPTool]:
|
248
238
|
"""Loads all tools from the MCP Server.
|
249
239
|
|
@@ -252,7 +242,11 @@ class MCPToolset:
|
|
252
242
|
"""
|
253
243
|
tools_response: ListToolsResult = await self.session.list_tools()
|
254
244
|
return [
|
255
|
-
MCPTool(
|
245
|
+
MCPTool(
|
246
|
+
mcp_tool=tool,
|
247
|
+
mcp_session=self.session,
|
248
|
+
mcp_session_manager=self.session_manager,
|
249
|
+
)
|
256
250
|
for tool in tools_response.tools
|
257
251
|
]
|
258
252
|
|
@@ -28,7 +28,7 @@ from typing_extensions import override
|
|
28
28
|
|
29
29
|
from ....auth.auth_credential import AuthCredential
|
30
30
|
from ....auth.auth_schemes import AuthScheme
|
31
|
-
from ....tools import BaseTool
|
31
|
+
from ....tools.base_tool import BaseTool
|
32
32
|
from ...tool_context import ToolContext
|
33
33
|
from ..auth.auth_helpers import credential_to_param
|
34
34
|
from ..auth.auth_helpers import dict_to_auth_scheme
|
google/adk/tools/toolbox_tool.py
CHANGED
google/adk/version.py
CHANGED