speedy-utils 1.0.9__py3-none-any.whl → 1.0.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +4 -1
- llm_utils/lm/__init__.py +2 -1
- llm_utils/lm/alm.py +447 -0
- llm_utils/lm/lm.py +282 -28
- llm_utils/scripts/vllm_load_balancer.py +7 -6
- llm_utils/scripts/vllm_serve.py +66 -136
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.12.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.12.dist-info}/RECORD +10 -9
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.12.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.12.dist-info}/entry_points.txt +0 -0
llm_utils/__init__.py
CHANGED
|
@@ -9,7 +9,8 @@ from .chat_format import (
|
|
|
9
9
|
format_msgs,
|
|
10
10
|
display_chat_messages_as_html,
|
|
11
11
|
)
|
|
12
|
-
from .lm import LM
|
|
12
|
+
from .lm.lm import LM, LMReasoner
|
|
13
|
+
from .lm.alm import AsyncLM
|
|
13
14
|
from .group_messages import (
|
|
14
15
|
split_indices_by_length,
|
|
15
16
|
group_messages_by_len,
|
|
@@ -27,5 +28,7 @@ __all__ = [
|
|
|
27
28
|
"split_indices_by_length",
|
|
28
29
|
"group_messages_by_len",
|
|
29
30
|
"LM",
|
|
31
|
+
"LMReasoner",
|
|
32
|
+
"AsyncLM",
|
|
30
33
|
"display_chat_messages_as_html",
|
|
31
34
|
]
|
llm_utils/lm/__init__.py
CHANGED
llm_utils/lm/alm.py
ADDED
|
@@ -0,0 +1,447 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
"""An **asynchronous** drop‑in replacement for the original `LM` class.
|
|
4
|
+
|
|
5
|
+
Usage example (Python ≥3.8):
|
|
6
|
+
|
|
7
|
+
from async_lm import AsyncLM
|
|
8
|
+
import asyncio
|
|
9
|
+
|
|
10
|
+
async def main():
|
|
11
|
+
lm = AsyncLM(model="gpt-4o-mini")
|
|
12
|
+
reply: str = await lm(prompt="Hello, world!")
|
|
13
|
+
print(reply)
|
|
14
|
+
|
|
15
|
+
asyncio.run(main())
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
import asyncio
|
|
19
|
+
import base64
|
|
20
|
+
import hashlib
|
|
21
|
+
import json
|
|
22
|
+
import os
|
|
23
|
+
from typing import (
|
|
24
|
+
Any,
|
|
25
|
+
Dict,
|
|
26
|
+
List,
|
|
27
|
+
Optional,
|
|
28
|
+
Sequence,
|
|
29
|
+
Type,
|
|
30
|
+
TypeVar,
|
|
31
|
+
Union,
|
|
32
|
+
overload,
|
|
33
|
+
cast,
|
|
34
|
+
)
|
|
35
|
+
|
|
36
|
+
from httpx import URL
|
|
37
|
+
from openai import AsyncOpenAI, AuthenticationError, RateLimitError
|
|
38
|
+
|
|
39
|
+
# from openai.pagination import AsyncSyncPage
|
|
40
|
+
from openai.types.chat import (
|
|
41
|
+
ChatCompletionAssistantMessageParam,
|
|
42
|
+
ChatCompletionMessageParam,
|
|
43
|
+
ChatCompletionSystemMessageParam,
|
|
44
|
+
ChatCompletionToolMessageParam,
|
|
45
|
+
ChatCompletionUserMessageParam,
|
|
46
|
+
)
|
|
47
|
+
from openai.types.chat.parsed_chat_completion import ParsedChatCompletion
|
|
48
|
+
from openai.types.model import Model
|
|
49
|
+
from pydantic import BaseModel
|
|
50
|
+
from loguru import logger
|
|
51
|
+
from openai.pagination import AsyncPage as AsyncSyncPage
|
|
52
|
+
|
|
53
|
+
# --------------------------------------------------------------------------- #
|
|
54
|
+
# type helpers
|
|
55
|
+
# --------------------------------------------------------------------------- #
|
|
56
|
+
TModel = TypeVar("TModel", bound=BaseModel)
|
|
57
|
+
Messages = List[ChatCompletionMessageParam]
|
|
58
|
+
LegacyMsgs = List[Dict[str, str]]
|
|
59
|
+
RawMsgs = Union[Messages, LegacyMsgs]
|
|
60
|
+
|
|
61
|
+
# --------------------------------------------------------------------------- #
|
|
62
|
+
# color helpers (unchanged)
|
|
63
|
+
# --------------------------------------------------------------------------- #
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _color(code: int, text: str) -> str:
|
|
67
|
+
return f"\x1b[{code}m{text}\x1b[0m"
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
_red = lambda t: _color(31, t)
|
|
71
|
+
_green = lambda t: _color(32, t)
|
|
72
|
+
_blue = lambda t: _color(34, t)
|
|
73
|
+
_yellow = lambda t: _color(33, t)
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
class AsyncLM:
|
|
77
|
+
"""Unified **async** language‑model wrapper with optional JSON parsing."""
|
|
78
|
+
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
model: str | None = None,
|
|
82
|
+
*,
|
|
83
|
+
temperature: float = 0.0,
|
|
84
|
+
max_tokens: int = 2_000,
|
|
85
|
+
host: str = "localhost",
|
|
86
|
+
port: Optional[int | str] = None,
|
|
87
|
+
base_url: Optional[str] = None,
|
|
88
|
+
api_key: Optional[str] = None,
|
|
89
|
+
cache: bool = True,
|
|
90
|
+
ports: Optional[List[int]] = None,
|
|
91
|
+
**openai_kwargs: Any,
|
|
92
|
+
) -> None:
|
|
93
|
+
self.model = model
|
|
94
|
+
self.temperature = temperature
|
|
95
|
+
self.max_tokens = max_tokens
|
|
96
|
+
self.port = port
|
|
97
|
+
self.host = host
|
|
98
|
+
self.base_url = base_url or (f"http://{host}:{port}/v1" if port else None)
|
|
99
|
+
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
100
|
+
self.openai_kwargs = openai_kwargs
|
|
101
|
+
self.do_cache = cache
|
|
102
|
+
self.ports = ports
|
|
103
|
+
|
|
104
|
+
# Async client
|
|
105
|
+
|
|
106
|
+
@property
|
|
107
|
+
def client(self) -> AsyncOpenAI:
|
|
108
|
+
# if have multiple ports
|
|
109
|
+
if self.ports:
|
|
110
|
+
import random
|
|
111
|
+
port = random.choice(self.ports)
|
|
112
|
+
api_base = f"http://{self.host}:{port}/v1"
|
|
113
|
+
logger.debug(f"Using port: {port}")
|
|
114
|
+
else:
|
|
115
|
+
api_base = self.base_url or f"http://{self.host}:{self.port}/v1"
|
|
116
|
+
client = AsyncOpenAI(
|
|
117
|
+
api_key=self.api_key, base_url=api_base, **self.openai_kwargs
|
|
118
|
+
)
|
|
119
|
+
return client
|
|
120
|
+
|
|
121
|
+
# ------------------------------------------------------------------ #
|
|
122
|
+
# Public API – typed overloads
|
|
123
|
+
# ------------------------------------------------------------------ #
|
|
124
|
+
@overload
|
|
125
|
+
async def __call__(
|
|
126
|
+
self,
|
|
127
|
+
*,
|
|
128
|
+
prompt: str | None = ...,
|
|
129
|
+
messages: RawMsgs | None = ...,
|
|
130
|
+
response_format: type[str] = str,
|
|
131
|
+
return_openai_response: bool = ...,
|
|
132
|
+
**kwargs: Any,
|
|
133
|
+
) -> str: ...
|
|
134
|
+
|
|
135
|
+
@overload
|
|
136
|
+
async def __call__(
|
|
137
|
+
self,
|
|
138
|
+
*,
|
|
139
|
+
prompt: str | None = ...,
|
|
140
|
+
messages: RawMsgs | None = ...,
|
|
141
|
+
response_format: Type[TModel],
|
|
142
|
+
return_openai_response: bool = ...,
|
|
143
|
+
**kwargs: Any,
|
|
144
|
+
) -> TModel: ...
|
|
145
|
+
|
|
146
|
+
async def __call__(
|
|
147
|
+
self,
|
|
148
|
+
prompt: Optional[str] = None,
|
|
149
|
+
messages: Optional[RawMsgs] = None,
|
|
150
|
+
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
151
|
+
cache: Optional[bool] = None,
|
|
152
|
+
max_tokens: Optional[int] = None,
|
|
153
|
+
return_openai_response: bool = False,
|
|
154
|
+
**kwargs: Any,
|
|
155
|
+
):
|
|
156
|
+
if (prompt is None) == (messages is None):
|
|
157
|
+
raise ValueError("Provide *either* `prompt` or `messages` (but not both).")
|
|
158
|
+
|
|
159
|
+
if prompt is not None:
|
|
160
|
+
messages = [{"role": "user", "content": prompt}]
|
|
161
|
+
|
|
162
|
+
assert messages is not None
|
|
163
|
+
# assert self.model is not None, "Model must be set before calling."
|
|
164
|
+
if not self.model:
|
|
165
|
+
models = await self.list_models(port=self.port, host=self.host)
|
|
166
|
+
self.model = models[0] if models else None
|
|
167
|
+
logger.info(
|
|
168
|
+
f"No model specified. Using the first available model. {self.model}"
|
|
169
|
+
)
|
|
170
|
+
openai_msgs: Messages = (
|
|
171
|
+
self._convert_messages(cast(LegacyMsgs, messages))
|
|
172
|
+
if isinstance(messages[0], dict)
|
|
173
|
+
else cast(Messages, messages)
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
kw = dict(
|
|
177
|
+
self.openai_kwargs,
|
|
178
|
+
temperature=self.temperature,
|
|
179
|
+
max_tokens=max_tokens or self.max_tokens,
|
|
180
|
+
)
|
|
181
|
+
kw.update(kwargs)
|
|
182
|
+
use_cache = self.do_cache if cache is None else cache
|
|
183
|
+
|
|
184
|
+
raw_response = await self._call_raw(
|
|
185
|
+
openai_msgs,
|
|
186
|
+
response_format=response_format,
|
|
187
|
+
use_cache=use_cache,
|
|
188
|
+
**kw,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if return_openai_response:
|
|
192
|
+
response = raw_response
|
|
193
|
+
else:
|
|
194
|
+
response = self._parse_output(raw_response, response_format)
|
|
195
|
+
|
|
196
|
+
self.last_log = [prompt, messages, raw_response]
|
|
197
|
+
return response
|
|
198
|
+
|
|
199
|
+
# ------------------------------------------------------------------ #
|
|
200
|
+
# Model invocation (async)
|
|
201
|
+
# ------------------------------------------------------------------ #
|
|
202
|
+
async def _call_raw(
|
|
203
|
+
self,
|
|
204
|
+
messages: Sequence[ChatCompletionMessageParam],
|
|
205
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
206
|
+
use_cache: bool,
|
|
207
|
+
**kw: Any,
|
|
208
|
+
):
|
|
209
|
+
assert self.model is not None, "Model must be set before making a call."
|
|
210
|
+
model: str = self.model
|
|
211
|
+
|
|
212
|
+
cache_key = (
|
|
213
|
+
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
214
|
+
)
|
|
215
|
+
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
216
|
+
return hit
|
|
217
|
+
|
|
218
|
+
try:
|
|
219
|
+
if response_format is not str and issubclass(response_format, BaseModel):
|
|
220
|
+
openai_response = await self.client.beta.chat.completions.parse(
|
|
221
|
+
model=model,
|
|
222
|
+
messages=list(messages),
|
|
223
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
224
|
+
**kw,
|
|
225
|
+
)
|
|
226
|
+
else:
|
|
227
|
+
openai_response = await self.client.chat.completions.create(
|
|
228
|
+
model=model,
|
|
229
|
+
messages=list(messages),
|
|
230
|
+
**kw,
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
except (AuthenticationError, RateLimitError) as exc:
|
|
234
|
+
logger.error(exc)
|
|
235
|
+
raise
|
|
236
|
+
|
|
237
|
+
if cache_key:
|
|
238
|
+
self._dump_cache(cache_key, openai_response)
|
|
239
|
+
|
|
240
|
+
return openai_response
|
|
241
|
+
|
|
242
|
+
# ------------------------------------------------------------------ #
|
|
243
|
+
# Utilities below are unchanged (sync I/O is acceptable)
|
|
244
|
+
# ------------------------------------------------------------------ #
|
|
245
|
+
@staticmethod
|
|
246
|
+
def _convert_messages(msgs: LegacyMsgs) -> Messages:
|
|
247
|
+
converted: Messages = []
|
|
248
|
+
for msg in msgs:
|
|
249
|
+
role = msg["role"]
|
|
250
|
+
content = msg["content"]
|
|
251
|
+
if role == "user":
|
|
252
|
+
converted.append(
|
|
253
|
+
ChatCompletionUserMessageParam(role="user", content=content)
|
|
254
|
+
)
|
|
255
|
+
elif role == "assistant":
|
|
256
|
+
converted.append(
|
|
257
|
+
ChatCompletionAssistantMessageParam(
|
|
258
|
+
role="assistant", content=content
|
|
259
|
+
)
|
|
260
|
+
)
|
|
261
|
+
elif role == "system":
|
|
262
|
+
converted.append(
|
|
263
|
+
ChatCompletionSystemMessageParam(role="system", content=content)
|
|
264
|
+
)
|
|
265
|
+
elif role == "tool":
|
|
266
|
+
converted.append(
|
|
267
|
+
ChatCompletionToolMessageParam(
|
|
268
|
+
role="tool",
|
|
269
|
+
content=content,
|
|
270
|
+
tool_call_id=msg.get("tool_call_id") or "",
|
|
271
|
+
)
|
|
272
|
+
)
|
|
273
|
+
else:
|
|
274
|
+
converted.append({"role": role, "content": content}) # type: ignore[arg-type]
|
|
275
|
+
return converted
|
|
276
|
+
|
|
277
|
+
@staticmethod
|
|
278
|
+
def _parse_output(
|
|
279
|
+
raw_response: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
280
|
+
) -> str | BaseModel:
|
|
281
|
+
if hasattr(raw_response, "model_dump"):
|
|
282
|
+
raw_response = raw_response.model_dump()
|
|
283
|
+
|
|
284
|
+
if response_format is str:
|
|
285
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
286
|
+
message = raw_response["choices"][0]["message"]
|
|
287
|
+
return message.get("content", "") or ""
|
|
288
|
+
return cast(str, raw_response)
|
|
289
|
+
|
|
290
|
+
model_cls = cast(Type[BaseModel], response_format)
|
|
291
|
+
|
|
292
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
293
|
+
message = raw_response["choices"][0]["message"]
|
|
294
|
+
if "parsed" in message:
|
|
295
|
+
return model_cls.model_validate(message["parsed"])
|
|
296
|
+
content = message.get("content")
|
|
297
|
+
if content is None:
|
|
298
|
+
raise ValueError("Model returned empty content")
|
|
299
|
+
try:
|
|
300
|
+
data = json.loads(content)
|
|
301
|
+
return model_cls.model_validate(data)
|
|
302
|
+
except Exception as exc:
|
|
303
|
+
raise ValueError(
|
|
304
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
305
|
+
) from exc
|
|
306
|
+
|
|
307
|
+
if isinstance(raw_response, model_cls):
|
|
308
|
+
return raw_response
|
|
309
|
+
if isinstance(raw_response, dict):
|
|
310
|
+
return model_cls.model_validate(raw_response)
|
|
311
|
+
|
|
312
|
+
try:
|
|
313
|
+
data = json.loads(raw_response)
|
|
314
|
+
return model_cls.model_validate(data)
|
|
315
|
+
except Exception as exc:
|
|
316
|
+
raise ValueError(
|
|
317
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
318
|
+
) from exc
|
|
319
|
+
|
|
320
|
+
# ------------------------------------------------------------------ #
|
|
321
|
+
# Simple disk cache (sync)
|
|
322
|
+
# ------------------------------------------------------------------ #
|
|
323
|
+
@staticmethod
|
|
324
|
+
def _cache_key(
|
|
325
|
+
messages: Any, kw: Any, response_format: Union[type[str], Type[BaseModel]]
|
|
326
|
+
) -> str:
|
|
327
|
+
tag = response_format.__name__ if response_format is not str else "text"
|
|
328
|
+
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
329
|
+
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
330
|
+
|
|
331
|
+
@staticmethod
|
|
332
|
+
def _cache_path(key: str) -> str:
|
|
333
|
+
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
334
|
+
|
|
335
|
+
def _dump_cache(self, key: str, val: Any) -> None:
|
|
336
|
+
try:
|
|
337
|
+
path = self._cache_path(key)
|
|
338
|
+
os.makedirs(os.path.dirname(path), exist_ok=True)
|
|
339
|
+
with open(path, "w") as fh:
|
|
340
|
+
if isinstance(val, BaseModel):
|
|
341
|
+
json.dump(val.model_dump(mode="json"), fh)
|
|
342
|
+
else:
|
|
343
|
+
json.dump(val, fh)
|
|
344
|
+
except Exception as exc:
|
|
345
|
+
logger.debug(f"cache write skipped: {exc}")
|
|
346
|
+
|
|
347
|
+
def _load_cache(self, key: str) -> Any | None:
|
|
348
|
+
path = self._cache_path(key)
|
|
349
|
+
if not os.path.exists(path):
|
|
350
|
+
return None
|
|
351
|
+
try:
|
|
352
|
+
with open(path) as fh:
|
|
353
|
+
return json.load(fh)
|
|
354
|
+
except Exception:
|
|
355
|
+
return None
|
|
356
|
+
|
|
357
|
+
# ------------------------------------------------------------------ #
|
|
358
|
+
# Utility helpers
|
|
359
|
+
# ------------------------------------------------------------------ #
|
|
360
|
+
async def inspect_history(self) -> None:
|
|
361
|
+
if not hasattr(self, "last_log"):
|
|
362
|
+
raise ValueError("No history available. Please call the model first.")
|
|
363
|
+
|
|
364
|
+
prompt, messages, response = self.last_log
|
|
365
|
+
if hasattr(response, "model_dump"):
|
|
366
|
+
response = response.model_dump()
|
|
367
|
+
if not messages:
|
|
368
|
+
messages = [{"role": "user", "content": prompt}]
|
|
369
|
+
|
|
370
|
+
print("\n\n")
|
|
371
|
+
print(_blue("[Conversation History]") + "\n")
|
|
372
|
+
|
|
373
|
+
for msg in messages:
|
|
374
|
+
role = msg["role"]
|
|
375
|
+
content = msg["content"]
|
|
376
|
+
print(_red(f"{role.capitalize()}:"))
|
|
377
|
+
if isinstance(content, str):
|
|
378
|
+
print(content.strip())
|
|
379
|
+
elif isinstance(content, list):
|
|
380
|
+
for item in content:
|
|
381
|
+
if item.get("type") == "text":
|
|
382
|
+
print(item["text"].strip())
|
|
383
|
+
elif item.get("type") == "image_url":
|
|
384
|
+
image_url = item["image_url"]["url"]
|
|
385
|
+
if "base64" in image_url:
|
|
386
|
+
len_base64 = len(image_url.split("base64,")[1])
|
|
387
|
+
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
388
|
+
else:
|
|
389
|
+
print(_blue(f"<image_url: {image_url}>"))
|
|
390
|
+
print("\n")
|
|
391
|
+
|
|
392
|
+
print(_red("Response:"))
|
|
393
|
+
if isinstance(response, dict) and response.get("choices"):
|
|
394
|
+
message = response["choices"][0].get("message", {})
|
|
395
|
+
reasoning = message.get("reasoning_content")
|
|
396
|
+
parsed = message.get("parsed")
|
|
397
|
+
content = message.get("content")
|
|
398
|
+
if reasoning:
|
|
399
|
+
print(_yellow("<think>"))
|
|
400
|
+
print(reasoning.strip())
|
|
401
|
+
print(_yellow("</think>\n"))
|
|
402
|
+
if parsed:
|
|
403
|
+
print(
|
|
404
|
+
json.dumps(
|
|
405
|
+
(
|
|
406
|
+
parsed.model_dump()
|
|
407
|
+
if hasattr(parsed, "model_dump")
|
|
408
|
+
else parsed
|
|
409
|
+
),
|
|
410
|
+
indent=2,
|
|
411
|
+
)
|
|
412
|
+
+ "\n"
|
|
413
|
+
)
|
|
414
|
+
elif content:
|
|
415
|
+
print(content.strip())
|
|
416
|
+
else:
|
|
417
|
+
print(_green("[No content]"))
|
|
418
|
+
if len(response["choices"]) > 1:
|
|
419
|
+
print(
|
|
420
|
+
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
421
|
+
)
|
|
422
|
+
else:
|
|
423
|
+
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
424
|
+
if isinstance(response, str):
|
|
425
|
+
print(_green(response.strip()))
|
|
426
|
+
elif isinstance(response, dict):
|
|
427
|
+
print(_green(json.dumps(response, indent=2)))
|
|
428
|
+
else:
|
|
429
|
+
print(_green(str(response)))
|
|
430
|
+
|
|
431
|
+
# ------------------------------------------------------------------ #
|
|
432
|
+
# Misc helpers
|
|
433
|
+
# ------------------------------------------------------------------ #
|
|
434
|
+
def set_model(self, model: str) -> None:
|
|
435
|
+
self.model = model
|
|
436
|
+
|
|
437
|
+
@staticmethod
|
|
438
|
+
async def list_models(port=None, host="localhost") -> List[str]:
|
|
439
|
+
try:
|
|
440
|
+
client: AsyncOpenAI = AsyncLM(port=port, host=host).client # type: ignore[arg-type]
|
|
441
|
+
base_url: URL = client.base_url
|
|
442
|
+
logger.debug(f"Base URL: {base_url}")
|
|
443
|
+
models: AsyncSyncPage[Model] = await client.models.list() # type: ignore[assignment]
|
|
444
|
+
return [model.id for model in models.data]
|
|
445
|
+
except Exception as exc:
|
|
446
|
+
logger.error(f"Failed to list models: {exc}")
|
|
447
|
+
return []
|
llm_utils/lm/lm.py
CHANGED
|
@@ -4,6 +4,7 @@ import base64
|
|
|
4
4
|
import hashlib
|
|
5
5
|
import json
|
|
6
6
|
import os
|
|
7
|
+
from token import OP
|
|
7
8
|
from typing import (
|
|
8
9
|
Any,
|
|
9
10
|
Dict,
|
|
@@ -18,7 +19,9 @@ from typing import (
|
|
|
18
19
|
)
|
|
19
20
|
|
|
20
21
|
from httpx import URL
|
|
22
|
+
from huggingface_hub import repo_info
|
|
21
23
|
from loguru import logger
|
|
24
|
+
from numpy import isin
|
|
22
25
|
from openai import OpenAI, AuthenticationError, RateLimitError
|
|
23
26
|
from openai.pagination import SyncPage
|
|
24
27
|
from openai.types.chat import (
|
|
@@ -42,6 +45,29 @@ LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
|
|
|
42
45
|
RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
|
|
43
46
|
|
|
44
47
|
|
|
48
|
+
# --------------------------------------------------------------------------- #
|
|
49
|
+
# color formatting helpers
|
|
50
|
+
# --------------------------------------------------------------------------- #
|
|
51
|
+
def _red(text: str) -> str:
|
|
52
|
+
"""Format text with red color."""
|
|
53
|
+
return f"\x1b[31m{text}\x1b[0m"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def _green(text: str) -> str:
|
|
57
|
+
"""Format text with green color."""
|
|
58
|
+
return f"\x1b[32m{text}\x1b[0m"
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
def _blue(text: str) -> str:
|
|
62
|
+
"""Format text with blue color."""
|
|
63
|
+
return f"\x1b[34m{text}\x1b[0m"
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
def _yellow(text: str) -> str:
|
|
67
|
+
"""Format text with yellow color."""
|
|
68
|
+
return f"\x1b[33m{text}\x1b[0m"
|
|
69
|
+
|
|
70
|
+
|
|
45
71
|
class LM:
|
|
46
72
|
"""
|
|
47
73
|
Unified language-model wrapper.
|
|
@@ -73,6 +99,7 @@ class LM:
|
|
|
73
99
|
self.api_key = api_key or os.getenv("OPENAI_API_KEY", "abc")
|
|
74
100
|
self.openai_kwargs = openai_kwargs
|
|
75
101
|
self.do_cache = cache
|
|
102
|
+
self._init_port = port # <-- store the port provided at init
|
|
76
103
|
|
|
77
104
|
self.client = OpenAI(api_key=self.api_key, base_url=self.base_url)
|
|
78
105
|
|
|
@@ -90,6 +117,7 @@ class LM:
|
|
|
90
117
|
prompt: str | None = ...,
|
|
91
118
|
messages: RawMsgs | None = ...,
|
|
92
119
|
response_format: type[str] = str,
|
|
120
|
+
return_openai_response: bool = ...,
|
|
93
121
|
**kwargs: Any,
|
|
94
122
|
) -> str: ...
|
|
95
123
|
|
|
@@ -100,6 +128,7 @@ class LM:
|
|
|
100
128
|
prompt: str | None = ...,
|
|
101
129
|
messages: RawMsgs | None = ...,
|
|
102
130
|
response_format: Type[TModel],
|
|
131
|
+
return_openai_response: bool = ...,
|
|
103
132
|
**kwargs: Any,
|
|
104
133
|
) -> TModel: ...
|
|
105
134
|
|
|
@@ -111,6 +140,7 @@ class LM:
|
|
|
111
140
|
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
112
141
|
cache: Optional[bool] = None,
|
|
113
142
|
max_tokens: Optional[int] = None,
|
|
143
|
+
return_openai_response: bool = False,
|
|
114
144
|
**kwargs: Any,
|
|
115
145
|
):
|
|
116
146
|
# argument validation ------------------------------------------------
|
|
@@ -121,7 +151,20 @@ class LM:
|
|
|
121
151
|
messages = [{"role": "user", "content": prompt}]
|
|
122
152
|
|
|
123
153
|
assert messages is not None # for type-checker
|
|
124
|
-
|
|
154
|
+
|
|
155
|
+
# If model is not specified, but port is provided, use the first available model
|
|
156
|
+
if self.model is None:
|
|
157
|
+
port = self._init_port
|
|
158
|
+
if port:
|
|
159
|
+
available_models = self.list_models(port=port)
|
|
160
|
+
if available_models:
|
|
161
|
+
self.model = available_models[0]
|
|
162
|
+
logger.info(f"Auto-selected model: {self.model}")
|
|
163
|
+
else:
|
|
164
|
+
raise ValueError("No models available to select from.")
|
|
165
|
+
else:
|
|
166
|
+
raise AssertionError("Model must be set before calling.")
|
|
167
|
+
|
|
125
168
|
openai_msgs: Messages = (
|
|
126
169
|
self._convert_messages(cast(LegacyMsgs, messages))
|
|
127
170
|
if isinstance(messages[0], dict) # legacy style
|
|
@@ -132,17 +175,119 @@ class LM:
|
|
|
132
175
|
self.openai_kwargs,
|
|
133
176
|
temperature=self.temperature,
|
|
134
177
|
max_tokens=max_tokens or self.max_tokens,
|
|
135
|
-
**kwargs,
|
|
136
178
|
)
|
|
179
|
+
kw.update(kwargs)
|
|
137
180
|
use_cache = self.do_cache if cache is None else cache
|
|
138
181
|
|
|
139
|
-
|
|
182
|
+
raw_response = self._call_raw(
|
|
140
183
|
openai_msgs,
|
|
141
184
|
response_format=response_format,
|
|
142
185
|
use_cache=use_cache,
|
|
143
186
|
**kw,
|
|
144
187
|
)
|
|
145
|
-
|
|
188
|
+
|
|
189
|
+
if return_openai_response:
|
|
190
|
+
response = raw_response
|
|
191
|
+
else:
|
|
192
|
+
response = self._parse_output(raw_response, response_format)
|
|
193
|
+
|
|
194
|
+
self.last_log = [prompt, messages, raw_response]
|
|
195
|
+
return response
|
|
196
|
+
|
|
197
|
+
def inspect_history(self) -> None:
|
|
198
|
+
if not hasattr(self, "last_log"):
|
|
199
|
+
raise ValueError("No history available. Please call the model first.")
|
|
200
|
+
|
|
201
|
+
prompt, messages, response = self.last_log
|
|
202
|
+
# Ensure response is a dictionary
|
|
203
|
+
if hasattr(response, "model_dump"):
|
|
204
|
+
response = response.model_dump()
|
|
205
|
+
|
|
206
|
+
if not messages:
|
|
207
|
+
messages = [{"role": "user", "content": prompt}]
|
|
208
|
+
|
|
209
|
+
print("\n\n")
|
|
210
|
+
print(_blue("[Conversation History]") + "\n")
|
|
211
|
+
|
|
212
|
+
# Print all messages in the conversation
|
|
213
|
+
for msg in messages:
|
|
214
|
+
role = msg["role"]
|
|
215
|
+
content = msg["content"]
|
|
216
|
+
print(_red(f"{role.capitalize()}:"))
|
|
217
|
+
|
|
218
|
+
if isinstance(content, str):
|
|
219
|
+
print(content.strip())
|
|
220
|
+
elif isinstance(content, list):
|
|
221
|
+
# Handle multimodal content
|
|
222
|
+
for item in content:
|
|
223
|
+
if item.get("type") == "text":
|
|
224
|
+
print(item["text"].strip())
|
|
225
|
+
elif item.get("type") == "image_url":
|
|
226
|
+
image_url = item["image_url"]["url"]
|
|
227
|
+
if "base64" in image_url:
|
|
228
|
+
len_base64 = len(image_url.split("base64,")[1])
|
|
229
|
+
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
230
|
+
else:
|
|
231
|
+
print(_blue(f"<image_url: {image_url}>"))
|
|
232
|
+
print("\n")
|
|
233
|
+
|
|
234
|
+
# Print the response - now always an OpenAI completion
|
|
235
|
+
print(_red("Response:"))
|
|
236
|
+
|
|
237
|
+
# Handle OpenAI response object
|
|
238
|
+
if isinstance(response, dict) and "choices" in response and response["choices"]:
|
|
239
|
+
message = response["choices"][0].get("message", {})
|
|
240
|
+
|
|
241
|
+
# Check for reasoning content (if available)
|
|
242
|
+
reasoning = message.get("reasoning_content")
|
|
243
|
+
|
|
244
|
+
# Check for parsed content (structured mode)
|
|
245
|
+
parsed = message.get("parsed")
|
|
246
|
+
|
|
247
|
+
# Get regular content
|
|
248
|
+
content = message.get("content")
|
|
249
|
+
|
|
250
|
+
# Display reasoning if available
|
|
251
|
+
if reasoning:
|
|
252
|
+
print(_yellow("<think>"))
|
|
253
|
+
print(reasoning.strip())
|
|
254
|
+
print(_yellow("</think>"))
|
|
255
|
+
print()
|
|
256
|
+
|
|
257
|
+
# Display parsed content for structured responses
|
|
258
|
+
if parsed:
|
|
259
|
+
# print(_green('<Parsed Structure>'))
|
|
260
|
+
if hasattr(parsed, "model_dump"):
|
|
261
|
+
print(json.dumps(parsed.model_dump(), indent=2))
|
|
262
|
+
else:
|
|
263
|
+
print(json.dumps(parsed, indent=2))
|
|
264
|
+
# print(_green('</Parsed Structure>'))
|
|
265
|
+
print()
|
|
266
|
+
|
|
267
|
+
else:
|
|
268
|
+
if content:
|
|
269
|
+
# print(_green("<Content>"))
|
|
270
|
+
print(content.strip())
|
|
271
|
+
# print(_green("</Content>"))
|
|
272
|
+
else:
|
|
273
|
+
print(_green("[No content]"))
|
|
274
|
+
|
|
275
|
+
# Show if there were multiple completions
|
|
276
|
+
if len(response["choices"]) > 1:
|
|
277
|
+
print(
|
|
278
|
+
_blue(f"\n(Plus {len(response['choices']) - 1} other completions)")
|
|
279
|
+
)
|
|
280
|
+
else:
|
|
281
|
+
# Fallback for non-standard response objects or cached responses
|
|
282
|
+
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
283
|
+
if isinstance(response, str):
|
|
284
|
+
print(_green(response.strip()))
|
|
285
|
+
elif isinstance(response, dict):
|
|
286
|
+
print(_green(json.dumps(response, indent=2)))
|
|
287
|
+
else:
|
|
288
|
+
print(_green(str(response)))
|
|
289
|
+
|
|
290
|
+
# print("\n\n")
|
|
146
291
|
|
|
147
292
|
# --------------------------------------------------------------------- #
|
|
148
293
|
# low-level OpenAI call
|
|
@@ -156,6 +301,7 @@ class LM:
|
|
|
156
301
|
):
|
|
157
302
|
assert self.model is not None, "Model must be set before making a call."
|
|
158
303
|
model: str = self.model
|
|
304
|
+
|
|
159
305
|
cache_key = (
|
|
160
306
|
self._cache_key(messages, kw, response_format) if use_cache else None
|
|
161
307
|
)
|
|
@@ -165,31 +311,28 @@ class LM:
|
|
|
165
311
|
try:
|
|
166
312
|
# structured mode
|
|
167
313
|
if response_format is not str and issubclass(response_format, BaseModel):
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
**kw,
|
|
174
|
-
)
|
|
314
|
+
openai_response = self.client.beta.chat.completions.parse(
|
|
315
|
+
model=model,
|
|
316
|
+
messages=list(messages),
|
|
317
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
318
|
+
**kw,
|
|
175
319
|
)
|
|
176
|
-
result: Any = rsp.choices[0].message.parsed # already a model
|
|
177
320
|
# plain-text mode
|
|
178
321
|
else:
|
|
179
|
-
|
|
322
|
+
openai_response = self.client.chat.completions.create(
|
|
180
323
|
model=model,
|
|
181
324
|
messages=list(messages),
|
|
182
325
|
**kw,
|
|
183
326
|
)
|
|
184
|
-
|
|
327
|
+
|
|
185
328
|
except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
|
|
186
329
|
logger.error(exc)
|
|
187
330
|
raise
|
|
188
331
|
|
|
189
332
|
if cache_key:
|
|
190
|
-
self._dump_cache(cache_key,
|
|
333
|
+
self._dump_cache(cache_key, openai_response)
|
|
191
334
|
|
|
192
|
-
return
|
|
335
|
+
return openai_response
|
|
193
336
|
|
|
194
337
|
# --------------------------------------------------------------------- #
|
|
195
338
|
# legacy → typed messages
|
|
@@ -232,31 +375,67 @@ class LM:
|
|
|
232
375
|
# --------------------------------------------------------------------- #
|
|
233
376
|
@staticmethod
|
|
234
377
|
def _parse_output(
|
|
235
|
-
|
|
378
|
+
raw_response: Any,
|
|
236
379
|
response_format: Union[type[str], Type[BaseModel]],
|
|
237
380
|
) -> str | BaseModel:
|
|
381
|
+
# Convert any object to dict if needed
|
|
382
|
+
if hasattr(raw_response, "model_dump"):
|
|
383
|
+
raw_response = raw_response.model_dump()
|
|
384
|
+
|
|
238
385
|
if response_format is str:
|
|
239
|
-
|
|
386
|
+
# Extract the content from OpenAI response dict
|
|
387
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
388
|
+
message = raw_response["choices"][0]["message"]
|
|
389
|
+
return message.get("content", "") or ""
|
|
390
|
+
return cast(str, raw_response)
|
|
240
391
|
|
|
241
392
|
# For the type-checker: we *know* it's a BaseModel subclass here.
|
|
242
393
|
model_cls = cast(Type[BaseModel], response_format)
|
|
243
394
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
395
|
+
# Handle structured response
|
|
396
|
+
if isinstance(raw_response, dict) and "choices" in raw_response:
|
|
397
|
+
message = raw_response["choices"][0]["message"]
|
|
398
|
+
|
|
399
|
+
# Check if already parsed by OpenAI client
|
|
400
|
+
if "parsed" in message:
|
|
401
|
+
return model_cls.model_validate(message["parsed"])
|
|
402
|
+
|
|
403
|
+
# Need to parse the content
|
|
404
|
+
content = message.get("content")
|
|
405
|
+
if content is None:
|
|
406
|
+
raise ValueError("Model returned empty content")
|
|
407
|
+
|
|
408
|
+
try:
|
|
409
|
+
data = json.loads(content)
|
|
410
|
+
return model_cls.model_validate(data)
|
|
411
|
+
except Exception as exc:
|
|
412
|
+
raise ValueError(
|
|
413
|
+
f"Failed to parse model output as JSON:\n{content}"
|
|
414
|
+
) from exc
|
|
415
|
+
|
|
416
|
+
# Handle cached response or other formats
|
|
417
|
+
if isinstance(raw_response, model_cls):
|
|
418
|
+
return raw_response
|
|
419
|
+
if isinstance(raw_response, dict):
|
|
420
|
+
return model_cls.model_validate(raw_response)
|
|
421
|
+
|
|
422
|
+
# Try parsing as JSON string
|
|
248
423
|
try:
|
|
249
|
-
data = json.loads(
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
424
|
+
data = json.loads(raw_response)
|
|
425
|
+
return model_cls.model_validate(data)
|
|
426
|
+
except Exception as exc:
|
|
427
|
+
raise ValueError(
|
|
428
|
+
f"Model did not return valid JSON:\n---\n{raw_response}"
|
|
429
|
+
) from exc
|
|
253
430
|
|
|
254
431
|
# --------------------------------------------------------------------- #
|
|
255
432
|
# tiny disk cache
|
|
256
433
|
# --------------------------------------------------------------------- #
|
|
257
434
|
@staticmethod
|
|
258
435
|
def _cache_key(
|
|
259
|
-
messages: Any,
|
|
436
|
+
messages: Any,
|
|
437
|
+
kw: Any,
|
|
438
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
260
439
|
) -> str:
|
|
261
440
|
tag = response_format.__name__ if response_format is not str else "text"
|
|
262
441
|
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
@@ -289,7 +468,7 @@ class LM:
|
|
|
289
468
|
return None
|
|
290
469
|
|
|
291
470
|
@staticmethod
|
|
292
|
-
def list_models(port=None, host=
|
|
471
|
+
def list_models(port=None, host="localhost") -> List[str]:
|
|
293
472
|
"""
|
|
294
473
|
List available models.
|
|
295
474
|
"""
|
|
@@ -302,3 +481,78 @@ class LM:
|
|
|
302
481
|
except Exception as exc:
|
|
303
482
|
logger.error(f"Failed to list models: {exc}")
|
|
304
483
|
return []
|
|
484
|
+
|
|
485
|
+
|
|
486
|
+
from functools import cache
|
|
487
|
+
from llm_utils.lm.lm import LM, RawMsgs
|
|
488
|
+
from pydantic import BaseModel
|
|
489
|
+
import re
|
|
490
|
+
import json
|
|
491
|
+
from typing import *
|
|
492
|
+
import re
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
class LMReasoner(LM):
|
|
496
|
+
"Regex-based reasoning wrapper for LM."
|
|
497
|
+
|
|
498
|
+
def build_regex_from_pydantic(self, model: type[BaseModel]) -> str:
|
|
499
|
+
"""
|
|
500
|
+
Build a regex pattern string for validating output that should match a Pydantic model.
|
|
501
|
+
|
|
502
|
+
Args:
|
|
503
|
+
model: A Pydantic BaseModel class
|
|
504
|
+
|
|
505
|
+
Returns:
|
|
506
|
+
A regex string that matches a JSON representation of the model
|
|
507
|
+
"""
|
|
508
|
+
# regex = f"<think>\\n.*?\\n</think>\\n\\n\\```json\\n.*"
|
|
509
|
+
print(f"{regex=}")
|
|
510
|
+
|
|
511
|
+
return regex
|
|
512
|
+
|
|
513
|
+
def __call__(
|
|
514
|
+
self,
|
|
515
|
+
response_format: type[BaseModel],
|
|
516
|
+
prompt: Optional[str] = None,
|
|
517
|
+
messages: Optional[RawMsgs] = None,
|
|
518
|
+
**kwargs,
|
|
519
|
+
):
|
|
520
|
+
|
|
521
|
+
if prompt is not None:
|
|
522
|
+
output = super().__call__(
|
|
523
|
+
prompt=prompt
|
|
524
|
+
+ "\nresponse_format:"
|
|
525
|
+
+ str(response_format.model_json_schema()),
|
|
526
|
+
response_format=str,
|
|
527
|
+
# extra_body={"guided_regex": regex},
|
|
528
|
+
**kwargs,
|
|
529
|
+
) # type: ignore
|
|
530
|
+
elif messages is not None:
|
|
531
|
+
# append last message with the json schema
|
|
532
|
+
messages[-1]["content"] += "\nresponse_format:" + str( # type: ignore
|
|
533
|
+
response_format.model_json_schema()
|
|
534
|
+
)
|
|
535
|
+
output = super().__call__(
|
|
536
|
+
messages=messages,
|
|
537
|
+
response_format=str,
|
|
538
|
+
# extra_body={"guided_regex": regex},
|
|
539
|
+
**kwargs,
|
|
540
|
+
)
|
|
541
|
+
else:
|
|
542
|
+
raise ValueError("Either prompt or messages must be provided.")
|
|
543
|
+
# import ipdb; ipdb.set_trace()
|
|
544
|
+
# parse using regex
|
|
545
|
+
pattern = re.compile(
|
|
546
|
+
r"<think>\n(?P<think>.*?)\n</think>\n\n(?P<json>\{.*\})",
|
|
547
|
+
re.DOTALL,
|
|
548
|
+
)
|
|
549
|
+
match = pattern.search(output)
|
|
550
|
+
if not match:
|
|
551
|
+
raise ValueError("Output does not match expected format")
|
|
552
|
+
parsed_output = match.group(0)
|
|
553
|
+
think_part = match.group("think")
|
|
554
|
+
json_part = match.group("json")
|
|
555
|
+
|
|
556
|
+
pydantic_object = response_format.model_validate(json.loads(json_part))
|
|
557
|
+
return pydantic_object
|
|
558
|
+
|
|
@@ -5,18 +5,19 @@ import time
|
|
|
5
5
|
from tabulate import tabulate
|
|
6
6
|
import contextlib
|
|
7
7
|
import aiohttp # <-- Import aiohttp
|
|
8
|
+
from speedy_utils import setup_logger
|
|
8
9
|
from loguru import logger
|
|
9
|
-
|
|
10
|
+
setup_logger(min_interval=5)
|
|
10
11
|
# --- Configuration ---
|
|
11
12
|
LOAD_BALANCER_HOST = "0.0.0.0"
|
|
12
13
|
LOAD_BALANCER_PORT = 8008
|
|
13
14
|
|
|
14
15
|
SCAN_TARGET_HOST = "localhost"
|
|
15
|
-
SCAN_PORT_START =
|
|
16
|
+
SCAN_PORT_START = 8140
|
|
16
17
|
SCAN_PORT_END = 8170 # Inclusive
|
|
17
18
|
SCAN_INTERVAL = 30
|
|
18
19
|
# Timeout applies to the HTTP health check request now
|
|
19
|
-
HEALTH_CHECK_TIMEOUT = 2
|
|
20
|
+
HEALTH_CHECK_TIMEOUT = 2 # Increased slightly for HTTP requests
|
|
20
21
|
|
|
21
22
|
STATUS_PRINT_INTERVAL = 5
|
|
22
23
|
BUFFER_SIZE = 4096
|
|
@@ -83,14 +84,14 @@ async def check_server_health(session, host, port):
|
|
|
83
84
|
# Check for a successful status code (2xx range)
|
|
84
85
|
if 200 <= response.status < 300:
|
|
85
86
|
logger.debug(
|
|
86
|
-
f"Health check success for {url} (Status: {response.status})"
|
|
87
|
+
f"[{LOAD_BALANCER_PORT=}] Health check success for {url} (Status: {response.status})"
|
|
87
88
|
)
|
|
88
89
|
# Ensure the connection is released back to the pool
|
|
89
90
|
await response.release()
|
|
90
91
|
return True
|
|
91
92
|
else:
|
|
92
93
|
logger.debug(
|
|
93
|
-
f"Health check failed for {url} (Status: {response.status})"
|
|
94
|
+
f"[{LOAD_BALANCER_PORT=}] Health check failed for {url} (Status: {response.status})"
|
|
94
95
|
)
|
|
95
96
|
await response.release()
|
|
96
97
|
return False
|
|
@@ -180,7 +181,7 @@ async def scan_and_update_servers():
|
|
|
180
181
|
if server not in connection_counts:
|
|
181
182
|
connection_counts[server] = 0
|
|
182
183
|
|
|
183
|
-
logger.debug(f"Scan complete. Active servers: {available_servers}")
|
|
184
|
+
logger.debug(f"[{LOAD_BALANCER_PORT=}]Scan complete. Active servers: {available_servers}")
|
|
184
185
|
|
|
185
186
|
except asyncio.CancelledError:
|
|
186
187
|
logger.info("Server scan task cancelled.")
|
llm_utils/scripts/vllm_serve.py
CHANGED
|
@@ -9,19 +9,17 @@ Serve a base model:
|
|
|
9
9
|
svllm serve --model MODEL_NAME --gpus GPU_GROUPS
|
|
10
10
|
|
|
11
11
|
Add a LoRA to a served model:
|
|
12
|
-
svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port
|
|
12
|
+
svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port
|
|
13
|
+
(if add then the port must be specify)
|
|
13
14
|
"""
|
|
14
15
|
|
|
15
|
-
from glob import glob
|
|
16
16
|
import os
|
|
17
17
|
import subprocess
|
|
18
|
-
import
|
|
19
|
-
from typing import List, Literal, Optional
|
|
20
|
-
from fastcore.script import call_parse
|
|
21
|
-
from loguru import logger
|
|
18
|
+
from typing import List, Optional
|
|
22
19
|
import argparse
|
|
23
20
|
import requests
|
|
24
21
|
import openai
|
|
22
|
+
from loguru import logger
|
|
25
23
|
|
|
26
24
|
from speedy_utils.common.utils_io import load_by_ext
|
|
27
25
|
|
|
@@ -32,63 +30,22 @@ HF_HOME: str = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingfac
|
|
|
32
30
|
logger.info(f"LORA_DIR: {LORA_DIR}")
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
def model_list(host_port, api_key="abc"):
|
|
33
|
+
def model_list(host_port: str, api_key: str = "abc") -> None:
|
|
34
|
+
"""List models from the vLLM server."""
|
|
36
35
|
client = openai.OpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
|
|
37
36
|
models = client.models.list()
|
|
38
37
|
for model in models:
|
|
39
38
|
print(f"Model ID: {model.id}")
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
def kill_existing_vllm(vllm_binary: Optional[str] = None) -> None:
|
|
43
|
-
"""Kill selected vLLM processes using fzf."""
|
|
44
|
-
if not vllm_binary:
|
|
45
|
-
vllm_binary = get_vllm()
|
|
46
|
-
|
|
47
|
-
# List running vLLM processes
|
|
48
|
-
result = subprocess.run(
|
|
49
|
-
f"ps aux | grep {vllm_binary} | grep -v grep",
|
|
50
|
-
shell=True,
|
|
51
|
-
capture_output=True,
|
|
52
|
-
text=True,
|
|
53
|
-
)
|
|
54
|
-
processes = result.stdout.strip().split("\n")
|
|
55
|
-
|
|
56
|
-
if not processes or processes == [""]:
|
|
57
|
-
print("No running vLLM processes found.")
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
# Use fzf to select processes to kill
|
|
61
|
-
fzf = subprocess.Popen(
|
|
62
|
-
["fzf", "--multi"],
|
|
63
|
-
stdin=subprocess.PIPE,
|
|
64
|
-
stdout=subprocess.PIPE,
|
|
65
|
-
text=True,
|
|
66
|
-
)
|
|
67
|
-
selected, _ = fzf.communicate("\n".join(processes))
|
|
68
|
-
|
|
69
|
-
if not selected:
|
|
70
|
-
print("No processes selected.")
|
|
71
|
-
return
|
|
72
|
-
|
|
73
|
-
# Extract PIDs and kill selected processes
|
|
74
|
-
pids = [line.split()[1] for line in selected.strip().split("\n")]
|
|
75
|
-
for pid in pids:
|
|
76
|
-
subprocess.run(
|
|
77
|
-
f"kill -9 {pid}",
|
|
78
|
-
shell=True,
|
|
79
|
-
stdout=subprocess.DEVNULL,
|
|
80
|
-
stderr=subprocess.DEVNULL,
|
|
81
|
-
)
|
|
82
|
-
print(f"Killed processes: {', '.join(pids)}")
|
|
83
|
-
|
|
84
|
-
|
|
85
41
|
def add_lora(
|
|
86
42
|
lora_name_or_path: str,
|
|
87
43
|
host_port: str,
|
|
88
44
|
url: str = "http://HOST:PORT/v1/load_lora_adapter",
|
|
89
45
|
served_model_name: Optional[str] = None,
|
|
90
|
-
lora_module: Optional[str] = None,
|
|
46
|
+
lora_module: Optional[str] = None,
|
|
91
47
|
) -> dict:
|
|
48
|
+
"""Add a LoRA adapter to a running vLLM server."""
|
|
92
49
|
url = url.replace("HOST:PORT", host_port)
|
|
93
50
|
headers = {"Content-Type": "application/json"}
|
|
94
51
|
|
|
@@ -96,15 +53,12 @@ def add_lora(
|
|
|
96
53
|
"lora_name": served_model_name,
|
|
97
54
|
"lora_path": os.path.abspath(lora_name_or_path),
|
|
98
55
|
}
|
|
99
|
-
if lora_module:
|
|
56
|
+
if lora_module:
|
|
100
57
|
data["lora_module"] = lora_module
|
|
101
58
|
logger.info(f"{data=}, {headers}, {url=}")
|
|
102
|
-
# logger.warning(f"Failed to unload LoRA adapter: {str(e)}")
|
|
103
59
|
try:
|
|
104
|
-
response = requests.post(url, headers=headers, json=data)
|
|
60
|
+
response = requests.post(url, headers=headers, json=data, timeout=10)
|
|
105
61
|
response.raise_for_status()
|
|
106
|
-
|
|
107
|
-
# Handle potential non-JSON responses
|
|
108
62
|
try:
|
|
109
63
|
return response.json()
|
|
110
64
|
except ValueError:
|
|
@@ -116,113 +70,100 @@ def add_lora(
|
|
|
116
70
|
else "Request completed with empty response"
|
|
117
71
|
),
|
|
118
72
|
}
|
|
119
|
-
|
|
120
73
|
except requests.exceptions.RequestException as e:
|
|
121
74
|
logger.error(f"Request failed: {str(e)}")
|
|
122
75
|
return {"error": f"Request failed: {str(e)}"}
|
|
123
76
|
|
|
124
77
|
|
|
125
|
-
def unload_lora(lora_name, host_port):
|
|
78
|
+
def unload_lora(lora_name: str, host_port: str) -> Optional[dict]:
|
|
79
|
+
"""Unload a LoRA adapter from a running vLLM server."""
|
|
126
80
|
try:
|
|
127
81
|
url = f"http://{host_port}/v1/unload_lora_adapter"
|
|
128
82
|
logger.info(f"{url=}")
|
|
129
83
|
headers = {"Content-Type": "application/json"}
|
|
130
84
|
data = {"lora_name": lora_name}
|
|
131
85
|
logger.info(f"Unloading LoRA adapter: {data=}")
|
|
132
|
-
response = requests.post(url, headers=headers, json=data)
|
|
86
|
+
response = requests.post(url, headers=headers, json=data, timeout=10)
|
|
133
87
|
response.raise_for_status()
|
|
134
88
|
logger.success(f"Unloaded LoRA adapter: {lora_name}")
|
|
135
89
|
except requests.exceptions.RequestException as e:
|
|
136
90
|
return {"error": f"Request failed: {str(e)}"}
|
|
137
91
|
|
|
138
92
|
|
|
139
|
-
def serve(
|
|
140
|
-
model: str,
|
|
141
|
-
gpu_groups: str,
|
|
142
|
-
served_model_name: Optional[str] = None,
|
|
143
|
-
port_start: int = 8155,
|
|
144
|
-
gpu_memory_utilization: float = 0.93,
|
|
145
|
-
dtype: str = "bfloat16",
|
|
146
|
-
max_model_len: int = 8192,
|
|
147
|
-
enable_lora: bool = False,
|
|
148
|
-
is_bnb: bool = False,
|
|
149
|
-
eager: bool = False,
|
|
150
|
-
lora_modules: Optional[List[str]] = None, # Updated type
|
|
151
|
-
) -> None:
|
|
152
|
-
"""Main function to start or kill vLLM containers."""
|
|
153
|
-
|
|
93
|
+
def serve(args) -> None:
|
|
154
94
|
"""Start vLLM containers with dynamic args."""
|
|
155
95
|
print("Starting vLLM containers...,")
|
|
156
|
-
gpu_groups_arr: List[str] = gpu_groups.split(",")
|
|
157
|
-
|
|
158
|
-
if enable_lora:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
96
|
+
gpu_groups_arr: List[str] = args.gpu_groups.split(",")
|
|
97
|
+
vllm_binary: str = get_vllm()
|
|
98
|
+
if args.enable_lora:
|
|
99
|
+
vllm_binary = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + vllm_binary
|
|
100
|
+
|
|
101
|
+
if (
|
|
102
|
+
not args.bnb
|
|
103
|
+
and args.model
|
|
104
|
+
and ("bnb" in args.model.lower() or "4bit" in args.model.lower())
|
|
105
|
+
):
|
|
106
|
+
args.bnb = True
|
|
107
|
+
print(f"Auto-detected quantization for model: {args.model}")
|
|
108
|
+
|
|
109
|
+
if args.enable_lora:
|
|
168
110
|
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
|
|
169
111
|
print("Enabled runtime LoRA updating")
|
|
170
112
|
|
|
171
113
|
for i, gpu_group in enumerate(gpu_groups_arr):
|
|
172
|
-
port =
|
|
114
|
+
port = int(args.host_port.split(":")[-1]) + i
|
|
173
115
|
gpu_group = ",".join([str(x) for x in gpu_group])
|
|
174
116
|
tensor_parallel = len(gpu_group.split(","))
|
|
175
117
|
|
|
176
118
|
cmd = [
|
|
177
119
|
f"CUDA_VISIBLE_DEVICES={gpu_group}",
|
|
178
|
-
|
|
120
|
+
vllm_binary,
|
|
179
121
|
"serve",
|
|
180
|
-
model,
|
|
122
|
+
args.model,
|
|
181
123
|
"--port",
|
|
182
124
|
str(port),
|
|
183
125
|
"--tensor-parallel",
|
|
184
126
|
str(tensor_parallel),
|
|
185
127
|
"--gpu-memory-utilization",
|
|
186
|
-
str(gpu_memory_utilization),
|
|
128
|
+
str(args.gpu_memory_utilization),
|
|
187
129
|
"--dtype",
|
|
188
|
-
dtype,
|
|
130
|
+
args.dtype,
|
|
189
131
|
"--max-model-len",
|
|
190
|
-
str(max_model_len),
|
|
132
|
+
str(args.max_model_len),
|
|
191
133
|
"--enable-prefix-caching",
|
|
192
134
|
"--disable-log-requests",
|
|
193
|
-
"--uvicorn-log-level critical",
|
|
135
|
+
# "--uvicorn-log-level critical",
|
|
194
136
|
]
|
|
195
137
|
if HF_HOME:
|
|
196
|
-
# insert
|
|
197
138
|
cmd.insert(0, f"HF_HOME={HF_HOME}")
|
|
198
|
-
if eager:
|
|
139
|
+
if args.eager:
|
|
199
140
|
cmd.append("--enforce-eager")
|
|
200
141
|
|
|
201
|
-
if served_model_name:
|
|
202
|
-
cmd.extend(["--served-model-name", served_model_name])
|
|
142
|
+
if args.served_model_name:
|
|
143
|
+
cmd.extend(["--served-model-name", args.served_model_name])
|
|
203
144
|
|
|
204
|
-
if
|
|
145
|
+
if args.bnb:
|
|
205
146
|
cmd.extend(
|
|
206
147
|
["--quantization", "bitsandbytes", "--load-format", "bitsandbytes"]
|
|
207
148
|
)
|
|
208
149
|
|
|
209
|
-
if enable_lora:
|
|
150
|
+
if args.enable_lora:
|
|
210
151
|
cmd.extend(["--fully-sharded-loras", "--enable-lora"])
|
|
211
152
|
|
|
212
|
-
if lora_modules:
|
|
213
|
-
|
|
214
|
-
# len must be even and we will join tuple with `=`
|
|
215
|
-
assert len(lora_modules) % 2 == 0, "lora_modules must be even"
|
|
216
|
-
# lora_modulle = [f'{name}={module}' for name, module in zip(lora_module[::2], lora_module[1::2])]
|
|
217
|
-
# import ipdb;ipdb.set_trace()
|
|
153
|
+
if args.lora_modules:
|
|
154
|
+
assert len(args.lora_modules) % 2 == 0, "lora_modules must be even"
|
|
218
155
|
s = ""
|
|
219
|
-
for i in range(0, len(lora_modules), 2):
|
|
220
|
-
name = lora_modules[i]
|
|
221
|
-
module = lora_modules[i + 1]
|
|
156
|
+
for i in range(0, len(args.lora_modules), 2):
|
|
157
|
+
name = args.lora_modules[i]
|
|
158
|
+
module = args.lora_modules[i + 1]
|
|
222
159
|
s += f"{name}={module} "
|
|
223
|
-
|
|
224
160
|
cmd.extend(["--lora-modules", s])
|
|
225
|
-
|
|
161
|
+
|
|
162
|
+
if hasattr(args, "enable_reasoning") and args.enable_reasoning:
|
|
163
|
+
cmd.extend(["--enable-reasoning", "--reasoning-parser", "deepseek_r1"])
|
|
164
|
+
# Add VLLM_USE_V1=0 to the environment for reasoning mode
|
|
165
|
+
cmd.insert(0, "VLLM_USE_V1=0")
|
|
166
|
+
|
|
226
167
|
final_cmd = " ".join(cmd)
|
|
227
168
|
log_file = f"/tmp/vllm_{port}.txt"
|
|
228
169
|
final_cmd_with_log = f'"{final_cmd} 2>&1 | tee {log_file}"'
|
|
@@ -235,14 +176,15 @@ def serve(
|
|
|
235
176
|
os.system(run_in_tmux)
|
|
236
177
|
|
|
237
178
|
|
|
238
|
-
def get_vllm():
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
179
|
+
def get_vllm() -> str:
|
|
180
|
+
"""Get the vLLM binary path."""
|
|
181
|
+
vllm_binary = subprocess.check_output("which vllm", shell=True, text=True).strip()
|
|
182
|
+
vllm_binary = os.getenv("VLLM_BINARY", vllm_binary)
|
|
183
|
+
logger.info(f"vLLM binary: {vllm_binary}")
|
|
242
184
|
assert os.path.exists(
|
|
243
|
-
|
|
244
|
-
), f"vLLM binary not found at {
|
|
245
|
-
return
|
|
185
|
+
vllm_binary
|
|
186
|
+
), f"vLLM binary not found at {vllm_binary}, please set VLLM_BINARY env variable"
|
|
187
|
+
return vllm_binary
|
|
246
188
|
|
|
247
189
|
|
|
248
190
|
def get_args():
|
|
@@ -292,11 +234,11 @@ def get_args():
|
|
|
292
234
|
"--max_model_len", "-mml", type=int, default=8192, help="Maximum model length"
|
|
293
235
|
)
|
|
294
236
|
parser.add_argument(
|
|
295
|
-
"--
|
|
237
|
+
"--enable_lora",
|
|
296
238
|
dest="enable_lora",
|
|
297
|
-
action="
|
|
239
|
+
action="store_true",
|
|
298
240
|
help="Disable LoRA support",
|
|
299
|
-
default=
|
|
241
|
+
default=False,
|
|
300
242
|
)
|
|
301
243
|
parser.add_argument("--bnb", action="store_true", help="Enable quantization")
|
|
302
244
|
parser.add_argument(
|
|
@@ -330,6 +272,9 @@ def get_args():
|
|
|
330
272
|
type=str,
|
|
331
273
|
help="List of LoRA modules in the format lora_name lora_module",
|
|
332
274
|
)
|
|
275
|
+
parser.add_argument(
|
|
276
|
+
"--enable-reasoning", action="store_true", help="Enable reasoning"
|
|
277
|
+
)
|
|
333
278
|
return parser.parse_args()
|
|
334
279
|
|
|
335
280
|
|
|
@@ -371,23 +316,8 @@ def main():
|
|
|
371
316
|
logger.info(f"Model name from LoRA config: {model_name}")
|
|
372
317
|
args.model = model_name
|
|
373
318
|
# port_start from hostport
|
|
374
|
-
|
|
375
|
-
serve(
|
|
376
|
-
args.model,
|
|
377
|
-
args.gpu_groups,
|
|
378
|
-
args.served_model_name,
|
|
379
|
-
port_start,
|
|
380
|
-
args.gpu_memory_utilization,
|
|
381
|
-
args.dtype,
|
|
382
|
-
args.max_model_len,
|
|
383
|
-
args.enable_lora,
|
|
384
|
-
args.bnb,
|
|
385
|
-
args.eager,
|
|
386
|
-
args.lora_modules,
|
|
387
|
-
)
|
|
319
|
+
serve(args)
|
|
388
320
|
|
|
389
|
-
elif args.mode == "kill":
|
|
390
|
-
kill_existing_vllm(args.vllm_binary)
|
|
391
321
|
elif args.mode == "add_lora":
|
|
392
322
|
if args.lora:
|
|
393
323
|
lora_name, lora_path = args.lora
|
|
@@ -1,14 +1,15 @@
|
|
|
1
|
-
llm_utils/__init__.py,sha256=
|
|
1
|
+
llm_utils/__init__.py,sha256=JcRRsx6dtGLD8nwIz90Iowj6PLOO2_hUi254VbDpc_I,773
|
|
2
2
|
llm_utils/chat_format/__init__.py,sha256=8dBIUqFJvkgQYedxBtcyxt-4tt8JxAKVap2JlTXmgaM,737
|
|
3
3
|
llm_utils/chat_format/display.py,sha256=a3zWzo47SUf4i-uic-dwf-vxtu6gZWLbnJrszjjZjQ8,9801
|
|
4
4
|
llm_utils/chat_format/transform.py,sha256=328V18FOgRQzljAl9Mh8NF4Tl-N3cZZIPmAwHQspXCY,5461
|
|
5
5
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
6
|
llm_utils/group_messages.py,sha256=wyiZzs7O8yK2lyIakV2x-1CrrWVT12sjnP1vVnmPet4,3606
|
|
7
|
-
llm_utils/lm/__init__.py,sha256=
|
|
8
|
-
llm_utils/lm/
|
|
7
|
+
llm_utils/lm/__init__.py,sha256=e8eCWlLo39GZjq9CokludZGHYVZ7BnbWZ6GOJoiWGzE,110
|
|
8
|
+
llm_utils/lm/alm.py,sha256=mJvB6uAzfakIjA7We19-VJNI9UKKkdfqeef1rJlKR9A,15773
|
|
9
|
+
llm_utils/lm/lm.py,sha256=3mzLYKRbo50XjHp6_WuqkfG2HqTwmozXtQjYQC81m28,19516
|
|
9
10
|
llm_utils/lm/utils.py,sha256=-fDNueiXKQI6RDoNHJYNyORomf2XlCf2doJZ3GEV2Io,4762
|
|
10
|
-
llm_utils/scripts/vllm_load_balancer.py,sha256=
|
|
11
|
-
llm_utils/scripts/vllm_serve.py,sha256=
|
|
11
|
+
llm_utils/scripts/vllm_load_balancer.py,sha256=17zaq8RJseikHVoAibGOz0p_MCLcNlnhZDkk7g4cuLc,17519
|
|
12
|
+
llm_utils/scripts/vllm_serve.py,sha256=CbW_3Y9Vt7eQYoGGPT3yj1nhbLYOc3b1LdJBy1sVX-Y,11976
|
|
12
13
|
speedy_utils/__init__.py,sha256=I2bSfDIE9yRF77tnHW0vqfExDA2m1gUx4AH8C9XmGtg,1707
|
|
13
14
|
speedy_utils/all.py,sha256=A9jiKGjo950eg1pscS9x38OWAjKGyusoAN5mrfweY4E,3090
|
|
14
15
|
speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -24,7 +25,7 @@ speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
|
24
25
|
speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
|
|
25
26
|
speedy_utils/multi_worker/thread.py,sha256=9pXjvgjD0s0Hp0cZ6I3M0ndp1OlYZ1yvqbs_bcun_Kw,12775
|
|
26
27
|
speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
|
|
27
|
-
speedy_utils-1.0.
|
|
28
|
-
speedy_utils-1.0.
|
|
29
|
-
speedy_utils-1.0.
|
|
30
|
-
speedy_utils-1.0.
|
|
28
|
+
speedy_utils-1.0.12.dist-info/METADATA,sha256=obiUx5u8QPzhUDupqzgjZW-pHWyR4tRUt8iNhoZtZ10,7392
|
|
29
|
+
speedy_utils-1.0.12.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
30
|
+
speedy_utils-1.0.12.dist-info/entry_points.txt,sha256=rP43satgw1uHcKUAlmVxS-MTAQImL-03-WwLIB5a300,165
|
|
31
|
+
speedy_utils-1.0.12.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|