google-adk 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -65,10 +65,13 @@ from ..evaluation.eval_metrics import EvalMetric
65
65
  from ..evaluation.eval_metrics import EvalMetricResult
66
66
  from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
67
67
  from ..evaluation.eval_result import EvalSetResult
68
+ from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
69
+ from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
68
70
  from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
69
71
  from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
70
72
  from ..events.event import Event
71
73
  from ..memory.in_memory_memory_service import InMemoryMemoryService
74
+ from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
72
75
  from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
73
76
  from ..runners import Runner
74
77
  from ..sessions.database_session_service import DatabaseSessionService
@@ -198,6 +201,7 @@ def get_fast_api_app(
198
201
  session_service_uri: Optional[str] = None,
199
202
  artifact_service_uri: Optional[str] = None,
200
203
  memory_service_uri: Optional[str] = None,
204
+ eval_storage_uri: Optional[str] = None,
201
205
  allow_origins: Optional[list[str]] = None,
202
206
  web: bool,
203
207
  trace_to_cloud: bool = False,
@@ -256,8 +260,18 @@ def get_fast_api_app(
256
260
 
257
261
  runner_dict = {}
258
262
 
259
- eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
260
- eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
263
+ # Set up eval managers.
264
+ eval_sets_manager = None
265
+ eval_set_results_manager = None
266
+ if eval_storage_uri:
267
+ gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
268
+ eval_storage_uri
269
+ )
270
+ eval_sets_manager = gcs_eval_managers.eval_sets_manager
271
+ eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
272
+ else:
273
+ eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
274
+ eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
261
275
 
262
276
  # Build the Memory service
263
277
  if memory_service_uri:
@@ -269,6 +283,16 @@ def get_fast_api_app(
269
283
  memory_service = VertexAiRagMemoryService(
270
284
  rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
271
285
  )
286
+ elif memory_service_uri.startswith("agentengine://"):
287
+ agent_engine_id = memory_service_uri.split("://")[1]
288
+ if not agent_engine_id:
289
+ raise click.ClickException("Agent engine id can not be empty.")
290
+ envs.load_dotenv_for_agent("", agents_dir)
291
+ memory_service = VertexAiMemoryBankService(
292
+ project=os.environ["GOOGLE_CLOUD_PROJECT"],
293
+ location=os.environ["GOOGLE_CLOUD_LOCATION"],
294
+ agent_engine_id=agent_engine_id,
295
+ )
272
296
  else:
273
297
  raise click.ClickException(
274
298
  "Unsupported memory service URI: %s" % memory_service_uri
@@ -14,17 +14,36 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
+ import dataclasses
18
+ import os
17
19
  from typing import Any
18
20
  from typing import Tuple
19
21
 
20
22
  from google.genai import types as genai_types
23
+ from pydantic import alias_generators
24
+ from pydantic import BaseModel
25
+ from pydantic import ConfigDict
21
26
  from typing_extensions import deprecated
22
27
 
23
28
  from ...evaluation.eval_case import IntermediateData
24
29
  from ...evaluation.eval_case import Invocation
30
+ from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
31
+ from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
25
32
  from ...sessions.session import Session
26
33
 
27
34
 
35
+ class GcsEvalManagers(BaseModel):
36
+ model_config = ConfigDict(
37
+ alias_generator=alias_generators.to_camel,
38
+ populate_by_name=True,
39
+ arbitrary_types_allowed=True,
40
+ )
41
+
42
+ eval_sets_manager: GcsEvalSetsManager
43
+
44
+ eval_set_results_manager: GcsEvalSetResultsManager
45
+
46
+
28
47
  @deprecated('Use convert_session_to_eval_invocations instead.')
29
48
  def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
30
49
  """Converts a session data into eval format.
@@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]:
176
195
  )
177
196
 
178
197
  return invocations
198
+
199
+
200
+ def create_gcs_eval_managers_from_uri(
201
+ eval_storage_uri: str,
202
+ ) -> GcsEvalManagers:
203
+ """Creates GcsEvalManagers from eval_storage_uri.
204
+
205
+ Args:
206
+ eval_storage_uri: The evals storage URI to use. Supported URIs:
207
+ gs://<bucket name>. If a path is provided, the bucket will be extracted.
208
+
209
+ Returns:
210
+ GcsEvalManagers: The GcsEvalManagers object.
211
+
212
+ Raises:
213
+ ValueError: If the eval_storage_uri is not supported.
214
+ """
215
+ if eval_storage_uri.startswith('gs://'):
216
+ gcs_bucket = eval_storage_uri.split('://')[1]
217
+ eval_sets_manager = GcsEvalSetsManager(
218
+ bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
219
+ )
220
+ eval_set_results_manager = GcsEvalSetResultsManager(
221
+ bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
222
+ )
223
+ return GcsEvalManagers(
224
+ eval_sets_manager=eval_sets_manager,
225
+ eval_set_results_manager=eval_set_results_manager,
226
+ )
227
+ else:
228
+ raise ValueError(
229
+ f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:'
230
+ ' gs://<bucket name>'
231
+ )
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional
18
+
19
+ from google.genai import types as genai_types
20
+ from rouge_score import rouge_scorer
21
+ from typing_extensions import override
22
+
23
+ from .eval_case import Invocation
24
+ from .eval_metrics import EvalMetric
25
+ from .evaluator import EvalStatus
26
+ from .evaluator import EvaluationResult
27
+ from .evaluator import Evaluator
28
+ from .evaluator import PerInvocationResult
29
+
30
+
31
+ class RougeEvaluator(Evaluator):
32
+ """Calculates the ROUGE-1 metric to compare responses."""
33
+
34
+ def __init__(self, eval_metric: EvalMetric):
35
+ self._eval_metric = eval_metric
36
+
37
+ @override
38
+ def evaluate_invocations(
39
+ self,
40
+ actual_invocations: list[Invocation],
41
+ expected_invocations: list[Invocation],
42
+ ) -> EvaluationResult:
43
+ total_score = 0.0
44
+ num_invocations = 0
45
+ per_invocation_results = []
46
+ for actual, expected in zip(actual_invocations, expected_invocations):
47
+ reference = _get_text_from_content(expected.final_response)
48
+ response = _get_text_from_content(actual.final_response)
49
+ rouge_1_scores = _calculate_rouge_1_scores(response, reference)
50
+ score = rouge_1_scores.fmeasure
51
+ per_invocation_results.append(
52
+ PerInvocationResult(
53
+ actual_invocation=actual,
54
+ expected_invocation=expected,
55
+ score=score,
56
+ eval_status=_get_eval_status(score, self._eval_metric.threshold),
57
+ )
58
+ )
59
+ total_score += score
60
+ num_invocations += 1
61
+
62
+ if per_invocation_results:
63
+ overall_score = total_score / num_invocations
64
+ return EvaluationResult(
65
+ overall_score=overall_score,
66
+ overall_eval_status=_get_eval_status(
67
+ overall_score, self._eval_metric.threshold
68
+ ),
69
+ per_invocation_results=per_invocation_results,
70
+ )
71
+
72
+ return EvaluationResult()
73
+
74
+
75
+ def _get_text_from_content(content: Optional[genai_types.Content]) -> str:
76
+ if content and content.parts:
77
+ return "\n".join([part.text for part in content.parts if part.text])
78
+
79
+ return ""
80
+
81
+
82
+ def _get_eval_status(score: float, threshold: float):
83
+ return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
84
+
85
+
86
+ def _calculate_rouge_1_scores(candidate: str, reference: str):
87
+ """Calculates the ROUGE-1 score between a candidate and reference text.
88
+
89
+ ROUGE-1 measures the overlap of unigrams (single words) between the
90
+ candidate and reference texts. The score is broken down into:
91
+ - Precision: The proportion of unigrams in the candidate that are also in the
92
+ reference.
93
+ - Recall: The proportion of unigrams in the reference that are also in the
94
+ candidate.
95
+ - F-measure: The harmonic mean of precision and recall.
96
+
97
+ Args:
98
+ candidate: The generated text to be evaluated.
99
+ reference: The ground-truth text to compare against.
100
+
101
+ Returns:
102
+ A dictionary containing the ROUGE-1 precision, recall, and f-measure.
103
+ """
104
+ scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
105
+
106
+ # The score method returns a dictionary where keys are the ROUGE types
107
+ # and values are Score objects (tuples) with precision, recall, and fmeasure.
108
+ scores = scorer.score(reference, candidate)
109
+
110
+ return scores["rouge1"]
@@ -72,6 +72,13 @@ class GcsEvalSetsManager(EvalSetsManager):
72
72
  f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
73
73
  )
74
74
 
75
+ def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]:
76
+ blob = self.bucket.blob(blob_name)
77
+ if not blob.exists():
78
+ return None
79
+ eval_set_data = blob.download_as_text()
80
+ return EvalSet.model_validate_json(eval_set_data)
81
+
75
82
  def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet):
76
83
  """Writes an EvalSet to GCS."""
77
84
  blob = self.bucket.blob(blob_name)
@@ -88,11 +95,7 @@ class GcsEvalSetsManager(EvalSetsManager):
88
95
  def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
89
96
  """Returns an EvalSet identified by an app_name and eval_set_id."""
90
97
  eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
91
- blob = self.bucket.blob(eval_set_blob_name)
92
- if not blob.exists():
93
- return None
94
- eval_set_data = blob.download_as_text()
95
- return EvalSet.model_validate_json(eval_set_data)
98
+ return self._load_eval_set_from_blob(eval_set_blob_name)
96
99
 
97
100
  @override
98
101
  def create_eval_set(self, app_name: str, eval_set_id: str):
@@ -27,10 +27,12 @@ from vertexai.preview.evaluation import MetricPromptTemplateExamples
27
27
 
28
28
  from .eval_case import IntermediateData
29
29
  from .eval_case import Invocation
30
+ from .eval_metrics import EvalMetric
30
31
  from .evaluator import EvalStatus
31
32
  from .evaluator import EvaluationResult
32
33
  from .evaluator import Evaluator
33
34
  from .evaluator import PerInvocationResult
35
+ from .final_response_match_v1 import RougeEvaluator
34
36
 
35
37
 
36
38
  class ResponseEvaluator(Evaluator):
@@ -40,7 +42,7 @@ class ResponseEvaluator(Evaluator):
40
42
  if "response_evaluation_score" == metric_name:
41
43
  self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
42
44
  elif "response_match_score" == metric_name:
43
- self._metric_name = "rouge_1"
45
+ self._metric_name = "response_match_score"
44
46
  else:
45
47
  raise ValueError(f"`{metric_name}` is not supported.")
46
48
 
@@ -52,6 +54,15 @@ class ResponseEvaluator(Evaluator):
52
54
  actual_invocations: list[Invocation],
53
55
  expected_invocations: list[Invocation],
54
56
  ) -> EvaluationResult:
57
+ # If the metric is response_match_score, just use the RougeEvaluator.
58
+ if self._metric_name == "response_match_score":
59
+ rouge_evaluator = RougeEvaluator(
60
+ EvalMetric(metric_name=self._metric_name, threshold=self._threshold)
61
+ )
62
+ return rouge_evaluator.evaluate_invocations(
63
+ actual_invocations, expected_invocations
64
+ )
65
+
55
66
  total_score = 0.0
56
67
  num_invocations = 0
57
68
  per_invocation_results = []
@@ -34,9 +34,10 @@ class Event(LlmResponse):
34
34
  taken by the agents like function calls, etc.
35
35
 
36
36
  Attributes:
37
- invocation_id: The invocation ID of the event.
38
- author: "user" or the name of the agent, indicating who appended the event
39
- to the session.
37
+ invocation_id: Required. The invocation ID of the event. Should be non-empty
38
+ before appending to a session.
39
+ author: Required. "user" or the name of the agent, indicating who appended
40
+ the event to the session.
40
41
  actions: The actions taken by the agent.
41
42
  long_running_tool_ids: The ids of the long running function calls.
42
43
  branch: The branch of the event.
@@ -55,9 +56,8 @@ class Event(LlmResponse):
55
56
  )
56
57
  """The pydantic model config."""
57
58
 
58
- # TODO: revert to be required after spark migration
59
59
  invocation_id: str = ''
60
- """The invocation ID of the event."""
60
+ """The invocation ID of the event. Should be non-empty before appending to a session."""
61
61
  author: str
62
62
  """'user' or the name of the agent, indicating who appended the event to the
63
63
  session."""
@@ -43,12 +43,20 @@ class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
43
43
  if not isinstance(agent, LlmAgent):
44
44
  return
45
45
 
46
- if agent.include_contents != 'none':
46
+ if agent.include_contents == 'default':
47
+ # Include full conversation history
47
48
  llm_request.contents = _get_contents(
48
49
  invocation_context.branch,
49
50
  invocation_context.session.events,
50
51
  agent.name,
51
52
  )
53
+ else:
54
+ # Include current turn context only (no conversation history)
55
+ llm_request.contents = _get_current_turn_contents(
56
+ invocation_context.branch,
57
+ invocation_context.session.events,
58
+ agent.name,
59
+ )
52
60
 
53
61
  # Maintain async generator behavior
54
62
  if False: # Ensures it behaves as a generator
@@ -190,13 +198,15 @@ def _get_contents(
190
198
  ) -> list[types.Content]:
191
199
  """Get the contents for the LLM request.
192
200
 
201
+ Applies filtering, rearrangement, and content processing to events.
202
+
193
203
  Args:
194
204
  current_branch: The current branch of the agent.
195
- events: A list of events.
205
+ events: Events to process.
196
206
  agent_name: The name of the agent.
197
207
 
198
208
  Returns:
199
- A list of contents.
209
+ A list of processed contents.
200
210
  """
201
211
  filtered_events = []
202
212
  # Parse the events, leaving the contents and the function calls and
@@ -211,12 +221,13 @@ def _get_contents(
211
221
  # Skip events without content, or generated neither by user nor by model
212
222
  # or has empty text.
213
223
  # E.g. events purely for mutating session states.
224
+
214
225
  continue
215
226
  if not _is_event_belongs_to_branch(current_branch, event):
216
227
  # Skip events not belong to current branch.
217
228
  continue
218
229
  if _is_auth_event(event):
219
- # skip auth event
230
+ # Skip auth events.
220
231
  continue
221
232
  filtered_events.append(
222
233
  _convert_foreign_event(event)
@@ -224,12 +235,15 @@ def _get_contents(
224
235
  else event
225
236
  )
226
237
 
238
+ # Rearrange events for proper function call/response pairing
227
239
  result_events = _rearrange_events_for_latest_function_response(
228
240
  filtered_events
229
241
  )
230
242
  result_events = _rearrange_events_for_async_function_responses_in_history(
231
243
  result_events
232
244
  )
245
+
246
+ # Convert events to contents
233
247
  contents = []
234
248
  for event in result_events:
235
249
  content = copy.deepcopy(event.content)
@@ -238,6 +252,37 @@ def _get_contents(
238
252
  return contents
239
253
 
240
254
 
255
+ def _get_current_turn_contents(
256
+ current_branch: Optional[str], events: list[Event], agent_name: str = ''
257
+ ) -> list[types.Content]:
258
+ """Get contents for the current turn only (no conversation history).
259
+
260
+ When include_contents='none', we want to include:
261
+ - The current user input
262
+ - Tool calls and responses from the current turn
263
+ But exclude conversation history from previous turns.
264
+
265
+ In multi-agent scenarios, the "current turn" for an agent starts from an
266
+ actual user or from another agent.
267
+
268
+ Args:
269
+ current_branch: The current branch of the agent.
270
+ events: A list of all session events.
271
+ agent_name: The name of the agent.
272
+
273
+ Returns:
274
+ A list of contents for the current turn only, preserving context needed
275
+ for proper tool execution while excluding conversation history.
276
+ """
277
+ # Find the latest event that starts the current turn and process from there
278
+ for i in range(len(events) - 1, -1, -1):
279
+ event = events[i]
280
+ if event.author == 'user' or _is_other_agent_reply(agent_name, event):
281
+ return _get_contents(current_branch, events[i:], agent_name)
282
+
283
+ return []
284
+
285
+
241
286
  def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
242
287
  """Whether the event is a reply from another agent."""
243
288
  return bool(
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
519
519
  # Use the base_event as the timestamp
520
520
  merged_event.timestamp = base_event.timestamp
521
521
  return merged_event
522
+
523
+
524
+ def find_matching_function_call(
525
+ events: list[Event],
526
+ ) -> Optional[Event]:
527
+ """Finds the function call event that matches the function response id of the last event."""
528
+ if not events:
529
+ return None
530
+
531
+ last_event = events[-1]
532
+ if (
533
+ last_event.content
534
+ and last_event.content.parts
535
+ and any(part.function_response for part in last_event.content.parts)
536
+ ):
537
+
538
+ function_call_id = next(
539
+ part.function_response.id
540
+ for part in last_event.content.parts
541
+ if part.function_response
542
+ )
543
+ for i in range(len(events) - 2, -1, -1):
544
+ event = events[i]
545
+ # looking for the system long running request euc function call
546
+ function_calls = event.get_function_calls()
547
+ if not function_calls:
548
+ continue
549
+
550
+ for function_call in function_calls:
551
+ if function_call.id == function_call_id:
552
+ return event
553
+ return None
@@ -15,12 +15,14 @@ import logging
15
15
 
16
16
  from .base_memory_service import BaseMemoryService
17
17
  from .in_memory_memory_service import InMemoryMemoryService
18
+ from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
18
19
 
19
20
  logger = logging.getLogger('google_adk.' + __name__)
20
21
 
21
22
  __all__ = [
22
23
  'BaseMemoryService',
23
24
  'InMemoryMemoryService',
25
+ 'VertexAiMemoryBankService',
24
26
  ]
25
27
 
26
28
  try:
@@ -29,7 +31,7 @@ try:
29
31
  __all__.append('VertexAiRagMemoryService')
30
32
  except ImportError:
31
33
  logger.debug(
32
- 'The Vertex sdk is not installed. If you want to use the'
34
+ 'The Vertex SDK is not installed. If you want to use the'
33
35
  ' VertexAiRagMemoryService please install it. If not, you can ignore this'
34
36
  ' warning.'
35
37
  )
@@ -0,0 +1,150 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ from typing import Optional
20
+ from typing import TYPE_CHECKING
21
+
22
+ from typing_extensions import override
23
+
24
+ from google import genai
25
+
26
+ from .base_memory_service import BaseMemoryService
27
+ from .base_memory_service import SearchMemoryResponse
28
+ from .memory_entry import MemoryEntry
29
+
30
+ if TYPE_CHECKING:
31
+ from ..sessions.session import Session
32
+
33
+ logger = logging.getLogger('google_adk.' + __name__)
34
+
35
+
36
+ class VertexAiMemoryBankService(BaseMemoryService):
37
+ """Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
38
+
39
+ def __init__(
40
+ self,
41
+ project: Optional[str] = None,
42
+ location: Optional[str] = None,
43
+ agent_engine_id: Optional[str] = None,
44
+ ):
45
+ """Initializes a VertexAiMemoryBankService.
46
+
47
+ Args:
48
+ project: The project ID of the Memory Bank to use.
49
+ location: The location of the Memory Bank to use.
50
+ agent_engine_id: The ID of the agent engine to use for the Memory Bank.
51
+ e.g. '456' in
52
+ 'projects/my-project/locations/us-central1/reasoningEngines/456'.
53
+ """
54
+ self._project = project
55
+ self._location = location
56
+ self._agent_engine_id = agent_engine_id
57
+
58
+ @override
59
+ async def add_session_to_memory(self, session: Session):
60
+ api_client = self._get_api_client()
61
+
62
+ if not self._agent_engine_id:
63
+ raise ValueError('Agent Engine ID is required for Memory Bank.')
64
+
65
+ events = []
66
+ for event in session.events:
67
+ if event.content and event.content.parts:
68
+ events.append({
69
+ 'content': event.content.model_dump(exclude_none=True, mode='json')
70
+ })
71
+ request_dict = {
72
+ 'direct_contents_source': {
73
+ 'events': events,
74
+ },
75
+ 'scope': {
76
+ 'app_name': session.app_name,
77
+ 'user_id': session.user_id,
78
+ },
79
+ }
80
+
81
+ if events:
82
+ api_response = await api_client.async_request(
83
+ http_method='POST',
84
+ path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
85
+ request_dict=request_dict,
86
+ )
87
+ logger.info(f'Generate memory response: {api_response}')
88
+ else:
89
+ logger.info('No events to add to memory.')
90
+
91
+ @override
92
+ async def search_memory(self, *, app_name: str, user_id: str, query: str):
93
+ api_client = self._get_api_client()
94
+
95
+ api_response = await api_client.async_request(
96
+ http_method='POST',
97
+ path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
98
+ request_dict={
99
+ 'scope': {
100
+ 'app_name': app_name,
101
+ 'user_id': user_id,
102
+ },
103
+ 'similarity_search_params': {
104
+ 'search_query': query,
105
+ },
106
+ },
107
+ )
108
+ api_response = _convert_api_response(api_response)
109
+ logger.info(f'Search memory response: {api_response}')
110
+
111
+ if not api_response or not api_response.get('retrievedMemories', None):
112
+ return SearchMemoryResponse()
113
+
114
+ memory_events = []
115
+ for memory in api_response.get('retrievedMemories', []):
116
+ # TODO: add more complex error handling
117
+ memory_events.append(
118
+ MemoryEntry(
119
+ author='user',
120
+ content=genai.types.Content(
121
+ parts=[
122
+ genai.types.Part(text=memory.get('memory').get('fact'))
123
+ ],
124
+ role='user',
125
+ ),
126
+ timestamp=memory.get('updateTime'),
127
+ )
128
+ )
129
+ return SearchMemoryResponse(memories=memory_events)
130
+
131
+ def _get_api_client(self):
132
+ """Instantiates an API client for the given project and location.
133
+
134
+ It needs to be instantiated inside each request so that the event loop
135
+ management can be properly propagated.
136
+
137
+ Returns:
138
+ An API client for the given project and location.
139
+ """
140
+ client = genai.Client(
141
+ vertexai=True, project=self._project, location=self._location
142
+ )
143
+ return client._api_client
144
+
145
+
146
+ def _convert_api_response(api_response):
147
+ """Converts the API response to a JSON object based on the type."""
148
+ if hasattr(api_response, 'body'):
149
+ return json.loads(api_response.body)
150
+ return api_response
@@ -29,6 +29,7 @@ from typing import Tuple
29
29
  from typing import Union
30
30
 
31
31
  from google.genai import types
32
+ import litellm
32
33
  from litellm import acompletion
33
34
  from litellm import ChatCompletionAssistantMessage
34
35
  from litellm import ChatCompletionAssistantToolCall
@@ -53,6 +54,9 @@ from .base_llm import BaseLlm
53
54
  from .llm_request import LlmRequest
54
55
  from .llm_response import LlmResponse
55
56
 
57
+ # This will add functions to prompts if functions are provided.
58
+ litellm.add_function_to_prompt = True
59
+
56
60
  logger = logging.getLogger("google_adk." + __name__)
57
61
 
58
62
  _NEW_LINE = "\n"
@@ -662,6 +666,10 @@ class LiteLlm(BaseLlm):
662
666
 
663
667
  messages, tools, response_format = _get_completion_inputs(llm_request)
664
668
 
669
+ if "functions" in self._additional_args:
670
+ # LiteLLM does not support both tools and functions together.
671
+ tools = None
672
+
665
673
  completion_args = {
666
674
  "model": self.model,
667
675
  "messages": messages,
@@ -679,7 +687,7 @@ class LiteLlm(BaseLlm):
679
687
  aggregated_llm_response_with_tool_call = None
680
688
  usage_metadata = None
681
689
  fallback_index = 0
682
- for part in self.llm_client.completion(**completion_args):
690
+ async for part in await self.llm_client.acompletion(**completion_args):
683
691
  for chunk, finish_reason in _model_response_to_chunk(part):
684
692
  if isinstance(chunk, FunctionChunk):
685
693
  index = chunk.index or fallback_index