google-adk 0.1.1__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. google/adk/agents/base_agent.py +4 -4
  2. google/adk/agents/callback_context.py +0 -1
  3. google/adk/agents/invocation_context.py +1 -1
  4. google/adk/agents/remote_agent.py +1 -1
  5. google/adk/agents/run_config.py +1 -1
  6. google/adk/auth/auth_credential.py +2 -1
  7. google/adk/auth/auth_handler.py +7 -3
  8. google/adk/auth/auth_preprocessor.py +2 -2
  9. google/adk/auth/auth_tool.py +1 -1
  10. google/adk/cli/browser/index.html +2 -2
  11. google/adk/cli/browser/{main-SLIAU2JL.js → main-HWIBUY2R.js} +69 -69
  12. google/adk/cli/cli_create.py +279 -0
  13. google/adk/cli/cli_deploy.py +10 -1
  14. google/adk/cli/cli_eval.py +3 -3
  15. google/adk/cli/cli_tools_click.py +95 -19
  16. google/adk/cli/fast_api.py +57 -16
  17. google/adk/cli/utils/envs.py +0 -3
  18. google/adk/cli/utils/evals.py +2 -2
  19. google/adk/evaluation/agent_evaluator.py +2 -2
  20. google/adk/evaluation/evaluation_generator.py +4 -4
  21. google/adk/evaluation/response_evaluator.py +17 -5
  22. google/adk/evaluation/trajectory_evaluator.py +4 -5
  23. google/adk/events/event.py +3 -3
  24. google/adk/flows/llm_flows/_nl_planning.py +10 -4
  25. google/adk/flows/llm_flows/agent_transfer.py +1 -1
  26. google/adk/flows/llm_flows/base_llm_flow.py +1 -1
  27. google/adk/flows/llm_flows/contents.py +2 -2
  28. google/adk/flows/llm_flows/functions.py +1 -3
  29. google/adk/flows/llm_flows/instructions.py +2 -2
  30. google/adk/models/gemini_llm_connection.py +2 -2
  31. google/adk/models/lite_llm.py +51 -34
  32. google/adk/models/llm_response.py +10 -1
  33. google/adk/planners/built_in_planner.py +1 -0
  34. google/adk/planners/plan_re_act_planner.py +2 -2
  35. google/adk/runners.py +1 -1
  36. google/adk/sessions/database_session_service.py +91 -26
  37. google/adk/sessions/state.py +2 -2
  38. google/adk/telemetry.py +2 -2
  39. google/adk/tools/agent_tool.py +2 -3
  40. google/adk/tools/application_integration_tool/clients/integration_client.py +3 -2
  41. google/adk/tools/base_tool.py +1 -1
  42. google/adk/tools/function_parameter_parse_util.py +2 -2
  43. google/adk/tools/google_api_tool/__init__.py +74 -1
  44. google/adk/tools/google_api_tool/google_api_tool_set.py +12 -9
  45. google/adk/tools/google_api_tool/google_api_tool_sets.py +91 -34
  46. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +3 -1
  47. google/adk/tools/load_artifacts_tool.py +1 -1
  48. google/adk/tools/load_memory_tool.py +25 -2
  49. google/adk/tools/mcp_tool/mcp_session_manager.py +176 -0
  50. google/adk/tools/mcp_tool/mcp_tool.py +15 -2
  51. google/adk/tools/mcp_tool/mcp_toolset.py +31 -37
  52. google/adk/tools/openapi_tool/auth/credential_exchangers/oauth2_exchanger.py +4 -4
  53. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +1 -1
  54. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -12
  55. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +47 -9
  56. google/adk/tools/toolbox_tool.py +1 -1
  57. google/adk/version.py +1 -1
  58. google_adk-0.3.0.dist-info/METADATA +235 -0
  59. {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/RECORD +62 -60
  60. google_adk-0.1.1.dist-info/METADATA +0 -181
  61. {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/WHEEL +0 -0
  62. {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/entry_points.txt +0 -0
  63. {google_adk-0.1.1.dist-info → google_adk-0.3.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ from contextlib import asynccontextmanager
16
17
  import importlib
18
+ import inspect
17
19
  import json
18
20
  import logging
19
21
  import os
@@ -28,6 +30,7 @@ from typing import Literal
28
30
  from typing import Optional
29
31
 
30
32
  import click
33
+ from click import Tuple
31
34
  from fastapi import FastAPI
32
35
  from fastapi import HTTPException
33
36
  from fastapi import Query
@@ -56,6 +59,7 @@ from ..agents.llm_agent import Agent
56
59
  from ..agents.run_config import StreamingMode
57
60
  from ..artifacts import InMemoryArtifactService
58
61
  from ..events.event import Event
62
+ from ..memory.in_memory_memory_service import InMemoryMemoryService
59
63
  from ..runners import Runner
60
64
  from ..sessions.database_session_service import DatabaseSessionService
61
65
  from ..sessions.in_memory_session_service import InMemorySessionService
@@ -144,7 +148,7 @@ def get_fast_api_app(
144
148
  export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
145
149
  )
146
150
  if trace_to_cloud:
147
- envs.load_dotenv()
151
+ envs.load_dotenv_for_agent("", agent_dir)
148
152
  if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
149
153
  processor = export.BatchSpanProcessor(
150
154
  CloudTraceSpanExporter(project_id=project_id)
@@ -158,8 +162,22 @@ def get_fast_api_app(
158
162
 
159
163
  trace.set_tracer_provider(provider)
160
164
 
165
+ exit_stacks = []
166
+
167
+ @asynccontextmanager
168
+ async def internal_lifespan(app: FastAPI):
169
+ if lifespan:
170
+ async with lifespan(app) as lifespan_context:
171
+ yield
172
+
173
+ if exit_stacks:
174
+ for stack in exit_stacks:
175
+ await stack.aclose()
176
+ else:
177
+ yield
178
+
161
179
  # Run the FastAPI server.
162
- app = FastAPI(lifespan=lifespan)
180
+ app = FastAPI(lifespan=internal_lifespan)
163
181
 
164
182
  if allow_origins:
165
183
  app.add_middleware(
@@ -178,6 +196,7 @@ def get_fast_api_app(
178
196
 
179
197
  # Build the Artifact service
180
198
  artifact_service = InMemoryArtifactService()
199
+ memory_service = InMemoryMemoryService()
181
200
 
182
201
  # Build the Session service
183
202
  agent_engine_id = ""
@@ -355,7 +374,7 @@ def get_fast_api_app(
355
374
  "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
356
375
  response_model_exclude_none=True,
357
376
  )
358
- def add_session_to_eval_set(
377
+ async def add_session_to_eval_set(
359
378
  app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
360
379
  ):
361
380
  pattern = r"^[a-zA-Z0-9_]+$"
@@ -390,7 +409,9 @@ def get_fast_api_app(
390
409
  test_data = evals.convert_session_to_eval_format(session)
391
410
 
392
411
  # Populate the session with initial session state.
393
- initial_session_state = create_empty_state(_get_root_agent(app_name))
412
+ initial_session_state = create_empty_state(
413
+ await _get_root_agent_async(app_name)
414
+ )
394
415
 
395
416
  eval_set_data.append({
396
417
  "name": req.eval_id,
@@ -427,7 +448,7 @@ def get_fast_api_app(
427
448
  "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
428
449
  response_model_exclude_none=True,
429
450
  )
430
- def run_eval(
451
+ async def run_eval(
431
452
  app_name: str, eval_set_id: str, req: RunEvalRequest
432
453
  ) -> list[RunEvalResult]:
433
454
  from .cli_eval import run_evals
@@ -444,7 +465,7 @@ def get_fast_api_app(
444
465
  logger.info(
445
466
  "Eval ids to run list is empty. We will all evals in the eval set."
446
467
  )
447
- root_agent = _get_root_agent(app_name)
468
+ root_agent = await _get_root_agent_async(app_name)
448
469
  eval_results = list(
449
470
  run_evals(
450
471
  eval_set_to_evals,
@@ -574,7 +595,7 @@ def get_fast_api_app(
574
595
  )
575
596
  if not session:
576
597
  raise HTTPException(status_code=404, detail="Session not found")
577
- runner = _get_runner(req.app_name)
598
+ runner = await _get_runner_async(req.app_name)
578
599
  events = [
579
600
  event
580
601
  async for event in runner.run_async(
@@ -601,7 +622,7 @@ def get_fast_api_app(
601
622
  async def event_generator():
602
623
  try:
603
624
  stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
604
- runner = _get_runner(req.app_name)
625
+ runner = await _get_runner_async(req.app_name)
605
626
  async for event in runner.run_async(
606
627
  user_id=req.user_id,
607
628
  session_id=req.session_id,
@@ -627,7 +648,7 @@ def get_fast_api_app(
627
648
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
628
649
  response_model_exclude_none=True,
629
650
  )
630
- def get_event_graph(
651
+ async def get_event_graph(
631
652
  app_name: str, user_id: str, session_id: str, event_id: str
632
653
  ):
633
654
  # Connect to managed session if agent_engine_id is set.
@@ -644,7 +665,7 @@ def get_fast_api_app(
644
665
 
645
666
  function_calls = event.get_function_calls()
646
667
  function_responses = event.get_function_responses()
647
- root_agent = _get_root_agent(app_name)
668
+ root_agent = await _get_root_agent_async(app_name)
648
669
  dot_graph = None
649
670
  if function_calls:
650
671
  function_call_highlights = []
@@ -701,7 +722,7 @@ def get_fast_api_app(
701
722
  live_request_queue = LiveRequestQueue()
702
723
 
703
724
  async def forward_events():
704
- runner = _get_runner(app_name)
725
+ runner = await _get_runner_async(app_name)
705
726
  async for event in runner.run_live(
706
727
  session=session, live_request_queue=live_request_queue
707
728
  ):
@@ -735,30 +756,50 @@ def get_fast_api_app(
735
756
  except Exception as e:
736
757
  logger.exception("Error during live websocket communication: %s", e)
737
758
  traceback.print_exc()
759
+ WEBSOCKET_INTERNAL_ERROR_CODE = 1011
760
+ WEBSOCKET_MAX_BYTES_FOR_REASON = 123
761
+ await websocket.close(
762
+ code=WEBSOCKET_INTERNAL_ERROR_CODE,
763
+ reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
764
+ )
738
765
  finally:
739
766
  for task in pending:
740
767
  task.cancel()
741
768
 
742
- def _get_root_agent(app_name: str) -> Agent:
769
+ async def _get_root_agent_async(app_name: str) -> Agent:
743
770
  """Returns the root agent for the given app."""
744
771
  if app_name in root_agent_dict:
745
772
  return root_agent_dict[app_name]
746
- envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
747
773
  agent_module = importlib.import_module(app_name)
748
- root_agent: Agent = agent_module.agent.root_agent
774
+ if getattr(agent_module.agent, "root_agent"):
775
+ root_agent = agent_module.agent.root_agent
776
+ else:
777
+ raise ValueError(f'Unable to find "root_agent" from {app_name}.')
778
+
779
+ # Handle an awaitable root agent and await for the actual agent.
780
+ if inspect.isawaitable(root_agent):
781
+ try:
782
+ agent, exit_stack = await root_agent
783
+ exit_stacks.append(exit_stack)
784
+ root_agent = agent
785
+ except Exception as e:
786
+ raise RuntimeError(f"error getting root agent, {e}") from e
787
+
749
788
  root_agent_dict[app_name] = root_agent
750
789
  return root_agent
751
790
 
752
- def _get_runner(app_name: str) -> Runner:
791
+ async def _get_runner_async(app_name: str) -> Runner:
753
792
  """Returns the runner for the given app."""
793
+ envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
754
794
  if app_name in runner_dict:
755
795
  return runner_dict[app_name]
756
- root_agent = _get_root_agent(app_name)
796
+ root_agent = await _get_root_agent_async(app_name)
757
797
  runner = Runner(
758
798
  app_name=agent_engine_id if agent_engine_id else app_name,
759
799
  agent=root_agent,
760
800
  artifact_service=artifact_service,
761
801
  session_service=session_service,
802
+ memory_service=memory_service,
762
803
  )
763
804
  runner_dict[app_name] = runner
764
805
  return runner
@@ -50,8 +50,5 @@ def load_dotenv_for_agent(
50
50
  agent_name,
51
51
  dotenv_file_path,
52
52
  )
53
- logger.info(
54
- 'Reloaded %s file for %s at %s', filename, agent_name, dotenv_file_path
55
- )
56
53
  else:
57
54
  logger.info('No %s file found for %s', filename, agent_name)
@@ -66,7 +66,7 @@ def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
66
66
  'tool_input': tool_input,
67
67
  })
68
68
  elif subsequent_part.text:
69
- # Also keep track of all the natural langauge responses that
69
+ # Also keep track of all the natural language responses that
70
70
  # agent (or sub agents) generated.
71
71
  intermediate_agent_responses.append(
72
72
  {'author': event_author, 'text': subsequent_part.text}
@@ -75,7 +75,7 @@ def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
75
75
  # If we are here then either we are done reading all the events or we
76
76
  # encountered an event that had content authored by the end-user.
77
77
  # This, basically means an end of turn.
78
- # We assume that the last natural langauge intermediate response is the
78
+ # We assume that the last natural language intermediate response is the
79
79
  # final response from the agent/model. We treat that as a reference.
80
80
  eval_case.append({
81
81
  'query': query,
@@ -55,7 +55,7 @@ def load_json(file_path: str) -> Union[Dict, List]:
55
55
 
56
56
 
57
57
  class AgentEvaluator:
58
- """An evaluator for Agents, mainly intented for helping with test cases."""
58
+ """An evaluator for Agents, mainly intended for helping with test cases."""
59
59
 
60
60
  @staticmethod
61
61
  def find_config_for_test_file(test_file: str):
@@ -91,7 +91,7 @@ class AgentEvaluator:
91
91
  look for 'root_agent' in the loaded module.
92
92
  eval_dataset: The eval data set. This can be either a string representing
93
93
  full path to the file containing eval dataset, or a directory that is
94
- recusively explored for all files that have a `.test.json` suffix.
94
+ recursively explored for all files that have a `.test.json` suffix.
95
95
  num_runs: Number of times all entries in the eval dataset should be
96
96
  assessed.
97
97
  agent_name: The name of the agent.
@@ -42,10 +42,10 @@ class EvaluationGenerator:
42
42
  """Returns evaluation responses for the given dataset and agent.
43
43
 
44
44
  Args:
45
- eval_dataset: The dataset that needs to be scraped for resposnes.
45
+ eval_dataset: The dataset that needs to be scraped for responses.
46
46
  agent_module_path: Path to the module that contains the root agent.
47
47
  repeat_num: Number of time the eval dataset should be repeated. This is
48
- usually done to remove uncertainity that a single run may bring.
48
+ usually done to remove uncertainty that a single run may bring.
49
49
  agent_name: The name of the agent that should be evaluated. This is
50
50
  usually the sub-agent.
51
51
  initial_session: Initial session for the eval data.
@@ -253,8 +253,8 @@ class EvaluationGenerator:
253
253
  all_mock_tools: set[str],
254
254
  ):
255
255
  """Recursively apply the before_tool_callback to the root agent and all its subagents."""
256
- # check if the agent has tools that defined by evalset
257
- # We use function name to check if tools match
256
+ # Check if the agent has tools that are defined by evalset.
257
+ # We use function names to check if tools match
258
258
  if not isinstance(agent, Agent) and not isinstance(agent, LlmAgent):
259
259
  return
260
260
 
@@ -35,14 +35,14 @@ class ResponseEvaluator:
35
35
  Args:
36
36
  raw_eval_dataset: The dataset that will be evaluated.
37
37
  evaluation_criteria: The evaluation criteria to be used. This method
38
- support two criterias, `response_evaluation_score` and
38
+ support two criteria, `response_evaluation_score` and
39
39
  `response_match_score`.
40
40
  print_detailed_results: Prints detailed results on the console. This is
41
41
  usually helpful during debugging.
42
42
 
43
43
  A note on evaluation_criteria:
44
44
  `response_match_score`: This metric compares the agents final natural
45
- language reponse with the expected final response, stored in the
45
+ language response with the expected final response, stored in the
46
46
  "reference" field in test/eval files. We use Rouge metric to compare the
47
47
  two responses.
48
48
 
@@ -56,7 +56,7 @@ class ResponseEvaluator:
56
56
  Value range: [0, 5], where 0 means that the agent's response is not
57
57
  coherent, while 5 means it is . High values are good.
58
58
  A note on raw_eval_dataset:
59
- The dataset should be a list session, where each sesssion is represented
59
+ The dataset should be a list session, where each session is represented
60
60
  as a list of interaction that need evaluation. Each evaluation is
61
61
  represented as a dictionary that is expected to have values for the
62
62
  following keys:
@@ -106,9 +106,11 @@ class ResponseEvaluator:
106
106
  eval_dataset = pd.DataFrame(flattened_queries).rename(
107
107
  columns={"query": "prompt", "expected_tool_use": "reference_trajectory"}
108
108
  )
109
- eval_task = EvalTask(dataset=eval_dataset, metrics=metrics)
110
109
 
111
- eval_result = eval_task.evaluate()
110
+ eval_result = ResponseEvaluator._perform_eval(
111
+ dataset=eval_dataset, metrics=metrics
112
+ )
113
+
112
114
  if print_detailed_results:
113
115
  ResponseEvaluator._print_results(eval_result)
114
116
  return eval_result.summary_metrics
@@ -129,6 +131,16 @@ class ResponseEvaluator:
129
131
  metrics.append("rouge_1")
130
132
  return metrics
131
133
 
134
+ @staticmethod
135
+ def _perform_eval(dataset, metrics):
136
+ """This method hides away the call to external service.
137
+
138
+ Primarily helps with unit testing.
139
+ """
140
+ eval_task = EvalTask(dataset=dataset, metrics=metrics)
141
+
142
+ return eval_task.evaluate()
143
+
132
144
  @staticmethod
133
145
  def _print_results(eval_result):
134
146
  print("Evaluation Summary Metrics:", eval_result.summary_metrics)
@@ -31,10 +31,9 @@ class TrajectoryEvaluator:
31
31
  ):
32
32
  r"""Returns the mean tool use accuracy of the eval dataset.
33
33
 
34
- Tool use accuracy is calculated by comparing the expected and actuall tool
35
- use trajectories. An exact match scores a 1, 0 otherwise. The final number
36
- is an
37
- average of these individual scores.
34
+ Tool use accuracy is calculated by comparing the expected and the actual
35
+ tool use trajectories. An exact match scores a 1, 0 otherwise. The final
36
+ number is an average of these individual scores.
38
37
 
39
38
  Value range: [0, 1], where 0 is means none of the too use entries aligned,
40
39
  and 1 would mean all of them aligned. Higher value is good.
@@ -45,7 +44,7 @@ class TrajectoryEvaluator:
45
44
  usually helpful during debugging.
46
45
 
47
46
  A note on eval_dataset:
48
- The dataset should be a list session, where each sesssion is represented
47
+ The dataset should be a list session, where each session is represented
49
48
  as a list of interaction that need evaluation. Each evaluation is
50
49
  represented as a dictionary that is expected to have values for the
51
50
  following keys:
@@ -70,7 +70,7 @@ class Event(LlmResponse):
70
70
  agent_2, and agent_2 is the parent of agent_3.
71
71
 
72
72
  Branch is used when multiple sub-agent shouldn't see their peer agents'
73
- conversaction history.
73
+ conversation history.
74
74
  """
75
75
 
76
76
  # The following are computed fields.
@@ -94,7 +94,7 @@ class Event(LlmResponse):
94
94
  not self.get_function_calls()
95
95
  and not self.get_function_responses()
96
96
  and not self.partial
97
- and not self.has_trailing_code_exeuction_result()
97
+ and not self.has_trailing_code_execution_result()
98
98
  )
99
99
 
100
100
  def get_function_calls(self) -> list[types.FunctionCall]:
@@ -115,7 +115,7 @@ class Event(LlmResponse):
115
115
  func_response.append(part.function_response)
116
116
  return func_response
117
117
 
118
- def has_trailing_code_exeuction_result(
118
+ def has_trailing_code_execution_result(
119
119
  self,
120
120
  ) -> bool:
121
121
  """Returns whether the event has a trailing code execution result."""
@@ -87,15 +87,21 @@ class _NlPlanningResponse(BaseLlmResponseProcessor):
87
87
  return
88
88
 
89
89
  # Postprocess the LLM response.
90
+ callback_context = CallbackContext(invocation_context)
90
91
  processed_parts = planner.process_planning_response(
91
- CallbackContext(invocation_context), llm_response.content.parts
92
+ callback_context, llm_response.content.parts
92
93
  )
93
94
  if processed_parts:
94
95
  llm_response.content.parts = processed_parts
95
96
 
96
- # Maintain async generator behavior
97
- if False: # Ensures it behaves as a generator
98
- yield # This is a no-op but maintains generator structure
97
+ if callback_context.state.has_delta():
98
+ state_update_event = Event(
99
+ invocation_id=invocation_context.invocation_id,
100
+ author=invocation_context.agent.name,
101
+ branch=invocation_context.branch,
102
+ actions=callback_context._event_actions,
103
+ )
104
+ yield state_update_event
99
105
 
100
106
 
101
107
  response_processor = _NlPlanningResponse()
@@ -94,7 +94,7 @@ can answer it.
94
94
 
95
95
  If another agent is better for answering the question according to its
96
96
  description, call `{_TRANSFER_TO_AGENT_FUNCTION_NAME}` function to transfer the
97
- question to that agent. When transfering, do not generate any text other than
97
+ question to that agent. When transferring, do not generate any text other than
98
98
  the function call.
99
99
  """
100
100
 
@@ -115,7 +115,7 @@ class BaseLlmFlow(ABC):
115
115
  yield event
116
116
  # send back the function response
117
117
  if event.get_function_responses():
118
- logger.debug('Sending back last function resonse event: %s', event)
118
+ logger.debug('Sending back last function response event: %s', event)
119
119
  invocation_context.live_request_queue.send_content(event.content)
120
120
  if (
121
121
  event.content
@@ -111,7 +111,7 @@ def _rearrange_events_for_latest_function_response(
111
111
  """Rearrange the events for the latest function_response.
112
112
 
113
113
  If the latest function_response is for an async function_call, all events
114
- bewteen the initial function_call and the latest function_response will be
114
+ between the initial function_call and the latest function_response will be
115
115
  removed.
116
116
 
117
117
  Args:
@@ -310,7 +310,7 @@ def _merge_function_response_events(
310
310
  function_response_events: A list of function_response events.
311
311
  NOTE: function_response_events must fulfill these requirements: 1. The
312
312
  list is in increasing order of timestamp; 2. the first event is the
313
- initial function_reponse event; 3. all later events should contain at
313
+ initial function_response event; 3. all later events should contain at
314
314
  least one function_response part that related to the function_call
315
315
  event. (Note, 3. may not be true when aync function return some
316
316
  intermediate response, there could also be some intermediate model
@@ -310,9 +310,7 @@ async def _process_function_live_helper(
310
310
  function_response = {
311
311
  'status': f'No active streaming function named {function_name} found'
312
312
  }
313
- elif inspect.isasyncgenfunction(tool.func):
314
- print('is async')
315
-
313
+ elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
316
314
  # for streaming tool use case
317
315
  # we require the function to be a async generator function
318
316
  async def run_tool_and_update_queue(tool, function_args, tool_context):
@@ -52,7 +52,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
52
52
  # Appends global instructions if set.
53
53
  if (
54
54
  isinstance(root_agent, LlmAgent) and root_agent.global_instruction
55
- ): # not emtpy str
55
+ ): # not empty str
56
56
  raw_si = root_agent.canonical_global_instruction(
57
57
  ReadonlyContext(invocation_context)
58
58
  )
@@ -60,7 +60,7 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
60
60
  llm_request.append_instructions([si])
61
61
 
62
62
  # Appends agent instructions if set.
63
- if agent.instruction: # not emtpy str
63
+ if agent.instruction: # not empty str
64
64
  raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
65
65
  si = _populate_values(raw_si, invocation_context)
66
66
  llm_request.append_instructions([si])
@@ -152,7 +152,7 @@ class GeminiLlmConnection(BaseLlmConnection):
152
152
  ):
153
153
  # TODO: Right now, we just support output_transcription without
154
154
  # changing interface and data protocol. Later, we can consider to
155
- # support output_transcription as a separete field in LlmResponse.
155
+ # support output_transcription as a separate field in LlmResponse.
156
156
 
157
157
  # Transcription is always considered as partial event
158
158
  # We rely on other control signals to determine when to yield the
@@ -179,7 +179,7 @@ class GeminiLlmConnection(BaseLlmConnection):
179
179
  # in case of empty content or parts, we sill surface it
180
180
  # in case it's an interrupted message, we merge the previous partial
181
181
  # text. Other we don't merge. because content can be none when model
182
- # safty threshold is triggered
182
+ # safety threshold is triggered
183
183
  if message.server_content.interrupted and text:
184
184
  yield self.__build_full_text_response(text)
185
185
  text = ''
@@ -136,54 +136,68 @@ def _safe_json_serialize(obj) -> str:
136
136
 
137
137
  def _content_to_message_param(
138
138
  content: types.Content,
139
- ) -> Message:
140
- """Converts a types.Content to a litellm Message.
139
+ ) -> Union[Message, list[Message]]:
140
+ """Converts a types.Content to a litellm Message or list of Messages.
141
+
142
+ Handles multipart function responses by returning a list of
143
+ ChatCompletionToolMessage objects if multiple function_response parts exist.
141
144
 
142
145
  Args:
143
146
  content: The content to convert.
144
147
 
145
148
  Returns:
146
- The litellm Message.
149
+ A litellm Message, a list of litellm Messages.
147
150
  """
148
151
 
149
- if content.parts and content.parts[0].function_response:
150
- return ChatCompletionToolMessage(
151
- role="tool",
152
- tool_call_id=content.parts[0].function_response.id,
153
- content=_safe_json_serialize(
154
- content.parts[0].function_response.response
155
- ),
156
- )
152
+ tool_messages = []
153
+ for part in content.parts:
154
+ if part.function_response:
155
+ tool_messages.append(
156
+ ChatCompletionToolMessage(
157
+ role="tool",
158
+ tool_call_id=part.function_response.id,
159
+ content=_safe_json_serialize(part.function_response.response),
160
+ )
161
+ )
162
+ if tool_messages:
163
+ return tool_messages if len(tool_messages) > 1 else tool_messages[0]
157
164
 
165
+ # Handle user or assistant messages
158
166
  role = _to_litellm_role(content.role)
167
+ message_content = _get_content(content.parts) or None
159
168
 
160
169
  if role == "user":
161
- return ChatCompletionUserMessage(
162
- role="user", content=_get_content(content.parts)
163
- )
164
- else:
165
-
166
- tool_calls = [
167
- ChatCompletionMessageToolCall(
168
- type="function",
169
- id=part.function_call.id,
170
- function=Function(
171
- name=part.function_call.name,
172
- arguments=part.function_call.args,
173
- ),
174
- )
175
- for part in content.parts
176
- if part.function_call
177
- ]
170
+ return ChatCompletionUserMessage(role="user", content=message_content)
171
+ else: # assistant/model
172
+ tool_calls = []
173
+ content_present = False
174
+ for part in content.parts:
175
+ if part.function_call:
176
+ tool_calls.append(
177
+ ChatCompletionMessageToolCall(
178
+ type="function",
179
+ id=part.function_call.id,
180
+ function=Function(
181
+ name=part.function_call.name,
182
+ arguments=part.function_call.args,
183
+ ),
184
+ )
185
+ )
186
+ elif part.text or part.inline_data:
187
+ content_present = True
188
+
189
+ final_content = message_content if content_present else None
178
190
 
179
191
  return ChatCompletionAssistantMessage(
180
192
  role=role,
181
- content=_get_content(content.parts),
193
+ content=final_content,
182
194
  tool_calls=tool_calls or None,
183
195
  )
184
196
 
185
197
 
186
- def _get_content(parts: Iterable[types.Part]) -> OpenAIMessageContent | str:
198
+ def _get_content(
199
+ parts: Iterable[types.Part],
200
+ ) -> Union[OpenAIMessageContent, str]:
187
201
  """Converts a list of parts to litellm content.
188
202
 
189
203
  Args:
@@ -435,10 +449,13 @@ def _get_completion_inputs(
435
449
  Returns:
436
450
  The litellm inputs (message list and tool dictionary).
437
451
  """
438
- messages = [
439
- _content_to_message_param(content)
440
- for content in llm_request.contents or []
441
- ]
452
+ messages = []
453
+ for content in llm_request.contents or []:
454
+ message_param_or_list = _content_to_message_param(content)
455
+ if isinstance(message_param_or_list, list):
456
+ messages.extend(message_param_or_list)
457
+ elif message_param_or_list: # Ensure it's not None before appending
458
+ messages.append(message_param_or_list)
442
459
 
443
460
  if llm_request.config.system_instruction:
444
461
  messages.insert(
@@ -14,7 +14,7 @@
14
14
 
15
15
  from __future__ import annotations
16
16
 
17
- from typing import Optional
17
+ from typing import Any, Optional
18
18
 
19
19
  from google.genai import types
20
20
  from pydantic import BaseModel
@@ -37,6 +37,7 @@ class LlmResponse(BaseModel):
37
37
  error_message: Error message if the response is an error.
38
38
  interrupted: Flag indicating that LLM was interrupted when generating the
39
39
  content. Usually it's due to user interruption during a bidi streaming.
40
+ custom_metadata: The custom metadata of the LlmResponse.
40
41
  """
41
42
 
42
43
  model_config = ConfigDict(extra='forbid')
@@ -71,6 +72,14 @@ class LlmResponse(BaseModel):
71
72
  Usually it's due to user interruption during a bidi streaming.
72
73
  """
73
74
 
75
+ custom_metadata: Optional[dict[str, Any]] = None
76
+ """The custom metadata of the LlmResponse.
77
+
78
+ An optional key-value pair to label an LlmResponse.
79
+
80
+ NOTE: the entire dict must be JSON serializable.
81
+ """
82
+
74
83
  @staticmethod
75
84
  def create(
76
85
  generate_content_response: types.GenerateContentResponse,
@@ -56,6 +56,7 @@ class BuiltInPlanner(BasePlanner):
56
56
  llm_request: The LLM request to apply the thinking config to.
57
57
  """
58
58
  if self.thinking_config:
59
+ llm_request.config = llm_request.config or types.GenerateContentConfig()
59
60
  llm_request.config.thinking_config = self.thinking_config
60
61
 
61
62
  @override
@@ -31,9 +31,9 @@ FINAL_ANSWER_TAG = '/*FINAL_ANSWER*/'
31
31
 
32
32
 
33
33
  class PlanReActPlanner(BasePlanner):
34
- """Plan-Re-Act planner that constraints the LLM response to generate a plan before any action/observation.
34
+ """Plan-Re-Act planner that constrains the LLM response to generate a plan before any action/observation.
35
35
 
36
- Note: this planner does not require the model to support buil-in thinking
36
+ Note: this planner does not require the model to support built-in thinking
37
37
  features or setting the thinking config.
38
38
  """
39
39