flaxdiff 0.2.6.1__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +36 -24
- flaxdiff/data/dataset_map.py +2 -2
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +71 -12
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/inference/pipeline.py +9 -4
- flaxdiff/inference/utils.py +2 -2
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +476 -0
- flaxdiff/models/simple_mmdit.py +861 -0
- flaxdiff/models/simple_vit.py +278 -117
- flaxdiff/trainer/general_diffusion_trainer.py +29 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/RECORD +18 -16
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.6.1.dist-info → flaxdiff-0.2.8.dist-info}/top_level.txt +0 -0
flaxdiff/data/dataloaders.py
CHANGED
@@ -251,6 +251,12 @@ def generate_collate_fn(media_type="image"):
|
|
251
251
|
else: # Default to image
|
252
252
|
return image_collate
|
253
253
|
|
254
|
+
class CaptionDeletionTransform(pygrain.MapTransform):
|
255
|
+
def map(self, element):
|
256
|
+
"""Delete the caption from the element."""
|
257
|
+
if "caption" in element:
|
258
|
+
del element["caption"]
|
259
|
+
return element
|
254
260
|
|
255
261
|
def get_dataset_grain(
|
256
262
|
data_name="cc12m",
|
@@ -286,13 +292,14 @@ def get_dataset_grain(
|
|
286
292
|
Dictionary with train dataset function and metadata.
|
287
293
|
"""
|
288
294
|
dataset = datasetMap[data_name]
|
289
|
-
|
295
|
+
train_source = dataset["source"](dataset_source, split="train")
|
296
|
+
# val_source = dataset["source"](dataset_source, split="val")
|
290
297
|
augmenter = dataset["augmenter"](image_scale, method)
|
291
298
|
|
292
299
|
local_batch_size = batch_size // jax.process_count()
|
293
300
|
|
294
301
|
train_sampler = pygrain.IndexSampler(
|
295
|
-
num_records=len(
|
302
|
+
num_records=len(train_source) if count is None else count,
|
296
303
|
shuffle=True,
|
297
304
|
seed=seed,
|
298
305
|
num_epochs=num_epochs,
|
@@ -300,7 +307,7 @@ def get_dataset_grain(
|
|
300
307
|
)
|
301
308
|
|
302
309
|
# val_sampler = pygrain.IndexSampler(
|
303
|
-
# num_records=len(
|
310
|
+
# num_records=len(val_source) if count is None else count,
|
304
311
|
# shuffle=False,
|
305
312
|
# seed=seed,
|
306
313
|
# num_epochs=num_epochs,
|
@@ -310,11 +317,17 @@ def get_dataset_grain(
|
|
310
317
|
def get_trainset():
|
311
318
|
transformations = [
|
312
319
|
augmenter(),
|
313
|
-
pygrain.Batch(local_batch_size, drop_remainder=True),
|
314
320
|
]
|
321
|
+
|
322
|
+
# if filters:
|
323
|
+
# print("Adding filters to transformations")
|
324
|
+
# transformations.append(filters())
|
325
|
+
|
326
|
+
# transformations.append(CaptionDeletionTransform())
|
327
|
+
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
|
315
328
|
|
316
329
|
loader = pygrain.DataLoader(
|
317
|
-
data_source=
|
330
|
+
data_source=train_source,
|
318
331
|
sampler=train_sampler,
|
319
332
|
operations=transformations,
|
320
333
|
worker_count=worker_count,
|
@@ -325,30 +338,29 @@ def get_dataset_grain(
|
|
325
338
|
)
|
326
339
|
return loader
|
327
340
|
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
341
|
+
def get_valset():
|
342
|
+
transformations = [
|
343
|
+
augmenter(),
|
344
|
+
pygrain.Batch(local_batch_size, drop_remainder=True),
|
345
|
+
]
|
333
346
|
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
344
|
-
|
345
|
-
get_valset = get_trainset # For now, use the same function for validation
|
347
|
+
loader = pygrain.DataLoader(
|
348
|
+
data_source=train_source,
|
349
|
+
sampler=train_sampler,
|
350
|
+
operations=transformations,
|
351
|
+
worker_count=2,
|
352
|
+
read_options=pygrain.ReadOptions(
|
353
|
+
read_thread_count, read_buffer_size
|
354
|
+
),
|
355
|
+
worker_buffer_size=2,
|
356
|
+
)
|
357
|
+
return loader
|
346
358
|
|
347
359
|
return {
|
348
360
|
"train": get_trainset,
|
349
|
-
"train_len": len(
|
361
|
+
"train_len": len(train_source),
|
350
362
|
"val": get_valset,
|
351
|
-
"val_len": len(
|
363
|
+
"val_len": len(train_source),
|
352
364
|
"local_batch_size": local_batch_size,
|
353
365
|
"global_batch_size": batch_size,
|
354
366
|
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -8,7 +8,7 @@ from .sources.videos import VideoTFDSSource, VideoLocalSource, AudioVideoAugment
|
|
8
8
|
# ---------------------------------------------------------------------------------
|
9
9
|
|
10
10
|
from .sources.images import data_source_tfds, tfds_augmenters, data_source_gcs
|
11
|
-
from .sources.images import data_source_combined_gcs, gcs_augmenters
|
11
|
+
from .sources.images import data_source_combined_gcs, gcs_augmenters, gcs_filters
|
12
12
|
|
13
13
|
# Configure the following for your datasets
|
14
14
|
datasetMap = {
|
@@ -21,7 +21,7 @@ datasetMap = {
|
|
21
21
|
"augmenter": gcs_augmenters,
|
22
22
|
},
|
23
23
|
"laiona_coco": {
|
24
|
-
"source": data_source_gcs('datasets/laion12m+
|
24
|
+
"source": data_source_gcs('datasets/laion12m+mscoco_filtered-new'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
26
|
},
|
27
27
|
"aesthetic_coyo": {
|
flaxdiff/data/sources/base.py
CHANGED
@@ -62,6 +62,18 @@ class DataAugmenter(ABC):
|
|
62
62
|
"""
|
63
63
|
pass
|
64
64
|
|
65
|
+
@abstractmethod
|
66
|
+
def create_filter(self, **kwargs) -> Callable[[], pygrain.FilterTransform]:
|
67
|
+
"""Create a filter function for the data.
|
68
|
+
|
69
|
+
Args:
|
70
|
+
**kwargs: Additional arguments for the filter.
|
71
|
+
|
72
|
+
Returns:
|
73
|
+
A callable that returns a pygrain.FilterTransform instance.
|
74
|
+
"""
|
75
|
+
pass
|
76
|
+
|
65
77
|
@staticmethod
|
66
78
|
def create(augmenter_type: str, **kwargs) -> 'DataAugmenter':
|
67
79
|
"""Factory method to create a data augmenter of the specified type.
|
flaxdiff/data/sources/images.py
CHANGED
@@ -82,7 +82,7 @@ def labelizer_oxford_flowers102(path):
|
|
82
82
|
class ImageTFDSSource(DataSource):
|
83
83
|
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
|
84
84
|
|
85
|
-
def __init__(self, name: str, use_tf: bool = True
|
85
|
+
def __init__(self, name: str, use_tf: bool = True):
|
86
86
|
"""Initialize a TFDS image data source.
|
87
87
|
|
88
88
|
Args:
|
@@ -92,9 +92,8 @@ class ImageTFDSSource(DataSource):
|
|
92
92
|
"""
|
93
93
|
self.name = name
|
94
94
|
self.use_tf = use_tf
|
95
|
-
self.split = split
|
96
95
|
|
97
|
-
def get_source(self, path_override: str) -> Any:
|
96
|
+
def get_source(self, path_override: str, split: str = "all") -> Any:
|
98
97
|
"""Get the TFDS data source.
|
99
98
|
|
100
99
|
Args:
|
@@ -105,20 +104,22 @@ class ImageTFDSSource(DataSource):
|
|
105
104
|
"""
|
106
105
|
import tensorflow_datasets as tfds
|
107
106
|
if self.use_tf:
|
108
|
-
return tfds.load(self.name, split=
|
107
|
+
return tfds.load(self.name, split=split, shuffle_files=True)
|
109
108
|
else:
|
110
|
-
return tfds.data_source(self.name, split=
|
109
|
+
return tfds.data_source(self.name, split=split, try_gcs=False)
|
111
110
|
|
112
111
|
|
113
112
|
class ImageTFDSAugmenter(DataAugmenter):
|
114
113
|
"""Augmenter for TFDS image datasets."""
|
115
114
|
|
116
|
-
def __init__(self, label_path: str =
|
115
|
+
def __init__(self, label_path: str = None):
|
117
116
|
"""Initialize a TFDS image augmenter.
|
118
117
|
|
119
118
|
Args:
|
120
119
|
label_path: Path to the labels file for datasets like Oxford Flowers.
|
121
120
|
"""
|
121
|
+
if label_path is None:
|
122
|
+
label_path = os.path.join(os.path.expanduser("~"), "tensorflow_datasets/oxford_flowers102/2.1.1/label.labels.txt")
|
122
123
|
self.label_path = label_path
|
123
124
|
|
124
125
|
def create_transform(self, image_scale: int = 256, method: Any = None) -> Callable[[], pygrain.MapTransform]:
|
@@ -166,7 +167,11 @@ class ImageTFDSAugmenter(DataAugmenter):
|
|
166
167
|
}
|
167
168
|
|
168
169
|
return TFDSTransform
|
169
|
-
|
170
|
+
|
171
|
+
def create_filter(self, image_scale: int = 256):
|
172
|
+
class FilterTransform(pygrain.FilterTransform):
|
173
|
+
def map(self, element) -> bool:
|
174
|
+
return True
|
170
175
|
"""
|
171
176
|
Batch structure:
|
172
177
|
{
|
@@ -193,7 +198,7 @@ class ImageGCSSource(DataSource):
|
|
193
198
|
"""
|
194
199
|
self.source = source
|
195
200
|
|
196
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
201
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
|
197
202
|
"""Get the GCS data source.
|
198
203
|
|
199
204
|
Args:
|
@@ -205,6 +210,8 @@ class ImageGCSSource(DataSource):
|
|
205
210
|
records_path = os.path.join(path_override, self.source)
|
206
211
|
records = [os.path.join(records_path, i) for i in os.listdir(
|
207
212
|
records_path) if 'array_record' in i]
|
213
|
+
if split == "val":
|
214
|
+
records = records[:1]
|
208
215
|
return pygrain.ArrayRecordDataSource(records)
|
209
216
|
|
210
217
|
|
@@ -219,7 +226,7 @@ class CombinedImageGCSSource(DataSource):
|
|
219
226
|
"""
|
220
227
|
self.sources = sources
|
221
228
|
|
222
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
229
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount", split: str = "train") -> Any:
|
223
230
|
"""Get the combined GCS data source.
|
224
231
|
|
225
232
|
Args:
|
@@ -233,9 +240,10 @@ class CombinedImageGCSSource(DataSource):
|
|
233
240
|
for records_path in records_paths:
|
234
241
|
records += [os.path.join(records_path, i) for i in os.listdir(
|
235
242
|
records_path) if 'array_record' in i]
|
243
|
+
if split == "val":
|
244
|
+
records = records[:1]
|
236
245
|
return pygrain.ArrayRecordDataSource(records)
|
237
246
|
|
238
|
-
|
239
247
|
class ImageGCSAugmenter(DataAugmenter):
|
240
248
|
"""Augmenter for GCS image datasets."""
|
241
249
|
|
@@ -295,6 +303,52 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
295
303
|
}
|
296
304
|
|
297
305
|
return GCSTransform
|
306
|
+
|
307
|
+
def create_filter(self, image_scale: int = 256):
|
308
|
+
import torch.nn.functional as F
|
309
|
+
class FilterTransform(pygrain.FilterTransform):
|
310
|
+
"""
|
311
|
+
Filter transform for GCS data source.
|
312
|
+
"""
|
313
|
+
def __init__(self, model=None, processor=None, method=cv2.INTER_AREA):
|
314
|
+
super().__init__()
|
315
|
+
self.image_scale = image_scale
|
316
|
+
if model is None:
|
317
|
+
from transformers import AutoProcessor, CLIPVisionModelWithProjection, FlaxCLIPModel, CLIPModel
|
318
|
+
model_name = "openai/clip-vit-base-patch32"
|
319
|
+
model = CLIPModel.from_pretrained(model_name)
|
320
|
+
processor = AutoProcessor.from_pretrained(model_name, use_fast=False)
|
321
|
+
self.method = method
|
322
|
+
self.model = model
|
323
|
+
self.processor = processor
|
324
|
+
|
325
|
+
# def _filter_(pixel_values, input_ids):
|
326
|
+
# image_embeds = self.model.get_image_features(pixel_values=pixel_values)
|
327
|
+
# text_embeds = self.model.get_text_features(input_ids=input_ids)
|
328
|
+
# image_embeds = image_embeds / jnp.linalg.norm(image_embeds, axis=-1, keepdims=True)
|
329
|
+
# text_embeds = text_embeds / jnp.linalg.norm(text_embeds, axis=-1, keepdims=True)
|
330
|
+
# similarity = jnp.sum(image_embeds * text_embeds, axis=-1)
|
331
|
+
# return jnp.all(similarity >= 0.25)
|
332
|
+
|
333
|
+
# self._filter_ = _filter_
|
334
|
+
|
335
|
+
def filter(self, data: Dict[str, Any]) -> bool:
|
336
|
+
images = [data['image']]
|
337
|
+
texts = [data['caption']]
|
338
|
+
inputs = self.processor(text=texts, images=images, return_tensors="pt", padding=True, truncation=True)
|
339
|
+
# result = self._filter_(
|
340
|
+
# pixel_values=inputs['pixel_values'],
|
341
|
+
# input_ids=inputs['input_ids']
|
342
|
+
# )
|
343
|
+
# return result
|
344
|
+
|
345
|
+
image_embeds = self.model.get_image_features(pixel_values=inputs['pixel_values'])
|
346
|
+
text_embeds = self.model.get_text_features(input_ids=inputs['input_ids'])
|
347
|
+
similarity = F.cosine_similarity(image_embeds, text_embeds)
|
348
|
+
# Filter out images with similarity less than 0.25
|
349
|
+
return similarity[0] >= 0.25
|
350
|
+
|
351
|
+
return FilterTransform
|
298
352
|
|
299
353
|
|
300
354
|
# ----------------------------------------------------------------------------------
|
@@ -303,9 +357,9 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
303
357
|
|
304
358
|
# These functions maintain backward compatibility with existing code
|
305
359
|
|
306
|
-
def data_source_tfds(name, use_tf=True
|
360
|
+
def data_source_tfds(name, use_tf=True):
|
307
361
|
"""Legacy function for TFDS data sources."""
|
308
|
-
source = ImageTFDSSource(name=name, use_tf=use_tf
|
362
|
+
source = ImageTFDSSource(name=name, use_tf=use_tf)
|
309
363
|
return source.get_source
|
310
364
|
|
311
365
|
|
@@ -331,3 +385,8 @@ def gcs_augmenters(image_scale, method):
|
|
331
385
|
"""Legacy function for GCS augmenters."""
|
332
386
|
augmenter = ImageGCSAugmenter()
|
333
387
|
return augmenter.create_transform(image_scale=image_scale, method=method)
|
388
|
+
|
389
|
+
def gcs_filters(image_scale):
|
390
|
+
"""Legacy function for GCS Filters."""
|
391
|
+
augmenter = ImageGCSAugmenter()
|
392
|
+
return augmenter.create_filter(image_scale=image_scale)
|
flaxdiff/data/sources/videos.py
CHANGED
@@ -216,6 +216,11 @@ class AudioVideoAugmenter(DataAugmenter):
|
|
216
216
|
|
217
217
|
return AudioVideoTransform
|
218
218
|
|
219
|
+
|
220
|
+
def create_filter(self, image_scale: int = 256):
|
221
|
+
class FilterTransform(pygrain.FilterTransform):
|
222
|
+
def map(self, element) -> bool:
|
223
|
+
return True
|
219
224
|
|
220
225
|
# ----------------------------------------------------------------------------------
|
221
226
|
# Helper functions for video datasets
|
flaxdiff/inference/pipeline.py
CHANGED
@@ -25,6 +25,7 @@ from flaxdiff.inference.utils import parse_config, load_from_wandb_run, load_fro
|
|
25
25
|
@dataclass
|
26
26
|
class InferencePipeline:
|
27
27
|
"""Inference pipeline for a general model."""
|
28
|
+
name: str = None
|
28
29
|
model: nn.Module = None
|
29
30
|
state: SimpleTrainState = None
|
30
31
|
best_state: SimpleTrainState = None
|
@@ -44,6 +45,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
44
45
|
This pipeline handles loading models from wandb and generating samples using the
|
45
46
|
DiffusionSampler from FlaxDiff.
|
46
47
|
"""
|
48
|
+
artifact: Any = None
|
47
49
|
state: TrainState = None
|
48
50
|
best_state: TrainState = None
|
49
51
|
rngstate: Optional[RandomMarkovState] = None
|
@@ -51,7 +53,6 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
51
53
|
model_output_transform: DiffusionPredictionTransform = None
|
52
54
|
autoencoder: AutoEncoder = None
|
53
55
|
input_config: DiffusionInputConfig = None
|
54
|
-
wandb_run = None
|
55
56
|
samplers: Dict[Type[DiffusionSampler], Dict[float, DiffusionSampler]] = field(default_factory=dict)
|
56
57
|
config: Dict[str, Any] = field(default_factory=dict)
|
57
58
|
|
@@ -76,7 +77,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
76
77
|
Returns:
|
77
78
|
DiffusionInferencePipeline instance
|
78
79
|
"""
|
79
|
-
states, config, run = load_from_wandb_run(
|
80
|
+
states, config, run, artifact = load_from_wandb_run(
|
80
81
|
wandb_run,
|
81
82
|
project=project,
|
82
83
|
entity=entity,
|
@@ -95,6 +96,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
95
96
|
best_state=best_state,
|
96
97
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
97
98
|
run=run,
|
99
|
+
artifact=artifact,
|
98
100
|
)
|
99
101
|
return pipeline
|
100
102
|
|
@@ -119,7 +121,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
119
121
|
Returns:
|
120
122
|
DiffusionInferencePipeline instance
|
121
123
|
"""
|
122
|
-
states, config, run = load_from_wandb_registry(
|
124
|
+
states, config, run, artifact = load_from_wandb_registry(
|
123
125
|
modelname=modelname,
|
124
126
|
project=project,
|
125
127
|
entity=entity,
|
@@ -140,6 +142,7 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
140
142
|
best_state=best_state,
|
141
143
|
rngstate=RandomMarkovState(jax.random.PRNGKey(42)),
|
142
144
|
run=run,
|
145
|
+
artifact=artifact,
|
143
146
|
)
|
144
147
|
return pipeline
|
145
148
|
|
@@ -151,11 +154,14 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
151
154
|
best_state: Optional[Dict[str, Any]] = None,
|
152
155
|
rngstate: Optional[RandomMarkovState] = None,
|
153
156
|
run=None,
|
157
|
+
artifact=None,
|
154
158
|
):
|
155
159
|
if rngstate is None:
|
156
160
|
rngstate = RandomMarkovState(jax.random.PRNGKey(42))
|
157
161
|
# Build and return pipeline
|
158
162
|
return cls(
|
163
|
+
name=run.name if run else None,
|
164
|
+
artifact=artifact,
|
159
165
|
model=config['model'],
|
160
166
|
state=state,
|
161
167
|
best_state=best_state,
|
@@ -165,7 +171,6 @@ class DiffusionInferencePipeline(InferencePipeline):
|
|
165
171
|
autoencoder=config['autoencoder'],
|
166
172
|
input_config=config['input_config'],
|
167
173
|
config=config,
|
168
|
-
wandb_run=run,
|
169
174
|
)
|
170
175
|
|
171
176
|
def get_sampler(
|
flaxdiff/inference/utils.py
CHANGED
@@ -292,7 +292,7 @@ def load_from_wandb_run(
|
|
292
292
|
config = run.config
|
293
293
|
except Exception as e:
|
294
294
|
print(f"Warning: Failed to load model from wandb: {e}")
|
295
|
-
return states, config, run
|
295
|
+
return states, config, run, artifact
|
296
296
|
|
297
297
|
def load_from_wandb_registry(
|
298
298
|
modelname: str,
|
@@ -318,4 +318,4 @@ def load_from_wandb_registry(
|
|
318
318
|
config = run.config
|
319
319
|
except Exception as e:
|
320
320
|
print(f"Warning: Failed to load model from wandb: {e}")
|
321
|
-
return states, config, run
|
321
|
+
return states, config, run, artifact
|
flaxdiff/models/common.py
CHANGED
@@ -335,73 +335,4 @@ class ResidualBlock(nn.Module):
|
|
335
335
|
|
336
336
|
out = jnp.concatenate([out, extra_features], axis=-1) if extra_features is not None else out
|
337
337
|
|
338
|
-
return out
|
339
|
-
|
340
|
-
# Convert Hilbert index d to 2D coordinates (x, y) for an n x n grid
|
341
|
-
def _d2xy(n, d):
|
342
|
-
x = 0
|
343
|
-
y = 0
|
344
|
-
t = d
|
345
|
-
s = 1
|
346
|
-
while s < n:
|
347
|
-
rx = (t // 2) & 1
|
348
|
-
ry = (t ^ rx) & 1
|
349
|
-
if ry == 0:
|
350
|
-
if rx == 1:
|
351
|
-
x = n - 1 - x
|
352
|
-
y = n - 1 - y
|
353
|
-
x, y = y, x
|
354
|
-
x += s * rx
|
355
|
-
y += s * ry
|
356
|
-
t //= 4
|
357
|
-
s *= 2
|
358
|
-
return x, y
|
359
|
-
|
360
|
-
# Hilbert index mapping for a rectangular grid of patches H_P x W_P
|
361
|
-
|
362
|
-
def hilbert_indices(H_P, W_P):
|
363
|
-
size = max(H_P, W_P)
|
364
|
-
order = math.ceil(math.log2(size))
|
365
|
-
n = 1 << order
|
366
|
-
coords = []
|
367
|
-
for d in range(n * n):
|
368
|
-
x, y = _d2xy(n, d)
|
369
|
-
# x is column index, y is row index
|
370
|
-
if x < W_P and y < H_P:
|
371
|
-
coords.append((y, x)) # (row, col)
|
372
|
-
if len(coords) == H_P * W_P:
|
373
|
-
break
|
374
|
-
# Convert (row, col) to linear indices row-major
|
375
|
-
indices = [r * W_P + c for r, c in coords]
|
376
|
-
return jnp.array(indices, dtype=jnp.int32)
|
377
|
-
|
378
|
-
# Inverse permutation: given idx where idx[i] = new position of element i, return inv such that inv[idx[i]] = i
|
379
|
-
|
380
|
-
def inverse_permutation(idx):
|
381
|
-
inv = jnp.zeros_like(idx)
|
382
|
-
inv = inv.at[idx].set(jnp.arange(idx.shape[0], dtype=idx.dtype))
|
383
|
-
return inv
|
384
|
-
|
385
|
-
# Patchify using Hilbert ordering: extract patches and reorder sequence
|
386
|
-
|
387
|
-
def hilbert_patchify(x, patch_size):
|
388
|
-
B, H, W, C = x.shape
|
389
|
-
H_P = H // patch_size
|
390
|
-
W_P = W // patch_size
|
391
|
-
# Extract patches in row-major
|
392
|
-
patches = rearrange(x, 'b (h p1) (w p2) c -> b (h w) (p1 p2 c)', p1=patch_size, p2=patch_size)
|
393
|
-
idx = hilbert_indices(H_P, W_P)
|
394
|
-
return patches[:, idx, :]
|
395
|
-
|
396
|
-
# Unpatchify from Hilbert ordering: reorder sequence back and reconstruct image
|
397
|
-
|
398
|
-
def hilbert_unpatchify(patches, patch_size, H, W, C):
|
399
|
-
B, N, D = patches.shape
|
400
|
-
H_P = H // patch_size
|
401
|
-
W_P = W // patch_size
|
402
|
-
inv = inverse_permutation(hilbert_indices(H_P, W_P))
|
403
|
-
# Reorder back to row-major
|
404
|
-
linear = patches[:, inv, :]
|
405
|
-
# Reconstruct image
|
406
|
-
x = rearrange(linear, 'b (h w) (p1 p2 c) -> b (h p1) (w p2) c', h=H_P, w=W_P, p1=patch_size, p2=patch_size, c=C)
|
407
|
-
return x
|
338
|
+
return out
|