speedy-utils 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ import base64
4
+ import hashlib
5
+ import json
6
+ import os
7
+ from typing import (
8
+ Any,
9
+ Dict,
10
+ List,
11
+ Optional,
12
+ Sequence,
13
+ Type,
14
+ TypeVar,
15
+ Union,
16
+ overload,
17
+ cast,
18
+ )
19
+
20
+ from httpx import URL
21
+ from loguru import logger
22
+ from openai import OpenAI, AuthenticationError, RateLimitError
23
+ from openai.pagination import SyncPage
24
+ from openai.types.chat import (
25
+ ChatCompletionAssistantMessageParam,
26
+ ChatCompletionMessageParam,
27
+ ChatCompletionSystemMessageParam,
28
+ ChatCompletionToolMessageParam,
29
+ ChatCompletionUserMessageParam,
30
+ )
31
+ from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
32
+ from openai.types.model import Model
33
+ from pydantic import BaseModel
34
+ import warnings
35
+
36
+ # --------------------------------------------------------------------------- #
37
+ # type helpers
38
+ # --------------------------------------------------------------------------- #
39
+ TModel = TypeVar("TModel", bound=BaseModel)
40
+ Messages = List[ChatCompletionMessageParam] # final, already-typed messages
41
+ LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
42
+ RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
43
+
44
+
45
+ class LM:
46
+ """
47
+ Unified language-model wrapper.
48
+
49
+ • `response_format=str` → returns `str`
50
+ • `response_format=YourPydanticModel` → returns that model instance
51
+ """
52
+
53
+ # --------------------------------------------------------------------- #
54
+ # ctor / plumbing
55
+ # --------------------------------------------------------------------- #
56
+ def __init__(
57
+ self,
58
+ model: str | None = None,
59
+ *,
60
+ temperature: float = 0.0,
61
+ max_tokens: int = 2_000,
62
+ host: str = "localhost",
63
+ port: Optional[int] = None,
64
+ base_url: Optional[str] = None,
65
+ api_key: Optional[str] = None,
66
+ cache: bool = True,
67
+ **openai_kwargs: Any,
68
+ ) -> None:
69
+ self.model = model
70
+ self.temperature = temperature
71
+ self.max_tokens = max_tokens
72
+ self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
73
+ self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
74
+ self.openai_kwargs = openai_kwargs
75
+ self.do_cache = cache
76
+
77
+ self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
78
+
79
+ def set_model(self, model: str) -> None:
80
+ """Set the model name after initialization."""
81
+ self.model = model
82
+
83
+ # --------------------------------------------------------------------- #
84
+ # public API – typed overloads
85
+ # --------------------------------------------------------------------- #
86
+ @overload
87
+ def __call__(
88
+ self,
89
+ *,
90
+ prompt: str | None = ...,
91
+ messages: RawMsgs | None = ...,
92
+ response_format: type[str] = str,
93
+ **kwargs: Any,
94
+ ) -> str: ...
95
+
96
+ @overload
97
+ def __call__(
98
+ self,
99
+ *,
100
+ prompt: str | None = ...,
101
+ messages: RawMsgs | None = ...,
102
+ response_format: Type[TModel],
103
+ **kwargs: Any,
104
+ ) -> TModel: ...
105
+
106
+ # single implementation
107
+ def __call__(
108
+ self,
109
+ prompt: Optional[str] = None,
110
+ messages: Optional[RawMsgs] = None,
111
+ response_format: Union[type[str], Type[BaseModel]] = str,
112
+ cache: Optional[bool] = None,
113
+ max_tokens: Optional[int] = None,
114
+ **kwargs: Any,
115
+ ):
116
+ # argument validation ------------------------------------------------
117
+ if (prompt is None) == (messages is None):
118
+ raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
119
+
120
+ if prompt is not None:
121
+ messages = [{"role": "user", "content": prompt}]
122
+
123
+ assert messages is not None # for type-checker
124
+ assert self.model is not None, "Model must be set before calling."
125
+ openai_msgs: Messages = (
126
+ self._convert_messages(cast(LegacyMsgs, messages))
127
+ if isinstance(messages[0], dict) # legacy style
128
+ else cast(Messages, messages) # already typed
129
+ )
130
+
131
+ kw = dict(
132
+ self.openai_kwargs,
133
+ temperature=self.temperature,
134
+ max_tokens=max_tokens or self.max_tokens,
135
+ **kwargs,
136
+ )
137
+ use_cache = self.do_cache if cache is None else cache
138
+
139
+ raw = self._call_raw(
140
+ openai_msgs,
141
+ response_format=response_format,
142
+ use_cache=use_cache,
143
+ **kw,
144
+ )
145
+ return self._parse_output(raw, response_format)
146
+
147
+ # --------------------------------------------------------------------- #
148
+ # low-level OpenAI call
149
+ # --------------------------------------------------------------------- #
150
+ def _call_raw(
151
+ self,
152
+ messages: Sequence[ChatCompletionMessageParam],
153
+ response_format: Union[type[str], Type[BaseModel]],
154
+ use_cache: bool,
155
+ **kw: Any,
156
+ ):
157
+ assert self.model is not None, "Model must be set before making a call."
158
+ model: str = self.model
159
+ cache_key = (
160
+ self._cache_key(messages, kw, response_format) if use_cache else None
161
+ )
162
+ if cache_key and (hit := self._load_cache(cache_key)) is not None:
163
+ return hit
164
+
165
+ try:
166
+ # structured mode
167
+ if response_format is not str and issubclass(response_format, BaseModel):
168
+ rsp: ParsedChatCompletion[BaseModel] = (
169
+ self.client.beta.chat.completions.parse(
170
+ model=model,
171
+ messages=list(messages),
172
+ response_format=response_format, # type: ignore[arg-type]
173
+ **kw,
174
+ )
175
+ )
176
+ result: Any = rsp.choices[0].message.parsed # already a model
177
+ # plain-text mode
178
+ else:
179
+ rsp = self.client.chat.completions.create(
180
+ model=model,
181
+ messages=list(messages),
182
+ **kw,
183
+ )
184
+ result = rsp.choices[0].message.content # str
185
+ except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
186
+ logger.error(exc)
187
+ raise
188
+
189
+ if cache_key:
190
+ self._dump_cache(cache_key, result)
191
+
192
+ return result
193
+
194
+ # --------------------------------------------------------------------- #
195
+ # legacy → typed messages
196
+ # --------------------------------------------------------------------- #
197
+ @staticmethod
198
+ def _convert_messages(msgs: LegacyMsgs) -> Messages:
199
+ converted: Messages = []
200
+ for msg in msgs:
201
+ role = msg["role"]
202
+ content = msg["content"]
203
+ if role == "user":
204
+ converted.append(
205
+ ChatCompletionUserMessageParam(role="user", content=content)
206
+ )
207
+ elif role == "assistant":
208
+ converted.append(
209
+ ChatCompletionAssistantMessageParam(
210
+ role="assistant", content=content
211
+ )
212
+ )
213
+ elif role == "system":
214
+ converted.append(
215
+ ChatCompletionSystemMessageParam(role="system", content=content)
216
+ )
217
+ elif role == "tool":
218
+ converted.append(
219
+ ChatCompletionToolMessageParam(
220
+ role="tool",
221
+ content=content,
222
+ tool_call_id=msg.get("tool_call_id") or "", # str, never None
223
+ )
224
+ )
225
+ else:
226
+ # fall back to raw dict for unknown roles
227
+ converted.append({"role": role, "content": content}) # type: ignore[arg-type]
228
+ return converted
229
+
230
+ # --------------------------------------------------------------------- #
231
+ # final parse (needed for plain-text or cache hits only)
232
+ # --------------------------------------------------------------------- #
233
+ @staticmethod
234
+ def _parse_output(
235
+ raw: Any,
236
+ response_format: Union[type[str], Type[BaseModel]],
237
+ ) -> str | BaseModel:
238
+ if response_format is str:
239
+ return cast(str, raw)
240
+
241
+ # For the type-checker: we *know* it's a BaseModel subclass here.
242
+ model_cls = cast(Type[BaseModel], response_format)
243
+
244
+ if isinstance(raw, model_cls):
245
+ return raw
246
+ if isinstance(raw, dict):
247
+ return model_cls.model_validate(raw)
248
+ try:
249
+ data = json.loads(raw)
250
+ except Exception as exc: # noqa: BLE001
251
+ raise ValueError(f"Model did not return JSON:\n---\n{raw}") from exc
252
+ return model_cls.model_validate(data)
253
+
254
+ # --------------------------------------------------------------------- #
255
+ # tiny disk cache
256
+ # --------------------------------------------------------------------- #
257
+ @staticmethod
258
+ def _cache_key(
259
+ messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
260
+ ) -> str:
261
+ tag = response_format.__name__ if response_format is not str else "text"
262
+ blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
263
+ return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
264
+
265
+ @staticmethod
266
+ def _cache_path(key: str) -> str:
267
+ return os.path.expanduser(f"~/.cache/lm/{key}.json")
268
+
269
+ def _dump_cache(self, key: str, val: Any) -> None:
270
+ try:
271
+ path = self._cache_path(key)
272
+ os.makedirs(os.path.dirname(path), exist_ok=True)
273
+ with open(path, "w") as fh:
274
+ if isinstance(val, BaseModel):
275
+ json.dump(val.model_dump(mode="json"), fh)
276
+ else:
277
+ json.dump(val, fh)
278
+ except Exception as exc: # pragma: no cover
279
+ logger.debug(f"cache write skipped: {exc}")
280
+
281
+ def _load_cache(self, key: str) -> Any | None:
282
+ path = self._cache_path(key)
283
+ if not os.path.exists(path):
284
+ return None
285
+ try:
286
+ with open(path) as fh:
287
+ return json.load(fh)
288
+ except Exception: # pragma: no cover
289
+ return None
290
+
291
+ @staticmethod
292
+ def list_models(port=None) -> List[str]:
293
+ """
294
+ List available models.
295
+ """
296
+ try:
297
+ client: OpenAI = LM(port=port).client
298
+ base_url: URL = client.base_url
299
+ logger.debug(f"Base URL: {base_url}")
300
+ models: SyncPage[Model] = client.models.list()
301
+ return [model.id for model in models.data]
302
+ except Exception as exc:
303
+ logger.error(f"Failed to list models: {exc}")
304
+ return []
llm_utils/lm/utils.py ADDED
@@ -0,0 +1,130 @@
1
+ import fcntl
2
+ import os
3
+ import tempfile
4
+ import time
5
+ from typing import List, Dict
6
+ import numpy as np
7
+ from loguru import logger
8
+
9
+
10
+ def _clear_port_use(ports):
11
+ for port in ports:
12
+ file_counter: str = f"/tmp/port_use_counter_{port}.npy"
13
+ if os.path.exists(file_counter):
14
+ os.remove(file_counter)
15
+
16
+
17
+ def _atomic_save(array: np.ndarray, filename: str):
18
+ tmp_dir = os.path.dirname(filename) or "."
19
+ with tempfile.NamedTemporaryFile(dir=tmp_dir, delete=False) as tmp:
20
+ np.save(tmp, array)
21
+ temp_name = tmp.name
22
+ os.replace(temp_name, filename)
23
+
24
+
25
+ def _update_port_use(port: int, increment: int) -> None:
26
+ file_counter: str = f"/tmp/port_use_counter_{port}.npy"
27
+ file_counter_lock: str = f"/tmp/port_use_counter_{port}.lock"
28
+ with open(file_counter_lock, "w") as lock_file:
29
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
30
+ try:
31
+ if os.path.exists(file_counter):
32
+ try:
33
+ counter = np.load(file_counter)
34
+ except Exception as e:
35
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
36
+ counter = np.array([0])
37
+ else:
38
+ counter: np.ndarray = np.array([0], dtype=np.int64)
39
+ counter[0] += increment
40
+ _atomic_save(counter, file_counter)
41
+ finally:
42
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
43
+
44
+
45
+ def _pick_least_used_port(ports: List[int]) -> int:
46
+ global_lock_file = "/tmp/ports.lock"
47
+ with open(global_lock_file, "w") as lock_file:
48
+ fcntl.flock(lock_file, fcntl.LOCK_EX)
49
+ try:
50
+ port_use: Dict[int, int] = {}
51
+ for port in ports:
52
+ file_counter = f"/tmp/port_use_counter_{port}.npy"
53
+ if os.path.exists(file_counter):
54
+ try:
55
+ counter = np.load(file_counter)
56
+ except Exception as e:
57
+ logger.warning(f"Corrupted usage file {file_counter}: {e}")
58
+ counter = np.array([0])
59
+ else:
60
+ counter = np.array([0])
61
+ port_use[port] = counter[0]
62
+ if not port_use:
63
+ if ports:
64
+ raise ValueError("Port usage data is empty, cannot pick a port.")
65
+ else:
66
+ raise ValueError("No ports provided to pick from.")
67
+ lsp = min(port_use, key=lambda k: port_use[k])
68
+ _update_port_use(lsp, 1)
69
+ finally:
70
+ fcntl.flock(lock_file, fcntl.LOCK_UN)
71
+ return lsp
72
+
73
+
74
+ def retry_on_exception(max_retries=10, exceptions=(Exception,), sleep_time=3):
75
+ def decorator(func):
76
+ from functools import wraps
77
+
78
+ def wrapper(self, *args, **kwargs):
79
+ retry_count = kwargs.get("retry_count", 0)
80
+ last_exception = None
81
+ while retry_count <= max_retries:
82
+ try:
83
+ return func(self, *args, **kwargs)
84
+ except exceptions as e:
85
+ import litellm
86
+
87
+ if isinstance(
88
+ e, (litellm.exceptions.APIError, litellm.exceptions.Timeout)
89
+ ):
90
+ base_url_info = kwargs.get(
91
+ "base_url", getattr(self, "base_url", None)
92
+ )
93
+ logger.warning(
94
+ f"[{base_url_info=}] {type(e).__name__}: {str(e)[:100]}, will sleep for {sleep_time}s and retry"
95
+ )
96
+ time.sleep(sleep_time)
97
+ retry_count += 1
98
+ kwargs["retry_count"] = retry_count
99
+ last_exception = e
100
+ continue
101
+ elif hasattr(
102
+ litellm.exceptions, "ContextWindowExceededError"
103
+ ) and isinstance(e, litellm.exceptions.ContextWindowExceededError):
104
+ logger.error(f"Context window exceeded: {e}")
105
+ raise
106
+ else:
107
+ logger.error(f"Generic error during LLM call: {e}")
108
+ import traceback
109
+
110
+ traceback.print_exc()
111
+ raise
112
+ logger.error(f"Retry limit exceeded, error: {last_exception}")
113
+ if last_exception:
114
+ raise last_exception
115
+ raise ValueError("Retry limit exceeded with no specific error.")
116
+
117
+ return wraps(func)(wrapper)
118
+
119
+ return decorator
120
+
121
+
122
+ def forward_only(func):
123
+ from functools import wraps
124
+
125
+ @wraps(func)
126
+ def wrapper(self, *args, **kwargs):
127
+ kwargs["retry_count"] = 0
128
+ return func(self, *args, **kwargs)
129
+
130
+ return wrapper