langwatch-scenario 0.7.9__py3-none-any.whl → 0.7.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: langwatch-scenario
3
- Version: 0.7.9
3
+ Version: 0.7.10
4
4
  Summary: The end-to-end agent testing library
5
5
  Author-email: LangWatch Team <support@langwatch.ai>
6
6
  License: MIT
@@ -14,7 +14,7 @@ Classifier: Programming Language :: Python :: 3.8
14
14
  Classifier: Programming Language :: Python :: 3.9
15
15
  Classifier: Programming Language :: Python :: 3.10
16
16
  Classifier: Programming Language :: Python :: 3.11
17
- Requires-Python: >=3.9
17
+ Requires-Python: >=3.10
18
18
  Description-Content-Type: text/markdown
19
19
  Requires-Dist: pytest>=8.1.1
20
20
  Requires-Dist: litellm>=1.49.0
@@ -31,6 +31,7 @@ Requires-Dist: httpx>=0.27.0
31
31
  Requires-Dist: rx>=3.2.0
32
32
  Requires-Dist: python-dateutil>=2.9.0.post0
33
33
  Requires-Dist: pydantic-settings>=2.9.1
34
+ Requires-Dist: langwatch>=0.2.19
34
35
  Provides-Extra: dev
35
36
  Requires-Dist: black; extra == "dev"
36
37
  Requires-Dist: isort; extra == "dev"
@@ -2,21 +2,21 @@ scenario/__init__.py,sha256=4WO8TjY8Lc0NhYL7b9LvaB1xCBqwUkLuI0uIA6PQP6c,4223
2
2
  scenario/_error_messages.py,sha256=QVFSbhzsVNGz2GOBOaoQFW6w6AOyZCWLTt0ySWPfnGw,3882
3
3
  scenario/agent_adapter.py,sha256=PoY2KQqYuqzIIb3-nhIU-MPXwHJc1vmwdweMy7ut-hk,4255
4
4
  scenario/cache.py,sha256=J6s6Sia_Ce6TrnsInlhfxm6SF8tygo3sH-_cQCRX1WA,6213
5
- scenario/judge_agent.py,sha256=TSwykEWhoBA9F__sUsSuUMpu7pOkT1lIJo8YlEj2eiA,16759
5
+ scenario/judge_agent.py,sha256=hHQ2nKsOgSyTtN0LdE6xIF0wZnnlYLN6RcxTPecFHDU,16770
6
6
  scenario/py.typed,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  scenario/pytest_plugin.py,sha256=wRCuGD9uwrrLt2fY15zK6mnmY9W_dO_m0WalPJYE5II,11491
8
- scenario/scenario_executor.py,sha256=_GRpFpw_WtgtaGpxWh0A0HNNf-aU78PdIiVdgEFm9MY,33136
9
- scenario/scenario_state.py,sha256=LWGqEQN-Yz0DIiC-TyMRHd-9rEiuBVUHKllMmKv-qGg,7029
8
+ scenario/scenario_executor.py,sha256=v41UgSHebosXf95FfYIeVUm6s4IbMP_U58FdGoZ_kZU,35653
9
+ scenario/scenario_state.py,sha256=R8PhPHW3obYo3DCjBH5XDdZ6bp4uol7wCXO8K2Tz30I,7101
10
10
  scenario/script.py,sha256=A0N5pP0l4FFn1xdKc78U_wkwWhEWH3EFeU_LRDtNyEI,12241
11
- scenario/types.py,sha256=qH5KFzJBDG1fEJB_qFRVtL3EZulxq3G1mztYczIzIAY,9613
12
- scenario/user_simulator_agent.py,sha256=kqnSd4_gytzEwtkc06r58UdE1EycZBzejRPzfORDjdo,9619
11
+ scenario/types.py,sha256=CRSCHUplXEXhj6EYQsncwJBzbd2128YTGlFxlk-rrG8,11193
12
+ scenario/user_simulator_agent.py,sha256=gXRaeoivEAcenIEqMDU6bWzv8cOrJaaooNrTdpC9TE4,9630
13
13
  scenario/_events/__init__.py,sha256=4cj6H9zuXzvWhT2P2JNdjWzeF1PUepTjqIDw85Vid9s,1500
14
- scenario/_events/event_alert_message_logger.py,sha256=n2W3uT8y4x6KKL3H9Ez6CfzJOFlvOfvjDKsdhHUJkxs,2787
14
+ scenario/_events/event_alert_message_logger.py,sha256=XcofGgXjeiTC75NPYheBpHxqA6R4pYAuHZa7-kH9Grg,2975
15
15
  scenario/_events/event_bus.py,sha256=IsKNsClF1JFYj728EcxX1hw_KbfDkfJq3Y2Kv4h94n4,9871
16
16
  scenario/_events/event_reporter.py,sha256=-6NNbBMy_FYr1O-1FuZ6eIUnLuI8NGRMUr0pybLJrCI,3873
17
17
  scenario/_events/events.py,sha256=UtEGY-_1B0LrwpgsNKgrvJBZhRtxuj3K_i6ZBfF7E4Q,6387
18
18
  scenario/_events/messages.py,sha256=quwP2OkeaGasNOoaV8GUeosZVKc5XDsde08T0xx_YQo,2297
19
- scenario/_events/utils.py,sha256=KKqWFGkj4XtofKxM2yi-DBhBQp8wQOdls48iPHGCmUY,3473
19
+ scenario/_events/utils.py,sha256=CRrdDHBD2ptcNIjzW0eEG1V5-Vw1gFnp_UTz5zMQ_Ak,4051
20
20
  scenario/_generated/langwatch_api_client/README.md,sha256=Az5f2L4ChOnG_ZtrdBagzRVgeTCtBkbD_S5cIeAry2o,5424
21
21
  scenario/_generated/langwatch_api_client/pyproject.toml,sha256=Z8wxuGp4H9BJYVVJB8diW7rRU9XYxtPfw9mU4_wq4cA,560
22
22
  scenario/_generated/langwatch_api_client/lang_watch_api_client/__init__.py,sha256=vVrn17y-3l3fOqeJk8aN3GlStRm2fo0f313l_0LtJNs,368
@@ -235,8 +235,8 @@ scenario/config/__init__.py,sha256=b2X_bqkIrd7jZY9dRrXk2wOqoPe87Nl_SRGuZhlolxA,1
235
235
  scenario/config/langwatch.py,sha256=ijWchFbUsLbQooAZmwyTw4rxfRLQseZ1GoVSiPPbzpw,1677
236
236
  scenario/config/model.py,sha256=T4HYA79CW1NxXDkFlyftYR6JzZcowbtIx0H-ijxRyfg,1297
237
237
  scenario/config/scenario.py,sha256=6jrtcm0Fo7FpxQta7QIKdGMgl7cXrn374Inzx29hRuk,5406
238
- langwatch_scenario-0.7.9.dist-info/METADATA,sha256=0s-yAn8iE1N-5dbqugYFpSl8btZrTyyDgWQDat8szxI,20030
239
- langwatch_scenario-0.7.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
240
- langwatch_scenario-0.7.9.dist-info/entry_points.txt,sha256=WlEnJ_gku0i18bIa3DSuGqXRX-QDQLe_s0YmRzK45TI,45
241
- langwatch_scenario-0.7.9.dist-info/top_level.txt,sha256=45Mn28aedJsetnBMB5xSmrJ-yo701QLH89Zlz4r1clE,9
242
- langwatch_scenario-0.7.9.dist-info/RECORD,,
238
+ langwatch_scenario-0.7.10.dist-info/METADATA,sha256=pbLZM8UXj1_1TWHjheHP6QREOvRWfX7nHEdfY2ZX4aA,20065
239
+ langwatch_scenario-0.7.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
240
+ langwatch_scenario-0.7.10.dist-info/entry_points.txt,sha256=WlEnJ_gku0i18bIa3DSuGqXRX-QDQLe_s0YmRzK45TI,45
241
+ langwatch_scenario-0.7.10.dist-info/top_level.txt,sha256=45Mn28aedJsetnBMB5xSmrJ-yo701QLH89Zlz4r1clE,9
242
+ langwatch_scenario-0.7.10.dist-info/RECORD,,
@@ -15,6 +15,7 @@ class EventAlertMessageLogger:
15
15
  """
16
16
 
17
17
  _shown_batch_ids: Set[str] = set()
18
+ _shown_watch_urls: Set[str] = set()
18
19
 
19
20
  def handle_greeting(self) -> None:
20
21
  """
@@ -40,6 +41,10 @@ class EventAlertMessageLogger:
40
41
  if self._is_greeting_disabled():
41
42
  return
42
43
 
44
+ if set_url in EventAlertMessageLogger._shown_watch_urls:
45
+ return
46
+
47
+ EventAlertMessageLogger._shown_watch_urls.add(set_url)
43
48
  self._display_watch_message(set_url)
44
49
 
45
50
  def _is_greeting_disabled(self) -> bool:
scenario/_events/utils.py CHANGED
@@ -1,5 +1,6 @@
1
1
  import warnings
2
- from openai.types.chat.chat_completion_message_param import ChatCompletionMessageParam
2
+
3
+ from ..types import ChatCompletionMessageParamWithTrace
3
4
  from .events import MessageType
4
5
  from .messages import (
5
6
  SystemMessage,
@@ -12,7 +13,10 @@ from .messages import (
12
13
  from typing import List
13
14
  from pksuid import PKSUID
14
15
 
15
- def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessageParam]) -> list[MessageType]:
16
+
17
+ def convert_messages_to_api_client_messages(
18
+ messages: list[ChatCompletionMessageParamWithTrace],
19
+ ) -> list[MessageType]:
16
20
  """
17
21
  Converts OpenAI ChatCompletionMessageParam messages to API client Message format.
18
22
 
@@ -33,7 +37,7 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
33
37
 
34
38
  for i, message in enumerate(messages):
35
39
  # Generate unique ID for each message
36
- message_id = message.get("id") or str(PKSUID('scenariomsg'))
40
+ message_id = message.get("id") or str(PKSUID("scenariomsg"))
37
41
 
38
42
  role = message.get("role")
39
43
  content = message.get("content")
@@ -41,11 +45,13 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
41
45
  if role == "user":
42
46
  if not content:
43
47
  raise ValueError(f"User message at index {i} missing required content")
44
- converted_messages.append(UserMessage(
48
+ message_ = UserMessage(
45
49
  id=message_id,
46
50
  role="user",
47
- content=str(content)
48
- ))
51
+ content=str(content),
52
+ )
53
+ message_.additional_properties = {"trace_id": message.get("trace_id")}
54
+ converted_messages.append(message_)
49
55
  elif role == "assistant":
50
56
  # Handle tool calls if present
51
57
  tool_calls = message.get("tool_calls")
@@ -53,44 +59,54 @@ def convert_messages_to_api_client_messages(messages: list[ChatCompletionMessage
53
59
 
54
60
  if tool_calls:
55
61
  for tool_call in tool_calls:
56
- api_tool_calls.append(ToolCall(
57
- id=tool_call.get("id", str(PKSUID('scenariotoolcall'))),
58
- type_="function",
59
- function=FunctionCall(
60
- name=tool_call["function"].get("name", "unknown"),
61
- arguments=tool_call["function"].get("arguments", "{}")
62
+ api_tool_calls.append(
63
+ ToolCall(
64
+ id=tool_call.get("id", str(PKSUID("scenariotoolcall"))),
65
+ type_="function",
66
+ function=FunctionCall(
67
+ name=tool_call["function"].get("name", "unknown"),
68
+ arguments=tool_call["function"].get("arguments", "{}"),
69
+ ),
62
70
  )
63
- ))
71
+ )
64
72
 
65
- converted_messages.append(AssistantMessage(
73
+ message_ = AssistantMessage(
66
74
  id=message_id,
67
75
  role="assistant",
68
76
  content=str(content),
69
- tool_calls=api_tool_calls
70
- ))
77
+ tool_calls=api_tool_calls,
78
+ )
79
+ message_.additional_properties = {"trace_id": message.get("trace_id")}
80
+ converted_messages.append(message_)
71
81
  elif role == "system":
72
82
  if not content:
73
- raise ValueError(f"System message at index {i} missing required content")
74
- converted_messages.append(SystemMessage(
75
- id=message_id,
76
- role="system",
77
- content=str(content)
78
- ))
83
+ raise ValueError(
84
+ f"System message at index {i} missing required content"
85
+ )
86
+ message_ = SystemMessage(id=message_id, role="system", content=str(content))
87
+ message_.additional_properties = {"trace_id": message.get("trace_id")}
88
+ converted_messages.append(message_)
79
89
  elif role == "tool":
80
90
  tool_call_id = message.get("tool_call_id")
81
91
  if not tool_call_id:
82
- warnings.warn(f"Tool message at index {i} missing required tool_call_id, skipping tool message")
92
+ warnings.warn(
93
+ f"Tool message at index {i} missing required tool_call_id, skipping tool message"
94
+ )
83
95
  continue
84
96
  if not content:
85
- warnings.warn(f"Tool message at index {i} missing required content, skipping tool message")
97
+ warnings.warn(
98
+ f"Tool message at index {i} missing required content, skipping tool message"
99
+ )
86
100
  continue
87
101
 
88
- converted_messages.append(ToolMessage(
102
+ message_ = ToolMessage(
89
103
  id=message_id,
90
104
  role="tool",
91
105
  content=str(content),
92
- tool_call_id=tool_call_id
93
- ))
106
+ tool_call_id=tool_call_id,
107
+ )
108
+ message_.additional_properties = {"trace_id": message.get("trace_id")}
109
+ converted_messages.append(message_)
94
110
  else:
95
111
  raise ValueError(f"Unsupported message role '{role}' at index {i}")
96
112
 
scenario/judge_agent.py CHANGED
@@ -12,7 +12,8 @@ import logging
12
12
  import re
13
13
  from typing import List, Optional, cast
14
14
 
15
- from litellm import Choices, completion
15
+ import litellm
16
+ from litellm import Choices
16
17
  from litellm.files.main import ModelResponse
17
18
 
18
19
  from scenario.cache import scenario_cache
@@ -356,7 +357,7 @@ if you don't have enough information to make a verdict, say inconclusive with ma
356
357
 
357
358
  response = cast(
358
359
  ModelResponse,
359
- completion(
360
+ litellm.completion(
360
361
  model=self.model,
361
362
  messages=messages,
362
363
  temperature=self.temperature,
@@ -6,6 +6,7 @@ of scenario tests, managing the interaction between user simulators, agents unde
6
6
  and judge agents to determine test success or failure.
7
7
  """
8
8
 
9
+ import json
9
10
  import sys
10
11
  from typing import (
11
12
  Awaitable,
@@ -17,6 +18,7 @@ from typing import (
17
18
  Tuple,
18
19
  Union,
19
20
  TypedDict,
21
+ cast,
20
22
  )
21
23
  import time
22
24
  import warnings
@@ -33,6 +35,7 @@ from scenario._utils import (
33
35
  await_if_awaitable,
34
36
  get_batch_run_id,
35
37
  generate_scenario_run_id,
38
+ SerializableWithStringFallback,
36
39
  )
37
40
  from openai.types.chat import (
38
41
  ChatCompletionMessageParam,
@@ -40,7 +43,7 @@ from openai.types.chat import (
40
43
  ChatCompletionAssistantMessageParam,
41
44
  )
42
45
 
43
- from .types import AgentInput, AgentRole, ScenarioResult, ScriptStep
46
+ from .types import AgentInput, AgentRole, ChatCompletionMessageParamWithTrace, ScenarioResult, ScriptStep
44
47
  from ._error_messages import agent_response_not_awaitable
45
48
  from .cache import context_scenario
46
49
  from .agent_adapter import AgentAdapter
@@ -62,6 +65,11 @@ from ._events import (
62
65
  from rx.subject.subject import Subject
63
66
  from rx.core.observable.observable import Observable
64
67
 
68
+ import litellm
69
+ import langwatch
70
+ import langwatch.telemetry.context
71
+ from langwatch.telemetry.tracing import LangWatchTrace
72
+
65
73
 
66
74
  class ScenarioExecutor:
67
75
  """
@@ -101,6 +109,7 @@ class ScenarioExecutor:
101
109
  _pending_agents_on_turn: Set[AgentAdapter] = set()
102
110
  _agent_times: Dict[int, float] = {}
103
111
  _events: Subject
112
+ _trace: LangWatchTrace
104
113
 
105
114
  event_bus: ScenarioEventBus
106
115
 
@@ -157,7 +166,8 @@ class ScenarioExecutor:
157
166
  )
158
167
  self.config = (ScenarioConfig.default_config or ScenarioConfig()).merge(config)
159
168
 
160
- self.reset()
169
+ self.batch_run_id = get_batch_run_id()
170
+ self.scenario_set_id = set_id or "default"
161
171
 
162
172
  # Create executor's own event stream
163
173
  self._events = Subject()
@@ -166,9 +176,6 @@ class ScenarioExecutor:
166
176
  self.event_bus = event_bus or ScenarioEventBus()
167
177
  self.event_bus.subscribe_to_events(self._events)
168
178
 
169
- self.batch_run_id = get_batch_run_id()
170
- self.scenario_set_id = set_id or "default"
171
-
172
179
  @property
173
180
  def events(self) -> Observable:
174
181
  """Expose event stream for subscribers like the event bus."""
@@ -253,6 +260,8 @@ class ScenarioExecutor:
253
260
  )
254
261
  ```
255
262
  """
263
+ message = cast(ChatCompletionMessageParamWithTrace, message)
264
+ message["trace_id"] = self._trace.trace_id
256
265
  self._state.messages.append(message)
257
266
 
258
267
  # Broadcast the message to other agents
@@ -263,6 +272,21 @@ class ScenarioExecutor:
263
272
  self._pending_messages[idx] = []
264
273
  self._pending_messages[idx].append(message)
265
274
 
275
+ # Update trace with input/output
276
+ if message["role"] == "user":
277
+ self._trace.update(input={"type": "text", "value": str(message["content"])})
278
+ elif message["role"] == "assistant":
279
+ self._trace.update(
280
+ output={
281
+ "type": "text",
282
+ "value": str(
283
+ message["content"]
284
+ if "content" in message
285
+ else json.dumps(message, cls=SerializableWithStringFallback)
286
+ ),
287
+ }
288
+ )
289
+
266
290
  def add_messages(
267
291
  self,
268
292
  messages: List[ChatCompletionMessageParam],
@@ -292,6 +316,21 @@ class ScenarioExecutor:
292
316
  self.add_message(message, from_agent_idx)
293
317
 
294
318
  def _new_turn(self):
319
+ if hasattr(self, "_trace") and self._trace is not None:
320
+ self._trace.__exit__(None, None, None)
321
+
322
+ self._trace = langwatch.trace(
323
+ name="Scenario Turn",
324
+ metadata={
325
+ "labels": ["scenario"],
326
+ "thread_id": self._state.thread_id,
327
+ "scenario.name": self.name,
328
+ "scenario.batch_id": self.batch_run_id,
329
+ "scenario.set_id": self.scenario_set_id,
330
+ "scenario.turn": self._state.current_turn,
331
+ },
332
+ ).__enter__()
333
+
295
334
  self._pending_agents_on_turn = set(self.agents)
296
335
  self._pending_roles_on_turn = [
297
336
  AgentRole.USER,
@@ -460,7 +499,7 @@ class ScenarioExecutor:
460
499
 
461
500
  async def _call_agent(
462
501
  self, idx: int, role: AgentRole, request_judgment: bool = False
463
- ) -> Union[List[ChatCompletionMessageParam], ScenarioResult]:
502
+ ) -> Union[List[ChatCompletionMessageParam], ScenarioResult, None]:
464
503
  agent = self.agents[idx]
465
504
 
466
505
  if role == AgentRole.USER and self.config.debug:
@@ -482,67 +521,84 @@ class ScenarioExecutor:
482
521
  ChatCompletionUserMessageParam(role="user", content=input_message)
483
522
  ]
484
523
 
485
- with show_spinner(
486
- text=(
487
- "Judging..."
488
- if role == AgentRole.JUDGE
489
- else f"{role.value if isinstance(role, AgentRole) else role}:"
490
- ),
491
- color=(
492
- "blue"
493
- if role == AgentRole.AGENT
494
- else "green" if role == AgentRole.USER else "yellow"
495
- ),
496
- enabled=self.config.verbose,
497
- ):
498
- start_time = time.time()
499
-
500
- # Prevent pydantic validation warnings which should already be disabled
501
- with warnings.catch_warnings():
502
- warnings.simplefilter("ignore")
503
- agent_response = agent.call(
504
- AgentInput(
505
- # TODO: test thread_id
506
- thread_id=self._state.thread_id,
507
- messages=self._state.messages,
508
- new_messages=self._pending_messages.get(idx, []),
509
- judgment_request=request_judgment,
510
- scenario_state=self._state,
524
+ with self._trace.span(type="agent", name=f"{agent.__class__.__name__}.call") as span:
525
+ with show_spinner(
526
+ text=(
527
+ "Judging..."
528
+ if role == AgentRole.JUDGE
529
+ else f"{role.value if isinstance(role, AgentRole) else role}:"
530
+ ),
531
+ color=(
532
+ "blue"
533
+ if role == AgentRole.AGENT
534
+ else "green" if role == AgentRole.USER else "yellow"
535
+ ),
536
+ enabled=self.config.verbose,
537
+ ):
538
+ start_time = time.time()
539
+
540
+ # Prevent pydantic validation warnings which should already be disabled
541
+ with warnings.catch_warnings():
542
+ warnings.simplefilter("ignore")
543
+
544
+ self._trace.autotrack_litellm_calls(litellm)
545
+
546
+ agent_response = agent.call(
547
+ AgentInput(
548
+ # TODO: test thread_id
549
+ thread_id=self._state.thread_id,
550
+ messages=cast(List[ChatCompletionMessageParam], self._state.messages),
551
+ new_messages=self._pending_messages.get(idx, []),
552
+ judgment_request=request_judgment,
553
+ scenario_state=self._state,
554
+ )
555
+ )
556
+ if not isinstance(agent_response, Awaitable):
557
+ raise Exception(
558
+ agent_response_not_awaitable(agent.__class__.__name__),
511
559
  )
512
- )
513
- if not isinstance(agent_response, Awaitable):
514
- raise Exception(
515
- agent_response_not_awaitable(agent.__class__.__name__),
516
- )
517
560
 
518
- agent_response = await agent_response
561
+ agent_response = await agent_response
519
562
 
520
- if idx not in self._agent_times:
521
- self._agent_times[idx] = 0
522
- self._agent_times[idx] += time.time() - start_time
563
+ if idx not in self._agent_times:
564
+ self._agent_times[idx] = 0
565
+ self._agent_times[idx] += time.time() - start_time
523
566
 
524
- self._pending_messages[idx] = []
525
- check_valid_return_type(agent_response, agent.__class__.__name__)
567
+ self._pending_messages[idx] = []
568
+ check_valid_return_type(agent_response, agent.__class__.__name__)
569
+
570
+ messages = []
571
+ if isinstance(agent_response, ScenarioResult):
572
+ # TODO: should be an event
573
+ span.add_evaluation(
574
+ name=f"{agent.__class__.__name__} Judgment",
575
+ status="processed",
576
+ passed=agent_response.success,
577
+ details=agent_response.reasoning,
578
+ score=(
579
+ len(agent_response.passed_criteria)
580
+ / len(agent_response.failed_criteria)
581
+ if agent_response.failed_criteria
582
+ else 1.0
583
+ ),
584
+ )
526
585
 
527
- messages = []
528
- if isinstance(agent_response, ScenarioResult):
529
- # TODO: should be an event
530
- return agent_response
531
- else:
532
- messages = convert_agent_return_types_to_openai_messages(
533
- agent_response,
534
- role="user" if role == AgentRole.USER else "assistant",
535
- )
586
+ return agent_response
587
+ else:
588
+ messages = convert_agent_return_types_to_openai_messages(
589
+ agent_response,
590
+ role="user" if role == AgentRole.USER else "assistant",
591
+ )
536
592
 
537
- self.add_messages(messages, from_agent_idx=idx)
593
+ self.add_messages(messages, from_agent_idx=idx)
538
594
 
539
- if messages and self.config.verbose:
540
- print_openai_messages(
541
- self._scenario_name(),
542
- [m for m in messages if m["role"] != "system"],
543
- )
595
+ if messages and self.config.verbose:
596
+ print_openai_messages(
597
+ self._scenario_name(),
598
+ [m for m in messages if m["role"] != "system"],
599
+ )
544
600
 
545
- return messages
601
+ return messages
546
602
 
547
603
  def _scenario_name(self):
548
604
  if self.config.verbose == 2:
@@ -817,6 +873,7 @@ class ScenarioExecutor:
817
873
 
818
874
  # Signal end of event stream
819
875
  self._events.on_completed()
876
+ self._trace.__exit__(None, None, None)
820
877
 
821
878
 
822
879
  async def run(
@@ -14,6 +14,7 @@ from openai.types.chat import (
14
14
  )
15
15
  from pydantic import BaseModel
16
16
 
17
+ from scenario.types import ChatCompletionMessageParamWithTrace
17
18
  from scenario.config import ScenarioConfig
18
19
 
19
20
  if TYPE_CHECKING:
@@ -70,7 +71,7 @@ class ScenarioState(BaseModel):
70
71
  """
71
72
 
72
73
  description: str
73
- messages: List[ChatCompletionMessageParam]
74
+ messages: List[ChatCompletionMessageParamWithTrace]
74
75
  thread_id: str
75
76
  current_turn: int
76
77
  config: ScenarioConfig
scenario/types.py CHANGED
@@ -8,10 +8,20 @@ from typing import (
8
8
  Callable,
9
9
  List,
10
10
  Optional,
11
+ TypeAlias,
11
12
  Union,
12
13
  )
13
14
 
14
- from openai.types.chat import ChatCompletionMessageParam, ChatCompletionUserMessageParam
15
+ from openai.types.chat import (
16
+ ChatCompletionMessageParam,
17
+ ChatCompletionUserMessageParam,
18
+ ChatCompletionToolMessageParam,
19
+ ChatCompletionUserMessageParam,
20
+ ChatCompletionSystemMessageParam,
21
+ ChatCompletionFunctionMessageParam,
22
+ ChatCompletionAssistantMessageParam,
23
+ ChatCompletionDeveloperMessageParam,
24
+ )
15
25
 
16
26
  # Prevent circular imports + Pydantic breaking
17
27
  if TYPE_CHECKING:
@@ -22,6 +32,48 @@ else:
22
32
  ScenarioStateType = Any
23
33
 
24
34
 
35
+ # Since Python types do not support intersection, we need to wrap ALL the chat completion
36
+ # message types with the trace_id field
37
+
38
+
39
+ class ChatCompletionDeveloperMessageParamWithTrace(ChatCompletionDeveloperMessageParam):
40
+ trace_id: Optional[str]
41
+
42
+
43
+ class ChatCompletionSystemMessageParamWithTrace(ChatCompletionSystemMessageParam):
44
+ trace_id: Optional[str]
45
+
46
+
47
+ class ChatCompletionUserMessageParamWithTrace(ChatCompletionUserMessageParam):
48
+ trace_id: Optional[str]
49
+
50
+
51
+ class ChatCompletionAssistantMessageParamWithTrace(ChatCompletionAssistantMessageParam):
52
+ trace_id: Optional[str]
53
+
54
+
55
+ class ChatCompletionToolMessageParamWithTrace(ChatCompletionToolMessageParam):
56
+ trace_id: Optional[str]
57
+
58
+
59
+ class ChatCompletionFunctionMessageParamWithTrace(ChatCompletionFunctionMessageParam):
60
+ trace_id: Optional[str]
61
+
62
+
63
+ """
64
+ A wrapper around ChatCompletionMessageParam that adds a trace_id field to be able to
65
+ tie back each message of the scenario run to a trace.
66
+ """
67
+ ChatCompletionMessageParamWithTrace: TypeAlias = Union[
68
+ ChatCompletionDeveloperMessageParamWithTrace,
69
+ ChatCompletionSystemMessageParamWithTrace,
70
+ ChatCompletionUserMessageParamWithTrace,
71
+ ChatCompletionAssistantMessageParamWithTrace,
72
+ ChatCompletionToolMessageParamWithTrace,
73
+ ChatCompletionFunctionMessageParamWithTrace,
74
+ ]
75
+
76
+
25
77
  class AgentRole(Enum):
26
78
  """
27
79
  Defines the different roles that agents can play in a scenario.
@@ -171,7 +223,7 @@ class ScenarioResult(BaseModel):
171
223
 
172
224
  success: bool
173
225
  # Prevent issues with slightly inconsistent message types for example when comming from Gemini right at the result level
174
- messages: Annotated[List[ChatCompletionMessageParam], SkipValidation]
226
+ messages: Annotated[List[ChatCompletionMessageParamWithTrace], SkipValidation]
175
227
  reasoning: Optional[str] = None
176
228
  passed_criteria: List[str] = []
177
229
  failed_criteria: List[str] = []
@@ -10,7 +10,8 @@ conversation history.
10
10
  import logging
11
11
  from typing import Optional, cast
12
12
 
13
- from litellm import Choices, completion
13
+ import litellm
14
+ from litellm import Choices
14
15
  from litellm.files.main import ModelResponse
15
16
 
16
17
  from scenario.cache import scenario_cache
@@ -228,7 +229,7 @@ Your goal (assistant) is to interact with the Agent Under Test (user) as if you
228
229
 
229
230
  response = cast(
230
231
  ModelResponse,
231
- completion(
232
+ litellm.completion(
232
233
  model=self.model,
233
234
  messages=messages,
234
235
  temperature=self.temperature,