pytest-agentcontract 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- agentcontract/__init__.py +25 -0
- agentcontract/adapters/__init__.py +36 -0
- agentcontract/adapters/langgraph.py +157 -0
- agentcontract/adapters/llamaindex.py +167 -0
- agentcontract/adapters/openai_agents.py +260 -0
- agentcontract/assertions/__init__.py +5 -0
- agentcontract/assertions/engine.py +360 -0
- agentcontract/cli.py +127 -0
- agentcontract/config.py +255 -0
- agentcontract/plugin.py +177 -0
- agentcontract/recorder/__init__.py +5 -0
- agentcontract/recorder/core.py +193 -0
- agentcontract/recorder/interceptors.py +236 -0
- agentcontract/replay/__init__.py +5 -0
- agentcontract/replay/engine.py +203 -0
- agentcontract/serialization.py +299 -0
- agentcontract/types.py +117 -0
- pytest_agentcontract-0.1.0.dist-info/METADATA +281 -0
- pytest_agentcontract-0.1.0.dist-info/RECORD +22 -0
- pytest_agentcontract-0.1.0.dist-info/WHEEL +4 -0
- pytest_agentcontract-0.1.0.dist-info/entry_points.txt +5 -0
- pytest_agentcontract-0.1.0.dist-info/licenses/LICENSE +21 -0
|
@@ -0,0 +1,25 @@
|
|
|
1
|
+
"""pytest-agentcontract: Deterministic CI tests for LLM agent trajectories."""
|
|
2
|
+
|
|
3
|
+
__version__ = "0.1.0"
|
|
4
|
+
|
|
5
|
+
# Lazy imports to avoid circular dependencies and speed up pytest plugin loading.
|
|
6
|
+
# The plugin.py entry point imports specific modules directly.
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def __getattr__(name: str): # noqa: ANN001
|
|
10
|
+
"""Lazy-load public API classes on first access."""
|
|
11
|
+
_lazy = {
|
|
12
|
+
"Recorder": "agentcontract.recorder.core",
|
|
13
|
+
"ReplayEngine": "agentcontract.replay.engine",
|
|
14
|
+
"AssertionEngine": "agentcontract.assertions.engine",
|
|
15
|
+
"AgentContractConfig": "agentcontract.config",
|
|
16
|
+
}
|
|
17
|
+
if name in _lazy:
|
|
18
|
+
import importlib
|
|
19
|
+
|
|
20
|
+
mod = importlib.import_module(_lazy[name])
|
|
21
|
+
return getattr(mod, name)
|
|
22
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
__all__ = ["Recorder", "ReplayEngine", "AssertionEngine", "AgentContractConfig"]
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
"""Framework adapters for popular agent libraries.
|
|
2
|
+
|
|
3
|
+
Each adapter wraps a framework's execution method to record trajectories
|
|
4
|
+
into the agentcontract format. All adapters follow the same pattern:
|
|
5
|
+
|
|
6
|
+
unpatch = record_<thing>(target, recorder)
|
|
7
|
+
# ... run your agent ...
|
|
8
|
+
unpatch()
|
|
9
|
+
|
|
10
|
+
Available adapters:
|
|
11
|
+
- langgraph: LangGraph CompiledGraph (invoke/ainvoke)
|
|
12
|
+
- llamaindex: LlamaIndex AgentRunner (chat/query)
|
|
13
|
+
- openai_agents: OpenAI Agents SDK Runner (run/run_sync)
|
|
14
|
+
"""
|
|
15
|
+
|
|
16
|
+
from __future__ import annotations
|
|
17
|
+
|
|
18
|
+
from typing import Any
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def __getattr__(name: str) -> Any:
|
|
22
|
+
"""Lazy-load adapters to avoid importing heavy framework dependencies."""
|
|
23
|
+
_lazy = {
|
|
24
|
+
"record_graph": "agentcontract.adapters.langgraph",
|
|
25
|
+
"record_agent": "agentcontract.adapters.llamaindex",
|
|
26
|
+
"record_runner": "agentcontract.adapters.openai_agents",
|
|
27
|
+
}
|
|
28
|
+
if name in _lazy:
|
|
29
|
+
import importlib
|
|
30
|
+
|
|
31
|
+
mod = importlib.import_module(_lazy[name])
|
|
32
|
+
return getattr(mod, name)
|
|
33
|
+
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
__all__ = ["record_graph", "record_agent", "record_runner"]
|
|
@@ -0,0 +1,157 @@
|
|
|
1
|
+
"""LangGraph adapter for pytest-agentcontract.
|
|
2
|
+
|
|
3
|
+
Records agent trajectories from LangGraph graph executions by wrapping
|
|
4
|
+
the graph's stream/invoke methods.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from agentcontract.adapters.langgraph import record_graph
|
|
8
|
+
|
|
9
|
+
recorder = Recorder(scenario="customer-support")
|
|
10
|
+
with recorder.recording():
|
|
11
|
+
unpatch = record_graph(graph, recorder)
|
|
12
|
+
result = graph.invoke({"messages": [("user", "I need a refund")]})
|
|
13
|
+
unpatch()
|
|
14
|
+
recorder.save("tests/scenarios/customer-support.agentrun.json")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
import time
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from agentcontract.recorder.core import Recorder
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def record_graph(graph: Any, recorder: Recorder) -> Callable[[], None]:
|
|
28
|
+
"""Wrap a LangGraph CompiledGraph to record trajectories.
|
|
29
|
+
|
|
30
|
+
Intercepts ``invoke()`` and ``ainvoke()`` to capture agent turns,
|
|
31
|
+
tool calls, and message flow as an AgentRun.
|
|
32
|
+
|
|
33
|
+
Returns an unpatch function that restores original methods.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
graph: A LangGraph CompiledGraph (from ``graph.compile()``).
|
|
37
|
+
recorder: A Recorder instance to capture the trajectory.
|
|
38
|
+
"""
|
|
39
|
+
original_invoke = getattr(graph, "invoke", None)
|
|
40
|
+
original_ainvoke = getattr(graph, "ainvoke", None)
|
|
41
|
+
|
|
42
|
+
if original_invoke is not None:
|
|
43
|
+
|
|
44
|
+
@functools.wraps(original_invoke)
|
|
45
|
+
def recording_invoke(*args: Any, **kwargs: Any) -> Any:
|
|
46
|
+
start = time.monotonic()
|
|
47
|
+
result = original_invoke(*args, **kwargs)
|
|
48
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
49
|
+
_extract_turns(result, recorder, latency_ms)
|
|
50
|
+
return result
|
|
51
|
+
|
|
52
|
+
graph.invoke = recording_invoke
|
|
53
|
+
|
|
54
|
+
if original_ainvoke is not None:
|
|
55
|
+
|
|
56
|
+
@functools.wraps(original_ainvoke)
|
|
57
|
+
async def recording_ainvoke(*args: Any, **kwargs: Any) -> Any:
|
|
58
|
+
start = time.monotonic()
|
|
59
|
+
result = await original_ainvoke(*args, **kwargs)
|
|
60
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
61
|
+
_extract_turns(result, recorder, latency_ms)
|
|
62
|
+
return result
|
|
63
|
+
|
|
64
|
+
graph.ainvoke = recording_ainvoke
|
|
65
|
+
|
|
66
|
+
def unpatch() -> None:
|
|
67
|
+
if original_invoke is not None:
|
|
68
|
+
graph.invoke = original_invoke
|
|
69
|
+
if original_ainvoke is not None:
|
|
70
|
+
graph.ainvoke = original_ainvoke
|
|
71
|
+
|
|
72
|
+
return unpatch
|
|
73
|
+
|
|
74
|
+
|
|
75
|
+
def _extract_turns(result: Any, recorder: Recorder, latency_ms: float) -> None:
|
|
76
|
+
"""Extract turns from a LangGraph result dict.
|
|
77
|
+
|
|
78
|
+
LangGraph returns a state dict. The standard ``messages`` key contains
|
|
79
|
+
the conversation as a list of LangChain message objects.
|
|
80
|
+
"""
|
|
81
|
+
if not isinstance(result, dict):
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
messages = result.get("messages", [])
|
|
85
|
+
if not isinstance(messages, (list, tuple)):
|
|
86
|
+
return
|
|
87
|
+
|
|
88
|
+
for msg in messages:
|
|
89
|
+
role = _get_role(msg)
|
|
90
|
+
content = _get_content(msg)
|
|
91
|
+
tool_calls = _get_tool_calls(msg)
|
|
92
|
+
|
|
93
|
+
if role in ("user", "assistant", "system", "tool"):
|
|
94
|
+
recorder.add_turn(
|
|
95
|
+
role=role,
|
|
96
|
+
content=content,
|
|
97
|
+
tool_calls=tool_calls or None,
|
|
98
|
+
latency_ms=latency_ms if role == "assistant" else None,
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _get_role(msg: Any) -> str:
|
|
103
|
+
"""Extract role from a LangChain message object or dict."""
|
|
104
|
+
if isinstance(msg, dict):
|
|
105
|
+
return str(msg.get("role", msg.get("type", "")))
|
|
106
|
+
|
|
107
|
+
# LangChain message classes: HumanMessage, AIMessage, SystemMessage, ToolMessage
|
|
108
|
+
type_attr = getattr(msg, "type", None)
|
|
109
|
+
if type_attr:
|
|
110
|
+
role_map = {"human": "user", "ai": "assistant", "system": "system", "tool": "tool"}
|
|
111
|
+
return role_map.get(str(type_attr), str(type_attr))
|
|
112
|
+
|
|
113
|
+
return ""
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _get_content(msg: Any) -> str | None:
|
|
117
|
+
"""Extract text content from a message."""
|
|
118
|
+
content = msg.get("content") if isinstance(msg, dict) else getattr(msg, "content", None)
|
|
119
|
+
|
|
120
|
+
if content is None:
|
|
121
|
+
return None
|
|
122
|
+
if isinstance(content, str):
|
|
123
|
+
return content or None
|
|
124
|
+
# LangChain can return list of content blocks
|
|
125
|
+
if isinstance(content, list):
|
|
126
|
+
texts = [str(b.get("text", "")) if isinstance(b, dict) else str(b) for b in content]
|
|
127
|
+
joined = "".join(texts)
|
|
128
|
+
return joined or None
|
|
129
|
+
return str(content) or None
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def _get_tool_calls(msg: Any) -> list[dict[str, Any]] | None:
|
|
133
|
+
"""Extract tool calls from a LangChain AIMessage."""
|
|
134
|
+
raw = msg.get("tool_calls", []) if isinstance(msg, dict) else getattr(msg, "tool_calls", None)
|
|
135
|
+
|
|
136
|
+
if not raw or not isinstance(raw, (list, tuple)):
|
|
137
|
+
return None
|
|
138
|
+
|
|
139
|
+
calls: list[dict[str, Any]] = []
|
|
140
|
+
for tc in raw:
|
|
141
|
+
if isinstance(tc, dict):
|
|
142
|
+
name = tc.get("name", "")
|
|
143
|
+
args = tc.get("args", tc.get("arguments", {}))
|
|
144
|
+
call_id = tc.get("id", "")
|
|
145
|
+
else:
|
|
146
|
+
name = getattr(tc, "name", "")
|
|
147
|
+
args = getattr(tc, "args", getattr(tc, "arguments", {}))
|
|
148
|
+
call_id = getattr(tc, "id", "")
|
|
149
|
+
|
|
150
|
+
if name:
|
|
151
|
+
calls.append({
|
|
152
|
+
"id": str(call_id) if call_id else "",
|
|
153
|
+
"function": str(name),
|
|
154
|
+
"arguments": args if isinstance(args, dict) else {},
|
|
155
|
+
})
|
|
156
|
+
|
|
157
|
+
return calls or None
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
"""LlamaIndex adapter for pytest-agentcontract.
|
|
2
|
+
|
|
3
|
+
Records agent trajectories from LlamaIndex AgentRunner / ReActAgent
|
|
4
|
+
by wrapping the chat/query methods.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from agentcontract.adapters.llamaindex import record_agent
|
|
8
|
+
|
|
9
|
+
recorder = Recorder(scenario="rag-qa")
|
|
10
|
+
with recorder.recording():
|
|
11
|
+
unpatch = record_agent(agent, recorder)
|
|
12
|
+
response = agent.chat("What's the refund policy?")
|
|
13
|
+
unpatch()
|
|
14
|
+
recorder.save("tests/scenarios/rag-qa.agentrun.json")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
import time
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from agentcontract.recorder.core import Recorder
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def record_agent(agent: Any, recorder: Recorder) -> Callable[[], None]:
|
|
28
|
+
"""Wrap a LlamaIndex agent to record trajectories.
|
|
29
|
+
|
|
30
|
+
Intercepts ``chat()``, ``achat()``, ``query()``, and ``aquery()``
|
|
31
|
+
to capture the full interaction as an AgentRun.
|
|
32
|
+
|
|
33
|
+
Returns an unpatch function that restores original methods.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
agent: A LlamaIndex AgentRunner, ReActAgent, or similar.
|
|
37
|
+
recorder: A Recorder instance to capture the trajectory.
|
|
38
|
+
"""
|
|
39
|
+
originals: dict[str, Any] = {}
|
|
40
|
+
methods = ["chat", "achat", "query", "aquery"]
|
|
41
|
+
|
|
42
|
+
for method_name in methods:
|
|
43
|
+
original = getattr(agent, method_name, None)
|
|
44
|
+
if original is None:
|
|
45
|
+
continue
|
|
46
|
+
originals[method_name] = original
|
|
47
|
+
|
|
48
|
+
if method_name.startswith("a"):
|
|
49
|
+
# Async variant
|
|
50
|
+
@functools.wraps(original)
|
|
51
|
+
async def async_wrapper(
|
|
52
|
+
*args: Any,
|
|
53
|
+
_orig: Any = original,
|
|
54
|
+
_name: str = method_name,
|
|
55
|
+
**kwargs: Any,
|
|
56
|
+
) -> Any:
|
|
57
|
+
start = time.monotonic()
|
|
58
|
+
result = await _orig(*args, **kwargs)
|
|
59
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
60
|
+
_extract_from_response(result, recorder, latency_ms, agent)
|
|
61
|
+
return result
|
|
62
|
+
|
|
63
|
+
setattr(agent, method_name, async_wrapper)
|
|
64
|
+
else:
|
|
65
|
+
# Sync variant
|
|
66
|
+
@functools.wraps(original)
|
|
67
|
+
def sync_wrapper(
|
|
68
|
+
*args: Any,
|
|
69
|
+
_orig: Any = original,
|
|
70
|
+
_name: str = method_name,
|
|
71
|
+
**kwargs: Any,
|
|
72
|
+
) -> Any:
|
|
73
|
+
start = time.monotonic()
|
|
74
|
+
result = _orig(*args, **kwargs)
|
|
75
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
76
|
+
_extract_from_response(result, recorder, latency_ms, agent)
|
|
77
|
+
return result
|
|
78
|
+
|
|
79
|
+
setattr(agent, method_name, sync_wrapper)
|
|
80
|
+
|
|
81
|
+
def unpatch() -> None:
|
|
82
|
+
for name, orig in originals.items():
|
|
83
|
+
setattr(agent, name, orig)
|
|
84
|
+
|
|
85
|
+
return unpatch
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _extract_from_response(
|
|
89
|
+
response: Any, recorder: Recorder, latency_ms: float, agent: Any
|
|
90
|
+
) -> None:
|
|
91
|
+
"""Extract turns from a LlamaIndex response.
|
|
92
|
+
|
|
93
|
+
LlamaIndex agents return AgentChatResponse or Response objects.
|
|
94
|
+
We also check the agent's chat_history/memory for the full trajectory.
|
|
95
|
+
"""
|
|
96
|
+
# Try to get the response text
|
|
97
|
+
response_text = None
|
|
98
|
+
if hasattr(response, "response"):
|
|
99
|
+
response_text = str(response.response) if response.response else None
|
|
100
|
+
elif hasattr(response, "message"):
|
|
101
|
+
msg = response.message
|
|
102
|
+
response_text = _get_content(msg)
|
|
103
|
+
|
|
104
|
+
# Extract tool calls from sources/source_nodes (tool outputs)
|
|
105
|
+
tool_calls = _extract_tool_calls(response)
|
|
106
|
+
|
|
107
|
+
if response_text or tool_calls:
|
|
108
|
+
recorder.add_turn(
|
|
109
|
+
role="assistant",
|
|
110
|
+
content=response_text,
|
|
111
|
+
tool_calls=tool_calls or None,
|
|
112
|
+
latency_ms=latency_ms,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def _extract_tool_calls(response: Any) -> list[dict[str, Any]] | None:
|
|
117
|
+
"""Extract tool calls from LlamaIndex response sources."""
|
|
118
|
+
calls: list[dict[str, Any]] = []
|
|
119
|
+
|
|
120
|
+
# AgentChatResponse has .sources which are ToolOutput objects
|
|
121
|
+
sources = getattr(response, "sources", None)
|
|
122
|
+
if sources and isinstance(sources, (list, tuple)):
|
|
123
|
+
for source in sources:
|
|
124
|
+
tool_name = getattr(source, "tool_name", None)
|
|
125
|
+
if not tool_name:
|
|
126
|
+
continue
|
|
127
|
+
raw_input = getattr(source, "raw_input", {})
|
|
128
|
+
raw_output = getattr(source, "raw_output", None)
|
|
129
|
+
|
|
130
|
+
calls.append({
|
|
131
|
+
"id": "",
|
|
132
|
+
"function": str(tool_name),
|
|
133
|
+
"arguments": raw_input if isinstance(raw_input, dict) else {},
|
|
134
|
+
"result": str(raw_output) if raw_output is not None else None,
|
|
135
|
+
})
|
|
136
|
+
|
|
137
|
+
# Also check source_nodes for retrieval-based responses
|
|
138
|
+
source_nodes = getattr(response, "source_nodes", None)
|
|
139
|
+
if source_nodes and isinstance(source_nodes, (list, tuple)):
|
|
140
|
+
for node in source_nodes:
|
|
141
|
+
node_id = getattr(node, "node_id", "") or getattr(node, "id_", "")
|
|
142
|
+
score = getattr(node, "score", None)
|
|
143
|
+
text = ""
|
|
144
|
+
inner = getattr(node, "node", None) or getattr(node, "text", None)
|
|
145
|
+
if inner:
|
|
146
|
+
text = getattr(inner, "text", str(inner))[:200]
|
|
147
|
+
|
|
148
|
+
calls.append({
|
|
149
|
+
"id": str(node_id) if node_id else "",
|
|
150
|
+
"function": "_retrieve",
|
|
151
|
+
"arguments": {"score": score} if score is not None else {},
|
|
152
|
+
"result": text or None,
|
|
153
|
+
})
|
|
154
|
+
|
|
155
|
+
return calls or None
|
|
156
|
+
|
|
157
|
+
|
|
158
|
+
def _get_content(msg: Any) -> str | None:
|
|
159
|
+
"""Extract text from a LlamaIndex ChatMessage."""
|
|
160
|
+
if msg is None:
|
|
161
|
+
return None
|
|
162
|
+
content = getattr(msg, "content", None)
|
|
163
|
+
if content is None:
|
|
164
|
+
return None
|
|
165
|
+
if isinstance(content, str):
|
|
166
|
+
return content or None
|
|
167
|
+
return str(content) or None
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
"""OpenAI Agents SDK adapter for pytest-agentcontract.
|
|
2
|
+
|
|
3
|
+
Records agent trajectories from the OpenAI Agents SDK (openai-agents)
|
|
4
|
+
by wrapping Runner.run / Runner.run_sync.
|
|
5
|
+
|
|
6
|
+
Usage:
|
|
7
|
+
from agentcontract.adapters.openai_agents import record_runner
|
|
8
|
+
|
|
9
|
+
recorder = Recorder(scenario="triage-agent")
|
|
10
|
+
with recorder.recording():
|
|
11
|
+
unpatch = record_runner(recorder)
|
|
12
|
+
result = Runner.run_sync(agent, "I need help with billing")
|
|
13
|
+
unpatch()
|
|
14
|
+
recorder.save("tests/scenarios/triage-agent.agentrun.json")
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from __future__ import annotations
|
|
18
|
+
|
|
19
|
+
import functools
|
|
20
|
+
import time
|
|
21
|
+
from collections.abc import Callable
|
|
22
|
+
from typing import Any
|
|
23
|
+
|
|
24
|
+
from agentcontract.recorder.core import Recorder
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
def record_runner(recorder: Recorder) -> Callable[[], None]:
|
|
28
|
+
"""Patch the OpenAI Agents SDK Runner class to record trajectories.
|
|
29
|
+
|
|
30
|
+
Intercepts ``Runner.run()`` and ``Runner.run_sync()`` at the class level
|
|
31
|
+
to capture agent runs as AgentRun trajectories.
|
|
32
|
+
|
|
33
|
+
Returns an unpatch function that restores original methods.
|
|
34
|
+
|
|
35
|
+
Args:
|
|
36
|
+
recorder: A Recorder instance to capture the trajectory.
|
|
37
|
+
"""
|
|
38
|
+
try:
|
|
39
|
+
from agents import Runner # type: ignore[import-untyped]
|
|
40
|
+
except ImportError as e:
|
|
41
|
+
raise ImportError(
|
|
42
|
+
"OpenAI Agents SDK not installed. Install with: pip install openai-agents"
|
|
43
|
+
) from e
|
|
44
|
+
|
|
45
|
+
originals: dict[str, Any] = {}
|
|
46
|
+
|
|
47
|
+
# Patch run (async)
|
|
48
|
+
original_run = getattr(Runner, "run", None)
|
|
49
|
+
if original_run is not None:
|
|
50
|
+
originals["run"] = original_run
|
|
51
|
+
|
|
52
|
+
@functools.wraps(original_run)
|
|
53
|
+
async def recording_run(*args: Any, **kwargs: Any) -> Any:
|
|
54
|
+
start = time.monotonic()
|
|
55
|
+
result = await original_run(*args, **kwargs)
|
|
56
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
57
|
+
_extract_from_result(result, recorder, latency_ms)
|
|
58
|
+
return result
|
|
59
|
+
|
|
60
|
+
Runner.run = recording_run # type: ignore[assignment]
|
|
61
|
+
|
|
62
|
+
# Patch run_sync
|
|
63
|
+
original_run_sync = getattr(Runner, "run_sync", None)
|
|
64
|
+
if original_run_sync is not None:
|
|
65
|
+
originals["run_sync"] = original_run_sync
|
|
66
|
+
|
|
67
|
+
@functools.wraps(original_run_sync)
|
|
68
|
+
def recording_run_sync(*args: Any, **kwargs: Any) -> Any:
|
|
69
|
+
start = time.monotonic()
|
|
70
|
+
result = original_run_sync(*args, **kwargs)
|
|
71
|
+
latency_ms = (time.monotonic() - start) * 1000
|
|
72
|
+
_extract_from_result(result, recorder, latency_ms)
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
Runner.run_sync = recording_run_sync # type: ignore[assignment]
|
|
76
|
+
|
|
77
|
+
def unpatch() -> None:
|
|
78
|
+
for name, orig in originals.items():
|
|
79
|
+
setattr(Runner, name, orig)
|
|
80
|
+
|
|
81
|
+
return unpatch
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
def _extract_from_result(result: Any, recorder: Recorder, latency_ms: float) -> None:
|
|
85
|
+
"""Extract turns from a RunResult.
|
|
86
|
+
|
|
87
|
+
The OpenAI Agents SDK RunResult contains:
|
|
88
|
+
- result.final_output: the agent's final response
|
|
89
|
+
- result.new_items: list of RunItem objects (messages, tool calls, handoffs)
|
|
90
|
+
- result.raw_responses: list of ModelResponse objects
|
|
91
|
+
"""
|
|
92
|
+
if result is None:
|
|
93
|
+
return
|
|
94
|
+
|
|
95
|
+
# Extract from new_items (preferred -- gives full trajectory)
|
|
96
|
+
new_items = getattr(result, "new_items", None)
|
|
97
|
+
if new_items and isinstance(new_items, (list, tuple)):
|
|
98
|
+
_extract_from_items(new_items, recorder, latency_ms)
|
|
99
|
+
return
|
|
100
|
+
|
|
101
|
+
# Fallback: just record the final output
|
|
102
|
+
final = getattr(result, "final_output", None)
|
|
103
|
+
if final is not None:
|
|
104
|
+
content = str(final) if not isinstance(final, str) else final
|
|
105
|
+
recorder.add_turn(
|
|
106
|
+
role="assistant",
|
|
107
|
+
content=content,
|
|
108
|
+
latency_ms=latency_ms,
|
|
109
|
+
)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
def _extract_from_items(
|
|
113
|
+
items: list[Any] | tuple[Any, ...], recorder: Recorder, latency_ms: float
|
|
114
|
+
) -> None:
|
|
115
|
+
"""Extract turns from RunItem list."""
|
|
116
|
+
for item in items:
|
|
117
|
+
item_type = type(item).__name__
|
|
118
|
+
|
|
119
|
+
if item_type == "MessageOutputItem":
|
|
120
|
+
# Agent message output
|
|
121
|
+
agent_msg = getattr(item, "raw_item", None)
|
|
122
|
+
content = _extract_message_content(agent_msg)
|
|
123
|
+
tool_calls = _extract_message_tool_calls(agent_msg)
|
|
124
|
+
if content or tool_calls:
|
|
125
|
+
recorder.add_turn(
|
|
126
|
+
role="assistant",
|
|
127
|
+
content=content,
|
|
128
|
+
tool_calls=tool_calls or None,
|
|
129
|
+
latency_ms=latency_ms,
|
|
130
|
+
)
|
|
131
|
+
|
|
132
|
+
elif item_type == "ToolCallItem":
|
|
133
|
+
# Individual tool call
|
|
134
|
+
raw = getattr(item, "raw_item", None)
|
|
135
|
+
if raw is not None:
|
|
136
|
+
name = getattr(raw, "name", "") or _get_nested(raw, "function", "name", default="")
|
|
137
|
+
args = _get_tool_arguments(raw)
|
|
138
|
+
if name:
|
|
139
|
+
recorder.add_turn(
|
|
140
|
+
role="assistant",
|
|
141
|
+
tool_calls=[{
|
|
142
|
+
"id": str(getattr(raw, "id", "") or getattr(raw, "call_id", "") or ""),
|
|
143
|
+
"function": str(name),
|
|
144
|
+
"arguments": args,
|
|
145
|
+
}],
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
elif item_type == "ToolCallOutputItem":
|
|
149
|
+
# Tool result
|
|
150
|
+
output = getattr(item, "output", None)
|
|
151
|
+
recorder.add_turn(
|
|
152
|
+
role="tool",
|
|
153
|
+
content=str(output) if output is not None else None,
|
|
154
|
+
)
|
|
155
|
+
|
|
156
|
+
elif item_type == "HandoffCallItem":
|
|
157
|
+
# Agent handoff
|
|
158
|
+
target = getattr(item, "target_agent", None)
|
|
159
|
+
target_name = getattr(target, "name", str(target)) if target else "unknown"
|
|
160
|
+
recorder.add_turn(
|
|
161
|
+
role="assistant",
|
|
162
|
+
content=f"[handoff to {target_name}]",
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
|
|
166
|
+
def _extract_message_content(msg: Any) -> str | None:
|
|
167
|
+
"""Extract text content from an Agents SDK message."""
|
|
168
|
+
if msg is None:
|
|
169
|
+
return None
|
|
170
|
+
|
|
171
|
+
# Check for content list (OpenAI format)
|
|
172
|
+
content = getattr(msg, "content", None)
|
|
173
|
+
if content is None:
|
|
174
|
+
return None
|
|
175
|
+
|
|
176
|
+
if isinstance(content, str):
|
|
177
|
+
return content or None
|
|
178
|
+
|
|
179
|
+
if isinstance(content, list):
|
|
180
|
+
texts = []
|
|
181
|
+
for block in content:
|
|
182
|
+
if isinstance(block, dict):
|
|
183
|
+
if block.get("type") in ("output_text", "text"):
|
|
184
|
+
texts.append(str(block.get("text", "")))
|
|
185
|
+
else:
|
|
186
|
+
block_type = getattr(block, "type", "")
|
|
187
|
+
if block_type in ("output_text", "text"):
|
|
188
|
+
texts.append(str(getattr(block, "text", "")))
|
|
189
|
+
joined = "".join(texts)
|
|
190
|
+
return joined or None
|
|
191
|
+
|
|
192
|
+
return str(content) or None
|
|
193
|
+
|
|
194
|
+
|
|
195
|
+
def _extract_message_tool_calls(msg: Any) -> list[dict[str, Any]] | None:
|
|
196
|
+
"""Extract tool calls from an Agents SDK message."""
|
|
197
|
+
if msg is None:
|
|
198
|
+
return None
|
|
199
|
+
|
|
200
|
+
raw_calls = getattr(msg, "tool_calls", None)
|
|
201
|
+
if not raw_calls or not isinstance(raw_calls, (list, tuple)):
|
|
202
|
+
return None
|
|
203
|
+
|
|
204
|
+
calls: list[dict[str, Any]] = []
|
|
205
|
+
for tc in raw_calls:
|
|
206
|
+
name = getattr(tc, "name", "") or _get_nested(tc, "function", "name", default="")
|
|
207
|
+
if name:
|
|
208
|
+
calls.append({
|
|
209
|
+
"id": str(getattr(tc, "id", "") or getattr(tc, "call_id", "") or ""),
|
|
210
|
+
"function": str(name),
|
|
211
|
+
"arguments": _get_tool_arguments(tc),
|
|
212
|
+
})
|
|
213
|
+
|
|
214
|
+
return calls or None
|
|
215
|
+
|
|
216
|
+
|
|
217
|
+
def _get_tool_arguments(tc: Any) -> dict[str, Any]:
|
|
218
|
+
"""Extract arguments dict from a tool call object."""
|
|
219
|
+
# Try direct args/arguments
|
|
220
|
+
for attr in ("args", "arguments", "input"):
|
|
221
|
+
val = getattr(tc, attr, None)
|
|
222
|
+
if isinstance(val, dict):
|
|
223
|
+
return val
|
|
224
|
+
if isinstance(val, str) and val:
|
|
225
|
+
import json
|
|
226
|
+
|
|
227
|
+
try:
|
|
228
|
+
parsed = json.loads(val)
|
|
229
|
+
if isinstance(parsed, dict):
|
|
230
|
+
return parsed
|
|
231
|
+
except (json.JSONDecodeError, TypeError):
|
|
232
|
+
return {"_raw": val}
|
|
233
|
+
|
|
234
|
+
# Try function.arguments (OpenAI chat format)
|
|
235
|
+
fn = getattr(tc, "function", None)
|
|
236
|
+
if fn:
|
|
237
|
+
fn_args = getattr(fn, "arguments", None)
|
|
238
|
+
if isinstance(fn_args, dict):
|
|
239
|
+
return fn_args
|
|
240
|
+
if isinstance(fn_args, str) and fn_args:
|
|
241
|
+
import json
|
|
242
|
+
|
|
243
|
+
try:
|
|
244
|
+
parsed = json.loads(fn_args)
|
|
245
|
+
if isinstance(parsed, dict):
|
|
246
|
+
return parsed
|
|
247
|
+
except (json.JSONDecodeError, TypeError):
|
|
248
|
+
pass
|
|
249
|
+
|
|
250
|
+
return {}
|
|
251
|
+
|
|
252
|
+
|
|
253
|
+
def _get_nested(obj: Any, *attrs: str, default: Any = None) -> Any:
|
|
254
|
+
"""Safely traverse nested attributes."""
|
|
255
|
+
current = obj
|
|
256
|
+
for attr in attrs:
|
|
257
|
+
if current is None:
|
|
258
|
+
return default
|
|
259
|
+
current = getattr(current, attr, None)
|
|
260
|
+
return current if current is not None else default
|