langwatch-scenario 0.7.9__py3-none-any.whl → 0.7.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {langwatch_scenario-0.7.9.dist-info → langwatch_scenario-0.7.10.dist-info}/METADATA +3 -2
- {langwatch_scenario-0.7.9.dist-info → langwatch_scenario-0.7.10.dist-info}/RECORD +12 -12
- scenario/_events/event_alert_message_logger.py +5 -0
- scenario/_events/utils.py +43 -27
- scenario/judge_agent.py +3 -2
- scenario/scenario_executor.py +116 -59
- scenario/scenario_state.py +2 -1
- scenario/types.py +54 -2
- scenario/user_simulator_agent.py +3 -2
- {langwatch_scenario-0.7.9.dist-info → langwatch_scenario-0.7.10.dist-info}/WHEEL +0 -0
- {langwatch_scenario-0.7.9.dist-info → langwatch_scenario-0.7.10.dist-info}/entry_points.txt +0 -0
- {langwatch_scenario-0.7.9.dist-info → langwatch_scenario-0.7.10.dist-info}/top_level.txt +0 -0
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.4
|
2
2
|
Name: langwatch-scenario
|
3
|
-
Version: 0.7.
|
3
|
+
Version: 0.7.10
|
4
4
|
Summary: The end-to-end agent testing library
|
5
5
|
Author-email: LangWatch Team <support@langwatch.ai>
|
6
6
|
License: MIT
|
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.8
|
|
14
14
|
Classifier: Programming Language :: Python :: 3.9
|
15
15
|
Classifier: Programming Language :: Python :: 3.10
|
16
16
|
Classifier: Programming Language :: Python :: 3.11
|
17
|
-
Requires-Python: >=3.
|
17
|
+
Requires-Python: >=3.10
|
18
18
|
Description-Content-Type: text/markdown
|
19
19
|
Requires-Dist: pytest>=8.1.1
|
20
20
|
Requires-Dist: litellm>=1.49.0
|
@@ -31,6 +31,7 @@ Requires-Dist: httpx>=0.27.0
|
|
31
31
|
Requires-Dist: rx>=3.2.0
|
32
32
|
Requires-Dist: python-dateutil>=2.9.0.post0
|
33
33
|
Requires-Dist: pydantic-settings>=2.9.1
|
34
|
+
Requires-Dist: langwatch>=0.2.19
|
34
35
|
Provides-Extra: dev
|
35
36
|
Requires-Dist: black; extra == "dev"
|
36
37
|
Requires-Dist: isort; extra == "dev"
|
@@ -2,21 +2,21 @@ scenario/__init__.py,sha256=4WO8TjY8Lc0NhYL7b9LvaB1xCBqwUkLuI0uIA6PQP6c,4223
|
|
2
2
|
scenario/_error_messages.py,sha256=QVFSbhzsVNGz2GOBOaoQFW6w6AOyZCWLTt0ySWPfnGw,3882
|
3
3
|
scenario/agent_adapter.py,sha256=PoY2KQqYuqzIIb3-nhIU-MPXwHJc1vmwdweMy7ut-hk,4255
|
4
4
|
scenario/cache.py,sha256=J6s6Sia_Ce6TrnsInlhfxm6SF8tygo3sH-_cQCRX1WA,6213
|
5
|
-
scenario/judge_agent.py,sha256=
|
5
|
+
scenario/judge_agent.py,sha256=hHQ2nKsOgSyTtN0LdE6xIF0wZnnlYLN6RcxTPecFHDU,16770
|
6
6
|
scenario/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
7
7
|
scenario/pytest_plugin.py,sha256=wRCuGD9uwrrLt2fY15zK6mnmY9W_dO_m0WalPJYE5II,11491
|
8
|
-
scenario/scenario_executor.py,sha256=
|
9
|
-
scenario/scenario_state.py,sha256=
|
8
|
+
scenario/scenario_executor.py,sha256=v41UgSHebosXf95FfYIeVUm6s4IbMP_U58FdGoZ_kZU,35653
|
9
|
+
scenario/scenario_state.py,sha256=R8PhPHW3obYo3DCjBH5XDdZ6bp4uol7wCXO8K2Tz30I,7101
|
10
10
|
scenario/script.py,sha256=A0N5pP0l4FFn1xdKc78U_wkwWhEWH3EFeU_LRDtNyEI,12241
|
11
|
-
scenario/types.py,sha256=
|
12
|
-
scenario/user_simulator_agent.py,sha256=
|
11
|
+
scenario/types.py,sha256=CRSCHUplXEXhj6EYQsncwJBzbd2128YTGlFxlk-rrG8,11193
|
12
|
+
scenario/user_simulator_agent.py,sha256=gXRaeoivEAcenIEqMDU6bWzv8cOrJaaooNrTdpC9TE4,9630
|
13
13
|
scenario/_events/__init__.py,sha256=4cj6H9zuXzvWhT2P2JNdjWzeF1PUepTjqIDw85Vid9s,1500
|
14
|
-
scenario/_events/event_alert_message_logger.py,sha256=
|
14
|
+
scenario/_events/event_alert_message_logger.py,sha256=XcofGgXjeiTC75NPYheBpHxqA6R4pYAuHZa7-kH9Grg,2975
|
15
15
|
scenario/_events/event_bus.py,sha256=IsKNsClF1JFYj728EcxX1hw_KbfDkfJq3Y2Kv4h94n4,9871
|
16
16
|
scenario/_events/event_reporter.py,sha256=-6NNbBMy_FYr1O-1FuZ6eIUnLuI8NGRMUr0pybLJrCI,3873
|
17
17
|
scenario/_events/events.py,sha256=UtEGY-_1B0LrwpgsNKgrvJBZhRtxuj3K_i6ZBfF7E4Q,6387
|
18
18
|
scenario/_events/messages.py,sha256=quwP2OkeaGasNOoaV8GUeosZVKc5XDsde08T0xx_YQo,2297
|
19
|
-
scenario/_events/utils.py,sha256=
|
19
|
+
scenario/_events/utils.py,sha256=CRrdDHBD2ptcNIjzW0eEG1V5-Vw1gFnp_UTz5zMQ_Ak,4051
|
20
20
|
scenario/_generated/langwatch_api_client/README.md,sha256=Az5f2L4ChOnG_ZtrdBagzRVgeTCtBkbD_S5cIeAry2o,5424
|
21
21
|
scenario/_generated/langwatch_api_client/pyproject.toml,sha256=Z8wxuGp4H9BJYVVJB8diW7rRU9XYxtPfw9mU4_wq4cA,560
|
22
22
|
scenario/_generated/langwatch_api_client/lang_watch_api_client/__init__.py,sha256=vVrn17y-3l3fOqeJk8aN3GlStRm2fo0f313l_0LtJNs,368
|
@@ -235,8 +235,8 @@ scenario/config/__init__.py,sha256=b2X_bqkIrd7jZY9dRrXk2wOqoPe87Nl_SRGuZhlolxA,1
|
|
235
235
|
scenario/config/langwatch.py,sha256=ijWchFbUsLbQooAZmwyTw4rxfRLQseZ1GoVSiPPbzpw,1677
|
236
236
|
scenario/config/model.py,sha256=T4HYA79CW1NxXDkFlyftYR6JzZcowbtIx0H-ijxRyfg,1297
|
237
237
|
scenario/config/scenario.py,sha256=6jrtcm0Fo7FpxQta7QIKdGMgl7cXrn374Inzx29hRuk,5406
|
238
|
-
langwatch_scenario-0.7.
|
239
|
-
langwatch_scenario-0.7.
|
240
|
-
langwatch_scenario-0.7.
|
241
|
-
langwatch_scenario-0.7.
|
242
|
-
langwatch_scenario-0.7.
|
238
|
+
langwatch_scenario-0.7.10.dist-info/METADATA,sha256=pbLZM8UXj1_1TWHjheHP6QREOvRWfX7nHEdfY2ZX4aA,20065
|
239
|
+
langwatch_scenario-0.7.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
240
|
+
langwatch_scenario-0.7.10.dist-info/entry_points.txt,sha256=WlEnJ_gku0i18bIa3DSuGqXRX-QDQLe_s0YmRzK45TI,45
|
241
|
+
langwatch_scenario-0.7.10.dist-info/top_level.txt,sha256=45Mn28aedJsetnBMB5xSmrJ-yo701QLH89Zlz4r1clE,9
|
242
|
+
langwatch_scenario-0.7.10.dist-info/RECORD,,
|
@@ -15,6 +15,7 @@ class EventAlertMessageLogger:
|
|
15
15
|
"""
|
16
16
|
|
17
17
|
_shown_batch_ids: Set[str] = set()
|
18
|
+
_shown_watch_urls: Set[str] = set()
|
18
19
|
|
19
20
|
def handle_greeting(self) -> None:
|
20
21
|
"""
|
@@ -40,6 +41,10 @@ class EventAlertMessageLogger:
|
|
40
41
|
if self._is_greeting_disabled():
|
41
42
|
return
|
42
43
|
|
44
|
+
if set_url in EventAlertMessageLogger._shown_watch_urls:
|
45
|
+
return
|
46
|
+
|
47
|
+
EventAlertMessageLogger._shown_watch_urls.add(set_url)
|
43
48
|
self._display_watch_message(set_url)
|
44
49
|
|
45
50
|
def _is_greeting_disabled(self) -> bool:
|
scenario/_events/utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import warnings
|
2
|
-
|
2
|
+
|
3
|
+
from ..types import ChatCompletionMessageParamWithTrace
|
3
4
|
from .events import MessageType
|
4
5
|
from .messages import (
|
5
6
|
SystemMessage,
|
@@ -12,7 +13,10 @@ from .messages import (
|
|
12
13
|
from typing import List
|
13
14
|
from pksuid import PKSUID
|
14
15
|
|
15
|
-
|
16
|
+
|
17
|
+
def convert_messages_to_api_client_messages(
|
18
|
+
messages: list[ChatCompletionMessageParamWithTrace],
|
19
|
+
) -> list[MessageType]:
|
16
20
|
"""
|
17
21
|
Converts OpenAI ChatCompletionMessageParam messages to API client Message format.
|
18
22
|
|
@@ -33,7 +37,7 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
33
37
|
|
34
38
|
for i, message in enumerate(messages):
|
35
39
|
# Generate unique ID for each message
|
36
|
-
message_id = message.get("id") or str(PKSUID(
|
40
|
+
message_id = message.get("id") or str(PKSUID("scenariomsg"))
|
37
41
|
|
38
42
|
role = message.get("role")
|
39
43
|
content = message.get("content")
|
@@ -41,11 +45,13 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
41
45
|
if role == "user":
|
42
46
|
if not content:
|
43
47
|
raise ValueError(f"User message at index {i} missing required content")
|
44
|
-
|
48
|
+
message_ = UserMessage(
|
45
49
|
id=message_id,
|
46
50
|
role="user",
|
47
|
-
content=str(content)
|
48
|
-
)
|
51
|
+
content=str(content),
|
52
|
+
)
|
53
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
54
|
+
converted_messages.append(message_)
|
49
55
|
elif role == "assistant":
|
50
56
|
# Handle tool calls if present
|
51
57
|
tool_calls = message.get("tool_calls")
|
@@ -53,44 +59,54 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
|
|
53
59
|
|
54
60
|
if tool_calls:
|
55
61
|
for tool_call in tool_calls:
|
56
|
-
api_tool_calls.append(
|
57
|
-
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
+
api_tool_calls.append(
|
63
|
+
ToolCall(
|
64
|
+
id=tool_call.get("id", str(PKSUID("scenariotoolcall"))),
|
65
|
+
type_="function",
|
66
|
+
function=FunctionCall(
|
67
|
+
name=tool_call["function"].get("name", "unknown"),
|
68
|
+
arguments=tool_call["function"].get("arguments", "{}"),
|
69
|
+
),
|
62
70
|
)
|
63
|
-
)
|
71
|
+
)
|
64
72
|
|
65
|
-
|
73
|
+
message_ = AssistantMessage(
|
66
74
|
id=message_id,
|
67
75
|
role="assistant",
|
68
76
|
content=str(content),
|
69
|
-
tool_calls=api_tool_calls
|
70
|
-
)
|
77
|
+
tool_calls=api_tool_calls,
|
78
|
+
)
|
79
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
80
|
+
converted_messages.append(message_)
|
71
81
|
elif role == "system":
|
72
82
|
if not content:
|
73
|
-
raise ValueError(
|
74
|
-
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
)
|
83
|
+
raise ValueError(
|
84
|
+
f"System message at index {i} missing required content"
|
85
|
+
)
|
86
|
+
message_ = SystemMessage(id=message_id, role="system", content=str(content))
|
87
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
88
|
+
converted_messages.append(message_)
|
79
89
|
elif role == "tool":
|
80
90
|
tool_call_id = message.get("tool_call_id")
|
81
91
|
if not tool_call_id:
|
82
|
-
warnings.warn(
|
92
|
+
warnings.warn(
|
93
|
+
f"Tool message at index {i} missing required tool_call_id, skipping tool message"
|
94
|
+
)
|
83
95
|
continue
|
84
96
|
if not content:
|
85
|
-
warnings.warn(
|
97
|
+
warnings.warn(
|
98
|
+
f"Tool message at index {i} missing required content, skipping tool message"
|
99
|
+
)
|
86
100
|
continue
|
87
101
|
|
88
|
-
|
102
|
+
message_ = ToolMessage(
|
89
103
|
id=message_id,
|
90
104
|
role="tool",
|
91
105
|
content=str(content),
|
92
|
-
tool_call_id=tool_call_id
|
93
|
-
)
|
106
|
+
tool_call_id=tool_call_id,
|
107
|
+
)
|
108
|
+
message_.additional_properties = {"trace_id": message.get("trace_id")}
|
109
|
+
converted_messages.append(message_)
|
94
110
|
else:
|
95
111
|
raise ValueError(f"Unsupported message role '{role}' at index {i}")
|
96
112
|
|
scenario/judge_agent.py
CHANGED
@@ -12,7 +12,8 @@ import logging
|
|
12
12
|
import re
|
13
13
|
from typing import List, Optional, cast
|
14
14
|
|
15
|
-
|
15
|
+
import litellm
|
16
|
+
from litellm import Choices
|
16
17
|
from litellm.files.main import ModelResponse
|
17
18
|
|
18
19
|
from scenario.cache import scenario_cache
|
@@ -356,7 +357,7 @@ if you don't have enough information to make a verdict, say inconclusive with ma
|
|
356
357
|
|
357
358
|
response = cast(
|
358
359
|
ModelResponse,
|
359
|
-
completion(
|
360
|
+
litellm.completion(
|
360
361
|
model=self.model,
|
361
362
|
messages=messages,
|
362
363
|
temperature=self.temperature,
|
scenario/scenario_executor.py
CHANGED
@@ -6,6 +6,7 @@ of scenario tests, managing the interaction between user simulators, agents unde
|
|
6
6
|
and judge agents to determine test success or failure.
|
7
7
|
"""
|
8
8
|
|
9
|
+
import json
|
9
10
|
import sys
|
10
11
|
from typing import (
|
11
12
|
Awaitable,
|
@@ -17,6 +18,7 @@ from typing import (
|
|
17
18
|
Tuple,
|
18
19
|
Union,
|
19
20
|
TypedDict,
|
21
|
+
cast,
|
20
22
|
)
|
21
23
|
import time
|
22
24
|
import warnings
|
@@ -33,6 +35,7 @@ from scenario._utils import (
|
|
33
35
|
await_if_awaitable,
|
34
36
|
get_batch_run_id,
|
35
37
|
generate_scenario_run_id,
|
38
|
+
SerializableWithStringFallback,
|
36
39
|
)
|
37
40
|
from openai.types.chat import (
|
38
41
|
ChatCompletionMessageParam,
|
@@ -40,7 +43,7 @@ from openai.types.chat import (
|
|
40
43
|
ChatCompletionAssistantMessageParam,
|
41
44
|
)
|
42
45
|
|
43
|
-
from .types import AgentInput, AgentRole, ScenarioResult, ScriptStep
|
46
|
+
from .types import AgentInput, AgentRole, ChatCompletionMessageParamWithTrace, ScenarioResult, ScriptStep
|
44
47
|
from ._error_messages import agent_response_not_awaitable
|
45
48
|
from .cache import context_scenario
|
46
49
|
from .agent_adapter import AgentAdapter
|
@@ -62,6 +65,11 @@ from ._events import (
|
|
62
65
|
from rx.subject.subject import Subject
|
63
66
|
from rx.core.observable.observable import Observable
|
64
67
|
|
68
|
+
import litellm
|
69
|
+
import langwatch
|
70
|
+
import langwatch.telemetry.context
|
71
|
+
from langwatch.telemetry.tracing import LangWatchTrace
|
72
|
+
|
65
73
|
|
66
74
|
class ScenarioExecutor:
|
67
75
|
"""
|
@@ -101,6 +109,7 @@ class ScenarioExecutor:
|
|
101
109
|
_pending_agents_on_turn: Set[AgentAdapter] = set()
|
102
110
|
_agent_times: Dict[int, float] = {}
|
103
111
|
_events: Subject
|
112
|
+
_trace: LangWatchTrace
|
104
113
|
|
105
114
|
event_bus: ScenarioEventBus
|
106
115
|
|
@@ -157,7 +166,8 @@ class ScenarioExecutor:
|
|
157
166
|
)
|
158
167
|
self.config = (ScenarioConfig.default_config or ScenarioConfig()).merge(config)
|
159
168
|
|
160
|
-
self.
|
169
|
+
self.batch_run_id = get_batch_run_id()
|
170
|
+
self.scenario_set_id = set_id or "default"
|
161
171
|
|
162
172
|
# Create executor's own event stream
|
163
173
|
self._events = Subject()
|
@@ -166,9 +176,6 @@ class ScenarioExecutor:
|
|
166
176
|
self.event_bus = event_bus or ScenarioEventBus()
|
167
177
|
self.event_bus.subscribe_to_events(self._events)
|
168
178
|
|
169
|
-
self.batch_run_id = get_batch_run_id()
|
170
|
-
self.scenario_set_id = set_id or "default"
|
171
|
-
|
172
179
|
@property
|
173
180
|
def events(self) -> Observable:
|
174
181
|
"""Expose event stream for subscribers like the event bus."""
|
@@ -253,6 +260,8 @@ class ScenarioExecutor:
|
|
253
260
|
)
|
254
261
|
```
|
255
262
|
"""
|
263
|
+
message = cast(ChatCompletionMessageParamWithTrace, message)
|
264
|
+
message["trace_id"] = self._trace.trace_id
|
256
265
|
self._state.messages.append(message)
|
257
266
|
|
258
267
|
# Broadcast the message to other agents
|
@@ -263,6 +272,21 @@ class ScenarioExecutor:
|
|
263
272
|
self._pending_messages[idx] = []
|
264
273
|
self._pending_messages[idx].append(message)
|
265
274
|
|
275
|
+
# Update trace with input/output
|
276
|
+
if message["role"] == "user":
|
277
|
+
self._trace.update(input={"type": "text", "value": str(message["content"])})
|
278
|
+
elif message["role"] == "assistant":
|
279
|
+
self._trace.update(
|
280
|
+
output={
|
281
|
+
"type": "text",
|
282
|
+
"value": str(
|
283
|
+
message["content"]
|
284
|
+
if "content" in message
|
285
|
+
else json.dumps(message, cls=SerializableWithStringFallback)
|
286
|
+
),
|
287
|
+
}
|
288
|
+
)
|
289
|
+
|
266
290
|
def add_messages(
|
267
291
|
self,
|
268
292
|
messages: List[ChatCompletionMessageParam],
|
@@ -292,6 +316,21 @@ class ScenarioExecutor:
|
|
292
316
|
self.add_message(message, from_agent_idx)
|
293
317
|
|
294
318
|
def _new_turn(self):
|
319
|
+
if hasattr(self, "_trace") and self._trace is not None:
|
320
|
+
self._trace.__exit__(None, None, None)
|
321
|
+
|
322
|
+
self._trace = langwatch.trace(
|
323
|
+
name="Scenario Turn",
|
324
|
+
metadata={
|
325
|
+
"labels": ["scenario"],
|
326
|
+
"thread_id": self._state.thread_id,
|
327
|
+
"scenario.name": self.name,
|
328
|
+
"scenario.batch_id": self.batch_run_id,
|
329
|
+
"scenario.set_id": self.scenario_set_id,
|
330
|
+
"scenario.turn": self._state.current_turn,
|
331
|
+
},
|
332
|
+
).__enter__()
|
333
|
+
|
295
334
|
self._pending_agents_on_turn = set(self.agents)
|
296
335
|
self._pending_roles_on_turn = [
|
297
336
|
AgentRole.USER,
|
@@ -460,7 +499,7 @@ class ScenarioExecutor:
|
|
460
499
|
|
461
500
|
async def _call_agent(
|
462
501
|
self, idx: int, role: AgentRole, request_judgment: bool = False
|
463
|
-
) -> Union[List[ChatCompletionMessageParam], ScenarioResult]:
|
502
|
+
) -> Union[List[ChatCompletionMessageParam], ScenarioResult, None]:
|
464
503
|
agent = self.agents[idx]
|
465
504
|
|
466
505
|
if role == AgentRole.USER and self.config.debug:
|
@@ -482,67 +521,84 @@ class ScenarioExecutor:
|
|
482
521
|
ChatCompletionUserMessageParam(role="user", content=input_message)
|
483
522
|
]
|
484
523
|
|
485
|
-
with
|
486
|
-
|
487
|
-
|
488
|
-
|
489
|
-
|
490
|
-
|
491
|
-
|
492
|
-
|
493
|
-
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
warnings.
|
503
|
-
|
504
|
-
|
505
|
-
|
506
|
-
|
507
|
-
|
508
|
-
|
509
|
-
|
510
|
-
|
524
|
+
with self._trace.span(type="agent", name=f"{agent.__class__.__name__}.call") as span:
|
525
|
+
with show_spinner(
|
526
|
+
text=(
|
527
|
+
"Judging..."
|
528
|
+
if role == AgentRole.JUDGE
|
529
|
+
else f"{role.value if isinstance(role, AgentRole) else role}:"
|
530
|
+
),
|
531
|
+
color=(
|
532
|
+
"blue"
|
533
|
+
if role == AgentRole.AGENT
|
534
|
+
else "green" if role == AgentRole.USER else "yellow"
|
535
|
+
),
|
536
|
+
enabled=self.config.verbose,
|
537
|
+
):
|
538
|
+
start_time = time.time()
|
539
|
+
|
540
|
+
# Prevent pydantic validation warnings which should already be disabled
|
541
|
+
with warnings.catch_warnings():
|
542
|
+
warnings.simplefilter("ignore")
|
543
|
+
|
544
|
+
self._trace.autotrack_litellm_calls(litellm)
|
545
|
+
|
546
|
+
agent_response = agent.call(
|
547
|
+
AgentInput(
|
548
|
+
# TODO: test thread_id
|
549
|
+
thread_id=self._state.thread_id,
|
550
|
+
messages=cast(List[ChatCompletionMessageParam], self._state.messages),
|
551
|
+
new_messages=self._pending_messages.get(idx, []),
|
552
|
+
judgment_request=request_judgment,
|
553
|
+
scenario_state=self._state,
|
554
|
+
)
|
555
|
+
)
|
556
|
+
if not isinstance(agent_response, Awaitable):
|
557
|
+
raise Exception(
|
558
|
+
agent_response_not_awaitable(agent.__class__.__name__),
|
511
559
|
)
|
512
|
-
)
|
513
|
-
if not isinstance(agent_response, Awaitable):
|
514
|
-
raise Exception(
|
515
|
-
agent_response_not_awaitable(agent.__class__.__name__),
|
516
|
-
)
|
517
560
|
|
518
|
-
|
561
|
+
agent_response = await agent_response
|
519
562
|
|
520
|
-
|
521
|
-
|
522
|
-
|
563
|
+
if idx not in self._agent_times:
|
564
|
+
self._agent_times[idx] = 0
|
565
|
+
self._agent_times[idx] += time.time() - start_time
|
523
566
|
|
524
|
-
|
525
|
-
|
567
|
+
self._pending_messages[idx] = []
|
568
|
+
check_valid_return_type(agent_response, agent.__class__.__name__)
|
569
|
+
|
570
|
+
messages = []
|
571
|
+
if isinstance(agent_response, ScenarioResult):
|
572
|
+
# TODO: should be an event
|
573
|
+
span.add_evaluation(
|
574
|
+
name=f"{agent.__class__.__name__} Judgment",
|
575
|
+
status="processed",
|
576
|
+
passed=agent_response.success,
|
577
|
+
details=agent_response.reasoning,
|
578
|
+
score=(
|
579
|
+
len(agent_response.passed_criteria)
|
580
|
+
/ len(agent_response.failed_criteria)
|
581
|
+
if agent_response.failed_criteria
|
582
|
+
else 1.0
|
583
|
+
),
|
584
|
+
)
|
526
585
|
|
527
|
-
|
528
|
-
|
529
|
-
|
530
|
-
|
531
|
-
|
532
|
-
|
533
|
-
agent_response,
|
534
|
-
role="user" if role == AgentRole.USER else "assistant",
|
535
|
-
)
|
586
|
+
return agent_response
|
587
|
+
else:
|
588
|
+
messages = convert_agent_return_types_to_openai_messages(
|
589
|
+
agent_response,
|
590
|
+
role="user" if role == AgentRole.USER else "assistant",
|
591
|
+
)
|
536
592
|
|
537
|
-
|
593
|
+
self.add_messages(messages, from_agent_idx=idx)
|
538
594
|
|
539
|
-
|
540
|
-
|
541
|
-
|
542
|
-
|
543
|
-
|
595
|
+
if messages and self.config.verbose:
|
596
|
+
print_openai_messages(
|
597
|
+
self._scenario_name(),
|
598
|
+
[m for m in messages if m["role"] != "system"],
|
599
|
+
)
|
544
600
|
|
545
|
-
|
601
|
+
return messages
|
546
602
|
|
547
603
|
def _scenario_name(self):
|
548
604
|
if self.config.verbose == 2:
|
@@ -817,6 +873,7 @@ class ScenarioExecutor:
|
|
817
873
|
|
818
874
|
# Signal end of event stream
|
819
875
|
self._events.on_completed()
|
876
|
+
self._trace.__exit__(None, None, None)
|
820
877
|
|
821
878
|
|
822
879
|
async def run(
|
scenario/scenario_state.py
CHANGED
@@ -14,6 +14,7 @@ from openai.types.chat import (
|
|
14
14
|
)
|
15
15
|
from pydantic import BaseModel
|
16
16
|
|
17
|
+
from scenario.types import ChatCompletionMessageParamWithTrace
|
17
18
|
from scenario.config import ScenarioConfig
|
18
19
|
|
19
20
|
if TYPE_CHECKING:
|
@@ -70,7 +71,7 @@ class ScenarioState(BaseModel):
|
|
70
71
|
"""
|
71
72
|
|
72
73
|
description: str
|
73
|
-
messages: List[
|
74
|
+
messages: List[ChatCompletionMessageParamWithTrace]
|
74
75
|
thread_id: str
|
75
76
|
current_turn: int
|
76
77
|
config: ScenarioConfig
|
scenario/types.py
CHANGED
@@ -8,10 +8,20 @@ from typing import (
|
|
8
8
|
Callable,
|
9
9
|
List,
|
10
10
|
Optional,
|
11
|
+
TypeAlias,
|
11
12
|
Union,
|
12
13
|
)
|
13
14
|
|
14
|
-
from openai.types.chat import
|
15
|
+
from openai.types.chat import (
|
16
|
+
ChatCompletionMessageParam,
|
17
|
+
ChatCompletionUserMessageParam,
|
18
|
+
ChatCompletionToolMessageParam,
|
19
|
+
ChatCompletionUserMessageParam,
|
20
|
+
ChatCompletionSystemMessageParam,
|
21
|
+
ChatCompletionFunctionMessageParam,
|
22
|
+
ChatCompletionAssistantMessageParam,
|
23
|
+
ChatCompletionDeveloperMessageParam,
|
24
|
+
)
|
15
25
|
|
16
26
|
# Prevent circular imports + Pydantic breaking
|
17
27
|
if TYPE_CHECKING:
|
@@ -22,6 +32,48 @@ else:
|
|
22
32
|
ScenarioStateType = Any
|
23
33
|
|
24
34
|
|
35
|
+
# Since Python types do not support intersection, we need to wrap ALL the chat completion
|
36
|
+
# message types with the trace_id field
|
37
|
+
|
38
|
+
|
39
|
+
class ChatCompletionDeveloperMessageParamWithTrace(ChatCompletionDeveloperMessageParam):
|
40
|
+
trace_id: Optional[str]
|
41
|
+
|
42
|
+
|
43
|
+
class ChatCompletionSystemMessageParamWithTrace(ChatCompletionSystemMessageParam):
|
44
|
+
trace_id: Optional[str]
|
45
|
+
|
46
|
+
|
47
|
+
class ChatCompletionUserMessageParamWithTrace(ChatCompletionUserMessageParam):
|
48
|
+
trace_id: Optional[str]
|
49
|
+
|
50
|
+
|
51
|
+
class ChatCompletionAssistantMessageParamWithTrace(ChatCompletionAssistantMessageParam):
|
52
|
+
trace_id: Optional[str]
|
53
|
+
|
54
|
+
|
55
|
+
class ChatCompletionToolMessageParamWithTrace(ChatCompletionToolMessageParam):
|
56
|
+
trace_id: Optional[str]
|
57
|
+
|
58
|
+
|
59
|
+
class ChatCompletionFunctionMessageParamWithTrace(ChatCompletionFunctionMessageParam):
|
60
|
+
trace_id: Optional[str]
|
61
|
+
|
62
|
+
|
63
|
+
"""
|
64
|
+
A wrapper around ChatCompletionMessageParam that adds a trace_id field to be able to
|
65
|
+
tie back each message of the scenario run to a trace.
|
66
|
+
"""
|
67
|
+
ChatCompletionMessageParamWithTrace: TypeAlias = Union[
|
68
|
+
ChatCompletionDeveloperMessageParamWithTrace,
|
69
|
+
ChatCompletionSystemMessageParamWithTrace,
|
70
|
+
ChatCompletionUserMessageParamWithTrace,
|
71
|
+
ChatCompletionAssistantMessageParamWithTrace,
|
72
|
+
ChatCompletionToolMessageParamWithTrace,
|
73
|
+
ChatCompletionFunctionMessageParamWithTrace,
|
74
|
+
]
|
75
|
+
|
76
|
+
|
25
77
|
class AgentRole(Enum):
|
26
78
|
"""
|
27
79
|
Defines the different roles that agents can play in a scenario.
|
@@ -171,7 +223,7 @@ class ScenarioResult(BaseModel):
|
|
171
223
|
|
172
224
|
success: bool
|
173
225
|
# Prevent issues with slightly inconsistent message types for example when comming from Gemini right at the result level
|
174
|
-
messages: Annotated[List[
|
226
|
+
messages: Annotated[List[ChatCompletionMessageParamWithTrace], SkipValidation]
|
175
227
|
reasoning: Optional[str] = None
|
176
228
|
passed_criteria: List[str] = []
|
177
229
|
failed_criteria: List[str] = []
|
scenario/user_simulator_agent.py
CHANGED
@@ -10,7 +10,8 @@ conversation history.
|
|
10
10
|
import logging
|
11
11
|
from typing import Optional, cast
|
12
12
|
|
13
|
-
|
13
|
+
import litellm
|
14
|
+
from litellm import Choices
|
14
15
|
from litellm.files.main import ModelResponse
|
15
16
|
|
16
17
|
from scenario.cache import scenario_cache
|
@@ -228,7 +229,7 @@ Your goal (assistant) is to interact with the Agent Under Test (user) as if you
|
|
228
229
|
|
229
230
|
response = cast(
|
230
231
|
ModelResponse,
|
231
|
-
completion(
|
232
|
+
litellm.completion(
|
232
233
|
model=self.model,
|
233
234
|
messages=messages,
|
234
235
|
temperature=self.temperature,
|
File without changes
|
File without changes
|
File without changes
|