google-adk 0.5.0__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- google/adk/agents/base_agent.py +76 -30
- google/adk/agents/callback_context.py +2 -6
- google/adk/agents/llm_agent.py +122 -30
- google/adk/agents/loop_agent.py +1 -1
- google/adk/agents/parallel_agent.py +7 -0
- google/adk/agents/readonly_context.py +8 -0
- google/adk/agents/run_config.py +1 -1
- google/adk/agents/sequential_agent.py +31 -0
- google/adk/agents/transcription_entry.py +4 -2
- google/adk/artifacts/gcs_artifact_service.py +1 -1
- google/adk/artifacts/in_memory_artifact_service.py +1 -1
- google/adk/auth/auth_credential.py +10 -2
- google/adk/auth/auth_preprocessor.py +7 -1
- google/adk/auth/auth_tool.py +3 -4
- google/adk/cli/agent_graph.py +5 -5
- google/adk/cli/browser/index.html +4 -4
- google/adk/cli/browser/{main-ULN5R5I5.js → main-PKDNKWJE.js} +59 -60
- google/adk/cli/browser/polyfills-B6TNHZQ6.js +17 -0
- google/adk/cli/cli.py +10 -9
- google/adk/cli/cli_deploy.py +7 -2
- google/adk/cli/cli_eval.py +109 -115
- google/adk/cli/cli_tools_click.py +179 -67
- google/adk/cli/fast_api.py +248 -197
- google/adk/cli/utils/agent_loader.py +137 -0
- google/adk/cli/utils/cleanup.py +40 -0
- google/adk/cli/utils/common.py +23 -0
- google/adk/cli/utils/evals.py +83 -0
- google/adk/cli/utils/logs.py +8 -5
- google/adk/code_executors/__init__.py +3 -1
- google/adk/code_executors/built_in_code_executor.py +52 -0
- google/adk/code_executors/code_execution_utils.py +2 -1
- google/adk/code_executors/container_code_executor.py +0 -1
- google/adk/code_executors/vertex_ai_code_executor.py +6 -8
- google/adk/evaluation/__init__.py +1 -1
- google/adk/evaluation/agent_evaluator.py +168 -128
- google/adk/evaluation/eval_case.py +104 -0
- google/adk/evaluation/eval_metrics.py +74 -0
- google/adk/evaluation/eval_result.py +86 -0
- google/adk/evaluation/eval_set.py +39 -0
- google/adk/evaluation/eval_set_results_manager.py +47 -0
- google/adk/evaluation/eval_sets_manager.py +43 -0
- google/adk/evaluation/evaluation_generator.py +88 -113
- google/adk/evaluation/evaluator.py +58 -0
- google/adk/evaluation/local_eval_set_results_manager.py +113 -0
- google/adk/evaluation/local_eval_sets_manager.py +264 -0
- google/adk/evaluation/response_evaluator.py +106 -1
- google/adk/evaluation/trajectory_evaluator.py +84 -2
- google/adk/events/event.py +6 -1
- google/adk/events/event_actions.py +6 -1
- google/adk/examples/base_example_provider.py +1 -0
- google/adk/examples/example_util.py +3 -2
- google/adk/flows/llm_flows/_code_execution.py +9 -1
- google/adk/flows/llm_flows/audio_transcriber.py +4 -3
- google/adk/flows/llm_flows/base_llm_flow.py +58 -21
- google/adk/flows/llm_flows/contents.py +3 -1
- google/adk/flows/llm_flows/functions.py +9 -8
- google/adk/flows/llm_flows/instructions.py +18 -80
- google/adk/flows/llm_flows/single_flow.py +2 -2
- google/adk/memory/__init__.py +1 -1
- google/adk/memory/_utils.py +23 -0
- google/adk/memory/base_memory_service.py +23 -21
- google/adk/memory/in_memory_memory_service.py +57 -25
- google/adk/memory/memory_entry.py +37 -0
- google/adk/memory/vertex_ai_rag_memory_service.py +38 -15
- google/adk/models/anthropic_llm.py +16 -9
- google/adk/models/base_llm.py +2 -1
- google/adk/models/base_llm_connection.py +2 -0
- google/adk/models/gemini_llm_connection.py +11 -11
- google/adk/models/google_llm.py +12 -2
- google/adk/models/lite_llm.py +80 -23
- google/adk/models/llm_response.py +16 -3
- google/adk/models/registry.py +1 -1
- google/adk/runners.py +98 -42
- google/adk/sessions/__init__.py +1 -1
- google/adk/sessions/_session_util.py +2 -1
- google/adk/sessions/base_session_service.py +6 -33
- google/adk/sessions/database_session_service.py +57 -67
- google/adk/sessions/in_memory_session_service.py +106 -24
- google/adk/sessions/session.py +3 -0
- google/adk/sessions/vertex_ai_session_service.py +44 -51
- google/adk/telemetry.py +7 -2
- google/adk/tools/__init__.py +4 -7
- google/adk/tools/_memory_entry_utils.py +30 -0
- google/adk/tools/agent_tool.py +10 -10
- google/adk/tools/apihub_tool/apihub_toolset.py +55 -74
- google/adk/tools/apihub_tool/clients/apihub_client.py +10 -3
- google/adk/tools/apihub_tool/clients/secret_client.py +1 -0
- google/adk/tools/application_integration_tool/application_integration_toolset.py +111 -85
- google/adk/tools/application_integration_tool/clients/connections_client.py +28 -1
- google/adk/tools/application_integration_tool/clients/integration_client.py +7 -5
- google/adk/tools/application_integration_tool/integration_connector_tool.py +69 -26
- google/adk/tools/base_toolset.py +96 -0
- google/adk/tools/bigquery/__init__.py +28 -0
- google/adk/tools/bigquery/bigquery_credentials.py +216 -0
- google/adk/tools/bigquery/bigquery_tool.py +116 -0
- google/adk/tools/{built_in_code_execution_tool.py → enterprise_search_tool.py} +17 -11
- google/adk/tools/function_parameter_parse_util.py +9 -2
- google/adk/tools/function_tool.py +33 -3
- google/adk/tools/get_user_choice_tool.py +1 -0
- google/adk/tools/google_api_tool/__init__.py +24 -70
- google/adk/tools/google_api_tool/google_api_tool.py +12 -6
- google/adk/tools/google_api_tool/{google_api_tool_set.py → google_api_toolset.py} +57 -55
- google/adk/tools/google_api_tool/google_api_toolsets.py +108 -0
- google/adk/tools/google_api_tool/googleapi_to_openapi_converter.py +40 -42
- google/adk/tools/google_search_tool.py +2 -2
- google/adk/tools/langchain_tool.py +96 -49
- google/adk/tools/load_memory_tool.py +14 -5
- google/adk/tools/mcp_tool/__init__.py +3 -2
- google/adk/tools/mcp_tool/conversion_utils.py +6 -2
- google/adk/tools/mcp_tool/mcp_session_manager.py +80 -69
- google/adk/tools/mcp_tool/mcp_tool.py +35 -32
- google/adk/tools/mcp_tool/mcp_toolset.py +99 -194
- google/adk/tools/openapi_tool/auth/credential_exchangers/base_credential_exchanger.py +1 -3
- google/adk/tools/openapi_tool/auth/credential_exchangers/service_account_exchanger.py +6 -7
- google/adk/tools/openapi_tool/common/common.py +5 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/__init__.py +7 -2
- google/adk/tools/openapi_tool/openapi_spec_parser/openapi_toolset.py +27 -7
- google/adk/tools/openapi_tool/openapi_spec_parser/operation_parser.py +36 -32
- google/adk/tools/openapi_tool/openapi_spec_parser/rest_api_tool.py +11 -1
- google/adk/tools/openapi_tool/openapi_spec_parser/tool_auth_handler.py +1 -1
- google/adk/tools/preload_memory_tool.py +27 -18
- google/adk/tools/retrieval/__init__.py +1 -1
- google/adk/tools/retrieval/vertex_ai_rag_retrieval.py +1 -1
- google/adk/tools/toolbox_toolset.py +107 -0
- google/adk/tools/transfer_to_agent_tool.py +0 -1
- google/adk/utils/__init__.py +13 -0
- google/adk/utils/instructions_utils.py +131 -0
- google/adk/version.py +1 -1
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/METADATA +18 -19
- google_adk-1.1.0.dist-info/RECORD +200 -0
- google/adk/agents/remote_agent.py +0 -50
- google/adk/cli/browser/polyfills-FFHMD2TL.js +0 -18
- google/adk/cli/fast_api.py.orig +0 -728
- google/adk/tools/google_api_tool/google_api_tool_sets.py +0 -112
- google/adk/tools/toolbox_tool.py +0 -46
- google_adk-0.5.0.dist-info/RECORD +0 -180
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/WHEEL +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/entry_points.txt +0 -0
- {google_adk-0.5.0.dist-info → google_adk-1.1.0.dist-info}/licenses/LICENSE +0 -0
google/adk/cli/fast_api.py
CHANGED
@@ -12,16 +12,15 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
|
16
|
+
from __future__ import annotations
|
17
|
+
|
15
18
|
import asyncio
|
16
19
|
from contextlib import asynccontextmanager
|
17
|
-
import importlib
|
18
|
-
import inspect
|
19
|
-
import json
|
20
20
|
import logging
|
21
21
|
import os
|
22
22
|
from pathlib import Path
|
23
|
-
import
|
24
|
-
import sys
|
23
|
+
import time
|
25
24
|
import traceback
|
26
25
|
import typing
|
27
26
|
from typing import Any
|
@@ -30,7 +29,6 @@ from typing import Literal
|
|
30
29
|
from typing import Optional
|
31
30
|
|
32
31
|
import click
|
33
|
-
from click import Tuple
|
34
32
|
from fastapi import FastAPI
|
35
33
|
from fastapi import HTTPException
|
36
34
|
from fastapi import Query
|
@@ -48,16 +46,25 @@ from opentelemetry.exporter.cloud_trace import CloudTraceSpanExporter
|
|
48
46
|
from opentelemetry.sdk.trace import export
|
49
47
|
from opentelemetry.sdk.trace import ReadableSpan
|
50
48
|
from opentelemetry.sdk.trace import TracerProvider
|
51
|
-
from pydantic import
|
49
|
+
from pydantic import Field
|
52
50
|
from pydantic import ValidationError
|
53
51
|
from starlette.types import Lifespan
|
52
|
+
from typing_extensions import override
|
54
53
|
|
55
54
|
from ..agents import RunConfig
|
56
55
|
from ..agents.live_request_queue import LiveRequest
|
57
56
|
from ..agents.live_request_queue import LiveRequestQueue
|
58
57
|
from ..agents.llm_agent import Agent
|
59
58
|
from ..agents.run_config import StreamingMode
|
60
|
-
from ..artifacts import InMemoryArtifactService
|
59
|
+
from ..artifacts.in_memory_artifact_service import InMemoryArtifactService
|
60
|
+
from ..evaluation.eval_case import EvalCase
|
61
|
+
from ..evaluation.eval_case import SessionInput
|
62
|
+
from ..evaluation.eval_metrics import EvalMetric
|
63
|
+
from ..evaluation.eval_metrics import EvalMetricResult
|
64
|
+
from ..evaluation.eval_metrics import EvalMetricResultPerInvocation
|
65
|
+
from ..evaluation.eval_result import EvalSetResult
|
66
|
+
from ..evaluation.local_eval_set_results_manager import LocalEvalSetResultsManager
|
67
|
+
from ..evaluation.local_eval_sets_manager import LocalEvalSetsManager
|
61
68
|
from ..events.event import Event
|
62
69
|
from ..memory.in_memory_memory_service import InMemoryMemoryService
|
63
70
|
from ..runners import Runner
|
@@ -66,14 +73,15 @@ from ..sessions.in_memory_session_service import InMemorySessionService
|
|
66
73
|
from ..sessions.session import Session
|
67
74
|
from ..sessions.vertex_ai_session_service import VertexAiSessionService
|
68
75
|
from .cli_eval import EVAL_SESSION_ID_PREFIX
|
69
|
-
from .cli_eval import EvalMetric
|
70
|
-
from .cli_eval import EvalMetricResult
|
71
76
|
from .cli_eval import EvalStatus
|
77
|
+
from .utils import cleanup
|
78
|
+
from .utils import common
|
72
79
|
from .utils import create_empty_state
|
73
80
|
from .utils import envs
|
74
81
|
from .utils import evals
|
82
|
+
from .utils.agent_loader import AgentLoader
|
75
83
|
|
76
|
-
logger = logging.getLogger(__name__)
|
84
|
+
logger = logging.getLogger("google_adk." + __name__)
|
77
85
|
|
78
86
|
_EVAL_SET_FILE_EXTENSION = ".evalset.json"
|
79
87
|
|
@@ -103,7 +111,45 @@ class ApiServerSpanExporter(export.SpanExporter):
|
|
103
111
|
return True
|
104
112
|
|
105
113
|
|
106
|
-
class
|
114
|
+
class InMemoryExporter(export.SpanExporter):
|
115
|
+
|
116
|
+
def __init__(self, trace_dict):
|
117
|
+
super().__init__()
|
118
|
+
self._spans = []
|
119
|
+
self.trace_dict = trace_dict
|
120
|
+
|
121
|
+
@override
|
122
|
+
def export(
|
123
|
+
self, spans: typing.Sequence[ReadableSpan]
|
124
|
+
) -> export.SpanExportResult:
|
125
|
+
for span in spans:
|
126
|
+
trace_id = span.context.trace_id
|
127
|
+
if span.name == "call_llm":
|
128
|
+
attributes = dict(span.attributes)
|
129
|
+
session_id = attributes.get("gcp.vertex.agent.session_id", None)
|
130
|
+
if session_id:
|
131
|
+
if session_id not in self.trace_dict:
|
132
|
+
self.trace_dict[session_id] = [trace_id]
|
133
|
+
else:
|
134
|
+
self.trace_dict[session_id] += [trace_id]
|
135
|
+
self._spans.extend(spans)
|
136
|
+
return export.SpanExportResult.SUCCESS
|
137
|
+
|
138
|
+
@override
|
139
|
+
def force_flush(self, timeout_millis: int = 30000) -> bool:
|
140
|
+
return True
|
141
|
+
|
142
|
+
def get_finished_spans(self, session_id: str):
|
143
|
+
trace_ids = self.trace_dict.get(session_id, None)
|
144
|
+
if trace_ids is None or not trace_ids:
|
145
|
+
return []
|
146
|
+
return [x for x in self._spans if x.context.trace_id in trace_ids]
|
147
|
+
|
148
|
+
def clear(self):
|
149
|
+
self._spans.clear()
|
150
|
+
|
151
|
+
|
152
|
+
class AgentRunRequest(common.BaseModel):
|
107
153
|
app_name: str
|
108
154
|
user_id: str
|
109
155
|
session_id: str
|
@@ -111,28 +157,41 @@ class AgentRunRequest(BaseModel):
|
|
111
157
|
streaming: bool = False
|
112
158
|
|
113
159
|
|
114
|
-
class AddSessionToEvalSetRequest(BaseModel):
|
160
|
+
class AddSessionToEvalSetRequest(common.BaseModel):
|
115
161
|
eval_id: str
|
116
162
|
session_id: str
|
117
163
|
user_id: str
|
118
164
|
|
119
165
|
|
120
|
-
class RunEvalRequest(BaseModel):
|
166
|
+
class RunEvalRequest(common.BaseModel):
|
121
167
|
eval_ids: list[str] # if empty, then all evals in the eval set are run.
|
122
168
|
eval_metrics: list[EvalMetric]
|
123
169
|
|
124
170
|
|
125
|
-
class RunEvalResult(BaseModel):
|
171
|
+
class RunEvalResult(common.BaseModel):
|
172
|
+
eval_set_file: str
|
126
173
|
eval_set_id: str
|
127
174
|
eval_id: str
|
128
175
|
final_eval_status: EvalStatus
|
129
|
-
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]]
|
176
|
+
eval_metric_results: list[tuple[EvalMetric, EvalMetricResult]] = Field(
|
177
|
+
deprecated=True,
|
178
|
+
description=(
|
179
|
+
"This field is deprecated, use overall_eval_metric_results instead."
|
180
|
+
),
|
181
|
+
)
|
182
|
+
overall_eval_metric_results: list[EvalMetricResult]
|
183
|
+
eval_metric_result_per_invocation: list[EvalMetricResultPerInvocation]
|
184
|
+
user_id: str
|
130
185
|
session_id: str
|
131
186
|
|
132
187
|
|
188
|
+
class GetEventGraphResult(common.BaseModel):
|
189
|
+
dot_src: str
|
190
|
+
|
191
|
+
|
133
192
|
def get_fast_api_app(
|
134
193
|
*,
|
135
|
-
|
194
|
+
agents_dir: str,
|
136
195
|
session_db_url: str = "",
|
137
196
|
allow_origins: Optional[list[str]] = None,
|
138
197
|
web: bool,
|
@@ -141,40 +200,42 @@ def get_fast_api_app(
|
|
141
200
|
) -> FastAPI:
|
142
201
|
# InMemory tracing dict.
|
143
202
|
trace_dict: dict[str, Any] = {}
|
203
|
+
session_trace_dict: dict[str, Any] = {}
|
144
204
|
|
145
205
|
# Set up tracing in the FastAPI server.
|
146
206
|
provider = TracerProvider()
|
147
207
|
provider.add_span_processor(
|
148
208
|
export.SimpleSpanProcessor(ApiServerSpanExporter(trace_dict))
|
149
209
|
)
|
210
|
+
memory_exporter = InMemoryExporter(session_trace_dict)
|
211
|
+
provider.add_span_processor(export.SimpleSpanProcessor(memory_exporter))
|
150
212
|
if trace_to_cloud:
|
151
|
-
envs.load_dotenv_for_agent("",
|
213
|
+
envs.load_dotenv_for_agent("", agents_dir)
|
152
214
|
if project_id := os.environ.get("GOOGLE_CLOUD_PROJECT", None):
|
153
215
|
processor = export.BatchSpanProcessor(
|
154
216
|
CloudTraceSpanExporter(project_id=project_id)
|
155
217
|
)
|
156
218
|
provider.add_span_processor(processor)
|
157
219
|
else:
|
158
|
-
|
220
|
+
logger.warning(
|
159
221
|
"GOOGLE_CLOUD_PROJECT environment variable is not set. Tracing will"
|
160
222
|
" not be enabled."
|
161
223
|
)
|
162
224
|
|
163
225
|
trace.set_tracer_provider(provider)
|
164
226
|
|
165
|
-
exit_stacks = []
|
166
|
-
|
167
227
|
@asynccontextmanager
|
168
228
|
async def internal_lifespan(app: FastAPI):
|
169
|
-
if lifespan:
|
170
|
-
async with lifespan(app) as lifespan_context:
|
171
|
-
yield
|
172
229
|
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
230
|
+
try:
|
231
|
+
if lifespan:
|
232
|
+
async with lifespan(app) as lifespan_context:
|
233
|
+
yield lifespan_context
|
234
|
+
else:
|
235
|
+
yield
|
236
|
+
finally:
|
237
|
+
# Create tasks for all runner closures to run concurrently
|
238
|
+
await cleanup.close_runners(list(runner_dict.values()))
|
178
239
|
|
179
240
|
# Run the FastAPI server.
|
180
241
|
app = FastAPI(lifespan=internal_lifespan)
|
@@ -188,16 +249,15 @@ def get_fast_api_app(
|
|
188
249
|
allow_headers=["*"],
|
189
250
|
)
|
190
251
|
|
191
|
-
if agent_dir not in sys.path:
|
192
|
-
sys.path.append(agent_dir)
|
193
|
-
|
194
252
|
runner_dict = {}
|
195
|
-
root_agent_dict = {}
|
196
253
|
|
197
254
|
# Build the Artifact service
|
198
255
|
artifact_service = InMemoryArtifactService()
|
199
256
|
memory_service = InMemoryMemoryService()
|
200
257
|
|
258
|
+
eval_sets_manager = LocalEvalSetsManager(agents_dir=agents_dir)
|
259
|
+
eval_set_results_manager = LocalEvalSetResultsManager(agents_dir=agents_dir)
|
260
|
+
|
201
261
|
# Build the Session service
|
202
262
|
agent_engine_id = ""
|
203
263
|
if session_db_url:
|
@@ -206,7 +266,7 @@ def get_fast_api_app(
|
|
206
266
|
agent_engine_id = session_db_url.split("://")[1]
|
207
267
|
if not agent_engine_id:
|
208
268
|
raise click.ClickException("Agent engine id can not be empty.")
|
209
|
-
envs.load_dotenv_for_agent("",
|
269
|
+
envs.load_dotenv_for_agent("", agents_dir)
|
210
270
|
session_service = VertexAiSessionService(
|
211
271
|
os.environ["GOOGLE_CLOUD_PROJECT"],
|
212
272
|
os.environ["GOOGLE_CLOUD_LOCATION"],
|
@@ -216,9 +276,12 @@ def get_fast_api_app(
|
|
216
276
|
else:
|
217
277
|
session_service = InMemorySessionService()
|
218
278
|
|
279
|
+
# initialize Agent Loader
|
280
|
+
agent_loader = AgentLoader(agents_dir)
|
281
|
+
|
219
282
|
@app.get("/list-apps")
|
220
283
|
def list_apps() -> list[str]:
|
221
|
-
base_path = Path.cwd() /
|
284
|
+
base_path = Path.cwd() / agents_dir
|
222
285
|
if not base_path.exists():
|
223
286
|
raise HTTPException(status_code=404, detail="Path not found")
|
224
287
|
if not base_path.is_dir():
|
@@ -240,14 +303,34 @@ def get_fast_api_app(
|
|
240
303
|
raise HTTPException(status_code=404, detail="Trace not found")
|
241
304
|
return event_dict
|
242
305
|
|
306
|
+
@app.get("/debug/trace/session/{session_id}")
|
307
|
+
def get_session_trace(session_id: str) -> Any:
|
308
|
+
spans = memory_exporter.get_finished_spans(session_id)
|
309
|
+
if not spans:
|
310
|
+
return []
|
311
|
+
return [
|
312
|
+
{
|
313
|
+
"name": s.name,
|
314
|
+
"span_id": s.context.span_id,
|
315
|
+
"trace_id": s.context.trace_id,
|
316
|
+
"start_time": s.start_time,
|
317
|
+
"end_time": s.end_time,
|
318
|
+
"attributes": dict(s.attributes),
|
319
|
+
"parent_span_id": s.parent.span_id if s.parent else None,
|
320
|
+
}
|
321
|
+
for s in spans
|
322
|
+
]
|
323
|
+
|
243
324
|
@app.get(
|
244
325
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
245
326
|
response_model_exclude_none=True,
|
246
327
|
)
|
247
|
-
def get_session(
|
328
|
+
async def get_session(
|
329
|
+
app_name: str, user_id: str, session_id: str
|
330
|
+
) -> Session:
|
248
331
|
# Connect to managed session if agent_engine_id is set.
|
249
332
|
app_name = agent_engine_id if agent_engine_id else app_name
|
250
|
-
session = session_service.get_session(
|
333
|
+
session = await session_service.get_session(
|
251
334
|
app_name=app_name, user_id=user_id, session_id=session_id
|
252
335
|
)
|
253
336
|
if not session:
|
@@ -258,14 +341,15 @@ def get_fast_api_app(
|
|
258
341
|
"/apps/{app_name}/users/{user_id}/sessions",
|
259
342
|
response_model_exclude_none=True,
|
260
343
|
)
|
261
|
-
def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
344
|
+
async def list_sessions(app_name: str, user_id: str) -> list[Session]:
|
262
345
|
# Connect to managed session if agent_engine_id is set.
|
263
346
|
app_name = agent_engine_id if agent_engine_id else app_name
|
347
|
+
list_sessions_response = await session_service.list_sessions(
|
348
|
+
app_name=app_name, user_id=user_id
|
349
|
+
)
|
264
350
|
return [
|
265
351
|
session
|
266
|
-
for session in
|
267
|
-
app_name=app_name, user_id=user_id
|
268
|
-
).sessions
|
352
|
+
for session in list_sessions_response.sessions
|
269
353
|
# Remove sessions that were generated as a part of Eval.
|
270
354
|
if not session.id.startswith(EVAL_SESSION_ID_PREFIX)
|
271
355
|
]
|
@@ -274,7 +358,7 @@ def get_fast_api_app(
|
|
274
358
|
"/apps/{app_name}/users/{user_id}/sessions/{session_id}",
|
275
359
|
response_model_exclude_none=True,
|
276
360
|
)
|
277
|
-
def create_session_with_id(
|
361
|
+
async def create_session_with_id(
|
278
362
|
app_name: str,
|
279
363
|
user_id: str,
|
280
364
|
session_id: str,
|
@@ -283,7 +367,7 @@ def get_fast_api_app(
|
|
283
367
|
# Connect to managed session if agent_engine_id is set.
|
284
368
|
app_name = agent_engine_id if agent_engine_id else app_name
|
285
369
|
if (
|
286
|
-
session_service.get_session(
|
370
|
+
await session_service.get_session(
|
287
371
|
app_name=app_name, user_id=user_id, session_id=session_id
|
288
372
|
)
|
289
373
|
is not None
|
@@ -292,9 +376,8 @@ def get_fast_api_app(
|
|
292
376
|
raise HTTPException(
|
293
377
|
status_code=400, detail=f"Session already exists: {session_id}"
|
294
378
|
)
|
295
|
-
|
296
379
|
logger.info("New session created: %s", session_id)
|
297
|
-
return session_service.create_session(
|
380
|
+
return await session_service.create_session(
|
298
381
|
app_name=app_name, user_id=user_id, state=state, session_id=session_id
|
299
382
|
)
|
300
383
|
|
@@ -302,22 +385,21 @@ def get_fast_api_app(
|
|
302
385
|
"/apps/{app_name}/users/{user_id}/sessions",
|
303
386
|
response_model_exclude_none=True,
|
304
387
|
)
|
305
|
-
def create_session(
|
388
|
+
async def create_session(
|
306
389
|
app_name: str,
|
307
390
|
user_id: str,
|
308
391
|
state: Optional[dict[str, Any]] = None,
|
309
392
|
) -> Session:
|
310
393
|
# Connect to managed session if agent_engine_id is set.
|
311
394
|
app_name = agent_engine_id if agent_engine_id else app_name
|
312
|
-
|
313
395
|
logger.info("New session created")
|
314
|
-
return session_service.create_session(
|
396
|
+
return await session_service.create_session(
|
315
397
|
app_name=app_name, user_id=user_id, state=state
|
316
398
|
)
|
317
399
|
|
318
|
-
def _get_eval_set_file_path(app_name,
|
400
|
+
def _get_eval_set_file_path(app_name, agents_dir, eval_set_id) -> str:
|
319
401
|
return os.path.join(
|
320
|
-
|
402
|
+
agents_dir,
|
321
403
|
app_name,
|
322
404
|
eval_set_id + _EVAL_SET_FILE_EXTENSION,
|
323
405
|
)
|
@@ -331,28 +413,13 @@ def get_fast_api_app(
|
|
331
413
|
eval_set_id: str,
|
332
414
|
):
|
333
415
|
"""Creates an eval set, given the id."""
|
334
|
-
|
335
|
-
|
416
|
+
try:
|
417
|
+
eval_sets_manager.create_eval_set(app_name, eval_set_id)
|
418
|
+
except ValueError as ve:
|
336
419
|
raise HTTPException(
|
337
420
|
status_code=400,
|
338
|
-
detail=(
|
339
|
-
|
340
|
-
" format"
|
341
|
-
),
|
342
|
-
)
|
343
|
-
# Define the file path
|
344
|
-
new_eval_set_path = _get_eval_set_file_path(
|
345
|
-
app_name, agent_dir, eval_set_id
|
346
|
-
)
|
347
|
-
|
348
|
-
logger.info("Creating eval set file `%s`", new_eval_set_path)
|
349
|
-
|
350
|
-
if not os.path.exists(new_eval_set_path):
|
351
|
-
# Write the JSON string to the file
|
352
|
-
logger.info("Eval set file doesn't exist, we will create a new one.")
|
353
|
-
with open(new_eval_set_path, "w") as f:
|
354
|
-
empty_content = json.dumps([], indent=2)
|
355
|
-
f.write(empty_content)
|
421
|
+
detail=str(ve),
|
422
|
+
) from ve
|
356
423
|
|
357
424
|
@app.get(
|
358
425
|
"/apps/{app_name}/eval_sets",
|
@@ -360,15 +427,7 @@ def get_fast_api_app(
|
|
360
427
|
)
|
361
428
|
def list_eval_sets(app_name: str) -> list[str]:
|
362
429
|
"""Lists all eval sets for the given app."""
|
363
|
-
|
364
|
-
eval_sets = []
|
365
|
-
for file in os.listdir(eval_set_file_path):
|
366
|
-
if file.endswith(_EVAL_SET_FILE_EXTENSION):
|
367
|
-
eval_sets.append(
|
368
|
-
os.path.basename(file).removesuffix(_EVAL_SET_FILE_EXTENSION)
|
369
|
-
)
|
370
|
-
|
371
|
-
return sorted(eval_sets)
|
430
|
+
return eval_sets_manager.list_eval_sets(app_name)
|
372
431
|
|
373
432
|
@app.post(
|
374
433
|
"/apps/{app_name}/eval_sets/{eval_set_id}/add_session",
|
@@ -377,54 +436,33 @@ def get_fast_api_app(
|
|
377
436
|
async def add_session_to_eval_set(
|
378
437
|
app_name: str, eval_set_id: str, req: AddSessionToEvalSetRequest
|
379
438
|
):
|
380
|
-
pattern = r"^[a-zA-Z0-9_]+$"
|
381
|
-
if not bool(re.fullmatch(pattern, req.eval_id)):
|
382
|
-
raise HTTPException(
|
383
|
-
status_code=400,
|
384
|
-
detail=f"Invalid eval id. Eval id should have the `{pattern}` format",
|
385
|
-
)
|
386
|
-
|
387
439
|
# Get the session
|
388
|
-
session = session_service.get_session(
|
440
|
+
session = await session_service.get_session(
|
389
441
|
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
390
442
|
)
|
391
443
|
assert session, "Session not found."
|
392
|
-
# Load the eval set file data
|
393
|
-
eval_set_file_path = _get_eval_set_file_path(
|
394
|
-
app_name, agent_dir, eval_set_id
|
395
|
-
)
|
396
|
-
with open(eval_set_file_path, "r") as file:
|
397
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
398
|
-
|
399
|
-
if [x for x in eval_set_data if x["name"] == req.eval_id]:
|
400
|
-
raise HTTPException(
|
401
|
-
status_code=400,
|
402
|
-
detail=(
|
403
|
-
f"Eval id `{req.eval_id}` already exists in `{eval_set_id}`"
|
404
|
-
" eval set."
|
405
|
-
),
|
406
|
-
)
|
407
444
|
|
408
|
-
# Convert the session data to
|
409
|
-
|
445
|
+
# Convert the session data to eval invocations
|
446
|
+
invocations = evals.convert_session_to_eval_invocations(session)
|
410
447
|
|
411
448
|
# Populate the session with initial session state.
|
412
449
|
initial_session_state = create_empty_state(
|
413
|
-
|
450
|
+
agent_loader.load_agent(app_name)
|
451
|
+
)
|
452
|
+
|
453
|
+
new_eval_case = EvalCase(
|
454
|
+
eval_id=req.eval_id,
|
455
|
+
conversation=invocations,
|
456
|
+
session_input=SessionInput(
|
457
|
+
app_name=app_name, user_id=req.user_id, state=initial_session_state
|
458
|
+
),
|
459
|
+
creation_timestamp=time.time(),
|
414
460
|
)
|
415
461
|
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
"state": initial_session_state,
|
421
|
-
"app_name": app_name,
|
422
|
-
"user_id": req.user_id,
|
423
|
-
},
|
424
|
-
})
|
425
|
-
# Serialize the test data to JSON and write to the eval set file.
|
426
|
-
with open(eval_set_file_path, "w") as f:
|
427
|
-
f.write(json.dumps(eval_set_data, indent=2))
|
462
|
+
try:
|
463
|
+
eval_sets_manager.add_eval_case(app_name, eval_set_id, new_eval_case)
|
464
|
+
except ValueError as ve:
|
465
|
+
raise HTTPException(status_code=400, detail=str(ve)) from ve
|
428
466
|
|
429
467
|
@app.get(
|
430
468
|
"/apps/{app_name}/eval_sets/{eval_set_id}/evals",
|
@@ -435,14 +473,9 @@ def get_fast_api_app(
|
|
435
473
|
eval_set_id: str,
|
436
474
|
) -> list[str]:
|
437
475
|
"""Lists all evals in an eval set."""
|
438
|
-
|
439
|
-
eval_set_file_path = _get_eval_set_file_path(
|
440
|
-
app_name, agent_dir, eval_set_id
|
441
|
-
)
|
442
|
-
with open(eval_set_file_path, "r") as file:
|
443
|
-
eval_set_data = json.load(file) # Load JSON into a list
|
476
|
+
eval_set_data = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
444
477
|
|
445
|
-
return sorted([x
|
478
|
+
return sorted([x.eval_id for x in eval_set_data.eval_cases])
|
446
479
|
|
447
480
|
@app.post(
|
448
481
|
"/apps/{app_name}/eval_sets/{eval_set_id}/run_eval",
|
@@ -451,51 +484,89 @@ def get_fast_api_app(
|
|
451
484
|
async def run_eval(
|
452
485
|
app_name: str, eval_set_id: str, req: RunEvalRequest
|
453
486
|
) -> list[RunEvalResult]:
|
487
|
+
"""Runs an eval given the details in the eval request."""
|
454
488
|
from .cli_eval import run_evals
|
455
489
|
|
456
|
-
"""Runs an eval given the details in the eval request."""
|
457
490
|
# Create a mapping from eval set file to all the evals that needed to be
|
458
491
|
# run.
|
459
|
-
|
460
|
-
app_name, agent_dir, eval_set_id
|
461
|
-
)
|
462
|
-
eval_set_to_evals = {eval_set_file_path: req.eval_ids}
|
492
|
+
eval_set = eval_sets_manager.get_eval_set(app_name, eval_set_id)
|
463
493
|
|
464
|
-
if
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
await run_evals(
|
471
|
-
eval_set_to_evals,
|
472
|
-
root_agent,
|
473
|
-
getattr(root_agent, "reset_data", None),
|
474
|
-
req.eval_metrics,
|
475
|
-
session_service=session_service,
|
476
|
-
artifact_service=artifact_service,
|
477
|
-
)
|
478
|
-
)
|
494
|
+
if req.eval_ids:
|
495
|
+
eval_cases = [e for e in eval_set.eval_cases if e.eval_id in req.eval_ids]
|
496
|
+
eval_set_to_evals = {eval_set_id: eval_cases}
|
497
|
+
else:
|
498
|
+
logger.info("Eval ids to run list is empty. We will run all eval cases.")
|
499
|
+
eval_set_to_evals = {eval_set_id: eval_set.eval_cases}
|
479
500
|
|
501
|
+
root_agent = agent_loader.load_agent(app_name)
|
480
502
|
run_eval_results = []
|
481
|
-
|
503
|
+
eval_case_results = []
|
504
|
+
async for eval_case_result in run_evals(
|
505
|
+
eval_set_to_evals,
|
506
|
+
root_agent,
|
507
|
+
getattr(root_agent, "reset_data", None),
|
508
|
+
req.eval_metrics,
|
509
|
+
session_service=session_service,
|
510
|
+
artifact_service=artifact_service,
|
511
|
+
):
|
482
512
|
run_eval_results.append(
|
483
513
|
RunEvalResult(
|
484
514
|
app_name=app_name,
|
515
|
+
eval_set_file=eval_case_result.eval_set_file,
|
485
516
|
eval_set_id=eval_set_id,
|
486
|
-
eval_id=
|
487
|
-
final_eval_status=
|
488
|
-
eval_metric_results=
|
489
|
-
|
517
|
+
eval_id=eval_case_result.eval_id,
|
518
|
+
final_eval_status=eval_case_result.final_eval_status,
|
519
|
+
eval_metric_results=eval_case_result.eval_metric_results,
|
520
|
+
overall_eval_metric_results=eval_case_result.overall_eval_metric_results,
|
521
|
+
eval_metric_result_per_invocation=eval_case_result.eval_metric_result_per_invocation,
|
522
|
+
user_id=eval_case_result.user_id,
|
523
|
+
session_id=eval_case_result.session_id,
|
490
524
|
)
|
491
525
|
)
|
526
|
+
eval_case_result.session_details = await session_service.get_session(
|
527
|
+
app_name=app_name,
|
528
|
+
user_id=eval_case_result.user_id,
|
529
|
+
session_id=eval_case_result.session_id,
|
530
|
+
)
|
531
|
+
eval_case_results.append(eval_case_result)
|
532
|
+
|
533
|
+
eval_set_results_manager.save_eval_set_result(
|
534
|
+
app_name, eval_set_id, eval_case_results
|
535
|
+
)
|
536
|
+
|
492
537
|
return run_eval_results
|
493
538
|
|
539
|
+
@app.get(
|
540
|
+
"/apps/{app_name}/eval_results/{eval_result_id}",
|
541
|
+
response_model_exclude_none=True,
|
542
|
+
)
|
543
|
+
def get_eval_result(
|
544
|
+
app_name: str,
|
545
|
+
eval_result_id: str,
|
546
|
+
) -> EvalSetResult:
|
547
|
+
"""Gets the eval result for the given eval id."""
|
548
|
+
try:
|
549
|
+
return eval_set_results_manager.get_eval_set_result(
|
550
|
+
app_name, eval_result_id
|
551
|
+
)
|
552
|
+
except ValueError as ve:
|
553
|
+
raise HTTPException(status_code=404, detail=str(ve)) from ve
|
554
|
+
except ValidationError as ve:
|
555
|
+
raise HTTPException(status_code=500, detail=str(ve)) from ve
|
556
|
+
|
557
|
+
@app.get(
|
558
|
+
"/apps/{app_name}/eval_results",
|
559
|
+
response_model_exclude_none=True,
|
560
|
+
)
|
561
|
+
def list_eval_results(app_name: str) -> list[str]:
|
562
|
+
"""Lists all eval results for the given app."""
|
563
|
+
return eval_set_results_manager.list_eval_set_results(app_name)
|
564
|
+
|
494
565
|
@app.delete("/apps/{app_name}/users/{user_id}/sessions/{session_id}")
|
495
|
-
def delete_session(app_name: str, user_id: str, session_id: str):
|
566
|
+
async def delete_session(app_name: str, user_id: str, session_id: str):
|
496
567
|
# Connect to managed session if agent_engine_id is set.
|
497
568
|
app_name = agent_engine_id if agent_engine_id else app_name
|
498
|
-
session_service.delete_session(
|
569
|
+
await session_service.delete_session(
|
499
570
|
app_name=app_name, user_id=user_id, session_id=session_id
|
500
571
|
)
|
501
572
|
|
@@ -589,9 +660,9 @@ def get_fast_api_app(
|
|
589
660
|
@app.post("/run", response_model_exclude_none=True)
|
590
661
|
async def agent_run(req: AgentRunRequest) -> list[Event]:
|
591
662
|
# Connect to managed session if agent_engine_id is set.
|
592
|
-
|
593
|
-
session = session_service.get_session(
|
594
|
-
app_name=
|
663
|
+
app_name = agent_engine_id if agent_engine_id else req.app_name
|
664
|
+
session = await session_service.get_session(
|
665
|
+
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
595
666
|
)
|
596
667
|
if not session:
|
597
668
|
raise HTTPException(status_code=404, detail="Session not found")
|
@@ -610,10 +681,10 @@ def get_fast_api_app(
|
|
610
681
|
@app.post("/run_sse")
|
611
682
|
async def agent_run_sse(req: AgentRunRequest) -> StreamingResponse:
|
612
683
|
# Connect to managed session if agent_engine_id is set.
|
613
|
-
|
684
|
+
app_name = agent_engine_id if agent_engine_id else req.app_name
|
614
685
|
# SSE endpoint
|
615
|
-
session = session_service.get_session(
|
616
|
-
app_name=
|
686
|
+
session = await session_service.get_session(
|
687
|
+
app_name=app_name, user_id=req.user_id, session_id=req.session_id
|
617
688
|
)
|
618
689
|
if not session:
|
619
690
|
raise HTTPException(status_code=404, detail="Session not found")
|
@@ -652,9 +723,9 @@ def get_fast_api_app(
|
|
652
723
|
app_name: str, user_id: str, session_id: str, event_id: str
|
653
724
|
):
|
654
725
|
# Connect to managed session if agent_engine_id is set.
|
655
|
-
|
656
|
-
session = session_service.get_session(
|
657
|
-
app_name=
|
726
|
+
app_name = agent_engine_id if agent_engine_id else app_name
|
727
|
+
session = await session_service.get_session(
|
728
|
+
app_name=app_name, user_id=user_id, session_id=session_id
|
658
729
|
)
|
659
730
|
session_events = session.events if session else []
|
660
731
|
event = next((x for x in session_events if x.id == event_id), None)
|
@@ -665,7 +736,7 @@ def get_fast_api_app(
|
|
665
736
|
|
666
737
|
function_calls = event.get_function_calls()
|
667
738
|
function_responses = event.get_function_responses()
|
668
|
-
root_agent =
|
739
|
+
root_agent = agent_loader.load_agent(app_name)
|
669
740
|
dot_graph = None
|
670
741
|
if function_calls:
|
671
742
|
function_call_highlights = []
|
@@ -673,7 +744,7 @@ def get_fast_api_app(
|
|
673
744
|
from_name = event.author
|
674
745
|
to_name = function_call.name
|
675
746
|
function_call_highlights.append((from_name, to_name))
|
676
|
-
dot_graph = agent_graph.get_agent_graph(
|
747
|
+
dot_graph = await agent_graph.get_agent_graph(
|
677
748
|
root_agent, function_call_highlights
|
678
749
|
)
|
679
750
|
elif function_responses:
|
@@ -682,17 +753,17 @@ def get_fast_api_app(
|
|
682
753
|
from_name = function_response.name
|
683
754
|
to_name = event.author
|
684
755
|
function_responses_highlights.append((from_name, to_name))
|
685
|
-
dot_graph = agent_graph.get_agent_graph(
|
756
|
+
dot_graph = await agent_graph.get_agent_graph(
|
686
757
|
root_agent, function_responses_highlights
|
687
758
|
)
|
688
759
|
else:
|
689
760
|
from_name = event.author
|
690
761
|
to_name = ""
|
691
|
-
dot_graph = agent_graph.get_agent_graph(
|
762
|
+
dot_graph = await agent_graph.get_agent_graph(
|
692
763
|
root_agent, [(from_name, to_name)]
|
693
764
|
)
|
694
765
|
if dot_graph and isinstance(dot_graph, graphviz.Digraph):
|
695
|
-
return
|
766
|
+
return GetEventGraphResult(dot_src=dot_graph.source)
|
696
767
|
else:
|
697
768
|
return {}
|
698
769
|
|
@@ -709,9 +780,9 @@ def get_fast_api_app(
|
|
709
780
|
await websocket.accept()
|
710
781
|
|
711
782
|
# Connect to managed session if agent_engine_id is set.
|
712
|
-
|
713
|
-
session = session_service.get_session(
|
714
|
-
app_name=
|
783
|
+
app_name = agent_engine_id if agent_engine_id else app_name
|
784
|
+
session = await session_service.get_session(
|
785
|
+
app_name=app_name, user_id=user_id, session_id=session_id
|
715
786
|
)
|
716
787
|
if not session:
|
717
788
|
# Accept first so that the client is aware of connection establishment,
|
@@ -766,34 +837,12 @@ def get_fast_api_app(
|
|
766
837
|
for task in pending:
|
767
838
|
task.cancel()
|
768
839
|
|
769
|
-
async def _get_root_agent_async(app_name: str) -> Agent:
|
770
|
-
"""Returns the root agent for the given app."""
|
771
|
-
if app_name in root_agent_dict:
|
772
|
-
return root_agent_dict[app_name]
|
773
|
-
agent_module = importlib.import_module(app_name)
|
774
|
-
if getattr(agent_module.agent, "root_agent"):
|
775
|
-
root_agent = agent_module.agent.root_agent
|
776
|
-
else:
|
777
|
-
raise ValueError(f'Unable to find "root_agent" from {app_name}.')
|
778
|
-
|
779
|
-
# Handle an awaitable root agent and await for the actual agent.
|
780
|
-
if inspect.isawaitable(root_agent):
|
781
|
-
try:
|
782
|
-
agent, exit_stack = await root_agent
|
783
|
-
exit_stacks.append(exit_stack)
|
784
|
-
root_agent = agent
|
785
|
-
except Exception as e:
|
786
|
-
raise RuntimeError(f"error getting root agent, {e}") from e
|
787
|
-
|
788
|
-
root_agent_dict[app_name] = root_agent
|
789
|
-
return root_agent
|
790
|
-
|
791
840
|
async def _get_runner_async(app_name: str) -> Runner:
|
792
841
|
"""Returns the runner for the given app."""
|
793
|
-
envs.load_dotenv_for_agent(os.path.basename(app_name),
|
842
|
+
envs.load_dotenv_for_agent(os.path.basename(app_name), agents_dir)
|
794
843
|
if app_name in runner_dict:
|
795
844
|
return runner_dict[app_name]
|
796
|
-
root_agent =
|
845
|
+
root_agent = agent_loader.load_agent(app_name)
|
797
846
|
runner = Runner(
|
798
847
|
app_name=agent_engine_id if agent_engine_id else app_name,
|
799
848
|
agent=root_agent,
|
@@ -809,14 +858,16 @@ def get_fast_api_app(
|
|
809
858
|
ANGULAR_DIST_PATH = BASE_DIR / "browser"
|
810
859
|
|
811
860
|
@app.get("/")
|
812
|
-
async def
|
813
|
-
return RedirectResponse("/dev-ui")
|
861
|
+
async def redirect_root_to_dev_ui():
|
862
|
+
return RedirectResponse("/dev-ui/")
|
814
863
|
|
815
864
|
@app.get("/dev-ui")
|
816
|
-
async def
|
817
|
-
return
|
865
|
+
async def redirect_dev_ui_add_slash():
|
866
|
+
return RedirectResponse("/dev-ui/")
|
818
867
|
|
819
868
|
app.mount(
|
820
|
-
"/",
|
869
|
+
"/dev-ui/",
|
870
|
+
StaticFiles(directory=ANGULAR_DIST_PATH, html=True),
|
871
|
+
name="static",
|
821
872
|
)
|
822
873
|
return app
|