google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. google/adk/agents/active_streaming_tool.py +1 -0
  2. google/adk/agents/base_agent.py +91 -47
  3. google/adk/agents/base_agent.py.orig +330 -0
  4. google/adk/agents/callback_context.py +4 -9
  5. google/adk/agents/invocation_context.py +1 -0
  6. google/adk/agents/langgraph_agent.py +1 -0
  7. google/adk/agents/live_request_queue.py +1 -0
  8. google/adk/agents/llm_agent.py +172 -35
  9. google/adk/agents/loop_agent.py +1 -1
  10. google/adk/agents/parallel_agent.py +7 -0
  11. google/adk/agents/readonly_context.py +7 -1
  12. google/adk/agents/run_config.py +5 -1
  13. google/adk/agents/sequential_agent.py +31 -0
  14. google/adk/agents/transcription_entry.py +5 -2
  15. google/adk/artifacts/base_artifact_service.py +5 -10
  16. google/adk/artifacts/gcs_artifact_service.py +9 -9
  17. google/adk/artifacts/in_memory_artifact_service.py +6 -6
  18. google/adk/auth/auth_credential.py +9 -5
  19. google/adk/auth/auth_preprocessor.py +7 -1
  20. google/adk/auth/auth_tool.py +3 -4
  21. google/adk/cli/agent_graph.py +5 -5
  22. google/adk/cli/browser/index.html +2 -2
  23. google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
  24. google/adk/cli/cli.py +7 -7
  25. google/adk/cli/cli_deploy.py +7 -2
  26. google/adk/cli/cli_eval.py +181 -106
  27. google/adk/cli/cli_tools_click.py +147 -62
  28. google/adk/cli/fast_api.py +340 -158
  29. google/adk/cli/fast_api.py.orig +822 -0
  30. google/adk/cli/utils/common.py +23 -0
  31. google/adk/cli/utils/evals.py +83 -1
  32. google/adk/cli/utils/logs.py +13 -5
  33. google/adk/code_executors/__init__.py +3 -1
  34. google/adk/code_executors/built_in_code_executor.py +52 -0
  35. google/adk/evaluation/__init__.py +1 -1
  36. google/adk/evaluation/agent_evaluator.py +168 -128
  37. google/adk/evaluation/eval_case.py +102 -0
  38. google/adk/evaluation/eval_set.py +37 -0
  39. google/adk/evaluation/eval_sets_manager.py +42 -0
  40. google/adk/evaluation/evaluation_constants.py +1 -0
  41. google/adk/evaluation/evaluation_generator.py +89 -114
  42. google/adk/evaluation/evaluator.py +56 -0
  43. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  44. google/adk/evaluation/response_evaluator.py +107 -3
  45. google/adk/evaluation/trajectory_evaluator.py +83 -2
  46. google/adk/events/event.py +7 -1
  47. google/adk/events/event_actions.py +7 -1
  48. google/adk/examples/example.py +1 -0
  49. google/adk/examples/example_util.py +3 -2
  50. google/adk/flows/__init__.py +0 -1
  51. google/adk/flows/llm_flows/_code_execution.py +19 -11
  52. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  53. google/adk/flows/llm_flows/base_llm_flow.py +86 -22
  54. google/adk/flows/llm_flows/basic.py +3 -0
  55. google/adk/flows/llm_flows/functions.py +10 -9
  56. google/adk/flows/llm_flows/instructions.py +28 -9
  57. google/adk/flows/llm_flows/single_flow.py +1 -1
  58. google/adk/memory/__init__.py +1 -1
  59. google/adk/memory/_utils.py +23 -0
  60. google/adk/memory/base_memory_service.py +25 -21
  61. google/adk/memory/base_memory_service.py.orig +76 -0
  62. google/adk/memory/in_memory_memory_service.py +59 -27
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
  65. google/adk/models/anthropic_llm.py +36 -11
  66. google/adk/models/base_llm.py +45 -4
  67. google/adk/models/gemini_llm_connection.py +15 -2
  68. google/adk/models/google_llm.py +9 -44
  69. google/adk/models/google_llm.py.orig +305 -0
  70. google/adk/models/lite_llm.py +94 -38
  71. google/adk/models/llm_request.py +1 -1
  72. google/adk/models/llm_response.py +15 -3
  73. google/adk/models/registry.py +1 -1
  74. google/adk/runners.py +68 -44
  75. google/adk/sessions/__init__.py +1 -1
  76. google/adk/sessions/_session_util.py +14 -0
  77. google/adk/sessions/base_session_service.py +8 -32
  78. google/adk/sessions/database_session_service.py +58 -61
  79. google/adk/sessions/in_memory_session_service.py +108 -26
  80. google/adk/sessions/session.py +4 -0
  81. google/adk/sessions/vertex_ai_session_service.py +23 -45
  82. google/adk/telemetry.py +3 -0
  83. google/adk/tools/__init__.py +4 -7
  84. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  85. google/adk/tools/_memory_entry_utils.py +30 -0
  86. google/adk/tools/agent_tool.py +16 -13
  87. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +58 -0
  93. google/adk/tools/enterprise_search_tool.py +65 -0
  94. google/adk/tools/function_parameter_parse_util.py +2 -2
  95. google/adk/tools/google_api_tool/__init__.py +18 -70
  96. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  97. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  98. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  99. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  100. google/adk/tools/langchain_tool.py +96 -49
  101. google/adk/tools/load_artifacts_tool.py +4 -4
  102. google/adk/tools/load_memory_tool.py +16 -5
  103. google/adk/tools/mcp_tool/__init__.py +3 -2
  104. google/adk/tools/mcp_tool/conversion_utils.py +1 -1
  105. google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
  106. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  107. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  108. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  109. google/adk/tools/openapi_tool/common/common.py +2 -5
  110. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  111. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
  112. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  113. google/adk/tools/preload_memory_tool.py +27 -18
  114. google/adk/tools/retrieval/__init__.py +1 -1
  115. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  116. google/adk/tools/tool_context.py +4 -4
  117. google/adk/tools/toolbox_toolset.py +79 -0
  118. google/adk/tools/transfer_to_agent_tool.py +0 -1
  119. google/adk/version.py +1 -1
  120. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  121. google_adk-1.0.0.dist-info/RECORD +195 -0
  122. google/adk/agents/remote_agent.py +0 -50
  123. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  124. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  125. google/adk/tools/toolbox_tool.py +0 -46
  126. google_adk-0.4.0.dist-info/RECORD +0 -179
  127. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  128. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  129. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -62,6 +62,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
62
62
  llm_request.live_connect_config.output_audio_transcription = (
63
63
  invocation_context.run_config.output_audio_transcription
64
64
  )
65
+ llm_request.live_connect_config.input_audio_transcription = (
66
+ invocation_context.run_config.input_audio_transcription
67
+ )
65
68
 
66
69
  # TODO: handle tool append here, instead of in BaseTool.process_llm_request.
67
70
 
@@ -41,7 +41,7 @@ from ...tools.tool_context import ToolContext
41
41
  AF_FUNCTION_CALL_ID_PREFIX = 'adk-'
42
42
  REQUEST_EUC_FUNCTION_CALL_NAME = 'adk_request_credential'
43
43
 
44
- logger = logging.getLogger(__name__)
44
+ logger = logging.getLogger('google_adk.' + __name__)
45
45
 
46
46
 
47
47
  def generate_client_function_call_id() -> str:
@@ -106,7 +106,7 @@ def generate_auth_event(
106
106
  args=AuthToolArguments(
107
107
  function_call_id=function_call_id,
108
108
  auth_config=auth_config,
109
- ).model_dump(exclude_none=True),
109
+ ).model_dump(exclude_none=True, by_alias=True),
110
110
  )
111
111
  request_euc_function_call.id = generate_client_function_call_id()
112
112
  long_running_tool_ids.add(request_euc_function_call.id)
@@ -153,22 +153,22 @@ async def handle_function_calls_async(
153
153
  function_args = function_call.args or {}
154
154
  function_response: Optional[dict] = None
155
155
 
156
- # before_tool_callback (sync or async)
157
- if agent.before_tool_callback:
158
- function_response = agent.before_tool_callback(
156
+ for callback in agent.canonical_before_tool_callbacks:
157
+ function_response = callback(
159
158
  tool=tool, args=function_args, tool_context=tool_context
160
159
  )
161
160
  if inspect.isawaitable(function_response):
162
161
  function_response = await function_response
162
+ if function_response:
163
+ break
163
164
 
164
165
  if not function_response:
165
166
  function_response = await __call_tool_async(
166
167
  tool, args=function_args, tool_context=tool_context
167
168
  )
168
169
 
169
- # after_tool_callback (sync or async)
170
- if agent.after_tool_callback:
171
- altered_function_response = agent.after_tool_callback(
170
+ for callback in agent.canonical_after_tool_callbacks:
171
+ altered_function_response = callback(
172
172
  tool=tool,
173
173
  args=function_args,
174
174
  tool_context=tool_context,
@@ -178,6 +178,7 @@ async def handle_function_calls_async(
178
178
  altered_function_response = await altered_function_response
179
179
  if altered_function_response is not None:
180
180
  function_response = altered_function_response
181
+ break
181
182
 
182
183
  if tool.is_long_running:
183
184
  # Allow long running function to return None to not provide function response.
@@ -332,7 +333,7 @@ async def _process_function_live_helper(
332
333
  function_response = {
333
334
  'status': f'No active streaming function named {function_name} found'
334
335
  }
335
- elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
336
+ elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
336
337
  # for streaming tool use case
337
338
  # we require the function to be a async generator function
338
339
  async def run_tool_and_update_queue(tool, function_args, tool_context):
@@ -53,16 +53,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
53
53
  if (
54
54
  isinstance(root_agent, LlmAgent) and root_agent.global_instruction
55
55
  ): # not empty str
56
- raw_si = root_agent.canonical_global_instruction(
57
- ReadonlyContext(invocation_context)
56
+ raw_si, bypass_state_injection = (
57
+ await root_agent.canonical_global_instruction(
58
+ ReadonlyContext(invocation_context)
59
+ )
58
60
  )
59
- si = _populate_values(raw_si, invocation_context)
61
+ si = raw_si
62
+ if not bypass_state_injection:
63
+ si = await _populate_values(raw_si, invocation_context)
60
64
  llm_request.append_instructions([si])
61
65
 
62
66
  # Appends agent instructions if set.
63
67
  if agent.instruction: # not empty str
64
- raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
65
- si = _populate_values(raw_si, invocation_context)
68
+ raw_si, bypass_state_injection = await agent.canonical_instruction(
69
+ ReadonlyContext(invocation_context)
70
+ )
71
+ si = raw_si
72
+ if not bypass_state_injection:
73
+ si = await _populate_values(raw_si, invocation_context)
66
74
  llm_request.append_instructions([si])
67
75
 
68
76
  # Maintain async generator behavior
@@ -73,13 +81,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
73
81
  request_processor = _InstructionsLlmRequestProcessor()
74
82
 
75
83
 
76
- def _populate_values(
84
+ async def _populate_values(
77
85
  instruction_template: str,
78
86
  context: InvocationContext,
79
87
  ) -> str:
80
88
  """Populates values in the instruction template, e.g. state, artifact, etc."""
81
89
 
82
- def _replace_match(match) -> str:
90
+ async def _async_sub(pattern, repl_async_fn, string) -> str:
91
+ result = []
92
+ last_end = 0
93
+ for match in re.finditer(pattern, string):
94
+ result.append(string[last_end : match.start()])
95
+ replacement = await repl_async_fn(match)
96
+ result.append(replacement)
97
+ last_end = match.end()
98
+ result.append(string[last_end:])
99
+ return ''.join(result)
100
+
101
+ async def _replace_match(match) -> str:
83
102
  var_name = match.group().lstrip('{').rstrip('}').strip()
84
103
  optional = False
85
104
  if var_name.endswith('?'):
@@ -89,7 +108,7 @@ def _populate_values(
89
108
  var_name = var_name.removeprefix('artifact.')
90
109
  if context.artifact_service is None:
91
110
  raise ValueError('Artifact service is not initialized.')
92
- artifact = context.artifact_service.load_artifact(
111
+ artifact = await context.artifact_service.load_artifact(
93
112
  app_name=context.session.app_name,
94
113
  user_id=context.session.user_id,
95
114
  session_id=context.session.id,
@@ -109,7 +128,7 @@ def _populate_values(
109
128
  else:
110
129
  raise KeyError(f'Context variable not found: `{var_name}`.')
111
130
 
112
- return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
131
+ return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
113
132
 
114
133
 
115
134
  def _is_valid_state_name(var_name):
@@ -25,7 +25,7 @@ from . import identity
25
25
  from . import instructions
26
26
  from .base_llm_flow import BaseLlmFlow
27
27
 
28
- logger = logging.getLogger(__name__)
28
+ logger = logging.getLogger('google_adk.' + __name__)
29
29
 
30
30
 
31
31
  class SingleFlow(BaseLlmFlow):
@@ -16,7 +16,7 @@ import logging
16
16
  from .base_memory_service import BaseMemoryService
17
17
  from .in_memory_memory_service import InMemoryMemoryService
18
18
 
19
- logger = logging.getLogger(__name__)
19
+ logger = logging.getLogger('google_adk.' + __name__)
20
20
 
21
21
  __all__ = [
22
22
  'BaseMemoryService',
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from datetime import datetime
19
+
20
+
21
+ def format_timestamp(timestamp: float) -> str:
22
+ """Formats the timestamp of the memory entry."""
23
+ return datetime.fromtimestamp(timestamp).isoformat()
@@ -12,44 +12,44 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- import abc
15
+
16
+ from __future__ import annotations
17
+
18
+ from abc import ABC
19
+ from abc import abstractmethod
20
+ from typing import TYPE_CHECKING
16
21
 
17
22
  from pydantic import BaseModel
18
23
  from pydantic import Field
19
24
 
20
- from ..events.event import Event
21
- from ..sessions.session import Session
25
+ from .memory_entry import MemoryEntry
22
26
 
23
-
24
- class MemoryResult(BaseModel):
25
- """Represents a single memory retrieval result.
26
-
27
- Attributes:
28
- session_id: The session id associated with the memory.
29
- events: A list of events in the session.
30
- """
31
- session_id: str
32
- events: list[Event]
27
+ if TYPE_CHECKING:
28
+ from ..sessions.session import Session
33
29
 
34
30
 
35
31
  class SearchMemoryResponse(BaseModel):
36
32
  """Represents the response from a memory search.
37
33
 
38
34
  Attributes:
39
- memories: A list of memory results matching the search query.
35
+ memories: A list of memory entries that relate to the search query.
40
36
  """
41
- memories: list[MemoryResult] = Field(default_factory=list)
37
+
38
+ memories: list[MemoryEntry] = Field(default_factory=list)
42
39
 
43
40
 
44
- class BaseMemoryService(abc.ABC):
41
+ class BaseMemoryService(ABC):
45
42
  """Base class for memory services.
46
43
 
47
44
  The service provides functionalities to ingest sessions into memory so that
48
45
  the memory can be used for user queries.
49
46
  """
50
47
 
51
- @abc.abstractmethod
52
- def add_session_to_memory(self, session: Session):
48
+ @abstractmethod
49
+ async def add_session_to_memory(
50
+ self,
51
+ session: Session,
52
+ ):
53
53
  """Adds a session to the memory service.
54
54
 
55
55
  A session may be added multiple times during its lifetime.
@@ -58,9 +58,13 @@ class BaseMemoryService(abc.ABC):
58
58
  session: The session to add.
59
59
  """
60
60
 
61
- @abc.abstractmethod
62
- def search_memory(
63
- self, *, app_name: str, user_id: str, query: str
61
+ @abstractmethod
62
+ async def search_memory(
63
+ self,
64
+ *,
65
+ app_name: str,
66
+ user_id: str,
67
+ query: str,
64
68
  ) -> SearchMemoryResponse:
65
69
  """Searches for sessions that match the query.
66
70
 
@@ -0,0 +1,76 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import abc
16
+
17
+ from pydantic import BaseModel
18
+ from pydantic import Field
19
+
20
+ from ..events.event import Event
21
+ from ..sessions.session import Session
22
+
23
+
24
+ class MemoryResult(BaseModel):
25
+ """Represents a single memory retrieval result.
26
+
27
+ Attributes:
28
+ session_id: The session id associated with the memory.
29
+ events: A list of events in the session.
30
+ """
31
+
32
+ session_id: str
33
+ events: list[Event]
34
+
35
+
36
+ class SearchMemoryResponse(BaseModel):
37
+ """Represents the response from a memory search.
38
+
39
+ Attributes:
40
+ memories: A list of memory results matching the search query.
41
+ """
42
+
43
+ memories: list[MemoryResult] = Field(default_factory=list)
44
+
45
+
46
+ class BaseMemoryService(abc.ABC):
47
+ """Base class for memory services.
48
+
49
+ The service provides functionalities to ingest sessions into memory so that
50
+ the memory can be used for user queries.
51
+ """
52
+
53
+ @abc.abstractmethod
54
+ async def add_session_to_memory(self, session: Session):
55
+ """Adds a session to the memory service.
56
+
57
+ A session may be added multiple times during its lifetime.
58
+
59
+ Args:
60
+ session: The session to add.
61
+ """
62
+
63
+ @abc.abstractmethod
64
+ async def search_memory(
65
+ self, *, app_name: str, user_id: str, query: str
66
+ ) -> SearchMemoryResponse:
67
+ """Searches for sessions that match the query.
68
+
69
+ Args:
70
+ app_name: The name of the application.
71
+ user_id: The id of the user.
72
+ query: The query to search for.
73
+
74
+ Returns:
75
+ A SearchMemoryResponse containing the matching memories.
76
+ """
@@ -12,11 +12,31 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from ..events.event import Event
16
- from ..sessions.session import Session
15
+
16
+ from __future__ import annotations
17
+
18
+ import re
19
+ from typing import TYPE_CHECKING
20
+
21
+ from typing_extensions import override
22
+
23
+ from . import _utils
17
24
  from .base_memory_service import BaseMemoryService
18
- from .base_memory_service import MemoryResult
19
25
  from .base_memory_service import SearchMemoryResponse
26
+ from .memory_entry import MemoryEntry
27
+
28
+ if TYPE_CHECKING:
29
+ from ..events.event import Event
30
+ from ..sessions.session import Session
31
+
32
+
33
+ def _user_key(app_name: str, user_id: str):
34
+ return f'{app_name}/{user_id}'
35
+
36
+
37
+ def _extract_words_lower(text: str) -> set[str]:
38
+ """Extracts words from a string and converts them to lowercase."""
39
+ return set([word.lower() for word in re.findall(r'[A-Za-z]+', text)])
20
40
 
21
41
 
22
42
  class InMemoryMemoryService(BaseMemoryService):
@@ -26,37 +46,49 @@ class InMemoryMemoryService(BaseMemoryService):
26
46
  """
27
47
 
28
48
  def __init__(self):
29
- self.session_events: dict[str, list[Event]] = {}
30
- """keys are app_name/user_id/session_id"""
49
+ self._session_events: dict[str, dict[str, list[Event]]] = {}
50
+ """Keys are app_name/user_id, session_id. Values are session event lists."""
31
51
 
32
- def add_session_to_memory(self, session: Session):
33
- key = f'{session.app_name}/{session.user_id}/{session.id}'
34
- self.session_events[key] = [
35
- event for event in session.events if event.content
52
+ @override
53
+ async def add_session_to_memory(self, session: Session):
54
+ user_key = _user_key(session.app_name, session.user_id)
55
+ self._session_events[user_key] = self._session_events.get(
56
+ _user_key(session.app_name, session.user_id), {}
57
+ )
58
+ self._session_events[user_key][session.id] = [
59
+ event
60
+ for event in session.events
61
+ if event.content and event.content.parts
36
62
  ]
37
63
 
38
- def search_memory(
64
+ @override
65
+ async def search_memory(
39
66
  self, *, app_name: str, user_id: str, query: str
40
67
  ) -> SearchMemoryResponse:
41
- """Prototyping purpose only."""
42
- keywords = set(query.lower().split())
68
+ user_key = _user_key(app_name, user_id)
69
+ if user_key not in self._session_events:
70
+ return SearchMemoryResponse()
71
+
72
+ words_in_query = set(query.lower().split())
43
73
  response = SearchMemoryResponse()
44
- for key, events in self.session_events.items():
45
- if not key.startswith(f'{app_name}/{user_id}/'):
46
- continue
47
- matched_events = []
48
- for event in events:
74
+
75
+ for session_events in self._session_events[user_key].values():
76
+ for event in session_events:
49
77
  if not event.content or not event.content.parts:
50
78
  continue
51
- parts = event.content.parts
52
- text = '\n'.join([part.text for part in parts if part.text]).lower()
53
- for keyword in keywords:
54
- if keyword in text:
55
- matched_events.append(event)
56
- break
57
- if matched_events:
58
- session_id = key.split('/')[-1]
59
- response.memories.append(
60
- MemoryResult(session_id=session_id, events=matched_events)
79
+ words_in_event = _extract_words_lower(
80
+ ' '.join([part.text for part in event.content.parts if part.text])
61
81
  )
82
+ if not words_in_event:
83
+ continue
84
+
85
+ if any(query_word in words_in_event for query_word in words_in_query):
86
+ response.memories.append(
87
+ MemoryEntry(
88
+ content=event.content,
89
+ author=event.author,
90
+ timestamp=_utils.format_timestamp(event.timestamp),
91
+ )
92
+ )
93
+
62
94
  return response
@@ -0,0 +1,37 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Optional
19
+
20
+ from google.genai import types
21
+ from pydantic import BaseModel
22
+
23
+
24
+ class MemoryEntry(BaseModel):
25
+ """Represent one memory entry."""
26
+
27
+ content: types.Content
28
+ """The main content of the memory."""
29
+
30
+ author: Optional[str] = None
31
+ """The author of the memory."""
32
+
33
+ timestamp: Optional[str] = None
34
+ """The timestamp when the original content of this memory happened.
35
+
36
+ This string will be forwarded to LLM. Preferred format is ISO 8601 format.
37
+ """
@@ -12,20 +12,28 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
16
+ from __future__ import annotations
17
+
15
18
  from collections import OrderedDict
16
19
  import json
17
20
  import os
18
21
  import tempfile
22
+ from typing import Optional
23
+ from typing import TYPE_CHECKING
19
24
 
20
25
  from google.genai import types
21
26
  from typing_extensions import override
22
27
  from vertexai.preview import rag
23
28
 
24
- from ..events.event import Event
25
- from ..sessions.session import Session
29
+ from . import _utils
26
30
  from .base_memory_service import BaseMemoryService
27
- from .base_memory_service import MemoryResult
28
31
  from .base_memory_service import SearchMemoryResponse
32
+ from .memory_entry import MemoryEntry
33
+
34
+ if TYPE_CHECKING:
35
+ from ..events.event import Event
36
+ from ..sessions.session import Session
29
37
 
30
38
 
31
39
  class VertexAiRagMemoryService(BaseMemoryService):
@@ -33,8 +41,8 @@ class VertexAiRagMemoryService(BaseMemoryService):
33
41
 
34
42
  def __init__(
35
43
  self,
36
- rag_corpus: str = None,
37
- similarity_top_k: int = None,
44
+ rag_corpus: Optional[str] = None,
45
+ similarity_top_k: Optional[int] = None,
38
46
  vector_distance_threshold: float = 10,
39
47
  ):
40
48
  """Initializes a VertexAiRagMemoryService.
@@ -47,14 +55,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
47
55
  vector_distance_threshold: Only returns contexts with vector distance
48
56
  smaller than the threshold..
49
57
  """
50
- self.vertex_rag_store = types.VertexRagStore(
51
- rag_resources=[rag.RagResource(rag_corpus=rag_corpus)],
58
+ self._vertex_rag_store = types.VertexRagStore(
59
+ rag_resources=[
60
+ types.VertexRagStoreRagResource(rag_corpus=rag_corpus),
61
+ ],
52
62
  similarity_top_k=similarity_top_k,
53
63
  vector_distance_threshold=vector_distance_threshold,
54
64
  )
55
65
 
56
66
  @override
57
- def add_session_to_memory(self, session: Session):
67
+ async def add_session_to_memory(self, session: Session):
58
68
  with tempfile.NamedTemporaryFile(
59
69
  mode="w", delete=False, suffix=".txt"
60
70
  ) as temp_file:
@@ -79,7 +89,11 @@ class VertexAiRagMemoryService(BaseMemoryService):
79
89
  output_string = "\n".join(output_lines)
80
90
  temp_file.write(output_string)
81
91
  temp_file_path = temp_file.name
82
- for rag_resource in self.vertex_rag_store.rag_resources:
92
+
93
+ if not self._vertex_rag_store.rag_resources:
94
+ raise ValueError("Rag resources must be set.")
95
+
96
+ for rag_resource in self._vertex_rag_store.rag_resources:
83
97
  rag.upload_file(
84
98
  corpus_name=rag_resource.rag_corpus,
85
99
  path=temp_file_path,
@@ -91,16 +105,18 @@ class VertexAiRagMemoryService(BaseMemoryService):
91
105
  os.remove(temp_file_path)
92
106
 
93
107
  @override
94
- def search_memory(
108
+ async def search_memory(
95
109
  self, *, app_name: str, user_id: str, query: str
96
110
  ) -> SearchMemoryResponse:
97
111
  """Searches for sessions that match the query using rag.retrieval_query."""
112
+ from ..events.event import Event
113
+
98
114
  response = rag.retrieval_query(
99
115
  text=query,
100
- rag_resources=self.vertex_rag_store.rag_resources,
101
- rag_corpora=self.vertex_rag_store.rag_corpora,
102
- similarity_top_k=self.vertex_rag_store.similarity_top_k,
103
- vector_distance_threshold=self.vertex_rag_store.vector_distance_threshold,
116
+ rag_resources=self._vertex_rag_store.rag_resources,
117
+ rag_corpora=self._vertex_rag_store.rag_corpora,
118
+ similarity_top_k=self._vertex_rag_store.similarity_top_k,
119
+ vector_distance_threshold=self._vertex_rag_store.vector_distance_threshold,
104
120
  )
105
121
 
106
122
  memory_results = []
@@ -144,9 +160,16 @@ class VertexAiRagMemoryService(BaseMemoryService):
144
160
  for session_id, event_lists in session_events_map.items():
145
161
  for events in _merge_event_lists(event_lists):
146
162
  sorted_events = sorted(events, key=lambda e: e.timestamp)
147
- memory_results.append(
148
- MemoryResult(session_id=session_id, events=sorted_events)
149
- )
163
+
164
+ memory_results.extend([
165
+ MemoryEntry(
166
+ author=event.author,
167
+ content=event.content,
168
+ timestamp=_utils.format_timestamp(event.timestamp),
169
+ )
170
+ for event in sorted_events
171
+ if event.content
172
+ ])
150
173
  return SearchMemoryResponse(memories=memory_results)
151
174
 
152
175