google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (139) hide show
  1. google/adk/agents/base_agent.py +76 -30
  2. google/adk/agents/callback_context.py +2 -6
  3. google/adk/agents/llm_agent.py +122 -30
  4. google/adk/agents/loop_agent.py +1 -1
  5. google/adk/agents/parallel_agent.py +7 -0
  6. google/adk/agents/readonly_context.py +8 -0
  7. google/adk/agents/run_config.py +1 -1
  8. google/adk/agents/sequential_agent.py +31 -0
  9. google/adk/agents/transcription_entry.py +4 -2
  10. google/adk/artifacts/gcs_artifact_service.py +1 -1
  11. google/adk/artifacts/in_memory_artifact_service.py +1 -1
  12. google/adk/auth/auth_credential.py +10 -2
  13. google/adk/auth/auth_preprocessor.py +7 -1
  14. google/adk/auth/auth_tool.py +3 -4
  15. google/adk/cli/agent_graph.py +5 -5
  16. google/adk/cli/browser/index.html +4 -4
  17. google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
  18. google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
  19. google/adk/cli/cli.py +10 -9
  20. google/adk/cli/cli_deploy.py +7 -2
  21. google/adk/cli/cli_eval.py +109 -115
  22. google/adk/cli/cli_tools_click.py +179 -67
  23. google/adk/cli/fast_api.py +248 -197
  24. google/adk/cli/utils/agent_loader.py +137 -0
  25. google/adk/cli/utils/cleanup.py +40 -0
  26. google/adk/cli/utils/common.py +23 -0
  27. google/adk/cli/utils/evals.py +83 -0
  28. google/adk/cli/utils/logs.py +8 -5
  29. google/adk/code_executors/__init__.py +3 -1
  30. google/adk/code_executors/built_in_code_executor.py +52 -0
  31. google/adk/code_executors/code_execution_utils.py +2 -1
  32. google/adk/code_executors/container_code_executor.py +0 -1
  33. google/adk/code_executors/vertex_ai_code_executor.py +6 -8
  34. google/adk/evaluation/__init__.py +1 -1
  35. google/adk/evaluation/agent_evaluator.py +168 -128
  36. google/adk/evaluation/eval_case.py +104 -0
  37. google/adk/evaluation/eval_metrics.py +74 -0
  38. google/adk/evaluation/eval_result.py +86 -0
  39. google/adk/evaluation/eval_set.py +39 -0
  40. google/adk/evaluation/eval_set_results_manager.py +47 -0
  41. google/adk/evaluation/eval_sets_manager.py +43 -0
  42. google/adk/evaluation/evaluation_generator.py +88 -113
  43. google/adk/evaluation/evaluator.py +58 -0
  44. google/adk/evaluation/local_eval_set_results_manager.py +113 -0
  45. google/adk/evaluation/local_eval_sets_manager.py +264 -0
  46. google/adk/evaluation/response_evaluator.py +106 -1
  47. google/adk/evaluation/trajectory_evaluator.py +84 -2
  48. google/adk/events/event.py +6 -1
  49. google/adk/events/event_actions.py +6 -1
  50. google/adk/examples/base_example_provider.py +1 -0
  51. google/adk/examples/example_util.py +3 -2
  52. google/adk/flows/llm_flows/_code_execution.py +9 -1
  53. google/adk/flows/llm_flows/audio_transcriber.py +4 -3
  54. google/adk/flows/llm_flows/base_llm_flow.py +58 -21
  55. google/adk/flows/llm_flows/contents.py +3 -1
  56. google/adk/flows/llm_flows/functions.py +9 -8
  57. google/adk/flows/llm_flows/instructions.py +18 -80
  58. google/adk/flows/llm_flows/single_flow.py +2 -2
  59. google/adk/memory/__init__.py +1 -1
  60. google/adk/memory/_utils.py +23 -0
  61. google/adk/memory/base_memory_service.py +23 -21
  62. google/adk/memory/in_memory_memory_service.py +57 -25
  63. google/adk/memory/memory_entry.py +37 -0
  64. google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
  65. google/adk/models/anthropic_llm.py +16 -9
  66. google/adk/models/base_llm.py +2 -1
  67. google/adk/models/base_llm_connection.py +2 -0
  68. google/adk/models/gemini_llm_connection.py +11 -11
  69. google/adk/models/google_llm.py +12 -2
  70. google/adk/models/lite_llm.py +80 -23
  71. google/adk/models/llm_response.py +16 -3
  72. google/adk/models/registry.py +1 -1
  73. google/adk/runners.py +98 -42
  74. google/adk/sessions/__init__.py +1 -1
  75. google/adk/sessions/_session_util.py +2 -1
  76. google/adk/sessions/base_session_service.py +6 -33
  77. google/adk/sessions/database_session_service.py +57 -67
  78. google/adk/sessions/in_memory_session_service.py +106 -24
  79. google/adk/sessions/session.py +3 -0
  80. google/adk/sessions/vertex_ai_session_service.py +44 -51
  81. google/adk/telemetry.py +7 -2
  82. google/adk/tools/__init__.py +4 -7
  83. google/adk/tools/_memory_entry_utils.py +30 -0
  84. google/adk/tools/agent_tool.py +10 -10
  85. google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
  86. google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
  87. google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
  88. google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
  89. google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
  90. google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
  91. google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
  92. google/adk/tools/base_toolset.py +96 -0
  93. google/adk/tools/bigquery/__init__.py +28 -0
  94. google/adk/tools/bigquery/bigquery_credentials.py +216 -0
  95. google/adk/tools/bigquery/bigquery_tool.py +116 -0
  96. google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
  97. google/adk/tools/function_parameter_parse_util.py +9 -2
  98. google/adk/tools/function_tool.py +33 -3
  99. google/adk/tools/get_user_choice_tool.py +1 -0
  100. google/adk/tools/google_api_tool/__init__.py +24 -70
  101. google/adk/tools/google_api_tool/google_api_tool.py +12 -6
  102. google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
  103. google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
  104. google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
  105. google/adk/tools/google_search_tool.py +2 -2
  106. google/adk/tools/langchain_tool.py +96 -49
  107. google/adk/tools/load_memory_tool.py +14 -5
  108. google/adk/tools/mcp_tool/__init__.py +3 -2
  109. google/adk/tools/mcp_tool/conversion_utils.py +6 -2
  110. google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
  111. google/adk/tools/mcp_tool/mcp_tool.py +35 -32
  112. google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
  113. google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
  114. google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
  115. google/adk/tools/openapi_tool/common/common.py +5 -1
  116. google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
  117. google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
  118. google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
  119. google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
  120. google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
  121. google/adk/tools/preload_memory_tool.py +27 -18
  122. google/adk/tools/retrieval/__init__.py +1 -1
  123. google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
  124. google/adk/tools/toolbox_toolset.py +107 -0
  125. google/adk/tools/transfer_to_agent_tool.py +0 -1
  126. google/adk/utils/__init__.py +13 -0
  127. google/adk/utils/instructions_utils.py +131 -0
  128. google/adk/version.py +1 -1
  129. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
  130. google_adk-1.1.0.dist-info/RECORD +200 -0
  131. google/adk/agents/remote_agent.py +0 -50
  132. google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
  133. google/adk/cli/fast_api.py.orig +0 -728
  134. google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
  135. google/adk/tools/toolbox_tool.py +0 -46
  136. google_adk-0.5.0.dist-info/RECORD +0 -180
  137. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
  138. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
  139. {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
@@ -12,16 +12,15 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+
16
+ from __future__ import annotations
17
+
15
18
  import asyncio
16
19
  from contextlib import asynccontextmanager
17
- import importlib
18
- import inspect
19
- import json
20
20
  import logging
21
21
  import os
22
22
  from pathlib import Path
23
- import re
24
- import sys
23
+ import time
25
24
  import traceback
26
25
  import typing
27
26
  from typing import Any
@@ -30,7 +29,6 @@ from typing import Literal
30
29
  from typing import Optional
31
30
 
32
31
  import click
33
- from click import Tuple
34
32
  from fastapi import FastAPI
35
33
  from fastapi import HTTPException
36
34
  from fastapi import Query
@@ -48,16 +46,25 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
48
46
  from opentelemetry.sdk.trace import export
49
47
  from opentelemetry.sdk.trace import ReadableSpan
50
48
  from opentelemetry.sdk.trace import TracerProvider
51
- from pydantic import BaseModel
49
+ from pydantic import Field
52
50
  from pydantic import ValidationError
53
51
  from starlette.types import Lifespan
52
+ from typing_extensions import override
54
53
 
55
54
  from ..agents import RunConfig
56
55
  from ..agents.live_request_queue import LiveRequest
57
56
  from ..agents.live_request_queue import LiveRequestQueue
58
57
  from ..agents.llm_agent import Agent
59
58
  from ..agents.run_config import StreamingMode
60
- from ..artifacts import InMemoryArtifactService
59
+ from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
60
+ from ..evaluation.eval_case import EvalCase
61
+ from ..evaluation.eval_case import SessionInput
62
+ from ..evaluation.eval_metrics import EvalMetric
63
+ from ..evaluation.eval_metrics import EvalMetricResult
64
+ from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
65
+ from ..evaluation.eval_result import EvalSetResult
66
+ from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
67
+ from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
61
68
  from ..events.event import Event
62
69
  from ..memory.in_memory_memory_service import InMemoryMemoryService
63
70
  from ..runners import Runner
@@ -66,14 +73,15 @@ from ..sessions.in_memory_session_service import InMemorySessionService
66
73
  from ..sessions.session import Session
67
74
  from ..sessions.vertex_ai_session_service import VertexAiSessionService
68
75
  from .cli_eval import EVAL_SESSION_ID_PREFIX
69
- from .cli_eval import EvalMetric
70
- from .cli_eval import EvalMetricResult
71
76
  from .cli_eval import EvalStatus
77
+ from .utils import cleanup
78
+ from .utils import common
72
79
  from .utils import create_empty_state
73
80
  from .utils import envs
74
81
  from .utils import evals
82
+ from .utils.agent_loader import AgentLoader
75
83
 
76
- logger = logging.getLogger(__name__)
84
+ logger = logging.getLogger("google_adk." + __name__)
77
85
 
78
86
  _EVAL_SET_FILE_EXTENSION = ".evalset.json"
79
87
 
@@ -103,7 +111,45 @@ class ApiServerSpanExporter(export.SpanExporter):
103
111
  return True
104
112
 
105
113
 
106
- class AgentRunRequest(BaseModel):
114
+ class InMemoryExporter(export.SpanExporter):
115
+
116
+ def __init__(self, trace_dict):
117
+ super().__init__()
118
+ self._spans = []
119
+ self.trace_dict = trace_dict
120
+
121
+ @override
122
+ def export(
123
+ self, spans: typing.Sequence[ReadableSpan]
124
+ ) -> export.SpanExportResult:
125
+ for span in spans:
126
+ trace_id = span.context.trace_id
127
+ if span.name == "call_llm":
128
+ attributes = dict(span.attributes)
129
+ session_id = attributes.get("gcp.vertex.agent.session_id", None)
130
+ if session_id:
131
+ if session_id not in self.trace_dict:
132
+ self.trace_dict[session_id] = [trace_id]
133
+ else:
134
+ self.trace_dict[session_id] += [trace_id]
135
+ self._spans.extend(spans)
136
+ return export.SpanExportResult.SUCCESS
137
+
138
+ @override
139
+ def force_flush(self, timeout_millis: int = 30000) -> bool:
140
+ return True
141
+
142
+ def get_finished_spans(self, session_id: str):
143
+ trace_ids = self.trace_dict.get(session_id, None)
144
+ if trace_ids is None or not trace_ids:
145
+ return []
146
+ return [x for x in self._spans if x.context.trace_id in trace_ids]
147
+
148
+ def clear(self):
149
+ self._spans.clear()
150
+
151
+
152
+ class AgentRunRequest(common.BaseModel):
107
153
  app_name: str
108
154
  user_id: str
109
155
  session_id: str
@@ -111,28 +157,41 @@ class AgentRunRequest(BaseModel):
111
157
  streaming: bool = False
112
158
 
113
159
 
114
- class AddSessionToEvalSetRequest(BaseModel):
160
+ class AddSessionToEvalSetRequest(common.BaseModel):
115
161
  eval_id: str
116
162
  session_id: str
117
163
  user_id: str
118
164
 
119
165
 
120
- class RunEvalRequest(BaseModel):
166
+ class RunEvalRequest(common.BaseModel):
121
167
  eval_ids: list[str] # if empty, then all evals in the eval set are run.
122
168
  eval_metrics: list[EvalMetric]
123
169
 
124
170
 
125
- class RunEvalResult(BaseModel):
171
+ class RunEvalResult(common.BaseModel):
172
+ eval_set_file: str
126
173
  eval_set_id: str
127
174
  eval_id: str
128
175
  final_eval_status: EvalStatus
129
- eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
176
+ eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
177
+ deprecated=True,
178
+ description=(
179
+ "This field is deprecated, use overall_eval_metric_results instead."
180
+ ),
181
+ )
182
+ overall_eval_metric_results: list[EvalMetricResult]
183
+ eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
184
+ user_id: str
130
185
  session_id: str
131
186
 
132
187
 
188
+ class GetEventGraphResult(common.BaseModel):
189
+ dot_src: str
190
+
191
+
133
192
  def get_fast_api_app(
134
193
  *,
135
- agent_dir: str,
194
+ agents_dir: str,
136
195
  session_db_url: str = "",
137
196
  allow_origins: Optional[list[str]] = None,
138
197
  web: bool,
@@ -141,40 +200,42 @@ def get_fast_api_app(
141
200
  ) -> FastAPI:
142
201
  # InMemory tracing dict.
143
202
  trace_dict: dict[str, Any] = {}
203
+ session_trace_dict: dict[str, Any] = {}
144
204
 
145
205
  # Set up tracing in the FastAPI server.
146
206
  provider = TracerProvider()
147
207
  provider.add_span_processor(
148
208
  export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
149
209
  )
210
+ memory_exporter = InMemoryExporter(session_trace_dict)
211
+ provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
150
212
  if trace_to_cloud:
151
- envs.load_dotenv_for_agent("", agent_dir)
213
+ envs.load_dotenv_for_agent("", agents_dir)
152
214
  if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
153
215
  processor = export.BatchSpanProcessor(
154
216
  CloudTraceSpanExporter(project_id=project_id)
155
217
  )
156
218
  provider.add_span_processor(processor)
157
219
  else:
158
- logging.warning(
220
+ logger.warning(
159
221
  "GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
160
222
  " not be enabled."
161
223
  )
162
224
 
163
225
  trace.set_tracer_provider(provider)
164
226
 
165
- exit_stacks = []
166
-
167
227
  @asynccontextmanager
168
228
  async def internal_lifespan(app: FastAPI):
169
- if lifespan:
170
- async with lifespan(app) as lifespan_context:
171
- yield
172
229
 
173
- if exit_stacks:
174
- for stack in exit_stacks:
175
- await stack.aclose()
176
- else:
177
- yield
230
+ try:
231
+ if lifespan:
232
+ async with lifespan(app) as lifespan_context:
233
+ yield lifespan_context
234
+ else:
235
+ yield
236
+ finally:
237
+ # Create tasks for all runner closures to run concurrently
238
+ await cleanup.close_runners(list(runner_dict.values()))
178
239
 
179
240
  # Run the FastAPI server.
180
241
  app = FastAPI(lifespan=internal_lifespan)
@@ -188,16 +249,15 @@ def get_fast_api_app(
188
249
  allow_headers=["*"],
189
250
  )
190
251
 
191
- if agent_dir not in sys.path:
192
- sys.path.append(agent_dir)
193
-
194
252
  runner_dict = {}
195
- root_agent_dict = {}
196
253
 
197
254
  # Build the Artifact service
198
255
  artifact_service = InMemoryArtifactService()
199
256
  memory_service = InMemoryMemoryService()
200
257
 
258
+ eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
259
+ eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
260
+
201
261
  # Build the Session service
202
262
  agent_engine_id = ""
203
263
  if session_db_url:
@@ -206,7 +266,7 @@ def get_fast_api_app(
206
266
  agent_engine_id = session_db_url.split("://")[1]
207
267
  if not agent_engine_id:
208
268
  raise click.ClickException("Agent engine id can not be empty.")
209
- envs.load_dotenv_for_agent("", agent_dir)
269
+ envs.load_dotenv_for_agent("", agents_dir)
210
270
  session_service = VertexAiSessionService(
211
271
  os.environ["GOOGLE_CLOUD_PROJECT"],
212
272
  os.environ["GOOGLE_CLOUD_LOCATION"],
@@ -216,9 +276,12 @@ def get_fast_api_app(
216
276
  else:
217
277
  session_service = InMemorySessionService()
218
278
 
279
+ # initialize Agent Loader
280
+ agent_loader = AgentLoader(agents_dir)
281
+
219
282
  @app.get("/list-apps")
220
283
  def list_apps() -> list[str]:
221
- base_path = Path.cwd() / agent_dir
284
+ base_path = Path.cwd() / agents_dir
222
285
  if not base_path.exists():
223
286
  raise HTTPException(status_code=404, detail="Path not found")
224
287
  if not base_path.is_dir():
@@ -240,14 +303,34 @@ def get_fast_api_app(
240
303
  raise HTTPException(status_code=404, detail="Trace not found")
241
304
  return event_dict
242
305
 
306
+ @app.get("/debug/trace/session/{session_id}")
307
+ def get_session_trace(session_id: str) -> Any:
308
+ spans = memory_exporter.get_finished_spans(session_id)
309
+ if not spans:
310
+ return []
311
+ return [
312
+ {
313
+ "name": s.name,
314
+ "span_id": s.context.span_id,
315
+ "trace_id": s.context.trace_id,
316
+ "start_time": s.start_time,
317
+ "end_time": s.end_time,
318
+ "attributes": dict(s.attributes),
319
+ "parent_span_id": s.parent.span_id if s.parent else None,
320
+ }
321
+ for s in spans
322
+ ]
323
+
243
324
  @app.get(
244
325
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
245
326
  response_model_exclude_none=True,
246
327
  )
247
- def get_session(app_name: str, user_id: str, session_id: str) -> Session:
328
+ async def get_session(
329
+ app_name: str, user_id: str, session_id: str
330
+ ) -> Session:
248
331
  # Connect to managed session if agent_engine_id is set.
249
332
  app_name = agent_engine_id if agent_engine_id else app_name
250
- session = session_service.get_session(
333
+ session = await session_service.get_session(
251
334
  app_name=app_name, user_id=user_id, session_id=session_id
252
335
  )
253
336
  if not session:
@@ -258,14 +341,15 @@ def get_fast_api_app(
258
341
  "/apps/{app_name}/users/{user_id}/sessions",
259
342
  response_model_exclude_none=True,
260
343
  )
261
- def list_sessions(app_name: str, user_id: str) -> list[Session]:
344
+ async def list_sessions(app_name: str, user_id: str) -> list[Session]:
262
345
  # Connect to managed session if agent_engine_id is set.
263
346
  app_name = agent_engine_id if agent_engine_id else app_name
347
+ list_sessions_response = await session_service.list_sessions(
348
+ app_name=app_name, user_id=user_id
349
+ )
264
350
  return [
265
351
  session
266
- for session in session_service.list_sessions(
267
- app_name=app_name, user_id=user_id
268
- ).sessions
352
+ for session in list_sessions_response.sessions
269
353
  # Remove sessions that were generated as a part of Eval.
270
354
  if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
271
355
  ]
@@ -274,7 +358,7 @@ def get_fast_api_app(
274
358
  "/apps/{app_name}/users/{user_id}/sessions/{session_id}",
275
359
  response_model_exclude_none=True,
276
360
  )
277
- def create_session_with_id(
361
+ async def create_session_with_id(
278
362
  app_name: str,
279
363
  user_id: str,
280
364
  session_id: str,
@@ -283,7 +367,7 @@ def get_fast_api_app(
283
367
  # Connect to managed session if agent_engine_id is set.
284
368
  app_name = agent_engine_id if agent_engine_id else app_name
285
369
  if (
286
- session_service.get_session(
370
+ await session_service.get_session(
287
371
  app_name=app_name, user_id=user_id, session_id=session_id
288
372
  )
289
373
  is not None
@@ -292,9 +376,8 @@ def get_fast_api_app(
292
376
  raise HTTPException(
293
377
  status_code=400, detail=f"Session already exists: {session_id}"
294
378
  )
295
-
296
379
  logger.info("New session created: %s", session_id)
297
- return session_service.create_session(
380
+ return await session_service.create_session(
298
381
  app_name=app_name, user_id=user_id, state=state, session_id=session_id
299
382
  )
300
383
 
@@ -302,22 +385,21 @@ def get_fast_api_app(
302
385
  "/apps/{app_name}/users/{user_id}/sessions",
303
386
  response_model_exclude_none=True,
304
387
  )
305
- def create_session(
388
+ async def create_session(
306
389
  app_name: str,
307
390
  user_id: str,
308
391
  state: Optional[dict[str, Any]] = None,
309
392
  ) -> Session:
310
393
  # Connect to managed session if agent_engine_id is set.
311
394
  app_name = agent_engine_id if agent_engine_id else app_name
312
-
313
395
  logger.info("New session created")
314
- return session_service.create_session(
396
+ return await session_service.create_session(
315
397
  app_name=app_name, user_id=user_id, state=state
316
398
  )
317
399
 
318
- def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
400
+ def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
319
401
  return os.path.join(
320
- agent_dir,
402
+ agents_dir,
321
403
  app_name,
322
404
  eval_set_id + _EVAL_SET_FILE_EXTENSION,
323
405
  )
@@ -331,28 +413,13 @@ def get_fast_api_app(
331
413
  eval_set_id: str,
332
414
  ):
333
415
  """Creates an eval set, given the id."""
334
- pattern = r"^[a-zA-Z0-9_]+$"
335
- if not bool(re.fullmatch(pattern, eval_set_id)):
416
+ try:
417
+ eval_sets_manager.create_eval_set(app_name, eval_set_id)
418
+ except ValueError as ve:
336
419
  raise HTTPException(
337
420
  status_code=400,
338
- detail=(
339
- f"Invalid eval set id. Eval set id should have the `{pattern}`"
340
- " format"
341
- ),
342
- )
343
- # Define the file path
344
- new_eval_set_path = _get_eval_set_file_path(
345
- app_name, agent_dir, eval_set_id
346
- )
347
-
348
- logger.info("Creating eval set file `%s`", new_eval_set_path)
349
-
350
- if not os.path.exists(new_eval_set_path):
351
- # Write the JSON string to the file
352
- logger.info("Eval set file doesn't exist, we will create a new one.")
353
- with open(new_eval_set_path, "w") as f:
354
- empty_content = json.dumps([], indent=2)
355
- f.write(empty_content)
421
+ detail=str(ve),
422
+ ) from ve
356
423
 
357
424
  @app.get(
358
425
  "/apps/{app_name}/eval_sets",
@@ -360,15 +427,7 @@ def get_fast_api_app(
360
427
  )
361
428
  def list_eval_sets(app_name: str) -> list[str]:
362
429
  """Lists all eval sets for the given app."""
363
- eval_set_file_path = os.path.join(agent_dir, app_name)
364
- eval_sets = []
365
- for file in os.listdir(eval_set_file_path):
366
- if file.endswith(_EVAL_SET_FILE_EXTENSION):
367
- eval_sets.append(
368
- os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
369
- )
370
-
371
- return sorted(eval_sets)
430
+ return eval_sets_manager.list_eval_sets(app_name)
372
431
 
373
432
  @app.post(
374
433
  "/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
@@ -377,54 +436,33 @@ def get_fast_api_app(
377
436
  async def add_session_to_eval_set(
378
437
  app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
379
438
  ):
380
- pattern = r"^[a-zA-Z0-9_]+$"
381
- if not bool(re.fullmatch(pattern, req.eval_id)):
382
- raise HTTPException(
383
- status_code=400,
384
- detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
385
- )
386
-
387
439
  # Get the session
388
- session = session_service.get_session(
440
+ session = await session_service.get_session(
389
441
  app_name=app_name, user_id=req.user_id, session_id=req.session_id
390
442
  )
391
443
  assert session, "Session not found."
392
- # Load the eval set file data
393
- eval_set_file_path = _get_eval_set_file_path(
394
- app_name, agent_dir, eval_set_id
395
- )
396
- with open(eval_set_file_path, "r") as file:
397
- eval_set_data = json.load(file) # Load JSON into a list
398
-
399
- if [x for x in eval_set_data if x["name"] == req.eval_id]:
400
- raise HTTPException(
401
- status_code=400,
402
- detail=(
403
- f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
404
- " eval set."
405
- ),
406
- )
407
444
 
408
- # Convert the session data to evaluation format
409
- test_data = evals.convert_session_to_eval_format(session)
445
+ # Convert the session data to eval invocations
446
+ invocations = evals.convert_session_to_eval_invocations(session)
410
447
 
411
448
  # Populate the session with initial session state.
412
449
  initial_session_state = create_empty_state(
413
- await _get_root_agent_async(app_name)
450
+ agent_loader.load_agent(app_name)
451
+ )
452
+
453
+ new_eval_case = EvalCase(
454
+ eval_id=req.eval_id,
455
+ conversation=invocations,
456
+ session_input=SessionInput(
457
+ app_name=app_name, user_id=req.user_id, state=initial_session_state
458
+ ),
459
+ creation_timestamp=time.time(),
414
460
  )
415
461
 
416
- eval_set_data.append({
417
- "name": req.eval_id,
418
- "data": test_data,
419
- "initial_session": {
420
- "state": initial_session_state,
421
- "app_name": app_name,
422
- "user_id": req.user_id,
423
- },
424
- })
425
- # Serialize the test data to JSON and write to the eval set file.
426
- with open(eval_set_file_path, "w") as f:
427
- f.write(json.dumps(eval_set_data, indent=2))
462
+ try:
463
+ eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case)
464
+ except ValueError as ve:
465
+ raise HTTPException(status_code=400, detail=str(ve)) from ve
428
466
 
429
467
  @app.get(
430
468
  "/apps/{app_name}/eval_sets/{eval_set_id}/evals",
@@ -435,14 +473,9 @@ def get_fast_api_app(
435
473
  eval_set_id: str,
436
474
  ) -> list[str]:
437
475
  """Lists all evals in an eval set."""
438
- # Load the eval set file data
439
- eval_set_file_path = _get_eval_set_file_path(
440
- app_name, agent_dir, eval_set_id
441
- )
442
- with open(eval_set_file_path, "r") as file:
443
- eval_set_data = json.load(file) # Load JSON into a list
476
+ eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
444
477
 
445
- return sorted([x["name"] for x in eval_set_data])
478
+ return sorted([x.eval_id for x in eval_set_data.eval_cases])
446
479
 
447
480
  @app.post(
448
481
  "/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
@@ -451,51 +484,89 @@ def get_fast_api_app(
451
484
  async def run_eval(
452
485
  app_name: str, eval_set_id: str, req: RunEvalRequest
453
486
  ) -> list[RunEvalResult]:
487
+ """Runs an eval given the details in the eval request."""
454
488
  from .cli_eval import run_evals
455
489
 
456
- """Runs an eval given the details in the eval request."""
457
490
  # Create a mapping from eval set file to all the evals that needed to be
458
491
  # run.
459
- eval_set_file_path = _get_eval_set_file_path(
460
- app_name, agent_dir, eval_set_id
461
- )
462
- eval_set_to_evals = {eval_set_file_path: req.eval_ids}
492
+ eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
463
493
 
464
- if not req.eval_ids:
465
- logger.info(
466
- "Eval ids to run list is empty. We will all evals in the eval set."
467
- )
468
- root_agent = await _get_root_agent_async(app_name)
469
- eval_results = list(
470
- await run_evals(
471
- eval_set_to_evals,
472
- root_agent,
473
- getattr(root_agent, "reset_data", None),
474
- req.eval_metrics,
475
- session_service=session_service,
476
- artifact_service=artifact_service,
477
- )
478
- )
494
+ if req.eval_ids:
495
+ eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
496
+ eval_set_to_evals = {eval_set_id: eval_cases}
497
+ else:
498
+ logger.info("Eval ids to run list is empty. We will run all eval cases.")
499
+ eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
479
500
 
501
+ root_agent = agent_loader.load_agent(app_name)
480
502
  run_eval_results = []
481
- for eval_result in eval_results:
503
+ eval_case_results = []
504
+ async for eval_case_result in run_evals(
505
+ eval_set_to_evals,
506
+ root_agent,
507
+ getattr(root_agent, "reset_data", None),
508
+ req.eval_metrics,
509
+ session_service=session_service,
510
+ artifact_service=artifact_service,
511
+ ):
482
512
  run_eval_results.append(
483
513
  RunEvalResult(
484
514
  app_name=app_name,
515
+ eval_set_file=eval_case_result.eval_set_file,
485
516
  eval_set_id=eval_set_id,
486
- eval_id=eval_result.eval_id,
487
- final_eval_status=eval_result.final_eval_status,
488
- eval_metric_results=eval_result.eval_metric_results,
489
- session_id=eval_result.session_id,
517
+ eval_id=eval_case_result.eval_id,
518
+ final_eval_status=eval_case_result.final_eval_status,
519
+ eval_metric_results=eval_case_result.eval_metric_results,
520
+ overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
521
+ eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
522
+ user_id=eval_case_result.user_id,
523
+ session_id=eval_case_result.session_id,
490
524
  )
491
525
  )
526
+ eval_case_result.session_details = await session_service.get_session(
527
+ app_name=app_name,
528
+ user_id=eval_case_result.user_id,
529
+ session_id=eval_case_result.session_id,
530
+ )
531
+ eval_case_results.append(eval_case_result)
532
+
533
+ eval_set_results_manager.save_eval_set_result(
534
+ app_name, eval_set_id, eval_case_results
535
+ )
536
+
492
537
  return run_eval_results
493
538
 
539
+ @app.get(
540
+ "/apps/{app_name}/eval_results/{eval_result_id}",
541
+ response_model_exclude_none=True,
542
+ )
543
+ def get_eval_result(
544
+ app_name: str,
545
+ eval_result_id: str,
546
+ ) -> EvalSetResult:
547
+ """Gets the eval result for the given eval id."""
548
+ try:
549
+ return eval_set_results_manager.get_eval_set_result(
550
+ app_name, eval_result_id
551
+ )
552
+ except ValueError as ve:
553
+ raise HTTPException(status_code=404, detail=str(ve)) from ve
554
+ except ValidationError as ve:
555
+ raise HTTPException(status_code=500, detail=str(ve)) from ve
556
+
557
+ @app.get(
558
+ "/apps/{app_name}/eval_results",
559
+ response_model_exclude_none=True,
560
+ )
561
+ def list_eval_results(app_name: str) -> list[str]:
562
+ """Lists all eval results for the given app."""
563
+ return eval_set_results_manager.list_eval_set_results(app_name)
564
+
494
565
  @app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
495
- def delete_session(app_name: str, user_id: str, session_id: str):
566
+ async def delete_session(app_name: str, user_id: str, session_id: str):
496
567
  # Connect to managed session if agent_engine_id is set.
497
568
  app_name = agent_engine_id if agent_engine_id else app_name
498
- session_service.delete_session(
569
+ await session_service.delete_session(
499
570
  app_name=app_name, user_id=user_id, session_id=session_id
500
571
  )
501
572
 
@@ -589,9 +660,9 @@ def get_fast_api_app(
589
660
  @app.post("/run", response_model_exclude_none=True)
590
661
  async def agent_run(req: AgentRunRequest) -> list[Event]:
591
662
  # Connect to managed session if agent_engine_id is set.
592
- app_id = agent_engine_id if agent_engine_id else req.app_name
593
- session = session_service.get_session(
594
- app_name=app_id, user_id=req.user_id, session_id=req.session_id
663
+ app_name = agent_engine_id if agent_engine_id else req.app_name
664
+ session = await session_service.get_session(
665
+ app_name=app_name, user_id=req.user_id, session_id=req.session_id
595
666
  )
596
667
  if not session:
597
668
  raise HTTPException(status_code=404, detail="Session not found")
@@ -610,10 +681,10 @@ def get_fast_api_app(
610
681
  @app.post("/run_sse")
611
682
  async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
612
683
  # Connect to managed session if agent_engine_id is set.
613
- app_id = agent_engine_id if agent_engine_id else req.app_name
684
+ app_name = agent_engine_id if agent_engine_id else req.app_name
614
685
  # SSE endpoint
615
- session = session_service.get_session(
616
- app_name=app_id, user_id=req.user_id, session_id=req.session_id
686
+ session = await session_service.get_session(
687
+ app_name=app_name, user_id=req.user_id, session_id=req.session_id
617
688
  )
618
689
  if not session:
619
690
  raise HTTPException(status_code=404, detail="Session not found")
@@ -652,9 +723,9 @@ def get_fast_api_app(
652
723
  app_name: str, user_id: str, session_id: str, event_id: str
653
724
  ):
654
725
  # Connect to managed session if agent_engine_id is set.
655
- app_id = agent_engine_id if agent_engine_id else app_name
656
- session = session_service.get_session(
657
- app_name=app_id, user_id=user_id, session_id=session_id
726
+ app_name = agent_engine_id if agent_engine_id else app_name
727
+ session = await session_service.get_session(
728
+ app_name=app_name, user_id=user_id, session_id=session_id
658
729
  )
659
730
  session_events = session.events if session else []
660
731
  event = next((x for x in session_events if x.id == event_id), None)
@@ -665,7 +736,7 @@ def get_fast_api_app(
665
736
 
666
737
  function_calls = event.get_function_calls()
667
738
  function_responses = event.get_function_responses()
668
- root_agent = await _get_root_agent_async(app_name)
739
+ root_agent = agent_loader.load_agent(app_name)
669
740
  dot_graph = None
670
741
  if function_calls:
671
742
  function_call_highlights = []
@@ -673,7 +744,7 @@ def get_fast_api_app(
673
744
  from_name = event.author
674
745
  to_name = function_call.name
675
746
  function_call_highlights.append((from_name, to_name))
676
- dot_graph = agent_graph.get_agent_graph(
747
+ dot_graph = await agent_graph.get_agent_graph(
677
748
  root_agent, function_call_highlights
678
749
  )
679
750
  elif function_responses:
@@ -682,17 +753,17 @@ def get_fast_api_app(
682
753
  from_name = function_response.name
683
754
  to_name = event.author
684
755
  function_responses_highlights.append((from_name, to_name))
685
- dot_graph = agent_graph.get_agent_graph(
756
+ dot_graph = await agent_graph.get_agent_graph(
686
757
  root_agent, function_responses_highlights
687
758
  )
688
759
  else:
689
760
  from_name = event.author
690
761
  to_name = ""
691
- dot_graph = agent_graph.get_agent_graph(
762
+ dot_graph = await agent_graph.get_agent_graph(
692
763
  root_agent, [(from_name, to_name)]
693
764
  )
694
765
  if dot_graph and isinstance(dot_graph, graphviz.Digraph):
695
- return {"dot_src": dot_graph.source}
766
+ return GetEventGraphResult(dot_src=dot_graph.source)
696
767
  else:
697
768
  return {}
698
769
 
@@ -709,9 +780,9 @@ def get_fast_api_app(
709
780
  await websocket.accept()
710
781
 
711
782
  # Connect to managed session if agent_engine_id is set.
712
- app_id = agent_engine_id if agent_engine_id else app_name
713
- session = session_service.get_session(
714
- app_name=app_id, user_id=user_id, session_id=session_id
783
+ app_name = agent_engine_id if agent_engine_id else app_name
784
+ session = await session_service.get_session(
785
+ app_name=app_name, user_id=user_id, session_id=session_id
715
786
  )
716
787
  if not session:
717
788
  # Accept first so that the client is aware of connection establishment,
@@ -766,34 +837,12 @@ def get_fast_api_app(
766
837
  for task in pending:
767
838
  task.cancel()
768
839
 
769
- async def _get_root_agent_async(app_name: str) -> Agent:
770
- """Returns the root agent for the given app."""
771
- if app_name in root_agent_dict:
772
- return root_agent_dict[app_name]
773
- agent_module = importlib.import_module(app_name)
774
- if getattr(agent_module.agent, "root_agent"):
775
- root_agent = agent_module.agent.root_agent
776
- else:
777
- raise ValueError(f'Unable to find "root_agent" from {app_name}.')
778
-
779
- # Handle an awaitable root agent and await for the actual agent.
780
- if inspect.isawaitable(root_agent):
781
- try:
782
- agent, exit_stack = await root_agent
783
- exit_stacks.append(exit_stack)
784
- root_agent = agent
785
- except Exception as e:
786
- raise RuntimeError(f"error getting root agent, {e}") from e
787
-
788
- root_agent_dict[app_name] = root_agent
789
- return root_agent
790
-
791
840
  async def _get_runner_async(app_name: str) -> Runner:
792
841
  """Returns the runner for the given app."""
793
- envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
842
+ envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
794
843
  if app_name in runner_dict:
795
844
  return runner_dict[app_name]
796
- root_agent = await _get_root_agent_async(app_name)
845
+ root_agent = agent_loader.load_agent(app_name)
797
846
  runner = Runner(
798
847
  app_name=agent_engine_id if agent_engine_id else app_name,
799
848
  agent=root_agent,
@@ -809,14 +858,16 @@ def get_fast_api_app(
809
858
  ANGULAR_DIST_PATH = BASE_DIR / "browser"
810
859
 
811
860
  @app.get("/")
812
- async def redirect_to_dev_ui():
813
- return RedirectResponse("/dev-ui")
861
+ async def redirect_root_to_dev_ui():
862
+ return RedirectResponse("/dev-ui/")
814
863
 
815
864
  @app.get("/dev-ui")
816
- async def dev_ui():
817
- return FileResponse(BASE_DIR / "browser/index.html")
865
+ async def redirect_dev_ui_add_slash():
866
+ return RedirectResponse("/dev-ui/")
818
867
 
819
868
  app.mount(
820
- "/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
869
+ "/dev-ui/",
870
+ StaticFiles(directory=ANGULAR_DIST_PATH, html=True),
871
+ name="static",
821
872
  )
822
873
  return app