pytest-agentcontract 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,360 @@
1
+ """Assertion engine: validates agent trajectories against contracts."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import re
6
+ from dataclasses import dataclass, field
7
+ from typing import Any
8
+
9
+ import jsonschema
10
+
11
+ from agentcontract.config import AssertionSpec, PolicySpec
12
+ from agentcontract.types import AgentRun, TurnRole
13
+
14
+
15
+ @dataclass
16
+ class AssertionResult:
17
+ """Result of a single assertion check."""
18
+
19
+ assertion: AssertionSpec
20
+ passed: bool
21
+ message: str = ""
22
+ details: dict[str, Any] = field(default_factory=dict)
23
+
24
+
25
+ @dataclass
26
+ class ContractResult:
27
+ """Result of all assertions + policies against a trajectory."""
28
+
29
+ scenario: str
30
+ results: list[AssertionResult] = field(default_factory=list)
31
+
32
+ @property
33
+ def passed(self) -> bool:
34
+ return all(r.passed for r in self.results)
35
+
36
+ @property
37
+ def failed_count(self) -> int:
38
+ return sum(1 for r in self.results if not r.passed)
39
+
40
+ def failures(self) -> list[AssertionResult]:
41
+ return [r for r in self.results if not r.passed]
42
+
43
+
44
+ class AssertionEngine:
45
+ """Evaluates assertions and policies against an AgentRun trajectory.
46
+
47
+ Usage:
48
+ engine = AssertionEngine()
49
+ result = engine.check(run, assertions, policies)
50
+ assert result.passed
51
+ """
52
+
53
+ def check(
54
+ self,
55
+ run: AgentRun,
56
+ assertions: list[AssertionSpec] | None = None,
57
+ policies: list[PolicySpec] | None = None,
58
+ ) -> ContractResult:
59
+ """Run all assertions and policies against a trajectory."""
60
+ result = ContractResult(scenario=run.metadata.scenario)
61
+
62
+ for assertion in assertions or []:
63
+ ar = self._check_assertion(run, assertion)
64
+ result.results.append(ar)
65
+
66
+ for policy in policies or []:
67
+ pr = self._check_policy(run, policy)
68
+ result.results.append(pr)
69
+
70
+ return result
71
+
72
+ def _check_assertion(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
73
+ """Dispatch to the appropriate assertion checker."""
74
+ checkers = {
75
+ "exact": self._check_exact,
76
+ "contains": self._check_contains,
77
+ "regex": self._check_regex,
78
+ "json_schema": self._check_json_schema,
79
+ "not_called": self._check_not_called,
80
+ "called_with": self._check_called_with,
81
+ "called_count": self._check_called_count,
82
+ }
83
+
84
+ checker = checkers.get(spec.type)
85
+ if checker is None:
86
+ return AssertionResult(
87
+ assertion=spec,
88
+ passed=False,
89
+ message=f"Unknown assertion type: {spec.type}",
90
+ )
91
+
92
+ try:
93
+ return checker(run, spec)
94
+ except Exception as e:
95
+ return AssertionResult(
96
+ assertion=spec,
97
+ passed=False,
98
+ message=f"Assertion error: {e}",
99
+ )
100
+
101
+ def _resolve_target(self, run: AgentRun, target: str) -> Any:
102
+ """Resolve a target string to actual content from the run.
103
+
104
+ Targets:
105
+ final_response - last assistant message content
106
+ turn:N - specific turn content
107
+ full_conversation - all turns concatenated
108
+ tool_call:function_name:arguments - tool call arguments
109
+ tool_call:function_name:result - tool call result
110
+ """
111
+ if target == "final_response":
112
+ for turn in reversed(run.turns):
113
+ if turn.role == TurnRole.ASSISTANT and turn.content is not None:
114
+ return turn.content
115
+ return None
116
+
117
+ if target == "full_conversation":
118
+ parts = []
119
+ for turn in run.turns:
120
+ if turn.content is not None:
121
+ parts.append(f"{turn.role.value}: {turn.content}")
122
+ return "\n".join(parts)
123
+
124
+ if target.startswith("turn:"):
125
+ _, raw_idx = target.split(":", 1)
126
+ idx = int(raw_idx)
127
+ if 0 <= idx < len(run.turns):
128
+ return run.turns[idx].content
129
+ return None
130
+
131
+ if target.startswith("tool_call:"):
132
+ parts = target.split(":", 2)
133
+ if len(parts) < 2 or not parts[1]:
134
+ return None
135
+ func_name = parts[1]
136
+ field_name = parts[2] if len(parts) > 2 else "arguments"
137
+
138
+ for turn in run.turns:
139
+ for tc in turn.tool_calls:
140
+ if tc.function == func_name:
141
+ if field_name == "arguments":
142
+ return tc.arguments
143
+ elif field_name == "result":
144
+ return tc.result
145
+ return None
146
+
147
+ return None
148
+
149
+ def _get_all_tool_calls(self, run: AgentRun) -> list[tuple[str, dict[str, Any], Any]]:
150
+ """Get all tool calls as (function, arguments, result) tuples."""
151
+ calls = []
152
+ for turn in run.turns:
153
+ for tc in turn.tool_calls:
154
+ calls.append((tc.function, tc.arguments, tc.result))
155
+ return calls
156
+
157
+ def _check_exact(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
158
+ if spec.value is None:
159
+ return AssertionResult(
160
+ assertion=spec,
161
+ passed=False,
162
+ message="'exact' requires a non-null 'value'",
163
+ )
164
+ actual = self._resolve_target(run, spec.target)
165
+ passed = actual == spec.value
166
+ return AssertionResult(
167
+ assertion=spec,
168
+ passed=passed,
169
+ message="" if passed else f"Expected exact '{spec.value}', got '{actual}'",
170
+ )
171
+
172
+ def _check_contains(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
173
+ actual = self._resolve_target(run, spec.target)
174
+ if actual is None or spec.value is None:
175
+ return AssertionResult(
176
+ assertion=spec,
177
+ passed=False,
178
+ message=f"Target '{spec.target}' resolved to None",
179
+ )
180
+ passed = spec.value in str(actual)
181
+ return AssertionResult(
182
+ assertion=spec,
183
+ passed=passed,
184
+ message="" if passed else f"'{spec.value}' not found in target",
185
+ )
186
+
187
+ def _check_regex(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
188
+ actual = self._resolve_target(run, spec.target)
189
+ if actual is None or spec.value is None:
190
+ return AssertionResult(
191
+ assertion=spec, passed=False, message="Target or pattern is None"
192
+ )
193
+ passed = bool(re.search(spec.value, str(actual)))
194
+ return AssertionResult(
195
+ assertion=spec,
196
+ passed=passed,
197
+ message="" if passed else f"Pattern '{spec.value}' not matched",
198
+ )
199
+
200
+ def _check_json_schema(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
201
+ actual = self._resolve_target(run, spec.target)
202
+ if actual is None or spec.schema is None:
203
+ return AssertionResult(assertion=spec, passed=False, message="Target or schema is None")
204
+ try:
205
+ jsonschema.validate(instance=actual, schema=spec.schema)
206
+ return AssertionResult(assertion=spec, passed=True)
207
+ except jsonschema.ValidationError as e:
208
+ return AssertionResult(
209
+ assertion=spec,
210
+ passed=False,
211
+ message=f"Schema validation failed: {e.message}",
212
+ )
213
+
214
+ def _check_not_called(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
215
+ """Assert a tool was NOT called."""
216
+ # target format: "tool:function_name" or just the function name
217
+ func_name = (
218
+ spec.target.replace("tool:", "") if spec.target.startswith("tool:") else spec.target
219
+ )
220
+ calls = self._get_all_tool_calls(run)
221
+ called = any(name == func_name for name, _, _ in calls)
222
+ return AssertionResult(
223
+ assertion=spec,
224
+ passed=not called,
225
+ message="" if not called else f"Tool '{func_name}' was called but should not have been",
226
+ )
227
+
228
+ def _check_called_with(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
229
+ """Assert a tool was called with specific arguments."""
230
+ func_name = (
231
+ spec.target.replace("tool:", "") if spec.target.startswith("tool:") else spec.target
232
+ )
233
+ if spec.schema is None:
234
+ return AssertionResult(
235
+ assertion=spec,
236
+ passed=False,
237
+ message="'called_with' requires expected arguments in 'schema'",
238
+ )
239
+ if not isinstance(spec.schema, dict):
240
+ return AssertionResult(
241
+ assertion=spec,
242
+ passed=False,
243
+ message="'called_with' expects a dict in 'schema'",
244
+ )
245
+
246
+ expected_args = spec.schema # reuse schema field for expected args
247
+ calls = self._get_all_tool_calls(run)
248
+
249
+ for name, args, _ in calls:
250
+ if name == func_name and isinstance(args, dict):
251
+ # Check if expected args are a subset of actual args
252
+ match = all(args.get(k) == v for k, v in expected_args.items())
253
+ if match:
254
+ return AssertionResult(assertion=spec, passed=True)
255
+
256
+ return AssertionResult(
257
+ assertion=spec,
258
+ passed=False,
259
+ message=f"Tool '{func_name}' was not called with expected arguments",
260
+ )
261
+
262
+ def _check_called_count(self, run: AgentRun, spec: AssertionSpec) -> AssertionResult:
263
+ """Assert a tool was called exactly N times."""
264
+ func_name = (
265
+ spec.target.replace("tool:", "") if spec.target.startswith("tool:") else spec.target
266
+ )
267
+ if spec.value is None:
268
+ return AssertionResult(
269
+ assertion=spec,
270
+ passed=False,
271
+ message="'called_count' requires an integer 'value'",
272
+ )
273
+ try:
274
+ expected_count = int(spec.value)
275
+ except (TypeError, ValueError):
276
+ return AssertionResult(
277
+ assertion=spec,
278
+ passed=False,
279
+ message="'called_count' expects an integer 'value'",
280
+ )
281
+
282
+ calls = self._get_all_tool_calls(run)
283
+ actual_count = sum(1 for name, _, _ in calls if name == func_name)
284
+
285
+ passed = actual_count == expected_count
286
+ return AssertionResult(
287
+ assertion=spec,
288
+ passed=passed,
289
+ message=""
290
+ if passed
291
+ else f"Tool '{func_name}' called {actual_count} times, expected {expected_count}",
292
+ )
293
+
294
+ def _check_policy(self, run: AgentRun, policy: PolicySpec) -> AssertionResult:
295
+ """Check a policy against the trajectory."""
296
+ policy_checkers = {
297
+ "tool_allowlist": self._policy_tool_allowlist,
298
+ "requires_confirmation": self._policy_requires_confirmation,
299
+ }
300
+
301
+ checker = policy_checkers.get(policy.type)
302
+ if checker is None:
303
+ return AssertionResult(
304
+ assertion=AssertionSpec(type=f"policy:{policy.name}"),
305
+ passed=False,
306
+ message=f"Unknown policy type: {policy.type}",
307
+ )
308
+
309
+ try:
310
+ return checker(run, policy)
311
+ except Exception as e:
312
+ return AssertionResult(
313
+ assertion=AssertionSpec(type=f"policy:{policy.name}"),
314
+ passed=False,
315
+ message=f"Policy error: {e}",
316
+ )
317
+
318
+ def _policy_tool_allowlist(self, run: AgentRun, policy: PolicySpec) -> AssertionResult:
319
+ """Only allowed tools may be called."""
320
+ calls = self._get_all_tool_calls(run)
321
+ violations = [name for name, _, _ in calls if name not in policy.tools]
322
+
323
+ spec = AssertionSpec(type=f"policy:{policy.name}", target="all_tool_calls")
324
+ if violations:
325
+ return AssertionResult(
326
+ assertion=spec,
327
+ passed=False,
328
+ message=f"Disallowed tools called: {violations}",
329
+ )
330
+ return AssertionResult(assertion=spec, passed=True)
331
+
332
+ def _policy_requires_confirmation(self, run: AgentRun, policy: PolicySpec) -> AssertionResult:
333
+ """Protected tools must be preceded by a user confirmation turn."""
334
+ spec = AssertionSpec(type=f"policy:{policy.name}", target="tool_sequence")
335
+
336
+ for i, turn in enumerate(run.turns):
337
+ for tc in turn.tool_calls:
338
+ if tc.function in policy.tools:
339
+ # Check if previous turn was a user message (confirmation)
340
+ if i == 0:
341
+ return AssertionResult(
342
+ assertion=spec,
343
+ passed=False,
344
+ message=(
345
+ f"Tool '{tc.function}' called at turn 0 "
346
+ f"with no prior confirmation"
347
+ ),
348
+ )
349
+ prev = run.turns[i - 1]
350
+ if prev.role != TurnRole.USER:
351
+ return AssertionResult(
352
+ assertion=spec,
353
+ passed=False,
354
+ message=(
355
+ f"Tool '{tc.function}' at turn {i} "
356
+ f"not preceded by user confirmation"
357
+ ),
358
+ )
359
+
360
+ return AssertionResult(assertion=spec, passed=True)
agentcontract/cli.py ADDED
@@ -0,0 +1,127 @@
1
+ """CLI entry point for agentcontract commands."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import sys
7
+ from pathlib import Path
8
+
9
+
10
+ def main(argv: list[str] | None = None) -> int:
11
+ """Main CLI entry point."""
12
+ parser = argparse.ArgumentParser(
13
+ prog="agentcontract",
14
+ description="pytest-agentcontract: Deterministic CI tests for LLM agent trajectories",
15
+ )
16
+ subparsers = parser.add_subparsers(dest="command")
17
+
18
+ # info command
19
+ info_parser = subparsers.add_parser("info", help="Show cassette info")
20
+ info_parser.add_argument("path", type=Path, help="Path to .agentrun.json file")
21
+
22
+ # validate command
23
+ validate_parser = subparsers.add_parser("validate", help="Validate a cassette file")
24
+ validate_parser.add_argument("path", type=Path, help="Path to .agentrun.json file")
25
+
26
+ # init command
27
+ subparsers.add_parser("init", help="Create a starter agentcontract.yml")
28
+
29
+ args = parser.parse_args(argv)
30
+
31
+ if args.command == "info":
32
+ return _cmd_info(args.path)
33
+ elif args.command == "validate":
34
+ return _cmd_validate(args.path)
35
+ elif args.command == "init":
36
+ return _cmd_init()
37
+ else:
38
+ parser.print_help()
39
+ return 0
40
+
41
+
42
+ def _cmd_info(path: Path) -> int:
43
+ """Print summary info about a cassette."""
44
+ from agentcontract.serialization import load_run
45
+
46
+ if not path.exists():
47
+ print(f"Error: {path} not found", file=sys.stderr)
48
+ return 1
49
+
50
+ try:
51
+ run = load_run(path)
52
+ print(f"Scenario: {run.metadata.scenario}")
53
+ print(f"Run ID: {run.run_id}")
54
+ print(f"Recorded: {run.recorded_at}")
55
+ print(f"Model: {run.model.provider}/{run.model.model}")
56
+ print(f"Turns: {run.summary.total_turns}")
57
+ print(f"Tool calls: {run.summary.total_tool_calls}")
58
+ print(f"Duration: {run.summary.total_duration_ms:.0f}ms")
59
+ print(f"Tokens: {run.summary.total_tokens.total}")
60
+ print(f"Est. cost: ${run.summary.estimated_cost_usd:.4f}")
61
+ return 0
62
+ except Exception as e:
63
+ print(f"Error: failed to read cassette '{path}': {e}", file=sys.stderr)
64
+ return 1
65
+
66
+
67
+ def _cmd_validate(path: Path) -> int:
68
+ """Validate a cassette file structure."""
69
+ from agentcontract.serialization import load_run
70
+
71
+ if not path.exists():
72
+ print(f"Error: {path} not found", file=sys.stderr)
73
+ return 1
74
+
75
+ try:
76
+ run = load_run(path)
77
+ print(f"✓ Valid cassette: {run.metadata.scenario} ({len(run.turns)} turns)")
78
+ return 0
79
+ except Exception as e:
80
+ print(f"✗ Invalid cassette: {e}", file=sys.stderr)
81
+ return 1
82
+
83
+
84
+ def _cmd_init() -> int:
85
+ """Create a starter agentcontract.yml in the current directory."""
86
+ target = Path("agentcontract.yml")
87
+ if target.exists():
88
+ print(f"Error: {target} already exists", file=sys.stderr)
89
+ return 1
90
+
91
+ template = """\
92
+ version: "1"
93
+
94
+ scenarios:
95
+ include: ["tests/scenarios/**/*.agentrun.json"]
96
+
97
+ replay:
98
+ stub_tools: true
99
+ concurrency: 5
100
+
101
+ defaults:
102
+ assertions:
103
+ - type: contains
104
+ target: final_response
105
+ value: "" # customize this
106
+
107
+ policies:
108
+ - name: allowed-tools
109
+ type: tool_allowlist
110
+ tools: [] # list your agent's tools here
111
+
112
+ budgets:
113
+ per_scenario:
114
+ max_cost_usd: 0.05
115
+ max_turns: 15
116
+
117
+ reporting:
118
+ github_comment: true
119
+ artifact_path: "agentci-results/"
120
+ """
121
+ target.write_text(template)
122
+ print(f"Created {target}")
123
+ return 0
124
+
125
+
126
+ if __name__ == "__main__":
127
+ sys.exit(main())