speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
- speedy_utils-1.1.29.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +4 -0
- vision_utils/io_utils.py +735 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -1,8 +1,19 @@
|
|
|
1
|
+
from llm_utils.lm import (
|
|
2
|
+
LLM,
|
|
3
|
+
AsyncLLMTask,
|
|
4
|
+
AsyncLM,
|
|
5
|
+
Input,
|
|
6
|
+
InputField,
|
|
7
|
+
LLMSignature,
|
|
8
|
+
Output,
|
|
9
|
+
OutputField,
|
|
10
|
+
Signature,
|
|
11
|
+
)
|
|
12
|
+
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
13
|
+
from llm_utils.lm.lm_base import get_model_name
|
|
1
14
|
from llm_utils.lm.openai_memoize import MOpenAI
|
|
2
|
-
from llm_utils.lm import LLM, AsyncLM, AsyncLLMTask, LLMSignature, Signature, InputField, OutputField, Input, Output
|
|
3
15
|
from llm_utils.vector_cache import VectorCache
|
|
4
|
-
|
|
5
|
-
from llm_utils.lm.base_prompt_builder import BasePromptBuilder
|
|
16
|
+
|
|
6
17
|
|
|
7
18
|
LLM_TASK = LLM
|
|
8
19
|
|
|
@@ -24,13 +35,14 @@ from llm_utils.chat_format import (
|
|
|
24
35
|
display_conversations,
|
|
25
36
|
format_msgs,
|
|
26
37
|
get_conversation_one_turn,
|
|
27
|
-
show_chat_v2,
|
|
28
38
|
show_chat,
|
|
39
|
+
show_chat_v2,
|
|
29
40
|
show_string_diff,
|
|
30
41
|
transform_messages,
|
|
31
42
|
transform_messages_to_chatml,
|
|
32
43
|
)
|
|
33
44
|
|
|
45
|
+
|
|
34
46
|
__all__ = [
|
|
35
47
|
"transform_messages",
|
|
36
48
|
"transform_messages_to_chatml",
|
|
@@ -1,17 +1,17 @@
|
|
|
1
|
-
from .transform import (
|
|
2
|
-
identify_format,
|
|
3
|
-
_transform_sharegpt_to_chatml,
|
|
4
|
-
transform_messages,
|
|
5
|
-
transform_messages_to_chatml,
|
|
6
|
-
)
|
|
7
1
|
from .display import (
|
|
8
|
-
|
|
2
|
+
display_chat_messages_as_html,
|
|
3
|
+
display_conversations,
|
|
9
4
|
get_conversation_one_turn,
|
|
10
5
|
highlight_diff_chars,
|
|
11
|
-
|
|
12
|
-
display_conversations,
|
|
13
|
-
display_chat_messages_as_html,
|
|
6
|
+
show_chat,
|
|
14
7
|
show_chat_v2,
|
|
8
|
+
show_string_diff,
|
|
9
|
+
)
|
|
10
|
+
from .transform import (
|
|
11
|
+
_transform_sharegpt_to_chatml,
|
|
12
|
+
identify_format,
|
|
13
|
+
transform_messages,
|
|
14
|
+
transform_messages_to_chatml,
|
|
15
15
|
)
|
|
16
16
|
from .utils import (
|
|
17
17
|
build_chatml_input,
|
llm_utils/chat_format/display.py
CHANGED
|
@@ -77,7 +77,7 @@ def show_chat(
|
|
|
77
77
|
theme: str = "default",
|
|
78
78
|
as_markdown: bool = False,
|
|
79
79
|
as_json: bool = False,
|
|
80
|
-
) ->
|
|
80
|
+
) -> str | None:
|
|
81
81
|
"""
|
|
82
82
|
Display chat messages as HTML.
|
|
83
83
|
|
|
@@ -168,7 +168,10 @@ def show_chat(
|
|
|
168
168
|
content = content.replace("\t", " ")
|
|
169
169
|
content = content.replace(" ", " ")
|
|
170
170
|
content = (
|
|
171
|
-
content.replace("<br>", "TEMP_BR")
|
|
171
|
+
content.replace("<br>", "TEMP_BR")
|
|
172
|
+
.replace("<", "<")
|
|
173
|
+
.replace(">", ">")
|
|
174
|
+
.replace("TEMP_BR", "<br>")
|
|
172
175
|
)
|
|
173
176
|
if role in color_scheme:
|
|
174
177
|
background_color = color_scheme[role]["background"]
|
|
@@ -239,15 +242,15 @@ def show_chat(
|
|
|
239
242
|
f.write(html)
|
|
240
243
|
if return_html:
|
|
241
244
|
return html
|
|
242
|
-
|
|
243
|
-
|
|
245
|
+
display(HTML(html))
|
|
246
|
+
return None
|
|
244
247
|
|
|
245
248
|
|
|
246
249
|
def get_conversation_one_turn(
|
|
247
|
-
system_msg:
|
|
248
|
-
user_msg:
|
|
249
|
-
assistant_msg:
|
|
250
|
-
assistant_prefix:
|
|
250
|
+
system_msg: str | None = None,
|
|
251
|
+
user_msg: str | None = None,
|
|
252
|
+
assistant_msg: str | None = None,
|
|
253
|
+
assistant_prefix: str | None = None,
|
|
251
254
|
return_format: str = "chatml",
|
|
252
255
|
) -> Any:
|
|
253
256
|
"""
|
|
@@ -261,7 +264,9 @@ def get_conversation_one_turn(
|
|
|
261
264
|
if assistant_msg is not None:
|
|
262
265
|
messages.append({"role": "assistant", "content": assistant_msg})
|
|
263
266
|
if assistant_prefix is not None:
|
|
264
|
-
assert
|
|
267
|
+
assert (
|
|
268
|
+
return_format != "chatml"
|
|
269
|
+
), 'Change return_format to "text" if you want to use assistant_prefix'
|
|
265
270
|
assert messages[-1]["role"] == "user"
|
|
266
271
|
from .transform import transform_messages
|
|
267
272
|
|
|
@@ -270,9 +275,8 @@ def get_conversation_one_turn(
|
|
|
270
275
|
msg = str(msg)
|
|
271
276
|
msg += assistant_prefix
|
|
272
277
|
return msg
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
return messages
|
|
278
|
+
assert return_format in ["chatml"]
|
|
279
|
+
return messages
|
|
276
280
|
|
|
277
281
|
|
|
278
282
|
def highlight_diff_chars(text1: str, text2: str) -> str:
|
|
@@ -286,13 +290,21 @@ def highlight_diff_chars(text1: str, text2: str) -> str:
|
|
|
286
290
|
html.append(text1[i1:i2])
|
|
287
291
|
elif tag == "replace":
|
|
288
292
|
if i1 != i2:
|
|
289
|
-
html.append(
|
|
293
|
+
html.append(
|
|
294
|
+
f'<span style="background-color:#ffd6d6; color:#b20000;">{text1[i1:i2]}</span>'
|
|
295
|
+
)
|
|
290
296
|
if j1 != j2:
|
|
291
|
-
html.append(
|
|
297
|
+
html.append(
|
|
298
|
+
f'<span style="background-color:#d6ffd6; color:#006600;">{text2[j1:j2]}</span>'
|
|
299
|
+
)
|
|
292
300
|
elif tag == "delete":
|
|
293
|
-
html.append(
|
|
301
|
+
html.append(
|
|
302
|
+
f'<span style="background-color:#ffd6d6; color:#b20000;">{text1[i1:i2]}</span>'
|
|
303
|
+
)
|
|
294
304
|
elif tag == "insert":
|
|
295
|
-
html.append(
|
|
305
|
+
html.append(
|
|
306
|
+
f'<span style="background-color:#d6ffd6; color:#006600;">{text2[j1:j2]}</span>'
|
|
307
|
+
)
|
|
296
308
|
return "".join(html)
|
|
297
309
|
|
|
298
310
|
|
|
@@ -321,7 +333,7 @@ def show_chat_v2(messages: list[dict[str, str]]):
|
|
|
321
333
|
|
|
322
334
|
if is_notebook:
|
|
323
335
|
# Use HTML display in notebook
|
|
324
|
-
from IPython.display import
|
|
336
|
+
from IPython.display import HTML, display
|
|
325
337
|
|
|
326
338
|
role_colors = {
|
|
327
339
|
"system": "red",
|
|
@@ -353,9 +365,7 @@ def show_chat_v2(messages: list[dict[str, str]]):
|
|
|
353
365
|
html += f"<div style='color:{color}'><strong>{label}</strong><br>{content}</div>"
|
|
354
366
|
# Add separator except after last message
|
|
355
367
|
if i < len(messages) - 1:
|
|
356
|
-
html +=
|
|
357
|
-
"<div style='color:#888; margin:0.5em 0;'>───────────────────────────────────────────────────</div>"
|
|
358
|
-
)
|
|
368
|
+
html += "<div style='color:#888; margin:0.5em 0;'>───────────────────────────────────────────────────</div>"
|
|
359
369
|
html += "</div>"
|
|
360
370
|
|
|
361
371
|
display(HTML(html))
|
|
@@ -385,7 +395,9 @@ def show_chat_v2(messages: list[dict[str, str]]):
|
|
|
385
395
|
print(f"{color}{content}{reset}")
|
|
386
396
|
# Add separator except after last message
|
|
387
397
|
if i < len(messages) - 1:
|
|
388
|
-
print(
|
|
398
|
+
print(
|
|
399
|
+
f"{separator_color}─────────────────────────────────────────────────────────{reset}"
|
|
400
|
+
)
|
|
389
401
|
|
|
390
402
|
|
|
391
403
|
def display_conversations(data1: Any, data2: Any, theme: str = "light") -> None:
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
+
|
|
2
3
|
from copy import deepcopy
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
def identify_format(item):
|
|
6
7
|
if isinstance(item, list) and "role" in item[0]:
|
|
7
8
|
return "chatml"
|
|
8
|
-
if isinstance(item, dict):
|
|
9
|
-
|
|
10
|
-
return "sharegpt"
|
|
9
|
+
if isinstance(item, dict) and "conversations" in item:
|
|
10
|
+
return "sharegpt"
|
|
11
11
|
raise ValueError(
|
|
12
12
|
f"The format of the item is not recognized. \n{type(item)=}, \n{item=}"
|
|
13
13
|
)
|
|
@@ -16,9 +16,9 @@ def identify_format(item):
|
|
|
16
16
|
def _transform_sharegpt_to_chatml(
|
|
17
17
|
item, default_system_message="You are a helpful assistant.", print_msg=False
|
|
18
18
|
):
|
|
19
|
-
assert isinstance(
|
|
20
|
-
|
|
21
|
-
)
|
|
19
|
+
assert isinstance(
|
|
20
|
+
item, dict
|
|
21
|
+
), "The item is not in the correct format. Please check the format of the item."
|
|
22
22
|
|
|
23
23
|
messages = []
|
|
24
24
|
system_msg = item.get("system", "")
|
|
@@ -82,16 +82,16 @@ def transform_messages(
|
|
|
82
82
|
{"from": message["role"], "value": message["content"]}
|
|
83
83
|
)
|
|
84
84
|
return ret
|
|
85
|
-
|
|
85
|
+
if to == "chatml":
|
|
86
86
|
return _transform_sharegpt_to_chatml(item)
|
|
87
|
-
|
|
87
|
+
if to == "text":
|
|
88
88
|
text = ""
|
|
89
89
|
for turn in chatml_messages:
|
|
90
90
|
text += f"<|im_start|>{turn['role']}\n{turn['content']}<|im_end|>\n"
|
|
91
91
|
if add_generation_prompt:
|
|
92
92
|
text += "<|im_start|>assistant\n"
|
|
93
93
|
return text
|
|
94
|
-
|
|
94
|
+
if to == "simulated_chat":
|
|
95
95
|
text = "<role> Given the simulated chat, you are the assistant. Lets continue the conversation. \n\n"
|
|
96
96
|
for turn in chatml_messages:
|
|
97
97
|
prefix = {
|
|
@@ -104,11 +104,9 @@ def transform_messages(
|
|
|
104
104
|
if add_generation_prompt:
|
|
105
105
|
text += "AI: [continue the conversation here]"
|
|
106
106
|
return text
|
|
107
|
-
|
|
108
|
-
raise ValueError(f"{to} is not supported.")
|
|
107
|
+
raise ValueError(f"{to} is not supported.")
|
|
109
108
|
|
|
110
|
-
|
|
111
|
-
return item
|
|
109
|
+
return item
|
|
112
110
|
|
|
113
111
|
|
|
114
112
|
def transform_messages_to_chatml(input_data, input_format="auto"):
|
|
@@ -116,16 +114,16 @@ def transform_messages_to_chatml(input_data, input_format="auto"):
|
|
|
116
114
|
input_data = deepcopy(input_data)
|
|
117
115
|
if isinstance(input_data, list):
|
|
118
116
|
input_format = "chatlm"
|
|
119
|
-
assert
|
|
120
|
-
"
|
|
121
|
-
)
|
|
117
|
+
assert (
|
|
118
|
+
input_data[0].get("role") is not None
|
|
119
|
+
), "The input format is not recognized. Please specify the input format."
|
|
122
120
|
elif isinstance(input_data, dict):
|
|
123
121
|
input_data = _transform_sharegpt_to_chatml(input_data)
|
|
124
122
|
input_format = "sharegpt"
|
|
125
123
|
elif isinstance(input_data, str):
|
|
126
|
-
assert
|
|
127
|
-
"
|
|
128
|
-
)
|
|
124
|
+
assert (
|
|
125
|
+
"<|im_end|>" in input_data
|
|
126
|
+
), "The input format is not recognized. Please specify the input format."
|
|
129
127
|
input_format = "chatlm"
|
|
130
128
|
parts = input_data.split("<|im_end|>")
|
|
131
129
|
input_data = []
|
llm_utils/chat_format/utils.py
CHANGED
|
@@ -1,10 +1,12 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
|
-
from typing import List, Dict, Callable
|
|
3
2
|
|
|
3
|
+
from collections.abc import Callable
|
|
4
|
+
from typing import Dict, List
|
|
4
5
|
|
|
5
|
-
|
|
6
|
-
|
|
7
|
-
|
|
6
|
+
|
|
7
|
+
def build_chatml_input(template: str, params: list[str]) -> Callable:
|
|
8
|
+
def formator(**kwargs) -> list[list[dict[str, str]]]:
|
|
9
|
+
system_msg = kwargs.get("system_msg")
|
|
8
10
|
kwargs.pop("system_msg", None)
|
|
9
11
|
for param in params:
|
|
10
12
|
if param not in kwargs:
|
llm_utils/group_messages.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
|
1
1
|
from __future__ import annotations
|
|
2
2
|
|
|
3
3
|
import random
|
|
4
|
-
from
|
|
4
|
+
from collections.abc import Sequence
|
|
5
|
+
from typing import Optional, cast
|
|
5
6
|
|
|
6
7
|
import numpy as np
|
|
7
8
|
import pandas as pd
|
|
@@ -16,7 +17,7 @@ def split_indices_by_length(
|
|
|
16
17
|
random_seed: int,
|
|
17
18
|
verbose: bool,
|
|
18
19
|
shuffle: bool,
|
|
19
|
-
mean_length:
|
|
20
|
+
mean_length: int | None = None,
|
|
20
21
|
) -> list[list[int]]:
|
|
21
22
|
"""
|
|
22
23
|
Split indices into batches so that the sum of lengths in each batch does not exceed max_batch_length.
|
|
@@ -55,19 +56,19 @@ def split_indices_by_length(
|
|
|
55
56
|
desc = pd.Series(batch_lengths).describe()
|
|
56
57
|
|
|
57
58
|
table = [
|
|
58
|
-
[
|
|
59
|
-
[
|
|
60
|
-
[
|
|
59
|
+
['New avg item len', desc['mean']],
|
|
60
|
+
['Number groups', len(batches)],
|
|
61
|
+
['Max length', max_batch_length],
|
|
61
62
|
]
|
|
62
63
|
|
|
63
|
-
print(tabulate(table, headers=[
|
|
64
|
+
print(tabulate(table, headers=['Metric', 'Value'], tablefmt='pretty'))
|
|
64
65
|
|
|
65
66
|
return batches
|
|
66
67
|
|
|
67
68
|
|
|
68
69
|
def group_messages_by_len(
|
|
69
70
|
messages: Sequence[dict],
|
|
70
|
-
model_name: str =
|
|
71
|
+
model_name: str = 'Qwen/Qwen2.5-7B-Instruct',
|
|
71
72
|
batch_size: int = 4,
|
|
72
73
|
mean_length: int = 512,
|
|
73
74
|
) -> list[dict]:
|
|
@@ -75,17 +76,19 @@ def group_messages_by_len(
|
|
|
75
76
|
Groups messages into batches based on token length and concatenates them.
|
|
76
77
|
"""
|
|
77
78
|
if messages is None:
|
|
78
|
-
raise ValueError(
|
|
79
|
+
raise ValueError('messages parameter cannot be None')
|
|
79
80
|
from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
|
|
80
81
|
|
|
81
82
|
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
82
83
|
|
|
83
84
|
def create_batches(messages: Sequence[dict]) -> list[dict]:
|
|
84
85
|
def get_token_length(message: dict) -> int:
|
|
85
|
-
ids = tokenizer.apply_chat_template(message[
|
|
86
|
+
ids = tokenizer.apply_chat_template(message['messages'][1:], tokenize=True)
|
|
86
87
|
return len(ids)
|
|
87
88
|
|
|
88
|
-
lengths: list[int] =
|
|
89
|
+
lengths: list[int] = cast(
|
|
90
|
+
list[int], multi_thread(get_token_length, messages, workers=64)
|
|
91
|
+
)
|
|
89
92
|
list_ids: list[list[int]] = split_indices_by_length(
|
|
90
93
|
lengths,
|
|
91
94
|
batch_size,
|
|
@@ -101,12 +104,12 @@ def group_messages_by_len(
|
|
|
101
104
|
turns: list[dict] = []
|
|
102
105
|
for conv in conversations:
|
|
103
106
|
turns.extend(conv[1:])
|
|
104
|
-
return {
|
|
107
|
+
return {'messages': [system_message] + turns}
|
|
105
108
|
|
|
106
109
|
for batch_ids in list_ids:
|
|
107
110
|
if not batch_ids:
|
|
108
111
|
continue
|
|
109
|
-
conversations = [messages[i][
|
|
112
|
+
conversations = [messages[i]['messages'] for i in batch_ids]
|
|
110
113
|
concatenated_messages.append(concatenate_messages(conversations))
|
|
111
114
|
return concatenated_messages
|
|
112
115
|
|
|
@@ -115,6 +118,6 @@ def group_messages_by_len(
|
|
|
115
118
|
|
|
116
119
|
|
|
117
120
|
__all__ = [
|
|
118
|
-
|
|
119
|
-
|
|
121
|
+
'split_indices_by_length',
|
|
122
|
+
'group_messages_by_len',
|
|
120
123
|
]
|
llm_utils/lm/__init__.py
CHANGED
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from .async_lm.async_lm import AsyncLM
|
|
2
1
|
from .async_lm.async_llm_task import AsyncLLMTask
|
|
3
|
-
from .
|
|
4
|
-
from .llm import LLM
|
|
2
|
+
from .async_lm.async_lm import AsyncLM
|
|
5
3
|
from .base_prompt_builder import BasePromptBuilder
|
|
4
|
+
from .llm import LLM
|
|
6
5
|
from .llm_signature import LLMSignature
|
|
7
|
-
from .
|
|
6
|
+
from .lm_base import LMBase, get_model_name
|
|
8
7
|
from .mixins import (
|
|
8
|
+
ModelUtilsMixin,
|
|
9
9
|
TemperatureRangeMixin,
|
|
10
10
|
TwoStepPydanticMixin,
|
|
11
11
|
VLLMMixin,
|
|
12
|
-
ModelUtilsMixin,
|
|
13
12
|
)
|
|
13
|
+
from .signature import Input, InputField, Output, OutputField, Signature
|
|
14
|
+
|
|
14
15
|
|
|
15
16
|
__all__ = [
|
|
16
17
|
"LMBase",
|
llm_utils/lm/async_lm/_utils.py
CHANGED
|
@@ -15,12 +15,13 @@ from openai.types.chat import (
|
|
|
15
15
|
from pydantic import BaseModel
|
|
16
16
|
from typing_extensions import TypedDict
|
|
17
17
|
|
|
18
|
+
|
|
18
19
|
# --------------------------------------------------------------------------- #
|
|
19
20
|
# type helpers
|
|
20
21
|
# --------------------------------------------------------------------------- #
|
|
21
22
|
TModel = TypeVar("TModel", bound=BaseModel)
|
|
22
|
-
Messages =
|
|
23
|
-
LegacyMsgs =
|
|
23
|
+
Messages = list[ChatCompletionMessageParam]
|
|
24
|
+
LegacyMsgs = list[dict[str, str]]
|
|
24
25
|
RawMsgs = Union[Messages, LegacyMsgs]
|
|
25
26
|
|
|
26
27
|
# --------------------------------------------------------------------------- #
|
|
@@ -55,10 +56,10 @@ OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
|
|
|
55
56
|
|
|
56
57
|
|
|
57
58
|
class ParsedOutput(TypedDict, Generic[OutputModelType]):
|
|
58
|
-
messages:
|
|
59
|
+
messages: list
|
|
59
60
|
completion: Any
|
|
60
61
|
parsed: OutputModelType
|
|
61
|
-
model_kwargs:
|
|
62
|
+
model_kwargs: dict[str, Any]
|
|
62
63
|
|
|
63
64
|
|
|
64
65
|
# --------------------------------------------------------------------------- #
|
|
@@ -83,7 +84,7 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
83
84
|
async def compute_word_log_probs(
|
|
84
85
|
tokenizer: Any,
|
|
85
86
|
lm_client: Any,
|
|
86
|
-
) -> tuple[
|
|
87
|
+
) -> tuple[list[dict[str, Any]], Any]:
|
|
87
88
|
# Build a prompt that preserves literal newlines
|
|
88
89
|
prompt = tokenizer.apply_chat_template(
|
|
89
90
|
messages,
|
|
@@ -112,7 +113,7 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
112
113
|
}
|
|
113
114
|
|
|
114
115
|
# Flatten tokens
|
|
115
|
-
tokens:
|
|
116
|
+
tokens: list[dict[str, Any]] = [
|
|
116
117
|
{"id": int(tid), **tdata}
|
|
117
118
|
for td in token_logprob_dicts
|
|
118
119
|
for tid, tdata in td.items()
|
|
@@ -133,7 +134,7 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
133
134
|
split_prompt = prompt.replace("\n", " <NL> ")
|
|
134
135
|
words = split_prompt.split()
|
|
135
136
|
|
|
136
|
-
word_log_probs:
|
|
137
|
+
word_log_probs: list[dict[str, Any]] = []
|
|
137
138
|
token_idx = 0
|
|
138
139
|
|
|
139
140
|
for word in words:
|
|
@@ -152,7 +153,7 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
152
153
|
|
|
153
154
|
return word_log_probs, token_logprob_dicts # type: ignore
|
|
154
155
|
|
|
155
|
-
def render_by_logprob(word_log_probs:
|
|
156
|
+
def render_by_logprob(word_log_probs: list[dict[str, Any]]) -> str:
|
|
156
157
|
"""
|
|
157
158
|
Return an ANSI-colored string for word probabilities (red → green).
|
|
158
159
|
"""
|
|
@@ -161,7 +162,7 @@ async def inspect_word_probs_async(lm, tokenizer, messages):
|
|
|
161
162
|
|
|
162
163
|
probs = [entry["probability"] for entry in word_log_probs]
|
|
163
164
|
min_p, max_p = min(probs), max(probs)
|
|
164
|
-
parts:
|
|
165
|
+
parts: list[str] = []
|
|
165
166
|
|
|
166
167
|
for entry in word_log_probs:
|
|
167
168
|
word = entry["word"]
|