google-adk 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. google/adk/agents/callback_context.py +2 -1
  2. google/adk/agents/readonly_context.py +3 -1
  3. google/adk/auth/auth_credential.py +4 -1
  4. google/adk/cli/browser/index.html +4 -4
  5. google/adk/cli/browser/{main-QOEMUXM4.js → main-PKDNKWJE.js} +59 -59
  6. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  7. google/adk/cli/cli.py +3 -2
  8. google/adk/cli/cli_eval.py +6 -85
  9. google/adk/cli/cli_tools_click.py +39 -10
  10. google/adk/cli/fast_api.py +53 -184
  11. google/adk/cli/utils/agent_loader.py +137 -0
  12. google/adk/cli/utils/cleanup.py +40 -0
  13. google/adk/cli/utils/evals.py +2 -1
  14. google/adk/cli/utils/logs.py +2 -7
  15. google/adk/code_executors/code_execution_utils.py +2 -1
  16. google/adk/code_executors/container_code_executor.py +0 -1
  17. google/adk/code_executors/vertex_ai_code_executor.py +6 -8
  18. google/adk/evaluation/eval_case.py +3 -1
  19. google/adk/evaluation/eval_metrics.py +74 -0
  20. google/adk/evaluation/eval_result.py +86 -0
  21. google/adk/evaluation/eval_set.py +2 -0
  22. google/adk/evaluation/eval_set_results_manager.py +47 -0
  23. google/adk/evaluation/eval_sets_manager.py +2 -1
  24. google/adk/evaluation/evaluator.py +2 -0
  25. google/adk/evaluation/local_eval_set_results_manager.py +113 -0
  26. google/adk/evaluation/local_eval_sets_manager.py +4 -4
  27. google/adk/evaluation/response_evaluator.py +2 -1
  28. google/adk/evaluation/trajectory_evaluator.py +3 -2
  29. google/adk/examples/base_example_provider.py +1 -0
  30. google/adk/flows/llm_flows/base_llm_flow.py +4 -6
  31. google/adk/flows/llm_flows/contents.py +3 -1
  32. google/adk/flows/llm_flows/instructions.py +7 -77
  33. google/adk/flows/llm_flows/single_flow.py +1 -1
  34. google/adk/models/base_llm.py +2 -1
  35. google/adk/models/base_llm_connection.py +2 -0
  36. google/adk/models/google_llm.py +4 -1
  37. google/adk/models/lite_llm.py +3 -2
  38. google/adk/models/llm_response.py +2 -1
  39. google/adk/runners.py +36 -4
  40. google/adk/sessions/_session_util.py +2 -1
  41. google/adk/sessions/database_session_service.py +5 -8
  42. google/adk/sessions/vertex_ai_session_service.py +28 -13
  43. google/adk/telemetry.py +4 -2
  44. google/adk/tools/agent_tool.py +1 -1
  45. google/adk/tools/apihub_tool/apihub_toolset.py +1 -1
  46. google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
  47. google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
  48. google/adk/tools/application_integration_tool/application_integration_toolset.py +6 -2
  49. google/adk/tools/application_integration_tool/clients/connections_client.py +8 -1
  50. google/adk/tools/application_integration_tool/clients/integration_client.py +3 -1
  51. google/adk/tools/application_integration_tool/integration_connector_tool.py +1 -1
  52. google/adk/tools/base_toolset.py +40 -2
  53. google/adk/tools/bigquery/__init__.py +28 -0
  54. google/adk/tools/bigquery/bigquery_credentials.py +216 -0
  55. google/adk/tools/bigquery/bigquery_tool.py +116 -0
  56. google/adk/tools/function_parameter_parse_util.py +7 -0
  57. google/adk/tools/function_tool.py +33 -3
  58. google/adk/tools/get_user_choice_tool.py +1 -0
  59. google/adk/tools/google_api_tool/__init__.py +17 -11
  60. google/adk/tools/google_api_tool/google_api_tool.py +1 -1
  61. google/adk/tools/google_api_tool/google_api_toolset.py +0 -14
  62. google/adk/tools/google_api_tool/google_api_toolsets.py +8 -2
  63. google/adk/tools/google_search_tool.py +2 -2
  64. google/adk/tools/mcp_tool/conversion_utils.py +6 -2
  65. google/adk/tools/mcp_tool/mcp_session_manager.py +62 -188
  66. google/adk/tools/mcp_tool/mcp_tool.py +27 -24
  67. google/adk/tools/mcp_tool/mcp_toolset.py +76 -131
  68. google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
  69. google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
  70. google/adk/tools/openapi_tool/common/common.py +5 -1
  71. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
  72. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +2 -7
  73. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -1
  74. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
  75. google/adk/tools/toolbox_toolset.py +31 -3
  76. google/adk/utils/__init__.py +13 -0
  77. google/adk/utils/instructions_utils.py +131 -0
  78. google/adk/version.py +1 -1
  79. {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +12 -15
  80. {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/RECORD +83 -78
  81. google/adk/agents/base_agent.py.orig +0 -330
  82. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
  83. google/adk/cli/fast_api.py.orig +0 -822
  84. google/adk/memory/base_memory_service.py.orig +0 -76
  85. google/adk/models/google_llm.py.orig +0 -305
  86. google/adk/tools/_built_in_code_execution_tool.py +0 -70
  87. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +0 -322
  88. {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
  89. {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
  90. {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -1,822 +0,0 @@
1
- # Copyright 2025 Google LLC
2
- #
3
- # Licensed under the Apache License, Version 2.0 (the "License");
4
- # you may not use this file except in compliance with the License.
5
- # You may obtain a copy of the License at
6
- #
7
- # http://www.apache.org/licenses/LICENSE-2.0
8
- #
9
- # Unless required by applicable law or agreed to in writing, software
10
- # distributed under the License is distributed on an "AS IS" BASIS,
11
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- # See the License for the specific language governing permissions and
13
- # limitations under the License.
14
-
15
- import asyncio
16
- from contextlib import asynccontextmanager
17
- import importlib
18
- import inspect
19
- import json
20
- import logging
21
- import os
22
- from pathlib import Path
23
- import re
24
- import sys
25
- import traceback
26
- import typing
27
- from typing import Any
28
- from typing import List
29
- from typing import Literal
30
- from typing import Optional
31
-
32
- import click
33
- from click import Tuple
34
- from fastapi import FastAPI
35
- from fastapi import HTTPException
36
- from fastapi import Query
37
- from fastapi.middleware.cors import CORSMiddleware
38
- from fastapi.responses import FileResponse
39
- from fastapi.responses import RedirectResponse
40
- from fastapi.responses import StreamingResponse
41
- from fastapi.staticfiles import StaticFiles
42
- from fastapi.websockets import WebSocket
43
- from fastapi.websockets import WebSocketDisconnect
44
- from google.genai import types
45
- import graphviz
46
- from opentelemetry import trace
47
- from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
48
- from opentelemetry.sdk.trace import export
49
- from opentelemetry.sdk.trace import ReadableSpan
50
- from opentelemetry.sdk.trace import TracerProvider
51
- from pydantic import BaseModel
52
- from pydantic import ValidationError
53
- from starlette.types import Lifespan
54
-
55
- from ..agents import RunConfig
56
- from ..agents.live_request_queue import LiveRequest
57
- from ..agents.live_request_queue import LiveRequestQueue
58
- from ..agents.llm_agent import Agent
59
- from ..agents.run_config import StreamingMode
60
- from ..artifacts import InMemoryArtifactService
61
- from ..events.event import Event
62
- from ..memory.in_memory_memory_service import InMemoryMemoryService
63
- from ..runners import Runner
64
- from ..sessions.database_session_service import DatabaseSessionService
65
- from ..sessions.in_memory_session_service import InMemorySessionService
66
- from ..sessions.session import Session
67
- from ..sessions.vertex_ai_session_service import VertexAiSessionService
68
- from .cli_eval import EVAL_SESSION_ID_PREFIX
69
- from .cli_eval import EvalMetric
70
- from .cli_eval import EvalMetricResult
71
- from .cli_eval import EvalStatus
72
- from .utils import create_empty_state
73
- from .utils import envs
74
- from .utils import evals
75
-
76
- logger = logging.getLogger(__name__)
77
-
78
- _EVAL_SET_FILE_EXTENSION = ".evalset.json"
79
-
80
-
81
- class ApiServerSpanExporter(export.SpanExporter):
82
-
83
- def __init__(self, trace_dict):
84
- self.trace_dict = trace_dict
85
-
86
- def export(
87
- self, spans: typing.Sequence[ReadableSpan]
88
- ) -> export.SpanExportResult:
89
- for span in spans:
90
- if (
91
- span.name == "call_llm"
92
- or span.name == "send_data"
93
- or span.name.startswith("tool_response")
94
- ):
95
- attributes = dict(span.attributes)
96
- attributes["trace_id"] = span.get_span_context().trace_id
97
- attributes["span_id"] = span.get_span_context().span_id
98
- if attributes.get("gcp.vertex.agent.event_id", None):
99
- self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
100
- return export.SpanExportResult.SUCCESS
101
-
102
- def force_flush(self, timeout_millis: int = 30000) -> bool:
103
- return True
104
-
105
-
106
- class AgentRunRequest(BaseModel):
107
- app_name: str
108
- user_id: str
109
- session_id: str
110
- new_message: types.Content
111
- streaming: bool = False
112
-
113
-
114
- class AddSessionToEvalSetRequest(BaseModel):
115
- eval_id: str
116
- session_id: str
117
- user_id: str
118
-
119
-
120
- class RunEvalRequest(BaseModel):
121
- eval_ids: list[str] # if empty, then all evals in the eval set are run.
122
- eval_metrics: list[EvalMetric]
123
-
124
-
125
- class RunEvalResult(BaseModel):
126
- eval_set_id: str
127
- eval_id: str
128
- final_eval_status: EvalStatus
129
- eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
130
- session_id: str
131
-
132
-
133
- def get_fast_api_app(
134
- *,
135
- agent_dir: str,
136
- session_db_url: str = "",
137
- allow_origins: Optional[list[str]] = None,
138
- web: bool,
139
- trace_to_cloud: bool = False,
140
- lifespan: Optional[Lifespan[FastAPI]] = None,
141
- ) -> FastAPI:
142
- # InMemory tracing dict.
143
- trace_dict: dict[str, Any] = {}
144
-
145
- # Set up tracing in the FastAPI server.
146
- provider = TracerProvider()
147
- provider.add_span_processor(
148
- export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
149
- )
150
- if trace_to_cloud:
151
- envs.load_dotenv_for_agent("", agent_dir)
152
- if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
153
- processor = export.BatchSpanProcessor(
154
- CloudTraceSpanExporter(project_id=project_id)
155
- )
156
- provider.add_span_processor(processor)
157
- else:
158
- logging.warning(
159
- "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
160
- " not be enabled."
161
- )
162
-
163
- trace.set_tracer_provider(provider)
164
-
165
- exit_stacks = []
166
-
167
- @asynccontextmanager
168
- async def internal_lifespan(app: FastAPI):
169
- if lifespan:
170
- async with lifespan(app) as lifespan_context:
171
- yield
172
-
173
- if exit_stacks:
174
- for stack in exit_stacks:
175
- await stack.aclose()
176
- else:
177
- yield
178
-
179
- # Run the FastAPI server.
180
- app = FastAPI(lifespan=internal_lifespan)
181
-
182
- if allow_origins:
183
- app.add_middleware(
184
- CORSMiddleware,
185
- allow_origins=allow_origins,
186
- allow_credentials=True,
187
- allow_methods=["*"],
188
- allow_headers=["*"],
189
- )
190
-
191
- if agent_dir not in sys.path:
192
- sys.path.append(agent_dir)
193
-
194
- runner_dict = {}
195
- root_agent_dict = {}
196
-
197
- # Build the Artifact service
198
- artifact_service = InMemoryArtifactService()
199
- memory_service = InMemoryMemoryService()
200
-
201
- # Build the Session service
202
- agent_engine_id = ""
203
- if session_db_url:
204
- if session_db_url.startswith("agentengine://"):
205
- # Create vertex session service
206
- agent_engine_id = session_db_url.split("://")[1]
207
- if not agent_engine_id:
208
- raise click.ClickException("Agent engine id can not be empty.")
209
- envs.load_dotenv_for_agent("", agent_dir)
210
- session_service = VertexAiSessionService(
211
- os.environ["GOOGLE_CLOUD_PROJECT"],
212
- os.environ["GOOGLE_CLOUD_LOCATION"],
213
- )
214
- else:
215
- session_service = DatabaseSessionService(db_url=session_db_url)
216
- else:
217
- session_service = InMemorySessionService()
218
-
219
- @app.get("/list-apps")
220
- def list_apps() -> list[str]:
221
- base_path = Path.cwd() / agent_dir
222
- if not base_path.exists():
223
- raise HTTPException(status_code=404, detail="Path not found")
224
- if not base_path.is_dir():
225
- raise HTTPException(status_code=400, detail="Not a directory")
226
- agent_names = [
227
- x
228
- for x in os.listdir(base_path)
229
- if os.path.isdir(os.path.join(base_path, x))
230
- and not x.startswith(".")
231
- and x != "__pycache__"
232
- ]
233
- agent_names.sort()
234
- return agent_names
235
-
236
- @app.get("/debug/trace/{event_id}")
237
- def get_trace_dict(event_id: str) -> Any:
238
- event_dict = trace_dict.get(event_id, None)
239
- if event_dict is None:
240
- raise HTTPException(status_code=404, detail="Trace not found")
241
- return event_dict
242
-
243
- @app.get(
244
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
245
- response_model_exclude_none=True,
246
- )
247
- def get_session(app_name: str, user_id: str, session_id: str) -> Session:
248
- # Connect to managed session if agent_engine_id is set.
249
- app_name = agent_engine_id if agent_engine_id else app_name
250
- session = session_service.get_session(
251
- app_name=app_name, user_id=user_id, session_id=session_id
252
- )
253
- if not session:
254
- raise HTTPException(status_code=404, detail="Session not found")
255
- return session
256
-
257
- @app.get(
258
- "/apps/{app_name}/users/{user_id}/sessions",
259
- response_model_exclude_none=True,
260
- )
261
- def list_sessions(app_name: str, user_id: str) -> list[Session]:
262
- # Connect to managed session if agent_engine_id is set.
263
- app_name = agent_engine_id if agent_engine_id else app_name
264
- return [
265
- session
266
- for session in session_service.list_sessions(
267
- app_name=app_name, user_id=user_id
268
- ).sessions
269
- # Remove sessions that were generated as a part of Eval.
270
- if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
271
- ]
272
-
273
- @app.post(
274
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
275
- response_model_exclude_none=True,
276
- )
277
- def create_session_with_id(
278
- app_name: str,
279
- user_id: str,
280
- session_id: str,
281
- state: Optional[dict[str, Any]] = None,
282
- ) -> Session:
283
- # Connect to managed session if agent_engine_id is set.
284
- app_name = agent_engine_id if agent_engine_id else app_name
285
- if (
286
- session_service.get_session(
287
- app_name=app_name, user_id=user_id, session_id=session_id
288
- )
289
- is not None
290
- ):
291
- logger.warning("Session already exists: %s", session_id)
292
- raise HTTPException(
293
- status_code=400, detail=f"Session already exists: {session_id}"
294
- )
295
-
296
- logger.info("New session created: %s", session_id)
297
- return session_service.create_session(
298
- app_name=app_name, user_id=user_id, state=state, session_id=session_id
299
- )
300
-
301
- @app.post(
302
- "/apps/{app_name}/users/{user_id}/sessions",
303
- response_model_exclude_none=True,
304
- )
305
- def create_session(
306
- app_name: str,
307
- user_id: str,
308
- state: Optional[dict[str, Any]] = None,
309
- ) -> Session:
310
- # Connect to managed session if agent_engine_id is set.
311
- app_name = agent_engine_id if agent_engine_id else app_name
312
-
313
- logger.info("New session created")
314
- return session_service.create_session(
315
- app_name=app_name, user_id=user_id, state=state
316
- )
317
-
318
- def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
319
- return os.path.join(
320
- agent_dir,
321
- app_name,
322
- eval_set_id + _EVAL_SET_FILE_EXTENSION,
323
- )
324
-
325
- @app.post(
326
- "/apps/{app_name}/eval_sets/{eval_set_id}",
327
- response_model_exclude_none=True,
328
- )
329
- def create_eval_set(
330
- app_name: str,
331
- eval_set_id: str,
332
- ):
333
- """Creates an eval set, given the id."""
334
- pattern = r"^[a-zA-Z0-9_]+$"
335
- if not bool(re.fullmatch(pattern, eval_set_id)):
336
- raise HTTPException(
337
- status_code=400,
338
- detail=(
339
- f"Invalid eval set id. Eval set id should have the `{pattern}`"
340
- " format"
341
- ),
342
- )
343
- # Define the file path
344
- new_eval_set_path = _get_eval_set_file_path(
345
- app_name, agent_dir, eval_set_id
346
- )
347
-
348
- logger.info("Creating eval set file `%s`", new_eval_set_path)
349
-
350
- if not os.path.exists(new_eval_set_path):
351
- # Write the JSON string to the file
352
- logger.info("Eval set file doesn't exist, we will create a new one.")
353
- with open(new_eval_set_path, "w") as f:
354
- empty_content = json.dumps([], indent=2)
355
- f.write(empty_content)
356
-
357
- @app.get(
358
- "/apps/{app_name}/eval_sets",
359
- response_model_exclude_none=True,
360
- )
361
- def list_eval_sets(app_name: str) -> list[str]:
362
- """Lists all eval sets for the given app."""
363
- eval_set_file_path = os.path.join(agent_dir, app_name)
364
- eval_sets = []
365
- for file in os.listdir(eval_set_file_path):
366
- if file.endswith(_EVAL_SET_FILE_EXTENSION):
367
- eval_sets.append(
368
- os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
369
- )
370
-
371
- return sorted(eval_sets)
372
-
373
- @app.post(
374
- "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
375
- response_model_exclude_none=True,
376
- )
377
- async def add_session_to_eval_set(
378
- app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
379
- ):
380
- pattern = r"^[a-zA-Z0-9_]+$"
381
- if not bool(re.fullmatch(pattern, req.eval_id)):
382
- raise HTTPException(
383
- status_code=400,
384
- detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
385
- )
386
-
387
- # Get the session
388
- session = session_service.get_session(
389
- app_name=app_name, user_id=req.user_id, session_id=req.session_id
390
- )
391
- assert session, "Session not found."
392
- # Load the eval set file data
393
- eval_set_file_path = _get_eval_set_file_path(
394
- app_name, agent_dir, eval_set_id
395
- )
396
- with open(eval_set_file_path, "r") as file:
397
- eval_set_data = json.load(file) # Load JSON into a list
398
-
399
- if [x for x in eval_set_data if x["name"] == req.eval_id]:
400
- raise HTTPException(
401
- status_code=400,
402
- detail=(
403
- f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
404
- " eval set."
405
- ),
406
- )
407
-
408
- # Convert the session data to evaluation format
409
- test_data = evals.convert_session_to_eval_format(session)
410
-
411
- # Populate the session with initial session state.
412
- initial_session_state = create_empty_state(
413
- await _get_root_agent_async(app_name)
414
- )
415
-
416
- eval_set_data.append({
417
- "name": req.eval_id,
418
- "data": test_data,
419
- "initial_session": {
420
- "state": initial_session_state,
421
- "app_name": app_name,
422
- "user_id": req.user_id,
423
- },
424
- })
425
- # Serialize the test data to JSON and write to the eval set file.
426
- with open(eval_set_file_path, "w") as f:
427
- f.write(json.dumps(eval_set_data, indent=2))
428
-
429
- @app.get(
430
- "/apps/{app_name}/eval_sets/{eval_set_id}/evals",
431
- response_model_exclude_none=True,
432
- )
433
- def list_evals_in_eval_set(
434
- app_name: str,
435
- eval_set_id: str,
436
- ) -> list[str]:
437
- """Lists all evals in an eval set."""
438
- # Load the eval set file data
439
- eval_set_file_path = _get_eval_set_file_path(
440
- app_name, agent_dir, eval_set_id
441
- )
442
- with open(eval_set_file_path, "r") as file:
443
- eval_set_data = json.load(file) # Load JSON into a list
444
-
445
- return sorted([x["name"] for x in eval_set_data])
446
-
447
- @app.post(
448
- "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
449
- response_model_exclude_none=True,
450
- )
451
- async def run_eval(
452
- app_name: str, eval_set_id: str, req: RunEvalRequest
453
- ) -> list[RunEvalResult]:
454
- from .cli_eval import run_evals
455
-
456
- """Runs an eval given the details in the eval request."""
457
- # Create a mapping from eval set file to all the evals that needed to be
458
- # run.
459
- eval_set_file_path = _get_eval_set_file_path(
460
- app_name, agent_dir, eval_set_id
461
- )
462
- eval_set_to_evals = {eval_set_file_path: req.eval_ids}
463
-
464
- if not req.eval_ids:
465
- logger.info(
466
- "Eval ids to run list is empty. We will all evals in the eval set."
467
- )
468
- root_agent = await _get_root_agent_async(app_name)
469
- eval_results = list(
470
- run_evals(
471
- eval_set_to_evals,
472
- root_agent,
473
- getattr(root_agent, "reset_data", None),
474
- req.eval_metrics,
475
- session_service=session_service,
476
- artifact_service=artifact_service,
477
- )
478
- )
479
-
480
- run_eval_results = []
481
- for eval_result in eval_results:
482
- run_eval_results.append(
483
- RunEvalResult(
484
- app_name=app_name,
485
- eval_set_id=eval_set_id,
486
- eval_id=eval_result.eval_id,
487
- final_eval_status=eval_result.final_eval_status,
488
- eval_metric_results=eval_result.eval_metric_results,
489
- session_id=eval_result.session_id,
490
- )
491
- )
492
- return run_eval_results
493
-
494
- @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
495
- def delete_session(app_name: str, user_id: str, session_id: str):
496
- # Connect to managed session if agent_engine_id is set.
497
- app_name = agent_engine_id if agent_engine_id else app_name
498
- session_service.delete_session(
499
- app_name=app_name, user_id=user_id, session_id=session_id
500
- )
501
-
502
- @app.get(
503
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
504
- response_model_exclude_none=True,
505
- )
506
- def load_artifact(
507
- app_name: str,
508
- user_id: str,
509
- session_id: str,
510
- artifact_name: str,
511
- version: Optional[int] = Query(None),
512
- ) -> Optional[types.Part]:
513
- app_name = agent_engine_id if agent_engine_id else app_name
514
- artifact = artifact_service.load_artifact(
515
- app_name=app_name,
516
- user_id=user_id,
517
- session_id=session_id,
518
- filename=artifact_name,
519
- version=version,
520
- )
521
- if not artifact:
522
- raise HTTPException(status_code=404, detail="Artifact not found")
523
- return artifact
524
-
525
- @app.get(
526
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
527
- response_model_exclude_none=True,
528
- )
529
- def load_artifact_version(
530
- app_name: str,
531
- user_id: str,
532
- session_id: str,
533
- artifact_name: str,
534
- version_id: int,
535
- ) -> Optional[types.Part]:
536
- app_name = agent_engine_id if agent_engine_id else app_name
537
- artifact = artifact_service.load_artifact(
538
- app_name=app_name,
539
- user_id=user_id,
540
- session_id=session_id,
541
- filename=artifact_name,
542
- version=version_id,
543
- )
544
- if not artifact:
545
- raise HTTPException(status_code=404, detail="Artifact not found")
546
- return artifact
547
-
548
- @app.get(
549
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
550
- response_model_exclude_none=True,
551
- )
552
- def list_artifact_names(
553
- app_name: str, user_id: str, session_id: str
554
- ) -> list[str]:
555
- app_name = agent_engine_id if agent_engine_id else app_name
556
- return artifact_service.list_artifact_keys(
557
- app_name=app_name, user_id=user_id, session_id=session_id
558
- )
559
-
560
- @app.get(
561
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
562
- response_model_exclude_none=True,
563
- )
564
- def list_artifact_versions(
565
- app_name: str, user_id: str, session_id: str, artifact_name: str
566
- ) -> list[int]:
567
- app_name = agent_engine_id if agent_engine_id else app_name
568
- return artifact_service.list_versions(
569
- app_name=app_name,
570
- user_id=user_id,
571
- session_id=session_id,
572
- filename=artifact_name,
573
- )
574
-
575
- @app.delete(
576
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
577
- )
578
- def delete_artifact(
579
- app_name: str, user_id: str, session_id: str, artifact_name: str
580
- ):
581
- app_name = agent_engine_id if agent_engine_id else app_name
582
- artifact_service.delete_artifact(
583
- app_name=app_name,
584
- user_id=user_id,
585
- session_id=session_id,
586
- filename=artifact_name,
587
- )
588
-
589
- @app.post("/run", response_model_exclude_none=True)
590
- async def agent_run(req: AgentRunRequest) -> list[Event]:
591
- # Connect to managed session if agent_engine_id is set.
592
- app_id = agent_engine_id if agent_engine_id else req.app_name
593
- session = session_service.get_session(
594
- app_name=app_id, user_id=req.user_id, session_id=req.session_id
595
- )
596
- if not session:
597
- raise HTTPException(status_code=404, detail="Session not found")
598
- runner = await _get_runner_async(req.app_name)
599
- events = [
600
- event
601
- async for event in runner.run_async(
602
- user_id=req.user_id,
603
- session_id=req.session_id,
604
- new_message=req.new_message,
605
- )
606
- ]
607
- logger.info("Generated %s events in agent run: %s", len(events), events)
608
- return events
609
-
610
- @app.post("/run_sse")
611
- async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
612
- # Connect to managed session if agent_engine_id is set.
613
- app_id = agent_engine_id if agent_engine_id else req.app_name
614
- # SSE endpoint
615
- session = session_service.get_session(
616
- app_name=app_id, user_id=req.user_id, session_id=req.session_id
617
- )
618
- if not session:
619
- raise HTTPException(status_code=404, detail="Session not found")
620
-
621
- # Convert the events to properly formatted SSE
622
- async def event_generator():
623
- try:
624
- stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
625
- runner = await _get_runner_async(req.app_name)
626
- async for event in runner.run_async(
627
- user_id=req.user_id,
628
- session_id=req.session_id,
629
- new_message=req.new_message,
630
- run_config=RunConfig(streaming_mode=stream_mode),
631
- ):
632
- # Format as SSE data
633
- sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
634
- logger.info("Generated event in agent run streaming: %s", sse_event)
635
- yield f"data: {sse_event}\n\n"
636
- except Exception as e:
637
- logger.exception("Error in event_generator: %s", e)
638
- # You might want to yield an error event here
639
- yield f'data: {{"error": "{str(e)}"}}\n\n'
640
-
641
- # Returns a streaming response with the proper media type for SSE
642
- return StreamingResponse(
643
- event_generator(),
644
- media_type="text/event-stream",
645
- )
646
-
647
- @app.get(
648
- "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
649
- response_model_exclude_none=True,
650
- )
651
- async def get_event_graph(
652
- app_name: str, user_id: str, session_id: str, event_id: str
653
- ):
654
- # Connect to managed session if agent_engine_id is set.
655
- app_id = agent_engine_id if agent_engine_id else app_name
656
- session = session_service.get_session(
657
- app_name=app_id, user_id=user_id, session_id=session_id
658
- )
659
- session_events = session.events if session else []
660
- event = next((x for x in session_events if x.id == event_id), None)
661
- if not event:
662
- return {}
663
-
664
- from . import agent_graph
665
-
666
- function_calls = event.get_function_calls()
667
- function_responses = event.get_function_responses()
668
- root_agent = await _get_root_agent_async(app_name)
669
- dot_graph = None
670
- if function_calls:
671
- function_call_highlights = []
672
- for function_call in function_calls:
673
- from_name = event.author
674
- to_name = function_call.name
675
- function_call_highlights.append((from_name, to_name))
676
- dot_graph = agent_graph.get_agent_graph(
677
- root_agent, function_call_highlights
678
- )
679
- elif function_responses:
680
- function_responses_highlights = []
681
- for function_response in function_responses:
682
- from_name = function_response.name
683
- to_name = event.author
684
- function_responses_highlights.append((from_name, to_name))
685
- dot_graph = agent_graph.get_agent_graph(
686
- root_agent, function_responses_highlights
687
- )
688
- else:
689
- from_name = event.author
690
- to_name = ""
691
- dot_graph = agent_graph.get_agent_graph(
692
- root_agent, [(from_name, to_name)]
693
- )
694
- if dot_graph and isinstance(dot_graph, graphviz.Digraph):
695
- return {"dot_src": dot_graph.source}
696
- else:
697
- return {}
698
-
699
- @app.websocket("/run_live")
700
- async def agent_live_run(
701
- websocket: WebSocket,
702
- app_name: str,
703
- user_id: str,
704
- session_id: str,
705
- modalities: List[Literal["TEXT", "AUDIO"]] = Query(
706
- default=["TEXT", "AUDIO"]
707
- ), # Only allows "TEXT" or "AUDIO"
708
- ) -> None:
709
- await websocket.accept()
710
-
711
- # Connect to managed session if agent_engine_id is set.
712
- app_id = agent_engine_id if agent_engine_id else app_name
713
- session = session_service.get_session(
714
- app_name=app_id, user_id=user_id, session_id=session_id
715
- )
716
- if not session:
717
- # Accept first so that the client is aware of connection establishment,
718
- # then close with a specific code.
719
- await websocket.close(code=1002, reason="Session not found")
720
- return
721
-
722
- live_request_queue = LiveRequestQueue()
723
-
724
- async def forward_events():
725
- runner = await _get_runner_async(app_name)
726
- async for event in runner.run_live(
727
- session=session, live_request_queue=live_request_queue
728
- ):
729
- await websocket.send_text(
730
- event.model_dump_json(exclude_none=True, by_alias=True)
731
- )
732
-
733
- async def process_messages():
734
- try:
735
- while True:
736
- data = await websocket.receive_text()
737
- # Validate and send the received message to the live queue.
738
- live_request_queue.send(LiveRequest.model_validate_json(data))
739
- except ValidationError as ve:
740
- logger.error("Validation error in process_messages: %s", ve)
741
-
742
- # Run both tasks concurrently and cancel all if one fails.
743
- tasks = [
744
- asyncio.create_task(forward_events()),
745
- asyncio.create_task(process_messages()),
746
- ]
747
- done, pending = await asyncio.wait(
748
- tasks, return_when=asyncio.FIRST_EXCEPTION
749
- )
750
- try:
751
- # This will re-raise any exception from the completed tasks.
752
- for task in done:
753
- task.result()
754
- except WebSocketDisconnect:
755
- logger.info("Client disconnected during process_messages.")
756
- except Exception as e:
757
- logger.exception("Error during live websocket communication: %s", e)
758
- traceback.print_exc()
759
- WEBSOCKET_INTERNAL_ERROR_CODE = 1011
760
- WEBSOCKET_MAX_BYTES_FOR_REASON = 123
761
- await websocket.close(
762
- code=WEBSOCKET_INTERNAL_ERROR_CODE,
763
- reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
764
- )
765
- finally:
766
- for task in pending:
767
- task.cancel()
768
-
769
- async def _get_root_agent_async(app_name: str) -> Agent:
770
- """Returns the root agent for the given app."""
771
- if app_name in root_agent_dict:
772
- return root_agent_dict[app_name]
773
- agent_module = importlib.import_module(app_name)
774
- if getattr(agent_module.agent, "root_agent"):
775
- root_agent = agent_module.agent.root_agent
776
- else:
777
- raise ValueError(f'Unable to find "root_agent" from {app_name}.')
778
-
779
- # Handle an awaitable root agent and await for the actual agent.
780
- if inspect.isawaitable(root_agent):
781
- try:
782
- agent, exit_stack = await root_agent
783
- exit_stacks.append(exit_stack)
784
- root_agent = agent
785
- except Exception as e:
786
- raise RuntimeError(f"error getting root agent, {e}") from e
787
-
788
- root_agent_dict[app_name] = root_agent
789
- return root_agent
790
-
791
- async def _get_runner_async(app_name: str) -> Runner:
792
- """Returns the runner for the given app."""
793
- envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
794
- if app_name in runner_dict:
795
- return runner_dict[app_name]
796
- root_agent = await _get_root_agent_async(app_name)
797
- runner = Runner(
798
- app_name=agent_engine_id if agent_engine_id else app_name,
799
- agent=root_agent,
800
- artifact_service=artifact_service,
801
- session_service=session_service,
802
- memory_service=memory_service,
803
- )
804
- runner_dict[app_name] = runner
805
- return runner
806
-
807
- if web:
808
- BASE_DIR = Path(__file__).parent.resolve()
809
- ANGULAR_DIST_PATH = BASE_DIR / "browser"
810
-
811
- @app.get("/")
812
- async def redirect_to_dev_ui():
813
- return RedirectResponse("/dev-ui")
814
-
815
- @app.get("/dev-ui")
816
- async def dev_ui():
817
- return FileResponse(BASE_DIR / "browser/index.html")
818
-
819
- app.mount(
820
- "/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
821
- )
822
- return app