google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (113) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/base_agent.py.orig +330 -0
  3. google/adk/agents/callback_context.py +0 -5
  4. google/adk/agents/llm_agent.py +122 -30
  5. google/adk/agents/loop_agent.py +1 -1
  6. google/adk/agents/parallel_agent.py +7 -0
  7. google/adk/agents/readonly_context.py +7 -1
  8. google/adk/agents/run_config.py +1 -1
  9. google/adk/agents/sequential_agent.py +31 -0
  10. google/adk/agents/transcription_entry.py +4 -2
  11. google/adk/artifacts/gcs_artifact_service.py +1 -1
  12. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  13. google/adk/auth/auth_credential.py +6 -1
  14. google/adk/auth/auth_preprocessor.py +7 -1
  15. google/adk/auth/auth_tool.py +3 -4
  16. google/adk/cli/agent_graph.py +5 -5
  17. google/adk/cli/browser/index.html +2 -2
  18. google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
  19. google/adk/cli/cli.py +7 -7
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +172 -99
  22. google/adk/cli/cli_tools_click.py +147 -64
  23. google/adk/cli/fast_api.py +330 -148
  24. google/adk/cli/fast_api.py.orig +174 -80
  25. google/adk/cli/utils/common.py +23 -0
  26. google/adk/cli/utils/evals.py +83 -1
  27. google/adk/cli/utils/logs.py +13 -5
  28. google/adk/code_executors/__init__.py +3 -1
  29. google/adk/code_executors/built_in_code_executor.py +52 -0
  30. google/adk/evaluation/__init__.py +1 -1
  31. google/adk/evaluation/agent_evaluator.py +168 -128
  32. google/adk/evaluation/eval_case.py +102 -0
  33. google/adk/evaluation/eval_set.py +37 -0
  34. google/adk/evaluation/eval_sets_manager.py +42 -0
  35. google/adk/evaluation/evaluation_generator.py +88 -113
  36. google/adk/evaluation/evaluator.py +56 -0
  37. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  38. google/adk/evaluation/response_evaluator.py +106 -2
  39. google/adk/evaluation/trajectory_evaluator.py +83 -2
  40. google/adk/events/event.py +6 -1
  41. google/adk/events/event_actions.py +6 -1
  42. google/adk/examples/example_util.py +3 -2
  43. google/adk/flows/llm_flows/_code_execution.py +9 -1
  44. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  45. google/adk/flows/llm_flows/base_llm_flow.py +54 -15
  46. google/adk/flows/llm_flows/functions.py +9 -8
  47. google/adk/flows/llm_flows/instructions.py +13 -5
  48. google/adk/flows/llm_flows/single_flow.py +1 -1
  49. google/adk/memory/__init__.py +1 -1
  50. google/adk/memory/_utils.py +23 -0
  51. google/adk/memory/base_memory_service.py +23 -21
  52. google/adk/memory/base_memory_service.py.orig +76 -0
  53. google/adk/memory/in_memory_memory_service.py +57 -25
  54. google/adk/memory/memory_entry.py +37 -0
  55. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  56. google/adk/models/anthropic_llm.py +16 -9
  57. google/adk/models/gemini_llm_connection.py +11 -11
  58. google/adk/models/google_llm.py +9 -2
  59. google/adk/models/google_llm.py.orig +305 -0
  60. google/adk/models/lite_llm.py +77 -21
  61. google/adk/models/llm_response.py +14 -2
  62. google/adk/models/registry.py +1 -1
  63. google/adk/runners.py +65 -41
  64. google/adk/sessions/__init__.py +1 -1
  65. google/adk/sessions/base_session_service.py +6 -33
  66. google/adk/sessions/database_session_service.py +58 -65
  67. google/adk/sessions/in_memory_session_service.py +106 -24
  68. google/adk/sessions/session.py +3 -0
  69. google/adk/sessions/vertex_ai_session_service.py +23 -45
  70. google/adk/telemetry.py +3 -0
  71. google/adk/tools/__init__.py +4 -7
  72. google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
  73. google/adk/tools/_memory_entry_utils.py +30 -0
  74. google/adk/tools/agent_tool.py +9 -9
  75. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  76. google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
  77. google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
  78. google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
  79. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  80. google/adk/tools/base_toolset.py +58 -0
  81. google/adk/tools/enterprise_search_tool.py +65 -0
  82. google/adk/tools/function_parameter_parse_util.py +2 -2
  83. google/adk/tools/google_api_tool/__init__.py +18 -70
  84. google/adk/tools/google_api_tool/google_api_tool.py +11 -5
  85. google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
  86. google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
  87. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  88. google/adk/tools/langchain_tool.py +96 -49
  89. google/adk/tools/load_memory_tool.py +14 -5
  90. google/adk/tools/mcp_tool/__init__.py +3 -2
  91. google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
  92. google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
  93. google/adk/tools/mcp_tool/mcp_tool.py +12 -12
  94. google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
  95. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
  96. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
  97. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  98. google/adk/tools/preload_memory_tool.py +27 -18
  99. google/adk/tools/retrieval/__init__.py +1 -1
  100. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  101. google/adk/tools/toolbox_toolset.py +79 -0
  102. google/adk/tools/transfer_to_agent_tool.py +0 -1
  103. google/adk/version.py +1 -1
  104. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
  105. google_adk-1.0.0.dist-info/RECORD +195 -0
  106. google/adk/agents/remote_agent.py +0 -50
  107. google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
  108. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  109. google/adk/tools/toolbox_tool.py +0 -46
  110. google_adk-0.5.0.dist-info/RECORD +0 -180
  111. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
  112. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
  113. {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,6 +12,7 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
15
16
  import asyncio
16
17
  from contextlib import asynccontextmanager
17
18
  import importlib
@@ -20,8 +21,9 @@ import json
20
21
  import logging
21
22
  import os
22
23
  from pathlib import Path
23
- import re
24
+ import signal
24
25
  import sys
26
+ import time
25
27
  import traceback
26
28
  import typing
27
29
  from typing import Any
@@ -30,7 +32,6 @@ from typing import Literal
30
32
  from typing import Optional
31
33
 
32
34
  import click
33
- from click import Tuple
34
35
  from fastapi import FastAPI
35
36
  from fastapi import HTTPException
36
37
  from fastapi import Query
@@ -48,16 +49,22 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
48
49
  from opentelemetry.sdk.trace import export
49
50
  from opentelemetry.sdk.trace import ReadableSpan
50
51
  from opentelemetry.sdk.trace import TracerProvider
51
- from pydantic import BaseModel
52
+ from pydantic import Field
52
53
  from pydantic import ValidationError
53
54
  from starlette.types import Lifespan
55
+ from typing_extensions import override
54
56
 
55
57
  from ..agents import RunConfig
58
+ from ..agents.base_agent import BaseAgent
56
59
  from ..agents.live_request_queue import LiveRequest
57
60
  from ..agents.live_request_queue import LiveRequestQueue
58
61
  from ..agents.llm_agent import Agent
62
+ from ..agents.llm_agent import LlmAgent
59
63
  from ..agents.run_config import StreamingMode
60
64
  from ..artifacts import InMemoryArtifactService
65
+ from ..evaluation.eval_case import EvalCase
66
+ from ..evaluation.eval_case import SessionInput
67
+ from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
61
68
  from ..events.event import Event
62
69
  from ..memory.in_memory_memory_service import InMemoryMemoryService
63
70
  from ..runners import Runner
@@ -65,17 +72,23 @@ from ..sessions.database_session_service import DatabaseSessionService
65
72
  from ..sessions.in_memory_session_service import InMemorySessionService
66
73
  from ..sessions.session import Session
67
74
  from ..sessions.vertex_ai_session_service import VertexAiSessionService
75
+ from ..tools.base_toolset import BaseToolset
68
76
  from .cli_eval import EVAL_SESSION_ID_PREFIX
77
+ from .cli_eval import EvalCaseResult
69
78
  from .cli_eval import EvalMetric
70
79
  from .cli_eval import EvalMetricResult
80
+ from .cli_eval import EvalMetricResultPerInvocation
81
+ from .cli_eval import EvalSetResult
71
82
  from .cli_eval import EvalStatus
83
+ from .utils import common
72
84
  from .utils import create_empty_state
73
85
  from .utils import envs
74
86
  from .utils import evals
75
87
 
76
- logger = logging.getLogger(__name__)
88
+ logger = logging.getLogger("google_adk." + __name__)
77
89
 
78
90
  _EVAL_SET_FILE_EXTENSION = ".evalset.json"
91
+ _EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json"
79
92
 
80
93
 
81
94
  class ApiServerSpanExporter(export.SpanExporter):
@@ -103,7 +116,45 @@ class ApiServerSpanExporter(export.SpanExporter):
103
116
  return True
104
117
 
105
118
 
106
- class AgentRunRequest(BaseModel):
119
+ class InMemoryExporter(export.SpanExporter):
120
+
121
+ def __init__(self, trace_dict):
122
+ super().__init__()
123
+ self._spans = []
124
+ self.trace_dict = trace_dict
125
+
126
+ @override
127
+ def export(
128
+ self, spans: typing.Sequence[ReadableSpan]
129
+ ) -> export.SpanExportResult:
130
+ for span in spans:
131
+ trace_id = span.context.trace_id
132
+ if span.name == "call_llm":
133
+ attributes = dict(span.attributes)
134
+ session_id = attributes.get("gcp.vertex.agent.session_id", None)
135
+ if session_id:
136
+ if session_id not in self.trace_dict:
137
+ self.trace_dict[session_id] = [trace_id]
138
+ else:
139
+ self.trace_dict[session_id] += [trace_id]
140
+ self._spans.extend(spans)
141
+ return export.SpanExportResult.SUCCESS
142
+
143
+ @override
144
+ def force_flush(self, timeout_millis: int = 30000) -> bool:
145
+ return True
146
+
147
+ def get_finished_spans(self, session_id: str):
148
+ trace_ids = self.trace_dict.get(session_id, None)
149
+ if trace_ids is None or not trace_ids:
150
+ return []
151
+ return [x for x in self._spans if x.context.trace_id in trace_ids]
152
+
153
+ def clear(self):
154
+ self._spans.clear()
155
+
156
+
157
+ class AgentRunRequest(common.BaseModel):
107
158
  app_name: str
108
159
  user_id: str
109
160
  session_id: str
@@ -111,25 +162,38 @@ class AgentRunRequest(BaseModel):
111
162
  streaming: bool = False
112
163
 
113
164
 
114
- class AddSessionToEvalSetRequest(BaseModel):
165
+ class AddSessionToEvalSetRequest(common.BaseModel):
115
166
  eval_id: str
116
167
  session_id: str
117
168
  user_id: str
118
169
 
119
170
 
120
- class RunEvalRequest(BaseModel):
171
+ class RunEvalRequest(common.BaseModel):
121
172
  eval_ids: list[str] # if empty, then all evals in the eval set are run.
122
173
  eval_metrics: list[EvalMetric]
123
174
 
124
175
 
125
- class RunEvalResult(BaseModel):
176
+ class RunEvalResult(common.BaseModel):
177
+ eval_set_file: str
126
178
  eval_set_id: str
127
179
  eval_id: str
128
180
  final_eval_status: EvalStatus
129
- eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
181
+ eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
182
+ deprecated=True,
183
+ description=(
184
+ "This field is deprecated, use overall_eval_metric_results instead."
185
+ ),
186
+ )
187
+ overall_eval_metric_results: list[EvalMetricResult]
188
+ eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
189
+ user_id: str
130
190
  session_id: str
131
191
 
132
192
 
193
+ class GetEventGraphResult(common.BaseModel):
194
+ dot_src: str
195
+
196
+
133
197
  def get_fast_api_app(
134
198
  *,
135
199
  agent_dir: str,
@@ -141,12 +205,15 @@ def get_fast_api_app(
141
205
  ) -> FastAPI:
142
206
  # InMemory tracing dict.
143
207
  trace_dict: dict[str, Any] = {}
208
+ session_trace_dict: dict[str, Any] = {}
144
209
 
145
210
  # Set up tracing in the FastAPI server.
146
211
  provider = TracerProvider()
147
212
  provider.add_span_processor(
148
213
  export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
149
214
  )
215
+ memory_exporter = InMemoryExporter(session_trace_dict)
216
+ provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
150
217
  if trace_to_cloud:
151
218
  envs.load_dotenv_for_agent("", agent_dir)
152
219
  if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
@@ -155,26 +222,82 @@ def get_fast_api_app(
155
222
  )
156
223
  provider.add_span_processor(processor)
157
224
  else:
158
- logging.warning(
225
+ logger.warning(
159
226
  "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
160
227
  " not be enabled."
161
228
  )
162
229
 
163
230
  trace.set_tracer_provider(provider)
164
231
 
165
- exit_stacks = []
232
+ toolsets_to_close: set[BaseToolset] = set()
166
233
 
167
234
  @asynccontextmanager
168
235
  async def internal_lifespan(app: FastAPI):
169
- if lifespan:
170
- async with lifespan(app) as lifespan_context:
236
+ # Set up signal handlers for graceful shutdown
237
+ original_sigterm = signal.getsignal(signal.SIGTERM)
238
+ original_sigint = signal.getsignal(signal.SIGINT)
239
+
240
+ def cleanup_handler(sig, frame):
241
+ # Log the signal
242
+ logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
243
+ # Do synchronous cleanup if needed
244
+ # Then call original handler if it exists
245
+ if sig == signal.SIGTERM and callable(original_sigterm):
246
+ original_sigterm(sig, frame)
247
+ elif sig == signal.SIGINT and callable(original_sigint):
248
+ original_sigint(sig, frame)
249
+
250
+ # Install cleanup handlers
251
+ signal.signal(signal.SIGTERM, cleanup_handler)
252
+ signal.signal(signal.SIGINT, cleanup_handler)
253
+
254
+ try:
255
+ if lifespan:
256
+ async with lifespan(app) as lifespan_context:
257
+ yield lifespan_context
258
+ else:
171
259
  yield
260
+ finally:
261
+ # During shutdown, properly clean up all toolsets
262
+ logger.info(
263
+ "Server shutdown initiated, cleaning up %s toolsets",
264
+ len(toolsets_to_close),
265
+ )
172
266
 
173
- if exit_stacks:
174
- for stack in exit_stacks:
175
- await stack.aclose()
176
- else:
177
- yield
267
+ # Create tasks for all toolset closures to run concurrently
268
+ cleanup_tasks = []
269
+ for toolset in toolsets_to_close:
270
+ task = asyncio.create_task(close_toolset_safely(toolset))
271
+ cleanup_tasks.append(task)
272
+
273
+ if cleanup_tasks:
274
+ # Wait for all cleanup tasks with timeout
275
+ done, pending = await asyncio.wait(
276
+ cleanup_tasks,
277
+ timeout=10.0, # 10 second timeout for cleanup
278
+ return_when=asyncio.ALL_COMPLETED,
279
+ )
280
+
281
+ # If any tasks are still pending, log it
282
+ if pending:
283
+ logger.warning(
284
+ f"{len(pending)} toolset cleanup tasks didn't complete in time"
285
+ )
286
+ for task in pending:
287
+ task.cancel()
288
+
289
+ # Restore original signal handlers
290
+ signal.signal(signal.SIGTERM, original_sigterm)
291
+ signal.signal(signal.SIGINT, original_sigint)
292
+
293
+ async def close_toolset_safely(toolset):
294
+ """Safely close a toolset with error handling."""
295
+ try:
296
+ logger.info(f"Closing toolset: {type(toolset).__name__}")
297
+ await toolset.close()
298
+ logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
299
+ except Exception as e:
300
+ logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
178
301
 
179
302
  # Run the FastAPI server.
180
303
  app = FastAPI(lifespan=internal_lifespan)
@@ -198,6 +321,8 @@ def get_fast_api_app(
198
321
  artifact_service = InMemoryArtifactService()
199
322
  memory_service = InMemoryMemoryService()
200
323
 
324
+ eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
325
+
201
326
  # Build the Session service
202
327
  agent_engine_id = ""
203
328
  if session_db_url:
@@ -240,14 +365,34 @@ def get_fast_api_app(
240
365
  raise HTTPException(status_code=404, detail="Trace not found")
241
366
  return event_dict
242
367
 
368
+ @app.get("/debug/trace/session/{session_id}")
369
+ def get_session_trace(session_id: str) -> Any:
370
+ spans = memory_exporter.get_finished_spans(session_id)
371
+ if not spans:
372
+ return []
373
+ return [
374
+ {
375
+ "name": s.name,
376
+ "span_id": s.context.span_id,
377
+ "trace_id": s.context.trace_id,
378
+ "start_time": s.start_time,
379
+ "end_time": s.end_time,
380
+ "attributes": dict(s.attributes),
381
+ "parent_span_id": s.parent.span_id if s.parent else None,
382
+ }
383
+ for s in spans
384
+ ]
385
+
243
386
  @app.get(
244
387
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
245
388
  response_model_exclude_none=True,
246
389
  )
247
- def get_session(app_name: str, user_id: str, session_id: str) -> Session:
390
+ async def get_session(
391
+ app_name: str, user_id: str, session_id: str
392
+ ) -> Session:
248
393
  # Connect to managed session if agent_engine_id is set.
249
394
  app_name = agent_engine_id if agent_engine_id else app_name
250
- session = session_service.get_session(
395
+ session = await session_service.get_session(
251
396
  app_name=app_name, user_id=user_id, session_id=session_id
252
397
  )
253
398
  if not session:
@@ -258,14 +403,15 @@ def get_fast_api_app(
258
403
  "/apps/{app_name}/users/{user_id}/sessions",
259
404
  response_model_exclude_none=True,
260
405
  )
261
- def list_sessions(app_name: str, user_id: str) -> list[Session]:
406
+ async def list_sessions(app_name: str, user_id: str) -> list[Session]:
262
407
  # Connect to managed session if agent_engine_id is set.
263
408
  app_name = agent_engine_id if agent_engine_id else app_name
409
+ list_sessions_response = await session_service.list_sessions(
410
+ app_name=app_name, user_id=user_id
411
+ )
264
412
  return [
265
413
  session
266
- for session in session_service.list_sessions(
267
- app_name=app_name, user_id=user_id
268
- ).sessions
414
+ for session in list_sessions_response.sessions
269
415
  # Remove sessions that were generated as a part of Eval.
270
416
  if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
271
417
  ]
@@ -274,7 +420,7 @@ def get_fast_api_app(
274
420
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
275
421
  response_model_exclude_none=True,
276
422
  )
277
- def create_session_with_id(
423
+ async def create_session_with_id(
278
424
  app_name: str,
279
425
  user_id: str,
280
426
  session_id: str,
@@ -283,7 +429,7 @@ def get_fast_api_app(
283
429
  # Connect to managed session if agent_engine_id is set.
284
430
  app_name = agent_engine_id if agent_engine_id else app_name
285
431
  if (
286
- session_service.get_session(
432
+ await session_service.get_session(
287
433
  app_name=app_name, user_id=user_id, session_id=session_id
288
434
  )
289
435
  is not None
@@ -292,9 +438,8 @@ def get_fast_api_app(
292
438
  raise HTTPException(
293
439
  status_code=400, detail=f"Session already exists: {session_id}"
294
440
  )
295
-
296
441
  logger.info("New session created: %s", session_id)
297
- return session_service.create_session(
442
+ return await session_service.create_session(
298
443
  app_name=app_name, user_id=user_id, state=state, session_id=session_id
299
444
  )
300
445
 
@@ -302,16 +447,15 @@ def get_fast_api_app(
302
447
  "/apps/{app_name}/users/{user_id}/sessions",
303
448
  response_model_exclude_none=True,
304
449
  )
305
- def create_session(
450
+ async def create_session(
306
451
  app_name: str,
307
452
  user_id: str,
308
453
  state: Optional[dict[str, Any]] = None,
309
454
  ) -> Session:
310
455
  # Connect to managed session if agent_engine_id is set.
311
456
  app_name = agent_engine_id if agent_engine_id else app_name
312
-
313
457
  logger.info("New session created")
314
- return session_service.create_session(
458
+ return await session_service.create_session(
315
459
  app_name=app_name, user_id=user_id, state=state
316
460
  )
317
461
 
@@ -331,28 +475,13 @@ def get_fast_api_app(
331
475
  eval_set_id: str,
332
476
  ):
333
477
  """Creates an eval set, given the id."""
334
- pattern = r"^[a-zA-Z0-9_]+$"
335
- if not bool(re.fullmatch(pattern, eval_set_id)):
478
+ try:
479
+ eval_sets_manager.create_eval_set(app_name, eval_set_id)
480
+ except ValueError as ve:
336
481
  raise HTTPException(
337
482
  status_code=400,
338
- detail=(
339
- f"Invalid eval set id. Eval set id should have the `{pattern}`"
340
- " format"
341
- ),
342
- )
343
- # Define the file path
344
- new_eval_set_path = _get_eval_set_file_path(
345
- app_name, agent_dir, eval_set_id
346
- )
347
-
348
- logger.info("Creating eval set file `%s`", new_eval_set_path)
349
-
350
- if not os.path.exists(new_eval_set_path):
351
- # Write the JSON string to the file
352
- logger.info("Eval set file doesn't exist, we will create a new one.")
353
- with open(new_eval_set_path, "w") as f:
354
- empty_content = json.dumps([], indent=2)
355
- f.write(empty_content)
483
+ detail=str(ve),
484
+ ) from ve
356
485
 
357
486
  @app.get(
358
487
  "/apps/{app_name}/eval_sets",
@@ -360,15 +489,7 @@ def get_fast_api_app(
360
489
  )
361
490
  def list_eval_sets(app_name: str) -> list[str]:
362
491
  """Lists all eval sets for the given app."""
363
- eval_set_file_path = os.path.join(agent_dir, app_name)
364
- eval_sets = []
365
- for file in os.listdir(eval_set_file_path):
366
- if file.endswith(_EVAL_SET_FILE_EXTENSION):
367
- eval_sets.append(
368
- os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
369
- )
370
-
371
- return sorted(eval_sets)
492
+ return eval_sets_manager.list_eval_sets(app_name)
372
493
 
373
494
  @app.post(
374
495
  "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
@@ -377,54 +498,33 @@ def get_fast_api_app(
377
498
  async def add_session_to_eval_set(
378
499
  app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
379
500
  ):
380
- pattern = r"^[a-zA-Z0-9_]+$"
381
- if not bool(re.fullmatch(pattern, req.eval_id)):
382
- raise HTTPException(
383
- status_code=400,
384
- detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
385
- )
386
-
387
501
  # Get the session
388
- session = session_service.get_session(
502
+ session = await session_service.get_session(
389
503
  app_name=app_name, user_id=req.user_id, session_id=req.session_id
390
504
  )
391
505
  assert session, "Session not found."
392
- # Load the eval set file data
393
- eval_set_file_path = _get_eval_set_file_path(
394
- app_name, agent_dir, eval_set_id
395
- )
396
- with open(eval_set_file_path, "r") as file:
397
- eval_set_data = json.load(file) # Load JSON into a list
398
506
 
399
- if [x for x in eval_set_data if x["name"] == req.eval_id]:
400
- raise HTTPException(
401
- status_code=400,
402
- detail=(
403
- f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
404
- " eval set."
405
- ),
406
- )
407
-
408
- # Convert the session data to evaluation format
409
- test_data = evals.convert_session_to_eval_format(session)
507
+ # Convert the session data to eval invocations
508
+ invocations = evals.convert_session_to_eval_invocations(session)
410
509
 
411
510
  # Populate the session with initial session state.
412
511
  initial_session_state = create_empty_state(
413
512
  await _get_root_agent_async(app_name)
414
513
  )
415
514
 
416
- eval_set_data.append({
417
- "name": req.eval_id,
418
- "data": test_data,
419
- "initial_session": {
420
- "state": initial_session_state,
421
- "app_name": app_name,
422
- "user_id": req.user_id,
423
- },
424
- })
425
- # Serialize the test data to JSON and write to the eval set file.
426
- with open(eval_set_file_path, "w") as f:
427
- f.write(json.dumps(eval_set_data, indent=2))
515
+ new_eval_case = EvalCase(
516
+ eval_id=req.eval_id,
517
+ conversation=invocations,
518
+ session_input=SessionInput(
519
+ app_name=app_name, user_id=req.user_id, state=initial_session_state
520
+ ),
521
+ creation_timestamp=time.time(),
522
+ )
523
+
524
+ try:
525
+ eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case)
526
+ except ValueError as ve:
527
+ raise HTTPException(status_code=400, detail=str(ve)) from ve
428
528
 
429
529
  @app.get(
430
530
  "/apps/{app_name}/eval_sets/{eval_set_id}/evals",
@@ -435,14 +535,9 @@ def get_fast_api_app(
435
535
  eval_set_id: str,
436
536
  ) -> list[str]:
437
537
  """Lists all evals in an eval set."""
438
- # Load the eval set file data
439
- eval_set_file_path = _get_eval_set_file_path(
440
- app_name, agent_dir, eval_set_id
441
- )
442
- with open(eval_set_file_path, "r") as file:
443
- eval_set_data = json.load(file) # Load JSON into a list
538
+ eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
444
539
 
445
- return sorted([x["name"] for x in eval_set_data])
540
+ return sorted([x.eval_id for x in eval_set_data.eval_cases])
446
541
 
447
542
  @app.post(
448
543
  "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
@@ -451,51 +546,136 @@ def get_fast_api_app(
451
546
  async def run_eval(
452
547
  app_name: str, eval_set_id: str, req: RunEvalRequest
453
548
  ) -> list[RunEvalResult]:
549
+ """Runs an eval given the details in the eval request."""
454
550
  from .cli_eval import run_evals
455
551
 
456
- """Runs an eval given the details in the eval request."""
457
552
  # Create a mapping from eval set file to all the evals that needed to be
458
553
  # run.
459
- eval_set_file_path = _get_eval_set_file_path(
460
- app_name, agent_dir, eval_set_id
461
- )
462
- eval_set_to_evals = {eval_set_file_path: req.eval_ids}
554
+ envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
463
555
 
464
- if not req.eval_ids:
465
- logger.info(
466
- "Eval ids to run list is empty. We will all evals in the eval set."
467
- )
468
- root_agent = await _get_root_agent_async(app_name)
469
- eval_results = list(
470
- await run_evals(
471
- eval_set_to_evals,
472
- root_agent,
473
- getattr(root_agent, "reset_data", None),
474
- req.eval_metrics,
475
- session_service=session_service,
476
- artifact_service=artifact_service,
477
- )
478
- )
556
+ eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
557
+
558
+ if req.eval_ids:
559
+ eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
560
+ eval_set_to_evals = {eval_set_id: eval_cases}
561
+ else:
562
+ logger.info("Eval ids to run list is empty. We will run all eval cases.")
563
+ eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
479
564
 
565
+ root_agent = await _get_root_agent_async(app_name)
480
566
  run_eval_results = []
481
- for eval_result in eval_results:
567
+ eval_case_results = []
568
+ async for eval_case_result in run_evals(
569
+ eval_set_to_evals,
570
+ root_agent,
571
+ getattr(root_agent, "reset_data", None),
572
+ req.eval_metrics,
573
+ session_service=session_service,
574
+ artifact_service=artifact_service,
575
+ ):
482
576
  run_eval_results.append(
483
577
  RunEvalResult(
484
578
  app_name=app_name,
579
+ eval_set_file=eval_case_result.eval_set_file,
485
580
  eval_set_id=eval_set_id,
486
- eval_id=eval_result.eval_id,
487
- final_eval_status=eval_result.final_eval_status,
488
- eval_metric_results=eval_result.eval_metric_results,
489
- session_id=eval_result.session_id,
581
+ eval_id=eval_case_result.eval_id,
582
+ final_eval_status=eval_case_result.final_eval_status,
583
+ eval_metric_results=eval_case_result.eval_metric_results,
584
+ overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
585
+ eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
586
+ user_id=eval_case_result.user_id,
587
+ session_id=eval_case_result.session_id,
490
588
  )
491
589
  )
590
+ eval_case_result.session_details = await session_service.get_session(
591
+ app_name=app_name,
592
+ user_id=eval_case_result.user_id,
593
+ session_id=eval_case_result.session_id,
594
+ )
595
+ eval_case_results.append(eval_case_result)
596
+
597
+ timestamp = time.time()
598
+ eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)
599
+ eval_set_result = EvalSetResult(
600
+ eval_set_result_id=eval_set_result_name,
601
+ eval_set_result_name=eval_set_result_name,
602
+ eval_set_id=eval_set_id,
603
+ eval_case_results=eval_case_results,
604
+ creation_timestamp=timestamp,
605
+ )
606
+
607
+ # Write eval result file, with eval_set_result_name.
608
+ app_eval_history_dir = os.path.join(
609
+ agent_dir, app_name, ".adk", "eval_history"
610
+ )
611
+ if not os.path.exists(app_eval_history_dir):
612
+ os.makedirs(app_eval_history_dir)
613
+ # Convert to json and write to file.
614
+ eval_set_result_json = eval_set_result.model_dump_json()
615
+ eval_set_result_file_path = os.path.join(
616
+ app_eval_history_dir,
617
+ eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION,
618
+ )
619
+ logger.info("Writing eval result to file: %s", eval_set_result_file_path)
620
+ with open(eval_set_result_file_path, "w") as f:
621
+ f.write(json.dumps(eval_set_result_json, indent=2))
622
+
492
623
  return run_eval_results
493
624
 
625
+ @app.get(
626
+ "/apps/{app_name}/eval_results/{eval_result_id}",
627
+ response_model_exclude_none=True,
628
+ )
629
+ def get_eval_result(
630
+ app_name: str,
631
+ eval_result_id: str,
632
+ ) -> EvalSetResult:
633
+ """Gets the eval result for the given eval id."""
634
+ # Load the eval set file data
635
+ maybe_eval_result_file_path = (
636
+ os.path.join(
637
+ agent_dir, app_name, ".adk", "eval_history", eval_result_id
638
+ )
639
+ + _EVAL_SET_RESULT_FILE_EXTENSION
640
+ )
641
+ if not os.path.exists(maybe_eval_result_file_path):
642
+ raise HTTPException(
643
+ status_code=404,
644
+ detail=f"Eval result `{eval_result_id}` not found.",
645
+ )
646
+ with open(maybe_eval_result_file_path, "r") as file:
647
+ eval_result_data = json.load(file) # Load JSON into a list
648
+ try:
649
+ eval_result = EvalSetResult.model_validate_json(eval_result_data)
650
+ return eval_result
651
+ except ValidationError as e:
652
+ logger.exception("get_eval_result validation error: %s", e)
653
+
654
+ @app.get(
655
+ "/apps/{app_name}/eval_results",
656
+ response_model_exclude_none=True,
657
+ )
658
+ def list_eval_results(app_name: str) -> list[str]:
659
+ """Lists all eval results for the given app."""
660
+ app_eval_history_directory = os.path.join(
661
+ agent_dir, app_name, ".adk", "eval_history"
662
+ )
663
+
664
+ if not os.path.exists(app_eval_history_directory):
665
+ return []
666
+
667
+ eval_result_files = [
668
+ file.removesuffix(_EVAL_SET_RESULT_FILE_EXTENSION)
669
+ for file in os.listdir(app_eval_history_directory)
670
+ if file.endswith(_EVAL_SET_RESULT_FILE_EXTENSION)
671
+ ]
672
+ return eval_result_files
673
+
494
674
  @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
495
- def delete_session(app_name: str, user_id: str, session_id: str):
675
+ async def delete_session(app_name: str, user_id: str, session_id: str):
496
676
  # Connect to managed session if agent_engine_id is set.
497
677
  app_name = agent_engine_id if agent_engine_id else app_name
498
- session_service.delete_session(
678
+ await session_service.delete_session(
499
679
  app_name=app_name, user_id=user_id, session_id=session_id
500
680
  )
501
681
 
@@ -590,7 +770,7 @@ def get_fast_api_app(
590
770
  async def agent_run(req: AgentRunRequest) -> list[Event]:
591
771
  # Connect to managed session if agent_engine_id is set.
592
772
  app_id = agent_engine_id if agent_engine_id else req.app_name
593
- session = session_service.get_session(
773
+ session = await session_service.get_session(
594
774
  app_name=app_id, user_id=req.user_id, session_id=req.session_id
595
775
  )
596
776
  if not session:
@@ -612,7 +792,7 @@ def get_fast_api_app(
612
792
  # Connect to managed session if agent_engine_id is set.
613
793
  app_id = agent_engine_id if agent_engine_id else req.app_name
614
794
  # SSE endpoint
615
- session = session_service.get_session(
795
+ session = await session_service.get_session(
616
796
  app_name=app_id, user_id=req.user_id, session_id=req.session_id
617
797
  )
618
798
  if not session:
@@ -653,7 +833,7 @@ def get_fast_api_app(
653
833
  ):
654
834
  # Connect to managed session if agent_engine_id is set.
655
835
  app_id = agent_engine_id if agent_engine_id else app_name
656
- session = session_service.get_session(
836
+ session = await session_service.get_session(
657
837
  app_name=app_id, user_id=user_id, session_id=session_id
658
838
  )
659
839
  session_events = session.events if session else []
@@ -673,7 +853,7 @@ def get_fast_api_app(
673
853
  from_name = event.author
674
854
  to_name = function_call.name
675
855
  function_call_highlights.append((from_name, to_name))
676
- dot_graph = agent_graph.get_agent_graph(
856
+ dot_graph = await agent_graph.get_agent_graph(
677
857
  root_agent, function_call_highlights
678
858
  )
679
859
  elif function_responses:
@@ -682,17 +862,17 @@ def get_fast_api_app(
682
862
  from_name = function_response.name
683
863
  to_name = event.author
684
864
  function_responses_highlights.append((from_name, to_name))
685
- dot_graph = agent_graph.get_agent_graph(
865
+ dot_graph = await agent_graph.get_agent_graph(
686
866
  root_agent, function_responses_highlights
687
867
  )
688
868
  else:
689
869
  from_name = event.author
690
870
  to_name = ""
691
- dot_graph = agent_graph.get_agent_graph(
871
+ dot_graph = await agent_graph.get_agent_graph(
692
872
  root_agent, [(from_name, to_name)]
693
873
  )
694
874
  if dot_graph and isinstance(dot_graph, graphviz.Digraph):
695
- return {"dot_src": dot_graph.source}
875
+ return GetEventGraphResult(dot_src=dot_graph.source)
696
876
  else:
697
877
  return {}
698
878
 
@@ -710,7 +890,7 @@ def get_fast_api_app(
710
890
 
711
891
  # Connect to managed session if agent_engine_id is set.
712
892
  app_id = agent_engine_id if agent_engine_id else app_name
713
- session = session_service.get_session(
893
+ session = await session_service.get_session(
714
894
  app_name=app_id, user_id=user_id, session_id=session_id
715
895
  )
716
896
  if not session:
@@ -766,6 +946,16 @@ def get_fast_api_app(
766
946
  for task in pending:
767
947
  task.cancel()
768
948
 
949
+ def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
950
+ toolsets = set()
951
+ if isinstance(agent, LlmAgent):
952
+ for tool_union in agent.tools:
953
+ if isinstance(tool_union, BaseToolset):
954
+ toolsets.add(tool_union)
955
+ for sub_agent in agent.sub_agents:
956
+ toolsets.update(_get_all_toolsets(sub_agent))
957
+ return toolsets
958
+
769
959
  async def _get_root_agent_async(app_name: str) -> Agent:
770
960
  """Returns the root agent for the given app."""
771
961
  if app_name in root_agent_dict:
@@ -776,16 +966,8 @@ def get_fast_api_app(
776
966
  else:
777
967
  raise ValueError(f'Unable to find "root_agent" from {app_name}.')
778
968
 
779
- # Handle an awaitable root agent and await for the actual agent.
780
- if inspect.isawaitable(root_agent):
781
- try:
782
- agent, exit_stack = await root_agent
783
- exit_stacks.append(exit_stack)
784
- root_agent = agent
785
- except Exception as e:
786
- raise RuntimeError(f"error getting root agent, {e}") from e
787
-
788
969
  root_agent_dict[app_name] = root_agent
970
+ toolsets_to_close.update(_get_all_toolsets(root_agent))
789
971
  return root_agent
790
972
 
791
973
  async def _get_runner_async(app_name: str) -> Runner: