flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +23 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +75 -3
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +275 -0
- flaxdiff/models/simple_mmdit.py +730 -0
- flaxdiff/models/simple_vit.py +405 -145
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +30 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/RECORD +18 -15
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,730 @@
|
|
1
|
+
import jax
|
2
|
+
import jax.numpy as jnp
|
3
|
+
from flax import linen as nn
|
4
|
+
from typing import Callable, Any, Optional, Tuple, Sequence, Union, List
|
5
|
+
import einops
|
6
|
+
from functools import partial
|
7
|
+
from flax.typing import Dtype, PrecisionLike
|
8
|
+
|
9
|
+
# Imports from local modules
|
10
|
+
from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention
|
11
|
+
from .common import kernel_init, FourierEmbedding, TimeProjection
|
12
|
+
from .attention import NormalAttention # Base for RoPEAttention
|
13
|
+
# Replace common.hilbert_indices with improved implementation from hilbert.py
|
14
|
+
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
15
|
+
|
16
|
+
# --- MM-DiT AdaLN-Zero ---
|
17
|
+
class MMAdaLNZero(nn.Module):
|
18
|
+
"""
|
19
|
+
Adaptive Layer Normalization Zero (AdaLN-Zero) tailored for MM-DiT.
|
20
|
+
Projects time and text embeddings separately, combines them, and then
|
21
|
+
generates modulation parameters (scale, shift, gate) for attention and MLP paths.
|
22
|
+
"""
|
23
|
+
features: int
|
24
|
+
dtype: Optional[Dtype] = None
|
25
|
+
precision: PrecisionLike = None
|
26
|
+
norm_epsilon: float = 1e-5
|
27
|
+
use_mean_pooling: bool = True # Whether to use mean pooling for sequence inputs
|
28
|
+
|
29
|
+
@nn.compact
|
30
|
+
def __call__(self, x, t_emb, text_emb):
|
31
|
+
# x shape: [B, S, F]
|
32
|
+
# t_emb shape: [B, D_t]
|
33
|
+
# text_emb shape: [B, S_text, D_text] or [B, D_text]
|
34
|
+
|
35
|
+
# First normalize the input features
|
36
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
|
37
|
+
use_scale=False, use_bias=False, dtype=self.dtype)
|
38
|
+
norm_x = norm(x) # Shape: [B, S, F]
|
39
|
+
|
40
|
+
# Process time embedding: ensure it has a sequence dimension for later broadcasting
|
41
|
+
if t_emb.ndim == 2: # [B, D_t]
|
42
|
+
t_emb = jnp.expand_dims(t_emb, axis=1) # [B, 1, D_t]
|
43
|
+
|
44
|
+
# Process text embedding: if it has a sequence dimension different from x
|
45
|
+
if text_emb.ndim == 2: # [B, D_text]
|
46
|
+
text_emb = jnp.expand_dims(text_emb, axis=1) # [B, 1, D_text]
|
47
|
+
elif text_emb.ndim == 3 and self.use_mean_pooling and text_emb.shape[1] != x.shape[1]:
|
48
|
+
# Mean pooling is standard in MM-DiT for handling different sequence lengths
|
49
|
+
text_emb = jnp.mean(
|
50
|
+
text_emb, axis=1, keepdims=True) # [B, 1, D_text]
|
51
|
+
|
52
|
+
# Project time embedding
|
53
|
+
t_params = nn.Dense(
|
54
|
+
features=6 * self.features,
|
55
|
+
dtype=self.dtype,
|
56
|
+
precision=self.precision,
|
57
|
+
kernel_init=nn.initializers.zeros, # Zero init is standard in AdaLN-Zero
|
58
|
+
name="ada_t_proj"
|
59
|
+
)(t_emb) # Shape: [B, 1, 6*F]
|
60
|
+
|
61
|
+
# Project text embedding
|
62
|
+
text_params = nn.Dense(
|
63
|
+
features=6 * self.features,
|
64
|
+
dtype=self.dtype,
|
65
|
+
precision=self.precision,
|
66
|
+
kernel_init=nn.initializers.zeros, # Zero init
|
67
|
+
name="ada_text_proj"
|
68
|
+
)(text_emb) # Shape: [B, 1, 6*F] or [B, S_text, 6*F]
|
69
|
+
|
70
|
+
# If text_params still has a sequence dim different from t_params, mean pool it
|
71
|
+
if t_params.shape[1] != text_params.shape[1]:
|
72
|
+
text_params = jnp.mean(text_params, axis=1, keepdims=True)
|
73
|
+
|
74
|
+
# Combine parameters (summing is standard in MM-DiT)
|
75
|
+
ada_params = t_params + text_params # Shape: [B, 1, 6*F]
|
76
|
+
|
77
|
+
# Split into scale, shift, gate for MLP and Attention
|
78
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
79
|
+
ada_params, 6, axis=-1) # Each shape: [B, 1, F]
|
80
|
+
|
81
|
+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
|
82
|
+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
|
83
|
+
# Apply modulation for Attention path (broadcasting handled by JAX)
|
84
|
+
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
85
|
+
|
86
|
+
# Apply modulation for MLP path
|
87
|
+
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
88
|
+
|
89
|
+
# Return modulated outputs and gates
|
90
|
+
return x_attn, gate_attn, x_mlp, gate_mlp
|
91
|
+
|
92
|
+
|
93
|
+
# --- MM-DiT Block ---
|
94
|
+
class MMDiTBlock(nn.Module):
|
95
|
+
"""
|
96
|
+
A Transformer block adapted for MM-DiT, using MMAdaLNZero for conditioning.
|
97
|
+
"""
|
98
|
+
features: int
|
99
|
+
num_heads: int
|
100
|
+
rope_emb: RotaryEmbedding # Pass RoPE module
|
101
|
+
mlp_ratio: int = 4
|
102
|
+
dropout_rate: float = 0.0
|
103
|
+
dtype: Optional[Dtype] = None
|
104
|
+
precision: PrecisionLike = None
|
105
|
+
# Keep option, though RoPEAttention doesn't use it
|
106
|
+
use_flash_attention: bool = False
|
107
|
+
force_fp32_for_softmax: bool = True
|
108
|
+
norm_epsilon: float = 1e-5
|
109
|
+
|
110
|
+
def setup(self):
|
111
|
+
hidden_features = int(self.features * self.mlp_ratio)
|
112
|
+
# Use the new MMAdaLNZero block
|
113
|
+
self.ada_ln_zero = MMAdaLNZero(
|
114
|
+
self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
|
115
|
+
|
116
|
+
# RoPEAttention remains the same
|
117
|
+
self.attention = RoPEAttention(
|
118
|
+
query_dim=self.features,
|
119
|
+
heads=self.num_heads,
|
120
|
+
dim_head=self.features // self.num_heads,
|
121
|
+
dtype=self.dtype,
|
122
|
+
precision=self.precision,
|
123
|
+
use_bias=True, # Bias is common in DiT attention proj
|
124
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
125
|
+
rope_emb=self.rope_emb # Pass RoPE module instance
|
126
|
+
)
|
127
|
+
|
128
|
+
# Standard MLP block remains the same
|
129
|
+
self.mlp = nn.Sequential([
|
130
|
+
nn.Dense(features=hidden_features, dtype=self.dtype,
|
131
|
+
precision=self.precision),
|
132
|
+
nn.gelu, # Consider swish/silu if preferred
|
133
|
+
nn.Dense(features=self.features, dtype=self.dtype,
|
134
|
+
precision=self.precision)
|
135
|
+
])
|
136
|
+
|
137
|
+
@nn.compact
|
138
|
+
def __call__(self, x, t_emb, text_emb, freqs_cis):
|
139
|
+
# x shape: [B, S, F]
|
140
|
+
# t_emb shape: [B, D_t] or [B, 1, D_t]
|
141
|
+
# text_emb shape: [B, D_text] or [B, 1, D_text]
|
142
|
+
|
143
|
+
residual = x
|
144
|
+
|
145
|
+
# Apply MMAdaLNZero with separate time and text embeddings
|
146
|
+
x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(
|
147
|
+
x, t_emb, text_emb)
|
148
|
+
|
149
|
+
# Attention block (remains the same)
|
150
|
+
attn_output = self.attention(
|
151
|
+
x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
|
152
|
+
x = residual + gate_attn * attn_output
|
153
|
+
|
154
|
+
# MLP block (remains the same)
|
155
|
+
mlp_output = self.mlp(x_mlp)
|
156
|
+
x = x + gate_mlp * mlp_output
|
157
|
+
|
158
|
+
return x
|
159
|
+
|
160
|
+
|
161
|
+
# --- SimpleMMDiT ---
|
162
|
+
class SimpleMMDiT(nn.Module):
|
163
|
+
"""
|
164
|
+
A Simple Multi-Modal Diffusion Transformer (MM-DiT) implementation.
|
165
|
+
Integrates time and text conditioning using separate projections within
|
166
|
+
each transformer block, following the MM-DiT approach. Uses RoPE for
|
167
|
+
patch positional encoding.
|
168
|
+
"""
|
169
|
+
output_channels: int = 3
|
170
|
+
patch_size: int = 16
|
171
|
+
emb_features: int = 768
|
172
|
+
num_layers: int = 12
|
173
|
+
num_heads: int = 12
|
174
|
+
mlp_ratio: int = 4
|
175
|
+
dropout_rate: float = 0.0 # Typically 0 for diffusion
|
176
|
+
dtype: Optional[Dtype] = None
|
177
|
+
precision: PrecisionLike = None
|
178
|
+
# Passed down, but RoPEAttention uses NormalAttention
|
179
|
+
use_flash_attention: bool = False
|
180
|
+
force_fp32_for_softmax: bool = True
|
181
|
+
norm_epsilon: float = 1e-5
|
182
|
+
learn_sigma: bool = False # Option to predict sigma like in DiT paper
|
183
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
184
|
+
norm_groups: int = 0
|
185
|
+
activation: Callable = jax.nn.swish
|
186
|
+
|
187
|
+
def setup(self):
|
188
|
+
self.patch_embed = PatchEmbedding(
|
189
|
+
patch_size=self.patch_size,
|
190
|
+
embedding_dim=self.emb_features,
|
191
|
+
dtype=self.dtype,
|
192
|
+
precision=self.precision
|
193
|
+
)
|
194
|
+
|
195
|
+
# Time embedding projection (output dim: emb_features)
|
196
|
+
self.time_embed = nn.Sequential([
|
197
|
+
FourierEmbedding(features=self.emb_features),
|
198
|
+
TimeProjection(features=self.emb_features *
|
199
|
+
self.mlp_ratio), # Intermediate projection
|
200
|
+
nn.Dense(features=self.emb_features, dtype=self.dtype,
|
201
|
+
precision=self.precision) # Final projection
|
202
|
+
], name="time_embed")
|
203
|
+
|
204
|
+
# Add projection layer for Hilbert patches
|
205
|
+
if self.use_hilbert:
|
206
|
+
self.hilbert_proj = nn.Dense(
|
207
|
+
features=self.emb_features,
|
208
|
+
dtype=self.dtype,
|
209
|
+
precision=self.precision,
|
210
|
+
name="hilbert_projection"
|
211
|
+
)
|
212
|
+
# Text context projection (output dim: emb_features)
|
213
|
+
# Input dim depends on the text encoder output, assumed to be handled externally
|
214
|
+
self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
|
215
|
+
precision=self.precision, name="text_context_proj")
|
216
|
+
|
217
|
+
# Rotary Positional Embedding (for patches)
|
218
|
+
# Dim per head, max_len should cover max number of patches
|
219
|
+
self.rope = RotaryEmbedding(
|
220
|
+
dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype)
|
221
|
+
|
222
|
+
# Transformer Blocks (use MMDiTBlock)
|
223
|
+
self.blocks = [
|
224
|
+
MMDiTBlock(
|
225
|
+
features=self.emb_features,
|
226
|
+
num_heads=self.num_heads,
|
227
|
+
mlp_ratio=self.mlp_ratio,
|
228
|
+
dropout_rate=self.dropout_rate,
|
229
|
+
dtype=self.dtype,
|
230
|
+
precision=self.precision,
|
231
|
+
use_flash_attention=self.use_flash_attention,
|
232
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
233
|
+
norm_epsilon=self.norm_epsilon,
|
234
|
+
rope_emb=self.rope, # Pass RoPE instance
|
235
|
+
name=f"mmdit_block_{i}"
|
236
|
+
) for i in range(self.num_layers)
|
237
|
+
]
|
238
|
+
|
239
|
+
# Final Layer (Normalization + Linear Projection)
|
240
|
+
self.final_norm = nn.LayerNorm(
|
241
|
+
epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
242
|
+
# self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm") # Alternative
|
243
|
+
|
244
|
+
# Predict patch pixels + potentially sigma
|
245
|
+
output_dim = self.patch_size * self.patch_size * self.output_channels
|
246
|
+
if self.learn_sigma:
|
247
|
+
output_dim *= 2 # Predict both mean and variance (or log_variance)
|
248
|
+
|
249
|
+
self.final_proj = nn.Dense(
|
250
|
+
features=output_dim,
|
251
|
+
dtype=self.dtype,
|
252
|
+
precision=self.precision,
|
253
|
+
kernel_init=nn.initializers.zeros, # Initialize final layer to zero
|
254
|
+
name="final_proj"
|
255
|
+
)
|
256
|
+
|
257
|
+
@nn.compact
|
258
|
+
def __call__(self, x, temb, textcontext): # textcontext is required
|
259
|
+
B, H, W, C = x.shape
|
260
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
261
|
+
assert textcontext is not None, "textcontext must be provided for SimpleMMDiT"
|
262
|
+
|
263
|
+
# 1. Patch Embedding
|
264
|
+
if self.use_hilbert:
|
265
|
+
# Use hilbert_patchify which handles both patchification and reordering
|
266
|
+
patches_raw, hilbert_inv_idx = hilbert_patchify(
|
267
|
+
x, self.patch_size) # Shape [B, S, P*P*C]
|
268
|
+
# Apply projection
|
269
|
+
# Shape [B, S, emb_features]
|
270
|
+
patches = self.hilbert_proj(patches_raw)
|
271
|
+
else:
|
272
|
+
# Shape: [B, num_patches, emb_features]
|
273
|
+
patches = self.patch_embed(x)
|
274
|
+
hilbert_inv_idx = None
|
275
|
+
|
276
|
+
num_patches = patches.shape[1]
|
277
|
+
x_seq = patches
|
278
|
+
|
279
|
+
# 2. Prepare Conditioning Signals
|
280
|
+
t_emb = self.time_embed(temb) # Shape: [B, emb_features]
|
281
|
+
# Assuming textcontext is [B, context_seq_len, context_dim] or [B, context_dim]
|
282
|
+
# If [B, context_seq_len, context_dim], usually mean/pool or take CLS token first.
|
283
|
+
# Assuming textcontext is already pooled/CLS token: [B, context_dim]
|
284
|
+
text_emb = self.text_proj(textcontext) # Shape: [B, emb_features]
|
285
|
+
|
286
|
+
# 3. Apply RoPE Frequencies (only to patch tokens)
|
287
|
+
seq_len = x_seq.shape[1]
|
288
|
+
freqs_cos, freqs_sin = self.rope(seq_len) # Shapes: [S, D_head/2]
|
289
|
+
|
290
|
+
# 4. Apply Transformer Blocks
|
291
|
+
for block in self.blocks:
|
292
|
+
# Pass t_emb and text_emb separately to the block
|
293
|
+
x_seq = block(x_seq, t_emb, text_emb,
|
294
|
+
freqs_cis=(freqs_cos, freqs_sin))
|
295
|
+
|
296
|
+
# 5. Final Layer
|
297
|
+
x_seq = self.final_norm(x_seq)
|
298
|
+
# Shape: [B, num_patches, P*P*C (*2 if learn_sigma)]
|
299
|
+
x_seq = self.final_proj(x_seq)
|
300
|
+
|
301
|
+
# 6. Unpatchify
|
302
|
+
if self.use_hilbert:
|
303
|
+
# For Hilbert mode, we need to use the specialized unpatchify function
|
304
|
+
if self.learn_sigma:
|
305
|
+
# Split into mean and variance predictions
|
306
|
+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
|
307
|
+
x_image = hilbert_unpatchify(
|
308
|
+
x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
309
|
+
# If needed, also unpack the logvar
|
310
|
+
# logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
311
|
+
# return x_image, logvar_image
|
312
|
+
return x_image
|
313
|
+
else:
|
314
|
+
x_image = hilbert_unpatchify(
|
315
|
+
x_seq, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
316
|
+
return x_image
|
317
|
+
else:
|
318
|
+
# Standard patch ordering - use the existing unpatchify function
|
319
|
+
if self.learn_sigma:
|
320
|
+
# Split into mean and variance predictions
|
321
|
+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
|
322
|
+
x = unpatchify(x_mean, channels=self.output_channels)
|
323
|
+
# Return both mean and logvar if needed by the loss function
|
324
|
+
# For now, just returning the mean prediction like standard diffusion models
|
325
|
+
# logvar = unpatchify(x_logvar, channels=self.output_channels)
|
326
|
+
# return x, logvar
|
327
|
+
return x
|
328
|
+
else:
|
329
|
+
# Shape: [B, H, W, C]
|
330
|
+
x = unpatchify(x_seq, channels=self.output_channels)
|
331
|
+
return x
|
332
|
+
|
333
|
+
|
334
|
+
# --- Hierarchical MM-DiT components ---
|
335
|
+
|
336
|
+
class PatchMerging(nn.Module):
|
337
|
+
"""
|
338
|
+
Merges a group of patches into a single patch with increased feature dimensions.
|
339
|
+
Used in the hierarchical structure to reduce spatial resolution and increase channels.
|
340
|
+
"""
|
341
|
+
out_features: int
|
342
|
+
merge_size: int = 2 # Default 2x2 patch merging
|
343
|
+
dtype: Optional[Dtype] = None
|
344
|
+
precision: PrecisionLike = None
|
345
|
+
norm_epsilon: float = 1e-5 # Add norm for stability like in Swin Transformer
|
346
|
+
|
347
|
+
@nn.compact
|
348
|
+
def __call__(self, x, H_patches, W_patches):
|
349
|
+
# x shape: [B, H*W, C]
|
350
|
+
B, L, C = x.shape
|
351
|
+
assert L == H_patches * \
|
352
|
+
W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
|
353
|
+
assert H_patches % self.merge_size == 0 and W_patches % self.merge_size == 0, f"Patch dimensions ({H_patches}, {W_patches}) not divisible by merge size {self.merge_size}"
|
354
|
+
|
355
|
+
# Reshape to [B, H, W, C]
|
356
|
+
x = x.reshape(B, H_patches, W_patches, C)
|
357
|
+
|
358
|
+
# Merge patches - rearrange to group nearby patches
|
359
|
+
merged = einops.rearrange(
|
360
|
+
x,
|
361
|
+
'b (h p1) (w p2) c -> b h w (p1 p2 c)',
|
362
|
+
p1=self.merge_size, p2=self.merge_size
|
363
|
+
)
|
364
|
+
|
365
|
+
# Apply LayerNorm before projection (common practice)
|
366
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
|
367
|
+
merged = norm(merged) # Apply norm on [B, H/p, W/p, p*p*C]
|
368
|
+
|
369
|
+
# Project to new dimension
|
370
|
+
merged = nn.Dense(
|
371
|
+
features=self.out_features,
|
372
|
+
dtype=self.dtype,
|
373
|
+
precision=self.precision,
|
374
|
+
name="projection"
|
375
|
+
)(merged)
|
376
|
+
|
377
|
+
# Flatten back to sequence
|
378
|
+
new_H = H_patches // self.merge_size
|
379
|
+
new_W = W_patches // self.merge_size
|
380
|
+
merged = merged.reshape(B, new_H * new_W, self.out_features)
|
381
|
+
|
382
|
+
return merged, new_H, new_W
|
383
|
+
|
384
|
+
class PatchExpanding(nn.Module):
|
385
|
+
"""
|
386
|
+
Expands patches to increase spatial resolution.
|
387
|
+
Used in the hierarchical structure decoder path.
|
388
|
+
"""
|
389
|
+
out_features: int
|
390
|
+
expand_size: int = 2 # Default 2x2 patch expansion
|
391
|
+
dtype: Optional[Dtype] = None
|
392
|
+
precision: PrecisionLike = None
|
393
|
+
norm_epsilon: float = 1e-5 # Add norm for stability
|
394
|
+
|
395
|
+
@nn.compact
|
396
|
+
def __call__(self, x, H_patches, W_patches):
|
397
|
+
# x shape: [B, H*W, C]
|
398
|
+
B, L, C = x.shape
|
399
|
+
assert L == H_patches * W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
|
400
|
+
|
401
|
+
# Project to expanded dimension first
|
402
|
+
expanded_features = self.expand_size * self.expand_size * self.out_features
|
403
|
+
x = nn.Dense(
|
404
|
+
features=expanded_features,
|
405
|
+
dtype=self.dtype,
|
406
|
+
precision=self.precision,
|
407
|
+
name="projection"
|
408
|
+
)(x) # Shape [B, L, P*P*C_out]
|
409
|
+
|
410
|
+
# Apply LayerNorm after projection
|
411
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
|
412
|
+
x = norm(x)
|
413
|
+
|
414
|
+
# Reshape to spatial grid before rearranging
|
415
|
+
x = x.reshape(B, H_patches, W_patches, expanded_features)
|
416
|
+
|
417
|
+
# Rearrange to expand spatial dims
|
418
|
+
expanded = einops.rearrange(
|
419
|
+
x,
|
420
|
+
'b h w (p1 p2 c) -> b (h p1) (w p2) c',
|
421
|
+
p1=self.expand_size, p2=self.expand_size, c=self.out_features
|
422
|
+
)
|
423
|
+
|
424
|
+
# Flatten back to sequence
|
425
|
+
new_H = H_patches * self.expand_size
|
426
|
+
new_W = W_patches * self.expand_size
|
427
|
+
expanded = expanded.reshape(B, new_H * new_W, self.out_features)
|
428
|
+
|
429
|
+
return expanded, new_H, new_W
|
430
|
+
|
431
|
+
|
432
|
+
# --- Hierarchical MM-DiT ---
|
433
|
+
class HierarchicalMMDiT(nn.Module):
|
434
|
+
"""
|
435
|
+
A Hierarchical Multi-Modal Diffusion Transformer (MM-DiT) implementation
|
436
|
+
based on the PixArt-α architecture. Processes images at multiple resolutions
|
437
|
+
with skip connections between encoder and decoder paths.
|
438
|
+
Follows a U-Net like structure: Fine -> Coarse (Encoder) -> Coarse -> Fine (Decoder).
|
439
|
+
"""
|
440
|
+
output_channels: int = 3
|
441
|
+
base_patch_size: int = 8 # Patch size at the *finest* resolution level (stage 0)
|
442
|
+
emb_features: Sequence[int] = (512, 768, 1024) # Feature dimensions for stages 0, 1, 2 (fine to coarse)
|
443
|
+
num_layers: Sequence[int] = (4, 4, 14) # Layers per stage (can be asymmetric encoder/decoder if needed)
|
444
|
+
num_heads: Sequence[int] = (8, 12, 16) # Heads per stage (fine to coarse)
|
445
|
+
mlp_ratio: int = 4
|
446
|
+
dropout_rate: float = 0.0
|
447
|
+
dtype: Optional[Dtype] = None
|
448
|
+
precision: PrecisionLike = None
|
449
|
+
use_flash_attention: bool = False
|
450
|
+
force_fp32_for_softmax: bool = True
|
451
|
+
norm_epsilon: float = 1e-5
|
452
|
+
learn_sigma: bool = False
|
453
|
+
use_hilbert: bool = False
|
454
|
+
norm_groups: int = 0 # Not used in this structure, maybe remove later
|
455
|
+
activation: Callable = jax.nn.swish # Not used directly here, used in MLP inside MMDiTBlock
|
456
|
+
|
457
|
+
def setup(self):
|
458
|
+
assert len(self.emb_features) == len(self.num_layers) == len(self.num_heads), \
|
459
|
+
"Feature dimensions, layers, and heads must have the same number of stages"
|
460
|
+
|
461
|
+
num_stages = len(self.emb_features)
|
462
|
+
|
463
|
+
# 1. Initial Patch Embedding (FINEST level - stage 0)
|
464
|
+
self.patch_embed = PatchEmbedding(
|
465
|
+
patch_size=self.base_patch_size,
|
466
|
+
embedding_dim=self.emb_features[0], # Finest embedding dim
|
467
|
+
dtype=self.dtype,
|
468
|
+
precision=self.precision
|
469
|
+
)
|
470
|
+
|
471
|
+
# 2. Time/Text Embeddings (Projected for each stage)
|
472
|
+
# Base projection to largest dimension first for stability/capacity
|
473
|
+
base_t_emb_dim = self.emb_features[-1]
|
474
|
+
self.time_embed_base = nn.Sequential([
|
475
|
+
FourierEmbedding(features=base_t_emb_dim),
|
476
|
+
TimeProjection(features=base_t_emb_dim * self.mlp_ratio),
|
477
|
+
nn.Dense(features=base_t_emb_dim, dtype=self.dtype, precision=self.precision)
|
478
|
+
], name="time_embed_base")
|
479
|
+
self.text_proj_base = nn.Dense(
|
480
|
+
features=base_t_emb_dim,
|
481
|
+
dtype=self.dtype,
|
482
|
+
precision=self.precision,
|
483
|
+
name="text_context_proj_base"
|
484
|
+
)
|
485
|
+
# Projections for each stage (0 to N-1)
|
486
|
+
self.t_emb_projs = [
|
487
|
+
nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"t_emb_proj_stage{i}")
|
488
|
+
for i in range(num_stages)
|
489
|
+
]
|
490
|
+
self.text_emb_projs = [
|
491
|
+
nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"text_emb_proj_stage{i}")
|
492
|
+
for i in range(num_stages)
|
493
|
+
]
|
494
|
+
|
495
|
+
# 3. Hilbert projection (if used, applied after initial patch embedding)
|
496
|
+
if self.use_hilbert:
|
497
|
+
self.hilbert_proj = nn.Dense(
|
498
|
+
features=self.emb_features[0], # Match finest embedding dim
|
499
|
+
dtype=self.dtype,
|
500
|
+
precision=self.precision,
|
501
|
+
name="hilbert_projection"
|
502
|
+
)
|
503
|
+
|
504
|
+
# 4. RoPE embeddings for each stage (0 to N-1)
|
505
|
+
self.ropes = [
|
506
|
+
RotaryEmbedding(
|
507
|
+
dim=self.emb_features[i] // self.num_heads[i],
|
508
|
+
max_seq_len=4096, # Adjust if needed based on max patch count per stage
|
509
|
+
dtype=self.dtype,
|
510
|
+
name=f"rope_stage_{i}"
|
511
|
+
)
|
512
|
+
for i in range(num_stages)
|
513
|
+
]
|
514
|
+
|
515
|
+
# 5. --- Encoder Path (Fine to Coarse) ---
|
516
|
+
encoder_blocks = []
|
517
|
+
patch_mergers = []
|
518
|
+
for stage in range(num_stages):
|
519
|
+
# Blocks for this stage
|
520
|
+
stage_blocks = [
|
521
|
+
MMDiTBlock(
|
522
|
+
features=self.emb_features[stage],
|
523
|
+
num_heads=self.num_heads[stage],
|
524
|
+
mlp_ratio=self.mlp_ratio,
|
525
|
+
dropout_rate=self.dropout_rate,
|
526
|
+
dtype=self.dtype,
|
527
|
+
precision=self.precision,
|
528
|
+
use_flash_attention=self.use_flash_attention,
|
529
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
530
|
+
norm_epsilon=self.norm_epsilon,
|
531
|
+
rope_emb=self.ropes[stage],
|
532
|
+
name=f"encoder_block_stage{stage}_{i}"
|
533
|
+
)
|
534
|
+
# Assuming symmetric layers for now, adjust if needed (e.g., self.num_encoder_layers[stage])
|
535
|
+
for i in range(self.num_layers[stage])
|
536
|
+
]
|
537
|
+
encoder_blocks.append(stage_blocks)
|
538
|
+
|
539
|
+
# Patch Merging layer (except for the last/coarsest stage)
|
540
|
+
if stage < num_stages - 1:
|
541
|
+
patch_mergers.append(
|
542
|
+
PatchMerging(
|
543
|
+
out_features=self.emb_features[stage + 1], # Target next stage dim
|
544
|
+
dtype=self.dtype,
|
545
|
+
precision=self.precision,
|
546
|
+
norm_epsilon=self.norm_epsilon,
|
547
|
+
name=f"patch_merger_{stage}"
|
548
|
+
)
|
549
|
+
)
|
550
|
+
self.encoder_blocks = encoder_blocks
|
551
|
+
self.patch_mergers = patch_mergers
|
552
|
+
|
553
|
+
# 6. --- Decoder Path (Coarse to Fine) ---
|
554
|
+
decoder_blocks = []
|
555
|
+
patch_expanders = []
|
556
|
+
fusion_layers = []
|
557
|
+
# Iterate from second coarsest stage (N-2) down to finest (0)
|
558
|
+
for stage in range(num_stages - 2, -1, -1):
|
559
|
+
# Patch Expanding layer (Expands from stage+1 to stage)
|
560
|
+
patch_expanders.append(
|
561
|
+
PatchExpanding(
|
562
|
+
out_features=self.emb_features[stage], # Target current stage dim
|
563
|
+
dtype=self.dtype,
|
564
|
+
precision=self.precision,
|
565
|
+
norm_epsilon=self.norm_epsilon,
|
566
|
+
name=f"patch_expander_{stage}" # Naming indicates target stage
|
567
|
+
)
|
568
|
+
)
|
569
|
+
# Fusion layer (Combines skip[stage] and expanded[stage+1]->[stage])
|
570
|
+
fusion_layers.append(
|
571
|
+
nn.Sequential([ # Use Sequential for Norm + Dense
|
572
|
+
nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name=f"fusion_norm_{stage}"),
|
573
|
+
nn.Dense(
|
574
|
+
features=self.emb_features[stage], # Output current stage dim
|
575
|
+
dtype=self.dtype,
|
576
|
+
precision=self.precision,
|
577
|
+
name=f"fusion_dense_{stage}"
|
578
|
+
)
|
579
|
+
])
|
580
|
+
)
|
581
|
+
|
582
|
+
# Blocks for this stage (stage N-2 down to 0)
|
583
|
+
# Assuming symmetric layers for now
|
584
|
+
stage_blocks = [
|
585
|
+
MMDiTBlock(
|
586
|
+
features=self.emb_features[stage],
|
587
|
+
num_heads=self.num_heads[stage],
|
588
|
+
mlp_ratio=self.mlp_ratio,
|
589
|
+
dropout_rate=self.dropout_rate,
|
590
|
+
dtype=self.dtype,
|
591
|
+
precision=self.precision,
|
592
|
+
use_flash_attention=self.use_flash_attention,
|
593
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
594
|
+
norm_epsilon=self.norm_epsilon,
|
595
|
+
rope_emb=self.ropes[stage],
|
596
|
+
name=f"decoder_block_stage{stage}_{i}"
|
597
|
+
)
|
598
|
+
for i in range(self.num_layers[stage])
|
599
|
+
]
|
600
|
+
# Append blocks in order: stage N-2, N-3, ..., 0
|
601
|
+
decoder_blocks.append(stage_blocks)
|
602
|
+
|
603
|
+
self.patch_expanders = patch_expanders
|
604
|
+
self.fusion_layers = fusion_layers
|
605
|
+
self.decoder_blocks = decoder_blocks
|
606
|
+
|
607
|
+
# Note: The lists expanders, fusion_layers, decoder_blocks are now ordered
|
608
|
+
# corresponding to stages N-2, N-3, ..., 0.
|
609
|
+
|
610
|
+
# 7. Final Layer
|
611
|
+
self.final_norm = nn.LayerNorm(
|
612
|
+
epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
613
|
+
|
614
|
+
# Output projection to pixels (at finest resolution)
|
615
|
+
output_dim = self.base_patch_size * self.base_patch_size * self.output_channels
|
616
|
+
if self.learn_sigma:
|
617
|
+
output_dim *= 2 # Predict both mean and variance
|
618
|
+
|
619
|
+
self.final_proj = nn.Dense(
|
620
|
+
features=output_dim,
|
621
|
+
dtype=self.dtype,
|
622
|
+
precision=self.precision,
|
623
|
+
kernel_init=nn.initializers.zeros, # Zero init
|
624
|
+
name="final_proj"
|
625
|
+
)
|
626
|
+
|
627
|
+
def __call__(self, x, temb, textcontext):
|
628
|
+
B, H, W, C = x.shape
|
629
|
+
num_stages = len(self.emb_features)
|
630
|
+
finest_patch_size = self.base_patch_size
|
631
|
+
|
632
|
+
# Assertions
|
633
|
+
assert H % (finest_patch_size * (2**(num_stages - 1))) == 0 and \
|
634
|
+
W % (finest_patch_size * (2**(num_stages - 1))) == 0, \
|
635
|
+
f"Image dimensions ({H},{W}) must be divisible by effective coarsest patch size {finest_patch_size * (2**(num_stages - 1))}"
|
636
|
+
assert textcontext is not None, "textcontext must be provided"
|
637
|
+
|
638
|
+
# 1. Initial Patch Embedding (Finest Level - stage 0)
|
639
|
+
H_patches = H // finest_patch_size
|
640
|
+
W_patches = W // finest_patch_size
|
641
|
+
total_patches = H_patches * W_patches # Calculate total patches
|
642
|
+
hilbert_inv_idx = None
|
643
|
+
if self.use_hilbert:
|
644
|
+
# Calculate Hilbert indices and inverse permutation for the *finest* grid
|
645
|
+
fine_idx = hilbert_indices(H_patches, W_patches)
|
646
|
+
# Pass the total number of patches as total_size
|
647
|
+
hilbert_inv_idx = inverse_permutation(fine_idx, total_size=total_patches) # Store for unpatchify
|
648
|
+
|
649
|
+
# Apply Hilbert patchify at the finest level
|
650
|
+
patches_raw, _ = hilbert_patchify(x, finest_patch_size) # We already have inv_idx
|
651
|
+
x_seq = self.hilbert_proj(patches_raw) # Shape [B, S_fine, emb[0]]
|
652
|
+
else:
|
653
|
+
x_seq = self.patch_embed(x) # Shape [B, S_fine, emb[0]]
|
654
|
+
|
655
|
+
# 2. Prepare Conditioning Signals for each stage
|
656
|
+
t_emb_base = self.time_embed_base(temb)
|
657
|
+
text_emb_base = self.text_proj_base(textcontext)
|
658
|
+
t_embs = [proj(t_emb_base) for proj in self.t_emb_projs] # List for stages 0 to N-1
|
659
|
+
text_embs = [proj(text_emb_base) for proj in self.text_emb_projs] # List for stages 0 to N-1
|
660
|
+
|
661
|
+
# --- Encoder Path (Fine to Coarse: stages 0 to N-1) ---
|
662
|
+
skip_features = {}
|
663
|
+
current_H_patches, current_W_patches = H_patches, W_patches
|
664
|
+
for stage in range(num_stages):
|
665
|
+
# Apply RoPE for current stage
|
666
|
+
seq_len = x_seq.shape[1]
|
667
|
+
freqs_cos, freqs_sin = self.ropes[stage](seq_len)
|
668
|
+
|
669
|
+
# Apply blocks for this stage
|
670
|
+
for block in self.encoder_blocks[stage]:
|
671
|
+
x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
|
672
|
+
|
673
|
+
# Store skip features (before merging)
|
674
|
+
skip_features[stage] = x_seq
|
675
|
+
|
676
|
+
# Apply Patch Merging (if not the last/coarsest stage)
|
677
|
+
if stage < num_stages - 1:
|
678
|
+
x_seq, current_H_patches, current_W_patches = self.patch_mergers[stage](
|
679
|
+
x_seq, current_H_patches, current_W_patches
|
680
|
+
)
|
681
|
+
|
682
|
+
# --- Bottleneck ---
|
683
|
+
# x_seq now holds the output of the coarsest stage (stage N-1)
|
684
|
+
|
685
|
+
# --- Decoder Path (Coarse to Fine: stages N-2 down to 0) ---
|
686
|
+
# Decoder lists (expanders, fusion, blocks) are ordered for stages N-2, ..., 0
|
687
|
+
for i, stage in enumerate(range(num_stages - 2, -1, -1)): # stage = N-2, N-3, ..., 0; i = 0, 1, ..., N-2
|
688
|
+
# Apply Patch Expanding (Expand from stage+1 feature map to stage feature map)
|
689
|
+
x_seq, current_H_patches, current_W_patches = self.patch_expanders[i](
|
690
|
+
x_seq, current_H_patches, current_W_patches
|
691
|
+
)
|
692
|
+
|
693
|
+
# Fusion with skip connection from corresponding encoder stage
|
694
|
+
skip = skip_features[stage]
|
695
|
+
x_seq = jnp.concatenate([x_seq, skip], axis=-1) # Concatenate along feature dim
|
696
|
+
x_seq = self.fusion_layers[i](x_seq) # Apply fusion (Norm + Dense)
|
697
|
+
|
698
|
+
# Apply RoPE for current stage
|
699
|
+
seq_len = x_seq.shape[1]
|
700
|
+
freqs_cos, freqs_sin = self.ropes[stage](seq_len)
|
701
|
+
|
702
|
+
# Apply blocks for this stage
|
703
|
+
for block in self.decoder_blocks[i]: # Use index i for the decoder block list
|
704
|
+
x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
|
705
|
+
|
706
|
+
# --- Final Layer ---
|
707
|
+
# x_seq should now be at the finest resolution (stage 0 features)
|
708
|
+
x_seq = self.final_norm(x_seq)
|
709
|
+
x_seq = self.final_proj(x_seq) # Project to patch pixel values
|
710
|
+
|
711
|
+
# --- Unpatchify ---
|
712
|
+
if self.use_hilbert:
|
713
|
+
# Use the inverse Hilbert index calculated for the *finest* grid
|
714
|
+
assert hilbert_inv_idx is not None, "Hilbert inverse index should exist if use_hilbert is True"
|
715
|
+
if self.learn_sigma:
|
716
|
+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
|
717
|
+
out = hilbert_unpatchify(x_mean, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
|
718
|
+
# Optionally return logvar: logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
|
719
|
+
else:
|
720
|
+
out = hilbert_unpatchify(x_seq, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
|
721
|
+
else:
|
722
|
+
# Standard unpatchify
|
723
|
+
if self.learn_sigma:
|
724
|
+
x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
|
725
|
+
out = unpatchify(x_mean, channels=self.output_channels)
|
726
|
+
# Optionally return logvar: logvar = unpatchify(x_logvar, channels=self.output_channels)
|
727
|
+
else:
|
728
|
+
out = unpatchify(x_seq, channels=self.output_channels)
|
729
|
+
|
730
|
+
return out
|