agenteval-py 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agenteval/__init__.py +46 -0
- agenteval/adapters/__init__.py +9 -0
- agenteval/adapters/anthropic_adapter.py +80 -0
- agenteval/adapters/langchain_adapter.py +135 -0
- agenteval/adapters/openai_adapter.py +80 -0
- agenteval/assertions.py +289 -0
- agenteval/cli.py +93 -0
- agenteval/exceptions.py +17 -0
- agenteval/models.py +123 -0
- agenteval/py.typed +0 -0
- agenteval/registry.py +99 -0
- agenteval/reporter.py +139 -0
- agenteval/runner.py +119 -0
- agenteval/suite.py +181 -0
- agenteval/tracer.py +303 -0
- agenteval_py-0.1.0.dist-info/METADATA +561 -0
- agenteval_py-0.1.0.dist-info/RECORD +20 -0
- agenteval_py-0.1.0.dist-info/WHEEL +4 -0
- agenteval_py-0.1.0.dist-info/entry_points.txt +2 -0
- agenteval_py-0.1.0.dist-info/licenses/LICENSE +21 -0
agenteval/__init__.py
ADDED
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
"""agenteval — evaluation toolkit for LLM agents.
|
|
2
|
+
|
|
3
|
+
Quick start::
|
|
4
|
+
|
|
5
|
+
import agenteval
|
|
6
|
+
|
|
7
|
+
@agenteval.test(n=20, threshold=0.8)
|
|
8
|
+
async def test_my_agent(tracer: agenteval.Tracer) -> None:
|
|
9
|
+
search = tracer.wrap(my_search_tool)
|
|
10
|
+
|
|
11
|
+
async with tracer.run(input="find Python tutorials") as run:
|
|
12
|
+
result = await my_agent("find Python tutorials", search=search)
|
|
13
|
+
run.set_output(result)
|
|
14
|
+
|
|
15
|
+
tracer.assert_that().called_tool("my_search_tool").no_errors().check()
|
|
16
|
+
|
|
17
|
+
# Run a single test directly:
|
|
18
|
+
result = agenteval.run(test_my_agent, n=10)
|
|
19
|
+
|
|
20
|
+
# Discover and run all @agenteval.test functions in a directory:
|
|
21
|
+
suite = agenteval.run_suite("tests/")
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from agenteval import adapters
|
|
25
|
+
from agenteval.assertions import AssertionSet
|
|
26
|
+
from agenteval.models import AgentTrace, SuiteResult, TestResult, ToolCall
|
|
27
|
+
from agenteval.registry import test
|
|
28
|
+
from agenteval.reporter import RichReporter
|
|
29
|
+
from agenteval.runner import run
|
|
30
|
+
from agenteval.suite import run_suite
|
|
31
|
+
from agenteval.tracer import Tracer
|
|
32
|
+
|
|
33
|
+
__version__ = "0.1.0"
|
|
34
|
+
__all__ = [
|
|
35
|
+
"AgentTrace",
|
|
36
|
+
"AssertionSet",
|
|
37
|
+
"RichReporter",
|
|
38
|
+
"SuiteResult",
|
|
39
|
+
"TestResult",
|
|
40
|
+
"ToolCall",
|
|
41
|
+
"Tracer",
|
|
42
|
+
"adapters",
|
|
43
|
+
"run",
|
|
44
|
+
"run_suite",
|
|
45
|
+
"test",
|
|
46
|
+
]
|
|
@@ -0,0 +1,9 @@
|
|
|
1
|
+
"""Framework adapters for agenteval.
|
|
2
|
+
|
|
3
|
+
Each adapter makes it easy to instrument a specific framework without
|
|
4
|
+
changing agent code. Available adapters:
|
|
5
|
+
|
|
6
|
+
- ``agenteval.adapters.openai`` — OpenAI function calling
|
|
7
|
+
- ``agenteval.adapters.anthropic`` — Anthropic tool use
|
|
8
|
+
- ``agenteval.adapters.langchain`` — LangChain callback handler
|
|
9
|
+
"""
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""Anthropic tool use adapter for agenteval.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from agenteval.adapters.anthropic_adapter import wrap_tools, extract_token_usage
|
|
6
|
+
|
|
7
|
+
async def test_my_agent(tracer: Tracer) -> None:
|
|
8
|
+
tools = wrap_tools({"web_search": search_fn, "calculator": calc_fn}, tracer)
|
|
9
|
+
|
|
10
|
+
async with tracer.run(input=prompt) as run:
|
|
11
|
+
messages = [{"role": "user", "content": prompt}]
|
|
12
|
+
while True:
|
|
13
|
+
response = await client.messages.create(
|
|
14
|
+
model="claude-sonnet-4-6",
|
|
15
|
+
max_tokens=1024,
|
|
16
|
+
tools=anthropic_tool_schemas,
|
|
17
|
+
messages=messages,
|
|
18
|
+
)
|
|
19
|
+
run.set_token_usage(extract_token_usage(response))
|
|
20
|
+
if response.stop_reason == "tool_use":
|
|
21
|
+
for block in response.content:
|
|
22
|
+
if block.type == "tool_use":
|
|
23
|
+
result = await tools[block.name](**block.input)
|
|
24
|
+
# append tool result to messages ...
|
|
25
|
+
elif response.stop_reason == "end_turn":
|
|
26
|
+
text = next(
|
|
27
|
+
(b.text for b in response.content if b.type == "text"), ""
|
|
28
|
+
)
|
|
29
|
+
run.set_output(text)
|
|
30
|
+
break
|
|
31
|
+
|
|
32
|
+
tracer.assert_that().called_tool("web_search").no_errors().check()
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from typing import Any, Callable, Optional
|
|
38
|
+
|
|
39
|
+
from agenteval.tracer import Tracer
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def wrap_tools(
|
|
43
|
+
tool_functions: dict[str, Callable[..., Any]],
|
|
44
|
+
tracer: Tracer,
|
|
45
|
+
) -> dict[str, Callable[..., Any]]:
|
|
46
|
+
"""Wrap a dict of Anthropic tool functions with the tracer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
tool_functions: Mapping of tool name → callable.
|
|
50
|
+
tracer: The active Tracer for the current test run.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
New dict with the same keys but wrapped callables that record
|
|
54
|
+
calls, timing, and errors into the tracer.
|
|
55
|
+
"""
|
|
56
|
+
return {name: tracer.wrap(fn, name=name) for name, fn in tool_functions.items()}
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
def extract_token_usage(response: Any) -> Optional[dict[str, int]]:
|
|
60
|
+
"""Extract token usage from an Anthropic Message response.
|
|
61
|
+
|
|
62
|
+
Args:
|
|
63
|
+
response: An anthropic.types.Message object.
|
|
64
|
+
|
|
65
|
+
Returns:
|
|
66
|
+
Dict with input_tokens, output_tokens, or None if unavailable.
|
|
67
|
+
"""
|
|
68
|
+
usage = getattr(response, "usage", None)
|
|
69
|
+
if usage is None:
|
|
70
|
+
return None
|
|
71
|
+
result: dict[str, int] = {}
|
|
72
|
+
if hasattr(usage, "input_tokens"):
|
|
73
|
+
result["input_tokens"] = usage.input_tokens
|
|
74
|
+
if hasattr(usage, "output_tokens"):
|
|
75
|
+
result["output_tokens"] = usage.output_tokens
|
|
76
|
+
if hasattr(usage, "cache_read_input_tokens"):
|
|
77
|
+
result["cache_read_input_tokens"] = usage.cache_read_input_tokens
|
|
78
|
+
if hasattr(usage, "cache_creation_input_tokens"):
|
|
79
|
+
result["cache_creation_input_tokens"] = usage.cache_creation_input_tokens
|
|
80
|
+
return result or None
|
|
@@ -0,0 +1,135 @@
|
|
|
1
|
+
"""LangChain callback handler adapter for agenteval.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from agenteval.adapters.langchain_adapter import AgentEvalCallbackHandler
|
|
6
|
+
|
|
7
|
+
async def test_langchain_agent(tracer: Tracer) -> None:
|
|
8
|
+
handler = AgentEvalCallbackHandler()
|
|
9
|
+
|
|
10
|
+
async with tracer.run(input="find Italian restaurants") as run:
|
|
11
|
+
result = await agent.ainvoke(
|
|
12
|
+
{"input": "find Italian restaurants"},
|
|
13
|
+
config={"callbacks": [handler]},
|
|
14
|
+
)
|
|
15
|
+
run.set_output(result.get("output", ""))
|
|
16
|
+
|
|
17
|
+
tracer.assert_that().called_tool("restaurant_search").no_errors().check()
|
|
18
|
+
|
|
19
|
+
The handler reads the active Tracer from the _ACTIVE_TRACER ContextVar, so it
|
|
20
|
+
works automatically when used inside agenteval.run() — no explicit tracer
|
|
21
|
+
reference needed.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
from __future__ import annotations
|
|
25
|
+
|
|
26
|
+
import time
|
|
27
|
+
from typing import Any
|
|
28
|
+
from uuid import UUID
|
|
29
|
+
|
|
30
|
+
from agenteval.tracer import Tracer
|
|
31
|
+
|
|
32
|
+
try:
|
|
33
|
+
from langchain_core.callbacks.base import BaseCallbackHandler
|
|
34
|
+
from langchain_core.outputs import LLMResult
|
|
35
|
+
|
|
36
|
+
_LANGCHAIN_AVAILABLE = True
|
|
37
|
+
except ImportError:
|
|
38
|
+
_LANGCHAIN_AVAILABLE = False
|
|
39
|
+
BaseCallbackHandler = object # type: ignore[assignment,misc]
|
|
40
|
+
LLMResult = Any # type: ignore[assignment,misc]
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class AgentEvalCallbackHandler(BaseCallbackHandler): # type: ignore[misc]
|
|
44
|
+
"""LangChain callback handler that records tool calls into the active Tracer.
|
|
45
|
+
|
|
46
|
+
Records each tool invocation's name, arguments, result, duration, and any
|
|
47
|
+
errors. Reads the active tracer via ``Tracer.current()`` (ContextVar), so
|
|
48
|
+
multiple concurrent test runs each get their own tracer automatically.
|
|
49
|
+
|
|
50
|
+
If no tracer is active (i.e., used outside of agenteval.run()), all
|
|
51
|
+
callbacks are no-ops to avoid errors in non-test contexts.
|
|
52
|
+
"""
|
|
53
|
+
|
|
54
|
+
def __init__(self) -> None:
|
|
55
|
+
if not _LANGCHAIN_AVAILABLE:
|
|
56
|
+
raise ImportError(
|
|
57
|
+
"langchain-core is required for AgentEvalCallbackHandler. "
|
|
58
|
+
"Install it with: pip install agenteval[langchain]"
|
|
59
|
+
)
|
|
60
|
+
super().__init__()
|
|
61
|
+
# Maps LangChain run_id → (start_time, tool_name, parsed_args)
|
|
62
|
+
self._pending: dict[str, tuple[float, str, dict[str, Any]]] = {}
|
|
63
|
+
|
|
64
|
+
def on_tool_start(
|
|
65
|
+
self,
|
|
66
|
+
serialized: dict[str, Any],
|
|
67
|
+
input_str: str,
|
|
68
|
+
*,
|
|
69
|
+
run_id: UUID,
|
|
70
|
+
**kwargs: Any,
|
|
71
|
+
) -> None:
|
|
72
|
+
tool_name: str = serialized.get("name", kwargs.get("name", "unknown"))
|
|
73
|
+
try:
|
|
74
|
+
import json
|
|
75
|
+
args = json.loads(input_str) if isinstance(input_str, str) else {"input": input_str}
|
|
76
|
+
if not isinstance(args, dict):
|
|
77
|
+
args = {"input": args}
|
|
78
|
+
except Exception:
|
|
79
|
+
args = {"input": input_str}
|
|
80
|
+
|
|
81
|
+
self._pending[str(run_id)] = (time.perf_counter(), tool_name, args)
|
|
82
|
+
|
|
83
|
+
def on_tool_end(
|
|
84
|
+
self,
|
|
85
|
+
output: str,
|
|
86
|
+
*,
|
|
87
|
+
run_id: UUID,
|
|
88
|
+
**kwargs: Any,
|
|
89
|
+
) -> None:
|
|
90
|
+
tracer = Tracer.current()
|
|
91
|
+
if tracer is None:
|
|
92
|
+
return
|
|
93
|
+
|
|
94
|
+
key = str(run_id)
|
|
95
|
+
entry = self._pending.pop(key, None)
|
|
96
|
+
if entry is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
start_time, tool_name, args = entry
|
|
100
|
+
duration = time.perf_counter() - start_time
|
|
101
|
+
tracer.record_tool_call(
|
|
102
|
+
name=tool_name,
|
|
103
|
+
arguments=args,
|
|
104
|
+
result=output,
|
|
105
|
+
duration_seconds=duration,
|
|
106
|
+
timestamp=time.time() - duration,
|
|
107
|
+
error=None,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
def on_tool_error(
|
|
111
|
+
self,
|
|
112
|
+
error: BaseException,
|
|
113
|
+
*,
|
|
114
|
+
run_id: UUID,
|
|
115
|
+
**kwargs: Any,
|
|
116
|
+
) -> None:
|
|
117
|
+
tracer = Tracer.current()
|
|
118
|
+
if tracer is None:
|
|
119
|
+
return
|
|
120
|
+
|
|
121
|
+
key = str(run_id)
|
|
122
|
+
entry = self._pending.pop(key, None)
|
|
123
|
+
if entry is None:
|
|
124
|
+
return
|
|
125
|
+
|
|
126
|
+
start_time, tool_name, args = entry
|
|
127
|
+
duration = time.perf_counter() - start_time
|
|
128
|
+
tracer.record_tool_call(
|
|
129
|
+
name=tool_name,
|
|
130
|
+
arguments=args,
|
|
131
|
+
result=None,
|
|
132
|
+
duration_seconds=duration,
|
|
133
|
+
timestamp=time.time() - duration,
|
|
134
|
+
error=f"{type(error).__name__}: {error}",
|
|
135
|
+
)
|
|
@@ -0,0 +1,80 @@
|
|
|
1
|
+
"""OpenAI function calling adapter for agenteval.
|
|
2
|
+
|
|
3
|
+
Usage::
|
|
4
|
+
|
|
5
|
+
from agenteval.adapters.openai_adapter import wrap_tools, extract_token_usage
|
|
6
|
+
|
|
7
|
+
async def test_my_agent(tracer: Tracer) -> None:
|
|
8
|
+
tools = wrap_tools({"search": search_fn, "calculator": calc_fn}, tracer)
|
|
9
|
+
|
|
10
|
+
async with tracer.run(input=prompt) as run:
|
|
11
|
+
# Your OpenAI tool-calling loop:
|
|
12
|
+
messages = [{"role": "user", "content": prompt}]
|
|
13
|
+
while True:
|
|
14
|
+
response = await client.chat.completions.create(
|
|
15
|
+
model="gpt-4o",
|
|
16
|
+
messages=messages,
|
|
17
|
+
tools=openai_tool_schemas,
|
|
18
|
+
)
|
|
19
|
+
run.set_token_usage(extract_token_usage(response))
|
|
20
|
+
choice = response.choices[0]
|
|
21
|
+
if choice.finish_reason == "tool_calls":
|
|
22
|
+
for tc in choice.message.tool_calls:
|
|
23
|
+
import json
|
|
24
|
+
fn_name = tc.function.name
|
|
25
|
+
fn_args = json.loads(tc.function.arguments)
|
|
26
|
+
result = await tools[fn_name](**fn_args)
|
|
27
|
+
# append tool result to messages ...
|
|
28
|
+
elif choice.finish_reason == "stop":
|
|
29
|
+
run.set_output(choice.message.content)
|
|
30
|
+
break
|
|
31
|
+
|
|
32
|
+
tracer.assert_that().called_tool("search").no_errors().check()
|
|
33
|
+
"""
|
|
34
|
+
|
|
35
|
+
from __future__ import annotations
|
|
36
|
+
|
|
37
|
+
from typing import Any, Callable, Optional
|
|
38
|
+
|
|
39
|
+
from agenteval.tracer import Tracer
|
|
40
|
+
|
|
41
|
+
|
|
42
|
+
def wrap_tools(
|
|
43
|
+
tool_functions: dict[str, Callable[..., Any]],
|
|
44
|
+
tracer: Tracer,
|
|
45
|
+
) -> dict[str, Callable[..., Any]]:
|
|
46
|
+
"""Wrap a dict of OpenAI tool functions with the tracer.
|
|
47
|
+
|
|
48
|
+
Args:
|
|
49
|
+
tool_functions: Mapping of tool name → callable.
|
|
50
|
+
tracer: The active Tracer for the current test run.
|
|
51
|
+
|
|
52
|
+
Returns:
|
|
53
|
+
New dict with the same keys but wrapped callables that record
|
|
54
|
+
calls, timing, and errors into the tracer.
|
|
55
|
+
|
|
56
|
+
Example::
|
|
57
|
+
|
|
58
|
+
tools = wrap_tools({"search": search_fn, "weather": weather_fn}, tracer)
|
|
59
|
+
result = await tools["search"](query="python news")
|
|
60
|
+
"""
|
|
61
|
+
return {name: tracer.wrap(fn, name=name) for name, fn in tool_functions.items()}
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def extract_token_usage(response: Any) -> Optional[dict[str, int]]:
|
|
65
|
+
"""Extract token usage from an OpenAI ChatCompletion response object.
|
|
66
|
+
|
|
67
|
+
Args:
|
|
68
|
+
response: An openai.types.chat.ChatCompletion object.
|
|
69
|
+
|
|
70
|
+
Returns:
|
|
71
|
+
Dict with prompt_tokens, completion_tokens, total_tokens, or None.
|
|
72
|
+
"""
|
|
73
|
+
usage = getattr(response, "usage", None)
|
|
74
|
+
if usage is None:
|
|
75
|
+
return None
|
|
76
|
+
return {
|
|
77
|
+
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
|
78
|
+
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
|
79
|
+
"total_tokens": getattr(usage, "total_tokens", 0),
|
|
80
|
+
}
|
agenteval/assertions.py
ADDED
|
@@ -0,0 +1,289 @@
|
|
|
1
|
+
"""Fluent assertion library for inspecting AgentTrace objects."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import json
|
|
6
|
+
import math
|
|
7
|
+
from typing import Any, Callable, Literal, Optional, Union
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
10
|
+
|
|
11
|
+
from agenteval.models import AgentTrace
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class AssertionSet:
|
|
15
|
+
"""Fluent assertions on an AgentTrace.
|
|
16
|
+
|
|
17
|
+
Failures are **collected**, not raised immediately. Call `.check()` at the
|
|
18
|
+
end of a chain to raise a single AssertionError listing all failures.
|
|
19
|
+
|
|
20
|
+
Usage::
|
|
21
|
+
|
|
22
|
+
tracer.assert_that()
|
|
23
|
+
.called_tool("search")
|
|
24
|
+
.never_called_tool("delete")
|
|
25
|
+
.completed_within_steps(5)
|
|
26
|
+
.no_errors()
|
|
27
|
+
.check()
|
|
28
|
+
"""
|
|
29
|
+
|
|
30
|
+
def __init__(self, trace: AgentTrace) -> None:
|
|
31
|
+
self._trace = trace
|
|
32
|
+
self._failures: list[str] = []
|
|
33
|
+
|
|
34
|
+
# ------------------------------------------------------------------ #
|
|
35
|
+
# Tool call assertions
|
|
36
|
+
# ------------------------------------------------------------------ #
|
|
37
|
+
|
|
38
|
+
def called_tool(self, name: str) -> "AssertionSet":
|
|
39
|
+
"""Assert that the tool was called at least once."""
|
|
40
|
+
calls = [tc for tc in self._trace.tool_calls if tc.name == name]
|
|
41
|
+
if not calls:
|
|
42
|
+
all_tools = [tc.name for tc in self._trace.tool_calls]
|
|
43
|
+
self._failures.append(
|
|
44
|
+
f"Expected tool '{name}' to be called, but it was not. "
|
|
45
|
+
f"Tools called: {all_tools or '(none)'}"
|
|
46
|
+
)
|
|
47
|
+
return self
|
|
48
|
+
|
|
49
|
+
def never_called_tool(self, name: str) -> "AssertionSet":
|
|
50
|
+
"""Assert that the tool was never called."""
|
|
51
|
+
calls = [tc for tc in self._trace.tool_calls if tc.name == name]
|
|
52
|
+
if calls:
|
|
53
|
+
self._failures.append(
|
|
54
|
+
f"Expected tool '{name}' to never be called, but it was called {len(calls)} time(s)."
|
|
55
|
+
)
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
def tool_call_count(
|
|
59
|
+
self,
|
|
60
|
+
name: str,
|
|
61
|
+
*,
|
|
62
|
+
min: int = 0,
|
|
63
|
+
max: int = math.inf, # type: ignore[assignment]
|
|
64
|
+
) -> "AssertionSet":
|
|
65
|
+
"""Assert that the tool was called between min and max times (inclusive)."""
|
|
66
|
+
count = sum(1 for tc in self._trace.tool_calls if tc.name == name)
|
|
67
|
+
if not (min <= count <= max):
|
|
68
|
+
self._failures.append(
|
|
69
|
+
f"Expected tool '{name}' to be called between {min} and "
|
|
70
|
+
f"{'∞' if max == math.inf else max} times, but it was called {count} time(s)."
|
|
71
|
+
)
|
|
72
|
+
return self
|
|
73
|
+
|
|
74
|
+
def tool_called_before(self, tool_a: str, tool_b: str) -> "AssertionSet":
|
|
75
|
+
"""Assert that tool_a was called before tool_b (at least one call each)."""
|
|
76
|
+
calls = self._trace.tool_calls
|
|
77
|
+
first_a = next((i for i, tc in enumerate(calls) if tc.name == tool_a), None)
|
|
78
|
+
first_b = next((i for i, tc in enumerate(calls) if tc.name == tool_b), None)
|
|
79
|
+
|
|
80
|
+
if first_a is None:
|
|
81
|
+
self._failures.append(
|
|
82
|
+
f"Ordering assertion failed: tool '{tool_a}' was never called."
|
|
83
|
+
)
|
|
84
|
+
elif first_b is None:
|
|
85
|
+
self._failures.append(
|
|
86
|
+
f"Ordering assertion failed: tool '{tool_b}' was never called."
|
|
87
|
+
)
|
|
88
|
+
elif first_a >= first_b:
|
|
89
|
+
self._failures.append(
|
|
90
|
+
f"Expected '{tool_a}' to be called before '{tool_b}', "
|
|
91
|
+
f"but '{tool_b}' was called first (positions: {tool_a}={first_a}, {tool_b}={first_b})."
|
|
92
|
+
)
|
|
93
|
+
return self
|
|
94
|
+
|
|
95
|
+
def tool_called_with_args(
|
|
96
|
+
self,
|
|
97
|
+
name: str,
|
|
98
|
+
args: dict[str, Any],
|
|
99
|
+
*,
|
|
100
|
+
match: Literal["subset", "exact"] = "subset",
|
|
101
|
+
) -> "AssertionSet":
|
|
102
|
+
"""Assert that a tool was called with specific arguments.
|
|
103
|
+
|
|
104
|
+
Args:
|
|
105
|
+
name: Tool name to check.
|
|
106
|
+
args: Expected arguments.
|
|
107
|
+
match: 'subset' (default) checks all provided keys are present with
|
|
108
|
+
matching values. 'exact' requires the arguments dict to match exactly.
|
|
109
|
+
"""
|
|
110
|
+
matching_calls = [tc for tc in self._trace.tool_calls if tc.name == name]
|
|
111
|
+
if not matching_calls:
|
|
112
|
+
self._failures.append(
|
|
113
|
+
f"tool_called_with_args: tool '{name}' was never called."
|
|
114
|
+
)
|
|
115
|
+
return self
|
|
116
|
+
|
|
117
|
+
def _matches(call_args: dict[str, Any]) -> bool:
|
|
118
|
+
if match == "exact":
|
|
119
|
+
return call_args == args
|
|
120
|
+
# subset: all expected keys present with matching values
|
|
121
|
+
return all(call_args.get(k) == v for k, v in args.items())
|
|
122
|
+
|
|
123
|
+
if not any(_matches(tc.arguments) for tc in matching_calls):
|
|
124
|
+
actual_args = [tc.arguments for tc in matching_calls]
|
|
125
|
+
self._failures.append(
|
|
126
|
+
f"tool '{name}' was called {len(matching_calls)} time(s), but none matched "
|
|
127
|
+
f"the expected args {args} (match='{match}'). Actual args: {actual_args}"
|
|
128
|
+
)
|
|
129
|
+
return self
|
|
130
|
+
|
|
131
|
+
# ------------------------------------------------------------------ #
|
|
132
|
+
# Step / time assertions
|
|
133
|
+
# ------------------------------------------------------------------ #
|
|
134
|
+
|
|
135
|
+
def completed_within_steps(self, n: int) -> "AssertionSet":
|
|
136
|
+
"""Assert that the agent finished in n steps or fewer."""
|
|
137
|
+
actual = self._trace.effective_steps
|
|
138
|
+
if actual > n:
|
|
139
|
+
self._failures.append(
|
|
140
|
+
f"Expected agent to complete within {n} steps, but took {actual} steps."
|
|
141
|
+
)
|
|
142
|
+
return self
|
|
143
|
+
|
|
144
|
+
def completed_within_seconds(self, n: float) -> "AssertionSet":
|
|
145
|
+
"""Assert that the agent finished within n seconds."""
|
|
146
|
+
actual = self._trace.duration_seconds
|
|
147
|
+
if actual > n:
|
|
148
|
+
self._failures.append(
|
|
149
|
+
f"Expected agent to complete within {n:.2f}s, but took {actual:.2f}s."
|
|
150
|
+
)
|
|
151
|
+
return self
|
|
152
|
+
|
|
153
|
+
# ------------------------------------------------------------------ #
|
|
154
|
+
# Output assertions
|
|
155
|
+
# ------------------------------------------------------------------ #
|
|
156
|
+
|
|
157
|
+
def response_contains(self, keyword: str, *, case_sensitive: bool = True) -> "AssertionSet":
|
|
158
|
+
"""Assert that the final response contains a keyword."""
|
|
159
|
+
output = self._trace.output
|
|
160
|
+
if output is None:
|
|
161
|
+
self._failures.append(
|
|
162
|
+
f"response_contains: agent output is None, expected to contain '{keyword}'."
|
|
163
|
+
)
|
|
164
|
+
return self
|
|
165
|
+
|
|
166
|
+
text = str(output)
|
|
167
|
+
haystack = text if case_sensitive else text.lower()
|
|
168
|
+
needle = keyword if case_sensitive else keyword.lower()
|
|
169
|
+
|
|
170
|
+
if needle not in haystack:
|
|
171
|
+
preview = text[:200] + "..." if len(text) > 200 else text
|
|
172
|
+
self._failures.append(
|
|
173
|
+
f"Expected response to contain '{keyword}', but it did not. "
|
|
174
|
+
f"Response: {preview!r}"
|
|
175
|
+
)
|
|
176
|
+
return self
|
|
177
|
+
|
|
178
|
+
def response_matches_schema(
|
|
179
|
+
self,
|
|
180
|
+
schema: type[BaseModel],
|
|
181
|
+
*,
|
|
182
|
+
parse_json: bool = True,
|
|
183
|
+
) -> "AssertionSet":
|
|
184
|
+
"""Assert that the final response matches a Pydantic schema.
|
|
185
|
+
|
|
186
|
+
If the output is a string and parse_json=True (default), it will be
|
|
187
|
+
JSON-parsed first before validation.
|
|
188
|
+
"""
|
|
189
|
+
output = self._trace.output
|
|
190
|
+
if output is None:
|
|
191
|
+
self._failures.append(
|
|
192
|
+
f"response_matches_schema: agent output is None, "
|
|
193
|
+
f"expected to match {schema.__name__}."
|
|
194
|
+
)
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
data: Any = output
|
|
198
|
+
if isinstance(output, str) and parse_json:
|
|
199
|
+
try:
|
|
200
|
+
data = json.loads(output)
|
|
201
|
+
except json.JSONDecodeError as e:
|
|
202
|
+
self._failures.append(
|
|
203
|
+
f"response_matches_schema: failed to JSON-parse output before "
|
|
204
|
+
f"validating against {schema.__name__}: {e}. "
|
|
205
|
+
f"Output was: {output[:200]!r}"
|
|
206
|
+
)
|
|
207
|
+
return self
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
schema.model_validate(data)
|
|
211
|
+
except Exception as e:
|
|
212
|
+
self._failures.append(
|
|
213
|
+
f"response_matches_schema: output does not match schema "
|
|
214
|
+
f"{schema.__name__}: {e}"
|
|
215
|
+
)
|
|
216
|
+
return self
|
|
217
|
+
|
|
218
|
+
# ------------------------------------------------------------------ #
|
|
219
|
+
# Error assertions
|
|
220
|
+
# ------------------------------------------------------------------ #
|
|
221
|
+
|
|
222
|
+
def no_errors(self) -> "AssertionSet":
|
|
223
|
+
"""Assert that the agent completed without any exceptions."""
|
|
224
|
+
if self._trace.error is not None:
|
|
225
|
+
self._failures.append(
|
|
226
|
+
f"Expected no errors, but agent raised: {self._trace.error}"
|
|
227
|
+
)
|
|
228
|
+
return self
|
|
229
|
+
|
|
230
|
+
# ------------------------------------------------------------------ #
|
|
231
|
+
# Custom / escape hatch
|
|
232
|
+
# ------------------------------------------------------------------ #
|
|
233
|
+
|
|
234
|
+
def custom(
|
|
235
|
+
self,
|
|
236
|
+
fn: Callable[[AgentTrace], Union[bool, str]],
|
|
237
|
+
*,
|
|
238
|
+
message: Optional[str] = None,
|
|
239
|
+
) -> "AssertionSet":
|
|
240
|
+
"""Run a custom assertion function against the trace.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
fn: Callable that receives the AgentTrace and returns True (pass),
|
|
244
|
+
False (fail), or a failure message string.
|
|
245
|
+
message: Optional failure message to use when fn returns False.
|
|
246
|
+
"""
|
|
247
|
+
try:
|
|
248
|
+
result = fn(self._trace)
|
|
249
|
+
except Exception as e:
|
|
250
|
+
self._failures.append(
|
|
251
|
+
f"custom assertion raised an exception: {type(e).__name__}: {e}"
|
|
252
|
+
)
|
|
253
|
+
return self
|
|
254
|
+
|
|
255
|
+
if result is True:
|
|
256
|
+
return self
|
|
257
|
+
|
|
258
|
+
if result is False:
|
|
259
|
+
self._failures.append(
|
|
260
|
+
message or "custom assertion failed (returned False)."
|
|
261
|
+
)
|
|
262
|
+
elif isinstance(result, str):
|
|
263
|
+
self._failures.append(result)
|
|
264
|
+
|
|
265
|
+
return self
|
|
266
|
+
|
|
267
|
+
# ------------------------------------------------------------------ #
|
|
268
|
+
# Terminator
|
|
269
|
+
# ------------------------------------------------------------------ #
|
|
270
|
+
|
|
271
|
+
def check(self) -> None:
|
|
272
|
+
"""Raise AssertionError listing all collected failures. No-op if all passed."""
|
|
273
|
+
if self._failures:
|
|
274
|
+
lines = "\n".join(f" • {f}" for f in self._failures)
|
|
275
|
+
raise AssertionError(f"Trace assertions failed ({len(self._failures)} failure(s)):\n{lines}")
|
|
276
|
+
|
|
277
|
+
# ------------------------------------------------------------------ #
|
|
278
|
+
# Introspection (for use without raising)
|
|
279
|
+
# ------------------------------------------------------------------ #
|
|
280
|
+
|
|
281
|
+
@property
|
|
282
|
+
def passed(self) -> bool:
|
|
283
|
+
"""True if no failures have been collected."""
|
|
284
|
+
return len(self._failures) == 0
|
|
285
|
+
|
|
286
|
+
@property
|
|
287
|
+
def failures(self) -> list[str]:
|
|
288
|
+
"""List of all collected failure messages."""
|
|
289
|
+
return list(self._failures)
|