google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (129) hide show
  1. google/adk/agents/active_streaming_tool.py +1 -0
  2. google/adk/agents/base_agent.py +91 -47
  3. google/adk/agents/base_agent.py.orig +330 -0
  4. google/adk/agents/callback_context.py +4 -9
  5. google/adk/agents/invocation_context.py +1 -0
  6. google/adk/agents/langgraph_agent.py +1 -0
  7. google/adk/agents/live_request_queue.py +1 -0
  8. google/adk/agents/llm_agent.py +172 -35
  9. google/adk/agents/loop_agent.py +1 -1
  10. google/adk/agents/parallel_agent.py +7 -0
  11. google/adk/agents/readonly_context.py +7 -1
  12. google/adk/agents/run_config.py +5 -1
  13. google/adk/agents/sequential_agent.py +31 -0
  14. google/adk/agents/transcription_entry.py +5 -2
  15. google/adk/artifacts/base_artifact_service.py +5 -10
  16. google/adk/artifacts/gcs_artifact_service.py +9 -9
  17. google/adk/artifacts/in_memory_artifact_service.py +6 -6
  18. google/adk/auth/auth_credential.py +9 -5
  19. google/adk/auth/auth_preprocessor.py +7 -1
  20. google/adk/auth/auth_tool.py +3 -4
  21. google/adk/cli/agent_graph.py +5 -5
  22. google/adk/cli/browser/index.html +2 -2
  23. google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
  24. google/adk/cli/cli.py +7 -7
  25. google/adk/cli/cli_deploy.py +7 -2
  26. google/adk/cli/cli_eval.py +181 -106
  27. google/adk/cli/cli_tools_click.py +147 -62
  28. google/adk/cli/fast_api.py +340 -158
  29. google/adk/cli/fast_api.py.orig +822 -0
  30. google/adk/cli/utils/common.py +23 -0
  31. google/adk/cli/utils/evals.py +83 -1
  32. google/adk/cli/utils/logs.py +13 -5
  33. google/adk/code_executors/__init__.py +3 -1
  34. google/adk/code_executors/built_in_code_executor.py +52 -0
  35. google/adk/evaluation/__init__.py +1 -1
  36. google/adk/evaluation/agent_evaluator.py +168 -128
  37. google/adk/evaluation/eval_case.py +102 -0
  38. google/adk/evaluation/eval_set.py +37 -0
  39. google/adk/evaluation/eval_sets_manager.py +42 -0
  40. google/adk/evaluation/evaluation_constants.py +1 -0
  41. google/adk/evaluation/evaluation_generator.py +89 -114
  42. google/adk/evaluation/evaluator.py +56 -0
  43. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  44. google/adk/evaluation/response_evaluator.py +107 -3
  45. google/adk/evaluation/trajectory_evaluator.py +83 -2
  46. google/adk/events/event.py +7 -1
  47. google/adk/events/event_actions.py +7 -1
  48. google/adk/examples/example.py +1 -0
  49. google/adk/examples/example_util.py +3 -2
  50. google/adk/flows/__init__.py +0 -1
  51. google/adk/flows/llm_flows/_code_execution.py +19 -11
  52. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  53. google/adk/flows/llm_flows/base_llm_flow.py +86 -22
  54. google/adk/flows/llm_flows/basic.py +3 -0
  55. google/adk/flows/llm_flows/functions.py +10 -9
  56. google/adk/flows/llm_flows/instructions.py +28 -9
  57. google/adk/flows/llm_flows/single_flow.py +1 -1
  58. google/adk/memory/__init__.py +1 -1
  59. google/adk/memory/_utils.py +23 -0
  60. google/adk/memory/base_memory_service.py +25 -21
  61. google/adk/memory/base_memory_service.py.orig +76 -0
  62. google/adk/memory/in_memory_memory_service.py +59 -27
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
  65. google/adk/models/anthropic_llm.py +36 -11
  66. google/adk/models/base_llm.py +45 -4
  67. google/adk/models/gemini_llm_connection.py +15 -2
  68. google/adk/models/google_llm.py +9 -44
  69. google/adk/models/google_llm.py.orig +305 -0
  70. google/adk/models/lite_llm.py +94 -38
  71. google/adk/models/llm_request.py +1 -1
  72. google/adk/models/llm_response.py +15 -3
  73. google/adk/models/registry.py +1 -1
  74. google/adk/runners.py +68 -44
  75. google/adk/sessions/__init__.py +1 -1
  76. google/adk/sessions/_session_util.py +14 -0
  77. google/adk/sessions/base_session_service.py +8 -32
  78. google/adk/sessions/database_session_service.py +58 -61
  79. google/adk/sessions/in_memory_session_service.py +108 -26
  80. google/adk/sessions/session.py +4 -0
  81. google/adk/sessions/vertex_ai_session_service.py +23 -45
  82. google/adk/telemetry.py +3 -0
  83. google/adk/tools/__init__.py +4 -7
  84. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  85. google/adk/tools/_memory_entry_utils.py +30 -0
  86. google/adk/tools/agent_tool.py +16 -13
  87. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +58 -0
  93. google/adk/tools/enterprise_search_tool.py +65 -0
  94. google/adk/tools/function_parameter_parse_util.py +2 -2
  95. google/adk/tools/google_api_tool/__init__.py +18 -70
  96. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  97. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  98. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  99. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  100. google/adk/tools/langchain_tool.py +96 -49
  101. google/adk/tools/load_artifacts_tool.py +4 -4
  102. google/adk/tools/load_memory_tool.py +16 -5
  103. google/adk/tools/mcp_tool/__init__.py +3 -2
  104. google/adk/tools/mcp_tool/conversion_utils.py +1 -1
  105. google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
  106. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  107. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  108. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  109. google/adk/tools/openapi_tool/common/common.py +2 -5
  110. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  111. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
  112. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  113. google/adk/tools/preload_memory_tool.py +27 -18
  114. google/adk/tools/retrieval/__init__.py +1 -1
  115. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  116. google/adk/tools/tool_context.py +4 -4
  117. google/adk/tools/toolbox_toolset.py +79 -0
  118. google/adk/tools/transfer_to_agent_tool.py +0 -1
  119. google/adk/version.py +1 -1
  120. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  121. google_adk-1.0.0.dist-info/RECORD +195 -0
  122. google/adk/agents/remote_agent.py +0 -50
  123. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  124. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  125. google/adk/tools/toolbox_tool.py +0 -46
  126. google_adk-0.4.0.dist-info/RECORD +0 -179
  127. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  128. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  129. {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,23 +12,127 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any
15
+ from typing import Any, Optional
16
16
 
17
+ from deprecated import deprecated
18
+ from google.genai import types as genai_types
17
19
  import pandas as pd
18
20
  from tabulate import tabulate
21
+ from typing_extensions import override
19
22
  from vertexai.preview.evaluation import EvalTask
20
23
  from vertexai.preview.evaluation import MetricPromptTemplateExamples
21
24
 
25
+ from .eval_case import IntermediateData
26
+ from .eval_case import Invocation
27
+ from .evaluator import EvalStatus
28
+ from .evaluator import EvaluationResult
29
+ from .evaluator import Evaluator
30
+ from .evaluator import PerInvocationResult
22
31
 
23
- class ResponseEvaluator:
32
+
33
+ class ResponseEvaluator(Evaluator):
24
34
  """Runs response evaluation for agents."""
25
35
 
36
+ def __init__(self, threshold: float, metric_name: str):
37
+ if "response_evaluation_score" == metric_name:
38
+ self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
39
+ elif "response_match_score" == metric_name:
40
+ self._metric_name = "rouge_1"
41
+ else:
42
+ raise ValueError(f"`{metric_name}` is not supported.")
43
+
44
+ self._threshold = threshold
45
+
46
+ @override
47
+ def evaluate_invocations(
48
+ self,
49
+ actual_invocations: list[Invocation],
50
+ expected_invocations: list[Invocation],
51
+ ) -> EvaluationResult:
52
+ total_score = 0.0
53
+ num_invocations = 0
54
+ per_invocation_results = []
55
+ for actual, expected in zip(actual_invocations, expected_invocations):
56
+ prompt = self._get_text(expected.user_content)
57
+ reference = self._get_text(expected.final_response)
58
+ response = self._get_text(actual.final_response)
59
+ actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
60
+ reference_trajectory = self._get_tool_use_trajectory(
61
+ expected.intermediate_data
62
+ )
63
+
64
+ eval_case = {
65
+ "prompt": prompt,
66
+ "reference": reference,
67
+ "response": response,
68
+ "actual_tool_user": actual_tool_use,
69
+ "reference_trajectory": reference_trajectory,
70
+ }
71
+
72
+ eval_case_result = ResponseEvaluator._perform_eval(
73
+ pd.DataFrame([eval_case]), [self._metric_name]
74
+ )
75
+ score = self._get_score(eval_case_result)
76
+ per_invocation_results.append(
77
+ PerInvocationResult(
78
+ actual_invocation=actual,
79
+ expected_invocation=expected,
80
+ score=score,
81
+ eval_status=self._get_eval_status(score),
82
+ )
83
+ )
84
+ total_score += score
85
+ num_invocations += 1
86
+
87
+ if per_invocation_results:
88
+ overall_score = total_score / num_invocations
89
+ return EvaluationResult(
90
+ overall_score=overall_score,
91
+ overall_eval_status=self._get_eval_status(overall_score),
92
+ per_invocation_results=per_invocation_results,
93
+ )
94
+
95
+ return EvaluationResult()
96
+
97
+ def _get_text(self, content: Optional[genai_types.Content]) -> str:
98
+ if content and content.parts:
99
+ return "\n".join([p.text for p in content.parts if p.text])
100
+
101
+ return ""
102
+
103
+ def _get_tool_use_trajectory(
104
+ self, intermediate_data: Optional[IntermediateData]
105
+ ) -> list[dict[str, Any]]:
106
+ tool_use_trajectory = []
107
+ if not intermediate_data:
108
+ return tool_use_trajectory
109
+
110
+ for function_call in intermediate_data.tool_uses:
111
+ tool_use_trajectory.append({
112
+ "tool_name": function_call.name,
113
+ "tool_input": function_call.args or {},
114
+ })
115
+
116
+ return tool_use_trajectory
117
+
118
+ def _get_score(self, eval_result) -> float:
119
+ return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
120
+
121
+ def _get_eval_status(self, score: float):
122
+ return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
123
+
26
124
  @staticmethod
125
+ @deprecated(
126
+ reason=(
127
+ "This method has been deprecated and will be removed soon. Please use"
128
+ " evaluate_invocations instead."
129
+ )
130
+ )
27
131
  def evaluate(
28
132
  raw_eval_dataset: list[list[dict[str, Any]]],
29
133
  evaluation_criteria: list[str],
30
134
  *,
31
- print_detailed_results: bool = False
135
+ print_detailed_results: bool = False,
32
136
  ):
33
137
  r"""Returns the value of requested evaluation metrics.
34
138
 
@@ -12,18 +12,98 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any
15
+ from typing import Any, cast
16
16
 
17
+ from deprecated import deprecated
18
+ from google.genai import types as genai_types
17
19
  import pandas as pd
18
20
  from tabulate import tabulate
21
+ from typing_extensions import override
19
22
 
23
+ from .eval_case import Invocation
20
24
  from .evaluation_constants import EvalConstants
25
+ from .evaluator import EvalStatus
26
+ from .evaluator import EvaluationResult
27
+ from .evaluator import Evaluator
28
+ from .evaluator import PerInvocationResult
21
29
 
22
30
 
23
- class TrajectoryEvaluator:
31
+ class TrajectoryEvaluator(Evaluator):
24
32
  """Evaluates tool use trajectories for accuracy."""
25
33
 
34
+ def __init__(self, threshold: float):
35
+ self._threshold = threshold
36
+
37
+ @override
38
+ def evaluate_invocations(
39
+ self,
40
+ actual_invocations: list[Invocation],
41
+ expected_invocations: list[Invocation],
42
+ ) -> EvaluationResult:
43
+ """Returns EvaluationResult after performing evaluations using actual and expected invocations."""
44
+ total_tool_use_accuracy = 0.0
45
+ num_invocations = 0
46
+ per_invocation_results = []
47
+
48
+ for actual, expected in zip(actual_invocations, expected_invocations):
49
+ actual_tool_uses = (
50
+ actual.intermediate_data.tool_uses if actual.intermediate_data else []
51
+ )
52
+ expected_tool_uses = (
53
+ expected.intermediate_data.tool_uses
54
+ if expected.intermediate_data
55
+ else []
56
+ )
57
+ tool_use_accuracy = (
58
+ 1.0
59
+ if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
60
+ else 0.0
61
+ )
62
+ per_invocation_results.append(
63
+ PerInvocationResult(
64
+ actual_invocation=actual,
65
+ expected_invocation=expected,
66
+ score=tool_use_accuracy,
67
+ eval_status=self._get_eval_status(tool_use_accuracy),
68
+ )
69
+ )
70
+ total_tool_use_accuracy += tool_use_accuracy
71
+ num_invocations += 1
72
+
73
+ if per_invocation_results:
74
+ overall_score = total_tool_use_accuracy / num_invocations
75
+ return EvaluationResult(
76
+ overall_score=overall_score,
77
+ overall_eval_status=self._get_eval_status(overall_score),
78
+ per_invocation_results=per_invocation_results,
79
+ )
80
+
81
+ return EvaluationResult()
82
+
83
+ def _are_tool_calls_equal(
84
+ self,
85
+ actual_tool_calls: list[genai_types.FunctionCall],
86
+ expected_tool_calls: list[genai_types.FunctionCall],
87
+ ) -> bool:
88
+ if len(actual_tool_calls) != len(expected_tool_calls):
89
+ return False
90
+
91
+ for actual, expected in zip(actual_tool_calls, expected_tool_calls):
92
+ if actual.name != expected.name or actual.args != expected.args:
93
+ return False
94
+
95
+ return True
96
+
97
+ def _get_eval_status(self, score: float):
98
+ return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
99
+
26
100
  @staticmethod
101
+ @deprecated(
102
+ reason=(
103
+ "This method has been deprecated and will be removed soon. Please use"
104
+ " evaluate_invocations instead."
105
+ )
106
+ )
27
107
  def evaluate(
28
108
  eval_dataset: list[list[dict[str, Any]]],
29
109
  *,
@@ -137,6 +217,7 @@ class TrajectoryEvaluator:
137
217
  return new_row, failure
138
218
 
139
219
  @staticmethod
220
+ @deprecated()
140
221
  def are_tools_equal(list_a_original, list_b_original):
141
222
  # Remove other entries that we don't want to evaluate
142
223
  list_a = [
@@ -19,6 +19,7 @@ import string
19
19
  from typing import Optional
20
20
 
21
21
  from google.genai import types
22
+ from pydantic import alias_generators
22
23
  from pydantic import ConfigDict
23
24
  from pydantic import Field
24
25
 
@@ -46,8 +47,13 @@ class Event(LlmResponse):
46
47
  """
47
48
 
48
49
  model_config = ConfigDict(
49
- extra='forbid', ser_json_bytes='base64', val_json_bytes='base64'
50
+ extra='forbid',
51
+ ser_json_bytes='base64',
52
+ val_json_bytes='base64',
53
+ alias_generator=alias_generators.to_camel,
54
+ populate_by_name=True,
50
55
  )
56
+ """The pydantic model config."""
51
57
 
52
58
  # TODO: revert to be required after spark migration
53
59
  invocation_id: str = ''
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from typing import Optional
18
18
 
19
+ from pydantic import alias_generators
19
20
  from pydantic import BaseModel
20
21
  from pydantic import ConfigDict
21
22
  from pydantic import Field
@@ -26,7 +27,12 @@ from ..auth.auth_tool import AuthConfig
26
27
  class EventActions(BaseModel):
27
28
  """Represents the actions attached to an event."""
28
29
 
29
- model_config = ConfigDict(extra='forbid')
30
+ model_config = ConfigDict(
31
+ extra='forbid',
32
+ alias_generator=alias_generators.to_camel,
33
+ populate_by_name=True,
34
+ )
35
+ """The pydantic model config."""
30
36
 
31
37
  skip_summarization: Optional[bool] = None
32
38
  """If true, it won't call model to summarize function response.
@@ -23,5 +23,6 @@ class Example(BaseModel):
23
23
  input: The input content for the example.
24
24
  output: The expected output content for the example.
25
25
  """
26
+
26
27
  input: types.Content
27
28
  output: list[types.Content]
@@ -15,8 +15,9 @@
15
15
  """Utility functions for converting examples to a string that can be used in system instructions in the prompt."""
16
16
 
17
17
  import logging
18
- from typing import Optional, Union
18
+ from typing import Optional
19
19
  from typing import TYPE_CHECKING
20
+ from typing import Union
20
21
 
21
22
  from .base_example_provider import BaseExampleProvider
22
23
  from .example import Example
@@ -24,7 +25,7 @@ from .example import Example
24
25
  if TYPE_CHECKING:
25
26
  from ..sessions.session import Session
26
27
 
27
- logger = logging.getLogger(__name__)
28
+ logger = logging.getLogger("google_adk." + __name__)
28
29
 
29
30
  # Constant parts of the example string
30
31
  _EXAMPLES_INTRO = (
@@ -11,4 +11,3 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
@@ -22,7 +22,6 @@ import dataclasses
22
22
  import os
23
23
  import re
24
24
  from typing import AsyncGenerator
25
- from typing import Generator
26
25
  from typing import Optional
27
26
  from typing import TYPE_CHECKING
28
27
 
@@ -31,6 +30,7 @@ from typing_extensions import override
31
30
 
32
31
  from ...agents.invocation_context import InvocationContext
33
32
  from ...code_executors.base_code_executor import BaseCodeExecutor
33
+ from ...code_executors.built_in_code_executor import BuiltInCodeExecutor
34
34
  from ...code_executors.code_execution_utils import CodeExecutionInput
35
35
  from ...code_executors.code_execution_utils import CodeExecutionResult
36
36
  from ...code_executors.code_execution_utils import CodeExecutionUtils
@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
122
122
  if not invocation_context.agent.code_executor:
123
123
  return
124
124
 
125
- for event in _run_pre_processor(invocation_context, llm_request):
125
+ async for event in _run_pre_processor(invocation_context, llm_request):
126
126
  yield event
127
127
 
128
128
  # Convert the code execution parts to text parts.
@@ -152,17 +152,17 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
152
152
  if llm_response.partial:
153
153
  return
154
154
 
155
- for event in _run_post_processor(invocation_context, llm_response):
155
+ async for event in _run_post_processor(invocation_context, llm_response):
156
156
  yield event
157
157
 
158
158
 
159
159
  response_processor = _CodeExecutionResponseProcessor()
160
160
 
161
161
 
162
- def _run_pre_processor(
162
+ async def _run_pre_processor(
163
163
  invocation_context: InvocationContext,
164
164
  llm_request: LlmRequest,
165
- ) -> Generator[Event, None, None]:
165
+ ) -> AsyncGenerator[Event, None]:
166
166
  """Pre-process the user message by adding the user message to the Colab notebook."""
167
167
  from ...agents.llm_agent import LlmAgent
168
168
 
@@ -174,6 +174,11 @@ def _run_pre_processor(
174
174
 
175
175
  if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
176
176
  return
177
+
178
+ if isinstance(code_executor, BuiltInCodeExecutor):
179
+ code_executor.process_llm_request(llm_request)
180
+ return
181
+
177
182
  if not code_executor.optimize_data_file:
178
183
  return
179
184
 
@@ -242,17 +247,17 @@ def _run_pre_processor(
242
247
  code_executor_context.add_processed_file_names([file.name])
243
248
 
244
249
  # Emit the execution result, and add it to the LLM request.
245
- execution_result_event = _post_process_code_execution_result(
250
+ execution_result_event = await _post_process_code_execution_result(
246
251
  invocation_context, code_executor_context, code_execution_result
247
252
  )
248
253
  yield execution_result_event
249
254
  llm_request.contents.append(copy.deepcopy(execution_result_event.content))
250
255
 
251
256
 
252
- def _run_post_processor(
257
+ async def _run_post_processor(
253
258
  invocation_context: InvocationContext,
254
259
  llm_response,
255
- ) -> Generator[Event, None, None]:
260
+ ) -> AsyncGenerator[Event, None]:
256
261
  """Post-process the model response by extracting and executing the first code block."""
257
262
  agent = invocation_context.agent
258
263
  code_executor = agent.code_executor
@@ -262,6 +267,9 @@ def _run_post_processor(
262
267
  if not llm_response or not llm_response.content:
263
268
  return
264
269
 
270
+ if isinstance(code_executor, BuiltInCodeExecutor):
271
+ return
272
+
265
273
  code_executor_context = CodeExecutorContext(invocation_context.session.state)
266
274
  # Skip if the error count exceeds the max retry attempts.
267
275
  if (
@@ -305,7 +313,7 @@ def _run_post_processor(
305
313
  code_execution_result.stdout,
306
314
  code_execution_result.stderr,
307
315
  )
308
- yield _post_process_code_execution_result(
316
+ yield await _post_process_code_execution_result(
309
317
  invocation_context, code_executor_context, code_execution_result
310
318
  )
311
319
 
@@ -375,7 +383,7 @@ def _get_or_set_execution_id(
375
383
  return execution_id
376
384
 
377
385
 
378
- def _post_process_code_execution_result(
386
+ async def _post_process_code_execution_result(
379
387
  invocation_context: InvocationContext,
380
388
  code_executor_context: CodeExecutorContext,
381
389
  code_execution_result: CodeExecutionResult,
@@ -406,7 +414,7 @@ def _post_process_code_execution_result(
406
414
 
407
415
  # Handle output files.
408
416
  for output_file in code_execution_result.output_files:
409
- version = invocation_context.artifact_service.save_artifact(
417
+ version = await invocation_context.artifact_service.save_artifact(
410
418
  app_name=invocation_context.app_name,
411
419
  user_id=invocation_context.user_id,
412
420
  session_id=invocation_context.session.id,
@@ -25,8 +25,9 @@ if TYPE_CHECKING:
25
25
  class AudioTranscriber:
26
26
  """Transcribes audio using Google Cloud Speech-to-Text."""
27
27
 
28
- def __init__(self):
29
- self.client = speech.SpeechClient()
28
+ def __init__(self, init_client=False):
29
+ if init_client:
30
+ self.client = speech.SpeechClient()
30
31
 
31
32
  def transcribe_file(
32
33
  self, invocation_context: InvocationContext
@@ -84,7 +85,7 @@ class AudioTranscriber:
84
85
 
85
86
  # Step2: transcription
86
87
  for speaker, data in bundled_audio:
87
- if speaker == 'user':
88
+ if isinstance(data, genai_types.Blob):
88
89
  audio = speech.RecognitionAudio(content=data)
89
90
 
90
91
  config = speech.RecognitionConfig(
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from abc import ABC
18
18
  import asyncio
19
+ import inspect
19
20
  import logging
20
21
  from typing import AsyncGenerator
21
22
  from typing import cast
@@ -28,6 +29,7 @@ from ...agents.base_agent import BaseAgent
28
29
  from ...agents.callback_context import CallbackContext
29
30
  from ...agents.invocation_context import InvocationContext
30
31
  from ...agents.live_request_queue import LiveRequestQueue
32
+ from ...agents.readonly_context import ReadonlyContext
31
33
  from ...agents.run_config import StreamingMode
32
34
  from ...agents.transcription_entry import TranscriptionEntry
33
35
  from ...events.event import Event
@@ -46,7 +48,7 @@ if TYPE_CHECKING:
46
48
  from ._base_llm_processor import BaseLlmRequestProcessor
47
49
  from ._base_llm_processor import BaseLlmResponseProcessor
48
50
 
49
- logger = logging.getLogger(__name__)
51
+ logger = logging.getLogger('google_adk.' + __name__)
50
52
 
51
53
 
52
54
  class BaseLlmFlow(ABC):
@@ -87,7 +89,12 @@ class BaseLlmFlow(ABC):
87
89
  if invocation_context.transcription_cache:
88
90
  from . import audio_transcriber
89
91
 
90
- audio_transcriber = audio_transcriber.AudioTranscriber()
92
+ audio_transcriber = audio_transcriber.AudioTranscriber(
93
+ init_client=True
94
+ if invocation_context.run_config.input_audio_transcription
95
+ is None
96
+ else False
97
+ )
91
98
  contents = audio_transcriber.transcribe_file(invocation_context)
92
99
  logger.debug('Sending history to model: %s', contents)
93
100
  await llm_connection.send_history(contents)
@@ -128,6 +135,18 @@ class BaseLlmFlow(ABC):
128
135
  # cancel the tasks that belongs to the closed connection.
129
136
  send_task.cancel()
130
137
  await llm_connection.close()
138
+ if (
139
+ event.content
140
+ and event.content.parts
141
+ and event.content.parts[0].function_response
142
+ and event.content.parts[0].function_response.name
143
+ == 'task_completed'
144
+ ):
145
+ # this is used for sequential agent to signal the end of the agent.
146
+ await asyncio.sleep(1)
147
+ # cancel the tasks that belongs to the closed connection.
148
+ send_task.cancel()
149
+ return
131
150
  finally:
132
151
  # Clean up
133
152
  if not send_task.done():
@@ -175,9 +194,12 @@ class BaseLlmFlow(ABC):
175
194
  # Cache audio data here for transcription
176
195
  if not invocation_context.transcription_cache:
177
196
  invocation_context.transcription_cache = []
178
- invocation_context.transcription_cache.append(
179
- TranscriptionEntry(role='user', data=live_request.blob)
180
- )
197
+ if not invocation_context.run_config.input_audio_transcription:
198
+ # if the live model's input transcription is not enabled, then
199
+ # we use our onwn audio transcriber to achieve that.
200
+ invocation_context.transcription_cache.append(
201
+ TranscriptionEntry(role='user', data=live_request.blob)
202
+ )
181
203
  await llm_connection.send_realtime(live_request.blob)
182
204
  if live_request.content:
183
205
  await llm_connection.send_content(live_request.content)
@@ -190,6 +212,25 @@ class BaseLlmFlow(ABC):
190
212
  llm_request: LlmRequest,
191
213
  ) -> AsyncGenerator[Event, None]:
192
214
  """Receive data from model and process events using BaseLlmConnection."""
215
+
216
+ def get_author_for_event(llm_response):
217
+ """Get the author of the event.
218
+
219
+ When the model returns transcription, the author is "user". Otherwise, the
220
+ author is the agent name(not 'model').
221
+
222
+ Args:
223
+ llm_response: The LLM response from the LLM call.
224
+ """
225
+ if (
226
+ llm_response
227
+ and llm_response.content
228
+ and llm_response.content.role == 'user'
229
+ ):
230
+ return 'user'
231
+ else:
232
+ return invocation_context.agent.name
233
+
193
234
  assert invocation_context.live_request_queue
194
235
  try:
195
236
  while True:
@@ -197,7 +238,7 @@ class BaseLlmFlow(ABC):
197
238
  model_response_event = Event(
198
239
  id=Event.new_id(),
199
240
  invocation_id=invocation_context.invocation_id,
200
- author=invocation_context.agent.name,
241
+ author=get_author_for_event(llm_response),
201
242
  )
202
243
  async for event in self._postprocess_live(
203
244
  invocation_context,
@@ -208,13 +249,20 @@ class BaseLlmFlow(ABC):
208
249
  if (
209
250
  event.content
210
251
  and event.content.parts
211
- and event.content.parts[0].text
252
+ and event.content.parts[0].inline_data is None
212
253
  and not event.partial
213
254
  ):
255
+ # This can be either user data or transcription data.
256
+ # when output transcription enabled, it will contain model's
257
+ # transcription.
258
+ # when input transcription enabled, it will contain user
259
+ # transcription.
214
260
  if not invocation_context.transcription_cache:
215
261
  invocation_context.transcription_cache = []
216
262
  invocation_context.transcription_cache.append(
217
- TranscriptionEntry(role='model', data=event.content)
263
+ TranscriptionEntry(
264
+ role=event.content.role, data=event.content
265
+ )
218
266
  )
219
267
  yield event
220
268
  # Give opportunity for other tasks to run.
@@ -261,6 +309,8 @@ class BaseLlmFlow(ABC):
261
309
  async for event in self._postprocess_async(
262
310
  invocation_context, llm_request, llm_response, model_response_event
263
311
  ):
312
+ # Update the mutable event id to avoid conflict
313
+ model_response_event.id = Event.new_id()
264
314
  yield event
265
315
 
266
316
  async def _preprocess_async(
@@ -278,7 +328,9 @@ class BaseLlmFlow(ABC):
278
328
  yield event
279
329
 
280
330
  # Run processors for tools.
281
- for tool in agent.canonical_tools:
331
+ for tool in await agent.canonical_tools(
332
+ ReadonlyContext(invocation_context)
333
+ ):
282
334
  tool_context = ToolContext(invocation_context)
283
335
  await tool.process_llm_request(
284
336
  tool_context=tool_context, llm_request=llm_request
@@ -437,7 +489,7 @@ class BaseLlmFlow(ABC):
437
489
  model_response_event: Event,
438
490
  ) -> AsyncGenerator[LlmResponse, None]:
439
491
  # Runs before_model_callback if it exists.
440
- if response := self._handle_before_model_callback(
492
+ if response := await self._handle_before_model_callback(
441
493
  invocation_context, llm_request, model_response_event
442
494
  ):
443
495
  yield response
@@ -450,7 +502,7 @@ class BaseLlmFlow(ABC):
450
502
  invocation_context.live_request_queue = LiveRequestQueue()
451
503
  async for llm_response in self.run_live(invocation_context):
452
504
  # Runs after_model_callback if it exists.
453
- if altered_llm_response := self._handle_after_model_callback(
505
+ if altered_llm_response := await self._handle_after_model_callback(
454
506
  invocation_context, llm_response, model_response_event
455
507
  ):
456
508
  llm_response = altered_llm_response
@@ -479,14 +531,14 @@ class BaseLlmFlow(ABC):
479
531
  llm_response,
480
532
  )
481
533
  # Runs after_model_callback if it exists.
482
- if altered_llm_response := self._handle_after_model_callback(
534
+ if altered_llm_response := await self._handle_after_model_callback(
483
535
  invocation_context, llm_response, model_response_event
484
536
  ):
485
537
  llm_response = altered_llm_response
486
538
 
487
539
  yield llm_response
488
540
 
489
- def _handle_before_model_callback(
541
+ async def _handle_before_model_callback(
490
542
  self,
491
543
  invocation_context: InvocationContext,
492
544
  llm_request: LlmRequest,
@@ -498,17 +550,23 @@ class BaseLlmFlow(ABC):
498
550
  if not isinstance(agent, LlmAgent):
499
551
  return
500
552
 
501
- if not agent.before_model_callback:
553
+ if not agent.canonical_before_model_callbacks:
502
554
  return
503
555
 
504
556
  callback_context = CallbackContext(
505
557
  invocation_context, event_actions=model_response_event.actions
506
558
  )
507
- return agent.before_model_callback(
508
- callback_context=callback_context, llm_request=llm_request
509
- )
510
559
 
511
- def _handle_after_model_callback(
560
+ for callback in agent.canonical_before_model_callbacks:
561
+ before_model_callback_content = callback(
562
+ callback_context=callback_context, llm_request=llm_request
563
+ )
564
+ if inspect.isawaitable(before_model_callback_content):
565
+ before_model_callback_content = await before_model_callback_content
566
+ if before_model_callback_content:
567
+ return before_model_callback_content
568
+
569
+ async def _handle_after_model_callback(
512
570
  self,
513
571
  invocation_context: InvocationContext,
514
572
  llm_response: LlmResponse,
@@ -520,15 +578,21 @@ class BaseLlmFlow(ABC):
520
578
  if not isinstance(agent, LlmAgent):
521
579
  return
522
580
 
523
- if not agent.after_model_callback:
581
+ if not agent.canonical_after_model_callbacks:
524
582
  return
525
583
 
526
584
  callback_context = CallbackContext(
527
585
  invocation_context, event_actions=model_response_event.actions
528
586
  )
529
- return agent.after_model_callback(
530
- callback_context=callback_context, llm_response=llm_response
531
- )
587
+
588
+ for callback in agent.canonical_after_model_callbacks:
589
+ after_model_callback_content = callback(
590
+ callback_context=callback_context, llm_response=llm_response
591
+ )
592
+ if inspect.isawaitable(after_model_callback_content):
593
+ after_model_callback_content = await after_model_callback_content
594
+ if after_model_callback_content:
595
+ return after_model_callback_content
532
596
 
533
597
  def _finalize_model_response_event(
534
598
  self,