flaxdiff 0.2.7__tar.gz → 0.2.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/PKG-INFO +1 -1
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/dataloaders.py +23 -19
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/dataset_map.py +2 -1
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/base.py +12 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/images.py +75 -3
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/videos.py +5 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/inference/utils.py +7 -1
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/common.py +1 -70
- flaxdiff-0.2.9/flaxdiff/models/hilbert.py +617 -0
- flaxdiff-0.2.9/flaxdiff/models/simple_dit.py +275 -0
- flaxdiff-0.2.9/flaxdiff/models/simple_mmdit.py +730 -0
- flaxdiff-0.2.9/flaxdiff/models/simple_vit.py +446 -0
- flaxdiff-0.2.9/flaxdiff/models/vit_common.py +262 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/trainer/general_diffusion_trainer.py +30 -10
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff.egg-info/PKG-INFO +1 -1
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff.egg-info/SOURCES.txt +4 -1
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/pyproject.toml +1 -1
- flaxdiff-0.2.7/flaxdiff/models/better_uvit.py +0 -380
- flaxdiff-0.2.7/flaxdiff/models/simple_vit.py +0 -186
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/README.md +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/benchmark_decord.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/online_loader.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/audio_utils.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/av_example.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/av_utils.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/utils.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/data/sources/voxceleb2.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/inference/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/inference/pipeline.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/inputs/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/inputs/encoders.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/common.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/images.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/inception.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/psnr.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/ssim.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/metrics/utils.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/attention.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/autoencoder/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/autoencoder/autoencoder.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/autoencoder/diffusers.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/autoencoder/simple_autoenc.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/favor_fastattn.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/general.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/simple_unet.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/unet_3d.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/models/unet_3d_blocks.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/predictors/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/common.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/ddim.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/ddpm.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/euler.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/heun_sampler.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/multistep_dpm.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/samplers/rk4_sampler.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/common.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/continuous.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/cosine.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/discrete.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/exp.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/karras.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/linear.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/schedulers/sqrt.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/trainer/__init__.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/trainer/autoencoder_trainer.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/trainer/diffusion_trainer.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff/utils.py +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff.egg-info/dependency_links.txt +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff.egg-info/requires.txt +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/flaxdiff.egg-info/top_level.txt +0 -0
- {flaxdiff-0.2.7 → flaxdiff-0.2.9}/setup.cfg +0 -0
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
|
|
251
251
|
else: # Default to image
|
252
252
|
return image_collate
|
253
253
|
|
254
|
+
class CaptionDeletionTransform(pygrain.MapTransform):
|
255
|
+
def map(self, element):
|
256
|
+
"""Delete the caption from the element."""
|
257
|
+
if "caption" in element:
|
258
|
+
del element["caption"]
|
259
|
+
return element
|
254
260
|
|
255
261
|
def get_dataset_grain(
|
256
262
|
data_name="cc12m",
|
@@ -288,7 +294,6 @@ def get_dataset_grain(
|
|
288
294
|
dataset = datasetMap[data_name]
|
289
295
|
data_source = dataset["source"](dataset_source)
|
290
296
|
augmenter = dataset["augmenter"](image_scale, method)
|
291
|
-
|
292
297
|
local_batch_size = batch_size // jax.process_count()
|
293
298
|
|
294
299
|
train_sampler = pygrain.IndexSampler(
|
@@ -310,8 +315,8 @@ def get_dataset_grain(
|
|
310
315
|
def get_trainset():
|
311
316
|
transformations = [
|
312
317
|
augmenter(),
|
313
|
-
pygrain.Batch(local_batch_size, drop_remainder=True),
|
314
318
|
]
|
319
|
+
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
|
315
320
|
|
316
321
|
loader = pygrain.DataLoader(
|
317
322
|
data_source=data_source,
|
@@ -325,24 +330,23 @@ def get_dataset_grain(
|
|
325
330
|
)
|
326
331
|
return loader
|
327
332
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
+
def get_valset():
|
334
|
+
transformations = [
|
335
|
+
augmenter(),
|
336
|
+
pygrain.Batch(32, drop_remainder=True),
|
337
|
+
]
|
333
338
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
get_valset = get_trainset # For now, use the same function for validation
|
339
|
+
loader = pygrain.DataLoader(
|
340
|
+
data_source=data_source,
|
341
|
+
sampler=train_sampler,
|
342
|
+
operations=transformations,
|
343
|
+
worker_count=8,
|
344
|
+
read_options=pygrain.ReadOptions(
|
345
|
+
32, 128
|
346
|
+
),
|
347
|
+
worker_buffer_size=32,
|
348
|
+
)
|
349
|
+
return loader
|
346
350
|
|
347
351
|
return {
|
348
352
|
"train": get_trainset,
|
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
|
|
8
8
|
# ---------------------------------------------------------------------------------
|
9
9
|
|
10
10
|
from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
|
11
|
-
from .sources.images import data_source_combined_gcs, gcs_augmenters
|
11
|
+
from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
|
12
12
|
|
13
13
|
# Configure the following for your datasets
|
14
14
|
datasetMap = {
|
@@ -23,6 +23,7 @@ datasetMap = {
|
|
23
23
|
"laiona_coco": {
|
24
24
|
"source": data_source_gcs('datasets/laion12m+mscoco'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
|
+
"filter": gcs_filters,
|
26
27
|
},
|
27
28
|
"aesthetic_coyo": {
|
28
29
|
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
|
|
62
62
|
"""
|
63
63
|
pass
|
64
64
|
|
65
|
+
@abstractmethod
|
66
|
+
def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
|
67
|
+
"""Create a filter function for the data.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
**kwargs: Additional arguments for the filter.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A callable that returns a pygrain.FilterTransform instance.
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
65
77
|
@staticmethod
|
66
78
|
def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
|
67
79
|
"""Factory method to create a data augmenter of the specified type.
|
@@ -11,7 +11,7 @@ import struct as st
|
|
11
11
|
from functools import partial
|
12
12
|
import numpy as np
|
13
13
|
from .base import DataSource, DataAugmenter
|
14
|
-
|
14
|
+
import traceback
|
15
15
|
|
16
16
|
# ----------------------------------------------------------------------------------
|
17
17
|
# Utility functions
|
@@ -79,6 +79,24 @@ def labelizer_oxford_flowers102(path):
|
|
79
79
|
# TFDS Image Source
|
80
80
|
# ----------------------------------------------------------------------------------
|
81
81
|
|
82
|
+
def get_oxford_valset(text_encoder):
|
83
|
+
# Construct a validation set by the prompts for consistency
|
84
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
|
85
|
+
val_prompts *= 100
|
86
|
+
|
87
|
+
def get_val_dataset(batch_size=128):
|
88
|
+
for i in range(0, len(val_prompts), batch_size):
|
89
|
+
try:
|
90
|
+
prompts = val_prompts[i:i + batch_size]
|
91
|
+
tokens = text_encoder.tokenize(prompts)
|
92
|
+
yield {"text": tokens}
|
93
|
+
except Exception as e:
|
94
|
+
print(f"Error in get_val_dataset: {e}")
|
95
|
+
traceback.print_exc()
|
96
|
+
continue
|
97
|
+
|
98
|
+
return get_val_dataset, len(val_prompts)
|
99
|
+
|
82
100
|
class ImageTFDSSource(DataSource):
|
83
101
|
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
|
84
102
|
|
@@ -168,7 +186,11 @@ class ImageTFDSAugmenter(DataAugmenter):
|
|
168
186
|
}
|
169
187
|
|
170
188
|
return TFDSTransform
|
171
|
-
|
189
|
+
|
190
|
+
def create_filter(self, image_scale: int = 256):
|
191
|
+
class FilterTransform(pygrain.FilterTransform):
|
192
|
+
def map(self, element) -> bool:
|
193
|
+
return True
|
172
194
|
"""
|
173
195
|
Batch structure:
|
174
196
|
{
|
@@ -237,7 +259,6 @@ class CombinedImageGCSSource(DataSource):
|
|
237
259
|
records_path) if 'array_record' in i]
|
238
260
|
return pygrain.ArrayRecordDataSource(records)
|
239
261
|
|
240
|
-
|
241
262
|
class ImageGCSAugmenter(DataAugmenter):
|
242
263
|
"""Augmenter for GCS image datasets."""
|
243
264
|
|
@@ -297,6 +318,52 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
297
318
|
}
|
298
319
|
|
299
320
|
return GCSTransform
|
321
|
+
|
322
|
+
def create_filter(self, image_scale: int = 256):
|
323
|
+
import torch.nn.functional as F
|
324
|
+
class FilterTransform(pygrain.FilterTransform):
|
325
|
+
"""
|
326
|
+
Filter transform for GCS data source.
|
327
|
+
"""
|
328
|
+
def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
|
329
|
+
super().__init__()
|
330
|
+
self.image_scale = image_scale
|
331
|
+
if model is None:
|
332
|
+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
|
333
|
+
model_name = "openai/clip-vit-base-patch32"
|
334
|
+
model = CLIPModel.from_pretrained(model_name)
|
335
|
+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
|
336
|
+
self.method = method
|
337
|
+
self.model = model
|
338
|
+
self.processor = processor
|
339
|
+
|
340
|
+
# def _filter_(pixel_values, input_ids):
|
341
|
+
# image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
342
|
+
# text_embeds = self.model.get_text_features(input_ids=input_ids)
|
343
|
+
# image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
344
|
+
# text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
345
|
+
# similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
|
346
|
+
# return jnp.all(similarity >= 0.25)
|
347
|
+
|
348
|
+
# self._filter_ = _filter_
|
349
|
+
|
350
|
+
def filter(self, data: Dict[str, Any]) -> bool:
|
351
|
+
images = [data['image']]
|
352
|
+
texts = [data['caption']]
|
353
|
+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
|
354
|
+
# result = self._filter_(
|
355
|
+
# pixel_values=inputs['pixel_values'],
|
356
|
+
# input_ids=inputs['input_ids']
|
357
|
+
# )
|
358
|
+
# return result
|
359
|
+
|
360
|
+
image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
|
361
|
+
text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
|
362
|
+
similarity = F.cosine_similarity(image_embeds, text_embeds)
|
363
|
+
# Filter out images with similarity less than 0.25
|
364
|
+
return similarity[0] >= 0.25
|
365
|
+
|
366
|
+
return FilterTransform
|
300
367
|
|
301
368
|
|
302
369
|
# ----------------------------------------------------------------------------------
|
@@ -333,3 +400,8 @@ def gcs_augmenters(image_scale, method):
|
|
333
400
|
"""Legacy function for GCS augmenters."""
|
334
401
|
augmenter = ImageGCSAugmenter()
|
335
402
|
return augmenter.create_transform(image_scale=image_scale, method=method)
|
403
|
+
|
404
|
+
def gcs_filters(image_scale):
|
405
|
+
"""Legacy function for GCS Filters."""
|
406
|
+
augmenter = ImageGCSAugmenter()
|
407
|
+
return augmenter.create_filter(image_scale=image_scale)
|
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
|
|
216
216
|
|
217
217
|
return AudioVideoTransform
|
218
218
|
|
219
|
+
|
220
|
+
def create_filter(self, image_scale: int = 256):
|
221
|
+
class FilterTransform(pygrain.FilterTransform):
|
222
|
+
def map(self, element) -> bool:
|
223
|
+
return True
|
219
224
|
|
220
225
|
# ----------------------------------------------------------------------------------
|
221
226
|
# Helper functions for video datasets
|
@@ -25,6 +25,9 @@ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
|
|
25
25
|
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
|
26
26
|
from flaxdiff.utils import defaultTextEncodeModel
|
27
27
|
|
28
|
+
from flaxdiff.models.simple_vit import UViT, SimpleUDiT
|
29
|
+
from flaxdiff.models.simple_dit import SimpleDiT
|
30
|
+
from flaxdiff.models.simple_mmdit import SimpleMMDiT, HierarchicalMMDiT
|
28
31
|
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
|
29
32
|
import os
|
30
33
|
|
@@ -116,7 +119,10 @@ def parse_config(config, overrides=None):
|
|
116
119
|
MODEL_CLASSES = {
|
117
120
|
'unet': Unet,
|
118
121
|
'uvit': UViT,
|
119
|
-
'diffusers_unet_simple': FlaxUNet2DConditionModel
|
122
|
+
'diffusers_unet_simple': FlaxUNet2DConditionModel,
|
123
|
+
'simple_dit': SimpleDiT,
|
124
|
+
'simple_uvit': SimpleUDiT,
|
125
|
+
'simple_mmdit': SimpleMMDiT,
|
120
126
|
}
|
121
127
|
|
122
128
|
# Map all the leaves of the model config, converting strings to appropriate types
|
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
|
|
335
335
|
|
336
336
|
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
337
337
|
|
338
|
-
return out
|
339
|
-
|
340
|
-
# Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
|
341
|
-
def _d2xy(n, d):
|
342
|
-
x = 0
|
343
|
-
y = 0
|
344
|
-
t = d
|
345
|
-
s = 1
|
346
|
-
while s < n:
|
347
|
-
rx = (t // 2) & 1
|
348
|
-
ry = (t ^ rx) & 1
|
349
|
-
if ry == 0:
|
350
|
-
if rx == 1:
|
351
|
-
x = n - 1 - x
|
352
|
-
y = n - 1 - y
|
353
|
-
x, y = y, x
|
354
|
-
x += s * rx
|
355
|
-
y += s * ry
|
356
|
-
t //= 4
|
357
|
-
s *= 2
|
358
|
-
return x, y
|
359
|
-
|
360
|
-
# Hilbert index mapping for a rectangular grid of patches H_P x W_P
|
361
|
-
|
362
|
-
def hilbert_indices(H_P, W_P):
|
363
|
-
size = max(H_P, W_P)
|
364
|
-
order = math.ceil(math.log2(size))
|
365
|
-
n = 1 << order
|
366
|
-
coords = []
|
367
|
-
for d in range(n * n):
|
368
|
-
x, y = _d2xy(n, d)
|
369
|
-
# x is column index, y is row index
|
370
|
-
if x < W_P and y < H_P:
|
371
|
-
coords.append((y, x)) # (row, col)
|
372
|
-
if len(coords) == H_P * W_P:
|
373
|
-
break
|
374
|
-
# Convert (row, col) to linear indices row-major
|
375
|
-
indices = [r * W_P + c for r, c in coords]
|
376
|
-
return jnp.array(indices, dtype=jnp.int32)
|
377
|
-
|
378
|
-
# Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
|
379
|
-
|
380
|
-
def inverse_permutation(idx):
|
381
|
-
inv = jnp.zeros_like(idx)
|
382
|
-
inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
|
383
|
-
return inv
|
384
|
-
|
385
|
-
# Patchify using Hilbert ordering: extract patches and reorder sequence
|
386
|
-
|
387
|
-
def hilbert_patchify(x, patch_size):
|
388
|
-
B, H, W, C = x.shape
|
389
|
-
H_P = H // patch_size
|
390
|
-
W_P = W // patch_size
|
391
|
-
# Extract patches in row-major
|
392
|
-
patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
|
393
|
-
idx = hilbert_indices(H_P, W_P)
|
394
|
-
return patches[:, idx, :]
|
395
|
-
|
396
|
-
# Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
|
397
|
-
|
398
|
-
def hilbert_unpatchify(patches, patch_size, H, W, C):
|
399
|
-
B, N, D = patches.shape
|
400
|
-
H_P = H // patch_size
|
401
|
-
W_P = W // patch_size
|
402
|
-
inv = inverse_permutation(hilbert_indices(H_P, W_P))
|
403
|
-
# Reorder back to row-major
|
404
|
-
linear = patches[:, inv, :]
|
405
|
-
# Reconstruct image
|
406
|
-
x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
|
407
|
-
return x
|
338
|
+
return out
|