flaxdiff 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +11 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/images.py +29 -14
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/simple_dit.py +1 -202
- flaxdiff/models/simple_mmdit.py +1 -132
- flaxdiff/models/simple_vit.py +217 -118
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +22 -11
- flaxdiff/trainer/simple_trainer.py +14 -12
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/RECORD +14 -13
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/WHEEL +0 -0
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.10.dist-info}/top_level.txt +0 -0
flaxdiff/data/dataloaders.py
CHANGED
@@ -292,14 +292,12 @@ def get_dataset_grain(
|
|
292
292
|
Dictionary with train dataset function and metadata.
|
293
293
|
"""
|
294
294
|
dataset = datasetMap[data_name]
|
295
|
-
|
296
|
-
# val_source = dataset["source"](dataset_source, split="val")
|
295
|
+
data_source = dataset["source"](dataset_source)
|
297
296
|
augmenter = dataset["augmenter"](image_scale, method)
|
298
|
-
|
299
297
|
local_batch_size = batch_size // jax.process_count()
|
300
298
|
|
301
299
|
train_sampler = pygrain.IndexSampler(
|
302
|
-
num_records=len(
|
300
|
+
num_records=len(data_source) if count is None else count,
|
303
301
|
shuffle=True,
|
304
302
|
seed=seed,
|
305
303
|
num_epochs=num_epochs,
|
@@ -307,7 +305,7 @@ def get_dataset_grain(
|
|
307
305
|
)
|
308
306
|
|
309
307
|
# val_sampler = pygrain.IndexSampler(
|
310
|
-
# num_records=len(
|
308
|
+
# num_records=len(data_source) if count is None else count,
|
311
309
|
# shuffle=False,
|
312
310
|
# seed=seed,
|
313
311
|
# num_epochs=num_epochs,
|
@@ -318,16 +316,10 @@ def get_dataset_grain(
|
|
318
316
|
transformations = [
|
319
317
|
augmenter(),
|
320
318
|
]
|
321
|
-
|
322
|
-
# if filters:
|
323
|
-
# print("Adding filters to transformations")
|
324
|
-
# transformations.append(filters())
|
325
|
-
|
326
|
-
# transformations.append(CaptionDeletionTransform())
|
327
319
|
transformations.append(pygrain.Batch(local_batch_size, drop_remainder=True))
|
328
320
|
|
329
321
|
loader = pygrain.DataLoader(
|
330
|
-
data_source=
|
322
|
+
data_source=data_source,
|
331
323
|
sampler=train_sampler,
|
332
324
|
operations=transformations,
|
333
325
|
worker_count=worker_count,
|
@@ -341,26 +333,26 @@ def get_dataset_grain(
|
|
341
333
|
def get_valset():
|
342
334
|
transformations = [
|
343
335
|
augmenter(),
|
344
|
-
pygrain.Batch(
|
336
|
+
pygrain.Batch(32, drop_remainder=True),
|
345
337
|
]
|
346
338
|
|
347
339
|
loader = pygrain.DataLoader(
|
348
|
-
data_source=
|
340
|
+
data_source=data_source,
|
349
341
|
sampler=train_sampler,
|
350
342
|
operations=transformations,
|
351
|
-
worker_count=
|
343
|
+
worker_count=8,
|
352
344
|
read_options=pygrain.ReadOptions(
|
353
|
-
|
345
|
+
32, 128
|
354
346
|
),
|
355
|
-
worker_buffer_size=
|
347
|
+
worker_buffer_size=32,
|
356
348
|
)
|
357
349
|
return loader
|
358
350
|
|
359
351
|
return {
|
360
352
|
"train": get_trainset,
|
361
|
-
"train_len": len(
|
353
|
+
"train_len": len(data_source),
|
362
354
|
"val": get_valset,
|
363
|
-
"val_len": len(
|
355
|
+
"val_len": len(data_source),
|
364
356
|
"local_batch_size": local_batch_size,
|
365
357
|
"global_batch_size": batch_size,
|
366
358
|
}
|
flaxdiff/data/dataset_map.py
CHANGED
@@ -21,8 +21,9 @@ datasetMap = {
|
|
21
21
|
"augmenter": gcs_augmenters,
|
22
22
|
},
|
23
23
|
"laiona_coco": {
|
24
|
-
"source": data_source_gcs('datasets/laion12m+
|
24
|
+
"source": data_source_gcs('datasets/laion12m+mscoco'),
|
25
25
|
"augmenter": gcs_augmenters,
|
26
|
+
"filter": gcs_filters,
|
26
27
|
},
|
27
28
|
"aesthetic_coyo": {
|
28
29
|
"source": data_source_gcs('arrayrecords/aestheticCoyo_0.25clip_6aesthetic'),
|
flaxdiff/data/sources/images.py
CHANGED
@@ -11,7 +11,7 @@ import struct as st
|
|
11
11
|
from functools import partial
|
12
12
|
import numpy as np
|
13
13
|
from .base import DataSource, DataAugmenter
|
14
|
-
|
14
|
+
import traceback
|
15
15
|
|
16
16
|
# ----------------------------------------------------------------------------------
|
17
17
|
# Utility functions
|
@@ -79,10 +79,28 @@ def labelizer_oxford_flowers102(path):
|
|
79
79
|
# TFDS Image Source
|
80
80
|
# ----------------------------------------------------------------------------------
|
81
81
|
|
82
|
+
def get_oxford_valset(text_encoder):
|
83
|
+
# Construct a validation set by the prompts for consistency
|
84
|
+
val_prompts = ['water tulip', ' a water lily', ' a water lily', ' a photo of a rose', ' a photo of a rose', ' a water lily', ' a water lily', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a marigold', ' a water lily', ' a photo of a sunflower', ' a photo of a lotus', ' columbine', ' columbine', ' an orchid', ' an orchid', ' an orchid', ' a water lily', ' a water lily', ' a water lily', ' columbine', ' columbine', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a sunflower', ' a photo of a lotus', ' a photo of a lotus', ' a photo of a marigold', ' a photo of a marigold', ' a photo of a rose', ' a photo of a rose', ' a photo of a rose', ' orange dahlia', ' orange dahlia', ' a lenten rose', ' a lenten rose', ' a water lily', ' a water lily', ' a water lily', ' a water lily', ' an orchid', ' an orchid', ' an orchid', ' hard-leaved pocket orchid', ' bird of paradise', ' bird of paradise', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a globe-flower', ' a photo of a globe-flower', ' a photo of a lovely rose', ' a photo of a lovely rose', ' a photo of a ruby-lipped cattleya', ' a photo of a ruby-lipped cattleya', ' a photo of a lovely rose', ' a water lily', ' a osteospermum', ' a osteospermum', ' a water lily', ' a water lily', ' a water lily', ' a red rose', ' a red rose']
|
85
|
+
val_prompts *= 100
|
86
|
+
|
87
|
+
def get_val_dataset(batch_size=128):
|
88
|
+
for i in range(0, len(val_prompts), batch_size):
|
89
|
+
try:
|
90
|
+
prompts = val_prompts[i:i + batch_size]
|
91
|
+
tokens = text_encoder.tokenize(prompts)
|
92
|
+
yield {"text": tokens}
|
93
|
+
except Exception as e:
|
94
|
+
print(f"Error in get_val_dataset: {e}")
|
95
|
+
traceback.print_exc()
|
96
|
+
continue
|
97
|
+
|
98
|
+
return get_val_dataset, len(val_prompts)
|
99
|
+
|
82
100
|
class ImageTFDSSource(DataSource):
|
83
101
|
"""Data source for TensorFlow Datasets (TFDS) image datasets."""
|
84
102
|
|
85
|
-
def __init__(self, name: str, use_tf: bool = True):
|
103
|
+
def __init__(self, name: str, use_tf: bool = True, split: str = "all"):
|
86
104
|
"""Initialize a TFDS image data source.
|
87
105
|
|
88
106
|
Args:
|
@@ -92,8 +110,9 @@ class ImageTFDSSource(DataSource):
|
|
92
110
|
"""
|
93
111
|
self.name = name
|
94
112
|
self.use_tf = use_tf
|
113
|
+
self.split = split
|
95
114
|
|
96
|
-
def get_source(self, path_override: str
|
115
|
+
def get_source(self, path_override: str) -> Any:
|
97
116
|
"""Get the TFDS data source.
|
98
117
|
|
99
118
|
Args:
|
@@ -104,9 +123,9 @@ class ImageTFDSSource(DataSource):
|
|
104
123
|
"""
|
105
124
|
import tensorflow_datasets as tfds
|
106
125
|
if self.use_tf:
|
107
|
-
return tfds.load(self.name, split=split, shuffle_files=True)
|
126
|
+
return tfds.load(self.name, split=self.split, shuffle_files=True)
|
108
127
|
else:
|
109
|
-
return tfds.data_source(self.name, split=split, try_gcs=False)
|
128
|
+
return tfds.data_source(self.name, split=self.split, try_gcs=False)
|
110
129
|
|
111
130
|
|
112
131
|
class ImageTFDSAugmenter(DataAugmenter):
|
@@ -198,7 +217,7 @@ class ImageGCSSource(DataSource):
|
|
198
217
|
"""
|
199
218
|
self.source = source
|
200
219
|
|
201
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount"
|
220
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
202
221
|
"""Get the GCS data source.
|
203
222
|
|
204
223
|
Args:
|
@@ -210,8 +229,6 @@ class ImageGCSSource(DataSource):
|
|
210
229
|
records_path = os.path.join(path_override, self.source)
|
211
230
|
records = [os.path.join(records_path, i) for i in os.listdir(
|
212
231
|
records_path) if 'array_record' in i]
|
213
|
-
if split == "val":
|
214
|
-
records = records[:1]
|
215
232
|
return pygrain.ArrayRecordDataSource(records)
|
216
233
|
|
217
234
|
|
@@ -226,7 +243,7 @@ class CombinedImageGCSSource(DataSource):
|
|
226
243
|
"""
|
227
244
|
self.sources = sources
|
228
245
|
|
229
|
-
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount"
|
246
|
+
def get_source(self, path_override: str = "/home/mrwhite0racle/gcs_mount") -> Any:
|
230
247
|
"""Get the combined GCS data source.
|
231
248
|
|
232
249
|
Args:
|
@@ -240,8 +257,6 @@ class CombinedImageGCSSource(DataSource):
|
|
240
257
|
for records_path in records_paths:
|
241
258
|
records += [os.path.join(records_path, i) for i in os.listdir(
|
242
259
|
records_path) if 'array_record' in i]
|
243
|
-
if split == "val":
|
244
|
-
records = records[:1]
|
245
260
|
return pygrain.ArrayRecordDataSource(records)
|
246
261
|
|
247
262
|
class ImageGCSAugmenter(DataAugmenter):
|
@@ -357,9 +372,9 @@ class ImageGCSAugmenter(DataAugmenter):
|
|
357
372
|
|
358
373
|
# These functions maintain backward compatibility with existing code
|
359
374
|
|
360
|
-
def data_source_tfds(name, use_tf=True):
|
375
|
+
def data_source_tfds(name, use_tf=True, split="all"):
|
361
376
|
"""Legacy function for TFDS data sources."""
|
362
|
-
source = ImageTFDSSource(name=name, use_tf=use_tf)
|
377
|
+
source = ImageTFDSSource(name=name, use_tf=use_tf, split=split)
|
363
378
|
return source.get_source
|
364
379
|
|
365
380
|
|
@@ -389,4 +404,4 @@ def gcs_augmenters(image_scale, method):
|
|
389
404
|
def gcs_filters(image_scale):
|
390
405
|
"""Legacy function for GCS Filters."""
|
391
406
|
augmenter = ImageGCSAugmenter()
|
392
|
-
return augmenter.create_filter(image_scale=image_scale)
|
407
|
+
return augmenter.create_filter(image_scale=image_scale)
|
flaxdiff/inference/utils.py
CHANGED
@@ -25,6 +25,9 @@ from flaxdiff.models.autoencoder.diffusers import StableDiffusionVAE
|
|
25
25
|
from flaxdiff.inputs import DiffusionInputConfig, ConditionalInputConfig
|
26
26
|
from flaxdiff.utils import defaultTextEncodeModel
|
27
27
|
|
28
|
+
from flaxdiff.models.simple_vit import UViT, SimpleUDiT
|
29
|
+
from flaxdiff.models.simple_dit import SimpleDiT
|
30
|
+
from flaxdiff.models.simple_mmdit import SimpleMMDiT, HierarchicalMMDiT
|
28
31
|
from orbax.checkpoint import CheckpointManager, CheckpointManagerOptions, PyTreeCheckpointer
|
29
32
|
import os
|
30
33
|
|
@@ -116,7 +119,10 @@ def parse_config(config, overrides=None):
|
|
116
119
|
MODEL_CLASSES = {
|
117
120
|
'unet': Unet,
|
118
121
|
'uvit': UViT,
|
119
|
-
'diffusers_unet_simple': FlaxUNet2DConditionModel
|
122
|
+
'diffusers_unet_simple': FlaxUNet2DConditionModel,
|
123
|
+
'simple_dit': SimpleDiT,
|
124
|
+
'simple_mmdit': SimpleMMDiT,
|
125
|
+
'simple_udit': SimpleUDiT,
|
120
126
|
}
|
121
127
|
|
122
128
|
# Map all the leaves of the model config, converting strings to appropriate types
|
flaxdiff/models/simple_dit.py
CHANGED
@@ -1,4 +1,3 @@
|
|
1
|
-
from .simple_vit import PatchEmbedding, unpatchify
|
2
1
|
import jax
|
3
2
|
import jax.numpy as jnp
|
4
3
|
from flax import linen as nn
|
@@ -7,6 +6,7 @@ import einops
|
|
7
6
|
from functools import partial
|
8
7
|
|
9
8
|
# Re-use existing components if they are suitable
|
9
|
+
from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention, AdaLNParams
|
10
10
|
from .common import kernel_init, FourierEmbedding, TimeProjection
|
11
11
|
# Using NormalAttention for RoPE integration
|
12
12
|
from .attention import NormalAttention
|
@@ -15,207 +15,6 @@ from flax.typing import Dtype, PrecisionLike
|
|
15
15
|
# Use our improved Hilbert implementation
|
16
16
|
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
17
17
|
|
18
|
-
# --- Rotary Positional Embedding (RoPE) ---
|
19
|
-
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
|
20
|
-
|
21
|
-
|
22
|
-
def _rotate_half(x: jax.Array) -> jax.Array:
|
23
|
-
"""Rotates half the hidden dims of the input."""
|
24
|
-
x1 = x[..., : x.shape[-1] // 2]
|
25
|
-
x2 = x[..., x.shape[-1] // 2:]
|
26
|
-
return jnp.concatenate((-x2, x1), axis=-1)
|
27
|
-
|
28
|
-
def apply_rotary_embedding(
|
29
|
-
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
|
30
|
-
) -> jax.Array:
|
31
|
-
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
32
|
-
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
|
33
|
-
# freqs_cos/sin shape: [Sequence, Dimension / 2]
|
34
|
-
|
35
|
-
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
|
36
|
-
if x.ndim == 4: # [B, H, S, D]
|
37
|
-
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
|
38
|
-
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
|
39
|
-
elif x.ndim == 3: # [B, S, D]
|
40
|
-
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
|
41
|
-
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
|
42
|
-
|
43
|
-
# Duplicate cos and sin for the full dimension D
|
44
|
-
# Shape becomes [..., S, D]
|
45
|
-
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
46
|
-
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
47
|
-
|
48
|
-
# Apply rotation: x * cos + rotate_half(x) * sin
|
49
|
-
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
50
|
-
return x_rotated.astype(x.dtype)
|
51
|
-
|
52
|
-
class RotaryEmbedding(nn.Module):
|
53
|
-
dim: int # Dimension of the head
|
54
|
-
max_seq_len: int = 2048
|
55
|
-
base: int = 10000
|
56
|
-
dtype: Dtype = jnp.float32
|
57
|
-
|
58
|
-
def setup(self):
|
59
|
-
inv_freq = 1.0 / (
|
60
|
-
self.base ** (jnp.arange(0, self.dim, 2,
|
61
|
-
dtype=jnp.float32) / self.dim)
|
62
|
-
)
|
63
|
-
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
64
|
-
freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
|
65
|
-
|
66
|
-
# Store cosine and sine separately instead of as complex numbers
|
67
|
-
self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2]
|
68
|
-
self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2]
|
69
|
-
|
70
|
-
def __call__(self, seq_len: int):
|
71
|
-
if seq_len > self.max_seq_len:
|
72
|
-
raise ValueError(
|
73
|
-
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
74
|
-
# Return separate cos and sin components
|
75
|
-
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
|
76
|
-
# --- Attention with RoPE ---
|
77
|
-
|
78
|
-
|
79
|
-
class RoPEAttention(NormalAttention):
|
80
|
-
rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding
|
81
|
-
|
82
|
-
@nn.compact
|
83
|
-
def __call__(self, x, context=None, freqs_cis=None):
|
84
|
-
# x has shape [B, H, W, C] or [B, S, C]
|
85
|
-
orig_x_shape = x.shape
|
86
|
-
is_4d = len(orig_x_shape) == 4
|
87
|
-
if is_4d:
|
88
|
-
B, H, W, C = x.shape
|
89
|
-
seq_len = H * W
|
90
|
-
x = x.reshape((B, seq_len, C))
|
91
|
-
else:
|
92
|
-
B, seq_len, C = x.shape
|
93
|
-
|
94
|
-
context = x if context is None else context
|
95
|
-
if len(context.shape) == 4:
|
96
|
-
_B, _H, _W, _C = context.shape
|
97
|
-
context_seq_len = _H * _W
|
98
|
-
context = context.reshape((B, context_seq_len, _C))
|
99
|
-
# else: context is already [B, S_ctx, C]
|
100
|
-
|
101
|
-
query = self.query(x) # [B, S, H, D]
|
102
|
-
key = self.key(context) # [B, S_ctx, H, D]
|
103
|
-
value = self.value(context) # [B, S_ctx, H, D]
|
104
|
-
|
105
|
-
# Apply RoPE to query and key
|
106
|
-
if freqs_cis is None:
|
107
|
-
# Generate frequencies using the rope_emb instance
|
108
|
-
seq_len_q = query.shape[1] # Use query's sequence length
|
109
|
-
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
|
110
|
-
else:
|
111
|
-
# If freqs_cis is passed in as a tuple
|
112
|
-
freqs_cos, freqs_sin = freqs_cis
|
113
|
-
|
114
|
-
# Apply RoPE to query and key
|
115
|
-
# Permute to [B, H, S, D] for RoPE application
|
116
|
-
query = einops.rearrange(query, 'b s h d -> b h s d')
|
117
|
-
key = einops.rearrange(key, 'b s h d -> b h s d')
|
118
|
-
|
119
|
-
# Apply RoPE only up to the context sequence length for keys if different
|
120
|
-
# Assuming self-attention or context has same seq len for simplicity here
|
121
|
-
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
|
122
|
-
key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
|
123
|
-
|
124
|
-
# Permute back to [B, S, H, D] for dot_product_attention
|
125
|
-
query = einops.rearrange(query, 'b h s d -> b s h d')
|
126
|
-
key = einops.rearrange(key, 'b h s d -> b s h d')
|
127
|
-
|
128
|
-
hidden_states = nn.dot_product_attention(
|
129
|
-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
130
|
-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
131
|
-
deterministic=True
|
132
|
-
) # Output shape [B, S, H, D]
|
133
|
-
|
134
|
-
# Use the proj_attn from NormalAttention which expects [B, S, H, D]
|
135
|
-
proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
|
136
|
-
|
137
|
-
if is_4d:
|
138
|
-
proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
|
139
|
-
|
140
|
-
return proj
|
141
|
-
|
142
|
-
# --- adaLN-Zero ---
|
143
|
-
|
144
|
-
|
145
|
-
class AdaLNZero(nn.Module):
|
146
|
-
features: int
|
147
|
-
dtype: Optional[Dtype] = None
|
148
|
-
precision: PrecisionLike = None
|
149
|
-
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
|
150
|
-
|
151
|
-
@nn.compact
|
152
|
-
def __call__(self, x, conditioning):
|
153
|
-
# Project conditioning signal to get scale and shift parameters
|
154
|
-
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
|
155
|
-
# Or [B, 1, 6*features] if x is [B, S, F]
|
156
|
-
|
157
|
-
# Ensure conditioning has seq dim if x does
|
158
|
-
# x=[B,S,F], cond=[B,D_cond]
|
159
|
-
if x.ndim == 3 and conditioning.ndim == 2:
|
160
|
-
conditioning = jnp.expand_dims(
|
161
|
-
conditioning, axis=1) # cond=[B,1,D_cond]
|
162
|
-
|
163
|
-
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
|
164
|
-
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
|
165
|
-
ada_params = nn.Dense(
|
166
|
-
features=6 * self.features,
|
167
|
-
dtype=self.dtype,
|
168
|
-
precision=self.precision,
|
169
|
-
# Initialize projection to zero (Zero init)
|
170
|
-
kernel_init=nn.initializers.zeros,
|
171
|
-
name="ada_proj"
|
172
|
-
)(conditioning)
|
173
|
-
|
174
|
-
# Split into scale, shift, gate for MLP and Attention
|
175
|
-
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
176
|
-
ada_params, 6, axis=-1)
|
177
|
-
|
178
|
-
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
|
179
|
-
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
|
180
|
-
# Apply Layer Normalization
|
181
|
-
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
|
182
|
-
use_scale=False, use_bias=False, dtype=self.dtype)
|
183
|
-
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
|
184
|
-
|
185
|
-
norm_x = norm(x)
|
186
|
-
|
187
|
-
# Modulate for Attention path
|
188
|
-
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
189
|
-
|
190
|
-
# Modulate for MLP path
|
191
|
-
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
192
|
-
|
193
|
-
# Return modulated outputs and gates
|
194
|
-
return x_attn, gate_attn, x_mlp, gate_mlp
|
195
|
-
|
196
|
-
class AdaLNParams(nn.Module): # Renamed for clarity
|
197
|
-
features: int
|
198
|
-
dtype: Optional[Dtype] = None
|
199
|
-
precision: PrecisionLike = None
|
200
|
-
|
201
|
-
@nn.compact
|
202
|
-
def __call__(self, conditioning):
|
203
|
-
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
|
204
|
-
if conditioning.ndim == 2:
|
205
|
-
conditioning = jnp.expand_dims(conditioning, axis=1)
|
206
|
-
|
207
|
-
# Project conditioning to get 6 params per feature
|
208
|
-
ada_params = nn.Dense(
|
209
|
-
features=6 * self.features,
|
210
|
-
dtype=self.dtype,
|
211
|
-
precision=self.precision,
|
212
|
-
kernel_init=nn.initializers.zeros,
|
213
|
-
name="ada_proj"
|
214
|
-
)(conditioning)
|
215
|
-
# Return all params (or split if preferred, but maybe return tuple/dict)
|
216
|
-
# Shape: [B, 1, 6*F]
|
217
|
-
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
|
218
|
-
|
219
18
|
# --- DiT Block ---
|
220
19
|
class DiTBlock(nn.Module):
|
221
20
|
features: int
|
flaxdiff/models/simple_mmdit.py
CHANGED
@@ -7,143 +7,12 @@ from functools import partial
|
|
7
7
|
from flax.typing import Dtype, PrecisionLike
|
8
8
|
|
9
9
|
# Imports from local modules
|
10
|
-
from .
|
10
|
+
from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention
|
11
11
|
from .common import kernel_init, FourierEmbedding, TimeProjection
|
12
12
|
from .attention import NormalAttention # Base for RoPEAttention
|
13
13
|
# Replace common.hilbert_indices with improved implementation from hilbert.py
|
14
14
|
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
15
15
|
|
16
|
-
# --- Rotary Positional Embedding (RoPE) ---
|
17
|
-
# Re-used from simple_dit.py
|
18
|
-
|
19
|
-
|
20
|
-
def _rotate_half(x: jax.Array) -> jax.Array:
|
21
|
-
"""Rotates half the hidden dims of the input."""
|
22
|
-
x1 = x[..., : x.shape[-1] // 2]
|
23
|
-
x2 = x[..., x.shape[-1] // 2:]
|
24
|
-
return jnp.concatenate((-x2, x1), axis=-1)
|
25
|
-
|
26
|
-
|
27
|
-
def apply_rotary_embedding(
|
28
|
-
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
|
29
|
-
) -> jax.Array:
|
30
|
-
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
31
|
-
if x.ndim == 4: # [B, H, S, D]
|
32
|
-
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
|
33
|
-
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
|
34
|
-
elif x.ndim == 3: # [B, S, D]
|
35
|
-
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
|
36
|
-
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
|
37
|
-
else:
|
38
|
-
raise ValueError(f"Unsupported input dimension: {x.ndim}")
|
39
|
-
|
40
|
-
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
41
|
-
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
42
|
-
|
43
|
-
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
44
|
-
return x_rotated.astype(x.dtype)
|
45
|
-
|
46
|
-
|
47
|
-
class RotaryEmbedding(nn.Module):
|
48
|
-
dim: int
|
49
|
-
max_seq_len: int = 4096 # Increased default based on SimpleDiT
|
50
|
-
base: int = 10000
|
51
|
-
dtype: Dtype = jnp.float32
|
52
|
-
|
53
|
-
def setup(self):
|
54
|
-
inv_freq = 1.0 / (
|
55
|
-
self.base ** (jnp.arange(0, self.dim, 2,
|
56
|
-
dtype=jnp.float32) / self.dim)
|
57
|
-
)
|
58
|
-
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
59
|
-
freqs = jnp.outer(t, inv_freq)
|
60
|
-
self.freqs_cos = jnp.cos(freqs)
|
61
|
-
self.freqs_sin = jnp.sin(freqs)
|
62
|
-
|
63
|
-
def __call__(self, seq_len: int):
|
64
|
-
if seq_len > self.max_seq_len:
|
65
|
-
# Dynamically extend frequencies if needed (more robust)
|
66
|
-
t = jnp.arange(seq_len, dtype=jnp.float32)
|
67
|
-
inv_freq = 1.0 / (
|
68
|
-
self.base ** (jnp.arange(0, self.dim, 2,
|
69
|
-
dtype=jnp.float32) / self.dim)
|
70
|
-
)
|
71
|
-
freqs = jnp.outer(t, inv_freq)
|
72
|
-
freqs_cos = jnp.cos(freqs)
|
73
|
-
freqs_sin = jnp.sin(freqs)
|
74
|
-
# Consider caching extended freqs if this happens often
|
75
|
-
return freqs_cos, freqs_sin
|
76
|
-
# Or raise error like before:
|
77
|
-
# raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
78
|
-
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
|
79
|
-
|
80
|
-
# --- Attention with RoPE ---
|
81
|
-
# Re-used from simple_dit.py
|
82
|
-
|
83
|
-
|
84
|
-
class RoPEAttention(NormalAttention):
|
85
|
-
rope_emb: RotaryEmbedding = None
|
86
|
-
|
87
|
-
@nn.compact
|
88
|
-
def __call__(self, x, context=None, freqs_cis=None):
|
89
|
-
orig_x_shape = x.shape
|
90
|
-
is_4d = len(orig_x_shape) == 4
|
91
|
-
if is_4d:
|
92
|
-
B, H, W, C = x.shape
|
93
|
-
seq_len = H * W
|
94
|
-
x = x.reshape((B, seq_len, C))
|
95
|
-
else:
|
96
|
-
B, seq_len, C = x.shape
|
97
|
-
|
98
|
-
context = x if context is None else context
|
99
|
-
if len(context.shape) == 4:
|
100
|
-
_B, _H, _W, _C = context.shape
|
101
|
-
context_seq_len = _H * _W
|
102
|
-
context = context.reshape((B, context_seq_len, _C))
|
103
|
-
# else: # context is already [B, S_ctx, C]
|
104
|
-
|
105
|
-
query = self.query(x) # [B, S, H, D]
|
106
|
-
key = self.key(context) # [B, S_ctx, H, D]
|
107
|
-
value = self.value(context) # [B, S_ctx, H, D]
|
108
|
-
|
109
|
-
if freqs_cis is None and self.rope_emb is not None:
|
110
|
-
seq_len_q = query.shape[1] # Use query's sequence length
|
111
|
-
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
|
112
|
-
elif freqs_cis is not None:
|
113
|
-
freqs_cos, freqs_sin = freqs_cis
|
114
|
-
else:
|
115
|
-
# Should not happen if rope_emb is provided or freqs_cis are passed
|
116
|
-
raise ValueError("RoPE frequencies not provided.")
|
117
|
-
|
118
|
-
# Apply RoPE to query and key
|
119
|
-
# Permute to [B, H, S, D] for RoPE application
|
120
|
-
query = einops.rearrange(query, 'b s h d -> b h s d')
|
121
|
-
key = einops.rearrange(key, 'b s h d -> b h s d')
|
122
|
-
|
123
|
-
# Apply RoPE only up to the context sequence length for keys if different
|
124
|
-
# Assuming self-attention or context has same seq len for simplicity here
|
125
|
-
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
|
126
|
-
key = apply_rotary_embedding(
|
127
|
-
key, freqs_cos, freqs_sin) # Apply same freqs to key
|
128
|
-
|
129
|
-
# Permute back to [B, S, H, D] for dot_product_attention
|
130
|
-
query = einops.rearrange(query, 'b h s d -> b s h d')
|
131
|
-
key = einops.rearrange(key, 'b h s d -> b s h d')
|
132
|
-
|
133
|
-
hidden_states = nn.dot_product_attention(
|
134
|
-
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
135
|
-
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
136
|
-
deterministic=True
|
137
|
-
)
|
138
|
-
|
139
|
-
proj = self.proj_attn(hidden_states)
|
140
|
-
|
141
|
-
if is_4d:
|
142
|
-
proj = proj.reshape(orig_x_shape)
|
143
|
-
|
144
|
-
return proj
|
145
|
-
|
146
|
-
|
147
16
|
# --- MM-DiT AdaLN-Zero ---
|
148
17
|
class MMAdaLNZero(nn.Module):
|
149
18
|
"""
|