google-adk 0.5.0__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/base_agent.py.orig +330 -0
- google/adk/agents/callback_context.py +0 -5
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +7 -1
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +6 -1
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +2 -2
- google/adk/cli/browser/{main-ULN5R5I5.js → main-QOEMUXM4.js} +44 -45
- google/adk/cli/cli.py +7 -7
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +172 -99
- google/adk/cli/cli_tools_click.py +147 -64
- google/adk/cli/fast_api.py +330 -148
- google/adk/cli/fast_api.py.orig +174 -80
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -1
- google/adk/cli/utils/logs.py +13 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +102 -0
- google/adk/evaluation/eval_set.py +37 -0
- google/adk/evaluation/eval_sets_manager.py +42 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +56 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -2
- google/adk/evaluation/trajectory_evaluator.py +83 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +54 -15
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +13 -5
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/base_memory_service.py.orig +76 -0
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +9 -2
- google/adk/models/google_llm.py.orig +305 -0
- google/adk/models/lite_llm.py +77 -21
- google/adk/models/llm_response.py +14 -2
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +65 -41
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +58 -65
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +23 -45
- google/adk/telemetry.py +3 -0
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/{built_in_code_execution_tool.py → _built_in_code_execution_tool.py} +11 -0
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +9 -9
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/application_integration_tool/application_integration_toolset.py +107 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +20 -0
- google/adk/tools/application_integration_tool/clients/integration_client.py +6 -6
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +58 -0
- google/adk/tools/enterprise_search_tool.py +65 -0
- google/adk/tools/function_parameter_parse_util.py +2 -2
- google/adk/tools/google_api_tool/__init__.py +18 -70
- google/adk/tools/google_api_tool/google_api_tool.py +11 -5
- google/adk/tools/google_api_tool/google_api_toolset.py +126 -0
- google/adk/tools/google_api_tool/google_api_toolsets.py +102 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +153 -16
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +322 -0
- google/adk/tools/mcp_tool/mcp_tool.py +12 -12
- google/adk/tools/mcp_tool/mcp_toolset.py +155 -195
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +32 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +31 -31
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +79 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/METADATA +7 -5
- google_adk-1.0.0.dist-info/RECORD +195 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/tools/google_api_tool/google_api_tool_set.py +0 -110
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.0.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/fast_api.py
CHANGED
@@ -12,6 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
15
16
|
import asyncio
|
16
17
|
from contextlib import asynccontextmanager
|
17
18
|
import importlib
|
@@ -20,8 +21,9 @@ import json
|
|
20
21
|
import logging
|
21
22
|
import os
|
22
23
|
from pathlib import Path
|
23
|
-
import
|
24
|
+
import signal
|
24
25
|
import sys
|
26
|
+
import time
|
25
27
|
import traceback
|
26
28
|
import typing
|
27
29
|
from typing import Any
|
@@ -30,7 +32,6 @@ from typing import Literal
|
|
30
32
|
from typing import Optional
|
31
33
|
|
32
34
|
import click
|
33
|
-
from click import Tuple
|
34
35
|
from fastapi import FastAPI
|
35
36
|
from fastapi import HTTPException
|
36
37
|
from fastapi import Query
|
@@ -48,16 +49,22 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
|
48
49
|
from opentelemetry.sdk.trace import export
|
49
50
|
from opentelemetry.sdk.trace import ReadableSpan
|
50
51
|
from opentelemetry.sdk.trace import TracerProvider
|
51
|
-
from pydantic import
|
52
|
+
from pydantic import Field
|
52
53
|
from pydantic import ValidationError
|
53
54
|
from starlette.types import Lifespan
|
55
|
+
from typing_extensions import override
|
54
56
|
|
55
57
|
from ..agents import RunConfig
|
58
|
+
from ..agents.base_agent import BaseAgent
|
56
59
|
from ..agents.live_request_queue import LiveRequest
|
57
60
|
from ..agents.live_request_queue import LiveRequestQueue
|
58
61
|
from ..agents.llm_agent import Agent
|
62
|
+
from ..agents.llm_agent import LlmAgent
|
59
63
|
from ..agents.run_config import StreamingMode
|
60
64
|
from ..artifacts import InMemoryArtifactService
|
65
|
+
from ..evaluation.eval_case import EvalCase
|
66
|
+
from ..evaluation.eval_case import SessionInput
|
67
|
+
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
61
68
|
from ..events.event import Event
|
62
69
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
63
70
|
from ..runners import Runner
|
@@ -65,17 +72,23 @@ from ..sessions.database_session_service import DatabaseSessionService
|
|
65
72
|
from ..sessions.in_memory_session_service import InMemorySessionService
|
66
73
|
from ..sessions.session import Session
|
67
74
|
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
75
|
+
from ..tools.base_toolset import BaseToolset
|
68
76
|
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
77
|
+
from .cli_eval import EvalCaseResult
|
69
78
|
from .cli_eval import EvalMetric
|
70
79
|
from .cli_eval import EvalMetricResult
|
80
|
+
from .cli_eval import EvalMetricResultPerInvocation
|
81
|
+
from .cli_eval import EvalSetResult
|
71
82
|
from .cli_eval import EvalStatus
|
83
|
+
from .utils import common
|
72
84
|
from .utils import create_empty_state
|
73
85
|
from .utils import envs
|
74
86
|
from .utils import evals
|
75
87
|
|
76
|
-
logger = logging.getLogger(__name__)
|
88
|
+
logger = logging.getLogger("google_adk." + __name__)
|
77
89
|
|
78
90
|
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
91
|
+
_EVAL_SET_RESULT_FILE_EXTENSION = ".evalset_result.json"
|
79
92
|
|
80
93
|
|
81
94
|
class ApiServerSpanExporter(export.SpanExporter):
|
@@ -103,7 +116,45 @@ class ApiServerSpanExporter(export.SpanExporter):
|
|
103
116
|
return True
|
104
117
|
|
105
118
|
|
106
|
-
class
|
119
|
+
class InMemoryExporter(export.SpanExporter):
|
120
|
+
|
121
|
+
def __init__(self, trace_dict):
|
122
|
+
super().__init__()
|
123
|
+
self._spans = []
|
124
|
+
self.trace_dict = trace_dict
|
125
|
+
|
126
|
+
@override
|
127
|
+
def export(
|
128
|
+
self, spans: typing.Sequence[ReadableSpan]
|
129
|
+
) -> export.SpanExportResult:
|
130
|
+
for span in spans:
|
131
|
+
trace_id = span.context.trace_id
|
132
|
+
if span.name == "call_llm":
|
133
|
+
attributes = dict(span.attributes)
|
134
|
+
session_id = attributes.get("gcp.vertex.agent.session_id", None)
|
135
|
+
if session_id:
|
136
|
+
if session_id not in self.trace_dict:
|
137
|
+
self.trace_dict[session_id] = [trace_id]
|
138
|
+
else:
|
139
|
+
self.trace_dict[session_id] += [trace_id]
|
140
|
+
self._spans.extend(spans)
|
141
|
+
return export.SpanExportResult.SUCCESS
|
142
|
+
|
143
|
+
@override
|
144
|
+
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
145
|
+
return True
|
146
|
+
|
147
|
+
def get_finished_spans(self, session_id: str):
|
148
|
+
trace_ids = self.trace_dict.get(session_id, None)
|
149
|
+
if trace_ids is None or not trace_ids:
|
150
|
+
return []
|
151
|
+
return [x for x in self._spans if x.context.trace_id in trace_ids]
|
152
|
+
|
153
|
+
def clear(self):
|
154
|
+
self._spans.clear()
|
155
|
+
|
156
|
+
|
157
|
+
class AgentRunRequest(common.BaseModel):
|
107
158
|
app_name: str
|
108
159
|
user_id: str
|
109
160
|
session_id: str
|
@@ -111,25 +162,38 @@ class AgentRunRequest(BaseModel):
|
|
111
162
|
streaming: bool = False
|
112
163
|
|
113
164
|
|
114
|
-
class AddSessionToEvalSetRequest(BaseModel):
|
165
|
+
class AddSessionToEvalSetRequest(common.BaseModel):
|
115
166
|
eval_id: str
|
116
167
|
session_id: str
|
117
168
|
user_id: str
|
118
169
|
|
119
170
|
|
120
|
-
class RunEvalRequest(BaseModel):
|
171
|
+
class RunEvalRequest(common.BaseModel):
|
121
172
|
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
122
173
|
eval_metrics: list[EvalMetric]
|
123
174
|
|
124
175
|
|
125
|
-
class RunEvalResult(BaseModel):
|
176
|
+
class RunEvalResult(common.BaseModel):
|
177
|
+
eval_set_file: str
|
126
178
|
eval_set_id: str
|
127
179
|
eval_id: str
|
128
180
|
final_eval_status: EvalStatus
|
129
|
-
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
181
|
+
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
182
|
+
deprecated=True,
|
183
|
+
description=(
|
184
|
+
"This field is deprecated, use overall_eval_metric_results instead."
|
185
|
+
),
|
186
|
+
)
|
187
|
+
overall_eval_metric_results: list[EvalMetricResult]
|
188
|
+
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
189
|
+
user_id: str
|
130
190
|
session_id: str
|
131
191
|
|
132
192
|
|
193
|
+
class GetEventGraphResult(common.BaseModel):
|
194
|
+
dot_src: str
|
195
|
+
|
196
|
+
|
133
197
|
def get_fast_api_app(
|
134
198
|
*,
|
135
199
|
agent_dir: str,
|
@@ -141,12 +205,15 @@ def get_fast_api_app(
|
|
141
205
|
) -> FastAPI:
|
142
206
|
# InMemory tracing dict.
|
143
207
|
trace_dict: dict[str, Any] = {}
|
208
|
+
session_trace_dict: dict[str, Any] = {}
|
144
209
|
|
145
210
|
# Set up tracing in the FastAPI server.
|
146
211
|
provider = TracerProvider()
|
147
212
|
provider.add_span_processor(
|
148
213
|
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
149
214
|
)
|
215
|
+
memory_exporter = InMemoryExporter(session_trace_dict)
|
216
|
+
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
|
150
217
|
if trace_to_cloud:
|
151
218
|
envs.load_dotenv_for_agent("", agent_dir)
|
152
219
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
@@ -155,26 +222,82 @@ def get_fast_api_app(
|
|
155
222
|
)
|
156
223
|
provider.add_span_processor(processor)
|
157
224
|
else:
|
158
|
-
|
225
|
+
logger.warning(
|
159
226
|
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
160
227
|
" not be enabled."
|
161
228
|
)
|
162
229
|
|
163
230
|
trace.set_tracer_provider(provider)
|
164
231
|
|
165
|
-
|
232
|
+
toolsets_to_close: set[BaseToolset] = set()
|
166
233
|
|
167
234
|
@asynccontextmanager
|
168
235
|
async def internal_lifespan(app: FastAPI):
|
169
|
-
|
170
|
-
|
236
|
+
# Set up signal handlers for graceful shutdown
|
237
|
+
original_sigterm = signal.getsignal(signal.SIGTERM)
|
238
|
+
original_sigint = signal.getsignal(signal.SIGINT)
|
239
|
+
|
240
|
+
def cleanup_handler(sig, frame):
|
241
|
+
# Log the signal
|
242
|
+
logger.info("Received signal %s, performing pre-shutdown cleanup", sig)
|
243
|
+
# Do synchronous cleanup if needed
|
244
|
+
# Then call original handler if it exists
|
245
|
+
if sig == signal.SIGTERM and callable(original_sigterm):
|
246
|
+
original_sigterm(sig, frame)
|
247
|
+
elif sig == signal.SIGINT and callable(original_sigint):
|
248
|
+
original_sigint(sig, frame)
|
249
|
+
|
250
|
+
# Install cleanup handlers
|
251
|
+
signal.signal(signal.SIGTERM, cleanup_handler)
|
252
|
+
signal.signal(signal.SIGINT, cleanup_handler)
|
253
|
+
|
254
|
+
try:
|
255
|
+
if lifespan:
|
256
|
+
async with lifespan(app) as lifespan_context:
|
257
|
+
yield lifespan_context
|
258
|
+
else:
|
171
259
|
yield
|
260
|
+
finally:
|
261
|
+
# During shutdown, properly clean up all toolsets
|
262
|
+
logger.info(
|
263
|
+
"Server shutdown initiated, cleaning up %s toolsets",
|
264
|
+
len(toolsets_to_close),
|
265
|
+
)
|
172
266
|
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
267
|
+
# Create tasks for all toolset closures to run concurrently
|
268
|
+
cleanup_tasks = []
|
269
|
+
for toolset in toolsets_to_close:
|
270
|
+
task = asyncio.create_task(close_toolset_safely(toolset))
|
271
|
+
cleanup_tasks.append(task)
|
272
|
+
|
273
|
+
if cleanup_tasks:
|
274
|
+
# Wait for all cleanup tasks with timeout
|
275
|
+
done, pending = await asyncio.wait(
|
276
|
+
cleanup_tasks,
|
277
|
+
timeout=10.0, # 10 second timeout for cleanup
|
278
|
+
return_when=asyncio.ALL_COMPLETED,
|
279
|
+
)
|
280
|
+
|
281
|
+
# If any tasks are still pending, log it
|
282
|
+
if pending:
|
283
|
+
logger.warning(
|
284
|
+
f"{len(pending)} toolset cleanup tasks didn't complete in time"
|
285
|
+
)
|
286
|
+
for task in pending:
|
287
|
+
task.cancel()
|
288
|
+
|
289
|
+
# Restore original signal handlers
|
290
|
+
signal.signal(signal.SIGTERM, original_sigterm)
|
291
|
+
signal.signal(signal.SIGINT, original_sigint)
|
292
|
+
|
293
|
+
async def close_toolset_safely(toolset):
|
294
|
+
"""Safely close a toolset with error handling."""
|
295
|
+
try:
|
296
|
+
logger.info(f"Closing toolset: {type(toolset).__name__}")
|
297
|
+
await toolset.close()
|
298
|
+
logger.info(f"Successfully closed toolset: {type(toolset).__name__}")
|
299
|
+
except Exception as e:
|
300
|
+
logger.error(f"Error closing toolset {type(toolset).__name__}: {e}")
|
178
301
|
|
179
302
|
# Run the FastAPI server.
|
180
303
|
app = FastAPI(lifespan=internal_lifespan)
|
@@ -198,6 +321,8 @@ def get_fast_api_app(
|
|
198
321
|
artifact_service = InMemoryArtifactService()
|
199
322
|
memory_service = InMemoryMemoryService()
|
200
323
|
|
324
|
+
eval_sets_manager = LocalEvalSetsManager(agent_dir=agent_dir)
|
325
|
+
|
201
326
|
# Build the Session service
|
202
327
|
agent_engine_id = ""
|
203
328
|
if session_db_url:
|
@@ -240,14 +365,34 @@ def get_fast_api_app(
|
|
240
365
|
raise HTTPException(status_code=404, detail="Trace not found")
|
241
366
|
return event_dict
|
242
367
|
|
368
|
+
@app.get("/debug/trace/session/{session_id}")
|
369
|
+
def get_session_trace(session_id: str) -> Any:
|
370
|
+
spans = memory_exporter.get_finished_spans(session_id)
|
371
|
+
if not spans:
|
372
|
+
return []
|
373
|
+
return [
|
374
|
+
{
|
375
|
+
"name": s.name,
|
376
|
+
"span_id": s.context.span_id,
|
377
|
+
"trace_id": s.context.trace_id,
|
378
|
+
"start_time": s.start_time,
|
379
|
+
"end_time": s.end_time,
|
380
|
+
"attributes": dict(s.attributes),
|
381
|
+
"parent_span_id": s.parent.span_id if s.parent else None,
|
382
|
+
}
|
383
|
+
for s in spans
|
384
|
+
]
|
385
|
+
|
243
386
|
@app.get(
|
244
387
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
245
388
|
response_model_exclude_none=True,
|
246
389
|
)
|
247
|
-
def get_session(
|
390
|
+
async def get_session(
|
391
|
+
app_name: str, user_id: str, session_id: str
|
392
|
+
) -> Session:
|
248
393
|
# Connect to managed session if agent_engine_id is set.
|
249
394
|
app_name = agent_engine_id if agent_engine_id else app_name
|
250
|
-
session = session_service.get_session(
|
395
|
+
session = await session_service.get_session(
|
251
396
|
app_name=app_name, user_id=user_id, session_id=session_id
|
252
397
|
)
|
253
398
|
if not session:
|
@@ -258,14 +403,15 @@ def get_fast_api_app(
|
|
258
403
|
"/apps/{app_name}/users/{user_id}/sessions",
|
259
404
|
response_model_exclude_none=True,
|
260
405
|
)
|
261
|
-
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
406
|
+
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
262
407
|
# Connect to managed session if agent_engine_id is set.
|
263
408
|
app_name = agent_engine_id if agent_engine_id else app_name
|
409
|
+
list_sessions_response = await session_service.list_sessions(
|
410
|
+
app_name=app_name, user_id=user_id
|
411
|
+
)
|
264
412
|
return [
|
265
413
|
session
|
266
|
-
for session in
|
267
|
-
app_name=app_name, user_id=user_id
|
268
|
-
).sessions
|
414
|
+
for session in list_sessions_response.sessions
|
269
415
|
# Remove sessions that were generated as a part of Eval.
|
270
416
|
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
271
417
|
]
|
@@ -274,7 +420,7 @@ def get_fast_api_app(
|
|
274
420
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
275
421
|
response_model_exclude_none=True,
|
276
422
|
)
|
277
|
-
def create_session_with_id(
|
423
|
+
async def create_session_with_id(
|
278
424
|
app_name: str,
|
279
425
|
user_id: str,
|
280
426
|
session_id: str,
|
@@ -283,7 +429,7 @@ def get_fast_api_app(
|
|
283
429
|
# Connect to managed session if agent_engine_id is set.
|
284
430
|
app_name = agent_engine_id if agent_engine_id else app_name
|
285
431
|
if (
|
286
|
-
session_service.get_session(
|
432
|
+
await session_service.get_session(
|
287
433
|
app_name=app_name, user_id=user_id, session_id=session_id
|
288
434
|
)
|
289
435
|
is not None
|
@@ -292,9 +438,8 @@ def get_fast_api_app(
|
|
292
438
|
raise HTTPException(
|
293
439
|
status_code=400, detail=f"Session already exists: {session_id}"
|
294
440
|
)
|
295
|
-
|
296
441
|
logger.info("New session created: %s", session_id)
|
297
|
-
return session_service.create_session(
|
442
|
+
return await session_service.create_session(
|
298
443
|
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
299
444
|
)
|
300
445
|
|
@@ -302,16 +447,15 @@ def get_fast_api_app(
|
|
302
447
|
"/apps/{app_name}/users/{user_id}/sessions",
|
303
448
|
response_model_exclude_none=True,
|
304
449
|
)
|
305
|
-
def create_session(
|
450
|
+
async def create_session(
|
306
451
|
app_name: str,
|
307
452
|
user_id: str,
|
308
453
|
state: Optional[dict[str, Any]] = None,
|
309
454
|
) -> Session:
|
310
455
|
# Connect to managed session if agent_engine_id is set.
|
311
456
|
app_name = agent_engine_id if agent_engine_id else app_name
|
312
|
-
|
313
457
|
logger.info("New session created")
|
314
|
-
return session_service.create_session(
|
458
|
+
return await session_service.create_session(
|
315
459
|
app_name=app_name, user_id=user_id, state=state
|
316
460
|
)
|
317
461
|
|
@@ -331,28 +475,13 @@ def get_fast_api_app(
|
|
331
475
|
eval_set_id: str,
|
332
476
|
):
|
333
477
|
"""Creates an eval set, given the id."""
|
334
|
-
|
335
|
-
|
478
|
+
try:
|
479
|
+
eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
480
|
+
except ValueError as ve:
|
336
481
|
raise HTTPException(
|
337
482
|
status_code=400,
|
338
|
-
detail=(
|
339
|
-
|
340
|
-
" format"
|
341
|
-
),
|
342
|
-
)
|
343
|
-
# Define the file path
|
344
|
-
new_eval_set_path = _get_eval_set_file_path(
|
345
|
-
app_name, agent_dir, eval_set_id
|
346
|
-
)
|
347
|
-
|
348
|
-
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
349
|
-
|
350
|
-
if not os.path.exists(new_eval_set_path):
|
351
|
-
# Write the JSON string to the file
|
352
|
-
logger.info("Eval set file doesn't exist, we will create a new one.")
|
353
|
-
with open(new_eval_set_path, "w") as f:
|
354
|
-
empty_content = json.dumps([], indent=2)
|
355
|
-
f.write(empty_content)
|
483
|
+
detail=str(ve),
|
484
|
+
) from ve
|
356
485
|
|
357
486
|
@app.get(
|
358
487
|
"/apps/{app_name}/eval_sets",
|
@@ -360,15 +489,7 @@ def get_fast_api_app(
|
|
360
489
|
)
|
361
490
|
def list_eval_sets(app_name: str) -> list[str]:
|
362
491
|
"""Lists all eval sets for the given app."""
|
363
|
-
|
364
|
-
eval_sets = []
|
365
|
-
for file in os.listdir(eval_set_file_path):
|
366
|
-
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
367
|
-
eval_sets.append(
|
368
|
-
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
369
|
-
)
|
370
|
-
|
371
|
-
return sorted(eval_sets)
|
492
|
+
return eval_sets_manager.list_eval_sets(app_name)
|
372
493
|
|
373
494
|
@app.post(
|
374
495
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
@@ -377,54 +498,33 @@ def get_fast_api_app(
|
|
377
498
|
async def add_session_to_eval_set(
|
378
499
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
379
500
|
):
|
380
|
-
pattern = r"^[a-zA-Z0-9_]+$"
|
381
|
-
if not bool(re.fullmatch(pattern, req.eval_id)):
|
382
|
-
raise HTTPException(
|
383
|
-
status_code=400,
|
384
|
-
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
385
|
-
)
|
386
|
-
|
387
501
|
# Get the session
|
388
|
-
session = session_service.get_session(
|
502
|
+
session = await session_service.get_session(
|
389
503
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
390
504
|
)
|
391
505
|
assert session, "Session not found."
|
392
|
-
# Load the eval set file data
|
393
|
-
eval_set_file_path = _get_eval_set_file_path(
|
394
|
-
app_name, agent_dir, eval_set_id
|
395
|
-
)
|
396
|
-
with open(eval_set_file_path, "r") as file:
|
397
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
398
506
|
|
399
|
-
|
400
|
-
|
401
|
-
status_code=400,
|
402
|
-
detail=(
|
403
|
-
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
404
|
-
" eval set."
|
405
|
-
),
|
406
|
-
)
|
407
|
-
|
408
|
-
# Convert the session data to evaluation format
|
409
|
-
test_data = evals.convert_session_to_eval_format(session)
|
507
|
+
# Convert the session data to eval invocations
|
508
|
+
invocations = evals.convert_session_to_eval_invocations(session)
|
410
509
|
|
411
510
|
# Populate the session with initial session state.
|
412
511
|
initial_session_state = create_empty_state(
|
413
512
|
await _get_root_agent_async(app_name)
|
414
513
|
)
|
415
514
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
515
|
+
new_eval_case = EvalCase(
|
516
|
+
eval_id=req.eval_id,
|
517
|
+
conversation=invocations,
|
518
|
+
session_input=SessionInput(
|
519
|
+
app_name=app_name, user_id=req.user_id, state=initial_session_state
|
520
|
+
),
|
521
|
+
creation_timestamp=time.time(),
|
522
|
+
)
|
523
|
+
|
524
|
+
try:
|
525
|
+
eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case)
|
526
|
+
except ValueError as ve:
|
527
|
+
raise HTTPException(status_code=400, detail=str(ve)) from ve
|
428
528
|
|
429
529
|
@app.get(
|
430
530
|
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
@@ -435,14 +535,9 @@ def get_fast_api_app(
|
|
435
535
|
eval_set_id: str,
|
436
536
|
) -> list[str]:
|
437
537
|
"""Lists all evals in an eval set."""
|
438
|
-
|
439
|
-
eval_set_file_path = _get_eval_set_file_path(
|
440
|
-
app_name, agent_dir, eval_set_id
|
441
|
-
)
|
442
|
-
with open(eval_set_file_path, "r") as file:
|
443
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
538
|
+
eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
444
539
|
|
445
|
-
return sorted([x
|
540
|
+
return sorted([x.eval_id for x in eval_set_data.eval_cases])
|
446
541
|
|
447
542
|
@app.post(
|
448
543
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
@@ -451,51 +546,136 @@ def get_fast_api_app(
|
|
451
546
|
async def run_eval(
|
452
547
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
453
548
|
) -> list[RunEvalResult]:
|
549
|
+
"""Runs an eval given the details in the eval request."""
|
454
550
|
from .cli_eval import run_evals
|
455
551
|
|
456
|
-
"""Runs an eval given the details in the eval request."""
|
457
552
|
# Create a mapping from eval set file to all the evals that needed to be
|
458
553
|
# run.
|
459
|
-
|
460
|
-
app_name, agent_dir, eval_set_id
|
461
|
-
)
|
462
|
-
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
554
|
+
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
463
555
|
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
-
|
472
|
-
root_agent,
|
473
|
-
getattr(root_agent, "reset_data", None),
|
474
|
-
req.eval_metrics,
|
475
|
-
session_service=session_service,
|
476
|
-
artifact_service=artifact_service,
|
477
|
-
)
|
478
|
-
)
|
556
|
+
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
557
|
+
|
558
|
+
if req.eval_ids:
|
559
|
+
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
|
560
|
+
eval_set_to_evals = {eval_set_id: eval_cases}
|
561
|
+
else:
|
562
|
+
logger.info("Eval ids to run list is empty. We will run all eval cases.")
|
563
|
+
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
|
479
564
|
|
565
|
+
root_agent = await _get_root_agent_async(app_name)
|
480
566
|
run_eval_results = []
|
481
|
-
|
567
|
+
eval_case_results = []
|
568
|
+
async for eval_case_result in run_evals(
|
569
|
+
eval_set_to_evals,
|
570
|
+
root_agent,
|
571
|
+
getattr(root_agent, "reset_data", None),
|
572
|
+
req.eval_metrics,
|
573
|
+
session_service=session_service,
|
574
|
+
artifact_service=artifact_service,
|
575
|
+
):
|
482
576
|
run_eval_results.append(
|
483
577
|
RunEvalResult(
|
484
578
|
app_name=app_name,
|
579
|
+
eval_set_file=eval_case_result.eval_set_file,
|
485
580
|
eval_set_id=eval_set_id,
|
486
|
-
eval_id=
|
487
|
-
final_eval_status=
|
488
|
-
eval_metric_results=
|
489
|
-
|
581
|
+
eval_id=eval_case_result.eval_id,
|
582
|
+
final_eval_status=eval_case_result.final_eval_status,
|
583
|
+
eval_metric_results=eval_case_result.eval_metric_results,
|
584
|
+
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
|
585
|
+
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
|
586
|
+
user_id=eval_case_result.user_id,
|
587
|
+
session_id=eval_case_result.session_id,
|
490
588
|
)
|
491
589
|
)
|
590
|
+
eval_case_result.session_details = await session_service.get_session(
|
591
|
+
app_name=app_name,
|
592
|
+
user_id=eval_case_result.user_id,
|
593
|
+
session_id=eval_case_result.session_id,
|
594
|
+
)
|
595
|
+
eval_case_results.append(eval_case_result)
|
596
|
+
|
597
|
+
timestamp = time.time()
|
598
|
+
eval_set_result_name = app_name + "_" + eval_set_id + "_" + str(timestamp)
|
599
|
+
eval_set_result = EvalSetResult(
|
600
|
+
eval_set_result_id=eval_set_result_name,
|
601
|
+
eval_set_result_name=eval_set_result_name,
|
602
|
+
eval_set_id=eval_set_id,
|
603
|
+
eval_case_results=eval_case_results,
|
604
|
+
creation_timestamp=timestamp,
|
605
|
+
)
|
606
|
+
|
607
|
+
# Write eval result file, with eval_set_result_name.
|
608
|
+
app_eval_history_dir = os.path.join(
|
609
|
+
agent_dir, app_name, ".adk", "eval_history"
|
610
|
+
)
|
611
|
+
if not os.path.exists(app_eval_history_dir):
|
612
|
+
os.makedirs(app_eval_history_dir)
|
613
|
+
# Convert to json and write to file.
|
614
|
+
eval_set_result_json = eval_set_result.model_dump_json()
|
615
|
+
eval_set_result_file_path = os.path.join(
|
616
|
+
app_eval_history_dir,
|
617
|
+
eval_set_result_name + _EVAL_SET_RESULT_FILE_EXTENSION,
|
618
|
+
)
|
619
|
+
logger.info("Writing eval result to file: %s", eval_set_result_file_path)
|
620
|
+
with open(eval_set_result_file_path, "w") as f:
|
621
|
+
f.write(json.dumps(eval_set_result_json, indent=2))
|
622
|
+
|
492
623
|
return run_eval_results
|
493
624
|
|
625
|
+
@app.get(
|
626
|
+
"/apps/{app_name}/eval_results/{eval_result_id}",
|
627
|
+
response_model_exclude_none=True,
|
628
|
+
)
|
629
|
+
def get_eval_result(
|
630
|
+
app_name: str,
|
631
|
+
eval_result_id: str,
|
632
|
+
) -> EvalSetResult:
|
633
|
+
"""Gets the eval result for the given eval id."""
|
634
|
+
# Load the eval set file data
|
635
|
+
maybe_eval_result_file_path = (
|
636
|
+
os.path.join(
|
637
|
+
agent_dir, app_name, ".adk", "eval_history", eval_result_id
|
638
|
+
)
|
639
|
+
+ _EVAL_SET_RESULT_FILE_EXTENSION
|
640
|
+
)
|
641
|
+
if not os.path.exists(maybe_eval_result_file_path):
|
642
|
+
raise HTTPException(
|
643
|
+
status_code=404,
|
644
|
+
detail=f"Eval result `{eval_result_id}` not found.",
|
645
|
+
)
|
646
|
+
with open(maybe_eval_result_file_path, "r") as file:
|
647
|
+
eval_result_data = json.load(file) # Load JSON into a list
|
648
|
+
try:
|
649
|
+
eval_result = EvalSetResult.model_validate_json(eval_result_data)
|
650
|
+
return eval_result
|
651
|
+
except ValidationError as e:
|
652
|
+
logger.exception("get_eval_result validation error: %s", e)
|
653
|
+
|
654
|
+
@app.get(
|
655
|
+
"/apps/{app_name}/eval_results",
|
656
|
+
response_model_exclude_none=True,
|
657
|
+
)
|
658
|
+
def list_eval_results(app_name: str) -> list[str]:
|
659
|
+
"""Lists all eval results for the given app."""
|
660
|
+
app_eval_history_directory = os.path.join(
|
661
|
+
agent_dir, app_name, ".adk", "eval_history"
|
662
|
+
)
|
663
|
+
|
664
|
+
if not os.path.exists(app_eval_history_directory):
|
665
|
+
return []
|
666
|
+
|
667
|
+
eval_result_files = [
|
668
|
+
file.removesuffix(_EVAL_SET_RESULT_FILE_EXTENSION)
|
669
|
+
for file in os.listdir(app_eval_history_directory)
|
670
|
+
if file.endswith(_EVAL_SET_RESULT_FILE_EXTENSION)
|
671
|
+
]
|
672
|
+
return eval_result_files
|
673
|
+
|
494
674
|
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
495
|
-
def delete_session(app_name: str, user_id: str, session_id: str):
|
675
|
+
async def delete_session(app_name: str, user_id: str, session_id: str):
|
496
676
|
# Connect to managed session if agent_engine_id is set.
|
497
677
|
app_name = agent_engine_id if agent_engine_id else app_name
|
498
|
-
session_service.delete_session(
|
678
|
+
await session_service.delete_session(
|
499
679
|
app_name=app_name, user_id=user_id, session_id=session_id
|
500
680
|
)
|
501
681
|
|
@@ -590,7 +770,7 @@ def get_fast_api_app(
|
|
590
770
|
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
591
771
|
# Connect to managed session if agent_engine_id is set.
|
592
772
|
app_id = agent_engine_id if agent_engine_id else req.app_name
|
593
|
-
session = session_service.get_session(
|
773
|
+
session = await session_service.get_session(
|
594
774
|
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
595
775
|
)
|
596
776
|
if not session:
|
@@ -612,7 +792,7 @@ def get_fast_api_app(
|
|
612
792
|
# Connect to managed session if agent_engine_id is set.
|
613
793
|
app_id = agent_engine_id if agent_engine_id else req.app_name
|
614
794
|
# SSE endpoint
|
615
|
-
session = session_service.get_session(
|
795
|
+
session = await session_service.get_session(
|
616
796
|
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
617
797
|
)
|
618
798
|
if not session:
|
@@ -653,7 +833,7 @@ def get_fast_api_app(
|
|
653
833
|
):
|
654
834
|
# Connect to managed session if agent_engine_id is set.
|
655
835
|
app_id = agent_engine_id if agent_engine_id else app_name
|
656
|
-
session = session_service.get_session(
|
836
|
+
session = await session_service.get_session(
|
657
837
|
app_name=app_id, user_id=user_id, session_id=session_id
|
658
838
|
)
|
659
839
|
session_events = session.events if session else []
|
@@ -673,7 +853,7 @@ def get_fast_api_app(
|
|
673
853
|
from_name = event.author
|
674
854
|
to_name = function_call.name
|
675
855
|
function_call_highlights.append((from_name, to_name))
|
676
|
-
dot_graph = agent_graph.get_agent_graph(
|
856
|
+
dot_graph = await agent_graph.get_agent_graph(
|
677
857
|
root_agent, function_call_highlights
|
678
858
|
)
|
679
859
|
elif function_responses:
|
@@ -682,17 +862,17 @@ def get_fast_api_app(
|
|
682
862
|
from_name = function_response.name
|
683
863
|
to_name = event.author
|
684
864
|
function_responses_highlights.append((from_name, to_name))
|
685
|
-
dot_graph = agent_graph.get_agent_graph(
|
865
|
+
dot_graph = await agent_graph.get_agent_graph(
|
686
866
|
root_agent, function_responses_highlights
|
687
867
|
)
|
688
868
|
else:
|
689
869
|
from_name = event.author
|
690
870
|
to_name = ""
|
691
|
-
dot_graph = agent_graph.get_agent_graph(
|
871
|
+
dot_graph = await agent_graph.get_agent_graph(
|
692
872
|
root_agent, [(from_name, to_name)]
|
693
873
|
)
|
694
874
|
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
695
|
-
return
|
875
|
+
return GetEventGraphResult(dot_src=dot_graph.source)
|
696
876
|
else:
|
697
877
|
return {}
|
698
878
|
|
@@ -710,7 +890,7 @@ def get_fast_api_app(
|
|
710
890
|
|
711
891
|
# Connect to managed session if agent_engine_id is set.
|
712
892
|
app_id = agent_engine_id if agent_engine_id else app_name
|
713
|
-
session = session_service.get_session(
|
893
|
+
session = await session_service.get_session(
|
714
894
|
app_name=app_id, user_id=user_id, session_id=session_id
|
715
895
|
)
|
716
896
|
if not session:
|
@@ -766,6 +946,16 @@ def get_fast_api_app(
|
|
766
946
|
for task in pending:
|
767
947
|
task.cancel()
|
768
948
|
|
949
|
+
def _get_all_toolsets(agent: BaseAgent) -> set[BaseToolset]:
|
950
|
+
toolsets = set()
|
951
|
+
if isinstance(agent, LlmAgent):
|
952
|
+
for tool_union in agent.tools:
|
953
|
+
if isinstance(tool_union, BaseToolset):
|
954
|
+
toolsets.add(tool_union)
|
955
|
+
for sub_agent in agent.sub_agents:
|
956
|
+
toolsets.update(_get_all_toolsets(sub_agent))
|
957
|
+
return toolsets
|
958
|
+
|
769
959
|
async def _get_root_agent_async(app_name: str) -> Agent:
|
770
960
|
"""Returns the root agent for the given app."""
|
771
961
|
if app_name in root_agent_dict:
|
@@ -776,16 +966,8 @@ def get_fast_api_app(
|
|
776
966
|
else:
|
777
967
|
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
778
968
|
|
779
|
-
# Handle an awaitable root agent and await for the actual agent.
|
780
|
-
if inspect.isawaitable(root_agent):
|
781
|
-
try:
|
782
|
-
agent, exit_stack = await root_agent
|
783
|
-
exit_stacks.append(exit_stack)
|
784
|
-
root_agent = agent
|
785
|
-
except Exception as e:
|
786
|
-
raise RuntimeError(f"error getting root agent, {e}") from e
|
787
|
-
|
788
969
|
root_agent_dict[app_name] = root_agent
|
970
|
+
toolsets_to_close.update(_get_all_toolsets(root_agent))
|
789
971
|
return root_agent
|
790
972
|
|
791
973
|
async def _get_runner_async(app_name: str) -> Runner:
|