speedy-utils 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +29 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +120 -0
- llm_utils/lm/__init__.py +8 -0
- llm_utils/lm/base_lm.py +304 -0
- llm_utils/lm/utils.py +130 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +416 -0
- speedy_utils/__init__.py +85 -0
- speedy_utils/all.py +159 -0
- {speedy → speedy_utils}/common/__init__.py +0 -0
- speedy_utils/common/clock.py +215 -0
- speedy_utils/common/function_decorator.py +66 -0
- speedy_utils/common/logger.py +207 -0
- speedy_utils/common/report_manager.py +112 -0
- speedy_utils/common/utils_cache.py +264 -0
- {speedy → speedy_utils}/common/utils_io.py +66 -19
- {speedy → speedy_utils}/common/utils_misc.py +25 -11
- speedy_utils/common/utils_print.py +216 -0
- speedy_utils/multi_worker/__init__.py +0 -0
- speedy_utils/multi_worker/process.py +198 -0
- speedy_utils/multi_worker/thread.py +327 -0
- speedy_utils/scripts/mpython.py +108 -0
- speedy_utils-1.0.5.dist-info/METADATA +279 -0
- speedy_utils-1.0.5.dist-info/RECORD +27 -0
- {speedy_utils-1.0.3.dist-info → speedy_utils-1.0.5.dist-info}/WHEEL +1 -2
- speedy_utils-1.0.5.dist-info/entry_points.txt +3 -0
- speedy/__init__.py +0 -53
- speedy/common/clock.py +0 -68
- speedy/common/utils_cache.py +0 -170
- speedy/common/utils_print.py +0 -138
- speedy/multi_worker.py +0 -121
- speedy_utils-1.0.3.dist-info/METADATA +0 -22
- speedy_utils-1.0.3.dist-info/RECORD +0 -12
- speedy_utils-1.0.3.dist-info/top_level.txt +0 -1
llm_utils/__init__.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from .chat_format import (
|
|
2
|
+
transform_messages,
|
|
3
|
+
transform_messages_to_chatml,
|
|
4
|
+
display_chat_messages_as_html,
|
|
5
|
+
get_conversation_one_turn,
|
|
6
|
+
display_diff_two_string,
|
|
7
|
+
display_conversations,
|
|
8
|
+
build_chatml_input,
|
|
9
|
+
format_msgs,
|
|
10
|
+
)
|
|
11
|
+
from .lm import LM
|
|
12
|
+
from .group_messages import (
|
|
13
|
+
split_indices_by_length,
|
|
14
|
+
group_messages_by_len,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
__all__ = [
|
|
18
|
+
"transform_messages",
|
|
19
|
+
"transform_messages_to_chatml",
|
|
20
|
+
"display_chat_messages_as_html",
|
|
21
|
+
"get_conversation_one_turn",
|
|
22
|
+
"display_diff_two_string",
|
|
23
|
+
"display_conversations",
|
|
24
|
+
"build_chatml_input",
|
|
25
|
+
"format_msgs",
|
|
26
|
+
"split_indices_by_length",
|
|
27
|
+
"group_messages_by_len",
|
|
28
|
+
"LM",
|
|
29
|
+
]
|
llm_utils/chat_format.py
ADDED
|
@@ -0,0 +1,427 @@
|
|
|
1
|
+
from copy import deepcopy
|
|
2
|
+
from typing import Dict, List, Literal, Union
|
|
3
|
+
|
|
4
|
+
from IPython.display import HTML, Markdown, display
|
|
5
|
+
from loguru import logger
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def identify_format(item):
|
|
9
|
+
if isinstance(item, list) and "role" in item[0]:
|
|
10
|
+
return "chatml"
|
|
11
|
+
if isinstance(item, dict):
|
|
12
|
+
if "conversations" in item:
|
|
13
|
+
return "sharegpt"
|
|
14
|
+
raise ValueError(
|
|
15
|
+
f"The format of the item is not recognized. \n{type(item)=}, \n{item=}"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _transform_sharegpt_to_chatml(
|
|
20
|
+
item, default_system_message="You are a helpful assistant.", print_msg=False
|
|
21
|
+
):
|
|
22
|
+
# if isinstance(item, list):
|
|
23
|
+
# return [_transform_sharegpt_to_chatml(item) for item in item]
|
|
24
|
+
assert isinstance(
|
|
25
|
+
item, dict
|
|
26
|
+
), "The item is not in the correct format. Please check the format of the item."
|
|
27
|
+
|
|
28
|
+
messages = []
|
|
29
|
+
system_msg = item.get("system", "")
|
|
30
|
+
if system_msg:
|
|
31
|
+
messages.append({"role": "system", "content": system_msg})
|
|
32
|
+
elif default_system_message:
|
|
33
|
+
messages.append({"role": "system", "content": default_system_message})
|
|
34
|
+
conversations = item.get("conversations", [])
|
|
35
|
+
if hasattr(conversations, "tolist"):
|
|
36
|
+
conversations = conversations.tolist()
|
|
37
|
+
# import ipdb; ipdb.set_trace()
|
|
38
|
+
assert conversations, "The item does not have any conversations."
|
|
39
|
+
for conversation in item.get("conversations", []):
|
|
40
|
+
role = conversation["from"]
|
|
41
|
+
content = conversation["value"]
|
|
42
|
+
messages.append({"role": role, "content": content})
|
|
43
|
+
|
|
44
|
+
return messages
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def transform_messages(
|
|
48
|
+
item,
|
|
49
|
+
frm="chatml",
|
|
50
|
+
to="text",
|
|
51
|
+
add_generation_prompt=True,
|
|
52
|
+
tokenizer=None,
|
|
53
|
+
assistant_prefix=None,
|
|
54
|
+
):
|
|
55
|
+
assert to in [
|
|
56
|
+
"chatml",
|
|
57
|
+
"text",
|
|
58
|
+
"sharegpt",
|
|
59
|
+
"simulated_chat",
|
|
60
|
+
], "The output format is not recognized. Please specify the output format."
|
|
61
|
+
item = deepcopy(item)
|
|
62
|
+
|
|
63
|
+
if tokenizer is not None:
|
|
64
|
+
assert frm == "chatml", "Tokenizer is only supported for chatml format."
|
|
65
|
+
prompt = tokenizer.apply_chat_template(
|
|
66
|
+
item, tokenize=False, add_generation_prompt=True
|
|
67
|
+
)
|
|
68
|
+
assert isinstance(prompt, str), "Prompt must be a string."
|
|
69
|
+
if assistant_prefix:
|
|
70
|
+
prompt += f"{assistant_prefix}"
|
|
71
|
+
return prompt
|
|
72
|
+
|
|
73
|
+
if frm != to:
|
|
74
|
+
# convert item to chatml format
|
|
75
|
+
chatml_messages = transform_messages_to_chatml(item, input_format=frm)
|
|
76
|
+
if to == "sharegpt":
|
|
77
|
+
if chatml_messages[0]["role"] == "system":
|
|
78
|
+
system_message = chatml_messages[0]["content"]
|
|
79
|
+
ret = {"conversations": [], "system": system_message.strip()}
|
|
80
|
+
for message in chatml_messages[1:]:
|
|
81
|
+
ret["conversations"].append(
|
|
82
|
+
{"from": message["role"], "value": message["content"]}
|
|
83
|
+
)
|
|
84
|
+
else:
|
|
85
|
+
system_message = "You are a helpful assistant."
|
|
86
|
+
ret = {"conversations": [], "system": system_message.strip()}
|
|
87
|
+
for message in chatml_messages:
|
|
88
|
+
ret["conversations"].append(
|
|
89
|
+
{"from": message["role"], "value": message["content"]}
|
|
90
|
+
)
|
|
91
|
+
return ret
|
|
92
|
+
elif to == "chatml":
|
|
93
|
+
return _transform_sharegpt_to_chatml(item)
|
|
94
|
+
elif to == "text":
|
|
95
|
+
text = ""
|
|
96
|
+
for turn in chatml_messages:
|
|
97
|
+
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
|
|
98
|
+
if add_generation_prompt:
|
|
99
|
+
text += "<|im_start|>assistant\n"
|
|
100
|
+
return text
|
|
101
|
+
elif to == "simulated_chat":
|
|
102
|
+
text = "<role> Given the simulated chat, you are the assistant. Lets continue the conversation. \n\n"
|
|
103
|
+
for turn in chatml_messages:
|
|
104
|
+
prefix = {
|
|
105
|
+
"user": "Human",
|
|
106
|
+
"assistant": "AI",
|
|
107
|
+
"system": "System",
|
|
108
|
+
"function": "Function",
|
|
109
|
+
}.get(turn["role"])
|
|
110
|
+
text += f"{prefix}: {turn['content'].strip()}\n\n"
|
|
111
|
+
if add_generation_prompt:
|
|
112
|
+
text += "AI: [continue the conversation here]"
|
|
113
|
+
return text
|
|
114
|
+
else:
|
|
115
|
+
raise ValueError(f"{to} is not supported.")
|
|
116
|
+
|
|
117
|
+
else:
|
|
118
|
+
return item
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
def transform_messages_to_chatml(input_data, input_format="auto"):
|
|
122
|
+
if input_format == "auto":
|
|
123
|
+
input_data = raw_data = deepcopy(input_data)
|
|
124
|
+
if isinstance(input_data, list):
|
|
125
|
+
input_format = "chatlm"
|
|
126
|
+
assert (
|
|
127
|
+
input_data[0].get("role") is not None
|
|
128
|
+
), "The input format is not recognized. Please specify the input format."
|
|
129
|
+
elif isinstance(input_data, dict):
|
|
130
|
+
input_data = _transform_sharegpt_to_chatml(input_data)
|
|
131
|
+
input_format = "sharegpt"
|
|
132
|
+
elif isinstance(input_data, str):
|
|
133
|
+
# assume it has format <|im_start|>role\n content<|im_end|> use regex to parse
|
|
134
|
+
assert (
|
|
135
|
+
"<|im_end|>" in input_data
|
|
136
|
+
), "The input format is not recognized. Please specify the input format."
|
|
137
|
+
input_format = "chatlm"
|
|
138
|
+
parts = input_data.split("<|im_end|>")
|
|
139
|
+
# for each part, split by <|im_start|> to get role and content
|
|
140
|
+
input_data = []
|
|
141
|
+
for part in parts:
|
|
142
|
+
if not part.strip():
|
|
143
|
+
continue
|
|
144
|
+
role = part.split("<|im_start|>")[1].split("\n")[0]
|
|
145
|
+
# content is after |>role\n
|
|
146
|
+
content = part.split(f"<|im_start|>{role}\n")[1]
|
|
147
|
+
content = content.split("<|im_end|>")[0]
|
|
148
|
+
input_data.append({"role": role.strip(), "content": content.strip()})
|
|
149
|
+
|
|
150
|
+
return input_data
|
|
151
|
+
|
|
152
|
+
|
|
153
|
+
def display_chat_messages_as_html(
|
|
154
|
+
msgs, return_html=False, file="/tmp/conversation.html", theme="light"
|
|
155
|
+
):
|
|
156
|
+
if isinstance(msgs, dict) and "messages" in msgs:
|
|
157
|
+
msgs = msgs["messages"]
|
|
158
|
+
|
|
159
|
+
# ensure is a list of dict each dict has role and content
|
|
160
|
+
assert isinstance(msgs, list) and all(
|
|
161
|
+
isinstance(msg, dict) and "role" in msg and "content" in msg for msg in msgs
|
|
162
|
+
), "The input format is not recognized. Please specify the input format."
|
|
163
|
+
|
|
164
|
+
color_scheme = {
|
|
165
|
+
"system": {
|
|
166
|
+
"background": "#FFAAAA",
|
|
167
|
+
"text": "#000000",
|
|
168
|
+
}, # Light red background, black text
|
|
169
|
+
"user": {
|
|
170
|
+
"background": "#AAFFAA",
|
|
171
|
+
"text": "#000000",
|
|
172
|
+
}, # Light green background, black text
|
|
173
|
+
"assistant": {
|
|
174
|
+
"background": "#AAAAFF",
|
|
175
|
+
"text": "#000000",
|
|
176
|
+
}, # Light blue background, black text
|
|
177
|
+
"function": {
|
|
178
|
+
"background": "#AFFFFF",
|
|
179
|
+
"text": "#000000",
|
|
180
|
+
}, # Light yellow background, black text
|
|
181
|
+
"default": {
|
|
182
|
+
"background": "#FFFFFF",
|
|
183
|
+
"text": "#000000",
|
|
184
|
+
}, # White background, black text
|
|
185
|
+
"tool": {"background": "#FFAAFF", "text": "#000000"},
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
conversation_html = ""
|
|
189
|
+
for i, message in enumerate(msgs):
|
|
190
|
+
role = message["role"]
|
|
191
|
+
content = message.get("content", "")
|
|
192
|
+
if not content:
|
|
193
|
+
content = ""
|
|
194
|
+
|
|
195
|
+
tool_calls = message.get("tool_calls")
|
|
196
|
+
if not content and tool_calls:
|
|
197
|
+
# each tool call comes with name, and args
|
|
198
|
+
|
|
199
|
+
for tool_call in tool_calls:
|
|
200
|
+
tool_call = tool_call["function"]
|
|
201
|
+
name = tool_call["name"]
|
|
202
|
+
args = tool_call["arguments"]
|
|
203
|
+
content += "Tool: " + name + "\n" + "Arguments: " + str(args)
|
|
204
|
+
|
|
205
|
+
# Replace newlines with <br> tags
|
|
206
|
+
content = content.replace("\n", "<br>")
|
|
207
|
+
|
|
208
|
+
# Replace tabs with entities
|
|
209
|
+
content = content.replace("\t", " ")
|
|
210
|
+
|
|
211
|
+
# Replace multiple consecutive spaces with entities
|
|
212
|
+
content = content.replace(" ", " ")
|
|
213
|
+
# keep html tag without escaping
|
|
214
|
+
# content = content.replace('<', '<')
|
|
215
|
+
|
|
216
|
+
content = (
|
|
217
|
+
content.replace("<br>", "TEMP_BR")
|
|
218
|
+
.replace("<", "<")
|
|
219
|
+
.replace(">", ">")
|
|
220
|
+
.replace("TEMP_BR", "<br>")
|
|
221
|
+
)
|
|
222
|
+
|
|
223
|
+
if role in color_scheme:
|
|
224
|
+
background_color = color_scheme[role]["background"]
|
|
225
|
+
text_color = color_scheme[role]["text"]
|
|
226
|
+
else:
|
|
227
|
+
background_color = color_scheme["default"]["background"]
|
|
228
|
+
text_color = color_scheme["default"]["text"]
|
|
229
|
+
|
|
230
|
+
if role == "system":
|
|
231
|
+
conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>System:</strong><br><pre id="system-{i}">{content}</pre></div>'
|
|
232
|
+
elif role == "user":
|
|
233
|
+
conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>User:</strong><br><pre id="user-{i}">{content}</pre></div>'
|
|
234
|
+
elif role == "assistant":
|
|
235
|
+
conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>Assistant:</strong><br><pre id="assistant-{i}">{content}</pre></div>'
|
|
236
|
+
elif role == "function":
|
|
237
|
+
conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>Function:</strong><br><pre id="function-{i}">{content}</pre></div>'
|
|
238
|
+
else:
|
|
239
|
+
# logger.warning(f"Unknown role: {role}")
|
|
240
|
+
conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>{role}:</strong><br><pre id="{role}-{i}">{content}</pre><br><button onclick="copyContent(\'{role}-{i}\')">Copy</button></div>'
|
|
241
|
+
|
|
242
|
+
html = f"""
|
|
243
|
+
<html>
|
|
244
|
+
<head>
|
|
245
|
+
<style>
|
|
246
|
+
pre {{
|
|
247
|
+
white-space: pre-wrap;
|
|
248
|
+
}}
|
|
249
|
+
</style>
|
|
250
|
+
</head>
|
|
251
|
+
<body>
|
|
252
|
+
{conversation_html}
|
|
253
|
+
<script>
|
|
254
|
+
function copyContent(elementId) {{
|
|
255
|
+
var element = document.getElementById(elementId);
|
|
256
|
+
var text = element.innerText;
|
|
257
|
+
navigator.clipboard.writeText(text)
|
|
258
|
+
.then(function() {{
|
|
259
|
+
alert("Content copied to clipboard!");
|
|
260
|
+
}})
|
|
261
|
+
.catch(function(error) {{
|
|
262
|
+
console.error("Error copying content: ", error);
|
|
263
|
+
}});
|
|
264
|
+
}}
|
|
265
|
+
</script>
|
|
266
|
+
</body>
|
|
267
|
+
</html>
|
|
268
|
+
"""
|
|
269
|
+
|
|
270
|
+
if file:
|
|
271
|
+
with open(file, "w") as f:
|
|
272
|
+
f.write(html)
|
|
273
|
+
if return_html:
|
|
274
|
+
return html
|
|
275
|
+
else:
|
|
276
|
+
display(HTML(html))
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def get_conversation_one_turn(
|
|
280
|
+
system_msg=None,
|
|
281
|
+
user_msg=None,
|
|
282
|
+
assistant_msg=None,
|
|
283
|
+
assistant_prefix=None,
|
|
284
|
+
return_format="chatml",
|
|
285
|
+
):
|
|
286
|
+
messages = []
|
|
287
|
+
if system_msg:
|
|
288
|
+
messages.append({"role": "system", "content": system_msg})
|
|
289
|
+
messages.append({"role": "user", "content": user_msg})
|
|
290
|
+
if assistant_msg:
|
|
291
|
+
messages.append({"role": "assistant", "content": assistant_msg})
|
|
292
|
+
if assistant_prefix is not None:
|
|
293
|
+
assert (
|
|
294
|
+
return_format != "chatml"
|
|
295
|
+
), "Change return_format to 'text' if you want to use assistant_prefix"
|
|
296
|
+
assert messages[-1]["role"] == "user"
|
|
297
|
+
msg = transform_messages(messages, "chatml", "text", add_generation_prompt=True)
|
|
298
|
+
msg += assistant_prefix
|
|
299
|
+
return msg
|
|
300
|
+
else:
|
|
301
|
+
assert return_format in ["chatml"]
|
|
302
|
+
return messages
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
from difflib import ndiff
|
|
306
|
+
|
|
307
|
+
from IPython.display import HTML
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def display_diff_two_string(text1, text2):
|
|
311
|
+
# Split the texts into lines
|
|
312
|
+
lines1 = text1.splitlines()
|
|
313
|
+
lines2 = text2.splitlines()
|
|
314
|
+
|
|
315
|
+
# Perform the diff
|
|
316
|
+
diff = list(ndiff(lines1, lines2))
|
|
317
|
+
|
|
318
|
+
# Create the HTML table
|
|
319
|
+
table_rows = []
|
|
320
|
+
for line in diff:
|
|
321
|
+
if line.startswith("- "):
|
|
322
|
+
table_rows.append(
|
|
323
|
+
f'<tr><td style="background-color: #FFCCCB;">{line[2:]}</td><td></td></tr>'
|
|
324
|
+
)
|
|
325
|
+
elif line.startswith("+ "):
|
|
326
|
+
table_rows.append(
|
|
327
|
+
f'<tr><td></td><td style="background-color: #CCFFCC;">{line[2:]}</td></tr>'
|
|
328
|
+
)
|
|
329
|
+
elif line.startswith("? "):
|
|
330
|
+
continue
|
|
331
|
+
else:
|
|
332
|
+
table_rows.append(f"<tr><td>{line}</td><td>{line}</td></tr>")
|
|
333
|
+
|
|
334
|
+
table_html = '<table style="width: 100%; border-collapse: collapse;">'
|
|
335
|
+
table_html += '<tr><th style="width: 50%; text-align: left;">Text 1</th><th style="width: 50%; text-align: left;">Text 2</th></tr>'
|
|
336
|
+
table_html += "".join(table_rows)
|
|
337
|
+
table_html += "</table>"
|
|
338
|
+
|
|
339
|
+
# Display the HTML table
|
|
340
|
+
display(HTML(table_html))
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
def display_conversations(data1, data2, theme="light"):
|
|
344
|
+
html1 = display_chat_messages_as_html(data1, return_html=True, theme=theme)
|
|
345
|
+
html2 = display_chat_messages_as_html(data2, return_html=True, theme=theme)
|
|
346
|
+
|
|
347
|
+
html = f"""
|
|
348
|
+
<html>
|
|
349
|
+
<head>
|
|
350
|
+
<style>
|
|
351
|
+
table {{
|
|
352
|
+
width: 100%;
|
|
353
|
+
border-collapse: collapse;
|
|
354
|
+
}}
|
|
355
|
+
td {{
|
|
356
|
+
width: 50%;
|
|
357
|
+
vertical-align: top;
|
|
358
|
+
padding: 10px;
|
|
359
|
+
}}
|
|
360
|
+
</style>
|
|
361
|
+
</head>
|
|
362
|
+
<body>
|
|
363
|
+
<table>
|
|
364
|
+
<tr>
|
|
365
|
+
<td>{html1}</td>
|
|
366
|
+
<td>{html2}</td>
|
|
367
|
+
</tr>
|
|
368
|
+
</table>
|
|
369
|
+
</body>
|
|
370
|
+
</html>
|
|
371
|
+
"""
|
|
372
|
+
display(HTML(html))
|
|
373
|
+
|
|
374
|
+
|
|
375
|
+
from typing import Callable, Dict, List
|
|
376
|
+
|
|
377
|
+
|
|
378
|
+
def build_chatml_input(template: str, params: List[str]) -> Callable:
|
|
379
|
+
def formator(**kwargs) -> List[List[Dict[str, str]]]:
|
|
380
|
+
system_msg = kwargs.get("system_msg", None)
|
|
381
|
+
# remove system
|
|
382
|
+
kwargs.pop("system_msg", None)
|
|
383
|
+
# Ensure all required parameters are present in kwargs
|
|
384
|
+
for param in params:
|
|
385
|
+
if param not in kwargs:
|
|
386
|
+
raise ValueError(f"Missing parameter: {param}")
|
|
387
|
+
|
|
388
|
+
# Use the **kwargs directly in the format method
|
|
389
|
+
content = template.format(**kwargs)
|
|
390
|
+
msgs = []
|
|
391
|
+
if system_msg:
|
|
392
|
+
msgs += [{"role": "system", "content": system_msg}]
|
|
393
|
+
msgs += [{"role": "user", "content": content}]
|
|
394
|
+
return msgs
|
|
395
|
+
|
|
396
|
+
return formator
|
|
397
|
+
|
|
398
|
+
|
|
399
|
+
def _color_text(text, color_code):
|
|
400
|
+
"""Helper function to color text based on the provided ANSI color code."""
|
|
401
|
+
return f"\033[{color_code}m{text}\033[0m"
|
|
402
|
+
|
|
403
|
+
|
|
404
|
+
def format_msgs(messages):
|
|
405
|
+
"""Formats the role and content of a list of messages into a string."""
|
|
406
|
+
messages = transform_messages_to_chatml(messages)
|
|
407
|
+
output = []
|
|
408
|
+
|
|
409
|
+
for msg in messages:
|
|
410
|
+
role = msg.get("role", "unknown").lower()
|
|
411
|
+
content = msg.get("content", "").strip()
|
|
412
|
+
output.append(f"{role.capitalize()}:\t{content}")
|
|
413
|
+
output.append("---")
|
|
414
|
+
|
|
415
|
+
return "\n".join(output)
|
|
416
|
+
|
|
417
|
+
|
|
418
|
+
__all__ = [
|
|
419
|
+
"transform_messages",
|
|
420
|
+
"transform_messages_to_chatml",
|
|
421
|
+
"display_chat_messages_as_html",
|
|
422
|
+
"get_conversation_one_turn",
|
|
423
|
+
"display_diff_two_string",
|
|
424
|
+
"display_conversations",
|
|
425
|
+
"build_chatml_input",
|
|
426
|
+
"format_msgs",
|
|
427
|
+
]
|
|
@@ -0,0 +1,120 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import random
|
|
4
|
+
from typing import Sequence, Optional
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import pandas as pd
|
|
8
|
+
from tabulate import tabulate
|
|
9
|
+
|
|
10
|
+
from speedy_utils import multi_thread
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def split_indices_by_length(
|
|
14
|
+
lengths: Sequence[int],
|
|
15
|
+
batch_size_by_mean_length: int,
|
|
16
|
+
random_seed: int,
|
|
17
|
+
verbose: bool,
|
|
18
|
+
shuffle: bool,
|
|
19
|
+
mean_length: Optional[int] = None,
|
|
20
|
+
) -> list[list[int]]:
|
|
21
|
+
"""
|
|
22
|
+
Split indices into batches so that the sum of lengths in each batch does not exceed max_batch_length.
|
|
23
|
+
"""
|
|
24
|
+
if mean_length is None:
|
|
25
|
+
mean_length = int(np.mean(lengths))
|
|
26
|
+
max_batch_length: int = mean_length * batch_size_by_mean_length
|
|
27
|
+
|
|
28
|
+
r: random.Random = random.Random(random_seed)
|
|
29
|
+
indices: list[int] = list(range(len(lengths)))
|
|
30
|
+
|
|
31
|
+
if shuffle:
|
|
32
|
+
r.shuffle(indices)
|
|
33
|
+
|
|
34
|
+
batches: list[list[int]] = []
|
|
35
|
+
current_batch: list[int] = []
|
|
36
|
+
current_batch_length: int = 0
|
|
37
|
+
|
|
38
|
+
for idx in indices:
|
|
39
|
+
length: int = lengths[idx]
|
|
40
|
+
if current_batch_length + length <= max_batch_length:
|
|
41
|
+
current_batch.append(idx)
|
|
42
|
+
current_batch_length += length
|
|
43
|
+
else:
|
|
44
|
+
batches.append(current_batch)
|
|
45
|
+
current_batch = [idx]
|
|
46
|
+
current_batch_length = length
|
|
47
|
+
|
|
48
|
+
if current_batch:
|
|
49
|
+
batches.append(current_batch)
|
|
50
|
+
|
|
51
|
+
if verbose:
|
|
52
|
+
batch_lengths: list[int] = [
|
|
53
|
+
sum(lengths[idx] for idx in batch) for batch in batches
|
|
54
|
+
]
|
|
55
|
+
desc = pd.Series(batch_lengths).describe()
|
|
56
|
+
|
|
57
|
+
table = [
|
|
58
|
+
["New avg item len", desc["mean"]],
|
|
59
|
+
["Number groups", len(batches)],
|
|
60
|
+
["Max length", max_batch_length],
|
|
61
|
+
]
|
|
62
|
+
|
|
63
|
+
print(tabulate(table, headers=["Metric", "Value"], tablefmt="pretty"))
|
|
64
|
+
|
|
65
|
+
return batches
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
def group_messages_by_len(
|
|
69
|
+
messages: Sequence[dict],
|
|
70
|
+
model_name: str = "Qwen/Qwen2.5-7B-Instruct",
|
|
71
|
+
batch_size: int = 4,
|
|
72
|
+
mean_length: int = 512,
|
|
73
|
+
) -> list[dict]:
|
|
74
|
+
"""
|
|
75
|
+
Groups messages into batches based on token length and concatenates them.
|
|
76
|
+
"""
|
|
77
|
+
if messages is None:
|
|
78
|
+
raise ValueError("messages parameter cannot be None")
|
|
79
|
+
from transformers.models.auto.tokenization_auto import AutoTokenizer
|
|
80
|
+
|
|
81
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
82
|
+
|
|
83
|
+
def create_batches(messages: Sequence[dict]) -> list[dict]:
|
|
84
|
+
def get_token_length(message: dict) -> int:
|
|
85
|
+
ids = tokenizer.apply_chat_template(message["messages"][1:], tokenize=True)
|
|
86
|
+
return len(ids)
|
|
87
|
+
|
|
88
|
+
lengths: list[int] = multi_thread(get_token_length, messages, workers=64)
|
|
89
|
+
list_ids: list[list[int]] = split_indices_by_length(
|
|
90
|
+
lengths,
|
|
91
|
+
batch_size,
|
|
92
|
+
random_seed=0,
|
|
93
|
+
verbose=True,
|
|
94
|
+
shuffle=True,
|
|
95
|
+
mean_length=mean_length,
|
|
96
|
+
)
|
|
97
|
+
concatenated_messages: list[dict] = []
|
|
98
|
+
|
|
99
|
+
def concatenate_messages(conversations: Sequence[Sequence[dict]]) -> dict:
|
|
100
|
+
system_message = conversations[0][0]
|
|
101
|
+
turns: list[dict] = []
|
|
102
|
+
for conv in conversations:
|
|
103
|
+
turns.extend(conv[1:])
|
|
104
|
+
return {"messages": [system_message] + turns}
|
|
105
|
+
|
|
106
|
+
for batch_ids in list_ids:
|
|
107
|
+
if not batch_ids:
|
|
108
|
+
continue
|
|
109
|
+
conversations = [messages[i]["messages"] for i in batch_ids]
|
|
110
|
+
concatenated_messages.append(concatenate_messages(conversations))
|
|
111
|
+
return concatenated_messages
|
|
112
|
+
|
|
113
|
+
chunked_messages: list[dict] = create_batches(messages)
|
|
114
|
+
return chunked_messages
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
__all__ = [
|
|
118
|
+
"split_indices_by_length",
|
|
119
|
+
"group_messages_by_len",
|
|
120
|
+
]
|