google-adk 1.6.1__py3-none-any.whl → 1.7.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. google/adk/a2a/converters/event_converter.py +5 -85
  2. google/adk/a2a/executor/a2a_agent_executor.py +45 -16
  3. google/adk/agents/__init__.py +5 -0
  4. google/adk/agents/agent_config.py +46 -0
  5. google/adk/agents/base_agent.py +234 -41
  6. google/adk/agents/callback_context.py +41 -0
  7. google/adk/agents/common_configs.py +79 -0
  8. google/adk/agents/config_agent_utils.py +184 -0
  9. google/adk/agents/config_schemas/AgentConfig.json +544 -0
  10. google/adk/agents/invocation_context.py +5 -1
  11. google/adk/agents/llm_agent.py +190 -9
  12. google/adk/agents/loop_agent.py +29 -0
  13. google/adk/agents/parallel_agent.py +24 -3
  14. google/adk/agents/remote_a2a_agent.py +15 -3
  15. google/adk/agents/sequential_agent.py +22 -1
  16. google/adk/artifacts/gcs_artifact_service.py +24 -2
  17. google/adk/auth/auth_handler.py +3 -3
  18. google/adk/auth/credential_manager.py +23 -23
  19. google/adk/auth/credential_service/base_credential_service.py +6 -6
  20. google/adk/auth/credential_service/in_memory_credential_service.py +10 -8
  21. google/adk/auth/credential_service/session_state_credential_service.py +8 -8
  22. google/adk/auth/exchanger/oauth2_credential_exchanger.py +3 -3
  23. google/adk/auth/oauth2_credential_util.py +2 -2
  24. google/adk/auth/refresher/oauth2_credential_refresher.py +4 -4
  25. google/adk/cli/agent_graph.py +3 -1
  26. google/adk/cli/browser/index.html +1 -1
  27. google/adk/cli/browser/main-SRBSE46V.js +3914 -0
  28. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  29. google/adk/cli/fast_api.py +42 -2
  30. google/adk/cli/utils/agent_loader.py +35 -1
  31. google/adk/code_executors/base_code_executor.py +14 -19
  32. google/adk/code_executors/built_in_code_executor.py +4 -1
  33. google/adk/evaluation/base_eval_service.py +46 -2
  34. google/adk/evaluation/evaluation_generator.py +1 -1
  35. google/adk/evaluation/in_memory_eval_sets_manager.py +151 -0
  36. google/adk/evaluation/local_eval_service.py +389 -0
  37. google/adk/evaluation/local_eval_sets_manager.py +23 -8
  38. google/adk/flows/llm_flows/auto_flow.py +6 -11
  39. google/adk/flows/llm_flows/base_llm_flow.py +41 -23
  40. google/adk/flows/llm_flows/contents.py +16 -10
  41. google/adk/flows/llm_flows/functions.py +76 -33
  42. google/adk/memory/in_memory_memory_service.py +20 -14
  43. google/adk/models/anthropic_llm.py +44 -5
  44. google/adk/models/google_llm.py +11 -6
  45. google/adk/models/lite_llm.py +21 -4
  46. google/adk/plugins/__init__.py +17 -0
  47. google/adk/plugins/base_plugin.py +317 -0
  48. google/adk/plugins/plugin_manager.py +265 -0
  49. google/adk/runners.py +122 -18
  50. google/adk/sessions/database_session_service.py +26 -28
  51. google/adk/sessions/vertex_ai_session_service.py +14 -7
  52. google/adk/tools/agent_tool.py +1 -0
  53. google/adk/tools/apihub_tool/apihub_toolset.py +38 -39
  54. google/adk/tools/application_integration_tool/application_integration_toolset.py +35 -37
  55. google/adk/tools/application_integration_tool/integration_connector_tool.py +2 -3
  56. google/adk/tools/base_tool.py +9 -9
  57. google/adk/tools/base_toolset.py +7 -5
  58. google/adk/tools/bigquery/__init__.py +3 -3
  59. google/adk/tools/enterprise_search_tool.py +4 -2
  60. google/adk/tools/google_api_tool/google_api_tool.py +16 -1
  61. google/adk/tools/google_api_tool/google_api_toolset.py +9 -7
  62. google/adk/tools/google_api_tool/google_api_toolsets.py +41 -20
  63. google/adk/tools/google_search_tool.py +4 -2
  64. google/adk/tools/langchain_tool.py +2 -3
  65. google/adk/tools/long_running_tool.py +21 -0
  66. google/adk/tools/mcp_tool/mcp_toolset.py +27 -28
  67. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +8 -8
  68. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +4 -6
  69. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +3 -2
  70. google/adk/tools/tool_context.py +0 -10
  71. google/adk/tools/url_context_tool.py +4 -2
  72. google/adk/tools/vertex_ai_search_tool.py +4 -2
  73. google/adk/utils/model_name_utils.py +90 -0
  74. google/adk/version.py +1 -1
  75. {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/METADATA +2 -2
  76. {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/RECORD +79 -69
  77. google/adk/cli/browser/main-RXDVX3K6.js +0 -3914
  78. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -17
  79. {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/WHEEL +0 -0
  80. {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/entry_points.txt +0 -0
  81. {google_adk-1.6.1.dist-info → google_adk-1.7.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,389 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import asyncio
18
+ import inspect
19
+ import logging
20
+ from typing import AsyncGenerator
21
+ from typing import Callable
22
+ from typing import Optional
23
+ import uuid
24
+
25
+ from typing_extensions import override
26
+
27
+ from ..agents import BaseAgent
28
+ from ..artifacts.base_artifact_service import BaseArtifactService
29
+ from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
30
+ from ..errors.not_found_error import NotFoundError
31
+ from ..sessions.base_session_service import BaseSessionService
32
+ from ..sessions.in_memory_session_service import InMemorySessionService
33
+ from ..utils.feature_decorator import working_in_progress
34
+ from .base_eval_service import BaseEvalService
35
+ from .base_eval_service import EvaluateConfig
36
+ from .base_eval_service import EvaluateRequest
37
+ from .base_eval_service import InferenceRequest
38
+ from .base_eval_service import InferenceResult
39
+ from .base_eval_service import InferenceStatus
40
+ from .eval_case import Invocation
41
+ from .eval_metrics import EvalMetric
42
+ from .eval_metrics import EvalMetricResult
43
+ from .eval_metrics import EvalMetricResultPerInvocation
44
+ from .eval_result import EvalCaseResult
45
+ from .eval_set import EvalCase
46
+ from .eval_set_results_manager import EvalSetResultsManager
47
+ from .eval_sets_manager import EvalSetsManager
48
+ from .evaluation_generator import EvaluationGenerator
49
+ from .evaluator import EvalStatus
50
+ from .evaluator import EvaluationResult
51
+ from .metric_evaluator_registry import DEFAULT_METRIC_EVALUATOR_REGISTRY
52
+ from .metric_evaluator_registry import MetricEvaluatorRegistry
53
+
54
+ logger = logging.getLogger('google_adk.' + __name__)
55
+
56
+ EVAL_SESSION_ID_PREFIX = '___eval___session___'
57
+
58
+
59
+ def _get_session_id() -> str:
60
+ return f'{EVAL_SESSION_ID_PREFIX}{str(uuid.uuid4())}'
61
+
62
+
63
+ @working_in_progress("Incomplete feature, don't use yet")
64
+ class LocalEvalService(BaseEvalService):
65
+ """An implementation of BaseEvalService, that runs the evals locally."""
66
+
67
+ def __init__(
68
+ self,
69
+ root_agent: BaseAgent,
70
+ eval_sets_manager: EvalSetsManager,
71
+ metric_evaluator_registry: MetricEvaluatorRegistry = DEFAULT_METRIC_EVALUATOR_REGISTRY,
72
+ session_service: BaseSessionService = InMemorySessionService(),
73
+ artifact_service: BaseArtifactService = InMemoryArtifactService(),
74
+ eval_set_results_manager: Optional[EvalSetResultsManager] = None,
75
+ session_id_supplier: Callable[[], str] = _get_session_id,
76
+ ):
77
+ self._root_agent = root_agent
78
+ self._eval_sets_manager = eval_sets_manager
79
+ self._metric_evaluator_registry = metric_evaluator_registry
80
+ self._session_service = session_service
81
+ self._artifact_service = artifact_service
82
+ self._eval_set_results_manager = eval_set_results_manager
83
+ self._session_id_supplier = session_id_supplier
84
+
85
+ @override
86
+ async def perform_inference(
87
+ self,
88
+ inference_request: InferenceRequest,
89
+ ) -> AsyncGenerator[InferenceResult, None]:
90
+ """Returns InferenceResult obtained from the Agent as and when they are available.
91
+
92
+ Args:
93
+ inference_request: The request for generating inferences.
94
+ """
95
+ # Get the eval set from the storage.
96
+ eval_set = self._eval_sets_manager.get_eval_set(
97
+ app_name=inference_request.app_name,
98
+ eval_set_id=inference_request.eval_set_id,
99
+ )
100
+
101
+ if not eval_set:
102
+ raise NotFoundError(
103
+ f'Eval set with id {inference_request.eval_set_id} not found for app'
104
+ f' {inference_request.app_name}'
105
+ )
106
+
107
+ # Select eval cases for which we need to run inferencing. If the inference
108
+ # request specified eval cases, then we use only those.
109
+ eval_cases = eval_set.eval_cases
110
+ if inference_request.eval_case_ids:
111
+ eval_cases = [
112
+ eval_case
113
+ for eval_case in eval_cases
114
+ if eval_case.eval_id in inference_request.eval_case_ids
115
+ ]
116
+
117
+ root_agent = self._root_agent.clone()
118
+
119
+ semaphore = asyncio.Semaphore(
120
+ value=inference_request.inference_config.parallelism
121
+ )
122
+
123
+ async def run_inference(eval_case):
124
+ async with semaphore:
125
+ return await self._perform_inference_sigle_eval_item(
126
+ app_name=inference_request.app_name,
127
+ eval_set_id=inference_request.eval_set_id,
128
+ eval_case=eval_case,
129
+ root_agent=root_agent,
130
+ )
131
+
132
+ inference_results = [run_inference(eval_case) for eval_case in eval_cases]
133
+ for inference_result in asyncio.as_completed(inference_results):
134
+ yield await inference_result
135
+
136
+ @override
137
+ async def evaluate(
138
+ self,
139
+ evaluate_request: EvaluateRequest,
140
+ ) -> AsyncGenerator[EvalCaseResult, None]:
141
+ """Returns EvalCaseResult for each item as and when they are available.
142
+
143
+ Args:
144
+ evaluate_request: The request to perform metric evaluations on the
145
+ inferences.
146
+ """
147
+ semaphore = asyncio.Semaphore(
148
+ value=evaluate_request.evaluate_config.parallelism
149
+ )
150
+
151
+ async def run_evaluation(inference_result):
152
+ async with semaphore:
153
+ return await self._evaluate_single_inference_result(
154
+ inference_result=inference_result,
155
+ evaluate_config=evaluate_request.evaluate_config,
156
+ )
157
+
158
+ evaluation_tasks = [
159
+ run_evaluation(inference_result)
160
+ for inference_result in evaluate_request.inference_results
161
+ ]
162
+
163
+ for evaluation_task in asyncio.as_completed(evaluation_tasks):
164
+ inference_result, eval_case_result = await evaluation_task
165
+
166
+ if self._eval_set_results_manager:
167
+ self._eval_set_results_manager.save_eval_set_result(
168
+ app_name=inference_result.app_name,
169
+ eval_set_id=inference_result.eval_set_id,
170
+ eval_case_results=[eval_case_result],
171
+ )
172
+
173
+ yield eval_case_result
174
+
175
+ async def _evaluate_single_inference_result(
176
+ self, inference_result: InferenceResult, evaluate_config: EvaluateConfig
177
+ ) -> tuple[InferenceResult, EvalCaseResult]:
178
+ """Returns EvalCaseResult for the given inference result.
179
+
180
+ A single inference result can have multiple invocations. For each
181
+ invocaiton, this method evaluates the metrics present in evaluate config.
182
+
183
+ The EvalCaseResult contains scores for each metric per invocation and the
184
+ overall score.
185
+ """
186
+ eval_case = self._eval_sets_manager.get_eval_case(
187
+ app_name=inference_result.app_name,
188
+ eval_set_id=inference_result.eval_set_id,
189
+ eval_case_id=inference_result.eval_case_id,
190
+ )
191
+
192
+ if eval_case is None:
193
+ raise NotFoundError(
194
+ f'Eval case with id {inference_result.eval_case_id} not found for'
195
+ f' app {inference_result.app_name} and eval set'
196
+ f' {inference_result.eval_set_id}.'
197
+ )
198
+
199
+ # Metric results for each invocation
200
+ eval_metric_result_per_invocation = []
201
+
202
+ # We also keep track of the overall score for a metric, derived from all
203
+ # invocation. For example, if we were keeping track the metric that compares
204
+ # how well is the final resposne as compared to a golden answer, then each
205
+ # invocation will have the value of this metric. We will also have an
206
+ # overall score using aggregation strategy across all invocations. This
207
+ # would be the score for the eval case.
208
+ overall_eval_metric_results = []
209
+
210
+ if len(inference_result.inferences) != len(eval_case.conversation):
211
+ raise ValueError(
212
+ 'Inferences should match conversations in eval case. Found'
213
+ f'{len(inference_result.inferences)} inferences '
214
+ f'{len(eval_case.conversation)} conversations in eval cases.'
215
+ )
216
+
217
+ # Pre-creating the EvalMetricResults entries for each invocation.
218
+ for actual, expected in zip(
219
+ inference_result.inferences, eval_case.conversation
220
+ ):
221
+ eval_metric_result_per_invocation.append(
222
+ EvalMetricResultPerInvocation(
223
+ actual_invocation=actual,
224
+ expected_invocation=expected,
225
+ # We will fill this as we evaluate each metric per invocation.
226
+ eval_metric_results=[],
227
+ )
228
+ )
229
+
230
+ for eval_metric in evaluate_config.eval_metrics:
231
+ # Perform evaluation of the metric.
232
+ evaluation_result = await self._evaluate_metric(
233
+ eval_metric=eval_metric,
234
+ actual_invocations=inference_result.inferences,
235
+ expected_invocations=eval_case.conversation,
236
+ )
237
+
238
+ # Track overall scrore across all invocations.
239
+ overall_eval_metric_results.append(
240
+ EvalMetricResult(
241
+ metric_name=eval_metric.metric_name,
242
+ threshold=eval_metric.threshold,
243
+ score=evaluation_result.overall_score,
244
+ eval_status=evaluation_result.overall_eval_status,
245
+ )
246
+ )
247
+
248
+ if len(evaluation_result.per_invocation_results) != len(
249
+ eval_metric_result_per_invocation
250
+ ):
251
+ raise ValueError(
252
+ 'Eval metric should return results for each invocation. Found '
253
+ f'{len(evaluation_result.per_invocation_results)} results for '
254
+ f'{len(eval_metric_result_per_invocation)} invocations.'
255
+ )
256
+
257
+ # Track score across individual invocations.
258
+ for invocation_result, invocation in zip(
259
+ evaluation_result.per_invocation_results,
260
+ eval_metric_result_per_invocation,
261
+ ):
262
+ invocation.eval_metric_results.append(
263
+ EvalMetricResult(
264
+ metric_name=eval_metric.metric_name,
265
+ threshold=eval_metric.threshold,
266
+ score=invocation_result.score,
267
+ eval_status=invocation_result.eval_status,
268
+ )
269
+ )
270
+
271
+ final_eval_status = self._generate_final_eval_status(
272
+ overall_eval_metric_results
273
+ )
274
+ user_id = (
275
+ eval_case.session_input.user_id
276
+ if eval_case.session_input and eval_case.session_input.user_id
277
+ else 'test_user_id'
278
+ )
279
+
280
+ eval_case_result = EvalCaseResult(
281
+ eval_set_file=inference_result.eval_set_id,
282
+ eval_set_id=inference_result.eval_set_id,
283
+ eval_id=inference_result.eval_case_id,
284
+ final_eval_status=final_eval_status,
285
+ overall_eval_metric_results=overall_eval_metric_results,
286
+ eval_metric_result_per_invocation=eval_metric_result_per_invocation,
287
+ session_id=inference_result.session_id,
288
+ session_details=await self._session_service.get_session(
289
+ app_name=inference_result.app_name,
290
+ user_id=user_id,
291
+ session_id=inference_result.session_id,
292
+ ),
293
+ user_id=user_id,
294
+ )
295
+
296
+ return (inference_result, eval_case_result)
297
+
298
+ async def _evaluate_metric(
299
+ self,
300
+ eval_metric: EvalMetric,
301
+ actual_invocations: list[Invocation],
302
+ expected_invocations: list[Invocation],
303
+ ) -> EvaluationResult:
304
+ """Returns EvaluationResult obtained from evaluating a metric using an Evaluator."""
305
+
306
+ # Get the metric evaluator from the registry.
307
+ metric_evaluator = self._metric_evaluator_registry.get_evaluator(
308
+ eval_metric=eval_metric
309
+ )
310
+
311
+ if inspect.iscoroutinefunction(metric_evaluator.evaluate_invocations):
312
+ # Some evaluators could be async, for example those that use llm as a
313
+ # judge, so we need to make sure that we wait on them.
314
+ return await metric_evaluator.evaluate_invocations(
315
+ actual_invocations=actual_invocations,
316
+ expected_invocations=expected_invocations,
317
+ )
318
+ else:
319
+ # Metrics that perform computation synchronously, mostly these don't
320
+ # perform any i/o. An example of this would calculation of rouge_1 score.
321
+ return metric_evaluator.evaluate_invocations(
322
+ actual_invocations=actual_invocations,
323
+ expected_invocations=expected_invocations,
324
+ )
325
+
326
+ def _generate_final_eval_status(
327
+ self, overall_eval_metric_results: list[EvalMetricResult]
328
+ ) -> EvalStatus:
329
+ final_eval_status = EvalStatus.NOT_EVALUATED
330
+ # Go over the all the eval statuses and mark the final eval status as
331
+ # passed if all of them pass, otherwise mark the final eval status to
332
+ # failed.
333
+ for overall_eval_metric_result in overall_eval_metric_results:
334
+ overall_eval_status = overall_eval_metric_result.eval_status
335
+ if overall_eval_status == EvalStatus.PASSED:
336
+ final_eval_status = EvalStatus.PASSED
337
+ elif overall_eval_status == EvalStatus.NOT_EVALUATED:
338
+ continue
339
+ elif overall_eval_status == EvalStatus.FAILED:
340
+ final_eval_status = EvalStatus.FAILED
341
+ break
342
+ else:
343
+ raise ValueError(f'Unknown eval status: {overall_eval_status}.')
344
+
345
+ return final_eval_status
346
+
347
+ async def _perform_inference_sigle_eval_item(
348
+ self,
349
+ app_name: str,
350
+ eval_set_id: str,
351
+ eval_case: EvalCase,
352
+ root_agent: BaseAgent,
353
+ ) -> InferenceResult:
354
+ initial_session = eval_case.session_input
355
+ session_id = self._session_id_supplier()
356
+ inference_result = InferenceResult(
357
+ app_name=app_name,
358
+ eval_set_id=eval_set_id,
359
+ eval_case_id=eval_case.eval_id,
360
+ session_id=session_id,
361
+ )
362
+
363
+ try:
364
+ inferences = (
365
+ await EvaluationGenerator._generate_inferences_from_root_agent(
366
+ invocations=eval_case.conversation,
367
+ root_agent=root_agent,
368
+ initial_session=initial_session,
369
+ session_id=session_id,
370
+ session_service=self._session_service,
371
+ artifact_service=self._artifact_service,
372
+ )
373
+ )
374
+
375
+ inference_result.inferences = inferences
376
+ inference_result.status = InferenceStatus.SUCCESS
377
+
378
+ return inference_result
379
+ except Exception as e:
380
+ # We intentionally catch the Exception as we don't failures to affect
381
+ # other inferences.
382
+ logger.error(
383
+ 'Inference failed for eval case `%s` with error %s',
384
+ eval_case.eval_id,
385
+ e,
386
+ )
387
+ inference_result.status = InferenceStatus.FAILURE
388
+ inference_result.error_message = str(e)
389
+ return inference_result
@@ -27,6 +27,7 @@ from google.genai import types as genai_types
27
27
  from pydantic import ValidationError
28
28
  from typing_extensions import override
29
29
 
30
+ from ..errors.not_found_error import NotFoundError
30
31
  from ._eval_sets_manager_utils import add_eval_case_to_eval_set
31
32
  from ._eval_sets_manager_utils import delete_eval_case_from_eval_set
32
33
  from ._eval_sets_manager_utils import get_eval_case_from_eval_set
@@ -226,16 +227,30 @@ class LocalEvalSetsManager(EvalSetsManager):
226
227
 
227
228
  @override
228
229
  def list_eval_sets(self, app_name: str) -> list[str]:
229
- """Returns a list of EvalSets that belong to the given app_name."""
230
+ """Returns a list of EvalSets that belong to the given app_name.
231
+
232
+ Args:
233
+ app_name: The app name to list the eval sets for.
234
+
235
+ Returns:
236
+ A list of EvalSet ids.
237
+
238
+ Raises:
239
+ NotFoundError: If the eval directory for the app is not found.
240
+ """
230
241
  eval_set_file_path = os.path.join(self._agents_dir, app_name)
231
242
  eval_sets = []
232
- for file in os.listdir(eval_set_file_path):
233
- if file.endswith(_EVAL_SET_FILE_EXTENSION):
234
- eval_sets.append(
235
- os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
236
- )
237
-
238
- return sorted(eval_sets)
243
+ try:
244
+ for file in os.listdir(eval_set_file_path):
245
+ if file.endswith(_EVAL_SET_FILE_EXTENSION):
246
+ eval_sets.append(
247
+ os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
248
+ )
249
+ return sorted(eval_sets)
250
+ except FileNotFoundError as e:
251
+ raise NotFoundError(
252
+ f"Eval directory for app `{app_name}` not found."
253
+ ) from e
239
254
 
240
255
  @override
241
256
  def get_eval_case(
@@ -14,6 +14,8 @@
14
14
 
15
15
  """Implementation of AutoFlow."""
16
16
 
17
+ from __future__ import annotations
18
+
17
19
  from . import agent_transfer
18
20
  from .single_flow import SingleFlow
19
21
 
@@ -29,19 +31,12 @@ class AutoFlow(SingleFlow):
29
31
 
30
32
  For peer-agent transfers, it's only enabled when all below conditions are met:
31
33
 
32
- - The parent agent is also of AutoFlow;
34
+ - The parent agent is also an LlmAgent.
33
35
  - `disallow_transfer_to_peer` option of this agent is False (default).
34
36
 
35
- Depending on the target agent flow type, the transfer may be automatically
36
- reversed. The condition is as below:
37
-
38
- - If the flow type of the tranferee agent is also auto, transfee agent will
39
- remain as the active agent. The transfee agent will respond to the user's
40
- next message directly.
41
- - If the flow type of the transfere agent is not auto, the active agent will
42
- be reversed back to previous agent.
43
-
44
- TODO: allow user to config auto-reverse function.
37
+ Depending on the target agent type, the transfer may be automatically
38
+ reversed. (see Runner._find_agent_to_run method for which agent will remain
39
+ active to handle next user message.)
45
40
  """
46
41
 
47
42
  def __init__(self):
@@ -283,14 +283,10 @@ class BaseLlmFlow(ABC):
283
283
  async for event in self._run_one_step_async(invocation_context):
284
284
  last_event = event
285
285
  yield event
286
- if not last_event or last_event.is_final_response():
286
+ if not last_event or last_event.is_final_response() or last_event.partial:
287
+ if last_event and last_event.partial:
288
+ logger.warning('The last event is partial, which is not expected.')
287
289
  break
288
- if last_event.partial:
289
- # TODO: handle this in BaseLlm level.
290
- raise ValueError(
291
- f"Last event shouldn't be partial. LLM max output limit may be"
292
- f' reached.'
293
- )
294
290
 
295
291
  async def _run_one_step_async(
296
292
  self,
@@ -569,21 +565,32 @@ class BaseLlmFlow(ABC):
569
565
  if not isinstance(agent, LlmAgent):
570
566
  return
571
567
 
572
- if not agent.canonical_before_model_callbacks:
573
- return
574
-
575
568
  callback_context = CallbackContext(
576
569
  invocation_context, event_actions=model_response_event.actions
577
570
  )
578
571
 
572
+ # First run callbacks from the plugins.
573
+ callback_response = (
574
+ await invocation_context.plugin_manager.run_before_model_callback(
575
+ callback_context=callback_context,
576
+ llm_request=llm_request,
577
+ )
578
+ )
579
+ if callback_response:
580
+ return callback_response
581
+
582
+ # If no overrides are provided from the plugins, further run the canonical
583
+ # callbacks.
584
+ if not agent.canonical_before_model_callbacks:
585
+ return
579
586
  for callback in agent.canonical_before_model_callbacks:
580
- before_model_callback_content = callback(
587
+ callback_response = callback(
581
588
  callback_context=callback_context, llm_request=llm_request
582
589
  )
583
- if inspect.isawaitable(before_model_callback_content):
584
- before_model_callback_content = await before_model_callback_content
585
- if before_model_callback_content:
586
- return before_model_callback_content
590
+ if inspect.isawaitable(callback_response):
591
+ callback_response = await callback_response
592
+ if callback_response:
593
+ return callback_response
587
594
 
588
595
  async def _handle_after_model_callback(
589
596
  self,
@@ -597,21 +604,32 @@ class BaseLlmFlow(ABC):
597
604
  if not isinstance(agent, LlmAgent):
598
605
  return
599
606
 
600
- if not agent.canonical_after_model_callbacks:
601
- return
602
-
603
607
  callback_context = CallbackContext(
604
608
  invocation_context, event_actions=model_response_event.actions
605
609
  )
606
610
 
611
+ # First run callbacks from the plugins.
612
+ callback_response = (
613
+ await invocation_context.plugin_manager.run_after_model_callback(
614
+ callback_context=CallbackContext(invocation_context),
615
+ llm_response=llm_response,
616
+ )
617
+ )
618
+ if callback_response:
619
+ return callback_response
620
+
621
+ # If no overrides are provided from the plugins, further run the canonical
622
+ # callbacks.
623
+ if not agent.canonical_after_model_callbacks:
624
+ return
607
625
  for callback in agent.canonical_after_model_callbacks:
608
- after_model_callback_content = callback(
626
+ callback_response = callback(
609
627
  callback_context=callback_context, llm_response=llm_response
610
628
  )
611
- if inspect.isawaitable(after_model_callback_content):
612
- after_model_callback_content = await after_model_callback_content
613
- if after_model_callback_content:
614
- return after_model_callback_content
629
+ if inspect.isawaitable(callback_response):
630
+ callback_response = await callback_response
631
+ if callback_response:
632
+ return callback_response
615
633
 
616
634
  def _finalize_model_response_event(
617
635
  self,
@@ -157,12 +157,21 @@ def _rearrange_events_for_latest_function_response(
157
157
  for function_call in function_calls:
158
158
  if function_call.id in function_responses_ids:
159
159
  function_call_event_idx = idx
160
- break
161
- if function_call_event_idx != -1:
162
- # in case the last response event only have part of the responses
163
- # for the function calls in the function call event
164
- for function_call in function_calls:
165
- function_responses_ids.add(function_call.id)
160
+ function_call_ids = {
161
+ function_call.id for function_call in function_calls
162
+ }
163
+ # last response event should only contain the responses for the
164
+ # function calls in the same function call event
165
+ if not function_responses_ids.issubset(function_call_ids):
166
+ raise ValueError(
167
+ 'Last response event should only contain the responses for the'
168
+ ' function calls in the same function call event. Function'
169
+ f' call ids found : {function_call_ids}, function response'
170
+ f' ids provided: {function_responses_ids}'
171
+ )
172
+ # collect all function responses from the function call event to
173
+ # the last response event
174
+ function_responses_ids = function_call_ids
166
175
  break
167
176
 
168
177
  if function_call_event_idx == -1:
@@ -363,10 +372,7 @@ def _merge_function_response_events(
363
372
  list is in increasing order of timestamp; 2. the first event is the
364
373
  initial function_response event; 3. all later events should contain at
365
374
  least one function_response part that related to the function_call
366
- event. (Note, 3. may not be true when aync function return some
367
- intermediate response, there could also be some intermediate model
368
- response event without any function_response and such event will be
369
- ignored.)
375
+ event.
370
376
  Caveat: This implementation doesn't support when a parallel function_call
371
377
  event contains async function_call of the same name.
372
378