google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/base_agent.py.orig +330 -0
  3. google/adk/agents/callback_context.py +0 -5
  4. google/adk/agents/llm_agent.py +122 -30
  5. google/adk/agents/loop_agent.py +1 -1
  6. google/adk/agents/parallel_agent.py +7 -0
  7. google/adk/agents/readonly_context.py +7 -1
  8. google/adk/agents/run_config.py +1 -1
  9. google/adk/agents/sequential_agent.py +31 -0
  10. google/adk/agents/transcription_entry.py +4 -2
  11. google/adk/artifacts/gcs_artifact_service.py +1 -1
  12. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  13. google/adk/auth/auth_credential.py +6 -1
  14. google/adk/auth/auth_preprocessor.py +7 -1
  15. google/adk/auth/auth_tool.py +3 -4
  16. google/adk/cli/agent_graph.py +5 -5
  17. google/adk/cli/browser/index.html +2 -2
  18. google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
  19. google/adk/cli/cli.py +7 -7
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +172 -99
  22. google/adk/cli/cli_tools_click.py +147 -64
  23. google/adk/cli/fast_api.py +330 -148
  24. google/adk/cli/fast_api.py.orig +174 -80
  25. google/adk/cli/utils/common.py +23 -0
  26. google/adk/cli/utils/evals.py +83 -1
  27. google/adk/cli/utils/logs.py +13 -5
  28. google/adk/code_executors/__init__.py +3 -1
  29. google/adk/code_executors/built_in_code_executor.py +52 -0
  30. google/adk/evaluation/__init__.py +1 -1
  31. google/adk/evaluation/agent_evaluator.py +168 -128
  32. google/adk/evaluation/eval_case.py +102 -0
  33. google/adk/evaluation/eval_set.py +37 -0
  34. google/adk/evaluation/eval_sets_manager.py +42 -0
  35. google/adk/evaluation/evaluation_generator.py +88 -113
  36. google/adk/evaluation/evaluator.py +56 -0
  37. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  38. google/adk/evaluation/response_evaluator.py +106 -2
  39. google/adk/evaluation/trajectory_evaluator.py +83 -2
  40. google/adk/events/event.py +6 -1
  41. google/adk/events/event_actions.py +6 -1
  42. google/adk/examples/example_util.py +3 -2
  43. google/adk/flows/llm_flows/_code_execution.py +9 -1
  44. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  45. google/adk/flows/llm_flows/base_llm_flow.py +54 -15
  46. google/adk/flows/llm_flows/functions.py +9 -8
  47. google/adk/flows/llm_flows/instructions.py +13 -5
  48. google/adk/flows/llm_flows/single_flow.py +1 -1
  49. google/adk/memory/__init__.py +1 -1
  50. google/adk/memory/_utils.py +23 -0
  51. google/adk/memory/base_memory_service.py +23 -21
  52. google/adk/memory/base_memory_service.py.orig +76 -0
  53. google/adk/memory/in_memory_memory_service.py +57 -25
  54. google/adk/memory/memory_entry.py +37 -0
  55. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  56. google/adk/models/anthropic_llm.py +16 -9
  57. google/adk/models/gemini_llm_connection.py +11 -11
  58. google/adk/models/google_llm.py +9 -2
  59. google/adk/models/google_llm.py.orig +305 -0
  60. google/adk/models/lite_llm.py +77 -21
  61. google/adk/models/llm_response.py +14 -2
  62. google/adk/models/registry.py +1 -1
  63. google/adk/runners.py +65 -41
  64. google/adk/sessions/__init__.py +1 -1
  65. google/adk/sessions/base_session_service.py +6 -33
  66. google/adk/sessions/database_session_service.py +58 -65
  67. google/adk/sessions/in_memory_session_service.py +106 -24
  68. google/adk/sessions/session.py +3 -0
  69. google/adk/sessions/vertex_ai_session_service.py +23 -45
  70. google/adk/telemetry.py +3 -0
  71. google/adk/tools/__init__.py +4 -7
  72. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  73. google/adk/tools/_memory_entry_utils.py +30 -0
  74. google/adk/tools/agent_tool.py +9 -9
  75. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  76. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  77. google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
  78. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  79. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  80. google/adk/tools/base_toolset.py +58 -0
  81. google/adk/tools/enterprise_search_tool.py +65 -0
  82. google/adk/tools/function_parameter_parse_util.py +2 -2
  83. google/adk/tools/google_api_tool/__init__.py +18 -70
  84. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  85. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  86. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  87. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  88. google/adk/tools/langchain_tool.py +96 -49
  89. google/adk/tools/load_memory_tool.py +14 -5
  90. google/adk/tools/mcp_tool/__init__.py +3 -2
  91. google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
  92. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  93. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  94. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  95. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  96. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
  97. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  98. google/adk/tools/preload_memory_tool.py +27 -18
  99. google/adk/tools/retrieval/__init__.py +1 -1
  100. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  101. google/adk/tools/toolbox_toolset.py +79 -0
  102. google/adk/tools/transfer_to_agent_tool.py +0 -1
  103. google/adk/version.py +1 -1
  104. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  105. google_adk-1.0.0.dist-info/RECORD +195 -0
  106. google/adk/agents/remote_agent.py +0 -50
  107. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  108. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  109. google/adk/tools/toolbox_tool.py +0 -46
  110. google_adk-0.5.0.dist-info/RECORD +0 -180
  111. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  112. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  113. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -13,7 +13,9 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import asyncio
16
+ from contextlib import asynccontextmanager
16
17
  import importlib
18
+ import inspect
17
19
  import json
18
20
  import logging
19
21
  import os
@@ -28,10 +30,10 @@ from typing import Literal
28
30
  from typing import Optional
29
31
 
30
32
  import click
33
+ from click import Tuple
31
34
  from fastapi import FastAPI
32
35
  from fastapi import HTTPException
33
36
  from fastapi import Query
34
- from fastapi import Response
35
37
  from fastapi.middleware.cors import CORSMiddleware
36
38
  from fastapi.responses import FileResponse
37
39
  from fastapi.responses import RedirectResponse
@@ -40,6 +42,7 @@ from fastapi.staticfiles import StaticFiles
40
42
  from fastapi.websockets import WebSocket
41
43
  from fastapi.websockets import WebSocketDisconnect
42
44
  from google.genai import types
45
+ import graphviz
43
46
  from opentelemetry import trace
44
47
  from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
45
48
  from opentelemetry.sdk.trace import export
@@ -47,6 +50,7 @@ from opentelemetry.sdk.trace import ReadableSpan
47
50
  from opentelemetry.sdk.trace import TracerProvider
48
51
  from pydantic import BaseModel
49
52
  from pydantic import ValidationError
53
+ from starlette.types import Lifespan
50
54
 
51
55
  from ..agents import RunConfig
52
56
  from ..agents.live_request_queue import LiveRequest
@@ -55,6 +59,7 @@ from ..agents.llm_agent import Agent
55
59
  from ..agents.run_config import StreamingMode
56
60
  from ..artifacts import InMemoryArtifactService
57
61
  from ..events.event import Event
62
+ from ..memory.in_memory_memory_service import InMemoryMemoryService
58
63
  from ..runners import Runner
59
64
  from ..sessions.database_session_service import DatabaseSessionService
60
65
  from ..sessions.in_memory_session_service import InMemorySessionService
@@ -82,11 +87,16 @@ class ApiServerSpanExporter(export.SpanExporter):
82
87
  self, spans: typing.Sequence[ReadableSpan]
83
88
  ) -> export.SpanExportResult:
84
89
  for span in spans:
85
- if span.name == "call_llm" or span.name == "send_data":
90
+ if (
91
+ span.name == "call_llm"
92
+ or span.name == "send_data"
93
+ or span.name.startswith("tool_response")
94
+ ):
86
95
  attributes = dict(span.attributes)
87
96
  attributes["trace_id"] = span.get_span_context().trace_id
88
97
  attributes["span_id"] = span.get_span_context().span_id
89
- self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
98
+ if attributes.get("gcp.vertex.agent.event_id", None):
99
+ self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
90
100
  return export.SpanExportResult.SUCCESS
91
101
 
92
102
  def force_flush(self, timeout_millis: int = 30000) -> bool:
@@ -126,6 +136,8 @@ def get_fast_api_app(
126
136
  session_db_url: str = "",
127
137
  allow_origins: Optional[list[str]] = None,
128
138
  web: bool,
139
+ trace_to_cloud: bool = False,
140
+ lifespan: Optional[Lifespan[FastAPI]] = None,
129
141
  ) -> FastAPI:
130
142
  # InMemory tracing dict.
131
143
  trace_dict: dict[str, Any] = {}
@@ -135,26 +147,37 @@ def get_fast_api_app(
135
147
  provider.add_span_processor(
136
148
  export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
137
149
  )
138
- if os.environ.get("ADK_TRACE_TO_CLOUD", "0") == "1":
139
- processor = export.BatchSpanProcessor(
140
- CloudTraceSpanExporter(
141
- project_id=os.environ.get("GOOGLE_CLOUD_PROJECT", "")
142
- )
143
- )
144
- provider.add_span_processor(processor)
150
+ if trace_to_cloud:
151
+ envs.load_dotenv_for_agent("", agent_dir)
152
+ if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
153
+ processor = export.BatchSpanProcessor(
154
+ CloudTraceSpanExporter(project_id=project_id)
155
+ )
156
+ provider.add_span_processor(processor)
157
+ else:
158
+ logging.warning(
159
+ "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
160
+ " not be enabled."
161
+ )
145
162
 
146
163
  trace.set_tracer_provider(provider)
147
164
 
165
+ exit_stacks = []
166
+
167
+ @asynccontextmanager
168
+ async def internal_lifespan(app: FastAPI):
169
+ if lifespan:
170
+ async with lifespan(app) as lifespan_context:
171
+ yield
172
+
173
+ if exit_stacks:
174
+ for stack in exit_stacks:
175
+ await stack.aclose()
176
+ else:
177
+ yield
178
+
148
179
  # Run the FastAPI server.
149
- app = FastAPI()
150
- origins = ["http://localhost:4200"]
151
- app.add_middleware(
152
- CORSMiddleware,
153
- allow_origins=origins,
154
- allow_credentials=True,
155
- allow_methods=["*"],
156
- allow_headers=["*"],
157
- )
180
+ app = FastAPI(lifespan=internal_lifespan)
158
181
 
159
182
  if allow_origins:
160
183
  app.add_middleware(
@@ -173,6 +196,7 @@ def get_fast_api_app(
173
196
 
174
197
  # Build the Artifact service
175
198
  artifact_service = InMemoryArtifactService()
199
+ memory_service = InMemoryMemoryService()
176
200
 
177
201
  # Build the Session service
178
202
  agent_engine_id = ""
@@ -223,7 +247,9 @@ def get_fast_api_app(
223
247
  def get_session(app_name: str, user_id: str, session_id: str) -> Session:
224
248
  # Connect to managed session if agent_engine_id is set.
225
249
  app_name = agent_engine_id if agent_engine_id else app_name
226
- session = session_service.get_session(app_name, user_id, session_id)
250
+ session = session_service.get_session(
251
+ app_name=app_name, user_id=user_id, session_id=session_id
252
+ )
227
253
  if not session:
228
254
  raise HTTPException(status_code=404, detail="Session not found")
229
255
  return session
@@ -237,7 +263,9 @@ def get_fast_api_app(
237
263
  app_name = agent_engine_id if agent_engine_id else app_name
238
264
  return [
239
265
  session
240
- for session in session_service.list_sessions(app_name, user_id).sessions
266
+ for session in session_service.list_sessions(
267
+ app_name=app_name, user_id=user_id
268
+ ).sessions
241
269
  # Remove sessions that were generated as a part of Eval.
242
270
  if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
243
271
  ]
@@ -246,7 +274,7 @@ def get_fast_api_app(
246
274
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
247
275
  response_model_exclude_none=True,
248
276
  )
249
- def create_session(
277
+ def create_session_with_id(
250
278
  app_name: str,
251
279
  user_id: str,
252
280
  session_id: str,
@@ -254,7 +282,12 @@ def get_fast_api_app(
254
282
  ) -> Session:
255
283
  # Connect to managed session if agent_engine_id is set.
256
284
  app_name = agent_engine_id if agent_engine_id else app_name
257
- if session_service.get_session(app_name, user_id, session_id) is not None:
285
+ if (
286
+ session_service.get_session(
287
+ app_name=app_name, user_id=user_id, session_id=session_id
288
+ )
289
+ is not None
290
+ ):
258
291
  logger.warning("Session already exists: %s", session_id)
259
292
  raise HTTPException(
260
293
  status_code=400, detail=f"Session already exists: {session_id}"
@@ -262,7 +295,7 @@ def get_fast_api_app(
262
295
 
263
296
  logger.info("New session created: %s", session_id)
264
297
  return session_service.create_session(
265
- app_name, user_id, state, session_id=session_id
298
+ app_name=app_name, user_id=user_id, state=state, session_id=session_id
266
299
  )
267
300
 
268
301
  @app.post(
@@ -278,7 +311,9 @@ def get_fast_api_app(
278
311
  app_name = agent_engine_id if agent_engine_id else app_name
279
312
 
280
313
  logger.info("New session created")
281
- return session_service.create_session(app_name, user_id, state)
314
+ return session_service.create_session(
315
+ app_name=app_name, user_id=user_id, state=state
316
+ )
282
317
 
283
318
  def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
284
319
  return os.path.join(
@@ -339,7 +374,7 @@ def get_fast_api_app(
339
374
  "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
340
375
  response_model_exclude_none=True,
341
376
  )
342
- def add_session_to_eval_set(
377
+ async def add_session_to_eval_set(
343
378
  app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
344
379
  ):
345
380
  pattern = r"^[a-zA-Z0-9_]+$"
@@ -350,7 +385,9 @@ def get_fast_api_app(
350
385
  )
351
386
 
352
387
  # Get the session
353
- session = session_service.get_session(app_name, req.user_id, req.session_id)
388
+ session = session_service.get_session(
389
+ app_name=app_name, user_id=req.user_id, session_id=req.session_id
390
+ )
354
391
  assert session, "Session not found."
355
392
  # Load the eval set file data
356
393
  eval_set_file_path = _get_eval_set_file_path(
@@ -372,7 +409,9 @@ def get_fast_api_app(
372
409
  test_data = evals.convert_session_to_eval_format(session)
373
410
 
374
411
  # Populate the session with initial session state.
375
- initial_session_state = create_empty_state(_get_root_agent(app_name))
412
+ initial_session_state = create_empty_state(
413
+ await _get_root_agent_async(app_name)
414
+ )
376
415
 
377
416
  eval_set_data.append({
378
417
  "name": req.eval_id,
@@ -409,7 +448,7 @@ def get_fast_api_app(
409
448
  "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
410
449
  response_model_exclude_none=True,
411
450
  )
412
- def run_eval(
451
+ async def run_eval(
413
452
  app_name: str, eval_set_id: str, req: RunEvalRequest
414
453
  ) -> list[RunEvalResult]:
415
454
  from .cli_eval import run_evals
@@ -426,7 +465,7 @@ def get_fast_api_app(
426
465
  logger.info(
427
466
  "Eval ids to run list is empty. We will all evals in the eval set."
428
467
  )
429
- root_agent = _get_root_agent(app_name)
468
+ root_agent = await _get_root_agent_async(app_name)
430
469
  eval_results = list(
431
470
  run_evals(
432
471
  eval_set_to_evals,
@@ -456,7 +495,9 @@ def get_fast_api_app(
456
495
  def delete_session(app_name: str, user_id: str, session_id: str):
457
496
  # Connect to managed session if agent_engine_id is set.
458
497
  app_name = agent_engine_id if agent_engine_id else app_name
459
- session_service.delete_session(app_name, user_id, session_id)
498
+ session_service.delete_session(
499
+ app_name=app_name, user_id=user_id, session_id=session_id
500
+ )
460
501
 
461
502
  @app.get(
462
503
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
@@ -469,8 +510,13 @@ def get_fast_api_app(
469
510
  artifact_name: str,
470
511
  version: Optional[int] = Query(None),
471
512
  ) -> Optional[types.Part]:
513
+ app_name = agent_engine_id if agent_engine_id else app_name
472
514
  artifact = artifact_service.load_artifact(
473
- app_name, user_id, session_id, artifact_name, version
515
+ app_name=app_name,
516
+ user_id=user_id,
517
+ session_id=session_id,
518
+ filename=artifact_name,
519
+ version=version,
474
520
  )
475
521
  if not artifact:
476
522
  raise HTTPException(status_code=404, detail="Artifact not found")
@@ -487,8 +533,13 @@ def get_fast_api_app(
487
533
  artifact_name: str,
488
534
  version_id: int,
489
535
  ) -> Optional[types.Part]:
536
+ app_name = agent_engine_id if agent_engine_id else app_name
490
537
  artifact = artifact_service.load_artifact(
491
- app_name, user_id, session_id, artifact_name, version_id
538
+ app_name=app_name,
539
+ user_id=user_id,
540
+ session_id=session_id,
541
+ filename=artifact_name,
542
+ version=version_id,
492
543
  )
493
544
  if not artifact:
494
545
  raise HTTPException(status_code=404, detail="Artifact not found")
@@ -501,7 +552,10 @@ def get_fast_api_app(
501
552
  def list_artifact_names(
502
553
  app_name: str, user_id: str, session_id: str
503
554
  ) -> list[str]:
504
- return artifact_service.list_artifact_keys(app_name, user_id, session_id)
555
+ app_name = agent_engine_id if agent_engine_id else app_name
556
+ return artifact_service.list_artifact_keys(
557
+ app_name=app_name, user_id=user_id, session_id=session_id
558
+ )
505
559
 
506
560
  @app.get(
507
561
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
@@ -510,8 +564,12 @@ def get_fast_api_app(
510
564
  def list_artifact_versions(
511
565
  app_name: str, user_id: str, session_id: str, artifact_name: str
512
566
  ) -> list[int]:
567
+ app_name = agent_engine_id if agent_engine_id else app_name
513
568
  return artifact_service.list_versions(
514
- app_name, user_id, session_id, artifact_name
569
+ app_name=app_name,
570
+ user_id=user_id,
571
+ session_id=session_id,
572
+ filename=artifact_name,
515
573
  )
516
574
 
517
575
  @app.delete(
@@ -520,18 +578,24 @@ def get_fast_api_app(
520
578
  def delete_artifact(
521
579
  app_name: str, user_id: str, session_id: str, artifact_name: str
522
580
  ):
581
+ app_name = agent_engine_id if agent_engine_id else app_name
523
582
  artifact_service.delete_artifact(
524
- app_name, user_id, session_id, artifact_name
583
+ app_name=app_name,
584
+ user_id=user_id,
585
+ session_id=session_id,
586
+ filename=artifact_name,
525
587
  )
526
588
 
527
589
  @app.post("/run", response_model_exclude_none=True)
528
590
  async def agent_run(req: AgentRunRequest) -> list[Event]:
529
591
  # Connect to managed session if agent_engine_id is set.
530
592
  app_id = agent_engine_id if agent_engine_id else req.app_name
531
- session = session_service.get_session(app_id, req.user_id, req.session_id)
593
+ session = session_service.get_session(
594
+ app_name=app_id, user_id=req.user_id, session_id=req.session_id
595
+ )
532
596
  if not session:
533
597
  raise HTTPException(status_code=404, detail="Session not found")
534
- runner = _get_runner(req.app_name)
598
+ runner = await _get_runner_async(req.app_name)
535
599
  events = [
536
600
  event
537
601
  async for event in runner.run_async(
@@ -548,7 +612,9 @@ def get_fast_api_app(
548
612
  # Connect to managed session if agent_engine_id is set.
549
613
  app_id = agent_engine_id if agent_engine_id else req.app_name
550
614
  # SSE endpoint
551
- session = session_service.get_session(app_id, req.user_id, req.session_id)
615
+ session = session_service.get_session(
616
+ app_name=app_id, user_id=req.user_id, session_id=req.session_id
617
+ )
552
618
  if not session:
553
619
  raise HTTPException(status_code=404, detail="Session not found")
554
620
 
@@ -556,7 +622,7 @@ def get_fast_api_app(
556
622
  async def event_generator():
557
623
  try:
558
624
  stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
559
- runner = _get_runner(req.app_name)
625
+ runner = await _get_runner_async(req.app_name)
560
626
  async for event in runner.run_async(
561
627
  user_id=req.user_id,
562
628
  session_id=req.session_id,
@@ -582,47 +648,53 @@ def get_fast_api_app(
582
648
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
583
649
  response_model_exclude_none=True,
584
650
  )
585
- def get_event_graph(
651
+ async def get_event_graph(
586
652
  app_name: str, user_id: str, session_id: str, event_id: str
587
653
  ):
588
654
  # Connect to managed session if agent_engine_id is set.
589
655
  app_id = agent_engine_id if agent_engine_id else app_name
590
- session = session_service.get_session(app_id, user_id, session_id)
656
+ session = session_service.get_session(
657
+ app_name=app_id, user_id=user_id, session_id=session_id
658
+ )
591
659
  session_events = session.events if session else []
592
660
  event = next((x for x in session_events if x.id == event_id), None)
593
- if event:
594
- from . import agent_graph
595
-
596
- function_calls = event.get_function_calls()
597
- function_responses = event.get_function_responses()
598
- root_agent = _get_root_agent(app_name)
599
- image_bytes = None
600
- if function_calls:
601
- function_call_highlights = []
602
- for function_call in function_calls:
603
- from_name = event.author
604
- to_name = function_call.name
605
- function_call_highlights.append((from_name, to_name))
606
- image_bytes = agent_graph.get_agent_graph(
607
- root_agent, function_call_highlights, True
608
- )
609
- elif function_responses:
610
- function_responses_highlights = []
611
- for function_response in function_responses:
612
- from_name = function_response.name
613
- to_name = event.author
614
- function_responses_highlights.append((from_name, to_name))
615
- image_bytes = agent_graph.get_agent_graph(
616
- root_agent, function_responses_highlights, True
617
- )
618
- else:
661
+ if not event:
662
+ return {}
663
+
664
+ from . import agent_graph
665
+
666
+ function_calls = event.get_function_calls()
667
+ function_responses = event.get_function_responses()
668
+ root_agent = await _get_root_agent_async(app_name)
669
+ dot_graph = None
670
+ if function_calls:
671
+ function_call_highlights = []
672
+ for function_call in function_calls:
619
673
  from_name = event.author
620
- to_name = ""
621
- image_bytes = agent_graph.get_agent_graph(
622
- root_agent, [(from_name, to_name)], True
674
+ to_name = function_call.name
675
+ function_call_highlights.append((from_name, to_name))
676
+ dot_graph = agent_graph.get_agent_graph(
677
+ root_agent, function_call_highlights
678
+ )
679
+ elif function_responses:
680
+ function_responses_highlights = []
681
+ for function_response in function_responses:
682
+ from_name = function_response.name
683
+ to_name = event.author
684
+ function_responses_highlights.append((from_name, to_name))
685
+ dot_graph = agent_graph.get_agent_graph(
686
+ root_agent, function_responses_highlights
623
687
  )
624
- return Response(content=image_bytes, media_type="image/png")
625
- return None
688
+ else:
689
+ from_name = event.author
690
+ to_name = ""
691
+ dot_graph = agent_graph.get_agent_graph(
692
+ root_agent, [(from_name, to_name)]
693
+ )
694
+ if dot_graph and isinstance(dot_graph, graphviz.Digraph):
695
+ return {"dot_src": dot_graph.source}
696
+ else:
697
+ return {}
626
698
 
627
699
  @app.websocket("/run_live")
628
700
  async def agent_live_run(
@@ -638,7 +710,9 @@ def get_fast_api_app(
638
710
 
639
711
  # Connect to managed session if agent_engine_id is set.
640
712
  app_id = agent_engine_id if agent_engine_id else app_name
641
- session = session_service.get_session(app_id, user_id, session_id)
713
+ session = session_service.get_session(
714
+ app_name=app_id, user_id=user_id, session_id=session_id
715
+ )
642
716
  if not session:
643
717
  # Accept first so that the client is aware of connection establishment,
644
718
  # then close with a specific code.
@@ -648,7 +722,7 @@ def get_fast_api_app(
648
722
  live_request_queue = LiveRequestQueue()
649
723
 
650
724
  async def forward_events():
651
- runner = _get_runner(app_name)
725
+ runner = await _get_runner_async(app_name)
652
726
  async for event in runner.run_live(
653
727
  session=session, live_request_queue=live_request_queue
654
728
  ):
@@ -682,30 +756,50 @@ def get_fast_api_app(
682
756
  except Exception as e:
683
757
  logger.exception("Error during live websocket communication: %s", e)
684
758
  traceback.print_exc()
759
+ WEBSOCKET_INTERNAL_ERROR_CODE = 1011
760
+ WEBSOCKET_MAX_BYTES_FOR_REASON = 123
761
+ await websocket.close(
762
+ code=WEBSOCKET_INTERNAL_ERROR_CODE,
763
+ reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
764
+ )
685
765
  finally:
686
766
  for task in pending:
687
767
  task.cancel()
688
768
 
689
- def _get_root_agent(app_name: str) -> Agent:
769
+ async def _get_root_agent_async(app_name: str) -> Agent:
690
770
  """Returns the root agent for the given app."""
691
771
  if app_name in root_agent_dict:
692
772
  return root_agent_dict[app_name]
693
- envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
694
773
  agent_module = importlib.import_module(app_name)
695
- root_agent: Agent = agent_module.agent.root_agent
774
+ if getattr(agent_module.agent, "root_agent"):
775
+ root_agent = agent_module.agent.root_agent
776
+ else:
777
+ raise ValueError(f'Unable to find "root_agent" from {app_name}.')
778
+
779
+ # Handle an awaitable root agent and await for the actual agent.
780
+ if inspect.isawaitable(root_agent):
781
+ try:
782
+ agent, exit_stack = await root_agent
783
+ exit_stacks.append(exit_stack)
784
+ root_agent = agent
785
+ except Exception as e:
786
+ raise RuntimeError(f"error getting root agent, {e}") from e
787
+
696
788
  root_agent_dict[app_name] = root_agent
697
789
  return root_agent
698
790
 
699
- def _get_runner(app_name: str) -> Runner:
791
+ async def _get_runner_async(app_name: str) -> Runner:
700
792
  """Returns the runner for the given app."""
793
+ envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
701
794
  if app_name in runner_dict:
702
795
  return runner_dict[app_name]
703
- root_agent = _get_root_agent(app_name)
796
+ root_agent = await _get_root_agent_async(app_name)
704
797
  runner = Runner(
705
798
  app_name=agent_engine_id if agent_engine_id else app_name,
706
799
  agent=root_agent,
707
800
  artifact_service=artifact_service,
708
801
  session_service=session_service,
802
+ memory_service=memory_service,
709
803
  )
710
804
  runner_dict[app_name] = runner
711
805
  return runner
@@ -0,0 +1,23 @@
1
+ # Copyright 2025 Google LLC
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import pydantic
16
+ from pydantic import alias_generators
17
+
18
+
19
+ class BaseModel(pydantic.BaseModel):
20
+ model_config = pydantic.ConfigDict(
21
+ alias_generator=alias_generators.to_camel,
22
+ populate_by_name=True,
23
+ )
@@ -12,11 +12,17 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Any
15
+ from typing import Any, Tuple
16
16
 
17
+ from deprecated import deprecated
18
+ from google.genai import types as genai_types
19
+
20
+ from ...evaluation.eval_case import IntermediateData
21
+ from ...evaluation.eval_case import Invocation
17
22
  from ...sessions.session import Session
18
23
 
19
24
 
25
+ @deprecated(reason='Use convert_session_to_eval_invocations instead.')
20
26
  def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
21
27
  """Converts a session data into eval format.
22
28
 
@@ -91,3 +97,79 @@ def convert_session_to_eval_format(session: Session) -> list[dict[str, Any]]:
91
97
  })
92
98
 
93
99
  return eval_case
100
+
101
+
102
+ def convert_session_to_eval_invocations(session: Session) -> list[Invocation]:
103
+ """Converts a session data into a list of Invocation.
104
+
105
+ Args:
106
+ session: The session that should be converted.
107
+
108
+ Returns:
109
+ list: A list of invocation.
110
+ """
111
+ invocations: list[Invocation] = []
112
+ events = session.events if session and session.events else []
113
+
114
+ for event in events:
115
+ if event.author == 'user':
116
+ if not event.content or not event.content.parts:
117
+ continue
118
+
119
+ # The content present in this event is the user content.
120
+ user_content = event.content
121
+ invocation_id = event.invocation_id
122
+ invocaton_timestamp = event.timestamp
123
+
124
+ # Find the corresponding tool usage or response for the query
125
+ tool_uses: list[genai_types.FunctionCall] = []
126
+ intermediate_responses: list[Tuple[str, list[genai_types.Part]]] = []
127
+
128
+ # Check subsequent events to extract tool uses or responses for this turn.
129
+ for subsequent_event in events[events.index(event) + 1 :]:
130
+ event_author = subsequent_event.author or 'agent'
131
+ if event_author == 'user':
132
+ # We found an event where the author was the user. This means that a
133
+ # new turn has started. So close this turn here.
134
+ break
135
+
136
+ if not subsequent_event.content or not subsequent_event.content.parts:
137
+ continue
138
+
139
+ intermediate_response_parts = []
140
+ for subsequent_part in subsequent_event.content.parts:
141
+ # Some events have both function call and reference
142
+ if subsequent_part.function_call:
143
+ tool_uses.append(subsequent_part.function_call)
144
+ elif subsequent_part.text:
145
+ # Also keep track of all the natural language responses that
146
+ # agent (or sub agents) generated.
147
+ intermediate_response_parts.append(subsequent_part)
148
+
149
+ if intermediate_response_parts:
150
+ # Only add an entry if there any intermediate entries.
151
+ intermediate_responses.append(
152
+ (event_author, intermediate_response_parts)
153
+ )
154
+
155
+ # If we are here then either we are done reading all the events or we
156
+ # encountered an event that had content authored by the end-user.
157
+ # This, basically means an end of turn.
158
+ # We assume that the last natural language intermediate response is the
159
+ # final response from the agent/model. We treat that as a reference.
160
+ invocations.append(
161
+ Invocation(
162
+ user_content=user_content,
163
+ invocation_id=invocation_id,
164
+ creation_timestamp=invocaton_timestamp,
165
+ intermediate_data=IntermediateData(
166
+ tool_uses=tool_uses,
167
+ intermediate_responses=intermediate_responses[:-1],
168
+ ),
169
+ final_response=genai_types.Content(
170
+ parts=intermediate_responses[-1][1]
171
+ ),
172
+ )
173
+ )
174
+
175
+ return invocations
@@ -14,6 +14,7 @@
14
14
 
15
15
  import logging
16
16
  import os
17
+ import sys
17
18
  import tempfile
18
19
  import time
19
20
 
@@ -22,11 +23,18 @@ LOGGING_FORMAT = (
22
23
  )
23
24
 
24
25
 
25
- def log_to_stderr(level=logging.INFO):
26
- logging.basicConfig(
27
- level=level,
28
- format=LOGGING_FORMAT,
29
- )
26
+ def setup_adk_logger(level=logging.INFO):
27
+ # Configure the root logger format and level.
28
+ logging.basicConfig(level=level, format=LOGGING_FORMAT)
29
+
30
+ # Set up adk_logger and log to stderr.
31
+ handler = logging.StreamHandler(sys.stderr)
32
+ handler.setLevel(level)
33
+ handler.setFormatter(logging.Formatter(LOGGING_FORMAT))
34
+
35
+ adk_logger = logging.getLogger('google_adk')
36
+ adk_logger.setLevel(level)
37
+ adk_logger.addHandler(handler)
30
38
 
31
39
 
32
40
  def log_to_tmp_folder(
@@ -15,13 +15,15 @@
15
15
  import logging
16
16
 
17
17
  from .base_code_executor import BaseCodeExecutor
18
+ from .built_in_code_executor import BuiltInCodeExecutor
18
19
  from .code_executor_context import CodeExecutorContext
19
20
  from .unsafe_local_code_executor import UnsafeLocalCodeExecutor
20
21
 
21
- logger = logging.getLogger(__name__)
22
+ logger = logging.getLogger('google_adk.' + __name__)
22
23
 
23
24
  __all__ = [
24
25
  'BaseCodeExecutor',
26
+ 'BuiltInCodeExecutor',
25
27
  'CodeExecutorContext',
26
28
  'UnsafeLocalCodeExecutor',
27
29
  ]