speedy-utils 1.0.4__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (38) hide show
  1. llm_utils/__init__.py +31 -0
  2. llm_utils/chat_format/__init__.py +34 -0
  3. llm_utils/chat_format/display.py +274 -0
  4. llm_utils/chat_format/transform.py +149 -0
  5. llm_utils/chat_format/utils.py +43 -0
  6. llm_utils/group_messages.py +120 -0
  7. llm_utils/lm/__init__.py +8 -0
  8. llm_utils/lm/lm.py +304 -0
  9. llm_utils/lm/utils.py +130 -0
  10. llm_utils/scripts/vllm_load_balancer.py +435 -0
  11. llm_utils/scripts/vllm_serve.py +416 -0
  12. speedy_utils/__init__.py +85 -0
  13. speedy_utils/all.py +159 -0
  14. {speedy → speedy_utils}/common/__init__.py +0 -0
  15. speedy_utils/common/clock.py +215 -0
  16. speedy_utils/common/function_decorator.py +66 -0
  17. speedy_utils/common/logger.py +207 -0
  18. speedy_utils/common/report_manager.py +112 -0
  19. speedy_utils/common/utils_cache.py +264 -0
  20. {speedy → speedy_utils}/common/utils_io.py +66 -19
  21. {speedy → speedy_utils}/common/utils_misc.py +25 -11
  22. speedy_utils/common/utils_print.py +216 -0
  23. speedy_utils/multi_worker/__init__.py +0 -0
  24. speedy_utils/multi_worker/process.py +198 -0
  25. speedy_utils/multi_worker/thread.py +327 -0
  26. speedy_utils/scripts/mpython.py +108 -0
  27. speedy_utils-1.0.9.dist-info/METADATA +287 -0
  28. speedy_utils-1.0.9.dist-info/RECORD +30 -0
  29. {speedy_utils-1.0.4.dist-info → speedy_utils-1.0.9.dist-info}/WHEEL +1 -2
  30. speedy_utils-1.0.9.dist-info/entry_points.txt +5 -0
  31. speedy/__init__.py +0 -53
  32. speedy/common/clock.py +0 -68
  33. speedy/common/utils_cache.py +0 -170
  34. speedy/common/utils_print.py +0 -138
  35. speedy/multi_worker.py +0 -121
  36. speedy_utils-1.0.4.dist-info/METADATA +0 -22
  37. speedy_utils-1.0.4.dist-info/RECORD +0 -12
  38. speedy_utils-1.0.4.dist-info/top_level.txt +0 -1
llm_utils/__init__.py ADDED
@@ -0,0 +1,31 @@
1
+ from .chat_format import (
2
+ transform_messages,
3
+ transform_messages_to_chatml,
4
+ show_chat,
5
+ get_conversation_one_turn,
6
+ show_string_diff,
7
+ display_conversations,
8
+ build_chatml_input,
9
+ format_msgs,
10
+ display_chat_messages_as_html,
11
+ )
12
+ from .lm import LM
13
+ from .group_messages import (
14
+ split_indices_by_length,
15
+ group_messages_by_len,
16
+ )
17
+
18
+ __all__ = [
19
+ "transform_messages",
20
+ "transform_messages_to_chatml",
21
+ "show_chat",
22
+ "get_conversation_one_turn",
23
+ "show_string_diff",
24
+ "display_conversations",
25
+ "build_chatml_input",
26
+ "format_msgs",
27
+ "split_indices_by_length",
28
+ "group_messages_by_len",
29
+ "LM",
30
+ "display_chat_messages_as_html",
31
+ ]
@@ -0,0 +1,34 @@
1
+ from .transform import (
2
+ identify_format,
3
+ _transform_sharegpt_to_chatml,
4
+ transform_messages,
5
+ transform_messages_to_chatml,
6
+ )
7
+ from .display import (
8
+ show_chat,
9
+ get_conversation_one_turn,
10
+ highlight_diff_chars,
11
+ show_string_diff,
12
+ display_conversations,
13
+ display_chat_messages_as_html,
14
+ )
15
+ from .utils import (
16
+ build_chatml_input,
17
+ format_msgs,
18
+ )
19
+
20
+
21
+ __all__ = [
22
+ "identify_format",
23
+ "_transform_sharegpt_to_chatml",
24
+ "transform_messages",
25
+ "transform_messages_to_chatml",
26
+ "show_chat",
27
+ "get_conversation_one_turn",
28
+ "highlight_diff_chars",
29
+ "build_chatml_input",
30
+ "format_msgs",
31
+ "show_string_diff",
32
+ "display_conversations",
33
+ "display_chat_messages_as_html",
34
+ ]
@@ -0,0 +1,274 @@
1
+ from __future__ import annotations
2
+ from typing import List, Tuple, Sequence, Any, Dict, Optional
3
+ from IPython.display import HTML, display
4
+ from difflib import SequenceMatcher
5
+
6
+
7
+ def show_chat(
8
+ msgs: Any,
9
+ return_html: bool = False,
10
+ file: str = "/tmp/conversation.html",
11
+ theme: str = "default",
12
+ ) -> Optional[str]:
13
+ """
14
+ Display chat messages as HTML.
15
+ """
16
+ if isinstance(msgs, dict) and "messages" in msgs:
17
+ msgs = msgs["messages"]
18
+ assert isinstance(msgs, list) and all(
19
+ isinstance(msg, dict) and "role" in msg and "content" in msg for msg in msgs
20
+ ), "The input format is not recognized. Please specify the input format."
21
+
22
+ themes: dict[str, dict[str, dict[str, str]]] = {
23
+ "default": {
24
+ "system": {"background": "#ffaaaa", "text": "#222222"}, # More red
25
+ "user": {"background": "#f8c57e", "text": "#222222"}, # More orange
26
+ "assistant": {"background": "#9dfebd", "text": "#222222"}, # More green
27
+ "function": {"background": "#eafde7", "text": "#222222"},
28
+ "tool": {"background": "#fde7fa", "text": "#222222"},
29
+ "default": {"background": "#ffffff", "text": "#222222"},
30
+ },
31
+ "light": {
32
+ "system": {"background": "#ff6666", "text": "#000000"}, # More red
33
+ "user": {"background": "#ffd580", "text": "#000000"}, # More orange
34
+ "assistant": {"background": "#80ffb3", "text": "#000000"}, # More green
35
+ "function": {"background": "#AFFFFF", "text": "#000000"},
36
+ "tool": {"background": "#FFAAFF", "text": "#000000"},
37
+ "default": {"background": "#FFFFFF", "text": "#000000"},
38
+ },
39
+ "dark": {
40
+ "system": {"background": "#b22222", "text": "#fffbe7"}, # More red
41
+ "user": {"background": "#ff8800", "text": "#18181b"}, # More orange
42
+ "assistant": {"background": "#22c55e", "text": "#e0ffe0"}, # More green
43
+ "function": {"background": "#134e4a", "text": "#e0fff7"},
44
+ "tool": {"background": "#701a75", "text": "#ffe0fa"},
45
+ "default": {"background": "#18181b", "text": "#f4f4f5"},
46
+ },
47
+ }
48
+
49
+ color_scheme = themes.get(theme, themes["default"])
50
+
51
+ conversation_html = ""
52
+ for i, message in enumerate(msgs):
53
+ role = message["role"]
54
+ content = message.get("content", "")
55
+ if not content:
56
+ content = ""
57
+ tool_calls = message.get("tool_calls")
58
+ if not content and tool_calls:
59
+ for tool_call in tool_calls:
60
+ tool_call = tool_call["function"]
61
+ name = tool_call["name"]
62
+ args = tool_call["arguments"]
63
+ content += f"Tool: {name}\nArguments: {args}"
64
+ content = content.replace("\n", "<br>")
65
+ content = content.replace("\t", "&nbsp;&nbsp;&nbsp;&nbsp;")
66
+ content = content.replace(" ", "&nbsp;&nbsp;")
67
+ content = (
68
+ content.replace("<br>", "TEMP_BR")
69
+ .replace("<", "&lt;")
70
+ .replace(">", "&gt;")
71
+ .replace("TEMP_BR", "<br>")
72
+ )
73
+ if role in color_scheme:
74
+ background_color = color_scheme[role]["background"]
75
+ text_color = color_scheme[role]["text"]
76
+ else:
77
+ background_color = color_scheme["default"]["background"]
78
+ text_color = color_scheme["default"]["text"]
79
+ if role == "system":
80
+ conversation_html += (
81
+ f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;">'
82
+ f'<strong>System:</strong><br><pre id="system-{i}">{content}</pre></div>'
83
+ )
84
+ elif role == "user":
85
+ conversation_html += (
86
+ f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;">'
87
+ f'<strong>User:</strong><br><pre id="user-{i}">{content}</pre></div>'
88
+ )
89
+ elif role == "assistant":
90
+ conversation_html += (
91
+ f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;">'
92
+ f'<strong>Assistant:</strong><br><pre id="assistant-{i}">{content}</pre></div>'
93
+ )
94
+ elif role == "function":
95
+ conversation_html += (
96
+ f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;">'
97
+ f'<strong>Function:</strong><br><pre id="function-{i}">{content}</pre></div>'
98
+ )
99
+ else:
100
+ conversation_html += (
101
+ f'<div style="background-color: {background_color}; color: {text_color}; padding: 10px; margin-bottom: 10px;">'
102
+ f'<strong>{role}:</strong><br><pre id="{role}-{i}">{content}</pre><br>'
103
+ f"<button onclick=\"copyContent('{role}-{i}')\">Copy</button></div>"
104
+ )
105
+ html: str = f"""
106
+ <html>
107
+ <head>
108
+ <style>
109
+ pre {{
110
+ white-space: pre-wrap;
111
+ }}
112
+ </style>
113
+ </head>
114
+ <body>
115
+ {conversation_html}
116
+ <script>
117
+ function copyContent(elementId) {{
118
+ var element = document.getElementById(elementId);
119
+ var text = element.innerText;
120
+ navigator.clipboard.writeText(text)
121
+ .then(function() {{
122
+ alert("Content copied to clipboard!");
123
+ }})
124
+ .catch(function(error) {{
125
+ console.error("Error copying content: ", error);
126
+ }});
127
+ }}
128
+ </script>
129
+ </body>
130
+ </html>
131
+ """
132
+ if file:
133
+ with open(file, "w") as f:
134
+ f.write(html)
135
+ if return_html:
136
+ return html
137
+ else:
138
+ display(HTML(html))
139
+
140
+
141
+ def get_conversation_one_turn(
142
+ system_msg: Optional[str] = None,
143
+ user_msg: Optional[str] = None,
144
+ assistant_msg: Optional[str] = None,
145
+ assistant_prefix: Optional[str] = None,
146
+ return_format: str = "chatml",
147
+ ) -> Any:
148
+ """
149
+ Build a one-turn conversation.
150
+ """
151
+ messages: list[dict[str, str]] = []
152
+ if system_msg is not None:
153
+ messages.append({"role": "system", "content": system_msg})
154
+ if user_msg is not None:
155
+ messages.append({"role": "user", "content": user_msg})
156
+ if assistant_msg is not None:
157
+ messages.append({"role": "assistant", "content": assistant_msg})
158
+ if assistant_prefix is not None:
159
+ assert (
160
+ return_format != "chatml"
161
+ ), 'Change return_format to "text" if you want to use assistant_prefix'
162
+ assert messages[-1]["role"] == "user"
163
+ from .transform import transform_messages
164
+
165
+ msg = transform_messages(messages, "chatml", "text", add_generation_prompt=True)
166
+ if not isinstance(msg, str):
167
+ msg = str(msg)
168
+ msg += assistant_prefix
169
+ return msg
170
+ else:
171
+ assert return_format in ["chatml"]
172
+ return messages
173
+
174
+
175
+ def highlight_diff_chars(text1: str, text2: str) -> str:
176
+ """
177
+ Return a string with deletions in red and additions in green.
178
+ """
179
+ matcher = SequenceMatcher(None, text1, text2)
180
+ html: list[str] = []
181
+ for tag, i1, i2, j1, j2 in matcher.get_opcodes():
182
+ if tag == "equal":
183
+ html.append(text1[i1:i2])
184
+ elif tag == "replace":
185
+ if i1 != i2:
186
+ html.append(
187
+ f'<span style="background-color:#ffd6d6; color:#b20000;">{text1[i1:i2]}</span>'
188
+ )
189
+ if j1 != j2:
190
+ html.append(
191
+ f'<span style="background-color:#d6ffd6; color:#006600;">{text2[j1:j2]}</span>'
192
+ )
193
+ elif tag == "delete":
194
+ html.append(
195
+ f'<span style="background-color:#ffd6d6; color:#b20000;">{text1[i1:i2]}</span>'
196
+ )
197
+ elif tag == "insert":
198
+ html.append(
199
+ f'<span style="background-color:#d6ffd6; color:#006600;">{text2[j1:j2]}</span>'
200
+ )
201
+ return "".join(html)
202
+
203
+
204
+ def show_string_diff(old: str, new: str) -> None:
205
+ """
206
+ Display a one-line visual diff between two strings (old -> new).
207
+ """
208
+ html1 = highlight_diff_chars(old, new)
209
+ display(HTML(html1))
210
+
211
+
212
+ def display_conversations(data1: Any, data2: Any, theme: str = "light") -> None:
213
+ """
214
+ Display two conversations side by side.
215
+ """
216
+ import warnings
217
+
218
+ warnings.warn(
219
+ "display_conversations will be deprecated in the next version.",
220
+ DeprecationWarning,
221
+ stacklevel=2,
222
+ )
223
+ html1 = show_chat(data1, return_html=True, theme=theme)
224
+ html2 = show_chat(data2, return_html=True, theme=theme)
225
+ html = f"""
226
+ <html>
227
+ <head>
228
+ <style>
229
+ table {{
230
+ width: 100%;
231
+ border-collapse: collapse;
232
+ }}
233
+ td {{
234
+ width: 50%;
235
+ vertical-align: top;
236
+ padding: 10px;
237
+ }}
238
+ </style>
239
+ </head>
240
+ <body>
241
+ <table>
242
+ <tr>
243
+ <td>{html1}</td>
244
+ <td>{html2}</td>
245
+ </tr>
246
+ </table>
247
+ </body>
248
+ </html>
249
+ """
250
+ display(HTML(html))
251
+
252
+
253
+ def display_chat_messages_as_html(*args, **kwargs):
254
+ """
255
+ Use as show_chat and warn about the deprecated function.
256
+ """
257
+ import warnings
258
+
259
+ warnings.warn(
260
+ "display_chat_messages_as_html is deprecated, use show_chat instead.",
261
+ DeprecationWarning,
262
+ stacklevel=2,
263
+ )
264
+ return show_chat(*args, **kwargs)
265
+
266
+
267
+ __all__ = [
268
+ "show_chat",
269
+ "get_conversation_one_turn",
270
+ "highlight_diff_chars",
271
+ "show_string_diff",
272
+ "display_conversations",
273
+ "display_chat_messages_as_html",
274
+ ]
@@ -0,0 +1,149 @@
1
+ from __future__ import annotations
2
+ from copy import deepcopy
3
+ from typing import Callable, Dict, List, Sequence
4
+
5
+
6
+ def identify_format(item):
7
+ if isinstance(item, list) and "role" in item[0]:
8
+ return "chatml"
9
+ if isinstance(item, dict):
10
+ if "conversations" in item:
11
+ return "sharegpt"
12
+ raise ValueError(
13
+ f"The format of the item is not recognized. \n{type(item)=}, \n{item=}"
14
+ )
15
+
16
+
17
+ def _transform_sharegpt_to_chatml(
18
+ item, default_system_message="You are a helpful assistant.", print_msg=False
19
+ ):
20
+ assert isinstance(
21
+ item, dict
22
+ ), "The item is not in the correct format. Please check the format of the item."
23
+
24
+ messages = []
25
+ system_msg = item.get("system", "")
26
+ if system_msg:
27
+ messages.append({"role": "system", "content": system_msg})
28
+ elif default_system_message:
29
+ messages.append({"role": "system", "content": default_system_message})
30
+ conversations = item.get("conversations", [])
31
+ if hasattr(conversations, "tolist"):
32
+ conversations = conversations.tolist()
33
+ assert conversations, "The item does not have any conversations."
34
+ for conversation in item.get("conversations", []):
35
+ role = conversation["from"]
36
+ content = conversation["value"]
37
+ messages.append({"role": role, "content": content})
38
+
39
+ return messages
40
+
41
+
42
+ def transform_messages(
43
+ item,
44
+ frm="chatml",
45
+ to="text",
46
+ add_generation_prompt=True,
47
+ tokenizer=None,
48
+ assistant_prefix=None,
49
+ ):
50
+ assert to in [
51
+ "chatml",
52
+ "text",
53
+ "sharegpt",
54
+ "simulated_chat",
55
+ ], "The output format is not recognized. Please specify the output format."
56
+ item = deepcopy(item)
57
+
58
+ if tokenizer is not None:
59
+ assert frm == "chatml", "Tokenizer is only supported for chatml format."
60
+ prompt = tokenizer.apply_chat_template(
61
+ item, tokenize=False, add_generation_prompt=True
62
+ )
63
+ assert isinstance(prompt, str), "Prompt must be a string."
64
+ if assistant_prefix:
65
+ prompt += f"{assistant_prefix}"
66
+ return prompt
67
+
68
+ if frm != to:
69
+ chatml_messages = transform_messages_to_chatml(item, input_format=frm)
70
+ if to == "sharegpt":
71
+ if chatml_messages[0]["role"] == "system":
72
+ system_message = chatml_messages[0]["content"]
73
+ ret = {"conversations": [], "system": system_message.strip()}
74
+ for message in chatml_messages[1:]:
75
+ ret["conversations"].append(
76
+ {"from": message["role"], "value": message["content"]}
77
+ )
78
+ else:
79
+ system_message = "You are a helpful assistant."
80
+ ret = {"conversations": [], "system": system_message.strip()}
81
+ for message in chatml_messages:
82
+ ret["conversations"].append(
83
+ {"from": message["role"], "value": message["content"]}
84
+ )
85
+ return ret
86
+ elif to == "chatml":
87
+ return _transform_sharegpt_to_chatml(item)
88
+ elif to == "text":
89
+ text = ""
90
+ for turn in chatml_messages:
91
+ text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
92
+ if add_generation_prompt:
93
+ text += "<|im_start|>assistant\n"
94
+ return text
95
+ elif to == "simulated_chat":
96
+ text = "<role> Given the simulated chat, you are the assistant. Lets continue the conversation. \n\n"
97
+ for turn in chatml_messages:
98
+ prefix = {
99
+ "user": "Human",
100
+ "assistant": "AI",
101
+ "system": "System",
102
+ "function": "Function",
103
+ }.get(turn["role"])
104
+ text += f"{prefix}: {turn['content'].strip()}\n\n"
105
+ if add_generation_prompt:
106
+ text += "AI: [continue the conversation here]"
107
+ return text
108
+ else:
109
+ raise ValueError(f"{to} is not supported.")
110
+
111
+ else:
112
+ return item
113
+
114
+
115
+ def transform_messages_to_chatml(input_data, input_format="auto"):
116
+ if input_format == "auto":
117
+ input_data = raw_data = deepcopy(input_data)
118
+ if isinstance(input_data, list):
119
+ input_format = "chatlm"
120
+ assert (
121
+ input_data[0].get("role") is not None
122
+ ), "The input format is not recognized. Please specify the input format."
123
+ elif isinstance(input_data, dict):
124
+ input_data = _transform_sharegpt_to_chatml(input_data)
125
+ input_format = "sharegpt"
126
+ elif isinstance(input_data, str):
127
+ assert (
128
+ "<|im_end|>" in input_data
129
+ ), "The input format is not recognized. Please specify the input format."
130
+ input_format = "chatlm"
131
+ parts = input_data.split("<|im_end|>")
132
+ input_data = []
133
+ for part in parts:
134
+ if not part.strip():
135
+ continue
136
+ role = part.split("<|im_start|>")[1].split("\n")[0]
137
+ content = part.split(f"<|im_start|>{role}\n")[1]
138
+ content = content.split("<|im_end|>")[0]
139
+ input_data.append({"role": role.strip(), "content": content.strip()})
140
+
141
+ return input_data
142
+
143
+
144
+ __all__ = [
145
+ "identify_format",
146
+ "_transform_sharegpt_to_chatml",
147
+ "transform_messages",
148
+ "transform_messages_to_chatml",
149
+ ]
@@ -0,0 +1,43 @@
1
+ from __future__ import annotations
2
+ from typing import List, Dict, Callable
3
+
4
+
5
+ def build_chatml_input(template: str, params: List[str]) -> Callable:
6
+ def formator(**kwargs) -> List[List[Dict[str, str]]]:
7
+ system_msg = kwargs.get("system_msg", None)
8
+ kwargs.pop("system_msg", None)
9
+ for param in params:
10
+ if param not in kwargs:
11
+ raise ValueError(f"Missing parameter: {param}")
12
+ content = template.format(**kwargs)
13
+ msgs = []
14
+ if system_msg:
15
+ msgs += [{"role": "system", "content": system_msg}]
16
+ msgs += [{"role": "user", "content": content}]
17
+ return msgs
18
+
19
+ return formator
20
+
21
+
22
+ def _color_text(text, color_code):
23
+ return f"\033[{color_code}m{text}\033[0m"
24
+
25
+
26
+ def format_msgs(messages):
27
+ from .transform import transform_messages_to_chatml
28
+
29
+ messages = transform_messages_to_chatml(messages)
30
+ output = []
31
+ for msg in messages:
32
+ role = msg.get("role", "unknown").lower()
33
+ content = msg.get("content", "").strip()
34
+ output.append(f"{role.capitalize()}:\t{content}")
35
+ output.append("---")
36
+ return "\n".join(output)
37
+
38
+
39
+ __all__ = [
40
+ "build_chatml_input",
41
+ "_color_text",
42
+ "format_msgs",
43
+ ]
@@ -0,0 +1,120 @@
1
+ from __future__ import annotations
2
+
3
+ import random
4
+ from typing import Sequence, Optional
5
+
6
+ import numpy as np
7
+ import pandas as pd
8
+ from tabulate import tabulate
9
+
10
+ from speedy_utils import multi_thread
11
+
12
+
13
+ def split_indices_by_length(
14
+ lengths: Sequence[int],
15
+ batch_size_by_mean_length: int,
16
+ random_seed: int,
17
+ verbose: bool,
18
+ shuffle: bool,
19
+ mean_length: Optional[int] = None,
20
+ ) -> list[list[int]]:
21
+ """
22
+ Split indices into batches so that the sum of lengths in each batch does not exceed max_batch_length.
23
+ """
24
+ if mean_length is None:
25
+ mean_length = int(np.mean(lengths))
26
+ max_batch_length: int = mean_length * batch_size_by_mean_length
27
+
28
+ r: random.Random = random.Random(random_seed)
29
+ indices: list[int] = list(range(len(lengths)))
30
+
31
+ if shuffle:
32
+ r.shuffle(indices)
33
+
34
+ batches: list[list[int]] = []
35
+ current_batch: list[int] = []
36
+ current_batch_length: int = 0
37
+
38
+ for idx in indices:
39
+ length: int = lengths[idx]
40
+ if current_batch_length + length <= max_batch_length:
41
+ current_batch.append(idx)
42
+ current_batch_length += length
43
+ else:
44
+ batches.append(current_batch)
45
+ current_batch = [idx]
46
+ current_batch_length = length
47
+
48
+ if current_batch:
49
+ batches.append(current_batch)
50
+
51
+ if verbose:
52
+ batch_lengths: list[int] = [
53
+ sum(lengths[idx] for idx in batch) for batch in batches
54
+ ]
55
+ desc = pd.Series(batch_lengths).describe()
56
+
57
+ table = [
58
+ ["New avg item len", desc["mean"]],
59
+ ["Number groups", len(batches)],
60
+ ["Max length", max_batch_length],
61
+ ]
62
+
63
+ print(tabulate(table, headers=["Metric", "Value"], tablefmt="pretty"))
64
+
65
+ return batches
66
+
67
+
68
+ def group_messages_by_len(
69
+ messages: Sequence[dict],
70
+ model_name: str = "Qwen/Qwen2.5-7B-Instruct",
71
+ batch_size: int = 4,
72
+ mean_length: int = 512,
73
+ ) -> list[dict]:
74
+ """
75
+ Groups messages into batches based on token length and concatenates them.
76
+ """
77
+ if messages is None:
78
+ raise ValueError("messages parameter cannot be None")
79
+ from transformers.models.auto.tokenization_auto import AutoTokenizer
80
+
81
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
82
+
83
+ def create_batches(messages: Sequence[dict]) -> list[dict]:
84
+ def get_token_length(message: dict) -> int:
85
+ ids = tokenizer.apply_chat_template(message["messages"][1:], tokenize=True)
86
+ return len(ids)
87
+
88
+ lengths: list[int] = multi_thread(get_token_length, messages, workers=64)
89
+ list_ids: list[list[int]] = split_indices_by_length(
90
+ lengths,
91
+ batch_size,
92
+ random_seed=0,
93
+ verbose=True,
94
+ shuffle=True,
95
+ mean_length=mean_length,
96
+ )
97
+ concatenated_messages: list[dict] = []
98
+
99
+ def concatenate_messages(conversations: Sequence[Sequence[dict]]) -> dict:
100
+ system_message = conversations[0][0]
101
+ turns: list[dict] = []
102
+ for conv in conversations:
103
+ turns.extend(conv[1:])
104
+ return {"messages": [system_message] + turns}
105
+
106
+ for batch_ids in list_ids:
107
+ if not batch_ids:
108
+ continue
109
+ conversations = [messages[i]["messages"] for i in batch_ids]
110
+ concatenated_messages.append(concatenate_messages(conversations))
111
+ return concatenated_messages
112
+
113
+ chunked_messages: list[dict] = create_batches(messages)
114
+ return chunked_messages
115
+
116
+
117
+ __all__ = [
118
+ "split_indices_by_length",
119
+ "group_messages_by_len",
120
+ ]
@@ -0,0 +1,8 @@
1
+ from .lm import LM
2
+
3
+ OAI_LM = LM
4
+
5
+ __all__ = [
6
+ "LM",
7
+ "OAI_LM",
8
+ ]