speedy-utils 1.1.27__py3-none-any.whl → 1.1.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/METADATA +1 -1
- speedy_utils-1.1.28.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +5 -0
- vision_utils/io_utils.py +470 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.28.dist-info}/entry_points.txt +0 -0
vision_utils/io_utils.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# type: ignore
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Sequence, Tuple, TYPE_CHECKING
|
|
8
|
+
from multiprocessing import cpu_count
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from speedy_utils import identify
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from torch.utils.data import Dataset
|
|
16
|
+
except ImportError:
|
|
17
|
+
Dataset = object
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from nvidia.dali import fn, pipeline_def
|
|
22
|
+
from nvidia.dali import types as dali_types
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
PathLike = str | os.PathLike
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _to_str_paths(paths: Sequence[PathLike]) -> list[str]:
|
|
30
|
+
return [os.fspath(p) for p in paths]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _validate_image(path: PathLike) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Validate if an image file is readable and not corrupted.
|
|
36
|
+
Returns True if valid, False otherwise.
|
|
37
|
+
"""
|
|
38
|
+
from PIL import Image
|
|
39
|
+
|
|
40
|
+
path = os.fspath(path)
|
|
41
|
+
|
|
42
|
+
if not os.path.exists(path):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
with Image.open(path) as img:
|
|
47
|
+
img.verify() # Verify it's a valid image
|
|
48
|
+
# Re-open after verify (verify closes the file)
|
|
49
|
+
with Image.open(path) as img:
|
|
50
|
+
img.load() # Actually decode the image data
|
|
51
|
+
return True
|
|
52
|
+
except Exception:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_images_cpu(
|
|
57
|
+
paths: Sequence[PathLike],
|
|
58
|
+
hw: tuple[int, int] | None = None,
|
|
59
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
60
|
+
"""
|
|
61
|
+
CPU image loader using Pillow.
|
|
62
|
+
|
|
63
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
paths: Sequence of image file paths.
|
|
67
|
+
hw: Optional (height, width) for resizing.
|
|
68
|
+
"""
|
|
69
|
+
import numpy as np
|
|
70
|
+
from PIL import Image
|
|
71
|
+
from tqdm import tqdm
|
|
72
|
+
|
|
73
|
+
str_paths = _to_str_paths(paths)
|
|
74
|
+
|
|
75
|
+
# Pillow < 9.1.0 exposes resampling filters directly on Image
|
|
76
|
+
resample_attr = getattr(Image, 'Resampling', Image)
|
|
77
|
+
resample = resample_attr.BILINEAR
|
|
78
|
+
|
|
79
|
+
target_size = None # Pillow expects (width, height)
|
|
80
|
+
if hw is not None:
|
|
81
|
+
h, w = hw
|
|
82
|
+
target_size = (w, h)
|
|
83
|
+
|
|
84
|
+
result: dict[str, 'np.ndarray | None'] = {}
|
|
85
|
+
for path in tqdm(str_paths, desc='Loading images (CPU)', unit='img'):
|
|
86
|
+
try:
|
|
87
|
+
with Image.open(path) as img:
|
|
88
|
+
img = img.convert('RGB')
|
|
89
|
+
if target_size is not None:
|
|
90
|
+
img = img.resize(target_size, resample=resample)
|
|
91
|
+
result[path] = np.asarray(img)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f'Warning: Failed to load {path}: {e}')
|
|
94
|
+
result[path] = None
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def read_images_gpu(
|
|
99
|
+
paths: Sequence[PathLike],
|
|
100
|
+
batch_size: int = 32,
|
|
101
|
+
num_threads: int = 4,
|
|
102
|
+
hw: tuple[int, int] | None = None,
|
|
103
|
+
validate: bool = False,
|
|
104
|
+
device: str = 'mixed',
|
|
105
|
+
device_id: int = 0,
|
|
106
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
107
|
+
"""
|
|
108
|
+
GPU-accelerated image reader using NVIDIA DALI.
|
|
109
|
+
|
|
110
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
paths: Sequence of image file paths.
|
|
114
|
+
batch_size: Batch size for DALI processing.
|
|
115
|
+
num_threads: Number of threads for DALI decoding.
|
|
116
|
+
hw: Optional (height, width) for resizing.
|
|
117
|
+
validate: If True, pre-validate images (slower).
|
|
118
|
+
device: DALI decoder device: "mixed" (default), "cpu", or "gpu".
|
|
119
|
+
device_id: GPU device id.
|
|
120
|
+
"""
|
|
121
|
+
import numpy as np
|
|
122
|
+
from nvidia.dali import fn, pipeline_def
|
|
123
|
+
from nvidia.dali import types as dali_types
|
|
124
|
+
|
|
125
|
+
str_paths = _to_str_paths(paths)
|
|
126
|
+
|
|
127
|
+
if not str_paths:
|
|
128
|
+
return {}
|
|
129
|
+
|
|
130
|
+
result: dict[str, 'np.ndarray | None'] = {}
|
|
131
|
+
valid_paths: list[str] = str_paths
|
|
132
|
+
|
|
133
|
+
# Optional validation (slow but safer)
|
|
134
|
+
if validate:
|
|
135
|
+
from tqdm import tqdm
|
|
136
|
+
|
|
137
|
+
print('Validating images...')
|
|
138
|
+
tmp_valid: list[str] = []
|
|
139
|
+
invalid_paths: list[str] = []
|
|
140
|
+
|
|
141
|
+
for path in tqdm(str_paths, desc='Validating', unit='img'):
|
|
142
|
+
if _validate_image(path):
|
|
143
|
+
tmp_valid.append(path)
|
|
144
|
+
else:
|
|
145
|
+
invalid_paths.append(path)
|
|
146
|
+
print(f'Warning: Skipping invalid/corrupted image: {path}')
|
|
147
|
+
|
|
148
|
+
valid_paths = tmp_valid
|
|
149
|
+
# pre-fill invalid paths with None
|
|
150
|
+
for p in invalid_paths:
|
|
151
|
+
result[p] = None
|
|
152
|
+
|
|
153
|
+
if not valid_paths:
|
|
154
|
+
print('No valid images found.')
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
resize_h, resize_w = (None, None)
|
|
158
|
+
if hw is not None:
|
|
159
|
+
resize_h, resize_w = hw # (H, W)
|
|
160
|
+
|
|
161
|
+
files_for_reader = list(valid_paths)
|
|
162
|
+
|
|
163
|
+
@pipeline_def
|
|
164
|
+
def pipe():
|
|
165
|
+
# Keep deterministic order to match valid_paths
|
|
166
|
+
jpegs, _ = fn.readers.file(
|
|
167
|
+
files=files_for_reader,
|
|
168
|
+
random_shuffle=False,
|
|
169
|
+
name='Reader',
|
|
170
|
+
)
|
|
171
|
+
imgs = fn.decoders.image(jpegs, device=device, output_type=dali_types.RGB)
|
|
172
|
+
if resize_h is not None and resize_w is not None:
|
|
173
|
+
# DALI resize expects (resize_x=width, resize_y=height)
|
|
174
|
+
imgs_resized = fn.resize(
|
|
175
|
+
imgs,
|
|
176
|
+
resize_x=resize_w,
|
|
177
|
+
resize_y=resize_h,
|
|
178
|
+
interp_type=dali_types.INTERP_TRIANGULAR,
|
|
179
|
+
)
|
|
180
|
+
return imgs_resized
|
|
181
|
+
return imgs
|
|
182
|
+
|
|
183
|
+
dali_pipe = pipe(
|
|
184
|
+
batch_size=batch_size,
|
|
185
|
+
num_threads=num_threads,
|
|
186
|
+
device_id=device_id,
|
|
187
|
+
prefetch_queue_depth=2,
|
|
188
|
+
)
|
|
189
|
+
dali_pipe.build()
|
|
190
|
+
|
|
191
|
+
imgs: list['np.ndarray'] = []
|
|
192
|
+
num_files = len(valid_paths)
|
|
193
|
+
num_batches = (num_files + batch_size - 1) // batch_size
|
|
194
|
+
|
|
195
|
+
from tqdm import tqdm
|
|
196
|
+
|
|
197
|
+
for _ in tqdm(range(num_batches), desc='Decoding (DALI)', unit='batch'):
|
|
198
|
+
(out,) = dali_pipe.run()
|
|
199
|
+
out = out.as_cpu()
|
|
200
|
+
for i in range(len(out)):
|
|
201
|
+
imgs.append(np.array(out.at(i)))
|
|
202
|
+
|
|
203
|
+
# Handle possible padding / extra samples
|
|
204
|
+
if len(imgs) < num_files:
|
|
205
|
+
print(
|
|
206
|
+
f'Warning: DALI returned fewer samples ({len(imgs)}) than expected ({num_files}).'
|
|
207
|
+
)
|
|
208
|
+
if len(imgs) > num_files:
|
|
209
|
+
imgs = imgs[:num_files]
|
|
210
|
+
|
|
211
|
+
# Map valid images to result
|
|
212
|
+
for path, img in zip(valid_paths, imgs, strict=False):
|
|
213
|
+
result[path] = img
|
|
214
|
+
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def read_images(
|
|
219
|
+
paths: Sequence[PathLike],
|
|
220
|
+
batch_size: int = 32,
|
|
221
|
+
num_threads: int = 4,
|
|
222
|
+
hw: tuple[int, int] | None = None,
|
|
223
|
+
validate: bool = False,
|
|
224
|
+
device: str = 'mixed',
|
|
225
|
+
device_id: int = 0,
|
|
226
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
227
|
+
"""
|
|
228
|
+
Fast image reader that tries GPU (DALI) first, falls back to CPU (Pillow).
|
|
229
|
+
|
|
230
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
paths: Sequence of image file paths.
|
|
234
|
+
batch_size: Batch size for DALI processing (GPU only).
|
|
235
|
+
num_threads: Number of threads for decoding (GPU only).
|
|
236
|
+
hw: Optional (height, width) for resizing.
|
|
237
|
+
validate: If True, pre-validate images before GPU processing (slower).
|
|
238
|
+
device: DALI decoder device: "mixed", "cpu", or "gpu".
|
|
239
|
+
device_id: GPU device id for DALI.
|
|
240
|
+
"""
|
|
241
|
+
str_paths = _to_str_paths(paths)
|
|
242
|
+
|
|
243
|
+
if not str_paths:
|
|
244
|
+
return {}
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
return read_images_gpu(
|
|
248
|
+
str_paths,
|
|
249
|
+
batch_size=batch_size,
|
|
250
|
+
num_threads=num_threads,
|
|
251
|
+
hw=hw,
|
|
252
|
+
validate=validate,
|
|
253
|
+
device=device,
|
|
254
|
+
device_id=device_id,
|
|
255
|
+
)
|
|
256
|
+
except Exception as exc:
|
|
257
|
+
print(f'GPU loading failed ({exc}), falling back to CPU...')
|
|
258
|
+
return read_images_cpu(str_paths, hw=hw)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class ImageMmap(Dataset):
|
|
265
|
+
"""
|
|
266
|
+
One-time build + read-only mmap dataset.
|
|
267
|
+
|
|
268
|
+
- First run (no mmap file): read all img_paths -> resize -> write mmap.
|
|
269
|
+
- Next runs: only read from mmap (no filesystem image reads).
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
img_paths: Sequence[str | os.PathLike],
|
|
275
|
+
size: Tuple[int, int] = (224, 224),
|
|
276
|
+
mmap_path: str | os.PathLike | None = None,
|
|
277
|
+
dtype: np.dtype = np.uint8,
|
|
278
|
+
C=3,
|
|
279
|
+
safe: bool = True,
|
|
280
|
+
) -> None:
|
|
281
|
+
self.imgpath2idx = {str(p): i for i, p in enumerate(img_paths)}
|
|
282
|
+
self.img_paths = [str(p) for p in img_paths]
|
|
283
|
+
self.H, self.W = size
|
|
284
|
+
self.C = C
|
|
285
|
+
self.n = len(self.img_paths)
|
|
286
|
+
self.dtype = np.dtype(dtype)
|
|
287
|
+
self.safe = safe
|
|
288
|
+
|
|
289
|
+
# Generate default mmap path if not provided
|
|
290
|
+
if mmap_path is None:
|
|
291
|
+
hash_idx = identify(''.join(self.img_paths))
|
|
292
|
+
mmap_path = Path('.cache') / f'mmap_dataset_{hash_idx}.dat'
|
|
293
|
+
|
|
294
|
+
self.mmap_path = Path(mmap_path)
|
|
295
|
+
self.hash_path = Path(str(self.mmap_path) + '.hash')
|
|
296
|
+
self.lock_path = Path(str(self.mmap_path) + '.lock')
|
|
297
|
+
self.shape = (self.n, self.H, self.W, self.C)
|
|
298
|
+
|
|
299
|
+
if self.n == 0:
|
|
300
|
+
raise ValueError("Cannot create ImageMmap with empty img_paths list")
|
|
301
|
+
|
|
302
|
+
# Calculate hash of image paths
|
|
303
|
+
current_hash = identify(self.img_paths)
|
|
304
|
+
needs_rebuild = False
|
|
305
|
+
|
|
306
|
+
if not self.mmap_path.exists():
|
|
307
|
+
needs_rebuild = True
|
|
308
|
+
print("Mmap file does not exist, building cache...")
|
|
309
|
+
elif not self.hash_path.exists():
|
|
310
|
+
needs_rebuild = True
|
|
311
|
+
print("Hash file does not exist, rebuilding cache...")
|
|
312
|
+
else:
|
|
313
|
+
# Check if hash matches
|
|
314
|
+
stored_hash = self.hash_path.read_text().strip()
|
|
315
|
+
if stored_hash != current_hash:
|
|
316
|
+
needs_rebuild = True
|
|
317
|
+
print(f"Hash mismatch (stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), rebuilding cache...")
|
|
318
|
+
|
|
319
|
+
# Verify file size matches expected
|
|
320
|
+
expected_bytes = np.prod(self.shape) * self.dtype.itemsize
|
|
321
|
+
if self.mmap_path.exists():
|
|
322
|
+
actual_size = self.mmap_path.stat().st_size
|
|
323
|
+
if actual_size != expected_bytes:
|
|
324
|
+
needs_rebuild = True
|
|
325
|
+
print(f"Mmap file size mismatch (expected: {expected_bytes}, got: {actual_size}), rebuilding cache...")
|
|
326
|
+
|
|
327
|
+
if needs_rebuild:
|
|
328
|
+
self._build_cache_with_lock(current_hash)
|
|
329
|
+
|
|
330
|
+
# runtime: always open read-only; assume cache is complete
|
|
331
|
+
self.data = np.memmap(
|
|
332
|
+
self.mmap_path,
|
|
333
|
+
dtype=self.dtype,
|
|
334
|
+
mode="r",
|
|
335
|
+
shape=self.shape,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# --------------------------------------------------------------------- #
|
|
339
|
+
# Build phase (only on first run)
|
|
340
|
+
# --------------------------------------------------------------------- #
|
|
341
|
+
def _build_cache_with_lock(self, current_hash: str, num_workers: int = None) -> None:
|
|
342
|
+
"""Build cache with lock file to prevent concurrent disk writes"""
|
|
343
|
+
import fcntl
|
|
344
|
+
|
|
345
|
+
self.mmap_path.parent.mkdir(parents=True, exist_ok=True)
|
|
346
|
+
|
|
347
|
+
# Try to acquire lock file
|
|
348
|
+
lock_fd = None
|
|
349
|
+
try:
|
|
350
|
+
lock_fd = open(self.lock_path, 'w')
|
|
351
|
+
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
352
|
+
|
|
353
|
+
# We got the lock, build the cache
|
|
354
|
+
self._build_cache(current_hash, num_workers)
|
|
355
|
+
|
|
356
|
+
except BlockingIOError:
|
|
357
|
+
# Another process is building, wait for it
|
|
358
|
+
print("Another process is building the cache, waiting...")
|
|
359
|
+
if lock_fd:
|
|
360
|
+
lock_fd.close()
|
|
361
|
+
lock_fd = open(self.lock_path, 'w')
|
|
362
|
+
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # Wait for lock
|
|
363
|
+
print("Cache built by another process!")
|
|
364
|
+
|
|
365
|
+
finally:
|
|
366
|
+
if lock_fd:
|
|
367
|
+
lock_fd.close()
|
|
368
|
+
if self.lock_path.exists():
|
|
369
|
+
try:
|
|
370
|
+
self.lock_path.unlink()
|
|
371
|
+
except:
|
|
372
|
+
pass
|
|
373
|
+
|
|
374
|
+
def _build_cache(self, current_hash: str, num_workers: int = None) -> None:
|
|
375
|
+
from tqdm import tqdm
|
|
376
|
+
|
|
377
|
+
# Pre-allocate the file with the required size
|
|
378
|
+
total_bytes = np.prod(self.shape) * self.dtype.itemsize
|
|
379
|
+
print(f"Pre-allocating {total_bytes / (1024**3):.2f} GB for mmap file...")
|
|
380
|
+
with open(self.mmap_path, 'wb') as f:
|
|
381
|
+
f.seek(total_bytes - 1)
|
|
382
|
+
f.write(b'\0')
|
|
383
|
+
|
|
384
|
+
mm = np.memmap(
|
|
385
|
+
self.mmap_path,
|
|
386
|
+
dtype=self.dtype,
|
|
387
|
+
mode='r+',
|
|
388
|
+
shape=self.shape,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Process images in batches to avoid memory explosion
|
|
392
|
+
batch_size = 4096
|
|
393
|
+
num_batches = (self.n + batch_size - 1) // batch_size
|
|
394
|
+
|
|
395
|
+
print(f"Loading {self.n} images in {num_batches} batches of up to {batch_size} images...")
|
|
396
|
+
|
|
397
|
+
with tqdm(total=self.n, desc='Processing images', unit='img') as pbar:
|
|
398
|
+
for batch_idx in range(num_batches):
|
|
399
|
+
start_idx = batch_idx * batch_size
|
|
400
|
+
end_idx = min(start_idx + batch_size, self.n)
|
|
401
|
+
batch_paths = self.img_paths[start_idx:end_idx]
|
|
402
|
+
|
|
403
|
+
# Load one batch at a time
|
|
404
|
+
images_dict = read_images(
|
|
405
|
+
batch_paths,
|
|
406
|
+
hw=(self.H, self.W),
|
|
407
|
+
batch_size=32,
|
|
408
|
+
num_threads=num_workers or max(1, cpu_count() - 1),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Write batch to mmap
|
|
412
|
+
for local_idx, path in enumerate(batch_paths):
|
|
413
|
+
global_idx = start_idx + local_idx
|
|
414
|
+
img = images_dict.get(path)
|
|
415
|
+
|
|
416
|
+
if img is None:
|
|
417
|
+
if self.safe:
|
|
418
|
+
raise ValueError(f'Failed to load image: {path}')
|
|
419
|
+
else:
|
|
420
|
+
# Failed to load, write zeros
|
|
421
|
+
print(f'Warning: Failed to load {path}, using zeros')
|
|
422
|
+
mm[global_idx] = np.zeros(
|
|
423
|
+
(self.H, self.W, self.C),
|
|
424
|
+
dtype=self.dtype
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Ensure correct dtype
|
|
428
|
+
if img.dtype != self.dtype:
|
|
429
|
+
img = img.astype(self.dtype)
|
|
430
|
+
mm[global_idx] = img
|
|
431
|
+
|
|
432
|
+
pbar.update(1)
|
|
433
|
+
|
|
434
|
+
# Flush after each batch and clear memory
|
|
435
|
+
mm.flush()
|
|
436
|
+
del images_dict
|
|
437
|
+
|
|
438
|
+
mm.flush()
|
|
439
|
+
del mm # ensure descriptor is closed
|
|
440
|
+
|
|
441
|
+
# Save hash file
|
|
442
|
+
self.hash_path.write_text(current_hash)
|
|
443
|
+
print(f"Mmap cache built successfully! Hash saved to {self.hash_path}")
|
|
444
|
+
|
|
445
|
+
def _load_and_resize(self, path: str) -> np.ndarray:
|
|
446
|
+
img = Image.open(path).convert("RGB")
|
|
447
|
+
img = img.resize((self.W, self.H), Image.BILINEAR)
|
|
448
|
+
return np.asarray(img, dtype=self.dtype)
|
|
449
|
+
|
|
450
|
+
# --------------------------------------------------------------------- #
|
|
451
|
+
# Dataset API
|
|
452
|
+
# --------------------------------------------------------------------- #
|
|
453
|
+
def __len__(self) -> int:
|
|
454
|
+
return self.n
|
|
455
|
+
|
|
456
|
+
def __getitem__(self, idx: int) -> np.ndarray:
|
|
457
|
+
# At runtime: this is just a mmap read
|
|
458
|
+
return np.array(self.data[idx]) # copy to normal ndarray
|
|
459
|
+
|
|
460
|
+
def imread(self, image_path: str | os.PathLike) -> np.ndarray:
|
|
461
|
+
idx = self.imgpath2idx.get(str(image_path))
|
|
462
|
+
if idx is None:
|
|
463
|
+
raise ValueError(f"Image path {image_path} not found in dataset")
|
|
464
|
+
img = np.array(self.data[idx]) # copy to normal ndarray
|
|
465
|
+
summary = img.sum()
|
|
466
|
+
assert summary > 0, f"Image at {image_path} appears to be all zeros"
|
|
467
|
+
return img
|
|
468
|
+
|
|
469
|
+
|
|
470
|
+
__all__ = ['read_images', 'read_images_cpu', 'read_images_gpu', 'ImageMmap']
|