bat-cli 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. add/__init__.py +3 -0
  2. add/client.py +16 -0
  3. bat_cli-0.1.0.dist-info/METADATA +231 -0
  4. bat_cli-0.1.0.dist-info/RECORD +47 -0
  5. bat_cli-0.1.0.dist-info/WHEEL +5 -0
  6. bat_cli-0.1.0.dist-info/entry_points.txt +2 -0
  7. bat_cli-0.1.0.dist-info/top_level.txt +8 -0
  8. build/__init__.py +3 -0
  9. build/build.py +79 -0
  10. cli.py +260 -0
  11. create/__init__.py +3 -0
  12. create/agent.py +312 -0
  13. create/templates/agent/.dockerignore +3 -0
  14. create/templates/agent/.env.template +4 -0
  15. create/templates/agent/.python-version +1 -0
  16. create/templates/agent/Dockerfile +37 -0
  17. create/templates/agent/Makefile +34 -0
  18. create/templates/agent/README.md +1 -0
  19. create/templates/agent/__main__.py +2 -0
  20. create/templates/agent/agent.json.template +12 -0
  21. create/templates/agent/agent.spec +45 -0
  22. create/templates/agent/config.yaml +1 -0
  23. create/templates/agent/llm_client.py.template +36 -0
  24. create/templates/agent/pyproject.toml.template +9 -0
  25. create/templates/agent/src/__init__.py +0 -0
  26. create/templates/agent/src/graph.py +50 -0
  27. create/templates/agent/src/llm_clients/__init__.py +0 -0
  28. create/templates/agent/tests/__init__.py +0 -0
  29. eval/__init__.py +1 -0
  30. eval/commands.py +562 -0
  31. eval/engine/__init__.py +1 -0
  32. eval/engine/adapter.py +251 -0
  33. eval/engine/bench_runner.py +149 -0
  34. eval/engine/contracts.py +115 -0
  35. eval/engine/eval_config.py +294 -0
  36. eval/engine/evaluator.py +85 -0
  37. eval/engine/metrics/__init__.py +1 -0
  38. eval/engine/metrics/llm_evaluators.py +383 -0
  39. eval/engine/metrics/metrics.py +135 -0
  40. eval/engine/metrics/qualitative_helpers.py +64 -0
  41. eval/engine/orchestrator.py +157 -0
  42. eval/engine/plotter.py +347 -0
  43. image_defaults.py +80 -0
  44. push/__init__.py +3 -0
  45. push/push.py +58 -0
  46. set/__init__.py +3 -0
  47. set/env.py +50 -0
eval/engine/adapter.py ADDED
@@ -0,0 +1,251 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from typing import Any
6
+
7
+ from a2a.client import A2ACardResolver, ClientConfig, ClientFactory
8
+ from a2a.helpers import (
9
+ get_artifact_text,
10
+ get_message_text,
11
+ new_text_message,
12
+ )
13
+ from a2a.types import Role, SendMessageRequest, StreamResponse, TaskState
14
+ from google.protobuf.json_format import MessageToDict
15
+ from httpx import AsyncClient
16
+
17
+ from .contracts import EpisodeResult, EpisodeTrace, TaskSpec, TraceEvent
18
+
19
+
20
+ TERMINAL_STATUSES = {"completed", "error", "input-required"}
21
+
22
+
23
+ _TASK_STATE_TO_STR = {
24
+ TaskState.TASK_STATE_SUBMITTED: "working",
25
+ TaskState.TASK_STATE_WORKING: "working",
26
+ TaskState.TASK_STATE_INPUT_REQUIRED: "input-required",
27
+ TaskState.TASK_STATE_COMPLETED: "completed",
28
+ TaskState.TASK_STATE_FAILED: "error",
29
+ TaskState.TASK_STATE_CANCELED: "error",
30
+ TaskState.TASK_STATE_REJECTED: "error",
31
+ }
32
+
33
+
34
+ def _to_dict(value: Any) -> dict[str, Any]:
35
+ if value is None:
36
+ return {}
37
+ if isinstance(value, dict):
38
+ return value
39
+ if hasattr(value, "model_dump"):
40
+ dumped = value.model_dump(by_alias=True)
41
+ return dumped if isinstance(dumped, dict) else {}
42
+ return {}
43
+
44
+
45
+ def _struct_to_dict(metadata_struct: Any) -> dict[str, Any]:
46
+ if metadata_struct is None:
47
+ return {}
48
+ try:
49
+ return MessageToDict(metadata_struct) or {}
50
+ except Exception:
51
+ return {}
52
+
53
+
54
+ def _chunk_key(chunk: StreamResponse) -> str:
55
+ if chunk.HasField("status_update"):
56
+ message_id = chunk.status_update.status.message.message_id
57
+ if message_id:
58
+ return f"status:{message_id}"
59
+ if chunk.HasField("artifact_update"):
60
+ artifact_id = chunk.artifact_update.artifact.artifact_id
61
+ if artifact_id:
62
+ return f"artifact:{artifact_id}"
63
+ if chunk.HasField("message"):
64
+ if chunk.message.message_id:
65
+ return f"message:{chunk.message.message_id}"
66
+ if chunk.HasField("task"):
67
+ if chunk.task.id:
68
+ return f"task:{chunk.task.id}"
69
+ return f"raw:{chunk.SerializeToString().hex()}"
70
+
71
+
72
+ def _extract_status_and_content(chunk: StreamResponse) -> tuple[str | None, str]:
73
+ if chunk.HasField("message"):
74
+ return "completed", get_message_text(chunk.message)
75
+ if chunk.HasField("artifact_update"):
76
+ return "completed", get_artifact_text(chunk.artifact_update.artifact)
77
+ if chunk.HasField("status_update"):
78
+ state = chunk.status_update.status.state
79
+ status = _TASK_STATE_TO_STR.get(state)
80
+ content = ""
81
+ if chunk.status_update.status.HasField("message"):
82
+ content = get_message_text(chunk.status_update.status.message)
83
+ return status, content
84
+ if chunk.HasField("task"):
85
+ state = chunk.task.status.state
86
+ status = _TASK_STATE_TO_STR.get(state)
87
+ texts = [get_artifact_text(a) for a in chunk.task.artifacts]
88
+ return status, "\n".join(t for t in texts if t)
89
+ return None, ""
90
+
91
+
92
+ def _extract_metadata(chunk: StreamResponse) -> dict[str, Any]:
93
+ metadata_struct: Any = None
94
+ if chunk.HasField("status_update"):
95
+ metadata_struct = chunk.status_update.metadata
96
+ elif chunk.HasField("artifact_update"):
97
+ metadata_struct = chunk.artifact_update.metadata
98
+ elif chunk.HasField("message"):
99
+ metadata_struct = chunk.message.metadata
100
+ elif chunk.HasField("task"):
101
+ metadata_struct = chunk.task.metadata
102
+
103
+ metadata = _struct_to_dict(metadata_struct)
104
+
105
+ if chunk.HasField("artifact_update"):
106
+ artifact_metadata = _struct_to_dict(chunk.artifact_update.artifact.metadata)
107
+ if artifact_metadata:
108
+ merged = dict(artifact_metadata)
109
+ merged.update(metadata)
110
+ metadata = merged
111
+
112
+ return metadata
113
+
114
+
115
+ def _normalize_usage(metadata: dict[str, Any]) -> dict[str, Any]:
116
+ usage = _to_dict(metadata.get("usage"))
117
+ if not usage:
118
+ return {}
119
+
120
+ input_tokens = int(usage.get("input_tokens") or 0)
121
+ output_tokens = int(usage.get("output_tokens") or 0)
122
+ total_tokens = int(usage.get("total_tokens") or (input_tokens + output_tokens))
123
+ inference_time = float(usage.get("inference_time") or 0.0)
124
+
125
+ if input_tokens == 0 and output_tokens == 0 and total_tokens == 0 and inference_time == 0.0:
126
+ return {}
127
+
128
+ return {
129
+ "input_tokens": input_tokens,
130
+ "output_tokens": output_tokens,
131
+ "total_tokens": total_tokens,
132
+ "inference_time": inference_time,
133
+ }
134
+
135
+
136
+ def _add_usage(total: dict[str, Any], incremental: dict[str, Any]) -> dict[str, Any]:
137
+ return {
138
+ "input_tokens": int(total.get("input_tokens") or 0) + int(incremental.get("input_tokens") or 0),
139
+ "output_tokens": int(total.get("output_tokens") or 0) + int(incremental.get("output_tokens") or 0),
140
+ "total_tokens": int(total.get("total_tokens") or 0) + int(incremental.get("total_tokens") or 0),
141
+ "inference_time": float(total.get("inference_time") or 0.0)
142
+ + float(incremental.get("inference_time") or 0.0),
143
+ }
144
+
145
+
146
+ def _extract_tool_calls(metadata: dict[str, Any]) -> list[dict[str, Any]]:
147
+ trace = _to_dict(metadata.get("trace"))
148
+ tool_calls = trace.get("tool_calls")
149
+ if not isinstance(tool_calls, list):
150
+ return []
151
+ return [item for item in tool_calls if isinstance(item, dict)]
152
+
153
+
154
+ def _tool_call_key(tool_call: dict[str, Any]) -> str:
155
+ call_id = tool_call.get("id")
156
+ if isinstance(call_id, str) and call_id:
157
+ return f"id:{call_id}"
158
+ return json.dumps(tool_call, sort_keys=True, ensure_ascii=True, default=str)
159
+
160
+
161
+ class BatA2AAdapter:
162
+ def __init__(self, agent_url: str, request_timeout_s: float = 180.0, max_events: int = 200) -> None:
163
+ self.agent_url = agent_url
164
+ self.request_timeout_s = request_timeout_s
165
+ self.max_events = max_events
166
+
167
+ async def run_task(self, task: TaskSpec, *, thread_id: str) -> EpisodeResult:
168
+ t0_perf = time.perf_counter()
169
+ trace = EpisodeTrace()
170
+
171
+ usage_total: dict[str, Any] = {
172
+ "input_tokens": 0,
173
+ "output_tokens": 0,
174
+ "total_tokens": 0,
175
+ "inference_time": 0.0,
176
+ }
177
+ usage_seen: set[str] = set()
178
+ tool_calls_seen: set[str] = set()
179
+
180
+ last_status: str | None = None
181
+ last_content = ""
182
+
183
+ async with AsyncClient(timeout=self.request_timeout_s) as httpx_client:
184
+ resolver = A2ACardResolver(httpx_client=httpx_client, base_url=self.agent_url)
185
+ agent_card = await resolver.get_agent_card()
186
+
187
+ client = ClientFactory(ClientConfig(httpx_client=httpx_client, streaming=True)).create(card=agent_card)
188
+
189
+ try:
190
+ for turn in task.turns:
191
+ turn_started = False
192
+ message = new_text_message(
193
+ text=turn,
194
+ context_id=thread_id,
195
+ role=Role.ROLE_USER,
196
+ )
197
+ stream = client.send_message(SendMessageRequest(message=message))
198
+
199
+ async for chunk in stream:
200
+ metadata = _extract_metadata(chunk)
201
+
202
+ usage = _normalize_usage(metadata)
203
+ if usage:
204
+ usage_key = f"{_chunk_key(chunk)}::{json.dumps(usage, sort_keys=True, ensure_ascii=True)}"
205
+ if usage_key not in usage_seen:
206
+ usage_seen.add(usage_key)
207
+ usage_total = _add_usage(usage_total, usage)
208
+
209
+ for tool_call in _extract_tool_calls(metadata):
210
+ key = _tool_call_key(tool_call)
211
+ if key in tool_calls_seen:
212
+ continue
213
+ tool_calls_seen.add(key)
214
+ trace.tool_calls.append(tool_call)
215
+
216
+ status, content = _extract_status_and_content(chunk)
217
+ if status is None:
218
+ continue
219
+
220
+ if len(trace.events) < self.max_events:
221
+ trace.events.append(
222
+ TraceEvent(
223
+ t_ms=(time.perf_counter() - t0_perf) * 1000.0,
224
+ task_status=status,
225
+ content_preview=content,
226
+ user_input=turn if not turn_started else None,
227
+ )
228
+ )
229
+ turn_started = True
230
+
231
+ if status in TERMINAL_STATUSES:
232
+ last_status = status
233
+ last_content = content or ""
234
+
235
+ except Exception as exc:
236
+ last_status = "error"
237
+ last_content = f"{type(exc).__name__}: {exc}"
238
+
239
+ trace.timings["wall_ms"] = (time.perf_counter() - t0_perf) * 1000.0
240
+ trace.usage = usage_total
241
+
242
+ final_status = last_status or "error"
243
+ final_output = last_content or ""
244
+
245
+ return EpisodeResult(
246
+ task_id=task.id,
247
+ final_status=final_status,
248
+ final_output=final_output,
249
+ trace=trace,
250
+ aux={"agent_url": self.agent_url},
251
+ )
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import time
5
+ from dataclasses import dataclass
6
+ from pathlib import Path
7
+
8
+ from .contracts import EpisodeResult, TaskSpec
9
+ from .evaluator import EpisodeEvaluator
10
+
11
+
12
+ _QUALITATIVE_FIELDS = (
13
+ "response_relevance",
14
+ "task_completion_quality",
15
+ "hallucination_score",
16
+ "tool_call_appropriateness",
17
+ )
18
+
19
+
20
+ def _episode_passed(ep: EpisodeResult) -> bool:
21
+ return ep.verdict.passed if ep.verdict is not None else False
22
+
23
+
24
+ def _average_qualitative_scores(results: list[EpisodeResult]) -> dict[str, float]:
25
+ out: dict[str, float] = {}
26
+ for field in _QUALITATIVE_FIELDS:
27
+ values = [
28
+ getattr(r.qualitative_scores, field)
29
+ for r in results
30
+ if r.qualitative_scores is not None
31
+ and getattr(r.qualitative_scores, field) is not None
32
+ ]
33
+ if values:
34
+ out[field] = sum(values) / len(values)
35
+ return out
36
+
37
+
38
+ def _safe_task_id(task_id: str) -> str:
39
+ return "".join(ch if ch.isalnum() or ch in ("-", "_", ".") else "_" for ch in task_id)
40
+
41
+
42
+ def _write_json(path: Path, obj: object) -> None:
43
+ path.write_text(json.dumps(obj, indent=2, ensure_ascii=False), encoding="utf-8")
44
+
45
+
46
+ @dataclass
47
+ class RunConfig:
48
+ run_name: str
49
+ out_dir: str = "output"
50
+ k: int = 1
51
+ model: str = "default"
52
+ task_id: str = ""
53
+
54
+
55
+ class BenchRunner:
56
+ def __init__(
57
+ self,
58
+ adapter: object,
59
+ config: RunConfig,
60
+ evaluator: EpisodeEvaluator | None = None,
61
+ ):
62
+ self.adapter = adapter
63
+ self.config = config
64
+ self.evaluator = evaluator or EpisodeEvaluator()
65
+ self.task_dir: Path | None = None
66
+ self.run_dir: Path | None = None
67
+
68
+ def _episodes_dir(self) -> Path:
69
+ if self.run_dir is None:
70
+ raise ValueError("run_dir is not initialized")
71
+ return self.run_dir / "episodes"
72
+
73
+ def persist_results(self, results: list[EpisodeResult]) -> None:
74
+ episodes_dir = self._episodes_dir()
75
+ episodes_dir.mkdir(parents=True, exist_ok=True)
76
+
77
+ for episode in results:
78
+ task_file_id = _safe_task_id(episode.task_id)
79
+ attempt_index = int(episode.aux.get("attempt_index", 0))
80
+ (episodes_dir / f"{task_file_id}__try{attempt_index}.json").write_text(
81
+ episode.model_dump_json(indent=2),
82
+ encoding="utf-8",
83
+ )
84
+
85
+ async def run(self, tasks: list[TaskSpec]) -> list[EpisodeResult]:
86
+ stamp = time.strftime("%Y%m%d_%H%M%S", time.gmtime())
87
+ task_id = self.config.task_id or stamp
88
+ safe_model_name = self.config.model.replace(":", "-")
89
+
90
+ self.task_dir = Path(self.config.out_dir) / task_id
91
+ self.run_dir = self.task_dir / f"{self.config.run_name}_{safe_model_name}"
92
+ self._episodes_dir().mkdir(parents=True, exist_ok=True)
93
+ self._run_timestamp = stamp
94
+
95
+ all_attempts: list[EpisodeResult] = []
96
+
97
+ for task in tasks:
98
+ for i in range(max(1, int(self.config.k))):
99
+ thread_id = f"{task.id}__try{i}"
100
+ episode = await self.adapter.run_task(task=task, thread_id=thread_id)
101
+ episode.verdict = self.evaluator.evaluate(
102
+ episode.final_status, episode.final_output, episode.trace.tool_calls, task.expected
103
+ )
104
+ episode.expected_outcome = task.expected.expected_outcome
105
+ episode.model_name = self.config.model
106
+ episode.aux["attempt_index"] = i
107
+ all_attempts.append(episode)
108
+
109
+ self.persist_results(all_attempts)
110
+ return all_attempts
111
+
112
+ def write_summary(self, results: list[EpisodeResult]) -> None:
113
+ if self.run_dir is None:
114
+ raise ValueError("run_dir is not initialized; call run() first")
115
+
116
+ display_model_name = (
117
+ self.config.model.split(":")[-1] if ":" in self.config.model else self.config.model
118
+ )
119
+
120
+ attempts_by_task: dict[str, list[EpisodeResult]] = {}
121
+ for attempt in results:
122
+ attempts_by_task.setdefault(attempt.task_id, []).append(attempt)
123
+
124
+ attempts = []
125
+ for task_id, task_attempts in attempts_by_task.items():
126
+ total = len(task_attempts)
127
+ passed = sum(1 for ep in task_attempts if _episode_passed(ep))
128
+ attempts.append(
129
+ {
130
+ "task_id": task_id,
131
+ "attempts": total,
132
+ "passed": passed,
133
+ "failed": total - passed,
134
+ "success_percentage": (passed / total) * 100.0 if total else 0.0,
135
+ }
136
+ )
137
+
138
+ _write_json(
139
+ self.run_dir / "summary.json",
140
+ {
141
+ "run_name": self.config.run_name,
142
+ "timestamp_utc": getattr(self, "_run_timestamp", ""),
143
+ "k": self.config.k,
144
+ "model_name": display_model_name,
145
+ "attempts": attempts,
146
+ "qualitative_scores": _average_qualitative_scores(results),
147
+ "passed": sum(1 for ep in results if _episode_passed(ep)),
148
+ },
149
+ )
@@ -0,0 +1,115 @@
1
+ from __future__ import annotations
2
+ from pathlib import Path
3
+ from typing import Any, Literal
4
+
5
+ from pydantic import BaseModel, ConfigDict, Field
6
+
7
+
8
+ AgentTaskStatus = Literal["working", "input-required", "completed", "error"]
9
+
10
+
11
+ class ExpectedToolCall(BaseModel):
12
+ model_config = ConfigDict(extra="ignore")
13
+
14
+ name: str
15
+ args_subset: dict[str, Any] = Field(default_factory=dict)
16
+ times: int = 1
17
+
18
+
19
+ class TaskExpected(BaseModel):
20
+ model_config = ConfigDict(extra="ignore")
21
+
22
+ # Expected final status from the agent. None = skip status check entirely.
23
+ status: AgentTaskStatus | None = "completed"
24
+ # Free-text description of the desired outcome, evaluated semantically by the LLM judges.
25
+ expected_outcome: str | None = None
26
+ # All phrases must appear in the final output text. None/empty = skip substring check.
27
+ output_must_contain: list[str] | None = None
28
+ tool_calls: list[ExpectedToolCall] = Field(default_factory=list)
29
+
30
+
31
+ class TaskSpec(BaseModel):
32
+ model_config = ConfigDict(extra="ignore")
33
+
34
+ id: str
35
+ turns: list[str]
36
+ expected: TaskExpected = Field(default_factory=TaskExpected)
37
+ meta: dict[str, Any] = Field(default_factory=dict)
38
+
39
+
40
+ class TraceEvent(BaseModel):
41
+ model_config = ConfigDict(extra="ignore")
42
+
43
+ t_ms: float
44
+ task_status: AgentTaskStatus
45
+ content_preview: str
46
+ user_input: str | None = None
47
+
48
+
49
+ class EpisodeTrace(BaseModel):
50
+ model_config = ConfigDict(extra="ignore")
51
+
52
+ events: list[TraceEvent] = Field(default_factory=list)
53
+ usage: dict[str, Any] = Field(default_factory=dict)
54
+ timings: dict[str, float] = Field(default_factory=dict)
55
+ tool_calls: list[dict[str, Any]] = Field(default_factory=list)
56
+
57
+
58
+ class QualitativeScores(BaseModel):
59
+ model_config = ConfigDict(extra="ignore")
60
+
61
+ response_relevance: float | None = None
62
+ task_completion_quality: float | None = None
63
+ hallucination_score: float | None = None
64
+ tool_call_appropriateness: float | None = None
65
+ judge_reasoning: dict[str, str] = Field(default_factory=dict)
66
+
67
+
68
+ class EpisodeVerdict(BaseModel):
69
+ model_config = ConfigDict(extra="ignore")
70
+
71
+ passed: bool
72
+ reason: str = ""
73
+
74
+
75
+ class EpisodeResult(BaseModel):
76
+ model_config = ConfigDict(extra="ignore")
77
+
78
+ model_name: str | None = None
79
+ task_id: str
80
+ expected_outcome: str | None = None
81
+ final_status: AgentTaskStatus
82
+ final_output: str
83
+ verdict: EpisodeVerdict | None = None
84
+ qualitative_scores: QualitativeScores | None = None
85
+ aux: dict[str, Any] = Field(default_factory=dict)
86
+ trace: EpisodeTrace = Field(default_factory=EpisodeTrace)
87
+
88
+
89
+ class ModelSpec(BaseModel):
90
+ provider: str
91
+ model: str
92
+ base_url: str | None = None
93
+ env: dict[str, str] = Field(default_factory=dict)
94
+
95
+
96
+ class JudgeSpec(BaseModel):
97
+ provider: str
98
+ model: str
99
+ base_url: str | None = None
100
+ api_key_env: str | None = None
101
+ env: dict[str, str] = Field(default_factory=dict)
102
+ prompts: dict[str, str] = Field(default_factory=dict)
103
+
104
+
105
+ class EvalConfig(BaseModel):
106
+ dataset: Path
107
+ output_dir: Path
108
+ agent_url: str
109
+ agent_startup_timeout_s: int = 45
110
+ agent_shutdown_timeout_s: int = 10
111
+ k: int
112
+ qualitative: bool
113
+ run_name: str
114
+ models: list[ModelSpec]
115
+ judge: JudgeSpec | None