google-adk 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +4 -4
- google/adk/agents/callback_context.py +0 -1
- google/adk/agents/invocation_context.py +1 -1
- google/adk/agents/remote_agent.py +1 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/auth/auth_credential.py +2 -1
- google/adk/auth/auth_handler.py +7 -3
- google/adk/auth/auth_preprocessor.py +2 -2
- google/adk/auth/auth_tool.py +1 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-SLIAU2JL.js → main-HWIBUY2R.js} +69 -69
- google/adk/cli/cli_create.py +279 -0
- google/adk/cli/cli_deploy.py +10 -1
- google/adk/cli/cli_eval.py +3 -3
- google/adk/cli/cli_tools_click.py +95 -19
- google/adk/cli/fast_api.py +57 -16
- google/adk/cli/utils/envs.py +0 -3
- google/adk/cli/utils/evals.py +2 -2
- google/adk/evaluation/agent_evaluator.py +2 -2
- google/adk/evaluation/evaluation_generator.py +4 -4
- google/adk/evaluation/response_evaluator.py +17 -5
- google/adk/evaluation/trajectory_evaluator.py +4 -5
- google/adk/events/event.py +3 -3
- google/adk/flows/llm_flows/_nl_planning.py +10 -4
- google/adk/flows/llm_flows/agent_transfer.py +1 -1
- google/adk/flows/llm_flows/base_llm_flow.py +1 -1
- google/adk/flows/llm_flows/contents.py +2 -2
- google/adk/flows/llm_flows/functions.py +1 -3
- google/adk/flows/llm_flows/instructions.py +2 -2
- google/adk/models/gemini_llm_connection.py +2 -2
- google/adk/models/lite_llm.py +51 -34
- google/adk/models/llm_response.py +10 -1
- google/adk/planners/built_in_planner.py +1 -0
- google/adk/planners/plan_re_act_planner.py +2 -2
- google/adk/runners.py +1 -1
- google/adk/sessions/database_session_service.py +91 -26
- google/adk/sessions/state.py +2 -2
- google/adk/telemetry.py +2 -2
- google/adk/tools/agent_tool.py +2 -3
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -2
- google/adk/tools/base_tool.py +1 -1
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +74 -1
- google/adk/tools/google_api_tool/google_api_tool_set.py +12 -9
- google/adk/tools/google_api_tool/google_api_tool_sets.py +91 -34
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +3 -1
- google/adk/tools/load_artifacts_tool.py +1 -1
- google/adk/tools/load_memory_tool.py +25 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +176 -0
- google/adk/tools/mcp_tool/mcp_tool.py +15 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +31 -37
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +4 -4
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +1 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -12
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +47 -9
- google/adk/tools/toolbox_tool.py +1 -1
- google/adk/version.py +1 -1
- google_adk-0.3.0.dist-info/METADATA +235 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/RECORD +62 -60
- google_adk-0.1.1.dist-info/METADATA +0 -181
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/WHEEL +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/licenses/LICENSE +0 -0
google/adk/runners.py
CHANGED
@@ -108,7 +108,7 @@ class Runner:
|
|
108
108
|
"""Runs the agent.
|
109
109
|
|
110
110
|
NOTE: This sync interface is only for local testing and convenience purpose.
|
111
|
-
Consider
|
111
|
+
Consider using `run_async` for production usage.
|
112
112
|
|
113
113
|
Args:
|
114
114
|
user_id: The user ID of the session.
|
@@ -12,23 +12,25 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
import base64
|
15
16
|
import copy
|
16
17
|
from datetime import datetime
|
17
18
|
import json
|
18
19
|
import logging
|
19
|
-
from typing import Any
|
20
|
-
from typing import Optional
|
20
|
+
from typing import Any, Optional
|
21
21
|
import uuid
|
22
22
|
|
23
|
+
from google.genai import types
|
24
|
+
from sqlalchemy import Boolean
|
23
25
|
from sqlalchemy import delete
|
24
26
|
from sqlalchemy import Dialect
|
25
27
|
from sqlalchemy import ForeignKeyConstraint
|
26
28
|
from sqlalchemy import func
|
27
|
-
from sqlalchemy import select
|
28
29
|
from sqlalchemy import Text
|
29
30
|
from sqlalchemy.dialects import postgresql
|
30
31
|
from sqlalchemy.engine import create_engine
|
31
32
|
from sqlalchemy.engine import Engine
|
33
|
+
from sqlalchemy.exc import ArgumentError
|
32
34
|
from sqlalchemy.ext.mutable import MutableDict
|
33
35
|
from sqlalchemy.inspection import inspect
|
34
36
|
from sqlalchemy.orm import DeclarativeBase
|
@@ -53,6 +55,7 @@ from .base_session_service import ListSessionsResponse
|
|
53
55
|
from .session import Session
|
54
56
|
from .state import State
|
55
57
|
|
58
|
+
|
56
59
|
logger = logging.getLogger(__name__)
|
57
60
|
|
58
61
|
|
@@ -102,7 +105,7 @@ class StorageSession(Base):
|
|
102
105
|
String, primary_key=True, default=lambda: str(uuid.uuid4())
|
103
106
|
)
|
104
107
|
|
105
|
-
state: Mapped[
|
108
|
+
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
106
109
|
MutableDict.as_mutable(DynamicJSON), default={}
|
107
110
|
)
|
108
111
|
|
@@ -133,8 +136,20 @@ class StorageEvent(Base):
|
|
133
136
|
author: Mapped[str] = mapped_column(String)
|
134
137
|
branch: Mapped[str] = mapped_column(String, nullable=True)
|
135
138
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
136
|
-
content: Mapped[dict] = mapped_column(DynamicJSON)
|
137
|
-
actions: Mapped[
|
139
|
+
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
140
|
+
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
141
|
+
|
142
|
+
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
143
|
+
Text, nullable=True
|
144
|
+
)
|
145
|
+
grounding_metadata: Mapped[dict[str, Any]] = mapped_column(
|
146
|
+
DynamicJSON, nullable=True
|
147
|
+
)
|
148
|
+
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
149
|
+
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
150
|
+
error_code: Mapped[str] = mapped_column(String, nullable=True)
|
151
|
+
error_message: Mapped[str] = mapped_column(String, nullable=True)
|
152
|
+
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
138
153
|
|
139
154
|
storage_session: Mapped[StorageSession] = relationship(
|
140
155
|
"StorageSession",
|
@@ -149,13 +164,28 @@ class StorageEvent(Base):
|
|
149
164
|
),
|
150
165
|
)
|
151
166
|
|
167
|
+
@property
|
168
|
+
def long_running_tool_ids(self) -> set[str]:
|
169
|
+
return (
|
170
|
+
set(json.loads(self.long_running_tool_ids_json))
|
171
|
+
if self.long_running_tool_ids_json
|
172
|
+
else set()
|
173
|
+
)
|
174
|
+
|
175
|
+
@long_running_tool_ids.setter
|
176
|
+
def long_running_tool_ids(self, value: set[str]):
|
177
|
+
if value is None:
|
178
|
+
self.long_running_tool_ids_json = None
|
179
|
+
else:
|
180
|
+
self.long_running_tool_ids_json = json.dumps(list(value))
|
181
|
+
|
152
182
|
|
153
183
|
class StorageAppState(Base):
|
154
184
|
"""Represents an app state stored in the database."""
|
155
185
|
__tablename__ = "app_states"
|
156
186
|
|
157
187
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
158
|
-
state: Mapped[
|
188
|
+
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
159
189
|
MutableDict.as_mutable(DynamicJSON), default={}
|
160
190
|
)
|
161
191
|
update_time: Mapped[DateTime] = mapped_column(
|
@@ -169,7 +199,7 @@ class StorageUserState(Base):
|
|
169
199
|
|
170
200
|
app_name: Mapped[str] = mapped_column(String, primary_key=True)
|
171
201
|
user_id: Mapped[str] = mapped_column(String, primary_key=True)
|
172
|
-
state: Mapped[
|
202
|
+
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
173
203
|
MutableDict.as_mutable(DynamicJSON), default={}
|
174
204
|
)
|
175
205
|
update_time: Mapped[DateTime] = mapped_column(
|
@@ -187,15 +217,22 @@ class DatabaseSessionService(BaseSessionService):
|
|
187
217
|
"""
|
188
218
|
# 1. Create DB engine for db connection
|
189
219
|
# 2. Create all tables based on schema
|
190
|
-
# 3. Initialize all
|
191
|
-
|
192
|
-
supported_dialects = ["postgresql", "mysql", "sqlite"]
|
193
|
-
dialect = db_url.split("://")[0]
|
220
|
+
# 3. Initialize all properties
|
194
221
|
|
195
|
-
|
222
|
+
try:
|
196
223
|
db_engine = create_engine(db_url)
|
197
|
-
|
198
|
-
|
224
|
+
except Exception as e:
|
225
|
+
if isinstance(e, ArgumentError):
|
226
|
+
raise ValueError(
|
227
|
+
f"Invalid database URL format or argument '{db_url}'."
|
228
|
+
) from e
|
229
|
+
if isinstance(e, ImportError):
|
230
|
+
raise ValueError(
|
231
|
+
f"Database related module not found for URL '{db_url}'."
|
232
|
+
) from e
|
233
|
+
raise ValueError(
|
234
|
+
f"Failed to create database engine for URL '{db_url}'"
|
235
|
+
) from e
|
199
236
|
|
200
237
|
# Get the local timezone
|
201
238
|
local_timezone = get_localzone()
|
@@ -287,7 +324,6 @@ class DatabaseSessionService(BaseSessionService):
|
|
287
324
|
last_update_time=storage_session.update_time.timestamp(),
|
288
325
|
)
|
289
326
|
return session
|
290
|
-
return None
|
291
327
|
|
292
328
|
@override
|
293
329
|
def get_session(
|
@@ -301,7 +337,6 @@ class DatabaseSessionService(BaseSessionService):
|
|
301
337
|
# 1. Get the storage session entry from session table
|
302
338
|
# 2. Get all the events based on session id and filtering config
|
303
339
|
# 3. Convert and return the session
|
304
|
-
session: Session = None
|
305
340
|
with self.DatabaseSessionFactory() as sessionFactory:
|
306
341
|
storage_session = sessionFactory.get(
|
307
342
|
StorageSession, (app_name, user_id, session_id)
|
@@ -348,13 +383,19 @@ class DatabaseSessionService(BaseSessionService):
|
|
348
383
|
author=e.author,
|
349
384
|
branch=e.branch,
|
350
385
|
invocation_id=e.invocation_id,
|
351
|
-
content=e.content,
|
386
|
+
content=_decode_content(e.content),
|
352
387
|
actions=e.actions,
|
353
388
|
timestamp=e.timestamp.timestamp(),
|
389
|
+
long_running_tool_ids=e.long_running_tool_ids,
|
390
|
+
grounding_metadata=e.grounding_metadata,
|
391
|
+
partial=e.partial,
|
392
|
+
turn_complete=e.turn_complete,
|
393
|
+
error_code=e.error_code,
|
394
|
+
error_message=e.error_message,
|
395
|
+
interrupted=e.interrupted,
|
354
396
|
)
|
355
397
|
for e in storage_events
|
356
398
|
]
|
357
|
-
|
358
399
|
return session
|
359
400
|
|
360
401
|
@override
|
@@ -379,7 +420,6 @@ class DatabaseSessionService(BaseSessionService):
|
|
379
420
|
)
|
380
421
|
sessions.append(session)
|
381
422
|
return ListSessionsResponse(sessions=sessions)
|
382
|
-
raise ValueError("Failed to retrieve sessions.")
|
383
423
|
|
384
424
|
@override
|
385
425
|
def delete_session(
|
@@ -398,7 +438,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
398
438
|
def append_event(self, session: Session, event: Event) -> Event:
|
399
439
|
logger.info(f"Append event: {event} to session {session.id}")
|
400
440
|
|
401
|
-
if event.partial
|
441
|
+
if event.partial:
|
402
442
|
return event
|
403
443
|
|
404
444
|
# 1. Check if timestamp is stale
|
@@ -447,19 +487,34 @@ class DatabaseSessionService(BaseSessionService):
|
|
447
487
|
storage_user_state.state = user_state
|
448
488
|
storage_session.state = session_state
|
449
489
|
|
450
|
-
encoded_content = event.content.model_dump(exclude_none=True)
|
451
490
|
storage_event = StorageEvent(
|
452
491
|
id=event.id,
|
453
492
|
invocation_id=event.invocation_id,
|
454
493
|
author=event.author,
|
455
494
|
branch=event.branch,
|
456
|
-
content=encoded_content,
|
457
495
|
actions=event.actions,
|
458
496
|
session_id=session.id,
|
459
497
|
app_name=session.app_name,
|
460
498
|
user_id=session.user_id,
|
461
499
|
timestamp=datetime.fromtimestamp(event.timestamp),
|
500
|
+
long_running_tool_ids=event.long_running_tool_ids,
|
501
|
+
grounding_metadata=event.grounding_metadata,
|
502
|
+
partial=event.partial,
|
503
|
+
turn_complete=event.turn_complete,
|
504
|
+
error_code=event.error_code,
|
505
|
+
error_message=event.error_message,
|
506
|
+
interrupted=event.interrupted,
|
462
507
|
)
|
508
|
+
if event.content:
|
509
|
+
encoded_content = event.content.model_dump(exclude_none=True)
|
510
|
+
# Workaround for multimodal Content throwing JSON not serializable
|
511
|
+
# error with SQLAlchemy.
|
512
|
+
for p in encoded_content["parts"]:
|
513
|
+
if "inline_data" in p:
|
514
|
+
p["inline_data"]["data"] = (
|
515
|
+
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
516
|
+
)
|
517
|
+
storage_event.content = encoded_content
|
463
518
|
|
464
519
|
sessionFactory.add(storage_event)
|
465
520
|
|
@@ -481,8 +536,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
481
536
|
user_id: str,
|
482
537
|
session_id: str,
|
483
538
|
) -> ListEventsResponse:
|
484
|
-
|
485
|
-
|
539
|
+
raise NotImplementedError()
|
486
540
|
|
487
541
|
def convert_event(event: StorageEvent) -> Event:
|
488
542
|
"""Converts a storage event to an event."""
|
@@ -497,7 +551,7 @@ def convert_event(event: StorageEvent) -> Event:
|
|
497
551
|
)
|
498
552
|
|
499
553
|
|
500
|
-
def _extract_state_delta(state: dict):
|
554
|
+
def _extract_state_delta(state: dict[str, Any]):
|
501
555
|
app_state_delta = {}
|
502
556
|
user_state_delta = {}
|
503
557
|
session_state_delta = {}
|
@@ -520,3 +574,14 @@ def _merge_state(app_state, user_state, session_state):
|
|
520
574
|
for key in user_state.keys():
|
521
575
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
522
576
|
return merged_state
|
577
|
+
|
578
|
+
|
579
|
+
def _decode_content(
|
580
|
+
content: Optional[dict[str, Any]],
|
581
|
+
) -> Optional[types.Content]:
|
582
|
+
if not content:
|
583
|
+
return None
|
584
|
+
for p in content["parts"]:
|
585
|
+
if "inline_data" in p:
|
586
|
+
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
587
|
+
return types.Content.model_validate(content)
|
google/adk/sessions/state.py
CHANGED
@@ -26,7 +26,7 @@ class State:
|
|
26
26
|
"""
|
27
27
|
Args:
|
28
28
|
value: The current value of the state dict.
|
29
|
-
delta: The delta change to the current value that hasn't been
|
29
|
+
delta: The delta change to the current value that hasn't been committed.
|
30
30
|
"""
|
31
31
|
self._value = value
|
32
32
|
self._delta = delta
|
@@ -49,7 +49,7 @@ class State:
|
|
49
49
|
return key in self._value or key in self._delta
|
50
50
|
|
51
51
|
def has_delta(self) -> bool:
|
52
|
-
"""Whether the state has pending
|
52
|
+
"""Whether the state has pending delta."""
|
53
53
|
return bool(self._delta)
|
54
54
|
|
55
55
|
def get(self, key: str, default: Any = None) -> Any:
|
google/adk/telemetry.py
CHANGED
@@ -16,8 +16,8 @@
|
|
16
16
|
#
|
17
17
|
# We expect that the underlying GenAI SDK will provide a certain
|
18
18
|
# level of tracing and logging telemetry aligned with Open Telemetry
|
19
|
-
# Semantic Conventions (such as logging prompts,
|
20
|
-
# properties, etc.) and so the information that is recorded by the
|
19
|
+
# Semantic Conventions (such as logging prompts, responses,
|
20
|
+
# request properties, etc.) and so the information that is recorded by the
|
21
21
|
# Agent Development Kit should be focused on the higher-level
|
22
22
|
# constructs of the framework that are not observable by the SDK.
|
23
23
|
|
google/adk/tools/agent_tool.py
CHANGED
@@ -45,10 +45,9 @@ class AgentTool(BaseTool):
|
|
45
45
|
skip_summarization: Whether to skip summarization of the agent output.
|
46
46
|
"""
|
47
47
|
|
48
|
-
def __init__(self, agent: BaseAgent):
|
48
|
+
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
49
49
|
self.agent = agent
|
50
|
-
self.skip_summarization: bool =
|
51
|
-
"""Whether to skip summarization of the agent output."""
|
50
|
+
self.skip_summarization: bool = skip_summarization
|
52
51
|
|
53
52
|
super().__init__(name=agent.name, description=agent.description)
|
54
53
|
|
@@ -196,11 +196,12 @@ class IntegrationClient:
|
|
196
196
|
action_details = connections_client.get_action_schema(action)
|
197
197
|
input_schema = action_details["inputSchema"]
|
198
198
|
output_schema = action_details["outputSchema"]
|
199
|
-
|
199
|
+
# Remove spaces from the display name to generate valid spec
|
200
|
+
action_display_name = action_details["displayName"].replace(" ", "")
|
200
201
|
operation = "EXECUTE_ACTION"
|
201
202
|
if action == "ExecuteCustomQuery":
|
202
203
|
connector_spec["components"]["schemas"][
|
203
|
-
f"{
|
204
|
+
f"{action_display_name}_Request"
|
204
205
|
] = connections_client.execute_custom_query_request()
|
205
206
|
operation = "EXECUTE_QUERY"
|
206
207
|
else:
|
google/adk/tools/base_tool.py
CHANGED
@@ -53,7 +53,7 @@ def _raise_for_any_of_if_mldev(schema: types.Schema):
|
|
53
53
|
|
54
54
|
def _update_for_default_if_mldev(schema: types.Schema):
|
55
55
|
if schema.default is not None:
|
56
|
-
# TODO(kech): Remove this
|
56
|
+
# TODO(kech): Remove this workaround once mldev supports default value.
|
57
57
|
schema.default = None
|
58
58
|
logger.warning(
|
59
59
|
'Default value is not supported in function declaration schema for'
|
@@ -291,7 +291,7 @@ def _parse_schema_from_parameter(
|
|
291
291
|
return schema
|
292
292
|
raise ValueError(
|
293
293
|
f'Failed to parse the parameter {param} of function {func_name} for'
|
294
|
-
' automatic function calling.Automatic function calling works best with'
|
294
|
+
' automatic function calling. Automatic function calling works best with'
|
295
295
|
' simpler function signature schema,consider manually parse your'
|
296
296
|
f' function declaration for function {func_name}.'
|
297
297
|
)
|
@@ -11,4 +11,77 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
14
|
+
__all__ = [
|
15
|
+
'bigquery_tool_set',
|
16
|
+
'calendar_tool_set',
|
17
|
+
'gmail_tool_set',
|
18
|
+
'youtube_tool_set',
|
19
|
+
'slides_tool_set',
|
20
|
+
'sheets_tool_set',
|
21
|
+
'docs_tool_set',
|
22
|
+
]
|
23
|
+
|
24
|
+
# Nothing is imported here automatically
|
25
|
+
# Each tool set will only be imported when accessed
|
26
|
+
|
27
|
+
_bigquery_tool_set = None
|
28
|
+
_calendar_tool_set = None
|
29
|
+
_gmail_tool_set = None
|
30
|
+
_youtube_tool_set = None
|
31
|
+
_slides_tool_set = None
|
32
|
+
_sheets_tool_set = None
|
33
|
+
_docs_tool_set = None
|
34
|
+
|
35
|
+
|
36
|
+
def __getattr__(name):
|
37
|
+
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
38
|
+
|
39
|
+
match name:
|
40
|
+
case 'bigquery_tool_set':
|
41
|
+
if _bigquery_tool_set is None:
|
42
|
+
from .google_api_tool_sets import bigquery_tool_set as bigquery
|
43
|
+
|
44
|
+
_bigquery_tool_set = bigquery
|
45
|
+
return _bigquery_tool_set
|
46
|
+
|
47
|
+
case 'calendar_tool_set':
|
48
|
+
if _calendar_tool_set is None:
|
49
|
+
from .google_api_tool_sets import calendar_tool_set as calendar
|
50
|
+
|
51
|
+
_calendar_tool_set = calendar
|
52
|
+
return _calendar_tool_set
|
53
|
+
|
54
|
+
case 'gmail_tool_set':
|
55
|
+
if _gmail_tool_set is None:
|
56
|
+
from .google_api_tool_sets import gmail_tool_set as gmail
|
57
|
+
|
58
|
+
_gmail_tool_set = gmail
|
59
|
+
return _gmail_tool_set
|
60
|
+
|
61
|
+
case 'youtube_tool_set':
|
62
|
+
if _youtube_tool_set is None:
|
63
|
+
from .google_api_tool_sets import youtube_tool_set as youtube
|
64
|
+
|
65
|
+
_youtube_tool_set = youtube
|
66
|
+
return _youtube_tool_set
|
67
|
+
|
68
|
+
case 'slides_tool_set':
|
69
|
+
if _slides_tool_set is None:
|
70
|
+
from .google_api_tool_sets import slides_tool_set as slides
|
71
|
+
|
72
|
+
_slides_tool_set = slides
|
73
|
+
return _slides_tool_set
|
74
|
+
|
75
|
+
case 'sheets_tool_set':
|
76
|
+
if _sheets_tool_set is None:
|
77
|
+
from .google_api_tool_sets import sheets_tool_set as sheets
|
78
|
+
|
79
|
+
_sheets_tool_set = sheets
|
80
|
+
return _sheets_tool_set
|
81
|
+
|
82
|
+
case 'docs_tool_set':
|
83
|
+
if _docs_tool_set is None:
|
84
|
+
from .google_api_tool_sets import docs_tool_set as docs
|
85
|
+
|
86
|
+
_docs_tool_set = docs
|
87
|
+
return _docs_tool_set
|
@@ -11,10 +11,12 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
14
17
|
import inspect
|
15
18
|
import os
|
16
19
|
from typing import Any
|
17
|
-
from typing import Dict
|
18
20
|
from typing import Final
|
19
21
|
from typing import List
|
20
22
|
from typing import Optional
|
@@ -28,6 +30,7 @@ from .googleapi_to_openapi_converter import GoogleApiToOpenApiConverter
|
|
28
30
|
|
29
31
|
|
30
32
|
class GoogleApiToolSet:
|
33
|
+
"""Google API Tool Set."""
|
31
34
|
|
32
35
|
def __init__(self, tools: List[RestApiTool]):
|
33
36
|
self.tools: Final[List[GoogleApiTool]] = [
|
@@ -45,10 +48,10 @@ class GoogleApiToolSet:
|
|
45
48
|
|
46
49
|
@staticmethod
|
47
50
|
def _load_tool_set_with_oidc_auth(
|
48
|
-
spec_file: str = None,
|
49
|
-
spec_dict:
|
50
|
-
scopes: list[str] = None,
|
51
|
-
) ->
|
51
|
+
spec_file: Optional[str] = None,
|
52
|
+
spec_dict: Optional[dict[str, Any]] = None,
|
53
|
+
scopes: Optional[list[str]] = None,
|
54
|
+
) -> OpenAPIToolset:
|
52
55
|
spec_str = None
|
53
56
|
if spec_file:
|
54
57
|
# Get the frame of the caller
|
@@ -90,18 +93,18 @@ class GoogleApiToolSet:
|
|
90
93
|
|
91
94
|
@classmethod
|
92
95
|
def load_tool_set(
|
93
|
-
|
96
|
+
cls: Type[GoogleApiToolSet],
|
94
97
|
api_name: str,
|
95
98
|
api_version: str,
|
96
|
-
) ->
|
99
|
+
) -> GoogleApiToolSet:
|
97
100
|
spec_dict = GoogleApiToOpenApiConverter(api_name, api_version).convert()
|
98
101
|
scope = list(
|
99
102
|
spec_dict['components']['securitySchemes']['oauth2']['flows'][
|
100
103
|
'authorizationCode'
|
101
104
|
]['scopes'].keys()
|
102
105
|
)[0]
|
103
|
-
return
|
104
|
-
|
106
|
+
return cls(
|
107
|
+
cls._load_tool_set_with_oidc_auth(
|
105
108
|
spec_dict=spec_dict, scopes=[scope]
|
106
109
|
).get_tools()
|
107
110
|
)
|
@@ -19,37 +19,94 @@ from .google_api_tool_set import GoogleApiToolSet
|
|
19
19
|
|
20
20
|
logger = logging.getLogger(__name__)
|
21
21
|
|
22
|
-
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
)
|
41
|
-
|
42
|
-
|
43
|
-
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
22
|
+
_bigquery_tool_set = None
|
23
|
+
_calendar_tool_set = None
|
24
|
+
_gmail_tool_set = None
|
25
|
+
_youtube_tool_set = None
|
26
|
+
_slides_tool_set = None
|
27
|
+
_sheets_tool_set = None
|
28
|
+
_docs_tool_set = None
|
29
|
+
|
30
|
+
|
31
|
+
def __getattr__(name):
|
32
|
+
"""This method dynamically loads and returns GoogleApiToolSet instances for
|
33
|
+
|
34
|
+
various Google APIs. It uses a lazy loading approach, initializing each
|
35
|
+
tool set only when it is first requested. This avoids unnecessary loading
|
36
|
+
of tool sets that are not used in a given session.
|
37
|
+
|
38
|
+
Args:
|
39
|
+
name (str): The name of the tool set to retrieve (e.g.,
|
40
|
+
"bigquery_tool_set").
|
41
|
+
|
42
|
+
Returns:
|
43
|
+
GoogleApiToolSet: The requested tool set instance.
|
44
|
+
|
45
|
+
Raises:
|
46
|
+
AttributeError: If the requested tool set name is not recognized.
|
47
|
+
"""
|
48
|
+
global _bigquery_tool_set, _calendar_tool_set, _gmail_tool_set, _youtube_tool_set, _slides_tool_set, _sheets_tool_set, _docs_tool_set
|
49
|
+
|
50
|
+
match name:
|
51
|
+
case "bigquery_tool_set":
|
52
|
+
if _bigquery_tool_set is None:
|
53
|
+
_bigquery_tool_set = GoogleApiToolSet.load_tool_set(
|
54
|
+
api_name="bigquery",
|
55
|
+
api_version="v2",
|
56
|
+
)
|
57
|
+
|
58
|
+
return _bigquery_tool_set
|
59
|
+
|
60
|
+
case "calendar_tool_set":
|
61
|
+
if _calendar_tool_set is None:
|
62
|
+
_calendar_tool_set = GoogleApiToolSet.load_tool_set(
|
63
|
+
api_name="calendar",
|
64
|
+
api_version="v3",
|
65
|
+
)
|
66
|
+
|
67
|
+
return _calendar_tool_set
|
68
|
+
|
69
|
+
case "gmail_tool_set":
|
70
|
+
if _gmail_tool_set is None:
|
71
|
+
_gmail_tool_set = GoogleApiToolSet.load_tool_set(
|
72
|
+
api_name="gmail",
|
73
|
+
api_version="v1",
|
74
|
+
)
|
75
|
+
|
76
|
+
return _gmail_tool_set
|
77
|
+
|
78
|
+
case "youtube_tool_set":
|
79
|
+
if _youtube_tool_set is None:
|
80
|
+
_youtube_tool_set = GoogleApiToolSet.load_tool_set(
|
81
|
+
api_name="youtube",
|
82
|
+
api_version="v3",
|
83
|
+
)
|
84
|
+
|
85
|
+
return _youtube_tool_set
|
86
|
+
|
87
|
+
case "slides_tool_set":
|
88
|
+
if _slides_tool_set is None:
|
89
|
+
_slides_tool_set = GoogleApiToolSet.load_tool_set(
|
90
|
+
api_name="slides",
|
91
|
+
api_version="v1",
|
92
|
+
)
|
93
|
+
|
94
|
+
return _slides_tool_set
|
95
|
+
|
96
|
+
case "sheets_tool_set":
|
97
|
+
if _sheets_tool_set is None:
|
98
|
+
_sheets_tool_set = GoogleApiToolSet.load_tool_set(
|
99
|
+
api_name="sheets",
|
100
|
+
api_version="v4",
|
101
|
+
)
|
102
|
+
|
103
|
+
return _sheets_tool_set
|
104
|
+
|
105
|
+
case "docs_tool_set":
|
106
|
+
if _docs_tool_set is None:
|
107
|
+
_docs_tool_set = GoogleApiToolSet.load_tool_set(
|
108
|
+
api_name="docs",
|
109
|
+
api_version="v1",
|
110
|
+
)
|
111
|
+
|
112
|
+
return _docs_tool_set
|
@@ -311,7 +311,9 @@ class GoogleApiToOpenApiConverter:
|
|
311
311
|
|
312
312
|
# Determine the actual endpoint path
|
313
313
|
# Google often has the format something like 'users.messages.list'
|
314
|
-
|
314
|
+
# flatPath is preferred as it provides the actual path, while path
|
315
|
+
# might contain variables like {+projectId}
|
316
|
+
rest_path = method_data.get("flatPath", method_data.get("path", "/"))
|
315
317
|
if not rest_path.startswith("/"):
|
316
318
|
rest_path = "/" + rest_path
|
317
319
|
|
@@ -89,7 +89,7 @@ class LoadArtifactsTool(BaseTool):
|
|
89
89
|
than the function call.
|
90
90
|
"""])
|
91
91
|
|
92
|
-
#
|
92
|
+
# Attach the content of the artifacts if the model requests them.
|
93
93
|
# This only adds the content to the model request, instead of the session.
|
94
94
|
if llm_request.contents and llm_request.contents[-1].parts:
|
95
95
|
function_response = llm_request.contents[-1].parts[0].function_response
|