pytest-agentcontract 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,25 @@
1
+ """pytest-agentcontract: Deterministic CI tests for LLM agent trajectories."""
2
+
3
+ __version__ = "0.1.0"
4
+
5
+ # Lazy imports to avoid circular dependencies and speed up pytest plugin loading.
6
+ # The plugin.py entry point imports specific modules directly.
7
+
8
+
9
+ def __getattr__(name: str): # noqa: ANN001
10
+ """Lazy-load public API classes on first access."""
11
+ _lazy = {
12
+ "Recorder": "agentcontract.recorder.core",
13
+ "ReplayEngine": "agentcontract.replay.engine",
14
+ "AssertionEngine": "agentcontract.assertions.engine",
15
+ "AgentContractConfig": "agentcontract.config",
16
+ }
17
+ if name in _lazy:
18
+ import importlib
19
+
20
+ mod = importlib.import_module(_lazy[name])
21
+ return getattr(mod, name)
22
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
23
+
24
+
25
+ __all__ = ["Recorder", "ReplayEngine", "AssertionEngine", "AgentContractConfig"]
@@ -0,0 +1,36 @@
1
+ """Framework adapters for popular agent libraries.
2
+
3
+ Each adapter wraps a framework's execution method to record trajectories
4
+ into the agentcontract format. All adapters follow the same pattern:
5
+
6
+ unpatch = record_<thing>(target, recorder)
7
+ # ... run your agent ...
8
+ unpatch()
9
+
10
+ Available adapters:
11
+ - langgraph: LangGraph CompiledGraph (invoke/ainvoke)
12
+ - llamaindex: LlamaIndex AgentRunner (chat/query)
13
+ - openai_agents: OpenAI Agents SDK Runner (run/run_sync)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ from typing import Any
19
+
20
+
21
+ def __getattr__(name: str) -> Any:
22
+ """Lazy-load adapters to avoid importing heavy framework dependencies."""
23
+ _lazy = {
24
+ "record_graph": "agentcontract.adapters.langgraph",
25
+ "record_agent": "agentcontract.adapters.llamaindex",
26
+ "record_runner": "agentcontract.adapters.openai_agents",
27
+ }
28
+ if name in _lazy:
29
+ import importlib
30
+
31
+ mod = importlib.import_module(_lazy[name])
32
+ return getattr(mod, name)
33
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
34
+
35
+
36
+ __all__ = ["record_graph", "record_agent", "record_runner"]
@@ -0,0 +1,157 @@
1
+ """LangGraph adapter for pytest-agentcontract.
2
+
3
+ Records agent trajectories from LangGraph graph executions by wrapping
4
+ the graph's stream/invoke methods.
5
+
6
+ Usage:
7
+ from agentcontract.adapters.langgraph import record_graph
8
+
9
+ recorder = Recorder(scenario="customer-support")
10
+ with recorder.recording():
11
+ unpatch = record_graph(graph, recorder)
12
+ result = graph.invoke({"messages": [("user", "I need a refund")]})
13
+ unpatch()
14
+ recorder.save("tests/scenarios/customer-support.agentrun.json")
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import functools
20
+ import time
21
+ from collections.abc import Callable
22
+ from typing import Any
23
+
24
+ from agentcontract.recorder.core import Recorder
25
+
26
+
27
+ def record_graph(graph: Any, recorder: Recorder) -> Callable[[], None]:
28
+ """Wrap a LangGraph CompiledGraph to record trajectories.
29
+
30
+ Intercepts ``invoke()`` and ``ainvoke()`` to capture agent turns,
31
+ tool calls, and message flow as an AgentRun.
32
+
33
+ Returns an unpatch function that restores original methods.
34
+
35
+ Args:
36
+ graph: A LangGraph CompiledGraph (from ``graph.compile()``).
37
+ recorder: A Recorder instance to capture the trajectory.
38
+ """
39
+ original_invoke = getattr(graph, "invoke", None)
40
+ original_ainvoke = getattr(graph, "ainvoke", None)
41
+
42
+ if original_invoke is not None:
43
+
44
+ @functools.wraps(original_invoke)
45
+ def recording_invoke(*args: Any, **kwargs: Any) -> Any:
46
+ start = time.monotonic()
47
+ result = original_invoke(*args, **kwargs)
48
+ latency_ms = (time.monotonic() - start) * 1000
49
+ _extract_turns(result, recorder, latency_ms)
50
+ return result
51
+
52
+ graph.invoke = recording_invoke
53
+
54
+ if original_ainvoke is not None:
55
+
56
+ @functools.wraps(original_ainvoke)
57
+ async def recording_ainvoke(*args: Any, **kwargs: Any) -> Any:
58
+ start = time.monotonic()
59
+ result = await original_ainvoke(*args, **kwargs)
60
+ latency_ms = (time.monotonic() - start) * 1000
61
+ _extract_turns(result, recorder, latency_ms)
62
+ return result
63
+
64
+ graph.ainvoke = recording_ainvoke
65
+
66
+ def unpatch() -> None:
67
+ if original_invoke is not None:
68
+ graph.invoke = original_invoke
69
+ if original_ainvoke is not None:
70
+ graph.ainvoke = original_ainvoke
71
+
72
+ return unpatch
73
+
74
+
75
+ def _extract_turns(result: Any, recorder: Recorder, latency_ms: float) -> None:
76
+ """Extract turns from a LangGraph result dict.
77
+
78
+ LangGraph returns a state dict. The standard ``messages`` key contains
79
+ the conversation as a list of LangChain message objects.
80
+ """
81
+ if not isinstance(result, dict):
82
+ return
83
+
84
+ messages = result.get("messages", [])
85
+ if not isinstance(messages, (list, tuple)):
86
+ return
87
+
88
+ for msg in messages:
89
+ role = _get_role(msg)
90
+ content = _get_content(msg)
91
+ tool_calls = _get_tool_calls(msg)
92
+
93
+ if role in ("user", "assistant", "system", "tool"):
94
+ recorder.add_turn(
95
+ role=role,
96
+ content=content,
97
+ tool_calls=tool_calls or None,
98
+ latency_ms=latency_ms if role == "assistant" else None,
99
+ )
100
+
101
+
102
+ def _get_role(msg: Any) -> str:
103
+ """Extract role from a LangChain message object or dict."""
104
+ if isinstance(msg, dict):
105
+ return str(msg.get("role", msg.get("type", "")))
106
+
107
+ # LangChain message classes: HumanMessage, AIMessage, SystemMessage, ToolMessage
108
+ type_attr = getattr(msg, "type", None)
109
+ if type_attr:
110
+ role_map = {"human": "user", "ai": "assistant", "system": "system", "tool": "tool"}
111
+ return role_map.get(str(type_attr), str(type_attr))
112
+
113
+ return ""
114
+
115
+
116
+ def _get_content(msg: Any) -> str | None:
117
+ """Extract text content from a message."""
118
+ content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None)
119
+
120
+ if content is None:
121
+ return None
122
+ if isinstance(content, str):
123
+ return content or None
124
+ # LangChain can return list of content blocks
125
+ if isinstance(content, list):
126
+ texts = [str(b.get("text", "")) if isinstance(b, dict) else str(b) for b in content]
127
+ joined = "".join(texts)
128
+ return joined or None
129
+ return str(content) or None
130
+
131
+
132
+ def _get_tool_calls(msg: Any) -> list[dict[str, Any]] | None:
133
+ """Extract tool calls from a LangChain AIMessage."""
134
+ raw = msg.get("tool_calls", []) if isinstance(msg, dict) else getattr(msg, "tool_calls", None)
135
+
136
+ if not raw or not isinstance(raw, (list, tuple)):
137
+ return None
138
+
139
+ calls: list[dict[str, Any]] = []
140
+ for tc in raw:
141
+ if isinstance(tc, dict):
142
+ name = tc.get("name", "")
143
+ args = tc.get("args", tc.get("arguments", {}))
144
+ call_id = tc.get("id", "")
145
+ else:
146
+ name = getattr(tc, "name", "")
147
+ args = getattr(tc, "args", getattr(tc, "arguments", {}))
148
+ call_id = getattr(tc, "id", "")
149
+
150
+ if name:
151
+ calls.append({
152
+ "id": str(call_id) if call_id else "",
153
+ "function": str(name),
154
+ "arguments": args if isinstance(args, dict) else {},
155
+ })
156
+
157
+ return calls or None
@@ -0,0 +1,167 @@
1
+ """LlamaIndex adapter for pytest-agentcontract.
2
+
3
+ Records agent trajectories from LlamaIndex AgentRunner / ReActAgent
4
+ by wrapping the chat/query methods.
5
+
6
+ Usage:
7
+ from agentcontract.adapters.llamaindex import record_agent
8
+
9
+ recorder = Recorder(scenario="rag-qa")
10
+ with recorder.recording():
11
+ unpatch = record_agent(agent, recorder)
12
+ response = agent.chat("What's the refund policy?")
13
+ unpatch()
14
+ recorder.save("tests/scenarios/rag-qa.agentrun.json")
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import functools
20
+ import time
21
+ from collections.abc import Callable
22
+ from typing import Any
23
+
24
+ from agentcontract.recorder.core import Recorder
25
+
26
+
27
+ def record_agent(agent: Any, recorder: Recorder) -> Callable[[], None]:
28
+ """Wrap a LlamaIndex agent to record trajectories.
29
+
30
+ Intercepts ``chat()``, ``achat()``, ``query()``, and ``aquery()``
31
+ to capture the full interaction as an AgentRun.
32
+
33
+ Returns an unpatch function that restores original methods.
34
+
35
+ Args:
36
+ agent: A LlamaIndex AgentRunner, ReActAgent, or similar.
37
+ recorder: A Recorder instance to capture the trajectory.
38
+ """
39
+ originals: dict[str, Any] = {}
40
+ methods = ["chat", "achat", "query", "aquery"]
41
+
42
+ for method_name in methods:
43
+ original = getattr(agent, method_name, None)
44
+ if original is None:
45
+ continue
46
+ originals[method_name] = original
47
+
48
+ if method_name.startswith("a"):
49
+ # Async variant
50
+ @functools.wraps(original)
51
+ async def async_wrapper(
52
+ *args: Any,
53
+ _orig: Any = original,
54
+ _name: str = method_name,
55
+ **kwargs: Any,
56
+ ) -> Any:
57
+ start = time.monotonic()
58
+ result = await _orig(*args, **kwargs)
59
+ latency_ms = (time.monotonic() - start) * 1000
60
+ _extract_from_response(result, recorder, latency_ms, agent)
61
+ return result
62
+
63
+ setattr(agent, method_name, async_wrapper)
64
+ else:
65
+ # Sync variant
66
+ @functools.wraps(original)
67
+ def sync_wrapper(
68
+ *args: Any,
69
+ _orig: Any = original,
70
+ _name: str = method_name,
71
+ **kwargs: Any,
72
+ ) -> Any:
73
+ start = time.monotonic()
74
+ result = _orig(*args, **kwargs)
75
+ latency_ms = (time.monotonic() - start) * 1000
76
+ _extract_from_response(result, recorder, latency_ms, agent)
77
+ return result
78
+
79
+ setattr(agent, method_name, sync_wrapper)
80
+
81
+ def unpatch() -> None:
82
+ for name, orig in originals.items():
83
+ setattr(agent, name, orig)
84
+
85
+ return unpatch
86
+
87
+
88
+ def _extract_from_response(
89
+ response: Any, recorder: Recorder, latency_ms: float, agent: Any
90
+ ) -> None:
91
+ """Extract turns from a LlamaIndex response.
92
+
93
+ LlamaIndex agents return AgentChatResponse or Response objects.
94
+ We also check the agent's chat_history/memory for the full trajectory.
95
+ """
96
+ # Try to get the response text
97
+ response_text = None
98
+ if hasattr(response, "response"):
99
+ response_text = str(response.response) if response.response else None
100
+ elif hasattr(response, "message"):
101
+ msg = response.message
102
+ response_text = _get_content(msg)
103
+
104
+ # Extract tool calls from sources/source_nodes (tool outputs)
105
+ tool_calls = _extract_tool_calls(response)
106
+
107
+ if response_text or tool_calls:
108
+ recorder.add_turn(
109
+ role="assistant",
110
+ content=response_text,
111
+ tool_calls=tool_calls or None,
112
+ latency_ms=latency_ms,
113
+ )
114
+
115
+
116
+ def _extract_tool_calls(response: Any) -> list[dict[str, Any]] | None:
117
+ """Extract tool calls from LlamaIndex response sources."""
118
+ calls: list[dict[str, Any]] = []
119
+
120
+ # AgentChatResponse has .sources which are ToolOutput objects
121
+ sources = getattr(response, "sources", None)
122
+ if sources and isinstance(sources, (list, tuple)):
123
+ for source in sources:
124
+ tool_name = getattr(source, "tool_name", None)
125
+ if not tool_name:
126
+ continue
127
+ raw_input = getattr(source, "raw_input", {})
128
+ raw_output = getattr(source, "raw_output", None)
129
+
130
+ calls.append({
131
+ "id": "",
132
+ "function": str(tool_name),
133
+ "arguments": raw_input if isinstance(raw_input, dict) else {},
134
+ "result": str(raw_output) if raw_output is not None else None,
135
+ })
136
+
137
+ # Also check source_nodes for retrieval-based responses
138
+ source_nodes = getattr(response, "source_nodes", None)
139
+ if source_nodes and isinstance(source_nodes, (list, tuple)):
140
+ for node in source_nodes:
141
+ node_id = getattr(node, "node_id", "") or getattr(node, "id_", "")
142
+ score = getattr(node, "score", None)
143
+ text = ""
144
+ inner = getattr(node, "node", None) or getattr(node, "text", None)
145
+ if inner:
146
+ text = getattr(inner, "text", str(inner))[:200]
147
+
148
+ calls.append({
149
+ "id": str(node_id) if node_id else "",
150
+ "function": "_retrieve",
151
+ "arguments": {"score": score} if score is not None else {},
152
+ "result": text or None,
153
+ })
154
+
155
+ return calls or None
156
+
157
+
158
+ def _get_content(msg: Any) -> str | None:
159
+ """Extract text from a LlamaIndex ChatMessage."""
160
+ if msg is None:
161
+ return None
162
+ content = getattr(msg, "content", None)
163
+ if content is None:
164
+ return None
165
+ if isinstance(content, str):
166
+ return content or None
167
+ return str(content) or None
@@ -0,0 +1,260 @@
1
+ """OpenAI Agents SDK adapter for pytest-agentcontract.
2
+
3
+ Records agent trajectories from the OpenAI Agents SDK (openai-agents)
4
+ by wrapping Runner.run / Runner.run_sync.
5
+
6
+ Usage:
7
+ from agentcontract.adapters.openai_agents import record_runner
8
+
9
+ recorder = Recorder(scenario="triage-agent")
10
+ with recorder.recording():
11
+ unpatch = record_runner(recorder)
12
+ result = Runner.run_sync(agent, "I need help with billing")
13
+ unpatch()
14
+ recorder.save("tests/scenarios/triage-agent.agentrun.json")
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import functools
20
+ import time
21
+ from collections.abc import Callable
22
+ from typing import Any
23
+
24
+ from agentcontract.recorder.core import Recorder
25
+
26
+
27
+ def record_runner(recorder: Recorder) -> Callable[[], None]:
28
+ """Patch the OpenAI Agents SDK Runner class to record trajectories.
29
+
30
+ Intercepts ``Runner.run()`` and ``Runner.run_sync()`` at the class level
31
+ to capture agent runs as AgentRun trajectories.
32
+
33
+ Returns an unpatch function that restores original methods.
34
+
35
+ Args:
36
+ recorder: A Recorder instance to capture the trajectory.
37
+ """
38
+ try:
39
+ from agents import Runner # type: ignore[import-untyped]
40
+ except ImportError as e:
41
+ raise ImportError(
42
+ "OpenAI Agents SDK not installed. Install with: pip install openai-agents"
43
+ ) from e
44
+
45
+ originals: dict[str, Any] = {}
46
+
47
+ # Patch run (async)
48
+ original_run = getattr(Runner, "run", None)
49
+ if original_run is not None:
50
+ originals["run"] = original_run
51
+
52
+ @functools.wraps(original_run)
53
+ async def recording_run(*args: Any, **kwargs: Any) -> Any:
54
+ start = time.monotonic()
55
+ result = await original_run(*args, **kwargs)
56
+ latency_ms = (time.monotonic() - start) * 1000
57
+ _extract_from_result(result, recorder, latency_ms)
58
+ return result
59
+
60
+ Runner.run = recording_run # type: ignore[assignment]
61
+
62
+ # Patch run_sync
63
+ original_run_sync = getattr(Runner, "run_sync", None)
64
+ if original_run_sync is not None:
65
+ originals["run_sync"] = original_run_sync
66
+
67
+ @functools.wraps(original_run_sync)
68
+ def recording_run_sync(*args: Any, **kwargs: Any) -> Any:
69
+ start = time.monotonic()
70
+ result = original_run_sync(*args, **kwargs)
71
+ latency_ms = (time.monotonic() - start) * 1000
72
+ _extract_from_result(result, recorder, latency_ms)
73
+ return result
74
+
75
+ Runner.run_sync = recording_run_sync # type: ignore[assignment]
76
+
77
+ def unpatch() -> None:
78
+ for name, orig in originals.items():
79
+ setattr(Runner, name, orig)
80
+
81
+ return unpatch
82
+
83
+
84
+ def _extract_from_result(result: Any, recorder: Recorder, latency_ms: float) -> None:
85
+ """Extract turns from a RunResult.
86
+
87
+ The OpenAI Agents SDK RunResult contains:
88
+ - result.final_output: the agent's final response
89
+ - result.new_items: list of RunItem objects (messages, tool calls, handoffs)
90
+ - result.raw_responses: list of ModelResponse objects
91
+ """
92
+ if result is None:
93
+ return
94
+
95
+ # Extract from new_items (preferred -- gives full trajectory)
96
+ new_items = getattr(result, "new_items", None)
97
+ if new_items and isinstance(new_items, (list, tuple)):
98
+ _extract_from_items(new_items, recorder, latency_ms)
99
+ return
100
+
101
+ # Fallback: just record the final output
102
+ final = getattr(result, "final_output", None)
103
+ if final is not None:
104
+ content = str(final) if not isinstance(final, str) else final
105
+ recorder.add_turn(
106
+ role="assistant",
107
+ content=content,
108
+ latency_ms=latency_ms,
109
+ )
110
+
111
+
112
+ def _extract_from_items(
113
+ items: list[Any] | tuple[Any, ...], recorder: Recorder, latency_ms: float
114
+ ) -> None:
115
+ """Extract turns from RunItem list."""
116
+ for item in items:
117
+ item_type = type(item).__name__
118
+
119
+ if item_type == "MessageOutputItem":
120
+ # Agent message output
121
+ agent_msg = getattr(item, "raw_item", None)
122
+ content = _extract_message_content(agent_msg)
123
+ tool_calls = _extract_message_tool_calls(agent_msg)
124
+ if content or tool_calls:
125
+ recorder.add_turn(
126
+ role="assistant",
127
+ content=content,
128
+ tool_calls=tool_calls or None,
129
+ latency_ms=latency_ms,
130
+ )
131
+
132
+ elif item_type == "ToolCallItem":
133
+ # Individual tool call
134
+ raw = getattr(item, "raw_item", None)
135
+ if raw is not None:
136
+ name = getattr(raw, "name", "") or _get_nested(raw, "function", "name", default="")
137
+ args = _get_tool_arguments(raw)
138
+ if name:
139
+ recorder.add_turn(
140
+ role="assistant",
141
+ tool_calls=[{
142
+ "id": str(getattr(raw, "id", "") or getattr(raw, "call_id", "") or ""),
143
+ "function": str(name),
144
+ "arguments": args,
145
+ }],
146
+ )
147
+
148
+ elif item_type == "ToolCallOutputItem":
149
+ # Tool result
150
+ output = getattr(item, "output", None)
151
+ recorder.add_turn(
152
+ role="tool",
153
+ content=str(output) if output is not None else None,
154
+ )
155
+
156
+ elif item_type == "HandoffCallItem":
157
+ # Agent handoff
158
+ target = getattr(item, "target_agent", None)
159
+ target_name = getattr(target, "name", str(target)) if target else "unknown"
160
+ recorder.add_turn(
161
+ role="assistant",
162
+ content=f"[handoff to {target_name}]",
163
+ )
164
+
165
+
166
+ def _extract_message_content(msg: Any) -> str | None:
167
+ """Extract text content from an Agents SDK message."""
168
+ if msg is None:
169
+ return None
170
+
171
+ # Check for content list (OpenAI format)
172
+ content = getattr(msg, "content", None)
173
+ if content is None:
174
+ return None
175
+
176
+ if isinstance(content, str):
177
+ return content or None
178
+
179
+ if isinstance(content, list):
180
+ texts = []
181
+ for block in content:
182
+ if isinstance(block, dict):
183
+ if block.get("type") in ("output_text", "text"):
184
+ texts.append(str(block.get("text", "")))
185
+ else:
186
+ block_type = getattr(block, "type", "")
187
+ if block_type in ("output_text", "text"):
188
+ texts.append(str(getattr(block, "text", "")))
189
+ joined = "".join(texts)
190
+ return joined or None
191
+
192
+ return str(content) or None
193
+
194
+
195
+ def _extract_message_tool_calls(msg: Any) -> list[dict[str, Any]] | None:
196
+ """Extract tool calls from an Agents SDK message."""
197
+ if msg is None:
198
+ return None
199
+
200
+ raw_calls = getattr(msg, "tool_calls", None)
201
+ if not raw_calls or not isinstance(raw_calls, (list, tuple)):
202
+ return None
203
+
204
+ calls: list[dict[str, Any]] = []
205
+ for tc in raw_calls:
206
+ name = getattr(tc, "name", "") or _get_nested(tc, "function", "name", default="")
207
+ if name:
208
+ calls.append({
209
+ "id": str(getattr(tc, "id", "") or getattr(tc, "call_id", "") or ""),
210
+ "function": str(name),
211
+ "arguments": _get_tool_arguments(tc),
212
+ })
213
+
214
+ return calls or None
215
+
216
+
217
+ def _get_tool_arguments(tc: Any) -> dict[str, Any]:
218
+ """Extract arguments dict from a tool call object."""
219
+ # Try direct args/arguments
220
+ for attr in ("args", "arguments", "input"):
221
+ val = getattr(tc, attr, None)
222
+ if isinstance(val, dict):
223
+ return val
224
+ if isinstance(val, str) and val:
225
+ import json
226
+
227
+ try:
228
+ parsed = json.loads(val)
229
+ if isinstance(parsed, dict):
230
+ return parsed
231
+ except (json.JSONDecodeError, TypeError):
232
+ return {"_raw": val}
233
+
234
+ # Try function.arguments (OpenAI chat format)
235
+ fn = getattr(tc, "function", None)
236
+ if fn:
237
+ fn_args = getattr(fn, "arguments", None)
238
+ if isinstance(fn_args, dict):
239
+ return fn_args
240
+ if isinstance(fn_args, str) and fn_args:
241
+ import json
242
+
243
+ try:
244
+ parsed = json.loads(fn_args)
245
+ if isinstance(parsed, dict):
246
+ return parsed
247
+ except (json.JSONDecodeError, TypeError):
248
+ pass
249
+
250
+ return {}
251
+
252
+
253
+ def _get_nested(obj: Any, *attrs: str, default: Any = None) -> Any:
254
+ """Safely traverse nested attributes."""
255
+ current = obj
256
+ for attr in attrs:
257
+ if current is None:
258
+ return default
259
+ current = getattr(current, attr, None)
260
+ return current if current is not None else default
@@ -0,0 +1,5 @@
1
+ """Assertion engine for validating agent trajectories."""
2
+
3
+ from agentcontract.assertions.engine import AssertionEngine
4
+
5
+ __all__ = ["AssertionEngine"]