google-adk 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -44,7 +44,7 @@ Args:
44
44
  callback_context: MUST be named 'callback_context' (enforced).
45
45
 
46
46
  Returns:
47
- The content to return to the user. When set, the agent run will skipped and
47
+ The content to return to the user. When set, the agent run will be skipped and
48
48
  the provided content will be returned to user.
49
49
  """
50
50
 
@@ -55,8 +55,8 @@ Args:
55
55
  callback_context: MUST be named 'callback_context' (enforced).
56
56
 
57
57
  Returns:
58
- The content to return to the user. When set, the agent run will skipped and
59
- the provided content will be appended to event history as agent response.
58
+ The content to return to the user. When set, the provided content will be
59
+ appended to event history as agent response.
60
60
  """
61
61
 
62
62
 
@@ -101,8 +101,8 @@ class BaseAgent(BaseModel):
101
101
  callback_context: MUST be named 'callback_context' (enforced).
102
102
 
103
103
  Returns:
104
- The content to return to the user. When set, the agent run will skipped and
105
- the provided content will be returned to user.
104
+ The content to return to the user. When set, the agent run will be skipped
105
+ and the provided content will be returned to user.
106
106
  """
107
107
  after_agent_callback: Optional[AfterAgentCallback] = None
108
108
  """Callback signature that is invoked after the agent run.
@@ -111,8 +111,8 @@ class BaseAgent(BaseModel):
111
111
  callback_context: MUST be named 'callback_context' (enforced).
112
112
 
113
113
  Returns:
114
- The content to return to the user. When set, the agent run will skipped and
115
- the provided content will be appended to event history as agent response.
114
+ The content to return to the user. When set, the provided content will be
115
+ appended to event history as agent response.
116
116
  """
117
117
 
118
118
  @final
@@ -15,12 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import logging
18
- from typing import Any
19
- from typing import AsyncGenerator
20
- from typing import Callable
21
- from typing import Literal
22
- from typing import Optional
23
- from typing import Union
18
+ from typing import Any, AsyncGenerator, Awaitable, Callable, Literal, Optional, Union
24
19
 
25
20
  from google.genai import types
26
21
  from pydantic import BaseModel
@@ -62,11 +57,11 @@ AfterModelCallback: TypeAlias = Callable[
62
57
  ]
63
58
  BeforeToolCallback: TypeAlias = Callable[
64
59
  [BaseTool, dict[str, Any], ToolContext],
65
- Optional[dict],
60
+ Union[Awaitable[Optional[dict]], Optional[dict]],
66
61
  ]
67
62
  AfterToolCallback: TypeAlias = Callable[
68
63
  [BaseTool, dict[str, Any], ToolContext, dict],
69
- Optional[dict],
64
+ Union[Awaitable[Optional[dict]], Optional[dict]],
70
65
  ]
71
66
 
72
67
  InstructionProvider: TypeAlias = Callable[[ReadonlyContext], str]
google/adk/cli/cli.py CHANGED
@@ -39,12 +39,12 @@ class InputFile(BaseModel):
39
39
 
40
40
  async def run_input_file(
41
41
  app_name: str,
42
+ user_id: str,
42
43
  root_agent: LlmAgent,
43
44
  artifact_service: BaseArtifactService,
44
- session: Session,
45
45
  session_service: BaseSessionService,
46
46
  input_path: str,
47
- ) -> None:
47
+ ) -> Session:
48
48
  runner = Runner(
49
49
  app_name=app_name,
50
50
  agent=root_agent,
@@ -55,9 +55,11 @@ async def run_input_file(
55
55
  input_file = InputFile.model_validate_json(f.read())
56
56
  input_file.state['_time'] = datetime.now()
57
57
 
58
- session.state = input_file.state
58
+ session = session_service.create_session(
59
+ app_name=app_name, user_id=user_id, state=input_file.state
60
+ )
59
61
  for query in input_file.queries:
60
- click.echo(f'user: {query}')
62
+ click.echo(f'[user]: {query}')
61
63
  content = types.Content(role='user', parts=[types.Part(text=query)])
62
64
  async for event in runner.run_async(
63
65
  user_id=session.user_id, session_id=session.id, new_message=content
@@ -65,23 +67,23 @@ async def run_input_file(
65
67
  if event.content and event.content.parts:
66
68
  if text := ''.join(part.text or '' for part in event.content.parts):
67
69
  click.echo(f'[{event.author}]: {text}')
70
+ return session
68
71
 
69
72
 
70
73
  async def run_interactively(
71
- app_name: str,
72
74
  root_agent: LlmAgent,
73
75
  artifact_service: BaseArtifactService,
74
76
  session: Session,
75
77
  session_service: BaseSessionService,
76
78
  ) -> None:
77
79
  runner = Runner(
78
- app_name=app_name,
80
+ app_name=session.app_name,
79
81
  agent=root_agent,
80
82
  artifact_service=artifact_service,
81
83
  session_service=session_service,
82
84
  )
83
85
  while True:
84
- query = input('user: ')
86
+ query = input('[user]: ')
85
87
  if not query or not query.strip():
86
88
  continue
87
89
  if query == 'exit':
@@ -100,7 +102,8 @@ async def run_cli(
100
102
  *,
101
103
  agent_parent_dir: str,
102
104
  agent_folder_name: str,
103
- json_file_path: Optional[str] = None,
105
+ input_file: Optional[str] = None,
106
+ saved_session_file: Optional[str] = None,
104
107
  save_session: bool,
105
108
  ) -> None:
106
109
  """Runs an interactive CLI for a certain agent.
@@ -109,8 +112,11 @@ async def run_cli(
109
112
  agent_parent_dir: str, the absolute path of the parent folder of the agent
110
113
  folder.
111
114
  agent_folder_name: str, the name of the agent folder.
112
- json_file_path: Optional[str], the absolute path to the json file, either
113
- *.input.json or *.session.json.
115
+ input_file: Optional[str], the absolute path to the json file that contains
116
+ the initial session state and user queries, exclusive with
117
+ saved_session_file.
118
+ saved_session_file: Optional[str], the absolute path to the json file that
119
+ contains a previously saved session, exclusive with input_file.
114
120
  save_session: bool, whether to save the session on exit.
115
121
  """
116
122
  if agent_parent_dir not in sys.path:
@@ -118,46 +124,50 @@ async def run_cli(
118
124
 
119
125
  artifact_service = InMemoryArtifactService()
120
126
  session_service = InMemorySessionService()
121
- session = session_service.create_session(
122
- app_name=agent_folder_name, user_id='test_user'
123
- )
124
127
 
125
128
  agent_module_path = os.path.join(agent_parent_dir, agent_folder_name)
126
129
  agent_module = importlib.import_module(agent_folder_name)
130
+ user_id = 'test_user'
131
+ session = session_service.create_session(
132
+ app_name=agent_folder_name, user_id=user_id
133
+ )
127
134
  root_agent = agent_module.agent.root_agent
128
135
  envs.load_dotenv_for_agent(agent_folder_name, agent_parent_dir)
129
- if json_file_path:
130
- if json_file_path.endswith('.input.json'):
131
- await run_input_file(
132
- app_name=agent_folder_name,
133
- root_agent=root_agent,
134
- artifact_service=artifact_service,
135
- session=session,
136
- session_service=session_service,
137
- input_path=json_file_path,
138
- )
139
- elif json_file_path.endswith('.session.json'):
140
- with open(json_file_path, 'r') as f:
141
- session = Session.model_validate_json(f.read())
142
- for content in session.get_contents():
143
- if content.role == 'user':
144
- print('user: ', content.parts[0].text)
136
+ if input_file:
137
+ session = await run_input_file(
138
+ app_name=agent_folder_name,
139
+ user_id=user_id,
140
+ root_agent=root_agent,
141
+ artifact_service=artifact_service,
142
+ session_service=session_service,
143
+ input_path=input_file,
144
+ )
145
+ elif saved_session_file:
146
+
147
+ loaded_session = None
148
+ with open(saved_session_file, 'r') as f:
149
+ loaded_session = Session.model_validate_json(f.read())
150
+
151
+ if loaded_session:
152
+ for event in loaded_session.events:
153
+ session_service.append_event(session, event)
154
+ content = event.content
155
+ if not content or not content.parts or not content.parts[0].text:
156
+ continue
157
+ if event.author == 'user':
158
+ click.echo(f'[user]: {content.parts[0].text}')
145
159
  else:
146
- print(content.parts[0].text)
147
- await run_interactively(
148
- agent_folder_name,
149
- root_agent,
150
- artifact_service,
151
- session,
152
- session_service,
153
- )
154
- else:
155
- print(f'Unsupported file type: {json_file_path}')
156
- exit(1)
160
+ click.echo(f'[{event.author}]: {content.parts[0].text}')
161
+
162
+ await run_interactively(
163
+ root_agent,
164
+ artifact_service,
165
+ session,
166
+ session_service,
167
+ )
157
168
  else:
158
- print(f'Running agent {root_agent.name}, type exit to exit.')
169
+ click.echo(f'Running agent {root_agent.name}, type exit to exit.')
159
170
  await run_interactively(
160
- agent_folder_name,
161
171
  root_agent,
162
172
  artifact_service,
163
173
  session,
@@ -165,11 +175,8 @@ async def run_cli(
165
175
  )
166
176
 
167
177
  if save_session:
168
- if json_file_path:
169
- session_path = json_file_path.replace('.input.json', '.session.json')
170
- else:
171
- session_id = input('Session ID to save: ')
172
- session_path = f'{agent_module_path}/{session_id}.session.json'
178
+ session_id = input('Session ID to save: ')
179
+ session_path = f'{agent_module_path}/{session_id}.session.json'
173
180
 
174
181
  # Fetch the session again to get all the details.
175
182
  session = session_service.get_session(
@@ -96,6 +96,23 @@ def cli_create_cmd(
96
96
  )
97
97
 
98
98
 
99
+ def validate_exclusive(ctx, param, value):
100
+ # Store the validated parameters in the context
101
+ if not hasattr(ctx, "exclusive_opts"):
102
+ ctx.exclusive_opts = {}
103
+
104
+ # If this option has a value and we've already seen another exclusive option
105
+ if value is not None and any(ctx.exclusive_opts.values()):
106
+ exclusive_opt = next(key for key, val in ctx.exclusive_opts.items() if val)
107
+ raise click.UsageError(
108
+ f"Options '{param.name}' and '{exclusive_opt}' cannot be set together."
109
+ )
110
+
111
+ # Record this option's value
112
+ ctx.exclusive_opts[param.name] = value is not None
113
+ return value
114
+
115
+
99
116
  @main.command("run")
100
117
  @click.option(
101
118
  "--save_session",
@@ -105,13 +122,43 @@ def cli_create_cmd(
105
122
  default=False,
106
123
  help="Optional. Whether to save the session to a json file on exit.",
107
124
  )
125
+ @click.option(
126
+ "--replay",
127
+ type=click.Path(
128
+ exists=True, dir_okay=False, file_okay=True, resolve_path=True
129
+ ),
130
+ help=(
131
+ "The json file that contains the initial state of the session and user"
132
+ " queries. A new session will be created using this state. And user"
133
+ " queries are run againt the newly created session. Users cannot"
134
+ " continue to interact with the agent."
135
+ ),
136
+ callback=validate_exclusive,
137
+ )
138
+ @click.option(
139
+ "--resume",
140
+ type=click.Path(
141
+ exists=True, dir_okay=False, file_okay=True, resolve_path=True
142
+ ),
143
+ help=(
144
+ "The json file that contains a previously saved session (by"
145
+ "--save_session option). The previous session will be re-displayed. And"
146
+ " user can continue to interact with the agent."
147
+ ),
148
+ callback=validate_exclusive,
149
+ )
108
150
  @click.argument(
109
151
  "agent",
110
152
  type=click.Path(
111
153
  exists=True, dir_okay=True, file_okay=False, resolve_path=True
112
154
  ),
113
155
  )
114
- def cli_run(agent: str, save_session: bool):
156
+ def cli_run(
157
+ agent: str,
158
+ save_session: bool,
159
+ replay: Optional[str],
160
+ resume: Optional[str],
161
+ ):
115
162
  """Runs an interactive CLI for a certain agent.
116
163
 
117
164
  AGENT: The path to the agent source code folder.
@@ -129,6 +176,8 @@ def cli_run(agent: str, save_session: bool):
129
176
  run_cli(
130
177
  agent_parent_dir=agent_parent_folder,
131
178
  agent_folder_name=agent_folder_name,
179
+ input_file=replay,
180
+ saved_session_file=resume,
132
181
  save_session=save_session,
133
182
  )
134
183
  )
@@ -48,8 +48,13 @@ class EventActions(BaseModel):
48
48
  """The agent is escalating to a higher level agent."""
49
49
 
50
50
  requested_auth_configs: dict[str, AuthConfig] = Field(default_factory=dict)
51
- """Will only be set by a tool response indicating tool request euc.
52
- dict key is the function call id since one function call response (from model)
53
- could correspond to multiple function calls.
54
- dict value is the required auth config.
51
+ """Authentication configurations requested by tool responses.
52
+
53
+ This field will only be set by a tool response event indicating tool request
54
+ auth credential.
55
+ - Keys: The function call id. Since one function response event could contain
56
+ multiple function responses that correspond to multiple function calls. Each
57
+ function call could request different auth configs. This id is used to
58
+ identify the function call.
59
+ - Values: The requested auth config.
55
60
  """
@@ -15,9 +15,7 @@
15
15
  from __future__ import annotations
16
16
 
17
17
  import copy
18
- from typing import AsyncGenerator
19
- from typing import Generator
20
- from typing import Optional
18
+ from typing import AsyncGenerator, Generator, Optional
21
19
 
22
20
  from google.genai import types
23
21
  from typing_extensions import override
@@ -202,8 +200,14 @@ def _get_contents(
202
200
  # Parse the events, leaving the contents and the function calls and
203
201
  # responses from the current agent.
204
202
  for event in events:
205
- if not event.content or not event.content.role:
206
- # Skip events without content, or generated neither by user nor by model.
203
+ if (
204
+ not event.content
205
+ or not event.content.role
206
+ or not event.content.parts
207
+ or event.content.parts[0].text == ''
208
+ ):
209
+ # Skip events without content, or generated neither by user nor by model
210
+ # or has empty text.
207
211
  # E.g. events purely for mutating session states.
208
212
  continue
209
213
  if not _is_event_belongs_to_branch(current_branch, event):
@@ -151,28 +151,33 @@ async def handle_function_calls_async(
151
151
  # do not use "args" as the variable name, because it is a reserved keyword
152
152
  # in python debugger.
153
153
  function_args = function_call.args or {}
154
- function_response = None
155
- # Calls the tool if before_tool_callback does not exist or returns None.
154
+ function_response: Optional[dict] = None
155
+
156
+ # before_tool_callback (sync or async)
156
157
  if agent.before_tool_callback:
157
158
  function_response = agent.before_tool_callback(
158
159
  tool=tool, args=function_args, tool_context=tool_context
159
160
  )
161
+ if inspect.isawaitable(function_response):
162
+ function_response = await function_response
160
163
 
161
164
  if not function_response:
162
165
  function_response = await __call_tool_async(
163
166
  tool, args=function_args, tool_context=tool_context
164
167
  )
165
168
 
166
- # Calls after_tool_callback if it exists.
169
+ # after_tool_callback (sync or async)
167
170
  if agent.after_tool_callback:
168
- new_response = agent.after_tool_callback(
171
+ altered_function_response = agent.after_tool_callback(
169
172
  tool=tool,
170
173
  args=function_args,
171
174
  tool_context=tool_context,
172
175
  tool_response=function_response,
173
176
  )
174
- if new_response:
175
- function_response = new_response
177
+ if inspect.isawaitable(altered_function_response):
178
+ altered_function_response = await altered_function_response
179
+ if altered_function_response is not None:
180
+ function_response = altered_function_response
176
181
 
177
182
  if tool.is_long_running:
178
183
  # Allow long running function to return None to not provide function response.
@@ -223,11 +228,17 @@ async def handle_function_calls_live(
223
228
  # in python debugger.
224
229
  function_args = function_call.args or {}
225
230
  function_response = None
226
- # Calls the tool if before_tool_callback does not exist or returns None.
231
+ # # Calls the tool if before_tool_callback does not exist or returns None.
232
+ # if agent.before_tool_callback:
233
+ # function_response = agent.before_tool_callback(
234
+ # tool, function_args, tool_context
235
+ # )
227
236
  if agent.before_tool_callback:
228
237
  function_response = agent.before_tool_callback(
229
- tool, function_args, tool_context
238
+ tool=tool, args=function_args, tool_context=tool_context
230
239
  )
240
+ if inspect.isawaitable(function_response):
241
+ function_response = await function_response
231
242
 
232
243
  if not function_response:
233
244
  function_response = await _process_function_live_helper(
@@ -235,15 +246,26 @@ async def handle_function_calls_live(
235
246
  )
236
247
 
237
248
  # Calls after_tool_callback if it exists.
249
+ # if agent.after_tool_callback:
250
+ # new_response = agent.after_tool_callback(
251
+ # tool,
252
+ # function_args,
253
+ # tool_context,
254
+ # function_response,
255
+ # )
256
+ # if new_response:
257
+ # function_response = new_response
238
258
  if agent.after_tool_callback:
239
- new_response = agent.after_tool_callback(
240
- tool,
241
- function_args,
242
- tool_context,
243
- function_response,
259
+ altered_function_response = agent.after_tool_callback(
260
+ tool=tool,
261
+ args=function_args,
262
+ tool_context=tool_context,
263
+ tool_response=function_response,
244
264
  )
245
- if new_response:
246
- function_response = new_response
265
+ if inspect.isawaitable(altered_function_response):
266
+ altered_function_response = await altered_function_response
267
+ if altered_function_response is not None:
268
+ function_response = altered_function_response
247
269
 
248
270
  if tool.is_long_running:
249
271
  # Allow async function to return None to not provide function response.
@@ -0,0 +1,29 @@
1
+ """Utility functions for session service."""
2
+
3
+ import base64
4
+ from typing import Any, Optional
5
+
6
+ from google.genai import types
7
+
8
+
9
+ def encode_content(content: types.Content):
10
+ """Encodes a content object to a JSON dictionary."""
11
+ encoded_content = content.model_dump(exclude_none=True)
12
+ for p in encoded_content["parts"]:
13
+ if "inline_data" in p:
14
+ p["inline_data"]["data"] = base64.b64encode(
15
+ p["inline_data"]["data"]
16
+ ).decode("utf-8")
17
+ return encoded_content
18
+
19
+
20
+ def decode_content(
21
+ content: Optional[dict[str, Any]],
22
+ ) -> Optional[types.Content]:
23
+ """Decodes a content object from a JSON dictionary."""
24
+ if not content:
25
+ return None
26
+ for p in content["parts"]:
27
+ if "inline_data" in p:
28
+ p["inline_data"]["data"] = base64.b64decode(p["inline_data"]["data"])
29
+ return types.Content.model_validate(content)