flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
251
251
  else: # Default to image
252
252
  return image_collate
253
253
 
254
+ class CaptionDeletionTransform(pygrain.MapTransform):
255
+ def map(self, element):
256
+ """Delete the caption from the element."""
257
+ if "caption" in element:
258
+ del element["caption"]
259
+ return element
254
260
 
255
261
  def get_dataset_grain(
256
262
  data_name="cc12m",
@@ -288,7 +294,6 @@ def get_dataset_grain(
288
294
  dataset = datasetMap[data_name]
289
295
  data_source = dataset["source"](dataset_source)
290
296
  augmenter = dataset["augmenter"](image_scale, method)
291
-
292
297
  local_batch_size = batch_size // jax.process_count()
293
298
 
294
299
  train_sampler = pygrain.IndexSampler(
@@ -310,8 +315,8 @@ def get_dataset_grain(
310
315
  def get_trainset():
311
316
  transformations = [
312
317
  augmenter(),
313
- pygrain.Batch(local_batch_size, drop_remainder=True),
314
318
  ]
319
+ transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
315
320
 
316
321
  loader = pygrain.DataLoader(
317
322
  data_source=data_source,
@@ -325,24 +330,23 @@ def get_dataset_grain(
325
330
  )
326
331
  return loader
327
332
 
328
- # def get_valset():
329
- # transformations = [
330
- # augmenter(),
331
- # pygrain.Batch(local_batch_size, drop_remainder=True),
332
- # ]
333
+ def get_valset():
334
+ transformations = [
335
+ augmenter(),
336
+ pygrain.Batch(32, drop_remainder=True),
337
+ ]
333
338
 
334
- # loader = pygrain.DataLoader(
335
- # data_source=data_source,
336
- # sampler=val_sampler,
337
- # operations=transformations,
338
- # worker_count=worker_count,
339
- # read_options=pygrain.ReadOptions(
340
- # read_thread_count, read_buffer_size
341
- # ),
342
- # worker_buffer_size=worker_buffer_size,
343
- # )
344
- # return loader
345
- get_valset = get_trainset # For now, use the same function for validation
339
+ loader = pygrain.DataLoader(
340
+ data_source=data_source,
341
+ sampler=train_sampler,
342
+ operations=transformations,
343
+ worker_count=8,
344
+ read_options=pygrain.ReadOptions(
345
+ 32, 128
346
+ ),
347
+ worker_buffer_size=32,
348
+ )
349
+ return loader
346
350
 
347
351
  return {
348
352
  "train": get_trainset,
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
8
8
  # ---------------------------------------------------------------------------------
9
9
 
10
10
  from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
11
- from .sources.images import data_source_combined_gcs, gcs_augmenters
11
+ from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
12
12
 
13
13
  # Configure the following for your datasets
14
14
  datasetMap = {
@@ -23,6 +23,7 @@ datasetMap = {
23
23
  "laiona_coco": {
24
24
  "source": data_source_gcs('datasets/laion12m+mscoco'),
25
25
  "augmenter": gcs_augmenters,
26
+ "filter": gcs_filters,
26
27
  },
27
28
  "aesthetic_coyo": {
28
29
  "source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
62
62
  """
63
63
  pass
64
64
 
65
+ @abstractmethod
66
+ def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
67
+ """Create a filter function for the data.
68
+
69
+ Args:
70
+ **kwargs: Additional arguments for the filter.
71
+
72
+ Returns:
73
+ A callable that returns a pygrain.FilterTransform instance.
74
+ """
75
+ pass
76
+
65
77
  @staticmethod
66
78
  def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
67
79
  """Factory method to create a data augmenter of the specified type.
@@ -11,7 +11,7 @@ import struct as st
11
11
  from functools import partial
12
12
  import numpy as np
13
13
  from .base import DataSource, DataAugmenter
14
-
14
+ import traceback
15
15
 
16
16
  # ----------------------------------------------------------------------------------
17
17
  # Utility functions
@@ -79,6 +79,24 @@ def labelizer_oxford_flowers102(path):
79
79
  # TFDS Image Source
80
80
  # ----------------------------------------------------------------------------------
81
81
 
82
+ def get_oxford_valset(text_encoder):
83
+ # Construct a validation set by the prompts for consistency
84
+ val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
85
+ val_prompts *= 100
86
+
87
+ def get_val_dataset(batch_size=128):
88
+ for i in range(0, len(val_prompts), batch_size):
89
+ try:
90
+ prompts = val_prompts[i:i + batch_size]
91
+ tokens = text_encoder.tokenize(prompts)
92
+ yield {"text": tokens}
93
+ except Exception as e:
94
+ print(f"Error in get_val_dataset: {e}")
95
+ traceback.print_exc()
96
+ continue
97
+
98
+ return get_val_dataset, len(val_prompts)
99
+
82
100
  class ImageTFDSSource(DataSource):
83
101
  """Data source for TensorFlow Datasets (TFDS) image datasets."""
84
102
 
@@ -168,7 +186,11 @@ class ImageTFDSAugmenter(DataAugmenter):
168
186
  }
169
187
 
170
188
  return TFDSTransform
171
-
189
+
190
+ def create_filter(self, image_scale: int = 256):
191
+ class FilterTransform(pygrain.FilterTransform):
192
+ def map(self, element) -> bool:
193
+ return True
172
194
  """
173
195
  Batch structure:
174
196
  {
@@ -237,7 +259,6 @@ class CombinedImageGCSSource(DataSource):
237
259
  records_path) if 'array_record' in i]
238
260
  return pygrain.ArrayRecordDataSource(records)
239
261
 
240
-
241
262
  class ImageGCSAugmenter(DataAugmenter):
242
263
  """Augmenter for GCS image datasets."""
243
264
 
@@ -297,6 +318,52 @@ class ImageGCSAugmenter(DataAugmenter):
297
318
  }
298
319
 
299
320
  return GCSTransform
321
+
322
+ def create_filter(self, image_scale: int = 256):
323
+ import torch.nn.functional as F
324
+ class FilterTransform(pygrain.FilterTransform):
325
+ """
326
+ Filter transform for GCS data source.
327
+ """
328
+ def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
329
+ super().__init__()
330
+ self.image_scale = image_scale
331
+ if model is None:
332
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
333
+ model_name = "openai/clip-vit-base-patch32"
334
+ model = CLIPModel.from_pretrained(model_name)
335
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
336
+ self.method = method
337
+ self.model = model
338
+ self.processor = processor
339
+
340
+ # def _filter_(pixel_values, input_ids):
341
+ # image_embeds = self.model.get_image_features(pixel_values=pixel_values)
342
+ # text_embeds = self.model.get_text_features(input_ids=input_ids)
343
+ # image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
344
+ # text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
345
+ # similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
346
+ # return jnp.all(similarity >= 0.25)
347
+
348
+ # self._filter_ = _filter_
349
+
350
+ def filter(self, data: Dict[str, Any]) -> bool:
351
+ images = [data['image']]
352
+ texts = [data['caption']]
353
+ inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
354
+ # result = self._filter_(
355
+ # pixel_values=inputs['pixel_values'],
356
+ # input_ids=inputs['input_ids']
357
+ # )
358
+ # return result
359
+
360
+ image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
361
+ text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
362
+ similarity = F.cosine_similarity(image_embeds, text_embeds)
363
+ # Filter out images with similarity less than 0.25
364
+ return similarity[0] >= 0.25
365
+
366
+ return FilterTransform
300
367
 
301
368
 
302
369
  # ----------------------------------------------------------------------------------
@@ -333,3 +400,8 @@ def gcs_augmenters(image_scale, method):
333
400
  """Legacy function for GCS augmenters."""
334
401
  augmenter = ImageGCSAugmenter()
335
402
  return augmenter.create_transform(image_scale=image_scale, method=method)
403
+
404
+ def gcs_filters(image_scale):
405
+ """Legacy function for GCS Filters."""
406
+ augmenter = ImageGCSAugmenter()
407
+ return augmenter.create_filter(image_scale=image_scale)
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
216
216
 
217
217
  return AudioVideoTransform
218
218
 
219
+
220
+ def create_filter(self, image_scale: int = 256):
221
+ class FilterTransform(pygrain.FilterTransform):
222
+ def map(self, element) -> bool:
223
+ return True
219
224
 
220
225
  # ----------------------------------------------------------------------------------
221
226
  # Helper functions for video datasets
@@ -25,6 +25,9 @@ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
25
25
  from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
26
26
  from flaxdiff.utils import defaultTextEncodeModel
27
27
 
28
+ from flaxdiff.models.simple_vit import UViT, SimpleUDiT
29
+ from flaxdiff.models.simple_dit import SimpleDiT
30
+ from flaxdiff.models.simple_mmdit import SimpleMMDiT, HierarchicalMMDiT
28
31
  from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
29
32
  import os
30
33
 
@@ -116,7 +119,10 @@ def parse_config(config, overrides=None):
116
119
  MODEL_CLASSES = {
117
120
  'unet': Unet,
118
121
  'uvit': UViT,
119
- 'diffusers_unet_simple': FlaxUNet2DConditionModel
122
+ 'diffusers_unet_simple': FlaxUNet2DConditionModel,
123
+ 'simple_dit': SimpleDiT,
124
+ 'simple_uvit': SimpleUDiT,
125
+ 'simple_mmdit': SimpleMMDiT,
120
126
  }
121
127
 
122
128
  # Map all the leaves of the model config, converting strings to appropriate types
flaxdiff/models/common.py CHANGED
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
335
335
 
336
336
  out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
337
337
 
338
- return out
339
-
340
- # Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
341
- def _d2xy(n, d):
342
- x = 0
343
- y = 0
344
- t = d
345
- s = 1
346
- while s < n:
347
- rx = (t // 2) & 1
348
- ry = (t ^ rx) & 1
349
- if ry == 0:
350
- if rx == 1:
351
- x = n - 1 - x
352
- y = n - 1 - y
353
- x, y = y, x
354
- x += s * rx
355
- y += s * ry
356
- t //= 4
357
- s *= 2
358
- return x, y
359
-
360
- # Hilbert index mapping for a rectangular grid of patches H_P x W_P
361
-
362
- def hilbert_indices(H_P, W_P):
363
- size = max(H_P, W_P)
364
- order = math.ceil(math.log2(size))
365
- n = 1 << order
366
- coords = []
367
- for d in range(n * n):
368
- x, y = _d2xy(n, d)
369
- # x is column index, y is row index
370
- if x < W_P and y < H_P:
371
- coords.append((y, x)) # (row, col)
372
- if len(coords) == H_P * W_P:
373
- break
374
- # Convert (row, col) to linear indices row-major
375
- indices = [r * W_P + c for r, c in coords]
376
- return jnp.array(indices, dtype=jnp.int32)
377
-
378
- # Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
379
-
380
- def inverse_permutation(idx):
381
- inv = jnp.zeros_like(idx)
382
- inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
383
- return inv
384
-
385
- # Patchify using Hilbert ordering: extract patches and reorder sequence
386
-
387
- def hilbert_patchify(x, patch_size):
388
- B, H, W, C = x.shape
389
- H_P = H // patch_size
390
- W_P = W // patch_size
391
- # Extract patches in row-major
392
- patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
393
- idx = hilbert_indices(H_P, W_P)
394
- return patches[:, idx, :]
395
-
396
- # Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
397
-
398
- def hilbert_unpatchify(patches, patch_size, H, W, C):
399
- B, N, D = patches.shape
400
- H_P = H // patch_size
401
- W_P = W // patch_size
402
- inv = inverse_permutation(hilbert_indices(H_P, W_P))
403
- # Reorder back to row-major
404
- linear = patches[:, inv, :]
405
- # Reconstruct image
406
- x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
407
- return x
338
+ return out