speedy-utils 1.1.27__py3-none-any.whl → 1.1.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- llm_utils/__init__.py +16 -4
- llm_utils/chat_format/__init__.py +10 -10
- llm_utils/chat_format/display.py +33 -21
- llm_utils/chat_format/transform.py +17 -19
- llm_utils/chat_format/utils.py +6 -4
- llm_utils/group_messages.py +17 -14
- llm_utils/lm/__init__.py +6 -5
- llm_utils/lm/async_lm/__init__.py +1 -0
- llm_utils/lm/async_lm/_utils.py +10 -9
- llm_utils/lm/async_lm/async_llm_task.py +141 -137
- llm_utils/lm/async_lm/async_lm.py +48 -42
- llm_utils/lm/async_lm/async_lm_base.py +59 -60
- llm_utils/lm/async_lm/lm_specific.py +4 -3
- llm_utils/lm/base_prompt_builder.py +93 -70
- llm_utils/lm/llm.py +126 -108
- llm_utils/lm/llm_signature.py +4 -2
- llm_utils/lm/lm_base.py +72 -73
- llm_utils/lm/mixins.py +102 -62
- llm_utils/lm/openai_memoize.py +124 -87
- llm_utils/lm/signature.py +105 -92
- llm_utils/lm/utils.py +42 -23
- llm_utils/scripts/vllm_load_balancer.py +23 -30
- llm_utils/scripts/vllm_serve.py +8 -7
- llm_utils/vector_cache/__init__.py +9 -3
- llm_utils/vector_cache/cli.py +1 -1
- llm_utils/vector_cache/core.py +59 -63
- llm_utils/vector_cache/types.py +7 -5
- llm_utils/vector_cache/utils.py +12 -8
- speedy_utils/__imports.py +244 -0
- speedy_utils/__init__.py +90 -194
- speedy_utils/all.py +125 -227
- speedy_utils/common/clock.py +37 -42
- speedy_utils/common/function_decorator.py +6 -12
- speedy_utils/common/logger.py +43 -52
- speedy_utils/common/notebook_utils.py +13 -21
- speedy_utils/common/patcher.py +21 -17
- speedy_utils/common/report_manager.py +42 -44
- speedy_utils/common/utils_cache.py +152 -169
- speedy_utils/common/utils_io.py +137 -103
- speedy_utils/common/utils_misc.py +15 -21
- speedy_utils/common/utils_print.py +22 -28
- speedy_utils/multi_worker/process.py +66 -79
- speedy_utils/multi_worker/thread.py +78 -155
- speedy_utils/scripts/mpython.py +38 -36
- speedy_utils/scripts/openapi_client_codegen.py +10 -10
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/METADATA +1 -1
- speedy_utils-1.1.29.dist-info/RECORD +57 -0
- vision_utils/README.md +202 -0
- vision_utils/__init__.py +4 -0
- vision_utils/io_utils.py +735 -0
- vision_utils/plot.py +345 -0
- speedy_utils-1.1.27.dist-info/RECORD +0 -52
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/WHEEL +0 -0
- {speedy_utils-1.1.27.dist-info → speedy_utils-1.1.29.dist-info}/entry_points.txt +0 -0
vision_utils/io_utils.py
ADDED
|
@@ -0,0 +1,735 @@
|
|
|
1
|
+
from __future__ import annotations
|
|
2
|
+
|
|
3
|
+
# type: ignore
|
|
4
|
+
import os
|
|
5
|
+
import time
|
|
6
|
+
from pathlib import Path
|
|
7
|
+
from typing import Sequence, Tuple, TYPE_CHECKING
|
|
8
|
+
from multiprocessing import cpu_count
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from PIL import Image
|
|
12
|
+
from speedy_utils import identify
|
|
13
|
+
|
|
14
|
+
try:
|
|
15
|
+
from torch.utils.data import Dataset
|
|
16
|
+
except ImportError:
|
|
17
|
+
Dataset = object
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
if TYPE_CHECKING:
|
|
21
|
+
from nvidia.dali import fn, pipeline_def
|
|
22
|
+
from nvidia.dali import types as dali_types
|
|
23
|
+
from tqdm import tqdm
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
PathLike = str | os.PathLike
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
def _to_str_paths(paths: Sequence[PathLike]) -> list[str]:
|
|
30
|
+
return [os.fspath(p) for p in paths]
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
def _validate_image(path: PathLike) -> bool:
|
|
34
|
+
"""
|
|
35
|
+
Validate if an image file is readable and not corrupted.
|
|
36
|
+
Returns True if valid, False otherwise.
|
|
37
|
+
"""
|
|
38
|
+
from PIL import Image
|
|
39
|
+
|
|
40
|
+
path = os.fspath(path)
|
|
41
|
+
|
|
42
|
+
if not os.path.exists(path):
|
|
43
|
+
return False
|
|
44
|
+
|
|
45
|
+
try:
|
|
46
|
+
with Image.open(path) as img:
|
|
47
|
+
img.verify() # Verify it's a valid image
|
|
48
|
+
# Re-open after verify (verify closes the file)
|
|
49
|
+
with Image.open(path) as img:
|
|
50
|
+
img.load() # Actually decode the image data
|
|
51
|
+
return True
|
|
52
|
+
except Exception:
|
|
53
|
+
return False
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
def read_images_cpu(
|
|
57
|
+
paths: Sequence[PathLike],
|
|
58
|
+
hw: tuple[int, int] | None = None,
|
|
59
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
60
|
+
"""
|
|
61
|
+
CPU image loader using Pillow.
|
|
62
|
+
|
|
63
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
paths: Sequence of image file paths.
|
|
67
|
+
hw: Optional (height, width) for resizing.
|
|
68
|
+
"""
|
|
69
|
+
import numpy as np
|
|
70
|
+
from PIL import Image
|
|
71
|
+
from tqdm import tqdm
|
|
72
|
+
|
|
73
|
+
str_paths = _to_str_paths(paths)
|
|
74
|
+
|
|
75
|
+
# Pillow < 9.1.0 exposes resampling filters directly on Image
|
|
76
|
+
resample_attr = getattr(Image, 'Resampling', Image)
|
|
77
|
+
resample = resample_attr.BILINEAR
|
|
78
|
+
|
|
79
|
+
target_size = None # Pillow expects (width, height)
|
|
80
|
+
if hw is not None:
|
|
81
|
+
h, w = hw
|
|
82
|
+
target_size = (w, h)
|
|
83
|
+
|
|
84
|
+
result: dict[str, 'np.ndarray | None'] = {}
|
|
85
|
+
for path in tqdm(str_paths, desc='Loading images (CPU)', unit='img'):
|
|
86
|
+
try:
|
|
87
|
+
with Image.open(path) as img:
|
|
88
|
+
img = img.convert('RGB')
|
|
89
|
+
if target_size is not None:
|
|
90
|
+
img = img.resize(target_size, resample=resample)
|
|
91
|
+
result[path] = np.asarray(img)
|
|
92
|
+
except Exception as e:
|
|
93
|
+
print(f'Warning: Failed to load {path}: {e}')
|
|
94
|
+
result[path] = None
|
|
95
|
+
return result
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
def read_images_gpu(
|
|
99
|
+
paths: Sequence[PathLike],
|
|
100
|
+
batch_size: int = 32,
|
|
101
|
+
num_threads: int = 4,
|
|
102
|
+
hw: tuple[int, int] | None = None,
|
|
103
|
+
validate: bool = False,
|
|
104
|
+
device: str = 'mixed',
|
|
105
|
+
device_id: int = 0,
|
|
106
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
107
|
+
"""
|
|
108
|
+
GPU-accelerated image reader using NVIDIA DALI.
|
|
109
|
+
|
|
110
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
111
|
+
|
|
112
|
+
Args:
|
|
113
|
+
paths: Sequence of image file paths.
|
|
114
|
+
batch_size: Batch size for DALI processing.
|
|
115
|
+
num_threads: Number of threads for DALI decoding.
|
|
116
|
+
hw: Optional (height, width) for resizing.
|
|
117
|
+
validate: If True, pre-validate images (slower).
|
|
118
|
+
device: DALI decoder device: "mixed" (default), "cpu", or "gpu".
|
|
119
|
+
device_id: GPU device id.
|
|
120
|
+
"""
|
|
121
|
+
import numpy as np
|
|
122
|
+
from nvidia.dali import fn, pipeline_def
|
|
123
|
+
from nvidia.dali import types as dali_types
|
|
124
|
+
|
|
125
|
+
str_paths = _to_str_paths(paths)
|
|
126
|
+
|
|
127
|
+
if not str_paths:
|
|
128
|
+
return {}
|
|
129
|
+
|
|
130
|
+
result: dict[str, 'np.ndarray | None'] = {}
|
|
131
|
+
valid_paths: list[str] = str_paths
|
|
132
|
+
|
|
133
|
+
# Optional validation (slow but safer)
|
|
134
|
+
if validate:
|
|
135
|
+
from tqdm import tqdm
|
|
136
|
+
|
|
137
|
+
print('Validating images...')
|
|
138
|
+
tmp_valid: list[str] = []
|
|
139
|
+
invalid_paths: list[str] = []
|
|
140
|
+
|
|
141
|
+
for path in tqdm(str_paths, desc='Validating', unit='img'):
|
|
142
|
+
if _validate_image(path):
|
|
143
|
+
tmp_valid.append(path)
|
|
144
|
+
else:
|
|
145
|
+
invalid_paths.append(path)
|
|
146
|
+
print(f'Warning: Skipping invalid/corrupted image: {path}')
|
|
147
|
+
|
|
148
|
+
valid_paths = tmp_valid
|
|
149
|
+
# pre-fill invalid paths with None
|
|
150
|
+
for p in invalid_paths:
|
|
151
|
+
result[p] = None
|
|
152
|
+
|
|
153
|
+
if not valid_paths:
|
|
154
|
+
print('No valid images found.')
|
|
155
|
+
return result
|
|
156
|
+
|
|
157
|
+
resize_h, resize_w = (None, None)
|
|
158
|
+
if hw is not None:
|
|
159
|
+
resize_h, resize_w = hw # (H, W)
|
|
160
|
+
|
|
161
|
+
files_for_reader = list(valid_paths)
|
|
162
|
+
|
|
163
|
+
@pipeline_def
|
|
164
|
+
def pipe():
|
|
165
|
+
# Keep deterministic order to match valid_paths
|
|
166
|
+
jpegs, _ = fn.readers.file(
|
|
167
|
+
files=files_for_reader,
|
|
168
|
+
random_shuffle=False,
|
|
169
|
+
name='Reader',
|
|
170
|
+
)
|
|
171
|
+
imgs = fn.decoders.image(jpegs, device=device, output_type=dali_types.RGB)
|
|
172
|
+
if resize_h is not None and resize_w is not None:
|
|
173
|
+
# DALI resize expects (resize_x=width, resize_y=height)
|
|
174
|
+
imgs_resized = fn.resize(
|
|
175
|
+
imgs,
|
|
176
|
+
resize_x=resize_w,
|
|
177
|
+
resize_y=resize_h,
|
|
178
|
+
interp_type=dali_types.INTERP_TRIANGULAR,
|
|
179
|
+
)
|
|
180
|
+
return imgs_resized
|
|
181
|
+
return imgs
|
|
182
|
+
|
|
183
|
+
dali_pipe = pipe(
|
|
184
|
+
batch_size=batch_size,
|
|
185
|
+
num_threads=num_threads,
|
|
186
|
+
device_id=device_id,
|
|
187
|
+
prefetch_queue_depth=2,
|
|
188
|
+
)
|
|
189
|
+
dali_pipe.build()
|
|
190
|
+
|
|
191
|
+
imgs: list['np.ndarray'] = []
|
|
192
|
+
num_files = len(valid_paths)
|
|
193
|
+
num_batches = (num_files + batch_size - 1) // batch_size
|
|
194
|
+
|
|
195
|
+
from tqdm import tqdm
|
|
196
|
+
|
|
197
|
+
for _ in tqdm(range(num_batches), desc='Decoding (DALI)', unit='batch'):
|
|
198
|
+
(out,) = dali_pipe.run()
|
|
199
|
+
out = out.as_cpu()
|
|
200
|
+
for i in range(len(out)):
|
|
201
|
+
imgs.append(np.array(out.at(i)))
|
|
202
|
+
|
|
203
|
+
# Handle possible padding / extra samples
|
|
204
|
+
if len(imgs) < num_files:
|
|
205
|
+
print(
|
|
206
|
+
f'Warning: DALI returned fewer samples ({len(imgs)}) than expected ({num_files}).'
|
|
207
|
+
)
|
|
208
|
+
if len(imgs) > num_files:
|
|
209
|
+
imgs = imgs[:num_files]
|
|
210
|
+
|
|
211
|
+
# Map valid images to result
|
|
212
|
+
for path, img in zip(valid_paths, imgs, strict=False):
|
|
213
|
+
result[path] = img
|
|
214
|
+
|
|
215
|
+
return result
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
def read_images(
|
|
219
|
+
paths: Sequence[PathLike],
|
|
220
|
+
batch_size: int = 32,
|
|
221
|
+
num_threads: int = 4,
|
|
222
|
+
hw: tuple[int, int] | None = None,
|
|
223
|
+
validate: bool = False,
|
|
224
|
+
device: str = 'mixed',
|
|
225
|
+
device_id: int = 0,
|
|
226
|
+
) -> dict[str, 'np.ndarray | None']:
|
|
227
|
+
"""
|
|
228
|
+
Fast image reader that tries GPU (DALI) first, falls back to CPU (Pillow).
|
|
229
|
+
|
|
230
|
+
Returns dict mapping paths -> numpy arrays (H, W, C, RGB) or None for invalid images.
|
|
231
|
+
|
|
232
|
+
Args:
|
|
233
|
+
paths: Sequence of image file paths.
|
|
234
|
+
batch_size: Batch size for DALI processing (GPU only).
|
|
235
|
+
num_threads: Number of threads for decoding (GPU only).
|
|
236
|
+
hw: Optional (height, width) for resizing.
|
|
237
|
+
validate: If True, pre-validate images before GPU processing (slower).
|
|
238
|
+
device: DALI decoder device: "mixed", "cpu", or "gpu".
|
|
239
|
+
device_id: GPU device id for DALI.
|
|
240
|
+
"""
|
|
241
|
+
str_paths = _to_str_paths(paths)
|
|
242
|
+
|
|
243
|
+
if not str_paths:
|
|
244
|
+
return {}
|
|
245
|
+
|
|
246
|
+
try:
|
|
247
|
+
return read_images_gpu(
|
|
248
|
+
str_paths,
|
|
249
|
+
batch_size=batch_size,
|
|
250
|
+
num_threads=num_threads,
|
|
251
|
+
hw=hw,
|
|
252
|
+
validate=validate,
|
|
253
|
+
device=device,
|
|
254
|
+
device_id=device_id,
|
|
255
|
+
)
|
|
256
|
+
except Exception as exc:
|
|
257
|
+
print(f'GPU loading failed ({exc}), falling back to CPU...')
|
|
258
|
+
return read_images_cpu(str_paths, hw=hw)
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
|
|
264
|
+
class ImageMmap(Dataset):
|
|
265
|
+
"""
|
|
266
|
+
One-time build + read-only mmap dataset.
|
|
267
|
+
|
|
268
|
+
- First run (no mmap file): read all img_paths -> resize -> write mmap.
|
|
269
|
+
- Next runs: only read from mmap (no filesystem image reads).
|
|
270
|
+
"""
|
|
271
|
+
|
|
272
|
+
def __init__(
|
|
273
|
+
self,
|
|
274
|
+
img_paths: Sequence[str | os.PathLike],
|
|
275
|
+
size: Tuple[int, int] = (224, 224),
|
|
276
|
+
mmap_path: str | os.PathLike | None = None,
|
|
277
|
+
dtype: np.dtype = np.uint8,
|
|
278
|
+
C=3,
|
|
279
|
+
safe: bool = True,
|
|
280
|
+
) -> None:
|
|
281
|
+
self.imgpath2idx = {str(p): i for i, p in enumerate(img_paths)}
|
|
282
|
+
self.img_paths = [str(p) for p in img_paths]
|
|
283
|
+
self.H, self.W = size
|
|
284
|
+
self.C = C
|
|
285
|
+
self.n = len(self.img_paths)
|
|
286
|
+
self.dtype = np.dtype(dtype)
|
|
287
|
+
self.safe = safe
|
|
288
|
+
|
|
289
|
+
# Generate default mmap path if not provided
|
|
290
|
+
if mmap_path is None:
|
|
291
|
+
hash_idx = identify(''.join(self.img_paths))
|
|
292
|
+
mmap_path = Path('.cache') / f'mmap_dataset_{hash_idx}.dat'
|
|
293
|
+
|
|
294
|
+
self.mmap_path = Path(mmap_path)
|
|
295
|
+
self.hash_path = Path(str(self.mmap_path) + '.hash')
|
|
296
|
+
self.lock_path = Path(str(self.mmap_path) + '.lock')
|
|
297
|
+
self.shape = (self.n, self.H, self.W, self.C)
|
|
298
|
+
|
|
299
|
+
if self.n == 0:
|
|
300
|
+
raise ValueError("Cannot create ImageMmap with empty img_paths list")
|
|
301
|
+
|
|
302
|
+
# Calculate hash of image paths
|
|
303
|
+
current_hash = identify(self.img_paths)
|
|
304
|
+
needs_rebuild = False
|
|
305
|
+
|
|
306
|
+
if not self.mmap_path.exists():
|
|
307
|
+
needs_rebuild = True
|
|
308
|
+
print("Mmap file does not exist, building cache...")
|
|
309
|
+
elif not self.hash_path.exists():
|
|
310
|
+
needs_rebuild = True
|
|
311
|
+
print("Hash file does not exist, rebuilding cache...")
|
|
312
|
+
else:
|
|
313
|
+
# Check if hash matches
|
|
314
|
+
stored_hash = self.hash_path.read_text().strip()
|
|
315
|
+
if stored_hash != current_hash:
|
|
316
|
+
needs_rebuild = True
|
|
317
|
+
print(f"Hash mismatch (stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), rebuilding cache...")
|
|
318
|
+
|
|
319
|
+
# Verify file size matches expected
|
|
320
|
+
expected_bytes = np.prod(self.shape) * self.dtype.itemsize
|
|
321
|
+
if self.mmap_path.exists():
|
|
322
|
+
actual_size = self.mmap_path.stat().st_size
|
|
323
|
+
if actual_size != expected_bytes:
|
|
324
|
+
needs_rebuild = True
|
|
325
|
+
print(f"Mmap file size mismatch (expected: {expected_bytes}, got: {actual_size}), rebuilding cache...")
|
|
326
|
+
|
|
327
|
+
if needs_rebuild:
|
|
328
|
+
self._build_cache_with_lock(current_hash)
|
|
329
|
+
|
|
330
|
+
# runtime: always open read-only; assume cache is complete
|
|
331
|
+
self.data = np.memmap(
|
|
332
|
+
self.mmap_path,
|
|
333
|
+
dtype=self.dtype,
|
|
334
|
+
mode="r",
|
|
335
|
+
shape=self.shape,
|
|
336
|
+
)
|
|
337
|
+
|
|
338
|
+
# --------------------------------------------------------------------- #
|
|
339
|
+
# Build phase (only on first run)
|
|
340
|
+
# --------------------------------------------------------------------- #
|
|
341
|
+
def _build_cache_with_lock(self, current_hash: str, num_workers: int = None) -> None:
|
|
342
|
+
"""Build cache with lock file to prevent concurrent disk writes"""
|
|
343
|
+
import fcntl
|
|
344
|
+
|
|
345
|
+
self.mmap_path.parent.mkdir(parents=True, exist_ok=True)
|
|
346
|
+
|
|
347
|
+
# Try to acquire lock file
|
|
348
|
+
lock_fd = None
|
|
349
|
+
try:
|
|
350
|
+
lock_fd = open(self.lock_path, 'w')
|
|
351
|
+
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
352
|
+
|
|
353
|
+
# We got the lock, build the cache
|
|
354
|
+
self._build_cache(current_hash, num_workers)
|
|
355
|
+
|
|
356
|
+
except BlockingIOError:
|
|
357
|
+
# Another process is building, wait for it
|
|
358
|
+
print("Another process is building the cache, waiting...")
|
|
359
|
+
if lock_fd:
|
|
360
|
+
lock_fd.close()
|
|
361
|
+
lock_fd = open(self.lock_path, 'w')
|
|
362
|
+
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX) # Wait for lock
|
|
363
|
+
print("Cache built by another process!")
|
|
364
|
+
|
|
365
|
+
finally:
|
|
366
|
+
if lock_fd:
|
|
367
|
+
lock_fd.close()
|
|
368
|
+
if self.lock_path.exists():
|
|
369
|
+
try:
|
|
370
|
+
self.lock_path.unlink()
|
|
371
|
+
except:
|
|
372
|
+
pass
|
|
373
|
+
|
|
374
|
+
def _build_cache(self, current_hash: str, num_workers: int = None) -> None:
|
|
375
|
+
from tqdm import tqdm
|
|
376
|
+
|
|
377
|
+
# Pre-allocate the file with the required size
|
|
378
|
+
total_bytes = np.prod(self.shape) * self.dtype.itemsize
|
|
379
|
+
print(f"Pre-allocating {total_bytes / (1024**3):.2f} GB for mmap file...")
|
|
380
|
+
with open(self.mmap_path, 'wb') as f:
|
|
381
|
+
f.seek(total_bytes - 1)
|
|
382
|
+
f.write(b'\0')
|
|
383
|
+
|
|
384
|
+
mm = np.memmap(
|
|
385
|
+
self.mmap_path,
|
|
386
|
+
dtype=self.dtype,
|
|
387
|
+
mode='r+',
|
|
388
|
+
shape=self.shape,
|
|
389
|
+
)
|
|
390
|
+
|
|
391
|
+
# Process images in batches to avoid memory explosion
|
|
392
|
+
batch_size = 4096
|
|
393
|
+
num_batches = (self.n + batch_size - 1) // batch_size
|
|
394
|
+
|
|
395
|
+
print(f"Loading {self.n} images in {num_batches} batches of up to {batch_size} images...")
|
|
396
|
+
|
|
397
|
+
with tqdm(total=self.n, desc='Processing images', unit='img') as pbar:
|
|
398
|
+
for batch_idx in range(num_batches):
|
|
399
|
+
start_idx = batch_idx * batch_size
|
|
400
|
+
end_idx = min(start_idx + batch_size, self.n)
|
|
401
|
+
batch_paths = self.img_paths[start_idx:end_idx]
|
|
402
|
+
|
|
403
|
+
# Load one batch at a time
|
|
404
|
+
images_dict = read_images(
|
|
405
|
+
batch_paths,
|
|
406
|
+
hw=(self.H, self.W),
|
|
407
|
+
batch_size=32,
|
|
408
|
+
num_threads=num_workers or max(1, cpu_count() - 1),
|
|
409
|
+
)
|
|
410
|
+
|
|
411
|
+
# Write batch to mmap
|
|
412
|
+
for local_idx, path in enumerate(batch_paths):
|
|
413
|
+
global_idx = start_idx + local_idx
|
|
414
|
+
img = images_dict.get(path)
|
|
415
|
+
|
|
416
|
+
if img is None:
|
|
417
|
+
if self.safe:
|
|
418
|
+
raise ValueError(f'Failed to load image: {path}')
|
|
419
|
+
else:
|
|
420
|
+
# Failed to load, write zeros
|
|
421
|
+
print(f'Warning: Failed to load {path}, using zeros')
|
|
422
|
+
mm[global_idx] = np.zeros(
|
|
423
|
+
(self.H, self.W, self.C),
|
|
424
|
+
dtype=self.dtype
|
|
425
|
+
)
|
|
426
|
+
else:
|
|
427
|
+
# Ensure correct dtype
|
|
428
|
+
if img.dtype != self.dtype:
|
|
429
|
+
img = img.astype(self.dtype)
|
|
430
|
+
mm[global_idx] = img
|
|
431
|
+
|
|
432
|
+
pbar.update(1)
|
|
433
|
+
|
|
434
|
+
# Flush after each batch and clear memory
|
|
435
|
+
mm.flush()
|
|
436
|
+
del images_dict
|
|
437
|
+
|
|
438
|
+
mm.flush()
|
|
439
|
+
del mm # ensure descriptor is closed
|
|
440
|
+
|
|
441
|
+
# Save hash file
|
|
442
|
+
self.hash_path.write_text(current_hash)
|
|
443
|
+
print(f"Mmap cache built successfully! Hash saved to {self.hash_path}")
|
|
444
|
+
|
|
445
|
+
def _load_and_resize(self, path: str) -> np.ndarray:
|
|
446
|
+
img = Image.open(path).convert("RGB")
|
|
447
|
+
img = img.resize((self.W, self.H), Image.BILINEAR)
|
|
448
|
+
return np.asarray(img, dtype=self.dtype)
|
|
449
|
+
|
|
450
|
+
# --------------------------------------------------------------------- #
|
|
451
|
+
# Dataset API
|
|
452
|
+
# --------------------------------------------------------------------- #
|
|
453
|
+
def __len__(self) -> int:
|
|
454
|
+
return self.n
|
|
455
|
+
|
|
456
|
+
def __getitem__(self, idx: int) -> np.ndarray:
|
|
457
|
+
# At runtime: this is just a mmap read
|
|
458
|
+
return np.array(self.data[idx]) # copy to normal ndarray
|
|
459
|
+
|
|
460
|
+
def imread(self, image_path: str | os.PathLike) -> np.ndarray:
|
|
461
|
+
idx = self.imgpath2idx.get(str(image_path))
|
|
462
|
+
if idx is None:
|
|
463
|
+
raise ValueError(f"Image path {image_path} not found in dataset")
|
|
464
|
+
img = np.array(self.data[idx]) # copy to normal ndarray
|
|
465
|
+
summary = img.sum()
|
|
466
|
+
assert summary > 0, f"Image at {image_path} appears to be all zeros"
|
|
467
|
+
return img
|
|
468
|
+
|
|
469
|
+
class ImageMmapDynamic(Dataset):
|
|
470
|
+
"""
|
|
471
|
+
Dynamic-shape mmap dataset.
|
|
472
|
+
|
|
473
|
+
- First run (no mmap/meta or hash mismatch): read all img_paths, keep original H/W,
|
|
474
|
+
append flattened bytes sequentially into a flat mmap file.
|
|
475
|
+
- Also writes a .meta file with mapping:
|
|
476
|
+
img_path -> [offset, H, W, C]
|
|
477
|
+
- Next runs: only open mmap + meta and do constant-time slice + reshape.
|
|
478
|
+
"""
|
|
479
|
+
|
|
480
|
+
def __init__(
|
|
481
|
+
self,
|
|
482
|
+
img_paths: Sequence[str | os.PathLike],
|
|
483
|
+
mmap_path: str | os.PathLike | None = None,
|
|
484
|
+
dtype: np.dtype | str = np.uint8,
|
|
485
|
+
safe: bool = True,
|
|
486
|
+
) -> None:
|
|
487
|
+
self.img_paths = [str(p) for p in img_paths]
|
|
488
|
+
self.imgpath2idx = {p: i for i, p in enumerate(self.img_paths)}
|
|
489
|
+
self.n = len(self.img_paths)
|
|
490
|
+
if self.n == 0:
|
|
491
|
+
raise ValueError('Cannot create ImageMmapDynamic with empty img_paths list')
|
|
492
|
+
|
|
493
|
+
self.dtype = np.dtype(dtype)
|
|
494
|
+
self.safe = safe
|
|
495
|
+
|
|
496
|
+
# Default path if not provided
|
|
497
|
+
if mmap_path is None:
|
|
498
|
+
hash_idx = identify(''.join(self.img_paths))
|
|
499
|
+
mmap_path = Path('.cache') / f'mmap_dynamic_{hash_idx}.dat'
|
|
500
|
+
|
|
501
|
+
self.mmap_path = Path(mmap_path)
|
|
502
|
+
self.meta_path = Path(str(self.mmap_path) + '.meta')
|
|
503
|
+
self.hash_path = Path(str(self.mmap_path) + '.hash')
|
|
504
|
+
self.lock_path = Path(str(self.mmap_path) + '.lock')
|
|
505
|
+
|
|
506
|
+
# Hash of the path list to detect changes
|
|
507
|
+
current_hash = identify(self.img_paths)
|
|
508
|
+
needs_rebuild = False
|
|
509
|
+
|
|
510
|
+
if not self.mmap_path.exists() or not self.meta_path.exists():
|
|
511
|
+
needs_rebuild = True
|
|
512
|
+
print('Dynamic mmap or meta file does not exist, building cache...')
|
|
513
|
+
elif not self.hash_path.exists():
|
|
514
|
+
needs_rebuild = True
|
|
515
|
+
print('Hash file does not exist for dynamic mmap, rebuilding cache...')
|
|
516
|
+
else:
|
|
517
|
+
stored_hash = self.hash_path.read_text().strip()
|
|
518
|
+
if stored_hash != current_hash:
|
|
519
|
+
needs_rebuild = True
|
|
520
|
+
print(
|
|
521
|
+
f'Dynamic mmap hash mismatch '
|
|
522
|
+
f'(stored: {stored_hash[:16]}..., current: {current_hash[:16]}...), '
|
|
523
|
+
'rebuilding cache...'
|
|
524
|
+
)
|
|
525
|
+
else:
|
|
526
|
+
# Check size vs meta
|
|
527
|
+
import json
|
|
528
|
+
|
|
529
|
+
try:
|
|
530
|
+
with open(self.meta_path, 'r') as f:
|
|
531
|
+
meta = json.load(f)
|
|
532
|
+
meta_dtype = np.dtype(meta.get('dtype', 'uint8'))
|
|
533
|
+
total_elems = int(meta['total_elems'])
|
|
534
|
+
expected_bytes = total_elems * meta_dtype.itemsize
|
|
535
|
+
actual_bytes = self.mmap_path.stat().st_size
|
|
536
|
+
if actual_bytes != expected_bytes:
|
|
537
|
+
needs_rebuild = True
|
|
538
|
+
print(
|
|
539
|
+
'Dynamic mmap file size mismatch '
|
|
540
|
+
f'(expected: {expected_bytes}, got: {actual_bytes}), '
|
|
541
|
+
'rebuilding cache...'
|
|
542
|
+
)
|
|
543
|
+
except Exception as e:
|
|
544
|
+
needs_rebuild = True
|
|
545
|
+
print(f'Failed to read dynamic mmap meta ({e}), rebuilding cache...')
|
|
546
|
+
|
|
547
|
+
if needs_rebuild:
|
|
548
|
+
self._build_cache_with_lock(current_hash)
|
|
549
|
+
|
|
550
|
+
# After build (or if cache was already OK), load meta + mmap
|
|
551
|
+
self._load_metadata()
|
|
552
|
+
self.data = np.memmap(
|
|
553
|
+
self.mmap_path,
|
|
554
|
+
dtype=self.dtype,
|
|
555
|
+
mode='r',
|
|
556
|
+
shape=(self.total_elems,),
|
|
557
|
+
)
|
|
558
|
+
|
|
559
|
+
# ------------------------------------------------------------------ #
|
|
560
|
+
# Build phase with lock (same pattern as ImageMmap)
|
|
561
|
+
# ------------------------------------------------------------------ #
|
|
562
|
+
def _build_cache_with_lock(self, current_hash: str) -> None:
|
|
563
|
+
"""Build dynamic mmap with a lock file to prevent concurrent writes."""
|
|
564
|
+
self.mmap_path.parent.mkdir(parents=True, exist_ok=True)
|
|
565
|
+
|
|
566
|
+
lock_fd = None
|
|
567
|
+
try:
|
|
568
|
+
import fcntl # POSIX only, same as ImageMmap
|
|
569
|
+
|
|
570
|
+
lock_fd = open(self.lock_path, 'w')
|
|
571
|
+
fcntl.flock(lock_fd.fileno(), fcntl.LOCK_EX | fcntl.LOCK_NB)
|
|
572
|
+
|
|
573
|
+
# We got the lock -> build cache
|
|
574
|
+
self._build_cache(current_hash)
|
|
575
|
+
except BlockingIOError:
|
|
576
|
+
# Another process is building -> wait
|
|
577
|
+
print('Another process is building the dynamic mmap cache, waiting...')
|
|
578
|
+
if lock_fd:
|
|
579
|
+
lock_fd.close()
|
|
580
|
+
lock_fd = open(self.lock_path, 'w')
|
|
581
|
+
import fcntl as _fcntl
|
|
582
|
+
|
|
583
|
+
_fcntl.flock(lock_fd.fileno(), _fcntl.LOCK_EX) # block until released
|
|
584
|
+
print('Dynamic mmap cache built by another process!')
|
|
585
|
+
finally:
|
|
586
|
+
if lock_fd:
|
|
587
|
+
lock_fd.close()
|
|
588
|
+
if self.lock_path.exists():
|
|
589
|
+
try:
|
|
590
|
+
self.lock_path.unlink()
|
|
591
|
+
except Exception:
|
|
592
|
+
pass
|
|
593
|
+
|
|
594
|
+
def _build_cache(self, current_hash: str) -> None:
|
|
595
|
+
"""
|
|
596
|
+
Build the flat mmap + .meta file.
|
|
597
|
+
|
|
598
|
+
Layout:
|
|
599
|
+
- data file: concatenated flattened images in path order
|
|
600
|
+
- meta: JSON with offsets, shapes, dtype, total_elems, paths, n
|
|
601
|
+
"""
|
|
602
|
+
from tqdm import tqdm
|
|
603
|
+
import json
|
|
604
|
+
|
|
605
|
+
print(f'Building dynamic mmap cache for {self.n} images...')
|
|
606
|
+
# We don't know total size up front -> write sequentially
|
|
607
|
+
offsets = np.zeros(self.n, dtype=np.int64)
|
|
608
|
+
shapes = np.zeros((self.n, 3), dtype=np.int64)
|
|
609
|
+
|
|
610
|
+
batch_size = 4096
|
|
611
|
+
num_batches = (self.n + batch_size - 1) // batch_size
|
|
612
|
+
|
|
613
|
+
current_offset = 0 # in elements, not bytes
|
|
614
|
+
|
|
615
|
+
with open(self.mmap_path, 'wb') as f, tqdm(
|
|
616
|
+
total=self.n, desc='Processing images (dynamic)', unit='img'
|
|
617
|
+
) as pbar:
|
|
618
|
+
for batch_idx in range(num_batches):
|
|
619
|
+
start_idx = batch_idx * batch_size
|
|
620
|
+
end_idx = min(start_idx + batch_size, self.n)
|
|
621
|
+
batch_paths = self.img_paths[start_idx:end_idx]
|
|
622
|
+
|
|
623
|
+
images_dict = read_images(
|
|
624
|
+
batch_paths,
|
|
625
|
+
hw=None, # keep original size
|
|
626
|
+
batch_size=32,
|
|
627
|
+
num_threads=max(1, cpu_count() - 1),
|
|
628
|
+
)
|
|
629
|
+
|
|
630
|
+
for local_idx, path in enumerate(batch_paths):
|
|
631
|
+
global_idx = start_idx + local_idx
|
|
632
|
+
img = images_dict.get(path)
|
|
633
|
+
|
|
634
|
+
if img is None:
|
|
635
|
+
if self.safe:
|
|
636
|
+
raise ValueError(f'Failed to load image: {path}')
|
|
637
|
+
else:
|
|
638
|
+
print(
|
|
639
|
+
f'Warning: Failed to load {path}, storing 1x1x3 zeros'
|
|
640
|
+
)
|
|
641
|
+
img = np.zeros((1, 1, 3), dtype=self.dtype)
|
|
642
|
+
|
|
643
|
+
if img.dtype != self.dtype:
|
|
644
|
+
img = img.astype(self.dtype)
|
|
645
|
+
|
|
646
|
+
if img.ndim != 3:
|
|
647
|
+
raise ValueError(
|
|
648
|
+
f'Expected image with 3 dims (H,W,C), got shape {img.shape} '
|
|
649
|
+
f'for path {path}'
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
h, w, c = img.shape
|
|
653
|
+
shapes[global_idx] = (h, w, c)
|
|
654
|
+
offsets[global_idx] = current_offset
|
|
655
|
+
|
|
656
|
+
flat = img.reshape(-1)
|
|
657
|
+
f.write(flat.tobytes())
|
|
658
|
+
|
|
659
|
+
current_offset += flat.size
|
|
660
|
+
pbar.update(1)
|
|
661
|
+
|
|
662
|
+
total_elems = int(current_offset)
|
|
663
|
+
self.total_elems = total_elems
|
|
664
|
+
|
|
665
|
+
meta = {
|
|
666
|
+
'version': 1,
|
|
667
|
+
'dtype': self.dtype.name,
|
|
668
|
+
'n': self.n,
|
|
669
|
+
'paths': self.img_paths,
|
|
670
|
+
'offsets': offsets.tolist(),
|
|
671
|
+
'shapes': shapes.tolist(),
|
|
672
|
+
'total_elems': total_elems,
|
|
673
|
+
}
|
|
674
|
+
|
|
675
|
+
with open(self.meta_path, 'w') as mf:
|
|
676
|
+
json.dump(meta, mf)
|
|
677
|
+
|
|
678
|
+
self.hash_path.write_text(current_hash)
|
|
679
|
+
print(
|
|
680
|
+
f'Dynamic mmap cache built successfully! '
|
|
681
|
+
f'Meta saved to {self.meta_path}, total_elems={total_elems}'
|
|
682
|
+
)
|
|
683
|
+
|
|
684
|
+
# ------------------------------------------------------------------ #
|
|
685
|
+
# Metadata loader
|
|
686
|
+
# ------------------------------------------------------------------ #
|
|
687
|
+
def _load_metadata(self) -> None:
|
|
688
|
+
import json
|
|
689
|
+
|
|
690
|
+
with open(self.meta_path, 'r') as f:
|
|
691
|
+
meta = json.load(f)
|
|
692
|
+
|
|
693
|
+
# If paths order changed without hash mismatch, this will still keep
|
|
694
|
+
# the meta-consistent order (but hash comparison should prevent that).
|
|
695
|
+
self.img_paths = [str(p) for p in meta['paths']]
|
|
696
|
+
self.imgpath2idx = {p: i for i, p in enumerate(self.img_paths)}
|
|
697
|
+
self.n = int(meta['n'])
|
|
698
|
+
self.dtype = np.dtype(meta.get('dtype', 'uint8'))
|
|
699
|
+
self.offsets = np.asarray(meta['offsets'], dtype=np.int64)
|
|
700
|
+
self.shapes = np.asarray(meta['shapes'], dtype=np.int64)
|
|
701
|
+
self.total_elems = int(meta['total_elems'])
|
|
702
|
+
|
|
703
|
+
assert len(self.offsets) == self.n
|
|
704
|
+
assert self.shapes.shape == (self.n, 3)
|
|
705
|
+
|
|
706
|
+
# ------------------------------------------------------------------ #
|
|
707
|
+
# Dataset API
|
|
708
|
+
# ------------------------------------------------------------------ #
|
|
709
|
+
def __len__(self) -> int:
|
|
710
|
+
return self.n
|
|
711
|
+
|
|
712
|
+
def _get_flat_slice(self, idx: int) -> np.ndarray:
|
|
713
|
+
"""Return flat view for image idx (no copy)."""
|
|
714
|
+
offset = int(self.offsets[idx])
|
|
715
|
+
h, w, c = [int(x) for x in self.shapes[idx]]
|
|
716
|
+
num_elems = h * w * c
|
|
717
|
+
flat = self.data[offset : offset + num_elems]
|
|
718
|
+
return flat, h, w, c
|
|
719
|
+
|
|
720
|
+
def __getitem__(self, idx: int) -> np.ndarray:
|
|
721
|
+
flat, h, w, c = self._get_flat_slice(idx)
|
|
722
|
+
img = np.array(flat).reshape(h, w, c) # copy to normal ndarray
|
|
723
|
+
return img
|
|
724
|
+
|
|
725
|
+
def imread(self, image_path: str | os.PathLike) -> np.ndarray:
|
|
726
|
+
idx = self.imgpath2idx.get(str(image_path))
|
|
727
|
+
if idx is None:
|
|
728
|
+
raise ValueError(f'Image path {image_path} not found in dynamic dataset')
|
|
729
|
+
img = self[idx]
|
|
730
|
+
if self.safe:
|
|
731
|
+
summary = img.sum()
|
|
732
|
+
assert summary > 0, f'Image at {image_path} appears to be all zeros'
|
|
733
|
+
return img
|
|
734
|
+
|
|
735
|
+
__all__ = ['read_images', 'read_images_cpu', 'read_images_gpu', 'ImageMmap', 'ImageMmapDynamic']
|