google-adk 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +7 -7
- google/adk/agents/callback_context.py +0 -1
- google/adk/agents/llm_agent.py +3 -8
- google/adk/auth/auth_credential.py +2 -1
- google/adk/auth/auth_handler.py +7 -3
- google/adk/cli/browser/index.html +1 -1
- google/adk/cli/browser/{main-ZBO76GRM.js → main-HWIBUY2R.js} +69 -53
- google/adk/cli/cli.py +54 -47
- google/adk/cli/cli_deploy.py +6 -1
- google/adk/cli/cli_eval.py +1 -1
- google/adk/cli/cli_tools_click.py +78 -13
- google/adk/cli/fast_api.py +6 -0
- google/adk/evaluation/agent_evaluator.py +2 -2
- google/adk/evaluation/response_evaluator.py +2 -2
- google/adk/evaluation/trajectory_evaluator.py +4 -5
- google/adk/events/event_actions.py +9 -4
- google/adk/flows/llm_flows/agent_transfer.py +1 -1
- google/adk/flows/llm_flows/base_llm_flow.py +1 -1
- google/adk/flows/llm_flows/contents.py +10 -6
- google/adk/flows/llm_flows/functions.py +38 -18
- google/adk/flows/llm_flows/instructions.py +2 -2
- google/adk/models/gemini_llm_connection.py +2 -2
- google/adk/models/llm_response.py +10 -1
- google/adk/planners/built_in_planner.py +1 -0
- google/adk/sessions/_session_util.py +29 -0
- google/adk/sessions/database_session_service.py +60 -43
- google/adk/sessions/state.py +1 -1
- google/adk/sessions/vertex_ai_session_service.py +7 -5
- google/adk/tools/agent_tool.py +2 -3
- google/adk/tools/application_integration_tool/__init__.py +2 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +48 -26
- google/adk/tools/application_integration_tool/clients/connections_client.py +26 -54
- google/adk/tools/application_integration_tool/integration_connector_tool.py +159 -0
- google/adk/tools/function_tool.py +42 -0
- google/adk/tools/google_api_tool/google_api_tool_set.py +12 -9
- google/adk/tools/load_artifacts_tool.py +1 -1
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +4 -4
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +1 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -12
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +46 -8
- google/adk/version.py +1 -1
- {google_adk-0.2.0.dist-info → google_adk-0.4.0.dist-info}/METADATA +28 -9
- {google_adk-0.2.0.dist-info → google_adk-0.4.0.dist-info}/RECORD +46 -44
- {google_adk-0.2.0.dist-info → google_adk-0.4.0.dist-info}/WHEEL +0 -0
- {google_adk-0.2.0.dist-info → google_adk-0.4.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.2.0.dist-info → google_adk-0.4.0.dist-info}/licenses/LICENSE +0 -0
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Optional
|
17
|
+
from typing import Any, Optional
|
18
18
|
|
19
19
|
from google.genai import types
|
20
20
|
from pydantic import BaseModel
|
@@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
|
|
37
37
|
error_message: Error message if the response is an error.
|
38
38
|
interrupted: Flag indicating that LLM was interrupted when generating the
|
39
39
|
content. Usually it's due to user interruption during a bidi streaming.
|
40
|
+
custom_metadata: The custom metadata of the LlmResponse.
|
40
41
|
"""
|
41
42
|
|
42
43
|
model_config = ConfigDict(extra='forbid')
|
@@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
|
|
71
72
|
Usually it's due to user interruption during a bidi streaming.
|
72
73
|
"""
|
73
74
|
|
75
|
+
custom_metadata: Optional[dict[str, Any]] = None
|
76
|
+
"""The custom metadata of the LlmResponse.
|
77
|
+
|
78
|
+
An optional key-value pair to label an LlmResponse.
|
79
|
+
|
80
|
+
NOTE: the entire dict must be JSON serializable.
|
81
|
+
"""
|
82
|
+
|
74
83
|
@staticmethod
|
75
84
|
def create(
|
76
85
|
generate_content_response: types.GenerateContentResponse,
|
@@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
|
|
56
56
|
llm_request: The LLM request to apply the thinking config to.
|
57
57
|
"""
|
58
58
|
if self.thinking_config:
|
59
|
+
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
59
60
|
llm_request.config.thinking_config = self.thinking_config
|
60
61
|
|
61
62
|
@override
|
@@ -0,0 +1,29 @@
|
|
1
|
+
"""Utility functions for session service."""
|
2
|
+
|
3
|
+
import base64
|
4
|
+
from typing import Any, Optional
|
5
|
+
|
6
|
+
from google.genai import types
|
7
|
+
|
8
|
+
|
9
|
+
def encode_content(content: types.Content):
|
10
|
+
"""Encodes a content object to a JSON dictionary."""
|
11
|
+
encoded_content = content.model_dump(exclude_none=True)
|
12
|
+
for p in encoded_content["parts"]:
|
13
|
+
if "inline_data" in p:
|
14
|
+
p["inline_data"]["data"] = base64.b64encode(
|
15
|
+
p["inline_data"]["data"]
|
16
|
+
).decode("utf-8")
|
17
|
+
return encoded_content
|
18
|
+
|
19
|
+
|
20
|
+
def decode_content(
|
21
|
+
content: Optional[dict[str, Any]],
|
22
|
+
) -> Optional[types.Content]:
|
23
|
+
"""Decodes a content object from a JSON dictionary."""
|
24
|
+
if not content:
|
25
|
+
return None
|
26
|
+
for p in content["parts"]:
|
27
|
+
if "inline_data" in p:
|
28
|
+
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
29
|
+
return types.Content.model_validate(content)
|
@@ -11,14 +11,11 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
|
-
import base64
|
16
14
|
import copy
|
17
15
|
from datetime import datetime
|
18
16
|
import json
|
19
17
|
import logging
|
20
|
-
from typing import Any
|
21
|
-
from typing import Optional
|
18
|
+
from typing import Any, Optional
|
22
19
|
import uuid
|
23
20
|
|
24
21
|
from sqlalchemy import Boolean
|
@@ -27,6 +24,7 @@ from sqlalchemy import Dialect
|
|
27
24
|
from sqlalchemy import ForeignKeyConstraint
|
28
25
|
from sqlalchemy import func
|
29
26
|
from sqlalchemy import Text
|
27
|
+
from sqlalchemy.dialects import mysql
|
30
28
|
from sqlalchemy.dialects import postgresql
|
31
29
|
from sqlalchemy.engine import create_engine
|
32
30
|
from sqlalchemy.engine import Engine
|
@@ -48,6 +46,7 @@ from typing_extensions import override
|
|
48
46
|
from tzlocal import get_localzone
|
49
47
|
|
50
48
|
from ..events.event import Event
|
49
|
+
from . import _session_util
|
51
50
|
from .base_session_service import BaseSessionService
|
52
51
|
from .base_session_service import GetSessionConfig
|
53
52
|
from .base_session_service import ListEventsResponse
|
@@ -58,6 +57,9 @@ from .state import State
|
|
58
57
|
|
59
58
|
logger = logging.getLogger(__name__)
|
60
59
|
|
60
|
+
DEFAULT_MAX_KEY_LENGTH = 128
|
61
|
+
DEFAULT_MAX_VARCHAR_LENGTH = 256
|
62
|
+
|
61
63
|
|
62
64
|
class DynamicJSON(TypeDecorator):
|
63
65
|
"""A JSON-like type that uses JSONB on PostgreSQL and TEXT with JSON
|
@@ -70,15 +72,16 @@ class DynamicJSON(TypeDecorator):
|
|
70
72
|
def load_dialect_impl(self, dialect: Dialect):
|
71
73
|
if dialect.name == "postgresql":
|
72
74
|
return dialect.type_descriptor(postgresql.JSONB)
|
73
|
-
|
74
|
-
|
75
|
+
if dialect.name == "mysql":
|
76
|
+
# Use LONGTEXT for MySQL to address the data too long issue
|
77
|
+
return dialect.type_descriptor(mysql.LONGTEXT)
|
78
|
+
return dialect.type_descriptor(Text) # Default to Text for other dialects
|
75
79
|
|
76
80
|
def process_bind_param(self, value, dialect: Dialect):
|
77
81
|
if value is not None:
|
78
82
|
if dialect.name == "postgresql":
|
79
83
|
return value # JSONB handles dict directly
|
80
|
-
|
81
|
-
return json.dumps(value) # Serialize to JSON string for TEXT
|
84
|
+
return json.dumps(value) # Serialize to JSON string for TEXT
|
82
85
|
return value
|
83
86
|
|
84
87
|
def process_result_value(self, value, dialect: Dialect):
|
@@ -92,17 +95,25 @@ class DynamicJSON(TypeDecorator):
|
|
92
95
|
|
93
96
|
class Base(DeclarativeBase):
|
94
97
|
"""Base class for database tables."""
|
98
|
+
|
95
99
|
pass
|
96
100
|
|
97
101
|
|
98
102
|
class StorageSession(Base):
|
99
103
|
"""Represents a session stored in the database."""
|
104
|
+
|
100
105
|
__tablename__ = "sessions"
|
101
106
|
|
102
|
-
app_name: Mapped[str] = mapped_column(
|
103
|
-
|
107
|
+
app_name: Mapped[str] = mapped_column(
|
108
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
109
|
+
)
|
110
|
+
user_id: Mapped[str] = mapped_column(
|
111
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
112
|
+
)
|
104
113
|
id: Mapped[str] = mapped_column(
|
105
|
-
String
|
114
|
+
String(DEFAULT_MAX_KEY_LENGTH),
|
115
|
+
primary_key=True,
|
116
|
+
default=lambda: str(uuid.uuid4()),
|
106
117
|
)
|
107
118
|
|
108
119
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
@@ -125,18 +136,29 @@ class StorageSession(Base):
|
|
125
136
|
|
126
137
|
class StorageEvent(Base):
|
127
138
|
"""Represents an event stored in the database."""
|
139
|
+
|
128
140
|
__tablename__ = "events"
|
129
141
|
|
130
|
-
id: Mapped[str] = mapped_column(
|
131
|
-
|
132
|
-
|
133
|
-
|
142
|
+
id: Mapped[str] = mapped_column(
|
143
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
144
|
+
)
|
145
|
+
app_name: Mapped[str] = mapped_column(
|
146
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
147
|
+
)
|
148
|
+
user_id: Mapped[str] = mapped_column(
|
149
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
150
|
+
)
|
151
|
+
session_id: Mapped[str] = mapped_column(
|
152
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
153
|
+
)
|
134
154
|
|
135
|
-
invocation_id: Mapped[str] = mapped_column(String)
|
136
|
-
author: Mapped[str] = mapped_column(String)
|
137
|
-
branch: Mapped[str] = mapped_column(
|
155
|
+
invocation_id: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
156
|
+
author: Mapped[str] = mapped_column(String(DEFAULT_MAX_VARCHAR_LENGTH))
|
157
|
+
branch: Mapped[str] = mapped_column(
|
158
|
+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
159
|
+
)
|
138
160
|
timestamp: Mapped[DateTime] = mapped_column(DateTime(), default=func.now())
|
139
|
-
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON)
|
161
|
+
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
140
162
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
141
163
|
|
142
164
|
long_running_tool_ids_json: Mapped[Optional[str]] = mapped_column(
|
@@ -147,8 +169,10 @@ class StorageEvent(Base):
|
|
147
169
|
)
|
148
170
|
partial: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
149
171
|
turn_complete: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
150
|
-
error_code: Mapped[str] = mapped_column(
|
151
|
-
|
172
|
+
error_code: Mapped[str] = mapped_column(
|
173
|
+
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
174
|
+
)
|
175
|
+
error_message: Mapped[str] = mapped_column(String(1024), nullable=True)
|
152
176
|
interrupted: Mapped[bool] = mapped_column(Boolean, nullable=True)
|
153
177
|
|
154
178
|
storage_session: Mapped[StorageSession] = relationship(
|
@@ -182,9 +206,12 @@ class StorageEvent(Base):
|
|
182
206
|
|
183
207
|
class StorageAppState(Base):
|
184
208
|
"""Represents an app state stored in the database."""
|
209
|
+
|
185
210
|
__tablename__ = "app_states"
|
186
211
|
|
187
|
-
app_name: Mapped[str] = mapped_column(
|
212
|
+
app_name: Mapped[str] = mapped_column(
|
213
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
214
|
+
)
|
188
215
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
189
216
|
MutableDict.as_mutable(DynamicJSON), default={}
|
190
217
|
)
|
@@ -192,13 +219,17 @@ class StorageAppState(Base):
|
|
192
219
|
DateTime(), default=func.now(), onupdate=func.now()
|
193
220
|
)
|
194
221
|
|
195
|
-
|
196
222
|
class StorageUserState(Base):
|
197
223
|
"""Represents a user state stored in the database."""
|
224
|
+
|
198
225
|
__tablename__ = "user_states"
|
199
226
|
|
200
|
-
app_name: Mapped[str] = mapped_column(
|
201
|
-
|
227
|
+
app_name: Mapped[str] = mapped_column(
|
228
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
229
|
+
)
|
230
|
+
user_id: Mapped[str] = mapped_column(
|
231
|
+
String(DEFAULT_MAX_KEY_LENGTH), primary_key=True
|
232
|
+
)
|
202
233
|
state: Mapped[MutableDict[str, Any]] = mapped_column(
|
203
234
|
MutableDict.as_mutable(DynamicJSON), default={}
|
204
235
|
)
|
@@ -217,7 +248,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
217
248
|
"""
|
218
249
|
# 1. Create DB engine for db connection
|
219
250
|
# 2. Create all tables based on schema
|
220
|
-
# 3. Initialize all
|
251
|
+
# 3. Initialize all properties
|
221
252
|
|
222
253
|
try:
|
223
254
|
db_engine = create_engine(db_url)
|
@@ -353,6 +384,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
353
384
|
else True
|
354
385
|
)
|
355
386
|
.limit(config.num_recent_events if config else None)
|
387
|
+
.order_by(StorageEvent.timestamp.asc())
|
356
388
|
.all()
|
357
389
|
)
|
358
390
|
|
@@ -383,7 +415,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
383
415
|
author=e.author,
|
384
416
|
branch=e.branch,
|
385
417
|
invocation_id=e.invocation_id,
|
386
|
-
content=
|
418
|
+
content=_session_util.decode_content(e.content),
|
387
419
|
actions=e.actions,
|
388
420
|
timestamp=e.timestamp.timestamp(),
|
389
421
|
long_running_tool_ids=e.long_running_tool_ids,
|
@@ -506,15 +538,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
506
538
|
interrupted=event.interrupted,
|
507
539
|
)
|
508
540
|
if event.content:
|
509
|
-
|
510
|
-
# Workaround for multimodal Content throwing JSON not serializable
|
511
|
-
# error with SQLAlchemy.
|
512
|
-
for p in encoded_content["parts"]:
|
513
|
-
if "inline_data" in p:
|
514
|
-
p["inline_data"]["data"] = (
|
515
|
-
base64.b64encode(p["inline_data"]["data"]).decode("utf-8"),
|
516
|
-
)
|
517
|
-
storage_event.content = encoded_content
|
541
|
+
storage_event.content = _session_util.encode_content(event.content)
|
518
542
|
|
519
543
|
sessionFactory.add(storage_event)
|
520
544
|
|
@@ -574,10 +598,3 @@ def _merge_state(app_state, user_state, session_state):
|
|
574
598
|
for key in user_state.keys():
|
575
599
|
merged_state[State.USER_PREFIX + key] = user_state[key]
|
576
600
|
return merged_state
|
577
|
-
|
578
|
-
|
579
|
-
def _decode_content(content: dict[str, Any]) -> dict[str, Any]:
|
580
|
-
for p in content["parts"]:
|
581
|
-
if "inline_data" in p:
|
582
|
-
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"][0])
|
583
|
-
return content
|
google/adk/sessions/state.py
CHANGED
@@ -26,7 +26,7 @@ class State:
|
|
26
26
|
"""
|
27
27
|
Args:
|
28
28
|
value: The current value of the state dict.
|
29
|
-
delta: The delta change to the current value that hasn't been
|
29
|
+
delta: The delta change to the current value that hasn't been committed.
|
30
30
|
"""
|
31
31
|
self._value = value
|
32
32
|
self._delta = delta
|
@@ -14,21 +14,23 @@
|
|
14
14
|
import logging
|
15
15
|
import re
|
16
16
|
import time
|
17
|
-
from typing import Any
|
18
|
-
from typing import Optional
|
17
|
+
from typing import Any, Optional
|
19
18
|
|
20
|
-
from dateutil
|
19
|
+
from dateutil import parser
|
21
20
|
from google import genai
|
22
21
|
from typing_extensions import override
|
23
22
|
|
24
23
|
from ..events.event import Event
|
25
24
|
from ..events.event_actions import EventActions
|
25
|
+
from . import _session_util
|
26
26
|
from .base_session_service import BaseSessionService
|
27
27
|
from .base_session_service import GetSessionConfig
|
28
28
|
from .base_session_service import ListEventsResponse
|
29
29
|
from .base_session_service import ListSessionsResponse
|
30
30
|
from .session import Session
|
31
31
|
|
32
|
+
|
33
|
+
isoparse = parser.isoparse
|
32
34
|
logger = logging.getLogger(__name__)
|
33
35
|
|
34
36
|
|
@@ -289,7 +291,7 @@ def _convert_event_to_json(event: Event):
|
|
289
291
|
}
|
290
292
|
event_json['actions'] = actions_json
|
291
293
|
if event.content:
|
292
|
-
event_json['content'] = event.content
|
294
|
+
event_json['content'] = _session_util.encode_content(event.content)
|
293
295
|
if event.error_code:
|
294
296
|
event_json['error_code'] = event.error_code
|
295
297
|
if event.error_message:
|
@@ -316,7 +318,7 @@ def _from_api_event(api_event: dict) -> Event:
|
|
316
318
|
invocation_id=api_event['invocationId'],
|
317
319
|
author=api_event['author'],
|
318
320
|
actions=event_actions,
|
319
|
-
content=api_event.get('content', None),
|
321
|
+
content=_session_util.decode_content(api_event.get('content', None)),
|
320
322
|
timestamp=isoparse(api_event['timestamp']).timestamp(),
|
321
323
|
error_code=api_event.get('errorCode', None),
|
322
324
|
error_message=api_event.get('errorMessage', None),
|
google/adk/tools/agent_tool.py
CHANGED
@@ -45,10 +45,9 @@ class AgentTool(BaseTool):
|
|
45
45
|
skip_summarization: Whether to skip summarization of the agent output.
|
46
46
|
"""
|
47
47
|
|
48
|
-
def __init__(self, agent: BaseAgent):
|
48
|
+
def __init__(self, agent: BaseAgent, skip_summarization: bool = False):
|
49
49
|
self.agent = agent
|
50
|
-
self.skip_summarization: bool =
|
51
|
-
"""Whether to skip summarization of the agent output."""
|
50
|
+
self.skip_summarization: bool = skip_summarization
|
52
51
|
|
53
52
|
super().__init__(name=agent.name, description=agent.description)
|
54
53
|
|
@@ -13,7 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
from .application_integration_toolset import ApplicationIntegrationToolset
|
16
|
+
from .integration_connector_tool import IntegrationConnectorTool
|
16
17
|
|
17
18
|
__all__ = [
|
18
19
|
'ApplicationIntegrationToolset',
|
20
|
+
'IntegrationConnectorTool',
|
19
21
|
]
|
@@ -12,21 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Dict
|
16
|
-
from typing import List
|
17
|
-
from typing import Optional
|
15
|
+
from typing import Dict, List, Optional
|
18
16
|
|
19
17
|
from fastapi.openapi.models import HTTPBearer
|
20
|
-
from google.adk.tools.application_integration_tool.clients.connections_client import ConnectionsClient
|
21
|
-
from google.adk.tools.application_integration_tool.clients.integration_client import IntegrationClient
|
22
|
-
from google.adk.tools.openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
23
|
-
from google.adk.tools.openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
24
|
-
from google.adk.tools.openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
25
18
|
|
26
19
|
from ...auth.auth_credential import AuthCredential
|
27
20
|
from ...auth.auth_credential import AuthCredentialTypes
|
28
21
|
from ...auth.auth_credential import ServiceAccount
|
29
22
|
from ...auth.auth_credential import ServiceAccountCredential
|
23
|
+
from ..openapi_tool.auth.auth_helpers import service_account_scheme_credential
|
24
|
+
from ..openapi_tool.openapi_spec_parser.openapi_spec_parser import OpenApiSpecParser
|
25
|
+
from ..openapi_tool.openapi_spec_parser.openapi_toolset import OpenAPIToolset
|
26
|
+
from ..openapi_tool.openapi_spec_parser.rest_api_tool import RestApiTool
|
27
|
+
from .clients.connections_client import ConnectionsClient
|
28
|
+
from .clients.integration_client import IntegrationClient
|
29
|
+
from .integration_connector_tool import IntegrationConnectorTool
|
30
30
|
|
31
31
|
|
32
32
|
# TODO(cheliu): Apply a common toolset interface
|
@@ -168,6 +168,7 @@ class ApplicationIntegrationToolset:
|
|
168
168
|
actions,
|
169
169
|
service_account_json,
|
170
170
|
)
|
171
|
+
connection_details = {}
|
171
172
|
if integration and trigger:
|
172
173
|
spec = integration_client.get_openapi_spec_for_integration()
|
173
174
|
elif connection and (entity_operations or actions):
|
@@ -175,16 +176,6 @@ class ApplicationIntegrationToolset:
|
|
175
176
|
project, location, connection, service_account_json
|
176
177
|
)
|
177
178
|
connection_details = connections_client.get_connection_details()
|
178
|
-
tool_instructions += (
|
179
|
-
"ALWAYS use serviceName = "
|
180
|
-
+ connection_details["serviceName"]
|
181
|
-
+ ", host = "
|
182
|
-
+ connection_details["host"]
|
183
|
-
+ " and the connection name = "
|
184
|
-
+ f"projects/{project}/locations/{location}/connections/{connection} when"
|
185
|
-
" using this tool"
|
186
|
-
+ ". DONOT ask the user for these values as you already have those."
|
187
|
-
)
|
188
179
|
spec = integration_client.get_openapi_spec_for_connection(
|
189
180
|
tool_name,
|
190
181
|
tool_instructions,
|
@@ -194,9 +185,9 @@ class ApplicationIntegrationToolset:
|
|
194
185
|
"Either (integration and trigger) or (connection and"
|
195
186
|
" (entity_operations or actions)) should be provided."
|
196
187
|
)
|
197
|
-
self._parse_spec_to_tools(spec)
|
188
|
+
self._parse_spec_to_tools(spec, connection_details)
|
198
189
|
|
199
|
-
def _parse_spec_to_tools(self, spec_dict):
|
190
|
+
def _parse_spec_to_tools(self, spec_dict, connection_details):
|
200
191
|
"""Parses the spec dict to a list of RestApiTool."""
|
201
192
|
if self.service_account_json:
|
202
193
|
sa_credential = ServiceAccountCredential.model_validate_json(
|
@@ -218,12 +209,43 @@ class ApplicationIntegrationToolset:
|
|
218
209
|
),
|
219
210
|
)
|
220
211
|
auth_scheme = HTTPBearer(bearerFormat="JWT")
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
212
|
+
|
213
|
+
if self.integration and self.trigger:
|
214
|
+
tools = OpenAPIToolset(
|
215
|
+
spec_dict=spec_dict,
|
216
|
+
auth_credential=auth_credential,
|
217
|
+
auth_scheme=auth_scheme,
|
218
|
+
).get_tools()
|
219
|
+
for tool in tools:
|
220
|
+
self.generated_tools[tool.name] = tool
|
221
|
+
return
|
222
|
+
|
223
|
+
operations = OpenApiSpecParser().parse(spec_dict)
|
224
|
+
|
225
|
+
for open_api_operation in operations:
|
226
|
+
operation = getattr(open_api_operation.operation, "x-operation")
|
227
|
+
entity = None
|
228
|
+
action = None
|
229
|
+
if hasattr(open_api_operation.operation, "x-entity"):
|
230
|
+
entity = getattr(open_api_operation.operation, "x-entity")
|
231
|
+
elif hasattr(open_api_operation.operation, "x-action"):
|
232
|
+
action = getattr(open_api_operation.operation, "x-action")
|
233
|
+
rest_api_tool = RestApiTool.from_parsed_operation(open_api_operation)
|
234
|
+
if auth_scheme:
|
235
|
+
rest_api_tool.configure_auth_scheme(auth_scheme)
|
236
|
+
if auth_credential:
|
237
|
+
rest_api_tool.configure_auth_credential(auth_credential)
|
238
|
+
tool = IntegrationConnectorTool(
|
239
|
+
name=rest_api_tool.name,
|
240
|
+
description=rest_api_tool.description,
|
241
|
+
connection_name=connection_details["name"],
|
242
|
+
connection_host=connection_details["host"],
|
243
|
+
connection_service_name=connection_details["serviceName"],
|
244
|
+
entity=entity,
|
245
|
+
action=action,
|
246
|
+
operation=operation,
|
247
|
+
rest_api_tool=rest_api_tool,
|
248
|
+
)
|
227
249
|
self.generated_tools[tool.name] = tool
|
228
250
|
|
229
251
|
def get_tools(self) -> List[RestApiTool]:
|
@@ -68,12 +68,14 @@ class ConnectionsClient:
|
|
68
68
|
response = self._execute_api_call(url)
|
69
69
|
|
70
70
|
connection_data = response.json()
|
71
|
+
connection_name = connection_data.get("name", "")
|
71
72
|
service_name = connection_data.get("serviceDirectory", "")
|
72
73
|
host = connection_data.get("host", "")
|
73
74
|
if host:
|
74
75
|
service_name = connection_data.get("tlsServiceDirectory", "")
|
75
76
|
auth_override_enabled = connection_data.get("authOverrideEnabled", False)
|
76
77
|
return {
|
78
|
+
"name": connection_name,
|
77
79
|
"serviceName": service_name,
|
78
80
|
"host": host,
|
79
81
|
"authOverrideEnabled": auth_override_enabled,
|
@@ -291,13 +293,9 @@ class ConnectionsClient:
|
|
291
293
|
tool_name: str = "",
|
292
294
|
tool_instructions: str = "",
|
293
295
|
) -> Dict[str, Any]:
|
294
|
-
description =
|
295
|
-
f"Use this tool with" f' action = "{action}" and'
|
296
|
-
) + f' operation = "{operation}" only. Dont ask these values from user.'
|
296
|
+
description = f"Use this tool to execute {action}"
|
297
297
|
if operation == "EXECUTE_QUERY":
|
298
|
-
description
|
299
|
-
(f"Use this tool with" f' action = "{action}" and')
|
300
|
-
+ f' operation = "{operation}" only. Dont ask these values from user.'
|
298
|
+
description += (
|
301
299
|
" Use pageSize = 50 and timeout = 120 until user specifies a"
|
302
300
|
" different value otherwise. If user provides a query in natural"
|
303
301
|
" language, convert it to SQL query and then execute it using the"
|
@@ -308,6 +306,8 @@ class ConnectionsClient:
|
|
308
306
|
"summary": f"{action_display_name}",
|
309
307
|
"description": f"{description} {tool_instructions}",
|
310
308
|
"operationId": f"{tool_name}_{action_display_name}",
|
309
|
+
"x-action": f"{action}",
|
310
|
+
"x-operation": f"{operation}",
|
311
311
|
"requestBody": {
|
312
312
|
"content": {
|
313
313
|
"application/json": {
|
@@ -347,16 +347,12 @@ class ConnectionsClient:
|
|
347
347
|
"post": {
|
348
348
|
"summary": f"List {entity}",
|
349
349
|
"description": (
|
350
|
-
f"Returns
|
351
|
-
|
352
|
-
|
353
|
-
" from"
|
354
|
-
+ ' user. Always use ""'
|
355
|
-
+ ' as filter clause and ""'
|
356
|
-
+ " as page token and 50 as page size until user specifies a"
|
357
|
-
" different value otherwise. Use single quotes for strings in"
|
358
|
-
f" filter clause. {tool_instructions}"
|
350
|
+
f"""Returns the list of {entity} data. If the page token was available in the response, let users know there are more records available. Ask if the user wants to fetch the next page of results. When passing filter use the
|
351
|
+
following format: `field_name1='value1' AND field_name2='value2'
|
352
|
+
`. {tool_instructions}"""
|
359
353
|
),
|
354
|
+
"x-operation": "LIST_ENTITIES",
|
355
|
+
"x-entity": f"{entity}",
|
360
356
|
"operationId": f"{tool_name}_list_{entity}",
|
361
357
|
"requestBody": {
|
362
358
|
"content": {
|
@@ -401,14 +397,11 @@ class ConnectionsClient:
|
|
401
397
|
"post": {
|
402
398
|
"summary": f"Get {entity}",
|
403
399
|
"description": (
|
404
|
-
|
405
|
-
f"Returns the details of the {entity}. Use this tool with"
|
406
|
-
f' entity = "{entity}" and'
|
407
|
-
)
|
408
|
-
+ ' operation = "GET_ENTITY" only. Dont ask these values from'
|
409
|
-
f" user. {tool_instructions}"
|
400
|
+
f"Returns the details of the {entity}. {tool_instructions}"
|
410
401
|
),
|
411
402
|
"operationId": f"{tool_name}_get_{entity}",
|
403
|
+
"x-operation": "GET_ENTITY",
|
404
|
+
"x-entity": f"{entity}",
|
412
405
|
"requestBody": {
|
413
406
|
"content": {
|
414
407
|
"application/json": {
|
@@ -445,17 +438,10 @@ class ConnectionsClient:
|
|
445
438
|
) -> Dict[str, Any]:
|
446
439
|
return {
|
447
440
|
"post": {
|
448
|
-
"summary": f"
|
449
|
-
"description":
|
450
|
-
|
451
|
-
|
452
|
-
f' entity = "{entity}" and'
|
453
|
-
)
|
454
|
-
+ ' operation = "CREATE_ENTITY" only. Dont ask these values'
|
455
|
-
" from"
|
456
|
-
+ " user. Follow the schema of the entity provided in the"
|
457
|
-
f" instructions to create {entity}. {tool_instructions}"
|
458
|
-
),
|
441
|
+
"summary": f"Creates a new {entity}",
|
442
|
+
"description": f"Creates a new {entity}. {tool_instructions}",
|
443
|
+
"x-operation": "CREATE_ENTITY",
|
444
|
+
"x-entity": f"{entity}",
|
459
445
|
"operationId": f"{tool_name}_create_{entity}",
|
460
446
|
"requestBody": {
|
461
447
|
"content": {
|
@@ -491,18 +477,10 @@ class ConnectionsClient:
|
|
491
477
|
) -> Dict[str, Any]:
|
492
478
|
return {
|
493
479
|
"post": {
|
494
|
-
"summary": f"
|
495
|
-
"description":
|
496
|
-
|
497
|
-
|
498
|
-
f' entity = "{entity}" and'
|
499
|
-
)
|
500
|
-
+ ' operation = "UPDATE_ENTITY" only. Dont ask these values'
|
501
|
-
" from"
|
502
|
-
+ " user. Use entityId to uniquely identify the entity to"
|
503
|
-
" update. Follow the schema of the entity provided in the"
|
504
|
-
f" instructions to update {entity}. {tool_instructions}"
|
505
|
-
),
|
480
|
+
"summary": f"Updates the {entity}",
|
481
|
+
"description": f"Updates the {entity}. {tool_instructions}",
|
482
|
+
"x-operation": "UPDATE_ENTITY",
|
483
|
+
"x-entity": f"{entity}",
|
506
484
|
"operationId": f"{tool_name}_update_{entity}",
|
507
485
|
"requestBody": {
|
508
486
|
"content": {
|
@@ -538,16 +516,10 @@ class ConnectionsClient:
|
|
538
516
|
) -> Dict[str, Any]:
|
539
517
|
return {
|
540
518
|
"post": {
|
541
|
-
"summary": f"Delete {entity}",
|
542
|
-
"description":
|
543
|
-
|
544
|
-
|
545
|
-
f' entity = "{entity}" and'
|
546
|
-
)
|
547
|
-
+ ' operation = "DELETE_ENTITY" only. Dont ask these values'
|
548
|
-
" from"
|
549
|
-
f" user. {tool_instructions}"
|
550
|
-
),
|
519
|
+
"summary": f"Delete the {entity}",
|
520
|
+
"description": f"Deletes the {entity}. {tool_instructions}",
|
521
|
+
"x-operation": "DELETE_ENTITY",
|
522
|
+
"x-entity": f"{entity}",
|
551
523
|
"operationId": f"{tool_name}_delete_{entity}",
|
552
524
|
"requestBody": {
|
553
525
|
"content": {
|