speedy-utils 1.1.30__py3-none-any.whl → 1.1.32__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
speedy_utils/__imports.py CHANGED
@@ -5,15 +5,6 @@ import time
5
5
  import warnings
6
6
 
7
7
 
8
- # Suppress lazy_loader subpackage warning
9
- warnings.filterwarnings(
10
- 'ignore',
11
- message='subpackages can technically be lazily loaded',
12
- category=RuntimeWarning,
13
- module='lazy_loader',
14
- )
15
-
16
- t = time.time()
17
8
  # Third-party imports
18
9
  try:
19
10
  # Python 3.10+
@@ -86,7 +77,6 @@ from typing import (
86
77
  )
87
78
 
88
79
  import cachetools
89
- import lazy_loader as lazy
90
80
  import psutil
91
81
  from fastcore.parallel import parallel
92
82
  from json_repair import loads as jloads
@@ -94,27 +84,121 @@ from loguru import logger
94
84
  from tqdm import tqdm
95
85
 
96
86
 
97
- # Resolve long-import-time dependencies lazily
87
+ # Direct imports (previously lazy-loaded)
88
+ import numpy as np
89
+ tabulate = __import__('tabulate').tabulate
90
+ import xxhash
91
+
92
+ # Optional imports - lazy loaded for performance
93
+ def _get_pandas():
94
+ """Lazy import pandas."""
95
+ try:
96
+ import pandas as pd
97
+ return pd
98
+ except ImportError:
99
+ return None
100
+
101
+ def _get_ray():
102
+ """Lazy import ray."""
103
+ try:
104
+ import ray
105
+ return ray
106
+ except ImportError:
107
+ return None
108
+
109
+ def _get_matplotlib():
110
+ """Lazy import matplotlib."""
111
+ try:
112
+ import matplotlib
113
+ return matplotlib
114
+ except ImportError:
115
+ return None
116
+
117
+ def _get_matplotlib_pyplot():
118
+ """Lazy import matplotlib.pyplot."""
119
+ try:
120
+ import matplotlib.pyplot as plt
121
+ return plt
122
+ except ImportError:
123
+ return None
124
+
125
+ def _get_ipython_core():
126
+ """Lazy import IPython.core.getipython."""
127
+ try:
128
+ from IPython.core.getipython import get_ipython
129
+ return get_ipython
130
+ except ImportError:
131
+ return None
132
+
133
+ # Cache for lazy imports
134
+ _pandas_cache = None
135
+ _ray_cache = None
136
+ _matplotlib_cache = None
137
+ _plt_cache = None
138
+ _get_ipython_cache = None
139
+
140
+ # Lazy import classes for performance-critical modules
141
+ class _LazyModule:
142
+ """Lazy module loader that imports only when accessed."""
143
+ def __init__(self, import_func, cache_var_name):
144
+ self._import_func = import_func
145
+ self._cache_var_name = cache_var_name
146
+ self._module = None
147
+
148
+ def __call__(self):
149
+ """Allow calling as a function to get the module."""
150
+ if self._module is None:
151
+ # Use global cache
152
+ cache = globals().get(self._cache_var_name)
153
+ if cache is None:
154
+ cache = self._import_func()
155
+ globals()[self._cache_var_name] = cache
156
+ self._module = cache
157
+ return self._module
158
+
159
+ def __getattr__(self, name):
160
+ """Lazy attribute access."""
161
+ if self._module is None:
162
+ self() # Load the module
163
+ return getattr(self._module, name)
98
164
 
99
- torch = lazy.load('torch') # lazy at runtime
100
- np = lazy.load('numpy')
101
- pd = lazy.load('pandas')
102
- tqdm = lazy.load('tqdm').tqdm # type: ignore # noqa: F811
103
- pd = lazy.load('pandas')
104
- tabulate = lazy.load('tabulate').tabulate
105
- xxhash = lazy.load('xxhash')
106
- get_ipython = lazy.load('IPython.core.getipython')
107
- HTML = lazy.load('IPython.display').HTML
108
- display = lazy.load('IPython.display').display
109
- # logger = lazy.load('loguru').logger
110
- BaseModel = lazy.load('pydantic').BaseModel
111
- _pil = lazy.load('PIL.Image')
112
- Image = _pil.Image
113
- matplotlib = lazy.load('matplotlib')
114
- plt = lazy.load('matplotlib.pyplot')
165
+ def __bool__(self):
166
+ """Support truthiness checks."""
167
+ return self._module is not None
115
168
 
169
+ def __repr__(self):
170
+ if self._module is None:
171
+ return f"<LazyModule: not loaded>"
172
+ return repr(self._module)
116
173
 
117
- ray = lazy.load('ray') # lazy at runtime
174
+ # Create lazy loaders for top slow imports (import only when accessed)
175
+ pd = _LazyModule(_get_pandas, '_pandas_cache')
176
+ ray = _LazyModule(_get_ray, '_ray_cache')
177
+ matplotlib = _LazyModule(_get_matplotlib, '_matplotlib_cache')
178
+ plt = _LazyModule(_get_matplotlib_pyplot, '_plt_cache')
179
+ get_ipython = _LazyModule(_get_ipython_core, '_get_ipython_cache')
180
+
181
+ # Other optional imports (not lazy loaded as they're not in top slow imports)
182
+ try:
183
+ import torch
184
+ except ImportError:
185
+ torch = None
186
+
187
+ try:
188
+ from IPython.display import HTML, display
189
+ except ImportError:
190
+ HTML = None
191
+ display = None
192
+
193
+ try:
194
+ from PIL import Image
195
+ except ImportError:
196
+ Image = None
197
+
198
+ try:
199
+ from pydantic import BaseModel
200
+ except ImportError:
201
+ BaseModel = None
118
202
  if TYPE_CHECKING:
119
203
  import numpy as np
120
204
  import pandas as pd
@@ -133,7 +217,7 @@ if TYPE_CHECKING:
133
217
 
134
218
  __all__ = [
135
219
  # ------------------------------------------------------------------
136
- # Lazy-loaded external modules / objects
220
+ # Direct imports (previously lazy-loaded)
137
221
  # ------------------------------------------------------------------
138
222
  'torch',
139
223
  'np',
@@ -147,6 +231,8 @@ __all__ = [
147
231
  'BaseModel',
148
232
  'Image',
149
233
  'ray',
234
+ 'matplotlib',
235
+ 'plt',
150
236
  # ------------------------------------------------------------------
151
237
  # Standard library modules imported
152
238
  # ------------------------------------------------------------------
@@ -235,7 +321,6 @@ __all__ = [
235
321
  # Third-party modules
236
322
  # ------------------------------------------------------------------
237
323
  'cachetools',
238
- 'lazy',
239
324
  'psutil',
240
325
  'parallel',
241
326
  'jloads',
speedy_utils/__init__.py CHANGED
@@ -113,6 +113,8 @@ __all__ = [
113
113
  'tabulate',
114
114
  'tqdm',
115
115
  'np',
116
+ 'matplotlib',
117
+ 'plt',
116
118
  # Clock module
117
119
  'Clock',
118
120
  'speedy_timer',
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: speedy-utils
3
- Version: 1.1.30
3
+ Version: 1.1.32
4
4
  Summary: Fast and easy-to-use package for data science
5
5
  Project-URL: Homepage, https://github.com/anhvth/speedy
6
6
  Project-URL: Repository, https://github.com/anhvth/speedy
@@ -39,6 +39,7 @@ Requires-Dist: pydantic
39
39
  Requires-Dist: pytest
40
40
  Requires-Dist: ray
41
41
  Requires-Dist: requests
42
+ Requires-Dist: ruff
42
43
  Requires-Dist: scikit-learn
43
44
  Requires-Dist: tabulate
44
45
  Requires-Dist: tqdm
@@ -27,9 +27,8 @@ llm_utils/vector_cache/cli.py,sha256=MAvnmlZ7j7_0CvIcSyK4TvJlSRFWYkm4wE7zSq3KR8k
27
27
  llm_utils/vector_cache/core.py,sha256=VXuYJy1AX22NHKvIXRriETip5RrmQcNp73-g-ZT774o,30950
28
28
  llm_utils/vector_cache/types.py,sha256=CpMZanJSTeBVxQSqjBq6pBVWp7u2-JRcgY9t5jhykdQ,438
29
29
  llm_utils/vector_cache/utils.py,sha256=OsiRFydv8i8HiJtPL9hh40aUv8I5pYfg2zvmtDi4DME,1446
30
- speedy_utils/__imports.py,sha256=KQogps2TTzigJK8YkW875qaIEwf8PypmTtzyc7svvfw,5691
31
- speedy_utils/__init__.py,sha256=dxEwqFTQIXIsWlKEXm6b5pcEx2c8W8R5iblbHN2chdc,2656
32
- speedy_utils/all.py,sha256=5gN_mIvx9mtEWMXl64S0NvY3Wrj1bb5QyrMh-U03uiE,2508
30
+ speedy_utils/__imports.py,sha256=PhHqZWwVOKAbbXoWxZLVVyurGmZhui3boQ7Nji002cQ,7795
31
+ speedy_utils/__init__.py,sha256=VkKqS4eHXd8YeDu2TAQ3Osqy70RSufUL1sECDoYzqvM,2685
33
32
  speedy_utils/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
34
33
  speedy_utils/common/clock.py,sha256=raLtMGIgzrRej5kUt7hOUm2ZZw2THVPo-q8dMvdZOxw,7354
35
34
  speedy_utils/common/function_decorator.py,sha256=GKXqRs_hHFFmhyhql0Br0o52WzekUnpNlm99NfaVwgY,2025
@@ -49,9 +48,9 @@ speedy_utils/scripts/mpython.py,sha256=aZvusJLKa3APVhabbFUAEo873VBm8Fym7HKGmVW4L
49
48
  speedy_utils/scripts/openapi_client_codegen.py,sha256=GModmmhkvGnxljK4KczyixKDrk-VEcLaW5I0XT6tzWo,9657
50
49
  vision_utils/README.md,sha256=AIDZZj8jo_QNrEjFyHwd00iOO431s-js-M2dLtVTn3I,5740
51
50
  vision_utils/__init__.py,sha256=XsLxy1Fn33Zxu6hTFl3NEWfxGjuQQ-0Wmoh6lU9NZ_o,257
52
- vision_utils/io_utils.py,sha256=1FkG6k7uwZALh3-JkWXEHoGQJhjTqG1jC20SxObPRS0,25921
53
- vision_utils/plot.py,sha256=tJNuXmwUQ9GVe52RGBHzkRlBCbrGbpMNdBTtQ7eEljs,12055
54
- speedy_utils-1.1.30.dist-info/METADATA,sha256=i5X3HJZEX1UdLKvRIOAT6IQbmzuZA4kLLFUNFBZCc3I,8028
55
- speedy_utils-1.1.30.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
56
- speedy_utils-1.1.30.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
57
- speedy_utils-1.1.30.dist-info/RECORD,,
51
+ vision_utils/io_utils.py,sha256=q41pffN632HbMmzcBzfg2Z7DvZZgoAQCdD9jHLqDgjc,26603
52
+ vision_utils/plot.py,sha256=v73onfH8KbGHigw5KStUPqbLyJqIEOvvJaqtaoGKrls,12032
53
+ speedy_utils-1.1.32.dist-info/METADATA,sha256=ElLAOdGyTiqq33ON2WgHX4grgtlFqWRncupD0DivCBk,8048
54
+ speedy_utils-1.1.32.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
55
+ speedy_utils-1.1.32.dist-info/entry_points.txt,sha256=1rrFMfqvaMUE9hvwGiD6vnVh98kmgy0TARBj-v0Lfhs,244
56
+ speedy_utils-1.1.32.dist-info/RECORD,,
@@ -1,4 +1,4 @@
1
1
  Wheel-Version: 1.0
2
- Generator: hatchling 1.27.0
2
+ Generator: hatchling 1.28.0
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
vision_utils/io_utils.py CHANGED
@@ -56,7 +56,8 @@ def _validate_image(path: PathLike) -> bool:
56
56
  def read_images_cpu(
57
57
  paths: Sequence[PathLike],
58
58
  hw: tuple[int, int] | None = None,
59
- ) -> dict[str, 'np.ndarray | None']:
59
+ verbose: bool = True,
60
+ ) -> dict[str, "np.ndarray | None"]:
60
61
  """
61
62
  CPU image loader using Pillow.
62
63
 
@@ -73,7 +74,7 @@ def read_images_cpu(
73
74
  str_paths = _to_str_paths(paths)
74
75
 
75
76
  # Pillow < 9.1.0 exposes resampling filters directly on Image
76
- resample_attr = getattr(Image, 'Resampling', Image)
77
+ resample_attr = getattr(Image, "Resampling", Image)
77
78
  resample = resample_attr.BILINEAR
78
79
 
79
80
  target_size = None # Pillow expects (width, height)
@@ -81,16 +82,20 @@ def read_images_cpu(
81
82
  h, w = hw
82
83
  target_size = (w, h)
83
84
 
84
- result: dict[str, 'np.ndarray | None'] = {}
85
- for path in tqdm(str_paths, desc='Loading images (CPU)', unit='img'):
85
+ result: dict[str, "np.ndarray | None"] = {}
86
+ if verbose:
87
+ pbar = tqdm(str_paths, desc="Loading images (CPU)", unit="img")
88
+ else:
89
+ pbar = str_paths
90
+ for path in pbar:
86
91
  try:
87
92
  with Image.open(path) as img:
88
- img = img.convert('RGB')
93
+ img = img.convert("RGB")
89
94
  if target_size is not None:
90
95
  img = img.resize(target_size, resample=resample)
91
96
  result[path] = np.asarray(img)
92
97
  except Exception as e:
93
- print(f'Warning: Failed to load {path}: {e}')
98
+ print(f"Warning: Failed to load {path}: {e}")
94
99
  result[path] = None
95
100
  return result
96
101
 
@@ -101,9 +106,10 @@ def read_images_gpu(
101
106
  num_threads: int = 4,
102
107
  hw: tuple[int, int] | None = None,
103
108
  validate: bool = False,
104
- device: str = 'mixed',
109
+ device: str = "mixed",
105
110
  device_id: int = 0,
106
- ) -> dict[str, 'np.ndarray | None']:
111
+ verbose: bool = True,
112
+ ) -> dict[str, "np.ndarray | None"]:
107
113
  """
108
114
  GPU-accelerated image reader using NVIDIA DALI.
109
115
 
@@ -117,6 +123,7 @@ def read_images_gpu(
117
123
  validate: If True, pre-validate images (slower).
118
124
  device: DALI decoder device: "mixed" (default), "cpu", or "gpu".
119
125
  device_id: GPU device id.
126
+ verbose: If True, show progress bar.
120
127
  """
121
128
  import numpy as np
122
129
  from nvidia.dali import fn, pipeline_def
@@ -127,23 +134,23 @@ def read_images_gpu(
127
134
  if not str_paths:
128
135
  return {}
129
136
 
130
- result: dict[str, 'np.ndarray | None'] = {}
137
+ result: dict[str, "np.ndarray | None"] = {}
131
138
  valid_paths: list[str] = str_paths
132
139
 
133
140
  # Optional validation (slow but safer)
134
141
  if validate:
135
142
  from tqdm import tqdm
136
143
 
137
- print('Validating images...')
144
+ print("Validating images...")
138
145
  tmp_valid: list[str] = []
139
146
  invalid_paths: list[str] = []
140
147
 
141
- for path in tqdm(str_paths, desc='Validating', unit='img'):
148
+ for path in tqdm(str_paths, desc="Validating", unit="img"):
142
149
  if _validate_image(path):
143
150
  tmp_valid.append(path)
144
151
  else:
145
152
  invalid_paths.append(path)
146
- print(f'Warning: Skipping invalid/corrupted image: {path}')
153
+ print(f"Warning: Skipping invalid/corrupted image: {path}")
147
154
 
148
155
  valid_paths = tmp_valid
149
156
  # pre-fill invalid paths with None
@@ -151,7 +158,7 @@ def read_images_gpu(
151
158
  result[p] = None
152
159
 
153
160
  if not valid_paths:
154
- print('No valid images found.')
161
+ print("No valid images found.")
155
162
  return result
156
163
 
157
164
  resize_h, resize_w = (None, None)
@@ -166,7 +173,7 @@ def read_images_gpu(
166
173
  jpegs, _ = fn.readers.file(
167
174
  files=files_for_reader,
168
175
  random_shuffle=False,
169
- name='Reader',
176
+ name="Reader",
170
177
  )
171
178
  imgs = fn.decoders.image(jpegs, device=device, output_type=dali_types.RGB)
172
179
  if resize_h is not None and resize_w is not None:
@@ -188,13 +195,17 @@ def read_images_gpu(
188
195
  )
189
196
  dali_pipe.build()
190
197
 
191
- imgs: list['np.ndarray'] = []
198
+ imgs: list["np.ndarray"] = []
192
199
  num_files = len(valid_paths)
193
200
  num_batches = (num_files + batch_size - 1) // batch_size
194
201
 
195
202
  from tqdm import tqdm
203
+ if verbose:
204
+ pbar = tqdm(range(num_batches), desc="Decoding (DALI)", unit="batch")
205
+ else:
206
+ pbar = range(num_batches)
196
207
 
197
- for _ in tqdm(range(num_batches), desc='Decoding (DALI)', unit='batch'):
208
+ for _ in pbar:
198
209
  (out,) = dali_pipe.run()
199
210
  out = out.as_cpu()
200
211
  for i in range(len(out)):
@@ -203,7 +214,7 @@ def read_images_gpu(
203
214
  # Handle possible padding / extra samples
204
215
  if len(imgs) < num_files:
205
216
  print(
206
- f'Warning: DALI returned fewer samples ({len(imgs)}) than expected ({num_files}).'
217
+ f"Warning: DALI returned fewer samples ({len(imgs)}) than expected ({num_files})."
207
218
  )
208
219
  if len(imgs) > num_files:
209
220
  imgs = imgs[:num_files]
@@ -221,9 +232,10 @@ def read_images(
221
232
  num_threads: int = 4,
222
233
  hw: tuple[int, int] | None = None,
223
234
  validate: bool = False,
224
- device: str = 'mixed',
235
+ device: str = "mixed",
225
236
  device_id: int = 0,
226
- ) -> dict[str, 'np.ndarray | None']:
237
+ verbose: bool = True,
238
+ ) -> dict[str, "np.ndarray | None"]:
227
239
  """
228
240
  Fast image reader that tries GPU (DALI) first, falls back to CPU (Pillow).
229
241
 
@@ -237,6 +249,7 @@ def read_images(
237
249
  validate: If True, pre-validate images before GPU processing (slower).
238
250
  device: DALI decoder device: "mixed", "cpu", or "gpu".
239
251
  device_id: GPU device id for DALI.
252
+ verbose: If True, show progress bars.
240
253
  """
241
254
  str_paths = _to_str_paths(paths)
242
255
 
@@ -252,13 +265,12 @@ def read_images(
252
265
  validate=validate,
253
266
  device=device,
254
267
  device_id=device_id,
268
+ verbose=verbose,
255
269
  )
256
270
  except Exception as exc:
257
- print(f'GPU loading failed ({exc}), falling back to CPU...')
258
- return read_images_cpu(str_paths, hw=hw)
259
-
260
-
261
-
271
+ if verbose:
272
+ print(f"GPU loading failed ({exc}), falling back to CPU...")
273
+ return read_images_cpu(str_paths, hw=hw, verbose=verbose)
262
274
 
263
275
 
264
276
  class ImageMmap(Dataset):
@@ -287,22 +299,24 @@ class ImageMmap(Dataset):
287
299
  self.safe = safe
288
300
 
289
301
  # Generate default mmap path if not provided
302
+ current_hash = identify(
303
+ "".join(sorted(self.img_paths)) + f"_{self.H}x{self.W}x{self.C}"
304
+ )
290
305
  if mmap_path is None:
291
- hash_idx = identify(''.join(self.img_paths))
292
- mmap_path = Path('.cache') / f'mmap_dataset_{hash_idx}.dat'
293
-
306
+ # hash_idx = identify(''.join(sorted(self.img_paths)))
307
+ mmap_path = Path(".cache") / f"mmap_dataset_{current_hash}.dat"
308
+
294
309
  self.mmap_path = Path(mmap_path)
295
- self.hash_path = Path(str(self.mmap_path) + '.hash')
296
- self.lock_path = Path(str(self.mmap_path) + '.lock')
310
+ self.hash_path = Path(str(self.mmap_path) + ".hash")
311
+ self.lock_path = Path(str(self.mmap_path) + ".lock")
297
312
  self.shape = (self.n, self.H, self.W, self.C)
298
313
 
299
314
  if self.n == 0:
300
315
  raise ValueError("Cannot create ImageMmap with empty img_paths list")
301
316
 
302
317
  # Calculate hash of image paths
303
- current_hash = identify(self.img_paths)
304
318
  needs_rebuild = False
305
-
319
+
306
320
  if not self.mmap_path.exists():
307
321
  needs_rebuild = True
308
322
  print("Mmap file does not exist, building cache...")
@@ -314,16 +328,20 @@ class ImageMmap(Dataset):
314
328
  stored_hash = self.hash_path.read_text().strip()
315
329
  if stored_hash != current_hash:
316
330
  needs_rebuild = True
317
- print(f"Hash mismatch (stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), rebuilding cache...")
318
-
331
+ print(
332
+ f"Hash mismatch (stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), rebuilding cache..."
333
+ )
334
+
319
335
  # Verify file size matches expected
320
336
  expected_bytes = np.prod(self.shape) * self.dtype.itemsize
321
337
  if self.mmap_path.exists():
322
338
  actual_size = self.mmap_path.stat().st_size
323
339
  if actual_size != expected_bytes:
324
340
  needs_rebuild = True
325
- print(f"Mmap file size mismatch (expected: {expected_bytes}, got: {actual_size}), rebuilding cache...")
326
-
341
+ print(
342
+ f"Mmap file size mismatch (expected: {expected_bytes}, got: {actual_size}), rebuilding cache..."
343
+ )
344
+
327
345
  if needs_rebuild:
328
346
  self._build_cache_with_lock(current_hash)
329
347
 
@@ -338,30 +356,32 @@ class ImageMmap(Dataset):
338
356
  # --------------------------------------------------------------------- #
339
357
  # Build phase (only on first run)
340
358
  # --------------------------------------------------------------------- #
341
- def _build_cache_with_lock(self, current_hash: str, num_workers: int = None) -> None:
359
+ def _build_cache_with_lock(
360
+ self, current_hash: str, num_workers: int = None
361
+ ) -> None:
342
362
  """Build cache with lock file to prevent concurrent disk writes"""
343
363
  import fcntl
344
-
364
+
345
365
  self.mmap_path.parent.mkdir(parents=True, exist_ok=True)
346
-
366
+
347
367
  # Try to acquire lock file
348
368
  lock_fd = None
349
369
  try:
350
- lock_fd = open(self.lock_path, 'w')
370
+ lock_fd = open(self.lock_path, "w")
351
371
  fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
352
-
372
+
353
373
  # We got the lock, build the cache
354
374
  self._build_cache(current_hash, num_workers)
355
-
375
+
356
376
  except BlockingIOError:
357
377
  # Another process is building, wait for it
358
378
  print("Another process is building the cache, waiting...")
359
379
  if lock_fd:
360
380
  lock_fd.close()
361
- lock_fd = open(self.lock_path, 'w')
381
+ lock_fd = open(self.lock_path, "w")
362
382
  fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # Wait for lock
363
383
  print("Cache built by another process!")
364
-
384
+
365
385
  finally:
366
386
  if lock_fd:
367
387
  lock_fd.close()
@@ -370,36 +390,38 @@ class ImageMmap(Dataset):
370
390
  self.lock_path.unlink()
371
391
  except:
372
392
  pass
373
-
393
+
374
394
  def _build_cache(self, current_hash: str, num_workers: int = None) -> None:
375
395
  from tqdm import tqdm
376
-
396
+
377
397
  # Pre-allocate the file with the required size
378
398
  total_bytes = np.prod(self.shape) * self.dtype.itemsize
379
399
  print(f"Pre-allocating {total_bytes / (1024**3):.2f} GB for mmap file...")
380
- with open(self.mmap_path, 'wb') as f:
400
+ with open(self.mmap_path, "wb") as f:
381
401
  f.seek(total_bytes - 1)
382
- f.write(b'\0')
383
-
402
+ f.write(b"\0")
403
+
384
404
  mm = np.memmap(
385
405
  self.mmap_path,
386
406
  dtype=self.dtype,
387
- mode='r+',
407
+ mode="r+",
388
408
  shape=self.shape,
389
409
  )
390
-
410
+
391
411
  # Process images in batches to avoid memory explosion
392
- batch_size = 4096
412
+ batch_size = 40960
393
413
  num_batches = (self.n + batch_size - 1) // batch_size
394
-
395
- print(f"Loading {self.n} images in {num_batches} batches of up to {batch_size} images...")
396
-
397
- with tqdm(total=self.n, desc='Processing images', unit='img') as pbar:
414
+
415
+ print(
416
+ f"Loading {self.n} images in {num_batches} batches of up to {batch_size} images..."
417
+ )
418
+
419
+ with tqdm(total=self.n, desc="Processing images", unit="img") as pbar:
398
420
  for batch_idx in range(num_batches):
399
421
  start_idx = batch_idx * batch_size
400
422
  end_idx = min(start_idx + batch_size, self.n)
401
423
  batch_paths = self.img_paths[start_idx:end_idx]
402
-
424
+
403
425
  # Load one batch at a time
404
426
  images_dict = read_images(
405
427
  batch_paths,
@@ -407,37 +429,38 @@ class ImageMmap(Dataset):
407
429
  batch_size=32,
408
430
  num_threads=num_workers or max(1, cpu_count() - 1),
409
431
  )
410
-
432
+
411
433
  # Write batch to mmap
412
434
  for local_idx, path in enumerate(batch_paths):
413
435
  global_idx = start_idx + local_idx
414
436
  img = images_dict.get(path)
415
-
437
+
416
438
  if img is None:
417
439
  if self.safe:
418
- raise ValueError(f'Failed to load image: {path}')
440
+ raise ValueError(f"Failed to load image: {path}")
419
441
  else:
420
442
  # Failed to load, write zeros
421
- print(f'Warning: Failed to load {path}, using zeros')
443
+ print(f"Warning: Failed to load {path}, using zeros")
422
444
  mm[global_idx] = np.zeros(
423
- (self.H, self.W, self.C),
424
- dtype=self.dtype
445
+ (self.H, self.W, self.C), dtype=self.dtype
425
446
  )
426
447
  else:
427
- # Ensure correct dtype
448
+ # Clip to valid range and ensure correct dtype
449
+ if self.dtype == np.uint8:
450
+ img = np.clip(img, 0, 255)
428
451
  if img.dtype != self.dtype:
429
452
  img = img.astype(self.dtype)
430
453
  mm[global_idx] = img
431
-
454
+
432
455
  pbar.update(1)
433
-
456
+
434
457
  # Flush after each batch and clear memory
435
458
  mm.flush()
436
459
  del images_dict
437
-
460
+
438
461
  mm.flush()
439
462
  del mm # ensure descriptor is closed
440
-
463
+
441
464
  # Save hash file
442
465
  self.hash_path.write_text(current_hash)
443
466
  print(f"Mmap cache built successfully! Hash saved to {self.hash_path}")
@@ -456,16 +479,17 @@ class ImageMmap(Dataset):
456
479
  def __getitem__(self, idx: int) -> np.ndarray:
457
480
  # At runtime: this is just a mmap read
458
481
  return np.array(self.data[idx]) # copy to normal ndarray
459
-
482
+
460
483
  def imread(self, image_path: str | os.PathLike) -> np.ndarray:
461
484
  idx = self.imgpath2idx.get(str(image_path))
462
485
  if idx is None:
463
486
  raise ValueError(f"Image path {image_path} not found in dataset")
464
- img = np.array(self.data[idx]) # copy to normal ndarray
487
+ img = np.array(self.data[idx]) # copy to normal ndarray
465
488
  summary = img.sum()
466
489
  assert summary > 0, f"Image at {image_path} appears to be all zeros"
467
490
  return img
468
491
 
492
+
469
493
  class ImageMmapDynamic(Dataset):
470
494
  """
471
495
  Dynamic-shape mmap dataset.
@@ -488,20 +512,20 @@ class ImageMmapDynamic(Dataset):
488
512
  self.imgpath2idx = {p: i for i, p in enumerate(self.img_paths)}
489
513
  self.n = len(self.img_paths)
490
514
  if self.n == 0:
491
- raise ValueError('Cannot create ImageMmapDynamic with empty img_paths list')
515
+ raise ValueError("Cannot create ImageMmapDynamic with empty img_paths list")
492
516
 
493
517
  self.dtype = np.dtype(dtype)
494
518
  self.safe = safe
495
519
 
496
520
  # Default path if not provided
497
521
  if mmap_path is None:
498
- hash_idx = identify(''.join(self.img_paths))
499
- mmap_path = Path('.cache') / f'mmap_dynamic_{hash_idx}.dat'
522
+ hash_idx = identify("".join(self.img_paths))
523
+ mmap_path = Path(".cache") / f"mmap_dynamic_{hash_idx}.dat"
500
524
 
501
525
  self.mmap_path = Path(mmap_path)
502
- self.meta_path = Path(str(self.mmap_path) + '.meta')
503
- self.hash_path = Path(str(self.mmap_path) + '.hash')
504
- self.lock_path = Path(str(self.mmap_path) + '.lock')
526
+ self.meta_path = Path(str(self.mmap_path) + ".meta")
527
+ self.hash_path = Path(str(self.mmap_path) + ".hash")
528
+ self.lock_path = Path(str(self.mmap_path) + ".lock")
505
529
 
506
530
  # Hash of the path list to detect changes
507
531
  current_hash = identify(self.img_paths)
@@ -509,40 +533,42 @@ class ImageMmapDynamic(Dataset):
509
533
 
510
534
  if not self.mmap_path.exists() or not self.meta_path.exists():
511
535
  needs_rebuild = True
512
- print('Dynamic mmap or meta file does not exist, building cache...')
536
+ print("Dynamic mmap or meta file does not exist, building cache...")
513
537
  elif not self.hash_path.exists():
514
538
  needs_rebuild = True
515
- print('Hash file does not exist for dynamic mmap, rebuilding cache...')
539
+ print("Hash file does not exist for dynamic mmap, rebuilding cache...")
516
540
  else:
517
541
  stored_hash = self.hash_path.read_text().strip()
518
542
  if stored_hash != current_hash:
519
543
  needs_rebuild = True
520
544
  print(
521
- f'Dynamic mmap hash mismatch '
522
- f'(stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), '
523
- 'rebuilding cache...'
545
+ f"Dynamic mmap hash mismatch "
546
+ f"(stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), "
547
+ "rebuilding cache..."
524
548
  )
525
549
  else:
526
550
  # Check size vs meta
527
551
  import json
528
552
 
529
553
  try:
530
- with open(self.meta_path, 'r') as f:
554
+ with open(self.meta_path, "r") as f:
531
555
  meta = json.load(f)
532
- meta_dtype = np.dtype(meta.get('dtype', 'uint8'))
533
- total_elems = int(meta['total_elems'])
556
+ meta_dtype = np.dtype(meta.get("dtype", "uint8"))
557
+ total_elems = int(meta["total_elems"])
534
558
  expected_bytes = total_elems * meta_dtype.itemsize
535
559
  actual_bytes = self.mmap_path.stat().st_size
536
560
  if actual_bytes != expected_bytes:
537
561
  needs_rebuild = True
538
562
  print(
539
- 'Dynamic mmap file size mismatch '
540
- f'(expected: {expected_bytes}, got: {actual_bytes}), '
541
- 'rebuilding cache...'
563
+ "Dynamic mmap file size mismatch "
564
+ f"(expected: {expected_bytes}, got: {actual_bytes}), "
565
+ "rebuilding cache..."
542
566
  )
543
567
  except Exception as e:
544
568
  needs_rebuild = True
545
- print(f'Failed to read dynamic mmap meta ({e}), rebuilding cache...')
569
+ print(
570
+ f"Failed to read dynamic mmap meta ({e}), rebuilding cache..."
571
+ )
546
572
 
547
573
  if needs_rebuild:
548
574
  self._build_cache_with_lock(current_hash)
@@ -552,7 +578,7 @@ class ImageMmapDynamic(Dataset):
552
578
  self.data = np.memmap(
553
579
  self.mmap_path,
554
580
  dtype=self.dtype,
555
- mode='r',
581
+ mode="r",
556
582
  shape=(self.total_elems,),
557
583
  )
558
584
 
@@ -567,21 +593,21 @@ class ImageMmapDynamic(Dataset):
567
593
  try:
568
594
  import fcntl # POSIX only, same as ImageMmap
569
595
 
570
- lock_fd = open(self.lock_path, 'w')
596
+ lock_fd = open(self.lock_path, "w")
571
597
  fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
572
598
 
573
599
  # We got the lock -> build cache
574
600
  self._build_cache(current_hash)
575
601
  except BlockingIOError:
576
602
  # Another process is building -> wait
577
- print('Another process is building the dynamic mmap cache, waiting...')
603
+ print("Another process is building the dynamic mmap cache, waiting...")
578
604
  if lock_fd:
579
605
  lock_fd.close()
580
- lock_fd = open(self.lock_path, 'w')
606
+ lock_fd = open(self.lock_path, "w")
581
607
  import fcntl as _fcntl
582
608
 
583
609
  _fcntl.flock(lock_fd.fileno(), _fcntl.LOCK_EX) # block until released
584
- print('Dynamic mmap cache built by another process!')
610
+ print("Dynamic mmap cache built by another process!")
585
611
  finally:
586
612
  if lock_fd:
587
613
  lock_fd.close()
@@ -591,7 +617,7 @@ class ImageMmapDynamic(Dataset):
591
617
  except Exception:
592
618
  pass
593
619
 
594
- def _build_cache(self, current_hash: str) -> None:
620
+ def _build_cache(self, current_hash: str, batch_size: int = 4096) -> None:
595
621
  """
596
622
  Build the flat mmap + .meta file.
597
623
 
@@ -602,19 +628,19 @@ class ImageMmapDynamic(Dataset):
602
628
  from tqdm import tqdm
603
629
  import json
604
630
 
605
- print(f'Building dynamic mmap cache for {self.n} images...')
631
+ print(f"Building dynamic mmap cache for {self.n} images...")
606
632
  # We don't know total size up front -> write sequentially
607
633
  offsets = np.zeros(self.n, dtype=np.int64)
608
634
  shapes = np.zeros((self.n, 3), dtype=np.int64)
609
635
 
610
- batch_size = 4096
611
636
  num_batches = (self.n + batch_size - 1) // batch_size
612
637
 
613
638
  current_offset = 0 # in elements, not bytes
614
639
 
615
- with open(self.mmap_path, 'wb') as f, tqdm(
616
- total=self.n, desc='Processing images (dynamic)', unit='img'
617
- ) as pbar:
640
+ with (
641
+ open(self.mmap_path, "wb") as f,
642
+ tqdm(total=self.n, desc="Processing images (dynamic)", unit="img") as pbar,
643
+ ):
618
644
  for batch_idx in range(num_batches):
619
645
  start_idx = batch_idx * batch_size
620
646
  end_idx = min(start_idx + batch_size, self.n)
@@ -623,7 +649,7 @@ class ImageMmapDynamic(Dataset):
623
649
  images_dict = read_images(
624
650
  batch_paths,
625
651
  hw=None, # keep original size
626
- batch_size=32,
652
+ batch_size=128,
627
653
  num_threads=max(1, cpu_count() - 1),
628
654
  )
629
655
 
@@ -633,20 +659,23 @@ class ImageMmapDynamic(Dataset):
633
659
 
634
660
  if img is None:
635
661
  if self.safe:
636
- raise ValueError(f'Failed to load image: {path}')
662
+ raise ValueError(f"Failed to load image: {path}")
637
663
  else:
638
664
  print(
639
- f'Warning: Failed to load {path}, storing 1x1x3 zeros'
665
+ f"Warning: Failed to load {path}, storing 1x1x3 zeros"
640
666
  )
641
667
  img = np.zeros((1, 1, 3), dtype=self.dtype)
642
668
 
669
+ # Clip to valid range for uint8
670
+ if self.dtype == np.uint8:
671
+ img = np.clip(img, 0, 255)
643
672
  if img.dtype != self.dtype:
644
673
  img = img.astype(self.dtype)
645
674
 
646
675
  if img.ndim != 3:
647
676
  raise ValueError(
648
- f'Expected image with 3 dims (H,W,C), got shape {img.shape} '
649
- f'for path {path}'
677
+ f"Expected image with 3 dims (H,W,C), got shape {img.shape} "
678
+ f"for path {path}"
650
679
  )
651
680
 
652
681
  h, w, c = img.shape
@@ -663,22 +692,22 @@ class ImageMmapDynamic(Dataset):
663
692
  self.total_elems = total_elems
664
693
 
665
694
  meta = {
666
- 'version': 1,
667
- 'dtype': self.dtype.name,
668
- 'n': self.n,
669
- 'paths': self.img_paths,
670
- 'offsets': offsets.tolist(),
671
- 'shapes': shapes.tolist(),
672
- 'total_elems': total_elems,
695
+ "version": 1,
696
+ "dtype": self.dtype.name,
697
+ "n": self.n,
698
+ "paths": self.img_paths,
699
+ "offsets": offsets.tolist(),
700
+ "shapes": shapes.tolist(),
701
+ "total_elems": total_elems,
673
702
  }
674
703
 
675
- with open(self.meta_path, 'w') as mf:
704
+ with open(self.meta_path, "w") as mf:
676
705
  json.dump(meta, mf)
677
706
 
678
707
  self.hash_path.write_text(current_hash)
679
708
  print(
680
- f'Dynamic mmap cache built successfully! '
681
- f'Meta saved to {self.meta_path}, total_elems={total_elems}'
709
+ f"Dynamic mmap cache built successfully! "
710
+ f"Meta saved to {self.meta_path}, total_elems={total_elems}"
682
711
  )
683
712
 
684
713
  # ------------------------------------------------------------------ #
@@ -687,18 +716,18 @@ class ImageMmapDynamic(Dataset):
687
716
  def _load_metadata(self) -> None:
688
717
  import json
689
718
 
690
- with open(self.meta_path, 'r') as f:
719
+ with open(self.meta_path, "r") as f:
691
720
  meta = json.load(f)
692
721
 
693
722
  # If paths order changed without hash mismatch, this will still keep
694
723
  # the meta-consistent order (but hash comparison should prevent that).
695
- self.img_paths = [str(p) for p in meta['paths']]
724
+ self.img_paths = [str(p) for p in meta["paths"]]
696
725
  self.imgpath2idx = {p: i for i, p in enumerate(self.img_paths)}
697
- self.n = int(meta['n'])
698
- self.dtype = np.dtype(meta.get('dtype', 'uint8'))
699
- self.offsets = np.asarray(meta['offsets'], dtype=np.int64)
700
- self.shapes = np.asarray(meta['shapes'], dtype=np.int64)
701
- self.total_elems = int(meta['total_elems'])
726
+ self.n = int(meta["n"])
727
+ self.dtype = np.dtype(meta.get("dtype", "uint8"))
728
+ self.offsets = np.asarray(meta["offsets"], dtype=np.int64)
729
+ self.shapes = np.asarray(meta["shapes"], dtype=np.int64)
730
+ self.total_elems = int(meta["total_elems"])
702
731
 
703
732
  assert len(self.offsets) == self.n
704
733
  assert self.shapes.shape == (self.n, 3)
@@ -725,11 +754,18 @@ class ImageMmapDynamic(Dataset):
725
754
  def imread(self, image_path: str | os.PathLike) -> np.ndarray:
726
755
  idx = self.imgpath2idx.get(str(image_path))
727
756
  if idx is None:
728
- raise ValueError(f'Image path {image_path} not found in dynamic dataset')
757
+ raise ValueError(f"Image path {image_path} not found in dynamic dataset")
729
758
  img = self[idx]
730
759
  if self.safe:
731
760
  summary = img.sum()
732
- assert summary > 0, f'Image at {image_path} appears to be all zeros'
761
+ assert summary > 0, f"Image at {image_path} appears to be all zeros"
733
762
  return img
734
763
 
735
- __all__ = ['read_images', 'read_images_cpu', 'read_images_gpu', 'ImageMmap', 'ImageMmapDynamic']
764
+
765
+ __all__ = [
766
+ "read_images",
767
+ "read_images_cpu",
768
+ "read_images_gpu",
769
+ "ImageMmap",
770
+ "ImageMmapDynamic",
771
+ ]
vision_utils/plot.py CHANGED
@@ -1,12 +1,12 @@
1
1
  from pathlib import Path
2
2
  from typing import TYPE_CHECKING, Any, List, Optional, Tuple, Union
3
3
 
4
- from speedy_utils.__imports import np, plt
4
+ import numpy as np
5
+ import matplotlib.pyplot as plt
5
6
 
6
7
 
7
8
  if TYPE_CHECKING:
8
9
 
9
- import lazy_loader as lazy
10
10
  import matplotlib.pyplot as plt
11
11
  import numpy as np
12
12
  import torch
speedy_utils/all.py DELETED
@@ -1,129 +0,0 @@
1
- # from speedy_utils import ( # Clock module; Function decorators; Cache utilities; IO utilities; Misc utilities; Print utilities; Multi-worker processing
2
- # Clock,
3
- # convert_to_builtin_python,
4
- # display_pretty_table_html,
5
- # dump_json_or_pickle,
6
- # dump_jsonl,
7
- # flatten_dict,
8
- # flatten_list,
9
- # fprint,
10
- # get_arg_names,
11
- # identify,
12
- # identify_uuid,
13
- # is_notebook,
14
- # jdumps,
15
- # jloads,
16
- # load_by_ext,
17
- # load_json_or_pickle,
18
- # load_jsonl,
19
- # log,
20
- # memoize,
21
- # mkdir_or_exist,
22
- # multi_process,
23
- # multi_thread,
24
- # print_table,
25
- # retry_runtime,
26
- # setup_logger,
27
- # speedy_timer,
28
- # timef,
29
- # )
30
-
31
- # from .__imports import *
32
-
33
-
34
- # choice = random.choice
35
-
36
- # # Define __all__ explicitly with all exports
37
- # __all__ = [
38
- # # Standard library
39
- # 'random',
40
- # 'copy',
41
- # 'functools',
42
- # 'gc',
43
- # 'inspect',
44
- # 'json',
45
- # 'multiprocessing',
46
- # 'os',
47
- # 'osp',
48
- # 'pickle',
49
- # 'pprint',
50
- # 're',
51
- # 'sys',
52
- # 'textwrap',
53
- # 'threading',
54
- # 'time',
55
- # 'traceback',
56
- # 'uuid',
57
- # 'Counter',
58
- # 'ThreadPoolExecutor',
59
- # 'as_completed',
60
- # 'glob',
61
- # 'Pool',
62
- # 'Path',
63
- # 'Lock',
64
- # 'defaultdict',
65
- # # Typing
66
- # 'Any',
67
- # 'Awaitable',
68
- # 'Callable',
69
- # 'TypingCallable',
70
- # 'Dict',
71
- # 'Generic',
72
- # 'Iterable',
73
- # 'List',
74
- # 'Literal',
75
- # 'Mapping',
76
- # 'Optional',
77
- # 'Sequence',
78
- # 'Set',
79
- # 'Tuple',
80
- # 'Type',
81
- # 'TypeVar',
82
- # 'Union',
83
- # # Third-party
84
- # 'pd',
85
- # 'xxhash',
86
- # 'get_ipython',
87
- # 'HTML',
88
- # 'display',
89
- # 'logger',
90
- # 'BaseModel',
91
- # 'tabulate',
92
- # 'tqdm',
93
- # 'np',
94
- # # Clock module
95
- # 'Clock',
96
- # 'speedy_timer',
97
- # 'timef',
98
- # # Function decorators
99
- # 'retry_runtime',
100
- # # Cache utilities
101
- # 'memoize',
102
- # 'identify',
103
- # 'identify_uuid',
104
- # # IO utilities
105
- # 'dump_json_or_pickle',
106
- # 'dump_jsonl',
107
- # 'load_by_ext',
108
- # 'load_json_or_pickle',
109
- # 'load_jsonl',
110
- # 'jdumps',
111
- # 'jloads',
112
- # # Misc utilities
113
- # 'mkdir_or_exist',
114
- # 'flatten_list',
115
- # 'get_arg_names',
116
- # 'is_notebook',
117
- # 'convert_to_builtin_python',
118
- # # Print utilities
119
- # 'display_pretty_table_html',
120
- # 'flatten_dict',
121
- # 'fprint',
122
- # 'print_table',
123
- # 'setup_logger',
124
- # 'log',
125
- # # Multi-worker processing
126
- # 'multi_process',
127
- # 'multi_thread',
128
- # 'choice',
129
- # ]