flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,608 @@
|
|
1
|
+
import jax.numpy as jnp
|
2
|
+
import grain.python as pygrain
|
3
|
+
from typing import Dict, Any, Optional, Union, List, Callable
|
4
|
+
import numpy as np
|
5
|
+
import jax
|
6
|
+
import cv2 # Added missing import
|
7
|
+
from flaxdiff.utils import convert_to_global_tree, AutoTextTokenizer
|
8
|
+
from .dataset_map import datasetMap, onlineDatasetMap, mediaDatasetMap
|
9
|
+
import traceback
|
10
|
+
from .online_loader import OnlineStreamingDataLoader
|
11
|
+
import queue
|
12
|
+
from jax.sharding import Mesh
|
13
|
+
import threading
|
14
|
+
from functools import partial
|
15
|
+
|
16
|
+
|
17
|
+
def batch_mesh_map(mesh):
|
18
|
+
"""Create an augmenter that maps batches to a mesh."""
|
19
|
+
class augmenters(pygrain.MapTransform):
|
20
|
+
def __init__(self, *args, **kwargs):
|
21
|
+
super().__init__(*args, **kwargs)
|
22
|
+
|
23
|
+
def map(self, batch) -> Dict[str, jnp.array]:
|
24
|
+
return convert_to_global_tree(mesh, batch)
|
25
|
+
return augmenters
|
26
|
+
|
27
|
+
|
28
|
+
class DataLoaderWithMesh:
|
29
|
+
"""A wrapper for data loaders that distributes data to a JAX mesh.
|
30
|
+
|
31
|
+
This class wraps any iterable dataset and maps the data to a JAX mesh.
|
32
|
+
It runs a background thread that fetches data from the loader and
|
33
|
+
distributes it to the mesh.
|
34
|
+
"""
|
35
|
+
|
36
|
+
def __init__(self, dataloader, mesh, buffer_size=20):
|
37
|
+
"""Initialize a DataLoaderWithMesh.
|
38
|
+
|
39
|
+
Args:
|
40
|
+
dataloader: The data loader to wrap.
|
41
|
+
mesh: The JAX mesh to distribute data to.
|
42
|
+
buffer_size: Size of the prefetch buffer.
|
43
|
+
"""
|
44
|
+
self.dataloader = dataloader
|
45
|
+
self.mesh = mesh
|
46
|
+
self.buffer_size = buffer_size
|
47
|
+
self.tmp_queue = queue.Queue(buffer_size)
|
48
|
+
self.loader_thread = None
|
49
|
+
self._start_loader_thread()
|
50
|
+
|
51
|
+
def _start_loader_thread(self):
|
52
|
+
"""Start the background thread for data loading."""
|
53
|
+
def batch_loader():
|
54
|
+
try:
|
55
|
+
for batch in self.dataloader:
|
56
|
+
try:
|
57
|
+
self.tmp_queue.put(convert_to_global_tree(self.mesh, batch))
|
58
|
+
except Exception as e:
|
59
|
+
print("Error processing batch", e)
|
60
|
+
traceback.print_exc()
|
61
|
+
except Exception as e:
|
62
|
+
print("Error in batch loader thread", e)
|
63
|
+
traceback.print_exc()
|
64
|
+
|
65
|
+
self.loader_thread = threading.Thread(target=batch_loader, daemon=True)
|
66
|
+
self.loader_thread.start()
|
67
|
+
|
68
|
+
def __iter__(self):
|
69
|
+
return self
|
70
|
+
|
71
|
+
def __next__(self):
|
72
|
+
try:
|
73
|
+
return self.tmp_queue.get(timeout=60) # Add timeout to prevent hanging
|
74
|
+
except queue.Empty:
|
75
|
+
if not self.loader_thread.is_alive():
|
76
|
+
raise StopIteration("Loader thread died")
|
77
|
+
raise queue.Empty("Timed out waiting for batch")
|
78
|
+
|
79
|
+
def __del__(self):
|
80
|
+
# Clean up resources
|
81
|
+
if hasattr(self, 'loader_thread') and self.loader_thread is not None:
|
82
|
+
self.loader_thread.join(timeout=1)
|
83
|
+
|
84
|
+
|
85
|
+
def generate_collate_fn(media_type="image"):
|
86
|
+
"""Generate a collate function based on media type.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
media_type: Type of media ("image" or "video").
|
90
|
+
|
91
|
+
Returns:
|
92
|
+
A collate function for the specified media type.
|
93
|
+
"""
|
94
|
+
auto_tokenize = AutoTextTokenizer(tensor_type="np")
|
95
|
+
|
96
|
+
def image_collate(batch):
|
97
|
+
try:
|
98
|
+
# Check if batch is valid
|
99
|
+
if not batch or len(batch) == 0:
|
100
|
+
print("Warning: Empty batch received")
|
101
|
+
# Return an empty batch with the correct structure
|
102
|
+
return {
|
103
|
+
"image": np.zeros((0, 0, 0, 3), dtype=np.float32),
|
104
|
+
"text": {
|
105
|
+
"input_ids": np.zeros((0, 0), dtype=np.int32),
|
106
|
+
"attention_mask": np.zeros((0, 0), dtype=np.int32),
|
107
|
+
}
|
108
|
+
}
|
109
|
+
|
110
|
+
captions = [sample.get("caption", "") for sample in batch]
|
111
|
+
results = auto_tokenize(captions)
|
112
|
+
|
113
|
+
# Check if all images have the same shape
|
114
|
+
image_shapes = [sample["image"].shape for sample in batch]
|
115
|
+
if len(set(str(shape) for shape in image_shapes)) > 1:
|
116
|
+
# Different shapes, need to resize all to the same shape
|
117
|
+
target_shape = max(shape[0] for shape in image_shapes), max(shape[1] for shape in image_shapes)
|
118
|
+
images = np.stack([
|
119
|
+
cv2.resize(sample["image"], target_shape) if sample["image"].shape[:2] != target_shape else sample["image"]
|
120
|
+
for sample in batch
|
121
|
+
], axis=0)
|
122
|
+
else:
|
123
|
+
# All same shape, can just stack
|
124
|
+
images = np.stack([sample["image"] for sample in batch], axis=0)
|
125
|
+
|
126
|
+
return {
|
127
|
+
"image": images,
|
128
|
+
"text": {
|
129
|
+
"input_ids": results['input_ids'],
|
130
|
+
"attention_mask": results['attention_mask'],
|
131
|
+
}
|
132
|
+
}
|
133
|
+
except Exception as e:
|
134
|
+
print("Error in image collate function", e)
|
135
|
+
traceback.print_exc()
|
136
|
+
# Return a fallback batch
|
137
|
+
return fallback_batch(batch, media_type="image")
|
138
|
+
|
139
|
+
def video_collate(batch):
|
140
|
+
try:
|
141
|
+
# Check if batch is valid
|
142
|
+
if not batch or len(batch) == 0:
|
143
|
+
print("Warning: Empty batch received")
|
144
|
+
# Return an empty batch with the correct structure
|
145
|
+
return {
|
146
|
+
"video": np.zeros((0, 0, 0, 0, 3), dtype=np.float32),
|
147
|
+
"text": {
|
148
|
+
"input_ids": np.zeros((0, 0), dtype=np.int32),
|
149
|
+
"attention_mask": np.zeros((0, 0), dtype=np.int32),
|
150
|
+
}
|
151
|
+
}
|
152
|
+
|
153
|
+
captions = [sample.get("caption", "") for sample in batch]
|
154
|
+
results = auto_tokenize(captions)
|
155
|
+
|
156
|
+
# Check if all videos have the same shape
|
157
|
+
video_shapes = [sample["video"].shape for sample in batch]
|
158
|
+
if len(set(str(shape) for shape in video_shapes)) > 1:
|
159
|
+
# Get max dimensions
|
160
|
+
max_frames = max(shape[0] for shape in video_shapes)
|
161
|
+
max_height = max(shape[1] for shape in video_shapes)
|
162
|
+
max_width = max(shape[2] for shape in video_shapes)
|
163
|
+
|
164
|
+
# Resize videos to the same shape
|
165
|
+
videos = []
|
166
|
+
for sample in batch:
|
167
|
+
video = sample["video"]
|
168
|
+
num_frames, height, width = video.shape[:3]
|
169
|
+
|
170
|
+
if height != max_height or width != max_width:
|
171
|
+
# Resize each frame
|
172
|
+
resized_frames = np.array([
|
173
|
+
cv2.resize(frame, (max_width, max_height))
|
174
|
+
for frame in video
|
175
|
+
])
|
176
|
+
video = resized_frames
|
177
|
+
|
178
|
+
if num_frames < max_frames:
|
179
|
+
# Pad with duplicates of the last frame
|
180
|
+
padding = np.tile(video[-1:], (max_frames - num_frames, 1, 1, 1))
|
181
|
+
video = np.concatenate([video, padding], axis=0)
|
182
|
+
|
183
|
+
videos.append(video)
|
184
|
+
|
185
|
+
videos = np.stack(videos, axis=0)
|
186
|
+
else:
|
187
|
+
# All videos have the same shape, can just stack
|
188
|
+
videos = np.stack([sample["video"] for sample in batch], axis=0)
|
189
|
+
|
190
|
+
return {
|
191
|
+
"video": videos,
|
192
|
+
"text": {
|
193
|
+
"input_ids": results['input_ids'],
|
194
|
+
"attention_mask": results['attention_mask'],
|
195
|
+
}
|
196
|
+
}
|
197
|
+
except Exception as e:
|
198
|
+
print("Error in video collate function", e)
|
199
|
+
traceback.print_exc()
|
200
|
+
# Return a fallback batch
|
201
|
+
return fallback_batch(batch, media_type="video")
|
202
|
+
|
203
|
+
def fallback_batch(batch, media_type="image"):
|
204
|
+
"""Create a fallback batch when an error occurs."""
|
205
|
+
try:
|
206
|
+
batch_size = len(batch) if batch else 1
|
207
|
+
if media_type == "video":
|
208
|
+
# Create a small valid video batch
|
209
|
+
dummy_video = np.zeros((batch_size, 4, 32, 32, 3), dtype=np.uint8)
|
210
|
+
dummy_text = auto_tokenize(["Error processing video"] * batch_size)
|
211
|
+
return {
|
212
|
+
"video": dummy_video,
|
213
|
+
"text": {
|
214
|
+
"input_ids": dummy_text['input_ids'],
|
215
|
+
"attention_mask": dummy_text['attention_mask'],
|
216
|
+
}
|
217
|
+
}
|
218
|
+
else:
|
219
|
+
# Create a small valid image batch
|
220
|
+
dummy_image = np.zeros((batch_size, 32, 32, 3), dtype=np.uint8)
|
221
|
+
dummy_text = auto_tokenize(["Error processing image"] * batch_size)
|
222
|
+
return {
|
223
|
+
"image": dummy_image,
|
224
|
+
"text": {
|
225
|
+
"input_ids": dummy_text['input_ids'],
|
226
|
+
"attention_mask": dummy_text['attention_mask'],
|
227
|
+
}
|
228
|
+
}
|
229
|
+
except Exception as e:
|
230
|
+
print("Error creating fallback batch", e)
|
231
|
+
# Last resort fallback
|
232
|
+
if media_type == "video":
|
233
|
+
return {
|
234
|
+
"video": np.zeros((1, 4, 32, 32, 3), dtype=np.uint8),
|
235
|
+
"text": {
|
236
|
+
"input_ids": np.zeros((1, 16), dtype=np.int32),
|
237
|
+
"attention_mask": np.zeros((1, 16), dtype=np.int32),
|
238
|
+
}
|
239
|
+
}
|
240
|
+
else:
|
241
|
+
return {
|
242
|
+
"image": np.zeros((1, 32, 32, 3), dtype=np.uint8),
|
243
|
+
"text": {
|
244
|
+
"input_ids": np.zeros((1, 16), dtype=np.int32),
|
245
|
+
"attention_mask": np.zeros((1, 16), dtype=np.int32),
|
246
|
+
}
|
247
|
+
}
|
248
|
+
|
249
|
+
if media_type == "video":
|
250
|
+
return video_collate
|
251
|
+
else: # Default to image
|
252
|
+
return image_collate
|
253
|
+
|
254
|
+
|
255
|
+
def get_dataset_grain(
|
256
|
+
data_name="cc12m",
|
257
|
+
batch_size=64,
|
258
|
+
image_scale=256,
|
259
|
+
count=None,
|
260
|
+
num_epochs=None,
|
261
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
262
|
+
worker_count=32,
|
263
|
+
read_thread_count=64,
|
264
|
+
read_buffer_size=50,
|
265
|
+
worker_buffer_size=20,
|
266
|
+
seed=0,
|
267
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
268
|
+
):
|
269
|
+
"""Legacy function for getting grain dataset loaders for images.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
data_name: Name of the dataset in datasetMap.
|
273
|
+
batch_size: Batch size for the dataset.
|
274
|
+
image_scale: Size to scale images to.
|
275
|
+
count: Optional count limit for the dataset.
|
276
|
+
num_epochs: Number of epochs to iterate.
|
277
|
+
method: Interpolation method for resizing.
|
278
|
+
worker_count: Number of worker processes.
|
279
|
+
read_thread_count: Number of read threads.
|
280
|
+
read_buffer_size: Size of the read buffer.
|
281
|
+
worker_buffer_size: Size of the worker buffer.
|
282
|
+
seed: Random seed.
|
283
|
+
dataset_source: Source path for the dataset.
|
284
|
+
|
285
|
+
Returns:
|
286
|
+
Dictionary with train dataset function and metadata.
|
287
|
+
"""
|
288
|
+
dataset = datasetMap[data_name]
|
289
|
+
data_source = dataset["source"](dataset_source)
|
290
|
+
augmenter = dataset["augmenter"](image_scale, method)
|
291
|
+
|
292
|
+
local_batch_size = batch_size // jax.process_count()
|
293
|
+
|
294
|
+
sampler = pygrain.IndexSampler(
|
295
|
+
num_records=len(data_source) if count is None else count,
|
296
|
+
shuffle=True,
|
297
|
+
seed=seed,
|
298
|
+
num_epochs=num_epochs,
|
299
|
+
shard_options=pygrain.ShardByJaxProcess(),
|
300
|
+
)
|
301
|
+
|
302
|
+
def get_trainset():
|
303
|
+
transformations = [
|
304
|
+
augmenter(),
|
305
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
306
|
+
]
|
307
|
+
|
308
|
+
loader = pygrain.DataLoader(
|
309
|
+
data_source=data_source,
|
310
|
+
sampler=sampler,
|
311
|
+
operations=transformations,
|
312
|
+
worker_count=worker_count,
|
313
|
+
read_options=pygrain.ReadOptions(
|
314
|
+
read_thread_count, read_buffer_size
|
315
|
+
),
|
316
|
+
worker_buffer_size=worker_buffer_size,
|
317
|
+
)
|
318
|
+
return loader
|
319
|
+
|
320
|
+
return {
|
321
|
+
"train": get_trainset,
|
322
|
+
"train_len": len(data_source),
|
323
|
+
"local_batch_size": local_batch_size,
|
324
|
+
"global_batch_size": batch_size,
|
325
|
+
}
|
326
|
+
|
327
|
+
|
328
|
+
def get_dataset_online(
|
329
|
+
data_name="combined_online",
|
330
|
+
batch_size=64,
|
331
|
+
image_scale=256,
|
332
|
+
count=None,
|
333
|
+
num_epochs=None,
|
334
|
+
method=jax.image.ResizeMethod.LANCZOS3,
|
335
|
+
worker_count=32,
|
336
|
+
read_thread_count=64,
|
337
|
+
read_buffer_size=50,
|
338
|
+
worker_buffer_size=20,
|
339
|
+
seed=0,
|
340
|
+
dataset_source="/mnt/gcs_mount/arrayrecord2/cc12m/",
|
341
|
+
):
|
342
|
+
"""Legacy function for getting online streaming dataloader for images.
|
343
|
+
|
344
|
+
Args:
|
345
|
+
data_name: Name of the dataset in onlineDatasetMap.
|
346
|
+
batch_size: Batch size for the dataset.
|
347
|
+
image_scale: Size to scale images to.
|
348
|
+
count: Optional count limit for the dataset.
|
349
|
+
num_epochs: Number of epochs to iterate.
|
350
|
+
method: Interpolation method for resizing.
|
351
|
+
worker_count: Number of worker processes.
|
352
|
+
read_thread_count: Number of read threads.
|
353
|
+
read_buffer_size: Size of the read buffer.
|
354
|
+
worker_buffer_size: Size of the worker buffer.
|
355
|
+
seed: Random seed.
|
356
|
+
dataset_source: Source path for the dataset.
|
357
|
+
|
358
|
+
Returns:
|
359
|
+
Dictionary with train dataset function and metadata.
|
360
|
+
"""
|
361
|
+
local_batch_size = batch_size // jax.process_count()
|
362
|
+
|
363
|
+
sources = onlineDatasetMap[data_name]["source"]
|
364
|
+
dataloader = OnlineStreamingDataLoader(
|
365
|
+
sources,
|
366
|
+
batch_size=local_batch_size,
|
367
|
+
num_workers=worker_count,
|
368
|
+
num_threads=read_thread_count,
|
369
|
+
image_shape=(image_scale, image_scale),
|
370
|
+
global_process_count=jax.process_count(),
|
371
|
+
global_process_index=jax.process_index(),
|
372
|
+
prefetch=worker_buffer_size,
|
373
|
+
collate_fn=generate_collate_fn(),
|
374
|
+
default_split="train",
|
375
|
+
)
|
376
|
+
|
377
|
+
def get_trainset(mesh: Mesh = None):
|
378
|
+
if mesh is not None:
|
379
|
+
return DataLoaderWithMesh(dataloader, mesh, buffer_size=worker_buffer_size)
|
380
|
+
return dataloader
|
381
|
+
|
382
|
+
return {
|
383
|
+
"train": get_trainset,
|
384
|
+
"train_len": len(dataloader) * jax.process_count(),
|
385
|
+
"local_batch_size": local_batch_size,
|
386
|
+
"global_batch_size": batch_size,
|
387
|
+
}
|
388
|
+
|
389
|
+
|
390
|
+
# ---------------------------------------------------------------------------------
|
391
|
+
# New unified dataset loader for both images and videos
|
392
|
+
# ---------------------------------------------------------------------------------
|
393
|
+
|
394
|
+
def get_media_dataset_grain(
|
395
|
+
data_name: str,
|
396
|
+
batch_size: int = 64,
|
397
|
+
media_scale: int = 256,
|
398
|
+
sequence_length: int = 1,
|
399
|
+
count: Optional[int] = None,
|
400
|
+
num_epochs: Optional[int] = None,
|
401
|
+
method: Any = cv2.INTER_AREA,
|
402
|
+
worker_count: int = 32,
|
403
|
+
read_thread_count: int = 64,
|
404
|
+
read_buffer_size: int = 50,
|
405
|
+
worker_buffer_size: int = 20,
|
406
|
+
seed: int = 0,
|
407
|
+
dataset_source: str = None,
|
408
|
+
media_type: Optional[str] = None, # Will be auto-detected if None
|
409
|
+
mesh: Optional[Mesh] = None,
|
410
|
+
additional_transform_kwargs: Dict[str, Any] = None,
|
411
|
+
):
|
412
|
+
"""Get a grain dataset loader for any media type (image or video).
|
413
|
+
|
414
|
+
Args:
|
415
|
+
data_name: Name of the dataset in mediaDatasetMap.
|
416
|
+
batch_size: Batch size for the dataset.
|
417
|
+
media_scale: Size to scale media (image or video frames) to.
|
418
|
+
sequence_length: Length of the sequence for video data.
|
419
|
+
count: Optional count limit for the dataset.
|
420
|
+
num_epochs: Number of epochs to iterate.
|
421
|
+
method: Interpolation method for resizing.
|
422
|
+
worker_count: Number of worker processes.
|
423
|
+
read_thread_count: Number of read threads.
|
424
|
+
read_buffer_size: Size of the read buffer.
|
425
|
+
worker_buffer_size: Size of the worker buffer.
|
426
|
+
seed: Random seed.
|
427
|
+
dataset_source: Source path for the dataset.
|
428
|
+
media_type: Type of media ("image" or "video"). Auto-detected if None.
|
429
|
+
mesh: Optional JAX mesh for distributed training.
|
430
|
+
additional_transform_kwargs: Additional arguments for the transform.
|
431
|
+
|
432
|
+
Returns:
|
433
|
+
Dictionary with train dataset function and metadata.
|
434
|
+
"""
|
435
|
+
if data_name not in mediaDatasetMap:
|
436
|
+
raise ValueError(f"Dataset {data_name} not found in mediaDatasetMap")
|
437
|
+
|
438
|
+
media_dataset = mediaDatasetMap[data_name]
|
439
|
+
|
440
|
+
# Auto-detect media_type if not provided
|
441
|
+
if media_type is None:
|
442
|
+
media_type = media_dataset.media_type
|
443
|
+
|
444
|
+
# Get the data source and augmenter
|
445
|
+
data_source = media_dataset.get_source(dataset_source)
|
446
|
+
|
447
|
+
# Prepare transform kwargs
|
448
|
+
transform_kwargs = {
|
449
|
+
"image_scale" if media_type == "image" else "frame_size": media_scale,
|
450
|
+
"method": method,
|
451
|
+
"sequence_length": sequence_length,
|
452
|
+
}
|
453
|
+
if additional_transform_kwargs:
|
454
|
+
transform_kwargs.update(additional_transform_kwargs)
|
455
|
+
|
456
|
+
augmenter = media_dataset.get_augmenter(**transform_kwargs)
|
457
|
+
|
458
|
+
# Calculate local batch size for distributed training
|
459
|
+
local_batch_size = batch_size // jax.process_count()
|
460
|
+
|
461
|
+
# Create a sampler for the dataset
|
462
|
+
if hasattr(data_source, "__len__"):
|
463
|
+
dataset_length = len(data_source) if count is None else count
|
464
|
+
else:
|
465
|
+
# Some data sources like video files list don't have __len__
|
466
|
+
dataset_length = count if count is not None else 1000000 # Default large number
|
467
|
+
|
468
|
+
sampler = pygrain.IndexSampler(
|
469
|
+
num_records=dataset_length,
|
470
|
+
shuffle=True,
|
471
|
+
seed=seed,
|
472
|
+
num_epochs=num_epochs,
|
473
|
+
shard_options=pygrain.ShardByJaxProcess(),
|
474
|
+
)
|
475
|
+
|
476
|
+
def get_trainset(mesh_override: Optional[Mesh] = None):
|
477
|
+
"""Get a training dataset iterator.
|
478
|
+
|
479
|
+
Args:
|
480
|
+
mesh_override: Optional mesh to override the default.
|
481
|
+
|
482
|
+
Returns:
|
483
|
+
A dataset iterator.
|
484
|
+
"""
|
485
|
+
current_mesh = mesh_override or mesh
|
486
|
+
|
487
|
+
transformations = [
|
488
|
+
augmenter(),
|
489
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
490
|
+
]
|
491
|
+
|
492
|
+
# # Add mesh mapping if needed
|
493
|
+
# if current_mesh is not None:
|
494
|
+
# transformations.append(batch_mesh_map(current_mesh)())
|
495
|
+
|
496
|
+
loader = pygrain.DataLoader(
|
497
|
+
data_source=data_source,
|
498
|
+
sampler=sampler,
|
499
|
+
operations=transformations,
|
500
|
+
worker_count=worker_count,
|
501
|
+
read_options=pygrain.ReadOptions(
|
502
|
+
read_thread_count, read_buffer_size
|
503
|
+
),
|
504
|
+
worker_buffer_size=worker_buffer_size,
|
505
|
+
)
|
506
|
+
return loader
|
507
|
+
|
508
|
+
return {
|
509
|
+
"train": get_trainset,
|
510
|
+
"train_len": dataset_length,
|
511
|
+
"local_batch_size": local_batch_size,
|
512
|
+
"global_batch_size": batch_size,
|
513
|
+
"media_type": media_type,
|
514
|
+
}
|
515
|
+
|
516
|
+
|
517
|
+
def get_media_dataset_online(
|
518
|
+
data_name: str = "combined_online",
|
519
|
+
batch_size: int = 64,
|
520
|
+
media_scale: int = 256,
|
521
|
+
worker_count: int = 16,
|
522
|
+
read_thread_count: int = 512,
|
523
|
+
worker_buffer_size: int = 20,
|
524
|
+
dataset_sources: List[str] = None,
|
525
|
+
media_type: str = "image", # Default to image for online datasets
|
526
|
+
mesh: Optional[Mesh] = None,
|
527
|
+
timeout: int = 15,
|
528
|
+
retries: int = 3,
|
529
|
+
min_media_scale: int = 128,
|
530
|
+
):
|
531
|
+
"""Get an online streaming dataset loader for any media type.
|
532
|
+
|
533
|
+
Args:
|
534
|
+
data_name: Name of the dataset in onlineDatasetMap, or "custom" for custom sources.
|
535
|
+
batch_size: Batch size for the dataset.
|
536
|
+
media_scale: Size to scale media (image or video frames) to.
|
537
|
+
worker_count: Number of worker processes.
|
538
|
+
read_thread_count: Number of read threads.
|
539
|
+
worker_buffer_size: Size of the worker buffer.
|
540
|
+
dataset_sources: Custom dataset sources if data_name is "custom".
|
541
|
+
media_type: Type of media ("image" or "video").
|
542
|
+
mesh: Optional JAX mesh for distributed training.
|
543
|
+
timeout: Timeout for dataset operations.
|
544
|
+
retries: Number of retries for dataset operations.
|
545
|
+
min_media_scale: Minimum scale for media items.
|
546
|
+
|
547
|
+
Returns:
|
548
|
+
Dictionary with train dataset function and metadata.
|
549
|
+
"""
|
550
|
+
local_batch_size = batch_size // jax.process_count()
|
551
|
+
|
552
|
+
# Get dataset sources
|
553
|
+
if dataset_sources is None:
|
554
|
+
if data_name not in onlineDatasetMap:
|
555
|
+
raise ValueError(f"Dataset {data_name} not found in onlineDatasetMap")
|
556
|
+
sources = onlineDatasetMap[data_name]["source"]
|
557
|
+
else:
|
558
|
+
sources = dataset_sources
|
559
|
+
|
560
|
+
# Configure shape parameter based on media type
|
561
|
+
shape_param = "image_shape" if media_type == "image" else "frame_size"
|
562
|
+
shape_value = (media_scale, media_scale) if media_type == "image" else media_scale
|
563
|
+
|
564
|
+
# Configure min scale parameter based on media type
|
565
|
+
min_scale_param = "min_image_shape" if media_type == "image" else "min_frame_size"
|
566
|
+
min_scale_value = (min_media_scale, min_media_scale) if media_type == "image" else min_media_scale
|
567
|
+
|
568
|
+
# Prepare dataloader kwargs
|
569
|
+
dataloader_kwargs = {
|
570
|
+
"batch_size": local_batch_size,
|
571
|
+
"num_workers": worker_count,
|
572
|
+
"num_threads": read_thread_count,
|
573
|
+
shape_param: shape_value,
|
574
|
+
min_scale_param: min_scale_value,
|
575
|
+
"global_process_count": jax.process_count(),
|
576
|
+
"global_process_index": jax.process_index(),
|
577
|
+
"prefetch": worker_buffer_size,
|
578
|
+
"collate_fn": generate_collate_fn(media_type),
|
579
|
+
"default_split": "train",
|
580
|
+
"timeout": timeout,
|
581
|
+
"retries": retries,
|
582
|
+
}
|
583
|
+
|
584
|
+
dataloader = OnlineStreamingDataLoader(sources, **dataloader_kwargs)
|
585
|
+
|
586
|
+
def get_trainset(mesh_override: Optional[Mesh] = None):
|
587
|
+
"""Get a training dataset iterator.
|
588
|
+
|
589
|
+
Args:
|
590
|
+
mesh_override: Optional mesh to override the default.
|
591
|
+
|
592
|
+
Returns:
|
593
|
+
A dataset iterator.
|
594
|
+
"""
|
595
|
+
current_mesh = mesh_override or mesh
|
596
|
+
|
597
|
+
if current_mesh is not None:
|
598
|
+
return DataLoaderWithMesh(dataloader, current_mesh, buffer_size=worker_buffer_size)
|
599
|
+
|
600
|
+
return dataloader
|
601
|
+
|
602
|
+
return {
|
603
|
+
"train": get_trainset,
|
604
|
+
"train_len": len(dataloader) * jax.process_count(),
|
605
|
+
"local_batch_size": local_batch_size,
|
606
|
+
"global_batch_size": batch_size,
|
607
|
+
"media_type": media_type,
|
608
|
+
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -1,5 +1,14 @@
|
|
1
|
-
from .sources.
|
2
|
-
from .sources.
|
1
|
+
from .sources.base import MediaDataset, DataSource, DataAugmenter
|
2
|
+
from .sources.images import ImageTFDSSource, ImageGCSSource, CombinedImageGCSSource
|
3
|
+
from .sources.images import ImageTFDSAugmenter, ImageGCSAugmenter
|
4
|
+
from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugmenter
|
5
|
+
|
6
|
+
# ---------------------------------------------------------------------------------
|
7
|
+
# Legacy compatibility mappings
|
8
|
+
# ---------------------------------------------------------------------------------
|
9
|
+
|
10
|
+
from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
|
11
|
+
from .sources.images import data_source_combined_gcs, gcs_augmenters
|
3
12
|
|
4
13
|
# Configure the following for your datasets
|
5
14
|
datasetMap = {
|
@@ -50,9 +59,6 @@ datasetMap = {
|
|
50
59
|
onlineDatasetMap = {
|
51
60
|
"combined_online": {
|
52
61
|
"source": [
|
53
|
-
# "gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017.parquet"
|
54
|
-
# "ChristophSchuhmann/MS_COCO_2017_URL_TEXT",
|
55
|
-
# "dclure/laion-aesthetics-12m-umap",
|
56
62
|
"gs://flaxdiff-datasets-regional/datasets/laion-aesthetics-12m+mscoco-2017",
|
57
63
|
"gs://flaxdiff-datasets-regional/datasets/coyo700m-aesthetic-5.4_25M",
|
58
64
|
"gs://flaxdiff-datasets-regional/datasets/leonardo-liked-1.8m",
|
@@ -65,7 +71,56 @@ onlineDatasetMap = {
|
|
65
71
|
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
66
72
|
"gs://flaxdiff-datasets-regional/datasets/cc3m",
|
67
73
|
"gs://flaxdiff-datasets-regional/datasets/laion2B-en-aesthetic-4.2_37M",
|
68
|
-
# "gs://flaxdiff-datasets-regional/datasets/laiion400m-185M"
|
69
74
|
]
|
70
75
|
}
|
76
|
+
}
|
77
|
+
|
78
|
+
# ---------------------------------------------------------------------------------
|
79
|
+
# New media datasets configuration with the unified architecture
|
80
|
+
# ---------------------------------------------------------------------------------
|
81
|
+
|
82
|
+
mediaDatasetMap = {
|
83
|
+
# Image datasets
|
84
|
+
"oxford_flowers102": MediaDataset(
|
85
|
+
source=ImageTFDSSource(name="oxford_flowers102", use_tf=False),
|
86
|
+
augmenter=ImageTFDSAugmenter(),
|
87
|
+
media_type="image"
|
88
|
+
),
|
89
|
+
"cc12m": MediaDataset(
|
90
|
+
source=ImageGCSSource(source='arrayrecord2/cc12m'),
|
91
|
+
augmenter=ImageGCSAugmenter(),
|
92
|
+
media_type="image"
|
93
|
+
),
|
94
|
+
"laiona_coco": MediaDataset(
|
95
|
+
source=ImageGCSSource(source='arrayrecord2/laion-aesthetics-12m+mscoco-2017'),
|
96
|
+
augmenter=ImageGCSAugmenter(),
|
97
|
+
media_type="image"
|
98
|
+
),
|
99
|
+
"combined_aesthetic": MediaDataset(
|
100
|
+
source=CombinedImageGCSSource(sources=[
|
101
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
102
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
103
|
+
'arrayrecord2/cc12m',
|
104
|
+
'arrayrecords/aestheticCoyo_0.25clip_6aesthetic',
|
105
|
+
]),
|
106
|
+
augmenter=ImageGCSAugmenter(),
|
107
|
+
media_type="image"
|
108
|
+
),
|
109
|
+
"combined_30m": MediaDataset(
|
110
|
+
source=CombinedImageGCSSource(sources=[
|
111
|
+
'arrayrecord2/laion-aesthetics-12m+mscoco-2017',
|
112
|
+
'arrayrecord2/cc12m',
|
113
|
+
'arrayrecord2/aestheticCoyo_0.26_clip_5.5aesthetic_256plus',
|
114
|
+
"arrayrecord2/playground+leonardo_x4+cc3m.parquet",
|
115
|
+
]),
|
116
|
+
augmenter=ImageGCSAugmenter(),
|
117
|
+
media_type="image"
|
118
|
+
),
|
119
|
+
|
120
|
+
# Video dataset
|
121
|
+
"voxceleb2": MediaDataset(
|
122
|
+
source=VideoLocalSource(),
|
123
|
+
augmenter=AudioVideoAugmenter(),
|
124
|
+
media_type="video"
|
125
|
+
),
|
71
126
|
}
|