understudy 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
understudy/__init__.py ADDED
@@ -0,0 +1,67 @@
1
+ """understudy: simulation and trace-based evaluation for agentic systems.
2
+
3
+ The simulated user is an understudy standing in for a real user.
4
+ You write scenes, run rehearsals, and check the performance —
5
+ not by reading the script, but by inspecting what actually happened.
6
+ """
7
+
8
+ from .check import CheckItem, CheckResult, check
9
+ from .judges import Judge, JudgeResult
10
+ from .mocks import MockToolkit, ToolError
11
+ from .models import Expectations, Persona, PersonaPreset, Scene
12
+ from .prompts import (
13
+ ADVERSARIAL_ROBUSTNESS,
14
+ FACTUAL_GROUNDING,
15
+ INSTRUCTION_FOLLOWING,
16
+ POLICY_COMPLIANCE,
17
+ TASK_COMPLETION,
18
+ TONE_EMPATHY,
19
+ TOOL_USAGE_CORRECTNESS,
20
+ )
21
+ from .runner import AgentApp, AgentResponse, run
22
+ from .storage import RunStorage
23
+ from .suite import SceneResult, Suite, SuiteResults
24
+ from .trace import AgentTransfer, ToolCall, Trace, Turn
25
+
26
+ __version__ = "0.1.0"
27
+
28
+ __all__ = [
29
+ # models
30
+ "Scene",
31
+ "Persona",
32
+ "PersonaPreset",
33
+ "Expectations",
34
+ # trace
35
+ "Trace",
36
+ "Turn",
37
+ "ToolCall",
38
+ "AgentTransfer",
39
+ # runner
40
+ "run",
41
+ "AgentApp",
42
+ "AgentResponse",
43
+ # check
44
+ "check",
45
+ "CheckResult",
46
+ "CheckItem",
47
+ # suite
48
+ "Suite",
49
+ "SuiteResults",
50
+ "SceneResult",
51
+ # storage
52
+ "RunStorage",
53
+ # judges
54
+ "Judge",
55
+ "JudgeResult",
56
+ # mocks
57
+ "MockToolkit",
58
+ "ToolError",
59
+ # rubrics
60
+ "TOOL_USAGE_CORRECTNESS",
61
+ "POLICY_COMPLIANCE",
62
+ "TONE_EMPATHY",
63
+ "ADVERSARIAL_ROBUSTNESS",
64
+ "TASK_COMPLETION",
65
+ "FACTUAL_GROUNDING",
66
+ "INSTRUCTION_FOLLOWING",
67
+ ]
@@ -0,0 +1,194 @@
1
+ """ADK adapter: wraps Google ADK agents for use with understudy."""
2
+
3
+ from datetime import UTC, datetime
4
+ from typing import Any
5
+
6
+ from ..mocks import MockToolkit
7
+ from ..runner import AgentApp, AgentResponse
8
+ from ..trace import AgentTransfer, ToolCall
9
+
10
+
11
+ def _create_mock_callback(mocks: MockToolkit | None):
12
+ """Create a before_tool_callback that returns mock responses.
13
+
14
+ Args:
15
+ mocks: MockToolkit instance or None.
16
+
17
+ Returns:
18
+ A callback function compatible with google-adk's before_tool_callback.
19
+ Returns dict to bypass real tool execution (mock response).
20
+ Returns None to allow normal execution.
21
+ """
22
+
23
+ def callback(tool, args: dict[str, Any], tool_context) -> dict | None:
24
+ if mocks is None:
25
+ return None
26
+ tool_name = getattr(tool, "name", None) or getattr(tool, "__name__", str(tool))
27
+ if mocks.get_handler(tool_name):
28
+ try:
29
+ result = mocks.call(tool_name, **args)
30
+ return result
31
+ except Exception as e:
32
+ return {"error": str(e)}
33
+ return None
34
+
35
+ return callback
36
+
37
+
38
+ class ADKApp(AgentApp):
39
+ """Wraps a Google ADK Agent for use with understudy.
40
+
41
+ Usage:
42
+ from google.adk import Agent
43
+ from understudy.adk import ADKApp
44
+
45
+ agent = Agent(model="gemini-2.5-flash", name="my_agent", ...)
46
+ app = ADKApp(agent=agent)
47
+ trace = run(app, scene)
48
+ """
49
+
50
+ def __init__(self, agent: Any, session_id: str | None = None):
51
+ """
52
+ Args:
53
+ agent: A google.adk.Agent instance.
54
+ session_id: Optional session ID. If None, a random one is generated.
55
+ """
56
+ self.agent = agent
57
+ self.session_id = session_id
58
+ self._runner = None
59
+ self._session = None
60
+ self._mocks: MockToolkit | None = None
61
+ self._current_agent: str | None = None
62
+ self._agent_transfers: list[AgentTransfer] = []
63
+
64
+ def start(self, mocks: MockToolkit | None = None) -> None:
65
+ """Initialize the ADK session."""
66
+ try:
67
+ from google.adk import Runner
68
+ from google.adk.sessions import InMemorySessionService
69
+ except ImportError as e:
70
+ raise ImportError(
71
+ "google-adk package required. Install with: pip install understudy[adk]"
72
+ ) from e
73
+ import uuid
74
+
75
+ self._mocks = mocks
76
+ self._current_agent = getattr(self.agent, "name", None)
77
+ self._agent_transfers = []
78
+ self._session_id = self.session_id or str(uuid.uuid4())
79
+
80
+ session_service = InMemorySessionService()
81
+ if mocks:
82
+ self.agent.before_tool_callback = _create_mock_callback(mocks)
83
+
84
+ self._runner = Runner(
85
+ agent=self.agent,
86
+ app_name="understudy_test",
87
+ session_service=session_service,
88
+ )
89
+ self._session = session_service.create_session_sync(
90
+ app_name="understudy_test",
91
+ user_id="understudy_user",
92
+ session_id=self._session_id,
93
+ )
94
+
95
+ def send(self, message: str) -> AgentResponse:
96
+ """Send a user message to the ADK agent and capture the response."""
97
+ try:
98
+ from google.genai import types
99
+ except ImportError as e:
100
+ raise ImportError(
101
+ "google-adk package required. Install with: pip install understudy[adk]"
102
+ ) from e
103
+
104
+ user_content = types.Content(
105
+ role="user",
106
+ parts=[types.Part(text=message)],
107
+ )
108
+
109
+ tool_calls: list[ToolCall] = []
110
+ agent_text_parts: list[str] = []
111
+ terminal_state: str | None = None
112
+ current_agent_name = self._current_agent
113
+
114
+ for event in self._runner.run(
115
+ user_id="understudy_user",
116
+ session_id=self._session.id,
117
+ new_message=user_content,
118
+ ):
119
+ # track agent attribution from event.author
120
+ if hasattr(event, "author") and event.author:
121
+ event_agent = event.author
122
+ if event_agent != current_agent_name and current_agent_name:
123
+ self._agent_transfers.append(
124
+ AgentTransfer(
125
+ from_agent=current_agent_name,
126
+ to_agent=event_agent,
127
+ timestamp=datetime.now(UTC),
128
+ )
129
+ )
130
+ current_agent_name = event_agent
131
+
132
+ # detect explicit transfer_to_agent actions
133
+ if (
134
+ hasattr(event, "actions")
135
+ and event.actions
136
+ and hasattr(event.actions, "transfer_to_agent")
137
+ and event.actions.transfer_to_agent
138
+ ):
139
+ target_agent = event.actions.transfer_to_agent
140
+ if current_agent_name and target_agent != current_agent_name:
141
+ self._agent_transfers.append(
142
+ AgentTransfer(
143
+ from_agent=current_agent_name,
144
+ to_agent=target_agent,
145
+ timestamp=datetime.now(UTC),
146
+ )
147
+ )
148
+ current_agent_name = target_agent
149
+
150
+ # capture tool calls using get_function_calls()
151
+ for fc in event.get_function_calls():
152
+ call = ToolCall(
153
+ tool_name=fc.name,
154
+ arguments=dict(fc.args) if fc.args else {},
155
+ agent_name=current_agent_name,
156
+ )
157
+ tool_calls.append(call)
158
+
159
+ # capture function responses and update tool call results
160
+ for fr in event.get_function_responses():
161
+ for call in tool_calls:
162
+ if call.tool_name == fr.name and call.result is None:
163
+ call.result = fr.response
164
+ break
165
+
166
+ # capture text responses from content parts
167
+ if hasattr(event, "content") and event.content and hasattr(event.content, "parts"):
168
+ for part in event.content.parts:
169
+ text = getattr(part, "text", None)
170
+ if text:
171
+ agent_text_parts.append(text)
172
+
173
+ # check for terminal state markers
174
+ # convention: agent emits "TERMINAL_STATE: <state>"
175
+ if "TERMINAL_STATE:" in text:
176
+ state = text.split("TERMINAL_STATE:")[-1].strip()
177
+ terminal_state = state.split()[0].strip()
178
+
179
+ self._current_agent = current_agent_name
180
+
181
+ response = AgentResponse(
182
+ content=" ".join(agent_text_parts),
183
+ tool_calls=tool_calls,
184
+ terminal_state=terminal_state,
185
+ )
186
+ response.agent_name = current_agent_name
187
+ response.agent_transfers = list(self._agent_transfers)
188
+ return response
189
+
190
+ def stop(self) -> None:
191
+ """Clean up the ADK session."""
192
+ self._runner = None
193
+ self._session = None
194
+ self._mocks = None
understudy/check.py ADDED
@@ -0,0 +1,138 @@
1
+ """Check: validate a trace against scene expectations."""
2
+
3
+ from dataclasses import dataclass, field
4
+
5
+ from .models import Expectations
6
+ from .trace import Trace
7
+
8
+
9
+ @dataclass
10
+ class CheckResult:
11
+ """Result of checking a trace against expectations."""
12
+
13
+ checks: list["CheckItem"] = field(default_factory=list)
14
+
15
+ @property
16
+ def passed(self) -> bool:
17
+ return all(c.passed for c in self.checks)
18
+
19
+ @property
20
+ def failed_checks(self) -> list["CheckItem"]:
21
+ return [c for c in self.checks if not c.passed]
22
+
23
+ def summary(self) -> str:
24
+ lines = []
25
+ for c in self.checks:
26
+ mark = "✓" if c.passed else "✗"
27
+ lines.append(f" {mark} {c.label}: {c.detail}")
28
+ return "\n".join(lines)
29
+
30
+ def __repr__(self) -> str:
31
+ n_pass = sum(1 for c in self.checks if c.passed)
32
+ return f"CheckResult({n_pass}/{len(self.checks)} passed)"
33
+
34
+
35
+ @dataclass
36
+ class CheckItem:
37
+ """A single check result."""
38
+
39
+ label: str
40
+ passed: bool
41
+ detail: str
42
+
43
+
44
+ def check(trace: Trace, expectations: Expectations) -> CheckResult:
45
+ """Validate a trace against expectations.
46
+
47
+ Args:
48
+ trace: The execution trace from a rehearsal.
49
+ expectations: The expectations from a scene.
50
+
51
+ Returns:
52
+ A CheckResult with individual check outcomes.
53
+ """
54
+ result = CheckResult()
55
+ called_tools = set(trace.call_sequence())
56
+
57
+ # required tools
58
+ for tool in expectations.required_tools:
59
+ result.checks.append(
60
+ CheckItem(
61
+ label="required_tool",
62
+ passed=tool in called_tools,
63
+ detail=f"{tool} {'called' if tool in called_tools else 'NOT called'}",
64
+ )
65
+ )
66
+
67
+ # forbidden tools
68
+ for tool in expectations.forbidden_tools:
69
+ was_called = tool in called_tools
70
+ result.checks.append(
71
+ CheckItem(
72
+ label="forbidden_tool",
73
+ passed=not was_called,
74
+ detail=f"{tool} {'CALLED (violation)' if was_called else 'not called'}",
75
+ )
76
+ )
77
+
78
+ # terminal state
79
+ if expectations.allowed_terminal_states:
80
+ in_allowed = trace.terminal_state in expectations.allowed_terminal_states
81
+ result.checks.append(
82
+ CheckItem(
83
+ label="terminal_state",
84
+ passed=in_allowed,
85
+ detail=(
86
+ f"{trace.terminal_state} ({'allowed' if in_allowed else 'NOT in allowed'})"
87
+ ),
88
+ )
89
+ )
90
+
91
+ if expectations.forbidden_terminal_states:
92
+ in_forbidden = trace.terminal_state in expectations.forbidden_terminal_states
93
+ result.checks.append(
94
+ CheckItem(
95
+ label="forbidden_terminal_state",
96
+ passed=not in_forbidden,
97
+ detail=(
98
+ f"{trace.terminal_state} "
99
+ f"{'FORBIDDEN (violation)' if in_forbidden else 'not forbidden'}"
100
+ ),
101
+ )
102
+ )
103
+
104
+ # required agents
105
+ invoked_agents = set(trace.agents_invoked())
106
+ for agent in expectations.required_agents:
107
+ result.checks.append(
108
+ CheckItem(
109
+ label="required_agent",
110
+ passed=agent in invoked_agents,
111
+ detail=f"{agent} {'invoked' if agent in invoked_agents else 'NOT invoked'}",
112
+ )
113
+ )
114
+
115
+ # forbidden agents
116
+ for agent in expectations.forbidden_agents:
117
+ was_invoked = agent in invoked_agents
118
+ result.checks.append(
119
+ CheckItem(
120
+ label="forbidden_agent",
121
+ passed=not was_invoked,
122
+ detail=f"{agent} {'INVOKED (violation)' if was_invoked else 'not invoked'}",
123
+ )
124
+ )
125
+
126
+ # required agent tools
127
+ for agent, tools in expectations.required_agent_tools.items():
128
+ for tool in tools:
129
+ called = trace.agent_called(agent, tool)
130
+ result.checks.append(
131
+ CheckItem(
132
+ label="required_agent_tool",
133
+ passed=called,
134
+ detail=f"{agent}.{tool} {'called' if called else 'NOT called'}",
135
+ )
136
+ )
137
+
138
+ return result
understudy/cli.py ADDED
@@ -0,0 +1,258 @@
1
+ """CLI: command-line interface for understudy."""
2
+
3
+ import sys
4
+ from pathlib import Path
5
+
6
+ import click
7
+
8
+ from .reports import ReportGenerator
9
+ from .storage import RunStorage
10
+
11
+
12
+ @click.group()
13
+ @click.version_option()
14
+ def main():
15
+ """understudy - Test your AI agents with simulated users."""
16
+ pass
17
+
18
+
19
+ @main.command()
20
+ @click.option(
21
+ "--runs",
22
+ "-r",
23
+ type=click.Path(exists=True, path_type=Path),
24
+ default=".understudy/runs",
25
+ help="Path to runs directory",
26
+ )
27
+ @click.option(
28
+ "--output",
29
+ "-o",
30
+ type=click.Path(path_type=Path),
31
+ default="report.html",
32
+ help="Output HTML file path",
33
+ )
34
+ def report(runs: Path, output: Path):
35
+ """Generate a static HTML report from saved runs."""
36
+ storage = RunStorage(path=runs)
37
+
38
+ run_ids = storage.list_runs()
39
+ if not run_ids:
40
+ click.echo(f"No runs found in {runs}")
41
+ sys.exit(1)
42
+
43
+ click.echo(f"Found {len(run_ids)} runs")
44
+
45
+ generator = ReportGenerator(storage)
46
+ generator.generate_static_report(output)
47
+
48
+ click.echo(f"Report generated: {output}")
49
+
50
+
51
+ @main.command()
52
+ @click.option(
53
+ "--runs",
54
+ "-r",
55
+ type=click.Path(exists=True, path_type=Path),
56
+ default=".understudy/runs",
57
+ help="Path to runs directory",
58
+ )
59
+ @click.option(
60
+ "--port",
61
+ "-p",
62
+ type=int,
63
+ default=8080,
64
+ help="Port to serve on",
65
+ )
66
+ @click.option(
67
+ "--host",
68
+ "-h",
69
+ type=str,
70
+ default="127.0.0.1",
71
+ help="Host to bind to",
72
+ )
73
+ def serve(runs: Path, port: int, host: str):
74
+ """Start an interactive report browser."""
75
+ storage = RunStorage(path=runs)
76
+
77
+ run_ids = storage.list_runs()
78
+ if not run_ids:
79
+ click.echo(f"No runs found in {runs}")
80
+ sys.exit(1)
81
+
82
+ click.echo(f"Found {len(run_ids)} runs")
83
+
84
+ generator = ReportGenerator(storage)
85
+ generator.serve(port=port, host=host)
86
+
87
+
88
+ @main.command("list")
89
+ @click.option(
90
+ "--runs",
91
+ "-r",
92
+ type=click.Path(exists=True, path_type=Path),
93
+ default=".understudy/runs",
94
+ help="Path to runs directory",
95
+ )
96
+ def list_runs(runs: Path):
97
+ """List all saved runs."""
98
+ storage = RunStorage(path=runs)
99
+
100
+ run_ids = storage.list_runs()
101
+ if not run_ids:
102
+ click.echo(f"No runs found in {runs}")
103
+ return
104
+
105
+ click.echo(f"Found {len(run_ids)} runs:\n")
106
+
107
+ for run_id in run_ids:
108
+ data = storage.load(run_id)
109
+ meta = data.get("metadata", {})
110
+ status = "PASS" if meta.get("passed") else "FAIL"
111
+ state = meta.get("terminal_state", "none")
112
+ turns = meta.get("turn_count", 0)
113
+ click.echo(f" [{status}] {run_id} - {state} ({turns} turns)")
114
+
115
+
116
+ @main.command()
117
+ @click.option(
118
+ "--runs",
119
+ "-r",
120
+ type=click.Path(exists=True, path_type=Path),
121
+ default=".understudy/runs",
122
+ help="Path to runs directory",
123
+ )
124
+ def summary(runs: Path):
125
+ """Show aggregate metrics for saved runs."""
126
+ storage = RunStorage(path=runs)
127
+
128
+ run_ids = storage.list_runs()
129
+ if not run_ids:
130
+ click.echo(f"No runs found in {runs}")
131
+ return
132
+
133
+ stats = storage.get_summary()
134
+
135
+ click.echo("understudy Summary")
136
+ click.echo("=" * 40)
137
+ click.echo(f"Total Runs: {stats['total_runs']}")
138
+ click.echo(f"Pass Rate: {stats['pass_rate'] * 100:.1f}%")
139
+ click.echo(f"Avg Turns: {stats['avg_turns']:.1f}")
140
+
141
+ if stats["tool_usage"]:
142
+ click.echo("\nTool Usage:")
143
+ for tool, count in sorted(stats["tool_usage"].items(), key=lambda x: -x[1]):
144
+ click.echo(f" {tool}: {count}")
145
+
146
+ if stats["terminal_states"]:
147
+ click.echo("\nTerminal States:")
148
+ for state, count in sorted(stats["terminal_states"].items(), key=lambda x: -x[1]):
149
+ click.echo(f" {state}: {count}")
150
+
151
+ if stats["agents"]:
152
+ click.echo("\nAgents:")
153
+ for agent, count in sorted(stats["agents"].items(), key=lambda x: -x[1]):
154
+ click.echo(f" {agent}: {count}")
155
+
156
+
157
+ @main.command()
158
+ @click.argument("run_id")
159
+ @click.option(
160
+ "--runs",
161
+ "-r",
162
+ type=click.Path(exists=True, path_type=Path),
163
+ default=".understudy/runs",
164
+ help="Path to runs directory",
165
+ )
166
+ def show(run_id: str, runs: Path):
167
+ """Show details for a specific run."""
168
+ storage = RunStorage(path=runs)
169
+
170
+ try:
171
+ data = storage.load(run_id)
172
+ except FileNotFoundError:
173
+ click.echo(f"Run not found: {run_id}")
174
+ sys.exit(1)
175
+
176
+ meta = data.get("metadata", {})
177
+ trace = data.get("trace")
178
+ check = data.get("check")
179
+
180
+ click.echo(f"Run: {run_id}")
181
+ click.echo("=" * 40)
182
+ click.echo(f"Scene: {meta.get('scene_id', 'unknown')}")
183
+ click.echo(f"Status: {'PASS' if meta.get('passed') else 'FAIL'}")
184
+ click.echo(f"Terminal State: {meta.get('terminal_state', 'none')}")
185
+ click.echo(f"Turns: {meta.get('turn_count', 0)}")
186
+ click.echo(f"Tools Called: {', '.join(meta.get('tools_called', []))}")
187
+ click.echo(f"Agents: {', '.join(meta.get('agents_invoked', []))}")
188
+
189
+ if check and check.get("checks"):
190
+ click.echo("\nExpectation Checks:")
191
+ for c in check["checks"]:
192
+ icon = "+" if c["passed"] else "-"
193
+ click.echo(f" {icon} {c['label']}: {c['detail']}")
194
+
195
+ if trace:
196
+ click.echo("\nConversation:")
197
+ click.echo("-" * 40)
198
+ for turn in trace.turns:
199
+ role = turn.agent_name or turn.role.upper()
200
+ click.echo(f"[{role}]: {turn.content[:100]}{'...' if len(turn.content) > 100 else ''}")
201
+ for call in turn.tool_calls:
202
+ click.echo(f" -> {call.tool_name}({call.arguments})")
203
+
204
+
205
+ @main.command()
206
+ @click.argument("run_id")
207
+ @click.option(
208
+ "--runs",
209
+ "-r",
210
+ type=click.Path(exists=True, path_type=Path),
211
+ default=".understudy/runs",
212
+ help="Path to runs directory",
213
+ )
214
+ @click.option("--yes", "-y", is_flag=True, help="Skip confirmation")
215
+ def delete(run_id: str, runs: Path, yes: bool):
216
+ """Delete a specific run."""
217
+ storage = RunStorage(path=runs)
218
+
219
+ try:
220
+ storage.load(run_id)
221
+ except FileNotFoundError:
222
+ click.echo(f"Run not found: {run_id}")
223
+ sys.exit(1)
224
+
225
+ if not yes and not click.confirm(f"Delete run {run_id}?"):
226
+ return
227
+
228
+ storage.delete(run_id)
229
+ click.echo(f"Deleted: {run_id}")
230
+
231
+
232
+ @main.command()
233
+ @click.option(
234
+ "--runs",
235
+ "-r",
236
+ type=click.Path(exists=True, path_type=Path),
237
+ default=".understudy/runs",
238
+ help="Path to runs directory",
239
+ )
240
+ @click.option("--yes", "-y", is_flag=True, help="Skip confirmation")
241
+ def clear(runs: Path, yes: bool):
242
+ """Delete all saved runs."""
243
+ storage = RunStorage(path=runs)
244
+
245
+ run_ids = storage.list_runs()
246
+ if not run_ids:
247
+ click.echo(f"No runs found in {runs}")
248
+ return
249
+
250
+ if not yes and not click.confirm(f"Delete all {len(run_ids)} runs?"):
251
+ return
252
+
253
+ storage.clear()
254
+ click.echo(f"Cleared {len(run_ids)} runs")
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()