speedy-utils 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
llm_utils/__init__.py CHANGED
@@ -10,7 +10,6 @@ from .chat_format import (
10
10
  transform_messages_to_chatml,
11
11
  )
12
12
  from .lm.async_lm import AsyncLLMTask, AsyncLM
13
- from .lm.sync_lm import LM, LLMTask
14
13
 
15
14
  __all__ = [
16
15
  "transform_messages",
@@ -21,10 +20,7 @@ __all__ = [
21
20
  "display_conversations",
22
21
  "build_chatml_input",
23
22
  "format_msgs",
24
- # "group_messages_by_len",
25
- "LM",
26
- "AsyncLM",
27
23
  "display_chat_messages_as_html",
28
- "LLMTask",
24
+ "AsyncLM",
29
25
  "AsyncLLMTask",
30
26
  ]
@@ -1,7 +1,9 @@
1
1
  from __future__ import annotations
2
+
3
+ from difflib import SequenceMatcher
2
4
  from typing import Any, Optional
5
+
3
6
  from IPython.display import HTML, display
4
- from difflib import SequenceMatcher
5
7
 
6
8
 
7
9
  def show_chat(
@@ -19,6 +21,17 @@ def show_chat(
19
21
  isinstance(msg, dict) and "role" in msg and "content" in msg for msg in msgs
20
22
  ), "The input format is not recognized. Please specify the input format."
21
23
 
24
+ if isinstance(msgs[-1], dict) and "choices" in msgs[-1]:
25
+ message = msgs[-1]["choices"][0]["message"]
26
+ reasoning_content = message.get("reasoning_content")
27
+ content = message.get("content", "")
28
+ if reasoning_content:
29
+ content = reasoning_content + "\n" + content
30
+ msgs[-1] = {
31
+ "role": message["role"],
32
+ "content": content,
33
+ }
34
+
22
35
  themes: dict[str, dict[str, dict[str, str]]] = {
23
36
  "default": {
24
37
  "system": {"background": "#ffaaaa", "text": "#222222"}, # More red
@@ -156,9 +169,9 @@ def get_conversation_one_turn(
156
169
  if assistant_msg is not None:
157
170
  messages.append({"role": "assistant", "content": assistant_msg})
158
171
  if assistant_prefix is not None:
159
- assert (
160
- return_format != "chatml"
161
- ), 'Change return_format to "text" if you want to use assistant_prefix'
172
+ assert return_format != "chatml", (
173
+ 'Change return_format to "text" if you want to use assistant_prefix'
174
+ )
162
175
  assert messages[-1]["role"] == "user"
163
176
  from .transform import transform_messages
164
177
 
@@ -16,9 +16,9 @@ def identify_format(item):
16
16
  def _transform_sharegpt_to_chatml(
17
17
  item, default_system_message="You are a helpful assistant.", print_msg=False
18
18
  ):
19
- assert isinstance(
20
- item, dict
21
- ), "The item is not in the correct format. Please check the format of the item."
19
+ assert isinstance(item, dict), (
20
+ "The item is not in the correct format. Please check the format of the item."
21
+ )
22
22
 
23
23
  messages = []
24
24
  system_msg = item.get("system", "")
@@ -116,16 +116,16 @@ def transform_messages_to_chatml(input_data, input_format="auto"):
116
116
  input_data = deepcopy(input_data)
117
117
  if isinstance(input_data, list):
118
118
  input_format = "chatlm"
119
- assert (
120
- input_data[0].get("role") is not None
121
- ), "The input format is not recognized. Please specify the input format."
119
+ assert input_data[0].get("role") is not None, (
120
+ "The input format is not recognized. Please specify the input format."
121
+ )
122
122
  elif isinstance(input_data, dict):
123
123
  input_data = _transform_sharegpt_to_chatml(input_data)
124
124
  input_format = "sharegpt"
125
125
  elif isinstance(input_data, str):
126
- assert (
127
- "<|im_end|>" in input_data
128
- ), "The input format is not recognized. Please specify the input format."
126
+ assert "<|im_end|>" in input_data, (
127
+ "The input format is not recognized. Please specify the input format."
128
+ )
129
129
  input_format = "chatlm"
130
130
  parts = input_data.split("<|im_end|>")
131
131
  input_data = []
@@ -76,7 +76,7 @@ def group_messages_by_len(
76
76
  """
77
77
  if messages is None:
78
78
  raise ValueError("messages parameter cannot be None")
79
- from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
79
+ from transformers.models.auto.tokenization_auto import AutoTokenizer # type: ignore
80
80
 
81
81
  tokenizer = AutoTokenizer.from_pretrained(model_name)
82
82
 
@@ -0,0 +1,7 @@
1
+ from .async_llm_task import AsyncLLMTask
2
+ from .async_lm import AsyncLM
3
+
4
+ __all__ = [
5
+ "AsyncLM",
6
+ "AsyncLLMTask",
7
+ ]
@@ -0,0 +1,201 @@
1
+ from functools import lru_cache
2
+ from typing import (
3
+ Any,
4
+ Dict,
5
+ Generic,
6
+ List,
7
+ TypeVar,
8
+ Union,
9
+ )
10
+
11
+ # from openai.pagination import AsyncSyncPage
12
+ from openai.types.chat import (
13
+ ChatCompletionMessageParam,
14
+ )
15
+ from pydantic import BaseModel
16
+ from typing_extensions import TypedDict
17
+
18
+ # --------------------------------------------------------------------------- #
19
+ # type helpers
20
+ # --------------------------------------------------------------------------- #
21
+ TModel = TypeVar("TModel", bound=BaseModel)
22
+ Messages = List[ChatCompletionMessageParam]
23
+ LegacyMsgs = List[Dict[str, str]]
24
+ RawMsgs = Union[Messages, LegacyMsgs]
25
+
26
+ # --------------------------------------------------------------------------- #
27
+ # color helpers (unchanged)
28
+ # --------------------------------------------------------------------------- #
29
+
30
+
31
+ def _color(code: int, text: str) -> str:
32
+ return f"\x1b[{code}m{text}\x1b[0m"
33
+
34
+
35
+ def _red(t):
36
+ return _color(31, t)
37
+
38
+
39
+ def _green(t):
40
+ return _color(32, t)
41
+
42
+
43
+ def _blue(t):
44
+ return _color(34, t)
45
+
46
+
47
+ def _yellow(t):
48
+ return _color(33, t)
49
+
50
+
51
+ # TParsed = TypeVar("TParsed", bound=BaseModel)
52
+
53
+ InputModelType = TypeVar("InputModelType", bound=BaseModel)
54
+ OutputModelType = TypeVar("OutputModelType", bound=BaseModel)
55
+
56
+
57
+ class ParsedOutput(TypedDict, Generic[OutputModelType]):
58
+ messages: List
59
+ completion: Any
60
+ parsed: OutputModelType
61
+ model_kwargs: Dict[str, Any]
62
+
63
+
64
+ # --------------------------------------------------------------------------- #
65
+ # Module-level utility functions (async versions)
66
+ # --------------------------------------------------------------------------- #
67
+
68
+
69
+ @lru_cache(maxsize=10)
70
+ def get_tokenizer(model_name: str) -> Any:
71
+ """Get tokenizer for the given model."""
72
+ from transformers import AutoTokenizer # type: ignore
73
+
74
+ tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
75
+ return tokenizer
76
+
77
+
78
+ async def inspect_word_probs_async(lm, tokenizer, messages):
79
+ """Async version of inspect_word_probs."""
80
+
81
+ import numpy as np
82
+
83
+ async def compute_word_log_probs(
84
+ tokenizer: Any,
85
+ lm_client: Any,
86
+ ) -> tuple[List[Dict[str, Any]], Any]:
87
+ # Build a prompt that preserves literal newlines
88
+ prompt = tokenizer.apply_chat_template(
89
+ messages,
90
+ tokenize=False, # Don't tokenize yet, we need raw text
91
+ add_generation_prompt=False, # No generation prompt needed
92
+ )
93
+
94
+ # Request token logprobs
95
+ response = await lm_client.client.completions.create(
96
+ model=lm_client.model, # type: ignore
97
+ prompt=prompt,
98
+ max_tokens=1,
99
+ logprobs=1,
100
+ extra_body={"prompt_logprobs": 0},
101
+ )
102
+ token_logprob_dicts = response.choices[0].prompt_logprobs # type: ignore
103
+
104
+ # Override first token to known start marker
105
+ start_id = tokenizer.encode("<|im_start|>")[0]
106
+ token_logprob_dicts[0] = {
107
+ str(start_id): {
108
+ "logprob": -1,
109
+ "rank": 1,
110
+ "decoded_token": "<|im_start|>",
111
+ }
112
+ }
113
+
114
+ # Flatten tokens
115
+ tokens: List[Dict[str, Any]] = [
116
+ {"id": int(tid), **tdata}
117
+ for td in token_logprob_dicts
118
+ for tid, tdata in td.items()
119
+ ]
120
+
121
+ # Validate tokenization
122
+ tokenized = tokenizer.tokenize(prompt)
123
+ if len(tokenized) != len(tokens):
124
+ raise ValueError(f"Token count mismatch: {len(tokenized)} vs {len(tokens)}")
125
+ for idx, tok in enumerate(tokens):
126
+ if tokenized[idx] != tok["decoded_token"]:
127
+ raise AssertionError(
128
+ f"Token mismatch at {idx}: "
129
+ f"{tokenized[idx]} != {tok['decoded_token']}"
130
+ )
131
+
132
+ # Split on newline sentinel
133
+ split_prompt = prompt.replace("\n", " <NL> ")
134
+ words = split_prompt.split()
135
+
136
+ word_log_probs: List[Dict[str, Any]] = []
137
+ token_idx = 0
138
+
139
+ for word in words:
140
+ # Map sentinel back to actual newline for encoding
141
+ target = "\n" if word == "<NL>" else word
142
+ sub_ids = tokenizer.encode(target, add_special_tokens=False)
143
+ count = len(sub_ids)
144
+ if count == 0:
145
+ continue
146
+
147
+ subs = tokens[token_idx : token_idx + count]
148
+ avg_logprob = sum(s["logprob"] for s in subs) / count
149
+ prob = float(np.exp(avg_logprob))
150
+ word_log_probs.append({"word": target, "probability": prob})
151
+ token_idx += count
152
+
153
+ return word_log_probs, token_logprob_dicts # type: ignore
154
+
155
+ def render_by_logprob(word_log_probs: List[Dict[str, Any]]) -> str:
156
+ """
157
+ Return an ANSI-colored string for word probabilities (red → green).
158
+ """
159
+ if not word_log_probs:
160
+ return ""
161
+
162
+ probs = [entry["probability"] for entry in word_log_probs]
163
+ min_p, max_p = min(probs), max(probs)
164
+ parts: List[str] = []
165
+
166
+ for entry in word_log_probs:
167
+ word = entry["word"]
168
+ # Preserve actual line breaks
169
+ if word == "\n":
170
+ parts.append("\n")
171
+ continue
172
+
173
+ p = entry["probability"]
174
+ norm = (p - min_p) / (max_p - min_p or 1.0)
175
+ r = int(255 * (1 - norm)) # red component (high when prob is low)
176
+ g = int(255 * norm) # green component (high when prob is high)
177
+ b = 0 # no blue for red-green gradient
178
+ colored = f"\x1b[38;2;{r};{g};{b}m{word}\x1b[0m"
179
+ parts.append(colored + " ")
180
+
181
+ return "".join(parts).rstrip()
182
+
183
+ word_probs, token_logprob_dicts = await compute_word_log_probs(tokenizer, lm)
184
+ return word_probs, token_logprob_dicts, render_by_logprob(word_probs)
185
+
186
+
187
+ __all__ = [
188
+ "TModel",
189
+ "Messages",
190
+ "LegacyMsgs",
191
+ "RawMsgs",
192
+ "ParsedOutput",
193
+ "get_tokenizer",
194
+ "inspect_word_probs_async",
195
+ "_color",
196
+ "_red",
197
+ "_green",
198
+ "_blue",
199
+ "_yellow",
200
+ ]
201
+ # --------------------------------------------------------------------------- #]