google-adk 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (71) hide show
  1. google/adk/agents/active_streaming_tool.py +1 -0
  2. google/adk/agents/base_agent.py +27 -29
  3. google/adk/agents/callback_context.py +4 -4
  4. google/adk/agents/invocation_context.py +1 -0
  5. google/adk/agents/langgraph_agent.py +1 -0
  6. google/adk/agents/live_request_queue.py +1 -0
  7. google/adk/agents/llm_agent.py +54 -14
  8. google/adk/agents/run_config.py +4 -0
  9. google/adk/agents/transcription_entry.py +1 -0
  10. google/adk/artifacts/base_artifact_service.py +5 -10
  11. google/adk/artifacts/gcs_artifact_service.py +8 -8
  12. google/adk/artifacts/in_memory_artifact_service.py +5 -5
  13. google/adk/auth/auth_credential.py +4 -5
  14. google/adk/cli/browser/index.html +1 -1
  15. google/adk/cli/browser/{main-HWIBUY2R.js → main-ULN5R5I5.js} +40 -39
  16. google/adk/cli/cli.py +54 -47
  17. google/adk/cli/cli_eval.py +13 -11
  18. google/adk/cli/cli_tools_click.py +58 -7
  19. google/adk/cli/fast_api.py +11 -11
  20. google/adk/cli/fast_api.py.orig +728 -0
  21. google/adk/evaluation/agent_evaluator.py +3 -3
  22. google/adk/evaluation/evaluation_constants.py +1 -0
  23. google/adk/evaluation/evaluation_generator.py +5 -5
  24. google/adk/evaluation/response_evaluator.py +1 -1
  25. google/adk/events/event.py +1 -0
  26. google/adk/events/event_actions.py +10 -4
  27. google/adk/examples/example.py +1 -0
  28. google/adk/flows/__init__.py +0 -1
  29. google/adk/flows/llm_flows/_code_execution.py +10 -10
  30. google/adk/flows/llm_flows/base_llm_flow.py +40 -15
  31. google/adk/flows/llm_flows/basic.py +3 -0
  32. google/adk/flows/llm_flows/contents.py +9 -5
  33. google/adk/flows/llm_flows/functions.py +38 -16
  34. google/adk/flows/llm_flows/instructions.py +17 -6
  35. google/adk/memory/base_memory_service.py +4 -2
  36. google/adk/memory/in_memory_memory_service.py +2 -2
  37. google/adk/memory/vertex_ai_rag_memory_service.py +2 -2
  38. google/adk/models/anthropic_llm.py +20 -2
  39. google/adk/models/base_llm.py +45 -4
  40. google/adk/models/gemini_llm_connection.py +14 -1
  41. google/adk/models/google_llm.py +0 -42
  42. google/adk/models/lite_llm.py +17 -17
  43. google/adk/models/llm_request.py +1 -1
  44. google/adk/models/llm_response.py +1 -1
  45. google/adk/runners.py +5 -5
  46. google/adk/sessions/_session_util.py +43 -0
  47. google/adk/sessions/base_session_service.py +3 -0
  48. google/adk/sessions/database_session_service.py +63 -46
  49. google/adk/sessions/in_memory_session_service.py +3 -3
  50. google/adk/sessions/session.py +1 -0
  51. google/adk/sessions/vertex_ai_session_service.py +7 -5
  52. google/adk/tools/agent_tool.py +7 -4
  53. google/adk/tools/application_integration_tool/__init__.py +2 -0
  54. google/adk/tools/application_integration_tool/application_integration_toolset.py +48 -26
  55. google/adk/tools/application_integration_tool/clients/connections_client.py +33 -77
  56. google/adk/tools/application_integration_tool/integration_connector_tool.py +159 -0
  57. google/adk/tools/function_tool.py +42 -0
  58. google/adk/tools/load_artifacts_tool.py +4 -4
  59. google/adk/tools/load_memory_tool.py +4 -2
  60. google/adk/tools/mcp_tool/conversion_utils.py +1 -1
  61. google/adk/tools/mcp_tool/mcp_session_manager.py +14 -0
  62. google/adk/tools/openapi_tool/common/common.py +2 -5
  63. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +13 -3
  64. google/adk/tools/preload_memory_tool.py +1 -1
  65. google/adk/tools/tool_context.py +4 -4
  66. google/adk/version.py +1 -1
  67. {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/METADATA +3 -7
  68. {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/RECORD +71 -68
  69. {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/WHEEL +0 -0
  70. {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/entry_points.txt +0 -0
  71. {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -76,7 +76,7 @@ class AgentEvaluator:
76
76
  return DEFAULT_CRITERIA
77
77
 
78
78
  @staticmethod
79
- def evaluate(
79
+ async def evaluate(
80
80
  agent_module,
81
81
  eval_dataset_file_path_or_dir,
82
82
  num_runs=NUM_RUNS,
@@ -120,7 +120,7 @@ class AgentEvaluator:
120
120
 
121
121
  AgentEvaluator._validate_input([dataset], criteria)
122
122
 
123
- evaluation_response = AgentEvaluator._generate_responses(
123
+ evaluation_response = await AgentEvaluator._generate_responses(
124
124
  agent_module,
125
125
  [dataset],
126
126
  num_runs,
@@ -246,7 +246,7 @@ class AgentEvaluator:
246
246
  return inferred_criteria
247
247
 
248
248
  @staticmethod
249
- def _generate_responses(
249
+ async def _generate_responses(
250
250
  agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
251
251
  ):
252
252
  """Generates evaluation responses by running the agent module multiple times."""
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  class EvalConstants:
16
17
  """Holds constants for evaluation file constants."""
17
18
 
@@ -32,7 +32,7 @@ class EvaluationGenerator:
32
32
  """Generates evaluation responses for agents."""
33
33
 
34
34
  @staticmethod
35
- def generate_responses(
35
+ async def generate_responses(
36
36
  eval_dataset,
37
37
  agent_module_path,
38
38
  repeat_num=3,
@@ -107,7 +107,7 @@ class EvaluationGenerator:
107
107
  )
108
108
 
109
109
  @staticmethod
110
- def _process_query_with_root_agent(
110
+ async def _process_query_with_root_agent(
111
111
  data,
112
112
  root_agent,
113
113
  reset_func,
@@ -128,7 +128,7 @@ class EvaluationGenerator:
128
128
  all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
129
129
 
130
130
  eval_data_copy = data.copy()
131
- EvaluationGenerator.apply_before_tool_callback(
131
+ await EvaluationGenerator.apply_before_tool_callback(
132
132
  root_agent,
133
133
  lambda *args: EvaluationGenerator.before_tool_callback(
134
134
  *args, eval_dataset=eval_data_copy
@@ -247,7 +247,7 @@ class EvaluationGenerator:
247
247
  return None
248
248
 
249
249
  @staticmethod
250
- def apply_before_tool_callback(
250
+ async def apply_before_tool_callback(
251
251
  agent: BaseAgent,
252
252
  callback: BeforeToolCallback,
253
253
  all_mock_tools: set[str],
@@ -265,6 +265,6 @@ class EvaluationGenerator:
265
265
 
266
266
  # Apply recursively to subagents if they exist
267
267
  for sub_agent in agent.sub_agents:
268
- EvaluationGenerator.apply_before_tool_callback(
268
+ await EvaluationGenerator.apply_before_tool_callback(
269
269
  sub_agent, callback, all_mock_tools
270
270
  )
@@ -28,7 +28,7 @@ class ResponseEvaluator:
28
28
  raw_eval_dataset: list[list[dict[str, Any]]],
29
29
  evaluation_criteria: list[str],
30
30
  *,
31
- print_detailed_results: bool = False
31
+ print_detailed_results: bool = False,
32
32
  ):
33
33
  r"""Returns the value of requested evaluation metrics.
34
34
 
@@ -48,6 +48,7 @@ class Event(LlmResponse):
48
48
  model_config = ConfigDict(
49
49
  extra='forbid', ser_json_bytes='base64', val_json_bytes='base64'
50
50
  )
51
+ """The pydantic model config."""
51
52
 
52
53
  # TODO: revert to be required after spark migration
53
54
  invocation_id: str = ''
@@ -27,6 +27,7 @@ class EventActions(BaseModel):
27
27
  """Represents the actions attached to an event."""
28
28
 
29
29
  model_config = ConfigDict(extra='forbid')
30
+ """The pydantic model config."""
30
31
 
31
32
  skip_summarization: Optional[bool] = None
32
33
  """If true, it won't call model to summarize function response.
@@ -48,8 +49,13 @@ class EventActions(BaseModel):
48
49
  """The agent is escalating to a higher level agent."""
49
50
 
50
51
  requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
51
- """Will only be set by a tool response indicating tool request euc.
52
- dict key is the function call id since one function call response (from model)
53
- could correspond to multiple function calls.
54
- dict value is the required auth config.
52
+ """Authentication configurations requested by tool responses.
53
+
54
+ This field will only be set by a tool response event indicating tool request
55
+ auth credential.
56
+ - Keys: The function call id. Since one function response event could contain
57
+ multiple function responses that correspond to multiple function calls. Each
58
+ function call could request different auth configs. This id is used to
59
+ identify the function call.
60
+ - Values: The requested auth config.
55
61
  """
@@ -23,5 +23,6 @@ class Example(BaseModel):
23
23
  input: The input content for the example.
24
24
  output: The expected output content for the example.
25
25
  """
26
+
26
27
  input: types.Content
27
28
  output: list[types.Content]
@@ -11,4 +11,3 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
-
@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
122
122
  if not invocation_context.agent.code_executor:
123
123
  return
124
124
 
125
- for event in _run_pre_processor(invocation_context, llm_request):
125
+ async for event in _run_pre_processor(invocation_context, llm_request):
126
126
  yield event
127
127
 
128
128
  # Convert the code execution parts to text parts.
@@ -152,17 +152,17 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
152
152
  if llm_response.partial:
153
153
  return
154
154
 
155
- for event in _run_post_processor(invocation_context, llm_response):
155
+ async for event in _run_post_processor(invocation_context, llm_response):
156
156
  yield event
157
157
 
158
158
 
159
159
  response_processor = _CodeExecutionResponseProcessor()
160
160
 
161
161
 
162
- def _run_pre_processor(
162
+ async def _run_pre_processor(
163
163
  invocation_context: InvocationContext,
164
164
  llm_request: LlmRequest,
165
- ) -> Generator[Event, None, None]:
165
+ ) -> AsyncGenerator[Event, None]:
166
166
  """Pre-process the user message by adding the user message to the Colab notebook."""
167
167
  from ...agents.llm_agent import LlmAgent
168
168
 
@@ -242,17 +242,17 @@ def _run_pre_processor(
242
242
  code_executor_context.add_processed_file_names([file.name])
243
243
 
244
244
  # Emit the execution result, and add it to the LLM request.
245
- execution_result_event = _post_process_code_execution_result(
245
+ execution_result_event = await _post_process_code_execution_result(
246
246
  invocation_context, code_executor_context, code_execution_result
247
247
  )
248
248
  yield execution_result_event
249
249
  llm_request.contents.append(copy.deepcopy(execution_result_event.content))
250
250
 
251
251
 
252
- def _run_post_processor(
252
+ async def _run_post_processor(
253
253
  invocation_context: InvocationContext,
254
254
  llm_response,
255
- ) -> Generator[Event, None, None]:
255
+ ) -> AsyncGenerator[Event, None]:
256
256
  """Post-process the model response by extracting and executing the first code block."""
257
257
  agent = invocation_context.agent
258
258
  code_executor = agent.code_executor
@@ -305,7 +305,7 @@ def _run_post_processor(
305
305
  code_execution_result.stdout,
306
306
  code_execution_result.stderr,
307
307
  )
308
- yield _post_process_code_execution_result(
308
+ yield await _post_process_code_execution_result(
309
309
  invocation_context, code_executor_context, code_execution_result
310
310
  )
311
311
 
@@ -375,7 +375,7 @@ def _get_or_set_execution_id(
375
375
  return execution_id
376
376
 
377
377
 
378
- def _post_process_code_execution_result(
378
+ async def _post_process_code_execution_result(
379
379
  invocation_context: InvocationContext,
380
380
  code_executor_context: CodeExecutorContext,
381
381
  code_execution_result: CodeExecutionResult,
@@ -406,7 +406,7 @@ def _post_process_code_execution_result(
406
406
 
407
407
  # Handle output files.
408
408
  for output_file in code_execution_result.output_files:
409
- version = invocation_context.artifact_service.save_artifact(
409
+ version = await invocation_context.artifact_service.save_artifact(
410
410
  app_name=invocation_context.app_name,
411
411
  user_id=invocation_context.user_id,
412
412
  session_id=invocation_context.session.id,
@@ -16,6 +16,7 @@ from __future__ import annotations
16
16
 
17
17
  from abc import ABC
18
18
  import asyncio
19
+ import inspect
19
20
  import logging
20
21
  from typing import AsyncGenerator
21
22
  from typing import cast
@@ -190,6 +191,17 @@ class BaseLlmFlow(ABC):
190
191
  llm_request: LlmRequest,
191
192
  ) -> AsyncGenerator[Event, None]:
192
193
  """Receive data from model and process events using BaseLlmConnection."""
194
+ def get_author(llm_response):
195
+ """Get the author of the event.
196
+
197
+ When the model returns transcription, the author is "user". Otherwise, the
198
+ author is the agent.
199
+ """
200
+ if llm_response and llm_response.content and llm_response.content.role == "user":
201
+ return "user"
202
+ else:
203
+ return invocation_context.agent.name
204
+
193
205
  assert invocation_context.live_request_queue
194
206
  try:
195
207
  while True:
@@ -197,7 +209,7 @@ class BaseLlmFlow(ABC):
197
209
  model_response_event = Event(
198
210
  id=Event.new_id(),
199
211
  invocation_id=invocation_context.invocation_id,
200
- author=invocation_context.agent.name,
212
+ author=get_author(llm_response),
201
213
  )
202
214
  async for event in self._postprocess_live(
203
215
  invocation_context,
@@ -249,7 +261,6 @@ class BaseLlmFlow(ABC):
249
261
 
250
262
  # Calls the LLM.
251
263
  model_response_event = Event(
252
- id=Event.new_id(),
253
264
  invocation_id=invocation_context.invocation_id,
254
265
  author=invocation_context.agent.name,
255
266
  branch=invocation_context.branch,
@@ -261,6 +272,8 @@ class BaseLlmFlow(ABC):
261
272
  async for event in self._postprocess_async(
262
273
  invocation_context, llm_request, llm_response, model_response_event
263
274
  ):
275
+ # Use a new id for every event.
276
+ event.id = Event.new_id()
264
277
  yield event
265
278
 
266
279
  async def _preprocess_async(
@@ -437,7 +450,7 @@ class BaseLlmFlow(ABC):
437
450
  model_response_event: Event,
438
451
  ) -> AsyncGenerator[LlmResponse, None]:
439
452
  # Runs before_model_callback if it exists.
440
- if response := self._handle_before_model_callback(
453
+ if response := await self._handle_before_model_callback(
441
454
  invocation_context, llm_request, model_response_event
442
455
  ):
443
456
  yield response
@@ -450,7 +463,7 @@ class BaseLlmFlow(ABC):
450
463
  invocation_context.live_request_queue = LiveRequestQueue()
451
464
  async for llm_response in self.run_live(invocation_context):
452
465
  # Runs after_model_callback if it exists.
453
- if altered_llm_response := self._handle_after_model_callback(
466
+ if altered_llm_response := await self._handle_after_model_callback(
454
467
  invocation_context, llm_response, model_response_event
455
468
  ):
456
469
  llm_response = altered_llm_response
@@ -479,14 +492,14 @@ class BaseLlmFlow(ABC):
479
492
  llm_response,
480
493
  )
481
494
  # Runs after_model_callback if it exists.
482
- if altered_llm_response := self._handle_after_model_callback(
495
+ if altered_llm_response := await self._handle_after_model_callback(
483
496
  invocation_context, llm_response, model_response_event
484
497
  ):
485
498
  llm_response = altered_llm_response
486
499
 
487
500
  yield llm_response
488
501
 
489
- def _handle_before_model_callback(
502
+ async def _handle_before_model_callback(
490
503
  self,
491
504
  invocation_context: InvocationContext,
492
505
  llm_request: LlmRequest,
@@ -498,17 +511,23 @@ class BaseLlmFlow(ABC):
498
511
  if not isinstance(agent, LlmAgent):
499
512
  return
500
513
 
501
- if not agent.before_model_callback:
514
+ if not agent.canonical_before_model_callbacks:
502
515
  return
503
516
 
504
517
  callback_context = CallbackContext(
505
518
  invocation_context, event_actions=model_response_event.actions
506
519
  )
507
- return agent.before_model_callback(
508
- callback_context=callback_context, llm_request=llm_request
509
- )
510
520
 
511
- def _handle_after_model_callback(
521
+ for callback in agent.canonical_before_model_callbacks:
522
+ before_model_callback_content = callback(
523
+ callback_context=callback_context, llm_request=llm_request
524
+ )
525
+ if inspect.isawaitable(before_model_callback_content):
526
+ before_model_callback_content = await before_model_callback_content
527
+ if before_model_callback_content:
528
+ return before_model_callback_content
529
+
530
+ async def _handle_after_model_callback(
512
531
  self,
513
532
  invocation_context: InvocationContext,
514
533
  llm_response: LlmResponse,
@@ -520,15 +539,21 @@ class BaseLlmFlow(ABC):
520
539
  if not isinstance(agent, LlmAgent):
521
540
  return
522
541
 
523
- if not agent.after_model_callback:
542
+ if not agent.canonical_after_model_callbacks:
524
543
  return
525
544
 
526
545
  callback_context = CallbackContext(
527
546
  invocation_context, event_actions=model_response_event.actions
528
547
  )
529
- return agent.after_model_callback(
530
- callback_context=callback_context, llm_response=llm_response
531
- )
548
+
549
+ for callback in agent.canonical_after_model_callbacks:
550
+ after_model_callback_content = callback(
551
+ callback_context=callback_context, llm_response=llm_response
552
+ )
553
+ if inspect.isawaitable(after_model_callback_content):
554
+ after_model_callback_content = await after_model_callback_content
555
+ if after_model_callback_content:
556
+ return after_model_callback_content
532
557
 
533
558
  def _finalize_model_response_event(
534
559
  self,
@@ -62,6 +62,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
62
62
  llm_request.live_connect_config.output_audio_transcription = (
63
63
  invocation_context.run_config.output_audio_transcription
64
64
  )
65
+ llm_request.live_connect_config.input_audio_transcription = (
66
+ invocation_context.run_config.input_audio_transcription
67
+ )
65
68
 
66
69
  # TODO: handle tool append here, instead of in BaseTool.process_llm_request.
67
70
 
@@ -15,9 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import copy
18
- from typing import AsyncGenerator
19
- from typing import Generator
20
- from typing import Optional
18
+ from typing import AsyncGenerator, Generator, Optional
21
19
 
22
20
  from google.genai import types
23
21
  from typing_extensions import override
@@ -202,8 +200,14 @@ def _get_contents(
202
200
  # Parse the events, leaving the contents and the function calls and
203
201
  # responses from the current agent.
204
202
  for event in events:
205
- if not event.content or not event.content.role:
206
- # Skip events without content, or generated neither by user nor by model.
203
+ if (
204
+ not event.content
205
+ or not event.content.role
206
+ or not event.content.parts
207
+ or event.content.parts[0].text == ''
208
+ ):
209
+ # Skip events without content, or generated neither by user nor by model
210
+ # or has empty text.
207
211
  # E.g. events purely for mutating session states.
208
212
  continue
209
213
  if not _is_event_belongs_to_branch(current_branch, event):
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
151
151
  # do not use "args" as the variable name, because it is a reserved keyword
152
152
  # in python debugger.
153
153
  function_args = function_call.args or {}
154
- function_response = None
155
- # Calls the tool if before_tool_callback does not exist or returns None.
154
+ function_response: Optional[dict] = None
155
+
156
+ # before_tool_callback (sync or async)
156
157
  if agent.before_tool_callback:
157
158
  function_response = agent.before_tool_callback(
158
159
  tool=tool, args=function_args, tool_context=tool_context
159
160
  )
161
+ if inspect.isawaitable(function_response):
162
+ function_response = await function_response
160
163
 
161
164
  if not function_response:
162
165
  function_response = await __call_tool_async(
163
166
  tool, args=function_args, tool_context=tool_context
164
167
  )
165
168
 
166
- # Calls after_tool_callback if it exists.
169
+ # after_tool_callback (sync or async)
167
170
  if agent.after_tool_callback:
168
- new_response = agent.after_tool_callback(
171
+ altered_function_response = agent.after_tool_callback(
169
172
  tool=tool,
170
173
  args=function_args,
171
174
  tool_context=tool_context,
172
175
  tool_response=function_response,
173
176
  )
174
- if new_response:
175
- function_response = new_response
177
+ if inspect.isawaitable(altered_function_response):
178
+ altered_function_response = await altered_function_response
179
+ if altered_function_response is not None:
180
+ function_response = altered_function_response
176
181
 
177
182
  if tool.is_long_running:
178
183
  # Allow long running function to return None to not provide function response.
@@ -223,11 +228,17 @@ async def handle_function_calls_live(
223
228
  # in python debugger.
224
229
  function_args = function_call.args or {}
225
230
  function_response = None
226
- # Calls the tool if before_tool_callback does not exist or returns None.
231
+ # # Calls the tool if before_tool_callback does not exist or returns None.
232
+ # if agent.before_tool_callback:
233
+ # function_response = agent.before_tool_callback(
234
+ # tool, function_args, tool_context
235
+ # )
227
236
  if agent.before_tool_callback:
228
237
  function_response = agent.before_tool_callback(
229
- tool, function_args, tool_context
238
+ tool=tool, args=function_args, tool_context=tool_context
230
239
  )
240
+ if inspect.isawaitable(function_response):
241
+ function_response = await function_response
231
242
 
232
243
  if not function_response:
233
244
  function_response = await _process_function_live_helper(
@@ -235,15 +246,26 @@ async def handle_function_calls_live(
235
246
  )
236
247
 
237
248
  # Calls after_tool_callback if it exists.
249
+ # if agent.after_tool_callback:
250
+ # new_response = agent.after_tool_callback(
251
+ # tool,
252
+ # function_args,
253
+ # tool_context,
254
+ # function_response,
255
+ # )
256
+ # if new_response:
257
+ # function_response = new_response
238
258
  if agent.after_tool_callback:
239
- new_response = agent.after_tool_callback(
240
- tool,
241
- function_args,
242
- tool_context,
243
- function_response,
259
+ altered_function_response = agent.after_tool_callback(
260
+ tool=tool,
261
+ args=function_args,
262
+ tool_context=tool_context,
263
+ tool_response=function_response,
244
264
  )
245
- if new_response:
246
- function_response = new_response
265
+ if inspect.isawaitable(altered_function_response):
266
+ altered_function_response = await altered_function_response
267
+ if altered_function_response is not None:
268
+ function_response = altered_function_response
247
269
 
248
270
  if tool.is_long_running:
249
271
  # Allow async function to return None to not provide function response.
@@ -310,7 +332,7 @@ async def _process_function_live_helper(
310
332
  function_response = {
311
333
  'status': f'No active streaming function named {function_name} found'
312
334
  }
313
- elif hasattr(tool, "func") and inspect.isasyncgenfunction(tool.func):
335
+ elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
314
336
  # for streaming tool use case
315
337
  # we require the function to be a async generator function
316
338
  async def run_tool_and_update_queue(tool, function_args, tool_context):
@@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
56
56
  raw_si = root_agent.canonical_global_instruction(
57
57
  ReadonlyContext(invocation_context)
58
58
  )
59
- si = _populate_values(raw_si, invocation_context)
59
+ si = await _populate_values(raw_si, invocation_context)
60
60
  llm_request.append_instructions([si])
61
61
 
62
62
  # Appends agent instructions if set.
63
63
  if agent.instruction: # not empty str
64
64
  raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
65
- si = _populate_values(raw_si, invocation_context)
65
+ si = await _populate_values(raw_si, invocation_context)
66
66
  llm_request.append_instructions([si])
67
67
 
68
68
  # Maintain async generator behavior
@@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
73
73
  request_processor = _InstructionsLlmRequestProcessor()
74
74
 
75
75
 
76
- def _populate_values(
76
+ async def _populate_values(
77
77
  instruction_template: str,
78
78
  context: InvocationContext,
79
79
  ) -> str:
80
80
  """Populates values in the instruction template, e.g. state, artifact, etc."""
81
81
 
82
- def _replace_match(match) -> str:
82
+ async def _async_sub(pattern, repl_async_fn, string) -> str:
83
+ result = []
84
+ last_end = 0
85
+ for match in re.finditer(pattern, string):
86
+ result.append(string[last_end : match.start()])
87
+ replacement = await repl_async_fn(match)
88
+ result.append(replacement)
89
+ last_end = match.end()
90
+ result.append(string[last_end:])
91
+ return ''.join(result)
92
+
93
+ async def _replace_match(match) -> str:
83
94
  var_name = match.group().lstrip('{').rstrip('}').strip()
84
95
  optional = False
85
96
  if var_name.endswith('?'):
@@ -89,7 +100,7 @@ def _populate_values(
89
100
  var_name = var_name.removeprefix('artifact.')
90
101
  if context.artifact_service is None:
91
102
  raise ValueError('Artifact service is not initialized.')
92
- artifact = context.artifact_service.load_artifact(
103
+ artifact = await context.artifact_service.load_artifact(
93
104
  app_name=context.session.app_name,
94
105
  user_id=context.session.user_id,
95
106
  session_id=context.session.id,
@@ -109,7 +120,7 @@ def _populate_values(
109
120
  else:
110
121
  raise KeyError(f'Context variable not found: `{var_name}`.')
111
122
 
112
- return re.sub(r'{+[^{}]*}+', _replace_match, instruction_template)
123
+ return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
113
124
 
114
125
 
115
126
  def _is_valid_state_name(var_name):
@@ -28,6 +28,7 @@ class MemoryResult(BaseModel):
28
28
  session_id: The session id associated with the memory.
29
29
  events: A list of events in the session.
30
30
  """
31
+
31
32
  session_id: str
32
33
  events: list[Event]
33
34
 
@@ -38,6 +39,7 @@ class SearchMemoryResponse(BaseModel):
38
39
  Attributes:
39
40
  memories: A list of memory results matching the search query.
40
41
  """
42
+
41
43
  memories: list[MemoryResult] = Field(default_factory=list)
42
44
 
43
45
 
@@ -49,7 +51,7 @@ class BaseMemoryService(abc.ABC):
49
51
  """
50
52
 
51
53
  @abc.abstractmethod
52
- def add_session_to_memory(self, session: Session):
54
+ async def add_session_to_memory(self, session: Session):
53
55
  """Adds a session to the memory service.
54
56
 
55
57
  A session may be added multiple times during its lifetime.
@@ -59,7 +61,7 @@ class BaseMemoryService(abc.ABC):
59
61
  """
60
62
 
61
63
  @abc.abstractmethod
62
- def search_memory(
64
+ async def search_memory(
63
65
  self, *, app_name: str, user_id: str, query: str
64
66
  ) -> SearchMemoryResponse:
65
67
  """Searches for sessions that match the query.
@@ -29,13 +29,13 @@ class InMemoryMemoryService(BaseMemoryService):
29
29
  self.session_events: dict[str, list[Event]] = {}
30
30
  """keys are app_name/user_id/session_id"""
31
31
 
32
- def add_session_to_memory(self, session: Session):
32
+ async def add_session_to_memory(self, session: Session):
33
33
  key = f'{session.app_name}/{session.user_id}/{session.id}'
34
34
  self.session_events[key] = [
35
35
  event for event in session.events if event.content
36
36
  ]
37
37
 
38
- def search_memory(
38
+ async def search_memory(
39
39
  self, *, app_name: str, user_id: str, query: str
40
40
  ) -> SearchMemoryResponse:
41
41
  """Prototyping purpose only."""
@@ -54,7 +54,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
54
54
  )
55
55
 
56
56
  @override
57
- def add_session_to_memory(self, session: Session):
57
+ async def add_session_to_memory(self, session: Session):
58
58
  with tempfile.NamedTemporaryFile(
59
59
  mode="w", delete=False, suffix=".txt"
60
60
  ) as temp_file:
@@ -91,7 +91,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
91
91
  os.remove(temp_file_path)
92
92
 
93
93
  @override
94
- def search_memory(
94
+ async def search_memory(
95
95
  self, *, app_name: str, user_id: str, query: str
96
96
  ) -> SearchMemoryResponse:
97
97
  """Searches for sessions that match the query using rag.retrieval_query."""
@@ -19,6 +19,7 @@ from __future__ import annotations
19
19
  from functools import cached_property
20
20
  import logging
21
21
  import os
22
+ from typing import Any
22
23
  from typing import AsyncGenerator
23
24
  from typing import Generator
24
25
  from typing import Iterable
@@ -151,6 +152,24 @@ def message_to_generate_content_response(
151
152
  )
152
153
 
153
154
 
155
+ def _update_type_string(value_dict: dict[str, Any]):
156
+ """Updates 'type' field to expected JSON schema format."""
157
+ if "type" in value_dict:
158
+ value_dict["type"] = value_dict["type"].lower()
159
+
160
+ if "items" in value_dict:
161
+ # 'type' field could exist for items as well, this would be the case if
162
+ # items represent primitive types.
163
+ _update_type_string(value_dict["items"])
164
+
165
+ if "properties" in value_dict["items"]:
166
+ # There could be properties as well on the items, especially if the items
167
+ # are complex object themselves. We recursively traverse each individual
168
+ # property as well and fix the "type" value.
169
+ for _, value in value_dict["items"]["properties"].items():
170
+ _update_type_string(value)
171
+
172
+
154
173
  def function_declaration_to_tool_param(
155
174
  function_declaration: types.FunctionDeclaration,
156
175
  ) -> anthropic_types.ToolParam:
@@ -163,8 +182,7 @@ def function_declaration_to_tool_param(
163
182
  ):
164
183
  for key, value in function_declaration.parameters.properties.items():
165
184
  value_dict = value.model_dump(exclude_none=True)
166
- if "type" in value_dict:
167
- value_dict["type"] = value_dict["type"].lower()
185
+ _update_type_string(value_dict)
168
186
  properties[key] = value_dict
169
187
 
170
188
  return anthropic_types.ToolParam(