speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
- speedy_utils-1.1.29.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +4 -0
- vision_utils/io_utils.py +735 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
speedy_utils/common/utils_io.py
CHANGED
|
@@ -1,95 +1,102 @@
|
|
|
1
1
|
# utils/utils_io.py
|
|
2
2
|
|
|
3
|
-
import bz2
|
|
4
|
-
import
|
|
5
|
-
import
|
|
6
|
-
import
|
|
7
|
-
import
|
|
8
|
-
import
|
|
9
|
-
import os
|
|
10
|
-
import
|
|
11
|
-
import
|
|
12
|
-
import
|
|
13
|
-
|
|
14
|
-
from
|
|
15
|
-
from
|
|
16
|
-
|
|
17
|
-
from
|
|
18
|
-
|
|
19
|
-
|
|
3
|
+
# import bz2
|
|
4
|
+
# import contextlib
|
|
5
|
+
# import gzip
|
|
6
|
+
# import io
|
|
7
|
+
# import json
|
|
8
|
+
# import lzma
|
|
9
|
+
# import os
|
|
10
|
+
# import os.path as osp
|
|
11
|
+
# import pickle
|
|
12
|
+
# import time
|
|
13
|
+
# import warnings
|
|
14
|
+
# from collections.abc import Iterable
|
|
15
|
+
# from glob import glob
|
|
16
|
+
# from pathlib import Path
|
|
17
|
+
# from typing import IO, Any, Optional, Union, cast
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
# from pydantic import BaseModel
|
|
21
|
+
|
|
22
|
+
from ..__imports import *
|
|
20
23
|
from .utils_misc import mkdir_or_exist
|
|
21
24
|
|
|
22
25
|
|
|
23
|
-
def dump_jsonl(list_dictionaries: list[dict], file_name: str =
|
|
26
|
+
def dump_jsonl(list_dictionaries: list[dict], file_name: str = 'output.jsonl') -> None:
|
|
24
27
|
"""
|
|
25
28
|
Dumps a list of dictionaries to a file in JSON Lines format.
|
|
26
29
|
"""
|
|
27
|
-
with open(file_name,
|
|
30
|
+
with open(file_name, 'w', encoding='utf-8') as file:
|
|
28
31
|
for dictionary in list_dictionaries:
|
|
29
|
-
file.write(json.dumps(dictionary, ensure_ascii=False) +
|
|
32
|
+
file.write(json.dumps(dictionary, ensure_ascii=False) + '\n')
|
|
30
33
|
|
|
31
34
|
|
|
32
|
-
def dump_json_or_pickle(
|
|
35
|
+
def dump_json_or_pickle(
|
|
36
|
+
obj: Any, fname: str, ensure_ascii: bool = False, indent: int = 4
|
|
37
|
+
) -> None:
|
|
33
38
|
"""
|
|
34
39
|
Dump an object to a file, supporting both JSON and pickle formats.
|
|
35
40
|
"""
|
|
36
41
|
if isinstance(fname, Path):
|
|
37
42
|
fname = str(fname)
|
|
38
43
|
mkdir_or_exist(osp.abspath(os.path.dirname(osp.abspath(fname))))
|
|
39
|
-
if fname.endswith(
|
|
40
|
-
with open(fname,
|
|
44
|
+
if fname.endswith('.json'):
|
|
45
|
+
with open(fname, 'w', encoding='utf-8') as f:
|
|
41
46
|
try:
|
|
42
47
|
json.dump(obj, f, ensure_ascii=ensure_ascii, indent=indent)
|
|
43
48
|
# TypeError: Object of type datetime is not JSON serializable
|
|
44
49
|
except TypeError:
|
|
45
50
|
print(
|
|
46
|
-
|
|
51
|
+
'Error: Object of type datetime is not JSON serializable',
|
|
47
52
|
str(obj)[:1000],
|
|
48
53
|
)
|
|
49
54
|
raise
|
|
50
|
-
elif fname.endswith(
|
|
55
|
+
elif fname.endswith('.jsonl'):
|
|
51
56
|
dump_jsonl(obj, fname)
|
|
52
|
-
elif fname.endswith(
|
|
57
|
+
elif fname.endswith('.pkl'):
|
|
53
58
|
try:
|
|
54
|
-
with open(fname,
|
|
59
|
+
with open(fname, 'wb') as f:
|
|
55
60
|
pickle.dump(obj, f)
|
|
56
61
|
except Exception as e:
|
|
57
62
|
if isinstance(obj, BaseModel):
|
|
58
|
-
data = obj.model_dump()
|
|
63
|
+
data = obj.model_dump() # type: ignore
|
|
59
64
|
from fastcore.all import dict2obj, obj2dict
|
|
60
65
|
|
|
61
66
|
obj2 = dict2obj(data)
|
|
62
|
-
with open(fname,
|
|
67
|
+
with open(fname, 'wb') as f:
|
|
63
68
|
pickle.dump(obj2, f)
|
|
64
69
|
else:
|
|
65
|
-
raise ValueError(f
|
|
70
|
+
raise ValueError(f'Error {e} while dumping {fname}') from e
|
|
66
71
|
|
|
67
72
|
else:
|
|
68
|
-
raise NotImplementedError(f
|
|
73
|
+
raise NotImplementedError(f'File type {fname} not supported')
|
|
69
74
|
|
|
70
75
|
|
|
71
76
|
def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
72
77
|
"""
|
|
73
78
|
Load an object from a file, supporting both JSON and pickle formats.
|
|
74
79
|
"""
|
|
75
|
-
if fname.endswith(
|
|
76
|
-
with open(fname, encoding=
|
|
80
|
+
if fname.endswith(('.json', '.jsonl')):
|
|
81
|
+
with open(fname, encoding='utf-8') as f:
|
|
77
82
|
return json.load(f)
|
|
78
83
|
else:
|
|
79
84
|
try:
|
|
80
|
-
with open(fname,
|
|
85
|
+
with open(fname, 'rb') as f:
|
|
81
86
|
return pickle.load(f)
|
|
82
87
|
# EOFError: Ran out of input
|
|
83
88
|
except EOFError:
|
|
84
89
|
time.sleep(1)
|
|
85
90
|
if counter > 5:
|
|
86
91
|
# Keep message concise and actionable
|
|
87
|
-
print(
|
|
92
|
+
print(
|
|
93
|
+
f'Corrupted cache file {fname} removed; it will be regenerated on next access'
|
|
94
|
+
)
|
|
88
95
|
os.remove(fname)
|
|
89
96
|
raise
|
|
90
97
|
return load_json_or_pickle(fname, counter + 1)
|
|
91
98
|
except Exception as e:
|
|
92
|
-
raise ValueError(f
|
|
99
|
+
raise ValueError(f'Error {e} while loading {fname}') from e
|
|
93
100
|
|
|
94
101
|
|
|
95
102
|
try:
|
|
@@ -104,19 +111,19 @@ except Exception:
|
|
|
104
111
|
|
|
105
112
|
|
|
106
113
|
def fast_load_jsonl(
|
|
107
|
-
path_or_file:
|
|
114
|
+
path_or_file: str | os.PathLike | IO,
|
|
108
115
|
*,
|
|
109
116
|
progress: bool = False,
|
|
110
|
-
desc: str =
|
|
117
|
+
desc: str = 'Reading JSONL',
|
|
111
118
|
use_orjson: bool = True,
|
|
112
|
-
encoding: str =
|
|
113
|
-
errors: str =
|
|
114
|
-
on_error: str =
|
|
119
|
+
encoding: str = 'utf-8',
|
|
120
|
+
errors: str = 'strict',
|
|
121
|
+
on_error: str = 'raise', # 'raise' | 'warn' | 'skip'
|
|
115
122
|
skip_empty: bool = True,
|
|
116
|
-
max_lines:
|
|
123
|
+
max_lines: int | None = None,
|
|
117
124
|
use_multiworker: bool = True,
|
|
118
125
|
multiworker_threshold: int = 1000000,
|
|
119
|
-
workers:
|
|
126
|
+
workers: int | None = None,
|
|
120
127
|
) -> Iterable[Any]:
|
|
121
128
|
"""
|
|
122
129
|
Lazily iterate objects from a JSON Lines file.
|
|
@@ -144,31 +151,52 @@ def fast_load_jsonl(
|
|
|
144
151
|
Parsed Python objects per line.
|
|
145
152
|
"""
|
|
146
153
|
|
|
154
|
+
class ZstdWrapper:
|
|
155
|
+
"""Context manager wrapper for zstd decompression."""
|
|
156
|
+
|
|
157
|
+
def __init__(self, path):
|
|
158
|
+
self.path = path
|
|
159
|
+
self.fh = None
|
|
160
|
+
self.stream = None
|
|
161
|
+
self.buffered = None
|
|
162
|
+
|
|
163
|
+
def __enter__(self):
|
|
164
|
+
if zstd is None:
|
|
165
|
+
raise ImportError('zstandard package required for .zst files')
|
|
166
|
+
self.fh = open(self.path, 'rb')
|
|
167
|
+
dctx = zstd.ZstdDecompressor()
|
|
168
|
+
self.stream = dctx.stream_reader(self.fh)
|
|
169
|
+
self.buffered = io.BufferedReader(self.stream)
|
|
170
|
+
return self.buffered
|
|
171
|
+
|
|
172
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
173
|
+
if self.fh:
|
|
174
|
+
self.fh.close()
|
|
175
|
+
|
|
147
176
|
def _open_auto(pth_or_f) -> IO[Any]:
|
|
148
|
-
if hasattr(pth_or_f,
|
|
177
|
+
if hasattr(pth_or_f, 'read'):
|
|
149
178
|
# ensure binary buffer for consistent byte-length progress
|
|
150
179
|
fobj = pth_or_f
|
|
151
180
|
# If it's text, wrap it to binary via encoding; else just return
|
|
152
181
|
if isinstance(fobj, io.TextIOBase):
|
|
153
182
|
# TextIO -> re-encode to bytes on the fly
|
|
154
|
-
return io.BufferedReader(
|
|
183
|
+
return io.BufferedReader(
|
|
184
|
+
io.BytesIO(fobj.read().encode(encoding, errors))
|
|
185
|
+
)
|
|
155
186
|
return pth_or_f # assume binary
|
|
156
187
|
s = str(pth_or_f).lower()
|
|
157
|
-
if s.endswith(
|
|
158
|
-
return gzip.open(pth_or_f,
|
|
159
|
-
if s.endswith(
|
|
160
|
-
return bz2.open(pth_or_f,
|
|
161
|
-
if s.endswith((
|
|
162
|
-
return lzma.open(pth_or_f,
|
|
163
|
-
if s.endswith((
|
|
164
|
-
|
|
165
|
-
dctx = zstd.ZstdDecompressor()
|
|
166
|
-
stream = dctx.stream_reader(fh)
|
|
167
|
-
return io.BufferedReader(stream) # type: ignore
|
|
188
|
+
if s.endswith('.gz'):
|
|
189
|
+
return gzip.open(pth_or_f, 'rb') # type: ignore
|
|
190
|
+
if s.endswith('.bz2'):
|
|
191
|
+
return bz2.open(pth_or_f, 'rb') # type: ignore
|
|
192
|
+
if s.endswith(('.xz', '.lzma')):
|
|
193
|
+
return lzma.open(pth_or_f, 'rb') # type: ignore
|
|
194
|
+
if s.endswith(('.zst', '.zstd')) and zstd is not None:
|
|
195
|
+
return ZstdWrapper(pth_or_f).__enter__() # type: ignore
|
|
168
196
|
# plain
|
|
169
|
-
return open(pth_or_f,
|
|
197
|
+
return open(pth_or_f, 'rb', buffering=1024 * 1024)
|
|
170
198
|
|
|
171
|
-
def _count_lines_fast(file_path:
|
|
199
|
+
def _count_lines_fast(file_path: str | os.PathLike) -> int:
|
|
172
200
|
"""Quickly count lines in a file, handling compression."""
|
|
173
201
|
try:
|
|
174
202
|
f = _open_auto(file_path)
|
|
@@ -187,7 +215,7 @@ def fast_load_jsonl(
|
|
|
187
215
|
for line_bytes in chunk_lines:
|
|
188
216
|
if skip_empty and not line_bytes.strip():
|
|
189
217
|
continue
|
|
190
|
-
line_bytes = line_bytes.rstrip(b
|
|
218
|
+
line_bytes = line_bytes.rstrip(b'\r\n')
|
|
191
219
|
try:
|
|
192
220
|
if use_orjson and orjson is not None:
|
|
193
221
|
obj = orjson.loads(line_bytes)
|
|
@@ -195,10 +223,10 @@ def fast_load_jsonl(
|
|
|
195
223
|
obj = json.loads(line_bytes.decode(encoding, errors))
|
|
196
224
|
results.append(obj)
|
|
197
225
|
except Exception as e:
|
|
198
|
-
if on_error ==
|
|
226
|
+
if on_error == 'raise':
|
|
199
227
|
raise
|
|
200
|
-
if on_error ==
|
|
201
|
-
warnings.warn(f
|
|
228
|
+
if on_error == 'warn':
|
|
229
|
+
warnings.warn(f'Skipping malformed line: {e}', stacklevel=2)
|
|
202
230
|
# 'skip' and 'warn' both skip the line
|
|
203
231
|
continue
|
|
204
232
|
return results
|
|
@@ -206,12 +234,12 @@ def fast_load_jsonl(
|
|
|
206
234
|
# Check if we should use multi-worker processing
|
|
207
235
|
should_use_multiworker = (
|
|
208
236
|
use_multiworker
|
|
209
|
-
and not hasattr(path_or_file,
|
|
237
|
+
and not hasattr(path_or_file, 'read') # Only for file paths, not file objects
|
|
210
238
|
and max_lines is None # Don't use multiworker if we're limiting lines
|
|
211
239
|
)
|
|
212
240
|
|
|
213
241
|
if should_use_multiworker:
|
|
214
|
-
line_count = _count_lines_fast(cast(
|
|
242
|
+
line_count = _count_lines_fast(cast(str | os.PathLike, path_or_file))
|
|
215
243
|
if line_count > multiworker_threshold:
|
|
216
244
|
# Use multi-worker processing
|
|
217
245
|
from ..multi_worker.thread import multi_thread
|
|
@@ -236,9 +264,13 @@ def fast_load_jsonl(
|
|
|
236
264
|
|
|
237
265
|
# Process chunks in parallel
|
|
238
266
|
if progress:
|
|
239
|
-
print(
|
|
267
|
+
print(
|
|
268
|
+
f'Processing {line_count} lines with {num_workers} workers ({len(chunks)} chunks)...'
|
|
269
|
+
)
|
|
240
270
|
|
|
241
|
-
chunk_results = multi_thread(
|
|
271
|
+
chunk_results = multi_thread(
|
|
272
|
+
_process_chunk, chunks, workers=num_workers, progress=progress
|
|
273
|
+
)
|
|
242
274
|
|
|
243
275
|
# Flatten results and yield
|
|
244
276
|
if chunk_results:
|
|
@@ -257,15 +289,15 @@ def fast_load_jsonl(
|
|
|
257
289
|
try:
|
|
258
290
|
from tqdm import tqdm # type: ignore
|
|
259
291
|
except Exception as e:
|
|
260
|
-
raise ImportError(
|
|
292
|
+
raise ImportError('tqdm is required when progress=True') from e
|
|
261
293
|
total = None
|
|
262
|
-
if not hasattr(path_or_file,
|
|
294
|
+
if not hasattr(path_or_file, 'read'):
|
|
263
295
|
try:
|
|
264
|
-
path_for_size = cast(
|
|
296
|
+
path_for_size = cast(str | os.PathLike, path_or_file)
|
|
265
297
|
total = os.path.getsize(path_for_size) # compressed size if any
|
|
266
298
|
except Exception:
|
|
267
299
|
total = None
|
|
268
|
-
pbar = tqdm(total=total, unit=
|
|
300
|
+
pbar = tqdm(total=total, unit='B', unit_scale=True, desc=desc)
|
|
269
301
|
|
|
270
302
|
line_no = 0
|
|
271
303
|
try:
|
|
@@ -286,7 +318,7 @@ def fast_load_jsonl(
|
|
|
286
318
|
if max_lines and line_no >= max_lines:
|
|
287
319
|
break
|
|
288
320
|
continue
|
|
289
|
-
line_bytes = raw_line.rstrip(b
|
|
321
|
+
line_bytes = raw_line.rstrip(b'\r\n')
|
|
290
322
|
# Parse
|
|
291
323
|
try:
|
|
292
324
|
if use_orjson and orjson is not None:
|
|
@@ -294,10 +326,12 @@ def fast_load_jsonl(
|
|
|
294
326
|
else:
|
|
295
327
|
obj = json.loads(line_bytes.decode(encoding, errors))
|
|
296
328
|
except Exception as e:
|
|
297
|
-
if on_error ==
|
|
329
|
+
if on_error == 'raise':
|
|
298
330
|
raise
|
|
299
|
-
if on_error ==
|
|
300
|
-
warnings.warn(
|
|
331
|
+
if on_error == 'warn':
|
|
332
|
+
warnings.warn(
|
|
333
|
+
f'Skipping malformed line {line_no}: {e}', stacklevel=2
|
|
334
|
+
)
|
|
301
335
|
# 'skip' and 'warn' both skip the line
|
|
302
336
|
if max_lines and line_no >= max_lines:
|
|
303
337
|
break
|
|
@@ -311,10 +345,12 @@ def fast_load_jsonl(
|
|
|
311
345
|
try:
|
|
312
346
|
obj = json.loads(raw_line)
|
|
313
347
|
except Exception as e:
|
|
314
|
-
if on_error ==
|
|
348
|
+
if on_error == 'raise':
|
|
315
349
|
raise
|
|
316
|
-
if on_error ==
|
|
317
|
-
warnings.warn(
|
|
350
|
+
if on_error == 'warn':
|
|
351
|
+
warnings.warn(
|
|
352
|
+
f'Skipping malformed line {line_no}: {e}', stacklevel=2
|
|
353
|
+
)
|
|
318
354
|
if max_lines and line_no >= max_lines:
|
|
319
355
|
break
|
|
320
356
|
continue
|
|
@@ -326,14 +362,12 @@ def fast_load_jsonl(
|
|
|
326
362
|
if pbar is not None:
|
|
327
363
|
pbar.close()
|
|
328
364
|
# Close only if we opened it (i.e., not an external stream)
|
|
329
|
-
if not hasattr(path_or_file,
|
|
330
|
-
|
|
365
|
+
if not hasattr(path_or_file, 'read'):
|
|
366
|
+
with contextlib.suppress(Exception):
|
|
331
367
|
f.close()
|
|
332
|
-
except Exception:
|
|
333
|
-
pass
|
|
334
368
|
|
|
335
369
|
|
|
336
|
-
def load_by_ext(fname:
|
|
370
|
+
def load_by_ext(fname: str | list[str], do_memoize: bool = False) -> Any:
|
|
337
371
|
"""
|
|
338
372
|
Load data based on file extension.
|
|
339
373
|
"""
|
|
@@ -346,54 +380,54 @@ def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
|
|
|
346
380
|
)
|
|
347
381
|
|
|
348
382
|
try:
|
|
349
|
-
if isinstance(fname, str) and
|
|
383
|
+
if isinstance(fname, str) and '*' in fname:
|
|
350
384
|
paths = glob(fname)
|
|
351
385
|
paths = sorted(paths)
|
|
352
386
|
return multi_process(load_by_ext, paths, workers=16)
|
|
353
|
-
|
|
387
|
+
if isinstance(fname, list):
|
|
354
388
|
paths = fname
|
|
355
389
|
return multi_process(load_by_ext, paths, workers=16)
|
|
356
390
|
|
|
357
391
|
def load_csv(path: str, **pd_kwargs) -> Any:
|
|
358
392
|
import pandas as pd
|
|
359
393
|
|
|
360
|
-
return pd.read_csv(path, engine=
|
|
394
|
+
return pd.read_csv(path, engine='pyarrow', **pd_kwargs)
|
|
361
395
|
|
|
362
396
|
def load_txt(path: str) -> list[str]:
|
|
363
|
-
with open(path, encoding=
|
|
397
|
+
with open(path, encoding='utf-8') as f:
|
|
364
398
|
return f.read().splitlines()
|
|
365
399
|
|
|
366
400
|
def load_default(path: str) -> Any:
|
|
367
|
-
if path.endswith(
|
|
401
|
+
if path.endswith('.jsonl'):
|
|
368
402
|
return list(fast_load_jsonl(path, progress=True))
|
|
369
|
-
|
|
403
|
+
if path.endswith('.json'):
|
|
370
404
|
try:
|
|
371
405
|
return load_json_or_pickle(path)
|
|
372
406
|
except json.JSONDecodeError as exc:
|
|
373
|
-
raise ValueError(
|
|
407
|
+
raise ValueError('JSON decoding failed') from exc
|
|
374
408
|
return load_json_or_pickle(path)
|
|
375
409
|
|
|
376
410
|
handlers = {
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
411
|
+
'.csv': load_csv,
|
|
412
|
+
'.tsv': load_csv,
|
|
413
|
+
'.txt': load_txt,
|
|
414
|
+
'.pkl': load_default,
|
|
415
|
+
'.json': load_default,
|
|
416
|
+
'.jsonl': load_default,
|
|
383
417
|
}
|
|
384
418
|
|
|
385
419
|
ext = os.path.splitext(fname)[-1]
|
|
386
420
|
load_fn = handlers.get(ext)
|
|
387
421
|
|
|
388
422
|
if not load_fn:
|
|
389
|
-
raise NotImplementedError(f
|
|
423
|
+
raise NotImplementedError(f'File type {ext} not supported')
|
|
390
424
|
|
|
391
425
|
if do_memoize:
|
|
392
426
|
load_fn = memoize(load_fn)
|
|
393
427
|
|
|
394
428
|
return load_fn(fname)
|
|
395
429
|
except Exception as e:
|
|
396
|
-
raise ValueError(f
|
|
430
|
+
raise ValueError(f'Error {e} while loading {fname}') from e
|
|
397
431
|
|
|
398
432
|
|
|
399
433
|
def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
|
|
@@ -403,10 +437,10 @@ def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
|
|
|
403
437
|
load_jsonl = lambda path: list(fast_load_jsonl(path))
|
|
404
438
|
|
|
405
439
|
__all__ = [
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
|
|
440
|
+
'dump_json_or_pickle',
|
|
441
|
+
'dump_jsonl',
|
|
442
|
+
'load_by_ext',
|
|
443
|
+
'load_json_or_pickle',
|
|
444
|
+
'jdumps',
|
|
445
|
+
'jloads',
|
|
412
446
|
]
|
|
@@ -1,13 +1,8 @@
|
|
|
1
1
|
# utils/utils_misc.py
|
|
2
|
+
from ..__imports import *
|
|
2
3
|
|
|
3
|
-
import inspect
|
|
4
|
-
import os
|
|
5
|
-
from collections.abc import Callable
|
|
6
|
-
from typing import Any, TypeVar
|
|
7
4
|
|
|
8
|
-
|
|
9
|
-
|
|
10
|
-
T = TypeVar("T")
|
|
5
|
+
T = TypeVar('T')
|
|
11
6
|
|
|
12
7
|
|
|
13
8
|
def mkdir_or_exist(dir_name: str) -> None:
|
|
@@ -27,10 +22,10 @@ def get_arg_names(func: Callable) -> list[str]:
|
|
|
27
22
|
|
|
28
23
|
def is_notebook() -> bool:
|
|
29
24
|
try:
|
|
30
|
-
if
|
|
31
|
-
get_ipython = globals()[
|
|
25
|
+
if 'get_ipython' in globals():
|
|
26
|
+
get_ipython = globals()['get_ipython']
|
|
32
27
|
shell = get_ipython().__class__.__name__
|
|
33
|
-
if shell ==
|
|
28
|
+
if shell == 'ZMQInteractiveShell':
|
|
34
29
|
return True # Jupyter notebook or qtconsole
|
|
35
30
|
return False # Other type (?)
|
|
36
31
|
except NameError:
|
|
@@ -41,15 +36,14 @@ def convert_to_builtin_python(input_data: Any) -> Any:
|
|
|
41
36
|
"""Convert input data to built-in Python types."""
|
|
42
37
|
if isinstance(input_data, dict):
|
|
43
38
|
return {k: convert_to_builtin_python(v) for k, v in input_data.items()}
|
|
44
|
-
|
|
39
|
+
if isinstance(input_data, list):
|
|
45
40
|
return [convert_to_builtin_python(v) for v in input_data]
|
|
46
|
-
|
|
41
|
+
if isinstance(input_data, (int, float, str, bool, type(None))):
|
|
47
42
|
return input_data
|
|
48
|
-
|
|
43
|
+
if isinstance(input_data, BaseModel):
|
|
49
44
|
data = input_data.model_dump_json()
|
|
50
45
|
return convert_to_builtin_python(data)
|
|
51
|
-
|
|
52
|
-
raise ValueError(f"Unsupported type {type(input_data)}")
|
|
46
|
+
raise ValueError(f'Unsupported type {type(input_data)}')
|
|
53
47
|
|
|
54
48
|
|
|
55
49
|
def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
|
|
@@ -74,10 +68,10 @@ def dedup(items: list[T], key: Callable[[T], Any]) -> list[T]:
|
|
|
74
68
|
|
|
75
69
|
|
|
76
70
|
__all__ = [
|
|
77
|
-
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
|
|
81
|
-
|
|
82
|
-
|
|
71
|
+
'mkdir_or_exist',
|
|
72
|
+
'flatten_list',
|
|
73
|
+
'get_arg_names',
|
|
74
|
+
'is_notebook',
|
|
75
|
+
'convert_to_builtin_python',
|
|
76
|
+
'dedup',
|
|
83
77
|
]
|
|
@@ -1,17 +1,11 @@
|
|
|
1
1
|
# utils/utils_print.py
|
|
2
2
|
|
|
3
|
-
import
|
|
4
|
-
import pprint
|
|
5
|
-
import textwrap
|
|
6
|
-
from typing import Any, Union
|
|
7
|
-
|
|
8
|
-
from tabulate import tabulate
|
|
9
|
-
|
|
3
|
+
from ..__imports import *
|
|
10
4
|
from .notebook_utils import display_pretty_table_html
|
|
11
5
|
|
|
12
6
|
|
|
13
7
|
# Flattening the dictionary using "." notation for keys
|
|
14
|
-
def flatten_dict(d, parent_key=
|
|
8
|
+
def flatten_dict(d, parent_key='', sep='.'):
|
|
15
9
|
items = []
|
|
16
10
|
for k, v in d.items():
|
|
17
11
|
new_key = parent_key + sep + k if parent_key else k
|
|
@@ -24,22 +18,22 @@ def flatten_dict(d, parent_key="", sep="."):
|
|
|
24
18
|
|
|
25
19
|
def fprint(
|
|
26
20
|
input_data: Any,
|
|
27
|
-
key_ignore:
|
|
28
|
-
key_keep:
|
|
21
|
+
key_ignore: list[str] | None = None,
|
|
22
|
+
key_keep: list[str] | None = None,
|
|
29
23
|
max_width: int = 100,
|
|
30
24
|
indent: int = 2,
|
|
31
|
-
depth:
|
|
32
|
-
table_format: str =
|
|
25
|
+
depth: int | None = None,
|
|
26
|
+
table_format: str = 'grid',
|
|
33
27
|
str_wrap_width: int = 80,
|
|
34
28
|
grep=None,
|
|
35
29
|
is_notebook=None,
|
|
36
30
|
f=print,
|
|
37
|
-
) ->
|
|
31
|
+
) -> None | str:
|
|
38
32
|
"""
|
|
39
33
|
Pretty print structured data.
|
|
40
34
|
"""
|
|
41
35
|
if isinstance(input_data, list):
|
|
42
|
-
for
|
|
36
|
+
for _i, item in enumerate(input_data):
|
|
43
37
|
fprint(
|
|
44
38
|
item,
|
|
45
39
|
key_ignore,
|
|
@@ -53,7 +47,7 @@ def fprint(
|
|
|
53
47
|
is_notebook,
|
|
54
48
|
f,
|
|
55
49
|
)
|
|
56
|
-
print(
|
|
50
|
+
print('\n' + '-' * 100 + '\n')
|
|
57
51
|
|
|
58
52
|
from speedy_utils import is_notebook as is_interactive
|
|
59
53
|
|
|
@@ -61,24 +55,24 @@ def fprint(
|
|
|
61
55
|
if is_notebook is None:
|
|
62
56
|
is_notebook = is_interactive()
|
|
63
57
|
if isinstance(input_data, list):
|
|
64
|
-
if all(hasattr(item,
|
|
58
|
+
if all(hasattr(item, 'toDict') for item in input_data):
|
|
65
59
|
input_data = [item.toDict() for item in input_data]
|
|
66
|
-
elif hasattr(input_data,
|
|
60
|
+
elif hasattr(input_data, 'toDict'):
|
|
67
61
|
input_data = input_data.toDict()
|
|
68
62
|
|
|
69
63
|
if isinstance(input_data, list):
|
|
70
|
-
if all(hasattr(item,
|
|
64
|
+
if all(hasattr(item, 'to_dict') for item in input_data):
|
|
71
65
|
input_data = [item.to_dict() for item in input_data]
|
|
72
|
-
elif hasattr(input_data,
|
|
66
|
+
elif hasattr(input_data, 'to_dict'):
|
|
73
67
|
input_data = input_data.to_dict()
|
|
74
68
|
|
|
75
69
|
if isinstance(input_data, list):
|
|
76
|
-
if all(hasattr(item,
|
|
70
|
+
if all(hasattr(item, 'model_dump') for item in input_data):
|
|
77
71
|
input_data = [item.model_dump() for item in input_data]
|
|
78
|
-
elif hasattr(input_data,
|
|
72
|
+
elif hasattr(input_data, 'model_dump'):
|
|
79
73
|
input_data = input_data.model_dump()
|
|
80
74
|
if not isinstance(input_data, (dict, str)):
|
|
81
|
-
raise ValueError(
|
|
75
|
+
raise ValueError('Input data must be a dictionary or string')
|
|
82
76
|
|
|
83
77
|
if isinstance(input_data, dict):
|
|
84
78
|
input_data = flatten_dict(input_data)
|
|
@@ -89,7 +83,7 @@ def fprint(
|
|
|
89
83
|
def remove_keys(d: dict, keys: list[str]) -> dict:
|
|
90
84
|
"""Remove specified keys from a dictionary."""
|
|
91
85
|
for key in keys:
|
|
92
|
-
parts = key.split(
|
|
86
|
+
parts = key.split('.')
|
|
93
87
|
sub_dict = d
|
|
94
88
|
for part in parts[:-1]:
|
|
95
89
|
sub_dict = sub_dict.get(part, {})
|
|
@@ -100,7 +94,7 @@ def fprint(
|
|
|
100
94
|
"""Keep only specified keys in a dictionary."""
|
|
101
95
|
result = {}
|
|
102
96
|
for key in keys:
|
|
103
|
-
parts = key.split(
|
|
97
|
+
parts = key.split('.')
|
|
104
98
|
sub_source = d
|
|
105
99
|
sub_result = result
|
|
106
100
|
for part in parts[:-1]:
|
|
@@ -112,7 +106,7 @@ def fprint(
|
|
|
112
106
|
sub_result[parts[-1]] = copy.deepcopy(sub_source.get(parts[-1]))
|
|
113
107
|
return result
|
|
114
108
|
|
|
115
|
-
if hasattr(input_data,
|
|
109
|
+
if hasattr(input_data, 'to_dict') and not isinstance(input_data, str):
|
|
116
110
|
input_data = input_data.to_dict()
|
|
117
111
|
|
|
118
112
|
processed_data = copy.deepcopy(input_data)
|
|
@@ -132,7 +126,7 @@ def fprint(
|
|
|
132
126
|
f(
|
|
133
127
|
tabulate(
|
|
134
128
|
table,
|
|
135
|
-
headers=[
|
|
129
|
+
headers=['Key', 'Value'],
|
|
136
130
|
tablefmt=table_format,
|
|
137
131
|
maxcolwidths=[None, max_width],
|
|
138
132
|
)
|
|
@@ -148,6 +142,6 @@ def fprint(
|
|
|
148
142
|
|
|
149
143
|
|
|
150
144
|
__all__ = [
|
|
151
|
-
|
|
152
|
-
|
|
145
|
+
'flatten_dict',
|
|
146
|
+
'fprint',
|
|
153
147
|
]
|