google-adk 0.3.0__py3-none-any.whl → 0.5.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/active_streaming_tool.py +1 -0
- google/adk/agents/base_agent.py +27 -29
- google/adk/agents/callback_context.py +4 -4
- google/adk/agents/invocation_context.py +1 -0
- google/adk/agents/langgraph_agent.py +1 -0
- google/adk/agents/live_request_queue.py +1 -0
- google/adk/agents/llm_agent.py +54 -14
- google/adk/agents/run_config.py +4 -0
- google/adk/agents/transcription_entry.py +1 -0
- google/adk/artifacts/base_artifact_service.py +5 -10
- google/adk/artifacts/gcs_artifact_service.py +8 -8
- google/adk/artifacts/in_memory_artifact_service.py +5 -5
- google/adk/auth/auth_credential.py +4 -5
- google/adk/cli/browser/index.html +1 -1
- google/adk/cli/browser/{main-HWIBUY2R.js → main-ULN5R5I5.js} +40 -39
- google/adk/cli/cli.py +54 -47
- google/adk/cli/cli_eval.py +13 -11
- google/adk/cli/cli_tools_click.py +58 -7
- google/adk/cli/fast_api.py +11 -11
- google/adk/cli/fast_api.py.orig +728 -0
- google/adk/evaluation/agent_evaluator.py +3 -3
- google/adk/evaluation/evaluation_constants.py +1 -0
- google/adk/evaluation/evaluation_generator.py +5 -5
- google/adk/evaluation/response_evaluator.py +1 -1
- google/adk/events/event.py +1 -0
- google/adk/events/event_actions.py +10 -4
- google/adk/examples/example.py +1 -0
- google/adk/flows/__init__.py +0 -1
- google/adk/flows/llm_flows/_code_execution.py +10 -10
- google/adk/flows/llm_flows/base_llm_flow.py +40 -15
- google/adk/flows/llm_flows/basic.py +3 -0
- google/adk/flows/llm_flows/contents.py +9 -5
- google/adk/flows/llm_flows/functions.py +38 -16
- google/adk/flows/llm_flows/instructions.py +17 -6
- google/adk/memory/base_memory_service.py +4 -2
- google/adk/memory/in_memory_memory_service.py +2 -2
- google/adk/memory/vertex_ai_rag_memory_service.py +2 -2
- google/adk/models/anthropic_llm.py +20 -2
- google/adk/models/base_llm.py +45 -4
- google/adk/models/gemini_llm_connection.py +14 -1
- google/adk/models/google_llm.py +0 -42
- google/adk/models/lite_llm.py +17 -17
- google/adk/models/llm_request.py +1 -1
- google/adk/models/llm_response.py +1 -1
- google/adk/runners.py +5 -5
- google/adk/sessions/_session_util.py +43 -0
- google/adk/sessions/base_session_service.py +3 -0
- google/adk/sessions/database_session_service.py +63 -46
- google/adk/sessions/in_memory_session_service.py +3 -3
- google/adk/sessions/session.py +1 -0
- google/adk/sessions/vertex_ai_session_service.py +7 -5
- google/adk/tools/agent_tool.py +7 -4
- google/adk/tools/application_integration_tool/__init__.py +2 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +48 -26
- google/adk/tools/application_integration_tool/clients/connections_client.py +33 -77
- google/adk/tools/application_integration_tool/integration_connector_tool.py +159 -0
- google/adk/tools/function_tool.py +42 -0
- google/adk/tools/load_artifacts_tool.py +4 -4
- google/adk/tools/load_memory_tool.py +4 -2
- google/adk/tools/mcp_tool/conversion_utils.py +1 -1
- google/adk/tools/mcp_tool/mcp_session_manager.py +14 -0
- google/adk/tools/openapi_tool/common/common.py +2 -5
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +13 -3
- google/adk/tools/preload_memory_tool.py +1 -1
- google/adk/tools/tool_context.py +4 -4
- google/adk/version.py +1 -1
- {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/METADATA +3 -7
- {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/RECORD +71 -68
- {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/WHEEL +0 -0
- {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.3.0.dist-info → google_adk-0.5.0.dist-info}/licenses/LICENSE +0 -0
@@ -76,7 +76,7 @@ class AgentEvaluator:
|
|
76
76
|
return DEFAULT_CRITERIA
|
77
77
|
|
78
78
|
@staticmethod
|
79
|
-
def evaluate(
|
79
|
+
async def evaluate(
|
80
80
|
agent_module,
|
81
81
|
eval_dataset_file_path_or_dir,
|
82
82
|
num_runs=NUM_RUNS,
|
@@ -120,7 +120,7 @@ class AgentEvaluator:
|
|
120
120
|
|
121
121
|
AgentEvaluator._validate_input([dataset], criteria)
|
122
122
|
|
123
|
-
evaluation_response = AgentEvaluator._generate_responses(
|
123
|
+
evaluation_response = await AgentEvaluator._generate_responses(
|
124
124
|
agent_module,
|
125
125
|
[dataset],
|
126
126
|
num_runs,
|
@@ -246,7 +246,7 @@ class AgentEvaluator:
|
|
246
246
|
return inferred_criteria
|
247
247
|
|
248
248
|
@staticmethod
|
249
|
-
def _generate_responses(
|
249
|
+
async def _generate_responses(
|
250
250
|
agent_module, eval_dataset, num_runs, agent_name=None, initial_session={}
|
251
251
|
):
|
252
252
|
"""Generates evaluation responses by running the agent module multiple times."""
|
@@ -32,7 +32,7 @@ class EvaluationGenerator:
|
|
32
32
|
"""Generates evaluation responses for agents."""
|
33
33
|
|
34
34
|
@staticmethod
|
35
|
-
def generate_responses(
|
35
|
+
async def generate_responses(
|
36
36
|
eval_dataset,
|
37
37
|
agent_module_path,
|
38
38
|
repeat_num=3,
|
@@ -107,7 +107,7 @@ class EvaluationGenerator:
|
|
107
107
|
)
|
108
108
|
|
109
109
|
@staticmethod
|
110
|
-
def _process_query_with_root_agent(
|
110
|
+
async def _process_query_with_root_agent(
|
111
111
|
data,
|
112
112
|
root_agent,
|
113
113
|
reset_func,
|
@@ -128,7 +128,7 @@ class EvaluationGenerator:
|
|
128
128
|
all_mock_tools.add(expected[EvalConstants.TOOL_NAME])
|
129
129
|
|
130
130
|
eval_data_copy = data.copy()
|
131
|
-
EvaluationGenerator.apply_before_tool_callback(
|
131
|
+
await EvaluationGenerator.apply_before_tool_callback(
|
132
132
|
root_agent,
|
133
133
|
lambda *args: EvaluationGenerator.before_tool_callback(
|
134
134
|
*args, eval_dataset=eval_data_copy
|
@@ -247,7 +247,7 @@ class EvaluationGenerator:
|
|
247
247
|
return None
|
248
248
|
|
249
249
|
@staticmethod
|
250
|
-
def apply_before_tool_callback(
|
250
|
+
async def apply_before_tool_callback(
|
251
251
|
agent: BaseAgent,
|
252
252
|
callback: BeforeToolCallback,
|
253
253
|
all_mock_tools: set[str],
|
@@ -265,6 +265,6 @@ class EvaluationGenerator:
|
|
265
265
|
|
266
266
|
# Apply recursively to subagents if they exist
|
267
267
|
for sub_agent in agent.sub_agents:
|
268
|
-
EvaluationGenerator.apply_before_tool_callback(
|
268
|
+
await EvaluationGenerator.apply_before_tool_callback(
|
269
269
|
sub_agent, callback, all_mock_tools
|
270
270
|
)
|
@@ -28,7 +28,7 @@ class ResponseEvaluator:
|
|
28
28
|
raw_eval_dataset: list[list[dict[str, Any]]],
|
29
29
|
evaluation_criteria: list[str],
|
30
30
|
*,
|
31
|
-
print_detailed_results: bool = False
|
31
|
+
print_detailed_results: bool = False,
|
32
32
|
):
|
33
33
|
r"""Returns the value of requested evaluation metrics.
|
34
34
|
|
google/adk/events/event.py
CHANGED
@@ -27,6 +27,7 @@ class EventActions(BaseModel):
|
|
27
27
|
"""Represents the actions attached to an event."""
|
28
28
|
|
29
29
|
model_config = ConfigDict(extra='forbid')
|
30
|
+
"""The pydantic model config."""
|
30
31
|
|
31
32
|
skip_summarization: Optional[bool] = None
|
32
33
|
"""If true, it won't call model to summarize function response.
|
@@ -48,8 +49,13 @@ class EventActions(BaseModel):
|
|
48
49
|
"""The agent is escalating to a higher level agent."""
|
49
50
|
|
50
51
|
requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
|
51
|
-
"""
|
52
|
-
|
53
|
-
|
54
|
-
|
52
|
+
"""Authentication configurations requested by tool responses.
|
53
|
+
|
54
|
+
This field will only be set by a tool response event indicating tool request
|
55
|
+
auth credential.
|
56
|
+
- Keys: The function call id. Since one function response event could contain
|
57
|
+
multiple function responses that correspond to multiple function calls. Each
|
58
|
+
function call could request different auth configs. This id is used to
|
59
|
+
identify the function call.
|
60
|
+
- Values: The requested auth config.
|
55
61
|
"""
|
google/adk/examples/example.py
CHANGED
google/adk/flows/__init__.py
CHANGED
@@ -122,7 +122,7 @@ class _CodeExecutionRequestProcessor(BaseLlmRequestProcessor):
|
|
122
122
|
if not invocation_context.agent.code_executor:
|
123
123
|
return
|
124
124
|
|
125
|
-
for event in _run_pre_processor(invocation_context, llm_request):
|
125
|
+
async for event in _run_pre_processor(invocation_context, llm_request):
|
126
126
|
yield event
|
127
127
|
|
128
128
|
# Convert the code execution parts to text parts.
|
@@ -152,17 +152,17 @@ class _CodeExecutionResponseProcessor(BaseLlmResponseProcessor):
|
|
152
152
|
if llm_response.partial:
|
153
153
|
return
|
154
154
|
|
155
|
-
for event in _run_post_processor(invocation_context, llm_response):
|
155
|
+
async for event in _run_post_processor(invocation_context, llm_response):
|
156
156
|
yield event
|
157
157
|
|
158
158
|
|
159
159
|
response_processor = _CodeExecutionResponseProcessor()
|
160
160
|
|
161
161
|
|
162
|
-
def _run_pre_processor(
|
162
|
+
async def _run_pre_processor(
|
163
163
|
invocation_context: InvocationContext,
|
164
164
|
llm_request: LlmRequest,
|
165
|
-
) ->
|
165
|
+
) -> AsyncGenerator[Event, None]:
|
166
166
|
"""Pre-process the user message by adding the user message to the Colab notebook."""
|
167
167
|
from ...agents.llm_agent import LlmAgent
|
168
168
|
|
@@ -242,17 +242,17 @@ def _run_pre_processor(
|
|
242
242
|
code_executor_context.add_processed_file_names([file.name])
|
243
243
|
|
244
244
|
# Emit the execution result, and add it to the LLM request.
|
245
|
-
execution_result_event = _post_process_code_execution_result(
|
245
|
+
execution_result_event = await _post_process_code_execution_result(
|
246
246
|
invocation_context, code_executor_context, code_execution_result
|
247
247
|
)
|
248
248
|
yield execution_result_event
|
249
249
|
llm_request.contents.append(copy.deepcopy(execution_result_event.content))
|
250
250
|
|
251
251
|
|
252
|
-
def _run_post_processor(
|
252
|
+
async def _run_post_processor(
|
253
253
|
invocation_context: InvocationContext,
|
254
254
|
llm_response,
|
255
|
-
) ->
|
255
|
+
) -> AsyncGenerator[Event, None]:
|
256
256
|
"""Post-process the model response by extracting and executing the first code block."""
|
257
257
|
agent = invocation_context.agent
|
258
258
|
code_executor = agent.code_executor
|
@@ -305,7 +305,7 @@ def _run_post_processor(
|
|
305
305
|
code_execution_result.stdout,
|
306
306
|
code_execution_result.stderr,
|
307
307
|
)
|
308
|
-
yield _post_process_code_execution_result(
|
308
|
+
yield await _post_process_code_execution_result(
|
309
309
|
invocation_context, code_executor_context, code_execution_result
|
310
310
|
)
|
311
311
|
|
@@ -375,7 +375,7 @@ def _get_or_set_execution_id(
|
|
375
375
|
return execution_id
|
376
376
|
|
377
377
|
|
378
|
-
def _post_process_code_execution_result(
|
378
|
+
async def _post_process_code_execution_result(
|
379
379
|
invocation_context: InvocationContext,
|
380
380
|
code_executor_context: CodeExecutorContext,
|
381
381
|
code_execution_result: CodeExecutionResult,
|
@@ -406,7 +406,7 @@ def _post_process_code_execution_result(
|
|
406
406
|
|
407
407
|
# Handle output files.
|
408
408
|
for output_file in code_execution_result.output_files:
|
409
|
-
version = invocation_context.artifact_service.save_artifact(
|
409
|
+
version = await invocation_context.artifact_service.save_artifact(
|
410
410
|
app_name=invocation_context.app_name,
|
411
411
|
user_id=invocation_context.user_id,
|
412
412
|
session_id=invocation_context.session.id,
|
@@ -16,6 +16,7 @@ from __future__ import annotations
|
|
16
16
|
|
17
17
|
from abc import ABC
|
18
18
|
import asyncio
|
19
|
+
import inspect
|
19
20
|
import logging
|
20
21
|
from typing import AsyncGenerator
|
21
22
|
from typing import cast
|
@@ -190,6 +191,17 @@ class BaseLlmFlow(ABC):
|
|
190
191
|
llm_request: LlmRequest,
|
191
192
|
) -> AsyncGenerator[Event, None]:
|
192
193
|
"""Receive data from model and process events using BaseLlmConnection."""
|
194
|
+
def get_author(llm_response):
|
195
|
+
"""Get the author of the event.
|
196
|
+
|
197
|
+
When the model returns transcription, the author is "user". Otherwise, the
|
198
|
+
author is the agent.
|
199
|
+
"""
|
200
|
+
if llm_response and llm_response.content and llm_response.content.role == "user":
|
201
|
+
return "user"
|
202
|
+
else:
|
203
|
+
return invocation_context.agent.name
|
204
|
+
|
193
205
|
assert invocation_context.live_request_queue
|
194
206
|
try:
|
195
207
|
while True:
|
@@ -197,7 +209,7 @@ class BaseLlmFlow(ABC):
|
|
197
209
|
model_response_event = Event(
|
198
210
|
id=Event.new_id(),
|
199
211
|
invocation_id=invocation_context.invocation_id,
|
200
|
-
author=
|
212
|
+
author=get_author(llm_response),
|
201
213
|
)
|
202
214
|
async for event in self._postprocess_live(
|
203
215
|
invocation_context,
|
@@ -249,7 +261,6 @@ class BaseLlmFlow(ABC):
|
|
249
261
|
|
250
262
|
# Calls the LLM.
|
251
263
|
model_response_event = Event(
|
252
|
-
id=Event.new_id(),
|
253
264
|
invocation_id=invocation_context.invocation_id,
|
254
265
|
author=invocation_context.agent.name,
|
255
266
|
branch=invocation_context.branch,
|
@@ -261,6 +272,8 @@ class BaseLlmFlow(ABC):
|
|
261
272
|
async for event in self._postprocess_async(
|
262
273
|
invocation_context, llm_request, llm_response, model_response_event
|
263
274
|
):
|
275
|
+
# Use a new id for every event.
|
276
|
+
event.id = Event.new_id()
|
264
277
|
yield event
|
265
278
|
|
266
279
|
async def _preprocess_async(
|
@@ -437,7 +450,7 @@ class BaseLlmFlow(ABC):
|
|
437
450
|
model_response_event: Event,
|
438
451
|
) -> AsyncGenerator[LlmResponse, None]:
|
439
452
|
# Runs before_model_callback if it exists.
|
440
|
-
if response := self._handle_before_model_callback(
|
453
|
+
if response := await self._handle_before_model_callback(
|
441
454
|
invocation_context, llm_request, model_response_event
|
442
455
|
):
|
443
456
|
yield response
|
@@ -450,7 +463,7 @@ class BaseLlmFlow(ABC):
|
|
450
463
|
invocation_context.live_request_queue = LiveRequestQueue()
|
451
464
|
async for llm_response in self.run_live(invocation_context):
|
452
465
|
# Runs after_model_callback if it exists.
|
453
|
-
if altered_llm_response := self._handle_after_model_callback(
|
466
|
+
if altered_llm_response := await self._handle_after_model_callback(
|
454
467
|
invocation_context, llm_response, model_response_event
|
455
468
|
):
|
456
469
|
llm_response = altered_llm_response
|
@@ -479,14 +492,14 @@ class BaseLlmFlow(ABC):
|
|
479
492
|
llm_response,
|
480
493
|
)
|
481
494
|
# Runs after_model_callback if it exists.
|
482
|
-
if altered_llm_response := self._handle_after_model_callback(
|
495
|
+
if altered_llm_response := await self._handle_after_model_callback(
|
483
496
|
invocation_context, llm_response, model_response_event
|
484
497
|
):
|
485
498
|
llm_response = altered_llm_response
|
486
499
|
|
487
500
|
yield llm_response
|
488
501
|
|
489
|
-
def _handle_before_model_callback(
|
502
|
+
async def _handle_before_model_callback(
|
490
503
|
self,
|
491
504
|
invocation_context: InvocationContext,
|
492
505
|
llm_request: LlmRequest,
|
@@ -498,17 +511,23 @@ class BaseLlmFlow(ABC):
|
|
498
511
|
if not isinstance(agent, LlmAgent):
|
499
512
|
return
|
500
513
|
|
501
|
-
if not agent.
|
514
|
+
if not agent.canonical_before_model_callbacks:
|
502
515
|
return
|
503
516
|
|
504
517
|
callback_context = CallbackContext(
|
505
518
|
invocation_context, event_actions=model_response_event.actions
|
506
519
|
)
|
507
|
-
return agent.before_model_callback(
|
508
|
-
callback_context=callback_context, llm_request=llm_request
|
509
|
-
)
|
510
520
|
|
511
|
-
|
521
|
+
for callback in agent.canonical_before_model_callbacks:
|
522
|
+
before_model_callback_content = callback(
|
523
|
+
callback_context=callback_context, llm_request=llm_request
|
524
|
+
)
|
525
|
+
if inspect.isawaitable(before_model_callback_content):
|
526
|
+
before_model_callback_content = await before_model_callback_content
|
527
|
+
if before_model_callback_content:
|
528
|
+
return before_model_callback_content
|
529
|
+
|
530
|
+
async def _handle_after_model_callback(
|
512
531
|
self,
|
513
532
|
invocation_context: InvocationContext,
|
514
533
|
llm_response: LlmResponse,
|
@@ -520,15 +539,21 @@ class BaseLlmFlow(ABC):
|
|
520
539
|
if not isinstance(agent, LlmAgent):
|
521
540
|
return
|
522
541
|
|
523
|
-
if not agent.
|
542
|
+
if not agent.canonical_after_model_callbacks:
|
524
543
|
return
|
525
544
|
|
526
545
|
callback_context = CallbackContext(
|
527
546
|
invocation_context, event_actions=model_response_event.actions
|
528
547
|
)
|
529
|
-
|
530
|
-
|
531
|
-
|
548
|
+
|
549
|
+
for callback in agent.canonical_after_model_callbacks:
|
550
|
+
after_model_callback_content = callback(
|
551
|
+
callback_context=callback_context, llm_response=llm_response
|
552
|
+
)
|
553
|
+
if inspect.isawaitable(after_model_callback_content):
|
554
|
+
after_model_callback_content = await after_model_callback_content
|
555
|
+
if after_model_callback_content:
|
556
|
+
return after_model_callback_content
|
532
557
|
|
533
558
|
def _finalize_model_response_event(
|
534
559
|
self,
|
@@ -62,6 +62,9 @@ class _BasicLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
62
62
|
llm_request.live_connect_config.output_audio_transcription = (
|
63
63
|
invocation_context.run_config.output_audio_transcription
|
64
64
|
)
|
65
|
+
llm_request.live_connect_config.input_audio_transcription = (
|
66
|
+
invocation_context.run_config.input_audio_transcription
|
67
|
+
)
|
65
68
|
|
66
69
|
# TODO: handle tool append here, instead of in BaseTool.process_llm_request.
|
67
70
|
|
@@ -15,9 +15,7 @@
|
|
15
15
|
from __future__ import annotations
|
16
16
|
|
17
17
|
import copy
|
18
|
-
from typing import AsyncGenerator
|
19
|
-
from typing import Generator
|
20
|
-
from typing import Optional
|
18
|
+
from typing import AsyncGenerator, Generator, Optional
|
21
19
|
|
22
20
|
from google.genai import types
|
23
21
|
from typing_extensions import override
|
@@ -202,8 +200,14 @@ def _get_contents(
|
|
202
200
|
# Parse the events, leaving the contents and the function calls and
|
203
201
|
# responses from the current agent.
|
204
202
|
for event in events:
|
205
|
-
if
|
206
|
-
|
203
|
+
if (
|
204
|
+
not event.content
|
205
|
+
or not event.content.role
|
206
|
+
or not event.content.parts
|
207
|
+
or event.content.parts[0].text == ''
|
208
|
+
):
|
209
|
+
# Skip events without content, or generated neither by user nor by model
|
210
|
+
# or has empty text.
|
207
211
|
# E.g. events purely for mutating session states.
|
208
212
|
continue
|
209
213
|
if not _is_event_belongs_to_branch(current_branch, event):
|
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
|
|
151
151
|
# do not use "args" as the variable name, because it is a reserved keyword
|
152
152
|
# in python debugger.
|
153
153
|
function_args = function_call.args or {}
|
154
|
-
function_response = None
|
155
|
-
|
154
|
+
function_response: Optional[dict] = None
|
155
|
+
|
156
|
+
# before_tool_callback (sync or async)
|
156
157
|
if agent.before_tool_callback:
|
157
158
|
function_response = agent.before_tool_callback(
|
158
159
|
tool=tool, args=function_args, tool_context=tool_context
|
159
160
|
)
|
161
|
+
if inspect.isawaitable(function_response):
|
162
|
+
function_response = await function_response
|
160
163
|
|
161
164
|
if not function_response:
|
162
165
|
function_response = await __call_tool_async(
|
163
166
|
tool, args=function_args, tool_context=tool_context
|
164
167
|
)
|
165
168
|
|
166
|
-
#
|
169
|
+
# after_tool_callback (sync or async)
|
167
170
|
if agent.after_tool_callback:
|
168
|
-
|
171
|
+
altered_function_response = agent.after_tool_callback(
|
169
172
|
tool=tool,
|
170
173
|
args=function_args,
|
171
174
|
tool_context=tool_context,
|
172
175
|
tool_response=function_response,
|
173
176
|
)
|
174
|
-
if
|
175
|
-
|
177
|
+
if inspect.isawaitable(altered_function_response):
|
178
|
+
altered_function_response = await altered_function_response
|
179
|
+
if altered_function_response is not None:
|
180
|
+
function_response = altered_function_response
|
176
181
|
|
177
182
|
if tool.is_long_running:
|
178
183
|
# Allow long running function to return None to not provide function response.
|
@@ -223,11 +228,17 @@ async def handle_function_calls_live(
|
|
223
228
|
# in python debugger.
|
224
229
|
function_args = function_call.args or {}
|
225
230
|
function_response = None
|
226
|
-
# Calls the tool if before_tool_callback does not exist or returns None.
|
231
|
+
# # Calls the tool if before_tool_callback does not exist or returns None.
|
232
|
+
# if agent.before_tool_callback:
|
233
|
+
# function_response = agent.before_tool_callback(
|
234
|
+
# tool, function_args, tool_context
|
235
|
+
# )
|
227
236
|
if agent.before_tool_callback:
|
228
237
|
function_response = agent.before_tool_callback(
|
229
|
-
tool, function_args, tool_context
|
238
|
+
tool=tool, args=function_args, tool_context=tool_context
|
230
239
|
)
|
240
|
+
if inspect.isawaitable(function_response):
|
241
|
+
function_response = await function_response
|
231
242
|
|
232
243
|
if not function_response:
|
233
244
|
function_response = await _process_function_live_helper(
|
@@ -235,15 +246,26 @@ async def handle_function_calls_live(
|
|
235
246
|
)
|
236
247
|
|
237
248
|
# Calls after_tool_callback if it exists.
|
249
|
+
# if agent.after_tool_callback:
|
250
|
+
# new_response = agent.after_tool_callback(
|
251
|
+
# tool,
|
252
|
+
# function_args,
|
253
|
+
# tool_context,
|
254
|
+
# function_response,
|
255
|
+
# )
|
256
|
+
# if new_response:
|
257
|
+
# function_response = new_response
|
238
258
|
if agent.after_tool_callback:
|
239
|
-
|
240
|
-
tool,
|
241
|
-
function_args,
|
242
|
-
tool_context,
|
243
|
-
function_response,
|
259
|
+
altered_function_response = agent.after_tool_callback(
|
260
|
+
tool=tool,
|
261
|
+
args=function_args,
|
262
|
+
tool_context=tool_context,
|
263
|
+
tool_response=function_response,
|
244
264
|
)
|
245
|
-
if
|
246
|
-
|
265
|
+
if inspect.isawaitable(altered_function_response):
|
266
|
+
altered_function_response = await altered_function_response
|
267
|
+
if altered_function_response is not None:
|
268
|
+
function_response = altered_function_response
|
247
269
|
|
248
270
|
if tool.is_long_running:
|
249
271
|
# Allow async function to return None to not provide function response.
|
@@ -310,7 +332,7 @@ async def _process_function_live_helper(
|
|
310
332
|
function_response = {
|
311
333
|
'status': f'No active streaming function named {function_name} found'
|
312
334
|
}
|
313
|
-
elif hasattr(tool,
|
335
|
+
elif hasattr(tool, 'func') and inspect.isasyncgenfunction(tool.func):
|
314
336
|
# for streaming tool use case
|
315
337
|
# we require the function to be a async generator function
|
316
338
|
async def run_tool_and_update_queue(tool, function_args, tool_context):
|
@@ -56,13 +56,13 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
56
56
|
raw_si = root_agent.canonical_global_instruction(
|
57
57
|
ReadonlyContext(invocation_context)
|
58
58
|
)
|
59
|
-
si = _populate_values(raw_si, invocation_context)
|
59
|
+
si = await _populate_values(raw_si, invocation_context)
|
60
60
|
llm_request.append_instructions([si])
|
61
61
|
|
62
62
|
# Appends agent instructions if set.
|
63
63
|
if agent.instruction: # not empty str
|
64
64
|
raw_si = agent.canonical_instruction(ReadonlyContext(invocation_context))
|
65
|
-
si = _populate_values(raw_si, invocation_context)
|
65
|
+
si = await _populate_values(raw_si, invocation_context)
|
66
66
|
llm_request.append_instructions([si])
|
67
67
|
|
68
68
|
# Maintain async generator behavior
|
@@ -73,13 +73,24 @@ class _InstructionsLlmRequestProcessor(BaseLlmRequestProcessor):
|
|
73
73
|
request_processor = _InstructionsLlmRequestProcessor()
|
74
74
|
|
75
75
|
|
76
|
-
def _populate_values(
|
76
|
+
async def _populate_values(
|
77
77
|
instruction_template: str,
|
78
78
|
context: InvocationContext,
|
79
79
|
) -> str:
|
80
80
|
"""Populates values in the instruction template, e.g. state, artifact, etc."""
|
81
81
|
|
82
|
-
def
|
82
|
+
async def _async_sub(pattern, repl_async_fn, string) -> str:
|
83
|
+
result = []
|
84
|
+
last_end = 0
|
85
|
+
for match in re.finditer(pattern, string):
|
86
|
+
result.append(string[last_end : match.start()])
|
87
|
+
replacement = await repl_async_fn(match)
|
88
|
+
result.append(replacement)
|
89
|
+
last_end = match.end()
|
90
|
+
result.append(string[last_end:])
|
91
|
+
return ''.join(result)
|
92
|
+
|
93
|
+
async def _replace_match(match) -> str:
|
83
94
|
var_name = match.group().lstrip('{').rstrip('}').strip()
|
84
95
|
optional = False
|
85
96
|
if var_name.endswith('?'):
|
@@ -89,7 +100,7 @@ def _populate_values(
|
|
89
100
|
var_name = var_name.removeprefix('artifact.')
|
90
101
|
if context.artifact_service is None:
|
91
102
|
raise ValueError('Artifact service is not initialized.')
|
92
|
-
artifact = context.artifact_service.load_artifact(
|
103
|
+
artifact = await context.artifact_service.load_artifact(
|
93
104
|
app_name=context.session.app_name,
|
94
105
|
user_id=context.session.user_id,
|
95
106
|
session_id=context.session.id,
|
@@ -109,7 +120,7 @@ def _populate_values(
|
|
109
120
|
else:
|
110
121
|
raise KeyError(f'Context variable not found: `{var_name}`.')
|
111
122
|
|
112
|
-
return
|
123
|
+
return await _async_sub(r'{+[^{}]*}+', _replace_match, instruction_template)
|
113
124
|
|
114
125
|
|
115
126
|
def _is_valid_state_name(var_name):
|
@@ -28,6 +28,7 @@ class MemoryResult(BaseModel):
|
|
28
28
|
session_id: The session id associated with the memory.
|
29
29
|
events: A list of events in the session.
|
30
30
|
"""
|
31
|
+
|
31
32
|
session_id: str
|
32
33
|
events: list[Event]
|
33
34
|
|
@@ -38,6 +39,7 @@ class SearchMemoryResponse(BaseModel):
|
|
38
39
|
Attributes:
|
39
40
|
memories: A list of memory results matching the search query.
|
40
41
|
"""
|
42
|
+
|
41
43
|
memories: list[MemoryResult] = Field(default_factory=list)
|
42
44
|
|
43
45
|
|
@@ -49,7 +51,7 @@ class BaseMemoryService(abc.ABC):
|
|
49
51
|
"""
|
50
52
|
|
51
53
|
@abc.abstractmethod
|
52
|
-
def add_session_to_memory(self, session: Session):
|
54
|
+
async def add_session_to_memory(self, session: Session):
|
53
55
|
"""Adds a session to the memory service.
|
54
56
|
|
55
57
|
A session may be added multiple times during its lifetime.
|
@@ -59,7 +61,7 @@ class BaseMemoryService(abc.ABC):
|
|
59
61
|
"""
|
60
62
|
|
61
63
|
@abc.abstractmethod
|
62
|
-
def search_memory(
|
64
|
+
async def search_memory(
|
63
65
|
self, *, app_name: str, user_id: str, query: str
|
64
66
|
) -> SearchMemoryResponse:
|
65
67
|
"""Searches for sessions that match the query.
|
@@ -29,13 +29,13 @@ class InMemoryMemoryService(BaseMemoryService):
|
|
29
29
|
self.session_events: dict[str, list[Event]] = {}
|
30
30
|
"""keys are app_name/user_id/session_id"""
|
31
31
|
|
32
|
-
def add_session_to_memory(self, session: Session):
|
32
|
+
async def add_session_to_memory(self, session: Session):
|
33
33
|
key = f'{session.app_name}/{session.user_id}/{session.id}'
|
34
34
|
self.session_events[key] = [
|
35
35
|
event for event in session.events if event.content
|
36
36
|
]
|
37
37
|
|
38
|
-
def search_memory(
|
38
|
+
async def search_memory(
|
39
39
|
self, *, app_name: str, user_id: str, query: str
|
40
40
|
) -> SearchMemoryResponse:
|
41
41
|
"""Prototyping purpose only."""
|
@@ -54,7 +54,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
54
54
|
)
|
55
55
|
|
56
56
|
@override
|
57
|
-
def add_session_to_memory(self, session: Session):
|
57
|
+
async def add_session_to_memory(self, session: Session):
|
58
58
|
with tempfile.NamedTemporaryFile(
|
59
59
|
mode="w", delete=False, suffix=".txt"
|
60
60
|
) as temp_file:
|
@@ -91,7 +91,7 @@ class VertexAiRagMemoryService(BaseMemoryService):
|
|
91
91
|
os.remove(temp_file_path)
|
92
92
|
|
93
93
|
@override
|
94
|
-
def search_memory(
|
94
|
+
async def search_memory(
|
95
95
|
self, *, app_name: str, user_id: str, query: str
|
96
96
|
) -> SearchMemoryResponse:
|
97
97
|
"""Searches for sessions that match the query using rag.retrieval_query."""
|
@@ -19,6 +19,7 @@ from __future__ import annotations
|
|
19
19
|
from functools import cached_property
|
20
20
|
import logging
|
21
21
|
import os
|
22
|
+
from typing import Any
|
22
23
|
from typing import AsyncGenerator
|
23
24
|
from typing import Generator
|
24
25
|
from typing import Iterable
|
@@ -151,6 +152,24 @@ def message_to_generate_content_response(
|
|
151
152
|
)
|
152
153
|
|
153
154
|
|
155
|
+
def _update_type_string(value_dict: dict[str, Any]):
|
156
|
+
"""Updates 'type' field to expected JSON schema format."""
|
157
|
+
if "type" in value_dict:
|
158
|
+
value_dict["type"] = value_dict["type"].lower()
|
159
|
+
|
160
|
+
if "items" in value_dict:
|
161
|
+
# 'type' field could exist for items as well, this would be the case if
|
162
|
+
# items represent primitive types.
|
163
|
+
_update_type_string(value_dict["items"])
|
164
|
+
|
165
|
+
if "properties" in value_dict["items"]:
|
166
|
+
# There could be properties as well on the items, especially if the items
|
167
|
+
# are complex object themselves. We recursively traverse each individual
|
168
|
+
# property as well and fix the "type" value.
|
169
|
+
for _, value in value_dict["items"]["properties"].items():
|
170
|
+
_update_type_string(value)
|
171
|
+
|
172
|
+
|
154
173
|
def function_declaration_to_tool_param(
|
155
174
|
function_declaration: types.FunctionDeclaration,
|
156
175
|
) -> anthropic_types.ToolParam:
|
@@ -163,8 +182,7 @@ def function_declaration_to_tool_param(
|
|
163
182
|
):
|
164
183
|
for key, value in function_declaration.parameters.properties.items():
|
165
184
|
value_dict = value.model_dump(exclude_none=True)
|
166
|
-
|
167
|
-
value_dict["type"] = value_dict["type"].lower()
|
185
|
+
_update_type_string(value_dict)
|
168
186
|
properties[key] = value_dict
|
169
187
|
|
170
188
|
return anthropic_types.ToolParam(
|