google-adk 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/a2a/converters/event_converter.py +5 -85
- google/adk/a2a/executor/a2a_agent_executor.py +45 -16
- google/adk/agents/__init__.py +5 -0
- google/adk/agents/agent_config.py +46 -0
- google/adk/agents/base_agent.py +234 -41
- google/adk/agents/callback_context.py +41 -0
- google/adk/agents/common_configs.py +79 -0
- google/adk/agents/config_agent_utils.py +184 -0
- google/adk/agents/config_schemas/AgentConfig.json +544 -0
- google/adk/agents/invocation_context.py +5 -1
- google/adk/agents/llm_agent.py +190 -9
- google/adk/agents/loop_agent.py +29 -0
- google/adk/agents/parallel_agent.py +24 -3
- google/adk/agents/remote_a2a_agent.py +15 -3
- google/adk/agents/sequential_agent.py +22 -1
- google/adk/artifacts/gcs_artifact_service.py +24 -2
- google/adk/auth/auth_handler.py +3 -3
- google/adk/auth/credential_manager.py +23 -23
- google/adk/auth/credential_service/base_credential_service.py +6 -6
- google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
- google/adk/auth/credential_service/session_state_credential_service.py +8 -8
- google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
- google/adk/auth/oauth2_credential_util.py +2 -2
- google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
- google/adk/cli/agent_graph.py +3 -1
- google/adk/cli/browser/index.html +1 -1
- google/adk/cli/browser/main-SRBSE46V.js +3914 -0
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/fast_api.py +42 -2
- google/adk/cli/utils/agent_loader.py +35 -1
- google/adk/code_executors/base_code_executor.py +14 -19
- google/adk/code_executors/built_in_code_executor.py +4 -1
- google/adk/evaluation/base_eval_service.py +46 -2
- google/adk/evaluation/evaluation_generator.py +1 -1
- google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
- google/adk/evaluation/local_eval_service.py +389 -0
- google/adk/evaluation/local_eval_sets_manager.py +23 -8
- google/adk/flows/llm_flows/auto_flow.py +6 -11
- google/adk/flows/llm_flows/base_llm_flow.py +41 -23
- google/adk/flows/llm_flows/contents.py +16 -10
- google/adk/flows/llm_flows/functions.py +76 -33
- google/adk/memory/in_memory_memory_service.py +20 -14
- google/adk/models/anthropic_llm.py +44 -5
- google/adk/models/google_llm.py +11 -6
- google/adk/models/lite_llm.py +21 -4
- google/adk/plugins/__init__.py +17 -0
- google/adk/plugins/base_plugin.py +317 -0
- google/adk/plugins/plugin_manager.py +265 -0
- google/adk/runners.py +122 -18
- google/adk/sessions/database_session_service.py +26 -28
- google/adk/sessions/vertex_ai_session_service.py +14 -7
- google/adk/tools/agent_tool.py +1 -0
- google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
- google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
- google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
- google/adk/tools/base_tool.py +9 -9
- google/adk/tools/base_toolset.py +7 -5
- google/adk/tools/bigquery/__init__.py +3 -3
- google/adk/tools/enterprise_search_tool.py +4 -2
- google/adk/tools/google_api_tool/google_api_tool.py +16 -1
- google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
- google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
- google/adk/tools/google_search_tool.py +4 -2
- google/adk/tools/langchain_tool.py +2 -3
- google/adk/tools/long_running_tool.py +21 -0
- google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
- google/adk/tools/tool_context.py +0 -10
- google/adk/tools/url_context_tool.py +4 -2
- google/adk/tools/vertex_ai_search_tool.py +4 -2
- google/adk/utils/model_name_utils.py +90 -0
- google/adk/version.py +1 -1
- {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/METADATA +2 -2
- {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/RECORD +79 -69
- google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
- {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/WHEEL +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,389 @@
|
|
1
|
+
# Copyright 2025 Google LLC
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from __future__ import annotations
|
16
|
+
|
17
|
+
import asyncio
|
18
|
+
import inspect
|
19
|
+
import logging
|
20
|
+
from typing import AsyncGenerator
|
21
|
+
from typing import Callable
|
22
|
+
from typing import Optional
|
23
|
+
import uuid
|
24
|
+
|
25
|
+
from typing_extensions import override
|
26
|
+
|
27
|
+
from ..agents import BaseAgent
|
28
|
+
from ..artifacts.base_artifact_service import BaseArtifactService
|
29
|
+
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
30
|
+
from ..errors.not_found_error import NotFoundError
|
31
|
+
from ..sessions.base_session_service import BaseSessionService
|
32
|
+
from ..sessions.in_memory_session_service import InMemorySessionService
|
33
|
+
from ..utils.feature_decorator import working_in_progress
|
34
|
+
from .base_eval_service import BaseEvalService
|
35
|
+
from .base_eval_service import EvaluateConfig
|
36
|
+
from .base_eval_service import EvaluateRequest
|
37
|
+
from .base_eval_service import InferenceRequest
|
38
|
+
from .base_eval_service import InferenceResult
|
39
|
+
from .base_eval_service import InferenceStatus
|
40
|
+
from .eval_case import Invocation
|
41
|
+
from .eval_metrics import EvalMetric
|
42
|
+
from .eval_metrics import EvalMetricResult
|
43
|
+
from .eval_metrics import EvalMetricResultPerInvocation
|
44
|
+
from .eval_result import EvalCaseResult
|
45
|
+
from .eval_set import EvalCase
|
46
|
+
from .eval_set_results_manager import EvalSetResultsManager
|
47
|
+
from .eval_sets_manager import EvalSetsManager
|
48
|
+
from .evaluation_generator import EvaluationGenerator
|
49
|
+
from .evaluator import EvalStatus
|
50
|
+
from .evaluator import EvaluationResult
|
51
|
+
from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
|
52
|
+
from .metric_evaluator_registry import MetricEvaluatorRegistry
|
53
|
+
|
54
|
+
logger = logging.getLogger('google_adk.' + __name__)
|
55
|
+
|
56
|
+
EVAL_SESSION_ID_PREFIX = '___eval___session___'
|
57
|
+
|
58
|
+
|
59
|
+
def _get_session_id() -> str:
|
60
|
+
return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
|
61
|
+
|
62
|
+
|
63
|
+
@working_in_progress("Incomplete feature, don't use yet")
|
64
|
+
class LocalEvalService(BaseEvalService):
|
65
|
+
"""An implementation of BaseEvalService, that runs the evals locally."""
|
66
|
+
|
67
|
+
def __init__(
|
68
|
+
self,
|
69
|
+
root_agent: BaseAgent,
|
70
|
+
eval_sets_manager: EvalSetsManager,
|
71
|
+
metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
|
72
|
+
session_service: BaseSessionService = InMemorySessionService(),
|
73
|
+
artifact_service: BaseArtifactService = InMemoryArtifactService(),
|
74
|
+
eval_set_results_manager: Optional[EvalSetResultsManager] = None,
|
75
|
+
session_id_supplier: Callable[[], str] = _get_session_id,
|
76
|
+
):
|
77
|
+
self._root_agent = root_agent
|
78
|
+
self._eval_sets_manager = eval_sets_manager
|
79
|
+
self._metric_evaluator_registry = metric_evaluator_registry
|
80
|
+
self._session_service = session_service
|
81
|
+
self._artifact_service = artifact_service
|
82
|
+
self._eval_set_results_manager = eval_set_results_manager
|
83
|
+
self._session_id_supplier = session_id_supplier
|
84
|
+
|
85
|
+
@override
|
86
|
+
async def perform_inference(
|
87
|
+
self,
|
88
|
+
inference_request: InferenceRequest,
|
89
|
+
) -> AsyncGenerator[InferenceResult, None]:
|
90
|
+
"""Returns InferenceResult obtained from the Agent as and when they are available.
|
91
|
+
|
92
|
+
Args:
|
93
|
+
inference_request: The request for generating inferences.
|
94
|
+
"""
|
95
|
+
# Get the eval set from the storage.
|
96
|
+
eval_set = self._eval_sets_manager.get_eval_set(
|
97
|
+
app_name=inference_request.app_name,
|
98
|
+
eval_set_id=inference_request.eval_set_id,
|
99
|
+
)
|
100
|
+
|
101
|
+
if not eval_set:
|
102
|
+
raise NotFoundError(
|
103
|
+
f'Eval set with id {inference_request.eval_set_id} not found for app'
|
104
|
+
f' {inference_request.app_name}'
|
105
|
+
)
|
106
|
+
|
107
|
+
# Select eval cases for which we need to run inferencing. If the inference
|
108
|
+
# request specified eval cases, then we use only those.
|
109
|
+
eval_cases = eval_set.eval_cases
|
110
|
+
if inference_request.eval_case_ids:
|
111
|
+
eval_cases = [
|
112
|
+
eval_case
|
113
|
+
for eval_case in eval_cases
|
114
|
+
if eval_case.eval_id in inference_request.eval_case_ids
|
115
|
+
]
|
116
|
+
|
117
|
+
root_agent = self._root_agent.clone()
|
118
|
+
|
119
|
+
semaphore = asyncio.Semaphore(
|
120
|
+
value=inference_request.inference_config.parallelism
|
121
|
+
)
|
122
|
+
|
123
|
+
async def run_inference(eval_case):
|
124
|
+
async with semaphore:
|
125
|
+
return await self._perform_inference_sigle_eval_item(
|
126
|
+
app_name=inference_request.app_name,
|
127
|
+
eval_set_id=inference_request.eval_set_id,
|
128
|
+
eval_case=eval_case,
|
129
|
+
root_agent=root_agent,
|
130
|
+
)
|
131
|
+
|
132
|
+
inference_results = [run_inference(eval_case) for eval_case in eval_cases]
|
133
|
+
for inference_result in asyncio.as_completed(inference_results):
|
134
|
+
yield await inference_result
|
135
|
+
|
136
|
+
@override
|
137
|
+
async def evaluate(
|
138
|
+
self,
|
139
|
+
evaluate_request: EvaluateRequest,
|
140
|
+
) -> AsyncGenerator[EvalCaseResult, None]:
|
141
|
+
"""Returns EvalCaseResult for each item as and when they are available.
|
142
|
+
|
143
|
+
Args:
|
144
|
+
evaluate_request: The request to perform metric evaluations on the
|
145
|
+
inferences.
|
146
|
+
"""
|
147
|
+
semaphore = asyncio.Semaphore(
|
148
|
+
value=evaluate_request.evaluate_config.parallelism
|
149
|
+
)
|
150
|
+
|
151
|
+
async def run_evaluation(inference_result):
|
152
|
+
async with semaphore:
|
153
|
+
return await self._evaluate_single_inference_result(
|
154
|
+
inference_result=inference_result,
|
155
|
+
evaluate_config=evaluate_request.evaluate_config,
|
156
|
+
)
|
157
|
+
|
158
|
+
evaluation_tasks = [
|
159
|
+
run_evaluation(inference_result)
|
160
|
+
for inference_result in evaluate_request.inference_results
|
161
|
+
]
|
162
|
+
|
163
|
+
for evaluation_task in asyncio.as_completed(evaluation_tasks):
|
164
|
+
inference_result, eval_case_result = await evaluation_task
|
165
|
+
|
166
|
+
if self._eval_set_results_manager:
|
167
|
+
self._eval_set_results_manager.save_eval_set_result(
|
168
|
+
app_name=inference_result.app_name,
|
169
|
+
eval_set_id=inference_result.eval_set_id,
|
170
|
+
eval_case_results=[eval_case_result],
|
171
|
+
)
|
172
|
+
|
173
|
+
yield eval_case_result
|
174
|
+
|
175
|
+
async def _evaluate_single_inference_result(
|
176
|
+
self, inference_result: InferenceResult, evaluate_config: EvaluateConfig
|
177
|
+
) -> tuple[InferenceResult, EvalCaseResult]:
|
178
|
+
"""Returns EvalCaseResult for the given inference result.
|
179
|
+
|
180
|
+
A single inference result can have multiple invocations. For each
|
181
|
+
invocaiton, this method evaluates the metrics present in evaluate config.
|
182
|
+
|
183
|
+
The EvalCaseResult contains scores for each metric per invocation and the
|
184
|
+
overall score.
|
185
|
+
"""
|
186
|
+
eval_case = self._eval_sets_manager.get_eval_case(
|
187
|
+
app_name=inference_result.app_name,
|
188
|
+
eval_set_id=inference_result.eval_set_id,
|
189
|
+
eval_case_id=inference_result.eval_case_id,
|
190
|
+
)
|
191
|
+
|
192
|
+
if eval_case is None:
|
193
|
+
raise NotFoundError(
|
194
|
+
f'Eval case with id {inference_result.eval_case_id} not found for'
|
195
|
+
f' app {inference_result.app_name} and eval set'
|
196
|
+
f' {inference_result.eval_set_id}.'
|
197
|
+
)
|
198
|
+
|
199
|
+
# Metric results for each invocation
|
200
|
+
eval_metric_result_per_invocation = []
|
201
|
+
|
202
|
+
# We also keep track of the overall score for a metric, derived from all
|
203
|
+
# invocation. For example, if we were keeping track the metric that compares
|
204
|
+
# how well is the final resposne as compared to a golden answer, then each
|
205
|
+
# invocation will have the value of this metric. We will also have an
|
206
|
+
# overall score using aggregation strategy across all invocations. This
|
207
|
+
# would be the score for the eval case.
|
208
|
+
overall_eval_metric_results = []
|
209
|
+
|
210
|
+
if len(inference_result.inferences) != len(eval_case.conversation):
|
211
|
+
raise ValueError(
|
212
|
+
'Inferences should match conversations in eval case. Found'
|
213
|
+
f'{len(inference_result.inferences)} inferences '
|
214
|
+
f'{len(eval_case.conversation)} conversations in eval cases.'
|
215
|
+
)
|
216
|
+
|
217
|
+
# Pre-creating the EvalMetricResults entries for each invocation.
|
218
|
+
for actual, expected in zip(
|
219
|
+
inference_result.inferences, eval_case.conversation
|
220
|
+
):
|
221
|
+
eval_metric_result_per_invocation.append(
|
222
|
+
EvalMetricResultPerInvocation(
|
223
|
+
actual_invocation=actual,
|
224
|
+
expected_invocation=expected,
|
225
|
+
# We will fill this as we evaluate each metric per invocation.
|
226
|
+
eval_metric_results=[],
|
227
|
+
)
|
228
|
+
)
|
229
|
+
|
230
|
+
for eval_metric in evaluate_config.eval_metrics:
|
231
|
+
# Perform evaluation of the metric.
|
232
|
+
evaluation_result = await self._evaluate_metric(
|
233
|
+
eval_metric=eval_metric,
|
234
|
+
actual_invocations=inference_result.inferences,
|
235
|
+
expected_invocations=eval_case.conversation,
|
236
|
+
)
|
237
|
+
|
238
|
+
# Track overall scrore across all invocations.
|
239
|
+
overall_eval_metric_results.append(
|
240
|
+
EvalMetricResult(
|
241
|
+
metric_name=eval_metric.metric_name,
|
242
|
+
threshold=eval_metric.threshold,
|
243
|
+
score=evaluation_result.overall_score,
|
244
|
+
eval_status=evaluation_result.overall_eval_status,
|
245
|
+
)
|
246
|
+
)
|
247
|
+
|
248
|
+
if len(evaluation_result.per_invocation_results) != len(
|
249
|
+
eval_metric_result_per_invocation
|
250
|
+
):
|
251
|
+
raise ValueError(
|
252
|
+
'Eval metric should return results for each invocation. Found '
|
253
|
+
f'{len(evaluation_result.per_invocation_results)} results for '
|
254
|
+
f'{len(eval_metric_result_per_invocation)} invocations.'
|
255
|
+
)
|
256
|
+
|
257
|
+
# Track score across individual invocations.
|
258
|
+
for invocation_result, invocation in zip(
|
259
|
+
evaluation_result.per_invocation_results,
|
260
|
+
eval_metric_result_per_invocation,
|
261
|
+
):
|
262
|
+
invocation.eval_metric_results.append(
|
263
|
+
EvalMetricResult(
|
264
|
+
metric_name=eval_metric.metric_name,
|
265
|
+
threshold=eval_metric.threshold,
|
266
|
+
score=invocation_result.score,
|
267
|
+
eval_status=invocation_result.eval_status,
|
268
|
+
)
|
269
|
+
)
|
270
|
+
|
271
|
+
final_eval_status = self._generate_final_eval_status(
|
272
|
+
overall_eval_metric_results
|
273
|
+
)
|
274
|
+
user_id = (
|
275
|
+
eval_case.session_input.user_id
|
276
|
+
if eval_case.session_input and eval_case.session_input.user_id
|
277
|
+
else 'test_user_id'
|
278
|
+
)
|
279
|
+
|
280
|
+
eval_case_result = EvalCaseResult(
|
281
|
+
eval_set_file=inference_result.eval_set_id,
|
282
|
+
eval_set_id=inference_result.eval_set_id,
|
283
|
+
eval_id=inference_result.eval_case_id,
|
284
|
+
final_eval_status=final_eval_status,
|
285
|
+
overall_eval_metric_results=overall_eval_metric_results,
|
286
|
+
eval_metric_result_per_invocation=eval_metric_result_per_invocation,
|
287
|
+
session_id=inference_result.session_id,
|
288
|
+
session_details=await self._session_service.get_session(
|
289
|
+
app_name=inference_result.app_name,
|
290
|
+
user_id=user_id,
|
291
|
+
session_id=inference_result.session_id,
|
292
|
+
),
|
293
|
+
user_id=user_id,
|
294
|
+
)
|
295
|
+
|
296
|
+
return (inference_result, eval_case_result)
|
297
|
+
|
298
|
+
async def _evaluate_metric(
|
299
|
+
self,
|
300
|
+
eval_metric: EvalMetric,
|
301
|
+
actual_invocations: list[Invocation],
|
302
|
+
expected_invocations: list[Invocation],
|
303
|
+
) -> EvaluationResult:
|
304
|
+
"""Returns EvaluationResult obtained from evaluating a metric using an Evaluator."""
|
305
|
+
|
306
|
+
# Get the metric evaluator from the registry.
|
307
|
+
metric_evaluator = self._metric_evaluator_registry.get_evaluator(
|
308
|
+
eval_metric=eval_metric
|
309
|
+
)
|
310
|
+
|
311
|
+
if inspect.iscoroutinefunction(metric_evaluator.evaluate_invocations):
|
312
|
+
# Some evaluators could be async, for example those that use llm as a
|
313
|
+
# judge, so we need to make sure that we wait on them.
|
314
|
+
return await metric_evaluator.evaluate_invocations(
|
315
|
+
actual_invocations=actual_invocations,
|
316
|
+
expected_invocations=expected_invocations,
|
317
|
+
)
|
318
|
+
else:
|
319
|
+
# Metrics that perform computation synchronously, mostly these don't
|
320
|
+
# perform any i/o. An example of this would calculation of rouge_1 score.
|
321
|
+
return metric_evaluator.evaluate_invocations(
|
322
|
+
actual_invocations=actual_invocations,
|
323
|
+
expected_invocations=expected_invocations,
|
324
|
+
)
|
325
|
+
|
326
|
+
def _generate_final_eval_status(
|
327
|
+
self, overall_eval_metric_results: list[EvalMetricResult]
|
328
|
+
) -> EvalStatus:
|
329
|
+
final_eval_status = EvalStatus.NOT_EVALUATED
|
330
|
+
# Go over the all the eval statuses and mark the final eval status as
|
331
|
+
# passed if all of them pass, otherwise mark the final eval status to
|
332
|
+
# failed.
|
333
|
+
for overall_eval_metric_result in overall_eval_metric_results:
|
334
|
+
overall_eval_status = overall_eval_metric_result.eval_status
|
335
|
+
if overall_eval_status == EvalStatus.PASSED:
|
336
|
+
final_eval_status = EvalStatus.PASSED
|
337
|
+
elif overall_eval_status == EvalStatus.NOT_EVALUATED:
|
338
|
+
continue
|
339
|
+
elif overall_eval_status == EvalStatus.FAILED:
|
340
|
+
final_eval_status = EvalStatus.FAILED
|
341
|
+
break
|
342
|
+
else:
|
343
|
+
raise ValueError(f'Unknown eval status: {overall_eval_status}.')
|
344
|
+
|
345
|
+
return final_eval_status
|
346
|
+
|
347
|
+
async def _perform_inference_sigle_eval_item(
|
348
|
+
self,
|
349
|
+
app_name: str,
|
350
|
+
eval_set_id: str,
|
351
|
+
eval_case: EvalCase,
|
352
|
+
root_agent: BaseAgent,
|
353
|
+
) -> InferenceResult:
|
354
|
+
initial_session = eval_case.session_input
|
355
|
+
session_id = self._session_id_supplier()
|
356
|
+
inference_result = InferenceResult(
|
357
|
+
app_name=app_name,
|
358
|
+
eval_set_id=eval_set_id,
|
359
|
+
eval_case_id=eval_case.eval_id,
|
360
|
+
session_id=session_id,
|
361
|
+
)
|
362
|
+
|
363
|
+
try:
|
364
|
+
inferences = (
|
365
|
+
await EvaluationGenerator._generate_inferences_from_root_agent(
|
366
|
+
invocations=eval_case.conversation,
|
367
|
+
root_agent=root_agent,
|
368
|
+
initial_session=initial_session,
|
369
|
+
session_id=session_id,
|
370
|
+
session_service=self._session_service,
|
371
|
+
artifact_service=self._artifact_service,
|
372
|
+
)
|
373
|
+
)
|
374
|
+
|
375
|
+
inference_result.inferences = inferences
|
376
|
+
inference_result.status = InferenceStatus.SUCCESS
|
377
|
+
|
378
|
+
return inference_result
|
379
|
+
except Exception as e:
|
380
|
+
# We intentionally catch the Exception as we don't failures to affect
|
381
|
+
# other inferences.
|
382
|
+
logger.error(
|
383
|
+
'Inference failed for eval case `%s` with error %s',
|
384
|
+
eval_case.eval_id,
|
385
|
+
e,
|
386
|
+
)
|
387
|
+
inference_result.status = InferenceStatus.FAILURE
|
388
|
+
inference_result.error_message = str(e)
|
389
|
+
return inference_result
|
@@ -27,6 +27,7 @@ from google.genai import types as genai_types
|
|
27
27
|
from pydantic import ValidationError
|
28
28
|
from typing_extensions import override
|
29
29
|
|
30
|
+
from ..errors.not_found_error import NotFoundError
|
30
31
|
from ._eval_sets_manager_utils import add_eval_case_to_eval_set
|
31
32
|
from ._eval_sets_manager_utils import delete_eval_case_from_eval_set
|
32
33
|
from ._eval_sets_manager_utils import get_eval_case_from_eval_set
|
@@ -226,16 +227,30 @@ class LocalEvalSetsManager(EvalSetsManager):
|
|
226
227
|
|
227
228
|
@override
|
228
229
|
def list_eval_sets(self, app_name: str) -> list[str]:
|
229
|
-
"""Returns a list of EvalSets that belong to the given app_name.
|
230
|
+
"""Returns a list of EvalSets that belong to the given app_name.
|
231
|
+
|
232
|
+
Args:
|
233
|
+
app_name: The app name to list the eval sets for.
|
234
|
+
|
235
|
+
Returns:
|
236
|
+
A list of EvalSet ids.
|
237
|
+
|
238
|
+
Raises:
|
239
|
+
NotFoundError: If the eval directory for the app is not found.
|
240
|
+
"""
|
230
241
|
eval_set_file_path = os.path.join(self._agents_dir, app_name)
|
231
242
|
eval_sets = []
|
232
|
-
|
233
|
-
|
234
|
-
|
235
|
-
|
236
|
-
|
237
|
-
|
238
|
-
|
243
|
+
try:
|
244
|
+
for file in os.listdir(eval_set_file_path):
|
245
|
+
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
246
|
+
eval_sets.append(
|
247
|
+
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
248
|
+
)
|
249
|
+
return sorted(eval_sets)
|
250
|
+
except FileNotFoundError as e:
|
251
|
+
raise NotFoundError(
|
252
|
+
f"Eval directory for app `{app_name}` not found."
|
253
|
+
) from e
|
239
254
|
|
240
255
|
@override
|
241
256
|
def get_eval_case(
|
@@ -14,6 +14,8 @@
|
|
14
14
|
|
15
15
|
"""Implementation of AutoFlow."""
|
16
16
|
|
17
|
+
from __future__ import annotations
|
18
|
+
|
17
19
|
from . import agent_transfer
|
18
20
|
from .single_flow import SingleFlow
|
19
21
|
|
@@ -29,19 +31,12 @@ class AutoFlow(SingleFlow):
|
|
29
31
|
|
30
32
|
For peer-agent transfers, it's only enabled when all below conditions are met:
|
31
33
|
|
32
|
-
- The parent agent is also
|
34
|
+
- The parent agent is also an LlmAgent.
|
33
35
|
- `disallow_transfer_to_peer` option of this agent is False (default).
|
34
36
|
|
35
|
-
Depending on the target agent
|
36
|
-
reversed.
|
37
|
-
|
38
|
-
- If the flow type of the tranferee agent is also auto, transfee agent will
|
39
|
-
remain as the active agent. The transfee agent will respond to the user's
|
40
|
-
next message directly.
|
41
|
-
- If the flow type of the transfere agent is not auto, the active agent will
|
42
|
-
be reversed back to previous agent.
|
43
|
-
|
44
|
-
TODO: allow user to config auto-reverse function.
|
37
|
+
Depending on the target agent type, the transfer may be automatically
|
38
|
+
reversed. (see Runner._find_agent_to_run method for which agent will remain
|
39
|
+
active to handle next user message.)
|
45
40
|
"""
|
46
41
|
|
47
42
|
def __init__(self):
|
@@ -283,14 +283,10 @@ class BaseLlmFlow(ABC):
|
|
283
283
|
async for event in self._run_one_step_async(invocation_context):
|
284
284
|
last_event = event
|
285
285
|
yield event
|
286
|
-
if not last_event or last_event.is_final_response():
|
286
|
+
if not last_event or last_event.is_final_response() or last_event.partial:
|
287
|
+
if last_event and last_event.partial:
|
288
|
+
logger.warning('The last event is partial, which is not expected.')
|
287
289
|
break
|
288
|
-
if last_event.partial:
|
289
|
-
# TODO: handle this in BaseLlm level.
|
290
|
-
raise ValueError(
|
291
|
-
f"Last event shouldn't be partial. LLM max output limit may be"
|
292
|
-
f' reached.'
|
293
|
-
)
|
294
290
|
|
295
291
|
async def _run_one_step_async(
|
296
292
|
self,
|
@@ -569,21 +565,32 @@ class BaseLlmFlow(ABC):
|
|
569
565
|
if not isinstance(agent, LlmAgent):
|
570
566
|
return
|
571
567
|
|
572
|
-
if not agent.canonical_before_model_callbacks:
|
573
|
-
return
|
574
|
-
|
575
568
|
callback_context = CallbackContext(
|
576
569
|
invocation_context, event_actions=model_response_event.actions
|
577
570
|
)
|
578
571
|
|
572
|
+
# First run callbacks from the plugins.
|
573
|
+
callback_response = (
|
574
|
+
await invocation_context.plugin_manager.run_before_model_callback(
|
575
|
+
callback_context=callback_context,
|
576
|
+
llm_request=llm_request,
|
577
|
+
)
|
578
|
+
)
|
579
|
+
if callback_response:
|
580
|
+
return callback_response
|
581
|
+
|
582
|
+
# If no overrides are provided from the plugins, further run the canonical
|
583
|
+
# callbacks.
|
584
|
+
if not agent.canonical_before_model_callbacks:
|
585
|
+
return
|
579
586
|
for callback in agent.canonical_before_model_callbacks:
|
580
|
-
|
587
|
+
callback_response = callback(
|
581
588
|
callback_context=callback_context, llm_request=llm_request
|
582
589
|
)
|
583
|
-
if inspect.isawaitable(
|
584
|
-
|
585
|
-
if
|
586
|
-
return
|
590
|
+
if inspect.isawaitable(callback_response):
|
591
|
+
callback_response = await callback_response
|
592
|
+
if callback_response:
|
593
|
+
return callback_response
|
587
594
|
|
588
595
|
async def _handle_after_model_callback(
|
589
596
|
self,
|
@@ -597,21 +604,32 @@ class BaseLlmFlow(ABC):
|
|
597
604
|
if not isinstance(agent, LlmAgent):
|
598
605
|
return
|
599
606
|
|
600
|
-
if not agent.canonical_after_model_callbacks:
|
601
|
-
return
|
602
|
-
|
603
607
|
callback_context = CallbackContext(
|
604
608
|
invocation_context, event_actions=model_response_event.actions
|
605
609
|
)
|
606
610
|
|
611
|
+
# First run callbacks from the plugins.
|
612
|
+
callback_response = (
|
613
|
+
await invocation_context.plugin_manager.run_after_model_callback(
|
614
|
+
callback_context=CallbackContext(invocation_context),
|
615
|
+
llm_response=llm_response,
|
616
|
+
)
|
617
|
+
)
|
618
|
+
if callback_response:
|
619
|
+
return callback_response
|
620
|
+
|
621
|
+
# If no overrides are provided from the plugins, further run the canonical
|
622
|
+
# callbacks.
|
623
|
+
if not agent.canonical_after_model_callbacks:
|
624
|
+
return
|
607
625
|
for callback in agent.canonical_after_model_callbacks:
|
608
|
-
|
626
|
+
callback_response = callback(
|
609
627
|
callback_context=callback_context, llm_response=llm_response
|
610
628
|
)
|
611
|
-
if inspect.isawaitable(
|
612
|
-
|
613
|
-
if
|
614
|
-
return
|
629
|
+
if inspect.isawaitable(callback_response):
|
630
|
+
callback_response = await callback_response
|
631
|
+
if callback_response:
|
632
|
+
return callback_response
|
615
633
|
|
616
634
|
def _finalize_model_response_event(
|
617
635
|
self,
|
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
|
|
157
157
|
for function_call in function_calls:
|
158
158
|
if function_call.id in function_responses_ids:
|
159
159
|
function_call_event_idx = idx
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
#
|
164
|
-
|
165
|
-
|
160
|
+
function_call_ids = {
|
161
|
+
function_call.id for function_call in function_calls
|
162
|
+
}
|
163
|
+
# last response event should only contain the responses for the
|
164
|
+
# function calls in the same function call event
|
165
|
+
if not function_responses_ids.issubset(function_call_ids):
|
166
|
+
raise ValueError(
|
167
|
+
'Last response event should only contain the responses for the'
|
168
|
+
' function calls in the same function call event. Function'
|
169
|
+
f' call ids found : {function_call_ids}, function response'
|
170
|
+
f' ids provided: {function_responses_ids}'
|
171
|
+
)
|
172
|
+
# collect all function responses from the function call event to
|
173
|
+
# the last response event
|
174
|
+
function_responses_ids = function_call_ids
|
166
175
|
break
|
167
176
|
|
168
177
|
if function_call_event_idx == -1:
|
@@ -363,10 +372,7 @@ def _merge_function_response_events(
|
|
363
372
|
list is in increasing order of timestamp; 2. the first event is the
|
364
373
|
initial function_response event; 3. all later events should contain at
|
365
374
|
least one function_response part that related to the function_call
|
366
|
-
event.
|
367
|
-
intermediate response, there could also be some intermediate model
|
368
|
-
response event without any function_response and such event will be
|
369
|
-
ignored.)
|
375
|
+
event.
|
370
376
|
Caveat: This implementation doesn't support when a parallel function_call
|
371
377
|
event contains async function_call of the same name.
|
372
378
|
|