google-adk 1.4.1__py3-none-any.whl → 1.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. google/adk/a2a/converters/event_converter.py +382 -0
  2. google/adk/a2a/converters/part_converter.py +4 -2
  3. google/adk/a2a/converters/request_converter.py +90 -0
  4. google/adk/a2a/converters/utils.py +71 -0
  5. google/adk/agents/llm_agent.py +5 -3
  6. google/adk/artifacts/gcs_artifact_service.py +3 -2
  7. google/adk/auth/auth_tool.py +2 -2
  8. google/adk/auth/credential_service/session_state_credential_service.py +83 -0
  9. google/adk/cli/cli_deploy.py +9 -2
  10. google/adk/cli/cli_tools_click.py +110 -52
  11. google/adk/cli/fast_api.py +26 -2
  12. google/adk/cli/utils/evals.py +53 -0
  13. google/adk/evaluation/final_response_match_v1.py +110 -0
  14. google/adk/evaluation/gcs_eval_sets_manager.py +8 -5
  15. google/adk/evaluation/response_evaluator.py +12 -1
  16. google/adk/events/event.py +5 -5
  17. google/adk/flows/llm_flows/contents.py +49 -4
  18. google/adk/flows/llm_flows/functions.py +32 -0
  19. google/adk/memory/__init__.py +3 -1
  20. google/adk/memory/vertex_ai_memory_bank_service.py +150 -0
  21. google/adk/models/lite_llm.py +9 -1
  22. google/adk/runners.py +10 -0
  23. google/adk/sessions/vertex_ai_session_service.py +70 -19
  24. google/adk/telemetry.py +10 -0
  25. google/adk/tools/bigquery/bigquery_credentials.py +28 -11
  26. google/adk/tools/bigquery/bigquery_tool.py +1 -1
  27. google/adk/tools/bigquery/client.py +1 -1
  28. google/adk/tools/bigquery/metadata_tool.py +1 -1
  29. google/adk/tools/bigquery/query_tool.py +1 -1
  30. google/adk/version.py +1 -1
  31. {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/METADATA +6 -5
  32. {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/RECORD +35 -29
  33. {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/WHEEL +0 -0
  34. {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/entry_points.txt +0 -0
  35. {google_adk-1.4.1.dist-info → google_adk-1.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -0,0 +1,110 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ from typing import Optional
18
+
19
+ from google.genai import types as genai_types
20
+ from rouge_score import rouge_scorer
21
+ from typing_extensions import override
22
+
23
+ from .eval_case import Invocation
24
+ from .eval_metrics import EvalMetric
25
+ from .evaluator import EvalStatus
26
+ from .evaluator import EvaluationResult
27
+ from .evaluator import Evaluator
28
+ from .evaluator import PerInvocationResult
29
+
30
+
31
+ class RougeEvaluator(Evaluator):
32
+ """Calculates the ROUGE-1 metric to compare responses."""
33
+
34
+ def __init__(self, eval_metric: EvalMetric):
35
+ self._eval_metric = eval_metric
36
+
37
+ @override
38
+ def evaluate_invocations(
39
+ self,
40
+ actual_invocations: list[Invocation],
41
+ expected_invocations: list[Invocation],
42
+ ) -> EvaluationResult:
43
+ total_score = 0.0
44
+ num_invocations = 0
45
+ per_invocation_results = []
46
+ for actual, expected in zip(actual_invocations, expected_invocations):
47
+ reference = _get_text_from_content(expected.final_response)
48
+ response = _get_text_from_content(actual.final_response)
49
+ rouge_1_scores = _calculate_rouge_1_scores(response, reference)
50
+ score = rouge_1_scores.fmeasure
51
+ per_invocation_results.append(
52
+ PerInvocationResult(
53
+ actual_invocation=actual,
54
+ expected_invocation=expected,
55
+ score=score,
56
+ eval_status=_get_eval_status(score, self._eval_metric.threshold),
57
+ )
58
+ )
59
+ total_score += score
60
+ num_invocations += 1
61
+
62
+ if per_invocation_results:
63
+ overall_score = total_score / num_invocations
64
+ return EvaluationResult(
65
+ overall_score=overall_score,
66
+ overall_eval_status=_get_eval_status(
67
+ overall_score, self._eval_metric.threshold
68
+ ),
69
+ per_invocation_results=per_invocation_results,
70
+ )
71
+
72
+ return EvaluationResult()
73
+
74
+
75
+ def _get_text_from_content(content: Optional[genai_types.Content]) -> str:
76
+ if content and content.parts:
77
+ return "\n".join([part.text for part in content.parts if part.text])
78
+
79
+ return ""
80
+
81
+
82
+ def _get_eval_status(score: float, threshold: float):
83
+ return EvalStatus.PASSED if score >= threshold else EvalStatus.FAILED
84
+
85
+
86
+ def _calculate_rouge_1_scores(candidate: str, reference: str):
87
+ """Calculates the ROUGE-1 score between a candidate and reference text.
88
+
89
+ ROUGE-1 measures the overlap of unigrams (single words) between the
90
+ candidate and reference texts. The score is broken down into:
91
+ - Precision: The proportion of unigrams in the candidate that are also in the
92
+ reference.
93
+ - Recall: The proportion of unigrams in the reference that are also in the
94
+ candidate.
95
+ - F-measure: The harmonic mean of precision and recall.
96
+
97
+ Args:
98
+ candidate: The generated text to be evaluated.
99
+ reference: The ground-truth text to compare against.
100
+
101
+ Returns:
102
+ A dictionary containing the ROUGE-1 precision, recall, and f-measure.
103
+ """
104
+ scorer = rouge_scorer.RougeScorer(["rouge1"], use_stemmer=True)
105
+
106
+ # The score method returns a dictionary where keys are the ROUGE types
107
+ # and values are Score objects (tuples) with precision, recall, and fmeasure.
108
+ scores = scorer.score(reference, candidate)
109
+
110
+ return scores["rouge1"]
@@ -72,6 +72,13 @@ class GcsEvalSetsManager(EvalSetsManager):
72
72
  f"Invalid {id_name}. {id_name} should have the `{pattern}` format",
73
73
  )
74
74
 
75
+ def _load_eval_set_from_blob(self, blob_name: str) -> Optional[EvalSet]:
76
+ blob = self.bucket.blob(blob_name)
77
+ if not blob.exists():
78
+ return None
79
+ eval_set_data = blob.download_as_text()
80
+ return EvalSet.model_validate_json(eval_set_data)
81
+
75
82
  def _write_eval_set_to_blob(self, blob_name: str, eval_set: EvalSet):
76
83
  """Writes an EvalSet to GCS."""
77
84
  blob = self.bucket.blob(blob_name)
@@ -88,11 +95,7 @@ class GcsEvalSetsManager(EvalSetsManager):
88
95
  def get_eval_set(self, app_name: str, eval_set_id: str) -> Optional[EvalSet]:
89
96
  """Returns an EvalSet identified by an app_name and eval_set_id."""
90
97
  eval_set_blob_name = self._get_eval_set_blob_name(app_name, eval_set_id)
91
- blob = self.bucket.blob(eval_set_blob_name)
92
- if not blob.exists():
93
- return None
94
- eval_set_data = blob.download_as_text()
95
- return EvalSet.model_validate_json(eval_set_data)
98
+ return self._load_eval_set_from_blob(eval_set_blob_name)
96
99
 
97
100
  @override
98
101
  def create_eval_set(self, app_name: str, eval_set_id: str):
@@ -27,10 +27,12 @@ from vertexai.preview.evaluation import MetricPromptTemplateExamples
27
27
 
28
28
  from .eval_case import IntermediateData
29
29
  from .eval_case import Invocation
30
+ from .eval_metrics import EvalMetric
30
31
  from .evaluator import EvalStatus
31
32
  from .evaluator import EvaluationResult
32
33
  from .evaluator import Evaluator
33
34
  from .evaluator import PerInvocationResult
35
+ from .final_response_match_v1 import RougeEvaluator
34
36
 
35
37
 
36
38
  class ResponseEvaluator(Evaluator):
@@ -40,7 +42,7 @@ class ResponseEvaluator(Evaluator):
40
42
  if "response_evaluation_score" == metric_name:
41
43
  self._metric_name = MetricPromptTemplateExamples.Pointwise.COHERENCE
42
44
  elif "response_match_score" == metric_name:
43
- self._metric_name = "rouge_1"
45
+ self._metric_name = "response_match_score"
44
46
  else:
45
47
  raise ValueError(f"`{metric_name}` is not supported.")
46
48
 
@@ -52,6 +54,15 @@ class ResponseEvaluator(Evaluator):
52
54
  actual_invocations: list[Invocation],
53
55
  expected_invocations: list[Invocation],
54
56
  ) -> EvaluationResult:
57
+ # If the metric is response_match_score, just use the RougeEvaluator.
58
+ if self._metric_name == "response_match_score":
59
+ rouge_evaluator = RougeEvaluator(
60
+ EvalMetric(metric_name=self._metric_name, threshold=self._threshold)
61
+ )
62
+ return rouge_evaluator.evaluate_invocations(
63
+ actual_invocations, expected_invocations
64
+ )
65
+
55
66
  total_score = 0.0
56
67
  num_invocations = 0
57
68
  per_invocation_results = []
@@ -34,9 +34,10 @@ class Event(LlmResponse):
34
34
  taken by the agents like function calls, etc.
35
35
 
36
36
  Attributes:
37
- invocation_id: The invocation ID of the event.
38
- author: "user" or the name of the agent, indicating who appended the event
39
- to the session.
37
+ invocation_id: Required. The invocation ID of the event. Should be non-empty
38
+ before appending to a session.
39
+ author: Required. "user" or the name of the agent, indicating who appended
40
+ the event to the session.
40
41
  actions: The actions taken by the agent.
41
42
  long_running_tool_ids: The ids of the long running function calls.
42
43
  branch: The branch of the event.
@@ -55,9 +56,8 @@ class Event(LlmResponse):
55
56
  )
56
57
  """The pydantic model config."""
57
58
 
58
- # TODO: revert to be required after spark migration
59
59
  invocation_id: str = ''
60
- """The invocation ID of the event."""
60
+ """The invocation ID of the event. Should be non-empty before appending to a session."""
61
61
  author: str
62
62
  """'user' or the name of the agent, indicating who appended the event to the
63
63
  session."""
@@ -43,12 +43,20 @@ class _ContentLlmRequestProcessor(BaseLlmRequestProcessor):
43
43
  if not isinstance(agent, LlmAgent):
44
44
  return
45
45
 
46
- if agent.include_contents != 'none':
46
+ if agent.include_contents == 'default':
47
+ # Include full conversation history
47
48
  llm_request.contents = _get_contents(
48
49
  invocation_context.branch,
49
50
  invocation_context.session.events,
50
51
  agent.name,
51
52
  )
53
+ else:
54
+ # Include current turn context only (no conversation history)
55
+ llm_request.contents = _get_current_turn_contents(
56
+ invocation_context.branch,
57
+ invocation_context.session.events,
58
+ agent.name,
59
+ )
52
60
 
53
61
  # Maintain async generator behavior
54
62
  if False: # Ensures it behaves as a generator
@@ -190,13 +198,15 @@ def _get_contents(
190
198
  ) -> list[types.Content]:
191
199
  """Get the contents for the LLM request.
192
200
 
201
+ Applies filtering, rearrangement, and content processing to events.
202
+
193
203
  Args:
194
204
  current_branch: The current branch of the agent.
195
- events: A list of events.
205
+ events: Events to process.
196
206
  agent_name: The name of the agent.
197
207
 
198
208
  Returns:
199
- A list of contents.
209
+ A list of processed contents.
200
210
  """
201
211
  filtered_events = []
202
212
  # Parse the events, leaving the contents and the function calls and
@@ -211,12 +221,13 @@ def _get_contents(
211
221
  # Skip events without content, or generated neither by user nor by model
212
222
  # or has empty text.
213
223
  # E.g. events purely for mutating session states.
224
+
214
225
  continue
215
226
  if not _is_event_belongs_to_branch(current_branch, event):
216
227
  # Skip events not belong to current branch.
217
228
  continue
218
229
  if _is_auth_event(event):
219
- # skip auth event
230
+ # Skip auth events.
220
231
  continue
221
232
  filtered_events.append(
222
233
  _convert_foreign_event(event)
@@ -224,12 +235,15 @@ def _get_contents(
224
235
  else event
225
236
  )
226
237
 
238
+ # Rearrange events for proper function call/response pairing
227
239
  result_events = _rearrange_events_for_latest_function_response(
228
240
  filtered_events
229
241
  )
230
242
  result_events = _rearrange_events_for_async_function_responses_in_history(
231
243
  result_events
232
244
  )
245
+
246
+ # Convert events to contents
233
247
  contents = []
234
248
  for event in result_events:
235
249
  content = copy.deepcopy(event.content)
@@ -238,6 +252,37 @@ def _get_contents(
238
252
  return contents
239
253
 
240
254
 
255
+ def _get_current_turn_contents(
256
+ current_branch: Optional[str], events: list[Event], agent_name: str = ''
257
+ ) -> list[types.Content]:
258
+ """Get contents for the current turn only (no conversation history).
259
+
260
+ When include_contents='none', we want to include:
261
+ - The current user input
262
+ - Tool calls and responses from the current turn
263
+ But exclude conversation history from previous turns.
264
+
265
+ In multi-agent scenarios, the "current turn" for an agent starts from an
266
+ actual user or from another agent.
267
+
268
+ Args:
269
+ current_branch: The current branch of the agent.
270
+ events: A list of all session events.
271
+ agent_name: The name of the agent.
272
+
273
+ Returns:
274
+ A list of contents for the current turn only, preserving context needed
275
+ for proper tool execution while excluding conversation history.
276
+ """
277
+ # Find the latest event that starts the current turn and process from there
278
+ for i in range(len(events) - 1, -1, -1):
279
+ event = events[i]
280
+ if event.author == 'user' or _is_other_agent_reply(agent_name, event):
281
+ return _get_contents(current_branch, events[i:], agent_name)
282
+
283
+ return []
284
+
285
+
241
286
  def _is_other_agent_reply(current_agent_name: str, event: Event) -> bool:
242
287
  """Whether the event is a reply from another agent."""
243
288
  return bool(
@@ -519,3 +519,35 @@ def merge_parallel_function_response_events(
519
519
  # Use the base_event as the timestamp
520
520
  merged_event.timestamp = base_event.timestamp
521
521
  return merged_event
522
+
523
+
524
+ def find_matching_function_call(
525
+ events: list[Event],
526
+ ) -> Optional[Event]:
527
+ """Finds the function call event that matches the function response id of the last event."""
528
+ if not events:
529
+ return None
530
+
531
+ last_event = events[-1]
532
+ if (
533
+ last_event.content
534
+ and last_event.content.parts
535
+ and any(part.function_response for part in last_event.content.parts)
536
+ ):
537
+
538
+ function_call_id = next(
539
+ part.function_response.id
540
+ for part in last_event.content.parts
541
+ if part.function_response
542
+ )
543
+ for i in range(len(events) - 2, -1, -1):
544
+ event = events[i]
545
+ # looking for the system long running request euc function call
546
+ function_calls = event.get_function_calls()
547
+ if not function_calls:
548
+ continue
549
+
550
+ for function_call in function_calls:
551
+ if function_call.id == function_call_id:
552
+ return event
553
+ return None
@@ -15,12 +15,14 @@ import logging
15
15
 
16
16
  from .base_memory_service import BaseMemoryService
17
17
  from .in_memory_memory_service import InMemoryMemoryService
18
+ from .vertex_ai_memory_bank_service import VertexAiMemoryBankService
18
19
 
19
20
  logger = logging.getLogger('google_adk.' + __name__)
20
21
 
21
22
  __all__ = [
22
23
  'BaseMemoryService',
23
24
  'InMemoryMemoryService',
25
+ 'VertexAiMemoryBankService',
24
26
  ]
25
27
 
26
28
  try:
@@ -29,7 +31,7 @@ try:
29
31
  __all__.append('VertexAiRagMemoryService')
30
32
  except ImportError:
31
33
  logger.debug(
32
- 'The Vertex sdk is not installed. If you want to use the'
34
+ 'The Vertex SDK is not installed. If you want to use the'
33
35
  ' VertexAiRagMemoryService please install it. If not, you can ignore this'
34
36
  ' warning.'
35
37
  )
@@ -0,0 +1,150 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from __future__ import annotations
16
+
17
+ import json
18
+ import logging
19
+ from typing import Optional
20
+ from typing import TYPE_CHECKING
21
+
22
+ from typing_extensions import override
23
+
24
+ from google import genai
25
+
26
+ from .base_memory_service import BaseMemoryService
27
+ from .base_memory_service import SearchMemoryResponse
28
+ from .memory_entry import MemoryEntry
29
+
30
+ if TYPE_CHECKING:
31
+ from ..sessions.session import Session
32
+
33
+ logger = logging.getLogger('google_adk.' + __name__)
34
+
35
+
36
+ class VertexAiMemoryBankService(BaseMemoryService):
37
+ """Implementation of the BaseMemoryService using Vertex AI Memory Bank."""
38
+
39
+ def __init__(
40
+ self,
41
+ project: Optional[str] = None,
42
+ location: Optional[str] = None,
43
+ agent_engine_id: Optional[str] = None,
44
+ ):
45
+ """Initializes a VertexAiMemoryBankService.
46
+
47
+ Args:
48
+ project: The project ID of the Memory Bank to use.
49
+ location: The location of the Memory Bank to use.
50
+ agent_engine_id: The ID of the agent engine to use for the Memory Bank.
51
+ e.g. '456' in
52
+ 'projects/my-project/locations/us-central1/reasoningEngines/456'.
53
+ """
54
+ self._project = project
55
+ self._location = location
56
+ self._agent_engine_id = agent_engine_id
57
+
58
+ @override
59
+ async def add_session_to_memory(self, session: Session):
60
+ api_client = self._get_api_client()
61
+
62
+ if not self._agent_engine_id:
63
+ raise ValueError('Agent Engine ID is required for Memory Bank.')
64
+
65
+ events = []
66
+ for event in session.events:
67
+ if event.content and event.content.parts:
68
+ events.append({
69
+ 'content': event.content.model_dump(exclude_none=True, mode='json')
70
+ })
71
+ request_dict = {
72
+ 'direct_contents_source': {
73
+ 'events': events,
74
+ },
75
+ 'scope': {
76
+ 'app_name': session.app_name,
77
+ 'user_id': session.user_id,
78
+ },
79
+ }
80
+
81
+ if events:
82
+ api_response = await api_client.async_request(
83
+ http_method='POST',
84
+ path=f'reasoningEngines/{self._agent_engine_id}/memories:generate',
85
+ request_dict=request_dict,
86
+ )
87
+ logger.info(f'Generate memory response: {api_response}')
88
+ else:
89
+ logger.info('No events to add to memory.')
90
+
91
+ @override
92
+ async def search_memory(self, *, app_name: str, user_id: str, query: str):
93
+ api_client = self._get_api_client()
94
+
95
+ api_response = await api_client.async_request(
96
+ http_method='POST',
97
+ path=f'reasoningEngines/{self._agent_engine_id}/memories:retrieve',
98
+ request_dict={
99
+ 'scope': {
100
+ 'app_name': app_name,
101
+ 'user_id': user_id,
102
+ },
103
+ 'similarity_search_params': {
104
+ 'search_query': query,
105
+ },
106
+ },
107
+ )
108
+ api_response = _convert_api_response(api_response)
109
+ logger.info(f'Search memory response: {api_response}')
110
+
111
+ if not api_response or not api_response.get('retrievedMemories', None):
112
+ return SearchMemoryResponse()
113
+
114
+ memory_events = []
115
+ for memory in api_response.get('retrievedMemories', []):
116
+ # TODO: add more complex error handling
117
+ memory_events.append(
118
+ MemoryEntry(
119
+ author='user',
120
+ content=genai.types.Content(
121
+ parts=[
122
+ genai.types.Part(text=memory.get('memory').get('fact'))
123
+ ],
124
+ role='user',
125
+ ),
126
+ timestamp=memory.get('updateTime'),
127
+ )
128
+ )
129
+ return SearchMemoryResponse(memories=memory_events)
130
+
131
+ def _get_api_client(self):
132
+ """Instantiates an API client for the given project and location.
133
+
134
+ It needs to be instantiated inside each request so that the event loop
135
+ management can be properly propagated.
136
+
137
+ Returns:
138
+ An API client for the given project and location.
139
+ """
140
+ client = genai.Client(
141
+ vertexai=True, project=self._project, location=self._location
142
+ )
143
+ return client._api_client
144
+
145
+
146
+ def _convert_api_response(api_response):
147
+ """Converts the API response to a JSON object based on the type."""
148
+ if hasattr(api_response, 'body'):
149
+ return json.loads(api_response.body)
150
+ return api_response
@@ -29,6 +29,7 @@ from typing import Tuple
29
29
  from typing import Union
30
30
 
31
31
  from google.genai import types
32
+ import litellm
32
33
  from litellm import acompletion
33
34
  from litellm import ChatCompletionAssistantMessage
34
35
  from litellm import ChatCompletionAssistantToolCall
@@ -53,6 +54,9 @@ from .base_llm import BaseLlm
53
54
  from .llm_request import LlmRequest
54
55
  from .llm_response import LlmResponse
55
56
 
57
+ # This will add functions to prompts if functions are provided.
58
+ litellm.add_function_to_prompt = True
59
+
56
60
  logger = logging.getLogger("google_adk." + __name__)
57
61
 
58
62
  _NEW_LINE = "\n"
@@ -662,6 +666,10 @@ class LiteLlm(BaseLlm):
662
666
 
663
667
  messages, tools, response_format = _get_completion_inputs(llm_request)
664
668
 
669
+ if "functions" in self._additional_args:
670
+ # LiteLLM does not support both tools and functions together.
671
+ tools = None
672
+
665
673
  completion_args = {
666
674
  "model": self.model,
667
675
  "messages": messages,
@@ -679,7 +687,7 @@ class LiteLlm(BaseLlm):
679
687
  aggregated_llm_response_with_tool_call = None
680
688
  usage_metadata = None
681
689
  fallback_index = 0
682
- for part in self.llm_client.completion(**completion_args):
690
+ async for part in await self.llm_client.acompletion(**completion_args):
683
691
  for chunk, finish_reason in _model_response_to_chunk(part):
684
692
  if isinstance(chunk, FunctionChunk):
685
693
  index = chunk.index or fallback_index
google/adk/runners.py CHANGED
@@ -36,6 +36,7 @@ from .artifacts.in_memory_artifact_service import InMemoryArtifactService
36
36
  from .auth.credential_service.base_credential_service import BaseCredentialService
37
37
  from .code_executors.built_in_code_executor import BuiltInCodeExecutor
38
38
  from .events.event import Event
39
+ from .flows.llm_flows.functions import find_matching_function_call
39
40
  from .memory.base_memory_service import BaseMemoryService
40
41
  from .memory.in_memory_memory_service import InMemoryMemoryService
41
42
  from .platform.thread import create_thread
@@ -337,6 +338,8 @@ class Runner:
337
338
  """Finds the agent to run to continue the session.
338
339
 
339
340
  A qualified agent must be either of:
341
+ - The agent that returned a function call and the last user message is a
342
+ function response to this function call.
340
343
  - The root agent;
341
344
  - An LlmAgent who replied last and is capable to transfer to any other agent
342
345
  in the agent hierarchy.
@@ -348,6 +351,13 @@ class Runner:
348
351
  Returns:
349
352
  The agent of the last message in the session or the root agent.
350
353
  """
354
+ # If the last event is a function response, should send this response to
355
+ # the agent that returned the corressponding function call regardless the
356
+ # type of the agent. e.g. a remote a2a agent may surface a credential
357
+ # request as a special long running function tool call.
358
+ event = find_matching_function_call(session.events)
359
+ if event and event.author:
360
+ return root_agent.find_agent(event.author)
351
361
  for event in filter(lambda e: e.author != 'user', reversed(session.events)):
352
362
  if event.author == root_agent.name:
353
363
  # Found root agent.