flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/attention.py +22 -16
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/common.py +8 -18
- flaxdiff/models/simple_unet.py +6 -17
- flaxdiff/models/simple_vit.py +9 -13
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +35 -29
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +51 -16
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
flaxdiff/data/online_loader.py
CHANGED
@@ -1,14 +1,12 @@
|
|
1
1
|
import multiprocessing
|
2
2
|
import threading
|
3
3
|
from multiprocessing import Queue
|
4
|
-
# from arrayqueues.shared_arrays import ArrayQueue
|
5
|
-
# from faster_fifo import Queue
|
6
4
|
import time
|
7
5
|
import albumentations as A
|
8
6
|
import queue
|
9
7
|
import cv2
|
10
8
|
from functools import partial
|
11
|
-
from typing import Any, Dict, List, Tuple
|
9
|
+
from typing import Any, Dict, List, Tuple, Optional, Union, Callable
|
12
10
|
|
13
11
|
import numpy as np
|
14
12
|
from functools import partial
|
@@ -18,18 +16,42 @@ from datasets.utils.file_utils import get_datasets_user_agent
|
|
18
16
|
from concurrent.futures import ThreadPoolExecutor
|
19
17
|
import io
|
20
18
|
import urllib
|
19
|
+
import os
|
21
20
|
|
22
21
|
import PIL.Image
|
23
|
-
import cv2
|
24
22
|
import traceback
|
25
23
|
|
26
24
|
USER_AGENT = get_datasets_user_agent()
|
27
25
|
|
28
|
-
|
26
|
+
|
27
|
+
class ResourceManager:
|
28
|
+
"""A manager for shared resources across data loading processes."""
|
29
|
+
|
30
|
+
def __init__(self, max_queue_size: int = 32000):
|
31
|
+
"""Initialize a resource manager.
|
32
|
+
|
33
|
+
Args:
|
34
|
+
max_queue_size: Maximum size of the data queue.
|
35
|
+
"""
|
36
|
+
self.data_queue = Queue(max_queue_size)
|
37
|
+
|
38
|
+
def get_data_queue(self) -> Queue:
|
39
|
+
"""Get the data queue."""
|
40
|
+
return self.data_queue
|
29
41
|
|
30
42
|
|
31
|
-
def fetch_single_image(image_url, timeout=None, retries=0):
|
32
|
-
|
43
|
+
def fetch_single_image(image_url: str, timeout: Optional[int] = None, retries: int = 0) -> Optional[PIL.Image.Image]:
|
44
|
+
"""Fetch a single image from a URL.
|
45
|
+
|
46
|
+
Args:
|
47
|
+
image_url: URL of the image to fetch.
|
48
|
+
timeout: Timeout in seconds for the request.
|
49
|
+
retries: Number of times to retry the request.
|
50
|
+
|
51
|
+
Returns:
|
52
|
+
A PIL image or None if the image couldn't be fetched.
|
53
|
+
"""
|
54
|
+
for attempt in range(retries + 1):
|
33
55
|
try:
|
34
56
|
request = urllib.request.Request(
|
35
57
|
image_url,
|
@@ -38,38 +60,135 @@ def fetch_single_image(image_url, timeout=None, retries=0):
|
|
38
60
|
)
|
39
61
|
with urllib.request.urlopen(request, timeout=timeout) as req:
|
40
62
|
image = PIL.Image.open(io.BytesIO(req.read()))
|
41
|
-
|
42
|
-
except Exception:
|
43
|
-
|
44
|
-
|
63
|
+
return image
|
64
|
+
except Exception as e:
|
65
|
+
if attempt < retries:
|
66
|
+
# Wait a bit before retrying
|
67
|
+
time.sleep(0.1 * (attempt + 1))
|
68
|
+
continue
|
69
|
+
# Log the error on the final attempt
|
70
|
+
print(f"Error fetching image {image_url}: {e}")
|
71
|
+
return None
|
72
|
+
|
73
|
+
|
74
|
+
def fetch_single_video(video_url: str, timeout: Optional[int] = None, retries: int = 0,
|
75
|
+
max_frames: int = 32) -> Optional[List[np.ndarray]]:
|
76
|
+
"""Fetch a single video from a URL.
|
77
|
+
|
78
|
+
Args:
|
79
|
+
video_url: URL of the video to fetch.
|
80
|
+
timeout: Timeout in seconds for the request.
|
81
|
+
retries: Number of times to retry the request.
|
82
|
+
max_frames: Maximum number of frames to extract.
|
83
|
+
|
84
|
+
Returns:
|
85
|
+
A list of video frames as numpy arrays or None if the video couldn't be fetched.
|
86
|
+
"""
|
87
|
+
# Create a temporary file to download the video
|
88
|
+
import tempfile
|
89
|
+
|
90
|
+
for attempt in range(retries + 1):
|
91
|
+
try:
|
92
|
+
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
|
93
|
+
tmp_path = tmp_file.name
|
94
|
+
|
95
|
+
request = urllib.request.Request(
|
96
|
+
video_url,
|
97
|
+
data=None,
|
98
|
+
headers={"user-agent": USER_AGENT},
|
99
|
+
)
|
100
|
+
with urllib.request.urlopen(request, timeout=timeout) as req:
|
101
|
+
with open(tmp_path, 'wb') as f:
|
102
|
+
f.write(req.read())
|
103
|
+
|
104
|
+
# Load the video frames
|
105
|
+
cap = cv2.VideoCapture(tmp_path)
|
106
|
+
frames = []
|
107
|
+
|
108
|
+
while len(frames) < max_frames:
|
109
|
+
ret, frame = cap.read()
|
110
|
+
if not ret:
|
111
|
+
break
|
112
|
+
frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
|
113
|
+
|
114
|
+
cap.release()
|
115
|
+
|
116
|
+
# Delete the temporary file
|
117
|
+
try:
|
118
|
+
os.remove(tmp_path)
|
119
|
+
except:
|
120
|
+
pass
|
121
|
+
|
122
|
+
return frames if frames else None
|
123
|
+
|
124
|
+
except Exception as e:
|
125
|
+
if attempt < retries:
|
126
|
+
# Wait a bit before retrying
|
127
|
+
time.sleep(0.1 * (attempt + 1))
|
128
|
+
continue
|
129
|
+
# Log the error on the final attempt
|
130
|
+
print(f"Error fetching video {video_url}: {e}")
|
131
|
+
|
132
|
+
# Clean up the temporary file
|
133
|
+
try:
|
134
|
+
if 'tmp_path' in locals():
|
135
|
+
os.remove(tmp_path)
|
136
|
+
except:
|
137
|
+
pass
|
138
|
+
|
139
|
+
return None
|
45
140
|
|
46
141
|
|
47
142
|
def default_image_processor(
|
48
|
-
image
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
143
|
+
image: PIL.Image.Image,
|
144
|
+
image_shape: Tuple[int, int],
|
145
|
+
min_image_shape: Tuple[int, int] = (128, 128),
|
146
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
147
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
148
|
+
) -> Tuple[Optional[np.ndarray], int, int]:
|
149
|
+
"""Process an image for training.
|
150
|
+
|
151
|
+
Args:
|
152
|
+
image: PIL image to process.
|
153
|
+
image_shape: Target shape (height, width).
|
154
|
+
min_image_shape: Minimum acceptable shape.
|
155
|
+
upscale_interpolation: Interpolation method for upscaling.
|
156
|
+
downscale_interpolation: Interpolation method for downscaling.
|
157
|
+
|
158
|
+
Returns:
|
159
|
+
Tuple of (processed image, original height, original width).
|
160
|
+
Processed image may be None if the image couldn't be processed.
|
161
|
+
"""
|
53
162
|
try:
|
163
|
+
# Convert to numpy
|
54
164
|
image = np.array(image)
|
165
|
+
|
166
|
+
# Check if image has 3 channels
|
55
167
|
if len(image.shape) != 3 or image.shape[2] != 3:
|
56
168
|
return None, 0, 0
|
169
|
+
|
57
170
|
original_height, original_width = image.shape[:2]
|
58
|
-
|
171
|
+
|
172
|
+
# Check if the image is too small
|
59
173
|
if min(original_height, original_width) < min(min_image_shape):
|
60
174
|
return None, original_height, original_width
|
61
|
-
|
175
|
+
|
176
|
+
# Check if wrong aspect ratio
|
62
177
|
if max(original_height, original_width) / min(original_height, original_width) > 2.4:
|
63
178
|
return None, original_height, original_width
|
64
|
-
|
179
|
+
|
180
|
+
# Check if the variance is too low (likely a blank/solid color image)
|
65
181
|
if np.std(image) < 1e-5:
|
66
182
|
return None, original_height, original_width
|
67
|
-
|
183
|
+
|
184
|
+
# Choose interpolation method based on whether we're upscaling or downscaling
|
68
185
|
downscale = max(original_width, original_height) > max(image_shape)
|
69
186
|
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
70
187
|
|
71
|
-
|
72
|
-
|
188
|
+
# Resize while keeping aspect ratio
|
189
|
+
image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
|
190
|
+
|
191
|
+
# Pad to target shape
|
73
192
|
image = A.pad(
|
74
193
|
image,
|
75
194
|
min_height=image_shape[0],
|
@@ -77,30 +196,114 @@ def default_image_processor(
|
|
77
196
|
border_mode=cv2.BORDER_CONSTANT,
|
78
197
|
value=[255, 255, 255],
|
79
198
|
)
|
199
|
+
|
80
200
|
return image, original_height, original_width
|
201
|
+
|
202
|
+
except Exception as e:
|
203
|
+
# Log the error
|
204
|
+
print(f"Error processing image: {e}")
|
205
|
+
return None, 0, 0
|
206
|
+
|
207
|
+
|
208
|
+
def default_video_processor(
|
209
|
+
frames: List[np.ndarray],
|
210
|
+
frame_size: int = 256,
|
211
|
+
min_frame_size: int = 128,
|
212
|
+
num_frames: int = 16,
|
213
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
214
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
215
|
+
) -> Tuple[Optional[np.ndarray], int, int]:
|
216
|
+
"""Process video frames for training.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
frames: List of video frames as numpy arrays.
|
220
|
+
frame_size: Target size for each frame.
|
221
|
+
min_frame_size: Minimum acceptable frame size.
|
222
|
+
num_frames: Target number of frames.
|
223
|
+
upscale_interpolation: Interpolation method for upscaling.
|
224
|
+
downscale_interpolation: Interpolation method for downscaling.
|
225
|
+
|
226
|
+
Returns:
|
227
|
+
Tuple of (processed video array, original height, original width).
|
228
|
+
Processed video may be None if the video couldn't be processed.
|
229
|
+
"""
|
230
|
+
try:
|
231
|
+
if not frames or len(frames) == 0:
|
232
|
+
return None, 0, 0
|
233
|
+
|
234
|
+
# Get dimensions of the first frame
|
235
|
+
first_frame = frames[0]
|
236
|
+
original_height, original_width = first_frame.shape[:2]
|
237
|
+
|
238
|
+
# Check if frames are too small
|
239
|
+
if min(original_height, original_width) < min_frame_size:
|
240
|
+
return None, original_height, original_width
|
241
|
+
|
242
|
+
# Sample frames evenly
|
243
|
+
if len(frames) < num_frames:
|
244
|
+
# Not enough frames, duplicate some
|
245
|
+
indices = np.linspace(0, len(frames) - 1, num_frames, dtype=int)
|
246
|
+
sampled_frames = [frames[i] for i in indices]
|
247
|
+
else:
|
248
|
+
# Sample frames evenly
|
249
|
+
indices = np.linspace(0, len(frames) - 1, num_frames, dtype=int)
|
250
|
+
sampled_frames = [frames[i] for i in indices]
|
251
|
+
|
252
|
+
# Process each frame
|
253
|
+
processed_frames = []
|
254
|
+
for frame in sampled_frames:
|
255
|
+
# Choose interpolation method based on whether we're upscaling or downscaling
|
256
|
+
downscale = max(frame.shape[1], frame.shape[0]) > frame_size
|
257
|
+
interpolation = downscale_interpolation if downscale else upscale_interpolation
|
258
|
+
|
259
|
+
# Resize frame
|
260
|
+
resized_frame = cv2.resize(frame, (frame_size, frame_size), interpolation=interpolation)
|
261
|
+
processed_frames.append(resized_frame)
|
262
|
+
|
263
|
+
# Stack frames into a video tensor [num_frames, height, width, channels]
|
264
|
+
video_tensor = np.stack(processed_frames, axis=0)
|
265
|
+
|
266
|
+
return video_tensor, original_height, original_width
|
267
|
+
|
81
268
|
except Exception as e:
|
82
|
-
#
|
83
|
-
|
269
|
+
# Log the error
|
270
|
+
print(f"Error processing video: {e}")
|
84
271
|
return None, 0, 0
|
85
272
|
|
86
273
|
|
87
|
-
def
|
88
|
-
url,
|
89
|
-
caption,
|
90
|
-
|
91
|
-
|
92
|
-
|
93
|
-
|
94
|
-
|
95
|
-
|
96
|
-
|
274
|
+
def map_image_sample(
|
275
|
+
url: str,
|
276
|
+
caption: str,
|
277
|
+
data_queue: Queue,
|
278
|
+
image_shape: Tuple[int, int] = (256, 256),
|
279
|
+
min_image_shape: Tuple[int, int] = (128, 128),
|
280
|
+
timeout: int = 15,
|
281
|
+
retries: int = 3,
|
282
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
283
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
284
|
+
image_processor: Callable = default_image_processor,
|
97
285
|
):
|
286
|
+
"""Process a single image sample and put it in the queue.
|
287
|
+
|
288
|
+
Args:
|
289
|
+
url: URL of the image.
|
290
|
+
caption: Caption for the image.
|
291
|
+
data_queue: Queue to put the processed sample in.
|
292
|
+
image_shape: Target shape for the image.
|
293
|
+
min_image_shape: Minimum acceptable shape.
|
294
|
+
timeout: Timeout for image fetching.
|
295
|
+
retries: Number of retries for image fetching.
|
296
|
+
upscale_interpolation: Interpolation method for upscaling.
|
297
|
+
downscale_interpolation: Interpolation method for downscaling.
|
298
|
+
image_processor: Function to process the image.
|
299
|
+
"""
|
98
300
|
try:
|
99
|
-
#
|
301
|
+
# Fetch the image
|
100
302
|
image = fetch_single_image(url, timeout=timeout, retries=retries)
|
101
303
|
if image is None:
|
102
304
|
return
|
103
305
|
|
306
|
+
# Process the image
|
104
307
|
image, original_height, original_width = image_processor(
|
105
308
|
image, image_shape, min_image_shape=min_image_shape,
|
106
309
|
upscale_interpolation=upscale_interpolation,
|
@@ -110,6 +313,7 @@ def map_sample(
|
|
110
313
|
if image is None:
|
111
314
|
return
|
112
315
|
|
316
|
+
# Put the processed sample in the queue
|
113
317
|
data_queue.put({
|
114
318
|
"url": url,
|
115
319
|
"caption": caption,
|
@@ -117,158 +321,426 @@ def map_sample(
|
|
117
321
|
"original_height": original_height,
|
118
322
|
"original_width": original_width,
|
119
323
|
})
|
324
|
+
|
120
325
|
except Exception as e:
|
121
|
-
#
|
122
|
-
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
|
128
|
-
|
129
|
-
|
130
|
-
|
326
|
+
# Log the error
|
327
|
+
print(f"Error mapping image sample {url}: {e}")
|
328
|
+
|
329
|
+
|
330
|
+
def map_video_sample(
|
331
|
+
url: str,
|
332
|
+
caption: str,
|
333
|
+
data_queue: Queue,
|
334
|
+
frame_size: int = 256,
|
335
|
+
min_frame_size: int = 128,
|
336
|
+
num_frames: int = 16,
|
337
|
+
timeout: int = 30,
|
338
|
+
retries: int = 3,
|
339
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
340
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
341
|
+
video_processor: Callable = default_video_processor,
|
342
|
+
):
|
343
|
+
"""Process a single video sample and put it in the queue.
|
344
|
+
|
345
|
+
Args:
|
346
|
+
url: URL of the video.
|
347
|
+
caption: Caption for the video.
|
348
|
+
data_queue: Queue to put the processed sample in.
|
349
|
+
frame_size: Target size for each frame.
|
350
|
+
min_frame_size: Minimum acceptable frame size.
|
351
|
+
num_frames: Target number of frames.
|
352
|
+
timeout: Timeout for video fetching.
|
353
|
+
retries: Number of retries for video fetching.
|
354
|
+
upscale_interpolation: Interpolation method for upscaling.
|
355
|
+
downscale_interpolation: Interpolation method for downscaling.
|
356
|
+
video_processor: Function to process the video.
|
357
|
+
"""
|
358
|
+
try:
|
359
|
+
# Fetch the video frames
|
360
|
+
frames = fetch_single_video(url, timeout=timeout, retries=retries, max_frames=num_frames*2)
|
361
|
+
if frames is None or len(frames) == 0:
|
362
|
+
return
|
363
|
+
|
364
|
+
# Process the video
|
365
|
+
video, original_height, original_width = video_processor(
|
366
|
+
frames, frame_size, min_frame_size=min_frame_size,
|
367
|
+
num_frames=num_frames,
|
368
|
+
upscale_interpolation=upscale_interpolation,
|
369
|
+
downscale_interpolation=downscale_interpolation,
|
370
|
+
)
|
371
|
+
|
372
|
+
if video is None:
|
373
|
+
return
|
374
|
+
|
375
|
+
# Put the processed sample in the queue
|
376
|
+
data_queue.put({
|
377
|
+
"url": url,
|
378
|
+
"caption": caption,
|
379
|
+
"video": video,
|
380
|
+
"original_height": original_height,
|
381
|
+
"original_width": original_width,
|
382
|
+
})
|
383
|
+
|
384
|
+
except Exception as e:
|
385
|
+
# Log the error
|
386
|
+
print(f"Error mapping video sample {url}: {e}")
|
387
|
+
|
388
|
+
|
389
|
+
def default_feature_extractor(sample: Dict[str, Any]) -> Dict[str, Any]:
|
390
|
+
"""Extract features from a sample.
|
391
|
+
|
392
|
+
Args:
|
393
|
+
sample: Sample to extract features from.
|
394
|
+
|
395
|
+
Returns:
|
396
|
+
Dictionary with extracted url and caption.
|
397
|
+
"""
|
398
|
+
# Extract URL
|
131
399
|
url = None
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
400
|
+
for key in ["url", "URL", "image_url", "video_url"]:
|
401
|
+
if key in sample:
|
402
|
+
url = sample[key]
|
403
|
+
break
|
404
|
+
|
405
|
+
if url is None:
|
406
|
+
print("No URL found in sample, keys:", sample.keys())
|
407
|
+
return {"url": None, "caption": None}
|
140
408
|
|
409
|
+
# Extract caption
|
141
410
|
caption = None
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
caption = sample["TEXT"]
|
150
|
-
elif "text" in sample:
|
151
|
-
caption = sample["text"]
|
152
|
-
else:
|
153
|
-
print("No caption found in sample, skipping", sample.keys())
|
411
|
+
for key in ["caption", "CAPTION", "txt", "TEXT", "text"]:
|
412
|
+
if key in sample and sample[key] is not None:
|
413
|
+
caption = sample[key]
|
414
|
+
break
|
415
|
+
|
416
|
+
if caption is None:
|
417
|
+
caption = "No caption available"
|
154
418
|
|
155
419
|
return {
|
156
420
|
"url": url,
|
157
421
|
"caption": caption,
|
158
422
|
}
|
159
423
|
|
424
|
+
|
160
425
|
def map_batch(
|
161
|
-
batch
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
426
|
+
batch: Dict[str, Any],
|
427
|
+
data_queue: Queue,
|
428
|
+
media_type: str = "image",
|
429
|
+
num_threads: int = 256,
|
430
|
+
image_shape: Tuple[int, int] = (256, 256),
|
431
|
+
min_image_shape: Tuple[int, int] = (128, 128),
|
432
|
+
frame_size: int = 256,
|
433
|
+
min_frame_size: int = 128,
|
434
|
+
num_frames: int = 16,
|
435
|
+
timeout: int = 15,
|
436
|
+
retries: int = 3,
|
437
|
+
image_processor: Callable = default_image_processor,
|
438
|
+
video_processor: Callable = default_video_processor,
|
439
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
440
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
441
|
+
feature_extractor: Callable = default_feature_extractor,
|
167
442
|
):
|
443
|
+
"""Map a batch of samples and process them in parallel.
|
444
|
+
|
445
|
+
Args:
|
446
|
+
batch: Batch of samples to process.
|
447
|
+
data_queue: Queue to put processed samples in.
|
448
|
+
media_type: Type of media ("image" or "video").
|
449
|
+
num_threads: Number of threads to use for processing.
|
450
|
+
image_shape: Target shape for images.
|
451
|
+
min_image_shape: Minimum acceptable shape for images.
|
452
|
+
frame_size: Target size for video frames.
|
453
|
+
min_frame_size: Minimum acceptable size for video frames.
|
454
|
+
num_frames: Target number of frames for videos.
|
455
|
+
timeout: Timeout for fetching.
|
456
|
+
retries: Number of retries for fetching.
|
457
|
+
image_processor: Function to process images.
|
458
|
+
video_processor: Function to process videos.
|
459
|
+
upscale_interpolation: Interpolation method for upscaling.
|
460
|
+
downscale_interpolation: Interpolation method for downscaling.
|
461
|
+
feature_extractor: Function to extract features from samples.
|
462
|
+
"""
|
168
463
|
try:
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
464
|
+
# Choose mapping function based on media type
|
465
|
+
if media_type == "video":
|
466
|
+
map_func = partial(
|
467
|
+
map_video_sample,
|
468
|
+
data_queue=data_queue,
|
469
|
+
frame_size=frame_size,
|
470
|
+
min_frame_size=min_frame_size,
|
471
|
+
num_frames=num_frames,
|
472
|
+
timeout=timeout,
|
473
|
+
retries=retries,
|
474
|
+
video_processor=video_processor,
|
475
|
+
upscale_interpolation=upscale_interpolation,
|
476
|
+
downscale_interpolation=downscale_interpolation,
|
477
|
+
)
|
478
|
+
else: # Default to image
|
479
|
+
map_func = partial(
|
480
|
+
map_image_sample,
|
481
|
+
data_queue=data_queue,
|
482
|
+
image_shape=image_shape,
|
483
|
+
min_image_shape=min_image_shape,
|
484
|
+
timeout=timeout,
|
485
|
+
retries=retries,
|
486
|
+
image_processor=image_processor,
|
487
|
+
upscale_interpolation=upscale_interpolation,
|
488
|
+
downscale_interpolation=downscale_interpolation,
|
489
|
+
)
|
490
|
+
|
491
|
+
# Extract features from batch
|
492
|
+
features = feature_extractor(batch)
|
493
|
+
urls, captions = features["url"], features["caption"]
|
494
|
+
|
495
|
+
if urls is None or captions is None:
|
496
|
+
return
|
497
|
+
|
498
|
+
# Process samples in parallel
|
175
499
|
with ThreadPoolExecutor(max_workers=num_threads) as executor:
|
176
|
-
|
177
|
-
|
178
|
-
executor.map(map_sample_fn, url, caption)
|
500
|
+
executor.map(map_func, urls, captions)
|
501
|
+
|
179
502
|
except Exception as e:
|
180
|
-
|
503
|
+
# Log the error
|
504
|
+
print(f"Error mapping batch: {e}")
|
181
505
|
traceback.print_exc()
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
|
189
|
-
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
506
|
+
|
507
|
+
|
508
|
+
def parallel_media_loader(
|
509
|
+
dataset: Dataset,
|
510
|
+
data_queue: Queue,
|
511
|
+
media_type: str = "image",
|
512
|
+
num_workers: int = 8,
|
513
|
+
image_shape: Tuple[int, int] = (256, 256),
|
514
|
+
min_image_shape: Tuple[int, int] = (128, 128),
|
515
|
+
frame_size: int = 256,
|
516
|
+
min_frame_size: int = 128,
|
517
|
+
num_frames: int = 16,
|
518
|
+
num_threads: int = 256,
|
519
|
+
timeout: int = 15,
|
520
|
+
retries: int = 3,
|
521
|
+
image_processor: Callable = default_image_processor,
|
522
|
+
video_processor: Callable = default_video_processor,
|
523
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
524
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
525
|
+
feature_extractor: Callable = default_feature_extractor,
|
196
526
|
):
|
527
|
+
"""Load and process media from a dataset in parallel.
|
528
|
+
|
529
|
+
Args:
|
530
|
+
dataset: Dataset to load from.
|
531
|
+
data_queue: Queue to put processed samples in.
|
532
|
+
media_type: Type of media ("image" or "video").
|
533
|
+
num_workers: Number of worker processes.
|
534
|
+
image_shape: Target shape for images.
|
535
|
+
min_image_shape: Minimum acceptable shape for images.
|
536
|
+
frame_size: Target size for video frames.
|
537
|
+
min_frame_size: Minimum acceptable size for video frames.
|
538
|
+
num_frames: Target number of frames for videos.
|
539
|
+
num_threads: Number of threads per worker.
|
540
|
+
timeout: Timeout for fetching.
|
541
|
+
retries: Number of retries for fetching.
|
542
|
+
image_processor: Function to process images.
|
543
|
+
video_processor: Function to process videos.
|
544
|
+
upscale_interpolation: Interpolation method for upscaling.
|
545
|
+
downscale_interpolation: Interpolation method for downscaling.
|
546
|
+
feature_extractor: Function to extract features from samples.
|
547
|
+
"""
|
548
|
+
# Create mapping function
|
197
549
|
map_batch_fn = partial(
|
198
|
-
map_batch,
|
550
|
+
map_batch,
|
551
|
+
data_queue=data_queue,
|
552
|
+
media_type=media_type,
|
553
|
+
num_threads=num_threads,
|
554
|
+
image_shape=image_shape,
|
199
555
|
min_image_shape=min_image_shape,
|
200
|
-
|
556
|
+
frame_size=frame_size,
|
557
|
+
min_frame_size=min_frame_size,
|
558
|
+
num_frames=num_frames,
|
559
|
+
timeout=timeout,
|
560
|
+
retries=retries,
|
561
|
+
image_processor=image_processor,
|
562
|
+
video_processor=video_processor,
|
201
563
|
upscale_interpolation=upscale_interpolation,
|
202
564
|
downscale_interpolation=downscale_interpolation,
|
203
565
|
feature_extractor=feature_extractor
|
204
566
|
)
|
567
|
+
|
568
|
+
# Calculate shard length
|
205
569
|
shard_len = len(dataset) // num_workers
|
206
|
-
print(f"Local Shard
|
570
|
+
print(f"Local Shard length: {shard_len}")
|
571
|
+
|
572
|
+
# Process dataset in parallel
|
207
573
|
with multiprocessing.Pool(num_workers) as pool:
|
208
574
|
iteration = 0
|
209
575
|
while True:
|
210
|
-
#
|
211
|
-
shards = [dataset[i*shard_len:(i+1)*shard_len]
|
212
|
-
|
213
|
-
|
576
|
+
# Create shards for each worker
|
577
|
+
shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
|
578
|
+
print(f"Mapping {len(shards)} shards")
|
579
|
+
|
580
|
+
# Process shards in parallel
|
214
581
|
pool.map(map_batch_fn, shards)
|
582
|
+
|
583
|
+
# Shuffle dataset for next iteration
|
215
584
|
iteration += 1
|
216
585
|
print(f"Shuffling dataset with seed {iteration}")
|
217
586
|
dataset = dataset.shuffle(seed=iteration)
|
218
|
-
# Clear the error queue
|
219
|
-
# while not error_queue.empty():
|
220
|
-
# error_queue.get_nowait()
|
221
587
|
|
222
588
|
|
223
|
-
class
|
589
|
+
class MediaBatchIterator:
|
590
|
+
"""Iterator for batches of media samples."""
|
591
|
+
|
224
592
|
def __init__(
|
225
|
-
self,
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
|
593
|
+
self,
|
594
|
+
dataset: Dataset,
|
595
|
+
batch_size: int = 64,
|
596
|
+
media_type: str = "image",
|
597
|
+
image_shape: Tuple[int, int] = (256, 256),
|
598
|
+
min_image_shape: Tuple[int, int] = (128, 128),
|
599
|
+
frame_size: int = 256,
|
600
|
+
min_frame_size: int = 128,
|
601
|
+
num_frames: int = 16,
|
602
|
+
num_workers: int = 8,
|
603
|
+
num_threads: int = 256,
|
604
|
+
timeout: int = 15,
|
605
|
+
retries: int = 3,
|
606
|
+
image_processor: Callable = default_image_processor,
|
607
|
+
video_processor: Callable = default_video_processor,
|
608
|
+
upscale_interpolation: int = cv2.INTER_CUBIC,
|
609
|
+
downscale_interpolation: int = cv2.INTER_AREA,
|
610
|
+
feature_extractor: Callable = default_feature_extractor,
|
611
|
+
resource_manager: Optional[ResourceManager] = None,
|
232
612
|
):
|
613
|
+
"""Initialize a media batch iterator.
|
614
|
+
|
615
|
+
Args:
|
616
|
+
dataset: Dataset to iterate over.
|
617
|
+
batch_size: Batch size.
|
618
|
+
media_type: Type of media ("image" or "video").
|
619
|
+
image_shape: Target shape for images.
|
620
|
+
min_image_shape: Minimum acceptable shape for images.
|
621
|
+
frame_size: Target size for video frames.
|
622
|
+
min_frame_size: Minimum acceptable size for video frames.
|
623
|
+
num_frames: Target number of frames for videos.
|
624
|
+
num_workers: Number of worker processes.
|
625
|
+
num_threads: Number of threads per worker.
|
626
|
+
timeout: Timeout for fetching.
|
627
|
+
retries: Number of retries for fetching.
|
628
|
+
image_processor: Function to process images.
|
629
|
+
video_processor: Function to process videos.
|
630
|
+
upscale_interpolation: Interpolation method for upscaling.
|
631
|
+
downscale_interpolation: Interpolation method for downscaling.
|
632
|
+
feature_extractor: Function to extract features from samples.
|
633
|
+
resource_manager: Resource manager to use. Will create one if None.
|
634
|
+
"""
|
233
635
|
self.dataset = dataset
|
234
|
-
self.num_workers = num_workers
|
235
636
|
self.batch_size = batch_size
|
637
|
+
self.media_type = media_type
|
638
|
+
|
639
|
+
# Create or use resource manager
|
640
|
+
self.resource_manager = resource_manager or ResourceManager()
|
641
|
+
self.data_queue = self.resource_manager.get_data_queue()
|
642
|
+
|
643
|
+
# Start loader thread
|
236
644
|
loader = partial(
|
237
|
-
|
238
|
-
|
645
|
+
parallel_media_loader,
|
646
|
+
data_queue=self.data_queue,
|
647
|
+
media_type=media_type,
|
648
|
+
num_workers=num_workers,
|
239
649
|
image_shape=image_shape,
|
240
650
|
min_image_shape=min_image_shape,
|
241
|
-
|
242
|
-
|
651
|
+
frame_size=frame_size,
|
652
|
+
min_frame_size=min_frame_size,
|
653
|
+
num_frames=num_frames,
|
654
|
+
num_threads=num_threads,
|
655
|
+
timeout=timeout,
|
656
|
+
retries=retries,
|
243
657
|
image_processor=image_processor,
|
658
|
+
video_processor=video_processor,
|
244
659
|
upscale_interpolation=upscale_interpolation,
|
245
660
|
downscale_interpolation=downscale_interpolation,
|
246
661
|
feature_extractor=feature_extractor
|
247
662
|
)
|
248
|
-
|
663
|
+
|
664
|
+
# Start loader in background thread
|
665
|
+
self.thread = threading.Thread(target=loader, args=(dataset,), daemon=True)
|
249
666
|
self.thread.start()
|
250
667
|
|
251
668
|
def __iter__(self):
|
252
669
|
return self
|
253
670
|
|
254
671
|
def __next__(self):
|
672
|
+
"""Get the next batch of samples."""
|
255
673
|
def fetcher(_):
|
256
|
-
|
674
|
+
try:
|
675
|
+
return self.data_queue.get(timeout=60) # Add timeout to prevent hanging
|
676
|
+
except:
|
677
|
+
# Return a dummy sample on timeout
|
678
|
+
if self.media_type == "video":
|
679
|
+
return {
|
680
|
+
"url": "timeout",
|
681
|
+
"caption": "Timeout occurred while waiting for sample",
|
682
|
+
"video": np.zeros((4, 32, 32, 3), dtype=np.uint8),
|
683
|
+
"original_height": 32,
|
684
|
+
"original_width": 32,
|
685
|
+
}
|
686
|
+
else:
|
687
|
+
return {
|
688
|
+
"url": "timeout",
|
689
|
+
"caption": "Timeout occurred while waiting for sample",
|
690
|
+
"image": np.zeros((32, 32, 3), dtype=np.uint8),
|
691
|
+
"original_height": 32,
|
692
|
+
"original_width": 32,
|
693
|
+
}
|
694
|
+
|
695
|
+
# Fetch batch in parallel
|
257
696
|
with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
|
258
697
|
batch = list(executor.map(fetcher, range(self.batch_size)))
|
698
|
+
|
259
699
|
return batch
|
260
700
|
|
261
|
-
def __del__(self):
|
262
|
-
self.thread.join()
|
263
|
-
|
264
701
|
def __len__(self):
|
702
|
+
"""Get the number of batches in the dataset."""
|
265
703
|
return len(self.dataset) // self.batch_size
|
266
704
|
|
267
705
|
|
268
|
-
def
|
706
|
+
def default_image_collate(batch):
|
707
|
+
"""Default collate function for image batches.
|
708
|
+
|
709
|
+
Args:
|
710
|
+
batch: Batch of samples to collate.
|
711
|
+
|
712
|
+
Returns:
|
713
|
+
Collated batch.
|
714
|
+
"""
|
269
715
|
urls = [sample["url"] for sample in batch]
|
270
716
|
captions = [sample["caption"] for sample in batch]
|
271
|
-
|
717
|
+
|
718
|
+
# Check if all images have the same shape
|
719
|
+
image_shapes = [sample["image"].shape for sample in batch]
|
720
|
+
if len(set(str(shape) for shape in image_shapes)) > 1:
|
721
|
+
# Get max height and width
|
722
|
+
max_height = max(shape[0] for shape in image_shapes)
|
723
|
+
max_width = max(shape[1] for shape in image_shapes)
|
724
|
+
|
725
|
+
# Resize all images to the same shape
|
726
|
+
images = []
|
727
|
+
for sample in batch:
|
728
|
+
image = sample["image"]
|
729
|
+
height, width = image.shape[:2]
|
730
|
+
|
731
|
+
if height != max_height or width != max_width:
|
732
|
+
# Pad with white
|
733
|
+
padded_image = np.ones((max_height, max_width, 3), dtype=image.dtype) * 255
|
734
|
+
padded_image[:height, :width] = image
|
735
|
+
images.append(padded_image)
|
736
|
+
else:
|
737
|
+
images.append(image)
|
738
|
+
|
739
|
+
images = np.stack(images, axis=0)
|
740
|
+
else:
|
741
|
+
# All images have the same shape, just stack them
|
742
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
743
|
+
|
272
744
|
return {
|
273
745
|
"url": urls,
|
274
746
|
"caption": captions,
|
@@ -276,7 +748,83 @@ def default_collate(batch):
|
|
276
748
|
}
|
277
749
|
|
278
750
|
|
751
|
+
def default_video_collate(batch):
|
752
|
+
"""Default collate function for video batches.
|
753
|
+
|
754
|
+
Args:
|
755
|
+
batch: Batch of samples to collate.
|
756
|
+
|
757
|
+
Returns:
|
758
|
+
Collated batch.
|
759
|
+
"""
|
760
|
+
urls = [sample["url"] for sample in batch]
|
761
|
+
captions = [sample["caption"] for sample in batch]
|
762
|
+
|
763
|
+
# Check if all videos have the same shape
|
764
|
+
video_shapes = [sample["video"].shape for sample in batch]
|
765
|
+
if len(set(str(shape) for shape in video_shapes)) > 1:
|
766
|
+
# Get max dimensions
|
767
|
+
max_frames = max(shape[0] for shape in video_shapes)
|
768
|
+
max_height = max(shape[1] for shape in video_shapes)
|
769
|
+
max_width = max(shape[2] for shape in video_shapes)
|
770
|
+
|
771
|
+
# Resize all videos to the same shape
|
772
|
+
videos = []
|
773
|
+
for sample in batch:
|
774
|
+
video = sample["video"]
|
775
|
+
num_frames, height, width = video.shape[:3]
|
776
|
+
|
777
|
+
if num_frames != max_frames or height != max_height or width != max_width:
|
778
|
+
# Create a new video tensor with the max dimensions
|
779
|
+
padded_video = np.zeros((max_frames, max_height, max_width, 3), dtype=video.dtype)
|
780
|
+
|
781
|
+
# Copy the original video frames
|
782
|
+
padded_video[:num_frames, :height, :width] = video
|
783
|
+
|
784
|
+
# If we need more frames, duplicate the last frame
|
785
|
+
if num_frames < max_frames:
|
786
|
+
padded_video[num_frames:] = padded_video[num_frames-1:num_frames]
|
787
|
+
|
788
|
+
videos.append(padded_video)
|
789
|
+
else:
|
790
|
+
videos.append(video)
|
791
|
+
|
792
|
+
videos = np.stack(videos, axis=0)
|
793
|
+
else:
|
794
|
+
# All videos have the same shape, just stack them
|
795
|
+
videos = np.stack([sample["video"] for sample in batch], axis=0)
|
796
|
+
|
797
|
+
return {
|
798
|
+
"url": urls,
|
799
|
+
"caption": captions,
|
800
|
+
"video": videos,
|
801
|
+
}
|
802
|
+
|
803
|
+
|
804
|
+
def get_default_collate(media_type="image"):
|
805
|
+
"""Get the default collate function for a media type.
|
806
|
+
|
807
|
+
Args:
|
808
|
+
media_type: Type of media ("image" or "video").
|
809
|
+
|
810
|
+
Returns:
|
811
|
+
Collate function for the specified media type.
|
812
|
+
"""
|
813
|
+
if media_type == "video":
|
814
|
+
return default_video_collate
|
815
|
+
else: # Default to image
|
816
|
+
return default_image_collate
|
817
|
+
|
818
|
+
|
279
819
|
def dataMapper(map: Dict[str, Any]):
|
820
|
+
"""Create a function to map dataset samples to a standard format.
|
821
|
+
|
822
|
+
Args:
|
823
|
+
map: Dictionary mapping standard keys to dataset-specific keys.
|
824
|
+
|
825
|
+
Returns:
|
826
|
+
Function that maps a sample to the standard format.
|
827
|
+
"""
|
280
828
|
def _map(sample) -> Dict[str, Any]:
|
281
829
|
return {
|
282
830
|
"url": sample[map["url"]],
|
@@ -285,13 +833,19 @@ def dataMapper(map: Dict[str, Any]):
|
|
285
833
|
return _map
|
286
834
|
|
287
835
|
|
288
|
-
class OnlineStreamingDataLoader
|
836
|
+
class OnlineStreamingDataLoader:
|
837
|
+
"""Data loader for streaming media data from online sources."""
|
838
|
+
|
289
839
|
def __init__(
|
290
840
|
self,
|
291
841
|
dataset,
|
292
842
|
batch_size=64,
|
843
|
+
media_type="image",
|
293
844
|
image_shape=(256, 256),
|
294
845
|
min_image_shape=(128, 128),
|
846
|
+
frame_size=256,
|
847
|
+
min_frame_size=128,
|
848
|
+
num_frames=16,
|
295
849
|
num_workers=16,
|
296
850
|
num_threads=512,
|
297
851
|
default_split="all",
|
@@ -303,17 +857,49 @@ class OnlineStreamingDataLoader():
|
|
303
857
|
global_process_count=1,
|
304
858
|
global_process_index=0,
|
305
859
|
prefetch=1000,
|
306
|
-
collate_fn=
|
860
|
+
collate_fn=None,
|
307
861
|
timeout=15,
|
308
862
|
retries=3,
|
309
863
|
image_processor=default_image_processor,
|
864
|
+
video_processor=default_video_processor,
|
310
865
|
upscale_interpolation=cv2.INTER_CUBIC,
|
311
866
|
downscale_interpolation=cv2.INTER_AREA,
|
312
867
|
feature_extractor=default_feature_extractor,
|
868
|
+
resource_manager=None,
|
313
869
|
):
|
870
|
+
"""Initialize an online streaming data loader.
|
871
|
+
|
872
|
+
Args:
|
873
|
+
dataset: Dataset to load from, can be a path or a dataset object.
|
874
|
+
batch_size: Batch size.
|
875
|
+
media_type: Type of media ("image" or "video").
|
876
|
+
image_shape: Target shape for images.
|
877
|
+
min_image_shape: Minimum acceptable shape for images.
|
878
|
+
frame_size: Target size for video frames.
|
879
|
+
min_frame_size: Minimum acceptable size for video frames.
|
880
|
+
num_frames: Target number of frames for videos.
|
881
|
+
num_workers: Number of worker processes.
|
882
|
+
num_threads: Number of threads per worker.
|
883
|
+
default_split: Default split to use when loading datasets.
|
884
|
+
pre_map_maker: Function to create a mapping function.
|
885
|
+
pre_map_def: Default mapping definition.
|
886
|
+
global_process_count: Total number of processes.
|
887
|
+
global_process_index: Index of this process.
|
888
|
+
prefetch: Number of batches to prefetch.
|
889
|
+
collate_fn: Function to collate samples into batches.
|
890
|
+
timeout: Timeout for fetching.
|
891
|
+
retries: Number of retries for fetching.
|
892
|
+
image_processor: Function to process images.
|
893
|
+
video_processor: Function to process videos.
|
894
|
+
upscale_interpolation: Interpolation method for upscaling.
|
895
|
+
downscale_interpolation: Interpolation method for downscaling.
|
896
|
+
feature_extractor: Function to extract features from samples.
|
897
|
+
resource_manager: Resource manager to use.
|
898
|
+
"""
|
899
|
+
# Load dataset from path if needed
|
314
900
|
if isinstance(dataset, str):
|
315
901
|
dataset_path = dataset
|
316
|
-
print("Loading dataset from path")
|
902
|
+
print(f"Loading dataset from path: {dataset_path}")
|
317
903
|
if "gs://" in dataset:
|
318
904
|
dataset = load_from_disk(dataset_path)
|
319
905
|
else:
|
@@ -321,43 +907,86 @@ class OnlineStreamingDataLoader():
|
|
321
907
|
elif isinstance(dataset, list):
|
322
908
|
if isinstance(dataset[0], str):
|
323
909
|
print("Loading multiple datasets from paths")
|
324
|
-
dataset = [
|
325
|
-
dataset_path
|
326
|
-
|
910
|
+
dataset = [
|
911
|
+
load_from_disk(dataset_path) if "gs://" in dataset_path
|
912
|
+
else load_dataset(dataset_path, split=default_split)
|
913
|
+
for dataset_path in dataset
|
914
|
+
]
|
915
|
+
print(f"Concatenating {len(dataset)} datasets")
|
327
916
|
dataset = concatenate_datasets(dataset)
|
328
917
|
dataset = dataset.shuffle(seed=0)
|
329
|
-
|
918
|
+
|
919
|
+
# Shard dataset for distributed training
|
330
920
|
self.dataset = dataset.shard(
|
331
921
|
num_shards=global_process_count, index=global_process_index)
|
332
922
|
print(f"Dataset length: {len(dataset)}")
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
923
|
+
|
924
|
+
# Get or create resource manager
|
925
|
+
self.resource_manager = resource_manager or ResourceManager()
|
926
|
+
|
927
|
+
# Choose default collate function if not provided
|
928
|
+
if collate_fn is None:
|
929
|
+
collate_fn = get_default_collate(media_type)
|
930
|
+
|
931
|
+
# Create media batch iterator
|
932
|
+
self.iterator = MediaBatchIterator(
|
933
|
+
self.dataset,
|
934
|
+
batch_size=batch_size,
|
935
|
+
media_type=media_type,
|
936
|
+
image_shape=image_shape,
|
937
|
+
min_image_shape=min_image_shape,
|
938
|
+
frame_size=frame_size,
|
939
|
+
min_frame_size=min_frame_size,
|
940
|
+
num_frames=num_frames,
|
941
|
+
num_workers=num_workers,
|
942
|
+
num_threads=num_threads,
|
943
|
+
timeout=timeout,
|
944
|
+
retries=retries,
|
945
|
+
image_processor=image_processor,
|
946
|
+
video_processor=video_processor,
|
947
|
+
upscale_interpolation=upscale_interpolation,
|
948
|
+
downscale_interpolation=downscale_interpolation,
|
949
|
+
feature_extractor=feature_extractor,
|
950
|
+
resource_manager=self.resource_manager,
|
951
|
+
)
|
952
|
+
|
340
953
|
self.batch_size = batch_size
|
954
|
+
self.collate_fn = collate_fn
|
341
955
|
|
342
|
-
#
|
956
|
+
# Create batch queue for prefetching
|
343
957
|
self.batch_queue = queue.Queue(prefetch)
|
344
|
-
|
958
|
+
|
959
|
+
# Start batch loader thread
|
345
960
|
def batch_loader():
|
346
|
-
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
961
|
+
try:
|
962
|
+
for batch in self.iterator:
|
963
|
+
try:
|
964
|
+
if batch:
|
965
|
+
self.batch_queue.put(collate_fn(batch))
|
966
|
+
except Exception as e:
|
967
|
+
print(f"Error collating batch: {e}")
|
968
|
+
traceback.print_exc()
|
969
|
+
except Exception as e:
|
970
|
+
print(f"Error in batch loader thread: {e}")
|
971
|
+
traceback.print_exc()
|
972
|
+
|
973
|
+
self.loader_thread = threading.Thread(target=batch_loader, daemon=True)
|
353
974
|
self.loader_thread.start()
|
354
975
|
|
355
976
|
def __iter__(self):
|
977
|
+
"""Get an iterator for the data loader."""
|
356
978
|
return self
|
357
979
|
|
358
980
|
def __next__(self):
|
359
|
-
|
360
|
-
|
981
|
+
"""Get the next batch."""
|
982
|
+
try:
|
983
|
+
return self.batch_queue.get(timeout=60) # Add timeout to prevent hanging
|
984
|
+
except queue.Empty:
|
985
|
+
if not self.loader_thread.is_alive():
|
986
|
+
raise StopIteration("Loader thread died")
|
987
|
+
print("Timeout waiting for batch, retrying...")
|
988
|
+
return self.__next__()
|
361
989
|
|
362
990
|
def __len__(self):
|
991
|
+
"""Get the number of samples in the dataset."""
|
363
992
|
return len(self.dataset)
|