google-adk 1.0.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/callback_context.py +2 -1
- google/adk/agents/readonly_context.py +3 -1
- google/adk/auth/auth_credential.py +4 -1
- google/adk/cli/browser/index.html +4 -4
- google/adk/cli/browser/{main-QOEMUXM4.js → main-PKDNKWJE.js} +59 -59
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +3 -2
- google/adk/cli/cli_eval.py +6 -85
- google/adk/cli/cli_tools_click.py +39 -10
- google/adk/cli/fast_api.py +53 -184
- google/adk/cli/utils/agent_loader.py +137 -0
- google/adk/cli/utils/cleanup.py +40 -0
- google/adk/cli/utils/evals.py +2 -1
- google/adk/cli/utils/logs.py +2 -7
- google/adk/code_executors/code_execution_utils.py +2 -1
- google/adk/code_executors/container_code_executor.py +0 -1
- google/adk/code_executors/vertex_ai_code_executor.py +6 -8
- google/adk/evaluation/eval_case.py +3 -1
- google/adk/evaluation/eval_metrics.py +74 -0
- google/adk/evaluation/eval_result.py +86 -0
- google/adk/evaluation/eval_set.py +2 -0
- google/adk/evaluation/eval_set_results_manager.py +47 -0
- google/adk/evaluation/eval_sets_manager.py +2 -1
- google/adk/evaluation/evaluator.py +2 -0
- google/adk/evaluation/local_eval_set_results_manager.py +113 -0
- google/adk/evaluation/local_eval_sets_manager.py +4 -4
- google/adk/evaluation/response_evaluator.py +2 -1
- google/adk/evaluation/trajectory_evaluator.py +3 -2
- google/adk/examples/base_example_provider.py +1 -0
- google/adk/flows/llm_flows/base_llm_flow.py +4 -6
- google/adk/flows/llm_flows/contents.py +3 -1
- google/adk/flows/llm_flows/instructions.py +7 -77
- google/adk/flows/llm_flows/single_flow.py +1 -1
- google/adk/models/base_llm.py +2 -1
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/google_llm.py +4 -1
- google/adk/models/lite_llm.py +3 -2
- google/adk/models/llm_response.py +2 -1
- google/adk/runners.py +36 -4
- google/adk/sessions/_session_util.py +2 -1
- google/adk/sessions/database_session_service.py +5 -8
- google/adk/sessions/vertex_ai_session_service.py +28 -13
- google/adk/telemetry.py +4 -2
- google/adk/tools/agent_tool.py +1 -1
- google/adk/tools/apihub_tool/apihub_toolset.py +1 -1
- google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
- google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +6 -2
- google/adk/tools/application_integration_tool/clients/connections_client.py +8 -1
- google/adk/tools/application_integration_tool/clients/integration_client.py +3 -1
- google/adk/tools/application_integration_tool/integration_connector_tool.py +1 -1
- google/adk/tools/base_toolset.py +40 -2
- google/adk/tools/bigquery/__init__.py +28 -0
- google/adk/tools/bigquery/bigquery_credentials.py +216 -0
- google/adk/tools/bigquery/bigquery_tool.py +116 -0
- google/adk/tools/function_parameter_parse_util.py +7 -0
- google/adk/tools/function_tool.py +33 -3
- google/adk/tools/get_user_choice_tool.py +1 -0
- google/adk/tools/google_api_tool/__init__.py +17 -11
- google/adk/tools/google_api_tool/google_api_tool.py +1 -1
- google/adk/tools/google_api_tool/google_api_toolset.py +0 -14
- google/adk/tools/google_api_tool/google_api_toolsets.py +8 -2
- google/adk/tools/google_search_tool.py +2 -2
- google/adk/tools/mcp_tool/conversion_utils.py +6 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +62 -188
- google/adk/tools/mcp_tool/mcp_tool.py +27 -24
- google/adk/tools/mcp_tool/mcp_toolset.py +76 -131
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
- google/adk/tools/openapi_tool/common/common.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +2 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
- google/adk/tools/toolbox_toolset.py +31 -3
- google/adk/utils/__init__.py +13 -0
- google/adk/utils/instructions_utils.py +131 -0
- google/adk/version.py +1 -1
- {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +12 -15
- {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/RECORD +83 -78
- google/adk/agents/base_agent.py.orig +0 -330
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
- google/adk/cli/fast_api.py.orig +0 -822
- google/adk/memory/base_memory_service.py.orig +0 -76
- google/adk/models/google_llm.py.orig +0 -305
- google/adk/tools/_built_in_code_execution_tool.py +0 -70
- google/adk/tools/mcp_tool/mcp_session_manager.py.orig +0 -322
- {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
- {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
- {google_adk-1.0.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/fast_api.py.orig
DELETED
@@ -1,822 +0,0 @@
|
|
1
|
-
# Copyright 2025 Google LLC
|
2
|
-
#
|
3
|
-
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
-
# you may not use this file except in compliance with the License.
|
5
|
-
# You may obtain a copy of the License at
|
6
|
-
#
|
7
|
-
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
-
#
|
9
|
-
# Unless required by applicable law or agreed to in writing, software
|
10
|
-
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
-
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
-
# See the License for the specific language governing permissions and
|
13
|
-
# limitations under the License.
|
14
|
-
|
15
|
-
import asyncio
|
16
|
-
from contextlib import asynccontextmanager
|
17
|
-
import importlib
|
18
|
-
import inspect
|
19
|
-
import json
|
20
|
-
import logging
|
21
|
-
import os
|
22
|
-
from pathlib import Path
|
23
|
-
import re
|
24
|
-
import sys
|
25
|
-
import traceback
|
26
|
-
import typing
|
27
|
-
from typing import Any
|
28
|
-
from typing import List
|
29
|
-
from typing import Literal
|
30
|
-
from typing import Optional
|
31
|
-
|
32
|
-
import click
|
33
|
-
from click import Tuple
|
34
|
-
from fastapi import FastAPI
|
35
|
-
from fastapi import HTTPException
|
36
|
-
from fastapi import Query
|
37
|
-
from fastapi.middleware.cors import CORSMiddleware
|
38
|
-
from fastapi.responses import FileResponse
|
39
|
-
from fastapi.responses import RedirectResponse
|
40
|
-
from fastapi.responses import StreamingResponse
|
41
|
-
from fastapi.staticfiles import StaticFiles
|
42
|
-
from fastapi.websockets import WebSocket
|
43
|
-
from fastapi.websockets import WebSocketDisconnect
|
44
|
-
from google.genai import types
|
45
|
-
import graphviz
|
46
|
-
from opentelemetry import trace
|
47
|
-
from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
48
|
-
from opentelemetry.sdk.trace import export
|
49
|
-
from opentelemetry.sdk.trace import ReadableSpan
|
50
|
-
from opentelemetry.sdk.trace import TracerProvider
|
51
|
-
from pydantic import BaseModel
|
52
|
-
from pydantic import ValidationError
|
53
|
-
from starlette.types import Lifespan
|
54
|
-
|
55
|
-
from ..agents import RunConfig
|
56
|
-
from ..agents.live_request_queue import LiveRequest
|
57
|
-
from ..agents.live_request_queue import LiveRequestQueue
|
58
|
-
from ..agents.llm_agent import Agent
|
59
|
-
from ..agents.run_config import StreamingMode
|
60
|
-
from ..artifacts import InMemoryArtifactService
|
61
|
-
from ..events.event import Event
|
62
|
-
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
63
|
-
from ..runners import Runner
|
64
|
-
from ..sessions.database_session_service import DatabaseSessionService
|
65
|
-
from ..sessions.in_memory_session_service import InMemorySessionService
|
66
|
-
from ..sessions.session import Session
|
67
|
-
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
68
|
-
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
69
|
-
from .cli_eval import EvalMetric
|
70
|
-
from .cli_eval import EvalMetricResult
|
71
|
-
from .cli_eval import EvalStatus
|
72
|
-
from .utils import create_empty_state
|
73
|
-
from .utils import envs
|
74
|
-
from .utils import evals
|
75
|
-
|
76
|
-
logger = logging.getLogger(__name__)
|
77
|
-
|
78
|
-
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
79
|
-
|
80
|
-
|
81
|
-
class ApiServerSpanExporter(export.SpanExporter):
|
82
|
-
|
83
|
-
def __init__(self, trace_dict):
|
84
|
-
self.trace_dict = trace_dict
|
85
|
-
|
86
|
-
def export(
|
87
|
-
self, spans: typing.Sequence[ReadableSpan]
|
88
|
-
) -> export.SpanExportResult:
|
89
|
-
for span in spans:
|
90
|
-
if (
|
91
|
-
span.name == "call_llm"
|
92
|
-
or span.name == "send_data"
|
93
|
-
or span.name.startswith("tool_response")
|
94
|
-
):
|
95
|
-
attributes = dict(span.attributes)
|
96
|
-
attributes["trace_id"] = span.get_span_context().trace_id
|
97
|
-
attributes["span_id"] = span.get_span_context().span_id
|
98
|
-
if attributes.get("gcp.vertex.agent.event_id", None):
|
99
|
-
self.trace_dict[attributes["gcp.vertex.agent.event_id"]] = attributes
|
100
|
-
return export.SpanExportResult.SUCCESS
|
101
|
-
|
102
|
-
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
103
|
-
return True
|
104
|
-
|
105
|
-
|
106
|
-
class AgentRunRequest(BaseModel):
|
107
|
-
app_name: str
|
108
|
-
user_id: str
|
109
|
-
session_id: str
|
110
|
-
new_message: types.Content
|
111
|
-
streaming: bool = False
|
112
|
-
|
113
|
-
|
114
|
-
class AddSessionToEvalSetRequest(BaseModel):
|
115
|
-
eval_id: str
|
116
|
-
session_id: str
|
117
|
-
user_id: str
|
118
|
-
|
119
|
-
|
120
|
-
class RunEvalRequest(BaseModel):
|
121
|
-
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
122
|
-
eval_metrics: list[EvalMetric]
|
123
|
-
|
124
|
-
|
125
|
-
class RunEvalResult(BaseModel):
|
126
|
-
eval_set_id: str
|
127
|
-
eval_id: str
|
128
|
-
final_eval_status: EvalStatus
|
129
|
-
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
130
|
-
session_id: str
|
131
|
-
|
132
|
-
|
133
|
-
def get_fast_api_app(
|
134
|
-
*,
|
135
|
-
agent_dir: str,
|
136
|
-
session_db_url: str = "",
|
137
|
-
allow_origins: Optional[list[str]] = None,
|
138
|
-
web: bool,
|
139
|
-
trace_to_cloud: bool = False,
|
140
|
-
lifespan: Optional[Lifespan[FastAPI]] = None,
|
141
|
-
) -> FastAPI:
|
142
|
-
# InMemory tracing dict.
|
143
|
-
trace_dict: dict[str, Any] = {}
|
144
|
-
|
145
|
-
# Set up tracing in the FastAPI server.
|
146
|
-
provider = TracerProvider()
|
147
|
-
provider.add_span_processor(
|
148
|
-
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
149
|
-
)
|
150
|
-
if trace_to_cloud:
|
151
|
-
envs.load_dotenv_for_agent("", agent_dir)
|
152
|
-
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
153
|
-
processor = export.BatchSpanProcessor(
|
154
|
-
CloudTraceSpanExporter(project_id=project_id)
|
155
|
-
)
|
156
|
-
provider.add_span_processor(processor)
|
157
|
-
else:
|
158
|
-
logging.warning(
|
159
|
-
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
160
|
-
" not be enabled."
|
161
|
-
)
|
162
|
-
|
163
|
-
trace.set_tracer_provider(provider)
|
164
|
-
|
165
|
-
exit_stacks = []
|
166
|
-
|
167
|
-
@asynccontextmanager
|
168
|
-
async def internal_lifespan(app: FastAPI):
|
169
|
-
if lifespan:
|
170
|
-
async with lifespan(app) as lifespan_context:
|
171
|
-
yield
|
172
|
-
|
173
|
-
if exit_stacks:
|
174
|
-
for stack in exit_stacks:
|
175
|
-
await stack.aclose()
|
176
|
-
else:
|
177
|
-
yield
|
178
|
-
|
179
|
-
# Run the FastAPI server.
|
180
|
-
app = FastAPI(lifespan=internal_lifespan)
|
181
|
-
|
182
|
-
if allow_origins:
|
183
|
-
app.add_middleware(
|
184
|
-
CORSMiddleware,
|
185
|
-
allow_origins=allow_origins,
|
186
|
-
allow_credentials=True,
|
187
|
-
allow_methods=["*"],
|
188
|
-
allow_headers=["*"],
|
189
|
-
)
|
190
|
-
|
191
|
-
if agent_dir not in sys.path:
|
192
|
-
sys.path.append(agent_dir)
|
193
|
-
|
194
|
-
runner_dict = {}
|
195
|
-
root_agent_dict = {}
|
196
|
-
|
197
|
-
# Build the Artifact service
|
198
|
-
artifact_service = InMemoryArtifactService()
|
199
|
-
memory_service = InMemoryMemoryService()
|
200
|
-
|
201
|
-
# Build the Session service
|
202
|
-
agent_engine_id = ""
|
203
|
-
if session_db_url:
|
204
|
-
if session_db_url.startswith("agentengine://"):
|
205
|
-
# Create vertex session service
|
206
|
-
agent_engine_id = session_db_url.split("://")[1]
|
207
|
-
if not agent_engine_id:
|
208
|
-
raise click.ClickException("Agent engine id can not be empty.")
|
209
|
-
envs.load_dotenv_for_agent("", agent_dir)
|
210
|
-
session_service = VertexAiSessionService(
|
211
|
-
os.environ["GOOGLE_CLOUD_PROJECT"],
|
212
|
-
os.environ["GOOGLE_CLOUD_LOCATION"],
|
213
|
-
)
|
214
|
-
else:
|
215
|
-
session_service = DatabaseSessionService(db_url=session_db_url)
|
216
|
-
else:
|
217
|
-
session_service = InMemorySessionService()
|
218
|
-
|
219
|
-
@app.get("/list-apps")
|
220
|
-
def list_apps() -> list[str]:
|
221
|
-
base_path = Path.cwd() / agent_dir
|
222
|
-
if not base_path.exists():
|
223
|
-
raise HTTPException(status_code=404, detail="Path not found")
|
224
|
-
if not base_path.is_dir():
|
225
|
-
raise HTTPException(status_code=400, detail="Not a directory")
|
226
|
-
agent_names = [
|
227
|
-
x
|
228
|
-
for x in os.listdir(base_path)
|
229
|
-
if os.path.isdir(os.path.join(base_path, x))
|
230
|
-
and not x.startswith(".")
|
231
|
-
and x != "__pycache__"
|
232
|
-
]
|
233
|
-
agent_names.sort()
|
234
|
-
return agent_names
|
235
|
-
|
236
|
-
@app.get("/debug/trace/{event_id}")
|
237
|
-
def get_trace_dict(event_id: str) -> Any:
|
238
|
-
event_dict = trace_dict.get(event_id, None)
|
239
|
-
if event_dict is None:
|
240
|
-
raise HTTPException(status_code=404, detail="Trace not found")
|
241
|
-
return event_dict
|
242
|
-
|
243
|
-
@app.get(
|
244
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
245
|
-
response_model_exclude_none=True,
|
246
|
-
)
|
247
|
-
def get_session(app_name: str, user_id: str, session_id: str) -> Session:
|
248
|
-
# Connect to managed session if agent_engine_id is set.
|
249
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
250
|
-
session = session_service.get_session(
|
251
|
-
app_name=app_name, user_id=user_id, session_id=session_id
|
252
|
-
)
|
253
|
-
if not session:
|
254
|
-
raise HTTPException(status_code=404, detail="Session not found")
|
255
|
-
return session
|
256
|
-
|
257
|
-
@app.get(
|
258
|
-
"/apps/{app_name}/users/{user_id}/sessions",
|
259
|
-
response_model_exclude_none=True,
|
260
|
-
)
|
261
|
-
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
262
|
-
# Connect to managed session if agent_engine_id is set.
|
263
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
264
|
-
return [
|
265
|
-
session
|
266
|
-
for session in session_service.list_sessions(
|
267
|
-
app_name=app_name, user_id=user_id
|
268
|
-
).sessions
|
269
|
-
# Remove sessions that were generated as a part of Eval.
|
270
|
-
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
271
|
-
]
|
272
|
-
|
273
|
-
@app.post(
|
274
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
275
|
-
response_model_exclude_none=True,
|
276
|
-
)
|
277
|
-
def create_session_with_id(
|
278
|
-
app_name: str,
|
279
|
-
user_id: str,
|
280
|
-
session_id: str,
|
281
|
-
state: Optional[dict[str, Any]] = None,
|
282
|
-
) -> Session:
|
283
|
-
# Connect to managed session if agent_engine_id is set.
|
284
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
285
|
-
if (
|
286
|
-
session_service.get_session(
|
287
|
-
app_name=app_name, user_id=user_id, session_id=session_id
|
288
|
-
)
|
289
|
-
is not None
|
290
|
-
):
|
291
|
-
logger.warning("Session already exists: %s", session_id)
|
292
|
-
raise HTTPException(
|
293
|
-
status_code=400, detail=f"Session already exists: {session_id}"
|
294
|
-
)
|
295
|
-
|
296
|
-
logger.info("New session created: %s", session_id)
|
297
|
-
return session_service.create_session(
|
298
|
-
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
299
|
-
)
|
300
|
-
|
301
|
-
@app.post(
|
302
|
-
"/apps/{app_name}/users/{user_id}/sessions",
|
303
|
-
response_model_exclude_none=True,
|
304
|
-
)
|
305
|
-
def create_session(
|
306
|
-
app_name: str,
|
307
|
-
user_id: str,
|
308
|
-
state: Optional[dict[str, Any]] = None,
|
309
|
-
) -> Session:
|
310
|
-
# Connect to managed session if agent_engine_id is set.
|
311
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
312
|
-
|
313
|
-
logger.info("New session created")
|
314
|
-
return session_service.create_session(
|
315
|
-
app_name=app_name, user_id=user_id, state=state
|
316
|
-
)
|
317
|
-
|
318
|
-
def _get_eval_set_file_path(app_name, agent_dir, eval_set_id) -> str:
|
319
|
-
return os.path.join(
|
320
|
-
agent_dir,
|
321
|
-
app_name,
|
322
|
-
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
323
|
-
)
|
324
|
-
|
325
|
-
@app.post(
|
326
|
-
"/apps/{app_name}/eval_sets/{eval_set_id}",
|
327
|
-
response_model_exclude_none=True,
|
328
|
-
)
|
329
|
-
def create_eval_set(
|
330
|
-
app_name: str,
|
331
|
-
eval_set_id: str,
|
332
|
-
):
|
333
|
-
"""Creates an eval set, given the id."""
|
334
|
-
pattern = r"^[a-zA-Z0-9_]+$"
|
335
|
-
if not bool(re.fullmatch(pattern, eval_set_id)):
|
336
|
-
raise HTTPException(
|
337
|
-
status_code=400,
|
338
|
-
detail=(
|
339
|
-
f"Invalid eval set id. Eval set id should have the `{pattern}`"
|
340
|
-
" format"
|
341
|
-
),
|
342
|
-
)
|
343
|
-
# Define the file path
|
344
|
-
new_eval_set_path = _get_eval_set_file_path(
|
345
|
-
app_name, agent_dir, eval_set_id
|
346
|
-
)
|
347
|
-
|
348
|
-
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
349
|
-
|
350
|
-
if not os.path.exists(new_eval_set_path):
|
351
|
-
# Write the JSON string to the file
|
352
|
-
logger.info("Eval set file doesn't exist, we will create a new one.")
|
353
|
-
with open(new_eval_set_path, "w") as f:
|
354
|
-
empty_content = json.dumps([], indent=2)
|
355
|
-
f.write(empty_content)
|
356
|
-
|
357
|
-
@app.get(
|
358
|
-
"/apps/{app_name}/eval_sets",
|
359
|
-
response_model_exclude_none=True,
|
360
|
-
)
|
361
|
-
def list_eval_sets(app_name: str) -> list[str]:
|
362
|
-
"""Lists all eval sets for the given app."""
|
363
|
-
eval_set_file_path = os.path.join(agent_dir, app_name)
|
364
|
-
eval_sets = []
|
365
|
-
for file in os.listdir(eval_set_file_path):
|
366
|
-
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
367
|
-
eval_sets.append(
|
368
|
-
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
369
|
-
)
|
370
|
-
|
371
|
-
return sorted(eval_sets)
|
372
|
-
|
373
|
-
@app.post(
|
374
|
-
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
375
|
-
response_model_exclude_none=True,
|
376
|
-
)
|
377
|
-
async def add_session_to_eval_set(
|
378
|
-
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
379
|
-
):
|
380
|
-
pattern = r"^[a-zA-Z0-9_]+$"
|
381
|
-
if not bool(re.fullmatch(pattern, req.eval_id)):
|
382
|
-
raise HTTPException(
|
383
|
-
status_code=400,
|
384
|
-
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
385
|
-
)
|
386
|
-
|
387
|
-
# Get the session
|
388
|
-
session = session_service.get_session(
|
389
|
-
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
390
|
-
)
|
391
|
-
assert session, "Session not found."
|
392
|
-
# Load the eval set file data
|
393
|
-
eval_set_file_path = _get_eval_set_file_path(
|
394
|
-
app_name, agent_dir, eval_set_id
|
395
|
-
)
|
396
|
-
with open(eval_set_file_path, "r") as file:
|
397
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
398
|
-
|
399
|
-
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
400
|
-
raise HTTPException(
|
401
|
-
status_code=400,
|
402
|
-
detail=(
|
403
|
-
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
404
|
-
" eval set."
|
405
|
-
),
|
406
|
-
)
|
407
|
-
|
408
|
-
# Convert the session data to evaluation format
|
409
|
-
test_data = evals.convert_session_to_eval_format(session)
|
410
|
-
|
411
|
-
# Populate the session with initial session state.
|
412
|
-
initial_session_state = create_empty_state(
|
413
|
-
await _get_root_agent_async(app_name)
|
414
|
-
)
|
415
|
-
|
416
|
-
eval_set_data.append({
|
417
|
-
"name": req.eval_id,
|
418
|
-
"data": test_data,
|
419
|
-
"initial_session": {
|
420
|
-
"state": initial_session_state,
|
421
|
-
"app_name": app_name,
|
422
|
-
"user_id": req.user_id,
|
423
|
-
},
|
424
|
-
})
|
425
|
-
# Serialize the test data to JSON and write to the eval set file.
|
426
|
-
with open(eval_set_file_path, "w") as f:
|
427
|
-
f.write(json.dumps(eval_set_data, indent=2))
|
428
|
-
|
429
|
-
@app.get(
|
430
|
-
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
431
|
-
response_model_exclude_none=True,
|
432
|
-
)
|
433
|
-
def list_evals_in_eval_set(
|
434
|
-
app_name: str,
|
435
|
-
eval_set_id: str,
|
436
|
-
) -> list[str]:
|
437
|
-
"""Lists all evals in an eval set."""
|
438
|
-
# Load the eval set file data
|
439
|
-
eval_set_file_path = _get_eval_set_file_path(
|
440
|
-
app_name, agent_dir, eval_set_id
|
441
|
-
)
|
442
|
-
with open(eval_set_file_path, "r") as file:
|
443
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
444
|
-
|
445
|
-
return sorted([x["name"] for x in eval_set_data])
|
446
|
-
|
447
|
-
@app.post(
|
448
|
-
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
449
|
-
response_model_exclude_none=True,
|
450
|
-
)
|
451
|
-
async def run_eval(
|
452
|
-
app_name: str, eval_set_id: str, req: RunEvalRequest
|
453
|
-
) -> list[RunEvalResult]:
|
454
|
-
from .cli_eval import run_evals
|
455
|
-
|
456
|
-
"""Runs an eval given the details in the eval request."""
|
457
|
-
# Create a mapping from eval set file to all the evals that needed to be
|
458
|
-
# run.
|
459
|
-
eval_set_file_path = _get_eval_set_file_path(
|
460
|
-
app_name, agent_dir, eval_set_id
|
461
|
-
)
|
462
|
-
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
463
|
-
|
464
|
-
if not req.eval_ids:
|
465
|
-
logger.info(
|
466
|
-
"Eval ids to run list is empty. We will all evals in the eval set."
|
467
|
-
)
|
468
|
-
root_agent = await _get_root_agent_async(app_name)
|
469
|
-
eval_results = list(
|
470
|
-
run_evals(
|
471
|
-
eval_set_to_evals,
|
472
|
-
root_agent,
|
473
|
-
getattr(root_agent, "reset_data", None),
|
474
|
-
req.eval_metrics,
|
475
|
-
session_service=session_service,
|
476
|
-
artifact_service=artifact_service,
|
477
|
-
)
|
478
|
-
)
|
479
|
-
|
480
|
-
run_eval_results = []
|
481
|
-
for eval_result in eval_results:
|
482
|
-
run_eval_results.append(
|
483
|
-
RunEvalResult(
|
484
|
-
app_name=app_name,
|
485
|
-
eval_set_id=eval_set_id,
|
486
|
-
eval_id=eval_result.eval_id,
|
487
|
-
final_eval_status=eval_result.final_eval_status,
|
488
|
-
eval_metric_results=eval_result.eval_metric_results,
|
489
|
-
session_id=eval_result.session_id,
|
490
|
-
)
|
491
|
-
)
|
492
|
-
return run_eval_results
|
493
|
-
|
494
|
-
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
495
|
-
def delete_session(app_name: str, user_id: str, session_id: str):
|
496
|
-
# Connect to managed session if agent_engine_id is set.
|
497
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
498
|
-
session_service.delete_session(
|
499
|
-
app_name=app_name, user_id=user_id, session_id=session_id
|
500
|
-
)
|
501
|
-
|
502
|
-
@app.get(
|
503
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
504
|
-
response_model_exclude_none=True,
|
505
|
-
)
|
506
|
-
def load_artifact(
|
507
|
-
app_name: str,
|
508
|
-
user_id: str,
|
509
|
-
session_id: str,
|
510
|
-
artifact_name: str,
|
511
|
-
version: Optional[int] = Query(None),
|
512
|
-
) -> Optional[types.Part]:
|
513
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
514
|
-
artifact = artifact_service.load_artifact(
|
515
|
-
app_name=app_name,
|
516
|
-
user_id=user_id,
|
517
|
-
session_id=session_id,
|
518
|
-
filename=artifact_name,
|
519
|
-
version=version,
|
520
|
-
)
|
521
|
-
if not artifact:
|
522
|
-
raise HTTPException(status_code=404, detail="Artifact not found")
|
523
|
-
return artifact
|
524
|
-
|
525
|
-
@app.get(
|
526
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions/{version_id}",
|
527
|
-
response_model_exclude_none=True,
|
528
|
-
)
|
529
|
-
def load_artifact_version(
|
530
|
-
app_name: str,
|
531
|
-
user_id: str,
|
532
|
-
session_id: str,
|
533
|
-
artifact_name: str,
|
534
|
-
version_id: int,
|
535
|
-
) -> Optional[types.Part]:
|
536
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
537
|
-
artifact = artifact_service.load_artifact(
|
538
|
-
app_name=app_name,
|
539
|
-
user_id=user_id,
|
540
|
-
session_id=session_id,
|
541
|
-
filename=artifact_name,
|
542
|
-
version=version_id,
|
543
|
-
)
|
544
|
-
if not artifact:
|
545
|
-
raise HTTPException(status_code=404, detail="Artifact not found")
|
546
|
-
return artifact
|
547
|
-
|
548
|
-
@app.get(
|
549
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts",
|
550
|
-
response_model_exclude_none=True,
|
551
|
-
)
|
552
|
-
def list_artifact_names(
|
553
|
-
app_name: str, user_id: str, session_id: str
|
554
|
-
) -> list[str]:
|
555
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
556
|
-
return artifact_service.list_artifact_keys(
|
557
|
-
app_name=app_name, user_id=user_id, session_id=session_id
|
558
|
-
)
|
559
|
-
|
560
|
-
@app.get(
|
561
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}/versions",
|
562
|
-
response_model_exclude_none=True,
|
563
|
-
)
|
564
|
-
def list_artifact_versions(
|
565
|
-
app_name: str, user_id: str, session_id: str, artifact_name: str
|
566
|
-
) -> list[int]:
|
567
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
568
|
-
return artifact_service.list_versions(
|
569
|
-
app_name=app_name,
|
570
|
-
user_id=user_id,
|
571
|
-
session_id=session_id,
|
572
|
-
filename=artifact_name,
|
573
|
-
)
|
574
|
-
|
575
|
-
@app.delete(
|
576
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/artifacts/{artifact_name}",
|
577
|
-
)
|
578
|
-
def delete_artifact(
|
579
|
-
app_name: str, user_id: str, session_id: str, artifact_name: str
|
580
|
-
):
|
581
|
-
app_name = agent_engine_id if agent_engine_id else app_name
|
582
|
-
artifact_service.delete_artifact(
|
583
|
-
app_name=app_name,
|
584
|
-
user_id=user_id,
|
585
|
-
session_id=session_id,
|
586
|
-
filename=artifact_name,
|
587
|
-
)
|
588
|
-
|
589
|
-
@app.post("/run", response_model_exclude_none=True)
|
590
|
-
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
591
|
-
# Connect to managed session if agent_engine_id is set.
|
592
|
-
app_id = agent_engine_id if agent_engine_id else req.app_name
|
593
|
-
session = session_service.get_session(
|
594
|
-
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
595
|
-
)
|
596
|
-
if not session:
|
597
|
-
raise HTTPException(status_code=404, detail="Session not found")
|
598
|
-
runner = await _get_runner_async(req.app_name)
|
599
|
-
events = [
|
600
|
-
event
|
601
|
-
async for event in runner.run_async(
|
602
|
-
user_id=req.user_id,
|
603
|
-
session_id=req.session_id,
|
604
|
-
new_message=req.new_message,
|
605
|
-
)
|
606
|
-
]
|
607
|
-
logger.info("Generated %s events in agent run: %s", len(events), events)
|
608
|
-
return events
|
609
|
-
|
610
|
-
@app.post("/run_sse")
|
611
|
-
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
612
|
-
# Connect to managed session if agent_engine_id is set.
|
613
|
-
app_id = agent_engine_id if agent_engine_id else req.app_name
|
614
|
-
# SSE endpoint
|
615
|
-
session = session_service.get_session(
|
616
|
-
app_name=app_id, user_id=req.user_id, session_id=req.session_id
|
617
|
-
)
|
618
|
-
if not session:
|
619
|
-
raise HTTPException(status_code=404, detail="Session not found")
|
620
|
-
|
621
|
-
# Convert the events to properly formatted SSE
|
622
|
-
async def event_generator():
|
623
|
-
try:
|
624
|
-
stream_mode = StreamingMode.SSE if req.streaming else StreamingMode.NONE
|
625
|
-
runner = await _get_runner_async(req.app_name)
|
626
|
-
async for event in runner.run_async(
|
627
|
-
user_id=req.user_id,
|
628
|
-
session_id=req.session_id,
|
629
|
-
new_message=req.new_message,
|
630
|
-
run_config=RunConfig(streaming_mode=stream_mode),
|
631
|
-
):
|
632
|
-
# Format as SSE data
|
633
|
-
sse_event = event.model_dump_json(exclude_none=True, by_alias=True)
|
634
|
-
logger.info("Generated event in agent run streaming: %s", sse_event)
|
635
|
-
yield f"data: {sse_event}\n\n"
|
636
|
-
except Exception as e:
|
637
|
-
logger.exception("Error in event_generator: %s", e)
|
638
|
-
# You might want to yield an error event here
|
639
|
-
yield f'data: {{"error": "{str(e)}"}}\n\n'
|
640
|
-
|
641
|
-
# Returns a streaming response with the proper media type for SSE
|
642
|
-
return StreamingResponse(
|
643
|
-
event_generator(),
|
644
|
-
media_type="text/event-stream",
|
645
|
-
)
|
646
|
-
|
647
|
-
@app.get(
|
648
|
-
"/apps/{app_name}/users/{user_id}/sessions/{session_id}/events/{event_id}/graph",
|
649
|
-
response_model_exclude_none=True,
|
650
|
-
)
|
651
|
-
async def get_event_graph(
|
652
|
-
app_name: str, user_id: str, session_id: str, event_id: str
|
653
|
-
):
|
654
|
-
# Connect to managed session if agent_engine_id is set.
|
655
|
-
app_id = agent_engine_id if agent_engine_id else app_name
|
656
|
-
session = session_service.get_session(
|
657
|
-
app_name=app_id, user_id=user_id, session_id=session_id
|
658
|
-
)
|
659
|
-
session_events = session.events if session else []
|
660
|
-
event = next((x for x in session_events if x.id == event_id), None)
|
661
|
-
if not event:
|
662
|
-
return {}
|
663
|
-
|
664
|
-
from . import agent_graph
|
665
|
-
|
666
|
-
function_calls = event.get_function_calls()
|
667
|
-
function_responses = event.get_function_responses()
|
668
|
-
root_agent = await _get_root_agent_async(app_name)
|
669
|
-
dot_graph = None
|
670
|
-
if function_calls:
|
671
|
-
function_call_highlights = []
|
672
|
-
for function_call in function_calls:
|
673
|
-
from_name = event.author
|
674
|
-
to_name = function_call.name
|
675
|
-
function_call_highlights.append((from_name, to_name))
|
676
|
-
dot_graph = agent_graph.get_agent_graph(
|
677
|
-
root_agent, function_call_highlights
|
678
|
-
)
|
679
|
-
elif function_responses:
|
680
|
-
function_responses_highlights = []
|
681
|
-
for function_response in function_responses:
|
682
|
-
from_name = function_response.name
|
683
|
-
to_name = event.author
|
684
|
-
function_responses_highlights.append((from_name, to_name))
|
685
|
-
dot_graph = agent_graph.get_agent_graph(
|
686
|
-
root_agent, function_responses_highlights
|
687
|
-
)
|
688
|
-
else:
|
689
|
-
from_name = event.author
|
690
|
-
to_name = ""
|
691
|
-
dot_graph = agent_graph.get_agent_graph(
|
692
|
-
root_agent, [(from_name, to_name)]
|
693
|
-
)
|
694
|
-
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
695
|
-
return {"dot_src": dot_graph.source}
|
696
|
-
else:
|
697
|
-
return {}
|
698
|
-
|
699
|
-
@app.websocket("/run_live")
|
700
|
-
async def agent_live_run(
|
701
|
-
websocket: WebSocket,
|
702
|
-
app_name: str,
|
703
|
-
user_id: str,
|
704
|
-
session_id: str,
|
705
|
-
modalities: List[Literal["TEXT", "AUDIO"]] = Query(
|
706
|
-
default=["TEXT", "AUDIO"]
|
707
|
-
), # Only allows "TEXT" or "AUDIO"
|
708
|
-
) -> None:
|
709
|
-
await websocket.accept()
|
710
|
-
|
711
|
-
# Connect to managed session if agent_engine_id is set.
|
712
|
-
app_id = agent_engine_id if agent_engine_id else app_name
|
713
|
-
session = session_service.get_session(
|
714
|
-
app_name=app_id, user_id=user_id, session_id=session_id
|
715
|
-
)
|
716
|
-
if not session:
|
717
|
-
# Accept first so that the client is aware of connection establishment,
|
718
|
-
# then close with a specific code.
|
719
|
-
await websocket.close(code=1002, reason="Session not found")
|
720
|
-
return
|
721
|
-
|
722
|
-
live_request_queue = LiveRequestQueue()
|
723
|
-
|
724
|
-
async def forward_events():
|
725
|
-
runner = await _get_runner_async(app_name)
|
726
|
-
async for event in runner.run_live(
|
727
|
-
session=session, live_request_queue=live_request_queue
|
728
|
-
):
|
729
|
-
await websocket.send_text(
|
730
|
-
event.model_dump_json(exclude_none=True, by_alias=True)
|
731
|
-
)
|
732
|
-
|
733
|
-
async def process_messages():
|
734
|
-
try:
|
735
|
-
while True:
|
736
|
-
data = await websocket.receive_text()
|
737
|
-
# Validate and send the received message to the live queue.
|
738
|
-
live_request_queue.send(LiveRequest.model_validate_json(data))
|
739
|
-
except ValidationError as ve:
|
740
|
-
logger.error("Validation error in process_messages: %s", ve)
|
741
|
-
|
742
|
-
# Run both tasks concurrently and cancel all if one fails.
|
743
|
-
tasks = [
|
744
|
-
asyncio.create_task(forward_events()),
|
745
|
-
asyncio.create_task(process_messages()),
|
746
|
-
]
|
747
|
-
done, pending = await asyncio.wait(
|
748
|
-
tasks, return_when=asyncio.FIRST_EXCEPTION
|
749
|
-
)
|
750
|
-
try:
|
751
|
-
# This will re-raise any exception from the completed tasks.
|
752
|
-
for task in done:
|
753
|
-
task.result()
|
754
|
-
except WebSocketDisconnect:
|
755
|
-
logger.info("Client disconnected during process_messages.")
|
756
|
-
except Exception as e:
|
757
|
-
logger.exception("Error during live websocket communication: %s", e)
|
758
|
-
traceback.print_exc()
|
759
|
-
WEBSOCKET_INTERNAL_ERROR_CODE = 1011
|
760
|
-
WEBSOCKET_MAX_BYTES_FOR_REASON = 123
|
761
|
-
await websocket.close(
|
762
|
-
code=WEBSOCKET_INTERNAL_ERROR_CODE,
|
763
|
-
reason=str(e)[:WEBSOCKET_MAX_BYTES_FOR_REASON],
|
764
|
-
)
|
765
|
-
finally:
|
766
|
-
for task in pending:
|
767
|
-
task.cancel()
|
768
|
-
|
769
|
-
async def _get_root_agent_async(app_name: str) -> Agent:
|
770
|
-
"""Returns the root agent for the given app."""
|
771
|
-
if app_name in root_agent_dict:
|
772
|
-
return root_agent_dict[app_name]
|
773
|
-
agent_module = importlib.import_module(app_name)
|
774
|
-
if getattr(agent_module.agent, "root_agent"):
|
775
|
-
root_agent = agent_module.agent.root_agent
|
776
|
-
else:
|
777
|
-
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
778
|
-
|
779
|
-
# Handle an awaitable root agent and await for the actual agent.
|
780
|
-
if inspect.isawaitable(root_agent):
|
781
|
-
try:
|
782
|
-
agent, exit_stack = await root_agent
|
783
|
-
exit_stacks.append(exit_stack)
|
784
|
-
root_agent = agent
|
785
|
-
except Exception as e:
|
786
|
-
raise RuntimeError(f"error getting root agent, {e}") from e
|
787
|
-
|
788
|
-
root_agent_dict[app_name] = root_agent
|
789
|
-
return root_agent
|
790
|
-
|
791
|
-
async def _get_runner_async(app_name: str) -> Runner:
|
792
|
-
"""Returns the runner for the given app."""
|
793
|
-
envs.load_dotenv_for_agent(os.path.basename(app_name), agent_dir)
|
794
|
-
if app_name in runner_dict:
|
795
|
-
return runner_dict[app_name]
|
796
|
-
root_agent = await _get_root_agent_async(app_name)
|
797
|
-
runner = Runner(
|
798
|
-
app_name=agent_engine_id if agent_engine_id else app_name,
|
799
|
-
agent=root_agent,
|
800
|
-
artifact_service=artifact_service,
|
801
|
-
session_service=session_service,
|
802
|
-
memory_service=memory_service,
|
803
|
-
)
|
804
|
-
runner_dict[app_name] = runner
|
805
|
-
return runner
|
806
|
-
|
807
|
-
if web:
|
808
|
-
BASE_DIR = Path(__file__).parent.resolve()
|
809
|
-
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
810
|
-
|
811
|
-
@app.get("/")
|
812
|
-
async def redirect_to_dev_ui():
|
813
|
-
return RedirectResponse("/dev-ui")
|
814
|
-
|
815
|
-
@app.get("/dev-ui")
|
816
|
-
async def dev_ui():
|
817
|
-
return FileResponse(BASE_DIR / "browser/index.html")
|
818
|
-
|
819
|
-
app.mount(
|
820
|
-
"/", StaticFiles(directory=ANGULAR_DIST_PATH, html=True), name="static"
|
821
|
-
)
|
822
|
-
return app
|