speedy-utils 1.1.17__py3-none-any.whl → 1.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,13 +1,18 @@
1
1
  # utils/utils_io.py
2
2
 
3
+ import bz2
4
+ import gzip
5
+ import io
3
6
  import json
7
+ import lzma
4
8
  import os
5
9
  import os.path as osp
6
10
  import pickle
7
11
  import time
12
+ import warnings
8
13
  from glob import glob
9
14
  from pathlib import Path
10
- from typing import Any, Union
15
+ from typing import IO, Any, Iterable, Optional, Union, cast
11
16
 
12
17
  from json_repair import loads as jloads
13
18
  from pydantic import BaseModel
@@ -53,7 +58,7 @@ def dump_json_or_pickle(
53
58
  except Exception as e:
54
59
  if isinstance(obj, BaseModel):
55
60
  data = obj.model_dump()
56
- from fastcore.all import obj2dict, dict2obj
61
+ from fastcore.all import dict2obj, obj2dict
57
62
  obj2 = dict2obj(data)
58
63
  with open(fname, "wb") as f:
59
64
  pickle.dump(obj2, f)
@@ -87,9 +92,235 @@ def load_json_or_pickle(fname: str, counter=0) -> Any:
87
92
  raise ValueError(f"Error {e} while loading {fname}") from e
88
93
 
89
94
 
90
- def load_jsonl(path):
91
- lines = open(path, encoding="utf-8").read().splitlines()
92
- return [json.loads(line) for line in lines]
95
+
96
+
97
+ try:
98
+ import orjson # type: ignore[import-not-found] # fastest JSON parser when available
99
+ except Exception:
100
+ orjson = None
101
+
102
+ try:
103
+ import zstandard as zstd # type: ignore[import-not-found] # optional .zst support
104
+ except Exception:
105
+ zstd = None
106
+
107
+
108
+ def fast_load_jsonl(
109
+ path_or_file: Union[str, os.PathLike, IO],
110
+ *,
111
+ progress: bool = False,
112
+ desc: str = "Reading JSONL",
113
+ use_orjson: bool = True,
114
+ encoding: str = "utf-8",
115
+ errors: str = "strict",
116
+ on_error: str = "raise", # 'raise' | 'warn' | 'skip'
117
+ skip_empty: bool = True,
118
+ max_lines: Optional[int] = None,
119
+ use_multiworker: bool = True,
120
+ multiworker_threshold: int = 50000,
121
+ workers: Optional[int] = None,
122
+ ) -> Iterable[Any]:
123
+ """
124
+ Lazily iterate objects from a JSON Lines file.
125
+
126
+ - Streams line-by-line (constant memory).
127
+ - Optional tqdm progress over bytes (compressed size if gz/bz2/xz/zst).
128
+ - Auto-detects compression by extension: .gz, .bz2, .xz/.lzma, .zst/.zstd.
129
+ - Uses orjson if available (use_orjson=True), falls back to json.
130
+ - Automatically uses multi-worker processing for large files (>50k lines).
131
+
132
+ Args:
133
+ path_or_file: Path-like or file-like object. File-like can be binary or text.
134
+ progress: Show a tqdm progress bar (bytes). Requires `tqdm` if True.
135
+ desc: tqdm description if progress=True.
136
+ use_orjson: Prefer orjson for speed if installed.
137
+ encoding, errors: Used when decoding text or when falling back to `json`.
138
+ on_error: What to do on a malformed line: 'raise', 'warn', or 'skip'.
139
+ skip_empty: Skip blank/whitespace-only lines.
140
+ max_lines: Stop after reading this many lines (useful for sampling).
141
+ use_multiworker: Enable multi-worker processing for large files.
142
+ multiworker_threshold: Line count threshold to trigger multi-worker processing.
143
+ workers: Number of worker threads (defaults to CPU count).
144
+
145
+ Yields:
146
+ Parsed Python objects per line.
147
+ """
148
+ def _open_auto(pth_or_f) -> IO[Any]:
149
+ if hasattr(pth_or_f, "read"):
150
+ # ensure binary buffer for consistent byte-length progress
151
+ fobj = pth_or_f
152
+ # If it's text, wrap it to binary via encoding; else just return
153
+ if isinstance(fobj, io.TextIOBase):
154
+ # TextIO -> re-encode to bytes on the fly
155
+ return io.BufferedReader(io.BytesIO(fobj.read().encode(encoding, errors)))
156
+ return pth_or_f # assume binary
157
+ s = str(pth_or_f).lower()
158
+ if s.endswith(".gz"):
159
+ return gzip.open(pth_or_f, "rb") # type: ignore
160
+ if s.endswith(".bz2"):
161
+ return bz2.open(pth_or_f, "rb") # type: ignore
162
+ if s.endswith((".xz", ".lzma")):
163
+ return lzma.open(pth_or_f, "rb") # type: ignore
164
+ if s.endswith((".zst", ".zstd")) and zstd is not None:
165
+ fh = open(pth_or_f, "rb")
166
+ dctx = zstd.ZstdDecompressor()
167
+ stream = dctx.stream_reader(fh)
168
+ return io.BufferedReader(stream) # type: ignore
169
+ # plain
170
+ return open(pth_or_f, "rb", buffering=1024 * 1024)
171
+
172
+ def _count_lines_fast(file_path: Union[str, os.PathLike]) -> int:
173
+ """Quickly count lines in a file, handling compression."""
174
+ try:
175
+ f = _open_auto(file_path)
176
+ count = 0
177
+ for _ in f:
178
+ count += 1
179
+ f.close()
180
+ return count
181
+ except Exception:
182
+ # If we can't count lines, assume it's small
183
+ return 0
184
+
185
+ def _process_chunk(chunk_lines: list[bytes]) -> list[Any]:
186
+ """Process a chunk of lines and return parsed objects."""
187
+ results = []
188
+ for line_bytes in chunk_lines:
189
+ if skip_empty and not line_bytes.strip():
190
+ continue
191
+ line_bytes = line_bytes.rstrip(b"\r\n")
192
+ try:
193
+ if use_orjson and orjson is not None:
194
+ obj = orjson.loads(line_bytes)
195
+ else:
196
+ obj = json.loads(line_bytes.decode(encoding, errors))
197
+ results.append(obj)
198
+ except Exception as e:
199
+ if on_error == "raise":
200
+ raise
201
+ if on_error == "warn":
202
+ warnings.warn(f"Skipping malformed line: {e}")
203
+ # 'skip' and 'warn' both skip the line
204
+ continue
205
+ return results
206
+
207
+ # Check if we should use multi-worker processing
208
+ should_use_multiworker = (
209
+ use_multiworker
210
+ and not hasattr(path_or_file, "read") # Only for file paths, not file objects
211
+ and max_lines is None # Don't use multiworker if we're limiting lines
212
+ )
213
+
214
+ if should_use_multiworker:
215
+ line_count = _count_lines_fast(cast(Union[str, os.PathLike], path_or_file))
216
+ if line_count > multiworker_threshold:
217
+ # Use multi-worker processing
218
+ from ..multi_worker.thread import multi_thread
219
+
220
+ # Read all lines into chunks
221
+ f = _open_auto(path_or_file)
222
+ all_lines = list(f)
223
+ f.close()
224
+
225
+ # Split into chunks for workers
226
+ num_workers = workers or os.cpu_count() or 4
227
+ chunk_size = max(len(all_lines) // num_workers, 1000)
228
+ chunks = []
229
+ for i in range(0, len(all_lines), chunk_size):
230
+ chunks.append(all_lines[i:i + chunk_size])
231
+
232
+ # Process chunks in parallel
233
+ if progress:
234
+ print(f"Processing {line_count} lines with {num_workers} workers...")
235
+
236
+ chunk_results = multi_thread(_process_chunk, chunks, workers=num_workers, progress=progress)
237
+
238
+ # Flatten results and yield
239
+ for chunk_result in chunk_results:
240
+ for obj in chunk_result:
241
+ yield obj
242
+ return
243
+
244
+ # Single-threaded processing (original logic)
245
+
246
+ f = _open_auto(path_or_file)
247
+
248
+ pbar = None
249
+ if progress:
250
+ try:
251
+ from tqdm import tqdm # type: ignore
252
+ except Exception as e:
253
+ raise ImportError("tqdm is required when progress=True") from e
254
+ total = None
255
+ if not hasattr(path_or_file, "read"):
256
+ try:
257
+ path_for_size = cast(Union[str, os.PathLike], path_or_file)
258
+ total = os.path.getsize(path_for_size) # compressed size if any
259
+ except Exception:
260
+ total = None
261
+ pbar = tqdm(total=total, unit="B", unit_scale=True, desc=desc)
262
+
263
+ line_no = 0
264
+ try:
265
+ for raw_line in f:
266
+ line_no += 1
267
+ if pbar is not None:
268
+ # raw_line is bytes here; if not, compute byte length
269
+ nbytes = len(raw_line) if isinstance(raw_line, (bytes, bytearray)) else len(str(raw_line).encode(encoding, errors))
270
+ pbar.update(nbytes)
271
+
272
+ # Normalize to bytes -> str only if needed
273
+ if isinstance(raw_line, (bytes, bytearray)):
274
+ if skip_empty and not raw_line.strip():
275
+ if max_lines and line_no >= max_lines:
276
+ break
277
+ continue
278
+ line_bytes = raw_line.rstrip(b"\r\n")
279
+ # Parse
280
+ try:
281
+ if use_orjson and orjson is not None:
282
+ obj = orjson.loads(line_bytes)
283
+ else:
284
+ obj = json.loads(line_bytes.decode(encoding, errors))
285
+ except Exception as e:
286
+ if on_error == "raise":
287
+ raise
288
+ if on_error == "warn":
289
+ warnings.warn(f"Skipping malformed line {line_no}: {e}")
290
+ # 'skip' and 'warn' both skip the line
291
+ if max_lines and line_no >= max_lines:
292
+ break
293
+ continue
294
+ else:
295
+ # Text line path (unlikely)
296
+ if skip_empty and not raw_line.strip():
297
+ if max_lines and line_no >= max_lines:
298
+ break
299
+ continue
300
+ try:
301
+ obj = json.loads(raw_line)
302
+ except Exception as e:
303
+ if on_error == "raise":
304
+ raise
305
+ if on_error == "warn":
306
+ warnings.warn(f"Skipping malformed line {line_no}: {e}")
307
+ if max_lines and line_no >= max_lines:
308
+ break
309
+ continue
310
+
311
+ yield obj
312
+ if max_lines and line_no >= max_lines:
313
+ break
314
+ finally:
315
+ if pbar is not None:
316
+ pbar.close()
317
+ # Close only if we opened it (i.e., not an external stream)
318
+ if not hasattr(path_or_file, "read"):
319
+ try:
320
+ f.close()
321
+ except Exception:
322
+ pass
323
+
93
324
 
94
325
 
95
326
  def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
@@ -124,7 +355,7 @@ def load_by_ext(fname: Union[str, list[str]], do_memoize: bool = False) -> Any:
124
355
 
125
356
  def load_default(path: str) -> Any:
126
357
  if path.endswith(".jsonl"):
127
- return load_jsonl(path)
358
+ return list(fast_load_jsonl(path, progress=True))
128
359
  elif path.endswith(".json"):
129
360
  try:
130
361
  return load_json_or_pickle(path)
@@ -159,14 +390,13 @@ def jdumps(obj, ensure_ascii=False, indent=2, **kwargs):
159
390
  return json.dumps(obj, ensure_ascii=ensure_ascii, indent=indent, **kwargs)
160
391
 
161
392
 
162
-
393
+ load_jsonl = lambda path: list(fast_load_jsonl(path))
163
394
 
164
395
  __all__ = [
165
396
  "dump_json_or_pickle",
166
397
  "dump_jsonl",
167
398
  "load_by_ext",
168
399
  "load_json_or_pickle",
169
- "load_jsonl",
170
400
  "jdumps",
171
401
  "jloads",
172
402
  ]