langwatch-scenario 0.2.0__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
scenario/utils.py CHANGED
@@ -1,7 +1,26 @@
1
+ """
2
+ Utility functions for scenario execution and message handling.
3
+
4
+ This module provides various utility functions used throughout the Scenario framework,
5
+ including message formatting, validation, role reversal, and UI components like spinners
6
+ for better user experience during scenario execution.
7
+ """
8
+
1
9
  from contextlib import contextmanager
2
10
  import sys
3
- from typing import Optional, Union
11
+ from typing import (
12
+ Any,
13
+ Iterator,
14
+ List,
15
+ Literal,
16
+ Optional,
17
+ Union,
18
+ TypeVar,
19
+ Awaitable,
20
+ cast,
21
+ )
4
22
  from pydantic import BaseModel
23
+ import copy
5
24
 
6
25
  import json
7
26
 
@@ -14,16 +33,53 @@ from rich.console import Console
14
33
  from rich.text import Text
15
34
  from rich.errors import LiveError
16
35
 
36
+ from scenario.error_messages import message_return_error_message
37
+ from scenario.types import AgentReturnTypes, ScenarioResult
38
+
39
+ T = TypeVar("T")
17
40
 
18
41
 
19
42
  class SerializableAndPydanticEncoder(json.JSONEncoder):
43
+ """
44
+ JSON encoder that handles Pydantic models and iterators.
45
+
46
+ This encoder extends the standard JSON encoder to handle Pydantic BaseModel
47
+ instances and iterator objects, converting them to serializable formats.
48
+ Used for caching and logging scenarios that contain complex objects.
49
+
50
+ Example:
51
+ ```python
52
+ data = {
53
+ "model": SomeBaseModel(field="value"),
54
+ "iterator": iter([1, 2, 3])
55
+ }
56
+ json.dumps(data, cls=SerializableAndPydanticEncoder)
57
+ ```
58
+ """
20
59
  def default(self, o):
21
60
  if isinstance(o, BaseModel):
22
61
  return o.model_dump(exclude_unset=True)
62
+ if isinstance(o, Iterator):
63
+ return list(o)
23
64
  return super().default(o)
24
65
 
25
66
 
26
67
  class SerializableWithStringFallback(SerializableAndPydanticEncoder):
68
+ """
69
+ JSON encoder with string fallback for non-serializable objects.
70
+
71
+ This encoder extends SerializableAndPydanticEncoder by providing a string
72
+ fallback for any object that cannot be serialized normally. This ensures
73
+ that logging and caching operations never fail due to serialization issues.
74
+
75
+ Example:
76
+ ```python
77
+ # This will work even with complex non-serializable objects
78
+ data = {"function": lambda x: x, "complex_object": SomeComplexClass()}
79
+ json.dumps(data, cls=SerializableWithStringFallback)
80
+ # Result: {"function": "<function <lambda> at 0x...>", "complex_object": "..."}
81
+ ```
82
+ """
27
83
  def default(self, o):
28
84
  try:
29
85
  return super().default(o)
@@ -32,6 +88,25 @@ class SerializableWithStringFallback(SerializableAndPydanticEncoder):
32
88
 
33
89
 
34
90
  def safe_list_at(list, index, default=None):
91
+ """
92
+ Safely get an item from a list by index with a default fallback.
93
+
94
+ Args:
95
+ list: The list to access
96
+ index: The index to retrieve
97
+ default: Value to return if index is out of bounds
98
+
99
+ Returns:
100
+ The item at the index, or the default value if index is invalid
101
+
102
+ Example:
103
+ ```python
104
+ items = ["a", "b", "c"]
105
+ print(safe_list_at(items, 1)) # "b"
106
+ print(safe_list_at(items, 10)) # None
107
+ print(safe_list_at(items, 10, "default")) # "default"
108
+ ```
109
+ """
35
110
  try:
36
111
  return list[index]
37
112
  except:
@@ -39,14 +114,85 @@ def safe_list_at(list, index, default=None):
39
114
 
40
115
 
41
116
  def safe_attr_or_key(obj, attr_or_key, default=None):
117
+ """
118
+ Safely get an attribute or dictionary key from an object.
119
+
120
+ Tries to get the value as an attribute first, then as a dictionary key,
121
+ returning the default if neither exists.
122
+
123
+ Args:
124
+ obj: Object to access (can have attributes or be dict-like)
125
+ attr_or_key: Name of attribute or key to retrieve
126
+ default: Value to return if attribute/key doesn't exist
127
+
128
+ Returns:
129
+ The attribute/key value, or the default if not found
130
+
131
+ Example:
132
+ ```python
133
+ class MyClass:
134
+ attr = "value"
135
+
136
+ obj = MyClass()
137
+ dict_obj = {"key": "value"}
138
+
139
+ print(safe_attr_or_key(obj, "attr")) # "value"
140
+ print(safe_attr_or_key(dict_obj, "key")) # "value"
141
+ print(safe_attr_or_key(obj, "missing")) # None
142
+ ```
143
+ """
42
144
  return getattr(obj, attr_or_key, obj.get(attr_or_key))
43
145
 
44
146
 
45
147
  def title_case(string):
148
+ """
149
+ Convert snake_case string to Title Case.
150
+
151
+ Args:
152
+ string: Snake_case string to convert
153
+
154
+ Returns:
155
+ String converted to Title Case
156
+
157
+ Example:
158
+ ```python
159
+ print(title_case("user_simulator_agent")) # "User Simulator Agent"
160
+ print(title_case("api_key")) # "Api Key"
161
+ ```
162
+ """
46
163
  return " ".join(word.capitalize() for word in string.split("_"))
47
164
 
48
165
 
49
- def print_openai_messages(scenario_name: str, messages: list[ChatCompletionMessageParam]):
166
+ def print_openai_messages(
167
+ scenario_name: str, messages: list[ChatCompletionMessageParam]
168
+ ):
169
+ """
170
+ Print OpenAI-format messages with colored formatting for readability.
171
+
172
+ This function formats and prints conversation messages with appropriate
173
+ colors and formatting for different message types (user, assistant, tool calls, etc.).
174
+ Used for verbose output during scenario execution.
175
+
176
+ Args:
177
+ scenario_name: Name of the scenario (used as prefix)
178
+ messages: List of OpenAI-compatible messages to print
179
+
180
+ Example:
181
+ ```python
182
+ messages = [
183
+ {"role": "user", "content": "Hello"},
184
+ {"role": "assistant", "content": "Hi there!"},
185
+ {"role": "assistant", "tool_calls": [{"function": {"name": "search"}}]}
186
+ ]
187
+ print_openai_messages("Test Scenario", messages)
188
+ ```
189
+
190
+ Note:
191
+ - User messages are printed in green
192
+ - Assistant messages are printed in blue
193
+ - Tool calls are printed in magenta with formatted JSON
194
+ - Long JSON content is truncated for readability
195
+ """
50
196
  for msg in messages:
51
197
  role = safe_attr_or_key(msg, "role")
52
198
  content = safe_attr_or_key(msg, "content")
@@ -61,9 +207,12 @@ def print_openai_messages(scenario_name: str, messages: list[ChatCompletionMessa
61
207
  args = safe_attr_or_key(function, "arguments", "{}")
62
208
  args = _take_maybe_json_first_lines(args)
63
209
  print(
64
- scenario_name + termcolor.colored(f"ToolCall({name}):", "magenta"),
210
+ scenario_name
211
+ + termcolor.colored(f"ToolCall({name}):", "magenta"),
65
212
  f"\n\n{indent(args, ' ' * 4)}\n",
66
213
  )
214
+ elif role == "user":
215
+ print(scenario_name + termcolor.colored("User:", "green"), content)
67
216
  elif role == "tool":
68
217
  content = _take_maybe_json_first_lines(content or msg.__repr__())
69
218
  print(
@@ -78,6 +227,19 @@ def print_openai_messages(scenario_name: str, messages: list[ChatCompletionMessa
78
227
 
79
228
 
80
229
  def _take_maybe_json_first_lines(string, max_lines=5):
230
+ """
231
+ Truncate string content and format JSON if possible.
232
+
233
+ Internal utility function that attempts to format content as JSON
234
+ and truncates it to a reasonable number of lines for display.
235
+
236
+ Args:
237
+ string: Content to format and truncate
238
+ max_lines: Maximum number of lines to show
239
+
240
+ Returns:
241
+ Formatted and potentially truncated string
242
+ """
81
243
  content = str(string)
82
244
  try:
83
245
  content = json.dumps(json.loads(content), indent=2)
@@ -91,9 +253,25 @@ def _take_maybe_json_first_lines(string, max_lines=5):
91
253
 
92
254
  console = Console()
93
255
 
256
+
94
257
  class TextFirstSpinner(Spinner):
258
+ """
259
+ Custom spinner that displays text before the spinning animation.
260
+
261
+ This class extends Rich's Spinner to show descriptive text followed
262
+ by the spinning animation, improving the user experience during
263
+ scenario execution by clearly indicating what operation is happening.
264
+
265
+ Args:
266
+ name: Name of the spinner animation style
267
+ text: Descriptive text to show before the spinner
268
+ color: Color for the descriptive text
269
+ **kwargs: Additional arguments passed to the base Spinner class
270
+ """
95
271
  def __init__(self, name, text: str, color: str, **kwargs):
96
- super().__init__(name, "", style="bold white", **kwargs) # Initialize with empty text
272
+ super().__init__(
273
+ name, "", style="bold white", **kwargs
274
+ ) # Initialize with empty text
97
275
  self.text_before = text
98
276
  self.color = color
99
277
 
@@ -105,7 +283,35 @@ class TextFirstSpinner(Spinner):
105
283
 
106
284
 
107
285
  @contextmanager
108
- def show_spinner(text: str, color: str = "white", enabled: Optional[Union[bool, int]] = None):
286
+ def show_spinner(
287
+ text: str, color: str = "white", enabled: Optional[Union[bool, int]] = None
288
+ ):
289
+ """
290
+ Context manager for displaying a spinner during long-running operations.
291
+
292
+ Shows a spinning indicator with descriptive text while code executes
293
+ within the context. Automatically cleans up the spinner display when
294
+ the operation completes.
295
+
296
+ Args:
297
+ text: Descriptive text to show next to the spinner
298
+ color: Color for the descriptive text
299
+ enabled: Whether to show the spinner (respects verbose settings)
300
+
301
+ Example:
302
+ ```python
303
+ with show_spinner("Calling agent...", color="blue", enabled=True):
304
+ response = await agent.call(input_data)
305
+
306
+ # Spinner automatically disappears when block completes
307
+ print("Agent call completed")
308
+ ```
309
+
310
+ Note:
311
+ - Spinner is automatically cleaned up when context exits
312
+ - Gracefully handles multi-threading scenarios where multiple spinners might conflict
313
+ - Cursor positioning ensures clean terminal output
314
+ """
109
315
  if not enabled:
110
316
  yield
111
317
  else:
@@ -119,3 +325,190 @@ def show_spinner(text: str, color: str = "white", enabled: Optional[Union[bool,
119
325
 
120
326
  # Cursor up one line
121
327
  sys.stdout.write("\033[F")
328
+ # Erase the line
329
+ sys.stdout.write("\033[2K")
330
+
331
+
332
+ def check_valid_return_type(return_value: Any, class_name: str) -> None:
333
+ """
334
+ Validate that an agent's return value is in the expected format.
335
+
336
+ This function ensures that agent adapters return values in one of the
337
+ supported formats (string, OpenAI message, list of messages, or ScenarioResult).
338
+ It also verifies that the returned data is JSON-serializable for caching.
339
+
340
+ Args:
341
+ return_value: The value returned by an agent's call method
342
+ class_name: Name of the agent class (for error messages)
343
+
344
+ Raises:
345
+ ValueError: If the return value is not in a supported format
346
+
347
+ Example:
348
+ ```python
349
+ # Valid return values
350
+ check_valid_return_type("Hello world", "MyAgent") # OK
351
+ check_valid_return_type({"role": "assistant", "content": "Hi"}, "MyAgent") # OK
352
+ check_valid_return_type([{"role": "assistant", "content": "Hi"}], "MyAgent") # OK
353
+
354
+ # Invalid return value
355
+ check_valid_return_type(42, "MyAgent") # Raises ValueError
356
+ ```
357
+ """
358
+ def _is_valid_openai_message(message: Any) -> bool:
359
+ return (isinstance(message, dict) and "role" in message) or (
360
+ isinstance(message, BaseModel) and hasattr(message, "role")
361
+ )
362
+
363
+ if (
364
+ isinstance(return_value, str)
365
+ or _is_valid_openai_message(return_value)
366
+ or (
367
+ isinstance(return_value, list)
368
+ and all(_is_valid_openai_message(message) for message in return_value)
369
+ )
370
+ or isinstance(return_value, ScenarioResult)
371
+ ):
372
+ try:
373
+ json.dumps(return_value, cls=SerializableAndPydanticEncoder)
374
+ except:
375
+ raise ValueError(
376
+ message_return_error_message(got=return_value, class_name=class_name)
377
+ )
378
+
379
+ return
380
+
381
+ raise ValueError(
382
+ message_return_error_message(got=return_value, class_name=class_name)
383
+ )
384
+
385
+
386
+ def convert_agent_return_types_to_openai_messages(
387
+ agent_response: AgentReturnTypes, role: Literal["user", "assistant"]
388
+ ) -> List[ChatCompletionMessageParam]:
389
+ """
390
+ Convert various agent return types to standardized OpenAI message format.
391
+
392
+ This function normalizes different return types from agent adapters into
393
+ a consistent list of OpenAI-compatible messages that can be used throughout
394
+ the scenario execution pipeline.
395
+
396
+ Args:
397
+ agent_response: Response from an agent adapter call
398
+ role: The role to assign to string responses ("user" or "assistant")
399
+
400
+ Returns:
401
+ List of OpenAI-compatible messages
402
+
403
+ Raises:
404
+ ValueError: If agent_response is a ScenarioResult (which should be handled separately)
405
+
406
+ Example:
407
+ ```python
408
+ # String response
409
+ messages = convert_agent_return_types_to_openai_messages("Hello", "assistant")
410
+ # Result: [{"role": "assistant", "content": "Hello"}]
411
+
412
+ # Dict response
413
+ response = {"role": "assistant", "content": "Hi", "tool_calls": [...]}
414
+ messages = convert_agent_return_types_to_openai_messages(response, "assistant")
415
+ # Result: [{"role": "assistant", "content": "Hi", "tool_calls": [...]}]
416
+
417
+ # List response
418
+ responses = [
419
+ {"role": "assistant", "content": "Thinking..."},
420
+ {"role": "assistant", "content": "Here's the answer"}
421
+ ]
422
+ messages = convert_agent_return_types_to_openai_messages(responses, "assistant")
423
+ # Result: Same list, validated and normalized
424
+ ```
425
+ """
426
+ if isinstance(agent_response, ScenarioResult):
427
+ raise ValueError(
428
+ "Unexpectedly tried to convert a ScenarioResult to openai messages",
429
+ agent_response.__repr__(),
430
+ )
431
+
432
+ def convert_maybe_object_to_openai_message(
433
+ obj: Any,
434
+ ) -> ChatCompletionMessageParam:
435
+ if isinstance(obj, dict):
436
+ return cast(ChatCompletionMessageParam, obj)
437
+ elif isinstance(obj, BaseModel):
438
+ return cast(
439
+ ChatCompletionMessageParam,
440
+ obj.model_dump(
441
+ exclude_unset=True,
442
+ exclude_none=True,
443
+ exclude_defaults=True,
444
+ warnings=False,
445
+ ),
446
+ )
447
+ else:
448
+ raise ValueError(f"Unexpected agent response type: {type(obj).__name__}")
449
+
450
+ def ensure_dict(
451
+ obj: T,
452
+ ) -> T:
453
+ return json.loads(json.dumps(obj, cls=SerializableAndPydanticEncoder))
454
+
455
+ if isinstance(agent_response, str):
456
+ return [
457
+ (
458
+ {"role": "user", "content": agent_response}
459
+ if role == "user"
460
+ else {"role": "assistant", "content": agent_response}
461
+ )
462
+ ]
463
+ elif isinstance(agent_response, list):
464
+ return [
465
+ ensure_dict(convert_maybe_object_to_openai_message(message))
466
+ for message in agent_response
467
+ ]
468
+ else:
469
+ return [ensure_dict(convert_maybe_object_to_openai_message(agent_response))]
470
+
471
+
472
+ def reverse_roles(
473
+ messages: list[ChatCompletionMessageParam],
474
+ ) -> list[ChatCompletionMessageParam]:
475
+ """
476
+ Reverses the roles of the messages in the list.
477
+
478
+ Args:
479
+ messages: The list of messages to reverse the roles of.
480
+ """
481
+
482
+ reversed_messages = []
483
+ for message in messages:
484
+ message = copy.deepcopy(message)
485
+ # Can't reverse tool calls
486
+ if not safe_attr_or_key(message, "content") or safe_attr_or_key(
487
+ message, "tool_calls"
488
+ ):
489
+ # If no content nor tool calls, we should skip it entirely, as anthropic may generate some invalid ones e.g. pure {"role": "assistant"}
490
+ if safe_attr_or_key(message, "tool_calls"):
491
+ reversed_messages.append(message)
492
+ continue
493
+
494
+ if type(message) == dict:
495
+ if message["role"] == "user":
496
+ message["role"] = "assistant"
497
+ elif message["role"] == "assistant":
498
+ message["role"] = "user"
499
+ else:
500
+ if getattr(message, "role", None) == "user":
501
+ message.role = "assistant" # type: ignore
502
+ elif getattr(message, "role", None) == "assistant":
503
+ message.role = "user" # type: ignore
504
+
505
+ reversed_messages.append(message)
506
+
507
+ return reversed_messages
508
+
509
+
510
+ async def await_if_awaitable(value: T) -> T:
511
+ if isinstance(value, Awaitable):
512
+ return await value
513
+ else:
514
+ return value