google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/active_streaming_tool.py +1 -0
- google/adk/agents/base_agent.py +91 -47
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +4 -9
- google/adk/agents/invocation_context.py +1 -0
- google/adk/agents/langgraph_agent.py +1 -0
- google/adk/agents/live_request_queue.py +1 -0
- google/adk/agents/llm_agent.py +172 -35
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +5 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +5 -2
- google/adk/artifacts/base_artifact_service.py +5 -10
- google/adk/artifacts/gcs_artifact_service.py +9 -9
- google/adk/artifacts/in_memory_artifact_service.py +6 -6
- google/adk/auth/auth_credential.py +9 -5
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +181 -106
- google/adk/cli/cli_tools_click.py +147 -62
- google/adk/cli/fast_api.py +340 -158
- google/adk/cli/fast_api.py.orig +822 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_constants.py +1 -0
- google/adk/evaluation/evaluation_generator.py +89 -114
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +107 -3
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +7 -1
- google/adk/events/event_actions.py +7 -1
- google/adk/examples/example.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/__init__.py +0 -1
- google/adk/flows/llm_flows/_code_execution.py +19 -11
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +86 -22
- google/adk/flows/llm_flows/basic.py +3 -0
- google/adk/flows/llm_flows/functions.py +10 -9
- google/adk/flows/llm_flows/instructions.py +28 -9
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +25 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +59 -27
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
- google/adk/models/anthropic_llm.py +36 -11
- google/adk/models/base_llm.py +45 -4
- google/adk/models/gemini_llm_connection.py +15 -2
- google/adk/models/google_llm.py +9 -44
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +94 -38
- google/adk/models/llm_request.py +1 -1
- google/adk/models/llm_response.py +15 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +68 -44
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +14 -0
- google/adk/sessions/base_session_service.py +8 -32
- google/adk/sessions/database_session_service.py +58 -61
- google/adk/sessions/in_memory_session_service.py +108 -26
- google/adk/sessions/session.py +4 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +16 -13
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_artifacts_tool.py +4 -4
- google/adk/tools/load_memory_tool.py +16 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +1 -1
- google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/common/common.py +2 -5
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/tool_context.py +4 -4
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.4.0.dist-info/RECORD +0 -179
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -62,6 +62,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
62
62
|
llm_request.live_connect_config.output_audio_transcription = (
|
63
63
|
invocation_context.run_config.output_audio_transcription
|
64
64
|
)
|
65
|
+
llm_request.live_connect_config.input_audio_transcription = (
|
66
|
+
invocation_context.run_config.input_audio_transcription
|
67
|
+
)
|
65
68
|
|
66
69
|
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
|
67
70
|
|
@@ -41,7 +41,7 @@ from ...tools.tool_context import ToolContext
|
|
41
41
|
AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
|
42
42
|
REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
|
43
43
|
|
44
|
-
logger = logging.getLogger(__name__)
|
44
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
45
45
|
|
46
46
|
|
47
47
|
def generate_client_function_call_id() -> str:
|
@@ -106,7 +106,7 @@ def generate_auth_event(
|
|
106
106
|
args=AuthToolArguments(
|
107
107
|
function_call_id=function_call_id,
|
108
108
|
auth_config=auth_config,
|
109
|
-
).model_dump(exclude_none=True),
|
109
|
+
).model_dump(exclude_none=True, by_alias=True),
|
110
110
|
)
|
111
111
|
request_euc_function_call.id = generate_client_function_call_id()
|
112
112
|
long_running_tool_ids.add(request_euc_function_call.id)
|
@@ -153,22 +153,22 @@ async def handle_function_calls_async(
|
|
153
153
|
function_args = function_call.args or {}
|
154
154
|
function_response: Optional[dict] = None
|
155
155
|
|
156
|
-
|
157
|
-
|
158
|
-
function_response = agent.before_tool_callback(
|
156
|
+
for callback in agent.canonical_before_tool_callbacks:
|
157
|
+
function_response = callback(
|
159
158
|
tool=tool, args=function_args, tool_context=tool_context
|
160
159
|
)
|
161
160
|
if inspect.isawaitable(function_response):
|
162
161
|
function_response = await function_response
|
162
|
+
if function_response:
|
163
|
+
break
|
163
164
|
|
164
165
|
if not function_response:
|
165
166
|
function_response = await __call_tool_async(
|
166
167
|
tool, args=function_args, tool_context=tool_context
|
167
168
|
)
|
168
169
|
|
169
|
-
|
170
|
-
|
171
|
-
altered_function_response = agent.after_tool_callback(
|
170
|
+
for callback in agent.canonical_after_tool_callbacks:
|
171
|
+
altered_function_response = callback(
|
172
172
|
tool=tool,
|
173
173
|
args=function_args,
|
174
174
|
tool_context=tool_context,
|
@@ -178,6 +178,7 @@ async def handle_function_calls_async(
|
|
178
178
|
altered_function_response = await altered_function_response
|
179
179
|
if altered_function_response is not None:
|
180
180
|
function_response = altered_function_response
|
181
|
+
break
|
181
182
|
|
182
183
|
if tool.is_long_running:
|
183
184
|
# Allow long running function to return None to not provide function response.
|
@@ -332,7 +333,7 @@ async def _process_function_live_helper(
|
|
332
333
|
function_response = {
|
333
334
|
'status': f'No active streaming function named {function_name} found'
|
334
335
|
}
|
335
|
-
elif hasattr(tool,
|
336
|
+
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
|
336
337
|
# for streaming tool use case
|
337
338
|
# we require the function to be a async generator function
|
338
339
|
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
@@ -53,16 +53,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
53
53
|
if (
|
54
54
|
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
55
55
|
): # not empty str
|
56
|
-
raw_si =
|
57
|
-
|
56
|
+
raw_si, bypass_state_injection = (
|
57
|
+
await root_agent.canonical_global_instruction(
|
58
|
+
ReadonlyContext(invocation_context)
|
59
|
+
)
|
58
60
|
)
|
59
|
-
si =
|
61
|
+
si = raw_si
|
62
|
+
if not bypass_state_injection:
|
63
|
+
si = await _populate_values(raw_si, invocation_context)
|
60
64
|
llm_request.append_instructions([si])
|
61
65
|
|
62
66
|
# Appends agent instructions if set.
|
63
67
|
if agent.instruction: # not empty str
|
64
|
-
raw_si = agent.canonical_instruction(
|
65
|
-
|
68
|
+
raw_si, bypass_state_injection = await agent.canonical_instruction(
|
69
|
+
ReadonlyContext(invocation_context)
|
70
|
+
)
|
71
|
+
si = raw_si
|
72
|
+
if not bypass_state_injection:
|
73
|
+
si = await _populate_values(raw_si, invocation_context)
|
66
74
|
llm_request.append_instructions([si])
|
67
75
|
|
68
76
|
# Maintain async generator behavior
|
@@ -73,13 +81,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
73
81
|
request_processor = _InstructionsLlmRequestProcessor()
|
74
82
|
|
75
83
|
|
76
|
-
def _populate_values(
|
84
|
+
async def _populate_values(
|
77
85
|
instruction_template: str,
|
78
86
|
context: InvocationContext,
|
79
87
|
) -> str:
|
80
88
|
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
81
89
|
|
82
|
-
def
|
90
|
+
async def _async_sub(pattern, repl_async_fn, string) -> str:
|
91
|
+
result = []
|
92
|
+
last_end = 0
|
93
|
+
for match in re.finditer(pattern, string):
|
94
|
+
result.append(string[last_end : match.start()])
|
95
|
+
replacement = await repl_async_fn(match)
|
96
|
+
result.append(replacement)
|
97
|
+
last_end = match.end()
|
98
|
+
result.append(string[last_end:])
|
99
|
+
return ''.join(result)
|
100
|
+
|
101
|
+
async def _replace_match(match) -> str:
|
83
102
|
var_name = match.group().lstrip('{').rstrip('}').strip()
|
84
103
|
optional = False
|
85
104
|
if var_name.endswith('?'):
|
@@ -89,7 +108,7 @@ def _populate_values(
|
|
89
108
|
var_name = var_name.removeprefix('artifact.')
|
90
109
|
if context.artifact_service is None:
|
91
110
|
raise ValueError('Artifact service is not initialized.')
|
92
|
-
artifact = context.artifact_service.load_artifact(
|
111
|
+
artifact = await context.artifact_service.load_artifact(
|
93
112
|
app_name=context.session.app_name,
|
94
113
|
user_id=context.session.user_id,
|
95
114
|
session_id=context.session.id,
|
@@ -109,7 +128,7 @@ def _populate_values(
|
|
109
128
|
else:
|
110
129
|
raise KeyError(f'Context variable not found: `{var_name}`.')
|
111
130
|
|
112
|
-
return
|
131
|
+
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
113
132
|
|
114
133
|
|
115
134
|
def _is_valid_state_name(var_name):
|
google/adk/memory/__init__.py
CHANGED
@@ -16,7 +16,7 @@ import logging
|
|
16
16
|
from .base_memory_service import BaseMemoryService
|
17
17
|
from .in_memory_memory_service import InMemoryMemoryService
|
18
18
|
|
19
|
-
logger = logging.getLogger(__name__)
|
19
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
20
20
|
|
21
21
|
__all__ = [
|
22
22
|
'BaseMemoryService',
|
@@ -0,0 +1,23 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from datetime import datetime
|
19
|
+
|
20
|
+
|
21
|
+
def format_timestamp(timestamp: float) -> str:
|
22
|
+
"""Formats the timestamp of the memory entry."""
|
23
|
+
return datetime.fromtimestamp(timestamp).isoformat()
|
@@ -12,44 +12,44 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from abc import ABC
|
19
|
+
from abc import abstractmethod
|
20
|
+
from typing import TYPE_CHECKING
|
16
21
|
|
17
22
|
from pydantic import BaseModel
|
18
23
|
from pydantic import Field
|
19
24
|
|
20
|
-
from
|
21
|
-
from ..sessions.session import Session
|
25
|
+
from .memory_entry import MemoryEntry
|
22
26
|
|
23
|
-
|
24
|
-
|
25
|
-
"""Represents a single memory retrieval result.
|
26
|
-
|
27
|
-
Attributes:
|
28
|
-
session_id: The session id associated with the memory.
|
29
|
-
events: A list of events in the session.
|
30
|
-
"""
|
31
|
-
session_id: str
|
32
|
-
events: list[Event]
|
27
|
+
if TYPE_CHECKING:
|
28
|
+
from ..sessions.session import Session
|
33
29
|
|
34
30
|
|
35
31
|
class SearchMemoryResponse(BaseModel):
|
36
32
|
"""Represents the response from a memory search.
|
37
33
|
|
38
34
|
Attributes:
|
39
|
-
memories: A list of memory
|
35
|
+
memories: A list of memory entries that relate to the search query.
|
40
36
|
"""
|
41
|
-
|
37
|
+
|
38
|
+
memories: list[MemoryEntry] = Field(default_factory=list)
|
42
39
|
|
43
40
|
|
44
|
-
class BaseMemoryService(
|
41
|
+
class BaseMemoryService(ABC):
|
45
42
|
"""Base class for memory services.
|
46
43
|
|
47
44
|
The service provides functionalities to ingest sessions into memory so that
|
48
45
|
the memory can be used for user queries.
|
49
46
|
"""
|
50
47
|
|
51
|
-
@
|
52
|
-
def add_session_to_memory(
|
48
|
+
@abstractmethod
|
49
|
+
async def add_session_to_memory(
|
50
|
+
self,
|
51
|
+
session: Session,
|
52
|
+
):
|
53
53
|
"""Adds a session to the memory service.
|
54
54
|
|
55
55
|
A session may be added multiple times during its lifetime.
|
@@ -58,9 +58,13 @@ class BaseMemoryService(abc.ABC):
|
|
58
58
|
session: The session to add.
|
59
59
|
"""
|
60
60
|
|
61
|
-
@
|
62
|
-
def search_memory(
|
63
|
-
self,
|
61
|
+
@abstractmethod
|
62
|
+
async def search_memory(
|
63
|
+
self,
|
64
|
+
*,
|
65
|
+
app_name: str,
|
66
|
+
user_id: str,
|
67
|
+
query: str,
|
64
68
|
) -> SearchMemoryResponse:
|
65
69
|
"""Searches for sessions that match the query.
|
66
70
|
|
@@ -0,0 +1,76 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import abc
|
16
|
+
|
17
|
+
from pydantic import BaseModel
|
18
|
+
from pydantic import Field
|
19
|
+
|
20
|
+
from ..events.event import Event
|
21
|
+
from ..sessions.session import Session
|
22
|
+
|
23
|
+
|
24
|
+
class MemoryResult(BaseModel):
|
25
|
+
"""Represents a single memory retrieval result.
|
26
|
+
|
27
|
+
Attributes:
|
28
|
+
session_id: The session id associated with the memory.
|
29
|
+
events: A list of events in the session.
|
30
|
+
"""
|
31
|
+
|
32
|
+
session_id: str
|
33
|
+
events: list[Event]
|
34
|
+
|
35
|
+
|
36
|
+
class SearchMemoryResponse(BaseModel):
|
37
|
+
"""Represents the response from a memory search.
|
38
|
+
|
39
|
+
Attributes:
|
40
|
+
memories: A list of memory results matching the search query.
|
41
|
+
"""
|
42
|
+
|
43
|
+
memories: list[MemoryResult] = Field(default_factory=list)
|
44
|
+
|
45
|
+
|
46
|
+
class BaseMemoryService(abc.ABC):
|
47
|
+
"""Base class for memory services.
|
48
|
+
|
49
|
+
The service provides functionalities to ingest sessions into memory so that
|
50
|
+
the memory can be used for user queries.
|
51
|
+
"""
|
52
|
+
|
53
|
+
@abc.abstractmethod
|
54
|
+
async def add_session_to_memory(self, session: Session):
|
55
|
+
"""Adds a session to the memory service.
|
56
|
+
|
57
|
+
A session may be added multiple times during its lifetime.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
session: The session to add.
|
61
|
+
"""
|
62
|
+
|
63
|
+
@abc.abstractmethod
|
64
|
+
async def search_memory(
|
65
|
+
self, *, app_name: str, user_id: str, query: str
|
66
|
+
) -> SearchMemoryResponse:
|
67
|
+
"""Searches for sessions that match the query.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
app_name: The name of the application.
|
71
|
+
user_id: The id of the user.
|
72
|
+
query: The query to search for.
|
73
|
+
|
74
|
+
Returns:
|
75
|
+
A SearchMemoryResponse containing the matching memories.
|
76
|
+
"""
|
@@ -12,11 +12,31 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
|
16
|
-
from
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
import re
|
19
|
+
from typing import TYPE_CHECKING
|
20
|
+
|
21
|
+
from typing_extensions import override
|
22
|
+
|
23
|
+
from . import _utils
|
17
24
|
from .base_memory_service import BaseMemoryService
|
18
|
-
from .base_memory_service import MemoryResult
|
19
25
|
from .base_memory_service import SearchMemoryResponse
|
26
|
+
from .memory_entry import MemoryEntry
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from ..events.event import Event
|
30
|
+
from ..sessions.session import Session
|
31
|
+
|
32
|
+
|
33
|
+
def _user_key(app_name: str, user_id: str):
|
34
|
+
return f'{app_name}/{user_id}'
|
35
|
+
|
36
|
+
|
37
|
+
def _extract_words_lower(text: str) -> set[str]:
|
38
|
+
"""Extracts words from a string and converts them to lowercase."""
|
39
|
+
return set([word.lower() for word in re.findall(r'[A-Za-z]+', text)])
|
20
40
|
|
21
41
|
|
22
42
|
class InMemoryMemoryService(BaseMemoryService):
|
@@ -26,37 +46,49 @@ class InMemoryMemoryService(BaseMemoryService):
|
|
26
46
|
"""
|
27
47
|
|
28
48
|
def __init__(self):
|
29
|
-
self.
|
30
|
-
"""
|
49
|
+
self._session_events: dict[str, dict[str, list[Event]]] = {}
|
50
|
+
"""Keys are app_name/user_id, session_id. Values are session event lists."""
|
31
51
|
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
52
|
+
@override
|
53
|
+
async def add_session_to_memory(self, session: Session):
|
54
|
+
user_key = _user_key(session.app_name, session.user_id)
|
55
|
+
self._session_events[user_key] = self._session_events.get(
|
56
|
+
_user_key(session.app_name, session.user_id), {}
|
57
|
+
)
|
58
|
+
self._session_events[user_key][session.id] = [
|
59
|
+
event
|
60
|
+
for event in session.events
|
61
|
+
if event.content and event.content.parts
|
36
62
|
]
|
37
63
|
|
38
|
-
|
64
|
+
@override
|
65
|
+
async def search_memory(
|
39
66
|
self, *, app_name: str, user_id: str, query: str
|
40
67
|
) -> SearchMemoryResponse:
|
41
|
-
|
42
|
-
|
68
|
+
user_key = _user_key(app_name, user_id)
|
69
|
+
if user_key not in self._session_events:
|
70
|
+
return SearchMemoryResponse()
|
71
|
+
|
72
|
+
words_in_query = set(query.lower().split())
|
43
73
|
response = SearchMemoryResponse()
|
44
|
-
|
45
|
-
|
46
|
-
|
47
|
-
matched_events = []
|
48
|
-
for event in events:
|
74
|
+
|
75
|
+
for session_events in self._session_events[user_key].values():
|
76
|
+
for event in session_events:
|
49
77
|
if not event.content or not event.content.parts:
|
50
78
|
continue
|
51
|
-
|
52
|
-
|
53
|
-
for keyword in keywords:
|
54
|
-
if keyword in text:
|
55
|
-
matched_events.append(event)
|
56
|
-
break
|
57
|
-
if matched_events:
|
58
|
-
session_id = key.split('/')[-1]
|
59
|
-
response.memories.append(
|
60
|
-
MemoryResult(session_id=session_id, events=matched_events)
|
79
|
+
words_in_event = _extract_words_lower(
|
80
|
+
' '.join([part.text for part in event.content.parts if part.text])
|
61
81
|
)
|
82
|
+
if not words_in_event:
|
83
|
+
continue
|
84
|
+
|
85
|
+
if any(query_word in words_in_event for query_word in words_in_query):
|
86
|
+
response.memories.append(
|
87
|
+
MemoryEntry(
|
88
|
+
content=event.content,
|
89
|
+
author=event.author,
|
90
|
+
timestamp=_utils.format_timestamp(event.timestamp),
|
91
|
+
)
|
92
|
+
)
|
93
|
+
|
62
94
|
return response
|
@@ -0,0 +1,37 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
18
|
+
from typing import Optional
|
19
|
+
|
20
|
+
from google.genai import types
|
21
|
+
from pydantic import BaseModel
|
22
|
+
|
23
|
+
|
24
|
+
class MemoryEntry(BaseModel):
|
25
|
+
"""Represent one memory entry."""
|
26
|
+
|
27
|
+
content: types.Content
|
28
|
+
"""The main content of the memory."""
|
29
|
+
|
30
|
+
author: Optional[str] = None
|
31
|
+
"""The author of the memory."""
|
32
|
+
|
33
|
+
timestamp: Optional[str] = None
|
34
|
+
"""The timestamp when the original content of this memory happened.
|
35
|
+
|
36
|
+
This string will be forwarded to LLM. Preferred format is ISO 8601 format.
|
37
|
+
"""
|
@@ -12,20 +12,28 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
15
18
|
from collections import OrderedDict
|
16
19
|
import json
|
17
20
|
import os
|
18
21
|
import tempfile
|
22
|
+
from typing import Optional
|
23
|
+
from typing import TYPE_CHECKING
|
19
24
|
|
20
25
|
from google.genai import types
|
21
26
|
from typing_extensions import override
|
22
27
|
from vertexai.preview import rag
|
23
28
|
|
24
|
-
from
|
25
|
-
from ..sessions.session import Session
|
29
|
+
from . import _utils
|
26
30
|
from .base_memory_service import BaseMemoryService
|
27
|
-
from .base_memory_service import MemoryResult
|
28
31
|
from .base_memory_service import SearchMemoryResponse
|
32
|
+
from .memory_entry import MemoryEntry
|
33
|
+
|
34
|
+
if TYPE_CHECKING:
|
35
|
+
from ..events.event import Event
|
36
|
+
from ..sessions.session import Session
|
29
37
|
|
30
38
|
|
31
39
|
class VertexAiRagMemoryService(BaseMemoryService):
|
@@ -33,8 +41,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
33
41
|
|
34
42
|
def __init__(
|
35
43
|
self,
|
36
|
-
rag_corpus: str = None,
|
37
|
-
similarity_top_k: int = None,
|
44
|
+
rag_corpus: Optional[str] = None,
|
45
|
+
similarity_top_k: Optional[int] = None,
|
38
46
|
vector_distance_threshold: float = 10,
|
39
47
|
):
|
40
48
|
"""Initializes a VertexAiRagMemoryService.
|
@@ -47,14 +55,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
47
55
|
vector_distance_threshold: Only returns contexts with vector distance
|
48
56
|
smaller than the threshold..
|
49
57
|
"""
|
50
|
-
self.
|
51
|
-
rag_resources=[
|
58
|
+
self._vertex_rag_store = types.VertexRagStore(
|
59
|
+
rag_resources=[
|
60
|
+
types.VertexRagStoreRagResource(rag_corpus=rag_corpus),
|
61
|
+
],
|
52
62
|
similarity_top_k=similarity_top_k,
|
53
63
|
vector_distance_threshold=vector_distance_threshold,
|
54
64
|
)
|
55
65
|
|
56
66
|
@override
|
57
|
-
def add_session_to_memory(self, session: Session):
|
67
|
+
async def add_session_to_memory(self, session: Session):
|
58
68
|
with tempfile.NamedTemporaryFile(
|
59
69
|
mode="w", delete=False, suffix=".txt"
|
60
70
|
) as temp_file:
|
@@ -79,7 +89,11 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
79
89
|
output_string = "\n".join(output_lines)
|
80
90
|
temp_file.write(output_string)
|
81
91
|
temp_file_path = temp_file.name
|
82
|
-
|
92
|
+
|
93
|
+
if not self._vertex_rag_store.rag_resources:
|
94
|
+
raise ValueError("Rag resources must be set.")
|
95
|
+
|
96
|
+
for rag_resource in self._vertex_rag_store.rag_resources:
|
83
97
|
rag.upload_file(
|
84
98
|
corpus_name=rag_resource.rag_corpus,
|
85
99
|
path=temp_file_path,
|
@@ -91,16 +105,18 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
91
105
|
os.remove(temp_file_path)
|
92
106
|
|
93
107
|
@override
|
94
|
-
def search_memory(
|
108
|
+
async def search_memory(
|
95
109
|
self, *, app_name: str, user_id: str, query: str
|
96
110
|
) -> SearchMemoryResponse:
|
97
111
|
"""Searches for sessions that match the query using rag.retrieval_query."""
|
112
|
+
from ..events.event import Event
|
113
|
+
|
98
114
|
response = rag.retrieval_query(
|
99
115
|
text=query,
|
100
|
-
rag_resources=self.
|
101
|
-
rag_corpora=self.
|
102
|
-
similarity_top_k=self.
|
103
|
-
vector_distance_threshold=self.
|
116
|
+
rag_resources=self._vertex_rag_store.rag_resources,
|
117
|
+
rag_corpora=self._vertex_rag_store.rag_corpora,
|
118
|
+
similarity_top_k=self._vertex_rag_store.similarity_top_k,
|
119
|
+
vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold,
|
104
120
|
)
|
105
121
|
|
106
122
|
memory_results = []
|
@@ -144,9 +160,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
144
160
|
for session_id, event_lists in session_events_map.items():
|
145
161
|
for events in _merge_event_lists(event_lists):
|
146
162
|
sorted_events = sorted(events, key=lambda e: e.timestamp)
|
147
|
-
|
148
|
-
|
149
|
-
|
163
|
+
|
164
|
+
memory_results.extend([
|
165
|
+
MemoryEntry(
|
166
|
+
author=event.author,
|
167
|
+
content=event.content,
|
168
|
+
timestamp=_utils.format_timestamp(event.timestamp),
|
169
|
+
)
|
170
|
+
for event in sorted_events
|
171
|
+
if event.content
|
172
|
+
])
|
150
173
|
return SearchMemoryResponse(memories=memory_results)
|
151
174
|
|
152
175
|
|