speedy-utils 1.0.9__py3-none-any.whl → 1.0.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/lm/lm.py +190 -30
- llm_utils/scripts/vllm_serve.py +62 -132
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.11.dist-info}/METADATA +1 -1
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.11.dist-info}/RECORD +6 -6
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.11.dist-info}/WHEEL +0 -0
- {speedy_utils-1.0.9.dist-info → speedy_utils-1.0.11.dist-info}/entry_points.txt +0 -0
llm_utils/lm/lm.py
CHANGED
|
@@ -18,7 +18,9 @@ from typing import (
|
|
|
18
18
|
)
|
|
19
19
|
|
|
20
20
|
from httpx import URL
|
|
21
|
+
from huggingface_hub import repo_info
|
|
21
22
|
from loguru import logger
|
|
23
|
+
from numpy import isin
|
|
22
24
|
from openai import OpenAI, AuthenticationError, RateLimitError
|
|
23
25
|
from openai.pagination import SyncPage
|
|
24
26
|
from openai.types.chat import (
|
|
@@ -42,6 +44,29 @@ LegacyMsgs = List[Dict[str, str]] # old “…role/content…” dicts
|
|
|
42
44
|
RawMsgs = Union[Messages, LegacyMsgs] # what __call__ accepts
|
|
43
45
|
|
|
44
46
|
|
|
47
|
+
# --------------------------------------------------------------------------- #
|
|
48
|
+
# color formatting helpers
|
|
49
|
+
# --------------------------------------------------------------------------- #
|
|
50
|
+
def _red(text: str) -> str:
|
|
51
|
+
"""Format text with red color."""
|
|
52
|
+
return f"\x1b[31m{text}\x1b[0m"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def _green(text: str) -> str:
|
|
56
|
+
"""Format text with green color."""
|
|
57
|
+
return f"\x1b[32m{text}\x1b[0m"
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
def _blue(text: str) -> str:
|
|
61
|
+
"""Format text with blue color."""
|
|
62
|
+
return f"\x1b[34m{text}\x1b[0m"
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def _yellow(text: str) -> str:
|
|
66
|
+
"""Format text with yellow color."""
|
|
67
|
+
return f"\x1b[33m{text}\x1b[0m"
|
|
68
|
+
|
|
69
|
+
|
|
45
70
|
class LM:
|
|
46
71
|
"""
|
|
47
72
|
Unified language-model wrapper.
|
|
@@ -90,6 +115,7 @@ class LM:
|
|
|
90
115
|
prompt: str | None = ...,
|
|
91
116
|
messages: RawMsgs | None = ...,
|
|
92
117
|
response_format: type[str] = str,
|
|
118
|
+
return_openai_response: bool = ...,
|
|
93
119
|
**kwargs: Any,
|
|
94
120
|
) -> str: ...
|
|
95
121
|
|
|
@@ -100,6 +126,7 @@ class LM:
|
|
|
100
126
|
prompt: str | None = ...,
|
|
101
127
|
messages: RawMsgs | None = ...,
|
|
102
128
|
response_format: Type[TModel],
|
|
129
|
+
return_openai_response: bool = ...,
|
|
103
130
|
**kwargs: Any,
|
|
104
131
|
) -> TModel: ...
|
|
105
132
|
|
|
@@ -111,6 +138,7 @@ class LM:
|
|
|
111
138
|
response_format: Union[type[str], Type[BaseModel]] = str,
|
|
112
139
|
cache: Optional[bool] = None,
|
|
113
140
|
max_tokens: Optional[int] = None,
|
|
141
|
+
return_openai_response: bool = False,
|
|
114
142
|
**kwargs: Any,
|
|
115
143
|
):
|
|
116
144
|
# argument validation ------------------------------------------------
|
|
@@ -132,17 +160,117 @@ class LM:
|
|
|
132
160
|
self.openai_kwargs,
|
|
133
161
|
temperature=self.temperature,
|
|
134
162
|
max_tokens=max_tokens or self.max_tokens,
|
|
135
|
-
**kwargs,
|
|
136
163
|
)
|
|
164
|
+
kw.update(kwargs)
|
|
137
165
|
use_cache = self.do_cache if cache is None else cache
|
|
138
166
|
|
|
139
|
-
|
|
167
|
+
raw_response = self._call_raw(
|
|
140
168
|
openai_msgs,
|
|
141
169
|
response_format=response_format,
|
|
142
170
|
use_cache=use_cache,
|
|
143
171
|
**kw,
|
|
144
172
|
)
|
|
145
|
-
|
|
173
|
+
|
|
174
|
+
if return_openai_response:
|
|
175
|
+
response = raw_response
|
|
176
|
+
else:
|
|
177
|
+
response = self._parse_output(raw_response, response_format)
|
|
178
|
+
|
|
179
|
+
self.last_log = [prompt, messages, raw_response]
|
|
180
|
+
return response
|
|
181
|
+
|
|
182
|
+
def inspect_history(self) -> None:
|
|
183
|
+
if not hasattr(self, "last_log"):
|
|
184
|
+
raise ValueError("No history available. Please call the model first.")
|
|
185
|
+
|
|
186
|
+
prompt, messages, response = self.last_log
|
|
187
|
+
# Ensure response is a dictionary
|
|
188
|
+
if hasattr(response, "model_dump"):
|
|
189
|
+
response = response.model_dump()
|
|
190
|
+
|
|
191
|
+
if not messages:
|
|
192
|
+
messages = [{"role": "user", "content": prompt}]
|
|
193
|
+
|
|
194
|
+
print("\n\n")
|
|
195
|
+
print(_blue("[Conversation History]") + "\n")
|
|
196
|
+
|
|
197
|
+
# Print all messages in the conversation
|
|
198
|
+
for msg in messages:
|
|
199
|
+
role = msg["role"]
|
|
200
|
+
content = msg["content"]
|
|
201
|
+
print(_red(f"{role.capitalize()}:"))
|
|
202
|
+
|
|
203
|
+
if isinstance(content, str):
|
|
204
|
+
print(content.strip())
|
|
205
|
+
elif isinstance(content, list):
|
|
206
|
+
# Handle multimodal content
|
|
207
|
+
for item in content:
|
|
208
|
+
if item.get("type") == "text":
|
|
209
|
+
print(item["text"].strip())
|
|
210
|
+
elif item.get("type") == "image_url":
|
|
211
|
+
image_url = item["image_url"]["url"]
|
|
212
|
+
if "base64" in image_url:
|
|
213
|
+
len_base64 = len(image_url.split("base64,")[1])
|
|
214
|
+
print(_blue(f"<IMAGE BASE64 ENCODED({len_base64})>"))
|
|
215
|
+
else:
|
|
216
|
+
print(_blue(f"<image_url: {image_url}>"))
|
|
217
|
+
print("\n")
|
|
218
|
+
|
|
219
|
+
# Print the response - now always an OpenAI completion
|
|
220
|
+
print(_red("Response:"))
|
|
221
|
+
|
|
222
|
+
# Handle OpenAI response object
|
|
223
|
+
if isinstance(response, dict) and 'choices' in response and response['choices']:
|
|
224
|
+
message = response['choices'][0].get('message', {})
|
|
225
|
+
|
|
226
|
+
# Check for reasoning content (if available)
|
|
227
|
+
reasoning = message.get('reasoning_content')
|
|
228
|
+
|
|
229
|
+
# Check for parsed content (structured mode)
|
|
230
|
+
parsed = message.get('parsed')
|
|
231
|
+
|
|
232
|
+
# Get regular content
|
|
233
|
+
content = message.get('content')
|
|
234
|
+
|
|
235
|
+
# Display reasoning if available
|
|
236
|
+
if reasoning:
|
|
237
|
+
print(_yellow('<think>'))
|
|
238
|
+
print(reasoning.strip())
|
|
239
|
+
print(_yellow('</think>'))
|
|
240
|
+
print()
|
|
241
|
+
|
|
242
|
+
# Display parsed content for structured responses
|
|
243
|
+
if parsed:
|
|
244
|
+
# print(_green('<Parsed Structure>'))
|
|
245
|
+
if hasattr(parsed, 'model_dump'):
|
|
246
|
+
print(json.dumps(parsed.model_dump(), indent=2))
|
|
247
|
+
else:
|
|
248
|
+
print(json.dumps(parsed, indent=2))
|
|
249
|
+
# print(_green('</Parsed Structure>'))
|
|
250
|
+
print()
|
|
251
|
+
|
|
252
|
+
else:
|
|
253
|
+
if content:
|
|
254
|
+
# print(_green("<Content>"))
|
|
255
|
+
print(content.strip())
|
|
256
|
+
# print(_green("</Content>"))
|
|
257
|
+
else:
|
|
258
|
+
print(_green("[No content]"))
|
|
259
|
+
|
|
260
|
+
# Show if there were multiple completions
|
|
261
|
+
if len(response['choices']) > 1:
|
|
262
|
+
print(_blue(f"\n(Plus {len(response['choices']) - 1} other completions)"))
|
|
263
|
+
else:
|
|
264
|
+
# Fallback for non-standard response objects or cached responses
|
|
265
|
+
print(_yellow("Warning: Not a standard OpenAI response object"))
|
|
266
|
+
if isinstance(response, str):
|
|
267
|
+
print(_green(response.strip()))
|
|
268
|
+
elif isinstance(response, dict):
|
|
269
|
+
print(_green(json.dumps(response, indent=2)))
|
|
270
|
+
else:
|
|
271
|
+
print(_green(str(response)))
|
|
272
|
+
|
|
273
|
+
# print("\n\n")
|
|
146
274
|
|
|
147
275
|
# --------------------------------------------------------------------- #
|
|
148
276
|
# low-level OpenAI call
|
|
@@ -156,8 +284,11 @@ class LM:
|
|
|
156
284
|
):
|
|
157
285
|
assert self.model is not None, "Model must be set before making a call."
|
|
158
286
|
model: str = self.model
|
|
287
|
+
|
|
159
288
|
cache_key = (
|
|
160
|
-
self._cache_key(messages, kw, response_format)
|
|
289
|
+
self._cache_key(messages, kw, response_format)
|
|
290
|
+
if use_cache
|
|
291
|
+
else None
|
|
161
292
|
)
|
|
162
293
|
if cache_key and (hit := self._load_cache(cache_key)) is not None:
|
|
163
294
|
return hit
|
|
@@ -165,31 +296,28 @@ class LM:
|
|
|
165
296
|
try:
|
|
166
297
|
# structured mode
|
|
167
298
|
if response_format is not str and issubclass(response_format, BaseModel):
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
173
|
-
**kw,
|
|
174
|
-
)
|
|
299
|
+
openai_response = self.client.beta.chat.completions.parse(
|
|
300
|
+
model=model,
|
|
301
|
+
messages=list(messages),
|
|
302
|
+
response_format=response_format, # type: ignore[arg-type]
|
|
303
|
+
**kw,
|
|
175
304
|
)
|
|
176
|
-
result: Any = rsp.choices[0].message.parsed # already a model
|
|
177
305
|
# plain-text mode
|
|
178
306
|
else:
|
|
179
|
-
|
|
307
|
+
openai_response = self.client.chat.completions.create(
|
|
180
308
|
model=model,
|
|
181
309
|
messages=list(messages),
|
|
182
310
|
**kw,
|
|
183
311
|
)
|
|
184
|
-
|
|
312
|
+
|
|
185
313
|
except (AuthenticationError, RateLimitError) as exc: # pragma: no cover
|
|
186
314
|
logger.error(exc)
|
|
187
315
|
raise
|
|
188
316
|
|
|
189
317
|
if cache_key:
|
|
190
|
-
self._dump_cache(cache_key,
|
|
318
|
+
self._dump_cache(cache_key, openai_response)
|
|
191
319
|
|
|
192
|
-
return
|
|
320
|
+
return openai_response
|
|
193
321
|
|
|
194
322
|
# --------------------------------------------------------------------- #
|
|
195
323
|
# legacy → typed messages
|
|
@@ -232,36 +360,68 @@ class LM:
|
|
|
232
360
|
# --------------------------------------------------------------------- #
|
|
233
361
|
@staticmethod
|
|
234
362
|
def _parse_output(
|
|
235
|
-
|
|
363
|
+
raw_response: Any,
|
|
236
364
|
response_format: Union[type[str], Type[BaseModel]],
|
|
237
365
|
) -> str | BaseModel:
|
|
366
|
+
# Convert any object to dict if needed
|
|
367
|
+
if hasattr(raw_response, 'model_dump'):
|
|
368
|
+
raw_response = raw_response.model_dump()
|
|
369
|
+
|
|
238
370
|
if response_format is str:
|
|
239
|
-
|
|
240
|
-
|
|
371
|
+
# Extract the content from OpenAI response dict
|
|
372
|
+
if isinstance(raw_response, dict) and 'choices' in raw_response:
|
|
373
|
+
message = raw_response['choices'][0]['message']
|
|
374
|
+
return message.get('content', '') or ''
|
|
375
|
+
return cast(str, raw_response)
|
|
376
|
+
|
|
241
377
|
# For the type-checker: we *know* it's a BaseModel subclass here.
|
|
242
378
|
model_cls = cast(Type[BaseModel], response_format)
|
|
243
379
|
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
247
|
-
|
|
380
|
+
# Handle structured response
|
|
381
|
+
if isinstance(raw_response, dict) and 'choices' in raw_response:
|
|
382
|
+
message = raw_response['choices'][0]['message']
|
|
383
|
+
|
|
384
|
+
# Check if already parsed by OpenAI client
|
|
385
|
+
if 'parsed' in message:
|
|
386
|
+
return model_cls.model_validate(message['parsed'])
|
|
387
|
+
|
|
388
|
+
# Need to parse the content
|
|
389
|
+
content = message.get('content')
|
|
390
|
+
if content is None:
|
|
391
|
+
raise ValueError("Model returned empty content")
|
|
392
|
+
|
|
393
|
+
try:
|
|
394
|
+
data = json.loads(content)
|
|
395
|
+
return model_cls.model_validate(data)
|
|
396
|
+
except Exception as exc:
|
|
397
|
+
raise ValueError(f"Failed to parse model output as JSON:\n{content}") from exc
|
|
398
|
+
|
|
399
|
+
# Handle cached response or other formats
|
|
400
|
+
if isinstance(raw_response, model_cls):
|
|
401
|
+
return raw_response
|
|
402
|
+
if isinstance(raw_response, dict):
|
|
403
|
+
return model_cls.model_validate(raw_response)
|
|
404
|
+
|
|
405
|
+
# Try parsing as JSON string
|
|
248
406
|
try:
|
|
249
|
-
data = json.loads(
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
407
|
+
data = json.loads(raw_response)
|
|
408
|
+
return model_cls.model_validate(data)
|
|
409
|
+
except Exception as exc:
|
|
410
|
+
raise ValueError(f"Model did not return valid JSON:\n---\n{raw_response}") from exc
|
|
253
411
|
|
|
254
412
|
# --------------------------------------------------------------------- #
|
|
255
413
|
# tiny disk cache
|
|
256
414
|
# --------------------------------------------------------------------- #
|
|
257
415
|
@staticmethod
|
|
258
416
|
def _cache_key(
|
|
259
|
-
messages: Any,
|
|
417
|
+
messages: Any,
|
|
418
|
+
kw: Any,
|
|
419
|
+
response_format: Union[type[str], Type[BaseModel]],
|
|
260
420
|
) -> str:
|
|
261
421
|
tag = response_format.__name__ if response_format is not str else "text"
|
|
262
422
|
blob = json.dumps([messages, kw, tag], sort_keys=True).encode()
|
|
263
423
|
return base64.urlsafe_b64encode(hashlib.sha256(blob).digest()).decode()[:22]
|
|
264
|
-
|
|
424
|
+
|
|
265
425
|
@staticmethod
|
|
266
426
|
def _cache_path(key: str) -> str:
|
|
267
427
|
return os.path.expanduser(f"~/.cache/lm/{key}.json")
|
|
@@ -289,7 +449,7 @@ class LM:
|
|
|
289
449
|
return None
|
|
290
450
|
|
|
291
451
|
@staticmethod
|
|
292
|
-
def list_models(port=None, host=
|
|
452
|
+
def list_models(port=None, host="localhost") -> List[str]:
|
|
293
453
|
"""
|
|
294
454
|
List available models.
|
|
295
455
|
"""
|
llm_utils/scripts/vllm_serve.py
CHANGED
|
@@ -9,19 +9,17 @@ Serve a base model:
|
|
|
9
9
|
svllm serve --model MODEL_NAME --gpus GPU_GROUPS
|
|
10
10
|
|
|
11
11
|
Add a LoRA to a served model:
|
|
12
|
-
svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port
|
|
12
|
+
svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port
|
|
13
|
+
(if add then the port must be specify)
|
|
13
14
|
"""
|
|
14
15
|
|
|
15
|
-
from glob import glob
|
|
16
16
|
import os
|
|
17
17
|
import subprocess
|
|
18
|
-
import
|
|
19
|
-
from typing import List, Literal, Optional
|
|
20
|
-
from fastcore.script import call_parse
|
|
21
|
-
from loguru import logger
|
|
18
|
+
from typing import List, Optional
|
|
22
19
|
import argparse
|
|
23
20
|
import requests
|
|
24
21
|
import openai
|
|
22
|
+
from loguru import logger
|
|
25
23
|
|
|
26
24
|
from speedy_utils.common.utils_io import load_by_ext
|
|
27
25
|
|
|
@@ -32,63 +30,22 @@ HF_HOME: str = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingfac
|
|
|
32
30
|
logger.info(f"LORA_DIR: {LORA_DIR}")
|
|
33
31
|
|
|
34
32
|
|
|
35
|
-
def model_list(host_port, api_key="abc"):
|
|
33
|
+
def model_list(host_port: str, api_key: str = "abc") -> None:
|
|
34
|
+
"""List models from the vLLM server."""
|
|
36
35
|
client = openai.OpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
|
|
37
36
|
models = client.models.list()
|
|
38
37
|
for model in models:
|
|
39
38
|
print(f"Model ID: {model.id}")
|
|
40
39
|
|
|
41
40
|
|
|
42
|
-
def kill_existing_vllm(vllm_binary: Optional[str] = None) -> None:
|
|
43
|
-
"""Kill selected vLLM processes using fzf."""
|
|
44
|
-
if not vllm_binary:
|
|
45
|
-
vllm_binary = get_vllm()
|
|
46
|
-
|
|
47
|
-
# List running vLLM processes
|
|
48
|
-
result = subprocess.run(
|
|
49
|
-
f"ps aux | grep {vllm_binary} | grep -v grep",
|
|
50
|
-
shell=True,
|
|
51
|
-
capture_output=True,
|
|
52
|
-
text=True,
|
|
53
|
-
)
|
|
54
|
-
processes = result.stdout.strip().split("\n")
|
|
55
|
-
|
|
56
|
-
if not processes or processes == [""]:
|
|
57
|
-
print("No running vLLM processes found.")
|
|
58
|
-
return
|
|
59
|
-
|
|
60
|
-
# Use fzf to select processes to kill
|
|
61
|
-
fzf = subprocess.Popen(
|
|
62
|
-
["fzf", "--multi"],
|
|
63
|
-
stdin=subprocess.PIPE,
|
|
64
|
-
stdout=subprocess.PIPE,
|
|
65
|
-
text=True,
|
|
66
|
-
)
|
|
67
|
-
selected, _ = fzf.communicate("\n".join(processes))
|
|
68
|
-
|
|
69
|
-
if not selected:
|
|
70
|
-
print("No processes selected.")
|
|
71
|
-
return
|
|
72
|
-
|
|
73
|
-
# Extract PIDs and kill selected processes
|
|
74
|
-
pids = [line.split()[1] for line in selected.strip().split("\n")]
|
|
75
|
-
for pid in pids:
|
|
76
|
-
subprocess.run(
|
|
77
|
-
f"kill -9 {pid}",
|
|
78
|
-
shell=True,
|
|
79
|
-
stdout=subprocess.DEVNULL,
|
|
80
|
-
stderr=subprocess.DEVNULL,
|
|
81
|
-
)
|
|
82
|
-
print(f"Killed processes: {', '.join(pids)}")
|
|
83
|
-
|
|
84
|
-
|
|
85
41
|
def add_lora(
|
|
86
42
|
lora_name_or_path: str,
|
|
87
43
|
host_port: str,
|
|
88
44
|
url: str = "http://HOST:PORT/v1/load_lora_adapter",
|
|
89
45
|
served_model_name: Optional[str] = None,
|
|
90
|
-
lora_module: Optional[str] = None,
|
|
46
|
+
lora_module: Optional[str] = None,
|
|
91
47
|
) -> dict:
|
|
48
|
+
"""Add a LoRA adapter to a running vLLM server."""
|
|
92
49
|
url = url.replace("HOST:PORT", host_port)
|
|
93
50
|
headers = {"Content-Type": "application/json"}
|
|
94
51
|
|
|
@@ -96,15 +53,12 @@ def add_lora(
|
|
|
96
53
|
"lora_name": served_model_name,
|
|
97
54
|
"lora_path": os.path.abspath(lora_name_or_path),
|
|
98
55
|
}
|
|
99
|
-
if lora_module:
|
|
56
|
+
if lora_module:
|
|
100
57
|
data["lora_module"] = lora_module
|
|
101
58
|
logger.info(f"{data=}, {headers}, {url=}")
|
|
102
|
-
# logger.warning(f"Failed to unload LoRA adapter: {str(e)}")
|
|
103
59
|
try:
|
|
104
|
-
response = requests.post(url, headers=headers, json=data)
|
|
60
|
+
response = requests.post(url, headers=headers, json=data, timeout=10)
|
|
105
61
|
response.raise_for_status()
|
|
106
|
-
|
|
107
|
-
# Handle potential non-JSON responses
|
|
108
62
|
try:
|
|
109
63
|
return response.json()
|
|
110
64
|
except ValueError:
|
|
@@ -116,113 +70,100 @@ def add_lora(
|
|
|
116
70
|
else "Request completed with empty response"
|
|
117
71
|
),
|
|
118
72
|
}
|
|
119
|
-
|
|
120
73
|
except requests.exceptions.RequestException as e:
|
|
121
74
|
logger.error(f"Request failed: {str(e)}")
|
|
122
75
|
return {"error": f"Request failed: {str(e)}"}
|
|
123
76
|
|
|
124
77
|
|
|
125
|
-
def unload_lora(lora_name, host_port):
|
|
78
|
+
def unload_lora(lora_name: str, host_port: str) -> Optional[dict]:
|
|
79
|
+
"""Unload a LoRA adapter from a running vLLM server."""
|
|
126
80
|
try:
|
|
127
81
|
url = f"http://{host_port}/v1/unload_lora_adapter"
|
|
128
82
|
logger.info(f"{url=}")
|
|
129
83
|
headers = {"Content-Type": "application/json"}
|
|
130
84
|
data = {"lora_name": lora_name}
|
|
131
85
|
logger.info(f"Unloading LoRA adapter: {data=}")
|
|
132
|
-
response = requests.post(url, headers=headers, json=data)
|
|
86
|
+
response = requests.post(url, headers=headers, json=data, timeout=10)
|
|
133
87
|
response.raise_for_status()
|
|
134
88
|
logger.success(f"Unloaded LoRA adapter: {lora_name}")
|
|
135
89
|
except requests.exceptions.RequestException as e:
|
|
136
90
|
return {"error": f"Request failed: {str(e)}"}
|
|
137
91
|
|
|
138
92
|
|
|
139
|
-
def serve(
|
|
140
|
-
model: str,
|
|
141
|
-
gpu_groups: str,
|
|
142
|
-
served_model_name: Optional[str] = None,
|
|
143
|
-
port_start: int = 8155,
|
|
144
|
-
gpu_memory_utilization: float = 0.93,
|
|
145
|
-
dtype: str = "bfloat16",
|
|
146
|
-
max_model_len: int = 8192,
|
|
147
|
-
enable_lora: bool = False,
|
|
148
|
-
is_bnb: bool = False,
|
|
149
|
-
eager: bool = False,
|
|
150
|
-
lora_modules: Optional[List[str]] = None, # Updated type
|
|
151
|
-
) -> None:
|
|
152
|
-
"""Main function to start or kill vLLM containers."""
|
|
153
|
-
|
|
93
|
+
def serve(args) -> None:
|
|
154
94
|
"""Start vLLM containers with dynamic args."""
|
|
155
95
|
print("Starting vLLM containers...,")
|
|
156
|
-
gpu_groups_arr: List[str] = gpu_groups.split(",")
|
|
157
|
-
|
|
158
|
-
if enable_lora:
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
|
|
167
|
-
|
|
96
|
+
gpu_groups_arr: List[str] = args.gpu_groups.split(",")
|
|
97
|
+
vllm_binary: str = get_vllm()
|
|
98
|
+
if args.enable_lora:
|
|
99
|
+
vllm_binary = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + vllm_binary
|
|
100
|
+
|
|
101
|
+
if (
|
|
102
|
+
not args.bnb
|
|
103
|
+
and args.model
|
|
104
|
+
and ("bnb" in args.model.lower() or "4bit" in args.model.lower())
|
|
105
|
+
):
|
|
106
|
+
args.bnb = True
|
|
107
|
+
print(f"Auto-detected quantization for model: {args.model}")
|
|
108
|
+
|
|
109
|
+
if args.enable_lora:
|
|
168
110
|
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
|
|
169
111
|
print("Enabled runtime LoRA updating")
|
|
170
112
|
|
|
171
113
|
for i, gpu_group in enumerate(gpu_groups_arr):
|
|
172
|
-
port =
|
|
114
|
+
port = int(args.host_port.split(":")[-1]) + i
|
|
173
115
|
gpu_group = ",".join([str(x) for x in gpu_group])
|
|
174
116
|
tensor_parallel = len(gpu_group.split(","))
|
|
175
117
|
|
|
176
118
|
cmd = [
|
|
177
119
|
f"CUDA_VISIBLE_DEVICES={gpu_group}",
|
|
178
|
-
|
|
120
|
+
vllm_binary,
|
|
179
121
|
"serve",
|
|
180
|
-
model,
|
|
122
|
+
args.model,
|
|
181
123
|
"--port",
|
|
182
124
|
str(port),
|
|
183
125
|
"--tensor-parallel",
|
|
184
126
|
str(tensor_parallel),
|
|
185
127
|
"--gpu-memory-utilization",
|
|
186
|
-
str(gpu_memory_utilization),
|
|
128
|
+
str(args.gpu_memory_utilization),
|
|
187
129
|
"--dtype",
|
|
188
|
-
dtype,
|
|
130
|
+
args.dtype,
|
|
189
131
|
"--max-model-len",
|
|
190
|
-
str(max_model_len),
|
|
132
|
+
str(args.max_model_len),
|
|
191
133
|
"--enable-prefix-caching",
|
|
192
134
|
"--disable-log-requests",
|
|
193
135
|
"--uvicorn-log-level critical",
|
|
194
136
|
]
|
|
195
137
|
if HF_HOME:
|
|
196
|
-
# insert
|
|
197
138
|
cmd.insert(0, f"HF_HOME={HF_HOME}")
|
|
198
|
-
if eager:
|
|
139
|
+
if args.eager:
|
|
199
140
|
cmd.append("--enforce-eager")
|
|
200
141
|
|
|
201
|
-
if served_model_name:
|
|
202
|
-
cmd.extend(["--served-model-name", served_model_name])
|
|
142
|
+
if args.served_model_name:
|
|
143
|
+
cmd.extend(["--served-model-name", args.served_model_name])
|
|
203
144
|
|
|
204
|
-
if
|
|
145
|
+
if args.bnb:
|
|
205
146
|
cmd.extend(
|
|
206
147
|
["--quantization", "bitsandbytes", "--load-format", "bitsandbytes"]
|
|
207
148
|
)
|
|
208
149
|
|
|
209
|
-
if enable_lora:
|
|
150
|
+
if args.enable_lora:
|
|
210
151
|
cmd.extend(["--fully-sharded-loras", "--enable-lora"])
|
|
211
152
|
|
|
212
|
-
if lora_modules:
|
|
213
|
-
|
|
214
|
-
# len must be even and we will join tuple with `=`
|
|
215
|
-
assert len(lora_modules) % 2 == 0, "lora_modules must be even"
|
|
216
|
-
# lora_modulle = [f'{name}={module}' for name, module in zip(lora_module[::2], lora_module[1::2])]
|
|
217
|
-
# import ipdb;ipdb.set_trace()
|
|
153
|
+
if args.lora_modules:
|
|
154
|
+
assert len(args.lora_modules) % 2 == 0, "lora_modules must be even"
|
|
218
155
|
s = ""
|
|
219
|
-
for i in range(0, len(lora_modules), 2):
|
|
220
|
-
name = lora_modules[i]
|
|
221
|
-
module = lora_modules[i + 1]
|
|
156
|
+
for i in range(0, len(args.lora_modules), 2):
|
|
157
|
+
name = args.lora_modules[i]
|
|
158
|
+
module = args.lora_modules[i + 1]
|
|
222
159
|
s += f"{name}={module} "
|
|
223
|
-
|
|
224
160
|
cmd.extend(["--lora-modules", s])
|
|
225
|
-
|
|
161
|
+
|
|
162
|
+
if hasattr(args, "enable_reasoning") and args.enable_reasoning:
|
|
163
|
+
cmd.extend(["--enable-reasoning", "--reasoning-parser", "deepseek_r1"])
|
|
164
|
+
# Add VLLM_USE_V1=0 to the environment for reasoning mode
|
|
165
|
+
cmd.insert(0, "VLLM_USE_V1=0")
|
|
166
|
+
|
|
226
167
|
final_cmd = " ".join(cmd)
|
|
227
168
|
log_file = f"/tmp/vllm_{port}.txt"
|
|
228
169
|
final_cmd_with_log = f'"{final_cmd} 2>&1 | tee {log_file}"'
|
|
@@ -235,14 +176,15 @@ def serve(
|
|
|
235
176
|
os.system(run_in_tmux)
|
|
236
177
|
|
|
237
178
|
|
|
238
|
-
def get_vllm():
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
179
|
+
def get_vllm() -> str:
|
|
180
|
+
"""Get the vLLM binary path."""
|
|
181
|
+
vllm_binary = subprocess.check_output("which vllm", shell=True, text=True).strip()
|
|
182
|
+
vllm_binary = os.getenv("VLLM_BINARY", vllm_binary)
|
|
183
|
+
logger.info(f"vLLM binary: {vllm_binary}")
|
|
242
184
|
assert os.path.exists(
|
|
243
|
-
|
|
244
|
-
), f"vLLM binary not found at {
|
|
245
|
-
return
|
|
185
|
+
vllm_binary
|
|
186
|
+
), f"vLLM binary not found at {vllm_binary}, please set VLLM_BINARY env variable"
|
|
187
|
+
return vllm_binary
|
|
246
188
|
|
|
247
189
|
|
|
248
190
|
def get_args():
|
|
@@ -330,6 +272,9 @@ def get_args():
|
|
|
330
272
|
type=str,
|
|
331
273
|
help="List of LoRA modules in the format lora_name lora_module",
|
|
332
274
|
)
|
|
275
|
+
parser.add_argument(
|
|
276
|
+
"--enable-reasoning", action="store_true", help="Enable reasoning"
|
|
277
|
+
)
|
|
333
278
|
return parser.parse_args()
|
|
334
279
|
|
|
335
280
|
|
|
@@ -371,23 +316,8 @@ def main():
|
|
|
371
316
|
logger.info(f"Model name from LoRA config: {model_name}")
|
|
372
317
|
args.model = model_name
|
|
373
318
|
# port_start from hostport
|
|
374
|
-
|
|
375
|
-
serve(
|
|
376
|
-
args.model,
|
|
377
|
-
args.gpu_groups,
|
|
378
|
-
args.served_model_name,
|
|
379
|
-
port_start,
|
|
380
|
-
args.gpu_memory_utilization,
|
|
381
|
-
args.dtype,
|
|
382
|
-
args.max_model_len,
|
|
383
|
-
args.enable_lora,
|
|
384
|
-
args.bnb,
|
|
385
|
-
args.eager,
|
|
386
|
-
args.lora_modules,
|
|
387
|
-
)
|
|
319
|
+
serve(args)
|
|
388
320
|
|
|
389
|
-
elif args.mode == "kill":
|
|
390
|
-
kill_existing_vllm(args.vllm_binary)
|
|
391
321
|
elif args.mode == "add_lora":
|
|
392
322
|
if args.lora:
|
|
393
323
|
lora_name, lora_path = args.lora
|
|
@@ -5,10 +5,10 @@ llm_utils/chat_format/transform.py,sha256=328V18FOgRQzljAl9Mh8NF4Tl-N3cZZIPmAwHQ
|
|
|
5
5
|
llm_utils/chat_format/utils.py,sha256=xTxN4HrLHcRO2PfCTR43nH1M5zCa7v0kTTdzAcGkZg0,1229
|
|
6
6
|
llm_utils/group_messages.py,sha256=wyiZzs7O8yK2lyIakV2x-1CrrWVT12sjnP1vVnmPet4,3606
|
|
7
7
|
llm_utils/lm/__init__.py,sha256=vXFILZLBmmpg39cy5XniQPSMzoFQCE3wdfz39EtqDKU,71
|
|
8
|
-
llm_utils/lm/lm.py,sha256=
|
|
8
|
+
llm_utils/lm/lm.py,sha256=4bEo4nnyCi_ybTOYfzrJz9AwpxJNkzRFAUPq7KpBklw,16695
|
|
9
9
|
llm_utils/lm/utils.py,sha256=-fDNueiXKQI6RDoNHJYNyORomf2XlCf2doJZ3GEV2Io,4762
|
|
10
10
|
llm_utils/scripts/vllm_load_balancer.py,sha256=MgMnnoKWJQc-l2fspUSkyA9wxL1RkXd7wdBLJNQBlr4,17384
|
|
11
|
-
llm_utils/scripts/vllm_serve.py,sha256=
|
|
11
|
+
llm_utils/scripts/vllm_serve.py,sha256=LlrkwfWLxdMDhfOJ-eL1VJnA4AY1Beh_cI8U6l9Xl-A,11975
|
|
12
12
|
speedy_utils/__init__.py,sha256=I2bSfDIE9yRF77tnHW0vqfExDA2m1gUx4AH8C9XmGtg,1707
|
|
13
13
|
speedy_utils/all.py,sha256=A9jiKGjo950eg1pscS9x38OWAjKGyusoAN5mrfweY4E,3090
|
|
14
14
|
speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
@@ -24,7 +24,7 @@ speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJW
|
|
|
24
24
|
speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
|
|
25
25
|
speedy_utils/multi_worker/thread.py,sha256=9pXjvgjD0s0Hp0cZ6I3M0ndp1OlYZ1yvqbs_bcun_Kw,12775
|
|
26
26
|
speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
|
|
27
|
-
speedy_utils-1.0.
|
|
28
|
-
speedy_utils-1.0.
|
|
29
|
-
speedy_utils-1.0.
|
|
30
|
-
speedy_utils-1.0.
|
|
27
|
+
speedy_utils-1.0.11.dist-info/METADATA,sha256=F48tr0hmL3k-r9O2tPbUdfbBU5JHnwxVGB547eQXElU,7392
|
|
28
|
+
speedy_utils-1.0.11.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
29
|
+
speedy_utils-1.0.11.dist-info/entry_points.txt,sha256=rP43satgw1uHcKUAlmVxS-MTAQImL-03-WwLIB5a300,165
|
|
30
|
+
speedy_utils-1.0.11.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|