flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +36 -24
- flaxdiff/data/dataset_map.py +2 -2
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +68 -11
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +476 -0
- flaxdiff/models/simple_mmdit.py +861 -0
- flaxdiff/models/simple_vit.py +278 -117
- flaxdiff/trainer/general_diffusion_trainer.py +29 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/RECORD +16 -14
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/top_level.txt +0 -0
flaxdiff/data/dataloaders.py
CHANGED
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
|
|
251
251
|
else: # Default to image
|
252
252
|
return image_collate
|
253
253
|
|
254
|
+
class CaptionDeletionTransform(pygrain.MapTransform):
|
255
|
+
def map(self, element):
|
256
|
+
"""Delete the caption from the element."""
|
257
|
+
if "caption" in element:
|
258
|
+
del element["caption"]
|
259
|
+
return element
|
254
260
|
|
255
261
|
def get_dataset_grain(
|
256
262
|
data_name="cc12m",
|
@@ -286,13 +292,14 @@ def get_dataset_grain(
|
|
286
292
|
Dictionary with train dataset function and metadata.
|
287
293
|
"""
|
288
294
|
dataset = datasetMap[data_name]
|
289
|
-
|
295
|
+
train_source = dataset["source"](dataset_source, split="train")
|
296
|
+
# val_source = dataset["source"](dataset_source, split="val")
|
290
297
|
augmenter = dataset["augmenter"](image_scale, method)
|
291
298
|
|
292
299
|
local_batch_size = batch_size // jax.process_count()
|
293
300
|
|
294
301
|
train_sampler = pygrain.IndexSampler(
|
295
|
-
num_records=len(
|
302
|
+
num_records=len(train_source) if count is None else count,
|
296
303
|
shuffle=True,
|
297
304
|
seed=seed,
|
298
305
|
num_epochs=num_epochs,
|
@@ -300,7 +307,7 @@ def get_dataset_grain(
|
|
300
307
|
)
|
301
308
|
|
302
309
|
# val_sampler = pygrain.IndexSampler(
|
303
|
-
# num_records=len(
|
310
|
+
# num_records=len(val_source) if count is None else count,
|
304
311
|
# shuffle=False,
|
305
312
|
# seed=seed,
|
306
313
|
# num_epochs=num_epochs,
|
@@ -310,11 +317,17 @@ def get_dataset_grain(
|
|
310
317
|
def get_trainset():
|
311
318
|
transformations = [
|
312
319
|
augmenter(),
|
313
|
-
pygrain.Batch(local_batch_size, drop_remainder=True),
|
314
320
|
]
|
321
|
+
|
322
|
+
# if filters:
|
323
|
+
# print("Adding filters to transformations")
|
324
|
+
# transformations.append(filters())
|
325
|
+
|
326
|
+
# transformations.append(CaptionDeletionTransform())
|
327
|
+
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
|
315
328
|
|
316
329
|
loader = pygrain.DataLoader(
|
317
|
-
data_source=
|
330
|
+
data_source=train_source,
|
318
331
|
sampler=train_sampler,
|
319
332
|
operations=transformations,
|
320
333
|
worker_count=worker_count,
|
@@ -325,30 +338,29 @@ def get_dataset_grain(
|
|
325
338
|
)
|
326
339
|
return loader
|
327
340
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
341
|
+
def get_valset():
|
342
|
+
transformations = [
|
343
|
+
augmenter(),
|
344
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
345
|
+
]
|
333
346
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
get_valset = get_trainset # For now, use the same function for validation
|
347
|
+
loader = pygrain.DataLoader(
|
348
|
+
data_source=train_source,
|
349
|
+
sampler=train_sampler,
|
350
|
+
operations=transformations,
|
351
|
+
worker_count=2,
|
352
|
+
read_options=pygrain.ReadOptions(
|
353
|
+
read_thread_count, read_buffer_size
|
354
|
+
),
|
355
|
+
worker_buffer_size=2,
|
356
|
+
)
|
357
|
+
return loader
|
346
358
|
|
347
359
|
return {
|
348
360
|
"train": get_trainset,
|
349
|
-
"train_len": len(
|
361
|
+
"train_len": len(train_source),
|
350
362
|
"val": get_valset,
|
351
|
-
"val_len": len(
|
363
|
+
"val_len": len(train_source),
|
352
364
|
"local_batch_size": local_batch_size,
|
353
365
|
"global_batch_size": batch_size,
|
354
366
|
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
|
|
8
8
|
# ---------------------------------------------------------------------------------
|
9
9
|
|
10
10
|
from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
|
11
|
-
from .sources.images import data_source_combined_gcs, gcs_augmenters
|
11
|
+
from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
|
12
12
|
|
13
13
|
# Configure the following for your datasets
|
14
14
|
datasetMap = {
|
@@ -21,7 +21,7 @@ datasetMap = {
|
|
21
21
|
"augmenter": gcs_augmenters,
|
22
22
|
},
|
23
23
|
"laiona_coco": {
|
24
|
-
"source": data_source_gcs('datasets/laion12m+
|
24
|
+
"source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
26
|
},
|
27
27
|
"aesthetic_coyo": {
|
flaxdiff/data/sources/base.py
CHANGED
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
|
|
62
62
|
"""
|
63
63
|
pass
|
64
64
|
|
65
|
+
@abstractmethod
|
66
|
+
def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
|
67
|
+
"""Create a filter function for the data.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
**kwargs: Additional arguments for the filter.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A callable that returns a pygrain.FilterTransform instance.
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
65
77
|
@staticmethod
|
66
78
|
def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
|
67
79
|
"""Factory method to create a data augmenter of the specified type.
|
flaxdiff/data/sources/images.py
CHANGED
@@ -82,7 +82,7 @@ def labelizer_oxford_flowers102(path):
|
|
82
82
|
class ImageTFDSSource(DataSource):
|
83
83
|
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
|
84
84
|
|
85
|
-
def __init__(self, name: str, use_tf: bool = True
|
85
|
+
def __init__(self, name: str, use_tf: bool = True):
|
86
86
|
"""Initialize a TFDS image data source.
|
87
87
|
|
88
88
|
Args:
|
@@ -92,9 +92,8 @@ class ImageTFDSSource(DataSource):
|
|
92
92
|
"""
|
93
93
|
self.name = name
|
94
94
|
self.use_tf = use_tf
|
95
|
-
self.split = split
|
96
95
|
|
97
|
-
def get_source(self, path_override: str) -> Any:
|
96
|
+
def get_source(self, path_override: str, split: str = "all") -> Any:
|
98
97
|
"""Get the TFDS data source.
|
99
98
|
|
100
99
|
Args:
|
@@ -105,9 +104,9 @@ class ImageTFDSSource(DataSource):
|
|
105
104
|
"""
|
106
105
|
import tensorflow_datasets as tfds
|
107
106
|
if self.use_tf:
|
108
|
-
return tfds.load(self.name, split=
|
107
|
+
return tfds.load(self.name, split=split, shuffle_files=True)
|
109
108
|
else:
|
110
|
-
return tfds.data_source(self.name, split=
|
109
|
+
return tfds.data_source(self.name, split=split, try_gcs=False)
|
111
110
|
|
112
111
|
|
113
112
|
class ImageTFDSAugmenter(DataAugmenter):
|
@@ -168,7 +167,11 @@ class ImageTFDSAugmenter(DataAugmenter):
|
|
168
167
|
}
|
169
168
|
|
170
169
|
return TFDSTransform
|
171
|
-
|
170
|
+
|
171
|
+
def create_filter(self, image_scale: int = 256):
|
172
|
+
class FilterTransform(pygrain.FilterTransform):
|
173
|
+
def map(self, element) -> bool:
|
174
|
+
return True
|
172
175
|
"""
|
173
176
|
Batch structure:
|
174
177
|
{
|
@@ -195,7 +198,7 @@ class ImageGCSSource(DataSource):
|
|
195
198
|
"""
|
196
199
|
self.source = source
|
197
200
|
|
198
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
201
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
|
199
202
|
"""Get the GCS data source.
|
200
203
|
|
201
204
|
Args:
|
@@ -207,6 +210,8 @@ class ImageGCSSource(DataSource):
|
|
207
210
|
records_path = os.path.join(path_override, self.source)
|
208
211
|
records = [os.path.join(records_path, i) for i in os.listdir(
|
209
212
|
records_path) if 'array_record' in i]
|
213
|
+
if split == "val":
|
214
|
+
records = records[:1]
|
210
215
|
return pygrain.ArrayRecordDataSource(records)
|
211
216
|
|
212
217
|
|
@@ -221,7 +226,7 @@ class CombinedImageGCSSource(DataSource):
|
|
221
226
|
"""
|
222
227
|
self.sources = sources
|
223
228
|
|
224
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
229
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
|
225
230
|
"""Get the combined GCS data source.
|
226
231
|
|
227
232
|
Args:
|
@@ -235,9 +240,10 @@ class CombinedImageGCSSource(DataSource):
|
|
235
240
|
for records_path in records_paths:
|
236
241
|
records += [os.path.join(records_path, i) for i in os.listdir(
|
237
242
|
records_path) if 'array_record' in i]
|
243
|
+
if split == "val":
|
244
|
+
records = records[:1]
|
238
245
|
return pygrain.ArrayRecordDataSource(records)
|
239
246
|
|
240
|
-
|
241
247
|
class ImageGCSAugmenter(DataAugmenter):
|
242
248
|
"""Augmenter for GCS image datasets."""
|
243
249
|
|
@@ -297,6 +303,52 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
297
303
|
}
|
298
304
|
|
299
305
|
return GCSTransform
|
306
|
+
|
307
|
+
def create_filter(self, image_scale: int = 256):
|
308
|
+
import torch.nn.functional as F
|
309
|
+
class FilterTransform(pygrain.FilterTransform):
|
310
|
+
"""
|
311
|
+
Filter transform for GCS data source.
|
312
|
+
"""
|
313
|
+
def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
|
314
|
+
super().__init__()
|
315
|
+
self.image_scale = image_scale
|
316
|
+
if model is None:
|
317
|
+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
|
318
|
+
model_name = "openai/clip-vit-base-patch32"
|
319
|
+
model = CLIPModel.from_pretrained(model_name)
|
320
|
+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
|
321
|
+
self.method = method
|
322
|
+
self.model = model
|
323
|
+
self.processor = processor
|
324
|
+
|
325
|
+
# def _filter_(pixel_values, input_ids):
|
326
|
+
# image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
327
|
+
# text_embeds = self.model.get_text_features(input_ids=input_ids)
|
328
|
+
# image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
329
|
+
# text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
330
|
+
# similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
|
331
|
+
# return jnp.all(similarity >= 0.25)
|
332
|
+
|
333
|
+
# self._filter_ = _filter_
|
334
|
+
|
335
|
+
def filter(self, data: Dict[str, Any]) -> bool:
|
336
|
+
images = [data['image']]
|
337
|
+
texts = [data['caption']]
|
338
|
+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
|
339
|
+
# result = self._filter_(
|
340
|
+
# pixel_values=inputs['pixel_values'],
|
341
|
+
# input_ids=inputs['input_ids']
|
342
|
+
# )
|
343
|
+
# return result
|
344
|
+
|
345
|
+
image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
|
346
|
+
text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
|
347
|
+
similarity = F.cosine_similarity(image_embeds, text_embeds)
|
348
|
+
# Filter out images with similarity less than 0.25
|
349
|
+
return similarity[0] >= 0.25
|
350
|
+
|
351
|
+
return FilterTransform
|
300
352
|
|
301
353
|
|
302
354
|
# ----------------------------------------------------------------------------------
|
@@ -305,9 +357,9 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
305
357
|
|
306
358
|
# These functions maintain backward compatibility with existing code
|
307
359
|
|
308
|
-
def data_source_tfds(name, use_tf=True
|
360
|
+
def data_source_tfds(name, use_tf=True):
|
309
361
|
"""Legacy function for TFDS data sources."""
|
310
|
-
source = ImageTFDSSource(name=name, use_tf=use_tf
|
362
|
+
source = ImageTFDSSource(name=name, use_tf=use_tf)
|
311
363
|
return source.get_source
|
312
364
|
|
313
365
|
|
@@ -333,3 +385,8 @@ def gcs_augmenters(image_scale, method):
|
|
333
385
|
"""Legacy function for GCS augmenters."""
|
334
386
|
augmenter = ImageGCSAugmenter()
|
335
387
|
return augmenter.create_transform(image_scale=image_scale, method=method)
|
388
|
+
|
389
|
+
def gcs_filters(image_scale):
|
390
|
+
"""Legacy function for GCS Filters."""
|
391
|
+
augmenter = ImageGCSAugmenter()
|
392
|
+
return augmenter.create_filter(image_scale=image_scale)
|
flaxdiff/data/sources/videos.py
CHANGED
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
|
|
216
216
|
|
217
217
|
return AudioVideoTransform
|
218
218
|
|
219
|
+
|
220
|
+
def create_filter(self, image_scale: int = 256):
|
221
|
+
class FilterTransform(pygrain.FilterTransform):
|
222
|
+
def map(self, element) -> bool:
|
223
|
+
return True
|
219
224
|
|
220
225
|
# ----------------------------------------------------------------------------------
|
221
226
|
# Helper functions for video datasets
|
flaxdiff/models/common.py
CHANGED
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
|
|
335
335
|
|
336
336
|
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
337
337
|
|
338
|
-
return out
|
339
|
-
|
340
|
-
# Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
|
341
|
-
def _d2xy(n, d):
|
342
|
-
x = 0
|
343
|
-
y = 0
|
344
|
-
t = d
|
345
|
-
s = 1
|
346
|
-
while s < n:
|
347
|
-
rx = (t // 2) & 1
|
348
|
-
ry = (t ^ rx) & 1
|
349
|
-
if ry == 0:
|
350
|
-
if rx == 1:
|
351
|
-
x = n - 1 - x
|
352
|
-
y = n - 1 - y
|
353
|
-
x, y = y, x
|
354
|
-
x += s * rx
|
355
|
-
y += s * ry
|
356
|
-
t //= 4
|
357
|
-
s *= 2
|
358
|
-
return x, y
|
359
|
-
|
360
|
-
# Hilbert index mapping for a rectangular grid of patches H_P x W_P
|
361
|
-
|
362
|
-
def hilbert_indices(H_P, W_P):
|
363
|
-
size = max(H_P, W_P)
|
364
|
-
order = math.ceil(math.log2(size))
|
365
|
-
n = 1 << order
|
366
|
-
coords = []
|
367
|
-
for d in range(n * n):
|
368
|
-
x, y = _d2xy(n, d)
|
369
|
-
# x is column index, y is row index
|
370
|
-
if x < W_P and y < H_P:
|
371
|
-
coords.append((y, x)) # (row, col)
|
372
|
-
if len(coords) == H_P * W_P:
|
373
|
-
break
|
374
|
-
# Convert (row, col) to linear indices row-major
|
375
|
-
indices = [r * W_P + c for r, c in coords]
|
376
|
-
return jnp.array(indices, dtype=jnp.int32)
|
377
|
-
|
378
|
-
# Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
|
379
|
-
|
380
|
-
def inverse_permutation(idx):
|
381
|
-
inv = jnp.zeros_like(idx)
|
382
|
-
inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
|
383
|
-
return inv
|
384
|
-
|
385
|
-
# Patchify using Hilbert ordering: extract patches and reorder sequence
|
386
|
-
|
387
|
-
def hilbert_patchify(x, patch_size):
|
388
|
-
B, H, W, C = x.shape
|
389
|
-
H_P = H // patch_size
|
390
|
-
W_P = W // patch_size
|
391
|
-
# Extract patches in row-major
|
392
|
-
patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
|
393
|
-
idx = hilbert_indices(H_P, W_P)
|
394
|
-
return patches[:, idx, :]
|
395
|
-
|
396
|
-
# Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
|
397
|
-
|
398
|
-
def hilbert_unpatchify(patches, patch_size, H, W, C):
|
399
|
-
B, N, D = patches.shape
|
400
|
-
H_P = H // patch_size
|
401
|
-
W_P = W // patch_size
|
402
|
-
inv = inverse_permutation(hilbert_indices(H_P, W_P))
|
403
|
-
# Reorder back to row-major
|
404
|
-
linear = patches[:, inv, :]
|
405
|
-
# Reconstruct image
|
406
|
-
x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
|
407
|
-
return x
|
338
|
+
return out
|