google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +0 -5
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +6 -1
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +172 -99
- google/adk/cli/cli_tools_click.py +147 -64
- google/adk/cli/fast_api.py +330 -148
- google/adk/cli/fast_api.py.orig +174 -80
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -2
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +54 -15
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +13 -5
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +9 -2
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +77 -21
- google/adk/models/llm_response.py +14 -2
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +65 -41
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +58 -65
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +9 -9
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -18,19 +18,13 @@ import logging
|
|
18
18
|
from typing import Any
|
19
19
|
from typing import Dict
|
20
20
|
from typing import List
|
21
|
-
from typing import Optional
|
22
|
-
from typing import Union
|
23
21
|
|
24
22
|
# Google API client
|
25
23
|
from googleapiclient.discovery import build
|
26
|
-
from googleapiclient.discovery import Resource
|
27
24
|
from googleapiclient.errors import HttpError
|
28
25
|
|
29
26
|
# Configure logging
|
30
|
-
logging.
|
31
|
-
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
32
|
-
)
|
33
|
-
logger = logging.getLogger(__name__)
|
27
|
+
logger = logging.getLogger("google_adk." + __name__)
|
34
28
|
|
35
29
|
|
36
30
|
class GoogleApiToOpenApiConverter:
|
@@ -43,11 +37,11 @@ class GoogleApiToOpenApiConverter:
|
|
43
37
|
api_name: The name of the Google API (e.g., "calendar")
|
44
38
|
api_version: The version of the API (e.g., "v3")
|
45
39
|
"""
|
46
|
-
self.
|
47
|
-
self.
|
48
|
-
self.
|
49
|
-
self.
|
50
|
-
self.
|
40
|
+
self._api_name = api_name
|
41
|
+
self._api_version = api_version
|
42
|
+
self._google_api_resource = None
|
43
|
+
self._google_api_spec = None
|
44
|
+
self._openapi_spec = {
|
51
45
|
"openapi": "3.0.0",
|
52
46
|
"info": {},
|
53
47
|
"servers": [],
|
@@ -59,18 +53,20 @@ class GoogleApiToOpenApiConverter:
|
|
59
53
|
"""Fetches the Google API specification using discovery service."""
|
60
54
|
try:
|
61
55
|
logger.info(
|
62
|
-
"Fetching Google API spec for %s %s",
|
56
|
+
"Fetching Google API spec for %s %s",
|
57
|
+
self._api_name,
|
58
|
+
self._api_version,
|
63
59
|
)
|
64
60
|
# Build a resource object for the specified API
|
65
|
-
self.
|
61
|
+
self._google_api_resource = build(self._api_name, self._api_version)
|
66
62
|
|
67
63
|
# Access the underlying API discovery document
|
68
|
-
self.
|
64
|
+
self._google_api_spec = self._google_api_resource._rootDesc
|
69
65
|
|
70
|
-
if not self.
|
66
|
+
if not self._google_api_spec:
|
71
67
|
raise ValueError("Failed to retrieve API specification")
|
72
68
|
|
73
|
-
logger.info("Successfully fetched %s API specification", self.
|
69
|
+
logger.info("Successfully fetched %s API specification", self._api_name)
|
74
70
|
except HttpError as e:
|
75
71
|
logger.error("HTTP Error: %s", e)
|
76
72
|
raise
|
@@ -84,7 +80,7 @@ class GoogleApiToOpenApiConverter:
|
|
84
80
|
Returns:
|
85
81
|
Dict containing the converted OpenAPI v3 specification
|
86
82
|
"""
|
87
|
-
if not self.
|
83
|
+
if not self._google_api_spec:
|
88
84
|
self.fetch_google_api_spec()
|
89
85
|
|
90
86
|
# Convert basic API information
|
@@ -100,49 +96,49 @@ class GoogleApiToOpenApiConverter:
|
|
100
96
|
self._convert_schemas()
|
101
97
|
|
102
98
|
# Convert endpoints/paths
|
103
|
-
self._convert_resources(self.
|
99
|
+
self._convert_resources(self._google_api_spec.get("resources", {}))
|
104
100
|
|
105
101
|
# Convert top-level methods, if any
|
106
|
-
self._convert_methods(self.
|
102
|
+
self._convert_methods(self._google_api_spec.get("methods", {}), "/")
|
107
103
|
|
108
|
-
return self.
|
104
|
+
return self._openapi_spec
|
109
105
|
|
110
106
|
def _convert_info(self) -> None:
|
111
107
|
"""Convert basic API information."""
|
112
|
-
self.
|
113
|
-
"title": self.
|
114
|
-
"description": self.
|
115
|
-
"version": self.
|
108
|
+
self._openapi_spec["info"] = {
|
109
|
+
"title": self._google_api_spec.get("title", f"{self._api_name} API"),
|
110
|
+
"description": self._google_api_spec.get("description", ""),
|
111
|
+
"version": self._google_api_spec.get("version", self._api_version),
|
116
112
|
"contact": {},
|
117
|
-
"termsOfService": self.
|
113
|
+
"termsOfService": self._google_api_spec.get("documentationLink", ""),
|
118
114
|
}
|
119
115
|
|
120
116
|
# Add documentation links if available
|
121
|
-
docs_link = self.
|
117
|
+
docs_link = self._google_api_spec.get("documentationLink")
|
122
118
|
if docs_link:
|
123
|
-
self.
|
119
|
+
self._openapi_spec["externalDocs"] = {
|
124
120
|
"description": "API Documentation",
|
125
121
|
"url": docs_link,
|
126
122
|
}
|
127
123
|
|
128
124
|
def _convert_servers(self) -> None:
|
129
125
|
"""Convert server information."""
|
130
|
-
base_url = self.
|
126
|
+
base_url = self._google_api_spec.get(
|
131
127
|
"rootUrl", ""
|
132
|
-
) + self.
|
128
|
+
) + self._google_api_spec.get("servicePath", "")
|
133
129
|
|
134
130
|
# Remove trailing slash if present
|
135
131
|
if base_url.endswith("/"):
|
136
132
|
base_url = base_url[:-1]
|
137
133
|
|
138
|
-
self.
|
134
|
+
self._openapi_spec["servers"] = [{
|
139
135
|
"url": base_url,
|
140
|
-
"description": f"{self.
|
136
|
+
"description": f"{self._api_name} {self._api_version} API",
|
141
137
|
}]
|
142
138
|
|
143
139
|
def _convert_security_schemes(self) -> None:
|
144
140
|
"""Convert authentication and authorization schemes."""
|
145
|
-
auth = self.
|
141
|
+
auth = self._google_api_spec.get("auth", {})
|
146
142
|
oauth2 = auth.get("oauth2", {})
|
147
143
|
|
148
144
|
if oauth2:
|
@@ -153,7 +149,7 @@ class GoogleApiToOpenApiConverter:
|
|
153
149
|
for scope, scope_info in scopes.items():
|
154
150
|
formatted_scopes[scope] = scope_info.get("description", "")
|
155
151
|
|
156
|
-
self.
|
152
|
+
self._openapi_spec["components"]["securitySchemes"]["oauth2"] = {
|
157
153
|
"type": "oauth2",
|
158
154
|
"description": "OAuth 2.0 authentication",
|
159
155
|
"flows": {
|
@@ -168,7 +164,7 @@ class GoogleApiToOpenApiConverter:
|
|
168
164
|
}
|
169
165
|
|
170
166
|
# Add API key authentication (most Google APIs support this)
|
171
|
-
self.
|
167
|
+
self._openapi_spec["components"]["securitySchemes"]["apiKey"] = {
|
172
168
|
"type": "apiKey",
|
173
169
|
"in": "query",
|
174
170
|
"name": "key",
|
@@ -176,18 +172,20 @@ class GoogleApiToOpenApiConverter:
|
|
176
172
|
}
|
177
173
|
|
178
174
|
# Create global security requirement
|
179
|
-
self.
|
175
|
+
self._openapi_spec["security"] = [
|
180
176
|
{"oauth2": list(formatted_scopes.keys())} if oauth2 else {},
|
181
177
|
{"apiKey": []},
|
182
178
|
]
|
183
179
|
|
184
180
|
def _convert_schemas(self) -> None:
|
185
181
|
"""Convert schema definitions (models)."""
|
186
|
-
schemas = self.
|
182
|
+
schemas = self._google_api_spec.get("schemas", {})
|
187
183
|
|
188
184
|
for schema_name, schema_def in schemas.items():
|
189
185
|
converted_schema = self._convert_schema_object(schema_def)
|
190
|
-
self.
|
186
|
+
self._openapi_spec["components"]["schemas"][
|
187
|
+
schema_name
|
188
|
+
] = converted_schema
|
191
189
|
|
192
190
|
def _convert_schema_object(
|
193
191
|
self, schema_def: Dict[str, Any]
|
@@ -320,11 +318,11 @@ class GoogleApiToOpenApiConverter:
|
|
320
318
|
path_params = self._extract_path_parameters(rest_path)
|
321
319
|
|
322
320
|
# Create path entry if it doesn't exist
|
323
|
-
if rest_path not in self.
|
324
|
-
self.
|
321
|
+
if rest_path not in self._openapi_spec["paths"]:
|
322
|
+
self._openapi_spec["paths"][rest_path] = {}
|
325
323
|
|
326
324
|
# Add the operation for this method
|
327
|
-
self.
|
325
|
+
self._openapi_spec["paths"][rest_path][http_method] = (
|
328
326
|
self._convert_operation(method_data, path_params)
|
329
327
|
)
|
330
328
|
|
@@ -478,7 +476,7 @@ class GoogleApiToOpenApiConverter:
|
|
478
476
|
output_path: Path where the OpenAPI spec should be saved
|
479
477
|
"""
|
480
478
|
with open(output_path, "w", encoding="utf-8") as f:
|
481
|
-
json.dump(self.
|
479
|
+
json.dump(self._openapi_spec, f, indent=2)
|
482
480
|
logger.info("OpenAPI specification saved to %s", output_path)
|
483
481
|
|
484
482
|
|
@@ -13,10 +13,12 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from typing import Any
|
16
|
-
from typing import
|
16
|
+
from typing import Optional
|
17
|
+
from typing import Union
|
17
18
|
|
18
19
|
from google.genai import types
|
19
|
-
from
|
20
|
+
from langchain.agents import Tool
|
21
|
+
from langchain_core.tools import BaseTool
|
20
22
|
from typing_extensions import override
|
21
23
|
|
22
24
|
from . import _automatic_function_calling_util
|
@@ -24,63 +26,108 @@ from .function_tool import FunctionTool
|
|
24
26
|
|
25
27
|
|
26
28
|
class LangchainTool(FunctionTool):
|
27
|
-
"""
|
29
|
+
"""Adapter class that wraps a Langchain tool for use with ADK.
|
28
30
|
|
29
|
-
|
30
|
-
|
31
|
+
This adapter converts Langchain tools into a format compatible with Google's
|
32
|
+
generative AI function calling interface. It preserves the tool's name,
|
33
|
+
description, and functionality while adapting its schema.
|
34
|
+
|
35
|
+
The original tool's name and description can be overridden if needed.
|
36
|
+
|
37
|
+
Args:
|
38
|
+
tool: A Langchain tool to wrap (BaseTool or a tool with a .run method)
|
39
|
+
name: Optional override for the tool's name
|
40
|
+
description: Optional override for the tool's description
|
41
|
+
|
42
|
+
Examples:
|
43
|
+
```python
|
44
|
+
from langchain.tools import DuckDuckGoSearchTool
|
45
|
+
from google.genai.tools import LangchainTool
|
46
|
+
|
47
|
+
search_tool = DuckDuckGoSearchTool()
|
48
|
+
wrapped_tool = LangchainTool(search_tool)
|
49
|
+
```
|
31
50
|
"""
|
32
51
|
|
33
|
-
|
52
|
+
_langchain_tool: Union[BaseTool, object]
|
34
53
|
"""The wrapped langchain tool."""
|
35
54
|
|
36
|
-
def __init__(
|
37
|
-
|
38
|
-
|
39
|
-
|
55
|
+
def __init__(
|
56
|
+
self,
|
57
|
+
tool: Union[BaseTool, object],
|
58
|
+
name: Optional[str] = None,
|
59
|
+
description: Optional[str] = None,
|
60
|
+
):
|
61
|
+
# Check if the tool has a 'run' method
|
62
|
+
if not hasattr(tool, 'run') and not hasattr(tool, '_run'):
|
63
|
+
raise ValueError("Langchain tool must have a 'run' or '_run' method")
|
64
|
+
|
65
|
+
# Determine which function to use
|
66
|
+
func = tool._run if hasattr(tool, '_run') else tool.run
|
67
|
+
super().__init__(func)
|
68
|
+
|
69
|
+
self._langchain_tool = tool
|
70
|
+
|
71
|
+
# Set name: priority is 1) explicitly provided name, 2) tool's name, 3) default
|
72
|
+
if name is not None:
|
73
|
+
self.name = name
|
74
|
+
elif hasattr(tool, 'name') and tool.name:
|
40
75
|
self.name = tool.name
|
41
|
-
|
42
|
-
self.description = tool.description
|
76
|
+
# else: keep default from FunctionTool
|
43
77
|
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
78
|
+
# Set description: similar priority
|
79
|
+
if description is not None:
|
80
|
+
self.description = description
|
81
|
+
elif hasattr(tool, 'description') and tool.description:
|
82
|
+
self.description = tool.description
|
83
|
+
# else: keep default from FunctionTool
|
50
84
|
|
51
85
|
@override
|
52
86
|
def _get_declaration(self) -> types.FunctionDeclaration:
|
53
|
-
"""Build the function declaration for the tool.
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
)
|
67
|
-
if self.
|
68
|
-
tool_wrapper
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
87
|
+
"""Build the function declaration for the tool.
|
88
|
+
|
89
|
+
Returns:
|
90
|
+
A FunctionDeclaration object that describes the tool's interface.
|
91
|
+
|
92
|
+
Raises:
|
93
|
+
ValueError: If the tool schema cannot be correctly parsed.
|
94
|
+
"""
|
95
|
+
try:
|
96
|
+
# There are two types of tools:
|
97
|
+
# 1. BaseTool: the tool is defined in langchain_core.tools.
|
98
|
+
# 2. Other tools: the tool doesn't inherit any class but follow some
|
99
|
+
# conventions, like having a "run" method.
|
100
|
+
# Handle BaseTool type (preferred Langchain approach)
|
101
|
+
if isinstance(self._langchain_tool, BaseTool):
|
102
|
+
tool_wrapper = Tool(
|
103
|
+
name=self.name,
|
104
|
+
func=self.func,
|
105
|
+
description=self.description,
|
106
|
+
)
|
107
|
+
|
108
|
+
# Add schema if available
|
109
|
+
if (
|
110
|
+
hasattr(self._langchain_tool, 'args_schema')
|
111
|
+
and self._langchain_tool.args_schema
|
112
|
+
):
|
113
|
+
tool_wrapper.args_schema = self._langchain_tool.args_schema
|
114
|
+
|
115
|
+
return _automatic_function_calling_util.build_function_declaration_for_langchain(
|
116
|
+
False,
|
117
|
+
self.name,
|
118
|
+
self.description,
|
119
|
+
tool_wrapper.func,
|
120
|
+
getattr(tool_wrapper, 'args', None),
|
121
|
+
)
|
122
|
+
|
78
123
|
# Need to provide a way to override the function names and descriptions
|
79
124
|
# as the original function names are mostly ".run" and the descriptions
|
80
|
-
# may not meet users' needs
|
81
|
-
|
82
|
-
|
83
|
-
func=self.tool.run,
|
84
|
-
)
|
125
|
+
# may not meet users' needs
|
126
|
+
return _automatic_function_calling_util.build_function_declaration(
|
127
|
+
func=self._langchain_tool.run,
|
85
128
|
)
|
86
|
-
|
129
|
+
|
130
|
+
except Exception as e:
|
131
|
+
raise ValueError(
|
132
|
+
f'Failed to build function declaration for Langchain tool: {e}'
|
133
|
+
) from e
|
@@ -17,19 +17,25 @@ from __future__ import annotations
|
|
17
17
|
from typing import TYPE_CHECKING
|
18
18
|
|
19
19
|
from google.genai import types
|
20
|
+
from pydantic import BaseModel
|
21
|
+
from pydantic import Field
|
20
22
|
from typing_extensions import override
|
21
23
|
|
24
|
+
from ..memory.memory_entry import MemoryEntry
|
22
25
|
from .function_tool import FunctionTool
|
23
26
|
from .tool_context import ToolContext
|
24
27
|
|
25
28
|
if TYPE_CHECKING:
|
26
|
-
from ..memory.base_memory_service import MemoryResult
|
27
29
|
from ..models import LlmRequest
|
28
30
|
|
29
31
|
|
32
|
+
class LoadMemoryResponse(BaseModel):
|
33
|
+
memories: list[MemoryEntry] = Field(default_factory=list)
|
34
|
+
|
35
|
+
|
30
36
|
async def load_memory(
|
31
37
|
query: str, tool_context: ToolContext
|
32
|
-
) ->
|
38
|
+
) -> LoadMemoryResponse:
|
33
39
|
"""Loads the memory for the current user.
|
34
40
|
|
35
41
|
Args:
|
@@ -38,12 +44,15 @@ async def load_memory(
|
|
38
44
|
Returns:
|
39
45
|
A list of memory results.
|
40
46
|
"""
|
41
|
-
|
42
|
-
return
|
47
|
+
search_memory_response = await tool_context.search_memory(query)
|
48
|
+
return LoadMemoryResponse(memories=search_memory_response.memories)
|
43
49
|
|
44
50
|
|
45
51
|
class LoadMemoryTool(FunctionTool):
|
46
|
-
"""A tool that loads the memory for the current user.
|
52
|
+
"""A tool that loads the memory for the current user.
|
53
|
+
|
54
|
+
NOTE: Currently this tool only uses text part from the memory.
|
55
|
+
"""
|
47
56
|
|
48
57
|
def __init__(self):
|
49
58
|
super().__init__(load_memory)
|
@@ -15,7 +15,8 @@
|
|
15
15
|
__all__ = []
|
16
16
|
|
17
17
|
try:
|
18
|
-
from .conversion_utils import adk_to_mcp_tool_type
|
18
|
+
from .conversion_utils import adk_to_mcp_tool_type
|
19
|
+
from .conversion_utils import gemini_to_json_schema
|
19
20
|
from .mcp_tool import MCPTool
|
20
21
|
from .mcp_toolset import MCPToolset
|
21
22
|
|
@@ -30,7 +31,7 @@ except ImportError as e:
|
|
30
31
|
import logging
|
31
32
|
import sys
|
32
33
|
|
33
|
-
logger = logging.getLogger(__name__)
|
34
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
34
35
|
|
35
36
|
if sys.version_info < (3, 10):
|
36
37
|
logger.warning(
|
@@ -12,15 +12,22 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import asyncio
|
16
|
+
from contextlib import asynccontextmanager
|
15
17
|
from contextlib import AsyncExitStack
|
16
18
|
import functools
|
19
|
+
import logging
|
17
20
|
import sys
|
18
|
-
from typing import Any
|
21
|
+
from typing import Any
|
22
|
+
from typing import Optional
|
23
|
+
from typing import TextIO
|
24
|
+
|
19
25
|
import anyio
|
20
26
|
from pydantic import BaseModel
|
21
27
|
|
22
28
|
try:
|
23
|
-
from mcp import ClientSession
|
29
|
+
from mcp import ClientSession
|
30
|
+
from mcp import StdioServerParameters
|
24
31
|
from mcp.client.sse import sse_client
|
25
32
|
from mcp.client.stdio import stdio_client
|
26
33
|
except ImportError as e:
|
@@ -34,6 +41,8 @@ except ImportError as e:
|
|
34
41
|
else:
|
35
42
|
raise e
|
36
43
|
|
44
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
45
|
+
|
37
46
|
|
38
47
|
class SseServerParams(BaseModel):
|
39
48
|
"""Parameters for the MCP SSE connection.
|
@@ -108,6 +117,45 @@ def retry_on_closed_resource(async_reinit_func_name: str):
|
|
108
117
|
return decorator
|
109
118
|
|
110
119
|
|
120
|
+
@asynccontextmanager
|
121
|
+
async def tracked_stdio_client(server, errlog, process=None):
|
122
|
+
"""A wrapper around stdio_client that ensures proper process tracking and cleanup."""
|
123
|
+
our_process = process
|
124
|
+
|
125
|
+
# If no process was provided, create one
|
126
|
+
if our_process is None:
|
127
|
+
our_process = await asyncio.create_subprocess_exec(
|
128
|
+
server.command,
|
129
|
+
*server.args,
|
130
|
+
stdin=asyncio.subprocess.PIPE,
|
131
|
+
stdout=asyncio.subprocess.PIPE,
|
132
|
+
stderr=errlog,
|
133
|
+
)
|
134
|
+
|
135
|
+
# Use the original stdio_client, but ensure process cleanup
|
136
|
+
try:
|
137
|
+
async with stdio_client(server=server, errlog=errlog) as client:
|
138
|
+
yield client, our_process
|
139
|
+
finally:
|
140
|
+
# Ensure the process is properly terminated if it still exists
|
141
|
+
if our_process and our_process.returncode is None:
|
142
|
+
try:
|
143
|
+
logger.info(
|
144
|
+
f'Terminating process {our_process.pid} from tracked_stdio_client'
|
145
|
+
)
|
146
|
+
our_process.terminate()
|
147
|
+
try:
|
148
|
+
await asyncio.wait_for(our_process.wait(), timeout=3.0)
|
149
|
+
except asyncio.TimeoutError:
|
150
|
+
# Force kill if it doesn't terminate quickly
|
151
|
+
if our_process.returncode is None:
|
152
|
+
logger.warning(f'Forcing kill of process {our_process.pid}')
|
153
|
+
our_process.kill()
|
154
|
+
except ProcessLookupError:
|
155
|
+
# Process already gone, that's fine
|
156
|
+
logger.info(f'Process {our_process.pid} already terminated')
|
157
|
+
|
158
|
+
|
111
159
|
class MCPSessionManager:
|
112
160
|
"""Manages MCP client sessions.
|
113
161
|
|
@@ -120,7 +168,7 @@ class MCPSessionManager:
|
|
120
168
|
connection_params: StdioServerParameters | SseServerParams,
|
121
169
|
exit_stack: AsyncExitStack,
|
122
170
|
errlog: TextIO = sys.stderr,
|
123
|
-
)
|
171
|
+
):
|
124
172
|
"""Initializes the MCP session manager.
|
125
173
|
|
126
174
|
Example usage:
|
@@ -138,25 +186,39 @@ class MCPSessionManager:
|
|
138
186
|
errlog: (Optional) TextIO stream for error logging. Use only for
|
139
187
|
initializing a local stdio MCP session.
|
140
188
|
"""
|
141
|
-
|
142
|
-
self.
|
143
|
-
self.
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
|
189
|
+
|
190
|
+
self._connection_params = connection_params
|
191
|
+
self._exit_stack = exit_stack
|
192
|
+
self._errlog = errlog
|
193
|
+
self._process = None # Track the subprocess
|
194
|
+
self._active_processes = set() # Track all processes created
|
195
|
+
self._active_file_handles = set() # Track file handles
|
196
|
+
|
197
|
+
async def create_session(
|
198
|
+
self,
|
199
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
200
|
+
"""Creates a new MCP session and tracks the associated process."""
|
201
|
+
session, process = await self._initialize_session(
|
202
|
+
connection_params=self._connection_params,
|
203
|
+
exit_stack=self._exit_stack,
|
204
|
+
errlog=self._errlog,
|
150
205
|
)
|
206
|
+
self._process = process # Store reference to process
|
207
|
+
|
208
|
+
# Track the process
|
209
|
+
if process:
|
210
|
+
self._active_processes.add(process)
|
211
|
+
|
212
|
+
return session, process
|
151
213
|
|
152
214
|
@classmethod
|
153
|
-
async def
|
215
|
+
async def _initialize_session(
|
154
216
|
cls,
|
155
217
|
*,
|
156
218
|
connection_params: StdioServerParameters | SseServerParams,
|
157
219
|
exit_stack: AsyncExitStack,
|
158
220
|
errlog: TextIO = sys.stderr,
|
159
|
-
) -> ClientSession:
|
221
|
+
) -> tuple[ClientSession, Optional[asyncio.subprocess.Process]]:
|
160
222
|
"""Initializes an MCP client session.
|
161
223
|
|
162
224
|
Args:
|
@@ -168,9 +230,17 @@ class MCPSessionManager:
|
|
168
230
|
Returns:
|
169
231
|
ClientSession: The initialized MCP client session.
|
170
232
|
"""
|
233
|
+
process = None
|
234
|
+
|
171
235
|
if isinstance(connection_params, StdioServerParameters):
|
172
|
-
|
236
|
+
# For stdio connections, we need to track the subprocess
|
237
|
+
client, process = await cls._create_stdio_client(
|
238
|
+
server=connection_params,
|
239
|
+
errlog=errlog,
|
240
|
+
exit_stack=exit_stack,
|
241
|
+
)
|
173
242
|
elif isinstance(connection_params, SseServerParams):
|
243
|
+
# For SSE connections, create the client without a subprocess
|
174
244
|
client = sse_client(
|
175
245
|
url=connection_params.url,
|
176
246
|
headers=connection_params.headers,
|
@@ -184,7 +254,74 @@ class MCPSessionManager:
|
|
184
254
|
f' {connection_params}'
|
185
255
|
)
|
186
256
|
|
257
|
+
# Create the session with the client
|
187
258
|
transports = await exit_stack.enter_async_context(client)
|
188
259
|
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
189
260
|
await session.initialize()
|
190
|
-
|
261
|
+
|
262
|
+
return session, process
|
263
|
+
|
264
|
+
@staticmethod
|
265
|
+
async def _create_stdio_client(
|
266
|
+
server: StdioServerParameters,
|
267
|
+
errlog: TextIO,
|
268
|
+
exit_stack: AsyncExitStack,
|
269
|
+
) -> tuple[Any, asyncio.subprocess.Process]:
|
270
|
+
"""Create stdio client and return both the client and process.
|
271
|
+
|
272
|
+
This implementation adapts to how the MCP stdio_client is created.
|
273
|
+
The actual implementation may need to be adjusted based on the MCP library
|
274
|
+
structure.
|
275
|
+
"""
|
276
|
+
# Create the subprocess directly so we can track it
|
277
|
+
process = await asyncio.create_subprocess_exec(
|
278
|
+
server.command,
|
279
|
+
*server.args,
|
280
|
+
stdin=asyncio.subprocess.PIPE,
|
281
|
+
stdout=asyncio.subprocess.PIPE,
|
282
|
+
stderr=errlog,
|
283
|
+
)
|
284
|
+
|
285
|
+
# Create the stdio client using the MCP library
|
286
|
+
try:
|
287
|
+
# Method 1: Try using the existing process if stdio_client supports it
|
288
|
+
client = stdio_client(server=server, errlog=errlog, process=process)
|
289
|
+
except TypeError:
|
290
|
+
# Method 2: If the above doesn't work, let stdio_client create its own process
|
291
|
+
# and we'll need to terminate both processes later
|
292
|
+
logger.warning(
|
293
|
+
'Using stdio_client with its own process - may lead to duplicate'
|
294
|
+
' processes'
|
295
|
+
)
|
296
|
+
client = stdio_client(server=server, errlog=errlog)
|
297
|
+
|
298
|
+
return client, process
|
299
|
+
|
300
|
+
async def _emergency_cleanup(self):
|
301
|
+
"""Perform emergency cleanup of resources when normal cleanup fails."""
|
302
|
+
logger.info('Performing emergency cleanup of MCPSessionManager resources')
|
303
|
+
|
304
|
+
# Clean up any tracked processes
|
305
|
+
for proc in list(self._active_processes):
|
306
|
+
try:
|
307
|
+
if proc and proc.returncode is None:
|
308
|
+
logger.info(f'Emergency termination of process {proc.pid}')
|
309
|
+
proc.terminate()
|
310
|
+
try:
|
311
|
+
await asyncio.wait_for(proc.wait(), timeout=1.0)
|
312
|
+
except asyncio.TimeoutError:
|
313
|
+
logger.warning(f"Process {proc.pid} didn't terminate, forcing kill")
|
314
|
+
proc.kill()
|
315
|
+
self._active_processes.remove(proc)
|
316
|
+
except Exception as e:
|
317
|
+
logger.error(f'Error during process cleanup: {e}')
|
318
|
+
|
319
|
+
# Clean up any tracked file handles
|
320
|
+
for handle in list(self._active_file_handles):
|
321
|
+
try:
|
322
|
+
if not handle.closed:
|
323
|
+
logger.info('Closing file handle')
|
324
|
+
handle.close()
|
325
|
+
self._active_file_handles.remove(handle)
|
326
|
+
except Exception as e:
|
327
|
+
logger.error(f'Error closing file handle: {e}')
|