speedy-utils 1.0.4__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,416 @@
1
+ """ "
2
+ USAGE:
3
+ Serve models and LoRAs with vLLM:
4
+
5
+ Serve a LoRA model:
6
+ svllm serve --lora LORA_NAME LORA_PATH --gpus GPU_GROUPS
7
+
8
+ Serve a base model:
9
+ svllm serve --model MODEL_NAME --gpus GPU_GROUPS
10
+
11
+ Add a LoRA to a served model:
12
+ svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port (if add then the port must be specify)
13
+ """
14
+
15
+ from glob import glob
16
+ import os
17
+ import subprocess
18
+ import time
19
+ from typing import List, Literal, Optional
20
+ from fastcore.script import call_parse
21
+ from loguru import logger
22
+ import argparse
23
+ import requests
24
+ import openai
25
+
26
+ from speedy_utils.common.utils_io import load_by_ext
27
+
28
+
29
+ LORA_DIR: str = os.environ.get("LORA_DIR", "/loras")
30
+ LORA_DIR = os.path.abspath(LORA_DIR)
31
+ HF_HOME: str = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
32
+ logger.info(f"LORA_DIR: {LORA_DIR}")
33
+
34
+
35
+ def model_list(host_port, api_key="abc"):
36
+ client = openai.OpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
37
+ models = client.models.list()
38
+ for model in models:
39
+ print(f"Model ID: {model.id}")
40
+
41
+
42
+ def kill_existing_vllm(vllm_binary: Optional[str] = None) -> None:
43
+ """Kill selected vLLM processes using fzf."""
44
+ if not vllm_binary:
45
+ vllm_binary = get_vllm()
46
+
47
+ # List running vLLM processes
48
+ result = subprocess.run(
49
+ f"ps aux | grep {vllm_binary} | grep -v grep",
50
+ shell=True,
51
+ capture_output=True,
52
+ text=True,
53
+ )
54
+ processes = result.stdout.strip().split("\n")
55
+
56
+ if not processes or processes == [""]:
57
+ print("No running vLLM processes found.")
58
+ return
59
+
60
+ # Use fzf to select processes to kill
61
+ fzf = subprocess.Popen(
62
+ ["fzf", "--multi"],
63
+ stdin=subprocess.PIPE,
64
+ stdout=subprocess.PIPE,
65
+ text=True,
66
+ )
67
+ selected, _ = fzf.communicate("\n".join(processes))
68
+
69
+ if not selected:
70
+ print("No processes selected.")
71
+ return
72
+
73
+ # Extract PIDs and kill selected processes
74
+ pids = [line.split()[1] for line in selected.strip().split("\n")]
75
+ for pid in pids:
76
+ subprocess.run(
77
+ f"kill -9 {pid}",
78
+ shell=True,
79
+ stdout=subprocess.DEVNULL,
80
+ stderr=subprocess.DEVNULL,
81
+ )
82
+ print(f"Killed processes: {', '.join(pids)}")
83
+
84
+
85
+ def add_lora(
86
+ lora_name_or_path: str,
87
+ host_port: str,
88
+ url: str = "http://HOST:PORT/v1/load_lora_adapter",
89
+ served_model_name: Optional[str] = None,
90
+ lora_module: Optional[str] = None, # Added parameter
91
+ ) -> dict:
92
+ url = url.replace("HOST:PORT", host_port)
93
+ headers = {"Content-Type": "application/json"}
94
+
95
+ data = {
96
+ "lora_name": served_model_name,
97
+ "lora_path": os.path.abspath(lora_name_or_path),
98
+ }
99
+ if lora_module: # Include lora_module if provided
100
+ data["lora_module"] = lora_module
101
+ logger.info(f"{data=}, {headers}, {url=}")
102
+ # logger.warning(f"Failed to unload LoRA adapter: {str(e)}")
103
+ try:
104
+ response = requests.post(url, headers=headers, json=data)
105
+ response.raise_for_status()
106
+
107
+ # Handle potential non-JSON responses
108
+ try:
109
+ return response.json()
110
+ except ValueError:
111
+ return {
112
+ "status": "success",
113
+ "message": (
114
+ response.text
115
+ if response.text.strip()
116
+ else "Request completed with empty response"
117
+ ),
118
+ }
119
+
120
+ except requests.exceptions.RequestException as e:
121
+ logger.error(f"Request failed: {str(e)}")
122
+ return {"error": f"Request failed: {str(e)}"}
123
+
124
+
125
+ def unload_lora(lora_name, host_port):
126
+ try:
127
+ url = f"http://{host_port}/v1/unload_lora_adapter"
128
+ logger.info(f"{url=}")
129
+ headers = {"Content-Type": "application/json"}
130
+ data = {"lora_name": lora_name}
131
+ logger.info(f"Unloading LoRA adapter: {data=}")
132
+ response = requests.post(url, headers=headers, json=data)
133
+ response.raise_for_status()
134
+ logger.success(f"Unloaded LoRA adapter: {lora_name}")
135
+ except requests.exceptions.RequestException as e:
136
+ return {"error": f"Request failed: {str(e)}"}
137
+
138
+
139
+ def serve(
140
+ model: str,
141
+ gpu_groups: str,
142
+ served_model_name: Optional[str] = None,
143
+ port_start: int = 8155,
144
+ gpu_memory_utilization: float = 0.93,
145
+ dtype: str = "bfloat16",
146
+ max_model_len: int = 8192,
147
+ enable_lora: bool = False,
148
+ is_bnb: bool = False,
149
+ eager: bool = False,
150
+ lora_modules: Optional[List[str]] = None, # Updated type
151
+ ) -> None:
152
+ """Main function to start or kill vLLM containers."""
153
+
154
+ """Start vLLM containers with dynamic args."""
155
+ print("Starting vLLM containers...,")
156
+ gpu_groups_arr: List[str] = gpu_groups.split(",")
157
+ VLLM_BINARY: str = get_vllm()
158
+ if enable_lora:
159
+ VLLM_BINARY = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + VLLM_BINARY
160
+
161
+ # Auto-detect quantization based on model name if not explicitly set
162
+ if not is_bnb and model and ("bnb" in model.lower() or "4bit" in model.lower()):
163
+ is_bnb = True
164
+ print(f"Auto-detected quantization for model: {model}")
165
+
166
+ # Set environment variables for LoRA if needed
167
+ if enable_lora:
168
+ os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
169
+ print("Enabled runtime LoRA updating")
170
+
171
+ for i, gpu_group in enumerate(gpu_groups_arr):
172
+ port = port_start + i
173
+ gpu_group = ",".join([str(x) for x in gpu_group])
174
+ tensor_parallel = len(gpu_group.split(","))
175
+
176
+ cmd = [
177
+ f"CUDA_VISIBLE_DEVICES={gpu_group}",
178
+ VLLM_BINARY,
179
+ "serve",
180
+ model,
181
+ "--port",
182
+ str(port),
183
+ "--tensor-parallel",
184
+ str(tensor_parallel),
185
+ "--gpu-memory-utilization",
186
+ str(gpu_memory_utilization),
187
+ "--dtype",
188
+ dtype,
189
+ "--max-model-len",
190
+ str(max_model_len),
191
+ "--enable-prefix-caching",
192
+ "--disable-log-requests",
193
+ "--uvicorn-log-level critical",
194
+ ]
195
+ if HF_HOME:
196
+ # insert
197
+ cmd.insert(0, f"HF_HOME={HF_HOME}")
198
+ if eager:
199
+ cmd.append("--enforce-eager")
200
+
201
+ if served_model_name:
202
+ cmd.extend(["--served-model-name", served_model_name])
203
+
204
+ if is_bnb:
205
+ cmd.extend(
206
+ ["--quantization", "bitsandbytes", "--load-format", "bitsandbytes"]
207
+ )
208
+
209
+ if enable_lora:
210
+ cmd.extend(["--fully-sharded-loras", "--enable-lora"])
211
+
212
+ if lora_modules:
213
+ # for lora_module in lora_modules:
214
+ # len must be even and we will join tuple with `=`
215
+ assert len(lora_modules) % 2 == 0, "lora_modules must be even"
216
+ # lora_modulle = [f'{name}={module}' for name, module in zip(lora_module[::2], lora_module[1::2])]
217
+ # import ipdb;ipdb.set_trace()
218
+ s = ""
219
+ for i in range(0, len(lora_modules), 2):
220
+ name = lora_modules[i]
221
+ module = lora_modules[i + 1]
222
+ s += f"{name}={module} "
223
+
224
+ cmd.extend(["--lora-modules", s])
225
+ # add kwargs
226
+ final_cmd = " ".join(cmd)
227
+ log_file = f"/tmp/vllm_{port}.txt"
228
+ final_cmd_with_log = f'"{final_cmd} 2>&1 | tee {log_file}"'
229
+ run_in_tmux = (
230
+ f"tmux new-session -d -s vllm_{port} 'bash -c {final_cmd_with_log}'"
231
+ )
232
+
233
+ print(final_cmd)
234
+ print("Logging to", log_file)
235
+ os.system(run_in_tmux)
236
+
237
+
238
+ def get_vllm():
239
+ VLLM_BINARY = subprocess.check_output("which vllm", shell=True, text=True).strip()
240
+ VLLM_BINARY = os.getenv("VLLM_BINARY", VLLM_BINARY)
241
+ logger.info(f"vLLM binary: {VLLM_BINARY}")
242
+ assert os.path.exists(
243
+ VLLM_BINARY
244
+ ), f"vLLM binary not found at {VLLM_BINARY}, please set VLLM_BINARY env variable"
245
+ return VLLM_BINARY
246
+
247
+
248
+ def get_args():
249
+ """Parse command line arguments."""
250
+ example_args = [
251
+ "svllm serve --model MODEL_NAME --gpus 0,1,2,3",
252
+ "svllm serve --lora LORA_NAME LORA_PATH --gpus 0,1,2,3",
253
+ "svllm add_lora --lora LORA_NAME LORA_PATH --host_port localhost:8150",
254
+ "svllm kill",
255
+ ]
256
+
257
+ parser = argparse.ArgumentParser(
258
+ description="vLLM Serve Script", epilog="Example: " + " || ".join(example_args)
259
+ )
260
+ parser.add_argument(
261
+ "mode",
262
+ choices=["serve", "kill", "add_lora", "unload_lora", "list_models"],
263
+ help="Mode to run the script in",
264
+ )
265
+ parser.add_argument("--model", "-m", type=str, help="Model to serve")
266
+ parser.add_argument(
267
+ "--gpus",
268
+ "-g",
269
+ type=str,
270
+ help="Comma-separated list of GPU groups",
271
+ dest="gpu_groups",
272
+ )
273
+ parser.add_argument(
274
+ "--lora",
275
+ "-l",
276
+ nargs=2,
277
+ metavar=("LORA_NAME", "LORA_PATH"),
278
+ help="Name and path of the LoRA adapter",
279
+ )
280
+ parser.add_argument(
281
+ "--served_model_name", type=str, help="Name of the served model"
282
+ )
283
+ parser.add_argument(
284
+ "--gpu_memory_utilization",
285
+ "-gmu",
286
+ type=float,
287
+ default=0.9,
288
+ help="GPU memory utilization",
289
+ )
290
+ parser.add_argument("--dtype", type=str, default="auto", help="Data type")
291
+ parser.add_argument(
292
+ "--max_model_len", "-mml", type=int, default=8192, help="Maximum model length"
293
+ )
294
+ parser.add_argument(
295
+ "--disable_lora",
296
+ dest="enable_lora",
297
+ action="store_false",
298
+ help="Disable LoRA support",
299
+ default=True,
300
+ )
301
+ parser.add_argument("--bnb", action="store_true", help="Enable quantization")
302
+ parser.add_argument(
303
+ "--not_verbose", action="store_true", help="Disable verbose logging"
304
+ )
305
+ parser.add_argument("--vllm_binary", type=str, help="Path to the vLLM binary")
306
+ parser.add_argument(
307
+ "--pipeline_parallel",
308
+ "-pp",
309
+ default=1,
310
+ type=int,
311
+ help="Number of pipeline parallel stages",
312
+ )
313
+ parser.add_argument(
314
+ "--extra_args",
315
+ nargs=argparse.REMAINDER,
316
+ help="Additional arguments for the serve command",
317
+ )
318
+ parser.add_argument(
319
+ "--host_port",
320
+ "-hp",
321
+ type=str,
322
+ default="localhost:8150",
323
+ help="Host and port for the server format: host:port",
324
+ )
325
+ parser.add_argument("--eager", action="store_true", help="Enable eager execution")
326
+ parser.add_argument(
327
+ "--lora_modules",
328
+ "-lm",
329
+ nargs="+",
330
+ type=str,
331
+ help="List of LoRA modules in the format lora_name lora_module",
332
+ )
333
+ return parser.parse_args()
334
+
335
+
336
+ def main():
337
+ """Main entry point for the script."""
338
+
339
+ args = get_args()
340
+
341
+ if args.mode == "serve":
342
+ # Handle LoRA model serving via the new --lora argument
343
+ if args.lora:
344
+ lora_name, lora_path = args.lora
345
+ if not args.lora_modules:
346
+ args.lora_modules = [lora_name, lora_path]
347
+ # Try to get the model from LoRA config if not specified
348
+ if args.model is None:
349
+ lora_config = os.path.join(lora_path, "adapter_config.json")
350
+ if os.path.exists(lora_config):
351
+ config = load_by_ext(lora_config)
352
+ model_name = config.get("base_model_name_or_path")
353
+ # Handle different quantization suffixes
354
+ if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
355
+ model_name = model_name.replace("-unsloth-bnb-4bit", "")
356
+ elif model_name.endswith("-bnb-4bit") and not args.bnb:
357
+ model_name = model_name.replace("-bnb-4bit", "")
358
+ logger.info(f"Model name from LoRA config: {model_name}")
359
+ args.model = model_name
360
+
361
+ # Fall back to existing logic for other cases (already specified lora_modules)
362
+ if args.model is None and args.lora_modules is not None and not args.lora:
363
+ lora_config = os.path.join(args.lora_modules[1], "adapter_config.json")
364
+ if os.path.exists(lora_config):
365
+ config = load_by_ext(lora_config)
366
+ model_name = config.get("base_model_name_or_path")
367
+ if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
368
+ model_name = model_name.replace("-unsloth-bnb-4bit", "")
369
+ elif model_name.endswith("-bnb-4bit") and not args.bnb:
370
+ model_name = model_name.replace("-bnb-4bit", "")
371
+ logger.info(f"Model name from LoRA config: {model_name}")
372
+ args.model = model_name
373
+ # port_start from hostport
374
+ port_start = int(args.host_port.split(":")[-1])
375
+ serve(
376
+ args.model,
377
+ args.gpu_groups,
378
+ args.served_model_name,
379
+ port_start,
380
+ args.gpu_memory_utilization,
381
+ args.dtype,
382
+ args.max_model_len,
383
+ args.enable_lora,
384
+ args.bnb,
385
+ args.eager,
386
+ args.lora_modules,
387
+ )
388
+
389
+ elif args.mode == "kill":
390
+ kill_existing_vllm(args.vllm_binary)
391
+ elif args.mode == "add_lora":
392
+ if args.lora:
393
+ lora_name, lora_path = args.lora
394
+ add_lora(lora_path, host_port=args.host_port, served_model_name=lora_name)
395
+ else:
396
+ # Fallback to old behavior
397
+ lora_name = args.model
398
+ add_lora(
399
+ lora_name,
400
+ host_port=args.host_port,
401
+ served_model_name=args.served_model_name,
402
+ )
403
+ elif args.mode == "unload_lora":
404
+ if args.lora:
405
+ lora_name = args.lora[0]
406
+ else:
407
+ lora_name = args.model
408
+ unload_lora(lora_name, host_port=args.host_port)
409
+ elif args.mode == "list_models":
410
+ model_list(args.host_port)
411
+ else:
412
+ raise ValueError(f"Unknown mode: {args.mode}, ")
413
+
414
+
415
+ if __name__ == "__main__":
416
+ main()
@@ -0,0 +1,85 @@
1
+ # Import specific functions and classes from modules
2
+ # Logger
3
+ from speedy_utils.common.logger import log, setup_logger
4
+
5
+ # Clock module
6
+ from .common.clock import Clock, speedy_timer, timef
7
+
8
+ # Function decorators
9
+ from .common.function_decorator import retry_runtime
10
+
11
+ # Cache utilities
12
+ from .common.utils_cache import identify, identify_uuid, memoize
13
+
14
+ # IO utilities
15
+ from .common.utils_io import (
16
+ dump_json_or_pickle,
17
+ dump_jsonl,
18
+ jdumps,
19
+ jloads,
20
+ load_by_ext,
21
+ load_json_or_pickle,
22
+ load_jsonl,
23
+ )
24
+
25
+ # Misc utilities
26
+ from .common.utils_misc import (
27
+ convert_to_builtin_python,
28
+ flatten_list,
29
+ get_arg_names,
30
+ is_notebook,
31
+ mkdir_or_exist,
32
+ )
33
+
34
+ # Print utilities
35
+ from .common.utils_print import (
36
+ display_pretty_table_html,
37
+ flatten_dict,
38
+ fprint,
39
+ print_table,
40
+ )
41
+
42
+ # Multi-worker processing
43
+ from .multi_worker.process import multi_process
44
+ from .multi_worker.thread import multi_thread
45
+
46
+ # Define __all__ explicitly
47
+ __all__ = [
48
+ # Clock module
49
+ "Clock",
50
+ "speedy_timer",
51
+ "timef",
52
+ # Function decorators
53
+ "retry_runtime",
54
+ # Cache utilities
55
+ "memoize",
56
+ "identify",
57
+ "identify_uuid",
58
+ # IO utilities
59
+ "dump_json_or_pickle",
60
+ "dump_jsonl",
61
+ "load_by_ext",
62
+ "load_json_or_pickle",
63
+ "load_jsonl",
64
+ "jdumps",
65
+ "jloads",
66
+ # Misc utilities
67
+ "mkdir_or_exist",
68
+ "flatten_list",
69
+ "get_arg_names",
70
+ "is_notebook",
71
+ "convert_to_builtin_python",
72
+ # Print utilities
73
+ "display_pretty_table_html",
74
+ "flatten_dict",
75
+ "fprint",
76
+ "print_table",
77
+ "setup_logger",
78
+ "log",
79
+ # Multi-worker processing
80
+ "multi_process",
81
+ "multi_thread",
82
+ ]
83
+
84
+ # Setup default logger
85
+ # setup_logger('D')
speedy_utils/all.py ADDED
@@ -0,0 +1,159 @@
1
+ # speedy_utils/all.py
2
+
3
+ # Provide a consolidated set of imports for convenience
4
+
5
+ # Standard library imports
6
+ import copy
7
+ import functools
8
+ import gc
9
+ import inspect
10
+ import json
11
+ import multiprocessing
12
+ import os
13
+ import os.path as osp
14
+ import pickle
15
+ import pprint
16
+ import random
17
+ import re
18
+ import sys
19
+ import textwrap
20
+ import threading
21
+ import time
22
+ import traceback
23
+ import uuid
24
+ from collections import Counter, defaultdict
25
+ from collections.abc import Callable
26
+ from concurrent.futures import ThreadPoolExecutor, as_completed
27
+ from glob import glob
28
+ from multiprocessing import Pool
29
+ from pathlib import Path
30
+ from threading import Lock
31
+ from typing import Any, Dict, Generic, List, Literal, Optional, TypeVar, Union
32
+
33
+ # Third-party imports
34
+ import numpy as np
35
+ import pandas as pd
36
+ import xxhash
37
+ from IPython.core.getipython import get_ipython
38
+ from IPython.display import HTML, display
39
+ from loguru import logger
40
+ from pydantic import BaseModel
41
+ from tabulate import tabulate
42
+ from tqdm import tqdm
43
+
44
+ # Import specific functions from speedy_utils
45
+ from speedy_utils import ( # Clock module; Function decorators; Cache utilities; IO utilities; Misc utilities; Print utilities; Multi-worker processing
46
+ Clock,
47
+ convert_to_builtin_python,
48
+ display_pretty_table_html,
49
+ dump_json_or_pickle,
50
+ dump_jsonl,
51
+ flatten_dict,
52
+ flatten_list,
53
+ fprint,
54
+ get_arg_names,
55
+ identify,
56
+ identify_uuid,
57
+ is_notebook,
58
+ jdumps,
59
+ jloads,
60
+ load_by_ext,
61
+ load_json_or_pickle,
62
+ load_jsonl,
63
+ log,
64
+ memoize,
65
+ mkdir_or_exist,
66
+ multi_process,
67
+ multi_thread,
68
+ print_table,
69
+ retry_runtime,
70
+ setup_logger,
71
+ speedy_timer,
72
+ timef,
73
+ )
74
+
75
+ # Define __all__ explicitly with all exports
76
+ __all__ = [
77
+ # Standard library
78
+ "random",
79
+ "copy",
80
+ "functools",
81
+ "gc",
82
+ "inspect",
83
+ "json",
84
+ "multiprocessing",
85
+ "os",
86
+ "osp",
87
+ "pickle",
88
+ "pprint",
89
+ "re",
90
+ "sys",
91
+ "textwrap",
92
+ "threading",
93
+ "time",
94
+ "traceback",
95
+ "uuid",
96
+ "Counter",
97
+ "ThreadPoolExecutor",
98
+ "as_completed",
99
+ "glob",
100
+ "Pool",
101
+ "Path",
102
+ "Lock",
103
+ "defaultdict",
104
+ # Typing
105
+ "Any",
106
+ "Callable",
107
+ "Dict",
108
+ "Generic",
109
+ "List",
110
+ "Literal",
111
+ "Optional",
112
+ "TypeVar",
113
+ "Union",
114
+ # Third-party
115
+ "pd",
116
+ "xxhash",
117
+ "get_ipython",
118
+ "HTML",
119
+ "display",
120
+ "logger",
121
+ "BaseModel",
122
+ "tabulate",
123
+ "tqdm",
124
+ "np",
125
+ # Clock module
126
+ "Clock",
127
+ "speedy_timer",
128
+ "timef",
129
+ # Function decorators
130
+ "retry_runtime",
131
+ # Cache utilities
132
+ "memoize",
133
+ "identify",
134
+ "identify_uuid",
135
+ # IO utilities
136
+ "dump_json_or_pickle",
137
+ "dump_jsonl",
138
+ "load_by_ext",
139
+ "load_json_or_pickle",
140
+ "load_jsonl",
141
+ "jdumps",
142
+ "jloads",
143
+ # Misc utilities
144
+ "mkdir_or_exist",
145
+ "flatten_list",
146
+ "get_arg_names",
147
+ "is_notebook",
148
+ "convert_to_builtin_python",
149
+ # Print utilities
150
+ "display_pretty_table_html",
151
+ "flatten_dict",
152
+ "fprint",
153
+ "print_table",
154
+ "setup_logger",
155
+ "log",
156
+ # Multi-worker processing
157
+ "multi_process",
158
+ "multi_thread",
159
+ ]
File without changes