google-adk 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +4 -4
- google/adk/agents/callback_context.py +0 -1
- google/adk/agents/invocation_context.py +1 -1
- google/adk/agents/remote_agent.py +1 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/auth/auth_credential.py +2 -1
- google/adk/auth/auth_handler.py +7 -3
- google/adk/auth/auth_preprocessor.py +2 -2
- google/adk/auth/auth_tool.py +1 -1
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-SLIAU2JL.js → main-HWIBUY2R.js} +69 -69
- google/adk/cli/cli_create.py +279 -0
- google/adk/cli/cli_deploy.py +10 -1
- google/adk/cli/cli_eval.py +3 -3
- google/adk/cli/cli_tools_click.py +95 -19
- google/adk/cli/fast_api.py +57 -16
- google/adk/cli/utils/envs.py +0 -3
- google/adk/cli/utils/evals.py +2 -2
- google/adk/evaluation/agent_evaluator.py +2 -2
- google/adk/evaluation/evaluation_generator.py +4 -4
- google/adk/evaluation/response_evaluator.py +17 -5
- google/adk/evaluation/trajectory_evaluator.py +4 -5
- google/adk/events/event.py +3 -3
- google/adk/flows/llm_flows/_nl_planning.py +10 -4
- google/adk/flows/llm_flows/agent_transfer.py +1 -1
- google/adk/flows/llm_flows/base_llm_flow.py +1 -1
- google/adk/flows/llm_flows/contents.py +2 -2
- google/adk/flows/llm_flows/functions.py +1 -3
- google/adk/flows/llm_flows/instructions.py +2 -2
- google/adk/models/gemini_llm_connection.py +2 -2
- google/adk/models/lite_llm.py +51 -34
- google/adk/models/llm_response.py +10 -1
- google/adk/planners/built_in_planner.py +1 -0
- google/adk/planners/plan_re_act_planner.py +2 -2
- google/adk/runners.py +1 -1
- google/adk/sessions/database_session_service.py +91 -26
- google/adk/sessions/state.py +2 -2
- google/adk/telemetry.py +2 -2
- google/adk/tools/agent_tool.py +2 -3
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -2
- google/adk/tools/base_tool.py +1 -1
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +74 -1
- google/adk/tools/google_api_tool/google_api_tool_set.py +12 -9
- google/adk/tools/google_api_tool/google_api_tool_sets.py +91 -34
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +3 -1
- google/adk/tools/load_artifacts_tool.py +1 -1
- google/adk/tools/load_memory_tool.py +25 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +176 -0
- google/adk/tools/mcp_tool/mcp_tool.py +15 -2
- google/adk/tools/mcp_tool/mcp_toolset.py +31 -37
- google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +4 -4
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +1 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -12
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +47 -9
- google/adk/tools/toolbox_tool.py +1 -1
- google/adk/version.py +1 -1
- google_adk-0.3.0.dist-info/METADATA +235 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/RECORD +62 -60
- google_adk-0.1.1.dist-info/METADATA +0 -181
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/WHEEL +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/fast_api.py
CHANGED
@@ -13,7 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import asyncio
|
16
|
+
from contextlib import asynccontextmanager
|
16
17
|
import importlib
|
18
|
+
import inspect
|
17
19
|
import json
|
18
20
|
import logging
|
19
21
|
import os
|
@@ -28,6 +30,7 @@ from typing import Literal
|
|
28
30
|
from typing import Optional
|
29
31
|
|
30
32
|
import click
|
33
|
+
from click import Tuple
|
31
34
|
from fastapi import FastAPI
|
32
35
|
from fastapi import HTTPException
|
33
36
|
from fastapi import Query
|
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
|
|
56
59
|
from ..agents.run_config import StreamingMode
|
57
60
|
from ..artifacts import InMemoryArtifactService
|
58
61
|
from ..events.event import Event
|
62
|
+
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
59
63
|
from ..runners import Runner
|
60
64
|
from ..sessions.database_session_service import DatabaseSessionService
|
61
65
|
from ..sessions.in_memory_session_service import InMemorySessionService
|
@@ -144,7 +148,7 @@ def get_fast_api_app(
|
|
144
148
|
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
145
149
|
)
|
146
150
|
if trace_to_cloud:
|
147
|
-
envs.
|
151
|
+
envs.load_dotenv_for_agent("", agent_dir)
|
148
152
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
149
153
|
processor = export.BatchSpanProcessor(
|
150
154
|
CloudTraceSpanExporter(project_id=project_id)
|
@@ -158,8 +162,22 @@ def get_fast_api_app(
|
|
158
162
|
|
159
163
|
trace.set_tracer_provider(provider)
|
160
164
|
|
165
|
+
exit_stacks = []
|
166
|
+
|
167
|
+
@asynccontextmanager
|
168
|
+
async def internal_lifespan(app: FastAPI):
|
169
|
+
if lifespan:
|
170
|
+
async with lifespan(app) as lifespan_context:
|
171
|
+
yield
|
172
|
+
|
173
|
+
if exit_stacks:
|
174
|
+
for stack in exit_stacks:
|
175
|
+
await stack.aclose()
|
176
|
+
else:
|
177
|
+
yield
|
178
|
+
|
161
179
|
# Run the FastAPI server.
|
162
|
-
app = FastAPI(lifespan=
|
180
|
+
app = FastAPI(lifespan=internal_lifespan)
|
163
181
|
|
164
182
|
if allow_origins:
|
165
183
|
app.add_middleware(
|
@@ -178,6 +196,7 @@ def get_fast_api_app(
|
|
178
196
|
|
179
197
|
# Build the Artifact service
|
180
198
|
artifact_service = InMemoryArtifactService()
|
199
|
+
memory_service = InMemoryMemoryService()
|
181
200
|
|
182
201
|
# Build the Session service
|
183
202
|
agent_engine_id = ""
|
@@ -355,7 +374,7 @@ def get_fast_api_app(
|
|
355
374
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
356
375
|
response_model_exclude_none=True,
|
357
376
|
)
|
358
|
-
def add_session_to_eval_set(
|
377
|
+
async def add_session_to_eval_set(
|
359
378
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
360
379
|
):
|
361
380
|
pattern = r"^[a-zA-Z0-9_]+$"
|
@@ -390,7 +409,9 @@ def get_fast_api_app(
|
|
390
409
|
test_data = evals.convert_session_to_eval_format(session)
|
391
410
|
|
392
411
|
# Populate the session with initial session state.
|
393
|
-
initial_session_state = create_empty_state(
|
412
|
+
initial_session_state = create_empty_state(
|
413
|
+
await _get_root_agent_async(app_name)
|
414
|
+
)
|
394
415
|
|
395
416
|
eval_set_data.append({
|
396
417
|
"name": req.eval_id,
|
@@ -427,7 +448,7 @@ def get_fast_api_app(
|
|
427
448
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
428
449
|
response_model_exclude_none=True,
|
429
450
|
)
|
430
|
-
def run_eval(
|
451
|
+
async def run_eval(
|
431
452
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
432
453
|
) -> list[RunEvalResult]:
|
433
454
|
from .cli_eval import run_evals
|
@@ -444,7 +465,7 @@ def get_fast_api_app(
|
|
444
465
|
logger.info(
|
445
466
|
"Eval ids to run list is empty. We will all evals in the eval set."
|
446
467
|
)
|
447
|
-
root_agent =
|
468
|
+
root_agent = await _get_root_agent_async(app_name)
|
448
469
|
eval_results = list(
|
449
470
|
run_evals(
|
450
471
|
eval_set_to_evals,
|
@@ -574,7 +595,7 @@ def get_fast_api_app(
|
|
574
595
|
)
|
575
596
|
if not session:
|
576
597
|
raise HTTPException(status_code=404, detail="Session not found")
|
577
|
-
runner =
|
598
|
+
runner = await _get_runner_async(req.app_name)
|
578
599
|
events = [
|
579
600
|
event
|
580
601
|
async for event in runner.run_async(
|
@@ -601,7 +622,7 @@ def get_fast_api_app(
|
|
601
622
|
async def event_generator():
|
602
623
|
try:
|
603
624
|
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
604
|
-
runner =
|
625
|
+
runner = await _get_runner_async(req.app_name)
|
605
626
|
async for event in runner.run_async(
|
606
627
|
user_id=req.user_id,
|
607
628
|
session_id=req.session_id,
|
@@ -627,7 +648,7 @@ def get_fast_api_app(
|
|
627
648
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
628
649
|
response_model_exclude_none=True,
|
629
650
|
)
|
630
|
-
def get_event_graph(
|
651
|
+
async def get_event_graph(
|
631
652
|
app_name: str, user_id: str, session_id: str, event_id: str
|
632
653
|
):
|
633
654
|
# Connect to managed session if agent_engine_id is set.
|
@@ -644,7 +665,7 @@ def get_fast_api_app(
|
|
644
665
|
|
645
666
|
function_calls = event.get_function_calls()
|
646
667
|
function_responses = event.get_function_responses()
|
647
|
-
root_agent =
|
668
|
+
root_agent = await _get_root_agent_async(app_name)
|
648
669
|
dot_graph = None
|
649
670
|
if function_calls:
|
650
671
|
function_call_highlights = []
|
@@ -701,7 +722,7 @@ def get_fast_api_app(
|
|
701
722
|
live_request_queue = LiveRequestQueue()
|
702
723
|
|
703
724
|
async def forward_events():
|
704
|
-
runner =
|
725
|
+
runner = await _get_runner_async(app_name)
|
705
726
|
async for event in runner.run_live(
|
706
727
|
session=session, live_request_queue=live_request_queue
|
707
728
|
):
|
@@ -735,30 +756,50 @@ def get_fast_api_app(
|
|
735
756
|
except Exception as e:
|
736
757
|
logger.exception("Error during live websocket communication: %s", e)
|
737
758
|
traceback.print_exc()
|
759
|
+
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
760
|
+
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
761
|
+
await websocket.close(
|
762
|
+
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
763
|
+
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
764
|
+
)
|
738
765
|
finally:
|
739
766
|
for task in pending:
|
740
767
|
task.cancel()
|
741
768
|
|
742
|
-
def
|
769
|
+
async def _get_root_agent_async(app_name: str) -> Agent:
|
743
770
|
"""Returns the root agent for the given app."""
|
744
771
|
if app_name in root_agent_dict:
|
745
772
|
return root_agent_dict[app_name]
|
746
|
-
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
747
773
|
agent_module = importlib.import_module(app_name)
|
748
|
-
|
774
|
+
if getattr(agent_module.agent, "root_agent"):
|
775
|
+
root_agent = agent_module.agent.root_agent
|
776
|
+
else:
|
777
|
+
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
778
|
+
|
779
|
+
# Handle an awaitable root agent and await for the actual agent.
|
780
|
+
if inspect.isawaitable(root_agent):
|
781
|
+
try:
|
782
|
+
agent, exit_stack = await root_agent
|
783
|
+
exit_stacks.append(exit_stack)
|
784
|
+
root_agent = agent
|
785
|
+
except Exception as e:
|
786
|
+
raise RuntimeError(f"error getting root agent, {e}") from e
|
787
|
+
|
749
788
|
root_agent_dict[app_name] = root_agent
|
750
789
|
return root_agent
|
751
790
|
|
752
|
-
def
|
791
|
+
async def _get_runner_async(app_name: str) -> Runner:
|
753
792
|
"""Returns the runner for the given app."""
|
793
|
+
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
754
794
|
if app_name in runner_dict:
|
755
795
|
return runner_dict[app_name]
|
756
|
-
root_agent =
|
796
|
+
root_agent = await _get_root_agent_async(app_name)
|
757
797
|
runner = Runner(
|
758
798
|
app_name=agent_engine_id if agent_engine_id else app_name,
|
759
799
|
agent=root_agent,
|
760
800
|
artifact_service=artifact_service,
|
761
801
|
session_service=session_service,
|
802
|
+
memory_service=memory_service,
|
762
803
|
)
|
763
804
|
runner_dict[app_name] = runner
|
764
805
|
return runner
|
google/adk/cli/utils/envs.py
CHANGED
google/adk/cli/utils/evals.py
CHANGED
@@ -66,7 +66,7 @@ def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
|
66
66
|
'tool_input': tool_input,
|
67
67
|
})
|
68
68
|
elif subsequent_part.text:
|
69
|
-
# Also keep track of all the natural
|
69
|
+
# Also keep track of all the natural language responses that
|
70
70
|
# agent (or sub agents) generated.
|
71
71
|
intermediate_agent_responses.append(
|
72
72
|
{'author': event_author, 'text': subsequent_part.text}
|
@@ -75,7 +75,7 @@ def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
|
|
75
75
|
# If we are here then either we are done reading all the events or we
|
76
76
|
# encountered an event that had content authored by the end-user.
|
77
77
|
# This, basically means an end of turn.
|
78
|
-
# We assume that the last natural
|
78
|
+
# We assume that the last natural language intermediate response is the
|
79
79
|
# final response from the agent/model. We treat that as a reference.
|
80
80
|
eval_case.append({
|
81
81
|
'query': query,
|
@@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
|
|
55
55
|
|
56
56
|
|
57
57
|
class AgentEvaluator:
|
58
|
-
"""An evaluator for Agents, mainly
|
58
|
+
"""An evaluator for Agents, mainly intended for helping with test cases."""
|
59
59
|
|
60
60
|
@staticmethod
|
61
61
|
def find_config_for_test_file(test_file: str):
|
@@ -91,7 +91,7 @@ class AgentEvaluator:
|
|
91
91
|
look for 'root_agent' in the loaded module.
|
92
92
|
eval_dataset: The eval data set. This can be either a string representing
|
93
93
|
full path to the file containing eval dataset, or a directory that is
|
94
|
-
|
94
|
+
recursively explored for all files that have a `.test.json` suffix.
|
95
95
|
num_runs: Number of times all entries in the eval dataset should be
|
96
96
|
assessed.
|
97
97
|
agent_name: The name of the agent.
|
@@ -42,10 +42,10 @@ class EvaluationGenerator:
|
|
42
42
|
"""Returns evaluation responses for the given dataset and agent.
|
43
43
|
|
44
44
|
Args:
|
45
|
-
eval_dataset: The dataset that needs to be scraped for
|
45
|
+
eval_dataset: The dataset that needs to be scraped for responses.
|
46
46
|
agent_module_path: Path to the module that contains the root agent.
|
47
47
|
repeat_num: Number of time the eval dataset should be repeated. This is
|
48
|
-
usually done to remove
|
48
|
+
usually done to remove uncertainty that a single run may bring.
|
49
49
|
agent_name: The name of the agent that should be evaluated. This is
|
50
50
|
usually the sub-agent.
|
51
51
|
initial_session: Initial session for the eval data.
|
@@ -253,8 +253,8 @@ class EvaluationGenerator:
|
|
253
253
|
all_mock_tools: set[str],
|
254
254
|
):
|
255
255
|
"""Recursively apply the before_tool_callback to the root agent and all its subagents."""
|
256
|
-
#
|
257
|
-
# We use function
|
256
|
+
# Check if the agent has tools that are defined by evalset.
|
257
|
+
# We use function names to check if tools match
|
258
258
|
if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
|
259
259
|
return
|
260
260
|
|
@@ -35,14 +35,14 @@ class ResponseEvaluator:
|
|
35
35
|
Args:
|
36
36
|
raw_eval_dataset: The dataset that will be evaluated.
|
37
37
|
evaluation_criteria: The evaluation criteria to be used. This method
|
38
|
-
support two
|
38
|
+
support two criteria, `response_evaluation_score` and
|
39
39
|
`response_match_score`.
|
40
40
|
print_detailed_results: Prints detailed results on the console. This is
|
41
41
|
usually helpful during debugging.
|
42
42
|
|
43
43
|
A note on evaluation_criteria:
|
44
44
|
`response_match_score`: This metric compares the agents final natural
|
45
|
-
language
|
45
|
+
language response with the expected final response, stored in the
|
46
46
|
"reference" field in test/eval files. We use Rouge metric to compare the
|
47
47
|
two responses.
|
48
48
|
|
@@ -56,7 +56,7 @@ class ResponseEvaluator:
|
|
56
56
|
Value range: [0, 5], where 0 means that the agent's response is not
|
57
57
|
coherent, while 5 means it is . High values are good.
|
58
58
|
A note on raw_eval_dataset:
|
59
|
-
The dataset should be a list session, where each
|
59
|
+
The dataset should be a list session, where each session is represented
|
60
60
|
as a list of interaction that need evaluation. Each evaluation is
|
61
61
|
represented as a dictionary that is expected to have values for the
|
62
62
|
following keys:
|
@@ -106,9 +106,11 @@ class ResponseEvaluator:
|
|
106
106
|
eval_dataset = pd.DataFrame(flattened_queries).rename(
|
107
107
|
columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
|
108
108
|
)
|
109
|
-
eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
|
110
109
|
|
111
|
-
eval_result =
|
110
|
+
eval_result = ResponseEvaluator._perform_eval(
|
111
|
+
dataset=eval_dataset, metrics=metrics
|
112
|
+
)
|
113
|
+
|
112
114
|
if print_detailed_results:
|
113
115
|
ResponseEvaluator._print_results(eval_result)
|
114
116
|
return eval_result.summary_metrics
|
@@ -129,6 +131,16 @@ class ResponseEvaluator:
|
|
129
131
|
metrics.append("rouge_1")
|
130
132
|
return metrics
|
131
133
|
|
134
|
+
@staticmethod
|
135
|
+
def _perform_eval(dataset, metrics):
|
136
|
+
"""This method hides away the call to external service.
|
137
|
+
|
138
|
+
Primarily helps with unit testing.
|
139
|
+
"""
|
140
|
+
eval_task = EvalTask(dataset=dataset, metrics=metrics)
|
141
|
+
|
142
|
+
return eval_task.evaluate()
|
143
|
+
|
132
144
|
@staticmethod
|
133
145
|
def _print_results(eval_result):
|
134
146
|
print("Evaluation Summary Metrics:", eval_result.summary_metrics)
|
@@ -31,10 +31,9 @@ class TrajectoryEvaluator:
|
|
31
31
|
):
|
32
32
|
r"""Returns the mean tool use accuracy of the eval dataset.
|
33
33
|
|
34
|
-
Tool use accuracy is calculated by comparing the expected and
|
35
|
-
use trajectories. An exact match scores a 1, 0 otherwise. The final
|
36
|
-
is an
|
37
|
-
average of these individual scores.
|
34
|
+
Tool use accuracy is calculated by comparing the expected and the actual
|
35
|
+
tool use trajectories. An exact match scores a 1, 0 otherwise. The final
|
36
|
+
number is an average of these individual scores.
|
38
37
|
|
39
38
|
Value range: [0, 1], where 0 is means none of the too use entries aligned,
|
40
39
|
and 1 would mean all of them aligned. Higher value is good.
|
@@ -45,7 +44,7 @@ class TrajectoryEvaluator:
|
|
45
44
|
usually helpful during debugging.
|
46
45
|
|
47
46
|
A note on eval_dataset:
|
48
|
-
The dataset should be a list session, where each
|
47
|
+
The dataset should be a list session, where each session is represented
|
49
48
|
as a list of interaction that need evaluation. Each evaluation is
|
50
49
|
represented as a dictionary that is expected to have values for the
|
51
50
|
following keys:
|
google/adk/events/event.py
CHANGED
@@ -70,7 +70,7 @@ class Event(LlmResponse):
|
|
70
70
|
agent_2, and agent_2 is the parent of agent_3.
|
71
71
|
|
72
72
|
Branch is used when multiple sub-agent shouldn't see their peer agents'
|
73
|
-
|
73
|
+
conversation history.
|
74
74
|
"""
|
75
75
|
|
76
76
|
# The following are computed fields.
|
@@ -94,7 +94,7 @@ class Event(LlmResponse):
|
|
94
94
|
not self.get_function_calls()
|
95
95
|
and not self.get_function_responses()
|
96
96
|
and not self.partial
|
97
|
-
and not self.
|
97
|
+
and not self.has_trailing_code_execution_result()
|
98
98
|
)
|
99
99
|
|
100
100
|
def get_function_calls(self) -> list[types.FunctionCall]:
|
@@ -115,7 +115,7 @@ class Event(LlmResponse):
|
|
115
115
|
func_response.append(part.function_response)
|
116
116
|
return func_response
|
117
117
|
|
118
|
-
def
|
118
|
+
def has_trailing_code_execution_result(
|
119
119
|
self,
|
120
120
|
) -> bool:
|
121
121
|
"""Returns whether the event has a trailing code execution result."""
|
@@ -87,15 +87,21 @@ class _NlPlanningResponse(BaseLlmResponseProcessor):
|
|
87
87
|
return
|
88
88
|
|
89
89
|
# Postprocess the LLM response.
|
90
|
+
callback_context = CallbackContext(invocation_context)
|
90
91
|
processed_parts = planner.process_planning_response(
|
91
|
-
|
92
|
+
callback_context, llm_response.content.parts
|
92
93
|
)
|
93
94
|
if processed_parts:
|
94
95
|
llm_response.content.parts = processed_parts
|
95
96
|
|
96
|
-
|
97
|
-
|
98
|
-
|
97
|
+
if callback_context.state.has_delta():
|
98
|
+
state_update_event = Event(
|
99
|
+
invocation_id=invocation_context.invocation_id,
|
100
|
+
author=invocation_context.agent.name,
|
101
|
+
branch=invocation_context.branch,
|
102
|
+
actions=callback_context._event_actions,
|
103
|
+
)
|
104
|
+
yield state_update_event
|
99
105
|
|
100
106
|
|
101
107
|
response_processor = _NlPlanningResponse()
|
@@ -94,7 +94,7 @@ can answer it.
|
|
94
94
|
|
95
95
|
If another agent is better for answering the question according to its
|
96
96
|
description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
|
97
|
-
question to that agent. When
|
97
|
+
question to that agent. When transferring, do not generate any text other than
|
98
98
|
the function call.
|
99
99
|
"""
|
100
100
|
|
@@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
|
|
115
115
|
yield event
|
116
116
|
# send back the function response
|
117
117
|
if event.get_function_responses():
|
118
|
-
logger.debug('Sending back last function
|
118
|
+
logger.debug('Sending back last function response event: %s', event)
|
119
119
|
invocation_context.live_request_queue.send_content(event.content)
|
120
120
|
if (
|
121
121
|
event.content
|
@@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
|
|
111
111
|
"""Rearrange the events for the latest function_response.
|
112
112
|
|
113
113
|
If the latest function_response is for an async function_call, all events
|
114
|
-
|
114
|
+
between the initial function_call and the latest function_response will be
|
115
115
|
removed.
|
116
116
|
|
117
117
|
Args:
|
@@ -310,7 +310,7 @@ def _merge_function_response_events(
|
|
310
310
|
function_response_events: A list of function_response events.
|
311
311
|
NOTE: function_response_events must fulfill these requirements: 1. The
|
312
312
|
list is in increasing order of timestamp; 2. the first event is the
|
313
|
-
initial
|
313
|
+
initial function_response event; 3. all later events should contain at
|
314
314
|
least one function_response part that related to the function_call
|
315
315
|
event. (Note, 3. may not be true when aync function return some
|
316
316
|
intermediate response, there could also be some intermediate model
|
@@ -310,9 +310,7 @@ async def _process_function_live_helper(
|
|
310
310
|
function_response = {
|
311
311
|
'status': f'No active streaming function named {function_name} found'
|
312
312
|
}
|
313
|
-
elif inspect.isasyncgenfunction(tool.func):
|
314
|
-
print('is async')
|
315
|
-
|
313
|
+
elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
|
316
314
|
# for streaming tool use case
|
317
315
|
# we require the function to be a async generator function
|
318
316
|
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
@@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
52
52
|
# Appends global instructions if set.
|
53
53
|
if (
|
54
54
|
isinstance(root_agent, LlmAgent) and root_agent.global_instruction
|
55
|
-
): # not
|
55
|
+
): # not empty str
|
56
56
|
raw_si = root_agent.canonical_global_instruction(
|
57
57
|
ReadonlyContext(invocation_context)
|
58
58
|
)
|
@@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
60
60
|
llm_request.append_instructions([si])
|
61
61
|
|
62
62
|
# Appends agent instructions if set.
|
63
|
-
if agent.instruction: # not
|
63
|
+
if agent.instruction: # not empty str
|
64
64
|
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
65
65
|
si = _populate_values(raw_si, invocation_context)
|
66
66
|
llm_request.append_instructions([si])
|
@@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
|
152
152
|
):
|
153
153
|
# TODO: Right now, we just support output_transcription without
|
154
154
|
# changing interface and data protocol. Later, we can consider to
|
155
|
-
# support output_transcription as a
|
155
|
+
# support output_transcription as a separate field in LlmResponse.
|
156
156
|
|
157
157
|
# Transcription is always considered as partial event
|
158
158
|
# We rely on other control signals to determine when to yield the
|
@@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
|
|
179
179
|
# in case of empty content or parts, we sill surface it
|
180
180
|
# in case it's an interrupted message, we merge the previous partial
|
181
181
|
# text. Other we don't merge. because content can be none when model
|
182
|
-
#
|
182
|
+
# safety threshold is triggered
|
183
183
|
if message.server_content.interrupted and text:
|
184
184
|
yield self.__build_full_text_response(text)
|
185
185
|
text = ''
|
google/adk/models/lite_llm.py
CHANGED
@@ -136,54 +136,68 @@ def _safe_json_serialize(obj) -> str:
|
|
136
136
|
|
137
137
|
def _content_to_message_param(
|
138
138
|
content: types.Content,
|
139
|
-
) -> Message:
|
140
|
-
"""Converts a types.Content to a litellm Message.
|
139
|
+
) -> Union[Message, list[Message]]:
|
140
|
+
"""Converts a types.Content to a litellm Message or list of Messages.
|
141
|
+
|
142
|
+
Handles multipart function responses by returning a list of
|
143
|
+
ChatCompletionToolMessage objects if multiple function_response parts exist.
|
141
144
|
|
142
145
|
Args:
|
143
146
|
content: The content to convert.
|
144
147
|
|
145
148
|
Returns:
|
146
|
-
|
149
|
+
A litellm Message, a list of litellm Messages.
|
147
150
|
"""
|
148
151
|
|
149
|
-
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
152
|
+
tool_messages = []
|
153
|
+
for part in content.parts:
|
154
|
+
if part.function_response:
|
155
|
+
tool_messages.append(
|
156
|
+
ChatCompletionToolMessage(
|
157
|
+
role="tool",
|
158
|
+
tool_call_id=part.function_response.id,
|
159
|
+
content=_safe_json_serialize(part.function_response.response),
|
160
|
+
)
|
161
|
+
)
|
162
|
+
if tool_messages:
|
163
|
+
return tool_messages if len(tool_messages) > 1 else tool_messages[0]
|
157
164
|
|
165
|
+
# Handle user or assistant messages
|
158
166
|
role = _to_litellm_role(content.role)
|
167
|
+
message_content = _get_content(content.parts) or None
|
159
168
|
|
160
169
|
if role == "user":
|
161
|
-
return ChatCompletionUserMessage(
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
170
|
+
return ChatCompletionUserMessage(role="user", content=message_content)
|
171
|
+
else: # assistant/model
|
172
|
+
tool_calls = []
|
173
|
+
content_present = False
|
174
|
+
for part in content.parts:
|
175
|
+
if part.function_call:
|
176
|
+
tool_calls.append(
|
177
|
+
ChatCompletionMessageToolCall(
|
178
|
+
type="function",
|
179
|
+
id=part.function_call.id,
|
180
|
+
function=Function(
|
181
|
+
name=part.function_call.name,
|
182
|
+
arguments=part.function_call.args,
|
183
|
+
),
|
184
|
+
)
|
185
|
+
)
|
186
|
+
elif part.text or part.inline_data:
|
187
|
+
content_present = True
|
188
|
+
|
189
|
+
final_content = message_content if content_present else None
|
178
190
|
|
179
191
|
return ChatCompletionAssistantMessage(
|
180
192
|
role=role,
|
181
|
-
content=
|
193
|
+
content=final_content,
|
182
194
|
tool_calls=tool_calls or None,
|
183
195
|
)
|
184
196
|
|
185
197
|
|
186
|
-
def _get_content(
|
198
|
+
def _get_content(
|
199
|
+
parts: Iterable[types.Part],
|
200
|
+
) -> Union[OpenAIMessageContent, str]:
|
187
201
|
"""Converts a list of parts to litellm content.
|
188
202
|
|
189
203
|
Args:
|
@@ -435,10 +449,13 @@ def _get_completion_inputs(
|
|
435
449
|
Returns:
|
436
450
|
The litellm inputs (message list and tool dictionary).
|
437
451
|
"""
|
438
|
-
messages = [
|
439
|
-
|
440
|
-
|
441
|
-
|
452
|
+
messages = []
|
453
|
+
for content in llm_request.contents or []:
|
454
|
+
message_param_or_list = _content_to_message_param(content)
|
455
|
+
if isinstance(message_param_or_list, list):
|
456
|
+
messages.extend(message_param_or_list)
|
457
|
+
elif message_param_or_list: # Ensure it's not None before appending
|
458
|
+
messages.append(message_param_or_list)
|
442
459
|
|
443
460
|
if llm_request.config.system_instruction:
|
444
461
|
messages.insert(
|
@@ -14,7 +14,7 @@
|
|
14
14
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
|
-
from typing import Optional
|
17
|
+
from typing import Any, Optional
|
18
18
|
|
19
19
|
from google.genai import types
|
20
20
|
from pydantic import BaseModel
|
@@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
|
|
37
37
|
error_message: Error message if the response is an error.
|
38
38
|
interrupted: Flag indicating that LLM was interrupted when generating the
|
39
39
|
content. Usually it's due to user interruption during a bidi streaming.
|
40
|
+
custom_metadata: The custom metadata of the LlmResponse.
|
40
41
|
"""
|
41
42
|
|
42
43
|
model_config = ConfigDict(extra='forbid')
|
@@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
|
|
71
72
|
Usually it's due to user interruption during a bidi streaming.
|
72
73
|
"""
|
73
74
|
|
75
|
+
custom_metadata: Optional[dict[str, Any]] = None
|
76
|
+
"""The custom metadata of the LlmResponse.
|
77
|
+
|
78
|
+
An optional key-value pair to label an LlmResponse.
|
79
|
+
|
80
|
+
NOTE: the entire dict must be JSON serializable.
|
81
|
+
"""
|
82
|
+
|
74
83
|
@staticmethod
|
75
84
|
def create(
|
76
85
|
generate_content_response: types.GenerateContentResponse,
|
@@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
|
|
56
56
|
llm_request: The LLM request to apply the thinking config to.
|
57
57
|
"""
|
58
58
|
if self.thinking_config:
|
59
|
+
llm_request.config = llm_request.config or types.GenerateContentConfig()
|
59
60
|
llm_request.config.thinking_config = self.thinking_config
|
60
61
|
|
61
62
|
@override
|
@@ -31,9 +31,9 @@ FINAL_ANSWER_TAG = '/*FINAL_ANSWER*/'
|
|
31
31
|
|
32
32
|
|
33
33
|
class PlanReActPlanner(BasePlanner):
|
34
|
-
"""Plan-Re-Act planner that
|
34
|
+
"""Plan-Re-Act planner that constrains the LLM response to generate a plan before any action/observation.
|
35
35
|
|
36
|
-
Note: this planner does not require the model to support
|
36
|
+
Note: this planner does not require the model to support built-in thinking
|
37
37
|
features or setting the thinking config.
|
38
38
|
"""
|
39
39
|
|