google-adk 0.4.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/active_streaming_tool.py +1 -0
- google/adk/agents/base_agent.py +91 -47
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +4 -9
- google/adk/agents/invocation_context.py +1 -0
- google/adk/agents/langgraph_agent.py +1 -0
- google/adk/agents/live_request_queue.py +1 -0
- google/adk/agents/llm_agent.py +172 -35
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +5 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +5 -2
- google/adk/artifacts/base_artifact_service.py +5 -10
- google/adk/artifacts/gcs_artifact_service.py +9 -9
- google/adk/artifacts/in_memory_artifact_service.py +6 -6
- google/adk/auth/auth_credential.py +9 -5
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-HWIBUY2R.js → main-QOEMUXM4.js} +58 -58
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +181 -106
- google/adk/cli/cli_tools_click.py +147 -62
- google/adk/cli/fast_api.py +340 -158
- google/adk/cli/fast_api.py.orig +822 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_constants.py +1 -0
- google/adk/evaluation/evaluation_generator.py +89 -114
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +107 -3
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +7 -1
- google/adk/events/event_actions.py +7 -1
- google/adk/examples/example.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/__init__.py +0 -1
- google/adk/flows/llm_flows/_code_execution.py +19 -11
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +86 -22
- google/adk/flows/llm_flows/basic.py +3 -0
- google/adk/flows/llm_flows/functions.py +10 -9
- google/adk/flows/llm_flows/instructions.py +28 -9
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +25 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +59 -27
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +40 -17
- google/adk/models/anthropic_llm.py +36 -11
- google/adk/models/base_llm.py +45 -4
- google/adk/models/gemini_llm_connection.py +15 -2
- google/adk/models/google_llm.py +9 -44
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +94 -38
- google/adk/models/llm_request.py +1 -1
- google/adk/models/llm_response.py +15 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +68 -44
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +14 -0
- google/adk/sessions/base_session_service.py +8 -32
- google/adk/sessions/database_session_service.py +58 -61
- google/adk/sessions/in_memory_session_service.py +108 -26
- google/adk/sessions/session.py +4 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +16 -13
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +29 -25
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_artifacts_tool.py +4 -4
- google/adk/tools/load_memory_tool.py +16 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +1 -1
- google/adk/tools/mcp_tool/mcp_session_manager.py +167 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/common/common.py +2 -5
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +43 -33
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/tool_context.py +4 -4
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.4.0.dist-info/RECORD +0 -179
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.4.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,23 +12,127 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any, Optional
|
16
16
|
|
17
|
+
from deprecated import deprecated
|
18
|
+
from google.genai import types as genai_types
|
17
19
|
import pandas as pd
|
18
20
|
from tabulate import tabulate
|
21
|
+
from typing_extensions import override
|
19
22
|
from vertexai.preview.evaluation import EvalTask
|
20
23
|
from vertexai.preview.evaluation import MetricPromptTemplateExamples
|
21
24
|
|
25
|
+
from .eval_case import IntermediateData
|
26
|
+
from .eval_case import Invocation
|
27
|
+
from .evaluator import EvalStatus
|
28
|
+
from .evaluator import EvaluationResult
|
29
|
+
from .evaluator import Evaluator
|
30
|
+
from .evaluator import PerInvocationResult
|
22
31
|
|
23
|
-
|
32
|
+
|
33
|
+
class ResponseEvaluator(Evaluator):
|
24
34
|
"""Runs response evaluation for agents."""
|
25
35
|
|
36
|
+
def __init__(self, threshold: float, metric_name: str):
|
37
|
+
if "response_evaluation_score" == metric_name:
|
38
|
+
self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
|
39
|
+
elif "response_match_score" == metric_name:
|
40
|
+
self._metric_name = "rouge_1"
|
41
|
+
else:
|
42
|
+
raise ValueError(f"`{metric_name}` is not supported.")
|
43
|
+
|
44
|
+
self._threshold = threshold
|
45
|
+
|
46
|
+
@override
|
47
|
+
def evaluate_invocations(
|
48
|
+
self,
|
49
|
+
actual_invocations: list[Invocation],
|
50
|
+
expected_invocations: list[Invocation],
|
51
|
+
) -> EvaluationResult:
|
52
|
+
total_score = 0.0
|
53
|
+
num_invocations = 0
|
54
|
+
per_invocation_results = []
|
55
|
+
for actual, expected in zip(actual_invocations, expected_invocations):
|
56
|
+
prompt = self._get_text(expected.user_content)
|
57
|
+
reference = self._get_text(expected.final_response)
|
58
|
+
response = self._get_text(actual.final_response)
|
59
|
+
actual_tool_use = self._get_tool_use_trajectory(actual.intermediate_data)
|
60
|
+
reference_trajectory = self._get_tool_use_trajectory(
|
61
|
+
expected.intermediate_data
|
62
|
+
)
|
63
|
+
|
64
|
+
eval_case = {
|
65
|
+
"prompt": prompt,
|
66
|
+
"reference": reference,
|
67
|
+
"response": response,
|
68
|
+
"actual_tool_user": actual_tool_use,
|
69
|
+
"reference_trajectory": reference_trajectory,
|
70
|
+
}
|
71
|
+
|
72
|
+
eval_case_result = ResponseEvaluator._perform_eval(
|
73
|
+
pd.DataFrame([eval_case]), [self._metric_name]
|
74
|
+
)
|
75
|
+
score = self._get_score(eval_case_result)
|
76
|
+
per_invocation_results.append(
|
77
|
+
PerInvocationResult(
|
78
|
+
actual_invocation=actual,
|
79
|
+
expected_invocation=expected,
|
80
|
+
score=score,
|
81
|
+
eval_status=self._get_eval_status(score),
|
82
|
+
)
|
83
|
+
)
|
84
|
+
total_score += score
|
85
|
+
num_invocations += 1
|
86
|
+
|
87
|
+
if per_invocation_results:
|
88
|
+
overall_score = total_score / num_invocations
|
89
|
+
return EvaluationResult(
|
90
|
+
overall_score=overall_score,
|
91
|
+
overall_eval_status=self._get_eval_status(overall_score),
|
92
|
+
per_invocation_results=per_invocation_results,
|
93
|
+
)
|
94
|
+
|
95
|
+
return EvaluationResult()
|
96
|
+
|
97
|
+
def _get_text(self, content: Optional[genai_types.Content]) -> str:
|
98
|
+
if content and content.parts:
|
99
|
+
return "\n".join([p.text for p in content.parts if p.text])
|
100
|
+
|
101
|
+
return ""
|
102
|
+
|
103
|
+
def _get_tool_use_trajectory(
|
104
|
+
self, intermediate_data: Optional[IntermediateData]
|
105
|
+
) -> list[dict[str, Any]]:
|
106
|
+
tool_use_trajectory = []
|
107
|
+
if not intermediate_data:
|
108
|
+
return tool_use_trajectory
|
109
|
+
|
110
|
+
for function_call in intermediate_data.tool_uses:
|
111
|
+
tool_use_trajectory.append({
|
112
|
+
"tool_name": function_call.name,
|
113
|
+
"tool_input": function_call.args or {},
|
114
|
+
})
|
115
|
+
|
116
|
+
return tool_use_trajectory
|
117
|
+
|
118
|
+
def _get_score(self, eval_result) -> float:
|
119
|
+
return eval_result.summary_metrics[f"{self._metric_name}/mean"].item()
|
120
|
+
|
121
|
+
def _get_eval_status(self, score: float):
|
122
|
+
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
123
|
+
|
26
124
|
@staticmethod
|
125
|
+
@deprecated(
|
126
|
+
reason=(
|
127
|
+
"This method has been deprecated and will be removed soon. Please use"
|
128
|
+
" evaluate_invocations instead."
|
129
|
+
)
|
130
|
+
)
|
27
131
|
def evaluate(
|
28
132
|
raw_eval_dataset: list[list[dict[str, Any]]],
|
29
133
|
evaluation_criteria: list[str],
|
30
134
|
*,
|
31
|
-
print_detailed_results: bool = False
|
135
|
+
print_detailed_results: bool = False,
|
32
136
|
):
|
33
137
|
r"""Returns the value of requested evaluation metrics.
|
34
138
|
|
@@ -12,18 +12,98 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any
|
15
|
+
from typing import Any, cast
|
16
16
|
|
17
|
+
from deprecated import deprecated
|
18
|
+
from google.genai import types as genai_types
|
17
19
|
import pandas as pd
|
18
20
|
from tabulate import tabulate
|
21
|
+
from typing_extensions import override
|
19
22
|
|
23
|
+
from .eval_case import Invocation
|
20
24
|
from .evaluation_constants import EvalConstants
|
25
|
+
from .evaluator import EvalStatus
|
26
|
+
from .evaluator import EvaluationResult
|
27
|
+
from .evaluator import Evaluator
|
28
|
+
from .evaluator import PerInvocationResult
|
21
29
|
|
22
30
|
|
23
|
-
class TrajectoryEvaluator:
|
31
|
+
class TrajectoryEvaluator(Evaluator):
|
24
32
|
"""Evaluates tool use trajectories for accuracy."""
|
25
33
|
|
34
|
+
def __init__(self, threshold: float):
|
35
|
+
self._threshold = threshold
|
36
|
+
|
37
|
+
@override
|
38
|
+
def evaluate_invocations(
|
39
|
+
self,
|
40
|
+
actual_invocations: list[Invocation],
|
41
|
+
expected_invocations: list[Invocation],
|
42
|
+
) -> EvaluationResult:
|
43
|
+
"""Returns EvaluationResult after performing evaluations using actual and expected invocations."""
|
44
|
+
total_tool_use_accuracy = 0.0
|
45
|
+
num_invocations = 0
|
46
|
+
per_invocation_results = []
|
47
|
+
|
48
|
+
for actual, expected in zip(actual_invocations, expected_invocations):
|
49
|
+
actual_tool_uses = (
|
50
|
+
actual.intermediate_data.tool_uses if actual.intermediate_data else []
|
51
|
+
)
|
52
|
+
expected_tool_uses = (
|
53
|
+
expected.intermediate_data.tool_uses
|
54
|
+
if expected.intermediate_data
|
55
|
+
else []
|
56
|
+
)
|
57
|
+
tool_use_accuracy = (
|
58
|
+
1.0
|
59
|
+
if self._are_tool_calls_equal(actual_tool_uses, expected_tool_uses)
|
60
|
+
else 0.0
|
61
|
+
)
|
62
|
+
per_invocation_results.append(
|
63
|
+
PerInvocationResult(
|
64
|
+
actual_invocation=actual,
|
65
|
+
expected_invocation=expected,
|
66
|
+
score=tool_use_accuracy,
|
67
|
+
eval_status=self._get_eval_status(tool_use_accuracy),
|
68
|
+
)
|
69
|
+
)
|
70
|
+
total_tool_use_accuracy += tool_use_accuracy
|
71
|
+
num_invocations += 1
|
72
|
+
|
73
|
+
if per_invocation_results:
|
74
|
+
overall_score = total_tool_use_accuracy / num_invocations
|
75
|
+
return EvaluationResult(
|
76
|
+
overall_score=overall_score,
|
77
|
+
overall_eval_status=self._get_eval_status(overall_score),
|
78
|
+
per_invocation_results=per_invocation_results,
|
79
|
+
)
|
80
|
+
|
81
|
+
return EvaluationResult()
|
82
|
+
|
83
|
+
def _are_tool_calls_equal(
|
84
|
+
self,
|
85
|
+
actual_tool_calls: list[genai_types.FunctionCall],
|
86
|
+
expected_tool_calls: list[genai_types.FunctionCall],
|
87
|
+
) -> bool:
|
88
|
+
if len(actual_tool_calls) != len(expected_tool_calls):
|
89
|
+
return False
|
90
|
+
|
91
|
+
for actual, expected in zip(actual_tool_calls, expected_tool_calls):
|
92
|
+
if actual.name != expected.name or actual.args != expected.args:
|
93
|
+
return False
|
94
|
+
|
95
|
+
return True
|
96
|
+
|
97
|
+
def _get_eval_status(self, score: float):
|
98
|
+
return EvalStatus.PASSED if score >= self._threshold else EvalStatus.FAILED
|
99
|
+
|
26
100
|
@staticmethod
|
101
|
+
@deprecated(
|
102
|
+
reason=(
|
103
|
+
"This method has been deprecated and will be removed soon. Please use"
|
104
|
+
" evaluate_invocations instead."
|
105
|
+
)
|
106
|
+
)
|
27
107
|
def evaluate(
|
28
108
|
eval_dataset: list[list[dict[str, Any]]],
|
29
109
|
*,
|
@@ -137,6 +217,7 @@ class TrajectoryEvaluator:
|
|
137
217
|
return new_row, failure
|
138
218
|
|
139
219
|
@staticmethod
|
220
|
+
@deprecated()
|
140
221
|
def are_tools_equal(list_a_original, list_b_original):
|
141
222
|
# Remove other entries that we don't want to evaluate
|
142
223
|
list_a = [
|
google/adk/events/event.py
CHANGED
@@ -19,6 +19,7 @@ import string
|
|
19
19
|
from typing import Optional
|
20
20
|
|
21
21
|
from google.genai import types
|
22
|
+
from pydantic import alias_generators
|
22
23
|
from pydantic import ConfigDict
|
23
24
|
from pydantic import Field
|
24
25
|
|
@@ -46,8 +47,13 @@ class Event(LlmResponse):
|
|
46
47
|
"""
|
47
48
|
|
48
49
|
model_config = ConfigDict(
|
49
|
-
extra='forbid',
|
50
|
+
extra='forbid',
|
51
|
+
ser_json_bytes='base64',
|
52
|
+
val_json_bytes='base64',
|
53
|
+
alias_generator=alias_generators.to_camel,
|
54
|
+
populate_by_name=True,
|
50
55
|
)
|
56
|
+
"""The pydantic model config."""
|
51
57
|
|
52
58
|
# TODO: revert to be required after spark migration
|
53
59
|
invocation_id: str = ''
|
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
from typing import Optional
|
18
18
|
|
19
|
+
from pydantic import alias_generators
|
19
20
|
from pydantic import BaseModel
|
20
21
|
from pydantic import ConfigDict
|
21
22
|
from pydantic import Field
|
@@ -26,7 +27,12 @@ from ..auth.auth_tool import AuthConfig
|
|
26
27
|
class EventActions(BaseModel):
|
27
28
|
"""Represents the actions attached to an event."""
|
28
29
|
|
29
|
-
model_config = ConfigDict(
|
30
|
+
model_config = ConfigDict(
|
31
|
+
extra='forbid',
|
32
|
+
alias_generator=alias_generators.to_camel,
|
33
|
+
populate_by_name=True,
|
34
|
+
)
|
35
|
+
"""The pydantic model config."""
|
30
36
|
|
31
37
|
skip_summarization: Optional[bool] = None
|
32
38
|
"""If true, it won't call model to summarize function response.
|
google/adk/examples/example.py
CHANGED
@@ -15,8 +15,9 @@
|
|
15
15
|
"""Utility functions for converting examples to a string that can be used in system instructions in the prompt."""
|
16
16
|
|
17
17
|
import logging
|
18
|
-
from typing import Optional
|
18
|
+
from typing import Optional
|
19
19
|
from typing import TYPE_CHECKING
|
20
|
+
from typing import Union
|
20
21
|
|
21
22
|
from .base_example_provider import BaseExampleProvider
|
22
23
|
from .example import Example
|
@@ -24,7 +25,7 @@ from .example import Example
|
|
24
25
|
if TYPE_CHECKING:
|
25
26
|
from ..sessions.session import Session
|
26
27
|
|
27
|
-
logger = logging.getLogger(__name__)
|
28
|
+
logger = logging.getLogger("google_adk." + __name__)
|
28
29
|
|
29
30
|
# Constant parts of the example string
|
30
31
|
_EXAMPLES_INTRO = (
|
google/adk/flows/__init__.py
CHANGED
@@ -22,7 +22,6 @@ import dataclasses
|
|
22
22
|
import os
|
23
23
|
import re
|
24
24
|
from typing import AsyncGenerator
|
25
|
-
from typing import Generator
|
26
25
|
from typing import Optional
|
27
26
|
from typing import TYPE_CHECKING
|
28
27
|
|
@@ -31,6 +30,7 @@ from typing_extensions import override
|
|
31
30
|
|
32
31
|
from ...agents.invocation_context import InvocationContext
|
33
32
|
from ...code_executors.base_code_executor import BaseCodeExecutor
|
33
|
+
from ...code_executors.built_in_code_executor import BuiltInCodeExecutor
|
34
34
|
from ...code_executors.code_execution_utils import CodeExecutionInput
|
35
35
|
from ...code_executors.code_execution_utils import CodeExecutionResult
|
36
36
|
from ...code_executors.code_execution_utils import CodeExecutionUtils
|
@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
|
122
122
|
if not invocation_context.agent.code_executor:
|
123
123
|
return
|
124
124
|
|
125
|
-
for event in _run_pre_processor(invocation_context, llm_request):
|
125
|
+
async for event in _run_pre_processor(invocation_context, llm_request):
|
126
126
|
yield event
|
127
127
|
|
128
128
|
# Convert the code execution parts to text parts.
|
@@ -152,17 +152,17 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
|
152
152
|
if llm_response.partial:
|
153
153
|
return
|
154
154
|
|
155
|
-
for event in _run_post_processor(invocation_context, llm_response):
|
155
|
+
async for event in _run_post_processor(invocation_context, llm_response):
|
156
156
|
yield event
|
157
157
|
|
158
158
|
|
159
159
|
response_processor = _CodeExecutionResponseProcessor()
|
160
160
|
|
161
161
|
|
162
|
-
def _run_pre_processor(
|
162
|
+
async def _run_pre_processor(
|
163
163
|
invocation_context: InvocationContext,
|
164
164
|
llm_request: LlmRequest,
|
165
|
-
) ->
|
165
|
+
) -> AsyncGenerator[Event, None]:
|
166
166
|
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
167
167
|
from ...agents.llm_agent import LlmAgent
|
168
168
|
|
@@ -174,6 +174,11 @@ def _run_pre_processor(
|
|
174
174
|
|
175
175
|
if not code_executor or not isinstance(code_executor, BaseCodeExecutor):
|
176
176
|
return
|
177
|
+
|
178
|
+
if isinstance(code_executor, BuiltInCodeExecutor):
|
179
|
+
code_executor.process_llm_request(llm_request)
|
180
|
+
return
|
181
|
+
|
177
182
|
if not code_executor.optimize_data_file:
|
178
183
|
return
|
179
184
|
|
@@ -242,17 +247,17 @@ def _run_pre_processor(
|
|
242
247
|
code_executor_context.add_processed_file_names([file.name])
|
243
248
|
|
244
249
|
# Emit the execution result, and add it to the LLM request.
|
245
|
-
execution_result_event = _post_process_code_execution_result(
|
250
|
+
execution_result_event = await _post_process_code_execution_result(
|
246
251
|
invocation_context, code_executor_context, code_execution_result
|
247
252
|
)
|
248
253
|
yield execution_result_event
|
249
254
|
llm_request.contents.append(copy.deepcopy(execution_result_event.content))
|
250
255
|
|
251
256
|
|
252
|
-
def _run_post_processor(
|
257
|
+
async def _run_post_processor(
|
253
258
|
invocation_context: InvocationContext,
|
254
259
|
llm_response,
|
255
|
-
) ->
|
260
|
+
) -> AsyncGenerator[Event, None]:
|
256
261
|
"""Post-process the model response by extracting and executing the first code block."""
|
257
262
|
agent = invocation_context.agent
|
258
263
|
code_executor = agent.code_executor
|
@@ -262,6 +267,9 @@ def _run_post_processor(
|
|
262
267
|
if not llm_response or not llm_response.content:
|
263
268
|
return
|
264
269
|
|
270
|
+
if isinstance(code_executor, BuiltInCodeExecutor):
|
271
|
+
return
|
272
|
+
|
265
273
|
code_executor_context = CodeExecutorContext(invocation_context.session.state)
|
266
274
|
# Skip if the error count exceeds the max retry attempts.
|
267
275
|
if (
|
@@ -305,7 +313,7 @@ def _run_post_processor(
|
|
305
313
|
code_execution_result.stdout,
|
306
314
|
code_execution_result.stderr,
|
307
315
|
)
|
308
|
-
yield _post_process_code_execution_result(
|
316
|
+
yield await _post_process_code_execution_result(
|
309
317
|
invocation_context, code_executor_context, code_execution_result
|
310
318
|
)
|
311
319
|
|
@@ -375,7 +383,7 @@ def _get_or_set_execution_id(
|
|
375
383
|
return execution_id
|
376
384
|
|
377
385
|
|
378
|
-
def _post_process_code_execution_result(
|
386
|
+
async def _post_process_code_execution_result(
|
379
387
|
invocation_context: InvocationContext,
|
380
388
|
code_executor_context: CodeExecutorContext,
|
381
389
|
code_execution_result: CodeExecutionResult,
|
@@ -406,7 +414,7 @@ def _post_process_code_execution_result(
|
|
406
414
|
|
407
415
|
# Handle output files.
|
408
416
|
for output_file in code_execution_result.output_files:
|
409
|
-
version = invocation_context.artifact_service.save_artifact(
|
417
|
+
version = await invocation_context.artifact_service.save_artifact(
|
410
418
|
app_name=invocation_context.app_name,
|
411
419
|
user_id=invocation_context.user_id,
|
412
420
|
session_id=invocation_context.session.id,
|
@@ -25,8 +25,9 @@ if TYPE_CHECKING:
|
|
25
25
|
class AudioTranscriber:
|
26
26
|
"""Transcribes audio using Google Cloud Speech-to-Text."""
|
27
27
|
|
28
|
-
def __init__(self):
|
29
|
-
|
28
|
+
def __init__(self, init_client=False):
|
29
|
+
if init_client:
|
30
|
+
self.client = speech.SpeechClient()
|
30
31
|
|
31
32
|
def transcribe_file(
|
32
33
|
self, invocation_context: InvocationContext
|
@@ -84,7 +85,7 @@ class AudioTranscriber:
|
|
84
85
|
|
85
86
|
# Step2: transcription
|
86
87
|
for speaker, data in bundled_audio:
|
87
|
-
if
|
88
|
+
if isinstance(data, genai_types.Blob):
|
88
89
|
audio = speech.RecognitionAudio(content=data)
|
89
90
|
|
90
91
|
config = speech.RecognitionConfig(
|
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
from abc import ABC
|
18
18
|
import asyncio
|
19
|
+
import inspect
|
19
20
|
import logging
|
20
21
|
from typing import AsyncGenerator
|
21
22
|
from typing import cast
|
@@ -28,6 +29,7 @@ from ...agents.base_agent import BaseAgent
|
|
28
29
|
from ...agents.callback_context import CallbackContext
|
29
30
|
from ...agents.invocation_context import InvocationContext
|
30
31
|
from ...agents.live_request_queue import LiveRequestQueue
|
32
|
+
from ...agents.readonly_context import ReadonlyContext
|
31
33
|
from ...agents.run_config import StreamingMode
|
32
34
|
from ...agents.transcription_entry import TranscriptionEntry
|
33
35
|
from ...events.event import Event
|
@@ -46,7 +48,7 @@ if TYPE_CHECKING:
|
|
46
48
|
from ._base_llm_processor import BaseLlmRequestProcessor
|
47
49
|
from ._base_llm_processor import BaseLlmResponseProcessor
|
48
50
|
|
49
|
-
logger = logging.getLogger(__name__)
|
51
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
50
52
|
|
51
53
|
|
52
54
|
class BaseLlmFlow(ABC):
|
@@ -87,7 +89,12 @@ class BaseLlmFlow(ABC):
|
|
87
89
|
if invocation_context.transcription_cache:
|
88
90
|
from . import audio_transcriber
|
89
91
|
|
90
|
-
audio_transcriber = audio_transcriber.AudioTranscriber(
|
92
|
+
audio_transcriber = audio_transcriber.AudioTranscriber(
|
93
|
+
init_client=True
|
94
|
+
if invocation_context.run_config.input_audio_transcription
|
95
|
+
is None
|
96
|
+
else False
|
97
|
+
)
|
91
98
|
contents = audio_transcriber.transcribe_file(invocation_context)
|
92
99
|
logger.debug('Sending history to model: %s', contents)
|
93
100
|
await llm_connection.send_history(contents)
|
@@ -128,6 +135,18 @@ class BaseLlmFlow(ABC):
|
|
128
135
|
# cancel the tasks that belongs to the closed connection.
|
129
136
|
send_task.cancel()
|
130
137
|
await llm_connection.close()
|
138
|
+
if (
|
139
|
+
event.content
|
140
|
+
and event.content.parts
|
141
|
+
and event.content.parts[0].function_response
|
142
|
+
and event.content.parts[0].function_response.name
|
143
|
+
== 'task_completed'
|
144
|
+
):
|
145
|
+
# this is used for sequential agent to signal the end of the agent.
|
146
|
+
await asyncio.sleep(1)
|
147
|
+
# cancel the tasks that belongs to the closed connection.
|
148
|
+
send_task.cancel()
|
149
|
+
return
|
131
150
|
finally:
|
132
151
|
# Clean up
|
133
152
|
if not send_task.done():
|
@@ -175,9 +194,12 @@ class BaseLlmFlow(ABC):
|
|
175
194
|
# Cache audio data here for transcription
|
176
195
|
if not invocation_context.transcription_cache:
|
177
196
|
invocation_context.transcription_cache = []
|
178
|
-
invocation_context.
|
179
|
-
|
180
|
-
|
197
|
+
if not invocation_context.run_config.input_audio_transcription:
|
198
|
+
# if the live model's input transcription is not enabled, then
|
199
|
+
# we use our onwn audio transcriber to achieve that.
|
200
|
+
invocation_context.transcription_cache.append(
|
201
|
+
TranscriptionEntry(role='user', data=live_request.blob)
|
202
|
+
)
|
181
203
|
await llm_connection.send_realtime(live_request.blob)
|
182
204
|
if live_request.content:
|
183
205
|
await llm_connection.send_content(live_request.content)
|
@@ -190,6 +212,25 @@ class BaseLlmFlow(ABC):
|
|
190
212
|
llm_request: LlmRequest,
|
191
213
|
) -> AsyncGenerator[Event, None]:
|
192
214
|
"""Receive data from model and process events using BaseLlmConnection."""
|
215
|
+
|
216
|
+
def get_author_for_event(llm_response):
|
217
|
+
"""Get the author of the event.
|
218
|
+
|
219
|
+
When the model returns transcription, the author is "user". Otherwise, the
|
220
|
+
author is the agent name(not 'model').
|
221
|
+
|
222
|
+
Args:
|
223
|
+
llm_response: The LLM response from the LLM call.
|
224
|
+
"""
|
225
|
+
if (
|
226
|
+
llm_response
|
227
|
+
and llm_response.content
|
228
|
+
and llm_response.content.role == 'user'
|
229
|
+
):
|
230
|
+
return 'user'
|
231
|
+
else:
|
232
|
+
return invocation_context.agent.name
|
233
|
+
|
193
234
|
assert invocation_context.live_request_queue
|
194
235
|
try:
|
195
236
|
while True:
|
@@ -197,7 +238,7 @@ class BaseLlmFlow(ABC):
|
|
197
238
|
model_response_event = Event(
|
198
239
|
id=Event.new_id(),
|
199
240
|
invocation_id=invocation_context.invocation_id,
|
200
|
-
author=
|
241
|
+
author=get_author_for_event(llm_response),
|
201
242
|
)
|
202
243
|
async for event in self._postprocess_live(
|
203
244
|
invocation_context,
|
@@ -208,13 +249,20 @@ class BaseLlmFlow(ABC):
|
|
208
249
|
if (
|
209
250
|
event.content
|
210
251
|
and event.content.parts
|
211
|
-
and event.content.parts[0].
|
252
|
+
and event.content.parts[0].inline_data is None
|
212
253
|
and not event.partial
|
213
254
|
):
|
255
|
+
# This can be either user data or transcription data.
|
256
|
+
# when output transcription enabled, it will contain model's
|
257
|
+
# transcription.
|
258
|
+
# when input transcription enabled, it will contain user
|
259
|
+
# transcription.
|
214
260
|
if not invocation_context.transcription_cache:
|
215
261
|
invocation_context.transcription_cache = []
|
216
262
|
invocation_context.transcription_cache.append(
|
217
|
-
TranscriptionEntry(
|
263
|
+
TranscriptionEntry(
|
264
|
+
role=event.content.role, data=event.content
|
265
|
+
)
|
218
266
|
)
|
219
267
|
yield event
|
220
268
|
# Give opportunity for other tasks to run.
|
@@ -261,6 +309,8 @@ class BaseLlmFlow(ABC):
|
|
261
309
|
async for event in self._postprocess_async(
|
262
310
|
invocation_context, llm_request, llm_response, model_response_event
|
263
311
|
):
|
312
|
+
# Update the mutable event id to avoid conflict
|
313
|
+
model_response_event.id = Event.new_id()
|
264
314
|
yield event
|
265
315
|
|
266
316
|
async def _preprocess_async(
|
@@ -278,7 +328,9 @@ class BaseLlmFlow(ABC):
|
|
278
328
|
yield event
|
279
329
|
|
280
330
|
# Run processors for tools.
|
281
|
-
for tool in agent.canonical_tools
|
331
|
+
for tool in await agent.canonical_tools(
|
332
|
+
ReadonlyContext(invocation_context)
|
333
|
+
):
|
282
334
|
tool_context = ToolContext(invocation_context)
|
283
335
|
await tool.process_llm_request(
|
284
336
|
tool_context=tool_context, llm_request=llm_request
|
@@ -437,7 +489,7 @@ class BaseLlmFlow(ABC):
|
|
437
489
|
model_response_event: Event,
|
438
490
|
) -> AsyncGenerator[LlmResponse, None]:
|
439
491
|
# Runs before_model_callback if it exists.
|
440
|
-
if response := self._handle_before_model_callback(
|
492
|
+
if response := await self._handle_before_model_callback(
|
441
493
|
invocation_context, llm_request, model_response_event
|
442
494
|
):
|
443
495
|
yield response
|
@@ -450,7 +502,7 @@ class BaseLlmFlow(ABC):
|
|
450
502
|
invocation_context.live_request_queue = LiveRequestQueue()
|
451
503
|
async for llm_response in self.run_live(invocation_context):
|
452
504
|
# Runs after_model_callback if it exists.
|
453
|
-
if altered_llm_response := self._handle_after_model_callback(
|
505
|
+
if altered_llm_response := await self._handle_after_model_callback(
|
454
506
|
invocation_context, llm_response, model_response_event
|
455
507
|
):
|
456
508
|
llm_response = altered_llm_response
|
@@ -479,14 +531,14 @@ class BaseLlmFlow(ABC):
|
|
479
531
|
llm_response,
|
480
532
|
)
|
481
533
|
# Runs after_model_callback if it exists.
|
482
|
-
if altered_llm_response := self._handle_after_model_callback(
|
534
|
+
if altered_llm_response := await self._handle_after_model_callback(
|
483
535
|
invocation_context, llm_response, model_response_event
|
484
536
|
):
|
485
537
|
llm_response = altered_llm_response
|
486
538
|
|
487
539
|
yield llm_response
|
488
540
|
|
489
|
-
def _handle_before_model_callback(
|
541
|
+
async def _handle_before_model_callback(
|
490
542
|
self,
|
491
543
|
invocation_context: InvocationContext,
|
492
544
|
llm_request: LlmRequest,
|
@@ -498,17 +550,23 @@ class BaseLlmFlow(ABC):
|
|
498
550
|
if not isinstance(agent, LlmAgent):
|
499
551
|
return
|
500
552
|
|
501
|
-
if not agent.
|
553
|
+
if not agent.canonical_before_model_callbacks:
|
502
554
|
return
|
503
555
|
|
504
556
|
callback_context = CallbackContext(
|
505
557
|
invocation_context, event_actions=model_response_event.actions
|
506
558
|
)
|
507
|
-
return agent.before_model_callback(
|
508
|
-
callback_context=callback_context, llm_request=llm_request
|
509
|
-
)
|
510
559
|
|
511
|
-
|
560
|
+
for callback in agent.canonical_before_model_callbacks:
|
561
|
+
before_model_callback_content = callback(
|
562
|
+
callback_context=callback_context, llm_request=llm_request
|
563
|
+
)
|
564
|
+
if inspect.isawaitable(before_model_callback_content):
|
565
|
+
before_model_callback_content = await before_model_callback_content
|
566
|
+
if before_model_callback_content:
|
567
|
+
return before_model_callback_content
|
568
|
+
|
569
|
+
async def _handle_after_model_callback(
|
512
570
|
self,
|
513
571
|
invocation_context: InvocationContext,
|
514
572
|
llm_response: LlmResponse,
|
@@ -520,15 +578,21 @@ class BaseLlmFlow(ABC):
|
|
520
578
|
if not isinstance(agent, LlmAgent):
|
521
579
|
return
|
522
580
|
|
523
|
-
if not agent.
|
581
|
+
if not agent.canonical_after_model_callbacks:
|
524
582
|
return
|
525
583
|
|
526
584
|
callback_context = CallbackContext(
|
527
585
|
invocation_context, event_actions=model_response_event.actions
|
528
586
|
)
|
529
|
-
|
530
|
-
|
531
|
-
|
587
|
+
|
588
|
+
for callback in agent.canonical_after_model_callbacks:
|
589
|
+
after_model_callback_content = callback(
|
590
|
+
callback_context=callback_context, llm_response=llm_response
|
591
|
+
)
|
592
|
+
if inspect.isawaitable(after_model_callback_content):
|
593
|
+
after_model_callback_content = await after_model_callback_content
|
594
|
+
if after_model_callback_content:
|
595
|
+
return after_model_callback_content
|
532
596
|
|
533
597
|
def _finalize_model_response_event(
|
534
598
|
self,
|