speedy-utils 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +29 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +120 -0
- llm_utils/lm/__init__.py +8 -0
- llm_utils/lm/base_lm.py +304 -0
- llm_utils/lm/utils.py +130 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +416 -0
- speedy_utils/__init__.py +85 -0
- speedy_utils/all.py +159 -0
- {speedy → speedy_utils}/common/__init__.py +0 -0
- speedy_utils/common/clock.py +215 -0
- speedy_utils/common/function_decorator.py +66 -0
- speedy_utils/common/logger.py +207 -0
- speedy_utils/common/report_manager.py +112 -0
- speedy_utils/common/utils_cache.py +264 -0
- {speedy → speedy_utils}/common/utils_io.py +66 -19
- {speedy → speedy_utils}/common/utils_misc.py +25 -11
- speedy_utils/common/utils_print.py +216 -0
- speedy_utils/multi_worker/__init__.py +0 -0
- speedy_utils/multi_worker/process.py +198 -0
- speedy_utils/multi_worker/thread.py +327 -0
- speedy_utils/scripts/mpython.py +108 -0
- speedy_utils-1.0.5.dist-info/METADATA +279 -0
- speedy_utils-1.0.5.dist-info/RECORD +27 -0
- {speedy_utils-1.0.3.dist-info → speedy_utils-1.0.5.dist-info}/WHEEL +1 -2
- speedy_utils-1.0.5.dist-info/entry_points.txt +3 -0
- speedy/__init__.py +0 -53
- speedy/common/clock.py +0 -68
- speedy/common/utils_cache.py +0 -170
- speedy/common/utils_print.py +0 -138
- speedy/multi_worker.py +0 -121
- speedy_utils-1.0.3.dist-info/METADATA +0 -22
- speedy_utils-1.0.3.dist-info/RECORD +0 -12
- speedy_utils-1.0.3.dist-info/top_level.txt +0 -1
llm_utils/lm/base_lm.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
import base64
|
|
4
|
+
import hashlib
|
|
5
|
+
import json
|
|
6
|
+
import os
|
|
7
|
+
from typing import (
|
|
8
|
+
Any,
|
|
9
|
+
Dict,
|
|
10
|
+
List,
|
|
11
|
+
Optional,
|
|
12
|
+
Sequence,
|
|
13
|
+
Type,
|
|
14
|
+
TypeVar,
|
|
15
|
+
Union,
|
|
16
|
+
overload,
|
|
17
|
+
cast,
|
|
18
|
+
)
|
|
19
|
+
|
|
20
|
+
from httpx import URL
|
|
21
|
+
from loguru import logger
|
|
22
|
+
from openai import OpenAI, AuthenticationError, RateLimitError
|
|
23
|
+
from openai.pagination import SyncPage
|
|
24
|
+
from openai.types.chat import (
|
|
25
|
+
ChatCompletionAssistantMessageParam,
|
|
26
|
+
ChatCompletionMessageParam,
|
|
27
|
+
ChatCompletionSystemMessageParam,
|
|
28
|
+
ChatCompletionToolMessageParam,
|
|
29
|
+
ChatCompletionUserMessageParam,
|
|
30
|
+
)
|
|
31
|
+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
32
|
+
from openai.types.model import Model
|
|
33
|
+
from pydantic import BaseModel
|
|
34
|
+
import warnings
|
|
35
|
+
|
|
36
|
+
# --------------------------------------------------------------------------- #
|
|
37
|
+
# type helpers
|
|
38
|
+
# --------------------------------------------------------------------------- #
|
|
39
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
40
|
+
Messages = List[ChatCompletionMessageParam] # final, already-typed messages
|
|
41
|
+
LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
|
|
42
|
+
RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
class LM:
|
|
46
|
+
"""
|
|
47
|
+
Unified language-model wrapper.
|
|
48
|
+
|
|
49
|
+
• `response_format=str` → returns `str`
|
|
50
|
+
• `response_format=YourPydanticModel` → returns that model instance
|
|
51
|
+
"""
|
|
52
|
+
|
|
53
|
+
# --------------------------------------------------------------------- #
|
|
54
|
+
# ctor / plumbing
|
|
55
|
+
# --------------------------------------------------------------------- #
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
model: str | None = None,
|
|
59
|
+
*,
|
|
60
|
+
temperature: float = 0.0,
|
|
61
|
+
max_tokens: int = 2_000,
|
|
62
|
+
host: str = "localhost",
|
|
63
|
+
port: Optional[int] = None,
|
|
64
|
+
base_url: Optional[str] = None,
|
|
65
|
+
api_key: Optional[str] = None,
|
|
66
|
+
cache: bool = True,
|
|
67
|
+
**openai_kwargs: Any,
|
|
68
|
+
) -> None:
|
|
69
|
+
self.model = model
|
|
70
|
+
self.temperature = temperature
|
|
71
|
+
self.max_tokens = max_tokens
|
|
72
|
+
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
73
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
74
|
+
self.openai_kwargs = openai_kwargs
|
|
75
|
+
self.do_cache = cache
|
|
76
|
+
|
|
77
|
+
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
78
|
+
|
|
79
|
+
def set_model(self, model: str) -> None:
|
|
80
|
+
"""Set the model name after initialization."""
|
|
81
|
+
self.model = model
|
|
82
|
+
|
|
83
|
+
# --------------------------------------------------------------------- #
|
|
84
|
+
# public API – typed overloads
|
|
85
|
+
# --------------------------------------------------------------------- #
|
|
86
|
+
@overload
|
|
87
|
+
def __call__(
|
|
88
|
+
self,
|
|
89
|
+
*,
|
|
90
|
+
prompt: str | None = ...,
|
|
91
|
+
messages: RawMsgs | None = ...,
|
|
92
|
+
response_format: type[str] = str,
|
|
93
|
+
**kwargs: Any,
|
|
94
|
+
) -> str: ...
|
|
95
|
+
|
|
96
|
+
@overload
|
|
97
|
+
def __call__(
|
|
98
|
+
self,
|
|
99
|
+
*,
|
|
100
|
+
prompt: str | None = ...,
|
|
101
|
+
messages: RawMsgs | None = ...,
|
|
102
|
+
response_format: Type[TModel],
|
|
103
|
+
**kwargs: Any,
|
|
104
|
+
) -> TModel: ...
|
|
105
|
+
|
|
106
|
+
# single implementation
|
|
107
|
+
def __call__(
|
|
108
|
+
self,
|
|
109
|
+
prompt: Optional[str] = None,
|
|
110
|
+
messages: Optional[RawMsgs] = None,
|
|
111
|
+
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
112
|
+
cache: Optional[bool] = None,
|
|
113
|
+
max_tokens: Optional[int] = None,
|
|
114
|
+
**kwargs: Any,
|
|
115
|
+
):
|
|
116
|
+
# argument validation ------------------------------------------------
|
|
117
|
+
if (prompt is None) == (messages is None):
|
|
118
|
+
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
119
|
+
|
|
120
|
+
if prompt is not None:
|
|
121
|
+
messages = [{"role": "user", "content": prompt}]
|
|
122
|
+
|
|
123
|
+
assert messages is not None # for type-checker
|
|
124
|
+
assert self.model is not None, "Model must be set before calling."
|
|
125
|
+
openai_msgs: Messages = (
|
|
126
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
127
|
+
if isinstance(messages[0], dict) # legacy style
|
|
128
|
+
else cast(Messages, messages) # already typed
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
kw = dict(
|
|
132
|
+
self.openai_kwargs,
|
|
133
|
+
temperature=self.temperature,
|
|
134
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
135
|
+
**kwargs,
|
|
136
|
+
)
|
|
137
|
+
use_cache = self.do_cache if cache is None else cache
|
|
138
|
+
|
|
139
|
+
raw = self._call_raw(
|
|
140
|
+
openai_msgs,
|
|
141
|
+
response_format=response_format,
|
|
142
|
+
use_cache=use_cache,
|
|
143
|
+
**kw,
|
|
144
|
+
)
|
|
145
|
+
return self._parse_output(raw, response_format)
|
|
146
|
+
|
|
147
|
+
# --------------------------------------------------------------------- #
|
|
148
|
+
# low-level OpenAI call
|
|
149
|
+
# --------------------------------------------------------------------- #
|
|
150
|
+
def _call_raw(
|
|
151
|
+
self,
|
|
152
|
+
messages: Sequence[ChatCompletionMessageParam],
|
|
153
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
154
|
+
use_cache: bool,
|
|
155
|
+
**kw: Any,
|
|
156
|
+
):
|
|
157
|
+
assert self.model is not None, "Model must be set before making a call."
|
|
158
|
+
model: str = self.model
|
|
159
|
+
cache_key = (
|
|
160
|
+
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
161
|
+
)
|
|
162
|
+
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
163
|
+
return hit
|
|
164
|
+
|
|
165
|
+
try:
|
|
166
|
+
# structured mode
|
|
167
|
+
if response_format is not str and issubclass(response_format, BaseModel):
|
|
168
|
+
rsp: ParsedChatCompletion[BaseModel] = (
|
|
169
|
+
self.client.beta.chat.completions.parse(
|
|
170
|
+
model=model,
|
|
171
|
+
messages=list(messages),
|
|
172
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
173
|
+
**kw,
|
|
174
|
+
)
|
|
175
|
+
)
|
|
176
|
+
result: Any = rsp.choices[0].message.parsed # already a model
|
|
177
|
+
# plain-text mode
|
|
178
|
+
else:
|
|
179
|
+
rsp = self.client.chat.completions.create(
|
|
180
|
+
model=model,
|
|
181
|
+
messages=list(messages),
|
|
182
|
+
**kw,
|
|
183
|
+
)
|
|
184
|
+
result = rsp.choices[0].message.content # str
|
|
185
|
+
except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
|
|
186
|
+
logger.error(exc)
|
|
187
|
+
raise
|
|
188
|
+
|
|
189
|
+
if cache_key:
|
|
190
|
+
self._dump_cache(cache_key, result)
|
|
191
|
+
|
|
192
|
+
return result
|
|
193
|
+
|
|
194
|
+
# --------------------------------------------------------------------- #
|
|
195
|
+
# legacy → typed messages
|
|
196
|
+
# --------------------------------------------------------------------- #
|
|
197
|
+
@staticmethod
|
|
198
|
+
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
199
|
+
converted: Messages = []
|
|
200
|
+
for msg in msgs:
|
|
201
|
+
role = msg["role"]
|
|
202
|
+
content = msg["content"]
|
|
203
|
+
if role == "user":
|
|
204
|
+
converted.append(
|
|
205
|
+
ChatCompletionUserMessageParam(role="user", content=content)
|
|
206
|
+
)
|
|
207
|
+
elif role == "assistant":
|
|
208
|
+
converted.append(
|
|
209
|
+
ChatCompletionAssistantMessageParam(
|
|
210
|
+
role="assistant", content=content
|
|
211
|
+
)
|
|
212
|
+
)
|
|
213
|
+
elif role == "system":
|
|
214
|
+
converted.append(
|
|
215
|
+
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
216
|
+
)
|
|
217
|
+
elif role == "tool":
|
|
218
|
+
converted.append(
|
|
219
|
+
ChatCompletionToolMessageParam(
|
|
220
|
+
role="tool",
|
|
221
|
+
content=content,
|
|
222
|
+
tool_call_id=msg.get("tool_call_id") or "", # str, never None
|
|
223
|
+
)
|
|
224
|
+
)
|
|
225
|
+
else:
|
|
226
|
+
# fall back to raw dict for unknown roles
|
|
227
|
+
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
228
|
+
return converted
|
|
229
|
+
|
|
230
|
+
# --------------------------------------------------------------------- #
|
|
231
|
+
# final parse (needed for plain-text or cache hits only)
|
|
232
|
+
# --------------------------------------------------------------------- #
|
|
233
|
+
@staticmethod
|
|
234
|
+
def _parse_output(
|
|
235
|
+
raw: Any,
|
|
236
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
237
|
+
) -> str | BaseModel:
|
|
238
|
+
if response_format is str:
|
|
239
|
+
return cast(str, raw)
|
|
240
|
+
|
|
241
|
+
# For the type-checker: we *know* it's a BaseModel subclass here.
|
|
242
|
+
model_cls = cast(Type[BaseModel], response_format)
|
|
243
|
+
|
|
244
|
+
if isinstance(raw, model_cls):
|
|
245
|
+
return raw
|
|
246
|
+
if isinstance(raw, dict):
|
|
247
|
+
return model_cls.model_validate(raw)
|
|
248
|
+
try:
|
|
249
|
+
data = json.loads(raw)
|
|
250
|
+
except Exception as exc: # noqa: BLE001
|
|
251
|
+
raise ValueError(f"Model did not return JSON:\n---\n{raw}") from exc
|
|
252
|
+
return model_cls.model_validate(data)
|
|
253
|
+
|
|
254
|
+
# --------------------------------------------------------------------- #
|
|
255
|
+
# tiny disk cache
|
|
256
|
+
# --------------------------------------------------------------------- #
|
|
257
|
+
@staticmethod
|
|
258
|
+
def _cache_key(
|
|
259
|
+
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
260
|
+
) -> str:
|
|
261
|
+
tag = response_format.__name__ if response_format is not str else "text"
|
|
262
|
+
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
263
|
+
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
264
|
+
|
|
265
|
+
@staticmethod
|
|
266
|
+
def _cache_path(key: str) -> str:
|
|
267
|
+
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
268
|
+
|
|
269
|
+
def _dump_cache(self, key: str, val: Any) -> None:
|
|
270
|
+
try:
|
|
271
|
+
path = self._cache_path(key)
|
|
272
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
273
|
+
with open(path, "w") as fh:
|
|
274
|
+
if isinstance(val, BaseModel):
|
|
275
|
+
json.dump(val.model_dump(mode="json"), fh)
|
|
276
|
+
else:
|
|
277
|
+
json.dump(val, fh)
|
|
278
|
+
except Exception as exc: # pragma: no cover
|
|
279
|
+
logger.debug(f"cache write skipped: {exc}")
|
|
280
|
+
|
|
281
|
+
def _load_cache(self, key: str) -> Any | None:
|
|
282
|
+
path = self._cache_path(key)
|
|
283
|
+
if not os.path.exists(path):
|
|
284
|
+
return None
|
|
285
|
+
try:
|
|
286
|
+
with open(path) as fh:
|
|
287
|
+
return json.load(fh)
|
|
288
|
+
except Exception: # pragma: no cover
|
|
289
|
+
return None
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def list_models(port=None) -> List[str]:
|
|
293
|
+
"""
|
|
294
|
+
List available models.
|
|
295
|
+
"""
|
|
296
|
+
try:
|
|
297
|
+
client: OpenAI = LM(port=port).client
|
|
298
|
+
base_url: URL = client.base_url
|
|
299
|
+
logger.debug(f"Base URL: {base_url}")
|
|
300
|
+
models: SyncPage[Model] = client.models.list()
|
|
301
|
+
return [model.id for model in models.data]
|
|
302
|
+
except Exception as exc:
|
|
303
|
+
logger.error(f"Failed to list models: {exc}")
|
|
304
|
+
return []
|
llm_utils/lm/utils.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
1
|
+
import fcntl
|
|
2
|
+
import os
|
|
3
|
+
import tempfile
|
|
4
|
+
import time
|
|
5
|
+
from typing import List, Dict
|
|
6
|
+
import numpy as np
|
|
7
|
+
from loguru import logger
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _clear_port_use(ports):
|
|
11
|
+
for port in ports:
|
|
12
|
+
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
13
|
+
if os.path.exists(file_counter):
|
|
14
|
+
os.remove(file_counter)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def _atomic_save(array: np.ndarray, filename: str):
|
|
18
|
+
tmp_dir = os.path.dirname(filename) or "."
|
|
19
|
+
with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
|
|
20
|
+
np.save(tmp, array)
|
|
21
|
+
temp_name = tmp.name
|
|
22
|
+
os.replace(temp_name, filename)
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def _update_port_use(port: int, increment: int) -> None:
|
|
26
|
+
file_counter: str = f"/tmp/port_use_counter_{port}.npy"
|
|
27
|
+
file_counter_lock: str = f"/tmp/port_use_counter_{port}.lock"
|
|
28
|
+
with open(file_counter_lock, "w") as lock_file:
|
|
29
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
30
|
+
try:
|
|
31
|
+
if os.path.exists(file_counter):
|
|
32
|
+
try:
|
|
33
|
+
counter = np.load(file_counter)
|
|
34
|
+
except Exception as e:
|
|
35
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
36
|
+
counter = np.array([0])
|
|
37
|
+
else:
|
|
38
|
+
counter: np.ndarray = np.array([0], dtype=np.int64)
|
|
39
|
+
counter[0] += increment
|
|
40
|
+
_atomic_save(counter, file_counter)
|
|
41
|
+
finally:
|
|
42
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def _pick_least_used_port(ports: List[int]) -> int:
|
|
46
|
+
global_lock_file = "/tmp/ports.lock"
|
|
47
|
+
with open(global_lock_file, "w") as lock_file:
|
|
48
|
+
fcntl.flock(lock_file, fcntl.LOCK_EX)
|
|
49
|
+
try:
|
|
50
|
+
port_use: Dict[int, int] = {}
|
|
51
|
+
for port in ports:
|
|
52
|
+
file_counter = f"/tmp/port_use_counter_{port}.npy"
|
|
53
|
+
if os.path.exists(file_counter):
|
|
54
|
+
try:
|
|
55
|
+
counter = np.load(file_counter)
|
|
56
|
+
except Exception as e:
|
|
57
|
+
logger.warning(f"Corrupted usage file {file_counter}: {e}")
|
|
58
|
+
counter = np.array([0])
|
|
59
|
+
else:
|
|
60
|
+
counter = np.array([0])
|
|
61
|
+
port_use[port] = counter[0]
|
|
62
|
+
if not port_use:
|
|
63
|
+
if ports:
|
|
64
|
+
raise ValueError("Port usage data is empty, cannot pick a port.")
|
|
65
|
+
else:
|
|
66
|
+
raise ValueError("No ports provided to pick from.")
|
|
67
|
+
lsp = min(port_use, key=lambda k: port_use[k])
|
|
68
|
+
_update_port_use(lsp, 1)
|
|
69
|
+
finally:
|
|
70
|
+
fcntl.flock(lock_file, fcntl.LOCK_UN)
|
|
71
|
+
return lsp
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
|
|
75
|
+
def decorator(func):
|
|
76
|
+
from functools import wraps
|
|
77
|
+
|
|
78
|
+
def wrapper(self, *args, **kwargs):
|
|
79
|
+
retry_count = kwargs.get("retry_count", 0)
|
|
80
|
+
last_exception = None
|
|
81
|
+
while retry_count <= max_retries:
|
|
82
|
+
try:
|
|
83
|
+
return func(self, *args, **kwargs)
|
|
84
|
+
except exceptions as e:
|
|
85
|
+
import litellm
|
|
86
|
+
|
|
87
|
+
if isinstance(
|
|
88
|
+
e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
|
|
89
|
+
):
|
|
90
|
+
base_url_info = kwargs.get(
|
|
91
|
+
"base_url", getattr(self, "base_url", None)
|
|
92
|
+
)
|
|
93
|
+
logger.warning(
|
|
94
|
+
f"[{base_url_info=}] {type(e).__name__}: {str(e)[:100]}, will sleep for {sleep_time}s and retry"
|
|
95
|
+
)
|
|
96
|
+
time.sleep(sleep_time)
|
|
97
|
+
retry_count += 1
|
|
98
|
+
kwargs["retry_count"] = retry_count
|
|
99
|
+
last_exception = e
|
|
100
|
+
continue
|
|
101
|
+
elif hasattr(
|
|
102
|
+
litellm.exceptions, "ContextWindowExceededError"
|
|
103
|
+
) and isinstance(e, litellm.exceptions.ContextWindowExceededError):
|
|
104
|
+
logger.error(f"Context window exceeded: {e}")
|
|
105
|
+
raise
|
|
106
|
+
else:
|
|
107
|
+
logger.error(f"Generic error during LLM call: {e}")
|
|
108
|
+
import traceback
|
|
109
|
+
|
|
110
|
+
traceback.print_exc()
|
|
111
|
+
raise
|
|
112
|
+
logger.error(f"Retry limit exceeded, error: {last_exception}")
|
|
113
|
+
if last_exception:
|
|
114
|
+
raise last_exception
|
|
115
|
+
raise ValueError("Retry limit exceeded with no specific error.")
|
|
116
|
+
|
|
117
|
+
return wraps(func)(wrapper)
|
|
118
|
+
|
|
119
|
+
return decorator
|
|
120
|
+
|
|
121
|
+
|
|
122
|
+
def forward_only(func):
|
|
123
|
+
from functools import wraps
|
|
124
|
+
|
|
125
|
+
@wraps(func)
|
|
126
|
+
def wrapper(self, *args, **kwargs):
|
|
127
|
+
kwargs["retry_count"] = 0
|
|
128
|
+
return func(self, *args, **kwargs)
|
|
129
|
+
|
|
130
|
+
return wrapper
|