speedy-utils 0.1.28__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +30 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +119 -0
- llm_utils/lm.py +742 -0
- llm_utils/lm_classification.py +0 -0
- llm_utils/load_chat_dataset.py +41 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +482 -0
- speedy_utils/__init__.py +1 -2
- speedy_utils/all.py +0 -2
- speedy_utils/common/clock.py +10 -0
- speedy_utils/common/utils_misc.py +0 -1
- speedy_utils/multi_worker/thread.py +22 -6
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/METADATA +3 -27
- speedy_utils-1.0.0.dist-info/RECORD +27 -0
- speedy_utils/common/dataclass_parser.py +0 -101
- speedy_utils/multi_worker/_handle_inputs.py +0 -50
- speedy_utils-0.1.28.dist-info/RECORD +0 -21
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/WHEEL +0 -0
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -0,0 +1,482 @@
|
|
|
1
|
+
""" "
|
|
2
|
+
USAGE:
|
|
3
|
+
Serve models and LoRAs with vLLM:
|
|
4
|
+
|
|
5
|
+
Serve a LoRA model:
|
|
6
|
+
svllm serve --lora LORA_NAME LORA_PATH --gpus GPU_GROUPS
|
|
7
|
+
|
|
8
|
+
Serve a base model:
|
|
9
|
+
svllm serve --model MODEL_NAME --gpus GPU_GROUPS
|
|
10
|
+
|
|
11
|
+
Add a LoRA to a served model:
|
|
12
|
+
svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port (if add then the port must be specify)
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
from glob import glob
|
|
16
|
+
import os
|
|
17
|
+
import subprocess
|
|
18
|
+
import time
|
|
19
|
+
from typing import List, Literal, Optional
|
|
20
|
+
from fastcore.script import call_parse
|
|
21
|
+
from loguru import logger
|
|
22
|
+
import argparse
|
|
23
|
+
import requests
|
|
24
|
+
import openai
|
|
25
|
+
|
|
26
|
+
|
|
27
|
+
LORA_DIR = os.environ.get("LORA_DIR", "/loras")
|
|
28
|
+
LORA_DIR = os.path.abspath(LORA_DIR)
|
|
29
|
+
HF_HOME = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
|
|
30
|
+
logger.info(f"LORA_DIR: {LORA_DIR}")
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def model_list(host_port, api_key="abc"):
|
|
34
|
+
client = openai.OpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
|
|
35
|
+
models = client.models.list()
|
|
36
|
+
for model in models:
|
|
37
|
+
print(f"Model ID: {model.id}")
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
def kill_existing_vllm(vllm_binary: Optional[str] = None) -> None:
|
|
41
|
+
"""Kill selected vLLM processes using fzf."""
|
|
42
|
+
if not vllm_binary:
|
|
43
|
+
vllm_binary = get_vllm()
|
|
44
|
+
|
|
45
|
+
# List running vLLM processes
|
|
46
|
+
result = subprocess.run(
|
|
47
|
+
f"ps aux | grep {vllm_binary} | grep -v grep",
|
|
48
|
+
shell=True,
|
|
49
|
+
capture_output=True,
|
|
50
|
+
text=True,
|
|
51
|
+
)
|
|
52
|
+
processes = result.stdout.strip().split("\n")
|
|
53
|
+
|
|
54
|
+
if not processes or processes == [""]:
|
|
55
|
+
print("No running vLLM processes found.")
|
|
56
|
+
return
|
|
57
|
+
|
|
58
|
+
# Use fzf to select processes to kill
|
|
59
|
+
fzf = subprocess.Popen(
|
|
60
|
+
["fzf", "--multi"],
|
|
61
|
+
stdin=subprocess.PIPE,
|
|
62
|
+
stdout=subprocess.PIPE,
|
|
63
|
+
text=True,
|
|
64
|
+
)
|
|
65
|
+
selected, _ = fzf.communicate("\n".join(processes))
|
|
66
|
+
|
|
67
|
+
if not selected:
|
|
68
|
+
print("No processes selected.")
|
|
69
|
+
return
|
|
70
|
+
|
|
71
|
+
# Extract PIDs and kill selected processes
|
|
72
|
+
pids = [line.split()[1] for line in selected.strip().split("\n")]
|
|
73
|
+
for pid in pids:
|
|
74
|
+
subprocess.run(
|
|
75
|
+
f"kill -9 {pid}",
|
|
76
|
+
shell=True,
|
|
77
|
+
stdout=subprocess.DEVNULL,
|
|
78
|
+
stderr=subprocess.DEVNULL,
|
|
79
|
+
)
|
|
80
|
+
print(f"Killed processes: {', '.join(pids)}")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def add_lora(
|
|
84
|
+
lora_name_or_path: str,
|
|
85
|
+
host_port: str,
|
|
86
|
+
url: str = "http://HOST:PORT/v1/load_lora_adapter",
|
|
87
|
+
served_model_name: Optional[str] = None,
|
|
88
|
+
lora_module: Optional[str] = None, # Added parameter
|
|
89
|
+
) -> dict:
|
|
90
|
+
url = url.replace("HOST:PORT", host_port)
|
|
91
|
+
headers = {"Content-Type": "application/json"}
|
|
92
|
+
|
|
93
|
+
data = {
|
|
94
|
+
"lora_name": served_model_name,
|
|
95
|
+
"lora_path": os.path.abspath(lora_name_or_path),
|
|
96
|
+
}
|
|
97
|
+
if lora_module: # Include lora_module if provided
|
|
98
|
+
data["lora_module"] = lora_module
|
|
99
|
+
logger.info(f"{data=}, {headers}, {url=}")
|
|
100
|
+
# logger.warning(f"Failed to unload LoRA adapter: {str(e)}")
|
|
101
|
+
try:
|
|
102
|
+
response = requests.post(url, headers=headers, json=data)
|
|
103
|
+
response.raise_for_status()
|
|
104
|
+
|
|
105
|
+
# Handle potential non-JSON responses
|
|
106
|
+
try:
|
|
107
|
+
return response.json()
|
|
108
|
+
except ValueError:
|
|
109
|
+
return {
|
|
110
|
+
"status": "success",
|
|
111
|
+
"message": (
|
|
112
|
+
response.text
|
|
113
|
+
if response.text.strip()
|
|
114
|
+
else "Request completed with empty response"
|
|
115
|
+
),
|
|
116
|
+
}
|
|
117
|
+
|
|
118
|
+
except requests.exceptions.RequestException as e:
|
|
119
|
+
logger.error(f"Request failed: {str(e)}")
|
|
120
|
+
return {"error": f"Request failed: {str(e)}"}
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
def unload_lora(lora_name, host_port):
|
|
124
|
+
try:
|
|
125
|
+
url = f"http://{host_port}/v1/unload_lora_adapter"
|
|
126
|
+
logger.info(f"{url=}")
|
|
127
|
+
headers = {"Content-Type": "application/json"}
|
|
128
|
+
data = {"lora_name": lora_name}
|
|
129
|
+
logger.info(f"Unloading LoRA adapter: {data=}")
|
|
130
|
+
response = requests.post(url, headers=headers, json=data)
|
|
131
|
+
response.raise_for_status()
|
|
132
|
+
logger.success(f"Unloaded LoRA adapter: {lora_name}")
|
|
133
|
+
except requests.exceptions.RequestException as e:
|
|
134
|
+
return {"error": f"Request failed: {str(e)}"}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def serve(
|
|
138
|
+
model: str,
|
|
139
|
+
gpu_groups: str,
|
|
140
|
+
served_model_name: Optional[str] = None,
|
|
141
|
+
port_start: int = 8155,
|
|
142
|
+
gpu_memory_utilization: float = 0.93,
|
|
143
|
+
dtype: str = "bfloat16",
|
|
144
|
+
max_model_len: int = 8192,
|
|
145
|
+
enable_lora: bool = False,
|
|
146
|
+
is_bnb: bool = False,
|
|
147
|
+
eager: bool = False,
|
|
148
|
+
chat_template: Optional[str] = None,
|
|
149
|
+
lora_modules: Optional[List[str]] = None, # Updated type
|
|
150
|
+
):
|
|
151
|
+
"""Main function to start or kill vLLM containers."""
|
|
152
|
+
|
|
153
|
+
"""Start vLLM containers with dynamic args."""
|
|
154
|
+
print("Starting vLLM containers...,")
|
|
155
|
+
gpu_groups_arr = gpu_groups.split(",")
|
|
156
|
+
VLLM_BINARY = get_vllm()
|
|
157
|
+
if enable_lora:
|
|
158
|
+
VLLM_BINARY = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + VLLM_BINARY
|
|
159
|
+
|
|
160
|
+
# Auto-detect quantization based on model name if not explicitly set
|
|
161
|
+
if not is_bnb and model and ("bnb" in model.lower() or "4bit" in model.lower()):
|
|
162
|
+
is_bnb = True
|
|
163
|
+
print(f"Auto-detected quantization for model: {model}")
|
|
164
|
+
|
|
165
|
+
# Set environment variables for LoRA if needed
|
|
166
|
+
if enable_lora:
|
|
167
|
+
os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
|
|
168
|
+
print("Enabled runtime LoRA updating")
|
|
169
|
+
|
|
170
|
+
for i, gpu_group in enumerate(gpu_groups_arr):
|
|
171
|
+
port = port_start + i
|
|
172
|
+
gpu_group = ",".join([str(x) for x in gpu_group])
|
|
173
|
+
tensor_parallel = len(gpu_group.split(","))
|
|
174
|
+
|
|
175
|
+
cmd = [
|
|
176
|
+
f"CUDA_VISIBLE_DEVICES={gpu_group}",
|
|
177
|
+
VLLM_BINARY,
|
|
178
|
+
"serve",
|
|
179
|
+
model,
|
|
180
|
+
"--port",
|
|
181
|
+
str(port),
|
|
182
|
+
"--tensor-parallel",
|
|
183
|
+
str(tensor_parallel),
|
|
184
|
+
"--gpu-memory-utilization",
|
|
185
|
+
str(gpu_memory_utilization),
|
|
186
|
+
"--dtype",
|
|
187
|
+
dtype,
|
|
188
|
+
"--max-model-len",
|
|
189
|
+
str(max_model_len),
|
|
190
|
+
"--enable-prefix-caching",
|
|
191
|
+
"--disable-log-requests",
|
|
192
|
+
"--uvicorn-log-level critical",
|
|
193
|
+
]
|
|
194
|
+
if HF_HOME:
|
|
195
|
+
# insert
|
|
196
|
+
cmd.insert(0, f"HF_HOME={HF_HOME}")
|
|
197
|
+
if eager:
|
|
198
|
+
cmd.append("--enforce-eager")
|
|
199
|
+
|
|
200
|
+
if served_model_name:
|
|
201
|
+
cmd.extend(["--served-model-name", served_model_name])
|
|
202
|
+
|
|
203
|
+
if is_bnb:
|
|
204
|
+
cmd.extend(
|
|
205
|
+
["--quantization", "bitsandbytes", "--load-format", "bitsandbytes"]
|
|
206
|
+
)
|
|
207
|
+
|
|
208
|
+
if enable_lora:
|
|
209
|
+
cmd.extend(["--fully-sharded-loras", "--enable-lora"])
|
|
210
|
+
|
|
211
|
+
if chat_template:
|
|
212
|
+
chat_template = get_chat_template(chat_template)
|
|
213
|
+
cmd.extend(["--chat-template", chat_template]) # Add chat_template argument
|
|
214
|
+
if lora_modules:
|
|
215
|
+
# for lora_module in lora_modules:
|
|
216
|
+
# len must be even and we will join tuple with `=`
|
|
217
|
+
assert len(lora_modules) % 2 == 0, "lora_modules must be even"
|
|
218
|
+
# lora_modulle = [f'{name}={module}' for name, module in zip(lora_module[::2], lora_module[1::2])]
|
|
219
|
+
# import ipdb;ipdb.set_trace()
|
|
220
|
+
s = ""
|
|
221
|
+
for i in range(0, len(lora_modules), 2):
|
|
222
|
+
name = lora_modules[i]
|
|
223
|
+
module = lora_modules[i + 1]
|
|
224
|
+
s += f"{name}={module} "
|
|
225
|
+
|
|
226
|
+
cmd.extend(["--lora-modules", s])
|
|
227
|
+
# add kwargs
|
|
228
|
+
final_cmd = " ".join(cmd)
|
|
229
|
+
log_file = f"/tmp/vllm_{port}.txt"
|
|
230
|
+
final_cmd_with_log = f'"{final_cmd} 2>&1 | tee {log_file}"'
|
|
231
|
+
run_in_tmux = (
|
|
232
|
+
f"tmux new-session -d -s vllm_{port} 'bash -c {final_cmd_with_log}'"
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
print(final_cmd)
|
|
236
|
+
print("Logging to", log_file)
|
|
237
|
+
os.system(run_in_tmux)
|
|
238
|
+
|
|
239
|
+
|
|
240
|
+
def get_vllm():
|
|
241
|
+
VLLM_BINARY = subprocess.check_output("which vllm", shell=True, text=True).strip()
|
|
242
|
+
VLLM_BINARY = os.getenv("VLLM_BINARY", VLLM_BINARY)
|
|
243
|
+
logger.info(f"vLLM binary: {VLLM_BINARY}")
|
|
244
|
+
assert os.path.exists(
|
|
245
|
+
VLLM_BINARY
|
|
246
|
+
), f"vLLM binary not found at {VLLM_BINARY}, please set VLLM_BINARY env variable"
|
|
247
|
+
return VLLM_BINARY
|
|
248
|
+
|
|
249
|
+
|
|
250
|
+
def get_args():
|
|
251
|
+
"""Parse command line arguments."""
|
|
252
|
+
example_args = [
|
|
253
|
+
"svllm serve --model MODEL_NAME --gpus 0,1,2,3",
|
|
254
|
+
"svllm serve --lora LORA_NAME LORA_PATH --gpus 0,1,2,3",
|
|
255
|
+
"svllm add_lora --lora LORA_NAME LORA_PATH --host_port localhost:8150",
|
|
256
|
+
"svllm kill",
|
|
257
|
+
]
|
|
258
|
+
|
|
259
|
+
parser = argparse.ArgumentParser(
|
|
260
|
+
description="vLLM Serve Script", epilog="Example: " + " || ".join(example_args)
|
|
261
|
+
)
|
|
262
|
+
parser.add_argument(
|
|
263
|
+
"mode",
|
|
264
|
+
choices=["serve", "kill", "add_lora", "unload_lora", "list_models"],
|
|
265
|
+
help="Mode to run the script in",
|
|
266
|
+
)
|
|
267
|
+
parser.add_argument("--model", "-m", type=str, help="Model to serve")
|
|
268
|
+
parser.add_argument(
|
|
269
|
+
"--gpus",
|
|
270
|
+
"-g",
|
|
271
|
+
type=str,
|
|
272
|
+
help="Comma-separated list of GPU groups",
|
|
273
|
+
dest="gpu_groups",
|
|
274
|
+
)
|
|
275
|
+
parser.add_argument(
|
|
276
|
+
"--lora",
|
|
277
|
+
"-l",
|
|
278
|
+
nargs=2,
|
|
279
|
+
metavar=("LORA_NAME", "LORA_PATH"),
|
|
280
|
+
help="Name and path of the LoRA adapter",
|
|
281
|
+
)
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--served_model_name", type=str, help="Name of the served model"
|
|
284
|
+
)
|
|
285
|
+
parser.add_argument(
|
|
286
|
+
"--gpu_memory_utilization",
|
|
287
|
+
"-gmu",
|
|
288
|
+
type=float,
|
|
289
|
+
default=0.9,
|
|
290
|
+
help="GPU memory utilization",
|
|
291
|
+
)
|
|
292
|
+
parser.add_argument("--dtype", type=str, default="auto", help="Data type")
|
|
293
|
+
parser.add_argument(
|
|
294
|
+
"--max_model_len", "-mml", type=int, default=8192, help="Maximum model length"
|
|
295
|
+
)
|
|
296
|
+
parser.add_argument(
|
|
297
|
+
"--disable_lora",
|
|
298
|
+
dest="enable_lora",
|
|
299
|
+
action="store_false",
|
|
300
|
+
help="Disable LoRA support",
|
|
301
|
+
default=True,
|
|
302
|
+
)
|
|
303
|
+
parser.add_argument("--bnb", action="store_true", help="Enable quantization")
|
|
304
|
+
parser.add_argument(
|
|
305
|
+
"--not_verbose", action="store_true", help="Disable verbose logging"
|
|
306
|
+
)
|
|
307
|
+
parser.add_argument("--vllm_binary", type=str, help="Path to the vLLM binary")
|
|
308
|
+
parser.add_argument(
|
|
309
|
+
"--pipeline_parallel",
|
|
310
|
+
"-pp",
|
|
311
|
+
default=1,
|
|
312
|
+
type=int,
|
|
313
|
+
help="Number of pipeline parallel stages",
|
|
314
|
+
)
|
|
315
|
+
# parser.add_argument(
|
|
316
|
+
# "--extra_args",
|
|
317
|
+
# nargs=argparse.REMAINDER,
|
|
318
|
+
# help="Additional arguments for the serve command",
|
|
319
|
+
# )
|
|
320
|
+
parser.add_argument(
|
|
321
|
+
"--host_port",
|
|
322
|
+
"-hp",
|
|
323
|
+
type=str,
|
|
324
|
+
default="localhost:8150",
|
|
325
|
+
help="Host and port for the server format: host:port",
|
|
326
|
+
)
|
|
327
|
+
parser.add_argument("--eager", action="store_true", help="Enable eager execution")
|
|
328
|
+
parser.add_argument(
|
|
329
|
+
"--chat_template",
|
|
330
|
+
type=str,
|
|
331
|
+
help="Path to the chat template file",
|
|
332
|
+
)
|
|
333
|
+
parser.add_argument(
|
|
334
|
+
"--lora_modules",
|
|
335
|
+
"-lm",
|
|
336
|
+
nargs="+",
|
|
337
|
+
type=str,
|
|
338
|
+
help="List of LoRA modules in the format lora_name lora_module",
|
|
339
|
+
)
|
|
340
|
+
return parser.parse_args()
|
|
341
|
+
|
|
342
|
+
|
|
343
|
+
from speedy_utils import jloads, load_by_ext, memoize
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def fetch_chat_template(template_name: str = "qwen") -> str:
|
|
347
|
+
"""
|
|
348
|
+
Fetches a chat template file from a remote repository or local cache.
|
|
349
|
+
|
|
350
|
+
Args:
|
|
351
|
+
template_name (str): Name of the chat template. Defaults to 'qwen'.
|
|
352
|
+
|
|
353
|
+
Returns:
|
|
354
|
+
str: Path to the downloaded or cached chat template file.
|
|
355
|
+
|
|
356
|
+
Raises:
|
|
357
|
+
AssertionError: If the template_name is not supported.
|
|
358
|
+
ValueError: If the file URL is invalid.
|
|
359
|
+
"""
|
|
360
|
+
supported_templates = [
|
|
361
|
+
"alpaca",
|
|
362
|
+
"chatml",
|
|
363
|
+
"gemma-it",
|
|
364
|
+
"llama-2-chat",
|
|
365
|
+
"mistral-instruct",
|
|
366
|
+
"qwen2.5-instruct",
|
|
367
|
+
"saiga",
|
|
368
|
+
"vicuna",
|
|
369
|
+
"qwen",
|
|
370
|
+
]
|
|
371
|
+
assert template_name in supported_templates, (
|
|
372
|
+
f"Chat template '{template_name}' not supported. "
|
|
373
|
+
f"Please choose from {supported_templates}."
|
|
374
|
+
)
|
|
375
|
+
|
|
376
|
+
# Map 'qwen' to 'qwen2.5-instruct'
|
|
377
|
+
if template_name == "qwen":
|
|
378
|
+
template_name = "qwen2.5-instruct"
|
|
379
|
+
|
|
380
|
+
remote_url = (
|
|
381
|
+
f"https://raw.githubusercontent.com/chujiezheng/chat_templates/"
|
|
382
|
+
f"main/chat_templates/{template_name}.jinja"
|
|
383
|
+
)
|
|
384
|
+
local_cache_path = f"/tmp/chat_template_{template_name}.jinja"
|
|
385
|
+
|
|
386
|
+
if remote_url.startswith("http"):
|
|
387
|
+
import requests
|
|
388
|
+
|
|
389
|
+
response = requests.get(remote_url)
|
|
390
|
+
with open(local_cache_path, "w") as file:
|
|
391
|
+
file.write(response.text)
|
|
392
|
+
return local_cache_path
|
|
393
|
+
|
|
394
|
+
raise ValueError("The file URL must be a valid HTTP URL.")
|
|
395
|
+
|
|
396
|
+
|
|
397
|
+
def get_chat_template(template_name: str) -> str:
|
|
398
|
+
return fetch_chat_template(template_name)
|
|
399
|
+
|
|
400
|
+
|
|
401
|
+
def main():
|
|
402
|
+
"""Main entry point for the script."""
|
|
403
|
+
|
|
404
|
+
args = get_args()
|
|
405
|
+
|
|
406
|
+
if args.mode == "serve":
|
|
407
|
+
# Handle LoRA model serving via the new --lora argument
|
|
408
|
+
if args.lora:
|
|
409
|
+
lora_name, lora_path = args.lora
|
|
410
|
+
if not args.lora_modules:
|
|
411
|
+
args.lora_modules = [lora_name, lora_path]
|
|
412
|
+
# Try to get the model from LoRA config if not specified
|
|
413
|
+
if args.model is None:
|
|
414
|
+
lora_config = os.path.join(lora_path, "adapter_config.json")
|
|
415
|
+
if os.path.exists(lora_config):
|
|
416
|
+
config = load_by_ext(lora_config)
|
|
417
|
+
model_name = config.get("base_model_name_or_path")
|
|
418
|
+
# Handle different quantization suffixes
|
|
419
|
+
if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
|
|
420
|
+
model_name = model_name.replace("-unsloth-bnb-4bit", "")
|
|
421
|
+
elif model_name.endswith("-bnb-4bit") and not args.bnb:
|
|
422
|
+
model_name = model_name.replace("-bnb-4bit", "")
|
|
423
|
+
logger.info(f"Model name from LoRA config: {model_name}")
|
|
424
|
+
args.model = model_name
|
|
425
|
+
|
|
426
|
+
# Fall back to existing logic for other cases (already specified lora_modules)
|
|
427
|
+
if args.model is None and args.lora_modules is not None and not args.lora:
|
|
428
|
+
lora_config = os.path.join(args.lora_modules[1], "adapter_config.json")
|
|
429
|
+
if os.path.exists(lora_config):
|
|
430
|
+
config = load_by_ext(lora_config)
|
|
431
|
+
model_name = config.get("base_model_name_or_path")
|
|
432
|
+
if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
|
|
433
|
+
model_name = model_name.replace("-unsloth-bnb-4bit", "")
|
|
434
|
+
elif model_name.endswith("-bnb-4bit") and not args.bnb:
|
|
435
|
+
model_name = model_name.replace("-bnb-4bit", "")
|
|
436
|
+
logger.info(f"Model name from LoRA config: {model_name}")
|
|
437
|
+
args.model = model_name
|
|
438
|
+
# port_start from hostport
|
|
439
|
+
port_start = int(args.host_port.split(":")[-1])
|
|
440
|
+
serve(
|
|
441
|
+
args.model,
|
|
442
|
+
args.gpu_groups,
|
|
443
|
+
args.served_model_name,
|
|
444
|
+
port_start,
|
|
445
|
+
args.gpu_memory_utilization,
|
|
446
|
+
args.dtype,
|
|
447
|
+
args.max_model_len,
|
|
448
|
+
args.enable_lora,
|
|
449
|
+
args.bnb,
|
|
450
|
+
args.eager,
|
|
451
|
+
args.chat_template,
|
|
452
|
+
args.lora_modules,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
elif args.mode == "kill":
|
|
456
|
+
kill_existing_vllm(args.vllm_binary)
|
|
457
|
+
elif args.mode == "add_lora":
|
|
458
|
+
if args.lora:
|
|
459
|
+
lora_name, lora_path = args.lora
|
|
460
|
+
add_lora(lora_path, host_port=args.host_port, served_model_name=lora_name)
|
|
461
|
+
else:
|
|
462
|
+
# Fallback to old behavior
|
|
463
|
+
lora_name = args.model
|
|
464
|
+
add_lora(
|
|
465
|
+
lora_name,
|
|
466
|
+
host_port=args.host_port,
|
|
467
|
+
served_model_name=args.served_model_name,
|
|
468
|
+
)
|
|
469
|
+
elif args.mode == "unload_lora":
|
|
470
|
+
if args.lora:
|
|
471
|
+
lora_name = args.lora[0]
|
|
472
|
+
else:
|
|
473
|
+
lora_name = args.model
|
|
474
|
+
unload_lora(lora_name, host_port=args.host_port)
|
|
475
|
+
elif args.mode == "list_models":
|
|
476
|
+
model_list(args.host_port)
|
|
477
|
+
else:
|
|
478
|
+
raise ValueError(f"Unknown mode: {args.mode}, ")
|
|
479
|
+
|
|
480
|
+
|
|
481
|
+
if __name__ == "__main__":
|
|
482
|
+
main()
|
speedy_utils/__init__.py
CHANGED
|
@@ -41,7 +41,7 @@ from .common.utils_print import (
|
|
|
41
41
|
|
|
42
42
|
# Multi-worker processing
|
|
43
43
|
from .multi_worker.process import multi_process
|
|
44
|
-
from .multi_worker.thread import
|
|
44
|
+
from .multi_worker.thread import multi_thread
|
|
45
45
|
|
|
46
46
|
# Define __all__ explicitly
|
|
47
47
|
__all__ = [
|
|
@@ -79,7 +79,6 @@ __all__ = [
|
|
|
79
79
|
# Multi-worker processing
|
|
80
80
|
"multi_process",
|
|
81
81
|
"multi_thread",
|
|
82
|
-
"multi_threaad_standard",
|
|
83
82
|
]
|
|
84
83
|
|
|
85
84
|
# Setup default logger
|
speedy_utils/all.py
CHANGED
|
@@ -64,7 +64,6 @@ from speedy_utils import ( # Clock module; Function decorators; Cache utilities
|
|
|
64
64
|
memoize,
|
|
65
65
|
mkdir_or_exist,
|
|
66
66
|
multi_process,
|
|
67
|
-
multi_threaad_standard,
|
|
68
67
|
multi_thread,
|
|
69
68
|
print_table,
|
|
70
69
|
retry_runtime,
|
|
@@ -157,5 +156,4 @@ __all__ = [
|
|
|
157
156
|
# Multi-worker processing
|
|
158
157
|
"multi_process",
|
|
159
158
|
"multi_thread",
|
|
160
|
-
"multi_threaad_standard",
|
|
161
159
|
]
|
speedy_utils/common/clock.py
CHANGED
|
@@ -96,6 +96,11 @@ class Clock:
|
|
|
96
96
|
)
|
|
97
97
|
return
|
|
98
98
|
current_time = time.time()
|
|
99
|
+
if self.last_checkpoint is None:
|
|
100
|
+
logger.opt(depth=2).warning(
|
|
101
|
+
"Last checkpoint is not set. Please call start() before using this method."
|
|
102
|
+
)
|
|
103
|
+
return
|
|
99
104
|
elapsed = current_time - self.last_checkpoint
|
|
100
105
|
self.last_checkpoint = current_time
|
|
101
106
|
return elapsed
|
|
@@ -111,6 +116,11 @@ class Clock:
|
|
|
111
116
|
"Timer has not been started. Please call start() before using this method."
|
|
112
117
|
)
|
|
113
118
|
return
|
|
119
|
+
if self.last_checkpoint is None:
|
|
120
|
+
logger.opt(depth=2).warning(
|
|
121
|
+
"Last checkpoint is not set. Please call start() before using this method."
|
|
122
|
+
)
|
|
123
|
+
return
|
|
114
124
|
return time.time() - self.last_checkpoint
|
|
115
125
|
|
|
116
126
|
def update_task(self, task_name):
|
|
@@ -1,13 +1,12 @@
|
|
|
1
1
|
"""Provides thread-based parallel execution utilities."""
|
|
2
2
|
|
|
3
|
-
import inspect
|
|
4
3
|
import os
|
|
5
4
|
import time
|
|
6
5
|
import traceback
|
|
7
6
|
from collections.abc import Callable, Iterable
|
|
8
7
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
9
8
|
from itertools import islice
|
|
10
|
-
from typing import Any, TypeVar
|
|
9
|
+
from typing import Any, TypeVar
|
|
11
10
|
|
|
12
11
|
from loguru import logger
|
|
13
12
|
|
|
@@ -80,7 +79,7 @@ def multi_thread(
|
|
|
80
79
|
``False`` the failing task’s result becomes ``None``.
|
|
81
80
|
**fixed_kwargs – static keyword args forwarded to every ``func()`` call.
|
|
82
81
|
"""
|
|
83
|
-
from speedy_utils import dump_json_or_pickle,
|
|
82
|
+
from speedy_utils import dump_json_or_pickle, load_by_ext
|
|
84
83
|
|
|
85
84
|
if n_proc > 1:
|
|
86
85
|
import tempfile
|
|
@@ -297,10 +296,27 @@ def multi_thread(
|
|
|
297
296
|
return results
|
|
298
297
|
|
|
299
298
|
|
|
300
|
-
def
|
|
301
|
-
|
|
299
|
+
def multi_thread_standard(
|
|
300
|
+
fn: Callable[[Any], Any], items: Iterable[Any], workers: int = 4
|
|
301
|
+
) -> list[Any]:
|
|
302
|
+
"""Execute a function using standard ThreadPoolExecutor.
|
|
303
|
+
|
|
302
304
|
A standard implementation of multi-threading using ThreadPoolExecutor.
|
|
303
305
|
Ensures the order of results matches the input order.
|
|
306
|
+
|
|
307
|
+
Parameters
|
|
308
|
+
----------
|
|
309
|
+
fn : Callable
|
|
310
|
+
The function to execute for each item.
|
|
311
|
+
items : Iterable
|
|
312
|
+
The items to process.
|
|
313
|
+
workers : int, optional
|
|
314
|
+
Number of worker threads, by default 4.
|
|
315
|
+
|
|
316
|
+
Returns
|
|
317
|
+
-------
|
|
318
|
+
list
|
|
319
|
+
Results in same order as input items.
|
|
304
320
|
"""
|
|
305
321
|
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
306
322
|
futures = [executor.submit(fn, item) for item in items]
|
|
@@ -308,4 +324,4 @@ def multi_threaad_standard(fn, items, workers=4):
|
|
|
308
324
|
return results
|
|
309
325
|
|
|
310
326
|
|
|
311
|
-
__all__ = ["multi_thread", "
|
|
327
|
+
__all__ = ["multi_thread", "multi_thread_standard"]
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.3
|
|
2
2
|
Name: speedy-utils
|
|
3
|
-
Version: 0.
|
|
3
|
+
Version: 1.0.0
|
|
4
4
|
Summary: Fast and easy-to-use package for data science
|
|
5
5
|
Author: AnhVTH
|
|
6
6
|
Author-email: anhvth.226@gmail.com
|
|
@@ -19,11 +19,12 @@ Requires-Dist: fastprogress
|
|
|
19
19
|
Requires-Dist: freezegun (>=1.5.1,<2.0.0)
|
|
20
20
|
Requires-Dist: ipdb
|
|
21
21
|
Requires-Dist: ipywidgets
|
|
22
|
-
Requires-Dist: json-repair
|
|
22
|
+
Requires-Dist: json-repair (>=0.40.0,<0.41.0)
|
|
23
23
|
Requires-Dist: jupyterlab
|
|
24
24
|
Requires-Dist: loguru
|
|
25
25
|
Requires-Dist: matplotlib
|
|
26
26
|
Requires-Dist: numpy
|
|
27
|
+
Requires-Dist: packaging (>=23.2,<25)
|
|
27
28
|
Requires-Dist: pandas
|
|
28
29
|
Requires-Dist: pydantic
|
|
29
30
|
Requires-Dist: requests
|
|
@@ -261,29 +262,6 @@ Ensure all dependencies are installed before running tests:
|
|
|
261
262
|
pip install -r requirements.txt
|
|
262
263
|
```
|
|
263
264
|
|
|
264
|
-
## Data Arguments
|
|
265
|
-
|
|
266
|
-
Define and parse data arguments using a dataclass:
|
|
267
|
-
|
|
268
|
-
```python
|
|
269
|
-
from dataclasses import dataclass
|
|
270
|
-
from speedy_utils.common.dataclass_parser import ArgsParser
|
|
271
|
-
|
|
272
|
-
@dataclass
|
|
273
|
-
class ExampleArgs(ArgsParser):
|
|
274
|
-
from_peft: str = "./outputs/llm_hn_qw32b/hn_results_r3/"
|
|
275
|
-
model_name_or_path: str = "Qwen/Qwen2.5-32B-Instruct-AWQ"
|
|
276
|
-
use_fp16: bool = False
|
|
277
|
-
batch_size: int = 1
|
|
278
|
-
max_length: int = 512
|
|
279
|
-
cache_dir: str = ".cache/run_embeds"
|
|
280
|
-
output_dir: str = ".cache"
|
|
281
|
-
input_file: str = ".cache/doc.csv"
|
|
282
|
-
output_name: str = "qw32b_r3"
|
|
283
|
-
|
|
284
|
-
args = ExampleArgs.parse_args()
|
|
285
|
-
print(args)
|
|
286
|
-
```
|
|
287
265
|
|
|
288
266
|
Run the script to parse and display the arguments:
|
|
289
267
|
|
|
@@ -299,5 +277,3 @@ Example output:
|
|
|
299
277
|
|
|
300
278
|
Please ensure your code adheres to the project's coding standards and includes appropriate tests.
|
|
301
279
|
|
|
302
|
-
|
|
303
|
-
|
|
@@ -0,0 +1,27 @@
|
|
|
1
|
+
llm_utils/__init__.py,sha256=4IO_cP7FlFvV57W7_AkJGdXWcZfWALkr6UwR9s9QEF4,701
|
|
2
|
+
llm_utils/chat_format.py,sha256=ZY2HYv3FPL2xiMxbbO-huIwT5LZrcJm_if_us-2eSZ4,15094
|
|
3
|
+
llm_utils/group_messages.py,sha256=GKMQkenQf-6DD_1EJa11UBj7-VfkGT7xVhR_B_zMzqY,3868
|
|
4
|
+
llm_utils/lm.py,sha256=3d3b9UMtIeKcG6vpXHwuDZ3QP46X1aSGeMvayy3tmHs,29044
|
|
5
|
+
llm_utils/lm_classification.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
6
|
+
llm_utils/load_chat_dataset.py,sha256=hsPPlOmZEDqsg7GQD7SgxTEGJky6JDm6jnG3JHbpjb4,1895
|
|
7
|
+
llm_utils/scripts/vllm_load_balancer.py,sha256=uSjGd_jOmI9W9eVOhiOXbeUnZkQq9xG4bCVzhmpupcA,16096
|
|
8
|
+
llm_utils/scripts/vllm_serve.py,sha256=-gXF3DVs7sfZ4mVmv_0OFheqytMQbr73YkS3bQlFdpE,15818
|
|
9
|
+
speedy_utils/__init__.py,sha256=I2bSfDIE9yRF77tnHW0vqfExDA2m1gUx4AH8C9XmGtg,1707
|
|
10
|
+
speedy_utils/all.py,sha256=A9jiKGjo950eg1pscS9x38OWAjKGyusoAN5mrfweY4E,3090
|
|
11
|
+
speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
speedy_utils/common/clock.py,sha256=3n4FkCW0dz46O8By09V5Pve1DSMgpLDRbWEVRryryeQ,7423
|
|
13
|
+
speedy_utils/common/function_decorator.py,sha256=r_r42qCWuNcu0_aH7musf2BWvcJfgZrD81G28mDcolw,2226
|
|
14
|
+
speedy_utils/common/logger.py,sha256=NIOlhhACpcc0BUTSJ8oDYrLp23J2gW_KJFyRVdLN2tY,6432
|
|
15
|
+
speedy_utils/common/report_manager.py,sha256=dgGfS_fHbZiQMsLzkgnj0OfB758t1x6B4MhjsetSl9A,3930
|
|
16
|
+
speedy_utils/common/utils_cache.py,sha256=gXX5qTXpCG3qDgjnOsSvxM4LkQurmcsg4QRv_zOBG1E,8378
|
|
17
|
+
speedy_utils/common/utils_io.py,sha256=vXhgrMSse_5yuP7yiSjdqKgOu8pz983glelquyNjbkE,4809
|
|
18
|
+
speedy_utils/common/utils_misc.py,sha256=nsQOu2jcplcFHVQ1CnOjEpNcdxIINveGxB493Cqo63U,1812
|
|
19
|
+
speedy_utils/common/utils_print.py,sha256=QRaL2QPbks4Mtol_gJy3ZdahgUfzUEtcOp4--lBlzYI,6709
|
|
20
|
+
speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
|
+
speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
|
|
22
|
+
speedy_utils/multi_worker/thread.py,sha256=9pXjvgjD0s0Hp0cZ6I3M0ndp1OlYZ1yvqbs_bcun_Kw,12775
|
|
23
|
+
speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
|
|
24
|
+
speedy_utils-1.0.0.dist-info/METADATA,sha256=nIcZcNBPA9_eUs7on52UMVtcV43lmVFRJLHLptk5AA0,7165
|
|
25
|
+
speedy_utils-1.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
26
|
+
speedy_utils-1.0.0.dist-info/entry_points.txt,sha256=fsv8_lMg62BeswoUHrqfj2u6q2l4YcDCw7AgQFg6GRw,61
|
|
27
|
+
speedy_utils-1.0.0.dist-info/RECORD,,
|