speedy-utils 1.0.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- speedy/__init__.py +35 -0
- speedy/common/__init__.py +0 -0
- speedy/common/clock.py +68 -0
- speedy/common/utils_cache.py +170 -0
- speedy/common/utils_io.py +111 -0
- speedy/common/utils_misc.py +46 -0
- speedy/common/utils_print.py +138 -0
- speedy/multi_worker.py +121 -0
- speedy_utils-1.0.2.dist-info/METADATA +22 -0
- speedy_utils-1.0.2.dist-info/RECORD +12 -0
- speedy_utils-1.0.2.dist-info/WHEEL +5 -0
- speedy_utils-1.0.2.dist-info/top_level.txt +1 -0
speedy/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from .common.clock import Clock, timef
|
|
2
|
+
from .common.utils_cache import (ICACHE, SPEED_CACHE_DIR, identify, imemoize,
|
|
3
|
+
imemoize_v2, memoize, memoize_v2)
|
|
4
|
+
from .common.utils_io import (dump_json_or_pickle, dump_jsonl, load_by_ext,
|
|
5
|
+
load_json_or_pickle)
|
|
6
|
+
from .common.utils_misc import (convert_to_builtin_python, flatten_list,
|
|
7
|
+
get_arg_names, is_interactive, mkdir_or_exist)
|
|
8
|
+
from .common.utils_print import fprint, print_table
|
|
9
|
+
from .multi_worker import async_multi_thread, multi_process, multi_thread
|
|
10
|
+
|
|
11
|
+
__all__ = [
|
|
12
|
+
"SPEED_CACHE_DIR",
|
|
13
|
+
"ICACHE",
|
|
14
|
+
"mkdir_or_exist",
|
|
15
|
+
"dump_jsonl",
|
|
16
|
+
"dump_json_or_pickle",
|
|
17
|
+
"timef", # Ensure timef is moved to an appropriate module or included here
|
|
18
|
+
"load_json_or_pickle",
|
|
19
|
+
"load_by_ext",
|
|
20
|
+
"identify",
|
|
21
|
+
"memoize",
|
|
22
|
+
"imemoize",
|
|
23
|
+
"imemoize_v2",
|
|
24
|
+
"flatten_list",
|
|
25
|
+
"fprint",
|
|
26
|
+
"get_arg_names",
|
|
27
|
+
"memoize_v2",
|
|
28
|
+
"is_interactive",
|
|
29
|
+
"print_table",
|
|
30
|
+
"convert_to_builtin_python",
|
|
31
|
+
"Clock",
|
|
32
|
+
"multi_thread",
|
|
33
|
+
"multi_process",
|
|
34
|
+
"async_multi_thread",
|
|
35
|
+
]
|
|
File without changes
|
speedy/common/clock.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import time
|
|
2
|
+
from loguru import logger
|
|
3
|
+
|
|
4
|
+
__all__ = ["Clock"]
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
def timef(func):
|
|
8
|
+
def wrapper(*args, **kwargs):
|
|
9
|
+
start_time = time.time()
|
|
10
|
+
result = func(*args, **kwargs)
|
|
11
|
+
end_time = time.time()
|
|
12
|
+
execution_time = end_time - start_time
|
|
13
|
+
print(f"{func.__name__} took {execution_time:0.2f} seconds to execute.")
|
|
14
|
+
return result
|
|
15
|
+
|
|
16
|
+
return wrapper
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class Clock:
|
|
20
|
+
def __init__(self, start_now=True):
|
|
21
|
+
self.start_time = None
|
|
22
|
+
self.time_table = {}
|
|
23
|
+
self.last_check = None
|
|
24
|
+
if start_now:
|
|
25
|
+
self.start()
|
|
26
|
+
self.pbar_counter = 0
|
|
27
|
+
self.last_print = time.time()
|
|
28
|
+
|
|
29
|
+
def start(self):
|
|
30
|
+
self.start_time = time.time() if self.start_time is None else self.start_time
|
|
31
|
+
self.last_check = self.start_time
|
|
32
|
+
|
|
33
|
+
def since_start(self):
|
|
34
|
+
if self.start_time is None:
|
|
35
|
+
raise ValueError("Clock has not been started.")
|
|
36
|
+
return time.time() - self.start_time
|
|
37
|
+
|
|
38
|
+
def log(self, custom_logger=None):
|
|
39
|
+
msg = f"Time elapsed: {self.since_start():.2f} seconds."
|
|
40
|
+
if custom_logger:
|
|
41
|
+
custom_logger(msg)
|
|
42
|
+
else:
|
|
43
|
+
logger.info(msg)
|
|
44
|
+
|
|
45
|
+
def since_last_check(self):
|
|
46
|
+
now = time.time()
|
|
47
|
+
elapsed = now - self.last_check
|
|
48
|
+
self.last_check = now
|
|
49
|
+
return elapsed
|
|
50
|
+
|
|
51
|
+
def update(self, name):
|
|
52
|
+
if not name in self.time_table:
|
|
53
|
+
self.time_table[name] = 0
|
|
54
|
+
self.time_table[name] += self.since_last_check()
|
|
55
|
+
|
|
56
|
+
def print_table(self, every=1):
|
|
57
|
+
now = time.time()
|
|
58
|
+
if now - self.last_print > every:
|
|
59
|
+
self.pbar_counter += 1
|
|
60
|
+
total_time = sum(self.time_table.values())
|
|
61
|
+
desc = "Time table: "
|
|
62
|
+
for name, t in self.time_table.items():
|
|
63
|
+
percentage = (t / total_time) * 100
|
|
64
|
+
desc += "{}: avg_time: {:.2f} s ({:.2f}%), total: {} s".format(
|
|
65
|
+
name, t, percentage, total_time
|
|
66
|
+
)
|
|
67
|
+
logger.info(desc)
|
|
68
|
+
self.last_print = now
|
|
@@ -0,0 +1,170 @@
|
|
|
1
|
+
# utils/utils_cache.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import inspect
|
|
5
|
+
import os
|
|
6
|
+
import os.path as osp
|
|
7
|
+
import pickle
|
|
8
|
+
import traceback
|
|
9
|
+
from typing import Any, Callable, Dict, List, Optional
|
|
10
|
+
|
|
11
|
+
import xxhash
|
|
12
|
+
from loguru import logger
|
|
13
|
+
|
|
14
|
+
from .utils_io import dump_json_or_pickle, load_json_or_pickle
|
|
15
|
+
from .utils_misc import mkdir_or_exist
|
|
16
|
+
|
|
17
|
+
SPEED_CACHE_DIR = osp.join(osp.expanduser("~"), ".cache/av")
|
|
18
|
+
ICACHE: Dict[str, Any] = {}
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def identify(x: Any) -> str:
|
|
22
|
+
"""Return an hex digest of the input."""
|
|
23
|
+
return xxhash.xxh64(pickle.dumps(x), seed=0).hexdigest()
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def memoize(
|
|
27
|
+
func: Callable,
|
|
28
|
+
ignore_self: bool = True,
|
|
29
|
+
cache_dir: str = SPEED_CACHE_DIR,
|
|
30
|
+
cache_type: str = ".pkl",
|
|
31
|
+
verbose: bool = False,
|
|
32
|
+
cache_key: Optional[str] = None,
|
|
33
|
+
) -> Callable:
|
|
34
|
+
"""Cache result of function call on disk."""
|
|
35
|
+
assert cache_type in [".pkl", ".json"]
|
|
36
|
+
if os.environ.get("AV_MEMOIZE_DISABLE", "0") == "1":
|
|
37
|
+
logger.info("Memoize is disabled")
|
|
38
|
+
return func
|
|
39
|
+
|
|
40
|
+
@functools.wraps(func)
|
|
41
|
+
def memoized_func(*args, **kwargs):
|
|
42
|
+
try:
|
|
43
|
+
arg_names = inspect.getfullargspec(func).args
|
|
44
|
+
func_source = inspect.getsource(func).replace(" ", "")
|
|
45
|
+
if cache_key is not None:
|
|
46
|
+
logger.info(f"Use cache_key={kwargs[cache_key]}")
|
|
47
|
+
fid = [func_source, kwargs[cache_key]]
|
|
48
|
+
func_id = identify(fid)
|
|
49
|
+
elif len(arg_names) > 0 and arg_names[0] == "self" and ignore_self:
|
|
50
|
+
func_id = identify((func_source, args[1:], kwargs))
|
|
51
|
+
else:
|
|
52
|
+
func_id = identify((func_source, args, kwargs))
|
|
53
|
+
|
|
54
|
+
cache_path = osp.join(
|
|
55
|
+
cache_dir, "funcs", func.__name__, f"{func_id}{cache_type}"
|
|
56
|
+
)
|
|
57
|
+
mkdir_or_exist(os.path.dirname(cache_path))
|
|
58
|
+
if osp.exists(cache_path):
|
|
59
|
+
if verbose:
|
|
60
|
+
logger.info(f"Load from cache file: {cache_path}")
|
|
61
|
+
result = load_json_or_pickle(cache_path)
|
|
62
|
+
else:
|
|
63
|
+
result = func(*args, **kwargs)
|
|
64
|
+
dump_json_or_pickle(result, cache_path)
|
|
65
|
+
return result
|
|
66
|
+
except Exception as e:
|
|
67
|
+
traceback.print_exc()
|
|
68
|
+
logger.warning(f"Exception: {e}, using default function call")
|
|
69
|
+
return func(*args, **kwargs)
|
|
70
|
+
|
|
71
|
+
return memoized_func
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
def imemoize(func: Callable) -> Callable:
|
|
75
|
+
"""Memoize a function into memory."""
|
|
76
|
+
|
|
77
|
+
@functools.wraps(func)
|
|
78
|
+
def _f(*args, **kwargs):
|
|
79
|
+
ident_name = identify((inspect.getsource(func), args, kwargs))
|
|
80
|
+
try:
|
|
81
|
+
return ICACHE[ident_name]
|
|
82
|
+
except KeyError:
|
|
83
|
+
result = func(*args, **kwargs)
|
|
84
|
+
ICACHE[ident_name] = result
|
|
85
|
+
return result
|
|
86
|
+
|
|
87
|
+
return _f
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
def imemoize_v2(keys: List[str]) -> Callable:
|
|
91
|
+
"""Memoize a function into memory based on specified keys."""
|
|
92
|
+
|
|
93
|
+
def decorator(func: Callable) -> Callable:
|
|
94
|
+
@functools.wraps(func)
|
|
95
|
+
def wrapper(*args, **kwargs):
|
|
96
|
+
arg_names = inspect.getfullargspec(func).args
|
|
97
|
+
args_dict = dict(zip(arg_names, args))
|
|
98
|
+
all_args = {**args_dict, **kwargs}
|
|
99
|
+
key_values = {key: all_args[key] for key in keys if key in all_args}
|
|
100
|
+
if not key_values:
|
|
101
|
+
return func(*args, **kwargs)
|
|
102
|
+
|
|
103
|
+
ident_name = identify((func.__name__, tuple(sorted(key_values.items()))))
|
|
104
|
+
try:
|
|
105
|
+
return ICACHE[ident_name]
|
|
106
|
+
except KeyError:
|
|
107
|
+
result = func(*args, **kwargs)
|
|
108
|
+
ICACHE[ident_name] = result
|
|
109
|
+
return result
|
|
110
|
+
|
|
111
|
+
return wrapper
|
|
112
|
+
|
|
113
|
+
return decorator
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def memoize_v2(keys: List[str], cache_dir: str = SPEED_CACHE_DIR) -> Callable:
|
|
117
|
+
"""Decorator to memoize function results based on specific keys."""
|
|
118
|
+
|
|
119
|
+
def decorator(func: Callable) -> Callable:
|
|
120
|
+
@functools.wraps(func)
|
|
121
|
+
def wrapper(*args, **kwargs):
|
|
122
|
+
args_key_values = {}
|
|
123
|
+
for i, arg in enumerate(args):
|
|
124
|
+
arg_name = inspect.getfullargspec(func).args[i]
|
|
125
|
+
args_key_values[arg_name] = arg
|
|
126
|
+
args_key_values.update(kwargs)
|
|
127
|
+
|
|
128
|
+
values = [args_key_values[key] for key in keys if key in args_key_values]
|
|
129
|
+
if not values:
|
|
130
|
+
return func(*args, **kwargs)
|
|
131
|
+
|
|
132
|
+
key_id = identify(values)
|
|
133
|
+
func_source = inspect.getsource(func).replace(" ", "")
|
|
134
|
+
func_id = identify(func_source)
|
|
135
|
+
key_names = "_".join(keys)
|
|
136
|
+
cache_path = osp.join(
|
|
137
|
+
cache_dir, f"{func.__name__}_{func_id}", f"{key_names}_{key_id}.pkl"
|
|
138
|
+
)
|
|
139
|
+
if osp.exists(cache_path):
|
|
140
|
+
return load_json_or_pickle(cache_path)
|
|
141
|
+
result = func(*args, **kwargs)
|
|
142
|
+
dump_json_or_pickle(result, cache_path)
|
|
143
|
+
return result
|
|
144
|
+
|
|
145
|
+
return wrapper
|
|
146
|
+
|
|
147
|
+
return decorator
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def memoize_method(method):
|
|
151
|
+
"""
|
|
152
|
+
Decorator function to memoize (cache) results of a class method.
|
|
153
|
+
|
|
154
|
+
This decorator caches the output of the wrapped method based on its input arguments
|
|
155
|
+
(both positional and keyword). If the method is called again with the same arguments,
|
|
156
|
+
it returns the cached result instead of executing the method again.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
method (Callable): The decorated method whose result will be memoized.
|
|
160
|
+
"""
|
|
161
|
+
cache = {}
|
|
162
|
+
|
|
163
|
+
def cached_method(cls, *args, **kwargs):
|
|
164
|
+
cache_key = identify([args, kwargs])
|
|
165
|
+
logger.debug("HIT" if cache_key in cache else "MISS")
|
|
166
|
+
if cache_key not in cache:
|
|
167
|
+
cache[cache_key] = method(cls, *args, **kwargs)
|
|
168
|
+
return cache[cache_key]
|
|
169
|
+
|
|
170
|
+
return cached_method
|
|
@@ -0,0 +1,111 @@
|
|
|
1
|
+
# utils/utils_io.py
|
|
2
|
+
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import os.path as osp
|
|
6
|
+
import pickle
|
|
7
|
+
from glob import glob
|
|
8
|
+
from typing import Any, List, Dict, Union
|
|
9
|
+
|
|
10
|
+
from .utils_misc import mkdir_or_exist
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def dump_jsonl(list_dictionaries: List[Dict], file_name: str = "output.jsonl") -> None:
|
|
14
|
+
"""
|
|
15
|
+
Dumps a list of dictionaries to a file in JSON Lines format.
|
|
16
|
+
"""
|
|
17
|
+
with open(file_name, "w", encoding="utf-8") as file:
|
|
18
|
+
for dictionary in list_dictionaries:
|
|
19
|
+
file.write(json.dumps(dictionary, ensure_ascii=False) + "\n")
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def dump_json_or_pickle(
|
|
23
|
+
obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4
|
|
24
|
+
) -> None:
|
|
25
|
+
"""
|
|
26
|
+
Dump an object to a file, supporting both JSON and pickle formats.
|
|
27
|
+
"""
|
|
28
|
+
mkdir_or_exist(osp.abspath(os.path.dirname(osp.abspath(fname))))
|
|
29
|
+
if fname.endswith(".json"):
|
|
30
|
+
with open(fname, "w", encoding="utf-8") as f:
|
|
31
|
+
json.dump(obj, f, ensure_ascii=ensure_ascii, indent=indent)
|
|
32
|
+
elif fname.endswith(".jsonl"):
|
|
33
|
+
dump_jsonl(obj, fname)
|
|
34
|
+
elif fname.endswith(".pkl"):
|
|
35
|
+
with open(fname, "wb") as f:
|
|
36
|
+
pickle.dump(obj, f)
|
|
37
|
+
else:
|
|
38
|
+
raise NotImplementedError(f"File type {fname} not supported")
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def load_json_or_pickle(fname: str) -> Any:
|
|
42
|
+
"""
|
|
43
|
+
Load an object from a file, supporting both JSON and pickle formats.
|
|
44
|
+
"""
|
|
45
|
+
if fname.endswith(".json") or fname.endswith(".jsonl"):
|
|
46
|
+
with open(fname, "r", encoding="utf-8") as f:
|
|
47
|
+
return json.load(f)
|
|
48
|
+
else:
|
|
49
|
+
with open(fname, "rb") as f:
|
|
50
|
+
return pickle.load(f)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def load_by_ext(
|
|
54
|
+
fname: Union[str, List[str]], do_memoize: bool = False
|
|
55
|
+
) -> Any:
|
|
56
|
+
"""
|
|
57
|
+
Load data based on file extension.
|
|
58
|
+
"""
|
|
59
|
+
from .utils_cache import (
|
|
60
|
+
memoize,
|
|
61
|
+
) # Adjust import based on your actual multi_worker module
|
|
62
|
+
|
|
63
|
+
from speedy import multi_process # Ensure multi_worker is correctly referenced
|
|
64
|
+
|
|
65
|
+
try:
|
|
66
|
+
if isinstance(fname, str) and "*" in fname:
|
|
67
|
+
paths = glob(fname)
|
|
68
|
+
paths = sorted(paths)
|
|
69
|
+
return multi_process(load_by_ext, paths, workers=16)
|
|
70
|
+
elif isinstance(fname, list):
|
|
71
|
+
paths = fname
|
|
72
|
+
return multi_process(load_by_ext, paths, workers=16)
|
|
73
|
+
|
|
74
|
+
def load_csv(path: str, **pd_kwargs) -> Any:
|
|
75
|
+
import pandas as pd
|
|
76
|
+
|
|
77
|
+
return pd.read_csv(path, engine="pyarrow", **pd_kwargs)
|
|
78
|
+
|
|
79
|
+
def load_txt(path: str) -> List[str]:
|
|
80
|
+
with open(path, "r", encoding="utf-8") as f:
|
|
81
|
+
return f.read().splitlines()
|
|
82
|
+
|
|
83
|
+
def load_default(path: str) -> Any:
|
|
84
|
+
if path.endswith(".jsonl") or path.endswith(".json"):
|
|
85
|
+
try:
|
|
86
|
+
return load_json_or_pickle(path)
|
|
87
|
+
except json.JSONDecodeError as exc:
|
|
88
|
+
raise ValueError("JSON decoding failed") from exc
|
|
89
|
+
return load_json_or_pickle(path)
|
|
90
|
+
|
|
91
|
+
handlers = {
|
|
92
|
+
".csv": load_csv,
|
|
93
|
+
".tsv": load_csv,
|
|
94
|
+
".txt": load_txt,
|
|
95
|
+
".pkl": load_default,
|
|
96
|
+
".json": load_default,
|
|
97
|
+
".jsonl": load_default,
|
|
98
|
+
}
|
|
99
|
+
|
|
100
|
+
ext = os.path.splitext(fname)[-1]
|
|
101
|
+
load_fn = handlers.get(ext)
|
|
102
|
+
|
|
103
|
+
if not load_fn:
|
|
104
|
+
raise NotImplementedError(f"File type {ext} not supported")
|
|
105
|
+
|
|
106
|
+
if do_memoize:
|
|
107
|
+
load_fn = memoize(load_fn)
|
|
108
|
+
|
|
109
|
+
return load_fn(fname)
|
|
110
|
+
except Exception as e:
|
|
111
|
+
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
# utils/utils_misc.py
|
|
2
|
+
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
import sys
|
|
6
|
+
from typing import Any, Callable, List
|
|
7
|
+
from IPython import get_ipython
|
|
8
|
+
from openai import BaseModel
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def mkdir_or_exist(dir_name: str) -> None:
|
|
12
|
+
"""Create a directory if it doesn't exist."""
|
|
13
|
+
os.makedirs(dir_name, exist_ok=True)
|
|
14
|
+
|
|
15
|
+
def flatten_list(list_of_lists: List[List[Any]]) -> List[Any]:
|
|
16
|
+
"""Flatten a list of lists into a single list."""
|
|
17
|
+
return [item for sublist in list_of_lists for item in sublist]
|
|
18
|
+
|
|
19
|
+
def get_arg_names(func: Callable) -> List[str]:
|
|
20
|
+
"""Retrieve argument names of a function."""
|
|
21
|
+
return inspect.getfullargspec(func).args
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
def is_interactive() -> bool:
|
|
26
|
+
"""Check if the environment is interactive (e.g., Jupyter notebook)."""
|
|
27
|
+
try:
|
|
28
|
+
get_ipython()
|
|
29
|
+
return True
|
|
30
|
+
except NameError:
|
|
31
|
+
return len(sys.argv) == 1
|
|
32
|
+
|
|
33
|
+
|
|
34
|
+
def convert_to_builtin_python(input_data: Any) -> Any:
|
|
35
|
+
"""Convert input data to built-in Python types."""
|
|
36
|
+
if isinstance(input_data, dict):
|
|
37
|
+
return {k: convert_to_builtin_python(v) for k, v in input_data.items()}
|
|
38
|
+
elif isinstance(input_data, list):
|
|
39
|
+
return [convert_to_builtin_python(v) for v in input_data]
|
|
40
|
+
elif isinstance(input_data, (int, float, str, bool, type(None))):
|
|
41
|
+
return input_data
|
|
42
|
+
elif isinstance(input_data, BaseModel):
|
|
43
|
+
data = input_data.model_dump_json()
|
|
44
|
+
return convert_to_builtin_python(data)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(f"Unsupported type {type(input_data)}")
|
|
@@ -0,0 +1,138 @@
|
|
|
1
|
+
# utils/utils_print.py
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import json
|
|
5
|
+
import pprint
|
|
6
|
+
import textwrap
|
|
7
|
+
from typing import Any, Dict, List, Optional
|
|
8
|
+
|
|
9
|
+
from IPython.display import HTML, display
|
|
10
|
+
from pandas import get_option
|
|
11
|
+
from tabulate import tabulate
|
|
12
|
+
|
|
13
|
+
# from .utils_cache import is_interactive_env # Adjust based on actual implementation
|
|
14
|
+
# from .utils_misc import is_interactive
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
def display_pretty_table_html(data: Dict) -> None:
|
|
18
|
+
"""
|
|
19
|
+
Display a pretty HTML table in Jupyter notebooks.
|
|
20
|
+
"""
|
|
21
|
+
table = "<table>"
|
|
22
|
+
for key, value in data.items():
|
|
23
|
+
table += f"<tr><td>{key}</td><td>{value}</td></tr>"
|
|
24
|
+
table += "</table>"
|
|
25
|
+
display(HTML(table))
|
|
26
|
+
|
|
27
|
+
def fprint(
|
|
28
|
+
input_data: Any,
|
|
29
|
+
key_ignore: Optional[List[str]] = None,
|
|
30
|
+
key_keep: Optional[List[str]] = None,
|
|
31
|
+
max_width: int = 100,
|
|
32
|
+
indent: int = 2,
|
|
33
|
+
depth: Optional[int] = None,
|
|
34
|
+
table_format: str = "grid",
|
|
35
|
+
str_wrap_width: int = 80,
|
|
36
|
+
) -> None:
|
|
37
|
+
"""
|
|
38
|
+
Pretty print structured data.
|
|
39
|
+
"""
|
|
40
|
+
|
|
41
|
+
def is_interactive_env():
|
|
42
|
+
"""Check if the environment is interactive (e.g., Jupyter notebook)."""
|
|
43
|
+
try:
|
|
44
|
+
shell = get_option().__class__.__name__
|
|
45
|
+
return shell == "ZMQInteractiveShell"
|
|
46
|
+
except NameError:
|
|
47
|
+
return False
|
|
48
|
+
|
|
49
|
+
def remove_keys(d: Dict, keys: List[str]) -> Dict:
|
|
50
|
+
"""Remove specified keys from a dictionary."""
|
|
51
|
+
for key in keys:
|
|
52
|
+
parts = key.split(".")
|
|
53
|
+
sub_dict = d
|
|
54
|
+
for part in parts[:-1]:
|
|
55
|
+
sub_dict = sub_dict.get(part, {})
|
|
56
|
+
sub_dict.pop(parts[-1], None)
|
|
57
|
+
return d
|
|
58
|
+
|
|
59
|
+
def keep_keys(d: Dict, keys: List[str]) -> Dict:
|
|
60
|
+
"""Keep only specified keys in a dictionary."""
|
|
61
|
+
result = {}
|
|
62
|
+
for key in keys:
|
|
63
|
+
parts = key.split(".")
|
|
64
|
+
sub_source = d
|
|
65
|
+
sub_result = result
|
|
66
|
+
for part in parts[:-1]:
|
|
67
|
+
if part not in sub_source:
|
|
68
|
+
break
|
|
69
|
+
sub_result = sub_result.setdefault(part, {})
|
|
70
|
+
sub_source = sub_source[part]
|
|
71
|
+
else:
|
|
72
|
+
sub_result[parts[-1]] = copy.deepcopy(sub_source.get(parts[-1]))
|
|
73
|
+
return result
|
|
74
|
+
|
|
75
|
+
if hasattr(input_data, "to_dict"):
|
|
76
|
+
input_data = input_data.to_dict()
|
|
77
|
+
|
|
78
|
+
processed_data = copy.deepcopy(input_data)
|
|
79
|
+
|
|
80
|
+
if isinstance(processed_data, dict):
|
|
81
|
+
if key_keep is not None:
|
|
82
|
+
processed_data = keep_keys(processed_data, key_keep)
|
|
83
|
+
elif key_ignore is not None:
|
|
84
|
+
processed_data = remove_keys(processed_data, key_ignore)
|
|
85
|
+
|
|
86
|
+
if is_interactive_env():
|
|
87
|
+
display_pretty_table_html(processed_data)
|
|
88
|
+
return
|
|
89
|
+
|
|
90
|
+
if isinstance(processed_data, dict):
|
|
91
|
+
table = [[k, v] for k, v in processed_data.items()]
|
|
92
|
+
print(
|
|
93
|
+
tabulate(
|
|
94
|
+
table,
|
|
95
|
+
headers=["Key", "Value"],
|
|
96
|
+
tablefmt=table_format,
|
|
97
|
+
maxcolwidths=[None, max_width],
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
elif isinstance(processed_data, str):
|
|
101
|
+
wrapped_text = textwrap.fill(processed_data, width=str_wrap_width)
|
|
102
|
+
print(wrapped_text)
|
|
103
|
+
else:
|
|
104
|
+
printer = pprint.PrettyPrinter(width=max_width, indent=indent, depth=depth)
|
|
105
|
+
printer.pprint(processed_data)
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def print_table(data: Any) -> None:
|
|
109
|
+
"""
|
|
110
|
+
Print data as a table.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
def __get_table(data: Any) -> str:
|
|
114
|
+
if isinstance(data, str):
|
|
115
|
+
try:
|
|
116
|
+
data = json.loads(data)
|
|
117
|
+
except json.JSONDecodeError as exc:
|
|
118
|
+
raise ValueError("String input could not be decoded as JSON") from exc
|
|
119
|
+
|
|
120
|
+
if isinstance(data, list):
|
|
121
|
+
if all(isinstance(item, dict) for item in data):
|
|
122
|
+
headers = list(data[0].keys())
|
|
123
|
+
rows = [list(item.values()) for item in data]
|
|
124
|
+
return tabulate(rows, headers=headers)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError("List must contain dictionaries")
|
|
127
|
+
|
|
128
|
+
if isinstance(data, dict):
|
|
129
|
+
headers = ["Key", "Value"]
|
|
130
|
+
rows = list(data.items())
|
|
131
|
+
return tabulate(rows, headers=headers)
|
|
132
|
+
|
|
133
|
+
raise TypeError(
|
|
134
|
+
"Input data must be a list of dictionaries, a dictionary, or a JSON string"
|
|
135
|
+
)
|
|
136
|
+
|
|
137
|
+
table = __get_table(data)
|
|
138
|
+
print(table)
|
speedy/multi_worker.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import os
|
|
3
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
4
|
+
from multiprocessing import Pool
|
|
5
|
+
from typing import Any, Callable, List
|
|
6
|
+
import asyncio
|
|
7
|
+
from loguru import logger
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def multi_thread(
|
|
12
|
+
func: Callable,
|
|
13
|
+
inputs: List[Any],
|
|
14
|
+
workers: int = 4,
|
|
15
|
+
verbose: bool = True,
|
|
16
|
+
desc: str | None = None,
|
|
17
|
+
) -> List[Any]:
|
|
18
|
+
if desc is None:
|
|
19
|
+
fn_name = func.__name__
|
|
20
|
+
try:
|
|
21
|
+
source_file = inspect.getsourcefile(func) or "<string>"
|
|
22
|
+
source_line = inspect.getsourcelines(func)[1]
|
|
23
|
+
file_line = f"{source_file}:{source_line}"
|
|
24
|
+
except (TypeError, OSError):
|
|
25
|
+
file_line = "Unknown location"
|
|
26
|
+
desc = f"{fn_name} at {file_line}"
|
|
27
|
+
|
|
28
|
+
with ThreadPoolExecutor(max_workers=workers) as executor:
|
|
29
|
+
# Use executor.map to apply func to inputs in order
|
|
30
|
+
map_func = executor.map(func, inputs)
|
|
31
|
+
if verbose:
|
|
32
|
+
results = list(tqdm(map_func, total=len(inputs), desc=desc))
|
|
33
|
+
else:
|
|
34
|
+
results = list(map_func)
|
|
35
|
+
return results
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
def _init_pool_processes(func):
|
|
39
|
+
global _func
|
|
40
|
+
_func = func
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _pool_process_executor(args):
|
|
44
|
+
# Unpack arguments if necessary
|
|
45
|
+
if isinstance(args, tuple):
|
|
46
|
+
return _func(*args)
|
|
47
|
+
else:
|
|
48
|
+
return _func(args)
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def multi_process(
|
|
52
|
+
func: Callable,
|
|
53
|
+
inputs: List[Any],
|
|
54
|
+
workers: int = 16,
|
|
55
|
+
verbose: bool = True,
|
|
56
|
+
desc: str = "",
|
|
57
|
+
) -> List[Any]:
|
|
58
|
+
if not desc:
|
|
59
|
+
fn_name = func.__name__
|
|
60
|
+
try:
|
|
61
|
+
source_file = inspect.getsourcefile(func) or "<string>"
|
|
62
|
+
source_line = inspect.getsourcelines(func)[1]
|
|
63
|
+
file_line = f"{source_file}:{source_line}"
|
|
64
|
+
except (TypeError, OSError):
|
|
65
|
+
file_line = "Unknown location"
|
|
66
|
+
desc = f"{fn_name} at {file_line}"
|
|
67
|
+
|
|
68
|
+
if os.environ.get("DEBUG", "0") == "1":
|
|
69
|
+
logger.info("DEBUGGING set num workers to 1")
|
|
70
|
+
workers = 1
|
|
71
|
+
|
|
72
|
+
logger.info("Multi-processing {} | Num samples: {}", desc, len(inputs))
|
|
73
|
+
|
|
74
|
+
results = []
|
|
75
|
+
with Pool(
|
|
76
|
+
processes=workers, initializer=_init_pool_processes, initargs=(func,)
|
|
77
|
+
) as pool:
|
|
78
|
+
try:
|
|
79
|
+
if verbose:
|
|
80
|
+
for result in tqdm(
|
|
81
|
+
pool.imap(_pool_process_executor, inputs),
|
|
82
|
+
total=len(inputs),
|
|
83
|
+
desc=desc,
|
|
84
|
+
):
|
|
85
|
+
results.append(result)
|
|
86
|
+
else:
|
|
87
|
+
results = pool.map(_pool_process_executor, inputs)
|
|
88
|
+
except Exception as e:
|
|
89
|
+
logger.error(f"[multiprocess] Error {e}")
|
|
90
|
+
|
|
91
|
+
return results
|
|
92
|
+
|
|
93
|
+
|
|
94
|
+
async def async_multi_thread(f, inputs, desc="", user_tqdm=True):
|
|
95
|
+
"""
|
|
96
|
+
Uasge:
|
|
97
|
+
inputs = list(range(10))
|
|
98
|
+
def function(i):
|
|
99
|
+
time.sleep(1)
|
|
100
|
+
return 1/i
|
|
101
|
+
results = await amulti_thread(function, inputs)
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
def ensure_output_idx(idx_i):
|
|
105
|
+
idx, i = idx_i
|
|
106
|
+
return idx, f(i)
|
|
107
|
+
|
|
108
|
+
tasks = [asyncio.to_thread(ensure_output_idx, i) for i in enumerate(inputs)]
|
|
109
|
+
if not desc:
|
|
110
|
+
desc = f"{f.__name__}"
|
|
111
|
+
|
|
112
|
+
pbar = tqdm(total=len(inputs), desc=desc, disable=not user_tqdm)
|
|
113
|
+
results = [None] * len(inputs)
|
|
114
|
+
for task in asyncio.as_completed(tasks):
|
|
115
|
+
idx, result = await task
|
|
116
|
+
results[idx] = result
|
|
117
|
+
pbar.update(1)
|
|
118
|
+
return results
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
__all__ = ["multi_thread", "multi_process", "async_multi_thread"]
|
|
@@ -0,0 +1,22 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: speedy-utils
|
|
3
|
+
Version: 1.0.2
|
|
4
|
+
Summary: Fast and easy-to-use package for data science
|
|
5
|
+
Home-page: https://github.com/anhvth/speedy
|
|
6
|
+
Author: AnhVTH
|
|
7
|
+
Author-email: anhvth.226@gmail.com
|
|
8
|
+
Requires-Dist: numpy
|
|
9
|
+
Requires-Dist: requests
|
|
10
|
+
Requires-Dist: xxhash
|
|
11
|
+
Requires-Dist: loguru
|
|
12
|
+
Requires-Dist: fastcore
|
|
13
|
+
Requires-Dist: debugpy
|
|
14
|
+
Requires-Dist: ipywidgets
|
|
15
|
+
Requires-Dist: jupyterlab
|
|
16
|
+
Requires-Dist: ipdb
|
|
17
|
+
Requires-Dist: scikit-learn
|
|
18
|
+
Requires-Dist: matplotlib
|
|
19
|
+
Requires-Dist: pandas
|
|
20
|
+
Requires-Dist: tabulate
|
|
21
|
+
Requires-Dist: pydantic
|
|
22
|
+
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
speedy/__init__.py,sha256=BPGiT3ehXK3poEFR-j6dNejnB9GBvkX6IR7ANdD-Qy4,1127
|
|
2
|
+
speedy/multi_worker.py,sha256=kmk_Km6LkOUVntxmPKSYubMXFJBPoLoZW9NVi4UA9kc,3394
|
|
3
|
+
speedy/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
4
|
+
speedy/common/clock.py,sha256=q92SvdJLZyCtuq8C6m7GjS6_foI8yyspsKzeR_M1owo,2033
|
|
5
|
+
speedy/common/utils_cache.py,sha256=_qxGyISJBM-r0NDSUCYSEWnEsqprORdw3yXmIkns1D4,5656
|
|
6
|
+
speedy/common/utils_io.py,sha256=X2zFgVzfYM094auApGjYOO43NgnG5P_g7rOObzPYTrQ,3472
|
|
7
|
+
speedy/common/utils_misc.py,sha256=ooMb0xEjC-HrQSQQCTRlEqcrGTJptmA43azSi6BhiD4,1479
|
|
8
|
+
speedy/common/utils_print.py,sha256=jh_ihzWTEiPpFRiLoj6It8Zo_3tPxYi8MU1J3Nw3vwk,4273
|
|
9
|
+
speedy_utils-1.0.2.dist-info/METADATA,sha256=zCKCFD7z_s67PE93IKWCDk8-fCbgv9pXZSF6uquQyZI,538
|
|
10
|
+
speedy_utils-1.0.2.dist-info/WHEEL,sha256=GV9aMThwP_4oNCtvEC2ec3qUYutgWeAzklro_0m4WJQ,91
|
|
11
|
+
speedy_utils-1.0.2.dist-info/top_level.txt,sha256=eJxFW_gum7StgovqwA4v-9UndgnnWr4kUqcozY-aBmI,7
|
|
12
|
+
speedy_utils-1.0.2.dist-info/RECORD,,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
speedy
|