langwatch-scenario 0.2.0__py3-none-any.whl → 0.3.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,40 +2,201 @@
2
2
  ScenarioExecutor module: holds the scenario execution logic and state, orchestrating the conversation between the testing agent and the agent under test.
3
3
  """
4
4
 
5
- import json
6
5
  import sys
7
- from typing import TYPE_CHECKING, Awaitable, Dict, List, Any, Optional, Union
6
+ from typing import (
7
+ TYPE_CHECKING,
8
+ Awaitable,
9
+ Callable,
10
+ Dict,
11
+ List,
12
+ Any,
13
+ Optional,
14
+ Set,
15
+ Tuple,
16
+ Union,
17
+ )
8
18
  import time
9
19
  import termcolor
10
20
 
11
- from scenario.error_messages import message_return_error_message
12
- from scenario.utils import print_openai_messages, safe_attr_or_key, safe_list_at, show_spinner
13
- from openai.types.chat import ChatCompletionMessageParam
14
-
15
- from .result import ScenarioResult
16
- from .error_messages import default_config_error_message
21
+ from scenario.utils import (
22
+ await_if_awaitable,
23
+ check_valid_return_type,
24
+ convert_agent_return_types_to_openai_messages,
25
+ print_openai_messages,
26
+ show_spinner,
27
+ )
28
+ from openai.types.chat import (
29
+ ChatCompletionMessageParam,
30
+ ChatCompletionUserMessageParam,
31
+ ChatCompletionMessageToolCallParam,
32
+ )
33
+
34
+ from .types import AgentInput, ScenarioAgentRole, ScenarioResult, ScriptStep
35
+ from .error_messages import agent_response_not_awaitable
17
36
  from .cache import context_scenario
37
+ from .scenario_agent_adapter import ScenarioAgentAdapter
38
+ from pksuid import PKSUID
18
39
 
19
40
  if TYPE_CHECKING:
20
41
  from scenario.scenario import Scenario
21
42
 
22
43
 
23
-
24
44
  class ScenarioExecutor:
25
- def __init__(self, scenario: "Scenario"):
45
+ scenario: "Scenario"
46
+ messages: List[ChatCompletionMessageParam]
47
+ thread_id: str
48
+ current_turn: int
49
+
50
+ _context: Optional[Dict[str, Any]]
51
+ _script: List[ScriptStep]
52
+ _agents: List[ScenarioAgentAdapter]
53
+ _total_start_time: float
54
+ _pending_messages: Dict[int, List[ChatCompletionMessageParam]]
55
+
56
+ _pending_roles_on_turn: List[ScenarioAgentRole] = []
57
+ _pending_agents_on_turn: Set[ScenarioAgentAdapter] = set()
58
+ _agent_times: Dict[int, float] = {}
59
+
60
+ def __init__(
61
+ self,
62
+ scenario: "Scenario",
63
+ context: Optional[Dict[str, Any]] = None,
64
+ script: Optional[List[ScriptStep]] = None,
65
+ ):
66
+ super().__init__()
67
+
26
68
  self.scenario = scenario.model_copy()
69
+ self._context = context
70
+ self._script = script or [scenario.proceed()]
71
+ self.current_turn = 0
72
+ self.reset()
73
+
74
+ def reset(self):
75
+ self.messages = []
76
+ self._agents = []
77
+ self._pending_messages = {}
78
+ self.thread_id = str(PKSUID("thread"))
79
+ self._total_start_time = time.time()
80
+ self._agent_times = {}
81
+
82
+ for AgentClass in self.scenario.agents:
83
+ self._agents.append(
84
+ AgentClass(
85
+ input=AgentInput(
86
+ thread_id=self.thread_id,
87
+ messages=[],
88
+ new_messages=[],
89
+ context=self._context or {},
90
+ requested_role=list(AgentClass.roles)[0],
91
+ scenario_state=self,
92
+ )
93
+ )
94
+ )
95
+
96
+ self._new_turn()
97
+ self.current_turn = 0
27
98
 
28
- testing_agent = scenario.testing_agent
29
- if not testing_agent or not testing_agent.model:
30
- raise Exception(default_config_error_message)
31
- self.testing_agent = testing_agent
99
+ context_scenario.set(self.scenario)
100
+
101
+ def add_message(
102
+ self, message: ChatCompletionMessageParam, from_agent_idx: Optional[int] = None
103
+ ):
104
+ self.messages.append(message)
32
105
 
33
- self.conversation: List[Dict[str, Any]] = []
106
+ # Broadcast the message to other agents
107
+ for idx, _ in enumerate(self._agents):
108
+ if idx == from_agent_idx:
109
+ continue
110
+ if idx not in self._pending_messages:
111
+ self._pending_messages[idx] = []
112
+ self._pending_messages[idx].append(message)
34
113
 
35
- async def run(
114
+ def add_messages(
36
115
  self,
37
- context: Optional[Dict[str, Any]] = None,
38
- ) -> ScenarioResult:
116
+ messages: List[ChatCompletionMessageParam],
117
+ from_agent_idx: Optional[int] = None,
118
+ ):
119
+ for message in messages:
120
+ self.add_message(message, from_agent_idx)
121
+
122
+ def _new_turn(self):
123
+ self._pending_agents_on_turn = set(self._agents)
124
+ self._pending_roles_on_turn = [
125
+ ScenarioAgentRole.USER,
126
+ ScenarioAgentRole.AGENT,
127
+ ScenarioAgentRole.JUDGE,
128
+ ]
129
+ self.current_turn += 1
130
+
131
+ async def step(self) -> Union[List[ChatCompletionMessageParam], ScenarioResult]:
132
+ result = await self._step()
133
+ if result is None:
134
+ raise ValueError("No result from step")
135
+ return result
136
+
137
+ async def _step(
138
+ self,
139
+ go_to_next_turn=True,
140
+ on_turn: Optional[
141
+ Union[
142
+ Callable[["ScenarioExecutor"], None],
143
+ Callable[["ScenarioExecutor"], Awaitable[None]],
144
+ ]
145
+ ] = None,
146
+ ) -> Union[List[ChatCompletionMessageParam], ScenarioResult, None]:
147
+ if len(self._pending_roles_on_turn) == 0:
148
+ if not go_to_next_turn:
149
+ return None
150
+
151
+ self._new_turn()
152
+
153
+ if on_turn:
154
+ await await_if_awaitable(on_turn(self))
155
+
156
+ if self.current_turn >= (self.scenario.max_turns or 10):
157
+ return self._reached_max_turns()
158
+
159
+ current_role = self._pending_roles_on_turn[0]
160
+ idx, next_agent = self._next_agent_for_role(current_role)
161
+ if not next_agent:
162
+ self._pending_roles_on_turn.pop(0)
163
+ return await self._step(go_to_next_turn=go_to_next_turn, on_turn=on_turn)
164
+
165
+ self._pending_agents_on_turn.remove(next_agent)
166
+ return await self._call_agent(idx, role=current_role)
167
+
168
+ def _next_agent_for_role(
169
+ self, role: ScenarioAgentRole
170
+ ) -> Tuple[int, Optional[ScenarioAgentAdapter]]:
171
+ for idx, agent in enumerate(self._agents):
172
+ if role in agent.roles and agent in self._pending_agents_on_turn:
173
+ return idx, agent
174
+ return -1, None
175
+
176
+ def _reached_max_turns(self, error_message: Optional[str] = None) -> ScenarioResult:
177
+ # If we reached max turns without conclusion, fail the test
178
+ agent_roles_agents_idx = [
179
+ idx
180
+ for idx, agent in enumerate(self._agents)
181
+ if ScenarioAgentRole.AGENT in agent.roles
182
+ ]
183
+ agent_times = [
184
+ self._agent_times[idx]
185
+ for idx in agent_roles_agents_idx
186
+ if idx in self._agent_times
187
+ ]
188
+ agent_time = sum(agent_times)
189
+
190
+ return ScenarioResult(
191
+ success=False,
192
+ messages=self.messages,
193
+ reasoning=error_message
194
+ or f"Reached maximum turns ({self.scenario.max_turns or 10}) without conclusion",
195
+ total_time=time.time() - self._total_start_time,
196
+ agent_time=agent_time,
197
+ )
198
+
199
+ async def run(self) -> ScenarioResult:
39
200
  """
40
201
  Run a scenario against the agent under test.
41
202
 
@@ -49,156 +210,257 @@ class ScenarioExecutor:
49
210
  if self.scenario.verbose:
50
211
  print("") # new line
51
212
 
52
- # Run the initial testing agent prompt to get started
53
- total_start_time = time.time()
54
- context_scenario.set(self.scenario)
55
- next_message = self._generate_next_message(
56
- self.scenario, self.conversation, first_message=True
57
- )
213
+ self.reset()
58
214
 
59
- if isinstance(next_message, ScenarioResult):
60
- raise Exception(
61
- "Unexpectedly generated a ScenarioResult for the initial message",
62
- next_message.__repr__(),
63
- )
64
- elif self.scenario.verbose:
65
- print(self._scenario_name() + termcolor.colored("User:", "green"), next_message)
215
+ for script_step in self._script:
216
+ callable = script_step(self)
217
+ if isinstance(callable, Awaitable):
218
+ result = await callable
219
+ else:
220
+ result = callable
66
221
 
67
- # Execute the conversation
68
- current_turn = 0
69
- max_turns = self.scenario.max_turns or 10
70
- agent_time = 0
222
+ if isinstance(result, ScenarioResult):
223
+ return result
71
224
 
72
- # Start the test with the initial message
73
- while current_turn < max_turns:
74
- # Record the testing agent's message
75
- self.conversation.append({"role": "user", "content": next_message})
225
+ return self._reached_max_turns(
226
+ """Reached end of script without conclusion, add one of the following to the end of the script:
76
227
 
77
- # Get response from the agent under test
78
- start_time = time.time()
228
+ - `scenario.proceed()` to let the simulation continue to play out
229
+ - `scenario.judge()` to force criteria judgement
230
+ - `scenario.succeed()` or `scenario.fail()` to end the test with an explicit result
231
+ """
232
+ )
79
233
 
80
- context_scenario.set(self.scenario)
81
- with show_spinner(text="Agent:", color="blue", enabled=self.scenario.verbose):
82
- agent_response = self.scenario.agent(next_message, context)
83
- if isinstance(agent_response, Awaitable):
84
- agent_response = await agent_response
234
+ async def _call_agent(
235
+ self, idx: int, role: ScenarioAgentRole
236
+ ) -> Union[List[ChatCompletionMessageParam], ScenarioResult]:
237
+ agent = self._agents[idx]
85
238
 
86
- has_valid_message = (
87
- "message" in agent_response
88
- and type(agent_response["message"]) is str
89
- and agent_response["message"] is not None
239
+ if role == ScenarioAgentRole.USER and self.scenario.debug:
240
+ print(
241
+ f"\n{self._scenario_name()}{termcolor.colored('[Debug Mode]', 'yellow')} Press enter to continue or type a message to send"
90
242
  )
91
- has_valid_messages = (
92
- "messages" in agent_response
93
- and isinstance(agent_response["messages"], list)
94
- and all(
95
- "role" in msg or hasattr(msg, "role")
96
- for msg in agent_response["messages"]
243
+ input_message = input(
244
+ self._scenario_name() + termcolor.colored("User: ", "green")
245
+ )
246
+
247
+ # Clear the input prompt lines completely
248
+ for _ in range(3):
249
+ sys.stdout.write("\033[F") # Move up to the input line
250
+ sys.stdout.write("\033[2K") # Clear the entire input line
251
+ sys.stdout.flush() # Make sure the clearing is visible
252
+
253
+ if input_message:
254
+ return [
255
+ ChatCompletionUserMessageParam(role="user", content=input_message)
256
+ ]
257
+
258
+ with show_spinner(
259
+ text=(
260
+ "Judging..."
261
+ if role == ScenarioAgentRole.JUDGE
262
+ else f"{role.value if isinstance(role, ScenarioAgentRole) else role}:"
263
+ ),
264
+ color=(
265
+ "blue"
266
+ if role == ScenarioAgentRole.AGENT
267
+ else "green" if role == ScenarioAgentRole.USER else "yellow"
268
+ ),
269
+ enabled=self.scenario.verbose,
270
+ ):
271
+ start_time = time.time()
272
+
273
+ agent_response = agent.call(
274
+ AgentInput(
275
+ # TODO: test thread_id
276
+ thread_id=self.thread_id,
277
+ messages=self.messages,
278
+ new_messages=self._pending_messages.get(idx, []),
279
+ # TODO: test context
280
+ context=self._context or {},
281
+ requested_role=role,
282
+ scenario_state=self,
97
283
  )
98
284
  )
99
- if not has_valid_message and not has_valid_messages:
100
- raise Exception(message_return_error_message(agent_response))
285
+ if not isinstance(agent_response, Awaitable):
286
+ raise Exception(
287
+ agent_response_not_awaitable(agent.__class__.__name__),
288
+ )
101
289
 
102
- messages: list[ChatCompletionMessageParam] = []
103
- if has_valid_messages and len(agent_response["messages"]) > 0:
104
- messages = agent_response["messages"]
290
+ agent_response = await agent_response
105
291
 
106
- # Drop the first messages both if they are system or user messages
107
- if safe_attr_or_key(safe_list_at(messages, 0), "role") == "system":
108
- messages = messages[1:]
109
- if safe_attr_or_key(safe_list_at(messages, 0), "role") == "user":
110
- messages = messages[1:]
292
+ if idx not in self._agent_times:
293
+ self._agent_times[idx] = 0
294
+ self._agent_times[idx] += time.time() - start_time
111
295
 
112
- if has_valid_message and self.scenario.verbose:
113
- print(self._scenario_name() + termcolor.colored("Agent:", "blue"), agent_response["message"])
296
+ self._pending_messages[idx] = []
297
+ check_valid_return_type(agent_response, agent.__class__.__name__)
114
298
 
115
- if messages and self.scenario.verbose:
116
- print_openai_messages(self._scenario_name(), messages)
117
-
118
- if (
119
- self.scenario.verbose
120
- and "extra" in agent_response
121
- and len(agent_response["extra"].keys()) > 0
122
- ):
123
- print(
124
- termcolor.colored(
125
- "Extra:" + json.dumps(agent_response["extra"]),
126
- "magenta",
127
- )
299
+ messages = []
300
+ if isinstance(agent_response, ScenarioResult):
301
+ # TODO: should be an event
302
+ return agent_response
303
+ else:
304
+ messages = convert_agent_return_types_to_openai_messages(
305
+ agent_response,
306
+ role="user" if role == ScenarioAgentRole.USER else "assistant",
128
307
  )
129
- response_time = time.time() - start_time
130
- agent_time += response_time
131
-
132
- if messages:
133
- self.conversation.extend(agent_response["messages"])
134
- if "message" in agent_response:
135
- self.conversation.append(
136
- {"role": "assistant", "content": agent_response["message"]}
137
- )
138
- if "extra" in agent_response:
139
- self.conversation.append(
140
- {
141
- "role": "assistant",
142
- "content": json.dumps(agent_response["extra"]),
143
- }
308
+
309
+ self.add_messages(messages, from_agent_idx=idx)
310
+
311
+ if messages and self.scenario.verbose:
312
+ print_openai_messages(
313
+ self._scenario_name(),
314
+ [m for m in messages if m["role"] != "system"],
144
315
  )
145
316
 
146
- # Generate the next message OR finish the test based on the agent's evaluation
147
- result = self._generate_next_message(
148
- self.scenario,
149
- self.conversation,
150
- last_message=current_turn == max_turns - 1,
317
+ return messages
318
+
319
+ def _scenario_name(self):
320
+ if self.scenario.verbose == 2:
321
+ return termcolor.colored(f"[Scenario: {self.scenario.name}] ", "yellow")
322
+ else:
323
+ return ""
324
+
325
+ # State access utils
326
+
327
+ def last_message(self) -> ChatCompletionMessageParam:
328
+ if len(self.messages) == 0:
329
+ raise ValueError("No messages found")
330
+ return self.messages[-1]
331
+
332
+ def last_user_message(self) -> ChatCompletionUserMessageParam:
333
+ user_messages = [m for m in self.messages if m["role"] == "user"]
334
+ if not user_messages:
335
+ raise ValueError("No user messages found")
336
+ return user_messages[-1]
337
+
338
+ def last_tool_call(
339
+ self, tool_name: str
340
+ ) -> Optional[ChatCompletionMessageToolCallParam]:
341
+ for message in reversed(self.messages):
342
+ if message["role"] == "assistant" and "tool_calls" in message:
343
+ for tool_call in message["tool_calls"]:
344
+ if tool_call["function"]["name"] == tool_name:
345
+ return tool_call
346
+ return None
347
+
348
+ def has_tool_call(self, tool_name: str) -> bool:
349
+ return self.last_tool_call(tool_name) is not None
350
+
351
+ # Scripting utils
352
+
353
+ async def message(self, message: ChatCompletionMessageParam) -> None:
354
+ if message["role"] == "user":
355
+ await self._script_call_agent(ScenarioAgentRole.USER, message)
356
+ elif message["role"] == "assistant":
357
+ await self._script_call_agent(ScenarioAgentRole.AGENT, message)
358
+ else:
359
+ self.add_message(message)
360
+
361
+ async def user(
362
+ self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
363
+ ) -> None:
364
+ await self._script_call_agent(ScenarioAgentRole.USER, content)
365
+
366
+ async def agent(
367
+ self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
368
+ ) -> None:
369
+ await self._script_call_agent(ScenarioAgentRole.AGENT, content)
370
+
371
+ async def judge(
372
+ self, content: Optional[Union[str, ChatCompletionMessageParam]] = None
373
+ ) -> Optional[ScenarioResult]:
374
+ return await self._script_call_agent(ScenarioAgentRole.JUDGE, content)
375
+
376
+ async def proceed(
377
+ self,
378
+ turns: Optional[int] = None,
379
+ on_turn: Optional[
380
+ Union[
381
+ Callable[["ScenarioExecutor"], None],
382
+ Callable[["ScenarioExecutor"], Awaitable[None]],
383
+ ]
384
+ ] = None,
385
+ on_step: Optional[
386
+ Union[
387
+ Callable[["ScenarioExecutor"], None],
388
+ Callable[["ScenarioExecutor"], Awaitable[None]],
389
+ ]
390
+ ] = None,
391
+ ) -> Optional[ScenarioResult]:
392
+ initial_turn: Optional[int] = None
393
+ while True:
394
+ next_message = await self._step(
395
+ on_turn=on_turn,
396
+ go_to_next_turn=(
397
+ turns is None
398
+ or initial_turn is None
399
+ or (self.current_turn + 1 < initial_turn + turns)
400
+ ),
151
401
  )
152
402
 
153
- # Check if the result is a ScenarioResult (indicating test completion)
154
- if isinstance(result, ScenarioResult):
155
- result.total_time = time.time() - start_time
156
- result.agent_time = agent_time
157
- return result
158
- elif self.scenario.verbose:
159
- print(self._scenario_name() + termcolor.colored("User:", "green"), result)
403
+ if initial_turn is None:
404
+ initial_turn = self.current_turn
160
405
 
161
- # Otherwise, it's the next message to send to the agent
162
- next_message = result
406
+ if next_message is None:
407
+ break
163
408
 
164
- # Increment turn counter
165
- current_turn += 1
409
+ if on_step:
410
+ await await_if_awaitable(on_step(self))
166
411
 
167
- # If we reached max turns without conclusion, fail the test
168
- return ScenarioResult.failure_result(
169
- conversation=self.conversation,
170
- reasoning=f"Reached maximum turns ({max_turns}) without conclusion",
171
- total_time=time.time() - total_start_time,
172
- agent_time=agent_time,
412
+ if isinstance(next_message, ScenarioResult):
413
+ return next_message
414
+
415
+ async def succeed(self) -> ScenarioResult:
416
+ return ScenarioResult(
417
+ success=True,
418
+ messages=self.messages,
419
+ reasoning="Scenario marked as successful with scenario.succeed()",
420
+ passed_criteria=self.scenario.criteria,
173
421
  )
174
422
 
175
- def _generate_next_message(
423
+ async def fail(self) -> ScenarioResult:
424
+ return ScenarioResult(
425
+ success=False,
426
+ messages=self.messages,
427
+ reasoning="Scenario marked as failed with scenario.fail()",
428
+ passed_criteria=self.scenario.criteria,
429
+ )
430
+
431
+ async def _script_call_agent(
176
432
  self,
177
- scenario: "Scenario",
178
- conversation: List[Dict[str, Any]],
179
- first_message: bool = False,
180
- last_message: bool = False,
181
- ) -> Union[str, ScenarioResult]:
182
- if self.scenario.debug:
183
- print(f"\n{self._scenario_name()}{termcolor.colored('[Debug Mode]', 'yellow')} Press enter to continue or type a message to send")
184
- input_message = input(self._scenario_name() + termcolor.colored('User: ', 'green'))
433
+ role: ScenarioAgentRole,
434
+ content: Optional[Union[str, ChatCompletionMessageParam]] = None,
435
+ ) -> Optional[ScenarioResult]:
436
+ idx, next_agent = self._next_agent_for_role(role)
437
+ if not next_agent:
438
+ self._new_turn()
439
+ idx, next_agent = self._next_agent_for_role(role)
440
+
441
+ if not next_agent:
442
+ if content:
443
+ raise ValueError(
444
+ f"Cannot generate a message for role `{role.value}` with content `{content}` because no agent with this role was found"
445
+ )
446
+ raise ValueError(
447
+ f"Cannot generate a message for role `{role.value}` because no agent with this role was found"
448
+ )
185
449
 
186
- # Clear the input prompt lines completely
187
- for _ in range(3):
188
- sys.stdout.write("\033[F") # Move up to the input line
189
- sys.stdout.write("\033[2K") # Clear the entire input line
190
- sys.stdout.flush() # Make sure the clearing is visible
450
+ self._pending_agents_on_turn.remove(next_agent)
451
+ self._pending_roles_on_turn.remove(role)
191
452
 
192
- if input_message:
193
- return input_message
453
+ if content:
454
+ if isinstance(content, str):
455
+ message = ChatCompletionUserMessageParam(role="user", content=content)
456
+ else:
457
+ message = content
194
458
 
195
- with show_spinner(text=f"{self._scenario_name()}User:", color="green", enabled=self.scenario.verbose):
196
- return self.testing_agent.generate_next_message(
197
- scenario, conversation, first_message, last_message
198
- )
459
+ self.add_message(message)
460
+ if self.scenario.verbose:
461
+ print_openai_messages(self._scenario_name(), [message])
462
+ return
199
463
 
200
- def _scenario_name(self):
201
- if self.scenario.verbose == 2:
202
- return termcolor.colored(f"[Scenario: {self.scenario.name}] ", "yellow")
203
- else:
204
- return ""
464
+ result = await self._call_agent(idx, role=role)
465
+ if isinstance(result, ScenarioResult):
466
+ return result