speedy-utils 0.1.28__tar.gz → 0.1.30__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (28) hide show
  1. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/PKG-INFO +3 -27
  2. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/README.md +0 -25
  3. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/pyproject.toml +7 -3
  4. speedy_utils-0.1.30/src/llm_utils/__init__.py +30 -0
  5. speedy_utils-0.1.30/src/llm_utils/chat_format.py +427 -0
  6. speedy_utils-0.1.30/src/llm_utils/group_messages.py +119 -0
  7. speedy_utils-0.1.30/src/llm_utils/lm.py +742 -0
  8. speedy_utils-0.1.28/src/speedy_utils/common/__init__.py → speedy_utils-0.1.30/src/llm_utils/lm_classification.py +0 -0
  9. speedy_utils-0.1.30/src/llm_utils/load_chat_dataset.py +41 -0
  10. speedy_utils-0.1.30/src/llm_utils/scripts/vllm_load_balancer.py +353 -0
  11. speedy_utils-0.1.30/src/llm_utils/scripts/vllm_serve.py +482 -0
  12. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/__init__.py +1 -2
  13. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/all.py +0 -2
  14. {speedy_utils-0.1.28/src/speedy_utils/multi_worker → speedy_utils-0.1.30/src/speedy_utils/common}/__init__.py +0 -0
  15. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/clock.py +10 -0
  16. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/utils_misc.py +0 -1
  17. speedy_utils-0.1.30/src/speedy_utils/multi_worker/__init__.py +0 -0
  18. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/multi_worker/thread.py +22 -6
  19. speedy_utils-0.1.28/src/speedy_utils/common/dataclass_parser.py +0 -101
  20. speedy_utils-0.1.28/src/speedy_utils/multi_worker/_handle_inputs.py +0 -50
  21. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/function_decorator.py +0 -0
  22. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/logger.py +0 -0
  23. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/report_manager.py +0 -0
  24. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/utils_cache.py +0 -0
  25. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/utils_io.py +0 -0
  26. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/common/utils_print.py +0 -0
  27. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/multi_worker/process.py +0 -0
  28. {speedy_utils-0.1.28 → speedy_utils-0.1.30}/src/speedy_utils/scripts/mpython.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 0.1.28
3
+ Version: 0.1.30
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -19,11 +19,12 @@ Requires-Dist: fastprogress
19
19
  Requires-Dist: freezegun (>=1.5.1,<2.0.0)
20
20
  Requires-Dist: ipdb
21
21
  Requires-Dist: ipywidgets
22
- Requires-Dist: json-repair
22
+ Requires-Dist: json-repair (>=0.40.0,<0.41.0)
23
23
  Requires-Dist: jupyterlab
24
24
  Requires-Dist: loguru
25
25
  Requires-Dist: matplotlib
26
26
  Requires-Dist: numpy
27
+ Requires-Dist: packaging (>=23.2,<25)
27
28
  Requires-Dist: pandas
28
29
  Requires-Dist: pydantic
29
30
  Requires-Dist: requests
@@ -261,29 +262,6 @@ Ensure all dependencies are installed before running tests:
261
262
  pip install -r requirements.txt
262
263
  ```
263
264
 
264
- ## Data Arguments
265
-
266
- Define and parse data arguments using a dataclass:
267
-
268
- ```python
269
- from dataclasses import dataclass
270
- from speedy_utils.common.dataclass_parser import ArgsParser
271
-
272
- @dataclass
273
- class ExampleArgs(ArgsParser):
274
- from_peft: str = "./outputs/llm_hn_qw32b/hn_results_r3/"
275
- model_name_or_path: str = "Qwen/Qwen2.5-32B-Instruct-AWQ"
276
- use_fp16: bool = False
277
- batch_size: int = 1
278
- max_length: int = 512
279
- cache_dir: str = ".cache/run_embeds"
280
- output_dir: str = ".cache"
281
- input_file: str = ".cache/doc.csv"
282
- output_name: str = "qw32b_r3"
283
-
284
- args = ExampleArgs.parse_args()
285
- print(args)
286
- ```
287
265
 
288
266
  Run the script to parse and display the arguments:
289
267
 
@@ -299,5 +277,3 @@ Example output:
299
277
 
300
278
  Please ensure your code adheres to the project's coding standards and includes appropriate tests.
301
279
 
302
-
303
-
@@ -225,29 +225,6 @@ Ensure all dependencies are installed before running tests:
225
225
  pip install -r requirements.txt
226
226
  ```
227
227
 
228
- ## Data Arguments
229
-
230
- Define and parse data arguments using a dataclass:
231
-
232
- ```python
233
- from dataclasses import dataclass
234
- from speedy_utils.common.dataclass_parser import ArgsParser
235
-
236
- @dataclass
237
- class ExampleArgs(ArgsParser):
238
- from_peft: str = "./outputs/llm_hn_qw32b/hn_results_r3/"
239
- model_name_or_path: str = "Qwen/Qwen2.5-32B-Instruct-AWQ"
240
- use_fp16: bool = False
241
- batch_size: int = 1
242
- max_length: int = 512
243
- cache_dir: str = ".cache/run_embeds"
244
- output_dir: str = ".cache"
245
- input_file: str = ".cache/doc.csv"
246
- output_name: str = "qw32b_r3"
247
-
248
- args = ExampleArgs.parse_args()
249
- print(args)
250
- ```
251
228
 
252
229
  Run the script to parse and display the arguments:
253
230
 
@@ -262,5 +239,3 @@ Example output:
262
239
  | from_peft | ./outputs/llm_hn_qw32b/hn_results_r3/ |
263
240
 
264
241
  Please ensure your code adheres to the project's coding standards and includes appropriate tests.
265
-
266
-
@@ -1,11 +1,14 @@
1
1
  [tool.poetry]
2
2
  name = "speedy-utils"
3
- version = "0.1.28"
3
+ version = "0.1.30"
4
4
  description = "Fast and easy-to-use package for data science"
5
5
  authors = ["AnhVTH <anhvth.226@gmail.com>"]
6
6
  readme = "README.md"
7
7
  homepage = "https://github.com/anhvth/speedy"
8
- packages = [{ include = "speedy_utils", from = "src" }]
8
+ packages = [
9
+ { include = "speedy_utils", from = "src" },
10
+ { include = "llm_utils", from = "src" },
11
+ ]
9
12
 
10
13
  [build-system]
11
14
  requires = ["poetry-core>=1.0.0"]
@@ -51,9 +54,10 @@ pydantic = "*"
51
54
  tqdm = "*"
52
55
  cachetools = "*"
53
56
  bump2version = "*"
54
- json-repair = "*"
57
+ json-repair = ">=0.40.0,<0.41.0"
55
58
  fastprogress = "*"
56
59
  freezegun = "^1.5.1"
60
+ packaging = ">=23.2,<25"
57
61
 
58
62
  [tool.poetry.scripts]
59
63
  mpython = "speedy_utils.scripts.mpython:main"
@@ -0,0 +1,30 @@
1
+ from .chat_format import (
2
+ transform_messages,
3
+ transform_messages_to_chatml,
4
+ display_chat_messages_as_html,
5
+ get_conversation_one_turn,
6
+ display_diff_two_string,
7
+ display_conversations,
8
+ build_chatml_input,
9
+ format_msgs,
10
+ )
11
+ from .lm import OAI_LM, LM
12
+ from .group_messages import (
13
+ split_indices_by_length,
14
+ group_messages_by_len,
15
+ )
16
+
17
+ __all__ = [
18
+ "transform_messages",
19
+ "transform_messages_to_chatml",
20
+ "display_chat_messages_as_html",
21
+ "get_conversation_one_turn",
22
+ "display_diff_two_string",
23
+ "display_conversations",
24
+ "build_chatml_input",
25
+ "format_msgs",
26
+ "OAI_LM",
27
+ "LM",
28
+ "split_indices_by_length",
29
+ "group_messages_by_len",
30
+ ]
@@ -0,0 +1,427 @@
1
+ from copy import deepcopy
2
+ from typing import Dict, List, Literal, Union
3
+
4
+ from IPython.display import HTML, Markdown, display
5
+ from loguru import logger
6
+
7
+
8
+ def identify_format(item):
9
+ if isinstance(item, list) and "role" in item[0]:
10
+ return "chatml"
11
+ if isinstance(item, dict):
12
+ if "conversations" in item:
13
+ return "sharegpt"
14
+ raise ValueError(
15
+ f"The format of the item is not recognized. \n{type(item)=}, \n{item=}"
16
+ )
17
+
18
+
19
+ def _transform_sharegpt_to_chatml(
20
+ item, default_system_message="You are a helpful assistant.", print_msg=False
21
+ ):
22
+ # if isinstance(item, list):
23
+ # return [_transform_sharegpt_to_chatml(item) for item in item]
24
+ assert isinstance(
25
+ item, dict
26
+ ), "The item is not in the correct format. Please check the format of the item."
27
+
28
+ messages = []
29
+ system_msg = item.get("system", "")
30
+ if system_msg:
31
+ messages.append({"role": "system", "content": system_msg})
32
+ elif default_system_message:
33
+ messages.append({"role": "system", "content": default_system_message})
34
+ conversations = item.get("conversations", [])
35
+ if hasattr(conversations, "tolist"):
36
+ conversations = conversations.tolist()
37
+ # import ipdb; ipdb.set_trace()
38
+ assert conversations, "The item does not have any conversations."
39
+ for conversation in item.get("conversations", []):
40
+ role = conversation["from"]
41
+ content = conversation["value"]
42
+ messages.append({"role": role, "content": content})
43
+
44
+ return messages
45
+
46
+
47
+ def transform_messages(
48
+ item,
49
+ frm="chatml",
50
+ to="text",
51
+ add_generation_prompt=True,
52
+ tokenizer=None,
53
+ assistant_prefix=None,
54
+ ):
55
+ assert to in [
56
+ "chatml",
57
+ "text",
58
+ "sharegpt",
59
+ "simulated_chat",
60
+ ], "The output format is not recognized. Please specify the output format."
61
+ item = deepcopy(item)
62
+
63
+ if tokenizer is not None:
64
+ assert frm == "chatml", "Tokenizer is only supported for chatml format."
65
+ prompt = tokenizer.apply_chat_template(
66
+ item, tokenize=False, add_generation_prompt=True
67
+ )
68
+ assert isinstance(prompt, str), "Prompt must be a string."
69
+ if assistant_prefix:
70
+ prompt += f"{assistant_prefix}"
71
+ return prompt
72
+
73
+ if frm != to:
74
+ # convert item to chatml format
75
+ chatml_messages = transform_messages_to_chatml(item, input_format=frm)
76
+ if to == "sharegpt":
77
+ if chatml_messages[0]["role"] == "system":
78
+ system_message = chatml_messages[0]["content"]
79
+ ret = {"conversations": [], "system": system_message.strip()}
80
+ for message in chatml_messages[1:]:
81
+ ret["conversations"].append(
82
+ {"from": message["role"], "value": message["content"]}
83
+ )
84
+ else:
85
+ system_message = "You are a helpful assistant."
86
+ ret = {"conversations": [], "system": system_message.strip()}
87
+ for message in chatml_messages:
88
+ ret["conversations"].append(
89
+ {"from": message["role"], "value": message["content"]}
90
+ )
91
+ return ret
92
+ elif to == "chatml":
93
+ return _transform_sharegpt_to_chatml(item)
94
+ elif to == "text":
95
+ text = ""
96
+ for turn in chatml_messages:
97
+ text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
98
+ if add_generation_prompt:
99
+ text += "<|im_start|>assistant\n"
100
+ return text
101
+ elif to == "simulated_chat":
102
+ text = "<role> Given the simulated chat, you are the assistant. Lets continue the conversation. \n\n"
103
+ for turn in chatml_messages:
104
+ prefix = {
105
+ "user": "Human",
106
+ "assistant": "AI",
107
+ "system": "System",
108
+ "function": "Function",
109
+ }.get(turn["role"])
110
+ text += f"{prefix}: {turn['content'].strip()}\n\n"
111
+ if add_generation_prompt:
112
+ text += "AI: [continue the conversation here]"
113
+ return text
114
+ else:
115
+ raise ValueError(f"{to} is not supported.")
116
+
117
+ else:
118
+ return item
119
+
120
+
121
+ def transform_messages_to_chatml(input_data, input_format="auto"):
122
+ if input_format == "auto":
123
+ input_data = raw_data = deepcopy(input_data)
124
+ if isinstance(input_data, list):
125
+ input_format = "chatlm"
126
+ assert (
127
+ input_data[0].get("role") is not None
128
+ ), "The input format is not recognized. Please specify the input format."
129
+ elif isinstance(input_data, dict):
130
+ input_data = _transform_sharegpt_to_chatml(input_data)
131
+ input_format = "sharegpt"
132
+ elif isinstance(input_data, str):
133
+ # assume it has format <|im_start|>role\n content<|im_end|> use regex to parse
134
+ assert (
135
+ "<|im_end|>" in input_data
136
+ ), "The input format is not recognized. Please specify the input format."
137
+ input_format = "chatlm"
138
+ parts = input_data.split("<|im_end|>")
139
+ # for each part, split by <|im_start|> to get role and content
140
+ input_data = []
141
+ for part in parts:
142
+ if not part.strip():
143
+ continue
144
+ role = part.split("<|im_start|>")[1].split("\n")[0]
145
+ # content is after |>role\n
146
+ content = part.split(f"<|im_start|>{role}\n")[1]
147
+ content = content.split("<|im_end|>")[0]
148
+ input_data.append({"role": role.strip(), "content": content.strip()})
149
+
150
+ return input_data
151
+
152
+
153
+ def display_chat_messages_as_html(
154
+ msgs, return_html=False, file="/tmp/conversation.html", theme="light"
155
+ ):
156
+ if isinstance(msgs, dict) and "messages" in msgs:
157
+ msgs = msgs["messages"]
158
+
159
+ # ensure is a list of dict each dict has role and content
160
+ assert isinstance(msgs, list) and all(
161
+ isinstance(msg, dict) and "role" in msg and "content" in msg for msg in msgs
162
+ ), "The input format is not recognized. Please specify the input format."
163
+
164
+ color_scheme = {
165
+ "system": {
166
+ "background": "#FFAAAA",
167
+ "text": "#000000",
168
+ }, # Light red background, black text
169
+ "user": {
170
+ "background": "#AAFFAA",
171
+ "text": "#000000",
172
+ }, # Light green background, black text
173
+ "assistant": {
174
+ "background": "#AAAAFF",
175
+ "text": "#000000",
176
+ }, # Light blue background, black text
177
+ "function": {
178
+ "background": "#AFFFFF",
179
+ "text": "#000000",
180
+ }, # Light yellow background, black text
181
+ "default": {
182
+ "background": "#FFFFFF",
183
+ "text": "#000000",
184
+ }, # White background, black text
185
+ "tool": {"background": "#FFAAFF", "text": "#000000"},
186
+ }
187
+
188
+ conversation_html = ""
189
+ for i, message in enumerate(msgs):
190
+ role = message["role"]
191
+ content = message.get("content", "")
192
+ if not content:
193
+ content = ""
194
+
195
+ tool_calls = message.get("tool_calls")
196
+ if not content and tool_calls:
197
+ # each tool call comes with name, and args
198
+
199
+ for tool_call in tool_calls:
200
+ tool_call = tool_call["function"]
201
+ name = tool_call["name"]
202
+ args = tool_call["arguments"]
203
+ content += "Tool: " + name + "\n" + "Arguments: " + str(args)
204
+
205
+ # Replace newlines with <br> tags
206
+ content = content.replace("\n", "<br>")
207
+
208
+ # Replace tabs with &nbsp; entities
209
+ content = content.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
210
+
211
+ # Replace multiple consecutive spaces with &nbsp; entities
212
+ content = content.replace(" ", "&nbsp;&nbsp;")
213
+ # keep html tag without escaping
214
+ # content = content.replace('&lt;', '<')
215
+
216
+ content = (
217
+ content.replace("<br>", "TEMP_BR")
218
+ .replace("<", "&lt;")
219
+ .replace(">", "&gt;")
220
+ .replace("TEMP_BR", "<br>")
221
+ )
222
+
223
+ if role in color_scheme:
224
+ background_color = color_scheme[role]["background"]
225
+ text_color = color_scheme[role]["text"]
226
+ else:
227
+ background_color = color_scheme["default"]["background"]
228
+ text_color = color_scheme["default"]["text"]
229
+
230
+ if role == "system":
231
+ conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>System:</strong><br><pre id="system-{i}">{content}</pre></div>'
232
+ elif role == "user":
233
+ conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>User:</strong><br><pre id="user-{i}">{content}</pre></div>'
234
+ elif role == "assistant":
235
+ conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>Assistant:</strong><br><pre id="assistant-{i}">{content}</pre></div>'
236
+ elif role == "function":
237
+ conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>Function:</strong><br><pre id="function-{i}">{content}</pre></div>'
238
+ else:
239
+ # logger.warning(f"Unknown role: {role}")
240
+ conversation_html += f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;"><strong>{role}:</strong><br><pre id="{role}-{i}">{content}</pre><br><button onclick="copyContent(\'{role}-{i}\')">Copy</button></div>'
241
+
242
+ html = f"""
243
+ <html>
244
+ <head>
245
+ <style>
246
+ pre {{
247
+ white-space: pre-wrap;
248
+ }}
249
+ </style>
250
+ </head>
251
+ <body>
252
+ {conversation_html}
253
+ <script>
254
+ function copyContent(elementId) {{
255
+ var element = document.getElementById(elementId);
256
+ var text = element.innerText;
257
+ navigator.clipboard.writeText(text)
258
+ .then(function() {{
259
+ alert("Content copied to clipboard!");
260
+ }})
261
+ .catch(function(error) {{
262
+ console.error("Error copying content: ", error);
263
+ }});
264
+ }}
265
+ </script>
266
+ </body>
267
+ </html>
268
+ """
269
+
270
+ if file:
271
+ with open(file, "w") as f:
272
+ f.write(html)
273
+ if return_html:
274
+ return html
275
+ else:
276
+ display(HTML(html))
277
+
278
+
279
+ def get_conversation_one_turn(
280
+ system_msg=None,
281
+ user_msg=None,
282
+ assistant_msg=None,
283
+ assistant_prefix=None,
284
+ return_format="chatml",
285
+ ):
286
+ messages = []
287
+ if system_msg:
288
+ messages.append({"role": "system", "content": system_msg})
289
+ messages.append({"role": "user", "content": user_msg})
290
+ if assistant_msg:
291
+ messages.append({"role": "assistant", "content": assistant_msg})
292
+ if assistant_prefix is not None:
293
+ assert (
294
+ return_format != "chatml"
295
+ ), "Change return_format to 'text' if you want to use assistant_prefix"
296
+ assert messages[-1]["role"] == "user"
297
+ msg = transform_messages(messages, "chatml", "text", add_generation_prompt=True)
298
+ msg += assistant_prefix
299
+ return msg
300
+ else:
301
+ assert return_format in ["chatml"]
302
+ return messages
303
+
304
+
305
+ from difflib import ndiff
306
+
307
+ from IPython.display import HTML
308
+
309
+
310
+ def display_diff_two_string(text1, text2):
311
+ # Split the texts into lines
312
+ lines1 = text1.splitlines()
313
+ lines2 = text2.splitlines()
314
+
315
+ # Perform the diff
316
+ diff = list(ndiff(lines1, lines2))
317
+
318
+ # Create the HTML table
319
+ table_rows = []
320
+ for line in diff:
321
+ if line.startswith("- "):
322
+ table_rows.append(
323
+ f'<tr><td style="background-color: #FFCCCB;">{line[2:]}</td><td></td></tr>'
324
+ )
325
+ elif line.startswith("+ "):
326
+ table_rows.append(
327
+ f'<tr><td></td><td style="background-color: #CCFFCC;">{line[2:]}</td></tr>'
328
+ )
329
+ elif line.startswith("? "):
330
+ continue
331
+ else:
332
+ table_rows.append(f"<tr><td>{line}</td><td>{line}</td></tr>")
333
+
334
+ table_html = '<table style="width: 100%; border-collapse: collapse;">'
335
+ table_html += '<tr><th style="width: 50%; text-align: left;">Text 1</th><th style="width: 50%; text-align: left;">Text 2</th></tr>'
336
+ table_html += "".join(table_rows)
337
+ table_html += "</table>"
338
+
339
+ # Display the HTML table
340
+ display(HTML(table_html))
341
+
342
+
343
+ def display_conversations(data1, data2, theme="light"):
344
+ html1 = display_chat_messages_as_html(data1, return_html=True, theme=theme)
345
+ html2 = display_chat_messages_as_html(data2, return_html=True, theme=theme)
346
+
347
+ html = f"""
348
+ <html>
349
+ <head>
350
+ <style>
351
+ table {{
352
+ width: 100%;
353
+ border-collapse: collapse;
354
+ }}
355
+ td {{
356
+ width: 50%;
357
+ vertical-align: top;
358
+ padding: 10px;
359
+ }}
360
+ </style>
361
+ </head>
362
+ <body>
363
+ <table>
364
+ <tr>
365
+ <td>{html1}</td>
366
+ <td>{html2}</td>
367
+ </tr>
368
+ </table>
369
+ </body>
370
+ </html>
371
+ """
372
+ display(HTML(html))
373
+
374
+
375
+ from typing import Callable, Dict, List
376
+
377
+
378
+ def build_chatml_input(template: str, params: List[str]) -> Callable:
379
+ def formator(**kwargs) -> List[List[Dict[str, str]]]:
380
+ system_msg = kwargs.get("system_msg", None)
381
+ # remove system
382
+ kwargs.pop("system_msg", None)
383
+ # Ensure all required parameters are present in kwargs
384
+ for param in params:
385
+ if param not in kwargs:
386
+ raise ValueError(f"Missing parameter: {param}")
387
+
388
+ # Use the **kwargs directly in the format method
389
+ content = template.format(**kwargs)
390
+ msgs = []
391
+ if system_msg:
392
+ msgs += [{"role": "system", "content": system_msg}]
393
+ msgs += [{"role": "user", "content": content}]
394
+ return msgs
395
+
396
+ return formator
397
+
398
+
399
+ def _color_text(text, color_code):
400
+ """Helper function to color text based on the provided ANSI color code."""
401
+ return f"\033[{color_code}m{text}\033[0m"
402
+
403
+
404
+ def format_msgs(messages):
405
+ """Formats the role and content of a list of messages into a string."""
406
+ messages = transform_messages_to_chatml(messages)
407
+ output = []
408
+
409
+ for msg in messages:
410
+ role = msg.get("role", "unknown").lower()
411
+ content = msg.get("content", "").strip()
412
+ output.append(f"{role.capitalize()}:\t{content}")
413
+ output.append("---")
414
+
415
+ return "\n".join(output)
416
+
417
+
418
+ __all__ = [
419
+ "transform_messages",
420
+ "transform_messages_to_chatml",
421
+ "display_chat_messages_as_html",
422
+ "get_conversation_one_turn",
423
+ "display_diff_two_string",
424
+ "display_conversations",
425
+ "build_chatml_input",
426
+ "format_msgs",
427
+ ]
@@ -0,0 +1,119 @@
1
+ import random
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import pandas as pd
6
+ from tabulate import tabulate
7
+
8
+ from speedy_utils import multi_thread
9
+
10
+
11
+ def split_indices_by_length(
12
+ lengths: list[int],
13
+ batch_size_by_mean_length: int,
14
+ random_seed: int,
15
+ verbose: bool,
16
+ shuffle: bool,
17
+ mean_length: Optional[int] = None,
18
+ ) -> list[list[int]]:
19
+ if mean_length is None:
20
+ mean_length = int(np.mean(lengths))
21
+ max_batch_length = mean_length * batch_size_by_mean_length
22
+
23
+ r = random.Random(random_seed)
24
+ indices = list(range(len(lengths)))
25
+
26
+ if shuffle:
27
+ r.shuffle(indices)
28
+
29
+ batches = []
30
+ current_batch = []
31
+ current_batch_length = 0
32
+
33
+ for idx in indices:
34
+ length = lengths[idx]
35
+ if current_batch_length + length <= max_batch_length:
36
+ current_batch.append(idx)
37
+ current_batch_length += length
38
+ else:
39
+ batches.append(current_batch)
40
+ current_batch = [idx]
41
+ current_batch_length = length
42
+
43
+ if current_batch:
44
+ batches.append(current_batch)
45
+
46
+ if verbose:
47
+ batch_lengths = [sum(lengths[idx] for idx in batch) for batch in batches]
48
+ desc = pd.Series(batch_lengths).describe()
49
+
50
+ table = [
51
+ ["New avg item len", desc["mean"]],
52
+ ["Number groups", len(batches)],
53
+ ["Max length", max_batch_length],
54
+ ]
55
+
56
+ print(tabulate(table, headers=["Metric", "Value"], tablefmt="pretty"))
57
+
58
+ return batches
59
+
60
+
61
+ def group_messages_by_len(
62
+ messages, model_name="Qwen/Qwen2.5-7B-Instruct", batch_size=4, mean_length=512
63
+ ):
64
+ """
65
+ Groups a list of messages into batches based on token length and concatenates them.
66
+ Args:
67
+ messages (list[dict]): OpenAI message format, each dict should contain a "messages" key with a list of messages. ensure the system prompt are shared.
68
+ model_name (str): The name of the model to use for tokenization. Default is "Qwen/Qwen2.5-7B-Instruct".
69
+ batch_size (int): The number of messages to include in each batch. Default is 4.
70
+ mean_length (int): The mean length of tokens for each batch. Default is 512.
71
+ Returns:
72
+ list: A list of concatenated message dictionaries, where each dictionary contains a "messages" key with the grouped messages.
73
+ Raises:
74
+ ValueError: If the messages parameter is None.
75
+ """
76
+ if messages is None:
77
+ raise ValueError("messages parameter cannot be None")
78
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
79
+
80
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
81
+
82
+ def create_batches(messages):
83
+ def get_token_length(message):
84
+ ids = tokenizer.apply_chat_template(message["messages"][1:], tokenize=True)
85
+ return len(ids)
86
+
87
+ # lengths = [get_token_length(msg) for msg in messages]
88
+ lengths = multi_thread(get_token_length, messages, workers=64)
89
+ list_ids = split_indices_by_length(
90
+ lengths,
91
+ batch_size,
92
+ random_seed=0,
93
+ verbose=True,
94
+ shuffle=True,
95
+ mean_length=mean_length,
96
+ )
97
+ concatenated_messages = []
98
+
99
+ def concatenate_messages(conversations):
100
+ system_message = conversations[0][0]
101
+ turns = []
102
+ for conv in conversations:
103
+ turns.extend(conv[1:])
104
+ return {"messages": [system_message] + turns}
105
+
106
+ for batch_ids in list_ids:
107
+ if not batch_ids:
108
+ continue
109
+ conversations = [messages[i]["messages"] for i in batch_ids]
110
+ concatenated_messages.append(concatenate_messages(conversations))
111
+ return concatenated_messages
112
+
113
+ chunked_messages = create_batches(messages)
114
+ return chunked_messages
115
+
116
+ __all__ = [
117
+ "split_indices_by_length",
118
+ "group_messages_by_len",
119
+ ]