google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. google/adk/agents/active_streaming_tool.py +1 -0
  2. google/adk/agents/base_agent.py +91 -47
  3. google/adk/agents/base_agent.py.orig +330 -0
  4. google/adk/agents/callback_context.py +4 -9
  5. google/adk/agents/invocation_context.py +1 -0
  6. google/adk/agents/langgraph_agent.py +1 -0
  7. google/adk/agents/live_request_queue.py +1 -0
  8. google/adk/agents/llm_agent.py +172 -35
  9. google/adk/agents/loop_agent.py +1 -1
  10. google/adk/agents/parallel_agent.py +7 -0
  11. google/adk/agents/readonly_context.py +7 -1
  12. google/adk/agents/run_config.py +5 -1
  13. google/adk/agents/sequential_agent.py +31 -0
  14. google/adk/agents/transcription_entry.py +5 -2
  15. google/adk/artifacts/base_artifact_service.py +5 -10
  16. google/adk/artifacts/gcs_artifact_service.py +9 -9
  17. google/adk/artifacts/in_memory_artifact_service.py +6 -6
  18. google/adk/auth/auth_credential.py +9 -5
  19. google/adk/auth/auth_preprocessor.py +7 -1
  20. google/adk/auth/auth_tool.py +3 -4
  21. google/adk/cli/agent_graph.py +5 -5
  22. google/adk/cli/browser/index.html +2 -2
  23. google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
  24. google/adk/cli/cli.py +7 -7
  25. google/adk/cli/cli_deploy.py +7 -2
  26. google/adk/cli/cli_eval.py +181 -106
  27. google/adk/cli/cli_tools_click.py +147 -62
  28. google/adk/cli/fast_api.py +340 -158
  29. google/adk/cli/fast_api.py.orig +822 -0
  30. google/adk/cli/utils/common.py +23 -0
  31. google/adk/cli/utils/evals.py +83 -1
  32. google/adk/cli/utils/logs.py +13 -5
  33. google/adk/code_executors/__init__.py +3 -1
  34. google/adk/code_executors/built_in_code_executor.py +52 -0
  35. google/adk/evaluation/__init__.py +1 -1
  36. google/adk/evaluation/agent_evaluator.py +168 -128
  37. google/adk/evaluation/eval_case.py +102 -0
  38. google/adk/evaluation/eval_set.py +37 -0
  39. google/adk/evaluation/eval_sets_manager.py +42 -0
  40. google/adk/evaluation/evaluation_constants.py +1 -0
  41. google/adk/evaluation/evaluation_generator.py +89 -114
  42. google/adk/evaluation/evaluator.py +56 -0
  43. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  44. google/adk/evaluation/response_evaluator.py +107 -3
  45. google/adk/evaluation/trajectory_evaluator.py +83 -2
  46. google/adk/events/event.py +7 -1
  47. google/adk/events/event_actions.py +7 -1
  48. google/adk/examples/example.py +1 -0
  49. google/adk/examples/example_util.py +3 -2
  50. google/adk/flows/__init__.py +0 -1
  51. google/adk/flows/llm_flows/_code_execution.py +19 -11
  52. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  53. google/adk/flows/llm_flows/base_llm_flow.py +86 -22
  54. google/adk/flows/llm_flows/basic.py +3 -0
  55. google/adk/flows/llm_flows/functions.py +10 -9
  56. google/adk/flows/llm_flows/instructions.py +28 -9
  57. google/adk/flows/llm_flows/single_flow.py +1 -1
  58. google/adk/memory/__init__.py +1 -1
  59. google/adk/memory/_utils.py +23 -0
  60. google/adk/memory/base_memory_service.py +25 -21
  61. google/adk/memory/base_memory_service.py.orig +76 -0
  62. google/adk/memory/in_memory_memory_service.py +59 -27
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
  65. google/adk/models/anthropic_llm.py +36 -11
  66. google/adk/models/base_llm.py +45 -4
  67. google/adk/models/gemini_llm_connection.py +15 -2
  68. google/adk/models/google_llm.py +9 -44
  69. google/adk/models/google_llm.py.orig +305 -0
  70. google/adk/models/lite_llm.py +94 -38
  71. google/adk/models/llm_request.py +1 -1
  72. google/adk/models/llm_response.py +15 -3
  73. google/adk/models/registry.py +1 -1
  74. google/adk/runners.py +68 -44
  75. google/adk/sessions/__init__.py +1 -1
  76. google/adk/sessions/_session_util.py +14 -0
  77. google/adk/sessions/base_session_service.py +8 -32
  78. google/adk/sessions/database_session_service.py +58 -61
  79. google/adk/sessions/in_memory_session_service.py +108 -26
  80. google/adk/sessions/session.py +4 -0
  81. google/adk/sessions/vertex_ai_session_service.py +23 -45
  82. google/adk/telemetry.py +3 -0
  83. google/adk/tools/__init__.py +4 -7
  84. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  85. google/adk/tools/_memory_entry_utils.py +30 -0
  86. google/adk/tools/agent_tool.py +16 -13
  87. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +58 -0
  93. google/adk/tools/enterprise_search_tool.py +65 -0
  94. google/adk/tools/function_parameter_parse_util.py +2 -2
  95. google/adk/tools/google_api_tool/__init__.py +18 -70
  96. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  97. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  98. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  99. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  100. google/adk/tools/langchain_tool.py +96 -49
  101. google/adk/tools/load_artifacts_tool.py +4 -4
  102. google/adk/tools/load_memory_tool.py +16 -5
  103. google/adk/tools/mcp_tool/__init__.py +3 -2
  104. google/adk/tools/mcp_tool/conversion_utils.py +1 -1
  105. google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
  106. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  107. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  108. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  109. google/adk/tools/openapi_tool/common/common.py +2 -5
  110. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  111. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
  112. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  113. google/adk/tools/preload_memory_tool.py +27 -18
  114. google/adk/tools/retrieval/__init__.py +1 -1
  115. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  116. google/adk/tools/tool_context.py +4 -4
  117. google/adk/tools/toolbox_toolset.py +79 -0
  118. google/adk/tools/transfer_to_agent_tool.py +0 -1
  119. google/adk/version.py +1 -1
  120. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  121. google_adk-1.0.0.dist-info/RECORD +195 -0
  122. google/adk/agents/remote_agent.py +0 -50
  123. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  124. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  125. google/adk/tools/toolbox_tool.py +0 -46
  126. google_adk-0.4.0.dist-info/RECORD +0 -179
  127. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  128. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  129. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,93 +12,63 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ import asyncio
15
16
  from contextlib import AsyncExitStack
17
+ import logging
18
+ import os
19
+ import signal
16
20
  import sys
17
- from types import TracebackType
18
- from typing import List, Optional, TextIO, Tuple, Type
21
+ from typing import List
22
+ from typing import Optional
23
+ from typing import TextIO
24
+ from typing import Union
19
25
 
20
- from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
26
+ from typing_extensions import override
27
+
28
+ from ...agents.readonly_context import ReadonlyContext
29
+ from ..base_tool import BaseTool
30
+ from ..base_toolset import BaseToolset
31
+ from ..base_toolset import ToolPredicate
32
+ from .mcp_session_manager import MCPSessionManager
33
+ from .mcp_session_manager import retry_on_closed_resource
34
+ from .mcp_session_manager import SseServerParams
21
35
 
22
36
  # Attempt to import MCP Tool from the MCP library, and hints user to upgrade
23
37
  # their Python version to 3.10 if it fails.
24
38
  try:
25
- from mcp import ClientSession, StdioServerParameters
39
+ from mcp import ClientSession
40
+ from mcp import StdioServerParameters
26
41
  from mcp.types import ListToolsResult
27
42
  except ImportError as e:
28
43
  import sys
29
44
 
30
45
  if sys.version_info < (3, 10):
31
46
  raise ImportError(
32
- 'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
33
- ' version.'
47
+ "MCP Tool requires Python 3.10 or above. Please upgrade your Python"
48
+ " version."
34
49
  ) from e
35
50
  else:
36
51
  raise e
37
52
 
38
53
  from .mcp_tool import MCPTool
39
54
 
55
+ logger = logging.getLogger("google_adk." + __name__)
56
+
40
57
 
41
- class MCPToolset:
58
+ class MCPToolset(BaseToolset):
42
59
  """Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
43
60
 
44
61
  Usage:
45
- Example 1: (using from_server helper):
46
62
  ```
47
- async def load_tools():
48
- return await MCPToolset.from_server(
49
- connection_params=StdioServerParameters(
50
- command='npx',
51
- args=["-y", "@modelcontextprotocol/server-filesystem"],
63
+ root_agent = LlmAgent(
64
+ tools=MCPToolset(
65
+ connection_params=StdioServerParameters(
66
+ command='npx',
67
+ args=["-y", "@modelcontextprotocol/server-filesystem"],
52
68
  )
53
- )
54
-
55
- # Use the tools in an LLM agent
56
- tools, exit_stack = await load_tools()
57
- agent = LlmAgent(
58
- tools=tools
59
- )
60
- ...
61
- await exit_stack.aclose()
62
- ```
63
-
64
- Example 2: (using `async with`):
65
-
66
- ```
67
- async def load_tools():
68
- async with MCPToolset(
69
- connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
70
- ) as toolset:
71
- tools = await toolset.load_tools()
72
-
73
- agent = LlmAgent(
74
- ...
75
- tools=tools
76
69
  )
70
+ )
77
71
  ```
78
-
79
- Example 3: (provide AsyncExitStack):
80
- ```
81
- async def load_tools():
82
- async_exit_stack = AsyncExitStack()
83
- toolset = MCPToolset(
84
- connection_params=StdioServerParameters(...),
85
- )
86
- async_exit_stack.enter_async_context(toolset)
87
- tools = await toolset.load_tools()
88
- agent = LlmAgent(
89
- ...
90
- tools=tools
91
- )
92
- ...
93
- await async_exit_stack.aclose()
94
-
95
- ```
96
-
97
- Attributes:
98
- connection_params: The connection parameters to the MCP server. Can be
99
- either `StdioServerParameters` or `SseServerParams`.
100
- exit_stack: The async exit stack to manage the connection to the MCP server.
101
- session: The MCP session being initialized with the connection.
102
72
  """
103
73
 
104
74
  def __init__(
@@ -106,161 +76,151 @@ class MCPToolset:
106
76
  *,
107
77
  connection_params: StdioServerParameters | SseServerParams,
108
78
  errlog: TextIO = sys.stderr,
109
- exit_stack=AsyncExitStack(),
79
+ tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
110
80
  ):
111
81
  """Initializes the MCPToolset.
112
82
 
113
- Usage:
114
- Example 1: (using from_server helper):
115
- ```
116
- async def load_tools():
117
- return await MCPToolset.from_server(
118
- connection_params=StdioServerParameters(
119
- command='npx',
120
- args=["-y", "@modelcontextprotocol/server-filesystem"],
121
- )
122
- )
123
-
124
- # Use the tools in an LLM agent
125
- tools, exit_stack = await load_tools()
126
- agent = LlmAgent(
127
- tools=tools
128
- )
129
- ...
130
- await exit_stack.aclose()
131
- ```
132
-
133
- Example 2: (using `async with`):
134
-
135
- ```
136
- async def load_tools():
137
- async with MCPToolset(
138
- connection_params=SseServerParams(url="http://0.0.0.0:8090/sse")
139
- ) as toolset:
140
- tools = await toolset.load_tools()
141
-
142
- agent = LlmAgent(
143
- ...
144
- tools=tools
145
- )
146
- ```
147
-
148
- Example 3: (provide AsyncExitStack):
149
- ```
150
- async def load_tools():
151
- async_exit_stack = AsyncExitStack()
152
- toolset = MCPToolset(
153
- connection_params=StdioServerParameters(...),
154
- )
155
- async_exit_stack.enter_async_context(toolset)
156
- tools = await toolset.load_tools()
157
- agent = LlmAgent(
158
- ...
159
- tools=tools
160
- )
161
- ...
162
- await async_exit_stack.aclose()
163
-
164
- ```
165
-
166
83
  Args:
167
84
  connection_params: The connection parameters to the MCP server. Can be:
168
85
  `StdioServerParameters` for using local mcp server (e.g. using `npx` or
169
86
  `python3`); or `SseServerParams` for a local/remote SSE server.
87
+ errlog: (Optional) TextIO stream for error logging. Use only for
88
+ initializing a local stdio MCP session.
170
89
  """
171
- if not connection_params:
172
- raise ValueError('Missing connection params in MCPToolset.')
173
- self.connection_params = connection_params
174
- self.errlog = errlog
175
- self.exit_stack = exit_stack
176
90
 
177
- self.session_manager = MCPSessionManager(
178
- connection_params=self.connection_params,
179
- exit_stack=self.exit_stack,
180
- errlog=self.errlog,
91
+ if not connection_params:
92
+ raise ValueError("Missing connection params in MCPToolset.")
93
+ self._connection_params = connection_params
94
+ self._errlog = errlog
95
+ self._exit_stack = AsyncExitStack()
96
+ self._creator_task_id = None
97
+ self._process_pid = None # Store the subprocess PID
98
+
99
+ self._session_manager = MCPSessionManager(
100
+ connection_params=self._connection_params,
101
+ exit_stack=self._exit_stack,
102
+ errlog=self._errlog,
181
103
  )
104
+ self._session = None
105
+ self.tool_filter = tool_filter
106
+ self._initialized = False
182
107
 
183
- @classmethod
184
- async def from_server(
185
- cls,
186
- *,
187
- connection_params: StdioServerParameters | SseServerParams,
188
- async_exit_stack: Optional[AsyncExitStack] = None,
189
- errlog: TextIO = sys.stderr,
190
- ) -> Tuple[List[MCPTool], AsyncExitStack]:
191
- """Retrieve all tools from the MCP connection.
192
-
193
- Usage:
194
- ```
195
- async def load_tools():
196
- tools, exit_stack = await MCPToolset.from_server(
197
- connection_params=StdioServerParameters(
198
- command='npx',
199
- args=["-y", "@modelcontextprotocol/server-filesystem"],
200
- )
201
- )
202
- ```
203
-
204
- Args:
205
- connection_params: The connection parameters to the MCP server.
206
- async_exit_stack: The async exit stack to use. If not provided, a new
207
- AsyncExitStack will be created.
108
+ async def _initialize(self) -> ClientSession:
109
+ """Connects to the MCP Server and initializes the ClientSession."""
110
+ # Store the current task ID when initializing
111
+ self._creator_task_id = id(asyncio.current_task())
112
+ self._session, process = await self._session_manager.create_session()
113
+ # Store the process PID if available
114
+ if process and hasattr(process, "pid"):
115
+ self._process_pid = process.pid
116
+ self._initialized = True
117
+ return self._session
118
+
119
+ def _is_selected(
120
+ self, tool: BaseTool, readonly_context: Optional[ReadonlyContext]
121
+ ) -> bool:
122
+ """Checks if a tool should be selected based on the tool filter."""
123
+ if self.tool_filter is None:
124
+ return True
125
+ if isinstance(self.tool_filter, ToolPredicate):
126
+ return self.tool_filter(tool, readonly_context)
127
+ if isinstance(self.tool_filter, list):
128
+ return tool.name in self.tool_filter
129
+ return False
130
+
131
+ @override
132
+ async def close(self):
133
+ """Safely closes the connection to MCP Server with guaranteed resource cleanup."""
134
+ if not self._initialized:
135
+ return # Nothing to close
136
+
137
+ logger.info("Closing MCP Toolset")
138
+
139
+ # Step 1: Try graceful shutdown of the session if it exists
140
+ if self._session:
141
+ try:
142
+ logger.info("Attempting graceful session shutdown")
143
+ await self._session.shutdown()
144
+ except Exception as e:
145
+ logger.warning(f"Session shutdown error (continuing cleanup): {e}")
146
+
147
+ # Step 2: Try to close the exit stack
148
+ try:
149
+ logger.info("Closing AsyncExitStack")
150
+ await self._exit_stack.aclose()
151
+ # If we get here, the exit stack closed successfully
152
+ logger.info("AsyncExitStack closed successfully")
153
+ return
154
+ except RuntimeError as e:
155
+ if "Attempted to exit cancel scope in a different task" in str(e):
156
+ logger.warning("Task mismatch during shutdown - using fallback cleanup")
157
+ # Continue to manual cleanup
158
+ else:
159
+ logger.error(f"Unexpected RuntimeError: {e}")
160
+ # Continue to manual cleanup
161
+ except Exception as e:
162
+ logger.error(f"Error during exit stack closure: {e}")
163
+ # Continue to manual cleanup
208
164
 
209
- Returns:
210
- A tuple of the list of MCPTools and the AsyncExitStack.
211
- - tools: The list of MCPTools.
212
- - async_exit_stack: The AsyncExitStack used to manage the connection to
213
- the MCP server. Use `await async_exit_stack.aclose()` to close the
214
- connection when server shuts down.
215
- """
216
- async_exit_stack = async_exit_stack or AsyncExitStack()
217
- toolset = cls(
218
- connection_params=connection_params,
219
- exit_stack=async_exit_stack,
220
- errlog=errlog,
221
- )
165
+ # Step 3: Manual cleanup of the subprocess if we have its PID
166
+ if self._process_pid:
167
+ await self._ensure_process_terminated(self._process_pid)
222
168
 
223
- await async_exit_stack.enter_async_context(toolset)
224
- tools = await toolset.load_tools()
225
- return (tools, async_exit_stack)
169
+ # Step 4: Ask the session manager to do any additional cleanup it can
170
+ await self._session_manager._emergency_cleanup()
226
171
 
227
- async def _initialize(self) -> ClientSession:
228
- """Connects to the MCP Server and initializes the ClientSession."""
229
- self.session = await self.session_manager.create_session()
230
- return self.session
172
+ async def _ensure_process_terminated(self, pid):
173
+ """Ensure a process is terminated using its PID."""
174
+ try:
175
+ # Check if process exists
176
+ os.kill(pid, 0) # This just checks if the process exists
177
+
178
+ logger.info(f"Terminating process with PID {pid}")
179
+ # First try SIGTERM for graceful shutdown
180
+ os.kill(pid, signal.SIGTERM)
181
+
182
+ # Give it a moment to terminate
183
+ for _ in range(30): # wait up to 3 seconds
184
+ await asyncio.sleep(0.1)
185
+ try:
186
+ os.kill(pid, 0) # Process still exists
187
+ except ProcessLookupError:
188
+ logger.info(f"Process {pid} terminated successfully")
189
+ return
190
+
191
+ # If we get here, process didn't terminate gracefully
192
+ logger.warning(
193
+ f"Process {pid} didn't terminate gracefully, using SIGKILL"
194
+ )
195
+ os.kill(pid, signal.SIGKILL)
231
196
 
232
- async def _exit(self):
233
- """Closes the connection to MCP Server."""
234
- await self.exit_stack.aclose()
197
+ except ProcessLookupError:
198
+ logger.info(f"Process {pid} already terminated")
199
+ except Exception as e:
200
+ logger.error(f"Error terminating process {pid}: {e}")
235
201
 
236
- @retry_on_closed_resource('_initialize')
237
- async def load_tools(self) -> List[MCPTool]:
202
+ @retry_on_closed_resource("_initialize")
203
+ @override
204
+ async def get_tools(
205
+ self,
206
+ readonly_context: Optional[ReadonlyContext] = None,
207
+ ) -> List[MCPTool]:
238
208
  """Loads all tools from the MCP Server.
239
209
 
240
210
  Returns:
241
211
  A list of MCPTools imported from the MCP Server.
242
212
  """
243
- tools_response: ListToolsResult = await self.session.list_tools()
244
- return [
245
- MCPTool(
246
- mcp_tool=tool,
247
- mcp_session=self.session,
248
- mcp_session_manager=self.session_manager,
249
- )
250
- for tool in tools_response.tools
251
- ]
252
-
253
- async def __aenter__(self):
254
- try:
213
+ if not self._session:
255
214
  await self._initialize()
256
- return self
257
- except Exception as e:
258
- raise e
215
+ tools_response: ListToolsResult = await self._session.list_tools()
216
+ tools = []
217
+ for tool in tools_response.tools:
218
+ mcp_tool = MCPTool(
219
+ mcp_tool=tool,
220
+ mcp_session=self._session,
221
+ mcp_session_manager=self._session_manager,
222
+ )
259
223
 
260
- async def __aexit__(
261
- self,
262
- exc_type: Optional[Type[BaseException]],
263
- exc: Optional[BaseException],
264
- tb: Optional[TracebackType],
265
- ) -> None:
266
- await self._exit()
224
+ if self._is_selected(mcp_tool, readonly_context):
225
+ tools.append(mcp_tool)
226
+ return tools
@@ -14,11 +14,7 @@
14
14
 
15
15
  import keyword
16
16
  import re
17
- from typing import Any
18
- from typing import Dict
19
- from typing import List
20
- from typing import Optional
21
- from typing import Union
17
+ from typing import Any, Dict, List, Optional, Union
22
18
 
23
19
  from fastapi.openapi.models import Response
24
20
  from fastapi.openapi.models import Schema
@@ -100,6 +96,7 @@ class ApiParameter(BaseModel):
100
96
  py_name: Optional[str] = ''
101
97
  type_value: type[Any] = Field(default=None, init_var=False)
102
98
  type_hint: str = Field(default=None, init_var=False)
99
+ required: bool = False
103
100
 
104
101
  def model_post_init(self, _: Any):
105
102
  self.py_name = (
@@ -20,18 +20,23 @@ from typing import Final
20
20
  from typing import List
21
21
  from typing import Literal
22
22
  from typing import Optional
23
+ from typing import Union
23
24
 
25
+ from typing_extensions import override
24
26
  import yaml
25
27
 
28
+ from ....agents.readonly_context import ReadonlyContext
26
29
  from ....auth.auth_credential import AuthCredential
27
30
  from ....auth.auth_schemes import AuthScheme
31
+ from ...base_toolset import BaseToolset
32
+ from ...base_toolset import ToolPredicate
28
33
  from .openapi_spec_parser import OpenApiSpecParser
29
34
  from .rest_api_tool import RestApiTool
30
35
 
31
- logger = logging.getLogger(__name__)
36
+ logger = logging.getLogger("google_adk." + __name__)
32
37
 
33
38
 
34
- class OpenAPIToolset:
39
+ class OpenAPIToolset(BaseToolset):
35
40
  """Class for parsing OpenAPI spec into a list of RestApiTool.
36
41
 
37
42
  Usage:
@@ -61,6 +66,7 @@ class OpenAPIToolset:
61
66
  spec_str_type: Literal["json", "yaml"] = "json",
62
67
  auth_scheme: Optional[AuthScheme] = None,
63
68
  auth_credential: Optional[AuthCredential] = None,
69
+ tool_filter: Optional[Union[ToolPredicate, List[str]]] = None,
64
70
  ):
65
71
  """Initializes the OpenAPIToolset.
66
72
 
@@ -94,31 +100,46 @@ class OpenAPIToolset:
94
100
  auth_credential: The auth credential to use for all tools. Use
95
101
  AuthCredential or use helpers in
96
102
  `google.adk.tools.openapi_tool.auth.auth_helpers`
103
+ tool_filter: The filter used to filter the tools in the toolset. It can be
104
+ either a tool predicate or a list of tool names of the tools to expose.
97
105
  """
98
106
  if not spec_dict:
99
107
  spec_dict = self._load_spec(spec_str, spec_str_type)
100
- self.tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
108
+ self._tools: Final[List[RestApiTool]] = list(self._parse(spec_dict))
101
109
  if auth_scheme or auth_credential:
102
110
  self._configure_auth_all(auth_scheme, auth_credential)
111
+ self.tool_filter = tool_filter
103
112
 
104
113
  def _configure_auth_all(
105
114
  self, auth_scheme: AuthScheme, auth_credential: AuthCredential
106
115
  ):
107
116
  """Configure auth scheme and credential for all tools."""
108
117
 
109
- for tool in self.tools:
118
+ for tool in self._tools:
110
119
  if auth_scheme:
111
120
  tool.configure_auth_scheme(auth_scheme)
112
121
  if auth_credential:
113
122
  tool.configure_auth_credential(auth_credential)
114
123
 
115
- def get_tools(self) -> List[RestApiTool]:
124
+ @override
125
+ async def get_tools(
126
+ self, readonly_context: Optional[ReadonlyContext] = None
127
+ ) -> List[RestApiTool]:
116
128
  """Get all tools in the toolset."""
117
- return self.tools
129
+ return [
130
+ tool
131
+ for tool in self._tools
132
+ if self.tool_filter is None
133
+ or (
134
+ self.tool_filter(tool, readonly_context)
135
+ if isinstance(self.tool_filter, ToolPredicate)
136
+ else tool.name in self.tool_filter
137
+ )
138
+ ]
118
139
 
119
140
  def get_tool(self, tool_name: str) -> Optional[RestApiTool]:
120
141
  """Get a tool by name."""
121
- matching_tool = filter(lambda t: t.name == tool_name, self.tools)
142
+ matching_tool = filter(lambda t: t.name == tool_name, self._tools)
122
143
  return next(matching_tool, None)
123
144
 
124
145
  def _load_spec(
@@ -142,3 +163,7 @@ class OpenAPIToolset:
142
163
  logger.info("Parsed tool: %s", tool.name)
143
164
  tools.append(tool)
144
165
  return tools
166
+
167
+ @override
168
+ async def close(self):
169
+ pass