langwatch-scenario 0.3.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langwatch_scenario-0.6.0.dist-info/METADATA +385 -0
- langwatch_scenario-0.6.0.dist-info/RECORD +27 -0
- scenario/__init__.py +128 -17
- scenario/{error_messages.py → _error_messages.py} +8 -38
- scenario/_utils/__init__.py +32 -0
- scenario/_utils/ids.py +58 -0
- scenario/_utils/message_conversion.py +103 -0
- scenario/_utils/utils.py +425 -0
- scenario/agent_adapter.py +115 -0
- scenario/cache.py +134 -9
- scenario/config.py +156 -10
- scenario/events/__init__.py +66 -0
- scenario/events/event_bus.py +175 -0
- scenario/events/event_reporter.py +83 -0
- scenario/events/events.py +169 -0
- scenario/events/messages.py +84 -0
- scenario/events/utils.py +86 -0
- scenario/judge_agent.py +414 -0
- scenario/pytest_plugin.py +177 -14
- scenario/scenario_executor.py +630 -154
- scenario/scenario_state.py +205 -0
- scenario/script.py +361 -0
- scenario/types.py +197 -20
- scenario/user_simulator_agent.py +242 -0
- langwatch_scenario-0.3.0.dist-info/METADATA +0 -302
- langwatch_scenario-0.3.0.dist-info/RECORD +0 -16
- scenario/scenario.py +0 -238
- scenario/scenario_agent_adapter.py +0 -16
- scenario/testing_agent.py +0 -279
- scenario/utils.py +0 -264
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/WHEEL +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/entry_points.txt +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/top_level.txt +0 -0
scenario/_utils/ids.py
ADDED
@@ -0,0 +1,58 @@
|
|
1
|
+
"""
|
2
|
+
ID generation and management utilities for scenario execution.
|
3
|
+
|
4
|
+
This module provides functions for generating and managing unique identifiers
|
5
|
+
used throughout the scenario execution pipeline, particularly for batch runs
|
6
|
+
and scenario tracking.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import os
|
10
|
+
import uuid
|
11
|
+
|
12
|
+
|
13
|
+
def get_or_create_batch_run_id() -> str:
|
14
|
+
"""
|
15
|
+
Gets or creates a batch run ID for the current scenario execution.
|
16
|
+
|
17
|
+
The batch run ID is consistent across all scenarios in the same process
|
18
|
+
execution, allowing grouping of related scenario runs. This is useful
|
19
|
+
for tracking and reporting on batches of scenarios run together.
|
20
|
+
|
21
|
+
Returns:
|
22
|
+
str: A unique batch run ID that persists for the process lifetime
|
23
|
+
|
24
|
+
Example:
|
25
|
+
```python
|
26
|
+
# All scenarios in same process will share this ID
|
27
|
+
batch_id = get_or_create_batch_run_id()
|
28
|
+
print(f"Running scenario in batch: {batch_id}")
|
29
|
+
```
|
30
|
+
"""
|
31
|
+
|
32
|
+
# Check if batch ID already exists in environment
|
33
|
+
if not os.environ.get("SCENARIO_BATCH_ID"):
|
34
|
+
# Generate new batch ID if not set
|
35
|
+
os.environ["SCENARIO_BATCH_ID"] = f"batch-run-{uuid.uuid4()}"
|
36
|
+
|
37
|
+
return os.environ["SCENARIO_BATCH_ID"]
|
38
|
+
|
39
|
+
|
40
|
+
def generate_scenario_run_id() -> str:
|
41
|
+
"""
|
42
|
+
Generates a unique scenario run ID for a single scenario execution.
|
43
|
+
|
44
|
+
Each scenario run gets a unique identifier that distinguishes it from
|
45
|
+
other runs, even within the same batch. This is used for tracking
|
46
|
+
individual scenario executions and correlating events.
|
47
|
+
|
48
|
+
Returns:
|
49
|
+
str: A unique scenario run ID
|
50
|
+
|
51
|
+
Example:
|
52
|
+
```python
|
53
|
+
# Each scenario gets its own unique ID
|
54
|
+
scenario_id = generate_scenario_run_id()
|
55
|
+
print(f"Running scenario with ID: {scenario_id}")
|
56
|
+
```
|
57
|
+
"""
|
58
|
+
return f"scenario-run-{uuid.uuid4()}"
|
@@ -0,0 +1,103 @@
|
|
1
|
+
"""
|
2
|
+
Message conversion utilities for scenario execution.
|
3
|
+
|
4
|
+
This module provides functions for converting between different message formats
|
5
|
+
used in scenario execution, particularly for normalizing agent return types
|
6
|
+
to OpenAI-compatible message formats.
|
7
|
+
"""
|
8
|
+
|
9
|
+
import json
|
10
|
+
from typing import Any, List, Literal, TypeVar, cast
|
11
|
+
from pydantic import BaseModel
|
12
|
+
from openai.types.chat import ChatCompletionMessageParam
|
13
|
+
|
14
|
+
from scenario.types import AgentReturnTypes, ScenarioResult
|
15
|
+
from .utils import SerializableAndPydanticEncoder
|
16
|
+
|
17
|
+
T = TypeVar("T")
|
18
|
+
|
19
|
+
|
20
|
+
def convert_agent_return_types_to_openai_messages(
|
21
|
+
agent_response: AgentReturnTypes, role: Literal["user", "assistant"]
|
22
|
+
) -> List[ChatCompletionMessageParam]:
|
23
|
+
"""
|
24
|
+
Convert various agent return types to standardized OpenAI message format.
|
25
|
+
|
26
|
+
This function normalizes different return types from agent adapters into
|
27
|
+
a consistent list of OpenAI-compatible messages that can be used throughout
|
28
|
+
the scenario execution pipeline.
|
29
|
+
|
30
|
+
Args:
|
31
|
+
agent_response: Response from an agent adapter call
|
32
|
+
role: The role to assign to string responses ("user" or "assistant")
|
33
|
+
|
34
|
+
Returns:
|
35
|
+
List of OpenAI-compatible messages
|
36
|
+
|
37
|
+
Raises:
|
38
|
+
ValueError: If agent_response is a ScenarioResult (which should be handled separately)
|
39
|
+
|
40
|
+
Example:
|
41
|
+
```
|
42
|
+
# String response
|
43
|
+
messages = convert_agent_return_types_to_openai_messages("Hello", "assistant")
|
44
|
+
# Result: [{"role": "assistant", "content": "Hello"}]
|
45
|
+
|
46
|
+
# Dict response
|
47
|
+
response = {"role": "assistant", "content": "Hi", "tool_calls": [...]}
|
48
|
+
messages = convert_agent_return_types_to_openai_messages(response, "assistant")
|
49
|
+
# Result: [{"role": "assistant", "content": "Hi", "tool_calls": [...]}]
|
50
|
+
|
51
|
+
# List response
|
52
|
+
responses = [
|
53
|
+
{"role": "assistant", "content": "Thinking..."},
|
54
|
+
{"role": "assistant", "content": "Here's the answer"}
|
55
|
+
]
|
56
|
+
messages = convert_agent_return_types_to_openai_messages(responses, "assistant")
|
57
|
+
# Result: Same list, validated and normalized
|
58
|
+
```
|
59
|
+
"""
|
60
|
+
if isinstance(agent_response, ScenarioResult):
|
61
|
+
raise ValueError(
|
62
|
+
"Unexpectedly tried to convert a ScenarioResult to openai messages",
|
63
|
+
agent_response.__repr__(),
|
64
|
+
)
|
65
|
+
|
66
|
+
def convert_maybe_object_to_openai_message(
|
67
|
+
obj: Any,
|
68
|
+
) -> ChatCompletionMessageParam:
|
69
|
+
if isinstance(obj, dict):
|
70
|
+
return cast(ChatCompletionMessageParam, obj)
|
71
|
+
elif isinstance(obj, BaseModel):
|
72
|
+
return cast(
|
73
|
+
ChatCompletionMessageParam,
|
74
|
+
obj.model_dump(
|
75
|
+
exclude_unset=True,
|
76
|
+
exclude_none=True,
|
77
|
+
exclude_defaults=True,
|
78
|
+
warnings=False,
|
79
|
+
),
|
80
|
+
)
|
81
|
+
else:
|
82
|
+
raise ValueError(f"Unexpected agent response type: {type(obj).__name__}")
|
83
|
+
|
84
|
+
def ensure_dict(
|
85
|
+
obj: T,
|
86
|
+
) -> T:
|
87
|
+
return json.loads(json.dumps(obj, cls=SerializableAndPydanticEncoder))
|
88
|
+
|
89
|
+
if isinstance(agent_response, str):
|
90
|
+
return [
|
91
|
+
(
|
92
|
+
{"role": "user", "content": agent_response}
|
93
|
+
if role == "user"
|
94
|
+
else {"role": "assistant", "content": agent_response}
|
95
|
+
)
|
96
|
+
]
|
97
|
+
elif isinstance(agent_response, list):
|
98
|
+
return [
|
99
|
+
ensure_dict(convert_maybe_object_to_openai_message(message))
|
100
|
+
for message in agent_response
|
101
|
+
]
|
102
|
+
else:
|
103
|
+
return [ensure_dict(convert_maybe_object_to_openai_message(agent_response))]
|
scenario/_utils/utils.py
ADDED
@@ -0,0 +1,425 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for scenario execution and message handling.
|
3
|
+
|
4
|
+
This module provides various utility functions used throughout the Scenario framework,
|
5
|
+
including message formatting, validation, role reversal, and UI components like spinners
|
6
|
+
for better user experience during scenario execution.
|
7
|
+
"""
|
8
|
+
|
9
|
+
from contextlib import contextmanager
|
10
|
+
import sys
|
11
|
+
from typing import (
|
12
|
+
Any,
|
13
|
+
Iterator,
|
14
|
+
Optional,
|
15
|
+
Union,
|
16
|
+
TypeVar,
|
17
|
+
Awaitable,
|
18
|
+
)
|
19
|
+
from pydantic import BaseModel
|
20
|
+
import copy
|
21
|
+
|
22
|
+
import json
|
23
|
+
|
24
|
+
import termcolor
|
25
|
+
from textwrap import indent
|
26
|
+
from openai.types.chat import ChatCompletionMessageParam
|
27
|
+
from rich.live import Live
|
28
|
+
from rich.spinner import Spinner
|
29
|
+
from rich.console import Console
|
30
|
+
from rich.text import Text
|
31
|
+
from rich.errors import LiveError
|
32
|
+
|
33
|
+
from scenario._error_messages import message_return_error_message
|
34
|
+
from scenario.types import ScenarioResult
|
35
|
+
|
36
|
+
T = TypeVar("T")
|
37
|
+
|
38
|
+
|
39
|
+
class SerializableAndPydanticEncoder(json.JSONEncoder):
|
40
|
+
"""
|
41
|
+
JSON encoder that handles Pydantic models and iterators.
|
42
|
+
|
43
|
+
This encoder extends the standard JSON encoder to handle Pydantic BaseModel
|
44
|
+
instances and iterator objects, converting them to serializable formats.
|
45
|
+
Used for caching and logging scenarios that contain complex objects.
|
46
|
+
|
47
|
+
Example:
|
48
|
+
```
|
49
|
+
data = {
|
50
|
+
"model": SomeBaseModel(field="value"),
|
51
|
+
"iterator": iter([1, 2, 3])
|
52
|
+
}
|
53
|
+
json.dumps(data, cls=SerializableAndPydanticEncoder)
|
54
|
+
```
|
55
|
+
"""
|
56
|
+
def default(self, o: Any) -> Any:
|
57
|
+
if isinstance(o, BaseModel):
|
58
|
+
return o.model_dump(exclude_unset=True)
|
59
|
+
if isinstance(o, Iterator):
|
60
|
+
return list(o)
|
61
|
+
return super().default(o)
|
62
|
+
|
63
|
+
|
64
|
+
class SerializableWithStringFallback(SerializableAndPydanticEncoder):
|
65
|
+
"""
|
66
|
+
JSON encoder with string fallback for non-serializable objects.
|
67
|
+
|
68
|
+
This encoder extends SerializableAndPydanticEncoder by providing a string
|
69
|
+
fallback for any object that cannot be serialized normally. This ensures
|
70
|
+
that logging and caching operations never fail due to serialization issues.
|
71
|
+
|
72
|
+
Example:
|
73
|
+
```
|
74
|
+
# This will work even with complex non-serializable objects
|
75
|
+
data = {"function": lambda x: x, "complex_object": SomeComplexClass()}
|
76
|
+
json.dumps(data, cls=SerializableWithStringFallback)
|
77
|
+
# Result: {"function": "<function <lambda> at 0x...>", "complex_object": "..."}
|
78
|
+
```
|
79
|
+
"""
|
80
|
+
def default(self, o: Any) -> Any:
|
81
|
+
try:
|
82
|
+
return super().default(o)
|
83
|
+
except:
|
84
|
+
return str(o)
|
85
|
+
|
86
|
+
|
87
|
+
def safe_list_at(list_obj: list, index: int, default: Any = None) -> Any:
|
88
|
+
"""
|
89
|
+
Safely get an item from a list by index with a default fallback.
|
90
|
+
|
91
|
+
Args:
|
92
|
+
list_obj: The list to access
|
93
|
+
index: The index to retrieve
|
94
|
+
default: Value to return if index is out of bounds
|
95
|
+
|
96
|
+
Returns:
|
97
|
+
The item at the index, or the default value if index is invalid
|
98
|
+
|
99
|
+
Example:
|
100
|
+
```
|
101
|
+
items = ["a", "b", "c"]
|
102
|
+
print(safe_list_at(items, 1)) # "b"
|
103
|
+
print(safe_list_at(items, 10)) # None
|
104
|
+
print(safe_list_at(items, 10, "default")) # "default"
|
105
|
+
```
|
106
|
+
"""
|
107
|
+
try:
|
108
|
+
return list_obj[index]
|
109
|
+
except:
|
110
|
+
return default
|
111
|
+
|
112
|
+
|
113
|
+
def safe_attr_or_key(obj: Any, attr_or_key: str, default: Any = None) -> Any:
|
114
|
+
"""
|
115
|
+
Safely get an attribute or dictionary key from an object.
|
116
|
+
|
117
|
+
Tries to get the value as an attribute first, then as a dictionary key,
|
118
|
+
returning the default if neither exists.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
obj: Object to access (can have attributes or be dict-like)
|
122
|
+
attr_or_key: Name of attribute or key to retrieve
|
123
|
+
default: Value to return if attribute/key doesn't exist
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
The attribute/key value, or the default if not found
|
127
|
+
|
128
|
+
Example:
|
129
|
+
```
|
130
|
+
class MyClass:
|
131
|
+
attr = "value"
|
132
|
+
|
133
|
+
obj = MyClass()
|
134
|
+
dict_obj = {"key": "value"}
|
135
|
+
|
136
|
+
print(safe_attr_or_key(obj, "attr")) # "value"
|
137
|
+
print(safe_attr_or_key(dict_obj, "key")) # "value"
|
138
|
+
print(safe_attr_or_key(obj, "missing")) # None
|
139
|
+
```
|
140
|
+
"""
|
141
|
+
return getattr(obj, attr_or_key, getattr(obj, 'get', lambda x, default=None: default)(attr_or_key, default))
|
142
|
+
|
143
|
+
|
144
|
+
def title_case(string: str) -> str:
|
145
|
+
"""
|
146
|
+
Convert snake_case string to Title Case.
|
147
|
+
|
148
|
+
Args:
|
149
|
+
string: Snake_case string to convert
|
150
|
+
|
151
|
+
Returns:
|
152
|
+
String converted to Title Case
|
153
|
+
|
154
|
+
Example:
|
155
|
+
```
|
156
|
+
print(title_case("user_simulator_agent")) # "User Simulator Agent"
|
157
|
+
print(title_case("api_key")) # "Api Key"
|
158
|
+
```
|
159
|
+
"""
|
160
|
+
return " ".join(word.capitalize() for word in string.split("_"))
|
161
|
+
|
162
|
+
|
163
|
+
def print_openai_messages(
|
164
|
+
scenario_name: str, messages: list[ChatCompletionMessageParam]
|
165
|
+
):
|
166
|
+
"""
|
167
|
+
Print OpenAI-format messages with colored formatting for readability.
|
168
|
+
|
169
|
+
This function formats and prints conversation messages with appropriate
|
170
|
+
colors and formatting for different message types (user, assistant, tool calls, etc.).
|
171
|
+
Used for verbose output during scenario execution.
|
172
|
+
|
173
|
+
Args:
|
174
|
+
scenario_name: Name of the scenario (used as prefix)
|
175
|
+
messages: List of OpenAI-compatible messages to print
|
176
|
+
|
177
|
+
Example:
|
178
|
+
```
|
179
|
+
messages = [
|
180
|
+
{"role": "user", "content": "Hello"},
|
181
|
+
{"role": "assistant", "content": "Hi there!"},
|
182
|
+
{"role": "assistant", "tool_calls": [{"function": {"name": "search"}}]}
|
183
|
+
]
|
184
|
+
print_openai_messages("Test Scenario", messages)
|
185
|
+
```
|
186
|
+
|
187
|
+
Note:
|
188
|
+
- User messages are printed in green
|
189
|
+
- Assistant messages are printed in blue
|
190
|
+
- Tool calls are printed in magenta with formatted JSON
|
191
|
+
- Long JSON content is truncated for readability
|
192
|
+
"""
|
193
|
+
for msg in messages:
|
194
|
+
role = safe_attr_or_key(msg, "role")
|
195
|
+
content = safe_attr_or_key(msg, "content")
|
196
|
+
if role == "assistant":
|
197
|
+
tool_calls = safe_attr_or_key(msg, "tool_calls")
|
198
|
+
if content:
|
199
|
+
print(scenario_name + termcolor.colored("Agent:", "blue"), content)
|
200
|
+
if tool_calls:
|
201
|
+
for tool_call in tool_calls:
|
202
|
+
function = safe_attr_or_key(tool_call, "function")
|
203
|
+
name = safe_attr_or_key(function, "name")
|
204
|
+
args = safe_attr_or_key(function, "arguments", "{}")
|
205
|
+
args = _take_maybe_json_first_lines(args)
|
206
|
+
print(
|
207
|
+
scenario_name
|
208
|
+
+ termcolor.colored(f"ToolCall({name}):", "magenta"),
|
209
|
+
f"\n\n{indent(args, ' ' * 4)}\n",
|
210
|
+
)
|
211
|
+
elif role == "user":
|
212
|
+
print(scenario_name + termcolor.colored("User:", "green"), content)
|
213
|
+
elif role == "tool":
|
214
|
+
content = _take_maybe_json_first_lines(content or msg.__repr__())
|
215
|
+
print(
|
216
|
+
scenario_name + termcolor.colored(f"ToolResult:", "magenta"),
|
217
|
+
f"\n\n{indent(content, ' ' * 4)}\n",
|
218
|
+
)
|
219
|
+
else:
|
220
|
+
print(
|
221
|
+
scenario_name + termcolor.colored(f"{title_case(role)}:", "magenta"),
|
222
|
+
msg.__repr__(),
|
223
|
+
)
|
224
|
+
|
225
|
+
|
226
|
+
def _take_maybe_json_first_lines(string: str, max_lines: int = 5) -> str:
|
227
|
+
"""
|
228
|
+
Truncate string content and format JSON if possible.
|
229
|
+
|
230
|
+
Internal utility function that attempts to format content as JSON
|
231
|
+
and truncates it to a reasonable number of lines for display.
|
232
|
+
|
233
|
+
Args:
|
234
|
+
string: Content to format and truncate
|
235
|
+
max_lines: Maximum number of lines to show
|
236
|
+
|
237
|
+
Returns:
|
238
|
+
Formatted and potentially truncated string
|
239
|
+
"""
|
240
|
+
content = str(string)
|
241
|
+
try:
|
242
|
+
content = json.dumps(json.loads(content), indent=2)
|
243
|
+
except:
|
244
|
+
pass
|
245
|
+
content = content.split("\n")
|
246
|
+
if len(content) > max_lines:
|
247
|
+
content = content[:max_lines] + ["..."]
|
248
|
+
return "\n".join(content)
|
249
|
+
|
250
|
+
|
251
|
+
console = Console()
|
252
|
+
|
253
|
+
|
254
|
+
class TextFirstSpinner(Spinner):
|
255
|
+
"""
|
256
|
+
Custom spinner that displays text before the spinning animation.
|
257
|
+
|
258
|
+
This class extends Rich's Spinner to show descriptive text followed
|
259
|
+
by the spinning animation, improving the user experience during
|
260
|
+
scenario execution by clearly indicating what operation is happening.
|
261
|
+
|
262
|
+
Args:
|
263
|
+
name: Name of the spinner animation style
|
264
|
+
text: Descriptive text to show before the spinner
|
265
|
+
color: Color for the descriptive text
|
266
|
+
**kwargs: Additional arguments passed to the base Spinner class
|
267
|
+
"""
|
268
|
+
def __init__(self, name: str, text: str, color: str, **kwargs: Any) -> None:
|
269
|
+
super().__init__(
|
270
|
+
name, "", style="bold white", **kwargs
|
271
|
+
) # Initialize with empty text
|
272
|
+
self.text_before = text
|
273
|
+
self.color = color
|
274
|
+
|
275
|
+
def render(self, time: float) -> Text:
|
276
|
+
# Get the original spinner frame
|
277
|
+
spinner_frame = super().render(time)
|
278
|
+
# Create a composite with text first, then spinner
|
279
|
+
return Text(f"{self.text_before} ", style=self.color) + spinner_frame
|
280
|
+
|
281
|
+
|
282
|
+
@contextmanager
|
283
|
+
def show_spinner(
|
284
|
+
text: str, color: str = "white", enabled: Optional[Union[bool, int]] = None
|
285
|
+
):
|
286
|
+
"""
|
287
|
+
Context manager for displaying a spinner during long-running operations.
|
288
|
+
|
289
|
+
Shows a spinning indicator with descriptive text while code executes
|
290
|
+
within the context. Automatically cleans up the spinner display when
|
291
|
+
the operation completes.
|
292
|
+
|
293
|
+
Args:
|
294
|
+
text: Descriptive text to show next to the spinner
|
295
|
+
color: Color for the descriptive text
|
296
|
+
enabled: Whether to show the spinner (respects verbose settings)
|
297
|
+
|
298
|
+
Example:
|
299
|
+
```
|
300
|
+
with show_spinner("Calling agent...", color="blue", enabled=True):
|
301
|
+
response = await agent.call(input_data)
|
302
|
+
|
303
|
+
# Spinner automatically disappears when block completes
|
304
|
+
print("Agent call completed")
|
305
|
+
```
|
306
|
+
|
307
|
+
Note:
|
308
|
+
- Spinner is automatically cleaned up when context exits
|
309
|
+
- Gracefully handles multi-threading scenarios where multiple spinners might conflict
|
310
|
+
- Cursor positioning ensures clean terminal output
|
311
|
+
"""
|
312
|
+
if not enabled:
|
313
|
+
yield
|
314
|
+
else:
|
315
|
+
spinner = TextFirstSpinner("dots", text, color=color)
|
316
|
+
try:
|
317
|
+
with Live(spinner, console=console, refresh_per_second=20):
|
318
|
+
yield
|
319
|
+
# It happens when we are multi-threading, it's fine, just ignore it, you probably don't want multiple spinners at once anyway
|
320
|
+
except LiveError:
|
321
|
+
yield
|
322
|
+
|
323
|
+
# Cursor up one line
|
324
|
+
sys.stdout.write("\033[F")
|
325
|
+
# Erase the line
|
326
|
+
sys.stdout.write("\033[2K")
|
327
|
+
|
328
|
+
|
329
|
+
def check_valid_return_type(return_value: Any, class_name: str) -> None:
|
330
|
+
"""
|
331
|
+
Validate that an agent's return value is in the expected format.
|
332
|
+
|
333
|
+
This function ensures that agent adapters return values in one of the
|
334
|
+
supported formats (string, OpenAI message, list of messages, or ScenarioResult).
|
335
|
+
It also verifies that the returned data is JSON-serializable for caching.
|
336
|
+
|
337
|
+
Args:
|
338
|
+
return_value: The value returned by an agent's call method
|
339
|
+
class_name: Name of the agent class (for error messages)
|
340
|
+
|
341
|
+
Raises:
|
342
|
+
ValueError: If the return value is not in a supported format
|
343
|
+
|
344
|
+
Example:
|
345
|
+
```
|
346
|
+
# Valid return values
|
347
|
+
check_valid_return_type("Hello world", "MyAgent") # OK
|
348
|
+
check_valid_return_type({"role": "assistant", "content": "Hi"}, "MyAgent") # OK
|
349
|
+
check_valid_return_type([{"role": "assistant", "content": "Hi"}], "MyAgent") # OK
|
350
|
+
|
351
|
+
# Invalid return value
|
352
|
+
check_valid_return_type(42, "MyAgent") # Raises ValueError
|
353
|
+
```
|
354
|
+
"""
|
355
|
+
def _is_valid_openai_message(message: Any) -> bool:
|
356
|
+
return (isinstance(message, dict) and "role" in message) or (
|
357
|
+
isinstance(message, BaseModel) and hasattr(message, "role")
|
358
|
+
)
|
359
|
+
|
360
|
+
if (
|
361
|
+
isinstance(return_value, str)
|
362
|
+
or _is_valid_openai_message(return_value)
|
363
|
+
or (
|
364
|
+
isinstance(return_value, list)
|
365
|
+
and all(_is_valid_openai_message(message) for message in return_value)
|
366
|
+
)
|
367
|
+
or isinstance(return_value, ScenarioResult)
|
368
|
+
):
|
369
|
+
try:
|
370
|
+
json.dumps(return_value, cls=SerializableAndPydanticEncoder)
|
371
|
+
except:
|
372
|
+
raise ValueError(
|
373
|
+
message_return_error_message(got=return_value, class_name=class_name)
|
374
|
+
)
|
375
|
+
|
376
|
+
return
|
377
|
+
|
378
|
+
raise ValueError(
|
379
|
+
message_return_error_message(got=return_value, class_name=class_name)
|
380
|
+
)
|
381
|
+
|
382
|
+
|
383
|
+
def reverse_roles(
|
384
|
+
messages: list[ChatCompletionMessageParam],
|
385
|
+
) -> list[ChatCompletionMessageParam]:
|
386
|
+
"""
|
387
|
+
Reverses the roles of the messages in the list.
|
388
|
+
|
389
|
+
Args:
|
390
|
+
messages: The list of messages to reverse the roles of.
|
391
|
+
"""
|
392
|
+
|
393
|
+
reversed_messages = []
|
394
|
+
for message in messages:
|
395
|
+
message = copy.deepcopy(message)
|
396
|
+
# Can't reverse tool calls
|
397
|
+
if not safe_attr_or_key(message, "content") or safe_attr_or_key(
|
398
|
+
message, "tool_calls"
|
399
|
+
):
|
400
|
+
# If no content nor tool calls, we should skip it entirely, as anthropic may generate some invalid ones e.g. pure {"role": "assistant"}
|
401
|
+
if safe_attr_or_key(message, "tool_calls"):
|
402
|
+
reversed_messages.append(message)
|
403
|
+
continue
|
404
|
+
|
405
|
+
if type(message) == dict:
|
406
|
+
if message["role"] == "user":
|
407
|
+
message["role"] = "assistant"
|
408
|
+
elif message["role"] == "assistant":
|
409
|
+
message["role"] = "user"
|
410
|
+
else:
|
411
|
+
if getattr(message, "role", None) == "user":
|
412
|
+
message.role = "assistant" # type: ignore
|
413
|
+
elif getattr(message, "role", None) == "assistant":
|
414
|
+
message.role = "user" # type: ignore
|
415
|
+
|
416
|
+
reversed_messages.append(message)
|
417
|
+
|
418
|
+
return reversed_messages
|
419
|
+
|
420
|
+
|
421
|
+
async def await_if_awaitable(value: T) -> T:
|
422
|
+
if isinstance(value, Awaitable):
|
423
|
+
return await value
|
424
|
+
else:
|
425
|
+
return value
|
@@ -0,0 +1,115 @@
|
|
1
|
+
"""
|
2
|
+
Agent adapter module for integrating custom agents with the Scenario framework.
|
3
|
+
|
4
|
+
This module provides the abstract base class that users must implement to integrate
|
5
|
+
their existing agents with the Scenario testing framework. The adapter pattern allows
|
6
|
+
any agent implementation to work with the framework regardless of its underlying
|
7
|
+
architecture or API.
|
8
|
+
"""
|
9
|
+
|
10
|
+
from abc import ABC, abstractmethod
|
11
|
+
from typing import ClassVar
|
12
|
+
|
13
|
+
from .types import AgentInput, AgentReturnTypes, AgentRole
|
14
|
+
|
15
|
+
|
16
|
+
class AgentAdapter(ABC):
|
17
|
+
"""
|
18
|
+
Abstract base class for integrating custom agents with the Scenario framework.
|
19
|
+
|
20
|
+
This adapter pattern allows you to wrap any existing agent implementation
|
21
|
+
(LLM calls, agent frameworks, or complex multi-step systems) to work with
|
22
|
+
the Scenario testing framework. The adapter receives structured input about
|
23
|
+
the conversation state and returns responses in a standardized format.
|
24
|
+
|
25
|
+
Attributes:
|
26
|
+
role: The role this agent plays in scenarios (USER, AGENT, or JUDGE)
|
27
|
+
|
28
|
+
Example:
|
29
|
+
```
|
30
|
+
import scenario
|
31
|
+
from my_agent import MyCustomAgent
|
32
|
+
|
33
|
+
class MyAgentAdapter(scenario.AgentAdapter):
|
34
|
+
def __init__(self):
|
35
|
+
self.agent = MyCustomAgent()
|
36
|
+
|
37
|
+
async def call(self, input: scenario.AgentInput) -> scenario.AgentReturnTypes:
|
38
|
+
# Get the latest user message
|
39
|
+
user_message = input.last_new_user_message_str()
|
40
|
+
|
41
|
+
# Call your existing agent
|
42
|
+
response = await self.agent.process(
|
43
|
+
message=user_message,
|
44
|
+
history=input.messages,
|
45
|
+
thread_id=input.thread_id
|
46
|
+
)
|
47
|
+
|
48
|
+
# Return the response (can be string, message dict, or list of messages)
|
49
|
+
return response
|
50
|
+
|
51
|
+
# Use in a scenario
|
52
|
+
result = await scenario.run(
|
53
|
+
name="test my agent",
|
54
|
+
description="User asks for help with a coding problem",
|
55
|
+
agents=[
|
56
|
+
MyAgentAdapter(),
|
57
|
+
scenario.UserSimulatorAgent(),
|
58
|
+
scenario.JudgeAgent(criteria=["Provides helpful coding advice"])
|
59
|
+
]
|
60
|
+
)
|
61
|
+
```
|
62
|
+
|
63
|
+
Note:
|
64
|
+
- The call method must be async
|
65
|
+
- Return types can be: str, ChatCompletionMessageParam, List[ChatCompletionMessageParam], or ScenarioResult
|
66
|
+
- For stateful agents, use input.thread_id to maintain conversation context
|
67
|
+
- For stateless agents, use input.messages for the full conversation history
|
68
|
+
"""
|
69
|
+
|
70
|
+
role: ClassVar[AgentRole] = AgentRole.AGENT
|
71
|
+
|
72
|
+
@abstractmethod
|
73
|
+
async def call(self, input: AgentInput) -> AgentReturnTypes:
|
74
|
+
"""
|
75
|
+
Process the input and generate a response.
|
76
|
+
|
77
|
+
This is the main method that your agent implementation must provide.
|
78
|
+
It receives structured information about the current conversation state
|
79
|
+
and must return a response in one of the supported formats.
|
80
|
+
|
81
|
+
Args:
|
82
|
+
input: AgentInput containing conversation history, thread context, and scenario state
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
AgentReturnTypes: The agent's response, which can be:
|
86
|
+
|
87
|
+
- str: Simple text response
|
88
|
+
|
89
|
+
- ChatCompletionMessageParam: Single OpenAI-format message
|
90
|
+
|
91
|
+
- List[ChatCompletionMessageParam]: Multiple messages for complex responses
|
92
|
+
|
93
|
+
- ScenarioResult: Direct test result (typically only used by judge agents)
|
94
|
+
|
95
|
+
Example:
|
96
|
+
```
|
97
|
+
async def call(self, input: AgentInput) -> AgentReturnTypes:
|
98
|
+
# Simple string response
|
99
|
+
user_msg = input.last_new_user_message_str()
|
100
|
+
return f"I understand you said: {user_msg}"
|
101
|
+
|
102
|
+
# Or structured message response
|
103
|
+
return {
|
104
|
+
"role": "assistant",
|
105
|
+
"content": "Let me help you with that...",
|
106
|
+
}
|
107
|
+
|
108
|
+
# Or multiple messages for complex interactions
|
109
|
+
return [
|
110
|
+
{"role": "assistant", "content": "Let me search for that information..."},
|
111
|
+
{"role": "assistant", "content": "Here's what I found: ..."}
|
112
|
+
]
|
113
|
+
```
|
114
|
+
"""
|
115
|
+
pass
|