langwatch-scenario 0.7.8__py3-none-any.whl → 0.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langwatch_scenario-0.7.8.dist-info → langwatch_scenario-0.7.10.dist-info}/METADATA +4 -3
- {langwatch_scenario-0.7.8.dist-info → langwatch_scenario-0.7.10.dist-info}/RECORD +19 -18
- scenario/_events/event_alert_message_logger.py +20 -29
- scenario/_events/event_bus.py +4 -1
- scenario/_events/event_reporter.py +8 -3
- scenario/_events/utils.py +44 -28
- scenario/_utils/__init__.py +2 -2
- scenario/_utils/ids.py +12 -12
- scenario/config/scenario.py +8 -0
- scenario/judge_agent.py +4 -3
- scenario/py.typed +0 -0
- scenario/pytest_plugin.py +5 -0
- scenario/scenario_executor.py +118 -60
- scenario/scenario_state.py +2 -1
- scenario/types.py +54 -2
- scenario/user_simulator_agent.py +3 -2
- {langwatch_scenario-0.7.8.dist-info → langwatch_scenario-0.7.10.dist-info}/WHEEL +0 -0
- {langwatch_scenario-0.7.8.dist-info → langwatch_scenario-0.7.10.dist-info}/entry_points.txt +0 -0
- {langwatch_scenario-0.7.8.dist-info → langwatch_scenario-0.7.10.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: langwatch-scenario
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.10
|
4
4
|
Summary: The end-to-end agent testing library
|
5
5
|
Author-email: LangWatch Team <support@langwatch.ai>
|
6
6
|
License: MIT
|
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.9
|
15
15
|
Classifier: Programming Language :: Python :: 3.10
|
16
16
|
Classifier: Programming Language :: Python :: 3.11
|
17
|
-
Requires-Python: >=3.
|
17
|
+
Requires-Python: >=3.10
|
18
18
|
Description-Content-Type: text/markdown
|
19
19
|
Requires-Dist: pytest>=8.1.1
|
20
20
|
Requires-Dist: litellm>=1.49.0
|
@@ -31,6 +31,7 @@ Requires-Dist: httpx>=0.27.0
|
|
31
31
|
Requires-Dist: rx>=3.2.0
|
32
32
|
Requires-Dist: python-dateutil>=2.9.0.post0
|
33
33
|
Requires-Dist: pydantic-settings>=2.9.1
|
34
|
+
Requires-Dist: langwatch>=0.2.19
|
34
35
|
Provides-Extra: dev
|
35
36
|
Requires-Dist: black; extra == "dev"
|
36
37
|
Requires-Dist: isort; extra == "dev"
|
@@ -457,7 +458,7 @@ This will cache any function call you decorate when running the tests and make t
|
|
457
458
|
While optional, we strongly recommend setting stable identifiers for your scenarios, sets, and batches for better organization and tracking in LangWatch.
|
458
459
|
|
459
460
|
- **set_id**: Groups related scenarios into a test suite. This corresponds to the "Simulation Set" in the UI.
|
460
|
-
- **
|
461
|
+
- **SCENARIO_BATCH_RUN_ID**: Env variable that groups all scenarios that were run together in a single execution (e.g., a single CI job). This is automatically generated but can be overridden.
|
461
462
|
|
462
463
|
```python
|
463
464
|
import os
|
@@ -2,20 +2,21 @@ scenario/__init__.py,sha256=4WO8TjY8Lc0NhYL7b9LvaB1xCBqwUkLuI0uIA6PQP6c,4223
|
|
2
2
|
scenario/_error_messages.py,sha256=QVFSbhzsVNGz2GOBOaoQFW6w6AOyZCWLTt0ySWPfnGw,3882
|
3
3
|
scenario/agent_adapter.py,sha256=PoY2KQqYuqzIIb3-nhIU-MPXwHJc1vmwdweMy7ut-hk,4255
|
4
4
|
scenario/cache.py,sha256=J6s6Sia_Ce6TrnsInlhfxm6SF8tygo3sH-_cQCRX1WA,6213
|
5
|
-
scenario/judge_agent.py,sha256=
|
6
|
-
scenario/
|
7
|
-
scenario/
|
8
|
-
scenario/
|
5
|
+
scenario/judge_agent.py,sha256=hHQ2nKsOgSyTtN0LdE6xIF0wZnnlYLN6RcxTPecFHDU,16770
|
6
|
+
scenario/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
|
+
scenario/pytest_plugin.py,sha256=wRCuGD9uwrrLt2fY15zK6mnmY9W_dO_m0WalPJYE5II,11491
|
8
|
+
scenario/scenario_executor.py,sha256=v41UgSHebosXf95FfYIeVUm6s4IbMP_U58FdGoZ_kZU,35653
|
9
|
+
scenario/scenario_state.py,sha256=R8PhPHW3obYo3DCjBH5XDdZ6bp4uol7wCXO8K2Tz30I,7101
|
9
10
|
scenario/script.py,sha256=A0N5pP0l4FFn1xdKc78U_wkwWhEWH3EFeU_LRDtNyEI,12241
|
10
|
-
scenario/types.py,sha256=
|
11
|
-
scenario/user_simulator_agent.py,sha256=
|
11
|
+
scenario/types.py,sha256=CRSCHUplXEXhj6EYQsncwJBzbd2128YTGlFxlk-rrG8,11193
|
12
|
+
scenario/user_simulator_agent.py,sha256=gXRaeoivEAcenIEqMDU6bWzv8cOrJaaooNrTdpC9TE4,9630
|
12
13
|
scenario/_events/__init__.py,sha256=4cj6H9zuXzvWhT2P2JNdjWzeF1PUepTjqIDw85Vid9s,1500
|
13
|
-
scenario/_events/event_alert_message_logger.py,sha256=
|
14
|
-
scenario/_events/event_bus.py,sha256=
|
15
|
-
scenario/_events/event_reporter.py,sha256
|
14
|
+
scenario/_events/event_alert_message_logger.py,sha256=XcofGgXjeiTC75NPYheBpHxqA6R4pYAuHZa7-kH9Grg,2975
|
15
|
+
scenario/_events/event_bus.py,sha256=IsKNsClF1JFYj728EcxX1hw_KbfDkfJq3Y2Kv4h94n4,9871
|
16
|
+
scenario/_events/event_reporter.py,sha256=-6NNbBMy_FYr1O-1FuZ6eIUnLuI8NGRMUr0pybLJrCI,3873
|
16
17
|
scenario/_events/events.py,sha256=UtEGY-_1B0LrwpgsNKgrvJBZhRtxuj3K_i6ZBfF7E4Q,6387
|
17
18
|
scenario/_events/messages.py,sha256=quwP2OkeaGasNOoaV8GUeosZVKc5XDsde08T0xx_YQo,2297
|
18
|
-
scenario/_events/utils.py,sha256=
|
19
|
+
scenario/_events/utils.py,sha256=CRrdDHBD2ptcNIjzW0eEG1V5-Vw1gFnp_UTz5zMQ_Ak,4051
|
19
20
|
scenario/_generated/langwatch_api_client/README.md,sha256=Az5f2L4ChOnG_ZtrdBagzRVgeTCtBkbD_S5cIeAry2o,5424
|
20
21
|
scenario/_generated/langwatch_api_client/pyproject.toml,sha256=Z8wxuGp4H9BJYVVJB8diW7rRU9XYxtPfw9mU4_wq4cA,560
|
21
22
|
scenario/_generated/langwatch_api_client/lang_watch_api_client/__init__.py,sha256=vVrn17y-3l3fOqeJk8aN3GlStRm2fo0f313l_0LtJNs,368
|
@@ -226,16 +227,16 @@ scenario/_generated/langwatch_api_client/lang_watch_api_client/models/search_req
|
|
226
227
|
scenario/_generated/langwatch_api_client/lang_watch_api_client/models/search_response.py,sha256=zDYmJ8bFBSJyF9D3cEn_ffrey-ITIfwr-_7eu72zLyk,2832
|
227
228
|
scenario/_generated/langwatch_api_client/lang_watch_api_client/models/timestamps.py,sha256=-nRKUPZTAJQNxiKz128xF7DKgZNbFo4G3mr5xNXrkaw,2173
|
228
229
|
scenario/_generated/langwatch_api_client/lang_watch_api_client/models/trace.py,sha256=K9Lc_EQOrJ2dqMXx9EpiUXReT1_uYF7WRfYyhlfbi3I,7537
|
229
|
-
scenario/_utils/__init__.py,sha256=
|
230
|
-
scenario/_utils/ids.py,sha256=
|
230
|
+
scenario/_utils/__init__.py,sha256=xPVjLXnHTTq9fuRFh5lsMvwtIpEeJ3jy1vf5yTUMPsc,1313
|
231
|
+
scenario/_utils/ids.py,sha256=W4tVMCf9ky0KLTDA_qOfErNhb4tCmxwa8zEuo1K1ZuY,2071
|
231
232
|
scenario/_utils/message_conversion.py,sha256=AWHn31E7J0mz9sBXWruVVAgtsrJz1R_xEf-dGbX6jjs,3636
|
232
233
|
scenario/_utils/utils.py,sha256=msQgUWaLh3U9jIIHmxkEbOaklga63AF0KJzsaKa_mZc,14008
|
233
234
|
scenario/config/__init__.py,sha256=b2X_bqkIrd7jZY9dRrXk2wOqoPe87Nl_SRGuZhlolxA,1123
|
234
235
|
scenario/config/langwatch.py,sha256=ijWchFbUsLbQooAZmwyTw4rxfRLQseZ1GoVSiPPbzpw,1677
|
235
236
|
scenario/config/model.py,sha256=T4HYA79CW1NxXDkFlyftYR6JzZcowbtIx0H-ijxRyfg,1297
|
236
|
-
scenario/config/scenario.py,sha256=
|
237
|
-
langwatch_scenario-0.7.
|
238
|
-
langwatch_scenario-0.7.
|
239
|
-
langwatch_scenario-0.7.
|
240
|
-
langwatch_scenario-0.7.
|
241
|
-
langwatch_scenario-0.7.
|
237
|
+
scenario/config/scenario.py,sha256=6jrtcm0Fo7FpxQta7QIKdGMgl7cXrn374Inzx29hRuk,5406
|
238
|
+
langwatch_scenario-0.7.10.dist-info/METADATA,sha256=pbLZM8UXj1_1TWHjheHP6QREOvRWfX7nHEdfY2ZX4aA,20065
|
239
|
+
langwatch_scenario-0.7.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
240
|
+
langwatch_scenario-0.7.10.dist-info/entry_points.txt,sha256=WlEnJ_gku0i18bIa3DSuGqXRX-QDQLe_s0YmRzK45TI,45
|
241
|
+
langwatch_scenario-0.7.10.dist-info/top_level.txt,sha256=45Mn28aedJsetnBMB5xSmrJ-yo701QLH89Zlz4r1clE,9
|
242
|
+
langwatch_scenario-0.7.10.dist-info/RECORD,,
|
@@ -1,5 +1,8 @@
|
|
1
1
|
import os
|
2
|
+
import webbrowser
|
2
3
|
from typing import Set
|
4
|
+
|
5
|
+
from ..config.scenario import ScenarioConfig
|
3
6
|
from .._utils.ids import get_batch_run_id
|
4
7
|
|
5
8
|
|
@@ -12,6 +15,7 @@ class EventAlertMessageLogger:
|
|
12
15
|
"""
|
13
16
|
|
14
17
|
_shown_batch_ids: Set[str] = set()
|
18
|
+
_shown_watch_urls: Set[str] = set()
|
15
19
|
|
16
20
|
def handle_greeting(self) -> None:
|
17
21
|
"""
|
@@ -37,6 +41,10 @@ class EventAlertMessageLogger:
|
|
37
41
|
if self._is_greeting_disabled():
|
38
42
|
return
|
39
43
|
|
44
|
+
if set_url in EventAlertMessageLogger._shown_watch_urls:
|
45
|
+
return
|
46
|
+
|
47
|
+
EventAlertMessageLogger._shown_watch_urls.add(set_url)
|
40
48
|
self._display_watch_message(set_url)
|
41
49
|
|
42
50
|
def _is_greeting_disabled(self) -> bool:
|
@@ -49,35 +57,13 @@ class EventAlertMessageLogger:
|
|
49
57
|
|
50
58
|
if not os.getenv("LANGWATCH_API_KEY"):
|
51
59
|
print(f"\n{separator}")
|
52
|
-
print("
|
60
|
+
print("🎭 Running Scenario Tests")
|
53
61
|
print(f"{separator}")
|
54
|
-
print("➡️ API key not configured")
|
62
|
+
print("➡️ LangWatch API key not configured")
|
55
63
|
print(" Simulations will only output final results")
|
56
64
|
print("")
|
57
65
|
print("💡 To visualize conversations in real time:")
|
58
66
|
print(" • Set LANGWATCH_API_KEY environment variable")
|
59
|
-
print(" • Or configure apiKey in scenario.config.js")
|
60
|
-
print("")
|
61
|
-
print(f"📦 Batch Run ID: {batch_run_id}")
|
62
|
-
print("")
|
63
|
-
print("🔇 To disable these messages:")
|
64
|
-
print(" • Set SCENARIO_DISABLE_SIMULATION_REPORT_INFO=true")
|
65
|
-
print(f"{separator}\n")
|
66
|
-
else:
|
67
|
-
endpoint = os.getenv("LANGWATCH_ENDPOINT", "https://app.langwatch.ai")
|
68
|
-
api_key = os.getenv("LANGWATCH_API_KEY", "")
|
69
|
-
|
70
|
-
print(f"\n{separator}")
|
71
|
-
print("🚀 LangWatch Simulation Reporting")
|
72
|
-
print(f"{separator}")
|
73
|
-
print("✅ Simulation reporting enabled")
|
74
|
-
print(f" Endpoint: {endpoint}")
|
75
|
-
print(f" API Key: {'Configured' if api_key else 'Not configured'}")
|
76
|
-
print("")
|
77
|
-
print(f"📦 Batch Run ID: {batch_run_id}")
|
78
|
-
print("")
|
79
|
-
print("🔇 To disable these messages:")
|
80
|
-
print(" • Set SCENARIO_DISABLE_SIMULATION_REPORT_INFO=true")
|
81
67
|
print(f"{separator}\n")
|
82
68
|
|
83
69
|
def _display_watch_message(self, set_url: str) -> None:
|
@@ -86,10 +72,15 @@ class EventAlertMessageLogger:
|
|
86
72
|
batch_url = f"{set_url}/{get_batch_run_id()}"
|
87
73
|
|
88
74
|
print(f"\n{separator}")
|
89
|
-
print("
|
75
|
+
print("🎭 Running Scenario Tests")
|
90
76
|
print(f"{separator}")
|
91
|
-
print("
|
92
|
-
print(f" Scenario Set: {set_url}")
|
93
|
-
print(f" Batch Run: {batch_url}")
|
94
|
-
print("")
|
77
|
+
print(f"Follow it live: {batch_url}")
|
95
78
|
print(f"{separator}\n")
|
79
|
+
|
80
|
+
config = ScenarioConfig.default_config
|
81
|
+
if config and not config.headless:
|
82
|
+
# Open the URL in the default browser (cross-platform)
|
83
|
+
try:
|
84
|
+
webbrowser.open(batch_url)
|
85
|
+
except Exception:
|
86
|
+
pass
|
scenario/_events/event_bus.py
CHANGED
@@ -3,6 +3,7 @@ from typing import Optional, Any, Dict
|
|
3
3
|
from .events import ScenarioEvent
|
4
4
|
from .event_reporter import EventReporter
|
5
5
|
from .event_alert_message_logger import EventAlertMessageLogger
|
6
|
+
from ..config.scenario import ScenarioConfig
|
6
7
|
|
7
8
|
import asyncio
|
8
9
|
import queue
|
@@ -35,7 +36,9 @@ class ScenarioEventBus:
|
|
35
36
|
"""
|
36
37
|
|
37
38
|
def __init__(
|
38
|
-
self,
|
39
|
+
self,
|
40
|
+
event_reporter: Optional[EventReporter] = None,
|
41
|
+
max_retries: int = 3,
|
39
42
|
):
|
40
43
|
"""
|
41
44
|
Initialize the event bus with optional event reporter and retry configuration.
|
@@ -3,7 +3,7 @@ import httpx
|
|
3
3
|
from typing import Optional, Dict, Any
|
4
4
|
from .events import ScenarioEvent
|
5
5
|
from .event_alert_message_logger import EventAlertMessageLogger
|
6
|
-
from scenario.config import LangWatchSettings
|
6
|
+
from scenario.config import LangWatchSettings, ScenarioConfig
|
7
7
|
|
8
8
|
|
9
9
|
class EventReporter:
|
@@ -26,7 +26,11 @@ class EventReporter:
|
|
26
26
|
reporter = EventReporter(api_key="your-api-key")
|
27
27
|
"""
|
28
28
|
|
29
|
-
def __init__(
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
endpoint: Optional[str] = None,
|
32
|
+
api_key: Optional[str] = None,
|
33
|
+
):
|
30
34
|
# Load settings from environment variables
|
31
35
|
langwatch_settings = LangWatchSettings()
|
32
36
|
|
@@ -69,6 +73,7 @@ class EventReporter:
|
|
69
73
|
"Content-Type": "application/json",
|
70
74
|
"X-Auth-Token": self.api_key,
|
71
75
|
},
|
76
|
+
timeout=httpx.Timeout(30.0),
|
72
77
|
)
|
73
78
|
self.logger.info(
|
74
79
|
f"[{event_type}] POST response status: {response.status_code} ({event.scenario_run_id})"
|
@@ -92,7 +97,7 @@ class EventReporter:
|
|
92
97
|
)
|
93
98
|
except Exception as error:
|
94
99
|
self.logger.error(
|
95
|
-
f"[{event_type}] Event POST error: {error}, event={event}, endpoint={self.endpoint}"
|
100
|
+
f"[{event_type}] Event POST error: {repr(error)}, event={event}, endpoint={self.endpoint}"
|
96
101
|
)
|
97
102
|
|
98
103
|
return result
|
scenario/_events/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import warnings
|
2
|
-
|
2
|
+
|
3
|
+
from ..types import ChatCompletionMessageParamWithTrace
|
3
4
|
from .events import MessageType
|
4
5
|
from .messages import (
|
5
6
|
SystemMessage,
|
@@ -10,9 +11,12 @@ from .messages import (
|
|
10
11
|
FunctionCall,
|
11
12
|
)
|
12
13
|
from typing import List
|
13
|
-
import
|
14
|
+
from pksuid import PKSUID
|
15
|
+
|
14
16
|
|
15
|
-
def convert_messages_to_api_client_messages(
|
17
|
+
def convert_messages_to_api_client_messages(
|
18
|
+
messages: list[ChatCompletionMessageParamWithTrace],
|
19
|
+
) -> list[MessageType]:
|
16
20
|
"""
|
17
21
|
Converts OpenAI ChatCompletionMessageParam messages to API client Message format.
|
18
22
|
|
@@ -33,7 +37,7 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
33
37
|
|
34
38
|
for i, message in enumerate(messages):
|
35
39
|
# Generate unique ID for each message
|
36
|
-
message_id = message.get("id") or str(
|
40
|
+
message_id = message.get("id") or str(PKSUID("scenariomsg"))
|
37
41
|
|
38
42
|
role = message.get("role")
|
39
43
|
content = message.get("content")
|
@@ -41,11 +45,13 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
41
45
|
if role == "user":
|
42
46
|
if not content:
|
43
47
|
raise ValueError(f"User message at index {i} missing required content")
|
44
|
-
|
48
|
+
message_ = UserMessage(
|
45
49
|
id=message_id,
|
46
50
|
role="user",
|
47
|
-
content=str(content)
|
48
|
-
)
|
51
|
+
content=str(content),
|
52
|
+
)
|
53
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
54
|
+
converted_messages.append(message_)
|
49
55
|
elif role == "assistant":
|
50
56
|
# Handle tool calls if present
|
51
57
|
tool_calls = message.get("tool_calls")
|
@@ -53,44 +59,54 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
53
59
|
|
54
60
|
if tool_calls:
|
55
61
|
for tool_call in tool_calls:
|
56
|
-
api_tool_calls.append(
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
+
api_tool_calls.append(
|
63
|
+
ToolCall(
|
64
|
+
id=tool_call.get("id", str(PKSUID("scenariotoolcall"))),
|
65
|
+
type_="function",
|
66
|
+
function=FunctionCall(
|
67
|
+
name=tool_call["function"].get("name", "unknown"),
|
68
|
+
arguments=tool_call["function"].get("arguments", "{}"),
|
69
|
+
),
|
62
70
|
)
|
63
|
-
)
|
71
|
+
)
|
64
72
|
|
65
|
-
|
73
|
+
message_ = AssistantMessage(
|
66
74
|
id=message_id,
|
67
75
|
role="assistant",
|
68
76
|
content=str(content),
|
69
|
-
tool_calls=api_tool_calls
|
70
|
-
)
|
77
|
+
tool_calls=api_tool_calls,
|
78
|
+
)
|
79
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
80
|
+
converted_messages.append(message_)
|
71
81
|
elif role == "system":
|
72
82
|
if not content:
|
73
|
-
raise ValueError(
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
)
|
83
|
+
raise ValueError(
|
84
|
+
f"System message at index {i} missing required content"
|
85
|
+
)
|
86
|
+
message_ = SystemMessage(id=message_id, role="system", content=str(content))
|
87
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
88
|
+
converted_messages.append(message_)
|
79
89
|
elif role == "tool":
|
80
90
|
tool_call_id = message.get("tool_call_id")
|
81
91
|
if not tool_call_id:
|
82
|
-
warnings.warn(
|
92
|
+
warnings.warn(
|
93
|
+
f"Tool message at index {i} missing required tool_call_id, skipping tool message"
|
94
|
+
)
|
83
95
|
continue
|
84
96
|
if not content:
|
85
|
-
warnings.warn(
|
97
|
+
warnings.warn(
|
98
|
+
f"Tool message at index {i} missing required content, skipping tool message"
|
99
|
+
)
|
86
100
|
continue
|
87
101
|
|
88
|
-
|
102
|
+
message_ = ToolMessage(
|
89
103
|
id=message_id,
|
90
104
|
role="tool",
|
91
105
|
content=str(content),
|
92
|
-
tool_call_id=tool_call_id
|
93
|
-
)
|
106
|
+
tool_call_id=tool_call_id,
|
107
|
+
)
|
108
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
109
|
+
converted_messages.append(message_)
|
94
110
|
else:
|
95
111
|
raise ValueError(f"Unsupported message role '{role}' at index {i}")
|
96
112
|
|
scenario/_utils/__init__.py
CHANGED
@@ -14,7 +14,7 @@ from .ids import (
|
|
14
14
|
generate_scenario_id,
|
15
15
|
generate_thread_id,
|
16
16
|
generate_message_id,
|
17
|
-
|
17
|
+
safe_parse_ksuid,
|
18
18
|
)
|
19
19
|
from .utils import (
|
20
20
|
SerializableAndPydanticEncoder,
|
@@ -34,7 +34,7 @@ __all__ = [
|
|
34
34
|
"generate_scenario_id",
|
35
35
|
"generate_thread_id",
|
36
36
|
"generate_message_id",
|
37
|
-
"
|
37
|
+
"safe_parse_ksuid",
|
38
38
|
"SerializableAndPydanticEncoder",
|
39
39
|
"SerializableWithStringFallback",
|
40
40
|
"print_openai_messages",
|
scenario/_utils/ids.py
CHANGED
@@ -7,7 +7,7 @@ and scenario tracking.
|
|
7
7
|
"""
|
8
8
|
|
9
9
|
import os
|
10
|
-
import
|
10
|
+
from pksuid import PKSUID
|
11
11
|
|
12
12
|
|
13
13
|
def generate_thread_id() -> str:
|
@@ -17,7 +17,7 @@ def generate_thread_id() -> str:
|
|
17
17
|
Returns:
|
18
18
|
str: A new thread ID.
|
19
19
|
"""
|
20
|
-
return f"
|
20
|
+
return f"{PKSUID('scenariothread')}"
|
21
21
|
|
22
22
|
|
23
23
|
def generate_scenario_run_id() -> str:
|
@@ -27,7 +27,7 @@ def generate_scenario_run_id() -> str:
|
|
27
27
|
Returns:
|
28
28
|
str: A new scenario run ID.
|
29
29
|
"""
|
30
|
-
return f"
|
30
|
+
return f"{PKSUID('scenariorun')}"
|
31
31
|
|
32
32
|
|
33
33
|
def generate_scenario_id() -> str:
|
@@ -37,7 +37,7 @@ def generate_scenario_id() -> str:
|
|
37
37
|
Returns:
|
38
38
|
str: A new scenario ID.
|
39
39
|
"""
|
40
|
-
return f"
|
40
|
+
return f"{PKSUID('scenario')}"
|
41
41
|
|
42
42
|
|
43
43
|
def get_batch_run_id() -> str:
|
@@ -52,7 +52,7 @@ def get_batch_run_id() -> str:
|
|
52
52
|
batch_run_id = os.environ.get("SCENARIO_BATCH_RUN_ID")
|
53
53
|
if not batch_run_id:
|
54
54
|
# Generate new batch ID if not set
|
55
|
-
batch_run_id = f"
|
55
|
+
batch_run_id = f"{PKSUID('scenariobatch')}"
|
56
56
|
os.environ["SCENARIO_BATCH_RUN_ID"] = batch_run_id
|
57
57
|
|
58
58
|
return batch_run_id
|
@@ -65,23 +65,23 @@ def generate_message_id() -> str:
|
|
65
65
|
Returns:
|
66
66
|
str: A new message ID.
|
67
67
|
"""
|
68
|
-
return f"
|
68
|
+
return f"{PKSUID('scenariomsg')}"
|
69
69
|
|
70
70
|
|
71
|
-
def
|
71
|
+
def safe_parse_ksuid(id_str: str) -> bool:
|
72
72
|
"""
|
73
|
-
Safely parses a
|
73
|
+
Safely parses a Ksuid string.
|
74
74
|
|
75
75
|
Args:
|
76
|
-
id_str: The
|
76
|
+
id_str: The Ksuid string to parse.
|
77
77
|
|
78
78
|
Returns:
|
79
|
-
bool: True if the
|
79
|
+
bool: True if the Ksuid string is valid, false otherwise.
|
80
80
|
"""
|
81
81
|
try:
|
82
|
-
|
82
|
+
PKSUID.parse(id_str)
|
83
83
|
return True
|
84
|
-
except
|
84
|
+
except Exception:
|
85
85
|
return False
|
86
86
|
|
87
87
|
|
scenario/config/scenario.py
CHANGED
@@ -5,6 +5,7 @@ This module provides the main configuration class for customizing the behavior
|
|
5
5
|
of the Scenario testing framework, including execution parameters and debugging options.
|
6
6
|
"""
|
7
7
|
|
8
|
+
import os
|
8
9
|
from typing import Optional, Union, ClassVar
|
9
10
|
from pydantic import BaseModel
|
10
11
|
|
@@ -53,6 +54,11 @@ class ScenarioConfig(BaseModel):
|
|
53
54
|
verbose: Optional[Union[bool, int]] = True
|
54
55
|
cache_key: Optional[str] = None
|
55
56
|
debug: Optional[bool] = False
|
57
|
+
headless: Optional[bool] = os.getenv("SCENARIO_HEADLESS", "false").lower() not in [
|
58
|
+
"false",
|
59
|
+
"0",
|
60
|
+
"",
|
61
|
+
]
|
56
62
|
|
57
63
|
default_config: ClassVar[Optional["ScenarioConfig"]] = None
|
58
64
|
|
@@ -64,6 +70,7 @@ class ScenarioConfig(BaseModel):
|
|
64
70
|
verbose: Optional[Union[bool, int]] = None,
|
65
71
|
cache_key: Optional[str] = None,
|
66
72
|
debug: Optional[bool] = None,
|
73
|
+
headless: Optional[bool] = None,
|
67
74
|
) -> None:
|
68
75
|
"""
|
69
76
|
Set global configuration settings for all scenario executions.
|
@@ -107,6 +114,7 @@ class ScenarioConfig(BaseModel):
|
|
107
114
|
verbose=verbose,
|
108
115
|
cache_key=cache_key,
|
109
116
|
debug=debug,
|
117
|
+
headless=headless,
|
110
118
|
)
|
111
119
|
)
|
112
120
|
|
scenario/judge_agent.py
CHANGED
@@ -12,7 +12,8 @@ import logging
|
|
12
12
|
import re
|
13
13
|
from typing import List, Optional, cast
|
14
14
|
|
15
|
-
|
15
|
+
import litellm
|
16
|
+
from litellm import Choices
|
16
17
|
from litellm.files.main import ModelResponse
|
17
18
|
|
18
19
|
from scenario.cache import scenario_cache
|
@@ -356,7 +357,7 @@ if you don't have enough information to make a verdict, say inconclusive with ma
|
|
356
357
|
|
357
358
|
response = cast(
|
358
359
|
ModelResponse,
|
359
|
-
completion(
|
360
|
+
litellm.completion(
|
360
361
|
model=self.model,
|
361
362
|
messages=messages,
|
362
363
|
temperature=self.temperature,
|
@@ -398,7 +399,7 @@ if you don't have enough information to make a verdict, say inconclusive with ma
|
|
398
399
|
failed_criteria = [
|
399
400
|
self.criteria[idx]
|
400
401
|
for idx, criterion in enumerate(criteria.values())
|
401
|
-
if criterion == False
|
402
|
+
if criterion == False or criterion == "inconclusive"
|
402
403
|
]
|
403
404
|
|
404
405
|
# Return the appropriate ScenarioResult based on the verdict
|
scenario/py.typed
ADDED
File without changes
|
scenario/pytest_plugin.py
CHANGED
@@ -199,6 +199,8 @@ class ScenarioReporter:
|
|
199
199
|
# Store the original run method
|
200
200
|
original_run = ScenarioExecutor.run
|
201
201
|
|
202
|
+
def pytest_addoption(parser):
|
203
|
+
parser.addoption("--headless", action="store_true")
|
202
204
|
|
203
205
|
@pytest.hookimpl(trylast=True)
|
204
206
|
def pytest_configure(config):
|
@@ -240,6 +242,9 @@ def pytest_configure(config):
|
|
240
242
|
print(colored("\nScenario debug mode enabled (--debug).", "yellow"))
|
241
243
|
ScenarioConfig.configure(verbose=True, debug=True)
|
242
244
|
|
245
|
+
if config.getoption("--headless"):
|
246
|
+
ScenarioConfig.configure(headless=True)
|
247
|
+
|
243
248
|
# Create a global reporter instance
|
244
249
|
config._scenario_reporter = ScenarioReporter()
|
245
250
|
|
scenario/scenario_executor.py
CHANGED
@@ -6,6 +6,7 @@ of scenario tests, managing the interaction between user simulators, agents unde
|
|
6
6
|
and judge agents to determine test success or failure.
|
7
7
|
"""
|
8
8
|
|
9
|
+
import json
|
9
10
|
import sys
|
10
11
|
from typing import (
|
11
12
|
Awaitable,
|
@@ -17,6 +18,7 @@ from typing import (
|
|
17
18
|
Tuple,
|
18
19
|
Union,
|
19
20
|
TypedDict,
|
21
|
+
cast,
|
20
22
|
)
|
21
23
|
import time
|
22
24
|
import warnings
|
@@ -33,6 +35,7 @@ from scenario._utils import (
|
|
33
35
|
await_if_awaitable,
|
34
36
|
get_batch_run_id,
|
35
37
|
generate_scenario_run_id,
|
38
|
+
SerializableWithStringFallback,
|
36
39
|
)
|
37
40
|
from openai.types.chat import (
|
38
41
|
ChatCompletionMessageParam,
|
@@ -40,7 +43,7 @@ from openai.types.chat import (
|
|
40
43
|
ChatCompletionAssistantMessageParam,
|
41
44
|
)
|
42
45
|
|
43
|
-
from .types import AgentInput, AgentRole, ScenarioResult, ScriptStep
|
46
|
+
from .types import AgentInput, AgentRole, ChatCompletionMessageParamWithTrace, ScenarioResult, ScriptStep
|
44
47
|
from ._error_messages import agent_response_not_awaitable
|
45
48
|
from .cache import context_scenario
|
46
49
|
from .agent_adapter import AgentAdapter
|
@@ -62,6 +65,11 @@ from ._events import (
|
|
62
65
|
from rx.subject.subject import Subject
|
63
66
|
from rx.core.observable.observable import Observable
|
64
67
|
|
68
|
+
import litellm
|
69
|
+
import langwatch
|
70
|
+
import langwatch.telemetry.context
|
71
|
+
from langwatch.telemetry.tracing import LangWatchTrace
|
72
|
+
|
65
73
|
|
66
74
|
class ScenarioExecutor:
|
67
75
|
"""
|
@@ -101,6 +109,7 @@ class ScenarioExecutor:
|
|
101
109
|
_pending_agents_on_turn: Set[AgentAdapter] = set()
|
102
110
|
_agent_times: Dict[int, float] = {}
|
103
111
|
_events: Subject
|
112
|
+
_trace: LangWatchTrace
|
104
113
|
|
105
114
|
event_bus: ScenarioEventBus
|
106
115
|
|
@@ -153,10 +162,12 @@ class ScenarioExecutor:
|
|
153
162
|
verbose=verbose,
|
154
163
|
cache_key=cache_key,
|
155
164
|
debug=debug,
|
165
|
+
headless=None,
|
156
166
|
)
|
157
167
|
self.config = (ScenarioConfig.default_config or ScenarioConfig()).merge(config)
|
158
168
|
|
159
|
-
self.
|
169
|
+
self.batch_run_id = get_batch_run_id()
|
170
|
+
self.scenario_set_id = set_id or "default"
|
160
171
|
|
161
172
|
# Create executor's own event stream
|
162
173
|
self._events = Subject()
|
@@ -165,9 +176,6 @@ class ScenarioExecutor:
|
|
165
176
|
self.event_bus = event_bus or ScenarioEventBus()
|
166
177
|
self.event_bus.subscribe_to_events(self._events)
|
167
178
|
|
168
|
-
self.batch_run_id = get_batch_run_id()
|
169
|
-
self.scenario_set_id = set_id or "default"
|
170
|
-
|
171
179
|
@property
|
172
180
|
def events(self) -> Observable:
|
173
181
|
"""Expose event stream for subscribers like the event bus."""
|
@@ -198,7 +206,7 @@ class ScenarioExecutor:
|
|
198
206
|
self._state = ScenarioState(
|
199
207
|
description=self.description,
|
200
208
|
messages=[],
|
201
|
-
thread_id=str(PKSUID("
|
209
|
+
thread_id=str(PKSUID("scenariothread")),
|
202
210
|
current_turn=0,
|
203
211
|
config=self.config,
|
204
212
|
_executor=self,
|
@@ -252,6 +260,8 @@ class ScenarioExecutor:
|
|
252
260
|
)
|
253
261
|
```
|
254
262
|
"""
|
263
|
+
message = cast(ChatCompletionMessageParamWithTrace, message)
|
264
|
+
message["trace_id"] = self._trace.trace_id
|
255
265
|
self._state.messages.append(message)
|
256
266
|
|
257
267
|
# Broadcast the message to other agents
|
@@ -262,6 +272,21 @@ class ScenarioExecutor:
|
|
262
272
|
self._pending_messages[idx] = []
|
263
273
|
self._pending_messages[idx].append(message)
|
264
274
|
|
275
|
+
# Update trace with input/output
|
276
|
+
if message["role"] == "user":
|
277
|
+
self._trace.update(input={"type": "text", "value": str(message["content"])})
|
278
|
+
elif message["role"] == "assistant":
|
279
|
+
self._trace.update(
|
280
|
+
output={
|
281
|
+
"type": "text",
|
282
|
+
"value": str(
|
283
|
+
message["content"]
|
284
|
+
if "content" in message
|
285
|
+
else json.dumps(message, cls=SerializableWithStringFallback)
|
286
|
+
),
|
287
|
+
}
|
288
|
+
)
|
289
|
+
|
265
290
|
def add_messages(
|
266
291
|
self,
|
267
292
|
messages: List[ChatCompletionMessageParam],
|
@@ -291,6 +316,21 @@ class ScenarioExecutor:
|
|
291
316
|
self.add_message(message, from_agent_idx)
|
292
317
|
|
293
318
|
def _new_turn(self):
|
319
|
+
if hasattr(self, "_trace") and self._trace is not None:
|
320
|
+
self._trace.__exit__(None, None, None)
|
321
|
+
|
322
|
+
self._trace = langwatch.trace(
|
323
|
+
name="Scenario Turn",
|
324
|
+
metadata={
|
325
|
+
"labels": ["scenario"],
|
326
|
+
"thread_id": self._state.thread_id,
|
327
|
+
"scenario.name": self.name,
|
328
|
+
"scenario.batch_id": self.batch_run_id,
|
329
|
+
"scenario.set_id": self.scenario_set_id,
|
330
|
+
"scenario.turn": self._state.current_turn,
|
331
|
+
},
|
332
|
+
).__enter__()
|
333
|
+
|
294
334
|
self._pending_agents_on_turn = set(self.agents)
|
295
335
|
self._pending_roles_on_turn = [
|
296
336
|
AgentRole.USER,
|
@@ -459,7 +499,7 @@ class ScenarioExecutor:
|
|
459
499
|
|
460
500
|
async def _call_agent(
|
461
501
|
self, idx: int, role: AgentRole, request_judgment: bool = False
|
462
|
-
) -> Union[List[ChatCompletionMessageParam], ScenarioResult]:
|
502
|
+
) -> Union[List[ChatCompletionMessageParam], ScenarioResult, None]:
|
463
503
|
agent = self.agents[idx]
|
464
504
|
|
465
505
|
if role == AgentRole.USER and self.config.debug:
|
@@ -481,67 +521,84 @@ class ScenarioExecutor:
|
|
481
521
|
ChatCompletionUserMessageParam(role="user", content=input_message)
|
482
522
|
]
|
483
523
|
|
484
|
-
with
|
485
|
-
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
warnings.
|
502
|
-
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
524
|
+
with self._trace.span(type="agent", name=f"{agent.__class__.__name__}.call") as span:
|
525
|
+
with show_spinner(
|
526
|
+
text=(
|
527
|
+
"Judging..."
|
528
|
+
if role == AgentRole.JUDGE
|
529
|
+
else f"{role.value if isinstance(role, AgentRole) else role}:"
|
530
|
+
),
|
531
|
+
color=(
|
532
|
+
"blue"
|
533
|
+
if role == AgentRole.AGENT
|
534
|
+
else "green" if role == AgentRole.USER else "yellow"
|
535
|
+
),
|
536
|
+
enabled=self.config.verbose,
|
537
|
+
):
|
538
|
+
start_time = time.time()
|
539
|
+
|
540
|
+
# Prevent pydantic validation warnings which should already be disabled
|
541
|
+
with warnings.catch_warnings():
|
542
|
+
warnings.simplefilter("ignore")
|
543
|
+
|
544
|
+
self._trace.autotrack_litellm_calls(litellm)
|
545
|
+
|
546
|
+
agent_response = agent.call(
|
547
|
+
AgentInput(
|
548
|
+
# TODO: test thread_id
|
549
|
+
thread_id=self._state.thread_id,
|
550
|
+
messages=cast(List[ChatCompletionMessageParam], self._state.messages),
|
551
|
+
new_messages=self._pending_messages.get(idx, []),
|
552
|
+
judgment_request=request_judgment,
|
553
|
+
scenario_state=self._state,
|
554
|
+
)
|
555
|
+
)
|
556
|
+
if not isinstance(agent_response, Awaitable):
|
557
|
+
raise Exception(
|
558
|
+
agent_response_not_awaitable(agent.__class__.__name__),
|
510
559
|
)
|
511
|
-
)
|
512
|
-
if not isinstance(agent_response, Awaitable):
|
513
|
-
raise Exception(
|
514
|
-
agent_response_not_awaitable(agent.__class__.__name__),
|
515
|
-
)
|
516
560
|
|
517
|
-
|
561
|
+
agent_response = await agent_response
|
518
562
|
|
519
|
-
|
520
|
-
|
521
|
-
|
563
|
+
if idx not in self._agent_times:
|
564
|
+
self._agent_times[idx] = 0
|
565
|
+
self._agent_times[idx] += time.time() - start_time
|
522
566
|
|
523
|
-
|
524
|
-
|
567
|
+
self._pending_messages[idx] = []
|
568
|
+
check_valid_return_type(agent_response, agent.__class__.__name__)
|
569
|
+
|
570
|
+
messages = []
|
571
|
+
if isinstance(agent_response, ScenarioResult):
|
572
|
+
# TODO: should be an event
|
573
|
+
span.add_evaluation(
|
574
|
+
name=f"{agent.__class__.__name__} Judgment",
|
575
|
+
status="processed",
|
576
|
+
passed=agent_response.success,
|
577
|
+
details=agent_response.reasoning,
|
578
|
+
score=(
|
579
|
+
len(agent_response.passed_criteria)
|
580
|
+
/ len(agent_response.failed_criteria)
|
581
|
+
if agent_response.failed_criteria
|
582
|
+
else 1.0
|
583
|
+
),
|
584
|
+
)
|
525
585
|
|
526
|
-
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
agent_response,
|
533
|
-
role="user" if role == AgentRole.USER else "assistant",
|
534
|
-
)
|
586
|
+
return agent_response
|
587
|
+
else:
|
588
|
+
messages = convert_agent_return_types_to_openai_messages(
|
589
|
+
agent_response,
|
590
|
+
role="user" if role == AgentRole.USER else "assistant",
|
591
|
+
)
|
535
592
|
|
536
|
-
|
593
|
+
self.add_messages(messages, from_agent_idx=idx)
|
537
594
|
|
538
|
-
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
595
|
+
if messages and self.config.verbose:
|
596
|
+
print_openai_messages(
|
597
|
+
self._scenario_name(),
|
598
|
+
[m for m in messages if m["role"] != "system"],
|
599
|
+
)
|
543
600
|
|
544
|
-
|
601
|
+
return messages
|
545
602
|
|
546
603
|
def _scenario_name(self):
|
547
604
|
if self.config.verbose == 2:
|
@@ -816,6 +873,7 @@ class ScenarioExecutor:
|
|
816
873
|
|
817
874
|
# Signal end of event stream
|
818
875
|
self._events.on_completed()
|
876
|
+
self._trace.__exit__(None, None, None)
|
819
877
|
|
820
878
|
|
821
879
|
async def run(
|
scenario/scenario_state.py
CHANGED
@@ -14,6 +14,7 @@ from openai.types.chat import (
|
|
14
14
|
)
|
15
15
|
from pydantic import BaseModel
|
16
16
|
|
17
|
+
from scenario.types import ChatCompletionMessageParamWithTrace
|
17
18
|
from scenario.config import ScenarioConfig
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
@@ -70,7 +71,7 @@ class ScenarioState(BaseModel):
|
|
70
71
|
"""
|
71
72
|
|
72
73
|
description: str
|
73
|
-
messages: List[
|
74
|
+
messages: List[ChatCompletionMessageParamWithTrace]
|
74
75
|
thread_id: str
|
75
76
|
current_turn: int
|
76
77
|
config: ScenarioConfig
|
scenario/types.py
CHANGED
@@ -8,10 +8,20 @@ from typing import (
|
|
8
8
|
Callable,
|
9
9
|
List,
|
10
10
|
Optional,
|
11
|
+
TypeAlias,
|
11
12
|
Union,
|
12
13
|
)
|
13
14
|
|
14
|
-
from openai.types.chat import
|
15
|
+
from openai.types.chat import (
|
16
|
+
ChatCompletionMessageParam,
|
17
|
+
ChatCompletionUserMessageParam,
|
18
|
+
ChatCompletionToolMessageParam,
|
19
|
+
ChatCompletionUserMessageParam,
|
20
|
+
ChatCompletionSystemMessageParam,
|
21
|
+
ChatCompletionFunctionMessageParam,
|
22
|
+
ChatCompletionAssistantMessageParam,
|
23
|
+
ChatCompletionDeveloperMessageParam,
|
24
|
+
)
|
15
25
|
|
16
26
|
# Prevent circular imports + Pydantic breaking
|
17
27
|
if TYPE_CHECKING:
|
@@ -22,6 +32,48 @@ else:
|
|
22
32
|
ScenarioStateType = Any
|
23
33
|
|
24
34
|
|
35
|
+
# Since Python types do not support intersection, we need to wrap ALL the chat completion
|
36
|
+
# message types with the trace_id field
|
37
|
+
|
38
|
+
|
39
|
+
class ChatCompletionDeveloperMessageParamWithTrace(ChatCompletionDeveloperMessageParam):
|
40
|
+
trace_id: Optional[str]
|
41
|
+
|
42
|
+
|
43
|
+
class ChatCompletionSystemMessageParamWithTrace(ChatCompletionSystemMessageParam):
|
44
|
+
trace_id: Optional[str]
|
45
|
+
|
46
|
+
|
47
|
+
class ChatCompletionUserMessageParamWithTrace(ChatCompletionUserMessageParam):
|
48
|
+
trace_id: Optional[str]
|
49
|
+
|
50
|
+
|
51
|
+
class ChatCompletionAssistantMessageParamWithTrace(ChatCompletionAssistantMessageParam):
|
52
|
+
trace_id: Optional[str]
|
53
|
+
|
54
|
+
|
55
|
+
class ChatCompletionToolMessageParamWithTrace(ChatCompletionToolMessageParam):
|
56
|
+
trace_id: Optional[str]
|
57
|
+
|
58
|
+
|
59
|
+
class ChatCompletionFunctionMessageParamWithTrace(ChatCompletionFunctionMessageParam):
|
60
|
+
trace_id: Optional[str]
|
61
|
+
|
62
|
+
|
63
|
+
"""
|
64
|
+
A wrapper around ChatCompletionMessageParam that adds a trace_id field to be able to
|
65
|
+
tie back each message of the scenario run to a trace.
|
66
|
+
"""
|
67
|
+
ChatCompletionMessageParamWithTrace: TypeAlias = Union[
|
68
|
+
ChatCompletionDeveloperMessageParamWithTrace,
|
69
|
+
ChatCompletionSystemMessageParamWithTrace,
|
70
|
+
ChatCompletionUserMessageParamWithTrace,
|
71
|
+
ChatCompletionAssistantMessageParamWithTrace,
|
72
|
+
ChatCompletionToolMessageParamWithTrace,
|
73
|
+
ChatCompletionFunctionMessageParamWithTrace,
|
74
|
+
]
|
75
|
+
|
76
|
+
|
25
77
|
class AgentRole(Enum):
|
26
78
|
"""
|
27
79
|
Defines the different roles that agents can play in a scenario.
|
@@ -171,7 +223,7 @@ class ScenarioResult(BaseModel):
|
|
171
223
|
|
172
224
|
success: bool
|
173
225
|
# Prevent issues with slightly inconsistent message types for example when comming from Gemini right at the result level
|
174
|
-
messages: Annotated[List[
|
226
|
+
messages: Annotated[List[ChatCompletionMessageParamWithTrace], SkipValidation]
|
175
227
|
reasoning: Optional[str] = None
|
176
228
|
passed_criteria: List[str] = []
|
177
229
|
failed_criteria: List[str] = []
|
scenario/user_simulator_agent.py
CHANGED
@@ -10,7 +10,8 @@ conversation history.
|
|
10
10
|
import logging
|
11
11
|
from typing import Optional, cast
|
12
12
|
|
13
|
-
|
13
|
+
import litellm
|
14
|
+
from litellm import Choices
|
14
15
|
from litellm.files.main import ModelResponse
|
15
16
|
|
16
17
|
from scenario.cache import scenario_cache
|
@@ -228,7 +229,7 @@ Your goal (assistant) is to interact with the Agent Under Test (user) as if you
|
|
228
229
|
|
229
230
|
response = cast(
|
230
231
|
ModelResponse,
|
231
|
-
completion(
|
232
|
+
litellm.completion(
|
232
233
|
model=self.model,
|
233
234
|
messages=messages,
|
234
235
|
temperature=self.temperature,
|
File without changes
|
File without changes
|
File without changes
|