speedy-utils 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +1 -5
- llm_utils/chat_format/display.py +17 -4
- llm_utils/chat_format/transform.py +9 -9
- llm_utils/group_messages.py +1 -1
- llm_utils/lm/async_lm/__init__.py +7 -0
- llm_utils/lm/async_lm/_utils.py +201 -0
- llm_utils/lm/async_lm/async_llm_task.py +509 -0
- llm_utils/lm/async_lm/async_lm.py +387 -0
- llm_utils/lm/async_lm/async_lm_base.py +405 -0
- llm_utils/lm/async_lm/lm_specific.py +136 -0
- llm_utils/lm/utils.py +1 -3
- llm_utils/scripts/vllm_load_balancer.py +244 -147
- speedy_utils/__init__.py +3 -1
- speedy_utils/common/notebook_utils.py +4 -4
- speedy_utils/common/report_manager.py +2 -3
- speedy_utils/common/utils_cache.py +233 -3
- speedy_utils/common/utils_io.py +2 -0
- speedy_utils/scripts/mpython.py +1 -3
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/METADATA +1 -1
- speedy_utils-1.1.7.dist-info/RECORD +39 -0
- llm_utils/lm/async_lm.py +0 -942
- llm_utils/lm/chat_html.py +0 -246
- llm_utils/lm/lm_json.py +0 -68
- llm_utils/lm/sync_lm.py +0 -943
- speedy_utils-1.1.5.dist-info/RECORD +0 -37
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.5.dist-info → speedy_utils-1.1.7.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,405 @@
|
|
|
1
|
+
# from ._utils import *
|
|
2
|
+
import base64
|
|
3
|
+
import hashlib
|
|
4
|
+
import json
|
|
5
|
+
import os
|
|
6
|
+
from typing import (
|
|
7
|
+
Any,
|
|
8
|
+
List,
|
|
9
|
+
Optional,
|
|
10
|
+
Type,
|
|
11
|
+
Union,
|
|
12
|
+
cast,
|
|
13
|
+
overload,
|
|
14
|
+
)
|
|
15
|
+
|
|
16
|
+
from httpx import URL
|
|
17
|
+
from loguru import logger
|
|
18
|
+
from openai import AsyncOpenAI
|
|
19
|
+
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
20
|
+
from openai.types.chat import (
|
|
21
|
+
ChatCompletionAssistantMessageParam,
|
|
22
|
+
ChatCompletionSystemMessageParam,
|
|
23
|
+
ChatCompletionToolMessageParam,
|
|
24
|
+
ChatCompletionUserMessageParam,
|
|
25
|
+
)
|
|
26
|
+
from openai.types.model import Model
|
|
27
|
+
from pydantic import BaseModel
|
|
28
|
+
|
|
29
|
+
from ._utils import (
|
|
30
|
+
LegacyMsgs,
|
|
31
|
+
Messages,
|
|
32
|
+
RawMsgs,
|
|
33
|
+
TModel,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class AsyncLMBase:
|
|
38
|
+
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
host: str = "localhost",
|
|
44
|
+
port: Optional[int | str] = None,
|
|
45
|
+
base_url: Optional[str] = None,
|
|
46
|
+
api_key: Optional[str] = None,
|
|
47
|
+
cache: bool = True,
|
|
48
|
+
ports: Optional[List[int]] = None,
|
|
49
|
+
) -> None:
|
|
50
|
+
self._port = port
|
|
51
|
+
self._host = host
|
|
52
|
+
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
53
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
54
|
+
self._cache = cache
|
|
55
|
+
self.ports = ports
|
|
56
|
+
self._init_port = port # <-- store the port provided at init
|
|
57
|
+
|
|
58
|
+
@property
|
|
59
|
+
def client(self) -> AsyncOpenAI:
|
|
60
|
+
# if have multiple ports
|
|
61
|
+
if self.ports:
|
|
62
|
+
import random
|
|
63
|
+
|
|
64
|
+
port = random.choice(self.ports)
|
|
65
|
+
api_base = f"http://{self._host}:{port}/v1"
|
|
66
|
+
logger.debug(f"Using port: {port}")
|
|
67
|
+
else:
|
|
68
|
+
api_base = self.base_url or f"http://{self._host}:{self._port}/v1"
|
|
69
|
+
client = AsyncOpenAI(
|
|
70
|
+
api_key=self.api_key,
|
|
71
|
+
base_url=api_base,
|
|
72
|
+
)
|
|
73
|
+
return client
|
|
74
|
+
|
|
75
|
+
# ------------------------------------------------------------------ #
|
|
76
|
+
# Public API – typed overloads
|
|
77
|
+
# ------------------------------------------------------------------ #
|
|
78
|
+
@overload
|
|
79
|
+
async def __call__( # type: ignore
|
|
80
|
+
self,
|
|
81
|
+
*,
|
|
82
|
+
prompt: str | None = ...,
|
|
83
|
+
messages: RawMsgs | None = ...,
|
|
84
|
+
response_format: type[str] = str,
|
|
85
|
+
return_openai_response: bool = ...,
|
|
86
|
+
**kwargs: Any,
|
|
87
|
+
) -> str: ...
|
|
88
|
+
|
|
89
|
+
@overload
|
|
90
|
+
async def __call__(
|
|
91
|
+
self,
|
|
92
|
+
*,
|
|
93
|
+
prompt: str | None = ...,
|
|
94
|
+
messages: RawMsgs | None = ...,
|
|
95
|
+
response_format: Type[TModel],
|
|
96
|
+
return_openai_response: bool = ...,
|
|
97
|
+
**kwargs: Any,
|
|
98
|
+
) -> TModel: ...
|
|
99
|
+
|
|
100
|
+
# ------------------------------------------------------------------ #
|
|
101
|
+
# Utilities below are unchanged (sync I/O is acceptable)
|
|
102
|
+
# ------------------------------------------------------------------ #
|
|
103
|
+
@staticmethod
|
|
104
|
+
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
105
|
+
converted: Messages = []
|
|
106
|
+
for msg in msgs:
|
|
107
|
+
role = msg["role"]
|
|
108
|
+
content = msg["content"]
|
|
109
|
+
if role == "user":
|
|
110
|
+
converted.append(
|
|
111
|
+
ChatCompletionUserMessageParam(role="user", content=content)
|
|
112
|
+
)
|
|
113
|
+
elif role == "assistant":
|
|
114
|
+
converted.append(
|
|
115
|
+
ChatCompletionAssistantMessageParam(
|
|
116
|
+
role="assistant", content=content
|
|
117
|
+
)
|
|
118
|
+
)
|
|
119
|
+
elif role == "system":
|
|
120
|
+
converted.append(
|
|
121
|
+
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
122
|
+
)
|
|
123
|
+
elif role == "tool":
|
|
124
|
+
converted.append(
|
|
125
|
+
ChatCompletionToolMessageParam(
|
|
126
|
+
role="tool",
|
|
127
|
+
content=content,
|
|
128
|
+
tool_call_id=msg.get("tool_call_id") or "",
|
|
129
|
+
)
|
|
130
|
+
)
|
|
131
|
+
else:
|
|
132
|
+
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
133
|
+
return converted
|
|
134
|
+
|
|
135
|
+
@staticmethod
|
|
136
|
+
def _parse_output(
|
|
137
|
+
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
138
|
+
) -> str | BaseModel:
|
|
139
|
+
if hasattr(raw_response, "model_dump"):
|
|
140
|
+
raw_response = raw_response.model_dump()
|
|
141
|
+
|
|
142
|
+
if response_format is str:
|
|
143
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
144
|
+
message = raw_response["choices"][0]["message"]
|
|
145
|
+
return message.get("content", "") or ""
|
|
146
|
+
return cast(str, raw_response)
|
|
147
|
+
|
|
148
|
+
model_cls = cast(Type[BaseModel], response_format)
|
|
149
|
+
|
|
150
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
151
|
+
message = raw_response["choices"][0]["message"]
|
|
152
|
+
if "parsed" in message:
|
|
153
|
+
return model_cls.model_validate(message["parsed"])
|
|
154
|
+
content = message.get("content")
|
|
155
|
+
if content is None:
|
|
156
|
+
raise ValueError("Model returned empty content")
|
|
157
|
+
try:
|
|
158
|
+
data = json.loads(content)
|
|
159
|
+
return model_cls.model_validate(data)
|
|
160
|
+
except Exception as exc:
|
|
161
|
+
raise ValueError(
|
|
162
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
163
|
+
) from exc
|
|
164
|
+
|
|
165
|
+
if isinstance(raw_response, model_cls):
|
|
166
|
+
return raw_response
|
|
167
|
+
if isinstance(raw_response, dict):
|
|
168
|
+
return model_cls.model_validate(raw_response)
|
|
169
|
+
|
|
170
|
+
try:
|
|
171
|
+
data = json.loads(raw_response)
|
|
172
|
+
return model_cls.model_validate(data)
|
|
173
|
+
except Exception as exc:
|
|
174
|
+
raise ValueError(
|
|
175
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
176
|
+
) from exc
|
|
177
|
+
|
|
178
|
+
# ------------------------------------------------------------------ #
|
|
179
|
+
# Simple disk cache (sync)
|
|
180
|
+
# ------------------------------------------------------------------ #
|
|
181
|
+
@staticmethod
|
|
182
|
+
def _cache_key(
|
|
183
|
+
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
184
|
+
) -> str:
|
|
185
|
+
tag = response_format.__name__ if response_format is not str else "text"
|
|
186
|
+
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
187
|
+
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
188
|
+
|
|
189
|
+
@staticmethod
|
|
190
|
+
def _cache_path(key: str) -> str:
|
|
191
|
+
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
192
|
+
|
|
193
|
+
def _dump_cache(self, key: str, val: Any) -> None:
|
|
194
|
+
try:
|
|
195
|
+
path = self._cache_path(key)
|
|
196
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
197
|
+
with open(path, "w") as fh:
|
|
198
|
+
if isinstance(val, BaseModel):
|
|
199
|
+
json.dump(val.model_dump(mode="json"), fh)
|
|
200
|
+
else:
|
|
201
|
+
json.dump(val, fh)
|
|
202
|
+
except Exception as exc:
|
|
203
|
+
logger.debug(f"cache write skipped: {exc}")
|
|
204
|
+
|
|
205
|
+
def _load_cache(self, key: str) -> Any | None:
|
|
206
|
+
path = self._cache_path(key)
|
|
207
|
+
if not os.path.exists(path):
|
|
208
|
+
return None
|
|
209
|
+
try:
|
|
210
|
+
with open(path) as fh:
|
|
211
|
+
return json.load(fh)
|
|
212
|
+
except Exception:
|
|
213
|
+
return None
|
|
214
|
+
|
|
215
|
+
# async def inspect_word_probs(
|
|
216
|
+
# self,
|
|
217
|
+
# messages: Optional[List[Dict[str, Any]]] = None,
|
|
218
|
+
# tokenizer: Optional[Any] = None,
|
|
219
|
+
# do_print=True,
|
|
220
|
+
# add_think: bool = True,
|
|
221
|
+
# ) -> tuple[List[Dict[str, Any]], Any, str]:
|
|
222
|
+
# """
|
|
223
|
+
# Inspect word probabilities in a language model response.
|
|
224
|
+
|
|
225
|
+
# Args:
|
|
226
|
+
# tokenizer: Tokenizer instance to encode words.
|
|
227
|
+
# messages: List of messages to analyze.
|
|
228
|
+
|
|
229
|
+
# Returns:
|
|
230
|
+
# A tuple containing:
|
|
231
|
+
# - List of word probabilities with their log probabilities.
|
|
232
|
+
# - Token log probability dictionaries.
|
|
233
|
+
# - Rendered string with colored word probabilities.
|
|
234
|
+
# """
|
|
235
|
+
# if messages is None:
|
|
236
|
+
# messages = await self.last_messages(add_think=add_think)
|
|
237
|
+
# if messages is None:
|
|
238
|
+
# raise ValueError("No messages provided and no last messages available.")
|
|
239
|
+
|
|
240
|
+
# if tokenizer is None:
|
|
241
|
+
# tokenizer = get_tokenizer(self.model)
|
|
242
|
+
|
|
243
|
+
# ret = await inspect_word_probs_async(self, tokenizer, messages)
|
|
244
|
+
# if do_print:
|
|
245
|
+
# print(ret[-1])
|
|
246
|
+
# return ret
|
|
247
|
+
|
|
248
|
+
# async def last_messages(
|
|
249
|
+
# self, add_think: bool = True
|
|
250
|
+
# ) -> Optional[List[Dict[str, str]]]:
|
|
251
|
+
# """Get the last conversation messages including assistant response."""
|
|
252
|
+
# if not hasattr(self, "last_log"):
|
|
253
|
+
# return None
|
|
254
|
+
|
|
255
|
+
# last_conv = self._last_log
|
|
256
|
+
# messages = last_conv[1] if len(last_conv) > 1 else None
|
|
257
|
+
# last_msg = last_conv[2]
|
|
258
|
+
# if not isinstance(last_msg, dict):
|
|
259
|
+
# last_conv[2] = last_conv[2].model_dump() # type: ignore
|
|
260
|
+
# msg = last_conv[2]
|
|
261
|
+
# # Ensure msg is a dict
|
|
262
|
+
# if hasattr(msg, "model_dump"):
|
|
263
|
+
# msg = msg.model_dump()
|
|
264
|
+
# message = msg["choices"][0]["message"]
|
|
265
|
+
# reasoning = message.get("reasoning_content")
|
|
266
|
+
# answer = message.get("content")
|
|
267
|
+
# if reasoning and add_think:
|
|
268
|
+
# final_answer = f"<think>{reasoning}</think>\n{answer}"
|
|
269
|
+
# else:
|
|
270
|
+
# final_answer = f"<think>\n\n</think>\n{answer}"
|
|
271
|
+
# assistant = {"role": "assistant", "content": final_answer}
|
|
272
|
+
# messages = messages + [assistant] # type: ignore
|
|
273
|
+
# return messages if messages else None
|
|
274
|
+
|
|
275
|
+
# async def inspect_history(self) -> None:
|
|
276
|
+
# """Inspect the conversation history with proper formatting."""
|
|
277
|
+
# if not hasattr(self, "last_log"):
|
|
278
|
+
# raise ValueError("No history available. Please call the model first.")
|
|
279
|
+
|
|
280
|
+
# prompt, messages, response = self._last_log
|
|
281
|
+
# if hasattr(response, "model_dump"):
|
|
282
|
+
# response = response.model_dump()
|
|
283
|
+
# if not messages:
|
|
284
|
+
# messages = [{"role": "user", "content": prompt}]
|
|
285
|
+
|
|
286
|
+
# print("\n\n")
|
|
287
|
+
# print(_blue("[Conversation History]") + "\n")
|
|
288
|
+
|
|
289
|
+
# for msg in messages:
|
|
290
|
+
# role = msg["role"]
|
|
291
|
+
# content = msg["content"]
|
|
292
|
+
# print(_red(f"{role.capitalize()}:"))
|
|
293
|
+
# if isinstance(content, str):
|
|
294
|
+
# print(content.strip())
|
|
295
|
+
# elif isinstance(content, list):
|
|
296
|
+
# for item in content:
|
|
297
|
+
# if item.get("type") == "text":
|
|
298
|
+
# print(item["text"].strip())
|
|
299
|
+
# elif item.get("type") == "image_url":
|
|
300
|
+
# image_url = item["image_url"]["url"]
|
|
301
|
+
# if "base64" in image_url:
|
|
302
|
+
# len_base64 = len(image_url.split("base64,")[1])
|
|
303
|
+
# print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
304
|
+
# else:
|
|
305
|
+
# print(_blue(f"<image_url: {image_url}>"))
|
|
306
|
+
# print("\n")
|
|
307
|
+
|
|
308
|
+
# print(_red("Response:"))
|
|
309
|
+
# if isinstance(response, dict) and response.get("choices"):
|
|
310
|
+
# message = response["choices"][0].get("message", {})
|
|
311
|
+
# reasoning = message.get("reasoning_content")
|
|
312
|
+
# parsed = message.get("parsed")
|
|
313
|
+
# content = message.get("content")
|
|
314
|
+
# if reasoning:
|
|
315
|
+
# print(_yellow("<think>"))
|
|
316
|
+
# print(reasoning.strip())
|
|
317
|
+
# print(_yellow("</think>\n"))
|
|
318
|
+
# if parsed:
|
|
319
|
+
# print(
|
|
320
|
+
# json.dumps(
|
|
321
|
+
# (
|
|
322
|
+
# parsed.model_dump()
|
|
323
|
+
# if hasattr(parsed, "model_dump")
|
|
324
|
+
# else parsed
|
|
325
|
+
# ),
|
|
326
|
+
# indent=2,
|
|
327
|
+
# )
|
|
328
|
+
# + "\n"
|
|
329
|
+
# )
|
|
330
|
+
# elif content:
|
|
331
|
+
# print(content.strip())
|
|
332
|
+
# else:
|
|
333
|
+
# print(_green("[No content]"))
|
|
334
|
+
# if len(response["choices"]) > 1:
|
|
335
|
+
# print(
|
|
336
|
+
# _blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
337
|
+
# )
|
|
338
|
+
# else:
|
|
339
|
+
# print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
340
|
+
# if isinstance(response, str):
|
|
341
|
+
# print(_green(response.strip()))
|
|
342
|
+
# elif isinstance(response, dict):
|
|
343
|
+
# print(_green(json.dumps(response, indent=2)))
|
|
344
|
+
# else:
|
|
345
|
+
# print(_green(str(response)))
|
|
346
|
+
|
|
347
|
+
# ------------------------------------------------------------------ #
|
|
348
|
+
# Misc helpers
|
|
349
|
+
# ------------------------------------------------------------------ #
|
|
350
|
+
|
|
351
|
+
@staticmethod
|
|
352
|
+
async def list_models(port=None, host="localhost") -> List[str]:
|
|
353
|
+
try:
|
|
354
|
+
client = AsyncLMBase(port=port, host=host).client # type: ignore[arg-type]
|
|
355
|
+
base_url: URL = client.base_url
|
|
356
|
+
logger.debug(f"Base URL: {base_url}")
|
|
357
|
+
models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
|
|
358
|
+
return [model.id for model in models.data]
|
|
359
|
+
except Exception as exc:
|
|
360
|
+
logger.error(f"Failed to list models: {exc}")
|
|
361
|
+
return []
|
|
362
|
+
|
|
363
|
+
def build_system_prompt(
|
|
364
|
+
self,
|
|
365
|
+
response_model,
|
|
366
|
+
add_json_schema_to_instruction,
|
|
367
|
+
json_schema,
|
|
368
|
+
system_content,
|
|
369
|
+
think,
|
|
370
|
+
):
|
|
371
|
+
if add_json_schema_to_instruction and response_model:
|
|
372
|
+
schema_block = f"\n\n<output_json_schema>\n{json.dumps(json_schema, indent=2)}\n</output_json_schema>"
|
|
373
|
+
# if schema_block not in system_content:
|
|
374
|
+
if "<output_json_schema>" in system_content:
|
|
375
|
+
# remove exsting schema block
|
|
376
|
+
import re # replace
|
|
377
|
+
|
|
378
|
+
system_content = re.sub(
|
|
379
|
+
r"<output_json_schema>.*?</output_json_schema>",
|
|
380
|
+
"",
|
|
381
|
+
system_content,
|
|
382
|
+
flags=re.DOTALL,
|
|
383
|
+
)
|
|
384
|
+
system_content = system_content.strip()
|
|
385
|
+
system_content += schema_block
|
|
386
|
+
|
|
387
|
+
if think is True:
|
|
388
|
+
if "/think" in system_content:
|
|
389
|
+
pass
|
|
390
|
+
elif "/no_think" in system_content:
|
|
391
|
+
system_content = system_content.replace("/no_think", "/think")
|
|
392
|
+
else:
|
|
393
|
+
system_content += "\n\n/think"
|
|
394
|
+
elif think is False:
|
|
395
|
+
if "/no_think" in system_content:
|
|
396
|
+
pass
|
|
397
|
+
elif "/think" in system_content:
|
|
398
|
+
system_content = system_content.replace("/think", "/no_think")
|
|
399
|
+
else:
|
|
400
|
+
system_content += "\n\n/no_think"
|
|
401
|
+
return system_content
|
|
402
|
+
|
|
403
|
+
async def inspect_history(self):
|
|
404
|
+
"""Inspect the history of the LLM calls."""
|
|
405
|
+
pass
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
from typing import List
|
|
2
|
+
|
|
3
|
+
from .async_lm import AsyncLM
|
|
4
|
+
|
|
5
|
+
KNOWN_CONFIG = {
|
|
6
|
+
# Qwen3 family (see model card "Best Practices" section)
|
|
7
|
+
"qwen3-think": {
|
|
8
|
+
"sampling_params": {
|
|
9
|
+
"temperature": 0.6,
|
|
10
|
+
"top_p": 0.95,
|
|
11
|
+
"top_k": 20,
|
|
12
|
+
"min_p": 0.0,
|
|
13
|
+
"presence_penalty": 1.5,
|
|
14
|
+
},
|
|
15
|
+
},
|
|
16
|
+
"qwen3-no-think": {
|
|
17
|
+
"sampling_params": {
|
|
18
|
+
"temperature": 0.7,
|
|
19
|
+
"top_p": 0.8,
|
|
20
|
+
"top_k": 20,
|
|
21
|
+
"min_p": 0.0,
|
|
22
|
+
"presence_penalty": 1.5,
|
|
23
|
+
},
|
|
24
|
+
},
|
|
25
|
+
# DeepSeek V3 (model card: temperature=0.3)
|
|
26
|
+
"deepseek-v3": {
|
|
27
|
+
"sampling_params": {
|
|
28
|
+
"temperature": 0.3,
|
|
29
|
+
},
|
|
30
|
+
},
|
|
31
|
+
# DeepSeek R1 (model card: temperature=0.6, top_p=0.95)
|
|
32
|
+
"deepseek-r1": {
|
|
33
|
+
"sampling_params": {
|
|
34
|
+
"temperature": 0.6,
|
|
35
|
+
"top_p": 0.95,
|
|
36
|
+
},
|
|
37
|
+
},
|
|
38
|
+
# Mistral Small 3.2-24B Instruct (model card: temperature=0.15)
|
|
39
|
+
"mistral-small-3.2-24b-instruct-2506": {
|
|
40
|
+
"sampling_params": {
|
|
41
|
+
"temperature": 0.15,
|
|
42
|
+
},
|
|
43
|
+
},
|
|
44
|
+
# Magistral Small 2506 (model card: temperature=0.7, top_p=0.95)
|
|
45
|
+
"magistral-small-2506": {
|
|
46
|
+
"sampling_params": {
|
|
47
|
+
"temperature": 0.7,
|
|
48
|
+
"top_p": 0.95,
|
|
49
|
+
},
|
|
50
|
+
},
|
|
51
|
+
# Phi-4 Reasoning (model card: temperature=0.8, top_k=50, top_p=0.95)
|
|
52
|
+
"phi-4-reasoning": {
|
|
53
|
+
"sampling_params": {
|
|
54
|
+
"temperature": 0.8,
|
|
55
|
+
"top_k": 50,
|
|
56
|
+
"top_p": 0.95,
|
|
57
|
+
},
|
|
58
|
+
},
|
|
59
|
+
# GLM-Z1-32B-0414 (model card: temperature=0.6, top_p=0.95, top_k=40, max_new_tokens=30000)
|
|
60
|
+
"glm-z1-32b-0414": {
|
|
61
|
+
"sampling_params": {
|
|
62
|
+
"temperature": 0.6,
|
|
63
|
+
"top_p": 0.95,
|
|
64
|
+
"top_k": 40,
|
|
65
|
+
"max_new_tokens": 30000,
|
|
66
|
+
},
|
|
67
|
+
},
|
|
68
|
+
# Llama-4-Scout-17B-16E-Instruct (generation_config.json: temperature=0.6, top_p=0.9)
|
|
69
|
+
"llama-4-scout-17b-16e-instruct": {
|
|
70
|
+
"sampling_params": {
|
|
71
|
+
"temperature": 0.6,
|
|
72
|
+
"top_p": 0.9,
|
|
73
|
+
},
|
|
74
|
+
},
|
|
75
|
+
# Gemma-3-27b-it (alleged: temperature=1.0, top_k=64, top_p=0.96)
|
|
76
|
+
"gemma-3-27b-it": {
|
|
77
|
+
"sampling_params": {
|
|
78
|
+
"temperature": 1.0,
|
|
79
|
+
"top_k": 64,
|
|
80
|
+
"top_p": 0.96,
|
|
81
|
+
},
|
|
82
|
+
},
|
|
83
|
+
# Add more as needed...
|
|
84
|
+
}
|
|
85
|
+
|
|
86
|
+
KNOWN_KEYS: List[str] = list(KNOWN_CONFIG.keys())
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class AsyncLMQwenThink(AsyncLM):
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
model: str = "Qwen32B",
|
|
93
|
+
temperature: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"][
|
|
94
|
+
"temperature"
|
|
95
|
+
],
|
|
96
|
+
top_p: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"]["top_p"],
|
|
97
|
+
top_k: int = KNOWN_CONFIG["qwen3-think"]["sampling_params"]["top_k"],
|
|
98
|
+
presence_penalty: float = KNOWN_CONFIG["qwen3-think"]["sampling_params"][
|
|
99
|
+
"presence_penalty"
|
|
100
|
+
],
|
|
101
|
+
**other_kwargs,
|
|
102
|
+
):
|
|
103
|
+
super().__init__(
|
|
104
|
+
model="qwen3-think",
|
|
105
|
+
temperature=temperature,
|
|
106
|
+
top_p=top_p,
|
|
107
|
+
top_k=top_k,
|
|
108
|
+
presence_penalty=presence_penalty,
|
|
109
|
+
**other_kwargs,
|
|
110
|
+
think=True
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
class AsyncLMQwenNoThink(AsyncLM):
|
|
115
|
+
def __init__(
|
|
116
|
+
self,
|
|
117
|
+
model: str = "Qwen32B",
|
|
118
|
+
temperature: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"][
|
|
119
|
+
"temperature"
|
|
120
|
+
],
|
|
121
|
+
top_p: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"]["top_p"],
|
|
122
|
+
top_k: int = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"]["top_k"],
|
|
123
|
+
presence_penalty: float = KNOWN_CONFIG["qwen3-no-think"]["sampling_params"][
|
|
124
|
+
"presence_penalty"
|
|
125
|
+
],
|
|
126
|
+
**other_kwargs,
|
|
127
|
+
):
|
|
128
|
+
super().__init__(
|
|
129
|
+
model=model,
|
|
130
|
+
temperature=temperature,
|
|
131
|
+
top_p=top_p,
|
|
132
|
+
top_k=top_k,
|
|
133
|
+
presence_penalty=presence_penalty,
|
|
134
|
+
**other_kwargs,
|
|
135
|
+
think=False
|
|
136
|
+
)
|
llm_utils/lm/utils.py
CHANGED
|
@@ -7,8 +7,6 @@ import numpy as np
|
|
|
7
7
|
from loguru import logger
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
|
|
11
|
-
|
|
12
10
|
def _atomic_save(array: np.ndarray, filename: str):
|
|
13
11
|
tmp_dir = os.path.dirname(filename) or "."
|
|
14
12
|
with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
|
|
@@ -77,7 +75,7 @@ def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
|
|
|
77
75
|
try:
|
|
78
76
|
return func(self, *args, **kwargs)
|
|
79
77
|
except exceptions as e:
|
|
80
|
-
import litellm
|
|
78
|
+
import litellm # type: ignore
|
|
81
79
|
|
|
82
80
|
if isinstance(
|
|
83
81
|
e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
|