agentevals-cli 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentevals/__init__.py +16 -0
- agentevals/_protocol.py +83 -0
- agentevals/api/__init__.py +0 -0
- agentevals/api/app.py +137 -0
- agentevals/api/debug_routes.py +268 -0
- agentevals/api/models.py +204 -0
- agentevals/api/otlp_app.py +25 -0
- agentevals/api/otlp_routes.py +383 -0
- agentevals/api/routes.py +554 -0
- agentevals/api/streaming_routes.py +373 -0
- agentevals/builtin_metrics.py +234 -0
- agentevals/cli.py +643 -0
- agentevals/config.py +108 -0
- agentevals/converter.py +328 -0
- agentevals/custom_evaluators.py +468 -0
- agentevals/eval_config_loader.py +147 -0
- agentevals/evaluator/__init__.py +24 -0
- agentevals/evaluator/resolver.py +70 -0
- agentevals/evaluator/sources.py +293 -0
- agentevals/evaluator/templates.py +224 -0
- agentevals/extraction.py +444 -0
- agentevals/genai_converter.py +538 -0
- agentevals/loader/__init__.py +7 -0
- agentevals/loader/base.py +53 -0
- agentevals/loader/jaeger.py +112 -0
- agentevals/loader/otlp.py +193 -0
- agentevals/mcp_server.py +236 -0
- agentevals/output.py +204 -0
- agentevals/runner.py +310 -0
- agentevals/sdk.py +433 -0
- agentevals/streaming/__init__.py +120 -0
- agentevals/streaming/incremental_processor.py +337 -0
- agentevals/streaming/processor.py +285 -0
- agentevals/streaming/session.py +36 -0
- agentevals/streaming/ws_server.py +806 -0
- agentevals/trace_attrs.py +32 -0
- agentevals/trace_metrics.py +126 -0
- agentevals/utils/__init__.py +0 -0
- agentevals/utils/genai_messages.py +142 -0
- agentevals/utils/log_buffer.py +43 -0
- agentevals/utils/log_enrichment.py +187 -0
- agentevals_cli-0.5.2.dist-info/METADATA +22 -0
- agentevals_cli-0.5.2.dist-info/RECORD +46 -0
- agentevals_cli-0.5.2.dist-info/WHEEL +4 -0
- agentevals_cli-0.5.2.dist-info/entry_points.txt +2 -0
- agentevals_cli-0.5.2.dist-info/licenses/LICENSE +201 -0
|
@@ -0,0 +1,373 @@
|
|
|
1
|
+
"""API routes for streaming session management."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import json
|
|
7
|
+
import logging
|
|
8
|
+
|
|
9
|
+
from fastapi import APIRouter, HTTPException
|
|
10
|
+
from fastapi.responses import FileResponse
|
|
11
|
+
from pydantic import BaseModel
|
|
12
|
+
|
|
13
|
+
from ..config import EvalRunConfig
|
|
14
|
+
from ..converter import convert_traces
|
|
15
|
+
from ..loader.otlp import OtlpJsonLoader
|
|
16
|
+
from ..runner import run_evaluation
|
|
17
|
+
from ..trace_attrs import OTEL_GENAI_INPUT_MESSAGES, OTEL_GENAI_REQUEST_MODEL
|
|
18
|
+
from ..utils.log_enrichment import enrich_spans_with_logs
|
|
19
|
+
from .models import (
|
|
20
|
+
CreateEvalSetData,
|
|
21
|
+
EvaluateSessionsData,
|
|
22
|
+
GetTraceData,
|
|
23
|
+
PrepareEvaluationData,
|
|
24
|
+
SessionEvalResult,
|
|
25
|
+
SessionInfo,
|
|
26
|
+
StandardResponse,
|
|
27
|
+
)
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
streaming_router = APIRouter()
|
|
32
|
+
|
|
33
|
+
trace_manager = None
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
def set_trace_manager(manager):
|
|
37
|
+
"""Set the trace manager instance."""
|
|
38
|
+
global trace_manager
|
|
39
|
+
trace_manager = manager
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
class CreateEvalSetRequest(BaseModel):
|
|
43
|
+
session_id: str
|
|
44
|
+
eval_set_id: str
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class EvaluateSessionsRequest(BaseModel):
|
|
48
|
+
golden_session_id: str
|
|
49
|
+
eval_set_id: str
|
|
50
|
+
metrics: list[str] = ["tool_trajectory_avg_score"]
|
|
51
|
+
judge_model: str = "gemini-2.5-flash"
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
class PrepareEvaluationRequest(BaseModel):
|
|
55
|
+
golden_session_id: str
|
|
56
|
+
session_ids: list[str]
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
class GetTraceRequest(BaseModel):
|
|
60
|
+
session_id: str
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
@streaming_router.get("/sessions", response_model=StandardResponse[list[SessionInfo]])
|
|
64
|
+
async def list_sessions():
|
|
65
|
+
sessions_data = []
|
|
66
|
+
|
|
67
|
+
for session_id, session in trace_manager.sessions.items():
|
|
68
|
+
info = SessionInfo(
|
|
69
|
+
session_id=session_id,
|
|
70
|
+
trace_id=session.trace_id,
|
|
71
|
+
eval_set_id=session.eval_set_id,
|
|
72
|
+
span_count=len(session.spans),
|
|
73
|
+
is_complete=session.is_complete,
|
|
74
|
+
started_at=session.started_at.isoformat(),
|
|
75
|
+
metadata=session.metadata,
|
|
76
|
+
invocations=session.invocations if session.is_complete and session.invocations else None,
|
|
77
|
+
)
|
|
78
|
+
sessions_data.append(info)
|
|
79
|
+
|
|
80
|
+
return StandardResponse(data=sessions_data)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
@streaming_router.post("/create-eval-set", response_model=StandardResponse[CreateEvalSetData])
|
|
84
|
+
async def create_eval_set_from_session(request: CreateEvalSetRequest):
|
|
85
|
+
"""Convert a session's trace into an EvalSet."""
|
|
86
|
+
session = trace_manager.sessions.get(request.session_id)
|
|
87
|
+
if not session:
|
|
88
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
89
|
+
|
|
90
|
+
try:
|
|
91
|
+
trace_file = await trace_manager._save_spans_to_temp_file(session)
|
|
92
|
+
logger.debug(
|
|
93
|
+
"Session %s: %d spans, %d logs saved to %s",
|
|
94
|
+
request.session_id,
|
|
95
|
+
len(session.spans),
|
|
96
|
+
len(session.logs),
|
|
97
|
+
trace_file,
|
|
98
|
+
)
|
|
99
|
+
loader = OtlpJsonLoader()
|
|
100
|
+
traces = loader.load(str(trace_file))
|
|
101
|
+
|
|
102
|
+
if not traces:
|
|
103
|
+
raise HTTPException(
|
|
104
|
+
status_code=400,
|
|
105
|
+
detail=(
|
|
106
|
+
f"No traces found in session (spans={len(session.spans)}, "
|
|
107
|
+
f"logs={len(session.logs)}). If using the SDK with langchain/openai, "
|
|
108
|
+
f"ensure opentelemetry-instrumentation-openai-v2 is installed."
|
|
109
|
+
),
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
conversion_results = convert_traces(traces)
|
|
113
|
+
if not conversion_results:
|
|
114
|
+
raise HTTPException(status_code=400, detail="Failed to convert trace")
|
|
115
|
+
|
|
116
|
+
all_invocations = []
|
|
117
|
+
for conv_result in conversion_results:
|
|
118
|
+
all_invocations.extend(conv_result.invocations)
|
|
119
|
+
|
|
120
|
+
logger.debug(f"Creating eval set from {len(all_invocations)} invocations")
|
|
121
|
+
for i, inv in enumerate(all_invocations):
|
|
122
|
+
tool_count = len(inv.intermediate_data.tool_uses) if inv.intermediate_data else 0
|
|
123
|
+
logger.debug(f" Invocation {i}: {tool_count} tool calls")
|
|
124
|
+
|
|
125
|
+
conversation = []
|
|
126
|
+
for inv in all_invocations:
|
|
127
|
+
inv_dict = {
|
|
128
|
+
"invocation_id": inv.invocation_id,
|
|
129
|
+
"user_content": inv.user_content.model_dump(exclude_none=True) if inv.user_content else None,
|
|
130
|
+
}
|
|
131
|
+
if inv.final_response:
|
|
132
|
+
inv_dict["final_response"] = inv.final_response.model_dump(exclude_none=True)
|
|
133
|
+
if inv.intermediate_data:
|
|
134
|
+
inv_dict["intermediate_data"] = inv.intermediate_data.model_dump(exclude_none=True)
|
|
135
|
+
|
|
136
|
+
conversation.append(inv_dict)
|
|
137
|
+
|
|
138
|
+
eval_set = {
|
|
139
|
+
"eval_set_id": request.eval_set_id,
|
|
140
|
+
"eval_cases": [
|
|
141
|
+
{
|
|
142
|
+
"eval_id": "case_1",
|
|
143
|
+
"conversation": conversation,
|
|
144
|
+
}
|
|
145
|
+
],
|
|
146
|
+
}
|
|
147
|
+
|
|
148
|
+
return StandardResponse(
|
|
149
|
+
data=CreateEvalSetData(
|
|
150
|
+
eval_set=eval_set,
|
|
151
|
+
num_invocations=len(all_invocations),
|
|
152
|
+
)
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
except HTTPException:
|
|
156
|
+
raise
|
|
157
|
+
except Exception as exc:
|
|
158
|
+
logger.exception("Failed to create eval set")
|
|
159
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
160
|
+
|
|
161
|
+
|
|
162
|
+
@streaming_router.post("/evaluate-sessions", response_model=StandardResponse[EvaluateSessionsData])
|
|
163
|
+
async def evaluate_sessions(request: EvaluateSessionsRequest):
|
|
164
|
+
"""Evaluate all sessions against a golden session converted to EvalSet."""
|
|
165
|
+
golden_session = trace_manager.sessions.get(request.golden_session_id)
|
|
166
|
+
if not golden_session:
|
|
167
|
+
raise HTTPException(status_code=404, detail="Golden session not found")
|
|
168
|
+
|
|
169
|
+
try:
|
|
170
|
+
eval_set_response = await create_eval_set_from_session(
|
|
171
|
+
CreateEvalSetRequest(
|
|
172
|
+
session_id=request.golden_session_id,
|
|
173
|
+
eval_set_id=request.eval_set_id,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
import tempfile
|
|
178
|
+
|
|
179
|
+
eval_set_file = tempfile.NamedTemporaryFile(mode="w", suffix=".json", delete=False)
|
|
180
|
+
json.dump(eval_set_response.data.eval_set, eval_set_file)
|
|
181
|
+
eval_set_file.close()
|
|
182
|
+
|
|
183
|
+
sessions_to_evaluate = [
|
|
184
|
+
(session_id, session) for session_id, session in trace_manager.sessions.items() if session.is_complete
|
|
185
|
+
]
|
|
186
|
+
|
|
187
|
+
logger.info(
|
|
188
|
+
"Evaluating %d complete sessions (of %d total)", len(sessions_to_evaluate), len(trace_manager.sessions)
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
sem = asyncio.Semaphore(5)
|
|
192
|
+
|
|
193
|
+
async def eval_one_session(session_id: str, session) -> SessionEvalResult:
|
|
194
|
+
async with sem:
|
|
195
|
+
try:
|
|
196
|
+
trace_file = await trace_manager._save_spans_to_temp_file(session)
|
|
197
|
+
|
|
198
|
+
config = EvalRunConfig(
|
|
199
|
+
trace_files=[str(trace_file)],
|
|
200
|
+
trace_format="otlp-json",
|
|
201
|
+
eval_set_file=eval_set_file.name,
|
|
202
|
+
metrics=request.metrics,
|
|
203
|
+
judge_model=request.judge_model,
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
eval_result = await run_evaluation(config)
|
|
207
|
+
|
|
208
|
+
if eval_result.trace_results:
|
|
209
|
+
trace_result = eval_result.trace_results[0]
|
|
210
|
+
return SessionEvalResult(
|
|
211
|
+
session_id=session_id,
|
|
212
|
+
trace_id=trace_result.trace_id,
|
|
213
|
+
num_invocations=trace_result.num_invocations,
|
|
214
|
+
metric_results=[
|
|
215
|
+
{
|
|
216
|
+
"metricName": mr.metric_name,
|
|
217
|
+
"score": mr.score,
|
|
218
|
+
"evalStatus": mr.eval_status,
|
|
219
|
+
"error": mr.error,
|
|
220
|
+
}
|
|
221
|
+
for mr in trace_result.metric_results
|
|
222
|
+
],
|
|
223
|
+
)
|
|
224
|
+
else:
|
|
225
|
+
logger.warning("No trace results for session %s", session_id)
|
|
226
|
+
return SessionEvalResult(session_id=session_id, error="No trace results")
|
|
227
|
+
|
|
228
|
+
except Exception as exc:
|
|
229
|
+
logger.error(f"Failed to evaluate session {session_id}: {exc}", exc_info=True)
|
|
230
|
+
return SessionEvalResult(session_id=session_id, error=str(exc))
|
|
231
|
+
|
|
232
|
+
results = await asyncio.gather(*[eval_one_session(sid, sess) for sid, sess in sessions_to_evaluate])
|
|
233
|
+
|
|
234
|
+
logger.info("Evaluation complete. Total results: %d", len(results))
|
|
235
|
+
|
|
236
|
+
return StandardResponse(
|
|
237
|
+
data=EvaluateSessionsData(
|
|
238
|
+
golden_session_id=request.golden_session_id,
|
|
239
|
+
eval_set_id=request.eval_set_id,
|
|
240
|
+
results=results,
|
|
241
|
+
)
|
|
242
|
+
)
|
|
243
|
+
|
|
244
|
+
except HTTPException:
|
|
245
|
+
raise
|
|
246
|
+
except Exception as exc:
|
|
247
|
+
logger.exception("Failed to evaluate sessions")
|
|
248
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
249
|
+
|
|
250
|
+
|
|
251
|
+
@streaming_router.post("/prepare-evaluation", response_model=StandardResponse[PrepareEvaluationData])
|
|
252
|
+
async def prepare_evaluation(request: PrepareEvaluationRequest):
|
|
253
|
+
"""Prepare evaluation by saving traces and eval set as downloadable files."""
|
|
254
|
+
golden_session = trace_manager.sessions.get(request.golden_session_id)
|
|
255
|
+
if not golden_session:
|
|
256
|
+
raise HTTPException(status_code=404, detail="Golden session not found")
|
|
257
|
+
|
|
258
|
+
try:
|
|
259
|
+
eval_set_response = await create_eval_set_from_session(
|
|
260
|
+
CreateEvalSetRequest(
|
|
261
|
+
session_id=request.golden_session_id,
|
|
262
|
+
eval_set_id=f"golden_{request.golden_session_id}",
|
|
263
|
+
)
|
|
264
|
+
)
|
|
265
|
+
|
|
266
|
+
import os
|
|
267
|
+
import tempfile
|
|
268
|
+
|
|
269
|
+
temp_dir = tempfile.gettempdir()
|
|
270
|
+
|
|
271
|
+
eval_set_file = os.path.join(temp_dir, f"eval_set_{request.golden_session_id}.json")
|
|
272
|
+
with open(eval_set_file, "w") as f: # noqa: ASYNC230
|
|
273
|
+
json.dump(eval_set_response.data.eval_set, f)
|
|
274
|
+
|
|
275
|
+
trace_files = []
|
|
276
|
+
for session_id in request.session_ids:
|
|
277
|
+
session = trace_manager.sessions.get(session_id)
|
|
278
|
+
if not session or not session.is_complete:
|
|
279
|
+
continue
|
|
280
|
+
|
|
281
|
+
trace_file = await trace_manager._save_spans_to_temp_file(session)
|
|
282
|
+
trace_files.append(
|
|
283
|
+
{
|
|
284
|
+
"session_id": session_id,
|
|
285
|
+
"file_path": str(trace_file),
|
|
286
|
+
}
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
return StandardResponse(
|
|
290
|
+
data=PrepareEvaluationData(
|
|
291
|
+
eval_set_url=f"/api/streaming/download/{os.path.basename(eval_set_file)}",
|
|
292
|
+
trace_urls=[f"/api/streaming/download/{os.path.basename(tf['file_path'])}" for tf in trace_files],
|
|
293
|
+
num_traces=len(trace_files),
|
|
294
|
+
)
|
|
295
|
+
)
|
|
296
|
+
|
|
297
|
+
except HTTPException:
|
|
298
|
+
raise
|
|
299
|
+
except Exception as exc:
|
|
300
|
+
logger.exception("Failed to prepare evaluation")
|
|
301
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
302
|
+
|
|
303
|
+
|
|
304
|
+
@streaming_router.get("/download/{filename}")
|
|
305
|
+
async def download_file(filename: str):
|
|
306
|
+
"""Download a prepared trace or eval set file."""
|
|
307
|
+
import os
|
|
308
|
+
import tempfile
|
|
309
|
+
|
|
310
|
+
temp_dir = tempfile.gettempdir()
|
|
311
|
+
file_path = os.path.join(temp_dir, filename)
|
|
312
|
+
|
|
313
|
+
if not os.path.exists(file_path): # noqa: ASYNC240
|
|
314
|
+
raise HTTPException(status_code=404, detail="File not found")
|
|
315
|
+
|
|
316
|
+
if not file_path.startswith(temp_dir):
|
|
317
|
+
raise HTTPException(status_code=400, detail="Invalid file path")
|
|
318
|
+
|
|
319
|
+
return FileResponse(file_path, media_type="application/json", filename=filename)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
@streaming_router.post("/get-trace", response_model=StandardResponse[GetTraceData])
|
|
323
|
+
async def get_trace(request: GetTraceRequest):
|
|
324
|
+
session = trace_manager.sessions.get(request.session_id)
|
|
325
|
+
if not session:
|
|
326
|
+
raise HTTPException(status_code=404, detail="Session not found")
|
|
327
|
+
|
|
328
|
+
try:
|
|
329
|
+
import tempfile
|
|
330
|
+
|
|
331
|
+
temp_file = tempfile.NamedTemporaryFile(mode="w", suffix=".jsonl", delete=False)
|
|
332
|
+
|
|
333
|
+
unified_trace_id = session.trace_id
|
|
334
|
+
|
|
335
|
+
has_genai_spans = any(
|
|
336
|
+
span.get("attributes", [])
|
|
337
|
+
and any(
|
|
338
|
+
attr.get("key") in (OTEL_GENAI_REQUEST_MODEL, OTEL_GENAI_INPUT_MESSAGES)
|
|
339
|
+
for attr in span.get("attributes", [])
|
|
340
|
+
)
|
|
341
|
+
for span in session.spans
|
|
342
|
+
)
|
|
343
|
+
|
|
344
|
+
if has_genai_spans and not session.logs:
|
|
345
|
+
logger.warning(
|
|
346
|
+
"Session %s has GenAI spans but no logs. "
|
|
347
|
+
"Message content will be missing unless spans already enriched.",
|
|
348
|
+
request.session_id,
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
enriched_spans = enrich_spans_with_logs(session.spans, session.logs)
|
|
352
|
+
|
|
353
|
+
for span in enriched_spans:
|
|
354
|
+
span_copy = span.copy()
|
|
355
|
+
span_copy["traceId"] = unified_trace_id
|
|
356
|
+
temp_file.write(json.dumps(span_copy) + "\n")
|
|
357
|
+
|
|
358
|
+
temp_file.close()
|
|
359
|
+
|
|
360
|
+
with open(temp_file.name) as f: # noqa: ASYNC230
|
|
361
|
+
trace_content = f.read()
|
|
362
|
+
|
|
363
|
+
return StandardResponse(
|
|
364
|
+
data=GetTraceData(
|
|
365
|
+
session_id=request.session_id,
|
|
366
|
+
trace_content=trace_content,
|
|
367
|
+
num_spans=len(enriched_spans),
|
|
368
|
+
)
|
|
369
|
+
)
|
|
370
|
+
|
|
371
|
+
except Exception as exc:
|
|
372
|
+
logger.exception("Failed to get trace")
|
|
373
|
+
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
|
@@ -0,0 +1,234 @@
|
|
|
1
|
+
"""Built-in ADK metric evaluation — criteria construction and evaluator resolution."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import asyncio
|
|
6
|
+
import inspect
|
|
7
|
+
import logging
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
from google.adk.evaluation.eval_case import Invocation, get_all_tool_calls
|
|
11
|
+
from google.adk.evaluation.eval_metrics import (
|
|
12
|
+
BaseCriterion,
|
|
13
|
+
EvalMetric,
|
|
14
|
+
HallucinationsCriterion,
|
|
15
|
+
JudgeModelOptions,
|
|
16
|
+
LlmAsAJudgeCriterion,
|
|
17
|
+
LlmBackedUserSimulatorCriterion,
|
|
18
|
+
RubricsBasedCriterion,
|
|
19
|
+
ToolTrajectoryCriterion,
|
|
20
|
+
)
|
|
21
|
+
from google.adk.evaluation.eval_rubrics import Rubric, RubricContent
|
|
22
|
+
from google.adk.evaluation.evaluator import EvaluationResult, Evaluator
|
|
23
|
+
|
|
24
|
+
logger = logging.getLogger(__name__)
|
|
25
|
+
|
|
26
|
+
METRICS_NEEDING_EXPECTED = {
|
|
27
|
+
"tool_trajectory_avg_score",
|
|
28
|
+
"response_match_score",
|
|
29
|
+
"response_evaluation_score",
|
|
30
|
+
"final_response_match_v2",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
METRICS_NEEDING_LLM = {
|
|
34
|
+
"final_response_match_v2",
|
|
35
|
+
"rubric_based_final_response_quality_v1",
|
|
36
|
+
"hallucinations_v1",
|
|
37
|
+
"rubric_based_tool_use_quality_v1",
|
|
38
|
+
"per_turn_user_simulator_quality_v1",
|
|
39
|
+
}
|
|
40
|
+
|
|
41
|
+
METRICS_NEEDING_GCP = {
|
|
42
|
+
"response_evaluation_score",
|
|
43
|
+
"safety_v1",
|
|
44
|
+
}
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def rubric_strings_to_objects(rubric_texts: list[str]) -> list[Rubric]:
|
|
48
|
+
"""Convert plain-text rubric strings into ADK Rubric objects."""
|
|
49
|
+
return [
|
|
50
|
+
Rubric(
|
|
51
|
+
rubric_id=f"rubric_{i}",
|
|
52
|
+
rubric_content=RubricContent(text_property=text),
|
|
53
|
+
)
|
|
54
|
+
for i, text in enumerate(rubric_texts)
|
|
55
|
+
]
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def build_eval_metric(
|
|
59
|
+
metric_name: str,
|
|
60
|
+
judge_model: str | None,
|
|
61
|
+
threshold: float | None,
|
|
62
|
+
rubrics: list[str] | None = None,
|
|
63
|
+
) -> EvalMetric:
|
|
64
|
+
"""Construct an ADK ``EvalMetric`` with the appropriate criterion."""
|
|
65
|
+
effective_threshold = threshold if threshold is not None else 0.5
|
|
66
|
+
|
|
67
|
+
criterion: BaseCriterion | None = None
|
|
68
|
+
|
|
69
|
+
if metric_name == "tool_trajectory_avg_score":
|
|
70
|
+
criterion = ToolTrajectoryCriterion(threshold=effective_threshold)
|
|
71
|
+
elif metric_name == "final_response_match_v2":
|
|
72
|
+
judge_opts = JudgeModelOptions()
|
|
73
|
+
if judge_model:
|
|
74
|
+
judge_opts.judge_model = judge_model
|
|
75
|
+
criterion = LlmAsAJudgeCriterion(
|
|
76
|
+
threshold=effective_threshold,
|
|
77
|
+
judge_model_options=judge_opts,
|
|
78
|
+
)
|
|
79
|
+
elif metric_name == "hallucinations_v1":
|
|
80
|
+
judge_opts = JudgeModelOptions()
|
|
81
|
+
if judge_model:
|
|
82
|
+
judge_opts.judge_model = judge_model
|
|
83
|
+
criterion = HallucinationsCriterion(
|
|
84
|
+
threshold=effective_threshold,
|
|
85
|
+
judge_model_options=judge_opts,
|
|
86
|
+
)
|
|
87
|
+
elif metric_name in (
|
|
88
|
+
"rubric_based_final_response_quality_v1",
|
|
89
|
+
"rubric_based_tool_use_quality_v1",
|
|
90
|
+
):
|
|
91
|
+
judge_opts = JudgeModelOptions()
|
|
92
|
+
if judge_model:
|
|
93
|
+
judge_opts.judge_model = judge_model
|
|
94
|
+
rubric_objects = rubric_strings_to_objects(rubrics) if rubrics else []
|
|
95
|
+
criterion = RubricsBasedCriterion(
|
|
96
|
+
threshold=effective_threshold,
|
|
97
|
+
judge_model_options=judge_opts,
|
|
98
|
+
rubrics=rubric_objects,
|
|
99
|
+
)
|
|
100
|
+
elif metric_name == "per_turn_user_simulator_quality_v1":
|
|
101
|
+
judge_opts = JudgeModelOptions()
|
|
102
|
+
if judge_model:
|
|
103
|
+
judge_opts.judge_model = judge_model
|
|
104
|
+
criterion = LlmBackedUserSimulatorCriterion(
|
|
105
|
+
threshold=effective_threshold,
|
|
106
|
+
judge_model_options=judge_opts,
|
|
107
|
+
)
|
|
108
|
+
elif metric_name in ("response_match_score", "response_evaluation_score", "safety_v1"):
|
|
109
|
+
criterion = BaseCriterion(threshold=effective_threshold)
|
|
110
|
+
|
|
111
|
+
return EvalMetric(
|
|
112
|
+
metric_name=metric_name,
|
|
113
|
+
threshold=effective_threshold,
|
|
114
|
+
criterion=criterion,
|
|
115
|
+
)
|
|
116
|
+
|
|
117
|
+
|
|
118
|
+
def get_evaluator(eval_metric: EvalMetric) -> Evaluator:
|
|
119
|
+
"""Resolve an evaluator, using direct imports for known lightweight metrics
|
|
120
|
+
to avoid pulling in heavy deps (numpy/rouge_score) via the full registry."""
|
|
121
|
+
name = eval_metric.metric_name
|
|
122
|
+
|
|
123
|
+
_DIRECT_EVALUATORS: dict[str, tuple[str, str]] = {
|
|
124
|
+
"tool_trajectory_avg_score": (
|
|
125
|
+
"google.adk.evaluation.trajectory_evaluator",
|
|
126
|
+
"TrajectoryEvaluator",
|
|
127
|
+
),
|
|
128
|
+
}
|
|
129
|
+
|
|
130
|
+
if name in _DIRECT_EVALUATORS:
|
|
131
|
+
import importlib
|
|
132
|
+
|
|
133
|
+
mod_path, cls_name = _DIRECT_EVALUATORS[name]
|
|
134
|
+
mod = importlib.import_module(mod_path)
|
|
135
|
+
evaluator_cls = getattr(mod, cls_name)
|
|
136
|
+
return evaluator_cls(eval_metric=eval_metric) # type: ignore[call-arg]
|
|
137
|
+
|
|
138
|
+
from google.adk.evaluation.metric_evaluator_registry import (
|
|
139
|
+
DEFAULT_METRIC_EVALUATOR_REGISTRY,
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
return DEFAULT_METRIC_EVALUATOR_REGISTRY.get_evaluator(eval_metric)
|
|
143
|
+
|
|
144
|
+
|
|
145
|
+
def extract_trajectory_details(eval_result: EvaluationResult) -> dict[str, Any]:
|
|
146
|
+
"""Extract expected vs actual tool call details from trajectory evaluation."""
|
|
147
|
+
comparisons = []
|
|
148
|
+
|
|
149
|
+
for per_inv_result in eval_result.per_invocation_results:
|
|
150
|
+
actual_inv = per_inv_result.actual_invocation
|
|
151
|
+
expected_inv = per_inv_result.expected_invocation
|
|
152
|
+
|
|
153
|
+
actual_tools = []
|
|
154
|
+
expected_tools = []
|
|
155
|
+
|
|
156
|
+
if actual_inv and actual_inv.intermediate_data:
|
|
157
|
+
tool_calls = get_all_tool_calls(actual_inv.intermediate_data)
|
|
158
|
+
actual_tools = [{"name": tc.name, "args": tc.args} for tc in tool_calls]
|
|
159
|
+
|
|
160
|
+
if expected_inv and expected_inv.intermediate_data:
|
|
161
|
+
tool_calls = get_all_tool_calls(expected_inv.intermediate_data)
|
|
162
|
+
expected_tools = [{"name": tc.name, "args": tc.args} for tc in tool_calls]
|
|
163
|
+
|
|
164
|
+
comparisons.append(
|
|
165
|
+
{
|
|
166
|
+
"invocation_id": actual_inv.invocation_id if actual_inv else None,
|
|
167
|
+
"expected": expected_tools,
|
|
168
|
+
"actual": actual_tools,
|
|
169
|
+
"matched": per_inv_result.score == 1.0,
|
|
170
|
+
}
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
return {"comparisons": comparisons}
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
async def evaluate_builtin_metric(
|
|
177
|
+
metric_name: str,
|
|
178
|
+
actual_invocations: list[Invocation],
|
|
179
|
+
expected_invocations: list[Invocation] | None,
|
|
180
|
+
judge_model: str | None,
|
|
181
|
+
threshold: float | None,
|
|
182
|
+
) -> dict[str, Any]:
|
|
183
|
+
"""Evaluate a single built-in ADK metric.
|
|
184
|
+
|
|
185
|
+
Returns a dict with keys: metric_name, score, eval_status,
|
|
186
|
+
per_invocation_scores, error, details.
|
|
187
|
+
"""
|
|
188
|
+
from .runner import MetricResult
|
|
189
|
+
|
|
190
|
+
if metric_name in METRICS_NEEDING_EXPECTED and not expected_invocations:
|
|
191
|
+
return MetricResult(
|
|
192
|
+
metric_name=metric_name,
|
|
193
|
+
error=(
|
|
194
|
+
f"Metric '{metric_name}' requires expected invocations "
|
|
195
|
+
f"(golden eval set), but none were provided or matched."
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
try:
|
|
200
|
+
eval_metric = build_eval_metric(metric_name, judge_model, threshold)
|
|
201
|
+
evaluator: Evaluator = get_evaluator(eval_metric)
|
|
202
|
+
|
|
203
|
+
if inspect.iscoroutinefunction(evaluator.evaluate_invocations):
|
|
204
|
+
eval_result: EvaluationResult = await evaluator.evaluate_invocations(
|
|
205
|
+
actual_invocations=actual_invocations,
|
|
206
|
+
expected_invocations=expected_invocations,
|
|
207
|
+
)
|
|
208
|
+
else:
|
|
209
|
+
eval_result: EvaluationResult = await asyncio.to_thread(
|
|
210
|
+
evaluator.evaluate_invocations,
|
|
211
|
+
actual_invocations=actual_invocations,
|
|
212
|
+
expected_invocations=expected_invocations,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
per_inv_scores = [r.score for r in eval_result.per_invocation_results]
|
|
216
|
+
|
|
217
|
+
details = None
|
|
218
|
+
if metric_name == "tool_trajectory_avg_score":
|
|
219
|
+
details = extract_trajectory_details(eval_result)
|
|
220
|
+
|
|
221
|
+
return MetricResult(
|
|
222
|
+
metric_name=metric_name,
|
|
223
|
+
score=eval_result.overall_score,
|
|
224
|
+
eval_status=eval_result.overall_eval_status.name,
|
|
225
|
+
per_invocation_scores=per_inv_scores,
|
|
226
|
+
details=details,
|
|
227
|
+
)
|
|
228
|
+
|
|
229
|
+
except Exception as exc:
|
|
230
|
+
logger.exception("Failed to evaluate metric '%s'", metric_name)
|
|
231
|
+
return MetricResult(
|
|
232
|
+
metric_name=metric_name,
|
|
233
|
+
error=str(exc),
|
|
234
|
+
)
|