flaxdiff 0.2.7__tar.gz → 0.2.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (77) hide show
  1. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/PKG-INFO +1 -1
  2. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/dataloaders.py +36 -24
  3. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/dataset_map.py +2 -2
  4. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/base.py +12 -0
  5. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/images.py +68 -11
  6. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/videos.py +5 -0
  7. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/common.py +1 -70
  8. flaxdiff-0.2.8/flaxdiff/models/hilbert.py +617 -0
  9. flaxdiff-0.2.8/flaxdiff/models/simple_dit.py +476 -0
  10. flaxdiff-0.2.8/flaxdiff/models/simple_mmdit.py +861 -0
  11. flaxdiff-0.2.8/flaxdiff/models/simple_vit.py +347 -0
  12. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/trainer/general_diffusion_trainer.py +29 -10
  13. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/trainer/simple_trainer.py +113 -19
  14. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff.egg-info/PKG-INFO +1 -1
  15. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff.egg-info/SOURCES.txt +3 -1
  16. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/pyproject.toml +1 -1
  17. flaxdiff-0.2.7/flaxdiff/models/better_uvit.py +0 -380
  18. flaxdiff-0.2.7/flaxdiff/models/simple_vit.py +0 -186
  19. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/README.md +0 -0
  20. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/__init__.py +0 -0
  21. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/__init__.py +0 -0
  22. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/benchmark_decord.py +0 -0
  23. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/online_loader.py +0 -0
  24. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/audio_utils.py +0 -0
  25. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/av_example.py +0 -0
  26. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/av_utils.py +0 -0
  27. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/utils.py +0 -0
  28. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/data/sources/voxceleb2.py +0 -0
  29. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/inference/__init__.py +0 -0
  30. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/inference/pipeline.py +0 -0
  31. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/inference/utils.py +0 -0
  32. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/inputs/__init__.py +0 -0
  33. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/inputs/encoders.py +0 -0
  34. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/__init__.py +0 -0
  35. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/common.py +0 -0
  36. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/images.py +0 -0
  37. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/inception.py +0 -0
  38. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/psnr.py +0 -0
  39. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/ssim.py +0 -0
  40. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/metrics/utils.py +0 -0
  41. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/__init__.py +0 -0
  42. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/attention.py +0 -0
  43. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/__init__.py +0 -0
  44. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
  45. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/diffusers.py +0 -0
  46. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
  47. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/favor_fastattn.py +0 -0
  48. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/general.py +0 -0
  49. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/simple_unet.py +0 -0
  50. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/unet_3d.py +0 -0
  51. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/models/unet_3d_blocks.py +0 -0
  52. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/predictors/__init__.py +0 -0
  53. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/__init__.py +0 -0
  54. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/common.py +0 -0
  55. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/ddim.py +0 -0
  56. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/ddpm.py +0 -0
  57. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/euler.py +0 -0
  58. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/heun_sampler.py +0 -0
  59. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/multistep_dpm.py +0 -0
  60. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/samplers/rk4_sampler.py +0 -0
  61. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/__init__.py +0 -0
  62. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/common.py +0 -0
  63. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/continuous.py +0 -0
  64. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/cosine.py +0 -0
  65. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/discrete.py +0 -0
  66. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/exp.py +0 -0
  67. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/karras.py +0 -0
  68. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/linear.py +0 -0
  69. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/schedulers/sqrt.py +0 -0
  70. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/trainer/__init__.py +0 -0
  71. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
  72. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/trainer/diffusion_trainer.py +0 -0
  73. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff/utils.py +0 -0
  74. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff.egg-info/dependency_links.txt +0 -0
  75. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff.egg-info/requires.txt +0 -0
  76. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/flaxdiff.egg-info/top_level.txt +0 -0
  77. {flaxdiff-0.2.7 → flaxdiff-0.2.8}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: flaxdiff
3
- Version: 0.2.7
3
+ Version: 0.2.8
4
4
  Summary: A versatile and easy to understand Diffusion library
5
5
  Author-email: Ashish Kumar Singh <ashishkmr472@gmail.com>
6
6
  License-Expression: MIT
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
251
251
  else: # Default to image
252
252
  return image_collate
253
253
 
254
+ class CaptionDeletionTransform(pygrain.MapTransform):
255
+ def map(self, element):
256
+ """Delete the caption from the element."""
257
+ if "caption" in element:
258
+ del element["caption"]
259
+ return element
254
260
 
255
261
  def get_dataset_grain(
256
262
  data_name="cc12m",
@@ -286,13 +292,14 @@ def get_dataset_grain(
286
292
  Dictionary with train dataset function and metadata.
287
293
  """
288
294
  dataset = datasetMap[data_name]
289
- data_source = dataset["source"](dataset_source)
295
+ train_source = dataset["source"](dataset_source, split="train")
296
+ # val_source = dataset["source"](dataset_source, split="val")
290
297
  augmenter = dataset["augmenter"](image_scale, method)
291
298
 
292
299
  local_batch_size = batch_size // jax.process_count()
293
300
 
294
301
  train_sampler = pygrain.IndexSampler(
295
- num_records=len(data_source) if count is None else count,
302
+ num_records=len(train_source) if count is None else count,
296
303
  shuffle=True,
297
304
  seed=seed,
298
305
  num_epochs=num_epochs,
@@ -300,7 +307,7 @@ def get_dataset_grain(
300
307
  )
301
308
 
302
309
  # val_sampler = pygrain.IndexSampler(
303
- # num_records=len(data_source) if count is None else count,
310
+ # num_records=len(val_source) if count is None else count,
304
311
  # shuffle=False,
305
312
  # seed=seed,
306
313
  # num_epochs=num_epochs,
@@ -310,11 +317,17 @@ def get_dataset_grain(
310
317
  def get_trainset():
311
318
  transformations = [
312
319
  augmenter(),
313
- pygrain.Batch(local_batch_size, drop_remainder=True),
314
320
  ]
321
+
322
+ # if filters:
323
+ # print("Adding filters to transformations")
324
+ # transformations.append(filters())
325
+
326
+ # transformations.append(CaptionDeletionTransform())
327
+ transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
315
328
 
316
329
  loader = pygrain.DataLoader(
317
- data_source=data_source,
330
+ data_source=train_source,
318
331
  sampler=train_sampler,
319
332
  operations=transformations,
320
333
  worker_count=worker_count,
@@ -325,30 +338,29 @@ def get_dataset_grain(
325
338
  )
326
339
  return loader
327
340
 
328
- # def get_valset():
329
- # transformations = [
330
- # augmenter(),
331
- # pygrain.Batch(local_batch_size, drop_remainder=True),
332
- # ]
341
+ def get_valset():
342
+ transformations = [
343
+ augmenter(),
344
+ pygrain.Batch(local_batch_size, drop_remainder=True),
345
+ ]
333
346
 
334
- # loader = pygrain.DataLoader(
335
- # data_source=data_source,
336
- # sampler=val_sampler,
337
- # operations=transformations,
338
- # worker_count=worker_count,
339
- # read_options=pygrain.ReadOptions(
340
- # read_thread_count, read_buffer_size
341
- # ),
342
- # worker_buffer_size=worker_buffer_size,
343
- # )
344
- # return loader
345
- get_valset = get_trainset # For now, use the same function for validation
347
+ loader = pygrain.DataLoader(
348
+ data_source=train_source,
349
+ sampler=train_sampler,
350
+ operations=transformations,
351
+ worker_count=2,
352
+ read_options=pygrain.ReadOptions(
353
+ read_thread_count, read_buffer_size
354
+ ),
355
+ worker_buffer_size=2,
356
+ )
357
+ return loader
346
358
 
347
359
  return {
348
360
  "train": get_trainset,
349
- "train_len": len(data_source),
361
+ "train_len": len(train_source),
350
362
  "val": get_valset,
351
- "val_len": len(data_source),
363
+ "val_len": len(train_source),
352
364
  "local_batch_size": local_batch_size,
353
365
  "global_batch_size": batch_size,
354
366
  }
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
8
8
  # ---------------------------------------------------------------------------------
9
9
 
10
10
  from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
11
- from .sources.images import data_source_combined_gcs, gcs_augmenters
11
+ from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
12
12
 
13
13
  # Configure the following for your datasets
14
14
  datasetMap = {
@@ -21,7 +21,7 @@ datasetMap = {
21
21
  "augmenter": gcs_augmenters,
22
22
  },
23
23
  "laiona_coco": {
24
- "source": data_source_gcs('datasets/laion12m+mscoco'),
24
+ "source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
25
25
  "augmenter": gcs_augmenters,
26
26
  },
27
27
  "aesthetic_coyo": {
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
62
62
  """
63
63
  pass
64
64
 
65
+ @abstractmethod
66
+ def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
67
+ """Create a filter function for the data.
68
+
69
+ Args:
70
+ **kwargs: Additional arguments for the filter.
71
+
72
+ Returns:
73
+ A callable that returns a pygrain.FilterTransform instance.
74
+ """
75
+ pass
76
+
65
77
  @staticmethod
66
78
  def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
67
79
  """Factory method to create a data augmenter of the specified type.
@@ -82,7 +82,7 @@ def labelizer_oxford_flowers102(path):
82
82
  class ImageTFDSSource(DataSource):
83
83
  """Data source for TensorFlow Datasets (TFDS) image datasets."""
84
84
 
85
- def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
85
+ def __init__(self, name: str, use_tf: bool = True):
86
86
  """Initialize a TFDS image data source.
87
87
 
88
88
  Args:
@@ -92,9 +92,8 @@ class ImageTFDSSource(DataSource):
92
92
  """
93
93
  self.name = name
94
94
  self.use_tf = use_tf
95
- self.split = split
96
95
 
97
- def get_source(self, path_override: str) -> Any:
96
+ def get_source(self, path_override: str, split: str = "all") -> Any:
98
97
  """Get the TFDS data source.
99
98
 
100
99
  Args:
@@ -105,9 +104,9 @@ class ImageTFDSSource(DataSource):
105
104
  """
106
105
  import tensorflow_datasets as tfds
107
106
  if self.use_tf:
108
- return tfds.load(self.name, split=self.split, shuffle_files=True)
107
+ return tfds.load(self.name, split=split, shuffle_files=True)
109
108
  else:
110
- return tfds.data_source(self.name, split=self.split, try_gcs=False)
109
+ return tfds.data_source(self.name, split=split, try_gcs=False)
111
110
 
112
111
 
113
112
  class ImageTFDSAugmenter(DataAugmenter):
@@ -168,7 +167,11 @@ class ImageTFDSAugmenter(DataAugmenter):
168
167
  }
169
168
 
170
169
  return TFDSTransform
171
-
170
+
171
+ def create_filter(self, image_scale: int = 256):
172
+ class FilterTransform(pygrain.FilterTransform):
173
+ def map(self, element) -> bool:
174
+ return True
172
175
  """
173
176
  Batch structure:
174
177
  {
@@ -195,7 +198,7 @@ class ImageGCSSource(DataSource):
195
198
  """
196
199
  self.source = source
197
200
 
198
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
201
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
199
202
  """Get the GCS data source.
200
203
 
201
204
  Args:
@@ -207,6 +210,8 @@ class ImageGCSSource(DataSource):
207
210
  records_path = os.path.join(path_override, self.source)
208
211
  records = [os.path.join(records_path, i) for i in os.listdir(
209
212
  records_path) if 'array_record' in i]
213
+ if split == "val":
214
+ records = records[:1]
210
215
  return pygrain.ArrayRecordDataSource(records)
211
216
 
212
217
 
@@ -221,7 +226,7 @@ class CombinedImageGCSSource(DataSource):
221
226
  """
222
227
  self.sources = sources
223
228
 
224
- def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
229
+ def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
225
230
  """Get the combined GCS data source.
226
231
 
227
232
  Args:
@@ -235,9 +240,10 @@ class CombinedImageGCSSource(DataSource):
235
240
  for records_path in records_paths:
236
241
  records += [os.path.join(records_path, i) for i in os.listdir(
237
242
  records_path) if 'array_record' in i]
243
+ if split == "val":
244
+ records = records[:1]
238
245
  return pygrain.ArrayRecordDataSource(records)
239
246
 
240
-
241
247
  class ImageGCSAugmenter(DataAugmenter):
242
248
  """Augmenter for GCS image datasets."""
243
249
 
@@ -297,6 +303,52 @@ class ImageGCSAugmenter(DataAugmenter):
297
303
  }
298
304
 
299
305
  return GCSTransform
306
+
307
+ def create_filter(self, image_scale: int = 256):
308
+ import torch.nn.functional as F
309
+ class FilterTransform(pygrain.FilterTransform):
310
+ """
311
+ Filter transform for GCS data source.
312
+ """
313
+ def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
314
+ super().__init__()
315
+ self.image_scale = image_scale
316
+ if model is None:
317
+ from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
318
+ model_name = "openai/clip-vit-base-patch32"
319
+ model = CLIPModel.from_pretrained(model_name)
320
+ processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
321
+ self.method = method
322
+ self.model = model
323
+ self.processor = processor
324
+
325
+ # def _filter_(pixel_values, input_ids):
326
+ # image_embeds = self.model.get_image_features(pixel_values=pixel_values)
327
+ # text_embeds = self.model.get_text_features(input_ids=input_ids)
328
+ # image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
329
+ # text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
330
+ # similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
331
+ # return jnp.all(similarity >= 0.25)
332
+
333
+ # self._filter_ = _filter_
334
+
335
+ def filter(self, data: Dict[str, Any]) -> bool:
336
+ images = [data['image']]
337
+ texts = [data['caption']]
338
+ inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
339
+ # result = self._filter_(
340
+ # pixel_values=inputs['pixel_values'],
341
+ # input_ids=inputs['input_ids']
342
+ # )
343
+ # return result
344
+
345
+ image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
346
+ text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
347
+ similarity = F.cosine_similarity(image_embeds, text_embeds)
348
+ # Filter out images with similarity less than 0.25
349
+ return similarity[0] >= 0.25
350
+
351
+ return FilterTransform
300
352
 
301
353
 
302
354
  # ----------------------------------------------------------------------------------
@@ -305,9 +357,9 @@ class ImageGCSAugmenter(DataAugmenter):
305
357
 
306
358
  # These functions maintain backward compatibility with existing code
307
359
 
308
- def data_source_tfds(name, use_tf=True, split="all"):
360
+ def data_source_tfds(name, use_tf=True):
309
361
  """Legacy function for TFDS data sources."""
310
- source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
362
+ source = ImageTFDSSource(name=name, use_tf=use_tf)
311
363
  return source.get_source
312
364
 
313
365
 
@@ -333,3 +385,8 @@ def gcs_augmenters(image_scale, method):
333
385
  """Legacy function for GCS augmenters."""
334
386
  augmenter = ImageGCSAugmenter()
335
387
  return augmenter.create_transform(image_scale=image_scale, method=method)
388
+
389
+ def gcs_filters(image_scale):
390
+ """Legacy function for GCS Filters."""
391
+ augmenter = ImageGCSAugmenter()
392
+ return augmenter.create_filter(image_scale=image_scale)
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
216
216
 
217
217
  return AudioVideoTransform
218
218
 
219
+
220
+ def create_filter(self, image_scale: int = 256):
221
+ class FilterTransform(pygrain.FilterTransform):
222
+ def map(self, element) -> bool:
223
+ return True
219
224
 
220
225
  # ----------------------------------------------------------------------------------
221
226
  # Helper functions for video datasets
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
335
335
 
336
336
  out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
337
337
 
338
- return out
339
-
340
- # Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
341
- def _d2xy(n, d):
342
- x = 0
343
- y = 0
344
- t = d
345
- s = 1
346
- while s < n:
347
- rx = (t // 2) & 1
348
- ry = (t ^ rx) & 1
349
- if ry == 0:
350
- if rx == 1:
351
- x = n - 1 - x
352
- y = n - 1 - y
353
- x, y = y, x
354
- x += s * rx
355
- y += s * ry
356
- t //= 4
357
- s *= 2
358
- return x, y
359
-
360
- # Hilbert index mapping for a rectangular grid of patches H_P x W_P
361
-
362
- def hilbert_indices(H_P, W_P):
363
- size = max(H_P, W_P)
364
- order = math.ceil(math.log2(size))
365
- n = 1 << order
366
- coords = []
367
- for d in range(n * n):
368
- x, y = _d2xy(n, d)
369
- # x is column index, y is row index
370
- if x < W_P and y < H_P:
371
- coords.append((y, x)) # (row, col)
372
- if len(coords) == H_P * W_P:
373
- break
374
- # Convert (row, col) to linear indices row-major
375
- indices = [r * W_P + c for r, c in coords]
376
- return jnp.array(indices, dtype=jnp.int32)
377
-
378
- # Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
379
-
380
- def inverse_permutation(idx):
381
- inv = jnp.zeros_like(idx)
382
- inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
383
- return inv
384
-
385
- # Patchify using Hilbert ordering: extract patches and reorder sequence
386
-
387
- def hilbert_patchify(x, patch_size):
388
- B, H, W, C = x.shape
389
- H_P = H // patch_size
390
- W_P = W // patch_size
391
- # Extract patches in row-major
392
- patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
393
- idx = hilbert_indices(H_P, W_P)
394
- return patches[:, idx, :]
395
-
396
- # Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
397
-
398
- def hilbert_unpatchify(patches, patch_size, H, W, C):
399
- B, N, D = patches.shape
400
- H_P = H // patch_size
401
- W_P = W // patch_size
402
- inv = inverse_permutation(hilbert_indices(H_P, W_P))
403
- # Reorder back to row-major
404
- linear = patches[:, inv, :]
405
- # Reconstruct image
406
- x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
407
- return x
338
+ return out