speedy-utils 0.1.28__py3-none-any.whl → 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +30 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +119 -0
- llm_utils/lm.py +742 -0
- llm_utils/lm_classification.py +0 -0
- llm_utils/load_chat_dataset.py +41 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +482 -0
- speedy_utils/__init__.py +1 -2
- speedy_utils/all.py +0 -2
- speedy_utils/common/clock.py +10 -0
- speedy_utils/common/utils_misc.py +0 -1
- speedy_utils/multi_worker/thread.py +22 -6
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/METADATA +3 -27
- speedy_utils-1.0.0.dist-info/RECORD +27 -0
- speedy_utils/common/dataclass_parser.py +0 -101
- speedy_utils/multi_worker/_handle_inputs.py +0 -50
- speedy_utils-0.1.28.dist-info/RECORD +0 -21
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/WHEEL +0 -0
- {speedy_utils-0.1.28.dist-info → speedy_utils-1.0.0.dist-info}/entry_points.txt +0 -0
|
@@ -1,101 +0,0 @@
|
|
|
1
|
-
import argparse
|
|
2
|
-
from dataclasses import dataclass, fields, is_dataclass
|
|
3
|
-
from typing import Any, Dict, Type, TypeVar
|
|
4
|
-
|
|
5
|
-
import yaml
|
|
6
|
-
from loguru import logger
|
|
7
|
-
from tabulate import tabulate
|
|
8
|
-
|
|
9
|
-
# add depricated
|
|
10
|
-
logger.warning(
|
|
11
|
-
"This module is deprecated. Please use speedy_utils.common.dataclass_parser instead."
|
|
12
|
-
)
|
|
13
|
-
|
|
14
|
-
T = TypeVar("T")
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
class ArgsParser:
|
|
18
|
-
@classmethod
|
|
19
|
-
def get_parser(cls) -> argparse.ArgumentParser:
|
|
20
|
-
"""Generate an argument parser from the dataclass fields."""
|
|
21
|
-
parser = argparse.ArgumentParser(description=f"Parser for {cls.__name__}")
|
|
22
|
-
parser.add_argument(
|
|
23
|
-
"--yaml_file", type=str, help="Path to YAML file with arguments"
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
for field in fields(cls):
|
|
27
|
-
arg_name = f"--{field.name}"
|
|
28
|
-
default = field.default
|
|
29
|
-
field_type = field.type
|
|
30
|
-
# print(field, field_type)
|
|
31
|
-
if field_type is bool:
|
|
32
|
-
parser.add_argument(
|
|
33
|
-
arg_name,
|
|
34
|
-
action="store_true",
|
|
35
|
-
help=f"Enable {field.name} (default: {default})",
|
|
36
|
-
)
|
|
37
|
-
elif "list" in str(field_type):
|
|
38
|
-
elem_type = str(field_type).split("[")[1].split("]")[0]
|
|
39
|
-
parser.add_argument(
|
|
40
|
-
arg_name,
|
|
41
|
-
type=eval(elem_type),
|
|
42
|
-
nargs="+",
|
|
43
|
-
help=f"Override {field.name} (default: {default})",
|
|
44
|
-
)
|
|
45
|
-
else:
|
|
46
|
-
parser.add_argument(
|
|
47
|
-
arg_name,
|
|
48
|
-
type=field_type,
|
|
49
|
-
default=None,
|
|
50
|
-
help=f"Override {field.name} (default: {default})",
|
|
51
|
-
)
|
|
52
|
-
return parser
|
|
53
|
-
|
|
54
|
-
@classmethod
|
|
55
|
-
def from_args(cls, args: argparse.Namespace) -> T:
|
|
56
|
-
"""Create an instance of the dataclass from argparse arguments."""
|
|
57
|
-
config: dict[str, Any] = {}
|
|
58
|
-
if args.yaml_file:
|
|
59
|
-
with open(args.yaml_file) as file:
|
|
60
|
-
config = yaml.safe_load(file)
|
|
61
|
-
|
|
62
|
-
cli_overrides = {
|
|
63
|
-
field.name: getattr(args, field.name)
|
|
64
|
-
for field in fields(cls)
|
|
65
|
-
if getattr(args, field.name) is not None
|
|
66
|
-
}
|
|
67
|
-
config.update(cli_overrides)
|
|
68
|
-
|
|
69
|
-
return cls(**config)
|
|
70
|
-
|
|
71
|
-
@classmethod
|
|
72
|
-
def parse_args(cls) -> T:
|
|
73
|
-
"""Parse arguments and return an instance of the dataclass."""
|
|
74
|
-
parser = cls.get_parser()
|
|
75
|
-
args = parser.parse_args()
|
|
76
|
-
return cls.from_args(args)
|
|
77
|
-
|
|
78
|
-
def __str__(self):
|
|
79
|
-
return tabulate(
|
|
80
|
-
[[f.name, getattr(self, f.name)] for f in fields(self)],
|
|
81
|
-
headers=["Field", "Value"],
|
|
82
|
-
tablefmt="github",
|
|
83
|
-
)
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
@dataclass
|
|
87
|
-
class ExampleArgs(ArgsParser):
|
|
88
|
-
from_peft: str = "./outputs/llm_hn_qw32b/hn_results_r3/"
|
|
89
|
-
model_name_or_path: str = "Qwen/Qwen2.5-32B-Instruct-AWQ"
|
|
90
|
-
use_fp16: bool = False
|
|
91
|
-
batch_size: int = 1
|
|
92
|
-
max_length: int = 512
|
|
93
|
-
cache_dir: str = ".cache/run_embeds"
|
|
94
|
-
output_dir: str = ".cache"
|
|
95
|
-
input_file: str = ".cache/doc.csv"
|
|
96
|
-
output_name: str = "qw32b_r3"
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
if __name__ == "__main__":
|
|
100
|
-
args = ExampleArgs.parse_args()
|
|
101
|
-
print(args)
|
|
@@ -1,50 +0,0 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
import inspect
|
|
3
|
-
from collections.abc import Callable, Iterable
|
|
4
|
-
from typing import Any, Dict, List, Union
|
|
5
|
-
|
|
6
|
-
import pandas as pd
|
|
7
|
-
|
|
8
|
-
# Example object
|
|
9
|
-
|
|
10
|
-
|
|
11
|
-
def _get_original_func(func):
|
|
12
|
-
"""
|
|
13
|
-
Recursively unwrap a decorated function to find the actual
|
|
14
|
-
original function object.
|
|
15
|
-
"""
|
|
16
|
-
while hasattr(func, "__wrapped__"):
|
|
17
|
-
func = func.__wrapped__
|
|
18
|
-
return func
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
def handle_inputs(
|
|
22
|
-
f: Callable, inputs: list[dict[str, Any]] | list[Any] | pd.DataFrame
|
|
23
|
-
) -> list[dict[str, Any]]:
|
|
24
|
-
# 1. Unwrap in case f is decorated (e.g., by @memoize).
|
|
25
|
-
real_func = _get_original_func(f)
|
|
26
|
-
|
|
27
|
-
# 2. Count parameters with inspect.signature.
|
|
28
|
-
# This handles normal or annotated arguments, etc.
|
|
29
|
-
sig = inspect.signature(real_func)
|
|
30
|
-
num_params = len(sig.parameters)
|
|
31
|
-
|
|
32
|
-
# Convert certain input types to list to unify processing
|
|
33
|
-
if isinstance(inputs, (range, list, tuple)):
|
|
34
|
-
inputs = list(inputs)
|
|
35
|
-
|
|
36
|
-
# 3. If exactly 1 parameter, we do the single-arg logic:
|
|
37
|
-
if num_params == 1:
|
|
38
|
-
# If the user passed a dataframe, break it into rows
|
|
39
|
-
if isinstance(inputs, pd.DataFrame):
|
|
40
|
-
inputs = [r for _, r in inputs.iterrows()]
|
|
41
|
-
|
|
42
|
-
# For a single-arg function, turn each item into a dict: {arg_name: item}
|
|
43
|
-
# so we can later call func(**inp)
|
|
44
|
-
arg_name = next(iter(sig.parameters)) # name of the single parameter
|
|
45
|
-
inputs = [{arg_name: input_} for input_ in inputs]
|
|
46
|
-
return f, inputs
|
|
47
|
-
|
|
48
|
-
else:
|
|
49
|
-
|
|
50
|
-
return lambda x: f(x), inputs
|
|
@@ -1,21 +0,0 @@
|
|
|
1
|
-
speedy_utils/__init__.py,sha256=SaYWsVADsRbaWz5QVUJs3vIjZATMnpByCd7TCpD2Wmc,1761
|
|
2
|
-
speedy_utils/all.py,sha256=Zj5fjzm38Ko8xkyfApW6flkTHD4JZtGK53uveD6ZGGI,3148
|
|
3
|
-
speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
-
speedy_utils/common/clock.py,sha256=yxGWJeNoXri0Lbz8e29XkWAQqQPtmmZytA8wEiu5o0U,7009
|
|
5
|
-
speedy_utils/common/dataclass_parser.py,sha256=vrkZYth2HUJrR2NB_78qlwrzxvf5Uqt2pzQpRKCM-XM,3158
|
|
6
|
-
speedy_utils/common/function_decorator.py,sha256=r_r42qCWuNcu0_aH7musf2BWvcJfgZrD81G28mDcolw,2226
|
|
7
|
-
speedy_utils/common/logger.py,sha256=NIOlhhACpcc0BUTSJ8oDYrLp23J2gW_KJFyRVdLN2tY,6432
|
|
8
|
-
speedy_utils/common/report_manager.py,sha256=dgGfS_fHbZiQMsLzkgnj0OfB758t1x6B4MhjsetSl9A,3930
|
|
9
|
-
speedy_utils/common/utils_cache.py,sha256=gXX5qTXpCG3qDgjnOsSvxM4LkQurmcsg4QRv_zOBG1E,8378
|
|
10
|
-
speedy_utils/common/utils_io.py,sha256=vXhgrMSse_5yuP7yiSjdqKgOu8pz983glelquyNjbkE,4809
|
|
11
|
-
speedy_utils/common/utils_misc.py,sha256=RNqHD2f9FEBgwuMNuhYXj3QSItBW9q0LjKvzMi4eJiI,1844
|
|
12
|
-
speedy_utils/common/utils_print.py,sha256=QRaL2QPbks4Mtol_gJy3ZdahgUfzUEtcOp4--lBlzYI,6709
|
|
13
|
-
speedy_utils/multi_worker/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
14
|
-
speedy_utils/multi_worker/_handle_inputs.py,sha256=93tSLluQgXUizZee89zRCgqMMBSfnxXZ-gz9opC9uxY,1520
|
|
15
|
-
speedy_utils/multi_worker/process.py,sha256=XwQlffxzRFnCVeKjDNBZDwFfUQHiJiuFA12MRGJVru8,6708
|
|
16
|
-
speedy_utils/multi_worker/thread.py,sha256=NXyoGZHTDn7lAeqvnmWP6Sf1Xd6QHdzbtxt9AEe9Qn0,12388
|
|
17
|
-
speedy_utils/scripts/mpython.py,sha256=ZzkBWI5Xw3vPoMx8xQt2x4mOFRjtwWqfvAJ5_ngyWgw,3816
|
|
18
|
-
speedy_utils-0.1.28.dist-info/METADATA,sha256=koTjqoJMCsf2KiXtGMBP5Xhp6AukERaMEOijW_MJEXk,7723
|
|
19
|
-
speedy_utils-0.1.28.dist-info/WHEEL,sha256=b4K_helf-jlQoXBBETfwnf4B04YC67LOev0jo4fX5m8,88
|
|
20
|
-
speedy_utils-0.1.28.dist-info/entry_points.txt,sha256=fsv8_lMg62BeswoUHrqfj2u6q2l4YcDCw7AgQFg6GRw,61
|
|
21
|
-
speedy_utils-0.1.28.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|