speedy-utils 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +29 -0
- llm_utils/chat_format.py +427 -0
- llm_utils/group_messages.py +120 -0
- llm_utils/lm/__init__.py +8 -0
- llm_utils/lm/base_lm.py +304 -0
- llm_utils/lm/utils.py +130 -0
- llm_utils/scripts/vllm_load_balancer.py +353 -0
- llm_utils/scripts/vllm_serve.py +416 -0
- speedy_utils/__init__.py +85 -0
- speedy_utils/all.py +159 -0
- {speedy → speedy_utils}/common/__init__.py +0 -0
- speedy_utils/common/clock.py +215 -0
- speedy_utils/common/function_decorator.py +66 -0
- speedy_utils/common/logger.py +207 -0
- speedy_utils/common/report_manager.py +112 -0
- speedy_utils/common/utils_cache.py +264 -0
- {speedy → speedy_utils}/common/utils_io.py +66 -19
- {speedy → speedy_utils}/common/utils_misc.py +25 -11
- speedy_utils/common/utils_print.py +216 -0
- speedy_utils/multi_worker/__init__.py +0 -0
- speedy_utils/multi_worker/process.py +198 -0
- speedy_utils/multi_worker/thread.py +327 -0
- speedy_utils/scripts/mpython.py +108 -0
- speedy_utils-1.0.5.dist-info/METADATA +279 -0
- speedy_utils-1.0.5.dist-info/RECORD +27 -0
- {speedy_utils-1.0.3.dist-info → speedy_utils-1.0.5.dist-info}/WHEEL +1 -2
- speedy_utils-1.0.5.dist-info/entry_points.txt +3 -0
- speedy/__init__.py +0 -53
- speedy/common/clock.py +0 -68
- speedy/common/utils_cache.py +0 -170
- speedy/common/utils_print.py +0 -138
- speedy/multi_worker.py +0 -121
- speedy_utils-1.0.3.dist-info/METADATA +0 -22
- speedy_utils-1.0.3.dist-info/RECORD +0 -12
- speedy_utils-1.0.3.dist-info/top_level.txt +0 -1
|
@@ -0,0 +1,264 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import inspect
|
|
3
|
+
import json
|
|
4
|
+
import os
|
|
5
|
+
import os.path as osp
|
|
6
|
+
import pickle
|
|
7
|
+
import uuid
|
|
8
|
+
from typing import Any, List, Literal
|
|
9
|
+
|
|
10
|
+
import cachetools
|
|
11
|
+
import pandas as pd
|
|
12
|
+
import xxhash
|
|
13
|
+
from loguru import logger
|
|
14
|
+
from pydantic import BaseModel
|
|
15
|
+
|
|
16
|
+
from .utils_io import dump_json_or_pickle, load_json_or_pickle
|
|
17
|
+
from .utils_misc import mkdir_or_exist
|
|
18
|
+
|
|
19
|
+
SPEED_CACHE_DIR = osp.join(osp.expanduser("~"), ".cache/speedy_cache")
|
|
20
|
+
LRU_MEM_CACHE = cachetools.LRUCache(maxsize=128_000)
|
|
21
|
+
from threading import Lock
|
|
22
|
+
|
|
23
|
+
thread_locker = Lock()
|
|
24
|
+
|
|
25
|
+
# Add two locks for thread-safe cache access
|
|
26
|
+
disk_lock = Lock()
|
|
27
|
+
mem_lock = Lock()
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def compute_func_id(func, args, kwargs, ignore_self, keys):
|
|
31
|
+
func_source = get_source(func)
|
|
32
|
+
if keys:
|
|
33
|
+
arg_spec = inspect.getfullargspec(func).args
|
|
34
|
+
used_args = {arg_spec[i]: arg for i, arg in enumerate(args)}
|
|
35
|
+
used_args.update(kwargs)
|
|
36
|
+
values = [used_args[k] for k in keys if k in used_args]
|
|
37
|
+
if not values:
|
|
38
|
+
raise ValueError(f"Keys {keys} not found in function arguments")
|
|
39
|
+
param_hash = identify(values)
|
|
40
|
+
dir_path = f"{func.__name__}_{identify(func_source)}"
|
|
41
|
+
key_id = f"{'_'.join(keys)}_{param_hash}.pkl"
|
|
42
|
+
return func_source, dir_path, key_id
|
|
43
|
+
|
|
44
|
+
if (
|
|
45
|
+
inspect.getfullargspec(func).args
|
|
46
|
+
and inspect.getfullargspec(func).args[0] == "self"
|
|
47
|
+
and ignore_self
|
|
48
|
+
):
|
|
49
|
+
fid = (func_source, args[1:], kwargs)
|
|
50
|
+
else:
|
|
51
|
+
fid = (func_source, args, kwargs)
|
|
52
|
+
return func_source, "funcs", f"{identify(fid)}.pkl"
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
def fast_serialize(x: Any) -> bytes:
|
|
56
|
+
try:
|
|
57
|
+
return json.dumps(x, sort_keys=True).encode("utf-8")
|
|
58
|
+
except (TypeError, ValueError):
|
|
59
|
+
return pickle.dumps(x, protocol=pickle.HIGHEST_PROTOCOL)
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def identify(obj: Any, depth=0, max_depth=2) -> str:
|
|
63
|
+
if isinstance(obj, (list, tuple)):
|
|
64
|
+
x = [identify(x, depth + 1, max_depth) for x in obj]
|
|
65
|
+
x = "\n".join(x)
|
|
66
|
+
return identify(x, depth + 1, max_depth)
|
|
67
|
+
# is pandas row or dict
|
|
68
|
+
elif isinstance(obj, (pd.DataFrame, pd.Series)):
|
|
69
|
+
x = str(obj.to_dict())
|
|
70
|
+
return identify(x, depth + 1, max_depth)
|
|
71
|
+
elif hasattr(obj, "__code__"):
|
|
72
|
+
return identify(get_source(obj), depth + 1, max_depth)
|
|
73
|
+
elif isinstance(obj, BaseModel):
|
|
74
|
+
obj = obj.model_dump()
|
|
75
|
+
return identify(obj, depth + 1, max_depth)
|
|
76
|
+
elif isinstance(obj, dict):
|
|
77
|
+
ks = sorted(obj.keys())
|
|
78
|
+
vs = [identify(obj[k], depth + 1, max_depth) for k in ks]
|
|
79
|
+
return identify([ks, vs], depth + 1, max_depth)
|
|
80
|
+
elif obj is None:
|
|
81
|
+
return identify("None", depth + 1, max_depth)
|
|
82
|
+
else:
|
|
83
|
+
primitive_types = [int, float, str, bool]
|
|
84
|
+
if not type(obj) in primitive_types:
|
|
85
|
+
logger.warning(f"Unknown type: {type(obj)}")
|
|
86
|
+
return xxhash.xxh64_hexdigest(fast_serialize(obj), seed=0)
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
def identify_uuid(x: Any) -> str:
|
|
90
|
+
data = fast_serialize(x)
|
|
91
|
+
hash_obj = xxhash.xxh128(data, seed=0)
|
|
92
|
+
return str(uuid.UUID(bytes=hash_obj.digest()))
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def get_source(func):
|
|
96
|
+
code = inspect.getsource(func)
|
|
97
|
+
for r in [" ", "\n", "\t", "\r"]:
|
|
98
|
+
code = code.replace(r, "")
|
|
99
|
+
return code
|
|
100
|
+
|
|
101
|
+
|
|
102
|
+
def _disk_memoize(func, keys, cache_dir, ignore_self, verbose):
|
|
103
|
+
@functools.wraps(func)
|
|
104
|
+
def wrapper(*args, **kwargs):
|
|
105
|
+
try:
|
|
106
|
+
# Compute cache path as before
|
|
107
|
+
func_source, sub_dir, key_id = compute_func_id(
|
|
108
|
+
func, args, kwargs, ignore_self, keys
|
|
109
|
+
)
|
|
110
|
+
if func_source is None:
|
|
111
|
+
return func(*args, **kwargs)
|
|
112
|
+
if sub_dir == "funcs":
|
|
113
|
+
cache_path = osp.join(cache_dir, sub_dir, func.__name__, key_id)
|
|
114
|
+
else:
|
|
115
|
+
cache_path = osp.join(cache_dir, sub_dir, key_id)
|
|
116
|
+
mkdir_or_exist(osp.dirname(cache_path))
|
|
117
|
+
|
|
118
|
+
# First check with disk lock
|
|
119
|
+
with disk_lock:
|
|
120
|
+
if osp.exists(cache_path):
|
|
121
|
+
# logger.debug(f"Cache HIT for {func.__name__}, key={cache_path}")
|
|
122
|
+
try:
|
|
123
|
+
return load_json_or_pickle(cache_path)
|
|
124
|
+
except Exception as e:
|
|
125
|
+
if osp.exists(cache_path):
|
|
126
|
+
os.remove(cache_path)
|
|
127
|
+
logger.opt(depth=1).warning(
|
|
128
|
+
f"Error loading cache: {str(e)[:100]}, continue to recompute"
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
result = func(*args, **kwargs)
|
|
132
|
+
|
|
133
|
+
# Write result under disk lock to avoid race conditions
|
|
134
|
+
with disk_lock:
|
|
135
|
+
if not osp.exists(cache_path):
|
|
136
|
+
dump_json_or_pickle(result, cache_path)
|
|
137
|
+
return result
|
|
138
|
+
except Exception as e:
|
|
139
|
+
logger.opt(depth=1).warning(
|
|
140
|
+
f"Failed to cache {func.__name__}: {e}, continue to recompute without cache"
|
|
141
|
+
)
|
|
142
|
+
return func(*args, **kwargs)
|
|
143
|
+
|
|
144
|
+
return wrapper
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def _memory_memoize(func, size, keys, ignore_self):
|
|
148
|
+
global LRU_MEM_CACHE
|
|
149
|
+
if LRU_MEM_CACHE.maxsize != size:
|
|
150
|
+
LRU_MEM_CACHE = cachetools.LRUCache(maxsize=size)
|
|
151
|
+
|
|
152
|
+
@functools.wraps(func)
|
|
153
|
+
def wrapper(*args, **kwargs):
|
|
154
|
+
func_source, sub_dir, key_id = compute_func_id(
|
|
155
|
+
func, args, kwargs, ignore_self, keys
|
|
156
|
+
)
|
|
157
|
+
if func_source is None:
|
|
158
|
+
return func(*args, **kwargs)
|
|
159
|
+
name = identify((func_source, sub_dir, key_id))
|
|
160
|
+
|
|
161
|
+
if not hasattr(func, "_mem_cache"):
|
|
162
|
+
func._mem_cache = LRU_MEM_CACHE
|
|
163
|
+
|
|
164
|
+
with mem_lock:
|
|
165
|
+
if name in func._mem_cache:
|
|
166
|
+
# logger.debug(f"Cache HIT (memory) for {func.__name__}, key={name}")
|
|
167
|
+
return func._mem_cache[name]
|
|
168
|
+
|
|
169
|
+
result = func(*args, **kwargs)
|
|
170
|
+
|
|
171
|
+
with mem_lock:
|
|
172
|
+
if name not in func._mem_cache:
|
|
173
|
+
func._mem_cache[name] = result
|
|
174
|
+
return result
|
|
175
|
+
|
|
176
|
+
return wrapper
|
|
177
|
+
|
|
178
|
+
|
|
179
|
+
def both_memoize(func, keys, cache_dir, ignore_self):
|
|
180
|
+
@functools.wraps(func)
|
|
181
|
+
def wrapper(*args, **kwargs):
|
|
182
|
+
func_source, sub_dir, key_id = compute_func_id(
|
|
183
|
+
func, args, kwargs, ignore_self, keys
|
|
184
|
+
)
|
|
185
|
+
if func_source is None:
|
|
186
|
+
return func(*args, **kwargs)
|
|
187
|
+
|
|
188
|
+
mem_key = identify((func_source, sub_dir, key_id))
|
|
189
|
+
if not hasattr(func, "_mem_cache"):
|
|
190
|
+
func._mem_cache = LRU_MEM_CACHE
|
|
191
|
+
|
|
192
|
+
with mem_lock:
|
|
193
|
+
if mem_key in func._mem_cache:
|
|
194
|
+
# logger.debug(f"Cache HIT (memory) for {func.__name__}, key={mem_key}")
|
|
195
|
+
return func._mem_cache[mem_key]
|
|
196
|
+
|
|
197
|
+
if sub_dir == "funcs":
|
|
198
|
+
cache_path = osp.join(cache_dir, sub_dir, func.__name__, key_id)
|
|
199
|
+
else:
|
|
200
|
+
cache_path = osp.join(cache_dir, sub_dir, key_id)
|
|
201
|
+
mkdir_or_exist(osp.dirname(cache_path))
|
|
202
|
+
|
|
203
|
+
with disk_lock:
|
|
204
|
+
if osp.exists(cache_path):
|
|
205
|
+
# logger.debug(f"Cache HIT (disk) for {func.__name__}, key={cache_path}")
|
|
206
|
+
result = load_json_or_pickle(cache_path)
|
|
207
|
+
with mem_lock:
|
|
208
|
+
func._mem_cache[mem_key] = result
|
|
209
|
+
return result
|
|
210
|
+
# logger.debug(f"Cache MISS for {func.__name__}, key={cache_path}")
|
|
211
|
+
result = func(*args, **kwargs)
|
|
212
|
+
|
|
213
|
+
with disk_lock:
|
|
214
|
+
if not osp.exists(cache_path):
|
|
215
|
+
dump_json_or_pickle(result, cache_path)
|
|
216
|
+
with mem_lock:
|
|
217
|
+
func._mem_cache[mem_key] = result
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
return wrapper
|
|
221
|
+
|
|
222
|
+
|
|
223
|
+
def memoize(
|
|
224
|
+
_func=None,
|
|
225
|
+
*,
|
|
226
|
+
keys=None,
|
|
227
|
+
cache_dir=SPEED_CACHE_DIR,
|
|
228
|
+
cache_type: Literal["memory", "disk", "both"] = "disk",
|
|
229
|
+
size=10240,
|
|
230
|
+
ignore_self=True,
|
|
231
|
+
verbose=False,
|
|
232
|
+
):
|
|
233
|
+
if "~/" in cache_dir:
|
|
234
|
+
cache_dir = osp.expanduser(cache_dir)
|
|
235
|
+
|
|
236
|
+
def decorator(func):
|
|
237
|
+
if cache_type == "memory":
|
|
238
|
+
return _memory_memoize(
|
|
239
|
+
func,
|
|
240
|
+
size,
|
|
241
|
+
keys,
|
|
242
|
+
ignore_self,
|
|
243
|
+
)
|
|
244
|
+
elif cache_type == "disk":
|
|
245
|
+
return _disk_memoize(
|
|
246
|
+
func,
|
|
247
|
+
keys,
|
|
248
|
+
cache_dir,
|
|
249
|
+
ignore_self,
|
|
250
|
+
verbose,
|
|
251
|
+
)
|
|
252
|
+
return both_memoize(
|
|
253
|
+
func,
|
|
254
|
+
keys,
|
|
255
|
+
cache_dir,
|
|
256
|
+
verbose,
|
|
257
|
+
)
|
|
258
|
+
|
|
259
|
+
if _func is None:
|
|
260
|
+
return decorator
|
|
261
|
+
return decorator(_func)
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
__all__ = ["memoize", "identify", "identify_uuid"]
|
|
@@ -4,13 +4,17 @@ import json
|
|
|
4
4
|
import os
|
|
5
5
|
import os.path as osp
|
|
6
6
|
import pickle
|
|
7
|
+
import time
|
|
7
8
|
from glob import glob
|
|
8
|
-
from
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
from json_repair import loads as jloads
|
|
9
13
|
|
|
10
14
|
from .utils_misc import mkdir_or_exist
|
|
11
15
|
|
|
12
16
|
|
|
13
|
-
def dump_jsonl(list_dictionaries:
|
|
17
|
+
def dump_jsonl(list_dictionaries: list[dict], file_name: str = "output.jsonl") -> None:
|
|
14
18
|
"""
|
|
15
19
|
Dumps a list of dictionaries to a file in JSON Lines format.
|
|
16
20
|
"""
|
|
@@ -25,10 +29,20 @@ def dump_json_or_pickle(
|
|
|
25
29
|
"""
|
|
26
30
|
Dump an object to a file, supporting both JSON and pickle formats.
|
|
27
31
|
"""
|
|
32
|
+
if isinstance(fname, Path):
|
|
33
|
+
fname = str(fname)
|
|
28
34
|
mkdir_or_exist(osp.abspath(os.path.dirname(osp.abspath(fname))))
|
|
29
35
|
if fname.endswith(".json"):
|
|
30
36
|
with open(fname, "w", encoding="utf-8") as f:
|
|
31
|
-
|
|
37
|
+
try:
|
|
38
|
+
json.dump(obj, f, ensure_ascii=ensure_ascii, indent=indent)
|
|
39
|
+
# TypeError: Object of type datetime is not JSON serializable
|
|
40
|
+
except TypeError:
|
|
41
|
+
print(
|
|
42
|
+
"Error: Object of type datetime is not JSON serializable",
|
|
43
|
+
str(obj)[:1000],
|
|
44
|
+
)
|
|
45
|
+
raise
|
|
32
46
|
elif fname.endswith(".jsonl"):
|
|
33
47
|
dump_jsonl(obj, fname)
|
|
34
48
|
elif fname.endswith(".pkl"):
|
|
@@ -38,29 +52,45 @@ def dump_json_or_pickle(
|
|
|
38
52
|
raise NotImplementedError(f"File type {fname} not supported")
|
|
39
53
|
|
|
40
54
|
|
|
41
|
-
def load_json_or_pickle(fname: str) -> Any:
|
|
55
|
+
def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
42
56
|
"""
|
|
43
57
|
Load an object from a file, supporting both JSON and pickle formats.
|
|
44
58
|
"""
|
|
45
59
|
if fname.endswith(".json") or fname.endswith(".jsonl"):
|
|
46
|
-
with open(fname,
|
|
60
|
+
with open(fname, encoding="utf-8") as f:
|
|
47
61
|
return json.load(f)
|
|
48
62
|
else:
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
63
|
+
try:
|
|
64
|
+
with open(fname, "rb") as f:
|
|
65
|
+
return pickle.load(f)
|
|
66
|
+
# EOFError: Ran out of input
|
|
67
|
+
except EOFError:
|
|
68
|
+
time.sleep(1)
|
|
69
|
+
if counter > 5:
|
|
70
|
+
print("Error: Ran out of input", fname)
|
|
71
|
+
os.remove(fname)
|
|
72
|
+
raise
|
|
73
|
+
return load_json_or_pickle(fname, counter + 1)
|
|
74
|
+
except Exception as e:
|
|
75
|
+
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def load_jsonl(path):
|
|
79
|
+
lines = open(path, encoding="utf-8").read().splitlines()
|
|
80
|
+
return [json.loads(line) for line in lines]
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def load_by_ext(fname: str | list[str], do_memoize: bool = False) -> Any:
|
|
56
84
|
"""
|
|
57
85
|
Load data based on file extension.
|
|
58
86
|
"""
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
87
|
+
if isinstance(fname, Path):
|
|
88
|
+
fname = str(fname)
|
|
89
|
+
from speedy_utils import multi_process
|
|
62
90
|
|
|
63
|
-
from
|
|
91
|
+
from .utils_cache import ( # Adjust import based on your actual multi_worker module
|
|
92
|
+
memoize,
|
|
93
|
+
)
|
|
64
94
|
|
|
65
95
|
try:
|
|
66
96
|
if isinstance(fname, str) and "*" in fname:
|
|
@@ -76,12 +106,14 @@ def load_by_ext(
|
|
|
76
106
|
|
|
77
107
|
return pd.read_csv(path, engine="pyarrow", **pd_kwargs)
|
|
78
108
|
|
|
79
|
-
def load_txt(path: str) ->
|
|
80
|
-
with open(path,
|
|
109
|
+
def load_txt(path: str) -> list[str]:
|
|
110
|
+
with open(path, encoding="utf-8") as f:
|
|
81
111
|
return f.read().splitlines()
|
|
82
112
|
|
|
83
113
|
def load_default(path: str) -> Any:
|
|
84
|
-
if path.endswith(".jsonl")
|
|
114
|
+
if path.endswith(".jsonl"):
|
|
115
|
+
return load_jsonl(path)
|
|
116
|
+
elif path.endswith(".json"):
|
|
85
117
|
try:
|
|
86
118
|
return load_json_or_pickle(path)
|
|
87
119
|
except json.JSONDecodeError as exc:
|
|
@@ -109,3 +141,18 @@ def load_by_ext(
|
|
|
109
141
|
return load_fn(fname)
|
|
110
142
|
except Exception as e:
|
|
111
143
|
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
144
|
+
|
|
145
|
+
|
|
146
|
+
def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
|
|
147
|
+
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, **kwargs)
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
__all__ = [
|
|
151
|
+
"dump_json_or_pickle",
|
|
152
|
+
"dump_jsonl",
|
|
153
|
+
"load_by_ext",
|
|
154
|
+
"load_json_or_pickle",
|
|
155
|
+
"load_jsonl",
|
|
156
|
+
"jdumps",
|
|
157
|
+
"jloads",
|
|
158
|
+
]
|
|
@@ -3,32 +3,37 @@
|
|
|
3
3
|
import inspect
|
|
4
4
|
import os
|
|
5
5
|
import sys
|
|
6
|
-
from
|
|
7
|
-
from
|
|
8
|
-
|
|
6
|
+
from collections.abc import Callable
|
|
7
|
+
from typing import Any, List
|
|
8
|
+
|
|
9
|
+
from pydantic import BaseModel
|
|
9
10
|
|
|
10
11
|
|
|
11
12
|
def mkdir_or_exist(dir_name: str) -> None:
|
|
12
13
|
"""Create a directory if it doesn't exist."""
|
|
13
14
|
os.makedirs(dir_name, exist_ok=True)
|
|
14
15
|
|
|
15
|
-
|
|
16
|
+
|
|
17
|
+
def flatten_list(list_of_lists: list[list[Any]]) -> list[Any]:
|
|
16
18
|
"""Flatten a list of lists into a single list."""
|
|
17
19
|
return [item for sublist in list_of_lists for item in sublist]
|
|
18
20
|
|
|
19
|
-
|
|
21
|
+
|
|
22
|
+
def get_arg_names(func: Callable) -> list[str]:
|
|
20
23
|
"""Retrieve argument names of a function."""
|
|
21
24
|
return inspect.getfullargspec(func).args
|
|
22
25
|
|
|
23
26
|
|
|
24
|
-
|
|
25
|
-
def is_interactive() -> bool:
|
|
26
|
-
"""Check if the environment is interactive (e.g., Jupyter notebook)."""
|
|
27
|
+
def is_notebook() -> bool:
|
|
27
28
|
try:
|
|
28
|
-
get_ipython()
|
|
29
|
-
|
|
29
|
+
if "get_ipython" in globals().keys():
|
|
30
|
+
get_ipython = globals()["get_ipython"]
|
|
31
|
+
shell = get_ipython().__class__.__name__
|
|
32
|
+
if shell == "ZMQInteractiveShell":
|
|
33
|
+
return True # Jupyter notebook or qtconsole
|
|
34
|
+
return False # Other type (?)
|
|
30
35
|
except NameError:
|
|
31
|
-
return
|
|
36
|
+
return False # Probably standard Python interpreter
|
|
32
37
|
|
|
33
38
|
|
|
34
39
|
def convert_to_builtin_python(input_data: Any) -> Any:
|
|
@@ -44,3 +49,12 @@ def convert_to_builtin_python(input_data: Any) -> Any:
|
|
|
44
49
|
return convert_to_builtin_python(data)
|
|
45
50
|
else:
|
|
46
51
|
raise ValueError(f"Unsupported type {type(input_data)}")
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
__all__ = [
|
|
55
|
+
"mkdir_or_exist",
|
|
56
|
+
"flatten_list",
|
|
57
|
+
"get_arg_names",
|
|
58
|
+
"is_notebook",
|
|
59
|
+
"convert_to_builtin_python",
|
|
60
|
+
]
|
|
@@ -0,0 +1,216 @@
|
|
|
1
|
+
# utils/utils_print.py
|
|
2
|
+
|
|
3
|
+
import copy
|
|
4
|
+
import inspect
|
|
5
|
+
import json
|
|
6
|
+
import pprint
|
|
7
|
+
import re
|
|
8
|
+
import sys
|
|
9
|
+
import textwrap
|
|
10
|
+
import time
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
from typing import Annotated, Any, Dict, List, Literal, Optional
|
|
13
|
+
|
|
14
|
+
from IPython.display import HTML, display
|
|
15
|
+
from loguru import logger
|
|
16
|
+
from tabulate import tabulate
|
|
17
|
+
|
|
18
|
+
from .utils_misc import is_notebook
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def display_pretty_table_html(data: dict) -> None:
|
|
22
|
+
"""
|
|
23
|
+
Display a pretty HTML table in Jupyter notebooks.
|
|
24
|
+
"""
|
|
25
|
+
table = "<table>"
|
|
26
|
+
for key, value in data.items():
|
|
27
|
+
table += f"<tr><td>{key}</td><td>{value}</td></tr>"
|
|
28
|
+
table += "</table>"
|
|
29
|
+
display(HTML(table))
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
# Flattening the dictionary using "." notation for keys
|
|
33
|
+
def flatten_dict(d, parent_key="", sep="."):
|
|
34
|
+
items = []
|
|
35
|
+
for k, v in d.items():
|
|
36
|
+
new_key = parent_key + sep + k if parent_key else k
|
|
37
|
+
if isinstance(v, dict):
|
|
38
|
+
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
|
39
|
+
else:
|
|
40
|
+
items.append((new_key, v))
|
|
41
|
+
return dict(items)
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
def fprint(
|
|
45
|
+
input_data: Any,
|
|
46
|
+
key_ignore: list[str] | None = None,
|
|
47
|
+
key_keep: list[str] | None = None,
|
|
48
|
+
max_width: int = 100,
|
|
49
|
+
indent: int = 2,
|
|
50
|
+
depth: int | None = None,
|
|
51
|
+
table_format: str = "grid",
|
|
52
|
+
str_wrap_width: int = 80,
|
|
53
|
+
grep=None,
|
|
54
|
+
is_notebook=None,
|
|
55
|
+
f=print,
|
|
56
|
+
) -> None | str:
|
|
57
|
+
"""
|
|
58
|
+
Pretty print structured data.
|
|
59
|
+
"""
|
|
60
|
+
if isinstance(input_data, list):
|
|
61
|
+
for i, item in enumerate(input_data):
|
|
62
|
+
fprint(
|
|
63
|
+
item,
|
|
64
|
+
key_ignore,
|
|
65
|
+
key_keep,
|
|
66
|
+
max_width,
|
|
67
|
+
indent,
|
|
68
|
+
depth,
|
|
69
|
+
table_format,
|
|
70
|
+
str_wrap_width,
|
|
71
|
+
grep,
|
|
72
|
+
is_notebook,
|
|
73
|
+
f,
|
|
74
|
+
)
|
|
75
|
+
print("\n" + "-" * 100 + "\n")
|
|
76
|
+
|
|
77
|
+
from speedy_utils import is_notebook as is_interactive
|
|
78
|
+
|
|
79
|
+
# is_notebook = is_notebook or is_interactive()
|
|
80
|
+
if is_notebook is None:
|
|
81
|
+
is_notebook = is_interactive()
|
|
82
|
+
if isinstance(input_data, list):
|
|
83
|
+
if all(hasattr(item, "toDict") for item in input_data):
|
|
84
|
+
input_data = [item.toDict() for item in input_data]
|
|
85
|
+
elif hasattr(input_data, "toDict"):
|
|
86
|
+
input_data = input_data.toDict()
|
|
87
|
+
|
|
88
|
+
if isinstance(input_data, list):
|
|
89
|
+
if all(hasattr(item, "to_dict") for item in input_data):
|
|
90
|
+
input_data = [item.to_dict() for item in input_data]
|
|
91
|
+
elif hasattr(input_data, "to_dict"):
|
|
92
|
+
input_data = input_data.to_dict()
|
|
93
|
+
|
|
94
|
+
if isinstance(input_data, list):
|
|
95
|
+
if all(hasattr(item, "model_dump") for item in input_data):
|
|
96
|
+
input_data = [item.model_dump() for item in input_data]
|
|
97
|
+
elif hasattr(input_data, "model_dump"):
|
|
98
|
+
input_data = input_data.model_dump()
|
|
99
|
+
if not isinstance(input_data, (dict, str)):
|
|
100
|
+
raise ValueError("Input data must be a dictionary or string")
|
|
101
|
+
|
|
102
|
+
if isinstance(input_data, dict):
|
|
103
|
+
input_data = flatten_dict(input_data)
|
|
104
|
+
|
|
105
|
+
if grep is not None and isinstance(input_data, dict):
|
|
106
|
+
input_data = {k: v for k, v in input_data.items() if grep in str(k)}
|
|
107
|
+
|
|
108
|
+
def remove_keys(d: dict, keys: list[str]) -> dict:
|
|
109
|
+
"""Remove specified keys from a dictionary."""
|
|
110
|
+
for key in keys:
|
|
111
|
+
parts = key.split(".")
|
|
112
|
+
sub_dict = d
|
|
113
|
+
for part in parts[:-1]:
|
|
114
|
+
sub_dict = sub_dict.get(part, {})
|
|
115
|
+
sub_dict.pop(parts[-1], None)
|
|
116
|
+
return d
|
|
117
|
+
|
|
118
|
+
def keep_keys(d: dict, keys: list[str]) -> dict:
|
|
119
|
+
"""Keep only specified keys in a dictionary."""
|
|
120
|
+
result = {}
|
|
121
|
+
for key in keys:
|
|
122
|
+
parts = key.split(".")
|
|
123
|
+
sub_source = d
|
|
124
|
+
sub_result = result
|
|
125
|
+
for part in parts[:-1]:
|
|
126
|
+
if part not in sub_source:
|
|
127
|
+
break
|
|
128
|
+
sub_result = sub_result.setdefault(part, {})
|
|
129
|
+
sub_source = sub_source[part]
|
|
130
|
+
else:
|
|
131
|
+
sub_result[parts[-1]] = copy.deepcopy(sub_source.get(parts[-1]))
|
|
132
|
+
return result
|
|
133
|
+
|
|
134
|
+
if hasattr(input_data, "to_dict") and not isinstance(input_data, str):
|
|
135
|
+
input_data = input_data.to_dict()
|
|
136
|
+
|
|
137
|
+
processed_data = copy.deepcopy(input_data)
|
|
138
|
+
|
|
139
|
+
if isinstance(processed_data, dict) and is_notebook:
|
|
140
|
+
if key_keep is not None:
|
|
141
|
+
processed_data = keep_keys(processed_data, key_keep)
|
|
142
|
+
elif key_ignore is not None:
|
|
143
|
+
processed_data = remove_keys(processed_data, key_ignore)
|
|
144
|
+
|
|
145
|
+
if is_notebook:
|
|
146
|
+
display_pretty_table_html(processed_data)
|
|
147
|
+
return
|
|
148
|
+
|
|
149
|
+
if isinstance(processed_data, dict):
|
|
150
|
+
table = [[k, v] for k, v in processed_data.items()]
|
|
151
|
+
f(
|
|
152
|
+
tabulate(
|
|
153
|
+
table,
|
|
154
|
+
headers=["Key", "Value"],
|
|
155
|
+
tablefmt=table_format,
|
|
156
|
+
maxcolwidths=[None, max_width],
|
|
157
|
+
)
|
|
158
|
+
)
|
|
159
|
+
elif isinstance(processed_data, str):
|
|
160
|
+
wrapped_text = textwrap.fill(processed_data, width=str_wrap_width)
|
|
161
|
+
f(wrapped_text)
|
|
162
|
+
elif isinstance(processed_data, list):
|
|
163
|
+
f(tabulate(processed_data, tablefmt=table_format))
|
|
164
|
+
else:
|
|
165
|
+
printer = pprint.PrettyPrinter(width=max_width, indent=indent, depth=depth)
|
|
166
|
+
printer.pprint(processed_data)
|
|
167
|
+
|
|
168
|
+
|
|
169
|
+
def print_table(data: Any, use_html: bool = True) -> None:
|
|
170
|
+
"""
|
|
171
|
+
Print data as a table. If use_html is True, display using IPython HTML.
|
|
172
|
+
"""
|
|
173
|
+
|
|
174
|
+
def __get_table(data: Any) -> str:
|
|
175
|
+
if isinstance(data, str):
|
|
176
|
+
try:
|
|
177
|
+
data = json.loads(data)
|
|
178
|
+
except json.JSONDecodeError as exc:
|
|
179
|
+
raise ValueError("String input could not be decoded as JSON") from exc
|
|
180
|
+
|
|
181
|
+
if isinstance(data, list):
|
|
182
|
+
if all(isinstance(item, dict) for item in data):
|
|
183
|
+
headers = list(data[0].keys())
|
|
184
|
+
rows = [list(item.values()) for item in data]
|
|
185
|
+
return tabulate(
|
|
186
|
+
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
187
|
+
)
|
|
188
|
+
else:
|
|
189
|
+
raise ValueError("List must contain dictionaries")
|
|
190
|
+
|
|
191
|
+
if isinstance(data, dict):
|
|
192
|
+
headers = ["Key", "Value"]
|
|
193
|
+
rows = list(data.items())
|
|
194
|
+
return tabulate(
|
|
195
|
+
rows, headers=headers, tablefmt="html" if use_html else "grid"
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
raise TypeError(
|
|
199
|
+
"Input data must be a list of dictionaries, a dictionary, or a JSON string"
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
table = __get_table(data)
|
|
203
|
+
if use_html:
|
|
204
|
+
display(HTML(table))
|
|
205
|
+
else:
|
|
206
|
+
print(table)
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
__all__ = [
|
|
210
|
+
"display_pretty_table_html",
|
|
211
|
+
"flatten_dict",
|
|
212
|
+
"fprint",
|
|
213
|
+
"print_table",
|
|
214
|
+
# "setup_logger",
|
|
215
|
+
# "log",
|
|
216
|
+
]
|
|
File without changes
|