google-adk 1.1.0__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +0 -2
- google/adk/agents/invocation_context.py +3 -3
- google/adk/agents/parallel_agent.py +17 -7
- google/adk/agents/sequential_agent.py +8 -8
- google/adk/auth/auth_preprocessor.py +18 -17
- google/adk/cli/agent_graph.py +165 -23
- google/adk/cli/browser/assets/ADK-512-color.svg +9 -0
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-PKDNKWJE.js → main-CS5OLUMF.js} +59 -59
- google/adk/cli/browser/polyfills-FFHMD2TL.js +17 -0
- google/adk/cli/cli.py +9 -9
- google/adk/cli/cli_deploy.py +157 -0
- google/adk/cli/cli_tools_click.py +228 -99
- google/adk/cli/fast_api.py +119 -34
- google/adk/cli/utils/agent_loader.py +60 -44
- google/adk/cli/utils/envs.py +1 -1
- google/adk/code_executors/unsafe_local_code_executor.py +11 -0
- google/adk/errors/__init__.py +13 -0
- google/adk/errors/not_found_error.py +28 -0
- google/adk/evaluation/agent_evaluator.py +1 -1
- google/adk/evaluation/eval_sets_manager.py +36 -6
- google/adk/evaluation/evaluation_generator.py +5 -4
- google/adk/evaluation/local_eval_sets_manager.py +101 -6
- google/adk/flows/llm_flows/agent_transfer.py +2 -2
- google/adk/flows/llm_flows/base_llm_flow.py +19 -0
- google/adk/flows/llm_flows/contents.py +4 -4
- google/adk/flows/llm_flows/functions.py +140 -127
- google/adk/memory/vertex_ai_rag_memory_service.py +2 -2
- google/adk/models/anthropic_llm.py +7 -10
- google/adk/models/google_llm.py +46 -18
- google/adk/models/lite_llm.py +63 -26
- google/adk/py.typed +0 -0
- google/adk/sessions/_session_util.py +10 -16
- google/adk/sessions/database_session_service.py +81 -66
- google/adk/sessions/vertex_ai_session_service.py +32 -6
- google/adk/telemetry.py +91 -24
- google/adk/tools/_automatic_function_calling_util.py +31 -25
- google/adk/tools/{function_parameter_parse_util.py → _function_parameter_parse_util.py} +9 -3
- google/adk/tools/_gemini_schema_util.py +158 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +3 -2
- google/adk/tools/application_integration_tool/clients/connections_client.py +7 -0
- google/adk/tools/application_integration_tool/integration_connector_tool.py +5 -7
- google/adk/tools/base_tool.py +4 -8
- google/adk/tools/bigquery/__init__.py +11 -1
- google/adk/tools/bigquery/bigquery_credentials.py +9 -4
- google/adk/tools/bigquery/bigquery_toolset.py +86 -0
- google/adk/tools/bigquery/client.py +33 -0
- google/adk/tools/bigquery/metadata_tool.py +249 -0
- google/adk/tools/bigquery/query_tool.py +76 -0
- google/adk/tools/function_tool.py +4 -4
- google/adk/tools/langchain_tool.py +20 -13
- google/adk/tools/load_memory_tool.py +1 -0
- google/adk/tools/mcp_tool/conversion_utils.py +4 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +63 -5
- google/adk/tools/mcp_tool/mcp_tool.py +3 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +15 -8
- google/adk/tools/openapi_tool/common/common.py +4 -43
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +0 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_spec_parser.py +4 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +4 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +7 -127
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +2 -7
- google/adk/tools/transfer_to_agent_tool.py +8 -1
- google/adk/tools/vertex_ai_search_tool.py +8 -1
- google/adk/utils/variant_utils.py +51 -0
- google/adk/version.py +1 -1
- {google_adk-1.1.0.dist-info → google_adk-1.2.0.dist-info}/METADATA +7 -7
- {google_adk-1.1.0.dist-info → google_adk-1.2.0.dist-info}/RECORD +71 -61
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +0 -17
- {google_adk-1.1.0.dist-info → google_adk-1.2.0.dist-info}/WHEEL +0 -0
- {google_adk-1.1.0.dist-info → google_adk-1.2.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.1.0.dist-info → google_adk-1.2.0.dist-info}/licenses/LICENSE +0 -0
google/adk/models/lite_llm.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from __future__ import annotations
|
15
16
|
|
16
17
|
import base64
|
17
18
|
import json
|
@@ -62,6 +63,7 @@ class FunctionChunk(BaseModel):
|
|
62
63
|
id: Optional[str]
|
63
64
|
name: Optional[str]
|
64
65
|
args: Optional[str]
|
66
|
+
index: Optional[int] = 0
|
65
67
|
|
66
68
|
|
67
69
|
class TextChunk(BaseModel):
|
@@ -136,7 +138,7 @@ def _safe_json_serialize(obj) -> str:
|
|
136
138
|
|
137
139
|
try:
|
138
140
|
# Try direct JSON serialization first
|
139
|
-
return json.dumps(obj)
|
141
|
+
return json.dumps(obj, ensure_ascii=False)
|
140
142
|
except (TypeError, OverflowError):
|
141
143
|
return str(obj)
|
142
144
|
|
@@ -186,7 +188,7 @@ def _content_to_message_param(
|
|
186
188
|
id=part.function_call.id,
|
187
189
|
function=Function(
|
188
190
|
name=part.function_call.name,
|
189
|
-
arguments=
|
191
|
+
arguments=_safe_json_serialize(part.function_call.args),
|
190
192
|
),
|
191
193
|
)
|
192
194
|
)
|
@@ -194,6 +196,14 @@ def _content_to_message_param(
|
|
194
196
|
content_present = True
|
195
197
|
|
196
198
|
final_content = message_content if content_present else None
|
199
|
+
if final_content and isinstance(final_content, list):
|
200
|
+
# when the content is a single text object, we can use it directly.
|
201
|
+
# this is needed for ollama_chat provider which fails if content is a list
|
202
|
+
final_content = (
|
203
|
+
final_content[0].get("text", "")
|
204
|
+
if final_content[0].get("type", None) == "text"
|
205
|
+
else final_content
|
206
|
+
)
|
197
207
|
|
198
208
|
return ChatCompletionAssistantMessage(
|
199
209
|
role=role,
|
@@ -386,6 +396,7 @@ def _model_response_to_chunk(
|
|
386
396
|
id=tool_call.id,
|
387
397
|
name=tool_call.function.name,
|
388
398
|
args=tool_call.function.arguments,
|
399
|
+
index=tool_call.index,
|
389
400
|
), finish_reason
|
390
401
|
|
391
402
|
if finish_reason and not (
|
@@ -477,7 +488,7 @@ def _get_completion_inputs(
|
|
477
488
|
llm_request: The LlmRequest to convert.
|
478
489
|
|
479
490
|
Returns:
|
480
|
-
The litellm inputs (message list and
|
491
|
+
The litellm inputs (message list, tool dictionary and response format).
|
481
492
|
"""
|
482
493
|
messages = []
|
483
494
|
for content in llm_request.contents or []:
|
@@ -506,7 +517,13 @@ def _get_completion_inputs(
|
|
506
517
|
_function_declaration_to_tool_param(tool)
|
507
518
|
for tool in llm_request.config.tools[0].function_declarations
|
508
519
|
]
|
509
|
-
|
520
|
+
|
521
|
+
response_format = None
|
522
|
+
|
523
|
+
if llm_request.config.response_schema:
|
524
|
+
response_format = llm_request.config.response_schema
|
525
|
+
|
526
|
+
return messages, tools, response_format
|
510
527
|
|
511
528
|
|
512
529
|
def _build_function_declaration_log(
|
@@ -643,33 +660,48 @@ class LiteLlm(BaseLlm):
|
|
643
660
|
self._maybe_append_user_content(llm_request)
|
644
661
|
logger.debug(_build_request_log(llm_request))
|
645
662
|
|
646
|
-
messages, tools = _get_completion_inputs(llm_request)
|
663
|
+
messages, tools, response_format = _get_completion_inputs(llm_request)
|
647
664
|
|
648
665
|
completion_args = {
|
649
666
|
"model": self.model,
|
650
667
|
"messages": messages,
|
651
668
|
"tools": tools,
|
669
|
+
"response_format": response_format,
|
652
670
|
}
|
653
671
|
completion_args.update(self._additional_args)
|
654
672
|
|
655
673
|
if stream:
|
656
674
|
text = ""
|
657
|
-
|
658
|
-
|
659
|
-
function_id = None
|
675
|
+
# Track function calls by index
|
676
|
+
function_calls = {} # index -> {name, args, id}
|
660
677
|
completion_args["stream"] = True
|
661
678
|
aggregated_llm_response = None
|
662
679
|
aggregated_llm_response_with_tool_call = None
|
663
680
|
usage_metadata = None
|
664
|
-
|
681
|
+
fallback_index = 0
|
665
682
|
for part in self.llm_client.completion(**completion_args):
|
666
683
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
667
684
|
if isinstance(chunk, FunctionChunk):
|
685
|
+
index = chunk.index or fallback_index
|
686
|
+
if index not in function_calls:
|
687
|
+
function_calls[index] = {"name": "", "args": "", "id": None}
|
688
|
+
|
668
689
|
if chunk.name:
|
669
|
-
|
690
|
+
function_calls[index]["name"] += chunk.name
|
670
691
|
if chunk.args:
|
671
|
-
|
672
|
-
|
692
|
+
function_calls[index]["args"] += chunk.args
|
693
|
+
|
694
|
+
# check if args is completed (workaround for improper chunk
|
695
|
+
# indexing)
|
696
|
+
try:
|
697
|
+
json.loads(function_calls[index]["args"])
|
698
|
+
fallback_index += 1
|
699
|
+
except json.JSONDecodeError:
|
700
|
+
pass
|
701
|
+
|
702
|
+
function_calls[index]["id"] = (
|
703
|
+
chunk.id or function_calls[index]["id"] or str(index)
|
704
|
+
)
|
673
705
|
elif isinstance(chunk, TextChunk):
|
674
706
|
text += chunk.text
|
675
707
|
yield _message_to_generate_content_response(
|
@@ -686,28 +718,33 @@ class LiteLlm(BaseLlm):
|
|
686
718
|
total_token_count=chunk.total_tokens,
|
687
719
|
)
|
688
720
|
|
689
|
-
if
|
721
|
+
if (
|
722
|
+
finish_reason == "tool_calls" or finish_reason == "stop"
|
723
|
+
) and function_calls:
|
724
|
+
tool_calls = []
|
725
|
+
for index, func_data in function_calls.items():
|
726
|
+
if func_data["id"]:
|
727
|
+
tool_calls.append(
|
728
|
+
ChatCompletionMessageToolCall(
|
729
|
+
type="function",
|
730
|
+
id=func_data["id"],
|
731
|
+
function=Function(
|
732
|
+
name=func_data["name"],
|
733
|
+
arguments=func_data["args"],
|
734
|
+
index=index,
|
735
|
+
),
|
736
|
+
)
|
737
|
+
)
|
690
738
|
aggregated_llm_response_with_tool_call = (
|
691
739
|
_message_to_generate_content_response(
|
692
740
|
ChatCompletionAssistantMessage(
|
693
741
|
role="assistant",
|
694
742
|
content="",
|
695
|
-
tool_calls=
|
696
|
-
ChatCompletionMessageToolCall(
|
697
|
-
type="function",
|
698
|
-
id=function_id,
|
699
|
-
function=Function(
|
700
|
-
name=function_name,
|
701
|
-
arguments=function_args,
|
702
|
-
),
|
703
|
-
)
|
704
|
-
],
|
743
|
+
tool_calls=tool_calls,
|
705
744
|
)
|
706
745
|
)
|
707
746
|
)
|
708
|
-
|
709
|
-
function_args = ""
|
710
|
-
function_id = None
|
747
|
+
function_calls.clear()
|
711
748
|
elif finish_reason == "stop" and text:
|
712
749
|
aggregated_llm_response = _message_to_generate_content_response(
|
713
750
|
ChatCompletionAssistantMessage(role="assistant", content=text)
|
google/adk/py.typed
ADDED
File without changes
|
@@ -11,34 +11,28 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
-
|
15
14
|
"""Utility functions for session service."""
|
15
|
+
from __future__ import annotations
|
16
16
|
|
17
|
-
import base64
|
18
17
|
from typing import Any
|
19
18
|
from typing import Optional
|
20
19
|
|
21
20
|
from google.genai import types
|
22
21
|
|
23
22
|
|
24
|
-
def encode_content(content: types.Content):
|
25
|
-
"""Encodes a content object to a JSON dictionary."""
|
26
|
-
encoded_content = content.model_dump(exclude_none=True)
|
27
|
-
for p in encoded_content["parts"]:
|
28
|
-
if "inline_data" in p:
|
29
|
-
p["inline_data"]["data"] = base64.b64encode(
|
30
|
-
p["inline_data"]["data"]
|
31
|
-
).decode("utf-8")
|
32
|
-
return encoded_content
|
33
|
-
|
34
|
-
|
35
23
|
def decode_content(
|
36
24
|
content: Optional[dict[str, Any]],
|
37
25
|
) -> Optional[types.Content]:
|
38
26
|
"""Decodes a content object from a JSON dictionary."""
|
39
27
|
if not content:
|
40
28
|
return None
|
41
|
-
for p in content["parts"]:
|
42
|
-
if "inline_data" in p:
|
43
|
-
p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
|
44
29
|
return types.Content.model_validate(content)
|
30
|
+
|
31
|
+
|
32
|
+
def decode_grounding_metadata(
|
33
|
+
grounding_metadata: Optional[dict[str, Any]],
|
34
|
+
) -> Optional[types.GroundingMetadata]:
|
35
|
+
"""Decodes a grounding metadata object from a JSON dictionary."""
|
36
|
+
if not grounding_metadata:
|
37
|
+
return None
|
38
|
+
return types.GroundingMetadata.model_validate(grounding_metadata)
|
@@ -11,6 +11,8 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
from __future__ import annotations
|
15
|
+
|
14
16
|
import copy
|
15
17
|
from datetime import datetime
|
16
18
|
import json
|
@@ -19,6 +21,7 @@ from typing import Any
|
|
19
21
|
from typing import Optional
|
20
22
|
import uuid
|
21
23
|
|
24
|
+
from google.genai import types
|
22
25
|
from sqlalchemy import Boolean
|
23
26
|
from sqlalchemy import delete
|
24
27
|
from sqlalchemy import Dialect
|
@@ -89,6 +92,18 @@ class DynamicJSON(TypeDecorator):
|
|
89
92
|
return value
|
90
93
|
|
91
94
|
|
95
|
+
class PreciseTimestamp(TypeDecorator):
|
96
|
+
"""Represents a timestamp precise to the microsecond."""
|
97
|
+
|
98
|
+
impl = DateTime
|
99
|
+
cache_ok = True
|
100
|
+
|
101
|
+
def load_dialect_impl(self, dialect):
|
102
|
+
if dialect.name == "mysql":
|
103
|
+
return dialect.type_descriptor(mysql.DATETIME(fsp=6))
|
104
|
+
return self.impl
|
105
|
+
|
106
|
+
|
92
107
|
class Base(DeclarativeBase):
|
93
108
|
"""Base class for database tables."""
|
94
109
|
|
@@ -153,7 +168,9 @@ class StorageEvent(Base):
|
|
153
168
|
branch: Mapped[str] = mapped_column(
|
154
169
|
String(DEFAULT_MAX_VARCHAR_LENGTH), nullable=True
|
155
170
|
)
|
156
|
-
timestamp: Mapped[
|
171
|
+
timestamp: Mapped[PreciseTimestamp] = mapped_column(
|
172
|
+
PreciseTimestamp, default=func.now()
|
173
|
+
)
|
157
174
|
content: Mapped[dict[str, Any]] = mapped_column(DynamicJSON, nullable=True)
|
158
175
|
actions: Mapped[MutableDict[str, Any]] = mapped_column(PickleType)
|
159
176
|
|
@@ -199,6 +216,55 @@ class StorageEvent(Base):
|
|
199
216
|
else:
|
200
217
|
self.long_running_tool_ids_json = json.dumps(list(value))
|
201
218
|
|
219
|
+
@classmethod
|
220
|
+
def from_event(cls, session: Session, event: Event) -> StorageEvent:
|
221
|
+
storage_event = StorageEvent(
|
222
|
+
id=event.id,
|
223
|
+
invocation_id=event.invocation_id,
|
224
|
+
author=event.author,
|
225
|
+
branch=event.branch,
|
226
|
+
actions=event.actions,
|
227
|
+
session_id=session.id,
|
228
|
+
app_name=session.app_name,
|
229
|
+
user_id=session.user_id,
|
230
|
+
timestamp=datetime.fromtimestamp(event.timestamp),
|
231
|
+
long_running_tool_ids=event.long_running_tool_ids,
|
232
|
+
partial=event.partial,
|
233
|
+
turn_complete=event.turn_complete,
|
234
|
+
error_code=event.error_code,
|
235
|
+
error_message=event.error_message,
|
236
|
+
interrupted=event.interrupted,
|
237
|
+
)
|
238
|
+
if event.content:
|
239
|
+
storage_event.content = event.content.model_dump(
|
240
|
+
exclude_none=True, mode="json"
|
241
|
+
)
|
242
|
+
if event.grounding_metadata:
|
243
|
+
storage_event.grounding_metadata = event.grounding_metadata.model_dump(
|
244
|
+
exclude_none=True, mode="json"
|
245
|
+
)
|
246
|
+
return storage_event
|
247
|
+
|
248
|
+
def to_event(self) -> Event:
|
249
|
+
return Event(
|
250
|
+
id=self.id,
|
251
|
+
invocation_id=self.invocation_id,
|
252
|
+
author=self.author,
|
253
|
+
branch=self.branch,
|
254
|
+
actions=self.actions,
|
255
|
+
timestamp=self.timestamp.timestamp(),
|
256
|
+
content=_session_util.decode_content(self.content),
|
257
|
+
long_running_tool_ids=self.long_running_tool_ids,
|
258
|
+
partial=self.partial,
|
259
|
+
turn_complete=self.turn_complete,
|
260
|
+
error_code=self.error_code,
|
261
|
+
error_message=self.error_message,
|
262
|
+
interrupted=self.interrupted,
|
263
|
+
grounding_metadata=_session_util.decode_grounding_metadata(
|
264
|
+
self.grounding_metadata
|
265
|
+
),
|
266
|
+
)
|
267
|
+
|
202
268
|
|
203
269
|
class StorageAppState(Base):
|
204
270
|
"""Represents an app state stored in the database."""
|
@@ -238,14 +304,14 @@ class StorageUserState(Base):
|
|
238
304
|
class DatabaseSessionService(BaseSessionService):
|
239
305
|
"""A session service that uses a database for storage."""
|
240
306
|
|
241
|
-
def __init__(self, db_url: str):
|
307
|
+
def __init__(self, db_url: str, **kwargs: Any):
|
242
308
|
"""Initializes the database session service with a database URL."""
|
243
309
|
# 1. Create DB engine for db connection
|
244
310
|
# 2. Create all tables based on schema
|
245
311
|
# 3. Initialize all properties
|
246
312
|
|
247
313
|
try:
|
248
|
-
db_engine = create_engine(db_url)
|
314
|
+
db_engine = create_engine(db_url, **kwargs)
|
249
315
|
except Exception as e:
|
250
316
|
if isinstance(e, ArgumentError):
|
251
317
|
raise ValueError(
|
@@ -409,25 +475,7 @@ class DatabaseSessionService(BaseSessionService):
|
|
409
475
|
state=merged_state,
|
410
476
|
last_update_time=storage_session.update_time.timestamp(),
|
411
477
|
)
|
412
|
-
session.events = [
|
413
|
-
Event(
|
414
|
-
id=e.id,
|
415
|
-
author=e.author,
|
416
|
-
branch=e.branch,
|
417
|
-
invocation_id=e.invocation_id,
|
418
|
-
content=_session_util.decode_content(e.content),
|
419
|
-
actions=e.actions,
|
420
|
-
timestamp=e.timestamp.timestamp(),
|
421
|
-
long_running_tool_ids=e.long_running_tool_ids,
|
422
|
-
grounding_metadata=e.grounding_metadata,
|
423
|
-
partial=e.partial,
|
424
|
-
turn_complete=e.turn_complete,
|
425
|
-
error_code=e.error_code,
|
426
|
-
error_message=e.error_message,
|
427
|
-
interrupted=e.interrupted,
|
428
|
-
)
|
429
|
-
for e in reversed(storage_events)
|
430
|
-
]
|
478
|
+
session.events = [e.to_event() for e in reversed(storage_events)]
|
431
479
|
return session
|
432
480
|
|
433
481
|
@override
|
@@ -512,38 +560,18 @@ class DatabaseSessionService(BaseSessionService):
|
|
512
560
|
_extract_state_delta(event.actions.state_delta)
|
513
561
|
)
|
514
562
|
|
515
|
-
# Merge state
|
516
|
-
|
517
|
-
|
518
|
-
|
519
|
-
|
520
|
-
|
521
|
-
|
522
|
-
|
523
|
-
|
524
|
-
|
525
|
-
storage_event = StorageEvent(
|
526
|
-
id=event.id,
|
527
|
-
invocation_id=event.invocation_id,
|
528
|
-
author=event.author,
|
529
|
-
branch=event.branch,
|
530
|
-
actions=event.actions,
|
531
|
-
session_id=session.id,
|
532
|
-
app_name=session.app_name,
|
533
|
-
user_id=session.user_id,
|
534
|
-
timestamp=datetime.fromtimestamp(event.timestamp),
|
535
|
-
long_running_tool_ids=event.long_running_tool_ids,
|
536
|
-
grounding_metadata=event.grounding_metadata,
|
537
|
-
partial=event.partial,
|
538
|
-
turn_complete=event.turn_complete,
|
539
|
-
error_code=event.error_code,
|
540
|
-
error_message=event.error_message,
|
541
|
-
interrupted=event.interrupted,
|
542
|
-
)
|
543
|
-
if event.content:
|
544
|
-
storage_event.content = _session_util.encode_content(event.content)
|
563
|
+
# Merge state and update storage
|
564
|
+
if app_state_delta:
|
565
|
+
app_state.update(app_state_delta)
|
566
|
+
storage_app_state.state = app_state
|
567
|
+
if user_state_delta:
|
568
|
+
user_state.update(user_state_delta)
|
569
|
+
storage_user_state.state = user_state
|
570
|
+
if session_state_delta:
|
571
|
+
session_state.update(session_state_delta)
|
572
|
+
storage_session.state = session_state
|
545
573
|
|
546
|
-
session_factory.add(
|
574
|
+
session_factory.add(StorageEvent.from_event(session, event))
|
547
575
|
|
548
576
|
session_factory.commit()
|
549
577
|
session_factory.refresh(storage_session)
|
@@ -556,19 +584,6 @@ class DatabaseSessionService(BaseSessionService):
|
|
556
584
|
return event
|
557
585
|
|
558
586
|
|
559
|
-
def convert_event(event: StorageEvent) -> Event:
|
560
|
-
"""Converts a storage event to an event."""
|
561
|
-
return Event(
|
562
|
-
id=event.id,
|
563
|
-
author=event.author,
|
564
|
-
branch=event.branch,
|
565
|
-
invocation_id=event.invocation_id,
|
566
|
-
content=event.content,
|
567
|
-
actions=event.actions,
|
568
|
-
timestamp=event.timestamp.timestamp(),
|
569
|
-
)
|
570
|
-
|
571
|
-
|
572
587
|
def _extract_state_delta(state: dict[str, Any]):
|
573
588
|
app_state_delta = {}
|
574
589
|
user_state_delta = {}
|
@@ -11,13 +11,18 @@
|
|
11
11
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
|
+
from __future__ import annotations
|
15
|
+
|
14
16
|
import asyncio
|
15
17
|
import logging
|
16
18
|
import re
|
19
|
+
import time
|
17
20
|
from typing import Any
|
18
21
|
from typing import Optional
|
22
|
+
import urllib.parse
|
19
23
|
|
20
24
|
from dateutil import parser
|
25
|
+
from google.genai import types
|
21
26
|
from typing_extensions import override
|
22
27
|
|
23
28
|
from google import genai
|
@@ -154,15 +159,29 @@ class VertexAiSessionService(BaseSessionService):
|
|
154
159
|
if list_events_api_response.get('httpHeaders', None):
|
155
160
|
return session
|
156
161
|
|
157
|
-
session.events
|
162
|
+
session.events += [
|
158
163
|
_from_api_event(event)
|
159
164
|
for event in list_events_api_response['sessionEvents']
|
160
165
|
]
|
166
|
+
|
167
|
+
while list_events_api_response.get('nextPageToken', None):
|
168
|
+
page_token = list_events_api_response.get('nextPageToken', None)
|
169
|
+
list_events_api_response = await api_client.async_request(
|
170
|
+
http_method='GET',
|
171
|
+
path=f'reasoningEngines/{reasoning_engine_id}/sessions/{session_id}/events?pageToken={page_token}',
|
172
|
+
request_dict={},
|
173
|
+
)
|
174
|
+
session.events += [
|
175
|
+
_from_api_event(event)
|
176
|
+
for event in list_events_api_response['sessionEvents']
|
177
|
+
]
|
178
|
+
|
161
179
|
session.events = [
|
162
180
|
event for event in session.events if event.timestamp <= update_timestamp
|
163
181
|
]
|
164
182
|
session.events.sort(key=lambda event: event.timestamp)
|
165
183
|
|
184
|
+
# Filter events based on config
|
166
185
|
if config:
|
167
186
|
if config.num_recent_events:
|
168
187
|
session.events = session.events[-config.num_recent_events :]
|
@@ -183,10 +202,15 @@ class VertexAiSessionService(BaseSessionService):
|
|
183
202
|
) -> ListSessionsResponse:
|
184
203
|
reasoning_engine_id = _parse_reasoning_engine_id(app_name)
|
185
204
|
|
205
|
+
path = f'reasoningEngines/{reasoning_engine_id}/sessions'
|
206
|
+
if user_id:
|
207
|
+
parsed_user_id = urllib.parse.quote(f'''"{user_id}"''', safe='')
|
208
|
+
path = path + f'?filter=user_id={parsed_user_id}'
|
209
|
+
|
186
210
|
api_client = _get_api_client(self.project, self.location)
|
187
211
|
api_response = await api_client.async_request(
|
188
212
|
http_method='GET',
|
189
|
-
path=
|
213
|
+
path=path,
|
190
214
|
request_dict={},
|
191
215
|
)
|
192
216
|
|
@@ -256,7 +280,7 @@ def _convert_event_to_json(event: Event):
|
|
256
280
|
}
|
257
281
|
if event.grounding_metadata:
|
258
282
|
metadata_json['grounding_metadata'] = event.grounding_metadata.model_dump(
|
259
|
-
exclude_none=True
|
283
|
+
exclude_none=True, mode='json'
|
260
284
|
)
|
261
285
|
|
262
286
|
event_json = {
|
@@ -284,7 +308,9 @@ def _convert_event_to_json(event: Event):
|
|
284
308
|
}
|
285
309
|
event_json['actions'] = actions_json
|
286
310
|
if event.content:
|
287
|
-
event_json['content'] =
|
311
|
+
event_json['content'] = event.content.model_dump(
|
312
|
+
exclude_none=True, mode='json'
|
313
|
+
)
|
288
314
|
if event.error_code:
|
289
315
|
event_json['error_code'] = event.error_code
|
290
316
|
if event.error_message:
|
@@ -325,8 +351,8 @@ def _from_api_event(api_event: dict) -> Event:
|
|
325
351
|
event.turn_complete = api_event['eventMetadata'].get('turnComplete', None)
|
326
352
|
event.interrupted = api_event['eventMetadata'].get('interrupted', None)
|
327
353
|
event.branch = api_event['eventMetadata'].get('branch', None)
|
328
|
-
event.grounding_metadata =
|
329
|
-
'groundingMetadata', None
|
354
|
+
event.grounding_metadata = _session_util.decode_grounding_metadata(
|
355
|
+
api_event['eventMetadata'].get('groundingMetadata', None)
|
330
356
|
)
|
331
357
|
event.long_running_tool_ids = (
|
332
358
|
set(long_running_tool_ids_list) if long_running_tool_ids_list else None
|