google-adk 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +4 -4
- google/adk/agents/callback_context.py +0 -1
- google/adk/agents/invocation_context.py +1 -1
- google/adk/agents/remote_agent.py +1 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/auth/auth_credential.py +2 -1
- google/adk/auth/auth_handler.py +7 -3
- google/adk/auth/auth_preprocessor.py +2 -2
- google/adk/auth/auth_tool.py +1 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-SLIAU2JL.js → main-HWIBUY2R.js} +69 -69
- google/adk/cli/cli_create.py +279 -0
- google/adk/cli/cli_deploy.py +10 -1
- google/adk/cli/cli_eval.py +3 -3
- google/adk/cli/cli_tools_click.py +95 -19
- google/adk/cli/fast_api.py +57 -16
- google/adk/cli/utils/envs.py +0 -3
- google/adk/cli/utils/evals.py +2 -2
- google/adk/evaluation/agent_evaluator.py +2 -2
- google/adk/evaluation/evaluation_generator.py +4 -4
- google/adk/evaluation/response_evaluator.py +17 -5
- google/adk/evaluation/trajectory_evaluator.py +4 -5
- google/adk/events/event.py +3 -3
- google/adk/flows/llm_flows/_nl_planning.py +10 -4
- google/adk/flows/llm_flows/agent_transfer.py +1 -1
- google/adk/flows/llm_flows/base_llm_flow.py +1 -1
- google/adk/flows/llm_flows/contents.py +2 -2
- google/adk/flows/llm_flows/functions.py +1 -3
- google/adk/flows/llm_flows/instructions.py +2 -2
- google/adk/models/gemini_llm_connection.py +2 -2
- google/adk/models/lite_llm.py +51 -34
- google/adk/models/llm_response.py +10 -1
- google/adk/planners/built_in_planner.py +1 -0
- google/adk/planners/plan_re_act_planner.py +2 -2
- google/adk/runners.py +1 -1
- google/adk/sessions/database_session_service.py +91 -26
- google/adk/sessions/state.py +2 -2
- google/adk/telemetry.py +2 -2
- google/adk/tools/agent_tool.py +2 -3
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -2
- google/adk/tools/base_tool.py +1 -1
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +74 -1
- google/adk/tools/google_api_tool/google_api_tool_set.py +12 -9
- google/adk/tools/google_api_tool/google_api_tool_sets.py +91 -34
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +3 -1
- google/adk/tools/load_artifacts_tool.py +1 -1
- google/adk/tools/load_memory_tool.py +25 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +176 -0
- google/adk/tools/mcp_tool/mcp_tool.py +15 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +31 -37
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +4 -4
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +1 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -12
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +47 -9
- google/adk/tools/toolbox_tool.py +1 -1
- google/adk/version.py +1 -1
- google_adk-0.3.0.dist-info/METADATA +235 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/RECORD +62 -60
- google_adk-0.1.1.dist-info/METADATA +0 -181
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/WHEEL +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -16,18 +16,26 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
from typing import TYPE_CHECKING
|
18
18
|
|
19
|
+
from google.genai import types
|
19
20
|
from typing_extensions import override
|
20
21
|
|
21
22
|
from .function_tool import FunctionTool
|
22
23
|
from .tool_context import ToolContext
|
23
24
|
|
24
25
|
if TYPE_CHECKING:
|
25
|
-
from ..models import LlmRequest
|
26
26
|
from ..memory.base_memory_service import MemoryResult
|
27
|
+
from ..models import LlmRequest
|
27
28
|
|
28
29
|
|
29
30
|
def load_memory(query: str, tool_context: ToolContext) -> 'list[MemoryResult]':
|
30
|
-
"""Loads the memory for the current user.
|
31
|
+
"""Loads the memory for the current user.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
query: The query to load the memory for.
|
35
|
+
|
36
|
+
Returns:
|
37
|
+
A list of memory results.
|
38
|
+
"""
|
31
39
|
response = tool_context.search_memory(query)
|
32
40
|
return response.memories
|
33
41
|
|
@@ -38,6 +46,21 @@ class LoadMemoryTool(FunctionTool):
|
|
38
46
|
def __init__(self):
|
39
47
|
super().__init__(load_memory)
|
40
48
|
|
49
|
+
@override
|
50
|
+
def _get_declaration(self) -> types.FunctionDeclaration | None:
|
51
|
+
return types.FunctionDeclaration(
|
52
|
+
name=self.name,
|
53
|
+
description=self.description,
|
54
|
+
parameters=types.Schema(
|
55
|
+
type=types.Type.OBJECT,
|
56
|
+
properties={
|
57
|
+
'query': types.Schema(
|
58
|
+
type=types.Type.STRING,
|
59
|
+
)
|
60
|
+
},
|
61
|
+
),
|
62
|
+
)
|
63
|
+
|
41
64
|
@override
|
42
65
|
async def process_llm_request(
|
43
66
|
self,
|
@@ -0,0 +1,176 @@
|
|
1
|
+
from contextlib import AsyncExitStack
|
2
|
+
import functools
|
3
|
+
import sys
|
4
|
+
from typing import Any, TextIO
|
5
|
+
import anyio
|
6
|
+
from pydantic import BaseModel
|
7
|
+
|
8
|
+
try:
|
9
|
+
from mcp import ClientSession, StdioServerParameters
|
10
|
+
from mcp.client.sse import sse_client
|
11
|
+
from mcp.client.stdio import stdio_client
|
12
|
+
except ImportError as e:
|
13
|
+
import sys
|
14
|
+
|
15
|
+
if sys.version_info < (3, 10):
|
16
|
+
raise ImportError(
|
17
|
+
'MCP Tool requires Python 3.10 or above. Please upgrade your Python'
|
18
|
+
' version.'
|
19
|
+
) from e
|
20
|
+
else:
|
21
|
+
raise e
|
22
|
+
|
23
|
+
|
24
|
+
class SseServerParams(BaseModel):
|
25
|
+
"""Parameters for the MCP SSE connection.
|
26
|
+
|
27
|
+
See MCP SSE Client documentation for more details.
|
28
|
+
https://github.com/modelcontextprotocol/python-sdk/blob/main/src/mcp/client/sse.py
|
29
|
+
"""
|
30
|
+
|
31
|
+
url: str
|
32
|
+
headers: dict[str, Any] | None = None
|
33
|
+
timeout: float = 5
|
34
|
+
sse_read_timeout: float = 60 * 5
|
35
|
+
|
36
|
+
|
37
|
+
def retry_on_closed_resource(async_reinit_func_name: str):
|
38
|
+
"""Decorator to automatically reinitialize session and retry action.
|
39
|
+
|
40
|
+
When MCP session was closed, the decorator will automatically recreate the
|
41
|
+
session and retry the action with the same parameters.
|
42
|
+
|
43
|
+
Note:
|
44
|
+
1. async_reinit_func_name is the name of the class member function that
|
45
|
+
reinitializes the MCP session.
|
46
|
+
2. Both the decorated function and the async_reinit_func_name must be async
|
47
|
+
functions.
|
48
|
+
|
49
|
+
Usage:
|
50
|
+
class MCPTool:
|
51
|
+
...
|
52
|
+
async def create_session(self):
|
53
|
+
self.session = ...
|
54
|
+
|
55
|
+
@retry_on_closed_resource('create_session')
|
56
|
+
async def use_session(self):
|
57
|
+
await self.session.call_tool()
|
58
|
+
|
59
|
+
Args:
|
60
|
+
async_reinit_func_name: The name of the async function to recreate session.
|
61
|
+
|
62
|
+
Returns:
|
63
|
+
The decorated function.
|
64
|
+
"""
|
65
|
+
|
66
|
+
def decorator(func):
|
67
|
+
@functools.wraps(
|
68
|
+
func
|
69
|
+
) # Preserves original function metadata (name, docstring)
|
70
|
+
async def wrapper(self, *args, **kwargs):
|
71
|
+
try:
|
72
|
+
return await func(self, *args, **kwargs)
|
73
|
+
except anyio.ClosedResourceError:
|
74
|
+
try:
|
75
|
+
if hasattr(self, async_reinit_func_name) and callable(
|
76
|
+
getattr(self, async_reinit_func_name)
|
77
|
+
):
|
78
|
+
async_init_fn = getattr(self, async_reinit_func_name)
|
79
|
+
await async_init_fn()
|
80
|
+
else:
|
81
|
+
raise ValueError(
|
82
|
+
f'Function {async_reinit_func_name} does not exist in decorated'
|
83
|
+
' class. Please check the function name in'
|
84
|
+
' retry_on_closed_resource decorator.'
|
85
|
+
)
|
86
|
+
except Exception as reinit_err:
|
87
|
+
raise RuntimeError(
|
88
|
+
f'Error reinitializing: {reinit_err}'
|
89
|
+
) from reinit_err
|
90
|
+
return await func(self, *args, **kwargs)
|
91
|
+
|
92
|
+
return wrapper
|
93
|
+
|
94
|
+
return decorator
|
95
|
+
|
96
|
+
|
97
|
+
class MCPSessionManager:
|
98
|
+
"""Manages MCP client sessions.
|
99
|
+
|
100
|
+
This class provides methods for creating and initializing MCP client sessions,
|
101
|
+
handling different connection parameters (Stdio and SSE).
|
102
|
+
"""
|
103
|
+
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
connection_params: StdioServerParameters | SseServerParams,
|
107
|
+
exit_stack: AsyncExitStack,
|
108
|
+
errlog: TextIO = sys.stderr,
|
109
|
+
) -> ClientSession:
|
110
|
+
"""Initializes the MCP session manager.
|
111
|
+
|
112
|
+
Example usage:
|
113
|
+
```
|
114
|
+
mcp_session_manager = MCPSessionManager(
|
115
|
+
connection_params=connection_params,
|
116
|
+
exit_stack=exit_stack,
|
117
|
+
)
|
118
|
+
session = await mcp_session_manager.create_session()
|
119
|
+
```
|
120
|
+
|
121
|
+
Args:
|
122
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
123
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
124
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
125
|
+
initializing a local stdio MCP session.
|
126
|
+
"""
|
127
|
+
self.connection_params = connection_params
|
128
|
+
self.exit_stack = exit_stack
|
129
|
+
self.errlog = errlog
|
130
|
+
|
131
|
+
async def create_session(self) -> ClientSession:
|
132
|
+
return await MCPSessionManager.initialize_session(
|
133
|
+
connection_params=self.connection_params,
|
134
|
+
exit_stack=self.exit_stack,
|
135
|
+
errlog=self.errlog,
|
136
|
+
)
|
137
|
+
|
138
|
+
@classmethod
|
139
|
+
async def initialize_session(
|
140
|
+
cls,
|
141
|
+
*,
|
142
|
+
connection_params: StdioServerParameters | SseServerParams,
|
143
|
+
exit_stack: AsyncExitStack,
|
144
|
+
errlog: TextIO = sys.stderr,
|
145
|
+
) -> ClientSession:
|
146
|
+
"""Initializes an MCP client session.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
connection_params: Parameters for the MCP connection (Stdio or SSE).
|
150
|
+
exit_stack: AsyncExitStack to manage the session lifecycle.
|
151
|
+
errlog: (Optional) TextIO stream for error logging. Use only for
|
152
|
+
initializing a local stdio MCP session.
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
ClientSession: The initialized MCP client session.
|
156
|
+
"""
|
157
|
+
if isinstance(connection_params, StdioServerParameters):
|
158
|
+
client = stdio_client(server=connection_params, errlog=errlog)
|
159
|
+
elif isinstance(connection_params, SseServerParams):
|
160
|
+
client = sse_client(
|
161
|
+
url=connection_params.url,
|
162
|
+
headers=connection_params.headers,
|
163
|
+
timeout=connection_params.timeout,
|
164
|
+
sse_read_timeout=connection_params.sse_read_timeout,
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
raise ValueError(
|
168
|
+
'Unable to initialize connection. Connection should be'
|
169
|
+
' StdioServerParameters or SseServerParams, but got'
|
170
|
+
f' {connection_params}'
|
171
|
+
)
|
172
|
+
|
173
|
+
transports = await exit_stack.enter_async_context(client)
|
174
|
+
session = await exit_stack.enter_async_context(ClientSession(*transports))
|
175
|
+
await session.initialize()
|
176
|
+
return session
|
@@ -17,6 +17,8 @@ from typing import Optional
|
|
17
17
|
from google.genai.types import FunctionDeclaration
|
18
18
|
from typing_extensions import override
|
19
19
|
|
20
|
+
from .mcp_session_manager import MCPSessionManager, retry_on_closed_resource
|
21
|
+
|
20
22
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
21
23
|
# their Python version to 3.10 if it fails.
|
22
24
|
try:
|
@@ -33,6 +35,7 @@ except ImportError as e:
|
|
33
35
|
else:
|
34
36
|
raise e
|
35
37
|
|
38
|
+
|
36
39
|
from ..base_tool import BaseTool
|
37
40
|
from ...auth.auth_credential import AuthCredential
|
38
41
|
from ...auth.auth_schemes import AuthScheme
|
@@ -51,6 +54,7 @@ class MCPTool(BaseTool):
|
|
51
54
|
self,
|
52
55
|
mcp_tool: McpBaseTool,
|
53
56
|
mcp_session: ClientSession,
|
57
|
+
mcp_session_manager: MCPSessionManager,
|
54
58
|
auth_scheme: Optional[AuthScheme] = None,
|
55
59
|
auth_credential: Optional[AuthCredential] | None = None,
|
56
60
|
):
|
@@ -79,10 +83,14 @@ class MCPTool(BaseTool):
|
|
79
83
|
self.description = mcp_tool.description if mcp_tool.description else ""
|
80
84
|
self.mcp_tool = mcp_tool
|
81
85
|
self.mcp_session = mcp_session
|
86
|
+
self.mcp_session_manager = mcp_session_manager
|
82
87
|
# TODO(cheliu): Support passing auth to MCP Server.
|
83
88
|
self.auth_scheme = auth_scheme
|
84
89
|
self.auth_credential = auth_credential
|
85
90
|
|
91
|
+
async def _reinitialize_session(self):
|
92
|
+
self.mcp_session = await self.mcp_session_manager.create_session()
|
93
|
+
|
86
94
|
@override
|
87
95
|
def _get_declaration(self) -> FunctionDeclaration:
|
88
96
|
"""Gets the function declaration for the tool.
|
@@ -98,6 +106,7 @@ class MCPTool(BaseTool):
|
|
98
106
|
return function_decl
|
99
107
|
|
100
108
|
@override
|
109
|
+
@retry_on_closed_resource("_reinitialize_session")
|
101
110
|
async def run_async(self, *, args, tool_context: ToolContext):
|
102
111
|
"""Runs the tool asynchronously.
|
103
112
|
|
@@ -109,5 +118,9 @@ class MCPTool(BaseTool):
|
|
109
118
|
Any: The response from the tool.
|
110
119
|
"""
|
111
120
|
# TODO(cheliu): Support passing tool context to MCP Server.
|
112
|
-
|
113
|
-
|
121
|
+
try:
|
122
|
+
response = await self.mcp_session.call_tool(self.name, arguments=args)
|
123
|
+
return response
|
124
|
+
except Exception as e:
|
125
|
+
print(e)
|
126
|
+
raise e
|
@@ -13,15 +13,16 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from contextlib import AsyncExitStack
|
16
|
+
import sys
|
16
17
|
from types import TracebackType
|
17
|
-
from typing import
|
18
|
+
from typing import List, Optional, TextIO, Tuple, Type
|
19
|
+
|
20
|
+
from .mcp_session_manager import MCPSessionManager, SseServerParams, retry_on_closed_resource
|
18
21
|
|
19
22
|
# Attempt to import MCP Tool from the MCP library, and hints user to upgrade
|
20
23
|
# their Python version to 3.10 if it fails.
|
21
24
|
try:
|
22
25
|
from mcp import ClientSession, StdioServerParameters
|
23
|
-
from mcp.client.sse import sse_client
|
24
|
-
from mcp.client.stdio import stdio_client
|
25
26
|
from mcp.types import ListToolsResult
|
26
27
|
except ImportError as e:
|
27
28
|
import sys
|
@@ -34,18 +35,9 @@ except ImportError as e:
|
|
34
35
|
else:
|
35
36
|
raise e
|
36
37
|
|
37
|
-
from pydantic import BaseModel
|
38
|
-
|
39
38
|
from .mcp_tool import MCPTool
|
40
39
|
|
41
40
|
|
42
|
-
class SseServerParams(BaseModel):
|
43
|
-
url: str
|
44
|
-
headers: dict[str, Any] | None = None
|
45
|
-
timeout: float = 5
|
46
|
-
sse_read_timeout: float = 60 * 5
|
47
|
-
|
48
|
-
|
49
41
|
class MCPToolset:
|
50
42
|
"""Connects to a MCP Server, and retrieves MCP Tools into ADK Tools.
|
51
43
|
|
@@ -110,7 +102,11 @@ class MCPToolset:
|
|
110
102
|
"""
|
111
103
|
|
112
104
|
def __init__(
|
113
|
-
self,
|
105
|
+
self,
|
106
|
+
*,
|
107
|
+
connection_params: StdioServerParameters | SseServerParams,
|
108
|
+
errlog: TextIO = sys.stderr,
|
109
|
+
exit_stack=AsyncExitStack(),
|
114
110
|
):
|
115
111
|
"""Initializes the MCPToolset.
|
116
112
|
|
@@ -175,7 +171,14 @@ class MCPToolset:
|
|
175
171
|
if not connection_params:
|
176
172
|
raise ValueError('Missing connection params in MCPToolset.')
|
177
173
|
self.connection_params = connection_params
|
178
|
-
self.
|
174
|
+
self.errlog = errlog
|
175
|
+
self.exit_stack = exit_stack
|
176
|
+
|
177
|
+
self.session_manager = MCPSessionManager(
|
178
|
+
connection_params=self.connection_params,
|
179
|
+
exit_stack=self.exit_stack,
|
180
|
+
errlog=self.errlog,
|
181
|
+
)
|
179
182
|
|
180
183
|
@classmethod
|
181
184
|
async def from_server(
|
@@ -183,6 +186,7 @@ class MCPToolset:
|
|
183
186
|
*,
|
184
187
|
connection_params: StdioServerParameters | SseServerParams,
|
185
188
|
async_exit_stack: Optional[AsyncExitStack] = None,
|
189
|
+
errlog: TextIO = sys.stderr,
|
186
190
|
) -> Tuple[List[MCPTool], AsyncExitStack]:
|
187
191
|
"""Retrieve all tools from the MCP connection.
|
188
192
|
|
@@ -209,41 +213,27 @@ class MCPToolset:
|
|
209
213
|
the MCP server. Use `await async_exit_stack.aclose()` to close the
|
210
214
|
connection when server shuts down.
|
211
215
|
"""
|
212
|
-
toolset = cls(connection_params=connection_params)
|
213
216
|
async_exit_stack = async_exit_stack or AsyncExitStack()
|
217
|
+
toolset = cls(
|
218
|
+
connection_params=connection_params,
|
219
|
+
exit_stack=async_exit_stack,
|
220
|
+
errlog=errlog,
|
221
|
+
)
|
222
|
+
|
214
223
|
await async_exit_stack.enter_async_context(toolset)
|
215
224
|
tools = await toolset.load_tools()
|
216
225
|
return (tools, async_exit_stack)
|
217
226
|
|
218
227
|
async def _initialize(self) -> ClientSession:
|
219
228
|
"""Connects to the MCP Server and initializes the ClientSession."""
|
220
|
-
|
221
|
-
client = stdio_client(self.connection_params)
|
222
|
-
elif isinstance(self.connection_params, SseServerParams):
|
223
|
-
client = sse_client(
|
224
|
-
url=self.connection_params.url,
|
225
|
-
headers=self.connection_params.headers,
|
226
|
-
timeout=self.connection_params.timeout,
|
227
|
-
sse_read_timeout=self.connection_params.sse_read_timeout,
|
228
|
-
)
|
229
|
-
else:
|
230
|
-
raise ValueError(
|
231
|
-
'Unable to initialize connection. Connection should be'
|
232
|
-
' StdioServerParameters or SseServerParams, but got'
|
233
|
-
f' {self.connection_params}'
|
234
|
-
)
|
235
|
-
|
236
|
-
transports = await self.exit_stack.enter_async_context(client)
|
237
|
-
self.session = await self.exit_stack.enter_async_context(
|
238
|
-
ClientSession(*transports)
|
239
|
-
)
|
240
|
-
await self.session.initialize()
|
229
|
+
self.session = await self.session_manager.create_session()
|
241
230
|
return self.session
|
242
231
|
|
243
232
|
async def _exit(self):
|
244
233
|
"""Closes the connection to MCP Server."""
|
245
234
|
await self.exit_stack.aclose()
|
246
235
|
|
236
|
+
@retry_on_closed_resource('_initialize')
|
247
237
|
async def load_tools(self) -> List[MCPTool]:
|
248
238
|
"""Loads all tools from the MCP Server.
|
249
239
|
|
@@ -252,7 +242,11 @@ class MCPToolset:
|
|
252
242
|
"""
|
253
243
|
tools_response: ListToolsResult = await self.session.list_tools()
|
254
244
|
return [
|
255
|
-
MCPTool(
|
245
|
+
MCPTool(
|
246
|
+
mcp_tool=tool,
|
247
|
+
mcp_session=self.session,
|
248
|
+
mcp_session_manager=self.session_manager,
|
249
|
+
)
|
256
250
|
for tool in tools_response.tools
|
257
251
|
]
|
258
252
|
|
@@ -66,10 +66,10 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|
66
66
|
|
67
67
|
Returns:
|
68
68
|
An AuthCredential object containing the HTTP bearer access token. If the
|
69
|
-
|
69
|
+
HTTP bearer token cannot be generated, return the original credential.
|
70
70
|
"""
|
71
71
|
|
72
|
-
if
|
72
|
+
if not auth_credential.oauth2.access_token:
|
73
73
|
return auth_credential
|
74
74
|
|
75
75
|
# Return the access token as a bearer token.
|
@@ -78,7 +78,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|
78
78
|
http=HttpAuth(
|
79
79
|
scheme="bearer",
|
80
80
|
credentials=HttpCredentials(
|
81
|
-
token=auth_credential.oauth2.
|
81
|
+
token=auth_credential.oauth2.access_token
|
82
82
|
),
|
83
83
|
),
|
84
84
|
)
|
@@ -111,7 +111,7 @@ class OAuth2CredentialExchanger(BaseAuthCredentialExchanger):
|
|
111
111
|
return auth_credential
|
112
112
|
|
113
113
|
# If access token is exchanged, exchange a HTTPBearer token.
|
114
|
-
if auth_credential.oauth2.
|
114
|
+
if auth_credential.oauth2.access_token:
|
115
115
|
return self.generate_auth_token(auth_credential)
|
116
116
|
|
117
117
|
return None
|
@@ -124,7 +124,7 @@ class OpenAPIToolset:
|
|
124
124
|
def _load_spec(
|
125
125
|
self, spec_str: str, spec_type: Literal["json", "yaml"]
|
126
126
|
) -> Dict[str, Any]:
|
127
|
-
"""Loads the OpenAPI spec string into
|
127
|
+
"""Loads the OpenAPI spec string into a dictionary."""
|
128
128
|
if spec_type == "json":
|
129
129
|
return json.loads(spec_str)
|
130
130
|
elif spec_type == "yaml":
|
@@ -14,20 +14,12 @@
|
|
14
14
|
|
15
15
|
import inspect
|
16
16
|
from textwrap import dedent
|
17
|
-
from typing import Any
|
18
|
-
from typing import Dict
|
19
|
-
from typing import List
|
20
|
-
from typing import Optional
|
21
|
-
from typing import Union
|
17
|
+
from typing import Any, Dict, List, Optional, Union
|
22
18
|
|
23
19
|
from fastapi.encoders import jsonable_encoder
|
24
|
-
from fastapi.openapi.models import Operation
|
25
|
-
from fastapi.openapi.models import Parameter
|
26
|
-
from fastapi.openapi.models import Schema
|
20
|
+
from fastapi.openapi.models import Operation, Parameter, Schema
|
27
21
|
|
28
|
-
from ..common.common import ApiParameter
|
29
|
-
from ..common.common import PydocHelper
|
30
|
-
from ..common.common import to_snake_case
|
22
|
+
from ..common.common import ApiParameter, PydocHelper, to_snake_case
|
31
23
|
|
32
24
|
|
33
25
|
class OperationParser:
|
@@ -110,7 +102,8 @@ class OperationParser:
|
|
110
102
|
description = request_body.description or ''
|
111
103
|
|
112
104
|
if schema and schema.type == 'object':
|
113
|
-
|
105
|
+
properties = schema.properties or {}
|
106
|
+
for prop_name, prop_details in properties.items():
|
114
107
|
self.params.append(
|
115
108
|
ApiParameter(
|
116
109
|
original_name=prop_name,
|
@@ -17,6 +17,7 @@ from typing import Dict
|
|
17
17
|
from typing import List
|
18
18
|
from typing import Literal
|
19
19
|
from typing import Optional
|
20
|
+
from typing import Sequence
|
20
21
|
from typing import Tuple
|
21
22
|
from typing import Union
|
22
23
|
|
@@ -28,7 +29,7 @@ from typing_extensions import override
|
|
28
29
|
|
29
30
|
from ....auth.auth_credential import AuthCredential
|
30
31
|
from ....auth.auth_schemes import AuthScheme
|
31
|
-
from ....tools import BaseTool
|
32
|
+
from ....tools.base_tool import BaseTool
|
32
33
|
from ...tool_context import ToolContext
|
33
34
|
from ..auth.auth_helpers import credential_to_param
|
34
35
|
from ..auth.auth_helpers import dict_to_auth_scheme
|
@@ -59,6 +60,40 @@ def snake_to_lower_camel(snake_case_string: str):
|
|
59
60
|
])
|
60
61
|
|
61
62
|
|
63
|
+
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
64
|
+
# in Gemini SDK.
|
65
|
+
def normalize_json_schema_type(
|
66
|
+
json_schema_type: Optional[Union[str, Sequence[str]]],
|
67
|
+
) -> tuple[Optional[str], bool]:
|
68
|
+
"""Converts a JSON Schema Type into Gemini Schema type.
|
69
|
+
|
70
|
+
Adopted and modified from Gemini SDK. This gets the first available schema
|
71
|
+
type from JSON Schema, and use it to mark Gemini schema type. If JSON Schema
|
72
|
+
contains a list of types, the first non null type is used.
|
73
|
+
|
74
|
+
Remove this after switching to Gemini `from_json_schema`.
|
75
|
+
"""
|
76
|
+
if json_schema_type is None:
|
77
|
+
return None, False
|
78
|
+
if isinstance(json_schema_type, str):
|
79
|
+
if json_schema_type == "null":
|
80
|
+
return None, True
|
81
|
+
return json_schema_type, False
|
82
|
+
|
83
|
+
non_null_types = []
|
84
|
+
nullable = False
|
85
|
+
# If json schema type is an array, pick the first non null type.
|
86
|
+
for type_value in json_schema_type:
|
87
|
+
if type_value == "null":
|
88
|
+
nullable = True
|
89
|
+
else:
|
90
|
+
non_null_types.append(type_value)
|
91
|
+
non_null_type = non_null_types[0] if non_null_types else None
|
92
|
+
return non_null_type, nullable
|
93
|
+
|
94
|
+
|
95
|
+
# TODO: Switch to Gemini `from_json_schema` util when it is released
|
96
|
+
# in Gemini SDK.
|
62
97
|
def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
63
98
|
"""Converts an OpenAPI schema dictionary to a Gemini Schema object.
|
64
99
|
|
@@ -82,13 +117,6 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
|
82
117
|
if not openapi_schema.get("type"):
|
83
118
|
openapi_schema["type"] = "object"
|
84
119
|
|
85
|
-
# Adding this to avoid "properties: should be non-empty for OBJECT type" error
|
86
|
-
# See b/385165182
|
87
|
-
if openapi_schema.get("type", "") == "object" and not openapi_schema.get(
|
88
|
-
"properties"
|
89
|
-
):
|
90
|
-
openapi_schema["properties"] = {"dummy_DO_NOT_GENERATE": {"type": "string"}}
|
91
|
-
|
92
120
|
for key, value in openapi_schema.items():
|
93
121
|
snake_case_key = to_snake_case(key)
|
94
122
|
# Check if the snake_case_key exists in the Schema model's fields.
|
@@ -99,7 +127,17 @@ def to_gemini_schema(openapi_schema: Optional[Dict[str, Any]] = None) -> Schema:
|
|
99
127
|
# Format: properties[expiration].format: only 'enum' and 'date-time' are
|
100
128
|
# supported for STRING type
|
101
129
|
continue
|
102
|
-
|
130
|
+
elif snake_case_key == "type":
|
131
|
+
schema_type, nullable = normalize_json_schema_type(
|
132
|
+
openapi_schema.get("type", None)
|
133
|
+
)
|
134
|
+
# Adding this to force adding a type to an empty dict
|
135
|
+
# This avoid "... one_of or any_of must specify a type" error
|
136
|
+
pydantic_schema_data["type"] = schema_type if schema_type else "object"
|
137
|
+
pydantic_schema_data["type"] = pydantic_schema_data["type"].upper()
|
138
|
+
if nullable:
|
139
|
+
pydantic_schema_data["nullable"] = True
|
140
|
+
elif snake_case_key == "properties" and isinstance(value, dict):
|
103
141
|
pydantic_schema_data[snake_case_key] = {
|
104
142
|
k: to_gemini_schema(v) for k, v in value.items()
|
105
143
|
}
|
google/adk/tools/toolbox_tool.py
CHANGED
google/adk/version.py
CHANGED