toolcallcheck 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- toolcallcheck/__init__.py +71 -0
- toolcallcheck/adapters/__init__.py +130 -0
- toolcallcheck/assertions.py +232 -0
- toolcallcheck/builders.py +265 -0
- toolcallcheck/diff.py +108 -0
- toolcallcheck/fake_model.py +111 -0
- toolcallcheck/fixtures.py +58 -0
- toolcallcheck/markers.py +14 -0
- toolcallcheck/mock_server.py +191 -0
- toolcallcheck/multi_turn.py +104 -0
- toolcallcheck/offline.py +60 -0
- toolcallcheck/plugins.py +86 -0
- toolcallcheck/recording.py +118 -0
- toolcallcheck/result.py +85 -0
- toolcallcheck/runner.py +256 -0
- toolcallcheck/scenario.py +52 -0
- toolcallcheck/snapshot.py +88 -0
- toolcallcheck/trajectory.py +129 -0
- toolcallcheck-0.1.0.dist-info/METADATA +1019 -0
- toolcallcheck-0.1.0.dist-info/RECORD +23 -0
- toolcallcheck-0.1.0.dist-info/WHEEL +4 -0
- toolcallcheck-0.1.0.dist-info/entry_points.txt +2 -0
- toolcallcheck-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,71 @@
|
|
|
1
|
+
"""toolcallcheck: Deterministic, pytest-native testing for tool-using AI agents.
|
|
2
|
+
|
|
3
|
+
Mock MCP tools, assert exact tool calls and trajectories, verify headers
|
|
4
|
+
and routing, and reproduce failures locally without depending on cloud
|
|
5
|
+
dashboards or live models.
|
|
6
|
+
"""
|
|
7
|
+
|
|
8
|
+
__version__ = "0.1.0"
|
|
9
|
+
|
|
10
|
+
# ---------------------------------------------------------------------------
|
|
11
|
+
# P0 Core
|
|
12
|
+
# ---------------------------------------------------------------------------
|
|
13
|
+
from toolcallcheck.assertions import (
|
|
14
|
+
assert_headers,
|
|
15
|
+
assert_model_used,
|
|
16
|
+
assert_no_tool_calls,
|
|
17
|
+
assert_response_contains,
|
|
18
|
+
assert_response_equals,
|
|
19
|
+
assert_response_matches,
|
|
20
|
+
assert_tool_args_contain,
|
|
21
|
+
assert_tool_call_count,
|
|
22
|
+
assert_tool_call_order,
|
|
23
|
+
assert_tool_calls,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
# ---------------------------------------------------------------------------
|
|
27
|
+
# P1 Adoption & Coverage
|
|
28
|
+
# ---------------------------------------------------------------------------
|
|
29
|
+
from toolcallcheck.builders import ScenarioBuilder, ToolResponseBuilder, UserMessageBuilder
|
|
30
|
+
from toolcallcheck.fake_model import FakeModel
|
|
31
|
+
from toolcallcheck.mock_server import MockMCPServer, MockTool
|
|
32
|
+
from toolcallcheck.multi_turn import Conversation
|
|
33
|
+
from toolcallcheck.offline import offline
|
|
34
|
+
from toolcallcheck.plugins import register_assertion, run_custom_assertion
|
|
35
|
+
from toolcallcheck.recording import Recorder
|
|
36
|
+
from toolcallcheck.result import AgentResult, ToolCall
|
|
37
|
+
from toolcallcheck.runner import AgentRunner
|
|
38
|
+
from toolcallcheck.scenario import scenario
|
|
39
|
+
from toolcallcheck.snapshot import assert_snapshot
|
|
40
|
+
from toolcallcheck.trajectory import assert_trajectory
|
|
41
|
+
|
|
42
|
+
__all__ = [
|
|
43
|
+
"AgentResult",
|
|
44
|
+
"AgentRunner",
|
|
45
|
+
"Conversation",
|
|
46
|
+
"FakeModel",
|
|
47
|
+
"MockMCPServer",
|
|
48
|
+
"MockTool",
|
|
49
|
+
"Recorder",
|
|
50
|
+
"ScenarioBuilder",
|
|
51
|
+
"ToolCall",
|
|
52
|
+
"ToolResponseBuilder",
|
|
53
|
+
"UserMessageBuilder",
|
|
54
|
+
"__version__",
|
|
55
|
+
"assert_headers",
|
|
56
|
+
"assert_model_used",
|
|
57
|
+
"assert_no_tool_calls",
|
|
58
|
+
"assert_response_contains",
|
|
59
|
+
"assert_response_equals",
|
|
60
|
+
"assert_response_matches",
|
|
61
|
+
"assert_snapshot",
|
|
62
|
+
"assert_tool_args_contain",
|
|
63
|
+
"assert_tool_call_count",
|
|
64
|
+
"assert_tool_call_order",
|
|
65
|
+
"assert_tool_calls",
|
|
66
|
+
"assert_trajectory",
|
|
67
|
+
"offline",
|
|
68
|
+
"register_assertion",
|
|
69
|
+
"run_custom_assertion",
|
|
70
|
+
"scenario",
|
|
71
|
+
]
|
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
"""Framework adapter protocol and stubs.
|
|
2
|
+
|
|
3
|
+
Defines the base protocol for adapting toolcallcheck to different
|
|
4
|
+
agent frameworks: OpenAI Agents, LangGraph, PydanticAI, CrewAI.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any, Protocol, runtime_checkable
|
|
10
|
+
|
|
11
|
+
from toolcallcheck.result import AgentResult
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
@runtime_checkable
|
|
15
|
+
class FrameworkAdapter(Protocol):
|
|
16
|
+
"""Protocol that framework adapters must implement.
|
|
17
|
+
|
|
18
|
+
An adapter bridges between toolcallcheck's mock infrastructure and
|
|
19
|
+
a specific agent framework's execution model.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
@property
|
|
23
|
+
def framework_name(self) -> str:
|
|
24
|
+
"""Return the name of the framework this adapter supports."""
|
|
25
|
+
...
|
|
26
|
+
|
|
27
|
+
async def invoke(
|
|
28
|
+
self,
|
|
29
|
+
message: str,
|
|
30
|
+
*,
|
|
31
|
+
tools: list[dict[str, Any]] | None = None,
|
|
32
|
+
headers: dict[str, str] | None = None,
|
|
33
|
+
metadata: dict[str, Any] | None = None,
|
|
34
|
+
) -> AgentResult:
|
|
35
|
+
"""Run the agent with the given message and return a result."""
|
|
36
|
+
...
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class OpenAIAgentsAdapter:
|
|
40
|
+
"""Stub adapter for OpenAI Agents SDK.
|
|
41
|
+
|
|
42
|
+
This is a placeholder for the full integration. Override
|
|
43
|
+
``invoke()`` to wire up the OpenAI Agents SDK.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
@property
|
|
47
|
+
def framework_name(self) -> str:
|
|
48
|
+
return "openai-agents"
|
|
49
|
+
|
|
50
|
+
async def invoke(
|
|
51
|
+
self,
|
|
52
|
+
message: str,
|
|
53
|
+
*,
|
|
54
|
+
tools: list[dict[str, Any]] | None = None,
|
|
55
|
+
headers: dict[str, str] | None = None,
|
|
56
|
+
metadata: dict[str, Any] | None = None,
|
|
57
|
+
) -> AgentResult:
|
|
58
|
+
raise NotImplementedError(
|
|
59
|
+
"OpenAI Agents adapter is a stub. "
|
|
60
|
+
"Implement invoke() to connect to the OpenAI Agents SDK."
|
|
61
|
+
)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class LangGraphAdapter:
|
|
65
|
+
"""Stub adapter for LangGraph."""
|
|
66
|
+
|
|
67
|
+
@property
|
|
68
|
+
def framework_name(self) -> str:
|
|
69
|
+
return "langgraph"
|
|
70
|
+
|
|
71
|
+
async def invoke(
|
|
72
|
+
self,
|
|
73
|
+
message: str,
|
|
74
|
+
*,
|
|
75
|
+
tools: list[dict[str, Any]] | None = None,
|
|
76
|
+
headers: dict[str, str] | None = None,
|
|
77
|
+
metadata: dict[str, Any] | None = None,
|
|
78
|
+
) -> AgentResult:
|
|
79
|
+
raise NotImplementedError(
|
|
80
|
+
"LangGraph adapter is a stub. Implement invoke() to connect to LangGraph."
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class PydanticAIAdapter:
|
|
85
|
+
"""Stub adapter for PydanticAI."""
|
|
86
|
+
|
|
87
|
+
@property
|
|
88
|
+
def framework_name(self) -> str:
|
|
89
|
+
return "pydantic-ai"
|
|
90
|
+
|
|
91
|
+
async def invoke(
|
|
92
|
+
self,
|
|
93
|
+
message: str,
|
|
94
|
+
*,
|
|
95
|
+
tools: list[dict[str, Any]] | None = None,
|
|
96
|
+
headers: dict[str, str] | None = None,
|
|
97
|
+
metadata: dict[str, Any] | None = None,
|
|
98
|
+
) -> AgentResult:
|
|
99
|
+
raise NotImplementedError(
|
|
100
|
+
"PydanticAI adapter is a stub. Implement invoke() to connect to PydanticAI."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
class CrewAIAdapter:
|
|
105
|
+
"""Stub adapter for CrewAI."""
|
|
106
|
+
|
|
107
|
+
@property
|
|
108
|
+
def framework_name(self) -> str:
|
|
109
|
+
return "crewai"
|
|
110
|
+
|
|
111
|
+
async def invoke(
|
|
112
|
+
self,
|
|
113
|
+
message: str,
|
|
114
|
+
*,
|
|
115
|
+
tools: list[dict[str, Any]] | None = None,
|
|
116
|
+
headers: dict[str, str] | None = None,
|
|
117
|
+
metadata: dict[str, Any] | None = None,
|
|
118
|
+
) -> AgentResult:
|
|
119
|
+
raise NotImplementedError(
|
|
120
|
+
"CrewAI adapter is a stub. Implement invoke() to connect to CrewAI."
|
|
121
|
+
)
|
|
122
|
+
|
|
123
|
+
|
|
124
|
+
__all__ = [
|
|
125
|
+
"CrewAIAdapter",
|
|
126
|
+
"FrameworkAdapter",
|
|
127
|
+
"LangGraphAdapter",
|
|
128
|
+
"OpenAIAgentsAdapter",
|
|
129
|
+
"PydanticAIAdapter",
|
|
130
|
+
]
|
|
@@ -0,0 +1,232 @@
|
|
|
1
|
+
"""Assertion helpers for toolcallcheck.
|
|
2
|
+
|
|
3
|
+
Every function raises ``AssertionError`` with a structured, human-readable
|
|
4
|
+
diff message on failure so that CI output is immediately actionable.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
import re
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from toolcallcheck.diff import format_tool_call_diff, format_value_diff
|
|
13
|
+
from toolcallcheck.result import AgentResult
|
|
14
|
+
|
|
15
|
+
# ---------------------------------------------------------------------------
|
|
16
|
+
# Tool-call assertions
|
|
17
|
+
# ---------------------------------------------------------------------------
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def assert_tool_calls(
|
|
21
|
+
result: AgentResult,
|
|
22
|
+
expected: list[dict[str, Any]],
|
|
23
|
+
*,
|
|
24
|
+
strict_order: bool = True,
|
|
25
|
+
) -> None:
|
|
26
|
+
"""Assert that the agent made exactly the expected tool calls.
|
|
27
|
+
|
|
28
|
+
Parameters
|
|
29
|
+
----------
|
|
30
|
+
result:
|
|
31
|
+
The :class:`AgentResult` to inspect.
|
|
32
|
+
expected:
|
|
33
|
+
A list of dicts, each with ``"name"`` and ``"args"`` keys.
|
|
34
|
+
strict_order:
|
|
35
|
+
If ``True`` (default), the order of tool calls must match.
|
|
36
|
+
"""
|
|
37
|
+
actual = [{"name": tc.name, "args": tc.args} for tc in result.tool_calls]
|
|
38
|
+
|
|
39
|
+
if strict_order:
|
|
40
|
+
if actual != expected:
|
|
41
|
+
diff = format_tool_call_diff(expected, actual)
|
|
42
|
+
raise AssertionError(f"Tool calls do not match (strict order):\n{diff}")
|
|
43
|
+
else:
|
|
44
|
+
# Compare as sets — normalize by sorting
|
|
45
|
+
def _sort_key(d: dict[str, Any]) -> str:
|
|
46
|
+
return d["name"] + str(sorted(d.get("args", {}).items()))
|
|
47
|
+
|
|
48
|
+
if sorted(actual, key=_sort_key) != sorted(expected, key=_sort_key):
|
|
49
|
+
diff = format_tool_call_diff(expected, actual)
|
|
50
|
+
raise AssertionError(f"Tool calls do not match (any order):\n{diff}")
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def assert_tool_call_count(result: AgentResult, expected_count: int) -> None:
|
|
54
|
+
"""Assert the number of tool calls that occurred."""
|
|
55
|
+
actual = result.tool_call_count
|
|
56
|
+
if actual != expected_count:
|
|
57
|
+
raise AssertionError(
|
|
58
|
+
f"Expected {expected_count} tool call(s), got {actual}.\n"
|
|
59
|
+
f" Actual calls: {result.tool_names}"
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
def assert_no_tool_calls(result: AgentResult) -> None:
|
|
64
|
+
"""Assert that the agent responded without invoking any tool."""
|
|
65
|
+
if result.tool_calls:
|
|
66
|
+
names = result.tool_names
|
|
67
|
+
raise AssertionError(
|
|
68
|
+
f"Expected no tool calls, but {len(names)} call(s) were made:\n Tools called: {names}"
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def assert_tool_call_order(result: AgentResult, expected_names: list[str]) -> None:
|
|
73
|
+
"""Assert that tools were called in the specified order.
|
|
74
|
+
|
|
75
|
+
Only checks the order of the *named* tools; ignores other calls.
|
|
76
|
+
"""
|
|
77
|
+
actual_names = result.tool_names
|
|
78
|
+
if actual_names != expected_names:
|
|
79
|
+
diff = format_value_diff(expected_names, actual_names, label="tool call order")
|
|
80
|
+
raise AssertionError(f"Tool call order mismatch:\n{diff}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def assert_tool_args_contain(
|
|
84
|
+
result: AgentResult,
|
|
85
|
+
tool_name: str,
|
|
86
|
+
partial_args: dict[str, Any],
|
|
87
|
+
*,
|
|
88
|
+
call_index: int = 0,
|
|
89
|
+
) -> None:
|
|
90
|
+
"""Assert that a tool call's arguments contain the given subset.
|
|
91
|
+
|
|
92
|
+
Useful when the agent adds auto-generated fields (timestamps, IDs)
|
|
93
|
+
that you don't want to assert on.
|
|
94
|
+
|
|
95
|
+
Parameters
|
|
96
|
+
----------
|
|
97
|
+
result:
|
|
98
|
+
The :class:`AgentResult` to inspect.
|
|
99
|
+
tool_name:
|
|
100
|
+
Name of the tool to check.
|
|
101
|
+
partial_args:
|
|
102
|
+
A dict of key-value pairs that must be present in the actual args.
|
|
103
|
+
call_index:
|
|
104
|
+
If the tool was called multiple times, which invocation to check
|
|
105
|
+
(0-indexed, default 0).
|
|
106
|
+
"""
|
|
107
|
+
matching = result.get_all_tool_calls(tool_name)
|
|
108
|
+
if not matching:
|
|
109
|
+
raise AssertionError(
|
|
110
|
+
f"Tool '{tool_name}' was never called.\n Actual calls: {result.tool_names}"
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
if call_index >= len(matching):
|
|
114
|
+
raise AssertionError(
|
|
115
|
+
f"Tool '{tool_name}' was called {len(matching)} time(s), "
|
|
116
|
+
f"but call_index={call_index} requested."
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
actual_args = matching[call_index].args
|
|
120
|
+
missing: dict[str, Any] = {}
|
|
121
|
+
mismatched: dict[str, tuple[Any, Any]] = {}
|
|
122
|
+
|
|
123
|
+
for key, expected_val in partial_args.items():
|
|
124
|
+
if key not in actual_args:
|
|
125
|
+
missing[key] = expected_val
|
|
126
|
+
elif actual_args[key] != expected_val:
|
|
127
|
+
mismatched[key] = (expected_val, actual_args[key])
|
|
128
|
+
|
|
129
|
+
if missing or mismatched:
|
|
130
|
+
parts: list[str] = [f"Partial argument mismatch for tool '{tool_name}':"]
|
|
131
|
+
if missing:
|
|
132
|
+
parts.append(f" Missing keys: {missing}")
|
|
133
|
+
if mismatched:
|
|
134
|
+
for k, (exp, act) in mismatched.items():
|
|
135
|
+
parts.append(f" Key '{k}': expected {exp!r}, got {act!r}")
|
|
136
|
+
raise AssertionError("\n".join(parts))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
# ---------------------------------------------------------------------------
|
|
140
|
+
# Response assertions
|
|
141
|
+
# ---------------------------------------------------------------------------
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
def assert_response_contains(result: AgentResult, substring: str) -> None:
|
|
145
|
+
"""Assert that the agent's text response contains the given substring."""
|
|
146
|
+
if substring not in result.response:
|
|
147
|
+
raise AssertionError(
|
|
148
|
+
f"Response does not contain expected substring.\n"
|
|
149
|
+
f" Expected substring: {substring!r}\n"
|
|
150
|
+
f" Actual response: {result.response!r}"
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
|
|
154
|
+
def assert_response_matches(result: AgentResult, pattern: str) -> None:
|
|
155
|
+
"""Assert that the agent's text response matches the given regex pattern."""
|
|
156
|
+
if not re.search(pattern, result.response):
|
|
157
|
+
raise AssertionError(
|
|
158
|
+
f"Response does not match expected pattern.\n"
|
|
159
|
+
f" Pattern: {pattern!r}\n"
|
|
160
|
+
f" Actual response: {result.response!r}"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
def assert_response_equals(result: AgentResult, expected: str) -> None:
|
|
165
|
+
"""Assert that the agent's text response exactly equals the expected string."""
|
|
166
|
+
if result.response != expected:
|
|
167
|
+
diff = format_value_diff(expected, result.response, label="response")
|
|
168
|
+
raise AssertionError(f"Response does not match exactly:\n{diff}")
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
# ---------------------------------------------------------------------------
|
|
172
|
+
# Model / routing assertions
|
|
173
|
+
# ---------------------------------------------------------------------------
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def assert_model_used(result: AgentResult, expected_model: str) -> None:
|
|
177
|
+
"""Assert that the agent used the expected model or routing strategy."""
|
|
178
|
+
if result.model_used != expected_model:
|
|
179
|
+
raise AssertionError(f"Expected model '{expected_model}', got '{result.model_used}'.")
|
|
180
|
+
|
|
181
|
+
|
|
182
|
+
# ---------------------------------------------------------------------------
|
|
183
|
+
# Header assertions
|
|
184
|
+
# ---------------------------------------------------------------------------
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def assert_headers(
|
|
188
|
+
result: AgentResult,
|
|
189
|
+
expected_headers: dict[str, str],
|
|
190
|
+
*,
|
|
191
|
+
exact: bool = False,
|
|
192
|
+
) -> None:
|
|
193
|
+
"""Assert that the captured request headers contain the expected values.
|
|
194
|
+
|
|
195
|
+
Parameters
|
|
196
|
+
----------
|
|
197
|
+
result:
|
|
198
|
+
The :class:`AgentResult` to inspect.
|
|
199
|
+
expected_headers:
|
|
200
|
+
Key-value pairs that must be present.
|
|
201
|
+
exact:
|
|
202
|
+
If ``True``, the headers must match exactly with no extra keys.
|
|
203
|
+
"""
|
|
204
|
+
actual = result.headers
|
|
205
|
+
|
|
206
|
+
if exact and set(actual.keys()) != set(expected_headers.keys()):
|
|
207
|
+
extra = set(actual.keys()) - set(expected_headers.keys())
|
|
208
|
+
missing_keys = set(expected_headers.keys()) - set(actual.keys())
|
|
209
|
+
parts = ["Headers do not match exactly:"]
|
|
210
|
+
if extra:
|
|
211
|
+
parts.append(f" Unexpected keys: {extra}")
|
|
212
|
+
if missing_keys:
|
|
213
|
+
parts.append(f" Missing keys: {missing_keys}")
|
|
214
|
+
raise AssertionError("\n".join(parts))
|
|
215
|
+
|
|
216
|
+
missing: dict[str, str] = {}
|
|
217
|
+
mismatched: dict[str, tuple[str, str | None]] = {}
|
|
218
|
+
|
|
219
|
+
for key, expected_val in expected_headers.items():
|
|
220
|
+
if key not in actual:
|
|
221
|
+
missing[key] = expected_val
|
|
222
|
+
elif actual[key] != expected_val:
|
|
223
|
+
mismatched[key] = (expected_val, actual[key])
|
|
224
|
+
|
|
225
|
+
if missing or mismatched:
|
|
226
|
+
parts = ["Header assertion failed:"]
|
|
227
|
+
if missing:
|
|
228
|
+
parts.append(f" Missing headers: {missing}")
|
|
229
|
+
if mismatched:
|
|
230
|
+
for k, (exp, act) in mismatched.items():
|
|
231
|
+
parts.append(f" Header '{k}': expected {exp!r}, got {act!r}")
|
|
232
|
+
raise AssertionError("\n".join(parts))
|
|
@@ -0,0 +1,265 @@
|
|
|
1
|
+
"""Test data builders and factories.
|
|
2
|
+
|
|
3
|
+
Fluent APIs for constructing test messages, tool responses, and complete
|
|
4
|
+
test scenarios without repetitive boilerplate.
|
|
5
|
+
"""
|
|
6
|
+
|
|
7
|
+
from __future__ import annotations
|
|
8
|
+
|
|
9
|
+
from typing import Any
|
|
10
|
+
|
|
11
|
+
from toolcallcheck.fake_model import FakeModel
|
|
12
|
+
from toolcallcheck.mock_server import MockMCPServer, MockTool
|
|
13
|
+
from toolcallcheck.runner import AgentRunner
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class UserMessageBuilder:
|
|
17
|
+
"""Fluent builder for constructing agent invocation parameters.
|
|
18
|
+
|
|
19
|
+
Usage::
|
|
20
|
+
|
|
21
|
+
msg = (
|
|
22
|
+
UserMessageBuilder("Create user Jane")
|
|
23
|
+
.with_token("test-jwt")
|
|
24
|
+
.with_site("1234")
|
|
25
|
+
.with_header("X-Custom", "value")
|
|
26
|
+
.build()
|
|
27
|
+
)
|
|
28
|
+
result = await runner.invoke(**msg)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
def __init__(self, message: str) -> None:
|
|
32
|
+
self._message = message
|
|
33
|
+
self._token: str | None = None
|
|
34
|
+
self._site_id: str | None = None
|
|
35
|
+
self._headers: dict[str, str] = {}
|
|
36
|
+
self._metadata: dict[str, Any] = {}
|
|
37
|
+
|
|
38
|
+
def with_token(self, token: str) -> UserMessageBuilder:
|
|
39
|
+
"""Set the access token."""
|
|
40
|
+
self._token = token
|
|
41
|
+
return self
|
|
42
|
+
|
|
43
|
+
def with_site(self, site_id: str) -> UserMessageBuilder:
|
|
44
|
+
"""Set the site ID."""
|
|
45
|
+
self._site_id = site_id
|
|
46
|
+
return self
|
|
47
|
+
|
|
48
|
+
def with_header(self, key: str, value: str) -> UserMessageBuilder:
|
|
49
|
+
"""Add a request header."""
|
|
50
|
+
self._headers[key] = value
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def with_metadata(self, key: str, value: Any) -> UserMessageBuilder:
|
|
54
|
+
"""Add metadata."""
|
|
55
|
+
self._metadata[key] = value
|
|
56
|
+
return self
|
|
57
|
+
|
|
58
|
+
def build(self) -> dict[str, Any]:
|
|
59
|
+
"""Build the invocation kwargs dict."""
|
|
60
|
+
result: dict[str, Any] = {"message": self._message}
|
|
61
|
+
if self._token:
|
|
62
|
+
result["access_token"] = self._token
|
|
63
|
+
if self._site_id:
|
|
64
|
+
result["site_id"] = self._site_id
|
|
65
|
+
if self._headers:
|
|
66
|
+
result["headers"] = self._headers
|
|
67
|
+
if self._metadata:
|
|
68
|
+
result["metadata"] = self._metadata
|
|
69
|
+
return result
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class ToolResponseBuilder:
|
|
73
|
+
"""Fluent builder for constructing mock tool responses.
|
|
74
|
+
|
|
75
|
+
Usage::
|
|
76
|
+
|
|
77
|
+
tool = (
|
|
78
|
+
ToolResponseBuilder("create_user")
|
|
79
|
+
.with_description("Create a new user")
|
|
80
|
+
.with_param("firstName", "string")
|
|
81
|
+
.with_param("email", "string")
|
|
82
|
+
.with_response({"status": "success"})
|
|
83
|
+
.build()
|
|
84
|
+
)
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
def __init__(self, name: str) -> None:
|
|
88
|
+
self._name = name
|
|
89
|
+
self._description = ""
|
|
90
|
+
self._params: dict[str, Any] = {}
|
|
91
|
+
self._response: Any = None
|
|
92
|
+
self._error: str | dict[str, Any] | None = None
|
|
93
|
+
self._error_after: int | None = None
|
|
94
|
+
|
|
95
|
+
def with_description(self, desc: str) -> ToolResponseBuilder:
|
|
96
|
+
"""Set the tool description."""
|
|
97
|
+
self._description = desc
|
|
98
|
+
return self
|
|
99
|
+
|
|
100
|
+
def with_param(
|
|
101
|
+
self,
|
|
102
|
+
name: str,
|
|
103
|
+
type_: str = "string",
|
|
104
|
+
*,
|
|
105
|
+
required: bool = True,
|
|
106
|
+
) -> ToolResponseBuilder:
|
|
107
|
+
"""Add a parameter definition."""
|
|
108
|
+
self._params[name] = {"type": type_, "required": required}
|
|
109
|
+
return self
|
|
110
|
+
|
|
111
|
+
def with_response(self, response: Any) -> ToolResponseBuilder:
|
|
112
|
+
"""Set the canned success response."""
|
|
113
|
+
self._response = response
|
|
114
|
+
return self
|
|
115
|
+
|
|
116
|
+
def with_conditional_response(self, fn: Any) -> ToolResponseBuilder:
|
|
117
|
+
"""Set a conditional response function ``(args) -> response``."""
|
|
118
|
+
self._response = fn
|
|
119
|
+
return self
|
|
120
|
+
|
|
121
|
+
def with_error(
|
|
122
|
+
self,
|
|
123
|
+
error: str | dict[str, Any],
|
|
124
|
+
*,
|
|
125
|
+
after: int | None = None,
|
|
126
|
+
) -> ToolResponseBuilder:
|
|
127
|
+
"""Set error injection."""
|
|
128
|
+
self._error = error
|
|
129
|
+
self._error_after = after
|
|
130
|
+
return self
|
|
131
|
+
|
|
132
|
+
def build(self) -> MockTool:
|
|
133
|
+
"""Build the MockTool."""
|
|
134
|
+
return MockTool(
|
|
135
|
+
name=self._name,
|
|
136
|
+
description=self._description,
|
|
137
|
+
parameters=self._params,
|
|
138
|
+
response=self._response,
|
|
139
|
+
error=self._error,
|
|
140
|
+
error_after=self._error_after,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
|
|
144
|
+
class ScenarioBuilder:
|
|
145
|
+
"""Fluent builder for constructing complete test scenarios.
|
|
146
|
+
|
|
147
|
+
Usage::
|
|
148
|
+
|
|
149
|
+
scenario = (
|
|
150
|
+
ScenarioBuilder("user_flow")
|
|
151
|
+
.with_tool(
|
|
152
|
+
ToolResponseBuilder("create_user")
|
|
153
|
+
.with_response({"status": "ok"})
|
|
154
|
+
.build()
|
|
155
|
+
)
|
|
156
|
+
.with_model_responses([
|
|
157
|
+
{"tool_calls": [{"name": "create_user", "args": {"email": "a@example.com"}}]},
|
|
158
|
+
{"content": "Done!"},
|
|
159
|
+
])
|
|
160
|
+
.with_message("Create user a@example.com")
|
|
161
|
+
.build()
|
|
162
|
+
)
|
|
163
|
+
result = await scenario.runner.invoke(**scenario.invocation)
|
|
164
|
+
"""
|
|
165
|
+
|
|
166
|
+
def __init__(self, name: str) -> None:
|
|
167
|
+
self._name = name
|
|
168
|
+
self._tools: list[MockTool] = []
|
|
169
|
+
self._model_responses: list[dict[str, Any]] = []
|
|
170
|
+
self._model_rules: list[tuple[str, dict[str, Any]]] = []
|
|
171
|
+
self._message: str = ""
|
|
172
|
+
self._token: str | None = None
|
|
173
|
+
self._site_id: str | None = None
|
|
174
|
+
self._headers: dict[str, str] = {}
|
|
175
|
+
self._model_name: str | None = None
|
|
176
|
+
|
|
177
|
+
def with_tool(self, tool: MockTool) -> ScenarioBuilder:
|
|
178
|
+
"""Add a mock tool."""
|
|
179
|
+
self._tools.append(tool)
|
|
180
|
+
return self
|
|
181
|
+
|
|
182
|
+
def with_model_responses(self, responses: list[dict[str, Any]]) -> ScenarioBuilder:
|
|
183
|
+
"""Set scripted model responses."""
|
|
184
|
+
self._model_responses = responses
|
|
185
|
+
return self
|
|
186
|
+
|
|
187
|
+
def with_model_rules(self, rules: list[tuple[str, dict[str, Any]]]) -> ScenarioBuilder:
|
|
188
|
+
"""Set pattern-matching model rules."""
|
|
189
|
+
self._model_rules = rules
|
|
190
|
+
return self
|
|
191
|
+
|
|
192
|
+
def with_message(self, message: str) -> ScenarioBuilder:
|
|
193
|
+
"""Set the user message."""
|
|
194
|
+
self._message = message
|
|
195
|
+
return self
|
|
196
|
+
|
|
197
|
+
def with_token(self, token: str) -> ScenarioBuilder:
|
|
198
|
+
"""Set the access token."""
|
|
199
|
+
self._token = token
|
|
200
|
+
return self
|
|
201
|
+
|
|
202
|
+
def with_site_id(self, site_id: str) -> ScenarioBuilder:
|
|
203
|
+
"""Set the site ID."""
|
|
204
|
+
self._site_id = site_id
|
|
205
|
+
return self
|
|
206
|
+
|
|
207
|
+
def with_header(self, key: str, value: str) -> ScenarioBuilder:
|
|
208
|
+
"""Add a request header."""
|
|
209
|
+
self._headers[key] = value
|
|
210
|
+
return self
|
|
211
|
+
|
|
212
|
+
def with_model_name(self, name: str) -> ScenarioBuilder:
|
|
213
|
+
"""Set the model name for assertions."""
|
|
214
|
+
self._model_name = name
|
|
215
|
+
return self
|
|
216
|
+
|
|
217
|
+
def build(self) -> BuiltScenario:
|
|
218
|
+
"""Build the scenario."""
|
|
219
|
+
server = MockMCPServer()
|
|
220
|
+
server.add_tools(self._tools)
|
|
221
|
+
|
|
222
|
+
model = FakeModel(
|
|
223
|
+
responses=self._model_responses,
|
|
224
|
+
rules=self._model_rules,
|
|
225
|
+
)
|
|
226
|
+
|
|
227
|
+
runner = AgentRunner(
|
|
228
|
+
mcp_server=server,
|
|
229
|
+
model=model,
|
|
230
|
+
model_name=self._model_name,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
invocation: dict[str, Any] = {"message": self._message}
|
|
234
|
+
if self._token:
|
|
235
|
+
invocation["access_token"] = self._token
|
|
236
|
+
if self._site_id:
|
|
237
|
+
invocation["site_id"] = self._site_id
|
|
238
|
+
if self._headers:
|
|
239
|
+
invocation["headers"] = self._headers
|
|
240
|
+
|
|
241
|
+
return BuiltScenario(
|
|
242
|
+
name=self._name,
|
|
243
|
+
runner=runner,
|
|
244
|
+
server=server,
|
|
245
|
+
model=model,
|
|
246
|
+
invocation=invocation,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
class BuiltScenario:
|
|
251
|
+
"""Result of a :class:`ScenarioBuilder.build()` call."""
|
|
252
|
+
|
|
253
|
+
def __init__(
|
|
254
|
+
self,
|
|
255
|
+
name: str,
|
|
256
|
+
runner: AgentRunner,
|
|
257
|
+
server: MockMCPServer,
|
|
258
|
+
model: FakeModel,
|
|
259
|
+
invocation: dict[str, Any],
|
|
260
|
+
) -> None:
|
|
261
|
+
self.name = name
|
|
262
|
+
self.runner = runner
|
|
263
|
+
self.server = server
|
|
264
|
+
self.model = model
|
|
265
|
+
self.invocation = invocation
|