speedy-utils 1.0.11__tar.gz → 1.0.13__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/PKG-INFO +1 -1
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/pyproject.toml +2 -2
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/__init__.py +4 -1
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/lm/__init__.py +2 -1
- speedy_utils-1.0.13/src/llm_utils/lm/alm.py +447 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/lm/lm.py +138 -44
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/scripts/vllm_load_balancer.py +7 -6
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/scripts/vllm_serve.py +4 -4
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/README.md +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/chat_format/__init__.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/chat_format/display.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/chat_format/transform.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/chat_format/utils.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/group_messages.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/llm_utils/lm/utils.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/__init__.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/all.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/__init__.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/clock.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/function_decorator.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/logger.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/report_manager.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/utils_cache.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/utils_io.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/utils_misc.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/common/utils_print.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/multi_worker/__init__.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/multi_worker/process.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/multi_worker/thread.py +0 -0
- {speedy_utils-1.0.11 → speedy_utils-1.0.13}/src/speedy_utils/scripts/mpython.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
[tool.poetry]
|
|
2
2
|
name = "speedy-utils"
|
|
3
|
-
version = "1.0.
|
|
3
|
+
version = "1.0.13"
|
|
4
4
|
description = "Fast and easy-to-use package for data science"
|
|
5
5
|
authors = ["AnhVTH <anhvth.226@gmail.com>"]
|
|
6
6
|
readme = "README.md"
|
|
@@ -11,7 +11,7 @@ packages = [
|
|
|
11
11
|
]
|
|
12
12
|
|
|
13
13
|
[build-system]
|
|
14
|
-
requires = ["poetry-core>=1.0.
|
|
14
|
+
requires = ["poetry-core>=1.0.13"]
|
|
15
15
|
build-backend = "poetry.core.masonry.api"
|
|
16
16
|
|
|
17
17
|
[tool.black]
|
|
@@ -9,7 +9,8 @@ from .chat_format import (
|
|
|
9
9
|
format_msgs,
|
|
10
10
|
display_chat_messages_as_html,
|
|
11
11
|
)
|
|
12
|
-
from .lm import LM
|
|
12
|
+
from .lm.lm import LM, LMReasoner
|
|
13
|
+
from .lm.alm import AsyncLM
|
|
13
14
|
from .group_messages import (
|
|
14
15
|
split_indices_by_length,
|
|
15
16
|
group_messages_by_len,
|
|
@@ -27,5 +28,7 @@ __all__ = [
|
|
|
27
28
|
"split_indices_by_length",
|
|
28
29
|
"group_messages_by_len",
|
|
29
30
|
"LM",
|
|
31
|
+
"LMReasoner",
|
|
32
|
+
"AsyncLM",
|
|
30
33
|
"display_chat_messages_as_html",
|
|
31
34
|
]
|
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""An **asynchronous** drop‑in replacement for the original `LM` class.
|
|
4
|
+
|
|
5
|
+
Usage example (Python ≥3.8):
|
|
6
|
+
|
|
7
|
+
from async_lm import AsyncLM
|
|
8
|
+
import asyncio
|
|
9
|
+
|
|
10
|
+
async def main():
|
|
11
|
+
lm = AsyncLM(model="gpt-4o-mini")
|
|
12
|
+
reply: str = await lm(prompt="Hello, world!")
|
|
13
|
+
print(reply)
|
|
14
|
+
|
|
15
|
+
asyncio.run(main())
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import base64
|
|
20
|
+
import hashlib
|
|
21
|
+
import json
|
|
22
|
+
import os
|
|
23
|
+
from typing import (
|
|
24
|
+
Any,
|
|
25
|
+
Dict,
|
|
26
|
+
List,
|
|
27
|
+
Optional,
|
|
28
|
+
Sequence,
|
|
29
|
+
Type,
|
|
30
|
+
TypeVar,
|
|
31
|
+
Union,
|
|
32
|
+
overload,
|
|
33
|
+
cast,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from httpx import URL
|
|
37
|
+
from openai import AsyncOpenAI, AuthenticationError, RateLimitError
|
|
38
|
+
|
|
39
|
+
# from openai.pagination import AsyncSyncPage
|
|
40
|
+
from openai.types.chat import (
|
|
41
|
+
ChatCompletionAssistantMessageParam,
|
|
42
|
+
ChatCompletionMessageParam,
|
|
43
|
+
ChatCompletionSystemMessageParam,
|
|
44
|
+
ChatCompletionToolMessageParam,
|
|
45
|
+
ChatCompletionUserMessageParam,
|
|
46
|
+
)
|
|
47
|
+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
48
|
+
from openai.types.model import Model
|
|
49
|
+
from pydantic import BaseModel
|
|
50
|
+
from loguru import logger
|
|
51
|
+
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
52
|
+
|
|
53
|
+
# --------------------------------------------------------------------------- #
|
|
54
|
+
# type helpers
|
|
55
|
+
# --------------------------------------------------------------------------- #
|
|
56
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
57
|
+
Messages = List[ChatCompletionMessageParam]
|
|
58
|
+
LegacyMsgs = List[Dict[str, str]]
|
|
59
|
+
RawMsgs = Union[Messages, LegacyMsgs]
|
|
60
|
+
|
|
61
|
+
# --------------------------------------------------------------------------- #
|
|
62
|
+
# color helpers (unchanged)
|
|
63
|
+
# --------------------------------------------------------------------------- #
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _color(code: int, text: str) -> str:
|
|
67
|
+
return f"\x1b[{code}m{text}\x1b[0m"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_red = lambda t: _color(31, t)
|
|
71
|
+
_green = lambda t: _color(32, t)
|
|
72
|
+
_blue = lambda t: _color(34, t)
|
|
73
|
+
_yellow = lambda t: _color(33, t)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AsyncLM:
|
|
77
|
+
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
model: str | None = None,
|
|
82
|
+
*,
|
|
83
|
+
temperature: float = 0.0,
|
|
84
|
+
max_tokens: int = 2_000,
|
|
85
|
+
host: str = "localhost",
|
|
86
|
+
port: Optional[int | str] = None,
|
|
87
|
+
base_url: Optional[str] = None,
|
|
88
|
+
api_key: Optional[str] = None,
|
|
89
|
+
cache: bool = True,
|
|
90
|
+
ports: Optional[List[int]] = None,
|
|
91
|
+
**openai_kwargs: Any,
|
|
92
|
+
) -> None:
|
|
93
|
+
self.model = model
|
|
94
|
+
self.temperature = temperature
|
|
95
|
+
self.max_tokens = max_tokens
|
|
96
|
+
self.port = port
|
|
97
|
+
self.host = host
|
|
98
|
+
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
99
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
100
|
+
self.openai_kwargs = openai_kwargs
|
|
101
|
+
self.do_cache = cache
|
|
102
|
+
self.ports = ports
|
|
103
|
+
|
|
104
|
+
# Async client
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def client(self) -> AsyncOpenAI:
|
|
108
|
+
# if have multiple ports
|
|
109
|
+
if self.ports:
|
|
110
|
+
import random
|
|
111
|
+
port = random.choice(self.ports)
|
|
112
|
+
api_base = f"http://{self.host}:{port}/v1"
|
|
113
|
+
logger.debug(f"Using port: {port}")
|
|
114
|
+
else:
|
|
115
|
+
api_base = self.base_url or f"http://{self.host}:{self.port}/v1"
|
|
116
|
+
client = AsyncOpenAI(
|
|
117
|
+
api_key=self.api_key, base_url=api_base, **self.openai_kwargs
|
|
118
|
+
)
|
|
119
|
+
return client
|
|
120
|
+
|
|
121
|
+
# ------------------------------------------------------------------ #
|
|
122
|
+
# Public API – typed overloads
|
|
123
|
+
# ------------------------------------------------------------------ #
|
|
124
|
+
@overload
|
|
125
|
+
async def __call__(
|
|
126
|
+
self,
|
|
127
|
+
*,
|
|
128
|
+
prompt: str | None = ...,
|
|
129
|
+
messages: RawMsgs | None = ...,
|
|
130
|
+
response_format: type[str] = str,
|
|
131
|
+
return_openai_response: bool = ...,
|
|
132
|
+
**kwargs: Any,
|
|
133
|
+
) -> str: ...
|
|
134
|
+
|
|
135
|
+
@overload
|
|
136
|
+
async def __call__(
|
|
137
|
+
self,
|
|
138
|
+
*,
|
|
139
|
+
prompt: str | None = ...,
|
|
140
|
+
messages: RawMsgs | None = ...,
|
|
141
|
+
response_format: Type[TModel],
|
|
142
|
+
return_openai_response: bool = ...,
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
) -> TModel: ...
|
|
145
|
+
|
|
146
|
+
async def __call__(
|
|
147
|
+
self,
|
|
148
|
+
prompt: Optional[str] = None,
|
|
149
|
+
messages: Optional[RawMsgs] = None,
|
|
150
|
+
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
151
|
+
cache: Optional[bool] = None,
|
|
152
|
+
max_tokens: Optional[int] = None,
|
|
153
|
+
return_openai_response: bool = False,
|
|
154
|
+
**kwargs: Any,
|
|
155
|
+
):
|
|
156
|
+
if (prompt is None) == (messages is None):
|
|
157
|
+
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
158
|
+
|
|
159
|
+
if prompt is not None:
|
|
160
|
+
messages = [{"role": "user", "content": prompt}]
|
|
161
|
+
|
|
162
|
+
assert messages is not None
|
|
163
|
+
# assert self.model is not None, "Model must be set before calling."
|
|
164
|
+
if not self.model:
|
|
165
|
+
models = await self.list_models(port=self.port, host=self.host)
|
|
166
|
+
self.model = models[0] if models else None
|
|
167
|
+
logger.info(
|
|
168
|
+
f"No model specified. Using the first available model. {self.model}"
|
|
169
|
+
)
|
|
170
|
+
openai_msgs: Messages = (
|
|
171
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
172
|
+
if isinstance(messages[0], dict)
|
|
173
|
+
else cast(Messages, messages)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
kw = dict(
|
|
177
|
+
self.openai_kwargs,
|
|
178
|
+
temperature=self.temperature,
|
|
179
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
180
|
+
)
|
|
181
|
+
kw.update(kwargs)
|
|
182
|
+
use_cache = self.do_cache if cache is None else cache
|
|
183
|
+
|
|
184
|
+
raw_response = await self._call_raw(
|
|
185
|
+
openai_msgs,
|
|
186
|
+
response_format=response_format,
|
|
187
|
+
use_cache=use_cache,
|
|
188
|
+
**kw,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if return_openai_response:
|
|
192
|
+
response = raw_response
|
|
193
|
+
else:
|
|
194
|
+
response = self._parse_output(raw_response, response_format)
|
|
195
|
+
|
|
196
|
+
self.last_log = [prompt, messages, raw_response]
|
|
197
|
+
return response
|
|
198
|
+
|
|
199
|
+
# ------------------------------------------------------------------ #
|
|
200
|
+
# Model invocation (async)
|
|
201
|
+
# ------------------------------------------------------------------ #
|
|
202
|
+
async def _call_raw(
|
|
203
|
+
self,
|
|
204
|
+
messages: Sequence[ChatCompletionMessageParam],
|
|
205
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
206
|
+
use_cache: bool,
|
|
207
|
+
**kw: Any,
|
|
208
|
+
):
|
|
209
|
+
assert self.model is not None, "Model must be set before making a call."
|
|
210
|
+
model: str = self.model
|
|
211
|
+
|
|
212
|
+
cache_key = (
|
|
213
|
+
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
214
|
+
)
|
|
215
|
+
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
216
|
+
return hit
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
if response_format is not str and issubclass(response_format, BaseModel):
|
|
220
|
+
openai_response = await self.client.beta.chat.completions.parse(
|
|
221
|
+
model=model,
|
|
222
|
+
messages=list(messages),
|
|
223
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
224
|
+
**kw,
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
openai_response = await self.client.chat.completions.create(
|
|
228
|
+
model=model,
|
|
229
|
+
messages=list(messages),
|
|
230
|
+
**kw,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
except (AuthenticationError, RateLimitError) as exc:
|
|
234
|
+
logger.error(exc)
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
if cache_key:
|
|
238
|
+
self._dump_cache(cache_key, openai_response)
|
|
239
|
+
|
|
240
|
+
return openai_response
|
|
241
|
+
|
|
242
|
+
# ------------------------------------------------------------------ #
|
|
243
|
+
# Utilities below are unchanged (sync I/O is acceptable)
|
|
244
|
+
# ------------------------------------------------------------------ #
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
247
|
+
converted: Messages = []
|
|
248
|
+
for msg in msgs:
|
|
249
|
+
role = msg["role"]
|
|
250
|
+
content = msg["content"]
|
|
251
|
+
if role == "user":
|
|
252
|
+
converted.append(
|
|
253
|
+
ChatCompletionUserMessageParam(role="user", content=content)
|
|
254
|
+
)
|
|
255
|
+
elif role == "assistant":
|
|
256
|
+
converted.append(
|
|
257
|
+
ChatCompletionAssistantMessageParam(
|
|
258
|
+
role="assistant", content=content
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
elif role == "system":
|
|
262
|
+
converted.append(
|
|
263
|
+
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
264
|
+
)
|
|
265
|
+
elif role == "tool":
|
|
266
|
+
converted.append(
|
|
267
|
+
ChatCompletionToolMessageParam(
|
|
268
|
+
role="tool",
|
|
269
|
+
content=content,
|
|
270
|
+
tool_call_id=msg.get("tool_call_id") or "",
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
275
|
+
return converted
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def _parse_output(
|
|
279
|
+
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
280
|
+
) -> str | BaseModel:
|
|
281
|
+
if hasattr(raw_response, "model_dump"):
|
|
282
|
+
raw_response = raw_response.model_dump()
|
|
283
|
+
|
|
284
|
+
if response_format is str:
|
|
285
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
286
|
+
message = raw_response["choices"][0]["message"]
|
|
287
|
+
return message.get("content", "") or ""
|
|
288
|
+
return cast(str, raw_response)
|
|
289
|
+
|
|
290
|
+
model_cls = cast(Type[BaseModel], response_format)
|
|
291
|
+
|
|
292
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
293
|
+
message = raw_response["choices"][0]["message"]
|
|
294
|
+
if "parsed" in message:
|
|
295
|
+
return model_cls.model_validate(message["parsed"])
|
|
296
|
+
content = message.get("content")
|
|
297
|
+
if content is None:
|
|
298
|
+
raise ValueError("Model returned empty content")
|
|
299
|
+
try:
|
|
300
|
+
data = json.loads(content)
|
|
301
|
+
return model_cls.model_validate(data)
|
|
302
|
+
except Exception as exc:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
305
|
+
) from exc
|
|
306
|
+
|
|
307
|
+
if isinstance(raw_response, model_cls):
|
|
308
|
+
return raw_response
|
|
309
|
+
if isinstance(raw_response, dict):
|
|
310
|
+
return model_cls.model_validate(raw_response)
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
data = json.loads(raw_response)
|
|
314
|
+
return model_cls.model_validate(data)
|
|
315
|
+
except Exception as exc:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
318
|
+
) from exc
|
|
319
|
+
|
|
320
|
+
# ------------------------------------------------------------------ #
|
|
321
|
+
# Simple disk cache (sync)
|
|
322
|
+
# ------------------------------------------------------------------ #
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _cache_key(
|
|
325
|
+
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
326
|
+
) -> str:
|
|
327
|
+
tag = response_format.__name__ if response_format is not str else "text"
|
|
328
|
+
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
329
|
+
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _cache_path(key: str) -> str:
|
|
333
|
+
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
334
|
+
|
|
335
|
+
def _dump_cache(self, key: str, val: Any) -> None:
|
|
336
|
+
try:
|
|
337
|
+
path = self._cache_path(key)
|
|
338
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
339
|
+
with open(path, "w") as fh:
|
|
340
|
+
if isinstance(val, BaseModel):
|
|
341
|
+
json.dump(val.model_dump(mode="json"), fh)
|
|
342
|
+
else:
|
|
343
|
+
json.dump(val, fh)
|
|
344
|
+
except Exception as exc:
|
|
345
|
+
logger.debug(f"cache write skipped: {exc}")
|
|
346
|
+
|
|
347
|
+
def _load_cache(self, key: str) -> Any | None:
|
|
348
|
+
path = self._cache_path(key)
|
|
349
|
+
if not os.path.exists(path):
|
|
350
|
+
return None
|
|
351
|
+
try:
|
|
352
|
+
with open(path) as fh:
|
|
353
|
+
return json.load(fh)
|
|
354
|
+
except Exception:
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
# ------------------------------------------------------------------ #
|
|
358
|
+
# Utility helpers
|
|
359
|
+
# ------------------------------------------------------------------ #
|
|
360
|
+
async def inspect_history(self) -> None:
|
|
361
|
+
if not hasattr(self, "last_log"):
|
|
362
|
+
raise ValueError("No history available. Please call the model first.")
|
|
363
|
+
|
|
364
|
+
prompt, messages, response = self.last_log
|
|
365
|
+
if hasattr(response, "model_dump"):
|
|
366
|
+
response = response.model_dump()
|
|
367
|
+
if not messages:
|
|
368
|
+
messages = [{"role": "user", "content": prompt}]
|
|
369
|
+
|
|
370
|
+
print("\n\n")
|
|
371
|
+
print(_blue("[Conversation History]") + "\n")
|
|
372
|
+
|
|
373
|
+
for msg in messages:
|
|
374
|
+
role = msg["role"]
|
|
375
|
+
content = msg["content"]
|
|
376
|
+
print(_red(f"{role.capitalize()}:"))
|
|
377
|
+
if isinstance(content, str):
|
|
378
|
+
print(content.strip())
|
|
379
|
+
elif isinstance(content, list):
|
|
380
|
+
for item in content:
|
|
381
|
+
if item.get("type") == "text":
|
|
382
|
+
print(item["text"].strip())
|
|
383
|
+
elif item.get("type") == "image_url":
|
|
384
|
+
image_url = item["image_url"]["url"]
|
|
385
|
+
if "base64" in image_url:
|
|
386
|
+
len_base64 = len(image_url.split("base64,")[1])
|
|
387
|
+
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
388
|
+
else:
|
|
389
|
+
print(_blue(f"<image_url: {image_url}>"))
|
|
390
|
+
print("\n")
|
|
391
|
+
|
|
392
|
+
print(_red("Response:"))
|
|
393
|
+
if isinstance(response, dict) and response.get("choices"):
|
|
394
|
+
message = response["choices"][0].get("message", {})
|
|
395
|
+
reasoning = message.get("reasoning_content")
|
|
396
|
+
parsed = message.get("parsed")
|
|
397
|
+
content = message.get("content")
|
|
398
|
+
if reasoning:
|
|
399
|
+
print(_yellow("<think>"))
|
|
400
|
+
print(reasoning.strip())
|
|
401
|
+
print(_yellow("</think>\n"))
|
|
402
|
+
if parsed:
|
|
403
|
+
print(
|
|
404
|
+
json.dumps(
|
|
405
|
+
(
|
|
406
|
+
parsed.model_dump()
|
|
407
|
+
if hasattr(parsed, "model_dump")
|
|
408
|
+
else parsed
|
|
409
|
+
),
|
|
410
|
+
indent=2,
|
|
411
|
+
)
|
|
412
|
+
+ "\n"
|
|
413
|
+
)
|
|
414
|
+
elif content:
|
|
415
|
+
print(content.strip())
|
|
416
|
+
else:
|
|
417
|
+
print(_green("[No content]"))
|
|
418
|
+
if len(response["choices"]) > 1:
|
|
419
|
+
print(
|
|
420
|
+
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
424
|
+
if isinstance(response, str):
|
|
425
|
+
print(_green(response.strip()))
|
|
426
|
+
elif isinstance(response, dict):
|
|
427
|
+
print(_green(json.dumps(response, indent=2)))
|
|
428
|
+
else:
|
|
429
|
+
print(_green(str(response)))
|
|
430
|
+
|
|
431
|
+
# ------------------------------------------------------------------ #
|
|
432
|
+
# Misc helpers
|
|
433
|
+
# ------------------------------------------------------------------ #
|
|
434
|
+
def set_model(self, model: str) -> None:
|
|
435
|
+
self.model = model
|
|
436
|
+
|
|
437
|
+
@staticmethod
|
|
438
|
+
async def list_models(port=None, host="localhost") -> List[str]:
|
|
439
|
+
try:
|
|
440
|
+
client: AsyncOpenAI = AsyncLM(port=port, host=host).client # type: ignore[arg-type]
|
|
441
|
+
base_url: URL = client.base_url
|
|
442
|
+
logger.debug(f"Base URL: {base_url}")
|
|
443
|
+
models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
|
|
444
|
+
return [model.id for model in models.data]
|
|
445
|
+
except Exception as exc:
|
|
446
|
+
logger.error(f"Failed to list models: {exc}")
|
|
447
|
+
return []
|
|
@@ -4,6 +4,7 @@ import base64
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
+
from token import OP
|
|
7
8
|
from typing import (
|
|
8
9
|
Any,
|
|
9
10
|
Dict,
|
|
@@ -98,6 +99,7 @@ class LM:
|
|
|
98
99
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
99
100
|
self.openai_kwargs = openai_kwargs
|
|
100
101
|
self.do_cache = cache
|
|
102
|
+
self._init_port = port # <-- store the port provided at init
|
|
101
103
|
|
|
102
104
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
103
105
|
|
|
@@ -149,7 +151,20 @@ class LM:
|
|
|
149
151
|
messages = [{"role": "user", "content": prompt}]
|
|
150
152
|
|
|
151
153
|
assert messages is not None # for type-checker
|
|
152
|
-
|
|
154
|
+
|
|
155
|
+
# If model is not specified, but port is provided, use the first available model
|
|
156
|
+
if self.model is None:
|
|
157
|
+
port = self._init_port
|
|
158
|
+
if port:
|
|
159
|
+
available_models = self.list_models(port=port)
|
|
160
|
+
if available_models:
|
|
161
|
+
self.model = available_models[0]
|
|
162
|
+
logger.info(f"Auto-selected model: {self.model}")
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("No models available to select from.")
|
|
165
|
+
else:
|
|
166
|
+
raise AssertionError("Model must be set before calling.")
|
|
167
|
+
|
|
153
168
|
openai_msgs: Messages = (
|
|
154
169
|
self._convert_messages(cast(LegacyMsgs, messages))
|
|
155
170
|
if isinstance(messages[0], dict) # legacy style
|
|
@@ -170,7 +185,7 @@ class LM:
|
|
|
170
185
|
use_cache=use_cache,
|
|
171
186
|
**kw,
|
|
172
187
|
)
|
|
173
|
-
|
|
188
|
+
|
|
174
189
|
if return_openai_response:
|
|
175
190
|
response = raw_response
|
|
176
191
|
else:
|
|
@@ -182,24 +197,24 @@ class LM:
|
|
|
182
197
|
def inspect_history(self) -> None:
|
|
183
198
|
if not hasattr(self, "last_log"):
|
|
184
199
|
raise ValueError("No history available. Please call the model first.")
|
|
185
|
-
|
|
200
|
+
|
|
186
201
|
prompt, messages, response = self.last_log
|
|
187
202
|
# Ensure response is a dictionary
|
|
188
203
|
if hasattr(response, "model_dump"):
|
|
189
204
|
response = response.model_dump()
|
|
190
|
-
|
|
205
|
+
|
|
191
206
|
if not messages:
|
|
192
207
|
messages = [{"role": "user", "content": prompt}]
|
|
193
|
-
|
|
208
|
+
|
|
194
209
|
print("\n\n")
|
|
195
210
|
print(_blue("[Conversation History]") + "\n")
|
|
196
|
-
|
|
211
|
+
|
|
197
212
|
# Print all messages in the conversation
|
|
198
213
|
for msg in messages:
|
|
199
214
|
role = msg["role"]
|
|
200
215
|
content = msg["content"]
|
|
201
216
|
print(_red(f"{role.capitalize()}:"))
|
|
202
|
-
|
|
217
|
+
|
|
203
218
|
if isinstance(content, str):
|
|
204
219
|
print(content.strip())
|
|
205
220
|
elif isinstance(content, list):
|
|
@@ -215,40 +230,40 @@ class LM:
|
|
|
215
230
|
else:
|
|
216
231
|
print(_blue(f"<image_url: {image_url}>"))
|
|
217
232
|
print("\n")
|
|
218
|
-
|
|
233
|
+
|
|
219
234
|
# Print the response - now always an OpenAI completion
|
|
220
235
|
print(_red("Response:"))
|
|
221
|
-
|
|
236
|
+
|
|
222
237
|
# Handle OpenAI response object
|
|
223
|
-
if isinstance(response, dict) and
|
|
224
|
-
message = response[
|
|
238
|
+
if isinstance(response, dict) and "choices" in response and response["choices"]:
|
|
239
|
+
message = response["choices"][0].get("message", {})
|
|
225
240
|
|
|
226
241
|
# Check for reasoning content (if available)
|
|
227
|
-
reasoning = message.get(
|
|
242
|
+
reasoning = message.get("reasoning_content")
|
|
228
243
|
|
|
229
244
|
# Check for parsed content (structured mode)
|
|
230
|
-
parsed = message.get(
|
|
245
|
+
parsed = message.get("parsed")
|
|
231
246
|
|
|
232
247
|
# Get regular content
|
|
233
|
-
content = message.get(
|
|
248
|
+
content = message.get("content")
|
|
234
249
|
|
|
235
250
|
# Display reasoning if available
|
|
236
251
|
if reasoning:
|
|
237
|
-
print(_yellow(
|
|
252
|
+
print(_yellow("<think>"))
|
|
238
253
|
print(reasoning.strip())
|
|
239
|
-
print(_yellow(
|
|
254
|
+
print(_yellow("</think>"))
|
|
240
255
|
print()
|
|
241
256
|
|
|
242
257
|
# Display parsed content for structured responses
|
|
243
258
|
if parsed:
|
|
244
259
|
# print(_green('<Parsed Structure>'))
|
|
245
|
-
if hasattr(parsed,
|
|
260
|
+
if hasattr(parsed, "model_dump"):
|
|
246
261
|
print(json.dumps(parsed.model_dump(), indent=2))
|
|
247
262
|
else:
|
|
248
263
|
print(json.dumps(parsed, indent=2))
|
|
249
264
|
# print(_green('</Parsed Structure>'))
|
|
250
265
|
print()
|
|
251
|
-
|
|
266
|
+
|
|
252
267
|
else:
|
|
253
268
|
if content:
|
|
254
269
|
# print(_green("<Content>"))
|
|
@@ -256,10 +271,12 @@ class LM:
|
|
|
256
271
|
# print(_green("</Content>"))
|
|
257
272
|
else:
|
|
258
273
|
print(_green("[No content]"))
|
|
259
|
-
|
|
274
|
+
|
|
260
275
|
# Show if there were multiple completions
|
|
261
|
-
if len(response[
|
|
262
|
-
print(
|
|
276
|
+
if len(response["choices"]) > 1:
|
|
277
|
+
print(
|
|
278
|
+
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
279
|
+
)
|
|
263
280
|
else:
|
|
264
281
|
# Fallback for non-standard response objects or cached responses
|
|
265
282
|
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
@@ -269,7 +286,7 @@ class LM:
|
|
|
269
286
|
print(_green(json.dumps(response, indent=2)))
|
|
270
287
|
else:
|
|
271
288
|
print(_green(str(response)))
|
|
272
|
-
|
|
289
|
+
|
|
273
290
|
# print("\n\n")
|
|
274
291
|
|
|
275
292
|
# --------------------------------------------------------------------- #
|
|
@@ -286,9 +303,7 @@ class LM:
|
|
|
286
303
|
model: str = self.model
|
|
287
304
|
|
|
288
305
|
cache_key = (
|
|
289
|
-
self._cache_key(messages, kw, response_format)
|
|
290
|
-
if use_cache
|
|
291
|
-
else None
|
|
306
|
+
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
292
307
|
)
|
|
293
308
|
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
294
309
|
return hit
|
|
@@ -364,50 +379,54 @@ class LM:
|
|
|
364
379
|
response_format: Union[type[str], Type[BaseModel]],
|
|
365
380
|
) -> str | BaseModel:
|
|
366
381
|
# Convert any object to dict if needed
|
|
367
|
-
if hasattr(raw_response,
|
|
382
|
+
if hasattr(raw_response, "model_dump"):
|
|
368
383
|
raw_response = raw_response.model_dump()
|
|
369
|
-
|
|
384
|
+
|
|
370
385
|
if response_format is str:
|
|
371
386
|
# Extract the content from OpenAI response dict
|
|
372
|
-
if isinstance(raw_response, dict) and
|
|
373
|
-
message = raw_response[
|
|
374
|
-
return message.get(
|
|
387
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
388
|
+
message = raw_response["choices"][0]["message"]
|
|
389
|
+
return message.get("content", "") or ""
|
|
375
390
|
return cast(str, raw_response)
|
|
376
|
-
|
|
391
|
+
|
|
377
392
|
# For the type-checker: we *know* it's a BaseModel subclass here.
|
|
378
393
|
model_cls = cast(Type[BaseModel], response_format)
|
|
379
394
|
|
|
380
395
|
# Handle structured response
|
|
381
|
-
if isinstance(raw_response, dict) and
|
|
382
|
-
message = raw_response[
|
|
383
|
-
|
|
396
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
397
|
+
message = raw_response["choices"][0]["message"]
|
|
398
|
+
|
|
384
399
|
# Check if already parsed by OpenAI client
|
|
385
|
-
if
|
|
386
|
-
return model_cls.model_validate(message[
|
|
387
|
-
|
|
400
|
+
if "parsed" in message:
|
|
401
|
+
return model_cls.model_validate(message["parsed"])
|
|
402
|
+
|
|
388
403
|
# Need to parse the content
|
|
389
|
-
content = message.get(
|
|
404
|
+
content = message.get("content")
|
|
390
405
|
if content is None:
|
|
391
406
|
raise ValueError("Model returned empty content")
|
|
392
|
-
|
|
407
|
+
|
|
393
408
|
try:
|
|
394
409
|
data = json.loads(content)
|
|
395
410
|
return model_cls.model_validate(data)
|
|
396
411
|
except Exception as exc:
|
|
397
|
-
raise ValueError(
|
|
398
|
-
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
414
|
+
) from exc
|
|
415
|
+
|
|
399
416
|
# Handle cached response or other formats
|
|
400
417
|
if isinstance(raw_response, model_cls):
|
|
401
418
|
return raw_response
|
|
402
419
|
if isinstance(raw_response, dict):
|
|
403
420
|
return model_cls.model_validate(raw_response)
|
|
404
|
-
|
|
421
|
+
|
|
405
422
|
# Try parsing as JSON string
|
|
406
423
|
try:
|
|
407
424
|
data = json.loads(raw_response)
|
|
408
425
|
return model_cls.model_validate(data)
|
|
409
426
|
except Exception as exc:
|
|
410
|
-
raise ValueError(
|
|
427
|
+
raise ValueError(
|
|
428
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
429
|
+
) from exc
|
|
411
430
|
|
|
412
431
|
# --------------------------------------------------------------------- #
|
|
413
432
|
# tiny disk cache
|
|
@@ -421,7 +440,7 @@ class LM:
|
|
|
421
440
|
tag = response_format.__name__ if response_format is not str else "text"
|
|
422
441
|
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
423
442
|
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
424
|
-
|
|
443
|
+
|
|
425
444
|
@staticmethod
|
|
426
445
|
def _cache_path(key: str) -> str:
|
|
427
446
|
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
@@ -462,3 +481,78 @@ class LM:
|
|
|
462
481
|
except Exception as exc:
|
|
463
482
|
logger.error(f"Failed to list models: {exc}")
|
|
464
483
|
return []
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
from functools import cache
|
|
487
|
+
from llm_utils.lm.lm import LM, RawMsgs
|
|
488
|
+
from pydantic import BaseModel
|
|
489
|
+
import re
|
|
490
|
+
import json
|
|
491
|
+
from typing import *
|
|
492
|
+
import re
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class LMReasoner(LM):
|
|
496
|
+
"Regex-based reasoning wrapper for LM."
|
|
497
|
+
|
|
498
|
+
def build_regex_from_pydantic(self, model: type[BaseModel]) -> str:
|
|
499
|
+
"""
|
|
500
|
+
Build a regex pattern string for validating output that should match a Pydantic model.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
model: A Pydantic BaseModel class
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
A regex string that matches a JSON representation of the model
|
|
507
|
+
"""
|
|
508
|
+
# regex = f"<think>\\n.*?\\n</think>\\n\\n\\```json\\n.*"
|
|
509
|
+
print(f"{regex=}")
|
|
510
|
+
|
|
511
|
+
return regex
|
|
512
|
+
|
|
513
|
+
def __call__(
|
|
514
|
+
self,
|
|
515
|
+
response_format: type[BaseModel],
|
|
516
|
+
prompt: Optional[str] = None,
|
|
517
|
+
messages: Optional[RawMsgs] = None,
|
|
518
|
+
**kwargs,
|
|
519
|
+
):
|
|
520
|
+
|
|
521
|
+
if prompt is not None:
|
|
522
|
+
output = super().__call__(
|
|
523
|
+
prompt=prompt
|
|
524
|
+
+ "\nresponse_format:"
|
|
525
|
+
+ str(response_format.model_json_schema()),
|
|
526
|
+
response_format=str,
|
|
527
|
+
# extra_body={"guided_regex": regex},
|
|
528
|
+
**kwargs,
|
|
529
|
+
) # type: ignore
|
|
530
|
+
elif messages is not None:
|
|
531
|
+
# append last message with the json schema
|
|
532
|
+
messages[-1]["content"] += "\nresponse_format:" + str( # type: ignore
|
|
533
|
+
response_format.model_json_schema()
|
|
534
|
+
)
|
|
535
|
+
output = super().__call__(
|
|
536
|
+
messages=messages,
|
|
537
|
+
response_format=str,
|
|
538
|
+
# extra_body={"guided_regex": regex},
|
|
539
|
+
**kwargs,
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
raise ValueError("Either prompt or messages must be provided.")
|
|
543
|
+
# import ipdb; ipdb.set_trace()
|
|
544
|
+
# parse using regex
|
|
545
|
+
pattern = re.compile(
|
|
546
|
+
r"<think>\n(?P<think>.*?)\n</think>\n\n(?P<json>\{.*\})",
|
|
547
|
+
re.DOTALL,
|
|
548
|
+
)
|
|
549
|
+
match = pattern.search(output)
|
|
550
|
+
if not match:
|
|
551
|
+
raise ValueError("Output does not match expected format")
|
|
552
|
+
parsed_output = match.group(0)
|
|
553
|
+
think_part = match.group("think")
|
|
554
|
+
json_part = match.group("json")
|
|
555
|
+
|
|
556
|
+
pydantic_object = response_format.model_validate(json.loads(json_part))
|
|
557
|
+
return pydantic_object
|
|
558
|
+
|
|
@@ -5,18 +5,19 @@ import time
|
|
|
5
5
|
from tabulate import tabulate
|
|
6
6
|
import contextlib
|
|
7
7
|
import aiohttp # <-- Import aiohttp
|
|
8
|
+
from speedy_utils import setup_logger
|
|
8
9
|
from loguru import logger
|
|
9
|
-
|
|
10
|
+
setup_logger(min_interval=5)
|
|
10
11
|
# --- Configuration ---
|
|
11
12
|
LOAD_BALANCER_HOST = "0.0.0.0"
|
|
12
13
|
LOAD_BALANCER_PORT = 8008
|
|
13
14
|
|
|
14
15
|
SCAN_TARGET_HOST = "localhost"
|
|
15
|
-
SCAN_PORT_START =
|
|
16
|
+
SCAN_PORT_START = 8140
|
|
16
17
|
SCAN_PORT_END = 8170 # Inclusive
|
|
17
18
|
SCAN_INTERVAL = 30
|
|
18
19
|
# Timeout applies to the HTTP health check request now
|
|
19
|
-
HEALTH_CHECK_TIMEOUT = 2
|
|
20
|
+
HEALTH_CHECK_TIMEOUT = 2 # Increased slightly for HTTP requests
|
|
20
21
|
|
|
21
22
|
STATUS_PRINT_INTERVAL = 5
|
|
22
23
|
BUFFER_SIZE = 4096
|
|
@@ -83,14 +84,14 @@ async def check_server_health(session, host, port):
|
|
|
83
84
|
# Check for a successful status code (2xx range)
|
|
84
85
|
if 200 <= response.status < 300:
|
|
85
86
|
logger.debug(
|
|
86
|
-
f"Health check success for {url} (Status: {response.status})"
|
|
87
|
+
f"[{LOAD_BALANCER_PORT=}] Health check success for {url} (Status: {response.status})"
|
|
87
88
|
)
|
|
88
89
|
# Ensure the connection is released back to the pool
|
|
89
90
|
await response.release()
|
|
90
91
|
return True
|
|
91
92
|
else:
|
|
92
93
|
logger.debug(
|
|
93
|
-
f"Health check failed for {url} (Status: {response.status})"
|
|
94
|
+
f"[{LOAD_BALANCER_PORT=}] Health check failed for {url} (Status: {response.status})"
|
|
94
95
|
)
|
|
95
96
|
await response.release()
|
|
96
97
|
return False
|
|
@@ -180,7 +181,7 @@ async def scan_and_update_servers():
|
|
|
180
181
|
if server not in connection_counts:
|
|
181
182
|
connection_counts[server] = 0
|
|
182
183
|
|
|
183
|
-
logger.debug(f"Scan complete. Active servers: {available_servers}")
|
|
184
|
+
logger.debug(f"[{LOAD_BALANCER_PORT=}]Scan complete. Active servers: {available_servers}")
|
|
184
185
|
|
|
185
186
|
except asyncio.CancelledError:
|
|
186
187
|
logger.info("Server scan task cancelled.")
|
|
@@ -132,7 +132,7 @@ def serve(args) -> None:
|
|
|
132
132
|
str(args.max_model_len),
|
|
133
133
|
"--enable-prefix-caching",
|
|
134
134
|
"--disable-log-requests",
|
|
135
|
-
"--uvicorn-log-level critical",
|
|
135
|
+
# "--uvicorn-log-level critical",
|
|
136
136
|
]
|
|
137
137
|
if HF_HOME:
|
|
138
138
|
cmd.insert(0, f"HF_HOME={HF_HOME}")
|
|
@@ -234,11 +234,11 @@ def get_args():
|
|
|
234
234
|
"--max_model_len", "-mml", type=int, default=8192, help="Maximum model length"
|
|
235
235
|
)
|
|
236
236
|
parser.add_argument(
|
|
237
|
-
"--
|
|
237
|
+
"--enable_lora",
|
|
238
238
|
dest="enable_lora",
|
|
239
|
-
action="
|
|
239
|
+
action="store_true",
|
|
240
240
|
help="Disable LoRA support",
|
|
241
|
-
default=
|
|
241
|
+
default=False,
|
|
242
242
|
)
|
|
243
243
|
parser.add_argument("--bnb", action="store_true", help="Enable quantization")
|
|
244
244
|
parser.add_argument(
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|