google-adk 1.4.2__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/converters/request_converter.py +90 -0
- google/adk/a2a/converters/utils.py +37 -0
- google/adk/agents/llm_agent.py +5 -3
- google/adk/artifacts/gcs_artifact_service.py +3 -2
- google/adk/auth/auth_tool.py +2 -2
- google/adk/auth/credential_service/session_state_credential_service.py +83 -0
- google/adk/cli/cli_deploy.py +9 -2
- google/adk/cli/cli_tools_click.py +110 -52
- google/adk/cli/fast_api.py +26 -2
- google/adk/cli/utils/evals.py +53 -0
- google/adk/evaluation/final_response_match_v1.py +110 -0
- google/adk/evaluation/gcs_eval_sets_manager.py +8 -5
- google/adk/evaluation/response_evaluator.py +12 -1
- google/adk/events/event.py +5 -5
- google/adk/flows/llm_flows/contents.py +49 -4
- google/adk/flows/llm_flows/functions.py +32 -0
- google/adk/memory/__init__.py +3 -1
- google/adk/memory/vertex_ai_memory_bank_service.py +150 -0
- google/adk/models/lite_llm.py +9 -1
- google/adk/runners.py +10 -0
- google/adk/sessions/vertex_ai_session_service.py +53 -18
- google/adk/telemetry.py +10 -0
- google/adk/version.py +1 -1
- {google_adk-1.4.2.dist-info → google_adk-1.5.0.dist-info}/METADATA +6 -5
- {google_adk-1.4.2.dist-info → google_adk-1.5.0.dist-info}/RECORD +28 -24
- {google_adk-1.4.2.dist-info → google_adk-1.5.0.dist-info}/WHEEL +0 -0
- {google_adk-1.4.2.dist-info → google_adk-1.5.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.4.2.dist-info → google_adk-1.5.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/fast_api.py
CHANGED
@@ -65,10 +65,13 @@ from ..evaluation.eval_metrics import EvalMetric
|
|
65
65
|
from ..evaluation.eval_metrics import EvalMetricResult
|
66
66
|
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
67
67
|
from ..evaluation.eval_result import EvalSetResult
|
68
|
+
from ..evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
|
69
|
+
from ..evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
|
68
70
|
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
69
71
|
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
70
72
|
from ..events.event import Event
|
71
73
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
74
|
+
from ..memory.vertex_ai_memory_bank_service import VertexAiMemoryBankService
|
72
75
|
from ..memory.vertex_ai_rag_memory_service import VertexAiRagMemoryService
|
73
76
|
from ..runners import Runner
|
74
77
|
from ..sessions.database_session_service import DatabaseSessionService
|
@@ -198,6 +201,7 @@ def get_fast_api_app(
|
|
198
201
|
session_service_uri: Optional[str] = None,
|
199
202
|
artifact_service_uri: Optional[str] = None,
|
200
203
|
memory_service_uri: Optional[str] = None,
|
204
|
+
eval_storage_uri: Optional[str] = None,
|
201
205
|
allow_origins: Optional[list[str]] = None,
|
202
206
|
web: bool,
|
203
207
|
trace_to_cloud: bool = False,
|
@@ -256,8 +260,18 @@ def get_fast_api_app(
|
|
256
260
|
|
257
261
|
runner_dict = {}
|
258
262
|
|
259
|
-
|
260
|
-
|
263
|
+
# Set up eval managers.
|
264
|
+
eval_sets_manager = None
|
265
|
+
eval_set_results_manager = None
|
266
|
+
if eval_storage_uri:
|
267
|
+
gcs_eval_managers = evals.create_gcs_eval_managers_from_uri(
|
268
|
+
eval_storage_uri
|
269
|
+
)
|
270
|
+
eval_sets_manager = gcs_eval_managers.eval_sets_manager
|
271
|
+
eval_set_results_manager = gcs_eval_managers.eval_set_results_manager
|
272
|
+
else:
|
273
|
+
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
274
|
+
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
261
275
|
|
262
276
|
# Build the Memory service
|
263
277
|
if memory_service_uri:
|
@@ -269,6 +283,16 @@ def get_fast_api_app(
|
|
269
283
|
memory_service = VertexAiRagMemoryService(
|
270
284
|
rag_corpus=f'projects/{os.environ["GOOGLE_CLOUD_PROJECT"]}/locations/{os.environ["GOOGLE_CLOUD_LOCATION"]}/ragCorpora/{rag_corpus}'
|
271
285
|
)
|
286
|
+
elif memory_service_uri.startswith("agentengine://"):
|
287
|
+
agent_engine_id = memory_service_uri.split("://")[1]
|
288
|
+
if not agent_engine_id:
|
289
|
+
raise click.ClickException("Agent engine id can not be empty.")
|
290
|
+
envs.load_dotenv_for_agent("", agents_dir)
|
291
|
+
memory_service = VertexAiMemoryBankService(
|
292
|
+
project=os.environ["GOOGLE_CLOUD_PROJECT"],
|
293
|
+
location=os.environ["GOOGLE_CLOUD_LOCATION"],
|
294
|
+
agent_engine_id=agent_engine_id,
|
295
|
+
)
|
272
296
|
else:
|
273
297
|
raise click.ClickException(
|
274
298
|
"Unsupported memory service URI: %s" % memory_service_uri
|
google/adk/cli/utils/evals.py
CHANGED
@@ -14,17 +14,36 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
+
import dataclasses
|
18
|
+
import os
|
17
19
|
from typing import Any
|
18
20
|
from typing import Tuple
|
19
21
|
|
20
22
|
from google.genai import types as genai_types
|
23
|
+
from pydantic import alias_generators
|
24
|
+
from pydantic import BaseModel
|
25
|
+
from pydantic import ConfigDict
|
21
26
|
from typing_extensions import deprecated
|
22
27
|
|
23
28
|
from ...evaluation.eval_case import IntermediateData
|
24
29
|
from ...evaluation.eval_case import Invocation
|
30
|
+
from ...evaluation.gcs_eval_set_results_manager import GcsEvalSetResultsManager
|
31
|
+
from ...evaluation.gcs_eval_sets_manager import GcsEvalSetsManager
|
25
32
|
from ...sessions.session import Session
|
26
33
|
|
27
34
|
|
35
|
+
class GcsEvalManagers(BaseModel):
|
36
|
+
model_config = ConfigDict(
|
37
|
+
alias_generator=alias_generators.to_camel,
|
38
|
+
populate_by_name=True,
|
39
|
+
arbitrary_types_allowed=True,
|
40
|
+
)
|
41
|
+
|
42
|
+
eval_sets_manager: GcsEvalSetsManager
|
43
|
+
|
44
|
+
eval_set_results_manager: GcsEvalSetResultsManager
|
45
|
+
|
46
|
+
|
28
47
|
@deprecated('Use convert_session_to_eval_invocations instead.')
|
29
48
|
def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
30
49
|
"""Converts a session data into eval format.
|
@@ -176,3 +195,37 @@ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]:
|
|
176
195
|
)
|
177
196
|
|
178
197
|
return invocations
|
198
|
+
|
199
|
+
|
200
|
+
def create_gcs_eval_managers_from_uri(
|
201
|
+
eval_storage_uri: str,
|
202
|
+
) -> GcsEvalManagers:
|
203
|
+
"""Creates GcsEvalManagers from eval_storage_uri.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
eval_storage_uri: The evals storage URI to use. Supported URIs:
|
207
|
+
gs://<bucket name>. If a path is provided, the bucket will be extracted.
|
208
|
+
|
209
|
+
Returns:
|
210
|
+
GcsEvalManagers: The GcsEvalManagers object.
|
211
|
+
|
212
|
+
Raises:
|
213
|
+
ValueError: If the eval_storage_uri is not supported.
|
214
|
+
"""
|
215
|
+
if eval_storage_uri.startswith('gs://'):
|
216
|
+
gcs_bucket = eval_storage_uri.split('://')[1]
|
217
|
+
eval_sets_manager = GcsEvalSetsManager(
|
218
|
+
bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
|
219
|
+
)
|
220
|
+
eval_set_results_manager = GcsEvalSetResultsManager(
|
221
|
+
bucket_name=gcs_bucket, project=os.environ['GOOGLE_CLOUD_PROJECT']
|
222
|
+
)
|
223
|
+
return GcsEvalManagers(
|
224
|
+
eval_sets_manager=eval_sets_manager,
|
225
|
+
eval_set_results_manager=eval_set_results_manager,
|
226
|
+
)
|
227
|
+
else:
|
228
|
+
raise ValueError(
|
229
|
+
f'Unsupported evals storage URI: {eval_storage_uri}. Supported URIs:'
|
230
|
+
' gs://<bucket name>'
|
231
|
+
)
|
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from google.genai import types as genai_types
|
20
|
+
from rouge_score import rouge_scorer
|
21
|
+
from typing_extensions import override
|
22
|
+
|
23
|
+
from .eval_case import Invocation
|
24
|
+
from .eval_metrics import EvalMetric
|
25
|
+
from .evaluator import EvalStatus
|
26
|
+
from .evaluator import EvaluationResult
|
27
|
+
from .evaluator import Evaluator
|
28
|
+
from .evaluator import PerInvocationResult
|
29
|
+
|
30
|
+
|
31
|
+
class RougeEvaluator(Evaluator):
|
32
|
+
"""Calculates the ROUGE-1 metric to compare responses."""
|
33
|
+
|
34
|
+
def __init__(self, eval_metric: EvalMetric):
|
35
|
+
self._eval_metric = eval_metric
|
36
|
+
|
37
|
+
@override
|
38
|
+
def evaluate_invocations(
|
39
|
+
self,
|
40
|
+
actual_invocations: list[Invocation],
|
41
|
+
expected_invocations: list[Invocation],
|
42
|
+
) -> EvaluationResult:
|
43
|
+
total_score = 0.0
|
44
|
+
num_invocations = 0
|
45
|
+
per_invocation_results = []
|
46
|
+
for actual, expected in zip(actual_invocations, expected_invocations):
|
47
|
+
reference = _get_text_from_content(expected.final_response)
|
48
|
+
response = _get_text_from_content(actual.final_response)
|
49
|
+
rouge_1_scores = _calculate_rouge_1_scores(response, reference)
|
50
|
+
score = rouge_1_scores.fmeasure
|
51
|
+
per_invocation_results.append(
|
52
|
+
PerInvocationResult(
|
53
|
+
actual_invocation=actual,
|
54
|
+
expected_invocation=expected,
|
55
|
+
score=score,
|
56
|
+
eval_status=_get_eval_status(score, self._eval_metric.threshold),
|
57
|
+
)
|
58
|
+
)
|
59
|
+
total_score += score
|
60
|
+
num_invocations += 1
|
61
|
+
|
62
|
+
if per_invocation_results:
|
63
|
+
overall_score = total_score / num_invocations
|
64
|
+
return EvaluationResult(
|
65
|
+
overall_score=overall_score,
|
66
|
+
overall_eval_status=_get_eval_status(
|
67
|
+
overall_score, self._eval_metric.threshold
|
68
|
+
),
|
69
|
+
per_invocation_results=per_invocation_results,
|
70
|
+
)
|
71
|
+
|
72
|
+
return EvaluationResult()
|
73
|
+
|
74
|
+
|
75
|
+
def _get_text_from_content(content: Optional[genai_types.Content]) -> str:
|
76
|
+
if content and content.parts:
|
77
|
+
return "\n".join([part.text for part in content.parts if part.text])
|
78
|
+
|
79
|
+
return ""
|
80
|
+
|
81
|
+
|
82
|
+
def _get_eval_status(score: float, threshold: float):
|
83
|
+
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
|
84
|
+
|
85
|
+
|
86
|
+
def _calculate_rouge_1_scores(candidate: str, reference: str):
|
87
|
+
"""Calculates the ROUGE-1 score between a candidate and reference text.
|
88
|
+
|
89
|
+
ROUGE-1 measures the overlap of unigrams (single words) between the
|
90
|
+
candidate and reference texts. The score is broken down into:
|
91
|
+
- Precision: The proportion of unigrams in the candidate that are also in the
|
92
|
+
reference.
|
93
|
+
- Recall: The proportion of unigrams in the reference that are also in the
|
94
|
+
candidate.
|
95
|
+
- F-measure: The harmonic mean of precision and recall.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
candidate: The generated text to be evaluated.
|
99
|
+
reference: The ground-truth text to compare against.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
A dictionary containing the ROUGE-1 precision, recall, and f-measure.
|
103
|
+
"""
|
104
|
+
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
|
105
|
+
|
106
|
+
# The score method returns a dictionary where keys are the ROUGE types
|
107
|
+
# and values are Score objects (tuples) with precision, recall, and fmeasure.
|
108
|
+
scores = scorer.score(reference, candidate)
|
109
|
+
|
110
|
+
return scores["rouge1"]
|
@@ -72,6 +72,13 @@ class GcsEvalSetsManager(EvalSetsManager):
|
|
72
72
|
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
|
73
73
|
)
|
74
74
|
|
75
|
+
def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]:
|
76
|
+
blob = self.bucket.blob(blob_name)
|
77
|
+
if not blob.exists():
|
78
|
+
return None
|
79
|
+
eval_set_data = blob.download_as_text()
|
80
|
+
return EvalSet.model_validate_json(eval_set_data)
|
81
|
+
|
75
82
|
def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet):
|
76
83
|
"""Writes an EvalSet to GCS."""
|
77
84
|
blob = self.bucket.blob(blob_name)
|
@@ -88,11 +95,7 @@ class GcsEvalSetsManager(EvalSetsManager):
|
|
88
95
|
def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
|
89
96
|
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
90
97
|
eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
|
91
|
-
|
92
|
-
if not blob.exists():
|
93
|
-
return None
|
94
|
-
eval_set_data = blob.download_as_text()
|
95
|
-
return EvalSet.model_validate_json(eval_set_data)
|
98
|
+
return self._load_eval_set_from_blob(eval_set_blob_name)
|
96
99
|
|
97
100
|
@override
|
98
101
|
def create_eval_set(self, app_name: str, eval_set_id: str):
|
@@ -27,10 +27,12 @@ from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
|
27
27
|
|
28
28
|
from .eval_case import IntermediateData
|
29
29
|
from .eval_case import Invocation
|
30
|
+
from .eval_metrics import EvalMetric
|
30
31
|
from .evaluator import EvalStatus
|
31
32
|
from .evaluator import EvaluationResult
|
32
33
|
from .evaluator import Evaluator
|
33
34
|
from .evaluator import PerInvocationResult
|
35
|
+
from .final_response_match_v1 import RougeEvaluator
|
34
36
|
|
35
37
|
|
36
38
|
class ResponseEvaluator(Evaluator):
|
@@ -40,7 +42,7 @@ class ResponseEvaluator(Evaluator):
|
|
40
42
|
if "response_evaluation_score" == metric_name:
|
41
43
|
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
|
42
44
|
elif "response_match_score" == metric_name:
|
43
|
-
self._metric_name = "
|
45
|
+
self._metric_name = "response_match_score"
|
44
46
|
else:
|
45
47
|
raise ValueError(f"`{metric_name}` is not supported.")
|
46
48
|
|
@@ -52,6 +54,15 @@ class ResponseEvaluator(Evaluator):
|
|
52
54
|
actual_invocations: list[Invocation],
|
53
55
|
expected_invocations: list[Invocation],
|
54
56
|
) -> EvaluationResult:
|
57
|
+
# If the metric is response_match_score, just use the RougeEvaluator.
|
58
|
+
if self._metric_name == "response_match_score":
|
59
|
+
rouge_evaluator = RougeEvaluator(
|
60
|
+
EvalMetric(metric_name=self._metric_name, threshold=self._threshold)
|
61
|
+
)
|
62
|
+
return rouge_evaluator.evaluate_invocations(
|
63
|
+
actual_invocations, expected_invocations
|
64
|
+
)
|
65
|
+
|
55
66
|
total_score = 0.0
|
56
67
|
num_invocations = 0
|
57
68
|
per_invocation_results = []
|
google/adk/events/event.py
CHANGED
@@ -34,9 +34,10 @@ class Event(LlmResponse):
|
|
34
34
|
taken by the agents like function calls, etc.
|
35
35
|
|
36
36
|
Attributes:
|
37
|
-
invocation_id: The invocation ID of the event.
|
38
|
-
|
39
|
-
|
37
|
+
invocation_id: Required. The invocation ID of the event. Should be non-empty
|
38
|
+
before appending to a session.
|
39
|
+
author: Required. "user" or the name of the agent, indicating who appended
|
40
|
+
the event to the session.
|
40
41
|
actions: The actions taken by the agent.
|
41
42
|
long_running_tool_ids: The ids of the long running function calls.
|
42
43
|
branch: The branch of the event.
|
@@ -55,9 +56,8 @@ class Event(LlmResponse):
|
|
55
56
|
)
|
56
57
|
"""The pydantic model config."""
|
57
58
|
|
58
|
-
# TODO: revert to be required after spark migration
|
59
59
|
invocation_id: str = ''
|
60
|
-
"""The invocation ID of the event."""
|
60
|
+
"""The invocation ID of the event. Should be non-empty before appending to a session."""
|
61
61
|
author: str
|
62
62
|
"""'user' or the name of the agent, indicating who appended the event to the
|
63
63
|
session."""
|
@@ -43,12 +43,20 @@ class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
43
43
|
if not isinstance(agent, LlmAgent):
|
44
44
|
return
|
45
45
|
|
46
|
-
if agent.include_contents
|
46
|
+
if agent.include_contents == 'default':
|
47
|
+
# Include full conversation history
|
47
48
|
llm_request.contents = _get_contents(
|
48
49
|
invocation_context.branch,
|
49
50
|
invocation_context.session.events,
|
50
51
|
agent.name,
|
51
52
|
)
|
53
|
+
else:
|
54
|
+
# Include current turn context only (no conversation history)
|
55
|
+
llm_request.contents = _get_current_turn_contents(
|
56
|
+
invocation_context.branch,
|
57
|
+
invocation_context.session.events,
|
58
|
+
agent.name,
|
59
|
+
)
|
52
60
|
|
53
61
|
# Maintain async generator behavior
|
54
62
|
if False: # Ensures it behaves as a generator
|
@@ -190,13 +198,15 @@ def _get_contents(
|
|
190
198
|
) -> list[types.Content]:
|
191
199
|
"""Get the contents for the LLM request.
|
192
200
|
|
201
|
+
Applies filtering, rearrangement, and content processing to events.
|
202
|
+
|
193
203
|
Args:
|
194
204
|
current_branch: The current branch of the agent.
|
195
|
-
events:
|
205
|
+
events: Events to process.
|
196
206
|
agent_name: The name of the agent.
|
197
207
|
|
198
208
|
Returns:
|
199
|
-
A list of contents.
|
209
|
+
A list of processed contents.
|
200
210
|
"""
|
201
211
|
filtered_events = []
|
202
212
|
# Parse the events, leaving the contents and the function calls and
|
@@ -211,12 +221,13 @@ def _get_contents(
|
|
211
221
|
# Skip events without content, or generated neither by user nor by model
|
212
222
|
# or has empty text.
|
213
223
|
# E.g. events purely for mutating session states.
|
224
|
+
|
214
225
|
continue
|
215
226
|
if not _is_event_belongs_to_branch(current_branch, event):
|
216
227
|
# Skip events not belong to current branch.
|
217
228
|
continue
|
218
229
|
if _is_auth_event(event):
|
219
|
-
#
|
230
|
+
# Skip auth events.
|
220
231
|
continue
|
221
232
|
filtered_events.append(
|
222
233
|
_convert_foreign_event(event)
|
@@ -224,12 +235,15 @@ def _get_contents(
|
|
224
235
|
else event
|
225
236
|
)
|
226
237
|
|
238
|
+
# Rearrange events for proper function call/response pairing
|
227
239
|
result_events = _rearrange_events_for_latest_function_response(
|
228
240
|
filtered_events
|
229
241
|
)
|
230
242
|
result_events = _rearrange_events_for_async_function_responses_in_history(
|
231
243
|
result_events
|
232
244
|
)
|
245
|
+
|
246
|
+
# Convert events to contents
|
233
247
|
contents = []
|
234
248
|
for event in result_events:
|
235
249
|
content = copy.deepcopy(event.content)
|
@@ -238,6 +252,37 @@ def _get_contents(
|
|
238
252
|
return contents
|
239
253
|
|
240
254
|
|
255
|
+
def _get_current_turn_contents(
|
256
|
+
current_branch: Optional[str], events: list[Event], agent_name: str = ''
|
257
|
+
) -> list[types.Content]:
|
258
|
+
"""Get contents for the current turn only (no conversation history).
|
259
|
+
|
260
|
+
When include_contents='none', we want to include:
|
261
|
+
- The current user input
|
262
|
+
- Tool calls and responses from the current turn
|
263
|
+
But exclude conversation history from previous turns.
|
264
|
+
|
265
|
+
In multi-agent scenarios, the "current turn" for an agent starts from an
|
266
|
+
actual user or from another agent.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
current_branch: The current branch of the agent.
|
270
|
+
events: A list of all session events.
|
271
|
+
agent_name: The name of the agent.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
A list of contents for the current turn only, preserving context needed
|
275
|
+
for proper tool execution while excluding conversation history.
|
276
|
+
"""
|
277
|
+
# Find the latest event that starts the current turn and process from there
|
278
|
+
for i in range(len(events) - 1, -1, -1):
|
279
|
+
event = events[i]
|
280
|
+
if event.author == 'user' or _is_other_agent_reply(agent_name, event):
|
281
|
+
return _get_contents(current_branch, events[i:], agent_name)
|
282
|
+
|
283
|
+
return []
|
284
|
+
|
285
|
+
|
241
286
|
def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
|
242
287
|
"""Whether the event is a reply from another agent."""
|
243
288
|
return bool(
|
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
|
|
519
519
|
# Use the base_event as the timestamp
|
520
520
|
merged_event.timestamp = base_event.timestamp
|
521
521
|
return merged_event
|
522
|
+
|
523
|
+
|
524
|
+
def find_matching_function_call(
|
525
|
+
events: list[Event],
|
526
|
+
) -> Optional[Event]:
|
527
|
+
"""Finds the function call event that matches the function response id of the last event."""
|
528
|
+
if not events:
|
529
|
+
return None
|
530
|
+
|
531
|
+
last_event = events[-1]
|
532
|
+
if (
|
533
|
+
last_event.content
|
534
|
+
and last_event.content.parts
|
535
|
+
and any(part.function_response for part in last_event.content.parts)
|
536
|
+
):
|
537
|
+
|
538
|
+
function_call_id = next(
|
539
|
+
part.function_response.id
|
540
|
+
for part in last_event.content.parts
|
541
|
+
if part.function_response
|
542
|
+
)
|
543
|
+
for i in range(len(events) - 2, -1, -1):
|
544
|
+
event = events[i]
|
545
|
+
# looking for the system long running request euc function call
|
546
|
+
function_calls = event.get_function_calls()
|
547
|
+
if not function_calls:
|
548
|
+
continue
|
549
|
+
|
550
|
+
for function_call in function_calls:
|
551
|
+
if function_call.id == function_call_id:
|
552
|
+
return event
|
553
|
+
return None
|
google/adk/memory/__init__.py
CHANGED
@@ -15,12 +15,14 @@ import logging
|
|
15
15
|
|
16
16
|
from .base_memory_service import BaseMemoryService
|
17
17
|
from .in_memory_memory_service import InMemoryMemoryService
|
18
|
+
from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
|
18
19
|
|
19
20
|
logger = logging.getLogger('google_adk.' + __name__)
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
'BaseMemoryService',
|
23
24
|
'InMemoryMemoryService',
|
25
|
+
'VertexAiMemoryBankService',
|
24
26
|
]
|
25
27
|
|
26
28
|
try:
|
@@ -29,7 +31,7 @@ try:
|
|
29
31
|
__all__.append('VertexAiRagMemoryService')
|
30
32
|
except ImportError:
|
31
33
|
logger.debug(
|
32
|
-
'The Vertex
|
34
|
+
'The Vertex SDK is not installed. If you want to use the'
|
33
35
|
' VertexAiRagMemoryService please install it. If not, you can ignore this'
|
34
36
|
' warning.'
|
35
37
|
)
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
from typing import Optional
|
20
|
+
from typing import TYPE_CHECKING
|
21
|
+
|
22
|
+
from typing_extensions import override
|
23
|
+
|
24
|
+
from google import genai
|
25
|
+
|
26
|
+
from .base_memory_service import BaseMemoryService
|
27
|
+
from .base_memory_service import SearchMemoryResponse
|
28
|
+
from .memory_entry import MemoryEntry
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from ..sessions.session import Session
|
32
|
+
|
33
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
34
|
+
|
35
|
+
|
36
|
+
class VertexAiMemoryBankService(BaseMemoryService):
|
37
|
+
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
project: Optional[str] = None,
|
42
|
+
location: Optional[str] = None,
|
43
|
+
agent_engine_id: Optional[str] = None,
|
44
|
+
):
|
45
|
+
"""Initializes a VertexAiMemoryBankService.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
project: The project ID of the Memory Bank to use.
|
49
|
+
location: The location of the Memory Bank to use.
|
50
|
+
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
|
51
|
+
e.g. '456' in
|
52
|
+
'projects/my-project/locations/us-central1/reasoningEngines/456'.
|
53
|
+
"""
|
54
|
+
self._project = project
|
55
|
+
self._location = location
|
56
|
+
self._agent_engine_id = agent_engine_id
|
57
|
+
|
58
|
+
@override
|
59
|
+
async def add_session_to_memory(self, session: Session):
|
60
|
+
api_client = self._get_api_client()
|
61
|
+
|
62
|
+
if not self._agent_engine_id:
|
63
|
+
raise ValueError('Agent Engine ID is required for Memory Bank.')
|
64
|
+
|
65
|
+
events = []
|
66
|
+
for event in session.events:
|
67
|
+
if event.content and event.content.parts:
|
68
|
+
events.append({
|
69
|
+
'content': event.content.model_dump(exclude_none=True, mode='json')
|
70
|
+
})
|
71
|
+
request_dict = {
|
72
|
+
'direct_contents_source': {
|
73
|
+
'events': events,
|
74
|
+
},
|
75
|
+
'scope': {
|
76
|
+
'app_name': session.app_name,
|
77
|
+
'user_id': session.user_id,
|
78
|
+
},
|
79
|
+
}
|
80
|
+
|
81
|
+
if events:
|
82
|
+
api_response = await api_client.async_request(
|
83
|
+
http_method='POST',
|
84
|
+
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
|
85
|
+
request_dict=request_dict,
|
86
|
+
)
|
87
|
+
logger.info(f'Generate memory response: {api_response}')
|
88
|
+
else:
|
89
|
+
logger.info('No events to add to memory.')
|
90
|
+
|
91
|
+
@override
|
92
|
+
async def search_memory(self, *, app_name: str, user_id: str, query: str):
|
93
|
+
api_client = self._get_api_client()
|
94
|
+
|
95
|
+
api_response = await api_client.async_request(
|
96
|
+
http_method='POST',
|
97
|
+
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
|
98
|
+
request_dict={
|
99
|
+
'scope': {
|
100
|
+
'app_name': app_name,
|
101
|
+
'user_id': user_id,
|
102
|
+
},
|
103
|
+
'similarity_search_params': {
|
104
|
+
'search_query': query,
|
105
|
+
},
|
106
|
+
},
|
107
|
+
)
|
108
|
+
api_response = _convert_api_response(api_response)
|
109
|
+
logger.info(f'Search memory response: {api_response}')
|
110
|
+
|
111
|
+
if not api_response or not api_response.get('retrievedMemories', None):
|
112
|
+
return SearchMemoryResponse()
|
113
|
+
|
114
|
+
memory_events = []
|
115
|
+
for memory in api_response.get('retrievedMemories', []):
|
116
|
+
# TODO: add more complex error handling
|
117
|
+
memory_events.append(
|
118
|
+
MemoryEntry(
|
119
|
+
author='user',
|
120
|
+
content=genai.types.Content(
|
121
|
+
parts=[
|
122
|
+
genai.types.Part(text=memory.get('memory').get('fact'))
|
123
|
+
],
|
124
|
+
role='user',
|
125
|
+
),
|
126
|
+
timestamp=memory.get('updateTime'),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
return SearchMemoryResponse(memories=memory_events)
|
130
|
+
|
131
|
+
def _get_api_client(self):
|
132
|
+
"""Instantiates an API client for the given project and location.
|
133
|
+
|
134
|
+
It needs to be instantiated inside each request so that the event loop
|
135
|
+
management can be properly propagated.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
An API client for the given project and location.
|
139
|
+
"""
|
140
|
+
client = genai.Client(
|
141
|
+
vertexai=True, project=self._project, location=self._location
|
142
|
+
)
|
143
|
+
return client._api_client
|
144
|
+
|
145
|
+
|
146
|
+
def _convert_api_response(api_response):
|
147
|
+
"""Converts the API response to a JSON object based on the type."""
|
148
|
+
if hasattr(api_response, 'body'):
|
149
|
+
return json.loads(api_response.body)
|
150
|
+
return api_response
|
google/adk/models/lite_llm.py
CHANGED
@@ -29,6 +29,7 @@ from typing import Tuple
|
|
29
29
|
from typing import Union
|
30
30
|
|
31
31
|
from google.genai import types
|
32
|
+
import litellm
|
32
33
|
from litellm import acompletion
|
33
34
|
from litellm import ChatCompletionAssistantMessage
|
34
35
|
from litellm import ChatCompletionAssistantToolCall
|
@@ -53,6 +54,9 @@ from .base_llm import BaseLlm
|
|
53
54
|
from .llm_request import LlmRequest
|
54
55
|
from .llm_response import LlmResponse
|
55
56
|
|
57
|
+
# This will add functions to prompts if functions are provided.
|
58
|
+
litellm.add_function_to_prompt = True
|
59
|
+
|
56
60
|
logger = logging.getLogger("google_adk." + __name__)
|
57
61
|
|
58
62
|
_NEW_LINE = "\n"
|
@@ -662,6 +666,10 @@ class LiteLlm(BaseLlm):
|
|
662
666
|
|
663
667
|
messages, tools, response_format = _get_completion_inputs(llm_request)
|
664
668
|
|
669
|
+
if "functions" in self._additional_args:
|
670
|
+
# LiteLLM does not support both tools and functions together.
|
671
|
+
tools = None
|
672
|
+
|
665
673
|
completion_args = {
|
666
674
|
"model": self.model,
|
667
675
|
"messages": messages,
|
@@ -679,7 +687,7 @@ class LiteLlm(BaseLlm):
|
|
679
687
|
aggregated_llm_response_with_tool_call = None
|
680
688
|
usage_metadata = None
|
681
689
|
fallback_index = 0
|
682
|
-
for part in self.llm_client.
|
690
|
+
async for part in await self.llm_client.acompletion(**completion_args):
|
683
691
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
684
692
|
if isinstance(chunk, FunctionChunk):
|
685
693
|
index = chunk.index or fallback_index
|