langwatch-scenario 0.3.0__py3-none-any.whl → 0.6.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- langwatch_scenario-0.6.0.dist-info/METADATA +385 -0
- langwatch_scenario-0.6.0.dist-info/RECORD +27 -0
- scenario/__init__.py +128 -17
- scenario/{error_messages.py → _error_messages.py} +8 -38
- scenario/_utils/__init__.py +32 -0
- scenario/_utils/ids.py +58 -0
- scenario/_utils/message_conversion.py +103 -0
- scenario/_utils/utils.py +425 -0
- scenario/agent_adapter.py +115 -0
- scenario/cache.py +134 -9
- scenario/config.py +156 -10
- scenario/events/__init__.py +66 -0
- scenario/events/event_bus.py +175 -0
- scenario/events/event_reporter.py +83 -0
- scenario/events/events.py +169 -0
- scenario/events/messages.py +84 -0
- scenario/events/utils.py +86 -0
- scenario/judge_agent.py +414 -0
- scenario/pytest_plugin.py +177 -14
- scenario/scenario_executor.py +630 -154
- scenario/scenario_state.py +205 -0
- scenario/script.py +361 -0
- scenario/types.py +197 -20
- scenario/user_simulator_agent.py +242 -0
- langwatch_scenario-0.3.0.dist-info/METADATA +0 -302
- langwatch_scenario-0.3.0.dist-info/RECORD +0 -16
- scenario/scenario.py +0 -238
- scenario/scenario_agent_adapter.py +0 -16
- scenario/testing_agent.py +0 -279
- scenario/utils.py +0 -264
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/WHEEL +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/entry_points.txt +0 -0
- {langwatch_scenario-0.3.0.dist-info → langwatch_scenario-0.6.0.dist-info}/top_level.txt +0 -0
scenario/testing_agent.py
DELETED
@@ -1,279 +0,0 @@
|
|
1
|
-
"""
|
2
|
-
TestingAgent module: defines the testing agent that interacts with the agent under test.
|
3
|
-
"""
|
4
|
-
|
5
|
-
import json
|
6
|
-
import logging
|
7
|
-
import re
|
8
|
-
from typing import Optional, Type, cast
|
9
|
-
|
10
|
-
from litellm import Choices, completion
|
11
|
-
from litellm.files.main import ModelResponse
|
12
|
-
|
13
|
-
from scenario.cache import scenario_cache
|
14
|
-
from scenario.scenario_agent_adapter import ScenarioAgentAdapter
|
15
|
-
from scenario.utils import reverse_roles
|
16
|
-
|
17
|
-
from .error_messages import testing_agent_not_configured_error_message
|
18
|
-
from .types import AgentInput, AgentReturnTypes, ScenarioAgentRole, ScenarioResult
|
19
|
-
|
20
|
-
|
21
|
-
logger = logging.getLogger("scenario")
|
22
|
-
|
23
|
-
|
24
|
-
class TestingAgent(ScenarioAgentAdapter):
|
25
|
-
"""
|
26
|
-
The Testing Agent that interacts with the agent under test.
|
27
|
-
|
28
|
-
This agent is responsible for:
|
29
|
-
1. Generating messages to send to the agent based on the scenario
|
30
|
-
2. Evaluating the responses from the agent against the success/failure criteria
|
31
|
-
3. Determining when to end the test and return a result
|
32
|
-
"""
|
33
|
-
|
34
|
-
roles = {ScenarioAgentRole.USER, ScenarioAgentRole.JUDGE}
|
35
|
-
|
36
|
-
model: str = ""
|
37
|
-
api_key: Optional[str] = None
|
38
|
-
temperature: float = 0.0
|
39
|
-
max_tokens: Optional[int] = None
|
40
|
-
|
41
|
-
# To prevent pytest from thinking this is actually a test class
|
42
|
-
__test__ = False
|
43
|
-
|
44
|
-
def __init__(self, input: AgentInput):
|
45
|
-
super().__init__(input)
|
46
|
-
|
47
|
-
if not self.model:
|
48
|
-
raise Exception(testing_agent_not_configured_error_message)
|
49
|
-
|
50
|
-
@classmethod
|
51
|
-
def with_config(
|
52
|
-
cls,
|
53
|
-
model: str,
|
54
|
-
api_key: Optional[str] = None,
|
55
|
-
temperature: float = 0.0,
|
56
|
-
max_tokens: Optional[int] = None,
|
57
|
-
) -> Type["TestingAgent"]:
|
58
|
-
class TestingAgentWithConfig(cls):
|
59
|
-
def __init__(self, input: AgentInput):
|
60
|
-
self.model = model
|
61
|
-
self.api_key = api_key
|
62
|
-
self.temperature = temperature
|
63
|
-
self.max_tokens = max_tokens
|
64
|
-
|
65
|
-
super().__init__(input)
|
66
|
-
|
67
|
-
return TestingAgentWithConfig
|
68
|
-
|
69
|
-
@scenario_cache(ignore=["scenario"])
|
70
|
-
async def call(
|
71
|
-
self,
|
72
|
-
input: AgentInput,
|
73
|
-
) -> AgentReturnTypes:
|
74
|
-
"""
|
75
|
-
Generate the next message in the conversation based on history OR
|
76
|
-
return a ScenarioResult if the test should conclude.
|
77
|
-
|
78
|
-
Returns either:
|
79
|
-
- A string message to send to the agent (if conversation should continue)
|
80
|
-
- A ScenarioResult (if the test should conclude)
|
81
|
-
"""
|
82
|
-
|
83
|
-
scenario = input.scenario_state.scenario
|
84
|
-
|
85
|
-
messages = [
|
86
|
-
{
|
87
|
-
"role": "system",
|
88
|
-
"content": f"""
|
89
|
-
<role>
|
90
|
-
You are pretending to be a user, you are testing an AI Agent (shown as the user role) based on a scenario.
|
91
|
-
Approach this naturally, as a human user would, with very short inputs, few words, all lowercase, imperative, not periods, like when they google or talk to chatgpt.
|
92
|
-
</role>
|
93
|
-
|
94
|
-
<goal>
|
95
|
-
Your goal (assistant) is to interact with the Agent Under Test (user) as if you were a human user to see if it can complete the scenario successfully.
|
96
|
-
</goal>
|
97
|
-
|
98
|
-
<scenario>
|
99
|
-
{scenario.description}
|
100
|
-
</scenario>
|
101
|
-
|
102
|
-
<criteria>
|
103
|
-
{"\n".join([f"{idx + 1}. {criterion}" for idx, criterion in enumerate(scenario.criteria)])}
|
104
|
-
</criteria>
|
105
|
-
|
106
|
-
<execution_flow>
|
107
|
-
1. Generate the first message to start the scenario
|
108
|
-
2. After the Agent Under Test (user) responds, generate the next message to send to the Agent Under Test, keep repeating step 2 until criterias match
|
109
|
-
3. If the test should end, use the finish_test tool to determine if all the criteria have been met
|
110
|
-
</execution_flow>
|
111
|
-
|
112
|
-
<rules>
|
113
|
-
1. Test should end immediately if a criteria mentioning something the agent should NOT do is met
|
114
|
-
2. Test should continue until all scenario goals have been met to try going through all the criteria
|
115
|
-
3. DO NOT make any judgment calls that are not explicitly listed in the success or failure criteria, withhold judgement if necessary
|
116
|
-
4. DO NOT carry over any requests yourself, YOU ARE NOT the assistant today, wait for the user to do it
|
117
|
-
</rules>
|
118
|
-
""",
|
119
|
-
},
|
120
|
-
{"role": "assistant", "content": "Hello, how can I help you today?"},
|
121
|
-
*input.messages,
|
122
|
-
]
|
123
|
-
|
124
|
-
is_first_message = len(input.messages) == 0
|
125
|
-
is_last_message = (
|
126
|
-
input.scenario_state.current_turn == input.scenario_state.scenario.max_turns
|
127
|
-
)
|
128
|
-
|
129
|
-
if is_last_message:
|
130
|
-
messages.append(
|
131
|
-
{
|
132
|
-
"role": "user",
|
133
|
-
"content": """
|
134
|
-
System:
|
135
|
-
|
136
|
-
<finish_test>
|
137
|
-
This is the last message, conversation has reached the maximum number of turns, give your final verdict,
|
138
|
-
if you don't have enough information to make a verdict, say inconclusive with max turns reached.
|
139
|
-
</finish_test>
|
140
|
-
""",
|
141
|
-
}
|
142
|
-
)
|
143
|
-
|
144
|
-
# User to assistant role reversal
|
145
|
-
# LLM models are biased to always be the assistant not the user, so we need to do this reversal otherwise models like GPT 4.5 is
|
146
|
-
# super confused, and Claude 3.7 even starts throwing exceptions.
|
147
|
-
messages = reverse_roles(messages)
|
148
|
-
|
149
|
-
# Define the tool
|
150
|
-
criteria_names = [
|
151
|
-
re.sub(
|
152
|
-
r"[^a-zA-Z0-9]",
|
153
|
-
"_",
|
154
|
-
criterion.replace(" ", "_").replace("'", "").lower(),
|
155
|
-
)[:70]
|
156
|
-
for criterion in scenario.criteria
|
157
|
-
]
|
158
|
-
tools = [
|
159
|
-
{
|
160
|
-
"type": "function",
|
161
|
-
"function": {
|
162
|
-
"name": "finish_test",
|
163
|
-
"description": "Complete the test with a final verdict",
|
164
|
-
"strict": True,
|
165
|
-
"parameters": {
|
166
|
-
"type": "object",
|
167
|
-
"properties": {
|
168
|
-
"criteria": {
|
169
|
-
"type": "object",
|
170
|
-
"properties": {
|
171
|
-
criteria_names[idx]: {
|
172
|
-
"enum": [True, False, "inconclusive"],
|
173
|
-
"description": criterion,
|
174
|
-
}
|
175
|
-
for idx, criterion in enumerate(scenario.criteria)
|
176
|
-
},
|
177
|
-
"required": criteria_names,
|
178
|
-
"additionalProperties": False,
|
179
|
-
"description": "Strict verdict for each criterion",
|
180
|
-
},
|
181
|
-
"reasoning": {
|
182
|
-
"type": "string",
|
183
|
-
"description": "Explanation of what the final verdict should be",
|
184
|
-
},
|
185
|
-
"verdict": {
|
186
|
-
"type": "string",
|
187
|
-
"enum": ["success", "failure", "inconclusive"],
|
188
|
-
"description": "The final verdict of the test",
|
189
|
-
},
|
190
|
-
},
|
191
|
-
"required": ["criteria", "reasoning", "verdict"],
|
192
|
-
"additionalProperties": False,
|
193
|
-
},
|
194
|
-
},
|
195
|
-
}
|
196
|
-
]
|
197
|
-
|
198
|
-
enforce_judgment = input.requested_role == ScenarioAgentRole.JUDGE
|
199
|
-
has_criteria = len(scenario.criteria) > 0
|
200
|
-
|
201
|
-
if enforce_judgment and not has_criteria:
|
202
|
-
return ScenarioResult(
|
203
|
-
success=False,
|
204
|
-
messages=[],
|
205
|
-
reasoning="TestingAgent was called as a judge, but it has no criteria to judge against",
|
206
|
-
)
|
207
|
-
|
208
|
-
response = cast(
|
209
|
-
ModelResponse,
|
210
|
-
completion(
|
211
|
-
model=self.model,
|
212
|
-
messages=messages,
|
213
|
-
temperature=self.temperature,
|
214
|
-
max_tokens=self.max_tokens,
|
215
|
-
tools=(
|
216
|
-
tools
|
217
|
-
if (not is_first_message or enforce_judgment) and has_criteria
|
218
|
-
else None
|
219
|
-
),
|
220
|
-
tool_choice=(
|
221
|
-
"required"
|
222
|
-
if (is_last_message or enforce_judgment) and has_criteria
|
223
|
-
else None
|
224
|
-
),
|
225
|
-
),
|
226
|
-
)
|
227
|
-
|
228
|
-
# Extract the content from the response
|
229
|
-
if hasattr(response, "choices") and len(response.choices) > 0:
|
230
|
-
message = cast(Choices, response.choices[0]).message
|
231
|
-
|
232
|
-
# Check if the LLM chose to use the tool
|
233
|
-
if message.tool_calls:
|
234
|
-
tool_call = message.tool_calls[0]
|
235
|
-
if tool_call.function.name == "finish_test":
|
236
|
-
# Parse the tool call arguments
|
237
|
-
try:
|
238
|
-
args = json.loads(tool_call.function.arguments)
|
239
|
-
verdict = args.get("verdict", "inconclusive")
|
240
|
-
reasoning = args.get("reasoning", "No reasoning provided")
|
241
|
-
criteria = args.get("criteria", {})
|
242
|
-
|
243
|
-
passed_criteria = [
|
244
|
-
scenario.criteria[idx]
|
245
|
-
for idx, criterion in enumerate(criteria.values())
|
246
|
-
if criterion == True
|
247
|
-
]
|
248
|
-
failed_criteria = [
|
249
|
-
scenario.criteria[idx]
|
250
|
-
for idx, criterion in enumerate(criteria.values())
|
251
|
-
if criterion == False
|
252
|
-
]
|
253
|
-
|
254
|
-
# Return the appropriate ScenarioResult based on the verdict
|
255
|
-
return ScenarioResult(
|
256
|
-
success=verdict == "success",
|
257
|
-
messages=messages,
|
258
|
-
reasoning=reasoning,
|
259
|
-
passed_criteria=passed_criteria,
|
260
|
-
failed_criteria=failed_criteria,
|
261
|
-
)
|
262
|
-
except json.JSONDecodeError:
|
263
|
-
logger.error("Failed to parse tool call arguments")
|
264
|
-
|
265
|
-
# If no tool call use the message content as next message
|
266
|
-
message_content = message.content
|
267
|
-
if message_content is None:
|
268
|
-
# If invalid tool call, raise an error
|
269
|
-
if message.tool_calls:
|
270
|
-
raise Exception(
|
271
|
-
f"Invalid tool call from testing agent: {message.tool_calls.__repr__()}"
|
272
|
-
)
|
273
|
-
raise Exception(f"No response from LLM: {response.__repr__()}")
|
274
|
-
|
275
|
-
return {"role": "user", "content": message_content}
|
276
|
-
else:
|
277
|
-
raise Exception(
|
278
|
-
f"Unexpected response format from LLM: {response.__repr__()}"
|
279
|
-
)
|
scenario/utils.py
DELETED
@@ -1,264 +0,0 @@
|
|
1
|
-
from contextlib import contextmanager
|
2
|
-
import sys
|
3
|
-
from typing import (
|
4
|
-
Any,
|
5
|
-
Iterator,
|
6
|
-
List,
|
7
|
-
Literal,
|
8
|
-
Optional,
|
9
|
-
Union,
|
10
|
-
TypeVar,
|
11
|
-
Awaitable,
|
12
|
-
cast,
|
13
|
-
)
|
14
|
-
from pydantic import BaseModel
|
15
|
-
|
16
|
-
import json
|
17
|
-
|
18
|
-
import termcolor
|
19
|
-
from textwrap import indent
|
20
|
-
from openai.types.chat import ChatCompletionMessageParam
|
21
|
-
from rich.live import Live
|
22
|
-
from rich.spinner import Spinner
|
23
|
-
from rich.console import Console
|
24
|
-
from rich.text import Text
|
25
|
-
from rich.errors import LiveError
|
26
|
-
|
27
|
-
from scenario.error_messages import message_return_error_message
|
28
|
-
from scenario.types import AgentReturnTypes, ScenarioResult
|
29
|
-
|
30
|
-
T = TypeVar("T")
|
31
|
-
|
32
|
-
|
33
|
-
class SerializableAndPydanticEncoder(json.JSONEncoder):
|
34
|
-
def default(self, o):
|
35
|
-
if isinstance(o, BaseModel):
|
36
|
-
return o.model_dump(exclude_unset=True)
|
37
|
-
if isinstance(o, Iterator):
|
38
|
-
return list(o)
|
39
|
-
return super().default(o)
|
40
|
-
|
41
|
-
|
42
|
-
class SerializableWithStringFallback(SerializableAndPydanticEncoder):
|
43
|
-
def default(self, o):
|
44
|
-
try:
|
45
|
-
return super().default(o)
|
46
|
-
except:
|
47
|
-
return str(o)
|
48
|
-
|
49
|
-
|
50
|
-
def safe_list_at(list, index, default=None):
|
51
|
-
try:
|
52
|
-
return list[index]
|
53
|
-
except:
|
54
|
-
return default
|
55
|
-
|
56
|
-
|
57
|
-
def safe_attr_or_key(obj, attr_or_key, default=None):
|
58
|
-
return getattr(obj, attr_or_key, obj.get(attr_or_key))
|
59
|
-
|
60
|
-
|
61
|
-
def title_case(string):
|
62
|
-
return " ".join(word.capitalize() for word in string.split("_"))
|
63
|
-
|
64
|
-
|
65
|
-
def print_openai_messages(
|
66
|
-
scenario_name: str, messages: list[ChatCompletionMessageParam]
|
67
|
-
):
|
68
|
-
for msg in messages:
|
69
|
-
role = safe_attr_or_key(msg, "role")
|
70
|
-
content = safe_attr_or_key(msg, "content")
|
71
|
-
if role == "assistant":
|
72
|
-
tool_calls = safe_attr_or_key(msg, "tool_calls")
|
73
|
-
if content:
|
74
|
-
print(scenario_name + termcolor.colored("Agent:", "blue"), content)
|
75
|
-
if tool_calls:
|
76
|
-
for tool_call in tool_calls:
|
77
|
-
function = safe_attr_or_key(tool_call, "function")
|
78
|
-
name = safe_attr_or_key(function, "name")
|
79
|
-
args = safe_attr_or_key(function, "arguments", "{}")
|
80
|
-
args = _take_maybe_json_first_lines(args)
|
81
|
-
print(
|
82
|
-
scenario_name
|
83
|
-
+ termcolor.colored(f"ToolCall({name}):", "magenta"),
|
84
|
-
f"\n\n{indent(args, ' ' * 4)}\n",
|
85
|
-
)
|
86
|
-
elif role == "user":
|
87
|
-
print(scenario_name + termcolor.colored("User:", "green"), content)
|
88
|
-
elif role == "tool":
|
89
|
-
content = _take_maybe_json_first_lines(content or msg.__repr__())
|
90
|
-
print(
|
91
|
-
scenario_name + termcolor.colored(f"ToolResult:", "magenta"),
|
92
|
-
f"\n\n{indent(content, ' ' * 4)}\n",
|
93
|
-
)
|
94
|
-
else:
|
95
|
-
print(
|
96
|
-
scenario_name + termcolor.colored(f"{title_case(role)}:", "magenta"),
|
97
|
-
msg.__repr__(),
|
98
|
-
)
|
99
|
-
|
100
|
-
|
101
|
-
def _take_maybe_json_first_lines(string, max_lines=5):
|
102
|
-
content = str(string)
|
103
|
-
try:
|
104
|
-
content = json.dumps(json.loads(content), indent=2)
|
105
|
-
except:
|
106
|
-
pass
|
107
|
-
content = content.split("\n")
|
108
|
-
if len(content) > max_lines:
|
109
|
-
content = content[:max_lines] + ["..."]
|
110
|
-
return "\n".join(content)
|
111
|
-
|
112
|
-
|
113
|
-
console = Console()
|
114
|
-
|
115
|
-
|
116
|
-
class TextFirstSpinner(Spinner):
|
117
|
-
def __init__(self, name, text: str, color: str, **kwargs):
|
118
|
-
super().__init__(
|
119
|
-
name, "", style="bold white", **kwargs
|
120
|
-
) # Initialize with empty text
|
121
|
-
self.text_before = text
|
122
|
-
self.color = color
|
123
|
-
|
124
|
-
def render(self, time):
|
125
|
-
# Get the original spinner frame
|
126
|
-
spinner_frame = super().render(time)
|
127
|
-
# Create a composite with text first, then spinner
|
128
|
-
return Text(f"{self.text_before} ", style=self.color) + spinner_frame
|
129
|
-
|
130
|
-
|
131
|
-
@contextmanager
|
132
|
-
def show_spinner(
|
133
|
-
text: str, color: str = "white", enabled: Optional[Union[bool, int]] = None
|
134
|
-
):
|
135
|
-
if not enabled:
|
136
|
-
yield
|
137
|
-
else:
|
138
|
-
spinner = TextFirstSpinner("dots", text, color=color)
|
139
|
-
try:
|
140
|
-
with Live(spinner, console=console, refresh_per_second=20):
|
141
|
-
yield
|
142
|
-
# It happens when we are multi-threading, it's fine, just ignore it, you probably don't want multiple spinners at once anyway
|
143
|
-
except LiveError:
|
144
|
-
yield
|
145
|
-
|
146
|
-
# Cursor up one line
|
147
|
-
sys.stdout.write("\033[F")
|
148
|
-
# Erase the line
|
149
|
-
sys.stdout.write("\033[2K")
|
150
|
-
|
151
|
-
|
152
|
-
def check_valid_return_type(return_value: Any, class_name: str) -> None:
|
153
|
-
def _is_valid_openai_message(message: Any) -> bool:
|
154
|
-
return (isinstance(message, dict) and "role" in message) or (
|
155
|
-
isinstance(message, BaseModel) and hasattr(message, "role")
|
156
|
-
)
|
157
|
-
|
158
|
-
if (
|
159
|
-
isinstance(return_value, str)
|
160
|
-
or _is_valid_openai_message(return_value)
|
161
|
-
or (
|
162
|
-
isinstance(return_value, list)
|
163
|
-
and all(_is_valid_openai_message(message) for message in return_value)
|
164
|
-
)
|
165
|
-
or isinstance(return_value, ScenarioResult)
|
166
|
-
):
|
167
|
-
try:
|
168
|
-
json.dumps(return_value, cls=SerializableAndPydanticEncoder)
|
169
|
-
except:
|
170
|
-
raise ValueError(
|
171
|
-
message_return_error_message(got=return_value, class_name=class_name)
|
172
|
-
)
|
173
|
-
|
174
|
-
return
|
175
|
-
|
176
|
-
raise ValueError(
|
177
|
-
message_return_error_message(got=return_value, class_name=class_name)
|
178
|
-
)
|
179
|
-
|
180
|
-
|
181
|
-
def convert_agent_return_types_to_openai_messages(
|
182
|
-
agent_response: AgentReturnTypes, role: Literal["user", "assistant"]
|
183
|
-
) -> List[ChatCompletionMessageParam]:
|
184
|
-
if isinstance(agent_response, ScenarioResult):
|
185
|
-
raise ValueError(
|
186
|
-
"Unexpectedly tried to convert a ScenarioResult to openai messages",
|
187
|
-
agent_response.__repr__(),
|
188
|
-
)
|
189
|
-
|
190
|
-
def convert_maybe_object_to_openai_message(
|
191
|
-
obj: Any,
|
192
|
-
) -> ChatCompletionMessageParam:
|
193
|
-
if isinstance(obj, dict):
|
194
|
-
return cast(ChatCompletionMessageParam, obj)
|
195
|
-
elif isinstance(obj, BaseModel):
|
196
|
-
return cast(
|
197
|
-
ChatCompletionMessageParam,
|
198
|
-
obj.model_dump(
|
199
|
-
exclude_unset=True,
|
200
|
-
exclude_none=True,
|
201
|
-
exclude_defaults=True,
|
202
|
-
),
|
203
|
-
)
|
204
|
-
else:
|
205
|
-
raise ValueError(f"Unexpected agent response type: {type(obj).__name__}")
|
206
|
-
|
207
|
-
def ensure_dict(
|
208
|
-
obj: T,
|
209
|
-
) -> T:
|
210
|
-
return json.loads(json.dumps(obj, cls=SerializableAndPydanticEncoder))
|
211
|
-
|
212
|
-
if isinstance(agent_response, str):
|
213
|
-
return [
|
214
|
-
(
|
215
|
-
{"role": "user", "content": agent_response}
|
216
|
-
if role == "user"
|
217
|
-
else {"role": "assistant", "content": agent_response}
|
218
|
-
)
|
219
|
-
]
|
220
|
-
elif isinstance(agent_response, list):
|
221
|
-
return [
|
222
|
-
ensure_dict(convert_maybe_object_to_openai_message(message))
|
223
|
-
for message in agent_response
|
224
|
-
]
|
225
|
-
else:
|
226
|
-
return [ensure_dict(convert_maybe_object_to_openai_message(agent_response))]
|
227
|
-
|
228
|
-
|
229
|
-
def reverse_roles(
|
230
|
-
messages: list[ChatCompletionMessageParam],
|
231
|
-
) -> list[ChatCompletionMessageParam]:
|
232
|
-
"""
|
233
|
-
Reverses the roles of the messages in the list.
|
234
|
-
|
235
|
-
Args:
|
236
|
-
messages: The list of messages to reverse the roles of.
|
237
|
-
"""
|
238
|
-
|
239
|
-
for message in messages.copy():
|
240
|
-
# Can't reverse tool calls
|
241
|
-
if not safe_attr_or_key(message, "content") or safe_attr_or_key(
|
242
|
-
message, "tool_calls"
|
243
|
-
):
|
244
|
-
continue
|
245
|
-
|
246
|
-
if type(message) == dict:
|
247
|
-
if message["role"] == "user":
|
248
|
-
message["role"] = "assistant"
|
249
|
-
elif message["role"] == "assistant":
|
250
|
-
message["role"] = "user"
|
251
|
-
else:
|
252
|
-
if getattr(message, "role", None) == "user":
|
253
|
-
message.role = "assistant" # type: ignore
|
254
|
-
elif getattr(message, "role", None) == "assistant":
|
255
|
-
message.role = "user" # type: ignore
|
256
|
-
|
257
|
-
return messages
|
258
|
-
|
259
|
-
|
260
|
-
async def await_if_awaitable(value: T) -> T:
|
261
|
-
if isinstance(value, Awaitable):
|
262
|
-
return await value
|
263
|
-
else:
|
264
|
-
return value
|
File without changes
|
File without changes
|
File without changes
|