flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,608 @@
1
+ import jax.numpy as jnp
2
+ import grain.python as pygrain
3
+ from typing import Dict, Any, Optional, Union, List, Callable
4
+ import numpy as np
5
+ import jax
6
+ import cv2 # Added missing import
7
+ from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
8
+ from .dataset_map import datasetMap, onlineDatasetMap, mediaDatasetMap
9
+ import traceback
10
+ from .online_loader import OnlineStreamingDataLoader
11
+ import queue
12
+ from jax.sharding import Mesh
13
+ import threading
14
+ from functools import partial
15
+
16
+
17
+ def batch_mesh_map(mesh):
18
+ """Create an augmenter that maps batches to a mesh."""
19
+ class augmenters(pygrain.MapTransform):
20
+ def __init__(self, *args, **kwargs):
21
+ super().__init__(*args, **kwargs)
22
+
23
+ def map(self, batch) -> Dict[str, jnp.array]:
24
+ return convert_to_global_tree(mesh, batch)
25
+ return augmenters
26
+
27
+
28
+ class DataLoaderWithMesh:
29
+ """A wrapper for data loaders that distributes data to a JAX mesh.
30
+
31
+ This class wraps any iterable dataset and maps the data to a JAX mesh.
32
+ It runs a background thread that fetches data from the loader and
33
+ distributes it to the mesh.
34
+ """
35
+
36
+ def __init__(self, dataloader, mesh, buffer_size=20):
37
+ """Initialize a DataLoaderWithMesh.
38
+
39
+ Args:
40
+ dataloader: The data loader to wrap.
41
+ mesh: The JAX mesh to distribute data to.
42
+ buffer_size: Size of the prefetch buffer.
43
+ """
44
+ self.dataloader = dataloader
45
+ self.mesh = mesh
46
+ self.buffer_size = buffer_size
47
+ self.tmp_queue = queue.Queue(buffer_size)
48
+ self.loader_thread = None
49
+ self._start_loader_thread()
50
+
51
+ def _start_loader_thread(self):
52
+ """Start the background thread for data loading."""
53
+ def batch_loader():
54
+ try:
55
+ for batch in self.dataloader:
56
+ try:
57
+ self.tmp_queue.put(convert_to_global_tree(self.mesh, batch))
58
+ except Exception as e:
59
+ print("Error processing batch", e)
60
+ traceback.print_exc()
61
+ except Exception as e:
62
+ print("Error in batch loader thread", e)
63
+ traceback.print_exc()
64
+
65
+ self.loader_thread = threading.Thread(target=batch_loader, daemon=True)
66
+ self.loader_thread.start()
67
+
68
+ def __iter__(self):
69
+ return self
70
+
71
+ def __next__(self):
72
+ try:
73
+ return self.tmp_queue.get(timeout=60) # Add timeout to prevent hanging
74
+ except queue.Empty:
75
+ if not self.loader_thread.is_alive():
76
+ raise StopIteration("Loader thread died")
77
+ raise queue.Empty("Timed out waiting for batch")
78
+
79
+ def __del__(self):
80
+ # Clean up resources
81
+ if hasattr(self, 'loader_thread') and self.loader_thread is not None:
82
+ self.loader_thread.join(timeout=1)
83
+
84
+
85
+ def generate_collate_fn(media_type="image"):
86
+ """Generate a collate function based on media type.
87
+
88
+ Args:
89
+ media_type: Type of media ("image" or "video").
90
+
91
+ Returns:
92
+ A collate function for the specified media type.
93
+ """
94
+ auto_tokenize = AutoTextTokenizer(tensor_type="np")
95
+
96
+ def image_collate(batch):
97
+ try:
98
+ # Check if batch is valid
99
+ if not batch or len(batch) == 0:
100
+ print("Warning: Empty batch received")
101
+ # Return an empty batch with the correct structure
102
+ return {
103
+ "image": np.zeros((0, 0, 0, 3), dtype=np.float32),
104
+ "text": {
105
+ "input_ids": np.zeros((0, 0), dtype=np.int32),
106
+ "attention_mask": np.zeros((0, 0), dtype=np.int32),
107
+ }
108
+ }
109
+
110
+ captions = [sample.get("caption", "") for sample in batch]
111
+ results = auto_tokenize(captions)
112
+
113
+ # Check if all images have the same shape
114
+ image_shapes = [sample["image"].shape for sample in batch]
115
+ if len(set(str(shape) for shape in image_shapes)) > 1:
116
+ # Different shapes, need to resize all to the same shape
117
+ target_shape = max(shape[0] for shape in image_shapes), max(shape[1] for shape in image_shapes)
118
+ images = np.stack([
119
+ cv2.resize(sample["image"], target_shape) if sample["image"].shape[:2] != target_shape else sample["image"]
120
+ for sample in batch
121
+ ], axis=0)
122
+ else:
123
+ # All same shape, can just stack
124
+ images = np.stack([sample["image"] for sample in batch], axis=0)
125
+
126
+ return {
127
+ "image": images,
128
+ "text": {
129
+ "input_ids": results['input_ids'],
130
+ "attention_mask": results['attention_mask'],
131
+ }
132
+ }
133
+ except Exception as e:
134
+ print("Error in image collate function", e)
135
+ traceback.print_exc()
136
+ # Return a fallback batch
137
+ return fallback_batch(batch, media_type="image")
138
+
139
+ def video_collate(batch):
140
+ try:
141
+ # Check if batch is valid
142
+ if not batch or len(batch) == 0:
143
+ print("Warning: Empty batch received")
144
+ # Return an empty batch with the correct structure
145
+ return {
146
+ "video": np.zeros((0, 0, 0, 0, 3), dtype=np.float32),
147
+ "text": {
148
+ "input_ids": np.zeros((0, 0), dtype=np.int32),
149
+ "attention_mask": np.zeros((0, 0), dtype=np.int32),
150
+ }
151
+ }
152
+
153
+ captions = [sample.get("caption", "") for sample in batch]
154
+ results = auto_tokenize(captions)
155
+
156
+ # Check if all videos have the same shape
157
+ video_shapes = [sample["video"].shape for sample in batch]
158
+ if len(set(str(shape) for shape in video_shapes)) > 1:
159
+ # Get max dimensions
160
+ max_frames = max(shape[0] for shape in video_shapes)
161
+ max_height = max(shape[1] for shape in video_shapes)
162
+ max_width = max(shape[2] for shape in video_shapes)
163
+
164
+ # Resize videos to the same shape
165
+ videos = []
166
+ for sample in batch:
167
+ video = sample["video"]
168
+ num_frames, height, width = video.shape[:3]
169
+
170
+ if height != max_height or width != max_width:
171
+ # Resize each frame
172
+ resized_frames = np.array([
173
+ cv2.resize(frame, (max_width, max_height))
174
+ for frame in video
175
+ ])
176
+ video = resized_frames
177
+
178
+ if num_frames < max_frames:
179
+ # Pad with duplicates of the last frame
180
+ padding = np.tile(video[-1:], (max_frames - num_frames, 1, 1, 1))
181
+ video = np.concatenate([video, padding], axis=0)
182
+
183
+ videos.append(video)
184
+
185
+ videos = np.stack(videos, axis=0)
186
+ else:
187
+ # All videos have the same shape, can just stack
188
+ videos = np.stack([sample["video"] for sample in batch], axis=0)
189
+
190
+ return {
191
+ "video": videos,
192
+ "text": {
193
+ "input_ids": results['input_ids'],
194
+ "attention_mask": results['attention_mask'],
195
+ }
196
+ }
197
+ except Exception as e:
198
+ print("Error in video collate function", e)
199
+ traceback.print_exc()
200
+ # Return a fallback batch
201
+ return fallback_batch(batch, media_type="video")
202
+
203
+ def fallback_batch(batch, media_type="image"):
204
+ """Create a fallback batch when an error occurs."""
205
+ try:
206
+ batch_size = len(batch) if batch else 1
207
+ if media_type == "video":
208
+ # Create a small valid video batch
209
+ dummy_video = np.zeros((batch_size, 4, 32, 32, 3), dtype=np.uint8)
210
+ dummy_text = auto_tokenize(["Error processing video"] * batch_size)
211
+ return {
212
+ "video": dummy_video,
213
+ "text": {
214
+ "input_ids": dummy_text['input_ids'],
215
+ "attention_mask": dummy_text['attention_mask'],
216
+ }
217
+ }
218
+ else:
219
+ # Create a small valid image batch
220
+ dummy_image = np.zeros((batch_size, 32, 32, 3), dtype=np.uint8)
221
+ dummy_text = auto_tokenize(["Error processing image"] * batch_size)
222
+ return {
223
+ "image": dummy_image,
224
+ "text": {
225
+ "input_ids": dummy_text['input_ids'],
226
+ "attention_mask": dummy_text['attention_mask'],
227
+ }
228
+ }
229
+ except Exception as e:
230
+ print("Error creating fallback batch", e)
231
+ # Last resort fallback
232
+ if media_type == "video":
233
+ return {
234
+ "video": np.zeros((1, 4, 32, 32, 3), dtype=np.uint8),
235
+ "text": {
236
+ "input_ids": np.zeros((1, 16), dtype=np.int32),
237
+ "attention_mask": np.zeros((1, 16), dtype=np.int32),
238
+ }
239
+ }
240
+ else:
241
+ return {
242
+ "image": np.zeros((1, 32, 32, 3), dtype=np.uint8),
243
+ "text": {
244
+ "input_ids": np.zeros((1, 16), dtype=np.int32),
245
+ "attention_mask": np.zeros((1, 16), dtype=np.int32),
246
+ }
247
+ }
248
+
249
+ if media_type == "video":
250
+ return video_collate
251
+ else: # Default to image
252
+ return image_collate
253
+
254
+
255
+ def get_dataset_grain(
256
+ data_name="cc12m",
257
+ batch_size=64,
258
+ image_scale=256,
259
+ count=None,
260
+ num_epochs=None,
261
+ method=jax.image.ResizeMethod.LANCZOS3,
262
+ worker_count=32,
263
+ read_thread_count=64,
264
+ read_buffer_size=50,
265
+ worker_buffer_size=20,
266
+ seed=0,
267
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
268
+ ):
269
+ """Legacy function for getting grain dataset loaders for images.
270
+
271
+ Args:
272
+ data_name: Name of the dataset in datasetMap.
273
+ batch_size: Batch size for the dataset.
274
+ image_scale: Size to scale images to.
275
+ count: Optional count limit for the dataset.
276
+ num_epochs: Number of epochs to iterate.
277
+ method: Interpolation method for resizing.
278
+ worker_count: Number of worker processes.
279
+ read_thread_count: Number of read threads.
280
+ read_buffer_size: Size of the read buffer.
281
+ worker_buffer_size: Size of the worker buffer.
282
+ seed: Random seed.
283
+ dataset_source: Source path for the dataset.
284
+
285
+ Returns:
286
+ Dictionary with train dataset function and metadata.
287
+ """
288
+ dataset = datasetMap[data_name]
289
+ data_source = dataset["source"](dataset_source)
290
+ augmenter = dataset["augmenter"](image_scale, method)
291
+
292
+ local_batch_size = batch_size // jax.process_count()
293
+
294
+ sampler = pygrain.IndexSampler(
295
+ num_records=len(data_source) if count is None else count,
296
+ shuffle=True,
297
+ seed=seed,
298
+ num_epochs=num_epochs,
299
+ shard_options=pygrain.ShardByJaxProcess(),
300
+ )
301
+
302
+ def get_trainset():
303
+ transformations = [
304
+ augmenter(),
305
+ pygrain.Batch(local_batch_size, drop_remainder=True),
306
+ ]
307
+
308
+ loader = pygrain.DataLoader(
309
+ data_source=data_source,
310
+ sampler=sampler,
311
+ operations=transformations,
312
+ worker_count=worker_count,
313
+ read_options=pygrain.ReadOptions(
314
+ read_thread_count, read_buffer_size
315
+ ),
316
+ worker_buffer_size=worker_buffer_size,
317
+ )
318
+ return loader
319
+
320
+ return {
321
+ "train": get_trainset,
322
+ "train_len": len(data_source),
323
+ "local_batch_size": local_batch_size,
324
+ "global_batch_size": batch_size,
325
+ }
326
+
327
+
328
+ def get_dataset_online(
329
+ data_name="combined_online",
330
+ batch_size=64,
331
+ image_scale=256,
332
+ count=None,
333
+ num_epochs=None,
334
+ method=jax.image.ResizeMethod.LANCZOS3,
335
+ worker_count=32,
336
+ read_thread_count=64,
337
+ read_buffer_size=50,
338
+ worker_buffer_size=20,
339
+ seed=0,
340
+ dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
341
+ ):
342
+ """Legacy function for getting online streaming dataloader for images.
343
+
344
+ Args:
345
+ data_name: Name of the dataset in onlineDatasetMap.
346
+ batch_size: Batch size for the dataset.
347
+ image_scale: Size to scale images to.
348
+ count: Optional count limit for the dataset.
349
+ num_epochs: Number of epochs to iterate.
350
+ method: Interpolation method for resizing.
351
+ worker_count: Number of worker processes.
352
+ read_thread_count: Number of read threads.
353
+ read_buffer_size: Size of the read buffer.
354
+ worker_buffer_size: Size of the worker buffer.
355
+ seed: Random seed.
356
+ dataset_source: Source path for the dataset.
357
+
358
+ Returns:
359
+ Dictionary with train dataset function and metadata.
360
+ """
361
+ local_batch_size = batch_size // jax.process_count()
362
+
363
+ sources = onlineDatasetMap[data_name]["source"]
364
+ dataloader = OnlineStreamingDataLoader(
365
+ sources,
366
+ batch_size=local_batch_size,
367
+ num_workers=worker_count,
368
+ num_threads=read_thread_count,
369
+ image_shape=(image_scale, image_scale),
370
+ global_process_count=jax.process_count(),
371
+ global_process_index=jax.process_index(),
372
+ prefetch=worker_buffer_size,
373
+ collate_fn=generate_collate_fn(),
374
+ default_split="train",
375
+ )
376
+
377
+ def get_trainset(mesh: Mesh = None):
378
+ if mesh is not None:
379
+ return DataLoaderWithMesh(dataloader, mesh, buffer_size=worker_buffer_size)
380
+ return dataloader
381
+
382
+ return {
383
+ "train": get_trainset,
384
+ "train_len": len(dataloader) * jax.process_count(),
385
+ "local_batch_size": local_batch_size,
386
+ "global_batch_size": batch_size,
387
+ }
388
+
389
+
390
+ # ---------------------------------------------------------------------------------
391
+ # New unified dataset loader for both images and videos
392
+ # ---------------------------------------------------------------------------------
393
+
394
+ def get_media_dataset_grain(
395
+ data_name: str,
396
+ batch_size: int = 64,
397
+ media_scale: int = 256,
398
+ sequence_length: int = 1,
399
+ count: Optional[int] = None,
400
+ num_epochs: Optional[int] = None,
401
+ method: Any = cv2.INTER_AREA,
402
+ worker_count: int = 32,
403
+ read_thread_count: int = 64,
404
+ read_buffer_size: int = 50,
405
+ worker_buffer_size: int = 20,
406
+ seed: int = 0,
407
+ dataset_source: str = None,
408
+ media_type: Optional[str] = None, # Will be auto-detected if None
409
+ mesh: Optional[Mesh] = None,
410
+ additional_transform_kwargs: Dict[str, Any] = None,
411
+ ):
412
+ """Get a grain dataset loader for any media type (image or video).
413
+
414
+ Args:
415
+ data_name: Name of the dataset in mediaDatasetMap.
416
+ batch_size: Batch size for the dataset.
417
+ media_scale: Size to scale media (image or video frames) to.
418
+ sequence_length: Length of the sequence for video data.
419
+ count: Optional count limit for the dataset.
420
+ num_epochs: Number of epochs to iterate.
421
+ method: Interpolation method for resizing.
422
+ worker_count: Number of worker processes.
423
+ read_thread_count: Number of read threads.
424
+ read_buffer_size: Size of the read buffer.
425
+ worker_buffer_size: Size of the worker buffer.
426
+ seed: Random seed.
427
+ dataset_source: Source path for the dataset.
428
+ media_type: Type of media ("image" or "video"). Auto-detected if None.
429
+ mesh: Optional JAX mesh for distributed training.
430
+ additional_transform_kwargs: Additional arguments for the transform.
431
+
432
+ Returns:
433
+ Dictionary with train dataset function and metadata.
434
+ """
435
+ if data_name not in mediaDatasetMap:
436
+ raise ValueError(f"Dataset {data_name} not found in mediaDatasetMap")
437
+
438
+ media_dataset = mediaDatasetMap[data_name]
439
+
440
+ # Auto-detect media_type if not provided
441
+ if media_type is None:
442
+ media_type = media_dataset.media_type
443
+
444
+ # Get the data source and augmenter
445
+ data_source = media_dataset.get_source(dataset_source)
446
+
447
+ # Prepare transform kwargs
448
+ transform_kwargs = {
449
+ "image_scale" if media_type == "image" else "frame_size": media_scale,
450
+ "method": method,
451
+ "sequence_length": sequence_length,
452
+ }
453
+ if additional_transform_kwargs:
454
+ transform_kwargs.update(additional_transform_kwargs)
455
+
456
+ augmenter = media_dataset.get_augmenter(**transform_kwargs)
457
+
458
+ # Calculate local batch size for distributed training
459
+ local_batch_size = batch_size // jax.process_count()
460
+
461
+ # Create a sampler for the dataset
462
+ if hasattr(data_source, "__len__"):
463
+ dataset_length = len(data_source) if count is None else count
464
+ else:
465
+ # Some data sources like video files list don't have __len__
466
+ dataset_length = count if count is not None else 1000000 # Default large number
467
+
468
+ sampler = pygrain.IndexSampler(
469
+ num_records=dataset_length,
470
+ shuffle=True,
471
+ seed=seed,
472
+ num_epochs=num_epochs,
473
+ shard_options=pygrain.ShardByJaxProcess(),
474
+ )
475
+
476
+ def get_trainset(mesh_override: Optional[Mesh] = None):
477
+ """Get a training dataset iterator.
478
+
479
+ Args:
480
+ mesh_override: Optional mesh to override the default.
481
+
482
+ Returns:
483
+ A dataset iterator.
484
+ """
485
+ current_mesh = mesh_override or mesh
486
+
487
+ transformations = [
488
+ augmenter(),
489
+ pygrain.Batch(local_batch_size, drop_remainder=True),
490
+ ]
491
+
492
+ # # Add mesh mapping if needed
493
+ # if current_mesh is not None:
494
+ # transformations.append(batch_mesh_map(current_mesh)())
495
+
496
+ loader = pygrain.DataLoader(
497
+ data_source=data_source,
498
+ sampler=sampler,
499
+ operations=transformations,
500
+ worker_count=worker_count,
501
+ read_options=pygrain.ReadOptions(
502
+ read_thread_count, read_buffer_size
503
+ ),
504
+ worker_buffer_size=worker_buffer_size,
505
+ )
506
+ return loader
507
+
508
+ return {
509
+ "train": get_trainset,
510
+ "train_len": dataset_length,
511
+ "local_batch_size": local_batch_size,
512
+ "global_batch_size": batch_size,
513
+ "media_type": media_type,
514
+ }
515
+
516
+
517
+ def get_media_dataset_online(
518
+ data_name: str = "combined_online",
519
+ batch_size: int = 64,
520
+ media_scale: int = 256,
521
+ worker_count: int = 16,
522
+ read_thread_count: int = 512,
523
+ worker_buffer_size: int = 20,
524
+ dataset_sources: List[str] = None,
525
+ media_type: str = "image", # Default to image for online datasets
526
+ mesh: Optional[Mesh] = None,
527
+ timeout: int = 15,
528
+ retries: int = 3,
529
+ min_media_scale: int = 128,
530
+ ):
531
+ """Get an online streaming dataset loader for any media type.
532
+
533
+ Args:
534
+ data_name: Name of the dataset in onlineDatasetMap, or "custom" for custom sources.
535
+ batch_size: Batch size for the dataset.
536
+ media_scale: Size to scale media (image or video frames) to.
537
+ worker_count: Number of worker processes.
538
+ read_thread_count: Number of read threads.
539
+ worker_buffer_size: Size of the worker buffer.
540
+ dataset_sources: Custom dataset sources if data_name is "custom".
541
+ media_type: Type of media ("image" or "video").
542
+ mesh: Optional JAX mesh for distributed training.
543
+ timeout: Timeout for dataset operations.
544
+ retries: Number of retries for dataset operations.
545
+ min_media_scale: Minimum scale for media items.
546
+
547
+ Returns:
548
+ Dictionary with train dataset function and metadata.
549
+ """
550
+ local_batch_size = batch_size // jax.process_count()
551
+
552
+ # Get dataset sources
553
+ if dataset_sources is None:
554
+ if data_name not in onlineDatasetMap:
555
+ raise ValueError(f"Dataset {data_name} not found in onlineDatasetMap")
556
+ sources = onlineDatasetMap[data_name]["source"]
557
+ else:
558
+ sources = dataset_sources
559
+
560
+ # Configure shape parameter based on media type
561
+ shape_param = "image_shape" if media_type == "image" else "frame_size"
562
+ shape_value = (media_scale, media_scale) if media_type == "image" else media_scale
563
+
564
+ # Configure min scale parameter based on media type
565
+ min_scale_param = "min_image_shape" if media_type == "image" else "min_frame_size"
566
+ min_scale_value = (min_media_scale, min_media_scale) if media_type == "image" else min_media_scale
567
+
568
+ # Prepare dataloader kwargs
569
+ dataloader_kwargs = {
570
+ "batch_size": local_batch_size,
571
+ "num_workers": worker_count,
572
+ "num_threads": read_thread_count,
573
+ shape_param: shape_value,
574
+ min_scale_param: min_scale_value,
575
+ "global_process_count": jax.process_count(),
576
+ "global_process_index": jax.process_index(),
577
+ "prefetch": worker_buffer_size,
578
+ "collate_fn": generate_collate_fn(media_type),
579
+ "default_split": "train",
580
+ "timeout": timeout,
581
+ "retries": retries,
582
+ }
583
+
584
+ dataloader = OnlineStreamingDataLoader(sources, **dataloader_kwargs)
585
+
586
+ def get_trainset(mesh_override: Optional[Mesh] = None):
587
+ """Get a training dataset iterator.
588
+
589
+ Args:
590
+ mesh_override: Optional mesh to override the default.
591
+
592
+ Returns:
593
+ A dataset iterator.
594
+ """
595
+ current_mesh = mesh_override or mesh
596
+
597
+ if current_mesh is not None:
598
+ return DataLoaderWithMesh(dataloader, current_mesh, buffer_size=worker_buffer_size)
599
+
600
+ return dataloader
601
+
602
+ return {
603
+ "train": get_trainset,
604
+ "train_len": len(dataloader) * jax.process_count(),
605
+ "local_batch_size": local_batch_size,
606
+ "global_batch_size": batch_size,
607
+ "media_type": media_type,
608
+ }
@@ -1,5 +1,14 @@
1
- from .sources.tfds import data_source_tfds, tfds_augmenters
2
- from .sources.gcs import data_source_gcs, data_source_combined_gcs, gcs_augmenters
1
+ from .sources.base import MediaDataset, DataSource, DataAugmenter
2
+ from .sources.images import ImageTFDSSource, ImageGCSSource, CombinedImageGCSSource
3
+ from .sources.images import ImageTFDSAugmenter, ImageGCSAugmenter
4
+ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugmenter
5
+
6
+ # ---------------------------------------------------------------------------------
7
+ # Legacy compatibility mappings
8
+ # ---------------------------------------------------------------------------------
9
+
10
+ from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
11
+ from .sources.images import data_source_combined_gcs, gcs_augmenters
3
12
 
4
13
  # Configure the following for your datasets
5
14
  datasetMap = {
@@ -50,9 +59,6 @@ datasetMap = {
50
59
  onlineDatasetMap = {
51
60
  "combined_online": {
52
61
  "source": [
53
- # "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
54
- # "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
55
- # "dclure/laion-aesthetics-12m-umap",
56
62
  "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
57
63
  "gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
58
64
  "gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
@@ -65,7 +71,56 @@ onlineDatasetMap = {
65
71
  "gs://flaxdiff-datasets-regional/datasets/cc3m",
66
72
  "gs://flaxdiff-datasets-regional/datasets/cc3m",
67
73
  "gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
68
- # "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
69
74
  ]
70
75
  }
76
+ }
77
+
78
+ # ---------------------------------------------------------------------------------
79
+ # New media datasets configuration with the unified architecture
80
+ # ---------------------------------------------------------------------------------
81
+
82
+ mediaDatasetMap = {
83
+ # Image datasets
84
+ "oxford_flowers102": MediaDataset(
85
+ source=ImageTFDSSource(name="oxford_flowers102", use_tf=False),
86
+ augmenter=ImageTFDSAugmenter(),
87
+ media_type="image"
88
+ ),
89
+ "cc12m": MediaDataset(
90
+ source=ImageGCSSource(source='arrayrecord2/cc12m'),
91
+ augmenter=ImageGCSAugmenter(),
92
+ media_type="image"
93
+ ),
94
+ "laiona_coco": MediaDataset(
95
+ source=ImageGCSSource(source='arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
96
+ augmenter=ImageGCSAugmenter(),
97
+ media_type="image"
98
+ ),
99
+ "combined_aesthetic": MediaDataset(
100
+ source=CombinedImageGCSSource(sources=[
101
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
102
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
103
+ 'arrayrecord2/cc12m',
104
+ 'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
105
+ ]),
106
+ augmenter=ImageGCSAugmenter(),
107
+ media_type="image"
108
+ ),
109
+ "combined_30m": MediaDataset(
110
+ source=CombinedImageGCSSource(sources=[
111
+ 'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
112
+ 'arrayrecord2/cc12m',
113
+ 'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
114
+ "arrayrecord2/playground+leonardo_x4+cc3m.parquet",
115
+ ]),
116
+ augmenter=ImageGCSAugmenter(),
117
+ media_type="image"
118
+ ),
119
+
120
+ # Video dataset
121
+ "voxceleb2": MediaDataset(
122
+ source=VideoLocalSource(),
123
+ augmenter=AudioVideoAugmenter(),
124
+ media_type="video"
125
+ ),
71
126
  }