google-adk 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/converters/event_converter.py +382 -0
- google/adk/a2a/converters/part_converter.py +4 -2
- google/adk/a2a/converters/request_converter.py +90 -0
- google/adk/a2a/converters/utils.py +71 -0
- google/adk/agents/llm_agent.py +5 -3
- google/adk/artifacts/gcs_artifact_service.py +3 -2
- google/adk/auth/auth_tool.py +2 -2
- google/adk/auth/credential_service/session_state_credential_service.py +83 -0
- google/adk/cli/cli_deploy.py +9 -2
- google/adk/cli/cli_tools_click.py +110 -52
- google/adk/cli/fast_api.py +26 -2
- google/adk/cli/utils/evals.py +53 -0
- google/adk/evaluation/final_response_match_v1.py +110 -0
- google/adk/evaluation/gcs_eval_sets_manager.py +8 -5
- google/adk/evaluation/response_evaluator.py +12 -1
- google/adk/events/event.py +5 -5
- google/adk/flows/llm_flows/contents.py +49 -4
- google/adk/flows/llm_flows/functions.py +32 -0
- google/adk/memory/__init__.py +3 -1
- google/adk/memory/vertex_ai_memory_bank_service.py +150 -0
- google/adk/models/lite_llm.py +9 -1
- google/adk/runners.py +10 -0
- google/adk/sessions/vertex_ai_session_service.py +70 -19
- google/adk/telemetry.py +10 -0
- google/adk/tools/bigquery/bigquery_credentials.py +28 -11
- google/adk/tools/bigquery/bigquery_tool.py +1 -1
- google/adk/tools/bigquery/client.py +1 -1
- google/adk/tools/bigquery/metadata_tool.py +1 -1
- google/adk/tools/bigquery/query_tool.py +1 -1
- google/adk/version.py +1 -1
- {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/METADATA +6 -5
- {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/RECORD +35 -29
- {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/WHEEL +0 -0
- {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,110 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
from google.genai import types as genai_types
|
20
|
+
from rouge_score import rouge_scorer
|
21
|
+
from typing_extensions import override
|
22
|
+
|
23
|
+
from .eval_case import Invocation
|
24
|
+
from .eval_metrics import EvalMetric
|
25
|
+
from .evaluator import EvalStatus
|
26
|
+
from .evaluator import EvaluationResult
|
27
|
+
from .evaluator import Evaluator
|
28
|
+
from .evaluator import PerInvocationResult
|
29
|
+
|
30
|
+
|
31
|
+
class RougeEvaluator(Evaluator):
|
32
|
+
"""Calculates the ROUGE-1 metric to compare responses."""
|
33
|
+
|
34
|
+
def __init__(self, eval_metric: EvalMetric):
|
35
|
+
self._eval_metric = eval_metric
|
36
|
+
|
37
|
+
@override
|
38
|
+
def evaluate_invocations(
|
39
|
+
self,
|
40
|
+
actual_invocations: list[Invocation],
|
41
|
+
expected_invocations: list[Invocation],
|
42
|
+
) -> EvaluationResult:
|
43
|
+
total_score = 0.0
|
44
|
+
num_invocations = 0
|
45
|
+
per_invocation_results = []
|
46
|
+
for actual, expected in zip(actual_invocations, expected_invocations):
|
47
|
+
reference = _get_text_from_content(expected.final_response)
|
48
|
+
response = _get_text_from_content(actual.final_response)
|
49
|
+
rouge_1_scores = _calculate_rouge_1_scores(response, reference)
|
50
|
+
score = rouge_1_scores.fmeasure
|
51
|
+
per_invocation_results.append(
|
52
|
+
PerInvocationResult(
|
53
|
+
actual_invocation=actual,
|
54
|
+
expected_invocation=expected,
|
55
|
+
score=score,
|
56
|
+
eval_status=_get_eval_status(score, self._eval_metric.threshold),
|
57
|
+
)
|
58
|
+
)
|
59
|
+
total_score += score
|
60
|
+
num_invocations += 1
|
61
|
+
|
62
|
+
if per_invocation_results:
|
63
|
+
overall_score = total_score / num_invocations
|
64
|
+
return EvaluationResult(
|
65
|
+
overall_score=overall_score,
|
66
|
+
overall_eval_status=_get_eval_status(
|
67
|
+
overall_score, self._eval_metric.threshold
|
68
|
+
),
|
69
|
+
per_invocation_results=per_invocation_results,
|
70
|
+
)
|
71
|
+
|
72
|
+
return EvaluationResult()
|
73
|
+
|
74
|
+
|
75
|
+
def _get_text_from_content(content: Optional[genai_types.Content]) -> str:
|
76
|
+
if content and content.parts:
|
77
|
+
return "\n".join([part.text for part in content.parts if part.text])
|
78
|
+
|
79
|
+
return ""
|
80
|
+
|
81
|
+
|
82
|
+
def _get_eval_status(score: float, threshold: float):
|
83
|
+
return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
|
84
|
+
|
85
|
+
|
86
|
+
def _calculate_rouge_1_scores(candidate: str, reference: str):
|
87
|
+
"""Calculates the ROUGE-1 score between a candidate and reference text.
|
88
|
+
|
89
|
+
ROUGE-1 measures the overlap of unigrams (single words) between the
|
90
|
+
candidate and reference texts. The score is broken down into:
|
91
|
+
- Precision: The proportion of unigrams in the candidate that are also in the
|
92
|
+
reference.
|
93
|
+
- Recall: The proportion of unigrams in the reference that are also in the
|
94
|
+
candidate.
|
95
|
+
- F-measure: The harmonic mean of precision and recall.
|
96
|
+
|
97
|
+
Args:
|
98
|
+
candidate: The generated text to be evaluated.
|
99
|
+
reference: The ground-truth text to compare against.
|
100
|
+
|
101
|
+
Returns:
|
102
|
+
A dictionary containing the ROUGE-1 precision, recall, and f-measure.
|
103
|
+
"""
|
104
|
+
scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
|
105
|
+
|
106
|
+
# The score method returns a dictionary where keys are the ROUGE types
|
107
|
+
# and values are Score objects (tuples) with precision, recall, and fmeasure.
|
108
|
+
scores = scorer.score(reference, candidate)
|
109
|
+
|
110
|
+
return scores["rouge1"]
|
@@ -72,6 +72,13 @@ class GcsEvalSetsManager(EvalSetsManager):
|
|
72
72
|
f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
|
73
73
|
)
|
74
74
|
|
75
|
+
def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]:
|
76
|
+
blob = self.bucket.blob(blob_name)
|
77
|
+
if not blob.exists():
|
78
|
+
return None
|
79
|
+
eval_set_data = blob.download_as_text()
|
80
|
+
return EvalSet.model_validate_json(eval_set_data)
|
81
|
+
|
75
82
|
def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet):
|
76
83
|
"""Writes an EvalSet to GCS."""
|
77
84
|
blob = self.bucket.blob(blob_name)
|
@@ -88,11 +95,7 @@ class GcsEvalSetsManager(EvalSetsManager):
|
|
88
95
|
def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
|
89
96
|
"""Returns an EvalSet identified by an app_name and eval_set_id."""
|
90
97
|
eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
|
91
|
-
|
92
|
-
if not blob.exists():
|
93
|
-
return None
|
94
|
-
eval_set_data = blob.download_as_text()
|
95
|
-
return EvalSet.model_validate_json(eval_set_data)
|
98
|
+
return self._load_eval_set_from_blob(eval_set_blob_name)
|
96
99
|
|
97
100
|
@override
|
98
101
|
def create_eval_set(self, app_name: str, eval_set_id: str):
|
@@ -27,10 +27,12 @@ from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
|
27
27
|
|
28
28
|
from .eval_case import IntermediateData
|
29
29
|
from .eval_case import Invocation
|
30
|
+
from .eval_metrics import EvalMetric
|
30
31
|
from .evaluator import EvalStatus
|
31
32
|
from .evaluator import EvaluationResult
|
32
33
|
from .evaluator import Evaluator
|
33
34
|
from .evaluator import PerInvocationResult
|
35
|
+
from .final_response_match_v1 import RougeEvaluator
|
34
36
|
|
35
37
|
|
36
38
|
class ResponseEvaluator(Evaluator):
|
@@ -40,7 +42,7 @@ class ResponseEvaluator(Evaluator):
|
|
40
42
|
if "response_evaluation_score" == metric_name:
|
41
43
|
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
|
42
44
|
elif "response_match_score" == metric_name:
|
43
|
-
self._metric_name = "
|
45
|
+
self._metric_name = "response_match_score"
|
44
46
|
else:
|
45
47
|
raise ValueError(f"`{metric_name}` is not supported.")
|
46
48
|
|
@@ -52,6 +54,15 @@ class ResponseEvaluator(Evaluator):
|
|
52
54
|
actual_invocations: list[Invocation],
|
53
55
|
expected_invocations: list[Invocation],
|
54
56
|
) -> EvaluationResult:
|
57
|
+
# If the metric is response_match_score, just use the RougeEvaluator.
|
58
|
+
if self._metric_name == "response_match_score":
|
59
|
+
rouge_evaluator = RougeEvaluator(
|
60
|
+
EvalMetric(metric_name=self._metric_name, threshold=self._threshold)
|
61
|
+
)
|
62
|
+
return rouge_evaluator.evaluate_invocations(
|
63
|
+
actual_invocations, expected_invocations
|
64
|
+
)
|
65
|
+
|
55
66
|
total_score = 0.0
|
56
67
|
num_invocations = 0
|
57
68
|
per_invocation_results = []
|
google/adk/events/event.py
CHANGED
@@ -34,9 +34,10 @@ class Event(LlmResponse):
|
|
34
34
|
taken by the agents like function calls, etc.
|
35
35
|
|
36
36
|
Attributes:
|
37
|
-
invocation_id: The invocation ID of the event.
|
38
|
-
|
39
|
-
|
37
|
+
invocation_id: Required. The invocation ID of the event. Should be non-empty
|
38
|
+
before appending to a session.
|
39
|
+
author: Required. "user" or the name of the agent, indicating who appended
|
40
|
+
the event to the session.
|
40
41
|
actions: The actions taken by the agent.
|
41
42
|
long_running_tool_ids: The ids of the long running function calls.
|
42
43
|
branch: The branch of the event.
|
@@ -55,9 +56,8 @@ class Event(LlmResponse):
|
|
55
56
|
)
|
56
57
|
"""The pydantic model config."""
|
57
58
|
|
58
|
-
# TODO: revert to be required after spark migration
|
59
59
|
invocation_id: str = ''
|
60
|
-
"""The invocation ID of the event."""
|
60
|
+
"""The invocation ID of the event. Should be non-empty before appending to a session."""
|
61
61
|
author: str
|
62
62
|
"""'user' or the name of the agent, indicating who appended the event to the
|
63
63
|
session."""
|
@@ -43,12 +43,20 @@ class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
43
43
|
if not isinstance(agent, LlmAgent):
|
44
44
|
return
|
45
45
|
|
46
|
-
if agent.include_contents
|
46
|
+
if agent.include_contents == 'default':
|
47
|
+
# Include full conversation history
|
47
48
|
llm_request.contents = _get_contents(
|
48
49
|
invocation_context.branch,
|
49
50
|
invocation_context.session.events,
|
50
51
|
agent.name,
|
51
52
|
)
|
53
|
+
else:
|
54
|
+
# Include current turn context only (no conversation history)
|
55
|
+
llm_request.contents = _get_current_turn_contents(
|
56
|
+
invocation_context.branch,
|
57
|
+
invocation_context.session.events,
|
58
|
+
agent.name,
|
59
|
+
)
|
52
60
|
|
53
61
|
# Maintain async generator behavior
|
54
62
|
if False: # Ensures it behaves as a generator
|
@@ -190,13 +198,15 @@ def _get_contents(
|
|
190
198
|
) -> list[types.Content]:
|
191
199
|
"""Get the contents for the LLM request.
|
192
200
|
|
201
|
+
Applies filtering, rearrangement, and content processing to events.
|
202
|
+
|
193
203
|
Args:
|
194
204
|
current_branch: The current branch of the agent.
|
195
|
-
events:
|
205
|
+
events: Events to process.
|
196
206
|
agent_name: The name of the agent.
|
197
207
|
|
198
208
|
Returns:
|
199
|
-
A list of contents.
|
209
|
+
A list of processed contents.
|
200
210
|
"""
|
201
211
|
filtered_events = []
|
202
212
|
# Parse the events, leaving the contents and the function calls and
|
@@ -211,12 +221,13 @@ def _get_contents(
|
|
211
221
|
# Skip events without content, or generated neither by user nor by model
|
212
222
|
# or has empty text.
|
213
223
|
# E.g. events purely for mutating session states.
|
224
|
+
|
214
225
|
continue
|
215
226
|
if not _is_event_belongs_to_branch(current_branch, event):
|
216
227
|
# Skip events not belong to current branch.
|
217
228
|
continue
|
218
229
|
if _is_auth_event(event):
|
219
|
-
#
|
230
|
+
# Skip auth events.
|
220
231
|
continue
|
221
232
|
filtered_events.append(
|
222
233
|
_convert_foreign_event(event)
|
@@ -224,12 +235,15 @@ def _get_contents(
|
|
224
235
|
else event
|
225
236
|
)
|
226
237
|
|
238
|
+
# Rearrange events for proper function call/response pairing
|
227
239
|
result_events = _rearrange_events_for_latest_function_response(
|
228
240
|
filtered_events
|
229
241
|
)
|
230
242
|
result_events = _rearrange_events_for_async_function_responses_in_history(
|
231
243
|
result_events
|
232
244
|
)
|
245
|
+
|
246
|
+
# Convert events to contents
|
233
247
|
contents = []
|
234
248
|
for event in result_events:
|
235
249
|
content = copy.deepcopy(event.content)
|
@@ -238,6 +252,37 @@ def _get_contents(
|
|
238
252
|
return contents
|
239
253
|
|
240
254
|
|
255
|
+
def _get_current_turn_contents(
|
256
|
+
current_branch: Optional[str], events: list[Event], agent_name: str = ''
|
257
|
+
) -> list[types.Content]:
|
258
|
+
"""Get contents for the current turn only (no conversation history).
|
259
|
+
|
260
|
+
When include_contents='none', we want to include:
|
261
|
+
- The current user input
|
262
|
+
- Tool calls and responses from the current turn
|
263
|
+
But exclude conversation history from previous turns.
|
264
|
+
|
265
|
+
In multi-agent scenarios, the "current turn" for an agent starts from an
|
266
|
+
actual user or from another agent.
|
267
|
+
|
268
|
+
Args:
|
269
|
+
current_branch: The current branch of the agent.
|
270
|
+
events: A list of all session events.
|
271
|
+
agent_name: The name of the agent.
|
272
|
+
|
273
|
+
Returns:
|
274
|
+
A list of contents for the current turn only, preserving context needed
|
275
|
+
for proper tool execution while excluding conversation history.
|
276
|
+
"""
|
277
|
+
# Find the latest event that starts the current turn and process from there
|
278
|
+
for i in range(len(events) - 1, -1, -1):
|
279
|
+
event = events[i]
|
280
|
+
if event.author == 'user' or _is_other_agent_reply(agent_name, event):
|
281
|
+
return _get_contents(current_branch, events[i:], agent_name)
|
282
|
+
|
283
|
+
return []
|
284
|
+
|
285
|
+
|
241
286
|
def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
|
242
287
|
"""Whether the event is a reply from another agent."""
|
243
288
|
return bool(
|
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
|
|
519
519
|
# Use the base_event as the timestamp
|
520
520
|
merged_event.timestamp = base_event.timestamp
|
521
521
|
return merged_event
|
522
|
+
|
523
|
+
|
524
|
+
def find_matching_function_call(
|
525
|
+
events: list[Event],
|
526
|
+
) -> Optional[Event]:
|
527
|
+
"""Finds the function call event that matches the function response id of the last event."""
|
528
|
+
if not events:
|
529
|
+
return None
|
530
|
+
|
531
|
+
last_event = events[-1]
|
532
|
+
if (
|
533
|
+
last_event.content
|
534
|
+
and last_event.content.parts
|
535
|
+
and any(part.function_response for part in last_event.content.parts)
|
536
|
+
):
|
537
|
+
|
538
|
+
function_call_id = next(
|
539
|
+
part.function_response.id
|
540
|
+
for part in last_event.content.parts
|
541
|
+
if part.function_response
|
542
|
+
)
|
543
|
+
for i in range(len(events) - 2, -1, -1):
|
544
|
+
event = events[i]
|
545
|
+
# looking for the system long running request euc function call
|
546
|
+
function_calls = event.get_function_calls()
|
547
|
+
if not function_calls:
|
548
|
+
continue
|
549
|
+
|
550
|
+
for function_call in function_calls:
|
551
|
+
if function_call.id == function_call_id:
|
552
|
+
return event
|
553
|
+
return None
|
google/adk/memory/__init__.py
CHANGED
@@ -15,12 +15,14 @@ import logging
|
|
15
15
|
|
16
16
|
from .base_memory_service import BaseMemoryService
|
17
17
|
from .in_memory_memory_service import InMemoryMemoryService
|
18
|
+
from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
|
18
19
|
|
19
20
|
logger = logging.getLogger('google_adk.' + __name__)
|
20
21
|
|
21
22
|
__all__ = [
|
22
23
|
'BaseMemoryService',
|
23
24
|
'InMemoryMemoryService',
|
25
|
+
'VertexAiMemoryBankService',
|
24
26
|
]
|
25
27
|
|
26
28
|
try:
|
@@ -29,7 +31,7 @@ try:
|
|
29
31
|
__all__.append('VertexAiRagMemoryService')
|
30
32
|
except ImportError:
|
31
33
|
logger.debug(
|
32
|
-
'The Vertex
|
34
|
+
'The Vertex SDK is not installed. If you want to use the'
|
33
35
|
' VertexAiRagMemoryService please install it. If not, you can ignore this'
|
34
36
|
' warning.'
|
35
37
|
)
|
@@ -0,0 +1,150 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
from typing import Optional
|
20
|
+
from typing import TYPE_CHECKING
|
21
|
+
|
22
|
+
from typing_extensions import override
|
23
|
+
|
24
|
+
from google import genai
|
25
|
+
|
26
|
+
from .base_memory_service import BaseMemoryService
|
27
|
+
from .base_memory_service import SearchMemoryResponse
|
28
|
+
from .memory_entry import MemoryEntry
|
29
|
+
|
30
|
+
if TYPE_CHECKING:
|
31
|
+
from ..sessions.session import Session
|
32
|
+
|
33
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
34
|
+
|
35
|
+
|
36
|
+
class VertexAiMemoryBankService(BaseMemoryService):
|
37
|
+
"""Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
|
38
|
+
|
39
|
+
def __init__(
|
40
|
+
self,
|
41
|
+
project: Optional[str] = None,
|
42
|
+
location: Optional[str] = None,
|
43
|
+
agent_engine_id: Optional[str] = None,
|
44
|
+
):
|
45
|
+
"""Initializes a VertexAiMemoryBankService.
|
46
|
+
|
47
|
+
Args:
|
48
|
+
project: The project ID of the Memory Bank to use.
|
49
|
+
location: The location of the Memory Bank to use.
|
50
|
+
agent_engine_id: The ID of the agent engine to use for the Memory Bank.
|
51
|
+
e.g. '456' in
|
52
|
+
'projects/my-project/locations/us-central1/reasoningEngines/456'.
|
53
|
+
"""
|
54
|
+
self._project = project
|
55
|
+
self._location = location
|
56
|
+
self._agent_engine_id = agent_engine_id
|
57
|
+
|
58
|
+
@override
|
59
|
+
async def add_session_to_memory(self, session: Session):
|
60
|
+
api_client = self._get_api_client()
|
61
|
+
|
62
|
+
if not self._agent_engine_id:
|
63
|
+
raise ValueError('Agent Engine ID is required for Memory Bank.')
|
64
|
+
|
65
|
+
events = []
|
66
|
+
for event in session.events:
|
67
|
+
if event.content and event.content.parts:
|
68
|
+
events.append({
|
69
|
+
'content': event.content.model_dump(exclude_none=True, mode='json')
|
70
|
+
})
|
71
|
+
request_dict = {
|
72
|
+
'direct_contents_source': {
|
73
|
+
'events': events,
|
74
|
+
},
|
75
|
+
'scope': {
|
76
|
+
'app_name': session.app_name,
|
77
|
+
'user_id': session.user_id,
|
78
|
+
},
|
79
|
+
}
|
80
|
+
|
81
|
+
if events:
|
82
|
+
api_response = await api_client.async_request(
|
83
|
+
http_method='POST',
|
84
|
+
path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
|
85
|
+
request_dict=request_dict,
|
86
|
+
)
|
87
|
+
logger.info(f'Generate memory response: {api_response}')
|
88
|
+
else:
|
89
|
+
logger.info('No events to add to memory.')
|
90
|
+
|
91
|
+
@override
|
92
|
+
async def search_memory(self, *, app_name: str, user_id: str, query: str):
|
93
|
+
api_client = self._get_api_client()
|
94
|
+
|
95
|
+
api_response = await api_client.async_request(
|
96
|
+
http_method='POST',
|
97
|
+
path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
|
98
|
+
request_dict={
|
99
|
+
'scope': {
|
100
|
+
'app_name': app_name,
|
101
|
+
'user_id': user_id,
|
102
|
+
},
|
103
|
+
'similarity_search_params': {
|
104
|
+
'search_query': query,
|
105
|
+
},
|
106
|
+
},
|
107
|
+
)
|
108
|
+
api_response = _convert_api_response(api_response)
|
109
|
+
logger.info(f'Search memory response: {api_response}')
|
110
|
+
|
111
|
+
if not api_response or not api_response.get('retrievedMemories', None):
|
112
|
+
return SearchMemoryResponse()
|
113
|
+
|
114
|
+
memory_events = []
|
115
|
+
for memory in api_response.get('retrievedMemories', []):
|
116
|
+
# TODO: add more complex error handling
|
117
|
+
memory_events.append(
|
118
|
+
MemoryEntry(
|
119
|
+
author='user',
|
120
|
+
content=genai.types.Content(
|
121
|
+
parts=[
|
122
|
+
genai.types.Part(text=memory.get('memory').get('fact'))
|
123
|
+
],
|
124
|
+
role='user',
|
125
|
+
),
|
126
|
+
timestamp=memory.get('updateTime'),
|
127
|
+
)
|
128
|
+
)
|
129
|
+
return SearchMemoryResponse(memories=memory_events)
|
130
|
+
|
131
|
+
def _get_api_client(self):
|
132
|
+
"""Instantiates an API client for the given project and location.
|
133
|
+
|
134
|
+
It needs to be instantiated inside each request so that the event loop
|
135
|
+
management can be properly propagated.
|
136
|
+
|
137
|
+
Returns:
|
138
|
+
An API client for the given project and location.
|
139
|
+
"""
|
140
|
+
client = genai.Client(
|
141
|
+
vertexai=True, project=self._project, location=self._location
|
142
|
+
)
|
143
|
+
return client._api_client
|
144
|
+
|
145
|
+
|
146
|
+
def _convert_api_response(api_response):
|
147
|
+
"""Converts the API response to a JSON object based on the type."""
|
148
|
+
if hasattr(api_response, 'body'):
|
149
|
+
return json.loads(api_response.body)
|
150
|
+
return api_response
|
google/adk/models/lite_llm.py
CHANGED
@@ -29,6 +29,7 @@ from typing import Tuple
|
|
29
29
|
from typing import Union
|
30
30
|
|
31
31
|
from google.genai import types
|
32
|
+
import litellm
|
32
33
|
from litellm import acompletion
|
33
34
|
from litellm import ChatCompletionAssistantMessage
|
34
35
|
from litellm import ChatCompletionAssistantToolCall
|
@@ -53,6 +54,9 @@ from .base_llm import BaseLlm
|
|
53
54
|
from .llm_request import LlmRequest
|
54
55
|
from .llm_response import LlmResponse
|
55
56
|
|
57
|
+
# This will add functions to prompts if functions are provided.
|
58
|
+
litellm.add_function_to_prompt = True
|
59
|
+
|
56
60
|
logger = logging.getLogger("google_adk." + __name__)
|
57
61
|
|
58
62
|
_NEW_LINE = "\n"
|
@@ -662,6 +666,10 @@ class LiteLlm(BaseLlm):
|
|
662
666
|
|
663
667
|
messages, tools, response_format = _get_completion_inputs(llm_request)
|
664
668
|
|
669
|
+
if "functions" in self._additional_args:
|
670
|
+
# LiteLLM does not support both tools and functions together.
|
671
|
+
tools = None
|
672
|
+
|
665
673
|
completion_args = {
|
666
674
|
"model": self.model,
|
667
675
|
"messages": messages,
|
@@ -679,7 +687,7 @@ class LiteLlm(BaseLlm):
|
|
679
687
|
aggregated_llm_response_with_tool_call = None
|
680
688
|
usage_metadata = None
|
681
689
|
fallback_index = 0
|
682
|
-
for part in self.llm_client.
|
690
|
+
async for part in await self.llm_client.acompletion(**completion_args):
|
683
691
|
for chunk, finish_reason in _model_response_to_chunk(part):
|
684
692
|
if isinstance(chunk, FunctionChunk):
|
685
693
|
index = chunk.index or fallback_index
|
google/adk/runners.py
CHANGED
@@ -36,6 +36,7 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService
|
|
36
36
|
from .auth.credential_service.base_credential_service import BaseCredentialService
|
37
37
|
from .code_executors.built_in_code_executor import BuiltInCodeExecutor
|
38
38
|
from .events.event import Event
|
39
|
+
from .flows.llm_flows.functions import find_matching_function_call
|
39
40
|
from .memory.base_memory_service import BaseMemoryService
|
40
41
|
from .memory.in_memory_memory_service import InMemoryMemoryService
|
41
42
|
from .platform.thread import create_thread
|
@@ -337,6 +338,8 @@ class Runner:
|
|
337
338
|
"""Finds the agent to run to continue the session.
|
338
339
|
|
339
340
|
A qualified agent must be either of:
|
341
|
+
- The agent that returned a function call and the last user message is a
|
342
|
+
function response to this function call.
|
340
343
|
- The root agent;
|
341
344
|
- An LlmAgent who replied last and is capable to transfer to any other agent
|
342
345
|
in the agent hierarchy.
|
@@ -348,6 +351,13 @@ class Runner:
|
|
348
351
|
Returns:
|
349
352
|
The agent of the last message in the session or the root agent.
|
350
353
|
"""
|
354
|
+
# If the last event is a function response, should send this response to
|
355
|
+
# the agent that returned the corressponding function call regardless the
|
356
|
+
# type of the agent. e.g. a remote a2a agent may surface a credential
|
357
|
+
# request as a special long running function tool call.
|
358
|
+
event = find_matching_function_call(session.events)
|
359
|
+
if event and event.author:
|
360
|
+
return root_agent.find_agent(event.author)
|
351
361
|
for event in filter(lambda e: e.author != 'user', reversed(session.events)):
|
352
362
|
if event.author == root_agent.name:
|
353
363
|
# Found root agent.
|