langwatch-scenario 0.3.0__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.4.0.dist-info}/METADATA +140 -79
- langwatch_scenario-0.4.0.dist-info/RECORD +18 -0
- scenario/__init__.py +223 -9
- scenario/agent_adapter.py +111 -0
- scenario/cache.py +132 -8
- scenario/config.py +154 -10
- scenario/error_messages.py +8 -38
- scenario/judge_agent.py +435 -0
- scenario/pytest_plugin.py +223 -15
- scenario/scenario_executor.py +428 -136
- scenario/scenario_state.py +205 -0
- scenario/script.py +361 -0
- scenario/types.py +193 -20
- scenario/user_simulator_agent.py +249 -0
- scenario/utils.py +252 -2
- langwatch_scenario-0.3.0.dist-info/RECORD +0 -16
- scenario/scenario.py +0 -238
- scenario/scenario_agent_adapter.py +0 -16
- scenario/testing_agent.py +0 -279
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.4.0.dist-info}/WHEEL +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.4.0.dist-info}/entry_points.txt +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.4.0.dist-info}/top_level.txt +0 -0
scenario/utils.py
CHANGED
@@ -1,3 +1,11 @@
|
|
1
|
+
"""
|
2
|
+
Utility functions for scenario execution and message handling.
|
3
|
+
|
4
|
+
This module provides various utility functions used throughout the Scenario framework,
|
5
|
+
including message formatting, validation, role reversal, and UI components like spinners
|
6
|
+
for better user experience during scenario execution.
|
7
|
+
"""
|
8
|
+
|
1
9
|
from contextlib import contextmanager
|
2
10
|
import sys
|
3
11
|
from typing import (
|
@@ -12,6 +20,7 @@ from typing import (
|
|
12
20
|
cast,
|
13
21
|
)
|
14
22
|
from pydantic import BaseModel
|
23
|
+
import copy
|
15
24
|
|
16
25
|
import json
|
17
26
|
|
@@ -31,6 +40,22 @@ T = TypeVar("T")
|
|
31
40
|
|
32
41
|
|
33
42
|
class SerializableAndPydanticEncoder(json.JSONEncoder):
|
43
|
+
"""
|
44
|
+
JSON encoder that handles Pydantic models and iterators.
|
45
|
+
|
46
|
+
This encoder extends the standard JSON encoder to handle Pydantic BaseModel
|
47
|
+
instances and iterator objects, converting them to serializable formats.
|
48
|
+
Used for caching and logging scenarios that contain complex objects.
|
49
|
+
|
50
|
+
Example:
|
51
|
+
```python
|
52
|
+
data = {
|
53
|
+
"model": SomeBaseModel(field="value"),
|
54
|
+
"iterator": iter([1, 2, 3])
|
55
|
+
}
|
56
|
+
json.dumps(data, cls=SerializableAndPydanticEncoder)
|
57
|
+
```
|
58
|
+
"""
|
34
59
|
def default(self, o):
|
35
60
|
if isinstance(o, BaseModel):
|
36
61
|
return o.model_dump(exclude_unset=True)
|
@@ -40,6 +65,21 @@ class SerializableAndPydanticEncoder(json.JSONEncoder):
|
|
40
65
|
|
41
66
|
|
42
67
|
class SerializableWithStringFallback(SerializableAndPydanticEncoder):
|
68
|
+
"""
|
69
|
+
JSON encoder with string fallback for non-serializable objects.
|
70
|
+
|
71
|
+
This encoder extends SerializableAndPydanticEncoder by providing a string
|
72
|
+
fallback for any object that cannot be serialized normally. This ensures
|
73
|
+
that logging and caching operations never fail due to serialization issues.
|
74
|
+
|
75
|
+
Example:
|
76
|
+
```python
|
77
|
+
# This will work even with complex non-serializable objects
|
78
|
+
data = {"function": lambda x: x, "complex_object": SomeComplexClass()}
|
79
|
+
json.dumps(data, cls=SerializableWithStringFallback)
|
80
|
+
# Result: {"function": "<function <lambda> at 0x...>", "complex_object": "..."}
|
81
|
+
```
|
82
|
+
"""
|
43
83
|
def default(self, o):
|
44
84
|
try:
|
45
85
|
return super().default(o)
|
@@ -48,6 +88,25 @@ class SerializableWithStringFallback(SerializableAndPydanticEncoder):
|
|
48
88
|
|
49
89
|
|
50
90
|
def safe_list_at(list, index, default=None):
|
91
|
+
"""
|
92
|
+
Safely get an item from a list by index with a default fallback.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
list: The list to access
|
96
|
+
index: The index to retrieve
|
97
|
+
default: Value to return if index is out of bounds
|
98
|
+
|
99
|
+
Returns:
|
100
|
+
The item at the index, or the default value if index is invalid
|
101
|
+
|
102
|
+
Example:
|
103
|
+
```python
|
104
|
+
items = ["a", "b", "c"]
|
105
|
+
print(safe_list_at(items, 1)) # "b"
|
106
|
+
print(safe_list_at(items, 10)) # None
|
107
|
+
print(safe_list_at(items, 10, "default")) # "default"
|
108
|
+
```
|
109
|
+
"""
|
51
110
|
try:
|
52
111
|
return list[index]
|
53
112
|
except:
|
@@ -55,16 +114,85 @@ def safe_list_at(list, index, default=None):
|
|
55
114
|
|
56
115
|
|
57
116
|
def safe_attr_or_key(obj, attr_or_key, default=None):
|
117
|
+
"""
|
118
|
+
Safely get an attribute or dictionary key from an object.
|
119
|
+
|
120
|
+
Tries to get the value as an attribute first, then as a dictionary key,
|
121
|
+
returning the default if neither exists.
|
122
|
+
|
123
|
+
Args:
|
124
|
+
obj: Object to access (can have attributes or be dict-like)
|
125
|
+
attr_or_key: Name of attribute or key to retrieve
|
126
|
+
default: Value to return if attribute/key doesn't exist
|
127
|
+
|
128
|
+
Returns:
|
129
|
+
The attribute/key value, or the default if not found
|
130
|
+
|
131
|
+
Example:
|
132
|
+
```python
|
133
|
+
class MyClass:
|
134
|
+
attr = "value"
|
135
|
+
|
136
|
+
obj = MyClass()
|
137
|
+
dict_obj = {"key": "value"}
|
138
|
+
|
139
|
+
print(safe_attr_or_key(obj, "attr")) # "value"
|
140
|
+
print(safe_attr_or_key(dict_obj, "key")) # "value"
|
141
|
+
print(safe_attr_or_key(obj, "missing")) # None
|
142
|
+
```
|
143
|
+
"""
|
58
144
|
return getattr(obj, attr_or_key, obj.get(attr_or_key))
|
59
145
|
|
60
146
|
|
61
147
|
def title_case(string):
|
148
|
+
"""
|
149
|
+
Convert snake_case string to Title Case.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
string: Snake_case string to convert
|
153
|
+
|
154
|
+
Returns:
|
155
|
+
String converted to Title Case
|
156
|
+
|
157
|
+
Example:
|
158
|
+
```python
|
159
|
+
print(title_case("user_simulator_agent")) # "User Simulator Agent"
|
160
|
+
print(title_case("api_key")) # "Api Key"
|
161
|
+
```
|
162
|
+
"""
|
62
163
|
return " ".join(word.capitalize() for word in string.split("_"))
|
63
164
|
|
64
165
|
|
65
166
|
def print_openai_messages(
|
66
167
|
scenario_name: str, messages: list[ChatCompletionMessageParam]
|
67
168
|
):
|
169
|
+
"""
|
170
|
+
Print OpenAI-format messages with colored formatting for readability.
|
171
|
+
|
172
|
+
This function formats and prints conversation messages with appropriate
|
173
|
+
colors and formatting for different message types (user, assistant, tool calls, etc.).
|
174
|
+
Used for verbose output during scenario execution.
|
175
|
+
|
176
|
+
Args:
|
177
|
+
scenario_name: Name of the scenario (used as prefix)
|
178
|
+
messages: List of OpenAI-compatible messages to print
|
179
|
+
|
180
|
+
Example:
|
181
|
+
```python
|
182
|
+
messages = [
|
183
|
+
{"role": "user", "content": "Hello"},
|
184
|
+
{"role": "assistant", "content": "Hi there!"},
|
185
|
+
{"role": "assistant", "tool_calls": [{"function": {"name": "search"}}]}
|
186
|
+
]
|
187
|
+
print_openai_messages("Test Scenario", messages)
|
188
|
+
```
|
189
|
+
|
190
|
+
Note:
|
191
|
+
- User messages are printed in green
|
192
|
+
- Assistant messages are printed in blue
|
193
|
+
- Tool calls are printed in magenta with formatted JSON
|
194
|
+
- Long JSON content is truncated for readability
|
195
|
+
"""
|
68
196
|
for msg in messages:
|
69
197
|
role = safe_attr_or_key(msg, "role")
|
70
198
|
content = safe_attr_or_key(msg, "content")
|
@@ -99,6 +227,19 @@ def print_openai_messages(
|
|
99
227
|
|
100
228
|
|
101
229
|
def _take_maybe_json_first_lines(string, max_lines=5):
|
230
|
+
"""
|
231
|
+
Truncate string content and format JSON if possible.
|
232
|
+
|
233
|
+
Internal utility function that attempts to format content as JSON
|
234
|
+
and truncates it to a reasonable number of lines for display.
|
235
|
+
|
236
|
+
Args:
|
237
|
+
string: Content to format and truncate
|
238
|
+
max_lines: Maximum number of lines to show
|
239
|
+
|
240
|
+
Returns:
|
241
|
+
Formatted and potentially truncated string
|
242
|
+
"""
|
102
243
|
content = str(string)
|
103
244
|
try:
|
104
245
|
content = json.dumps(json.loads(content), indent=2)
|
@@ -114,6 +255,19 @@ console = Console()
|
|
114
255
|
|
115
256
|
|
116
257
|
class TextFirstSpinner(Spinner):
|
258
|
+
"""
|
259
|
+
Custom spinner that displays text before the spinning animation.
|
260
|
+
|
261
|
+
This class extends Rich's Spinner to show descriptive text followed
|
262
|
+
by the spinning animation, improving the user experience during
|
263
|
+
scenario execution by clearly indicating what operation is happening.
|
264
|
+
|
265
|
+
Args:
|
266
|
+
name: Name of the spinner animation style
|
267
|
+
text: Descriptive text to show before the spinner
|
268
|
+
color: Color for the descriptive text
|
269
|
+
**kwargs: Additional arguments passed to the base Spinner class
|
270
|
+
"""
|
117
271
|
def __init__(self, name, text: str, color: str, **kwargs):
|
118
272
|
super().__init__(
|
119
273
|
name, "", style="bold white", **kwargs
|
@@ -132,6 +286,32 @@ class TextFirstSpinner(Spinner):
|
|
132
286
|
def show_spinner(
|
133
287
|
text: str, color: str = "white", enabled: Optional[Union[bool, int]] = None
|
134
288
|
):
|
289
|
+
"""
|
290
|
+
Context manager for displaying a spinner during long-running operations.
|
291
|
+
|
292
|
+
Shows a spinning indicator with descriptive text while code executes
|
293
|
+
within the context. Automatically cleans up the spinner display when
|
294
|
+
the operation completes.
|
295
|
+
|
296
|
+
Args:
|
297
|
+
text: Descriptive text to show next to the spinner
|
298
|
+
color: Color for the descriptive text
|
299
|
+
enabled: Whether to show the spinner (respects verbose settings)
|
300
|
+
|
301
|
+
Example:
|
302
|
+
```python
|
303
|
+
with show_spinner("Calling agent...", color="blue", enabled=True):
|
304
|
+
response = await agent.call(input_data)
|
305
|
+
|
306
|
+
# Spinner automatically disappears when block completes
|
307
|
+
print("Agent call completed")
|
308
|
+
```
|
309
|
+
|
310
|
+
Note:
|
311
|
+
- Spinner is automatically cleaned up when context exits
|
312
|
+
- Gracefully handles multi-threading scenarios where multiple spinners might conflict
|
313
|
+
- Cursor positioning ensures clean terminal output
|
314
|
+
"""
|
135
315
|
if not enabled:
|
136
316
|
yield
|
137
317
|
else:
|
@@ -150,6 +330,31 @@ def show_spinner(
|
|
150
330
|
|
151
331
|
|
152
332
|
def check_valid_return_type(return_value: Any, class_name: str) -> None:
|
333
|
+
"""
|
334
|
+
Validate that an agent's return value is in the expected format.
|
335
|
+
|
336
|
+
This function ensures that agent adapters return values in one of the
|
337
|
+
supported formats (string, OpenAI message, list of messages, or ScenarioResult).
|
338
|
+
It also verifies that the returned data is JSON-serializable for caching.
|
339
|
+
|
340
|
+
Args:
|
341
|
+
return_value: The value returned by an agent's call method
|
342
|
+
class_name: Name of the agent class (for error messages)
|
343
|
+
|
344
|
+
Raises:
|
345
|
+
ValueError: If the return value is not in a supported format
|
346
|
+
|
347
|
+
Example:
|
348
|
+
```python
|
349
|
+
# Valid return values
|
350
|
+
check_valid_return_type("Hello world", "MyAgent") # OK
|
351
|
+
check_valid_return_type({"role": "assistant", "content": "Hi"}, "MyAgent") # OK
|
352
|
+
check_valid_return_type([{"role": "assistant", "content": "Hi"}], "MyAgent") # OK
|
353
|
+
|
354
|
+
# Invalid return value
|
355
|
+
check_valid_return_type(42, "MyAgent") # Raises ValueError
|
356
|
+
```
|
357
|
+
"""
|
153
358
|
def _is_valid_openai_message(message: Any) -> bool:
|
154
359
|
return (isinstance(message, dict) and "role" in message) or (
|
155
360
|
isinstance(message, BaseModel) and hasattr(message, "role")
|
@@ -181,6 +386,43 @@ def check_valid_return_type(return_value: Any, class_name: str) -> None:
|
|
181
386
|
def convert_agent_return_types_to_openai_messages(
|
182
387
|
agent_response: AgentReturnTypes, role: Literal["user", "assistant"]
|
183
388
|
) -> List[ChatCompletionMessageParam]:
|
389
|
+
"""
|
390
|
+
Convert various agent return types to standardized OpenAI message format.
|
391
|
+
|
392
|
+
This function normalizes different return types from agent adapters into
|
393
|
+
a consistent list of OpenAI-compatible messages that can be used throughout
|
394
|
+
the scenario execution pipeline.
|
395
|
+
|
396
|
+
Args:
|
397
|
+
agent_response: Response from an agent adapter call
|
398
|
+
role: The role to assign to string responses ("user" or "assistant")
|
399
|
+
|
400
|
+
Returns:
|
401
|
+
List of OpenAI-compatible messages
|
402
|
+
|
403
|
+
Raises:
|
404
|
+
ValueError: If agent_response is a ScenarioResult (which should be handled separately)
|
405
|
+
|
406
|
+
Example:
|
407
|
+
```python
|
408
|
+
# String response
|
409
|
+
messages = convert_agent_return_types_to_openai_messages("Hello", "assistant")
|
410
|
+
# Result: [{"role": "assistant", "content": "Hello"}]
|
411
|
+
|
412
|
+
# Dict response
|
413
|
+
response = {"role": "assistant", "content": "Hi", "tool_calls": [...]}
|
414
|
+
messages = convert_agent_return_types_to_openai_messages(response, "assistant")
|
415
|
+
# Result: [{"role": "assistant", "content": "Hi", "tool_calls": [...]}]
|
416
|
+
|
417
|
+
# List response
|
418
|
+
responses = [
|
419
|
+
{"role": "assistant", "content": "Thinking..."},
|
420
|
+
{"role": "assistant", "content": "Here's the answer"}
|
421
|
+
]
|
422
|
+
messages = convert_agent_return_types_to_openai_messages(responses, "assistant")
|
423
|
+
# Result: Same list, validated and normalized
|
424
|
+
```
|
425
|
+
"""
|
184
426
|
if isinstance(agent_response, ScenarioResult):
|
185
427
|
raise ValueError(
|
186
428
|
"Unexpectedly tried to convert a ScenarioResult to openai messages",
|
@@ -199,6 +441,7 @@ def convert_agent_return_types_to_openai_messages(
|
|
199
441
|
exclude_unset=True,
|
200
442
|
exclude_none=True,
|
201
443
|
exclude_defaults=True,
|
444
|
+
warnings=False,
|
202
445
|
),
|
203
446
|
)
|
204
447
|
else:
|
@@ -236,11 +479,16 @@ def reverse_roles(
|
|
236
479
|
messages: The list of messages to reverse the roles of.
|
237
480
|
"""
|
238
481
|
|
239
|
-
|
482
|
+
reversed_messages = []
|
483
|
+
for message in messages:
|
484
|
+
message = copy.deepcopy(message)
|
240
485
|
# Can't reverse tool calls
|
241
486
|
if not safe_attr_or_key(message, "content") or safe_attr_or_key(
|
242
487
|
message, "tool_calls"
|
243
488
|
):
|
489
|
+
# If no content nor tool calls, we should skip it entirely, as anthropic may generate some invalid ones e.g. pure {"role": "assistant"}
|
490
|
+
if safe_attr_or_key(message, "tool_calls"):
|
491
|
+
reversed_messages.append(message)
|
244
492
|
continue
|
245
493
|
|
246
494
|
if type(message) == dict:
|
@@ -254,7 +502,9 @@ def reverse_roles(
|
|
254
502
|
elif getattr(message, "role", None) == "assistant":
|
255
503
|
message.role = "user" # type: ignore
|
256
504
|
|
257
|
-
|
505
|
+
reversed_messages.append(message)
|
506
|
+
|
507
|
+
return reversed_messages
|
258
508
|
|
259
509
|
|
260
510
|
async def await_if_awaitable(value: T) -> T:
|
@@ -1,16 +0,0 @@
|
|
1
|
-
scenario/__init__.py,sha256=0OavO4hoZMFL6frlplNkR7BSHfGSOhuVtmKmTrOMFEs,844
|
2
|
-
scenario/cache.py,sha256=sYu16SAf-BnVYkWSlEDzpyynJGIQyNYsgMXPgCqEnmk,1719
|
3
|
-
scenario/config.py,sha256=NiCCmr8flds-VDzvF8ps4SChVTARtcWfEoHhK0UkDMQ,1076
|
4
|
-
scenario/error_messages.py,sha256=8_pa3HIaqkw08qOqeiRKDCNykr9jtofpNJoEV03aRWc,4690
|
5
|
-
scenario/pytest_plugin.py,sha256=oJtEPVPi5x50Z-UawVyVPNd6buvh_4msSZ-3hLFpw_Y,5770
|
6
|
-
scenario/scenario.py,sha256=K4Snu4-pJaoprEFyly7ZQT8qNlAamxt-eXibCJ0EIJU,7332
|
7
|
-
scenario/scenario_agent_adapter.py,sha256=Y2dP3z-2jLYCssQ20oHOphwwrRPQNo2HmLD2KBcJRu0,427
|
8
|
-
scenario/scenario_executor.py,sha256=geaP3Znd1he66L6ku3l2IAODj68TtAIk8b8Ssy494xA,15681
|
9
|
-
scenario/testing_agent.py,sha256=5S2PIl2hi9FBSVjjs9afXhEgiogryjBIyffH5iJBwdo,10676
|
10
|
-
scenario/types.py,sha256=-Uz0qg_fY5vAEkrZnM5CMqE5hiP8OtNErpDdHJmHtac,3179
|
11
|
-
scenario/utils.py,sha256=bx813RpZO3xyPfD-dTBbeLM9umWm3PGOq9pw48aJoHI,8113
|
12
|
-
langwatch_scenario-0.3.0.dist-info/METADATA,sha256=pywrVOVE2eE4Zk5wePzJoEfErNXWvgK-C8G-qfWp7EI,11040
|
13
|
-
langwatch_scenario-0.3.0.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
14
|
-
langwatch_scenario-0.3.0.dist-info/entry_points.txt,sha256=WlEnJ_gku0i18bIa3DSuGqXRX-QDQLe_s0YmRzK45TI,45
|
15
|
-
langwatch_scenario-0.3.0.dist-info/top_level.txt,sha256=45Mn28aedJsetnBMB5xSmrJ-yo701QLH89Zlz4r1clE,9
|
16
|
-
langwatch_scenario-0.3.0.dist-info/RECORD,,
|
scenario/scenario.py
DELETED
@@ -1,238 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
Scenario module: defines the core Scenario class for agent testing.
|
3
|
-
"""
|
4
|
-
|
5
|
-
from typing import (
|
6
|
-
Awaitable,
|
7
|
-
Callable,
|
8
|
-
List,
|
9
|
-
Dict,
|
10
|
-
Any,
|
11
|
-
Optional,
|
12
|
-
Type,
|
13
|
-
TypedDict,
|
14
|
-
Union,
|
15
|
-
)
|
16
|
-
import asyncio
|
17
|
-
import concurrent.futures
|
18
|
-
|
19
|
-
from scenario.config import ScenarioConfig
|
20
|
-
from scenario.error_messages import (
|
21
|
-
default_config_error_message,
|
22
|
-
message_invalid_agent_type,
|
23
|
-
)
|
24
|
-
from scenario.scenario_agent_adapter import ScenarioAgentAdapter
|
25
|
-
from scenario.scenario_executor import ScenarioExecutor
|
26
|
-
|
27
|
-
from .types import ScenarioResult, ScriptStep
|
28
|
-
|
29
|
-
from openai.types.chat import ChatCompletionMessageParam
|
30
|
-
|
31
|
-
|
32
|
-
class AgentResult(TypedDict, total=False):
|
33
|
-
message: str
|
34
|
-
messages: List[ChatCompletionMessageParam]
|
35
|
-
extra: Dict[str, Any]
|
36
|
-
|
37
|
-
|
38
|
-
class Scenario(ScenarioConfig):
|
39
|
-
"""
|
40
|
-
A scenario represents a specific testing case for an agent.
|
41
|
-
|
42
|
-
It includes:
|
43
|
-
- A description of the scenario
|
44
|
-
- Criteria to determine if the agent behaved correctly
|
45
|
-
- Optional additional parameters
|
46
|
-
"""
|
47
|
-
|
48
|
-
name: str
|
49
|
-
description: str
|
50
|
-
agents: List[Type[ScenarioAgentAdapter]]
|
51
|
-
criteria: List[str]
|
52
|
-
|
53
|
-
def __init__(
|
54
|
-
self,
|
55
|
-
name: str,
|
56
|
-
description: str,
|
57
|
-
criteria: List[str] = [],
|
58
|
-
agent: Optional[Type[ScenarioAgentAdapter]] = None,
|
59
|
-
testing_agent: Optional[Type[ScenarioAgentAdapter]] = None,
|
60
|
-
agents: List[Type[ScenarioAgentAdapter]] = [],
|
61
|
-
max_turns: Optional[int] = None,
|
62
|
-
verbose: Optional[Union[bool, int]] = None,
|
63
|
-
cache_key: Optional[str] = None,
|
64
|
-
debug: Optional[bool] = None,
|
65
|
-
):
|
66
|
-
"""Validate scenario configuration after initialization."""
|
67
|
-
|
68
|
-
config = ScenarioConfig(
|
69
|
-
testing_agent=testing_agent,
|
70
|
-
max_turns=max_turns,
|
71
|
-
verbose=verbose,
|
72
|
-
cache_key=cache_key,
|
73
|
-
debug=debug,
|
74
|
-
)
|
75
|
-
|
76
|
-
kwargs = config.items()
|
77
|
-
default_config: Optional[ScenarioConfig] = getattr(
|
78
|
-
Scenario, "default_config", None
|
79
|
-
)
|
80
|
-
if default_config:
|
81
|
-
kwargs = default_config.merge(config).items()
|
82
|
-
|
83
|
-
if not name:
|
84
|
-
raise ValueError("Scenario name cannot be empty")
|
85
|
-
kwargs["name"] = name
|
86
|
-
|
87
|
-
if not description:
|
88
|
-
raise ValueError("Scenario description cannot be empty")
|
89
|
-
kwargs["description"] = description
|
90
|
-
|
91
|
-
kwargs["criteria"] = criteria
|
92
|
-
|
93
|
-
if kwargs.get("max_turns", 10) < 1:
|
94
|
-
raise ValueError("max_turns must be a positive integer")
|
95
|
-
|
96
|
-
if not agents and not agent:
|
97
|
-
raise ValueError(
|
98
|
-
"Missing required argument `agent`. Either `agent` or `agents` argument must be provided for the Scenario"
|
99
|
-
)
|
100
|
-
|
101
|
-
if not agents and not kwargs.get("testing_agent"):
|
102
|
-
raise Exception(default_config_error_message)
|
103
|
-
|
104
|
-
agents = agents or [
|
105
|
-
kwargs.get("testing_agent"),
|
106
|
-
agent, # type: ignore
|
107
|
-
]
|
108
|
-
|
109
|
-
# Ensure each agent is a ScenarioAgentAdapter
|
110
|
-
for agent in agents:
|
111
|
-
if (
|
112
|
-
not agent
|
113
|
-
or not isinstance(agent, type)
|
114
|
-
or not issubclass(agent, ScenarioAgentAdapter)
|
115
|
-
):
|
116
|
-
raise ValueError(message_invalid_agent_type(agent))
|
117
|
-
kwargs["agents"] = agents
|
118
|
-
|
119
|
-
super().__init__(**kwargs)
|
120
|
-
|
121
|
-
def script(self, script: List[ScriptStep]):
|
122
|
-
class ScriptedScenario:
|
123
|
-
def __init__(self, scenario: "Scenario"):
|
124
|
-
self._scenario = scenario
|
125
|
-
|
126
|
-
async def run(
|
127
|
-
self, context: Optional[Dict[str, Any]] = None
|
128
|
-
) -> ScenarioResult:
|
129
|
-
return await self._scenario._run(context, script)
|
130
|
-
|
131
|
-
return ScriptedScenario(self)
|
132
|
-
|
133
|
-
async def run(self, context: Optional[Dict[str, Any]] = None) -> ScenarioResult:
|
134
|
-
"""
|
135
|
-
Run the scenario against the agent under test.
|
136
|
-
|
137
|
-
Args:
|
138
|
-
context: Optional initial context for the agent
|
139
|
-
|
140
|
-
Returns:
|
141
|
-
ScenarioResult containing the test outcome
|
142
|
-
"""
|
143
|
-
|
144
|
-
return await self._run(context, None)
|
145
|
-
|
146
|
-
async def _run(
|
147
|
-
self,
|
148
|
-
context: Optional[Dict[str, Any]] = None,
|
149
|
-
script: Optional[List[ScriptStep]] = None,
|
150
|
-
) -> ScenarioResult:
|
151
|
-
# We'll use a thread pool to run the execution logic, we
|
152
|
-
# require a separate thread because even though asyncio is
|
153
|
-
# being used throughout, any user code on the callback can
|
154
|
-
# be blocking, preventing them from running scenarios in parallel
|
155
|
-
with concurrent.futures.ThreadPoolExecutor() as executor:
|
156
|
-
|
157
|
-
def run_in_thread():
|
158
|
-
loop = asyncio.new_event_loop()
|
159
|
-
asyncio.set_event_loop(loop)
|
160
|
-
|
161
|
-
try:
|
162
|
-
return loop.run_until_complete(
|
163
|
-
ScenarioExecutor(self, context, script).run()
|
164
|
-
)
|
165
|
-
finally:
|
166
|
-
loop.close()
|
167
|
-
|
168
|
-
# Run the function in the thread pool and await its result
|
169
|
-
# This converts the thread's execution into a Future that the current
|
170
|
-
# event loop can await without blocking
|
171
|
-
loop = asyncio.get_event_loop()
|
172
|
-
result = await loop.run_in_executor(executor, run_in_thread)
|
173
|
-
return result
|
174
|
-
|
175
|
-
@classmethod
|
176
|
-
def configure(
|
177
|
-
cls,
|
178
|
-
testing_agent: Optional[Type[ScenarioAgentAdapter]] = None,
|
179
|
-
max_turns: Optional[int] = None,
|
180
|
-
verbose: Optional[Union[bool, int]] = None,
|
181
|
-
cache_key: Optional[str] = None,
|
182
|
-
debug: Optional[bool] = None,
|
183
|
-
) -> None:
|
184
|
-
existing_config = getattr(cls, "default_config", ScenarioConfig())
|
185
|
-
|
186
|
-
cls.default_config = existing_config.merge(
|
187
|
-
ScenarioConfig(
|
188
|
-
testing_agent=testing_agent,
|
189
|
-
max_turns=max_turns,
|
190
|
-
verbose=verbose,
|
191
|
-
cache_key=cache_key,
|
192
|
-
debug=debug,
|
193
|
-
)
|
194
|
-
)
|
195
|
-
|
196
|
-
# Scenario Scripting
|
197
|
-
|
198
|
-
def message(self, message: ChatCompletionMessageParam) -> ScriptStep:
|
199
|
-
return lambda state: state.message(message)
|
200
|
-
|
201
|
-
def user(
|
202
|
-
self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
|
203
|
-
) -> ScriptStep:
|
204
|
-
return lambda state: state.user(content)
|
205
|
-
|
206
|
-
def agent(
|
207
|
-
self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
|
208
|
-
) -> ScriptStep:
|
209
|
-
return lambda state: state.agent(content)
|
210
|
-
|
211
|
-
def judge(
|
212
|
-
self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
|
213
|
-
) -> ScriptStep:
|
214
|
-
return lambda state: state.judge(content)
|
215
|
-
|
216
|
-
def proceed(
|
217
|
-
self,
|
218
|
-
turns: Optional[int] = None,
|
219
|
-
on_turn: Optional[
|
220
|
-
Union[
|
221
|
-
Callable[[ScenarioExecutor], None],
|
222
|
-
Callable[[ScenarioExecutor], Awaitable[None]],
|
223
|
-
]
|
224
|
-
] = None,
|
225
|
-
on_step: Optional[
|
226
|
-
Union[
|
227
|
-
Callable[[ScenarioExecutor], None],
|
228
|
-
Callable[[ScenarioExecutor], Awaitable[None]],
|
229
|
-
]
|
230
|
-
] = None,
|
231
|
-
) -> ScriptStep:
|
232
|
-
return lambda state: state.proceed(turns, on_turn, on_step)
|
233
|
-
|
234
|
-
def succeed(self) -> ScriptStep:
|
235
|
-
return lambda state: state.succeed()
|
236
|
-
|
237
|
-
def fail(self) -> ScriptStep:
|
238
|
-
return lambda state: state.fail()
|
@@ -1,16 +0,0 @@
|
|
1
|
-
from abc import ABC, abstractmethod
|
2
|
-
from typing import ClassVar, Set
|
3
|
-
|
4
|
-
from .types import AgentInput, AgentReturnTypes, ScenarioAgentRole
|
5
|
-
|
6
|
-
|
7
|
-
class ScenarioAgentAdapter(ABC):
|
8
|
-
roles: ClassVar[Set[ScenarioAgentRole]] = {ScenarioAgentRole.AGENT}
|
9
|
-
|
10
|
-
def __init__(self, input: AgentInput):
|
11
|
-
super().__init__()
|
12
|
-
pass
|
13
|
-
|
14
|
-
@abstractmethod
|
15
|
-
async def call(self, input: AgentInput) -> AgentReturnTypes:
|
16
|
-
pass
|