bat-cli 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- add/__init__.py +3 -0
- add/client.py +16 -0
- bat_cli-0.1.0.dist-info/METADATA +231 -0
- bat_cli-0.1.0.dist-info/RECORD +47 -0
- bat_cli-0.1.0.dist-info/WHEEL +5 -0
- bat_cli-0.1.0.dist-info/entry_points.txt +2 -0
- bat_cli-0.1.0.dist-info/top_level.txt +8 -0
- build/__init__.py +3 -0
- build/build.py +79 -0
- cli.py +260 -0
- create/__init__.py +3 -0
- create/agent.py +312 -0
- create/templates/agent/.dockerignore +3 -0
- create/templates/agent/.env.template +4 -0
- create/templates/agent/.python-version +1 -0
- create/templates/agent/Dockerfile +37 -0
- create/templates/agent/Makefile +34 -0
- create/templates/agent/README.md +1 -0
- create/templates/agent/__main__.py +2 -0
- create/templates/agent/agent.json.template +12 -0
- create/templates/agent/agent.spec +45 -0
- create/templates/agent/config.yaml +1 -0
- create/templates/agent/llm_client.py.template +36 -0
- create/templates/agent/pyproject.toml.template +9 -0
- create/templates/agent/src/__init__.py +0 -0
- create/templates/agent/src/graph.py +50 -0
- create/templates/agent/src/llm_clients/__init__.py +0 -0
- create/templates/agent/tests/__init__.py +0 -0
- eval/__init__.py +1 -0
- eval/commands.py +562 -0
- eval/engine/__init__.py +1 -0
- eval/engine/adapter.py +251 -0
- eval/engine/bench_runner.py +149 -0
- eval/engine/contracts.py +115 -0
- eval/engine/eval_config.py +294 -0
- eval/engine/evaluator.py +85 -0
- eval/engine/metrics/__init__.py +1 -0
- eval/engine/metrics/llm_evaluators.py +383 -0
- eval/engine/metrics/metrics.py +135 -0
- eval/engine/metrics/qualitative_helpers.py +64 -0
- eval/engine/orchestrator.py +157 -0
- eval/engine/plotter.py +347 -0
- image_defaults.py +80 -0
- push/__init__.py +3 -0
- push/push.py +58 -0
- set/__init__.py +3 -0
- set/env.py +50 -0
eval/engine/adapter.py
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
|
|
8
|
+
from a2a.helpers import (
|
|
9
|
+
get_artifact_text,
|
|
10
|
+
get_message_text,
|
|
11
|
+
new_text_message,
|
|
12
|
+
)
|
|
13
|
+
from a2a.types import Role, SendMessageRequest, StreamResponse, TaskState
|
|
14
|
+
from google.protobuf.json_format import MessageToDict
|
|
15
|
+
from httpx import AsyncClient
|
|
16
|
+
|
|
17
|
+
from .contracts import EpisodeResult, EpisodeTrace, TaskSpec, TraceEvent
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
TERMINAL_STATUSES = {"completed", "error", "input-required"}
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
_TASK_STATE_TO_STR = {
|
|
24
|
+
TaskState.TASK_STATE_SUBMITTED: "working",
|
|
25
|
+
TaskState.TASK_STATE_WORKING: "working",
|
|
26
|
+
TaskState.TASK_STATE_INPUT_REQUIRED: "input-required",
|
|
27
|
+
TaskState.TASK_STATE_COMPLETED: "completed",
|
|
28
|
+
TaskState.TASK_STATE_FAILED: "error",
|
|
29
|
+
TaskState.TASK_STATE_CANCELED: "error",
|
|
30
|
+
TaskState.TASK_STATE_REJECTED: "error",
|
|
31
|
+
}
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def _to_dict(value: Any) -> dict[str, Any]:
|
|
35
|
+
if value is None:
|
|
36
|
+
return {}
|
|
37
|
+
if isinstance(value, dict):
|
|
38
|
+
return value
|
|
39
|
+
if hasattr(value, "model_dump"):
|
|
40
|
+
dumped = value.model_dump(by_alias=True)
|
|
41
|
+
return dumped if isinstance(dumped, dict) else {}
|
|
42
|
+
return {}
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _struct_to_dict(metadata_struct: Any) -> dict[str, Any]:
|
|
46
|
+
if metadata_struct is None:
|
|
47
|
+
return {}
|
|
48
|
+
try:
|
|
49
|
+
return MessageToDict(metadata_struct) or {}
|
|
50
|
+
except Exception:
|
|
51
|
+
return {}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _chunk_key(chunk: StreamResponse) -> str:
|
|
55
|
+
if chunk.HasField("status_update"):
|
|
56
|
+
message_id = chunk.status_update.status.message.message_id
|
|
57
|
+
if message_id:
|
|
58
|
+
return f"status:{message_id}"
|
|
59
|
+
if chunk.HasField("artifact_update"):
|
|
60
|
+
artifact_id = chunk.artifact_update.artifact.artifact_id
|
|
61
|
+
if artifact_id:
|
|
62
|
+
return f"artifact:{artifact_id}"
|
|
63
|
+
if chunk.HasField("message"):
|
|
64
|
+
if chunk.message.message_id:
|
|
65
|
+
return f"message:{chunk.message.message_id}"
|
|
66
|
+
if chunk.HasField("task"):
|
|
67
|
+
if chunk.task.id:
|
|
68
|
+
return f"task:{chunk.task.id}"
|
|
69
|
+
return f"raw:{chunk.SerializeToString().hex()}"
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def _extract_status_and_content(chunk: StreamResponse) -> tuple[str | None, str]:
|
|
73
|
+
if chunk.HasField("message"):
|
|
74
|
+
return "completed", get_message_text(chunk.message)
|
|
75
|
+
if chunk.HasField("artifact_update"):
|
|
76
|
+
return "completed", get_artifact_text(chunk.artifact_update.artifact)
|
|
77
|
+
if chunk.HasField("status_update"):
|
|
78
|
+
state = chunk.status_update.status.state
|
|
79
|
+
status = _TASK_STATE_TO_STR.get(state)
|
|
80
|
+
content = ""
|
|
81
|
+
if chunk.status_update.status.HasField("message"):
|
|
82
|
+
content = get_message_text(chunk.status_update.status.message)
|
|
83
|
+
return status, content
|
|
84
|
+
if chunk.HasField("task"):
|
|
85
|
+
state = chunk.task.status.state
|
|
86
|
+
status = _TASK_STATE_TO_STR.get(state)
|
|
87
|
+
texts = [get_artifact_text(a) for a in chunk.task.artifacts]
|
|
88
|
+
return status, "\n".join(t for t in texts if t)
|
|
89
|
+
return None, ""
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def _extract_metadata(chunk: StreamResponse) -> dict[str, Any]:
|
|
93
|
+
metadata_struct: Any = None
|
|
94
|
+
if chunk.HasField("status_update"):
|
|
95
|
+
metadata_struct = chunk.status_update.metadata
|
|
96
|
+
elif chunk.HasField("artifact_update"):
|
|
97
|
+
metadata_struct = chunk.artifact_update.metadata
|
|
98
|
+
elif chunk.HasField("message"):
|
|
99
|
+
metadata_struct = chunk.message.metadata
|
|
100
|
+
elif chunk.HasField("task"):
|
|
101
|
+
metadata_struct = chunk.task.metadata
|
|
102
|
+
|
|
103
|
+
metadata = _struct_to_dict(metadata_struct)
|
|
104
|
+
|
|
105
|
+
if chunk.HasField("artifact_update"):
|
|
106
|
+
artifact_metadata = _struct_to_dict(chunk.artifact_update.artifact.metadata)
|
|
107
|
+
if artifact_metadata:
|
|
108
|
+
merged = dict(artifact_metadata)
|
|
109
|
+
merged.update(metadata)
|
|
110
|
+
metadata = merged
|
|
111
|
+
|
|
112
|
+
return metadata
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def _normalize_usage(metadata: dict[str, Any]) -> dict[str, Any]:
|
|
116
|
+
usage = _to_dict(metadata.get("usage"))
|
|
117
|
+
if not usage:
|
|
118
|
+
return {}
|
|
119
|
+
|
|
120
|
+
input_tokens = int(usage.get("input_tokens") or 0)
|
|
121
|
+
output_tokens = int(usage.get("output_tokens") or 0)
|
|
122
|
+
total_tokens = int(usage.get("total_tokens") or (input_tokens + output_tokens))
|
|
123
|
+
inference_time = float(usage.get("inference_time") or 0.0)
|
|
124
|
+
|
|
125
|
+
if input_tokens == 0 and output_tokens == 0 and total_tokens == 0 and inference_time == 0.0:
|
|
126
|
+
return {}
|
|
127
|
+
|
|
128
|
+
return {
|
|
129
|
+
"input_tokens": input_tokens,
|
|
130
|
+
"output_tokens": output_tokens,
|
|
131
|
+
"total_tokens": total_tokens,
|
|
132
|
+
"inference_time": inference_time,
|
|
133
|
+
}
|
|
134
|
+
|
|
135
|
+
|
|
136
|
+
def _add_usage(total: dict[str, Any], incremental: dict[str, Any]) -> dict[str, Any]:
|
|
137
|
+
return {
|
|
138
|
+
"input_tokens": int(total.get("input_tokens") or 0) + int(incremental.get("input_tokens") or 0),
|
|
139
|
+
"output_tokens": int(total.get("output_tokens") or 0) + int(incremental.get("output_tokens") or 0),
|
|
140
|
+
"total_tokens": int(total.get("total_tokens") or 0) + int(incremental.get("total_tokens") or 0),
|
|
141
|
+
"inference_time": float(total.get("inference_time") or 0.0)
|
|
142
|
+
+ float(incremental.get("inference_time") or 0.0),
|
|
143
|
+
}
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def _extract_tool_calls(metadata: dict[str, Any]) -> list[dict[str, Any]]:
|
|
147
|
+
trace = _to_dict(metadata.get("trace"))
|
|
148
|
+
tool_calls = trace.get("tool_calls")
|
|
149
|
+
if not isinstance(tool_calls, list):
|
|
150
|
+
return []
|
|
151
|
+
return [item for item in tool_calls if isinstance(item, dict)]
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def _tool_call_key(tool_call: dict[str, Any]) -> str:
|
|
155
|
+
call_id = tool_call.get("id")
|
|
156
|
+
if isinstance(call_id, str) and call_id:
|
|
157
|
+
return f"id:{call_id}"
|
|
158
|
+
return json.dumps(tool_call, sort_keys=True, ensure_ascii=True, default=str)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
class BatA2AAdapter:
|
|
162
|
+
def __init__(self, agent_url: str, request_timeout_s: float = 180.0, max_events: int = 200) -> None:
|
|
163
|
+
self.agent_url = agent_url
|
|
164
|
+
self.request_timeout_s = request_timeout_s
|
|
165
|
+
self.max_events = max_events
|
|
166
|
+
|
|
167
|
+
async def run_task(self, task: TaskSpec, *, thread_id: str) -> EpisodeResult:
|
|
168
|
+
t0_perf = time.perf_counter()
|
|
169
|
+
trace = EpisodeTrace()
|
|
170
|
+
|
|
171
|
+
usage_total: dict[str, Any] = {
|
|
172
|
+
"input_tokens": 0,
|
|
173
|
+
"output_tokens": 0,
|
|
174
|
+
"total_tokens": 0,
|
|
175
|
+
"inference_time": 0.0,
|
|
176
|
+
}
|
|
177
|
+
usage_seen: set[str] = set()
|
|
178
|
+
tool_calls_seen: set[str] = set()
|
|
179
|
+
|
|
180
|
+
last_status: str | None = None
|
|
181
|
+
last_content = ""
|
|
182
|
+
|
|
183
|
+
async with AsyncClient(timeout=self.request_timeout_s) as httpx_client:
|
|
184
|
+
resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.agent_url)
|
|
185
|
+
agent_card = await resolver.get_agent_card()
|
|
186
|
+
|
|
187
|
+
client = ClientFactory(ClientConfig(httpx_client=httpx_client, streaming=True)).create(card=agent_card)
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
for turn in task.turns:
|
|
191
|
+
turn_started = False
|
|
192
|
+
message = new_text_message(
|
|
193
|
+
text=turn,
|
|
194
|
+
context_id=thread_id,
|
|
195
|
+
role=Role.ROLE_USER,
|
|
196
|
+
)
|
|
197
|
+
stream = client.send_message(SendMessageRequest(message=message))
|
|
198
|
+
|
|
199
|
+
async for chunk in stream:
|
|
200
|
+
metadata = _extract_metadata(chunk)
|
|
201
|
+
|
|
202
|
+
usage = _normalize_usage(metadata)
|
|
203
|
+
if usage:
|
|
204
|
+
usage_key = f"{_chunk_key(chunk)}::{json.dumps(usage, sort_keys=True, ensure_ascii=True)}"
|
|
205
|
+
if usage_key not in usage_seen:
|
|
206
|
+
usage_seen.add(usage_key)
|
|
207
|
+
usage_total = _add_usage(usage_total, usage)
|
|
208
|
+
|
|
209
|
+
for tool_call in _extract_tool_calls(metadata):
|
|
210
|
+
key = _tool_call_key(tool_call)
|
|
211
|
+
if key in tool_calls_seen:
|
|
212
|
+
continue
|
|
213
|
+
tool_calls_seen.add(key)
|
|
214
|
+
trace.tool_calls.append(tool_call)
|
|
215
|
+
|
|
216
|
+
status, content = _extract_status_and_content(chunk)
|
|
217
|
+
if status is None:
|
|
218
|
+
continue
|
|
219
|
+
|
|
220
|
+
if len(trace.events) < self.max_events:
|
|
221
|
+
trace.events.append(
|
|
222
|
+
TraceEvent(
|
|
223
|
+
t_ms=(time.perf_counter() - t0_perf) * 1000.0,
|
|
224
|
+
task_status=status,
|
|
225
|
+
content_preview=content,
|
|
226
|
+
user_input=turn if not turn_started else None,
|
|
227
|
+
)
|
|
228
|
+
)
|
|
229
|
+
turn_started = True
|
|
230
|
+
|
|
231
|
+
if status in TERMINAL_STATUSES:
|
|
232
|
+
last_status = status
|
|
233
|
+
last_content = content or ""
|
|
234
|
+
|
|
235
|
+
except Exception as exc:
|
|
236
|
+
last_status = "error"
|
|
237
|
+
last_content = f"{type(exc).__name__}: {exc}"
|
|
238
|
+
|
|
239
|
+
trace.timings["wall_ms"] = (time.perf_counter() - t0_perf) * 1000.0
|
|
240
|
+
trace.usage = usage_total
|
|
241
|
+
|
|
242
|
+
final_status = last_status or "error"
|
|
243
|
+
final_output = last_content or ""
|
|
244
|
+
|
|
245
|
+
return EpisodeResult(
|
|
246
|
+
task_id=task.id,
|
|
247
|
+
final_status=final_status,
|
|
248
|
+
final_output=final_output,
|
|
249
|
+
trace=trace,
|
|
250
|
+
aux={"agent_url": self.agent_url},
|
|
251
|
+
)
|
|
@@ -0,0 +1,149 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import time
|
|
5
|
+
from dataclasses import dataclass
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
|
|
8
|
+
from .contracts import EpisodeResult, TaskSpec
|
|
9
|
+
from .evaluator import EpisodeEvaluator
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
_QUALITATIVE_FIELDS = (
|
|
13
|
+
"response_relevance",
|
|
14
|
+
"task_completion_quality",
|
|
15
|
+
"hallucination_score",
|
|
16
|
+
"tool_call_appropriateness",
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def _episode_passed(ep: EpisodeResult) -> bool:
|
|
21
|
+
return ep.verdict.passed if ep.verdict is not None else False
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def _average_qualitative_scores(results: list[EpisodeResult]) -> dict[str, float]:
|
|
25
|
+
out: dict[str, float] = {}
|
|
26
|
+
for field in _QUALITATIVE_FIELDS:
|
|
27
|
+
values = [
|
|
28
|
+
getattr(r.qualitative_scores, field)
|
|
29
|
+
for r in results
|
|
30
|
+
if r.qualitative_scores is not None
|
|
31
|
+
and getattr(r.qualitative_scores, field) is not None
|
|
32
|
+
]
|
|
33
|
+
if values:
|
|
34
|
+
out[field] = sum(values) / len(values)
|
|
35
|
+
return out
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _safe_task_id(task_id: str) -> str:
|
|
39
|
+
return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in task_id)
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def _write_json(path: Path, obj: object) -> None:
|
|
43
|
+
path.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
@dataclass
|
|
47
|
+
class RunConfig:
|
|
48
|
+
run_name: str
|
|
49
|
+
out_dir: str = "output"
|
|
50
|
+
k: int = 1
|
|
51
|
+
model: str = "default"
|
|
52
|
+
task_id: str = ""
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class BenchRunner:
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
adapter: object,
|
|
59
|
+
config: RunConfig,
|
|
60
|
+
evaluator: EpisodeEvaluator | None = None,
|
|
61
|
+
):
|
|
62
|
+
self.adapter = adapter
|
|
63
|
+
self.config = config
|
|
64
|
+
self.evaluator = evaluator or EpisodeEvaluator()
|
|
65
|
+
self.task_dir: Path | None = None
|
|
66
|
+
self.run_dir: Path | None = None
|
|
67
|
+
|
|
68
|
+
def _episodes_dir(self) -> Path:
|
|
69
|
+
if self.run_dir is None:
|
|
70
|
+
raise ValueError("run_dir is not initialized")
|
|
71
|
+
return self.run_dir / "episodes"
|
|
72
|
+
|
|
73
|
+
def persist_results(self, results: list[EpisodeResult]) -> None:
|
|
74
|
+
episodes_dir = self._episodes_dir()
|
|
75
|
+
episodes_dir.mkdir(parents=True, exist_ok=True)
|
|
76
|
+
|
|
77
|
+
for episode in results:
|
|
78
|
+
task_file_id = _safe_task_id(episode.task_id)
|
|
79
|
+
attempt_index = int(episode.aux.get("attempt_index", 0))
|
|
80
|
+
(episodes_dir / f"{task_file_id}__try{attempt_index}.json").write_text(
|
|
81
|
+
episode.model_dump_json(indent=2),
|
|
82
|
+
encoding="utf-8",
|
|
83
|
+
)
|
|
84
|
+
|
|
85
|
+
async def run(self, tasks: list[TaskSpec]) -> list[EpisodeResult]:
|
|
86
|
+
stamp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
|
|
87
|
+
task_id = self.config.task_id or stamp
|
|
88
|
+
safe_model_name = self.config.model.replace(":", "-")
|
|
89
|
+
|
|
90
|
+
self.task_dir = Path(self.config.out_dir) / task_id
|
|
91
|
+
self.run_dir = self.task_dir / f"{self.config.run_name}_{safe_model_name}"
|
|
92
|
+
self._episodes_dir().mkdir(parents=True, exist_ok=True)
|
|
93
|
+
self._run_timestamp = stamp
|
|
94
|
+
|
|
95
|
+
all_attempts: list[EpisodeResult] = []
|
|
96
|
+
|
|
97
|
+
for task in tasks:
|
|
98
|
+
for i in range(max(1, int(self.config.k))):
|
|
99
|
+
thread_id = f"{task.id}__try{i}"
|
|
100
|
+
episode = await self.adapter.run_task(task=task, thread_id=thread_id)
|
|
101
|
+
episode.verdict = self.evaluator.evaluate(
|
|
102
|
+
episode.final_status, episode.final_output, episode.trace.tool_calls, task.expected
|
|
103
|
+
)
|
|
104
|
+
episode.expected_outcome = task.expected.expected_outcome
|
|
105
|
+
episode.model_name = self.config.model
|
|
106
|
+
episode.aux["attempt_index"] = i
|
|
107
|
+
all_attempts.append(episode)
|
|
108
|
+
|
|
109
|
+
self.persist_results(all_attempts)
|
|
110
|
+
return all_attempts
|
|
111
|
+
|
|
112
|
+
def write_summary(self, results: list[EpisodeResult]) -> None:
|
|
113
|
+
if self.run_dir is None:
|
|
114
|
+
raise ValueError("run_dir is not initialized; call run() first")
|
|
115
|
+
|
|
116
|
+
display_model_name = (
|
|
117
|
+
self.config.model.split(":")[-1] if ":" in self.config.model else self.config.model
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
attempts_by_task: dict[str, list[EpisodeResult]] = {}
|
|
121
|
+
for attempt in results:
|
|
122
|
+
attempts_by_task.setdefault(attempt.task_id, []).append(attempt)
|
|
123
|
+
|
|
124
|
+
attempts = []
|
|
125
|
+
for task_id, task_attempts in attempts_by_task.items():
|
|
126
|
+
total = len(task_attempts)
|
|
127
|
+
passed = sum(1 for ep in task_attempts if _episode_passed(ep))
|
|
128
|
+
attempts.append(
|
|
129
|
+
{
|
|
130
|
+
"task_id": task_id,
|
|
131
|
+
"attempts": total,
|
|
132
|
+
"passed": passed,
|
|
133
|
+
"failed": total - passed,
|
|
134
|
+
"success_percentage": (passed / total) * 100.0 if total else 0.0,
|
|
135
|
+
}
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
_write_json(
|
|
139
|
+
self.run_dir / "summary.json",
|
|
140
|
+
{
|
|
141
|
+
"run_name": self.config.run_name,
|
|
142
|
+
"timestamp_utc": getattr(self, "_run_timestamp", ""),
|
|
143
|
+
"k": self.config.k,
|
|
144
|
+
"model_name": display_model_name,
|
|
145
|
+
"attempts": attempts,
|
|
146
|
+
"qualitative_scores": _average_qualitative_scores(results),
|
|
147
|
+
"passed": sum(1 for ep in results if _episode_passed(ep)),
|
|
148
|
+
},
|
|
149
|
+
)
|
eval/engine/contracts.py
ADDED
|
@@ -0,0 +1,115 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
from pathlib import Path
|
|
3
|
+
from typing import Any, Literal
|
|
4
|
+
|
|
5
|
+
from pydantic import BaseModel, ConfigDict, Field
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
AgentTaskStatus = Literal["working", "input-required", "completed", "error"]
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class ExpectedToolCall(BaseModel):
|
|
12
|
+
model_config = ConfigDict(extra="ignore")
|
|
13
|
+
|
|
14
|
+
name: str
|
|
15
|
+
args_subset: dict[str, Any] = Field(default_factory=dict)
|
|
16
|
+
times: int = 1
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TaskExpected(BaseModel):
|
|
20
|
+
model_config = ConfigDict(extra="ignore")
|
|
21
|
+
|
|
22
|
+
# Expected final status from the agent. None = skip status check entirely.
|
|
23
|
+
status: AgentTaskStatus | None = "completed"
|
|
24
|
+
# Free-text description of the desired outcome, evaluated semantically by the LLM judges.
|
|
25
|
+
expected_outcome: str | None = None
|
|
26
|
+
# All phrases must appear in the final output text. None/empty = skip substring check.
|
|
27
|
+
output_must_contain: list[str] | None = None
|
|
28
|
+
tool_calls: list[ExpectedToolCall] = Field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
class TaskSpec(BaseModel):
|
|
32
|
+
model_config = ConfigDict(extra="ignore")
|
|
33
|
+
|
|
34
|
+
id: str
|
|
35
|
+
turns: list[str]
|
|
36
|
+
expected: TaskExpected = Field(default_factory=TaskExpected)
|
|
37
|
+
meta: dict[str, Any] = Field(default_factory=dict)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class TraceEvent(BaseModel):
|
|
41
|
+
model_config = ConfigDict(extra="ignore")
|
|
42
|
+
|
|
43
|
+
t_ms: float
|
|
44
|
+
task_status: AgentTaskStatus
|
|
45
|
+
content_preview: str
|
|
46
|
+
user_input: str | None = None
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class EpisodeTrace(BaseModel):
|
|
50
|
+
model_config = ConfigDict(extra="ignore")
|
|
51
|
+
|
|
52
|
+
events: list[TraceEvent] = Field(default_factory=list)
|
|
53
|
+
usage: dict[str, Any] = Field(default_factory=dict)
|
|
54
|
+
timings: dict[str, float] = Field(default_factory=dict)
|
|
55
|
+
tool_calls: list[dict[str, Any]] = Field(default_factory=list)
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
class QualitativeScores(BaseModel):
|
|
59
|
+
model_config = ConfigDict(extra="ignore")
|
|
60
|
+
|
|
61
|
+
response_relevance: float | None = None
|
|
62
|
+
task_completion_quality: float | None = None
|
|
63
|
+
hallucination_score: float | None = None
|
|
64
|
+
tool_call_appropriateness: float | None = None
|
|
65
|
+
judge_reasoning: dict[str, str] = Field(default_factory=dict)
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
class EpisodeVerdict(BaseModel):
|
|
69
|
+
model_config = ConfigDict(extra="ignore")
|
|
70
|
+
|
|
71
|
+
passed: bool
|
|
72
|
+
reason: str = ""
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
class EpisodeResult(BaseModel):
|
|
76
|
+
model_config = ConfigDict(extra="ignore")
|
|
77
|
+
|
|
78
|
+
model_name: str | None = None
|
|
79
|
+
task_id: str
|
|
80
|
+
expected_outcome: str | None = None
|
|
81
|
+
final_status: AgentTaskStatus
|
|
82
|
+
final_output: str
|
|
83
|
+
verdict: EpisodeVerdict | None = None
|
|
84
|
+
qualitative_scores: QualitativeScores | None = None
|
|
85
|
+
aux: dict[str, Any] = Field(default_factory=dict)
|
|
86
|
+
trace: EpisodeTrace = Field(default_factory=EpisodeTrace)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class ModelSpec(BaseModel):
|
|
90
|
+
provider: str
|
|
91
|
+
model: str
|
|
92
|
+
base_url: str | None = None
|
|
93
|
+
env: dict[str, str] = Field(default_factory=dict)
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
class JudgeSpec(BaseModel):
|
|
97
|
+
provider: str
|
|
98
|
+
model: str
|
|
99
|
+
base_url: str | None = None
|
|
100
|
+
api_key_env: str | None = None
|
|
101
|
+
env: dict[str, str] = Field(default_factory=dict)
|
|
102
|
+
prompts: dict[str, str] = Field(default_factory=dict)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class EvalConfig(BaseModel):
|
|
106
|
+
dataset: Path
|
|
107
|
+
output_dir: Path
|
|
108
|
+
agent_url: str
|
|
109
|
+
agent_startup_timeout_s: int = 45
|
|
110
|
+
agent_shutdown_timeout_s: int = 10
|
|
111
|
+
k: int
|
|
112
|
+
qualitative: bool
|
|
113
|
+
run_name: str
|
|
114
|
+
models: list[ModelSpec]
|
|
115
|
+
judge: JudgeSpec | None
|