flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,12 @@
1
1
  import multiprocessing
2
2
  import threading
3
3
  from multiprocessing import Queue
4
- # from arrayqueues.shared_arrays import ArrayQueue
5
- # from faster_fifo import Queue
6
4
  import time
7
5
  import albumentations as A
8
6
  import queue
9
7
  import cv2
10
8
  from functools import partial
11
- from typing import Any, Dict, List, Tuple
9
+ from typing import Any, Dict, List, Tuple, Optional, Union, Callable
12
10
 
13
11
  import numpy as np
14
12
  from functools import partial
@@ -18,18 +16,42 @@ from datasets.utils.file_utils import get_datasets_user_agent
18
16
  from concurrent.futures import ThreadPoolExecutor
19
17
  import io
20
18
  import urllib
19
+ import os
21
20
 
22
21
  import PIL.Image
23
- import cv2
24
22
  import traceback
25
23
 
26
24
  USER_AGENT = get_datasets_user_agent()
27
25
 
28
- data_queue = Queue(16*2000)
26
+
27
+ class ResourceManager:
28
+ """A manager for shared resources across data loading processes."""
29
+
30
+ def __init__(self, max_queue_size: int = 32000):
31
+ """Initialize a resource manager.
32
+
33
+ Args:
34
+ max_queue_size: Maximum size of the data queue.
35
+ """
36
+ self.data_queue = Queue(max_queue_size)
37
+
38
+ def get_data_queue(self) -> Queue:
39
+ """Get the data queue."""
40
+ return self.data_queue
29
41
 
30
42
 
31
- def fetch_single_image(image_url, timeout=None, retries=0):
32
- for _ in range(retries + 1):
43
+ def fetch_single_image(image_url: str, timeout: Optional[int] = None, retries: int = 0) -> Optional[PIL.Image.Image]:
44
+ """Fetch a single image from a URL.
45
+
46
+ Args:
47
+ image_url: URL of the image to fetch.
48
+ timeout: Timeout in seconds for the request.
49
+ retries: Number of times to retry the request.
50
+
51
+ Returns:
52
+ A PIL image or None if the image couldn't be fetched.
53
+ """
54
+ for attempt in range(retries + 1):
33
55
  try:
34
56
  request = urllib.request.Request(
35
57
  image_url,
@@ -38,38 +60,135 @@ def fetch_single_image(image_url, timeout=None, retries=0):
38
60
  )
39
61
  with urllib.request.urlopen(request, timeout=timeout) as req:
40
62
  image = PIL.Image.open(io.BytesIO(req.read()))
41
- break
42
- except Exception:
43
- image = None
44
- return image
63
+ return image
64
+ except Exception as e:
65
+ if attempt < retries:
66
+ # Wait a bit before retrying
67
+ time.sleep(0.1 * (attempt + 1))
68
+ continue
69
+ # Log the error on the final attempt
70
+ print(f"Error fetching image {image_url}: {e}")
71
+ return None
72
+
73
+
74
+ def fetch_single_video(video_url: str, timeout: Optional[int] = None, retries: int = 0,
75
+ max_frames: int = 32) -> Optional[List[np.ndarray]]:
76
+ """Fetch a single video from a URL.
77
+
78
+ Args:
79
+ video_url: URL of the video to fetch.
80
+ timeout: Timeout in seconds for the request.
81
+ retries: Number of times to retry the request.
82
+ max_frames: Maximum number of frames to extract.
83
+
84
+ Returns:
85
+ A list of video frames as numpy arrays or None if the video couldn't be fetched.
86
+ """
87
+ # Create a temporary file to download the video
88
+ import tempfile
89
+
90
+ for attempt in range(retries + 1):
91
+ try:
92
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_file:
93
+ tmp_path = tmp_file.name
94
+
95
+ request = urllib.request.Request(
96
+ video_url,
97
+ data=None,
98
+ headers={"user-agent": USER_AGENT},
99
+ )
100
+ with urllib.request.urlopen(request, timeout=timeout) as req:
101
+ with open(tmp_path, 'wb') as f:
102
+ f.write(req.read())
103
+
104
+ # Load the video frames
105
+ cap = cv2.VideoCapture(tmp_path)
106
+ frames = []
107
+
108
+ while len(frames) < max_frames:
109
+ ret, frame = cap.read()
110
+ if not ret:
111
+ break
112
+ frames.append(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB))
113
+
114
+ cap.release()
115
+
116
+ # Delete the temporary file
117
+ try:
118
+ os.remove(tmp_path)
119
+ except:
120
+ pass
121
+
122
+ return frames if frames else None
123
+
124
+ except Exception as e:
125
+ if attempt < retries:
126
+ # Wait a bit before retrying
127
+ time.sleep(0.1 * (attempt + 1))
128
+ continue
129
+ # Log the error on the final attempt
130
+ print(f"Error fetching video {video_url}: {e}")
131
+
132
+ # Clean up the temporary file
133
+ try:
134
+ if 'tmp_path' in locals():
135
+ os.remove(tmp_path)
136
+ except:
137
+ pass
138
+
139
+ return None
45
140
 
46
141
 
47
142
  def default_image_processor(
48
- image, image_shape,
49
- min_image_shape=(128, 128),
50
- upscale_interpolation=cv2.INTER_CUBIC,
51
- downscale_interpolation=cv2.INTER_AREA,
52
- ):
143
+ image: PIL.Image.Image,
144
+ image_shape: Tuple[int, int],
145
+ min_image_shape: Tuple[int, int] = (128, 128),
146
+ upscale_interpolation: int = cv2.INTER_CUBIC,
147
+ downscale_interpolation: int = cv2.INTER_AREA,
148
+ ) -> Tuple[Optional[np.ndarray], int, int]:
149
+ """Process an image for training.
150
+
151
+ Args:
152
+ image: PIL image to process.
153
+ image_shape: Target shape (height, width).
154
+ min_image_shape: Minimum acceptable shape.
155
+ upscale_interpolation: Interpolation method for upscaling.
156
+ downscale_interpolation: Interpolation method for downscaling.
157
+
158
+ Returns:
159
+ Tuple of (processed image, original height, original width).
160
+ Processed image may be None if the image couldn't be processed.
161
+ """
53
162
  try:
163
+ # Convert to numpy
54
164
  image = np.array(image)
165
+
166
+ # Check if image has 3 channels
55
167
  if len(image.shape) != 3 or image.shape[2] != 3:
56
168
  return None, 0, 0
169
+
57
170
  original_height, original_width = image.shape[:2]
58
- # check if the image is too small
171
+
172
+ # Check if the image is too small
59
173
  if min(original_height, original_width) < min(min_image_shape):
60
174
  return None, original_height, original_width
61
- # check if wrong aspect ratio
175
+
176
+ # Check if wrong aspect ratio
62
177
  if max(original_height, original_width) / min(original_height, original_width) > 2.4:
63
178
  return None, original_height, original_width
64
- # check if the variance is too low
179
+
180
+ # Check if the variance is too low (likely a blank/solid color image)
65
181
  if np.std(image) < 1e-5:
66
182
  return None, original_height, original_width
67
- # image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
183
+
184
+ # Choose interpolation method based on whether we're upscaling or downscaling
68
185
  downscale = max(original_width, original_height) > max(image_shape)
69
186
  interpolation = downscale_interpolation if downscale else upscale_interpolation
70
187
 
71
- image = A.longest_max_size(image, max(
72
- image_shape), interpolation=interpolation)
188
+ # Resize while keeping aspect ratio
189
+ image = A.longest_max_size(image, max(image_shape), interpolation=interpolation)
190
+
191
+ # Pad to target shape
73
192
  image = A.pad(
74
193
  image,
75
194
  min_height=image_shape[0],
@@ -77,30 +196,114 @@ def default_image_processor(
77
196
  border_mode=cv2.BORDER_CONSTANT,
78
197
  value=[255, 255, 255],
79
198
  )
199
+
80
200
  return image, original_height, original_width
201
+
202
+ except Exception as e:
203
+ # Log the error
204
+ print(f"Error processing image: {e}")
205
+ return None, 0, 0
206
+
207
+
208
+ def default_video_processor(
209
+ frames: List[np.ndarray],
210
+ frame_size: int = 256,
211
+ min_frame_size: int = 128,
212
+ num_frames: int = 16,
213
+ upscale_interpolation: int = cv2.INTER_CUBIC,
214
+ downscale_interpolation: int = cv2.INTER_AREA,
215
+ ) -> Tuple[Optional[np.ndarray], int, int]:
216
+ """Process video frames for training.
217
+
218
+ Args:
219
+ frames: List of video frames as numpy arrays.
220
+ frame_size: Target size for each frame.
221
+ min_frame_size: Minimum acceptable frame size.
222
+ num_frames: Target number of frames.
223
+ upscale_interpolation: Interpolation method for upscaling.
224
+ downscale_interpolation: Interpolation method for downscaling.
225
+
226
+ Returns:
227
+ Tuple of (processed video array, original height, original width).
228
+ Processed video may be None if the video couldn't be processed.
229
+ """
230
+ try:
231
+ if not frames or len(frames) == 0:
232
+ return None, 0, 0
233
+
234
+ # Get dimensions of the first frame
235
+ first_frame = frames[0]
236
+ original_height, original_width = first_frame.shape[:2]
237
+
238
+ # Check if frames are too small
239
+ if min(original_height, original_width) < min_frame_size:
240
+ return None, original_height, original_width
241
+
242
+ # Sample frames evenly
243
+ if len(frames) < num_frames:
244
+ # Not enough frames, duplicate some
245
+ indices = np.linspace(0, len(frames) - 1, num_frames, dtype=int)
246
+ sampled_frames = [frames[i] for i in indices]
247
+ else:
248
+ # Sample frames evenly
249
+ indices = np.linspace(0, len(frames) - 1, num_frames, dtype=int)
250
+ sampled_frames = [frames[i] for i in indices]
251
+
252
+ # Process each frame
253
+ processed_frames = []
254
+ for frame in sampled_frames:
255
+ # Choose interpolation method based on whether we're upscaling or downscaling
256
+ downscale = max(frame.shape[1], frame.shape[0]) > frame_size
257
+ interpolation = downscale_interpolation if downscale else upscale_interpolation
258
+
259
+ # Resize frame
260
+ resized_frame = cv2.resize(frame, (frame_size, frame_size), interpolation=interpolation)
261
+ processed_frames.append(resized_frame)
262
+
263
+ # Stack frames into a video tensor [num_frames, height, width, channels]
264
+ video_tensor = np.stack(processed_frames, axis=0)
265
+
266
+ return video_tensor, original_height, original_width
267
+
81
268
  except Exception as e:
82
- # print("Error processing image", e, image_shape, interpolation)
83
- # traceback.print_exc()
269
+ # Log the error
270
+ print(f"Error processing video: {e}")
84
271
  return None, 0, 0
85
272
 
86
273
 
87
- def map_sample(
88
- url,
89
- caption,
90
- image_shape=(256, 256),
91
- min_image_shape=(128, 128),
92
- timeout=15,
93
- retries=3,
94
- upscale_interpolation=cv2.INTER_CUBIC,
95
- downscale_interpolation=cv2.INTER_AREA,
96
- image_processor=default_image_processor,
274
+ def map_image_sample(
275
+ url: str,
276
+ caption: str,
277
+ data_queue: Queue,
278
+ image_shape: Tuple[int, int] = (256, 256),
279
+ min_image_shape: Tuple[int, int] = (128, 128),
280
+ timeout: int = 15,
281
+ retries: int = 3,
282
+ upscale_interpolation: int = cv2.INTER_CUBIC,
283
+ downscale_interpolation: int = cv2.INTER_AREA,
284
+ image_processor: Callable = default_image_processor,
97
285
  ):
286
+ """Process a single image sample and put it in the queue.
287
+
288
+ Args:
289
+ url: URL of the image.
290
+ caption: Caption for the image.
291
+ data_queue: Queue to put the processed sample in.
292
+ image_shape: Target shape for the image.
293
+ min_image_shape: Minimum acceptable shape.
294
+ timeout: Timeout for image fetching.
295
+ retries: Number of retries for image fetching.
296
+ upscale_interpolation: Interpolation method for upscaling.
297
+ downscale_interpolation: Interpolation method for downscaling.
298
+ image_processor: Function to process the image.
299
+ """
98
300
  try:
99
- # Assuming fetch_single_image is defined elsewhere
301
+ # Fetch the image
100
302
  image = fetch_single_image(url, timeout=timeout, retries=retries)
101
303
  if image is None:
102
304
  return
103
305
 
306
+ # Process the image
104
307
  image, original_height, original_width = image_processor(
105
308
  image, image_shape, min_image_shape=min_image_shape,
106
309
  upscale_interpolation=upscale_interpolation,
@@ -110,6 +313,7 @@ def map_sample(
110
313
  if image is None:
111
314
  return
112
315
 
316
+ # Put the processed sample in the queue
113
317
  data_queue.put({
114
318
  "url": url,
115
319
  "caption": caption,
@@ -117,158 +321,426 @@ def map_sample(
117
321
  "original_height": original_height,
118
322
  "original_width": original_width,
119
323
  })
324
+
120
325
  except Exception as e:
121
- # print(f"Error maping sample {url}", e)
122
- # traceback.print_exc()
123
- # error_queue.put_nowait({
124
- # "url": url,
125
- # "caption": caption,
126
- # "error": str(e)
127
- # })
128
- pass
129
-
130
- def default_feature_extractor(sample):
326
+ # Log the error
327
+ print(f"Error mapping image sample {url}: {e}")
328
+
329
+
330
+ def map_video_sample(
331
+ url: str,
332
+ caption: str,
333
+ data_queue: Queue,
334
+ frame_size: int = 256,
335
+ min_frame_size: int = 128,
336
+ num_frames: int = 16,
337
+ timeout: int = 30,
338
+ retries: int = 3,
339
+ upscale_interpolation: int = cv2.INTER_CUBIC,
340
+ downscale_interpolation: int = cv2.INTER_AREA,
341
+ video_processor: Callable = default_video_processor,
342
+ ):
343
+ """Process a single video sample and put it in the queue.
344
+
345
+ Args:
346
+ url: URL of the video.
347
+ caption: Caption for the video.
348
+ data_queue: Queue to put the processed sample in.
349
+ frame_size: Target size for each frame.
350
+ min_frame_size: Minimum acceptable frame size.
351
+ num_frames: Target number of frames.
352
+ timeout: Timeout for video fetching.
353
+ retries: Number of retries for video fetching.
354
+ upscale_interpolation: Interpolation method for upscaling.
355
+ downscale_interpolation: Interpolation method for downscaling.
356
+ video_processor: Function to process the video.
357
+ """
358
+ try:
359
+ # Fetch the video frames
360
+ frames = fetch_single_video(url, timeout=timeout, retries=retries, max_frames=num_frames*2)
361
+ if frames is None or len(frames) == 0:
362
+ return
363
+
364
+ # Process the video
365
+ video, original_height, original_width = video_processor(
366
+ frames, frame_size, min_frame_size=min_frame_size,
367
+ num_frames=num_frames,
368
+ upscale_interpolation=upscale_interpolation,
369
+ downscale_interpolation=downscale_interpolation,
370
+ )
371
+
372
+ if video is None:
373
+ return
374
+
375
+ # Put the processed sample in the queue
376
+ data_queue.put({
377
+ "url": url,
378
+ "caption": caption,
379
+ "video": video,
380
+ "original_height": original_height,
381
+ "original_width": original_width,
382
+ })
383
+
384
+ except Exception as e:
385
+ # Log the error
386
+ print(f"Error mapping video sample {url}: {e}")
387
+
388
+
389
+ def default_feature_extractor(sample: Dict[str, Any]) -> Dict[str, Any]:
390
+ """Extract features from a sample.
391
+
392
+ Args:
393
+ sample: Sample to extract features from.
394
+
395
+ Returns:
396
+ Dictionary with extracted url and caption.
397
+ """
398
+ # Extract URL
131
399
  url = None
132
- if "url" in sample:
133
- url = sample["url"]
134
- elif "URL" in sample:
135
- url = sample["URL"]
136
- elif "image_url" in sample:
137
- url = sample["image_url"]
138
- else:
139
- print("No url found in sample, skipping", sample.keys())
400
+ for key in ["url", "URL", "image_url", "video_url"]:
401
+ if key in sample:
402
+ url = sample[key]
403
+ break
404
+
405
+ if url is None:
406
+ print("No URL found in sample, keys:", sample.keys())
407
+ return {"url": None, "caption": None}
140
408
 
409
+ # Extract caption
141
410
  caption = None
142
- if "caption" in sample:
143
- caption = sample["caption"]
144
- elif "CAPTION" in sample:
145
- caption = sample["CAPTION"]
146
- elif "txt" in sample:
147
- caption = sample["txt"]
148
- elif "TEXT" in sample:
149
- caption = sample["TEXT"]
150
- elif "text" in sample:
151
- caption = sample["text"]
152
- else:
153
- print("No caption found in sample, skipping", sample.keys())
411
+ for key in ["caption", "CAPTION", "txt", "TEXT", "text"]:
412
+ if key in sample and sample[key] is not None:
413
+ caption = sample[key]
414
+ break
415
+
416
+ if caption is None:
417
+ caption = "No caption available"
154
418
 
155
419
  return {
156
420
  "url": url,
157
421
  "caption": caption,
158
422
  }
159
423
 
424
+
160
425
  def map_batch(
161
- batch, num_threads=256, image_shape=(256, 256),
162
- min_image_shape=(128, 128),
163
- timeout=15, retries=3, image_processor=default_image_processor,
164
- upscale_interpolation=cv2.INTER_CUBIC,
165
- downscale_interpolation=cv2.INTER_AREA,
166
- feature_extractor=default_feature_extractor,
426
+ batch: Dict[str, Any],
427
+ data_queue: Queue,
428
+ media_type: str = "image",
429
+ num_threads: int = 256,
430
+ image_shape: Tuple[int, int] = (256, 256),
431
+ min_image_shape: Tuple[int, int] = (128, 128),
432
+ frame_size: int = 256,
433
+ min_frame_size: int = 128,
434
+ num_frames: int = 16,
435
+ timeout: int = 15,
436
+ retries: int = 3,
437
+ image_processor: Callable = default_image_processor,
438
+ video_processor: Callable = default_video_processor,
439
+ upscale_interpolation: int = cv2.INTER_CUBIC,
440
+ downscale_interpolation: int = cv2.INTER_AREA,
441
+ feature_extractor: Callable = default_feature_extractor,
167
442
  ):
443
+ """Map a batch of samples and process them in parallel.
444
+
445
+ Args:
446
+ batch: Batch of samples to process.
447
+ data_queue: Queue to put processed samples in.
448
+ media_type: Type of media ("image" or "video").
449
+ num_threads: Number of threads to use for processing.
450
+ image_shape: Target shape for images.
451
+ min_image_shape: Minimum acceptable shape for images.
452
+ frame_size: Target size for video frames.
453
+ min_frame_size: Minimum acceptable size for video frames.
454
+ num_frames: Target number of frames for videos.
455
+ timeout: Timeout for fetching.
456
+ retries: Number of retries for fetching.
457
+ image_processor: Function to process images.
458
+ video_processor: Function to process videos.
459
+ upscale_interpolation: Interpolation method for upscaling.
460
+ downscale_interpolation: Interpolation method for downscaling.
461
+ feature_extractor: Function to extract features from samples.
462
+ """
168
463
  try:
169
- map_sample_fn = partial(
170
- map_sample, image_shape=image_shape, min_image_shape=min_image_shape,
171
- timeout=timeout, retries=retries, image_processor=image_processor,
172
- upscale_interpolation=upscale_interpolation,
173
- downscale_interpolation=downscale_interpolation
174
- )
464
+ # Choose mapping function based on media type
465
+ if media_type == "video":
466
+ map_func = partial(
467
+ map_video_sample,
468
+ data_queue=data_queue,
469
+ frame_size=frame_size,
470
+ min_frame_size=min_frame_size,
471
+ num_frames=num_frames,
472
+ timeout=timeout,
473
+ retries=retries,
474
+ video_processor=video_processor,
475
+ upscale_interpolation=upscale_interpolation,
476
+ downscale_interpolation=downscale_interpolation,
477
+ )
478
+ else: # Default to image
479
+ map_func = partial(
480
+ map_image_sample,
481
+ data_queue=data_queue,
482
+ image_shape=image_shape,
483
+ min_image_shape=min_image_shape,
484
+ timeout=timeout,
485
+ retries=retries,
486
+ image_processor=image_processor,
487
+ upscale_interpolation=upscale_interpolation,
488
+ downscale_interpolation=downscale_interpolation,
489
+ )
490
+
491
+ # Extract features from batch
492
+ features = feature_extractor(batch)
493
+ urls, captions = features["url"], features["caption"]
494
+
495
+ if urls is None or captions is None:
496
+ return
497
+
498
+ # Process samples in parallel
175
499
  with ThreadPoolExecutor(max_workers=num_threads) as executor:
176
- features = feature_extractor(batch)
177
- url, caption = features["url"], features["caption"]
178
- executor.map(map_sample_fn, url, caption)
500
+ executor.map(map_func, urls, captions)
501
+
179
502
  except Exception as e:
180
- print(f"Error maping batch", e)
503
+ # Log the error
504
+ print(f"Error mapping batch: {e}")
181
505
  traceback.print_exc()
182
- # error_queue.put_nowait({
183
- # "batch": batch,
184
- # "error": str(e)
185
- # })
186
- pass
187
-
188
-
189
- def parallel_image_loader(
190
- dataset: Dataset, num_workers: int = 8, image_shape=(256, 256),
191
- min_image_shape=(128, 128),
192
- num_threads=256, timeout=15, retries=3, image_processor=default_image_processor,
193
- upscale_interpolation=cv2.INTER_CUBIC,
194
- downscale_interpolation=cv2.INTER_AREA,
195
- feature_extractor=default_feature_extractor,
506
+
507
+
508
+ def parallel_media_loader(
509
+ dataset: Dataset,
510
+ data_queue: Queue,
511
+ media_type: str = "image",
512
+ num_workers: int = 8,
513
+ image_shape: Tuple[int, int] = (256, 256),
514
+ min_image_shape: Tuple[int, int] = (128, 128),
515
+ frame_size: int = 256,
516
+ min_frame_size: int = 128,
517
+ num_frames: int = 16,
518
+ num_threads: int = 256,
519
+ timeout: int = 15,
520
+ retries: int = 3,
521
+ image_processor: Callable = default_image_processor,
522
+ video_processor: Callable = default_video_processor,
523
+ upscale_interpolation: int = cv2.INTER_CUBIC,
524
+ downscale_interpolation: int = cv2.INTER_AREA,
525
+ feature_extractor: Callable = default_feature_extractor,
196
526
  ):
527
+ """Load and process media from a dataset in parallel.
528
+
529
+ Args:
530
+ dataset: Dataset to load from.
531
+ data_queue: Queue to put processed samples in.
532
+ media_type: Type of media ("image" or "video").
533
+ num_workers: Number of worker processes.
534
+ image_shape: Target shape for images.
535
+ min_image_shape: Minimum acceptable shape for images.
536
+ frame_size: Target size for video frames.
537
+ min_frame_size: Minimum acceptable size for video frames.
538
+ num_frames: Target number of frames for videos.
539
+ num_threads: Number of threads per worker.
540
+ timeout: Timeout for fetching.
541
+ retries: Number of retries for fetching.
542
+ image_processor: Function to process images.
543
+ video_processor: Function to process videos.
544
+ upscale_interpolation: Interpolation method for upscaling.
545
+ downscale_interpolation: Interpolation method for downscaling.
546
+ feature_extractor: Function to extract features from samples.
547
+ """
548
+ # Create mapping function
197
549
  map_batch_fn = partial(
198
- map_batch, num_threads=num_threads, image_shape=image_shape,
550
+ map_batch,
551
+ data_queue=data_queue,
552
+ media_type=media_type,
553
+ num_threads=num_threads,
554
+ image_shape=image_shape,
199
555
  min_image_shape=min_image_shape,
200
- timeout=timeout, retries=retries, image_processor=image_processor,
556
+ frame_size=frame_size,
557
+ min_frame_size=min_frame_size,
558
+ num_frames=num_frames,
559
+ timeout=timeout,
560
+ retries=retries,
561
+ image_processor=image_processor,
562
+ video_processor=video_processor,
201
563
  upscale_interpolation=upscale_interpolation,
202
564
  downscale_interpolation=downscale_interpolation,
203
565
  feature_extractor=feature_extractor
204
566
  )
567
+
568
+ # Calculate shard length
205
569
  shard_len = len(dataset) // num_workers
206
- print(f"Local Shard lengths: {shard_len}")
570
+ print(f"Local Shard length: {shard_len}")
571
+
572
+ # Process dataset in parallel
207
573
  with multiprocessing.Pool(num_workers) as pool:
208
574
  iteration = 0
209
575
  while True:
210
- # Repeat forever
211
- shards = [dataset[i*shard_len:(i+1)*shard_len]
212
- for i in range(num_workers)]
213
- print(f"mapping {len(shards)} shards")
576
+ # Create shards for each worker
577
+ shards = [dataset[i*shard_len:(i+1)*shard_len] for i in range(num_workers)]
578
+ print(f"Mapping {len(shards)} shards")
579
+
580
+ # Process shards in parallel
214
581
  pool.map(map_batch_fn, shards)
582
+
583
+ # Shuffle dataset for next iteration
215
584
  iteration += 1
216
585
  print(f"Shuffling dataset with seed {iteration}")
217
586
  dataset = dataset.shuffle(seed=iteration)
218
- # Clear the error queue
219
- # while not error_queue.empty():
220
- # error_queue.get_nowait()
221
587
 
222
588
 
223
- class ImageBatchIterator:
589
+ class MediaBatchIterator:
590
+ """Iterator for batches of media samples."""
591
+
224
592
  def __init__(
225
- self, dataset: Dataset, batch_size: int = 64, image_shape=(256, 256),
226
- min_image_shape=(128, 128),
227
- num_workers: int = 8, num_threads=256, timeout=15, retries=3,
228
- image_processor=default_image_processor,
229
- upscale_interpolation=cv2.INTER_CUBIC,
230
- downscale_interpolation=cv2.INTER_AREA,
231
- feature_extractor=default_feature_extractor,
593
+ self,
594
+ dataset: Dataset,
595
+ batch_size: int = 64,
596
+ media_type: str = "image",
597
+ image_shape: Tuple[int, int] = (256, 256),
598
+ min_image_shape: Tuple[int, int] = (128, 128),
599
+ frame_size: int = 256,
600
+ min_frame_size: int = 128,
601
+ num_frames: int = 16,
602
+ num_workers: int = 8,
603
+ num_threads: int = 256,
604
+ timeout: int = 15,
605
+ retries: int = 3,
606
+ image_processor: Callable = default_image_processor,
607
+ video_processor: Callable = default_video_processor,
608
+ upscale_interpolation: int = cv2.INTER_CUBIC,
609
+ downscale_interpolation: int = cv2.INTER_AREA,
610
+ feature_extractor: Callable = default_feature_extractor,
611
+ resource_manager: Optional[ResourceManager] = None,
232
612
  ):
613
+ """Initialize a media batch iterator.
614
+
615
+ Args:
616
+ dataset: Dataset to iterate over.
617
+ batch_size: Batch size.
618
+ media_type: Type of media ("image" or "video").
619
+ image_shape: Target shape for images.
620
+ min_image_shape: Minimum acceptable shape for images.
621
+ frame_size: Target size for video frames.
622
+ min_frame_size: Minimum acceptable size for video frames.
623
+ num_frames: Target number of frames for videos.
624
+ num_workers: Number of worker processes.
625
+ num_threads: Number of threads per worker.
626
+ timeout: Timeout for fetching.
627
+ retries: Number of retries for fetching.
628
+ image_processor: Function to process images.
629
+ video_processor: Function to process videos.
630
+ upscale_interpolation: Interpolation method for upscaling.
631
+ downscale_interpolation: Interpolation method for downscaling.
632
+ feature_extractor: Function to extract features from samples.
633
+ resource_manager: Resource manager to use. Will create one if None.
634
+ """
233
635
  self.dataset = dataset
234
- self.num_workers = num_workers
235
636
  self.batch_size = batch_size
637
+ self.media_type = media_type
638
+
639
+ # Create or use resource manager
640
+ self.resource_manager = resource_manager or ResourceManager()
641
+ self.data_queue = self.resource_manager.get_data_queue()
642
+
643
+ # Start loader thread
236
644
  loader = partial(
237
- parallel_image_loader,
238
- num_threads=num_threads,
645
+ parallel_media_loader,
646
+ data_queue=self.data_queue,
647
+ media_type=media_type,
648
+ num_workers=num_workers,
239
649
  image_shape=image_shape,
240
650
  min_image_shape=min_image_shape,
241
- num_workers=num_workers,
242
- timeout=timeout, retries=retries,
651
+ frame_size=frame_size,
652
+ min_frame_size=min_frame_size,
653
+ num_frames=num_frames,
654
+ num_threads=num_threads,
655
+ timeout=timeout,
656
+ retries=retries,
243
657
  image_processor=image_processor,
658
+ video_processor=video_processor,
244
659
  upscale_interpolation=upscale_interpolation,
245
660
  downscale_interpolation=downscale_interpolation,
246
661
  feature_extractor=feature_extractor
247
662
  )
248
- self.thread = threading.Thread(target=loader, args=(dataset,))
663
+
664
+ # Start loader in background thread
665
+ self.thread = threading.Thread(target=loader, args=(dataset,), daemon=True)
249
666
  self.thread.start()
250
667
 
251
668
  def __iter__(self):
252
669
  return self
253
670
 
254
671
  def __next__(self):
672
+ """Get the next batch of samples."""
255
673
  def fetcher(_):
256
- return data_queue.get()
674
+ try:
675
+ return self.data_queue.get(timeout=60) # Add timeout to prevent hanging
676
+ except:
677
+ # Return a dummy sample on timeout
678
+ if self.media_type == "video":
679
+ return {
680
+ "url": "timeout",
681
+ "caption": "Timeout occurred while waiting for sample",
682
+ "video": np.zeros((4, 32, 32, 3), dtype=np.uint8),
683
+ "original_height": 32,
684
+ "original_width": 32,
685
+ }
686
+ else:
687
+ return {
688
+ "url": "timeout",
689
+ "caption": "Timeout occurred while waiting for sample",
690
+ "image": np.zeros((32, 32, 3), dtype=np.uint8),
691
+ "original_height": 32,
692
+ "original_width": 32,
693
+ }
694
+
695
+ # Fetch batch in parallel
257
696
  with ThreadPoolExecutor(max_workers=self.batch_size) as executor:
258
697
  batch = list(executor.map(fetcher, range(self.batch_size)))
698
+
259
699
  return batch
260
700
 
261
- def __del__(self):
262
- self.thread.join()
263
-
264
701
  def __len__(self):
702
+ """Get the number of batches in the dataset."""
265
703
  return len(self.dataset) // self.batch_size
266
704
 
267
705
 
268
- def default_collate(batch):
706
+ def default_image_collate(batch):
707
+ """Default collate function for image batches.
708
+
709
+ Args:
710
+ batch: Batch of samples to collate.
711
+
712
+ Returns:
713
+ Collated batch.
714
+ """
269
715
  urls = [sample["url"] for sample in batch]
270
716
  captions = [sample["caption"] for sample in batch]
271
- images = np.stack([sample["image"] for sample in batch], axis=0)
717
+
718
+ # Check if all images have the same shape
719
+ image_shapes = [sample["image"].shape for sample in batch]
720
+ if len(set(str(shape) for shape in image_shapes)) > 1:
721
+ # Get max height and width
722
+ max_height = max(shape[0] for shape in image_shapes)
723
+ max_width = max(shape[1] for shape in image_shapes)
724
+
725
+ # Resize all images to the same shape
726
+ images = []
727
+ for sample in batch:
728
+ image = sample["image"]
729
+ height, width = image.shape[:2]
730
+
731
+ if height != max_height or width != max_width:
732
+ # Pad with white
733
+ padded_image = np.ones((max_height, max_width, 3), dtype=image.dtype) * 255
734
+ padded_image[:height, :width] = image
735
+ images.append(padded_image)
736
+ else:
737
+ images.append(image)
738
+
739
+ images = np.stack(images, axis=0)
740
+ else:
741
+ # All images have the same shape, just stack them
742
+ images = np.stack([sample["image"] for sample in batch], axis=0)
743
+
272
744
  return {
273
745
  "url": urls,
274
746
  "caption": captions,
@@ -276,7 +748,83 @@ def default_collate(batch):
276
748
  }
277
749
 
278
750
 
751
+ def default_video_collate(batch):
752
+ """Default collate function for video batches.
753
+
754
+ Args:
755
+ batch: Batch of samples to collate.
756
+
757
+ Returns:
758
+ Collated batch.
759
+ """
760
+ urls = [sample["url"] for sample in batch]
761
+ captions = [sample["caption"] for sample in batch]
762
+
763
+ # Check if all videos have the same shape
764
+ video_shapes = [sample["video"].shape for sample in batch]
765
+ if len(set(str(shape) for shape in video_shapes)) > 1:
766
+ # Get max dimensions
767
+ max_frames = max(shape[0] for shape in video_shapes)
768
+ max_height = max(shape[1] for shape in video_shapes)
769
+ max_width = max(shape[2] for shape in video_shapes)
770
+
771
+ # Resize all videos to the same shape
772
+ videos = []
773
+ for sample in batch:
774
+ video = sample["video"]
775
+ num_frames, height, width = video.shape[:3]
776
+
777
+ if num_frames != max_frames or height != max_height or width != max_width:
778
+ # Create a new video tensor with the max dimensions
779
+ padded_video = np.zeros((max_frames, max_height, max_width, 3), dtype=video.dtype)
780
+
781
+ # Copy the original video frames
782
+ padded_video[:num_frames, :height, :width] = video
783
+
784
+ # If we need more frames, duplicate the last frame
785
+ if num_frames < max_frames:
786
+ padded_video[num_frames:] = padded_video[num_frames-1:num_frames]
787
+
788
+ videos.append(padded_video)
789
+ else:
790
+ videos.append(video)
791
+
792
+ videos = np.stack(videos, axis=0)
793
+ else:
794
+ # All videos have the same shape, just stack them
795
+ videos = np.stack([sample["video"] for sample in batch], axis=0)
796
+
797
+ return {
798
+ "url": urls,
799
+ "caption": captions,
800
+ "video": videos,
801
+ }
802
+
803
+
804
+ def get_default_collate(media_type="image"):
805
+ """Get the default collate function for a media type.
806
+
807
+ Args:
808
+ media_type: Type of media ("image" or "video").
809
+
810
+ Returns:
811
+ Collate function for the specified media type.
812
+ """
813
+ if media_type == "video":
814
+ return default_video_collate
815
+ else: # Default to image
816
+ return default_image_collate
817
+
818
+
279
819
  def dataMapper(map: Dict[str, Any]):
820
+ """Create a function to map dataset samples to a standard format.
821
+
822
+ Args:
823
+ map: Dictionary mapping standard keys to dataset-specific keys.
824
+
825
+ Returns:
826
+ Function that maps a sample to the standard format.
827
+ """
280
828
  def _map(sample) -> Dict[str, Any]:
281
829
  return {
282
830
  "url": sample[map["url"]],
@@ -285,13 +833,19 @@ def dataMapper(map: Dict[str, Any]):
285
833
  return _map
286
834
 
287
835
 
288
- class OnlineStreamingDataLoader():
836
+ class OnlineStreamingDataLoader:
837
+ """Data loader for streaming media data from online sources."""
838
+
289
839
  def __init__(
290
840
  self,
291
841
  dataset,
292
842
  batch_size=64,
843
+ media_type="image",
293
844
  image_shape=(256, 256),
294
845
  min_image_shape=(128, 128),
846
+ frame_size=256,
847
+ min_frame_size=128,
848
+ num_frames=16,
295
849
  num_workers=16,
296
850
  num_threads=512,
297
851
  default_split="all",
@@ -303,17 +857,49 @@ class OnlineStreamingDataLoader():
303
857
  global_process_count=1,
304
858
  global_process_index=0,
305
859
  prefetch=1000,
306
- collate_fn=default_collate,
860
+ collate_fn=None,
307
861
  timeout=15,
308
862
  retries=3,
309
863
  image_processor=default_image_processor,
864
+ video_processor=default_video_processor,
310
865
  upscale_interpolation=cv2.INTER_CUBIC,
311
866
  downscale_interpolation=cv2.INTER_AREA,
312
867
  feature_extractor=default_feature_extractor,
868
+ resource_manager=None,
313
869
  ):
870
+ """Initialize an online streaming data loader.
871
+
872
+ Args:
873
+ dataset: Dataset to load from, can be a path or a dataset object.
874
+ batch_size: Batch size.
875
+ media_type: Type of media ("image" or "video").
876
+ image_shape: Target shape for images.
877
+ min_image_shape: Minimum acceptable shape for images.
878
+ frame_size: Target size for video frames.
879
+ min_frame_size: Minimum acceptable size for video frames.
880
+ num_frames: Target number of frames for videos.
881
+ num_workers: Number of worker processes.
882
+ num_threads: Number of threads per worker.
883
+ default_split: Default split to use when loading datasets.
884
+ pre_map_maker: Function to create a mapping function.
885
+ pre_map_def: Default mapping definition.
886
+ global_process_count: Total number of processes.
887
+ global_process_index: Index of this process.
888
+ prefetch: Number of batches to prefetch.
889
+ collate_fn: Function to collate samples into batches.
890
+ timeout: Timeout for fetching.
891
+ retries: Number of retries for fetching.
892
+ image_processor: Function to process images.
893
+ video_processor: Function to process videos.
894
+ upscale_interpolation: Interpolation method for upscaling.
895
+ downscale_interpolation: Interpolation method for downscaling.
896
+ feature_extractor: Function to extract features from samples.
897
+ resource_manager: Resource manager to use.
898
+ """
899
+ # Load dataset from path if needed
314
900
  if isinstance(dataset, str):
315
901
  dataset_path = dataset
316
- print("Loading dataset from path")
902
+ print(f"Loading dataset from path: {dataset_path}")
317
903
  if "gs://" in dataset:
318
904
  dataset = load_from_disk(dataset_path)
319
905
  else:
@@ -321,43 +907,86 @@ class OnlineStreamingDataLoader():
321
907
  elif isinstance(dataset, list):
322
908
  if isinstance(dataset[0], str):
323
909
  print("Loading multiple datasets from paths")
324
- dataset = [load_from_disk(dataset_path) if "gs://" in dataset_path else load_dataset(
325
- dataset_path, split=default_split) for dataset_path in dataset]
326
- print("Concatenating multiple datasets")
910
+ dataset = [
911
+ load_from_disk(dataset_path) if "gs://" in dataset_path
912
+ else load_dataset(dataset_path, split=default_split)
913
+ for dataset_path in dataset
914
+ ]
915
+ print(f"Concatenating {len(dataset)} datasets")
327
916
  dataset = concatenate_datasets(dataset)
328
917
  dataset = dataset.shuffle(seed=0)
329
- # dataset = dataset.map(pre_map_maker(pre_map_def), batched=True, batch_size=10000000)
918
+
919
+ # Shard dataset for distributed training
330
920
  self.dataset = dataset.shard(
331
921
  num_shards=global_process_count, index=global_process_index)
332
922
  print(f"Dataset length: {len(dataset)}")
333
- self.iterator = ImageBatchIterator(self.dataset, image_shape=image_shape,
334
- min_image_shape=min_image_shape,
335
- num_workers=num_workers, batch_size=batch_size, num_threads=num_threads,
336
- timeout=timeout, retries=retries, image_processor=image_processor,
337
- upscale_interpolation=upscale_interpolation,
338
- downscale_interpolation=downscale_interpolation,
339
- feature_extractor=feature_extractor)
923
+
924
+ # Get or create resource manager
925
+ self.resource_manager = resource_manager or ResourceManager()
926
+
927
+ # Choose default collate function if not provided
928
+ if collate_fn is None:
929
+ collate_fn = get_default_collate(media_type)
930
+
931
+ # Create media batch iterator
932
+ self.iterator = MediaBatchIterator(
933
+ self.dataset,
934
+ batch_size=batch_size,
935
+ media_type=media_type,
936
+ image_shape=image_shape,
937
+ min_image_shape=min_image_shape,
938
+ frame_size=frame_size,
939
+ min_frame_size=min_frame_size,
940
+ num_frames=num_frames,
941
+ num_workers=num_workers,
942
+ num_threads=num_threads,
943
+ timeout=timeout,
944
+ retries=retries,
945
+ image_processor=image_processor,
946
+ video_processor=video_processor,
947
+ upscale_interpolation=upscale_interpolation,
948
+ downscale_interpolation=downscale_interpolation,
949
+ feature_extractor=feature_extractor,
950
+ resource_manager=self.resource_manager,
951
+ )
952
+
340
953
  self.batch_size = batch_size
954
+ self.collate_fn = collate_fn
341
955
 
342
- # Launch a thread to load batches in the background
956
+ # Create batch queue for prefetching
343
957
  self.batch_queue = queue.Queue(prefetch)
344
-
958
+
959
+ # Start batch loader thread
345
960
  def batch_loader():
346
- for batch in self.iterator:
347
- try:
348
- self.batch_queue.put(collate_fn(batch))
349
- except Exception as e:
350
- print("Error collating batch", e)
351
-
352
- self.loader_thread = threading.Thread(target=batch_loader)
961
+ try:
962
+ for batch in self.iterator:
963
+ try:
964
+ if batch:
965
+ self.batch_queue.put(collate_fn(batch))
966
+ except Exception as e:
967
+ print(f"Error collating batch: {e}")
968
+ traceback.print_exc()
969
+ except Exception as e:
970
+ print(f"Error in batch loader thread: {e}")
971
+ traceback.print_exc()
972
+
973
+ self.loader_thread = threading.Thread(target=batch_loader, daemon=True)
353
974
  self.loader_thread.start()
354
975
 
355
976
  def __iter__(self):
977
+ """Get an iterator for the data loader."""
356
978
  return self
357
979
 
358
980
  def __next__(self):
359
- return self.batch_queue.get()
360
- # return self.collate_fn(next(self.iterator))
981
+ """Get the next batch."""
982
+ try:
983
+ return self.batch_queue.get(timeout=60) # Add timeout to prevent hanging
984
+ except queue.Empty:
985
+ if not self.loader_thread.is_alive():
986
+ raise StopIteration("Loader thread died")
987
+ print("Timeout waiting for batch, retrying...")
988
+ return self.__next__()
361
989
 
362
990
  def __len__(self):
991
+ """Get the number of samples in the dataset."""
363
992
  return len(self.dataset)