speedy-utils 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,407 @@
1
+ # from ._utils import *
2
+ import base64
3
+ import hashlib
4
+ import json
5
+ import os
6
+ from typing import (
7
+ Any,
8
+ List,
9
+ Optional,
10
+ Type,
11
+ Union,
12
+ cast,
13
+ overload,
14
+ )
15
+
16
+ from httpx import URL
17
+ from loguru import logger
18
+ from openai import AsyncOpenAI
19
+ from openai.pagination import AsyncPage as AsyncSyncPage
20
+ from openai.types.chat import (
21
+ ChatCompletionAssistantMessageParam,
22
+ ChatCompletionSystemMessageParam,
23
+ ChatCompletionToolMessageParam,
24
+ ChatCompletionUserMessageParam,
25
+ )
26
+ from openai.types.model import Model
27
+ from pydantic import BaseModel
28
+
29
+ from ._utils import (
30
+ LegacyMsgs,
31
+ Messages,
32
+ RawMsgs,
33
+ TModel,
34
+ )
35
+
36
+
37
+ class AsyncLMBase:
38
+ """Unified **async** language‑model wrapper with optional JSON parsing."""
39
+
40
+ def __init__(
41
+ self,
42
+ *,
43
+ host: str = "localhost",
44
+ port: Optional[int | str] = None,
45
+ base_url: Optional[str] = None,
46
+ api_key: Optional[str] = None,
47
+ cache: bool = True,
48
+ ports: Optional[List[int]] = None,
49
+ ) -> None:
50
+ self._port = port
51
+ self._host = host
52
+ self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
53
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
54
+ self._cache = cache
55
+ self.ports = ports
56
+ self._init_port = port # <-- store the port provided at init
57
+ self._last_client:AsyncOpenAI = None
58
+
59
+ @property
60
+ def client(self) -> AsyncOpenAI:
61
+ # if have multiple ports
62
+ if self.ports:
63
+ import random
64
+
65
+ port = random.choice(self.ports)
66
+ api_base = f"http://{self._host}:{port}/v1"
67
+ logger.debug(f"Using port: {port}")
68
+ else:
69
+ api_base = self.base_url or f"http://{self._host}:{self._port}/v1"
70
+ client = AsyncOpenAI(
71
+ api_key=self.api_key,
72
+ base_url=api_base,
73
+ )
74
+ self._last_client = client
75
+ return client
76
+
77
+ # ------------------------------------------------------------------ #
78
+ # Public API – typed overloads
79
+ # ------------------------------------------------------------------ #
80
+ @overload
81
+ async def __call__( # type: ignore
82
+ self,
83
+ *,
84
+ prompt: str | None = ...,
85
+ messages: RawMsgs | None = ...,
86
+ response_format: type[str] = str,
87
+ return_openai_response: bool = ...,
88
+ **kwargs: Any,
89
+ ) -> str: ...
90
+
91
+ @overload
92
+ async def __call__(
93
+ self,
94
+ *,
95
+ prompt: str | None = ...,
96
+ messages: RawMsgs | None = ...,
97
+ response_format: Type[TModel],
98
+ return_openai_response: bool = ...,
99
+ **kwargs: Any,
100
+ ) -> TModel: ...
101
+
102
+ # ------------------------------------------------------------------ #
103
+ # Utilities below are unchanged (sync I/O is acceptable)
104
+ # ------------------------------------------------------------------ #
105
+ @staticmethod
106
+ def _convert_messages(msgs: LegacyMsgs) -> Messages:
107
+ converted: Messages = []
108
+ for msg in msgs:
109
+ role = msg["role"]
110
+ content = msg["content"]
111
+ if role == "user":
112
+ converted.append(
113
+ ChatCompletionUserMessageParam(role="user", content=content)
114
+ )
115
+ elif role == "assistant":
116
+ converted.append(
117
+ ChatCompletionAssistantMessageParam(
118
+ role="assistant", content=content
119
+ )
120
+ )
121
+ elif role == "system":
122
+ converted.append(
123
+ ChatCompletionSystemMessageParam(role="system", content=content)
124
+ )
125
+ elif role == "tool":
126
+ converted.append(
127
+ ChatCompletionToolMessageParam(
128
+ role="tool",
129
+ content=content,
130
+ tool_call_id=msg.get("tool_call_id") or "",
131
+ )
132
+ )
133
+ else:
134
+ converted.append({"role": role, "content": content}) # type: ignore[arg-type]
135
+ return converted
136
+
137
+ @staticmethod
138
+ def _parse_output(
139
+ raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
140
+ ) -> str | BaseModel:
141
+ if hasattr(raw_response, "model_dump"):
142
+ raw_response = raw_response.model_dump()
143
+
144
+ if response_format is str:
145
+ if isinstance(raw_response, dict) and "choices" in raw_response:
146
+ message = raw_response["choices"][0]["message"]
147
+ return message.get("content", "") or ""
148
+ return cast(str, raw_response)
149
+
150
+ model_cls = cast(Type[BaseModel], response_format)
151
+
152
+ if isinstance(raw_response, dict) and "choices" in raw_response:
153
+ message = raw_response["choices"][0]["message"]
154
+ if "parsed" in message:
155
+ return model_cls.model_validate(message["parsed"])
156
+ content = message.get("content")
157
+ if content is None:
158
+ raise ValueError("Model returned empty content")
159
+ try:
160
+ data = json.loads(content)
161
+ return model_cls.model_validate(data)
162
+ except Exception as exc:
163
+ raise ValueError(
164
+ f"Failed to parse model output as JSON:\n{content}"
165
+ ) from exc
166
+
167
+ if isinstance(raw_response, model_cls):
168
+ return raw_response
169
+ if isinstance(raw_response, dict):
170
+ return model_cls.model_validate(raw_response)
171
+
172
+ try:
173
+ data = json.loads(raw_response)
174
+ return model_cls.model_validate(data)
175
+ except Exception as exc:
176
+ raise ValueError(
177
+ f"Model did not return valid JSON:\n---\n{raw_response}"
178
+ ) from exc
179
+
180
+ # ------------------------------------------------------------------ #
181
+ # Simple disk cache (sync)
182
+ # ------------------------------------------------------------------ #
183
+ @staticmethod
184
+ def _cache_key(
185
+ messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
186
+ ) -> str:
187
+ tag = response_format.__name__ if response_format is not str else "text"
188
+ blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
189
+ return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
190
+
191
+ @staticmethod
192
+ def _cache_path(key: str) -> str:
193
+ return os.path.expanduser(f"~/.cache/lm/{key}.json")
194
+
195
+ def _dump_cache(self, key: str, val: Any) -> None:
196
+ try:
197
+ path = self._cache_path(key)
198
+ os.makedirs(os.path.dirname(path), exist_ok=True)
199
+ with open(path, "w") as fh:
200
+ if isinstance(val, BaseModel):
201
+ json.dump(val.model_dump(mode="json"), fh)
202
+ else:
203
+ json.dump(val, fh)
204
+ except Exception as exc:
205
+ logger.debug(f"cache write skipped: {exc}")
206
+
207
+ def _load_cache(self, key: str) -> Any | None:
208
+ path = self._cache_path(key)
209
+ if not os.path.exists(path):
210
+ return None
211
+ try:
212
+ with open(path) as fh:
213
+ return json.load(fh)
214
+ except Exception:
215
+ return None
216
+
217
+ # async def inspect_word_probs(
218
+ # self,
219
+ # messages: Optional[List[Dict[str, Any]]] = None,
220
+ # tokenizer: Optional[Any] = None,
221
+ # do_print=True,
222
+ # add_think: bool = True,
223
+ # ) -> tuple[List[Dict[str, Any]], Any, str]:
224
+ # """
225
+ # Inspect word probabilities in a language model response.
226
+
227
+ # Args:
228
+ # tokenizer: Tokenizer instance to encode words.
229
+ # messages: List of messages to analyze.
230
+
231
+ # Returns:
232
+ # A tuple containing:
233
+ # - List of word probabilities with their log probabilities.
234
+ # - Token log probability dictionaries.
235
+ # - Rendered string with colored word probabilities.
236
+ # """
237
+ # if messages is None:
238
+ # messages = await self.last_messages(add_think=add_think)
239
+ # if messages is None:
240
+ # raise ValueError("No messages provided and no last messages available.")
241
+
242
+ # if tokenizer is None:
243
+ # tokenizer = get_tokenizer(self.model)
244
+
245
+ # ret = await inspect_word_probs_async(self, tokenizer, messages)
246
+ # if do_print:
247
+ # print(ret[-1])
248
+ # return ret
249
+
250
+ # async def last_messages(
251
+ # self, add_think: bool = True
252
+ # ) -> Optional[List[Dict[str, str]]]:
253
+ # """Get the last conversation messages including assistant response."""
254
+ # if not hasattr(self, "last_log"):
255
+ # return None
256
+
257
+ # last_conv = self._last_log
258
+ # messages = last_conv[1] if len(last_conv) > 1 else None
259
+ # last_msg = last_conv[2]
260
+ # if not isinstance(last_msg, dict):
261
+ # last_conv[2] = last_conv[2].model_dump() # type: ignore
262
+ # msg = last_conv[2]
263
+ # # Ensure msg is a dict
264
+ # if hasattr(msg, "model_dump"):
265
+ # msg = msg.model_dump()
266
+ # message = msg["choices"][0]["message"]
267
+ # reasoning = message.get("reasoning_content")
268
+ # answer = message.get("content")
269
+ # if reasoning and add_think:
270
+ # final_answer = f"<think>{reasoning}</think>\n{answer}"
271
+ # else:
272
+ # final_answer = f"<think>\n\n</think>\n{answer}"
273
+ # assistant = {"role": "assistant", "content": final_answer}
274
+ # messages = messages + [assistant] # type: ignore
275
+ # return messages if messages else None
276
+
277
+ # async def inspect_history(self) -> None:
278
+ # """Inspect the conversation history with proper formatting."""
279
+ # if not hasattr(self, "last_log"):
280
+ # raise ValueError("No history available. Please call the model first.")
281
+
282
+ # prompt, messages, response = self._last_log
283
+ # if hasattr(response, "model_dump"):
284
+ # response = response.model_dump()
285
+ # if not messages:
286
+ # messages = [{"role": "user", "content": prompt}]
287
+
288
+ # print("\n\n")
289
+ # print(_blue("[Conversation History]") + "\n")
290
+
291
+ # for msg in messages:
292
+ # role = msg["role"]
293
+ # content = msg["content"]
294
+ # print(_red(f"{role.capitalize()}:"))
295
+ # if isinstance(content, str):
296
+ # print(content.strip())
297
+ # elif isinstance(content, list):
298
+ # for item in content:
299
+ # if item.get("type") == "text":
300
+ # print(item["text"].strip())
301
+ # elif item.get("type") == "image_url":
302
+ # image_url = item["image_url"]["url"]
303
+ # if "base64" in image_url:
304
+ # len_base64 = len(image_url.split("base64,")[1])
305
+ # print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
306
+ # else:
307
+ # print(_blue(f"<image_url: {image_url}>"))
308
+ # print("\n")
309
+
310
+ # print(_red("Response:"))
311
+ # if isinstance(response, dict) and response.get("choices"):
312
+ # message = response["choices"][0].get("message", {})
313
+ # reasoning = message.get("reasoning_content")
314
+ # parsed = message.get("parsed")
315
+ # content = message.get("content")
316
+ # if reasoning:
317
+ # print(_yellow("<think>"))
318
+ # print(reasoning.strip())
319
+ # print(_yellow("</think>\n"))
320
+ # if parsed:
321
+ # print(
322
+ # json.dumps(
323
+ # (
324
+ # parsed.model_dump()
325
+ # if hasattr(parsed, "model_dump")
326
+ # else parsed
327
+ # ),
328
+ # indent=2,
329
+ # )
330
+ # + "\n"
331
+ # )
332
+ # elif content:
333
+ # print(content.strip())
334
+ # else:
335
+ # print(_green("[No content]"))
336
+ # if len(response["choices"]) > 1:
337
+ # print(
338
+ # _blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
339
+ # )
340
+ # else:
341
+ # print(_yellow("Warning: Not a standard OpenAI response object"))
342
+ # if isinstance(response, str):
343
+ # print(_green(response.strip()))
344
+ # elif isinstance(response, dict):
345
+ # print(_green(json.dumps(response, indent=2)))
346
+ # else:
347
+ # print(_green(str(response)))
348
+
349
+ # ------------------------------------------------------------------ #
350
+ # Misc helpers
351
+ # ------------------------------------------------------------------ #
352
+
353
+ @staticmethod
354
+ async def list_models(port=None, host="localhost") -> List[str]:
355
+ try:
356
+ client = AsyncLMBase(port=port, host=host).client # type: ignore[arg-type]
357
+ base_url: URL = client.base_url
358
+ logger.debug(f"Base URL: {base_url}")
359
+ models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
360
+ return [model.id for model in models.data]
361
+ except Exception as exc:
362
+ logger.error(f"Failed to list models: {exc}")
363
+ return []
364
+
365
+ def build_system_prompt(
366
+ self,
367
+ response_model,
368
+ add_json_schema_to_instruction,
369
+ json_schema,
370
+ system_content,
371
+ think,
372
+ ):
373
+ if add_json_schema_to_instruction and response_model:
374
+ schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
375
+ # if schema_block not in system_content:
376
+ if "<output_json_schema>" in system_content:
377
+ # remove exsting schema block
378
+ import re # replace
379
+
380
+ system_content = re.sub(
381
+ r"<output_json_schema>.*?</output_json_schema>",
382
+ "",
383
+ system_content,
384
+ flags=re.DOTALL,
385
+ )
386
+ system_content = system_content.strip()
387
+ system_content += schema_block
388
+
389
+ if think is True:
390
+ if "/think" in system_content:
391
+ pass
392
+ elif "/no_think" in system_content:
393
+ system_content = system_content.replace("/no_think", "/think")
394
+ else:
395
+ system_content += "\n\n/think"
396
+ elif think is False:
397
+ if "/no_think" in system_content:
398
+ pass
399
+ elif "/think" in system_content:
400
+ system_content = system_content.replace("/think", "/no_think")
401
+ else:
402
+ system_content += "\n\n/no_think"
403
+ return system_content
404
+
405
+ async def inspect_history(self):
406
+ """Inspect the history of the LLM calls."""
407
+ pass
@@ -0,0 +1,136 @@
1
+ from typing import List
2
+
3
+ from .async_lm import AsyncLM
4
+
5
+ KNOWN_CONFIG = {
6
+ # Qwen3 family (see model card "Best Practices" section)
7
+ "qwen3-think": {
8
+ "sampling_params": {
9
+ "temperature": 0.6,
10
+ "top_p": 0.95,
11
+ "top_k": 20,
12
+ "min_p": 0.0,
13
+ "presence_penalty": 1.5,
14
+ },
15
+ },
16
+ "qwen3-no-think": {
17
+ "sampling_params": {
18
+ "temperature": 0.7,
19
+ "top_p": 0.8,
20
+ "top_k": 20,
21
+ "min_p": 0.0,
22
+ "presence_penalty": 1.5,
23
+ },
24
+ },
25
+ # DeepSeek V3 (model card: temperature=0.3)
26
+ "deepseek-v3": {
27
+ "sampling_params": {
28
+ "temperature": 0.3,
29
+ },
30
+ },
31
+ # DeepSeek R1 (model card: temperature=0.6, top_p=0.95)
32
+ "deepseek-r1": {
33
+ "sampling_params": {
34
+ "temperature": 0.6,
35
+ "top_p": 0.95,
36
+ },
37
+ },
38
+ # Mistral Small 3.2-24B Instruct (model card: temperature=0.15)
39
+ "mistral-small-3.2-24b-instruct-2506": {
40
+ "sampling_params": {
41
+ "temperature": 0.15,
42
+ },
43
+ },
44
+ # Magistral Small 2506 (model card: temperature=0.7, top_p=0.95)
45
+ "magistral-small-2506": {
46
+ "sampling_params": {
47
+ "temperature": 0.7,
48
+ "top_p": 0.95,
49
+ },
50
+ },
51
+ # Phi-4 Reasoning (model card: temperature=0.8, top_k=50, top_p=0.95)
52
+ "phi-4-reasoning": {
53
+ "sampling_params": {
54
+ "temperature": 0.8,
55
+ "top_k": 50,
56
+ "top_p": 0.95,
57
+ },
58
+ },
59
+ # GLM-Z1-32B-0414 (model card: temperature=0.6, top_p=0.95, top_k=40, max_new_tokens=30000)
60
+ "glm-z1-32b-0414": {
61
+ "sampling_params": {
62
+ "temperature": 0.6,
63
+ "top_p": 0.95,
64
+ "top_k": 40,
65
+ "max_new_tokens": 30000,
66
+ },
67
+ },
68
+ # Llama-4-Scout-17B-16E-Instruct (generation_config.json: temperature=0.6, top_p=0.9)
69
+ "llama-4-scout-17b-16e-instruct": {
70
+ "sampling_params": {
71
+ "temperature": 0.6,
72
+ "top_p": 0.9,
73
+ },
74
+ },
75
+ # Gemma-3-27b-it (alleged: temperature=1.0, top_k=64, top_p=0.96)
76
+ "gemma-3-27b-it": {
77
+ "sampling_params": {
78
+ "temperature": 1.0,
79
+ "top_k": 64,
80
+ "top_p": 0.96,
81
+ },
82
+ },
83
+ # Add more as needed...
84
+ }
85
+
86
+ KNOWN_KEYS: List[str] = list(KNOWN_CONFIG.keys())
87
+
88
+
89
+ class AsyncLMQwenThink(AsyncLM):
90
+ def __init__(
91
+ self,
92
+ model: str = "Qwen32B",
93
+ temperature: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"][
94
+ "temperature"
95
+ ],
96
+ top_p: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"]["top_p"],
97
+ top_k: int = KNOWN_CONFIG["qwen3-think"]["sampling_params"]["top_k"],
98
+ presence_penalty: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"][
99
+ "presence_penalty"
100
+ ],
101
+ **other_kwargs,
102
+ ):
103
+ super().__init__(
104
+ model="qwen3-think",
105
+ temperature=temperature,
106
+ top_p=top_p,
107
+ top_k=top_k,
108
+ presence_penalty=presence_penalty,
109
+ **other_kwargs,
110
+ think=True
111
+ )
112
+
113
+
114
+ class AsyncLMQwenNoThink(AsyncLM):
115
+ def __init__(
116
+ self,
117
+ model: str = "Qwen32B",
118
+ temperature: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"][
119
+ "temperature"
120
+ ],
121
+ top_p: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"]["top_p"],
122
+ top_k: int = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"]["top_k"],
123
+ presence_penalty: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"][
124
+ "presence_penalty"
125
+ ],
126
+ **other_kwargs,
127
+ ):
128
+ super().__init__(
129
+ model=model,
130
+ temperature=temperature,
131
+ top_p=top_p,
132
+ top_k=top_k,
133
+ presence_penalty=presence_penalty,
134
+ **other_kwargs,
135
+ think=False
136
+ )
llm_utils/lm/utils.py CHANGED
@@ -7,8 +7,6 @@ import numpy as np
7
7
  from loguru import logger
8
8
 
9
9
 
10
-
11
-
12
10
  def _atomic_save(array: np.ndarray, filename: str):
13
11
  tmp_dir = os.path.dirname(filename) or "."
14
12
  with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
@@ -77,7 +75,7 @@ def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
77
75
  try:
78
76
  return func(self, *args, **kwargs)
79
77
  except exceptions as e:
80
- import litellm # type: ignore
78
+ import litellm # type: ignore
81
79
 
82
80
  if isinstance(
83
81
  e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)