speedy-utils 0.1.28__py3-none-any.whl → 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,482 @@
1
+ """ "
2
+ USAGE:
3
+ Serve models and LoRAs with vLLM:
4
+
5
+ Serve a LoRA model:
6
+ svllm serve --lora LORA_NAME LORA_PATH --gpus GPU_GROUPS
7
+
8
+ Serve a base model:
9
+ svllm serve --model MODEL_NAME --gpus GPU_GROUPS
10
+
11
+ Add a LoRA to a served model:
12
+ svllm add-lora --lora LORA_NAME LORA_PATH --host_port host:port (if add then the port must be specify)
13
+ """
14
+
15
+ from glob import glob
16
+ import os
17
+ import subprocess
18
+ import time
19
+ from typing import List, Literal, Optional
20
+ from fastcore.script import call_parse
21
+ from loguru import logger
22
+ import argparse
23
+ import requests
24
+ import openai
25
+
26
+
27
+ LORA_DIR = os.environ.get("LORA_DIR", "/loras")
28
+ LORA_DIR = os.path.abspath(LORA_DIR)
29
+ HF_HOME = os.environ.get("HF_HOME", os.path.expanduser("~/.cache/huggingface"))
30
+ logger.info(f"LORA_DIR: {LORA_DIR}")
31
+
32
+
33
+ def model_list(host_port, api_key="abc"):
34
+ client = openai.OpenAI(base_url=f"http://{host_port}/v1", api_key=api_key)
35
+ models = client.models.list()
36
+ for model in models:
37
+ print(f"Model ID: {model.id}")
38
+
39
+
40
+ def kill_existing_vllm(vllm_binary: Optional[str] = None) -> None:
41
+ """Kill selected vLLM processes using fzf."""
42
+ if not vllm_binary:
43
+ vllm_binary = get_vllm()
44
+
45
+ # List running vLLM processes
46
+ result = subprocess.run(
47
+ f"ps aux | grep {vllm_binary} | grep -v grep",
48
+ shell=True,
49
+ capture_output=True,
50
+ text=True,
51
+ )
52
+ processes = result.stdout.strip().split("\n")
53
+
54
+ if not processes or processes == [""]:
55
+ print("No running vLLM processes found.")
56
+ return
57
+
58
+ # Use fzf to select processes to kill
59
+ fzf = subprocess.Popen(
60
+ ["fzf", "--multi"],
61
+ stdin=subprocess.PIPE,
62
+ stdout=subprocess.PIPE,
63
+ text=True,
64
+ )
65
+ selected, _ = fzf.communicate("\n".join(processes))
66
+
67
+ if not selected:
68
+ print("No processes selected.")
69
+ return
70
+
71
+ # Extract PIDs and kill selected processes
72
+ pids = [line.split()[1] for line in selected.strip().split("\n")]
73
+ for pid in pids:
74
+ subprocess.run(
75
+ f"kill -9 {pid}",
76
+ shell=True,
77
+ stdout=subprocess.DEVNULL,
78
+ stderr=subprocess.DEVNULL,
79
+ )
80
+ print(f"Killed processes: {', '.join(pids)}")
81
+
82
+
83
+ def add_lora(
84
+ lora_name_or_path: str,
85
+ host_port: str,
86
+ url: str = "http://HOST:PORT/v1/load_lora_adapter",
87
+ served_model_name: Optional[str] = None,
88
+ lora_module: Optional[str] = None, # Added parameter
89
+ ) -> dict:
90
+ url = url.replace("HOST:PORT", host_port)
91
+ headers = {"Content-Type": "application/json"}
92
+
93
+ data = {
94
+ "lora_name": served_model_name,
95
+ "lora_path": os.path.abspath(lora_name_or_path),
96
+ }
97
+ if lora_module: # Include lora_module if provided
98
+ data["lora_module"] = lora_module
99
+ logger.info(f"{data=}, {headers}, {url=}")
100
+ # logger.warning(f"Failed to unload LoRA adapter: {str(e)}")
101
+ try:
102
+ response = requests.post(url, headers=headers, json=data)
103
+ response.raise_for_status()
104
+
105
+ # Handle potential non-JSON responses
106
+ try:
107
+ return response.json()
108
+ except ValueError:
109
+ return {
110
+ "status": "success",
111
+ "message": (
112
+ response.text
113
+ if response.text.strip()
114
+ else "Request completed with empty response"
115
+ ),
116
+ }
117
+
118
+ except requests.exceptions.RequestException as e:
119
+ logger.error(f"Request failed: {str(e)}")
120
+ return {"error": f"Request failed: {str(e)}"}
121
+
122
+
123
+ def unload_lora(lora_name, host_port):
124
+ try:
125
+ url = f"http://{host_port}/v1/unload_lora_adapter"
126
+ logger.info(f"{url=}")
127
+ headers = {"Content-Type": "application/json"}
128
+ data = {"lora_name": lora_name}
129
+ logger.info(f"Unloading LoRA adapter: {data=}")
130
+ response = requests.post(url, headers=headers, json=data)
131
+ response.raise_for_status()
132
+ logger.success(f"Unloaded LoRA adapter: {lora_name}")
133
+ except requests.exceptions.RequestException as e:
134
+ return {"error": f"Request failed: {str(e)}"}
135
+
136
+
137
+ def serve(
138
+ model: str,
139
+ gpu_groups: str,
140
+ served_model_name: Optional[str] = None,
141
+ port_start: int = 8155,
142
+ gpu_memory_utilization: float = 0.93,
143
+ dtype: str = "bfloat16",
144
+ max_model_len: int = 8192,
145
+ enable_lora: bool = False,
146
+ is_bnb: bool = False,
147
+ eager: bool = False,
148
+ chat_template: Optional[str] = None,
149
+ lora_modules: Optional[List[str]] = None, # Updated type
150
+ ):
151
+ """Main function to start or kill vLLM containers."""
152
+
153
+ """Start vLLM containers with dynamic args."""
154
+ print("Starting vLLM containers...,")
155
+ gpu_groups_arr = gpu_groups.split(",")
156
+ VLLM_BINARY = get_vllm()
157
+ if enable_lora:
158
+ VLLM_BINARY = "VLLM_ALLOW_RUNTIME_LORA_UPDATING=True " + VLLM_BINARY
159
+
160
+ # Auto-detect quantization based on model name if not explicitly set
161
+ if not is_bnb and model and ("bnb" in model.lower() or "4bit" in model.lower()):
162
+ is_bnb = True
163
+ print(f"Auto-detected quantization for model: {model}")
164
+
165
+ # Set environment variables for LoRA if needed
166
+ if enable_lora:
167
+ os.environ["VLLM_ALLOW_RUNTIME_LORA_UPDATING"] = "True"
168
+ print("Enabled runtime LoRA updating")
169
+
170
+ for i, gpu_group in enumerate(gpu_groups_arr):
171
+ port = port_start + i
172
+ gpu_group = ",".join([str(x) for x in gpu_group])
173
+ tensor_parallel = len(gpu_group.split(","))
174
+
175
+ cmd = [
176
+ f"CUDA_VISIBLE_DEVICES={gpu_group}",
177
+ VLLM_BINARY,
178
+ "serve",
179
+ model,
180
+ "--port",
181
+ str(port),
182
+ "--tensor-parallel",
183
+ str(tensor_parallel),
184
+ "--gpu-memory-utilization",
185
+ str(gpu_memory_utilization),
186
+ "--dtype",
187
+ dtype,
188
+ "--max-model-len",
189
+ str(max_model_len),
190
+ "--enable-prefix-caching",
191
+ "--disable-log-requests",
192
+ "--uvicorn-log-level critical",
193
+ ]
194
+ if HF_HOME:
195
+ # insert
196
+ cmd.insert(0, f"HF_HOME={HF_HOME}")
197
+ if eager:
198
+ cmd.append("--enforce-eager")
199
+
200
+ if served_model_name:
201
+ cmd.extend(["--served-model-name", served_model_name])
202
+
203
+ if is_bnb:
204
+ cmd.extend(
205
+ ["--quantization", "bitsandbytes", "--load-format", "bitsandbytes"]
206
+ )
207
+
208
+ if enable_lora:
209
+ cmd.extend(["--fully-sharded-loras", "--enable-lora"])
210
+
211
+ if chat_template:
212
+ chat_template = get_chat_template(chat_template)
213
+ cmd.extend(["--chat-template", chat_template]) # Add chat_template argument
214
+ if lora_modules:
215
+ # for lora_module in lora_modules:
216
+ # len must be even and we will join tuple with `=`
217
+ assert len(lora_modules) % 2 == 0, "lora_modules must be even"
218
+ # lora_modulle = [f'{name}={module}' for name, module in zip(lora_module[::2], lora_module[1::2])]
219
+ # import ipdb;ipdb.set_trace()
220
+ s = ""
221
+ for i in range(0, len(lora_modules), 2):
222
+ name = lora_modules[i]
223
+ module = lora_modules[i + 1]
224
+ s += f"{name}={module} "
225
+
226
+ cmd.extend(["--lora-modules", s])
227
+ # add kwargs
228
+ final_cmd = " ".join(cmd)
229
+ log_file = f"/tmp/vllm_{port}.txt"
230
+ final_cmd_with_log = f'"{final_cmd} 2>&1 | tee {log_file}"'
231
+ run_in_tmux = (
232
+ f"tmux new-session -d -s vllm_{port} 'bash -c {final_cmd_with_log}'"
233
+ )
234
+
235
+ print(final_cmd)
236
+ print("Logging to", log_file)
237
+ os.system(run_in_tmux)
238
+
239
+
240
+ def get_vllm():
241
+ VLLM_BINARY = subprocess.check_output("which vllm", shell=True, text=True).strip()
242
+ VLLM_BINARY = os.getenv("VLLM_BINARY", VLLM_BINARY)
243
+ logger.info(f"vLLM binary: {VLLM_BINARY}")
244
+ assert os.path.exists(
245
+ VLLM_BINARY
246
+ ), f"vLLM binary not found at {VLLM_BINARY}, please set VLLM_BINARY env variable"
247
+ return VLLM_BINARY
248
+
249
+
250
+ def get_args():
251
+ """Parse command line arguments."""
252
+ example_args = [
253
+ "svllm serve --model MODEL_NAME --gpus 0,1,2,3",
254
+ "svllm serve --lora LORA_NAME LORA_PATH --gpus 0,1,2,3",
255
+ "svllm add_lora --lora LORA_NAME LORA_PATH --host_port localhost:8150",
256
+ "svllm kill",
257
+ ]
258
+
259
+ parser = argparse.ArgumentParser(
260
+ description="vLLM Serve Script", epilog="Example: " + " || ".join(example_args)
261
+ )
262
+ parser.add_argument(
263
+ "mode",
264
+ choices=["serve", "kill", "add_lora", "unload_lora", "list_models"],
265
+ help="Mode to run the script in",
266
+ )
267
+ parser.add_argument("--model", "-m", type=str, help="Model to serve")
268
+ parser.add_argument(
269
+ "--gpus",
270
+ "-g",
271
+ type=str,
272
+ help="Comma-separated list of GPU groups",
273
+ dest="gpu_groups",
274
+ )
275
+ parser.add_argument(
276
+ "--lora",
277
+ "-l",
278
+ nargs=2,
279
+ metavar=("LORA_NAME", "LORA_PATH"),
280
+ help="Name and path of the LoRA adapter",
281
+ )
282
+ parser.add_argument(
283
+ "--served_model_name", type=str, help="Name of the served model"
284
+ )
285
+ parser.add_argument(
286
+ "--gpu_memory_utilization",
287
+ "-gmu",
288
+ type=float,
289
+ default=0.9,
290
+ help="GPU memory utilization",
291
+ )
292
+ parser.add_argument("--dtype", type=str, default="auto", help="Data type")
293
+ parser.add_argument(
294
+ "--max_model_len", "-mml", type=int, default=8192, help="Maximum model length"
295
+ )
296
+ parser.add_argument(
297
+ "--disable_lora",
298
+ dest="enable_lora",
299
+ action="store_false",
300
+ help="Disable LoRA support",
301
+ default=True,
302
+ )
303
+ parser.add_argument("--bnb", action="store_true", help="Enable quantization")
304
+ parser.add_argument(
305
+ "--not_verbose", action="store_true", help="Disable verbose logging"
306
+ )
307
+ parser.add_argument("--vllm_binary", type=str, help="Path to the vLLM binary")
308
+ parser.add_argument(
309
+ "--pipeline_parallel",
310
+ "-pp",
311
+ default=1,
312
+ type=int,
313
+ help="Number of pipeline parallel stages",
314
+ )
315
+ # parser.add_argument(
316
+ # "--extra_args",
317
+ # nargs=argparse.REMAINDER,
318
+ # help="Additional arguments for the serve command",
319
+ # )
320
+ parser.add_argument(
321
+ "--host_port",
322
+ "-hp",
323
+ type=str,
324
+ default="localhost:8150",
325
+ help="Host and port for the server format: host:port",
326
+ )
327
+ parser.add_argument("--eager", action="store_true", help="Enable eager execution")
328
+ parser.add_argument(
329
+ "--chat_template",
330
+ type=str,
331
+ help="Path to the chat template file",
332
+ )
333
+ parser.add_argument(
334
+ "--lora_modules",
335
+ "-lm",
336
+ nargs="+",
337
+ type=str,
338
+ help="List of LoRA modules in the format lora_name lora_module",
339
+ )
340
+ return parser.parse_args()
341
+
342
+
343
+ from speedy_utils import jloads, load_by_ext, memoize
344
+
345
+
346
+ def fetch_chat_template(template_name: str = "qwen") -> str:
347
+ """
348
+ Fetches a chat template file from a remote repository or local cache.
349
+
350
+ Args:
351
+ template_name (str): Name of the chat template. Defaults to 'qwen'.
352
+
353
+ Returns:
354
+ str: Path to the downloaded or cached chat template file.
355
+
356
+ Raises:
357
+ AssertionError: If the template_name is not supported.
358
+ ValueError: If the file URL is invalid.
359
+ """
360
+ supported_templates = [
361
+ "alpaca",
362
+ "chatml",
363
+ "gemma-it",
364
+ "llama-2-chat",
365
+ "mistral-instruct",
366
+ "qwen2.5-instruct",
367
+ "saiga",
368
+ "vicuna",
369
+ "qwen",
370
+ ]
371
+ assert template_name in supported_templates, (
372
+ f"Chat template '{template_name}' not supported. "
373
+ f"Please choose from {supported_templates}."
374
+ )
375
+
376
+ # Map 'qwen' to 'qwen2.5-instruct'
377
+ if template_name == "qwen":
378
+ template_name = "qwen2.5-instruct"
379
+
380
+ remote_url = (
381
+ f"https://raw.githubusercontent.com/chujiezheng/chat_templates/"
382
+ f"main/chat_templates/{template_name}.jinja"
383
+ )
384
+ local_cache_path = f"/tmp/chat_template_{template_name}.jinja"
385
+
386
+ if remote_url.startswith("http"):
387
+ import requests
388
+
389
+ response = requests.get(remote_url)
390
+ with open(local_cache_path, "w") as file:
391
+ file.write(response.text)
392
+ return local_cache_path
393
+
394
+ raise ValueError("The file URL must be a valid HTTP URL.")
395
+
396
+
397
+ def get_chat_template(template_name: str) -> str:
398
+ return fetch_chat_template(template_name)
399
+
400
+
401
+ def main():
402
+ """Main entry point for the script."""
403
+
404
+ args = get_args()
405
+
406
+ if args.mode == "serve":
407
+ # Handle LoRA model serving via the new --lora argument
408
+ if args.lora:
409
+ lora_name, lora_path = args.lora
410
+ if not args.lora_modules:
411
+ args.lora_modules = [lora_name, lora_path]
412
+ # Try to get the model from LoRA config if not specified
413
+ if args.model is None:
414
+ lora_config = os.path.join(lora_path, "adapter_config.json")
415
+ if os.path.exists(lora_config):
416
+ config = load_by_ext(lora_config)
417
+ model_name = config.get("base_model_name_or_path")
418
+ # Handle different quantization suffixes
419
+ if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
420
+ model_name = model_name.replace("-unsloth-bnb-4bit", "")
421
+ elif model_name.endswith("-bnb-4bit") and not args.bnb:
422
+ model_name = model_name.replace("-bnb-4bit", "")
423
+ logger.info(f"Model name from LoRA config: {model_name}")
424
+ args.model = model_name
425
+
426
+ # Fall back to existing logic for other cases (already specified lora_modules)
427
+ if args.model is None and args.lora_modules is not None and not args.lora:
428
+ lora_config = os.path.join(args.lora_modules[1], "adapter_config.json")
429
+ if os.path.exists(lora_config):
430
+ config = load_by_ext(lora_config)
431
+ model_name = config.get("base_model_name_or_path")
432
+ if model_name.endswith("-unsloth-bnb-4bit") and not args.bnb:
433
+ model_name = model_name.replace("-unsloth-bnb-4bit", "")
434
+ elif model_name.endswith("-bnb-4bit") and not args.bnb:
435
+ model_name = model_name.replace("-bnb-4bit", "")
436
+ logger.info(f"Model name from LoRA config: {model_name}")
437
+ args.model = model_name
438
+ # port_start from hostport
439
+ port_start = int(args.host_port.split(":")[-1])
440
+ serve(
441
+ args.model,
442
+ args.gpu_groups,
443
+ args.served_model_name,
444
+ port_start,
445
+ args.gpu_memory_utilization,
446
+ args.dtype,
447
+ args.max_model_len,
448
+ args.enable_lora,
449
+ args.bnb,
450
+ args.eager,
451
+ args.chat_template,
452
+ args.lora_modules,
453
+ )
454
+
455
+ elif args.mode == "kill":
456
+ kill_existing_vllm(args.vllm_binary)
457
+ elif args.mode == "add_lora":
458
+ if args.lora:
459
+ lora_name, lora_path = args.lora
460
+ add_lora(lora_path, host_port=args.host_port, served_model_name=lora_name)
461
+ else:
462
+ # Fallback to old behavior
463
+ lora_name = args.model
464
+ add_lora(
465
+ lora_name,
466
+ host_port=args.host_port,
467
+ served_model_name=args.served_model_name,
468
+ )
469
+ elif args.mode == "unload_lora":
470
+ if args.lora:
471
+ lora_name = args.lora[0]
472
+ else:
473
+ lora_name = args.model
474
+ unload_lora(lora_name, host_port=args.host_port)
475
+ elif args.mode == "list_models":
476
+ model_list(args.host_port)
477
+ else:
478
+ raise ValueError(f"Unknown mode: {args.mode}, ")
479
+
480
+
481
+ if __name__ == "__main__":
482
+ main()
speedy_utils/__init__.py CHANGED
@@ -41,7 +41,7 @@ from .common.utils_print import (
41
41
 
42
42
  # Multi-worker processing
43
43
  from .multi_worker.process import multi_process
44
- from .multi_worker.thread import multi_threaad_standard, multi_thread
44
+ from .multi_worker.thread import multi_thread
45
45
 
46
46
  # Define __all__ explicitly
47
47
  __all__ = [
@@ -79,7 +79,6 @@ __all__ = [
79
79
  # Multi-worker processing
80
80
  "multi_process",
81
81
  "multi_thread",
82
- "multi_threaad_standard",
83
82
  ]
84
83
 
85
84
  # Setup default logger
speedy_utils/all.py CHANGED
@@ -64,7 +64,6 @@ from speedy_utils import ( # Clock module; Function decorators; Cache utilities
64
64
  memoize,
65
65
  mkdir_or_exist,
66
66
  multi_process,
67
- multi_threaad_standard,
68
67
  multi_thread,
69
68
  print_table,
70
69
  retry_runtime,
@@ -157,5 +156,4 @@ __all__ = [
157
156
  # Multi-worker processing
158
157
  "multi_process",
159
158
  "multi_thread",
160
- "multi_threaad_standard",
161
159
  ]
@@ -96,6 +96,11 @@ class Clock:
96
96
  )
97
97
  return
98
98
  current_time = time.time()
99
+ if self.last_checkpoint is None:
100
+ logger.opt(depth=2).warning(
101
+ "Last checkpoint is not set. Please call start() before using this method."
102
+ )
103
+ return
99
104
  elapsed = current_time - self.last_checkpoint
100
105
  self.last_checkpoint = current_time
101
106
  return elapsed
@@ -111,6 +116,11 @@ class Clock:
111
116
  "Timer has not been started. Please call start() before using this method."
112
117
  )
113
118
  return
119
+ if self.last_checkpoint is None:
120
+ logger.opt(depth=2).warning(
121
+ "Last checkpoint is not set. Please call start() before using this method."
122
+ )
123
+ return
114
124
  return time.time() - self.last_checkpoint
115
125
 
116
126
  def update_task(self, task_name):
@@ -6,7 +6,6 @@ import sys
6
6
  from collections.abc import Callable
7
7
  from typing import Any, List
8
8
 
9
- from IPython import get_ipython
10
9
  from pydantic import BaseModel
11
10
 
12
11
 
@@ -1,13 +1,12 @@
1
1
  """Provides thread-based parallel execution utilities."""
2
2
 
3
- import inspect
4
3
  import os
5
4
  import time
6
5
  import traceback
7
6
  from collections.abc import Callable, Iterable
8
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
9
8
  from itertools import islice
10
- from typing import Any, TypeVar, cast
9
+ from typing import Any, TypeVar
11
10
 
12
11
  from loguru import logger
13
12
 
@@ -80,7 +79,7 @@ def multi_thread(
80
79
  ``False`` the failing task’s result becomes ``None``.
81
80
  **fixed_kwargs – static keyword args forwarded to every ``func()`` call.
82
81
  """
83
- from speedy_utils import dump_json_or_pickle, identify, load_by_ext
82
+ from speedy_utils import dump_json_or_pickle, load_by_ext
84
83
 
85
84
  if n_proc > 1:
86
85
  import tempfile
@@ -297,10 +296,27 @@ def multi_thread(
297
296
  return results
298
297
 
299
298
 
300
- def multi_threaad_standard(fn, items, workers=4):
301
- """
299
+ def multi_thread_standard(
300
+ fn: Callable[[Any], Any], items: Iterable[Any], workers: int = 4
301
+ ) -> list[Any]:
302
+ """Execute a function using standard ThreadPoolExecutor.
303
+
302
304
  A standard implementation of multi-threading using ThreadPoolExecutor.
303
305
  Ensures the order of results matches the input order.
306
+
307
+ Parameters
308
+ ----------
309
+ fn : Callable
310
+ The function to execute for each item.
311
+ items : Iterable
312
+ The items to process.
313
+ workers : int, optional
314
+ Number of worker threads, by default 4.
315
+
316
+ Returns
317
+ -------
318
+ list
319
+ Results in same order as input items.
304
320
  """
305
321
  with ThreadPoolExecutor(max_workers=workers) as executor:
306
322
  futures = [executor.submit(fn, item) for item in items]
@@ -308,4 +324,4 @@ def multi_threaad_standard(fn, items, workers=4):
308
324
  return results
309
325
 
310
326
 
311
- __all__ = ["multi_thread", "multi_threaad_standard"]
327
+ __all__ = ["multi_thread", "multi_thread_standard"]
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.3
2
2
  Name: speedy-utils
3
- Version: 0.1.28
3
+ Version: 1.0.0
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Author: AnhVTH
6
6
  Author-email: anhvth.226@gmail.com
@@ -19,11 +19,12 @@ Requires-Dist: fastprogress
19
19
  Requires-Dist: freezegun (>=1.5.1,<2.0.0)
20
20
  Requires-Dist: ipdb
21
21
  Requires-Dist: ipywidgets
22
- Requires-Dist: json-repair
22
+ Requires-Dist: json-repair (>=0.40.0,<0.41.0)
23
23
  Requires-Dist: jupyterlab
24
24
  Requires-Dist: loguru
25
25
  Requires-Dist: matplotlib
26
26
  Requires-Dist: numpy
27
+ Requires-Dist: packaging (>=23.2,<25)
27
28
  Requires-Dist: pandas
28
29
  Requires-Dist: pydantic
29
30
  Requires-Dist: requests
@@ -261,29 +262,6 @@ Ensure all dependencies are installed before running tests:
261
262
  pip install -r requirements.txt
262
263
  ```
263
264
 
264
- ## Data Arguments
265
-
266
- Define and parse data arguments using a dataclass:
267
-
268
- ```python
269
- from dataclasses import dataclass
270
- from speedy_utils.common.dataclass_parser import ArgsParser
271
-
272
- @dataclass
273
- class ExampleArgs(ArgsParser):
274
- from_peft: str = "./outputs/llm_hn_qw32b/hn_results_r3/"
275
- model_name_or_path: str = "Qwen/Qwen2.5-32B-Instruct-AWQ"
276
- use_fp16: bool = False
277
- batch_size: int = 1
278
- max_length: int = 512
279
- cache_dir: str = ".cache/run_embeds"
280
- output_dir: str = ".cache"
281
- input_file: str = ".cache/doc.csv"
282
- output_name: str = "qw32b_r3"
283
-
284
- args = ExampleArgs.parse_args()
285
- print(args)
286
- ```
287
265
 
288
266
  Run the script to parse and display the arguments:
289
267
 
@@ -299,5 +277,3 @@ Example output:
299
277
 
300
278
  Please ensure your code adheres to the project's coding standards and includes appropriate tests.
301
279
 
302
-
303
-
@@ -0,0 +1,27 @@
1
+ llm_utils/__init__.py,sha256=4IO_cP7FlFvV57W7_AkJGdXWcZfWALkr6UwR9s9QEF4,701
2
+ llm_utils/chat_format.py,sha256=ZY2HYv3FPL2xiMxbbO-huIwT5LZrcJm_if_us-2eSZ4,15094
3
+ llm_utils/group_messages.py,sha256=GKMQkenQf-6DD_1EJa11UBj7-VfkGT7xVhR_B_zMzqY,3868
4
+ llm_utils/lm.py,sha256=3d3b9UMtIeKcG6vpXHwuDZ3QP46X1aSGeMvayy3tmHs,29044
5
+ llm_utils/lm_classification.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ llm_utils/load_chat_dataset.py,sha256=hsPPlOmZEDqsg7GQD7SgxTEGJky6JDm6jnG3JHbpjb4,1895
7
+ llm_utils/scripts/vllm_load_balancer.py,sha256=uSjGd_jOmI9W9eVOhiOXbeUnZkQq9xG4bCVzhmpupcA,16096
8
+ llm_utils/scripts/vllm_serve.py,sha256=-gXF3DVs7sfZ4mVmv_0OFheqytMQbr73YkS3bQlFdpE,15818
9
+ speedy_utils/__init__.py,sha256=I2bSfDIE9yRF77tnHW0vqfExDA2m1gUx4AH8C9XmGtg,1707
10
+ speedy_utils/all.py,sha256=A9jiKGjo950eg1pscS9x38OWAjKGyusoAN5mrfweY4E,3090
11
+ speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ speedy_utils/common/clock.py,sha256=3n4FkCW0dz46O8By09V5Pve1DSMgpLDRbWEVRryryeQ,7423
13
+ speedy_utils/common/function_decorator.py,sha256=r_r42qCWuNcu0_aH7musf2BWvcJfgZrD81G28mDcolw,2226
14
+ speedy_utils/common/logger.py,sha256=NIOlhhACpcc0BUTSJ8oDYrLp23J2gW_KJFyRVdLN2tY,6432
15
+ speedy_utils/common/report_manager.py,sha256=dgGfS_fHbZiQMsLzkgnj0OfB758t1x6B4MhjsetSl9A,3930
16
+ speedy_utils/common/utils_cache.py,sha256=gXX5qTXpCG3qDgjnOsSvxM4LkQurmcsg4QRv_zOBG1E,8378
17
+ speedy_utils/common/utils_io.py,sha256=vXhgrMSse_5yuP7yiSjdqKgOu8pz983glelquyNjbkE,4809
18
+ speedy_utils/common/utils_misc.py,sha256=nsQOu2jcplcFHVQ1CnOjEpNcdxIINveGxB493Cqo63U,1812
19
+ speedy_utils/common/utils_print.py,sha256=QRaL2QPbks4Mtol_gJy3ZdahgUfzUEtcOp4--lBlzYI,6709
20
+ speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
21
+ speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
22
+ speedy_utils/multi_worker/thread.py,sha256=9pXjvgjD0s0Hp0cZ6I3M0ndp1OlYZ1yvqbs_bcun_Kw,12775
23
+ speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
24
+ speedy_utils-1.0.0.dist-info/METADATA,sha256=nIcZcNBPA9_eUs7on52UMVtcV43lmVFRJLHLptk5AA0,7165
25
+ speedy_utils-1.0.0.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
26
+ speedy_utils-1.0.0.dist-info/entry_points.txt,sha256=fsv8_lMg62BeswoUHrqfj2u6q2l4YcDCw7AgQFg6GRw,61
27
+ speedy_utils-1.0.0.dist-info/RECORD,,