speedy-utils 1.1.17__py3-none-any.whl → 1.1.19__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +9 -1
- llm_utils/chat_format/display.py +109 -14
- llm_utils/lm/__init__.py +12 -11
- llm_utils/lm/async_lm/async_llm_task.py +1 -10
- llm_utils/lm/async_lm/async_lm.py +13 -4
- llm_utils/lm/async_lm/async_lm_base.py +24 -14
- llm_utils/lm/base_prompt_builder.py +288 -0
- llm_utils/lm/llm_task.py +693 -0
- llm_utils/lm/lm.py +207 -0
- llm_utils/lm/lm_base.py +285 -0
- llm_utils/lm/openai_memoize.py +2 -2
- llm_utils/vector_cache/core.py +285 -89
- speedy_utils/__init__.py +2 -1
- speedy_utils/common/patcher.py +68 -0
- speedy_utils/common/utils_cache.py +6 -6
- speedy_utils/common/utils_io.py +238 -8
- speedy_utils/multi_worker/process.py +180 -192
- speedy_utils/multi_worker/thread.py +94 -2
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/METADATA +36 -14
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/RECORD +24 -19
- {speedy_utils-1.1.17.dist-info → speedy_utils-1.1.19.dist-info}/WHEEL +1 -1
- speedy_utils-1.1.19.dist-info/entry_points.txt +5 -0
- speedy_utils-1.1.17.dist-info/entry_points.txt +0 -6
speedy_utils/common/utils_io.py
CHANGED
|
@@ -1,13 +1,18 @@
|
|
|
1
1
|
# utils/utils_io.py
|
|
2
2
|
|
|
3
|
+
import bz2
|
|
4
|
+
import gzip
|
|
5
|
+
import io
|
|
3
6
|
import json
|
|
7
|
+
import lzma
|
|
4
8
|
import os
|
|
5
9
|
import os.path as osp
|
|
6
10
|
import pickle
|
|
7
11
|
import time
|
|
12
|
+
import warnings
|
|
8
13
|
from glob import glob
|
|
9
14
|
from pathlib import Path
|
|
10
|
-
from typing import Any, Union
|
|
15
|
+
from typing import IO, Any, Iterable, Optional, Union, cast
|
|
11
16
|
|
|
12
17
|
from json_repair import loads as jloads
|
|
13
18
|
from pydantic import BaseModel
|
|
@@ -53,7 +58,7 @@ def dump_json_or_pickle(
|
|
|
53
58
|
except Exception as e:
|
|
54
59
|
if isinstance(obj, BaseModel):
|
|
55
60
|
data = obj.model_dump()
|
|
56
|
-
from fastcore.all import
|
|
61
|
+
from fastcore.all import dict2obj, obj2dict
|
|
57
62
|
obj2 = dict2obj(data)
|
|
58
63
|
with open(fname, "wb") as f:
|
|
59
64
|
pickle.dump(obj2, f)
|
|
@@ -87,9 +92,235 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
|
|
|
87
92
|
raise ValueError(f"Error {e} while loading {fname}") from e
|
|
88
93
|
|
|
89
94
|
|
|
90
|
-
|
|
91
|
-
|
|
92
|
-
|
|
95
|
+
|
|
96
|
+
|
|
97
|
+
try:
|
|
98
|
+
import orjson # type: ignore[import-not-found] # fastest JSON parser when available
|
|
99
|
+
except Exception:
|
|
100
|
+
orjson = None
|
|
101
|
+
|
|
102
|
+
try:
|
|
103
|
+
import zstandard as zstd # type: ignore[import-not-found] # optional .zst support
|
|
104
|
+
except Exception:
|
|
105
|
+
zstd = None
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
def fast_load_jsonl(
|
|
109
|
+
path_or_file: Union[str, os.PathLike, IO],
|
|
110
|
+
*,
|
|
111
|
+
progress: bool = False,
|
|
112
|
+
desc: str = "Reading JSONL",
|
|
113
|
+
use_orjson: bool = True,
|
|
114
|
+
encoding: str = "utf-8",
|
|
115
|
+
errors: str = "strict",
|
|
116
|
+
on_error: str = "raise", # 'raise' | 'warn' | 'skip'
|
|
117
|
+
skip_empty: bool = True,
|
|
118
|
+
max_lines: Optional[int] = None,
|
|
119
|
+
use_multiworker: bool = True,
|
|
120
|
+
multiworker_threshold: int = 50000,
|
|
121
|
+
workers: Optional[int] = None,
|
|
122
|
+
) -> Iterable[Any]:
|
|
123
|
+
"""
|
|
124
|
+
Lazily iterate objects from a JSON Lines file.
|
|
125
|
+
|
|
126
|
+
- Streams line-by-line (constant memory).
|
|
127
|
+
- Optional tqdm progress over bytes (compressed size if gz/bz2/xz/zst).
|
|
128
|
+
- Auto-detects compression by extension: .gz, .bz2, .xz/.lzma, .zst/.zstd.
|
|
129
|
+
- Uses orjson if available (use_orjson=True), falls back to json.
|
|
130
|
+
- Automatically uses multi-worker processing for large files (>50k lines).
|
|
131
|
+
|
|
132
|
+
Args:
|
|
133
|
+
path_or_file: Path-like or file-like object. File-like can be binary or text.
|
|
134
|
+
progress: Show a tqdm progress bar (bytes). Requires `tqdm` if True.
|
|
135
|
+
desc: tqdm description if progress=True.
|
|
136
|
+
use_orjson: Prefer orjson for speed if installed.
|
|
137
|
+
encoding, errors: Used when decoding text or when falling back to `json`.
|
|
138
|
+
on_error: What to do on a malformed line: 'raise', 'warn', or 'skip'.
|
|
139
|
+
skip_empty: Skip blank/whitespace-only lines.
|
|
140
|
+
max_lines: Stop after reading this many lines (useful for sampling).
|
|
141
|
+
use_multiworker: Enable multi-worker processing for large files.
|
|
142
|
+
multiworker_threshold: Line count threshold to trigger multi-worker processing.
|
|
143
|
+
workers: Number of worker threads (defaults to CPU count).
|
|
144
|
+
|
|
145
|
+
Yields:
|
|
146
|
+
Parsed Python objects per line.
|
|
147
|
+
"""
|
|
148
|
+
def _open_auto(pth_or_f) -> IO[Any]:
|
|
149
|
+
if hasattr(pth_or_f, "read"):
|
|
150
|
+
# ensure binary buffer for consistent byte-length progress
|
|
151
|
+
fobj = pth_or_f
|
|
152
|
+
# If it's text, wrap it to binary via encoding; else just return
|
|
153
|
+
if isinstance(fobj, io.TextIOBase):
|
|
154
|
+
# TextIO -> re-encode to bytes on the fly
|
|
155
|
+
return io.BufferedReader(io.BytesIO(fobj.read().encode(encoding, errors)))
|
|
156
|
+
return pth_or_f # assume binary
|
|
157
|
+
s = str(pth_or_f).lower()
|
|
158
|
+
if s.endswith(".gz"):
|
|
159
|
+
return gzip.open(pth_or_f, "rb") # type: ignore
|
|
160
|
+
if s.endswith(".bz2"):
|
|
161
|
+
return bz2.open(pth_or_f, "rb") # type: ignore
|
|
162
|
+
if s.endswith((".xz", ".lzma")):
|
|
163
|
+
return lzma.open(pth_or_f, "rb") # type: ignore
|
|
164
|
+
if s.endswith((".zst", ".zstd")) and zstd is not None:
|
|
165
|
+
fh = open(pth_or_f, "rb")
|
|
166
|
+
dctx = zstd.ZstdDecompressor()
|
|
167
|
+
stream = dctx.stream_reader(fh)
|
|
168
|
+
return io.BufferedReader(stream) # type: ignore
|
|
169
|
+
# plain
|
|
170
|
+
return open(pth_or_f, "rb", buffering=1024 * 1024)
|
|
171
|
+
|
|
172
|
+
def _count_lines_fast(file_path: Union[str, os.PathLike]) -> int:
|
|
173
|
+
"""Quickly count lines in a file, handling compression."""
|
|
174
|
+
try:
|
|
175
|
+
f = _open_auto(file_path)
|
|
176
|
+
count = 0
|
|
177
|
+
for _ in f:
|
|
178
|
+
count += 1
|
|
179
|
+
f.close()
|
|
180
|
+
return count
|
|
181
|
+
except Exception:
|
|
182
|
+
# If we can't count lines, assume it's small
|
|
183
|
+
return 0
|
|
184
|
+
|
|
185
|
+
def _process_chunk(chunk_lines: list[bytes]) -> list[Any]:
|
|
186
|
+
"""Process a chunk of lines and return parsed objects."""
|
|
187
|
+
results = []
|
|
188
|
+
for line_bytes in chunk_lines:
|
|
189
|
+
if skip_empty and not line_bytes.strip():
|
|
190
|
+
continue
|
|
191
|
+
line_bytes = line_bytes.rstrip(b"\r\n")
|
|
192
|
+
try:
|
|
193
|
+
if use_orjson and orjson is not None:
|
|
194
|
+
obj = orjson.loads(line_bytes)
|
|
195
|
+
else:
|
|
196
|
+
obj = json.loads(line_bytes.decode(encoding, errors))
|
|
197
|
+
results.append(obj)
|
|
198
|
+
except Exception as e:
|
|
199
|
+
if on_error == "raise":
|
|
200
|
+
raise
|
|
201
|
+
if on_error == "warn":
|
|
202
|
+
warnings.warn(f"Skipping malformed line: {e}")
|
|
203
|
+
# 'skip' and 'warn' both skip the line
|
|
204
|
+
continue
|
|
205
|
+
return results
|
|
206
|
+
|
|
207
|
+
# Check if we should use multi-worker processing
|
|
208
|
+
should_use_multiworker = (
|
|
209
|
+
use_multiworker
|
|
210
|
+
and not hasattr(path_or_file, "read") # Only for file paths, not file objects
|
|
211
|
+
and max_lines is None # Don't use multiworker if we're limiting lines
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
if should_use_multiworker:
|
|
215
|
+
line_count = _count_lines_fast(cast(Union[str, os.PathLike], path_or_file))
|
|
216
|
+
if line_count > multiworker_threshold:
|
|
217
|
+
# Use multi-worker processing
|
|
218
|
+
from ..multi_worker.thread import multi_thread
|
|
219
|
+
|
|
220
|
+
# Read all lines into chunks
|
|
221
|
+
f = _open_auto(path_or_file)
|
|
222
|
+
all_lines = list(f)
|
|
223
|
+
f.close()
|
|
224
|
+
|
|
225
|
+
# Split into chunks for workers
|
|
226
|
+
num_workers = workers or os.cpu_count() or 4
|
|
227
|
+
chunk_size = max(len(all_lines) // num_workers, 1000)
|
|
228
|
+
chunks = []
|
|
229
|
+
for i in range(0, len(all_lines), chunk_size):
|
|
230
|
+
chunks.append(all_lines[i:i + chunk_size])
|
|
231
|
+
|
|
232
|
+
# Process chunks in parallel
|
|
233
|
+
if progress:
|
|
234
|
+
print(f"Processing {line_count} lines with {num_workers} workers...")
|
|
235
|
+
|
|
236
|
+
chunk_results = multi_thread(_process_chunk, chunks, workers=num_workers, progress=progress)
|
|
237
|
+
|
|
238
|
+
# Flatten results and yield
|
|
239
|
+
for chunk_result in chunk_results:
|
|
240
|
+
for obj in chunk_result:
|
|
241
|
+
yield obj
|
|
242
|
+
return
|
|
243
|
+
|
|
244
|
+
# Single-threaded processing (original logic)
|
|
245
|
+
|
|
246
|
+
f = _open_auto(path_or_file)
|
|
247
|
+
|
|
248
|
+
pbar = None
|
|
249
|
+
if progress:
|
|
250
|
+
try:
|
|
251
|
+
from tqdm import tqdm # type: ignore
|
|
252
|
+
except Exception as e:
|
|
253
|
+
raise ImportError("tqdm is required when progress=True") from e
|
|
254
|
+
total = None
|
|
255
|
+
if not hasattr(path_or_file, "read"):
|
|
256
|
+
try:
|
|
257
|
+
path_for_size = cast(Union[str, os.PathLike], path_or_file)
|
|
258
|
+
total = os.path.getsize(path_for_size) # compressed size if any
|
|
259
|
+
except Exception:
|
|
260
|
+
total = None
|
|
261
|
+
pbar = tqdm(total=total, unit="B", unit_scale=True, desc=desc)
|
|
262
|
+
|
|
263
|
+
line_no = 0
|
|
264
|
+
try:
|
|
265
|
+
for raw_line in f:
|
|
266
|
+
line_no += 1
|
|
267
|
+
if pbar is not None:
|
|
268
|
+
# raw_line is bytes here; if not, compute byte length
|
|
269
|
+
nbytes = len(raw_line) if isinstance(raw_line, (bytes, bytearray)) else len(str(raw_line).encode(encoding, errors))
|
|
270
|
+
pbar.update(nbytes)
|
|
271
|
+
|
|
272
|
+
# Normalize to bytes -> str only if needed
|
|
273
|
+
if isinstance(raw_line, (bytes, bytearray)):
|
|
274
|
+
if skip_empty and not raw_line.strip():
|
|
275
|
+
if max_lines and line_no >= max_lines:
|
|
276
|
+
break
|
|
277
|
+
continue
|
|
278
|
+
line_bytes = raw_line.rstrip(b"\r\n")
|
|
279
|
+
# Parse
|
|
280
|
+
try:
|
|
281
|
+
if use_orjson and orjson is not None:
|
|
282
|
+
obj = orjson.loads(line_bytes)
|
|
283
|
+
else:
|
|
284
|
+
obj = json.loads(line_bytes.decode(encoding, errors))
|
|
285
|
+
except Exception as e:
|
|
286
|
+
if on_error == "raise":
|
|
287
|
+
raise
|
|
288
|
+
if on_error == "warn":
|
|
289
|
+
warnings.warn(f"Skipping malformed line {line_no}: {e}")
|
|
290
|
+
# 'skip' and 'warn' both skip the line
|
|
291
|
+
if max_lines and line_no >= max_lines:
|
|
292
|
+
break
|
|
293
|
+
continue
|
|
294
|
+
else:
|
|
295
|
+
# Text line path (unlikely)
|
|
296
|
+
if skip_empty and not raw_line.strip():
|
|
297
|
+
if max_lines and line_no >= max_lines:
|
|
298
|
+
break
|
|
299
|
+
continue
|
|
300
|
+
try:
|
|
301
|
+
obj = json.loads(raw_line)
|
|
302
|
+
except Exception as e:
|
|
303
|
+
if on_error == "raise":
|
|
304
|
+
raise
|
|
305
|
+
if on_error == "warn":
|
|
306
|
+
warnings.warn(f"Skipping malformed line {line_no}: {e}")
|
|
307
|
+
if max_lines and line_no >= max_lines:
|
|
308
|
+
break
|
|
309
|
+
continue
|
|
310
|
+
|
|
311
|
+
yield obj
|
|
312
|
+
if max_lines and line_no >= max_lines:
|
|
313
|
+
break
|
|
314
|
+
finally:
|
|
315
|
+
if pbar is not None:
|
|
316
|
+
pbar.close()
|
|
317
|
+
# Close only if we opened it (i.e., not an external stream)
|
|
318
|
+
if not hasattr(path_or_file, "read"):
|
|
319
|
+
try:
|
|
320
|
+
f.close()
|
|
321
|
+
except Exception:
|
|
322
|
+
pass
|
|
323
|
+
|
|
93
324
|
|
|
94
325
|
|
|
95
326
|
def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
|
|
@@ -124,7 +355,7 @@ def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
|
|
|
124
355
|
|
|
125
356
|
def load_default(path: str) -> Any:
|
|
126
357
|
if path.endswith(".jsonl"):
|
|
127
|
-
return
|
|
358
|
+
return list(fast_load_jsonl(path, progress=True))
|
|
128
359
|
elif path.endswith(".json"):
|
|
129
360
|
try:
|
|
130
361
|
return load_json_or_pickle(path)
|
|
@@ -159,14 +390,13 @@ def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
|
|
|
159
390
|
return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, **kwargs)
|
|
160
391
|
|
|
161
392
|
|
|
162
|
-
|
|
393
|
+
load_jsonl = lambda path: list(fast_load_jsonl(path))
|
|
163
394
|
|
|
164
395
|
__all__ = [
|
|
165
396
|
"dump_json_or_pickle",
|
|
166
397
|
"dump_jsonl",
|
|
167
398
|
"load_by_ext",
|
|
168
399
|
"load_json_or_pickle",
|
|
169
|
-
"load_jsonl",
|
|
170
400
|
"jdumps",
|
|
171
401
|
"jloads",
|
|
172
402
|
]
|