flaxdiff 0.2.6.1__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,861 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Callable, Any, Optional, Tuple, Sequence, Union, List
5
+ import einops
6
+ from functools import partial
7
+ from flax.typing import Dtype, PrecisionLike
8
+
9
+ # Imports from local modules
10
+ from .simple_vit import PatchEmbedding, unpatchify
11
+ from .common import kernel_init, FourierEmbedding, TimeProjection
12
+ from .attention import NormalAttention # Base for RoPEAttention
13
+ # Replace common.hilbert_indices with improved implementation from hilbert.py
14
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
15
+
16
+ # --- Rotary Positional Embedding (RoPE) ---
17
+ # Re-used from simple_dit.py
18
+
19
+
20
+ def _rotate_half(x: jax.Array) -> jax.Array:
21
+ """Rotates half the hidden dims of the input."""
22
+ x1 = x[..., : x.shape[-1] // 2]
23
+ x2 = x[..., x.shape[-1] // 2:]
24
+ return jnp.concatenate((-x2, x1), axis=-1)
25
+
26
+
27
+ def apply_rotary_embedding(
28
+ x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
29
+ ) -> jax.Array:
30
+ """Applies rotary embedding to the input tensor using rotate_half method."""
31
+ if x.ndim == 4: # [B, H, S, D]
32
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
33
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
34
+ elif x.ndim == 3: # [B, S, D]
35
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
36
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
37
+ else:
38
+ raise ValueError(f"Unsupported input dimension: {x.ndim}")
39
+
40
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
41
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
42
+
43
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
44
+ return x_rotated.astype(x.dtype)
45
+
46
+
47
+ class RotaryEmbedding(nn.Module):
48
+ dim: int
49
+ max_seq_len: int = 4096 # Increased default based on SimpleDiT
50
+ base: int = 10000
51
+ dtype: Dtype = jnp.float32
52
+
53
+ def setup(self):
54
+ inv_freq = 1.0 / (
55
+ self.base ** (jnp.arange(0, self.dim, 2,
56
+ dtype=jnp.float32) / self.dim)
57
+ )
58
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
59
+ freqs = jnp.outer(t, inv_freq)
60
+ self.freqs_cos = jnp.cos(freqs)
61
+ self.freqs_sin = jnp.sin(freqs)
62
+
63
+ def __call__(self, seq_len: int):
64
+ if seq_len > self.max_seq_len:
65
+ # Dynamically extend frequencies if needed (more robust)
66
+ t = jnp.arange(seq_len, dtype=jnp.float32)
67
+ inv_freq = 1.0 / (
68
+ self.base ** (jnp.arange(0, self.dim, 2,
69
+ dtype=jnp.float32) / self.dim)
70
+ )
71
+ freqs = jnp.outer(t, inv_freq)
72
+ freqs_cos = jnp.cos(freqs)
73
+ freqs_sin = jnp.sin(freqs)
74
+ # Consider caching extended freqs if this happens often
75
+ return freqs_cos, freqs_sin
76
+ # Or raise error like before:
77
+ # raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
78
+ return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
79
+
80
+ # --- Attention with RoPE ---
81
+ # Re-used from simple_dit.py
82
+
83
+
84
+ class RoPEAttention(NormalAttention):
85
+ rope_emb: RotaryEmbedding = None
86
+
87
+ @nn.compact
88
+ def __call__(self, x, context=None, freqs_cis=None):
89
+ orig_x_shape = x.shape
90
+ is_4d = len(orig_x_shape) == 4
91
+ if is_4d:
92
+ B, H, W, C = x.shape
93
+ seq_len = H * W
94
+ x = x.reshape((B, seq_len, C))
95
+ else:
96
+ B, seq_len, C = x.shape
97
+
98
+ context = x if context is None else context
99
+ if len(context.shape) == 4:
100
+ _B, _H, _W, _C = context.shape
101
+ context_seq_len = _H * _W
102
+ context = context.reshape((B, context_seq_len, _C))
103
+ # else: # context is already [B, S_ctx, C]
104
+
105
+ query = self.query(x) # [B, S, H, D]
106
+ key = self.key(context) # [B, S_ctx, H, D]
107
+ value = self.value(context) # [B, S_ctx, H, D]
108
+
109
+ if freqs_cis is None and self.rope_emb is not None:
110
+ seq_len_q = query.shape[1] # Use query's sequence length
111
+ freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
112
+ elif freqs_cis is not None:
113
+ freqs_cos, freqs_sin = freqs_cis
114
+ else:
115
+ # Should not happen if rope_emb is provided or freqs_cis are passed
116
+ raise ValueError("RoPE frequencies not provided.")
117
+
118
+ # Apply RoPE to query and key
119
+ # Permute to [B, H, S, D] for RoPE application
120
+ query = einops.rearrange(query, 'b s h d -> b h s d')
121
+ key = einops.rearrange(key, 'b s h d -> b h s d')
122
+
123
+ # Apply RoPE only up to the context sequence length for keys if different
124
+ # Assuming self-attention or context has same seq len for simplicity here
125
+ query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
126
+ key = apply_rotary_embedding(
127
+ key, freqs_cos, freqs_sin) # Apply same freqs to key
128
+
129
+ # Permute back to [B, S, H, D] for dot_product_attention
130
+ query = einops.rearrange(query, 'b h s d -> b s h d')
131
+ key = einops.rearrange(key, 'b h s d -> b s h d')
132
+
133
+ hidden_states = nn.dot_product_attention(
134
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
135
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
136
+ deterministic=True
137
+ )
138
+
139
+ proj = self.proj_attn(hidden_states)
140
+
141
+ if is_4d:
142
+ proj = proj.reshape(orig_x_shape)
143
+
144
+ return proj
145
+
146
+
147
+ # --- MM-DiT AdaLN-Zero ---
148
+ class MMAdaLNZero(nn.Module):
149
+ """
150
+ Adaptive Layer Normalization Zero (AdaLN-Zero) tailored for MM-DiT.
151
+ Projects time and text embeddings separately, combines them, and then
152
+ generates modulation parameters (scale, shift, gate) for attention and MLP paths.
153
+ """
154
+ features: int
155
+ dtype: Optional[Dtype] = None
156
+ precision: PrecisionLike = None
157
+ norm_epsilon: float = 1e-5
158
+ use_mean_pooling: bool = True # Whether to use mean pooling for sequence inputs
159
+
160
+ @nn.compact
161
+ def __call__(self, x, t_emb, text_emb):
162
+ # x shape: [B, S, F]
163
+ # t_emb shape: [B, D_t]
164
+ # text_emb shape: [B, S_text, D_text] or [B, D_text]
165
+
166
+ # First normalize the input features
167
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
168
+ use_scale=False, use_bias=False, dtype=self.dtype)
169
+ norm_x = norm(x) # Shape: [B, S, F]
170
+
171
+ # Process time embedding: ensure it has a sequence dimension for later broadcasting
172
+ if t_emb.ndim == 2: # [B, D_t]
173
+ t_emb = jnp.expand_dims(t_emb, axis=1) # [B, 1, D_t]
174
+
175
+ # Process text embedding: if it has a sequence dimension different from x
176
+ if text_emb.ndim == 2: # [B, D_text]
177
+ text_emb = jnp.expand_dims(text_emb, axis=1) # [B, 1, D_text]
178
+ elif text_emb.ndim == 3 and self.use_mean_pooling and text_emb.shape[1] != x.shape[1]:
179
+ # Mean pooling is standard in MM-DiT for handling different sequence lengths
180
+ text_emb = jnp.mean(
181
+ text_emb, axis=1, keepdims=True) # [B, 1, D_text]
182
+
183
+ # Project time embedding
184
+ t_params = nn.Dense(
185
+ features=6 * self.features,
186
+ dtype=self.dtype,
187
+ precision=self.precision,
188
+ kernel_init=nn.initializers.zeros, # Zero init is standard in AdaLN-Zero
189
+ name="ada_t_proj"
190
+ )(t_emb) # Shape: [B, 1, 6*F]
191
+
192
+ # Project text embedding
193
+ text_params = nn.Dense(
194
+ features=6 * self.features,
195
+ dtype=self.dtype,
196
+ precision=self.precision,
197
+ kernel_init=nn.initializers.zeros, # Zero init
198
+ name="ada_text_proj"
199
+ )(text_emb) # Shape: [B, 1, 6*F] or [B, S_text, 6*F]
200
+
201
+ # If text_params still has a sequence dim different from t_params, mean pool it
202
+ if t_params.shape[1] != text_params.shape[1]:
203
+ text_params = jnp.mean(text_params, axis=1, keepdims=True)
204
+
205
+ # Combine parameters (summing is standard in MM-DiT)
206
+ ada_params = t_params + text_params # Shape: [B, 1, 6*F]
207
+
208
+ # Split into scale, shift, gate for MLP and Attention
209
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
210
+ ada_params, 6, axis=-1) # Each shape: [B, 1, F]
211
+
212
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
213
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
214
+ # Apply modulation for Attention path (broadcasting handled by JAX)
215
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
216
+
217
+ # Apply modulation for MLP path
218
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
219
+
220
+ # Return modulated outputs and gates
221
+ return x_attn, gate_attn, x_mlp, gate_mlp
222
+
223
+
224
+ # --- MM-DiT Block ---
225
+ class MMDiTBlock(nn.Module):
226
+ """
227
+ A Transformer block adapted for MM-DiT, using MMAdaLNZero for conditioning.
228
+ """
229
+ features: int
230
+ num_heads: int
231
+ rope_emb: RotaryEmbedding # Pass RoPE module
232
+ mlp_ratio: int = 4
233
+ dropout_rate: float = 0.0
234
+ dtype: Optional[Dtype] = None
235
+ precision: PrecisionLike = None
236
+ # Keep option, though RoPEAttention doesn't use it
237
+ use_flash_attention: bool = False
238
+ force_fp32_for_softmax: bool = True
239
+ norm_epsilon: float = 1e-5
240
+
241
+ def setup(self):
242
+ hidden_features = int(self.features * self.mlp_ratio)
243
+ # Use the new MMAdaLNZero block
244
+ self.ada_ln_zero = MMAdaLNZero(
245
+ self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
246
+
247
+ # RoPEAttention remains the same
248
+ self.attention = RoPEAttention(
249
+ query_dim=self.features,
250
+ heads=self.num_heads,
251
+ dim_head=self.features // self.num_heads,
252
+ dtype=self.dtype,
253
+ precision=self.precision,
254
+ use_bias=True, # Bias is common in DiT attention proj
255
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
256
+ rope_emb=self.rope_emb # Pass RoPE module instance
257
+ )
258
+
259
+ # Standard MLP block remains the same
260
+ self.mlp = nn.Sequential([
261
+ nn.Dense(features=hidden_features, dtype=self.dtype,
262
+ precision=self.precision),
263
+ nn.gelu, # Consider swish/silu if preferred
264
+ nn.Dense(features=self.features, dtype=self.dtype,
265
+ precision=self.precision)
266
+ ])
267
+
268
+ @nn.compact
269
+ def __call__(self, x, t_emb, text_emb, freqs_cis):
270
+ # x shape: [B, S, F]
271
+ # t_emb shape: [B, D_t] or [B, 1, D_t]
272
+ # text_emb shape: [B, D_text] or [B, 1, D_text]
273
+
274
+ residual = x
275
+
276
+ # Apply MMAdaLNZero with separate time and text embeddings
277
+ x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(
278
+ x, t_emb, text_emb)
279
+
280
+ # Attention block (remains the same)
281
+ attn_output = self.attention(
282
+ x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
283
+ x = residual + gate_attn * attn_output
284
+
285
+ # MLP block (remains the same)
286
+ mlp_output = self.mlp(x_mlp)
287
+ x = x + gate_mlp * mlp_output
288
+
289
+ return x
290
+
291
+
292
+ # --- SimpleMMDiT ---
293
+ class SimpleMMDiT(nn.Module):
294
+ """
295
+ A Simple Multi-Modal Diffusion Transformer (MM-DiT) implementation.
296
+ Integrates time and text conditioning using separate projections within
297
+ each transformer block, following the MM-DiT approach. Uses RoPE for
298
+ patch positional encoding.
299
+ """
300
+ output_channels: int = 3
301
+ patch_size: int = 16
302
+ emb_features: int = 768
303
+ num_layers: int = 12
304
+ num_heads: int = 12
305
+ mlp_ratio: int = 4
306
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
307
+ dtype: Optional[Dtype] = None
308
+ precision: PrecisionLike = None
309
+ # Passed down, but RoPEAttention uses NormalAttention
310
+ use_flash_attention: bool = False
311
+ force_fp32_for_softmax: bool = True
312
+ norm_epsilon: float = 1e-5
313
+ learn_sigma: bool = False # Option to predict sigma like in DiT paper
314
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
315
+ norm_groups: int = 0
316
+ activation: Callable = jax.nn.swish
317
+
318
+ def setup(self):
319
+ self.patch_embed = PatchEmbedding(
320
+ patch_size=self.patch_size,
321
+ embedding_dim=self.emb_features,
322
+ dtype=self.dtype,
323
+ precision=self.precision
324
+ )
325
+
326
+ # Time embedding projection (output dim: emb_features)
327
+ self.time_embed = nn.Sequential([
328
+ FourierEmbedding(features=self.emb_features),
329
+ TimeProjection(features=self.emb_features *
330
+ self.mlp_ratio), # Intermediate projection
331
+ nn.Dense(features=self.emb_features, dtype=self.dtype,
332
+ precision=self.precision) # Final projection
333
+ ], name="time_embed")
334
+
335
+ # Add projection layer for Hilbert patches
336
+ if self.use_hilbert:
337
+ self.hilbert_proj = nn.Dense(
338
+ features=self.emb_features,
339
+ dtype=self.dtype,
340
+ precision=self.precision,
341
+ name="hilbert_projection"
342
+ )
343
+ # Text context projection (output dim: emb_features)
344
+ # Input dim depends on the text encoder output, assumed to be handled externally
345
+ self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
346
+ precision=self.precision, name="text_context_proj")
347
+
348
+ # Rotary Positional Embedding (for patches)
349
+ # Dim per head, max_len should cover max number of patches
350
+ self.rope = RotaryEmbedding(
351
+ dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype)
352
+
353
+ # Transformer Blocks (use MMDiTBlock)
354
+ self.blocks = [
355
+ MMDiTBlock(
356
+ features=self.emb_features,
357
+ num_heads=self.num_heads,
358
+ mlp_ratio=self.mlp_ratio,
359
+ dropout_rate=self.dropout_rate,
360
+ dtype=self.dtype,
361
+ precision=self.precision,
362
+ use_flash_attention=self.use_flash_attention,
363
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
364
+ norm_epsilon=self.norm_epsilon,
365
+ rope_emb=self.rope, # Pass RoPE instance
366
+ name=f"mmdit_block_{i}"
367
+ ) for i in range(self.num_layers)
368
+ ]
369
+
370
+ # Final Layer (Normalization + Linear Projection)
371
+ self.final_norm = nn.LayerNorm(
372
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
373
+ # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm") # Alternative
374
+
375
+ # Predict patch pixels + potentially sigma
376
+ output_dim = self.patch_size * self.patch_size * self.output_channels
377
+ if self.learn_sigma:
378
+ output_dim *= 2 # Predict both mean and variance (or log_variance)
379
+
380
+ self.final_proj = nn.Dense(
381
+ features=output_dim,
382
+ dtype=self.dtype,
383
+ precision=self.precision,
384
+ kernel_init=nn.initializers.zeros, # Initialize final layer to zero
385
+ name="final_proj"
386
+ )
387
+
388
+ @nn.compact
389
+ def __call__(self, x, temb, textcontext): # textcontext is required
390
+ B, H, W, C = x.shape
391
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
392
+ assert textcontext is not None, "textcontext must be provided for SimpleMMDiT"
393
+
394
+ # 1. Patch Embedding
395
+ if self.use_hilbert:
396
+ # Use hilbert_patchify which handles both patchification and reordering
397
+ patches_raw, hilbert_inv_idx = hilbert_patchify(
398
+ x, self.patch_size) # Shape [B, S, P*P*C]
399
+ # Apply projection
400
+ # Shape [B, S, emb_features]
401
+ patches = self.hilbert_proj(patches_raw)
402
+ else:
403
+ # Shape: [B, num_patches, emb_features]
404
+ patches = self.patch_embed(x)
405
+ hilbert_inv_idx = None
406
+
407
+ num_patches = patches.shape[1]
408
+ x_seq = patches
409
+
410
+ # 2. Prepare Conditioning Signals
411
+ t_emb = self.time_embed(temb) # Shape: [B, emb_features]
412
+ # Assuming textcontext is [B, context_seq_len, context_dim] or [B, context_dim]
413
+ # If [B, context_seq_len, context_dim], usually mean/pool or take CLS token first.
414
+ # Assuming textcontext is already pooled/CLS token: [B, context_dim]
415
+ text_emb = self.text_proj(textcontext) # Shape: [B, emb_features]
416
+
417
+ # 3. Apply RoPE Frequencies (only to patch tokens)
418
+ seq_len = x_seq.shape[1]
419
+ freqs_cos, freqs_sin = self.rope(seq_len) # Shapes: [S, D_head/2]
420
+
421
+ # 4. Apply Transformer Blocks
422
+ for block in self.blocks:
423
+ # Pass t_emb and text_emb separately to the block
424
+ x_seq = block(x_seq, t_emb, text_emb,
425
+ freqs_cis=(freqs_cos, freqs_sin))
426
+
427
+ # 5. Final Layer
428
+ x_seq = self.final_norm(x_seq)
429
+ # Shape: [B, num_patches, P*P*C (*2 if learn_sigma)]
430
+ x_seq = self.final_proj(x_seq)
431
+
432
+ # 6. Unpatchify
433
+ if self.use_hilbert:
434
+ # For Hilbert mode, we need to use the specialized unpatchify function
435
+ if self.learn_sigma:
436
+ # Split into mean and variance predictions
437
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
438
+ x_image = hilbert_unpatchify(
439
+ x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
440
+ # If needed, also unpack the logvar
441
+ # logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
442
+ # return x_image, logvar_image
443
+ return x_image
444
+ else:
445
+ x_image = hilbert_unpatchify(
446
+ x_seq, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
447
+ return x_image
448
+ else:
449
+ # Standard patch ordering - use the existing unpatchify function
450
+ if self.learn_sigma:
451
+ # Split into mean and variance predictions
452
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
453
+ x = unpatchify(x_mean, channels=self.output_channels)
454
+ # Return both mean and logvar if needed by the loss function
455
+ # For now, just returning the mean prediction like standard diffusion models
456
+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
457
+ # return x, logvar
458
+ return x
459
+ else:
460
+ # Shape: [B, H, W, C]
461
+ x = unpatchify(x_seq, channels=self.output_channels)
462
+ return x
463
+
464
+
465
+ # --- Hierarchical MM-DiT components ---
466
+
467
+ class PatchMerging(nn.Module):
468
+ """
469
+ Merges a group of patches into a single patch with increased feature dimensions.
470
+ Used in the hierarchical structure to reduce spatial resolution and increase channels.
471
+ """
472
+ out_features: int
473
+ merge_size: int = 2 # Default 2x2 patch merging
474
+ dtype: Optional[Dtype] = None
475
+ precision: PrecisionLike = None
476
+ norm_epsilon: float = 1e-5 # Add norm for stability like in Swin Transformer
477
+
478
+ @nn.compact
479
+ def __call__(self, x, H_patches, W_patches):
480
+ # x shape: [B, H*W, C]
481
+ B, L, C = x.shape
482
+ assert L == H_patches * \
483
+ W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
484
+ assert H_patches % self.merge_size == 0 and W_patches % self.merge_size == 0, f"Patch dimensions ({H_patches}, {W_patches}) not divisible by merge size {self.merge_size}"
485
+
486
+ # Reshape to [B, H, W, C]
487
+ x = x.reshape(B, H_patches, W_patches, C)
488
+
489
+ # Merge patches - rearrange to group nearby patches
490
+ merged = einops.rearrange(
491
+ x,
492
+ 'b (h p1) (w p2) c -> b h w (p1 p2 c)',
493
+ p1=self.merge_size, p2=self.merge_size
494
+ )
495
+
496
+ # Apply LayerNorm before projection (common practice)
497
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
498
+ merged = norm(merged) # Apply norm on [B, H/p, W/p, p*p*C]
499
+
500
+ # Project to new dimension
501
+ merged = nn.Dense(
502
+ features=self.out_features,
503
+ dtype=self.dtype,
504
+ precision=self.precision,
505
+ name="projection"
506
+ )(merged)
507
+
508
+ # Flatten back to sequence
509
+ new_H = H_patches // self.merge_size
510
+ new_W = W_patches // self.merge_size
511
+ merged = merged.reshape(B, new_H * new_W, self.out_features)
512
+
513
+ return merged, new_H, new_W
514
+
515
+ class PatchExpanding(nn.Module):
516
+ """
517
+ Expands patches to increase spatial resolution.
518
+ Used in the hierarchical structure decoder path.
519
+ """
520
+ out_features: int
521
+ expand_size: int = 2 # Default 2x2 patch expansion
522
+ dtype: Optional[Dtype] = None
523
+ precision: PrecisionLike = None
524
+ norm_epsilon: float = 1e-5 # Add norm for stability
525
+
526
+ @nn.compact
527
+ def __call__(self, x, H_patches, W_patches):
528
+ # x shape: [B, H*W, C]
529
+ B, L, C = x.shape
530
+ assert L == H_patches * W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
531
+
532
+ # Project to expanded dimension first
533
+ expanded_features = self.expand_size * self.expand_size * self.out_features
534
+ x = nn.Dense(
535
+ features=expanded_features,
536
+ dtype=self.dtype,
537
+ precision=self.precision,
538
+ name="projection"
539
+ )(x) # Shape [B, L, P*P*C_out]
540
+
541
+ # Apply LayerNorm after projection
542
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
543
+ x = norm(x)
544
+
545
+ # Reshape to spatial grid before rearranging
546
+ x = x.reshape(B, H_patches, W_patches, expanded_features)
547
+
548
+ # Rearrange to expand spatial dims
549
+ expanded = einops.rearrange(
550
+ x,
551
+ 'b h w (p1 p2 c) -> b (h p1) (w p2) c',
552
+ p1=self.expand_size, p2=self.expand_size, c=self.out_features
553
+ )
554
+
555
+ # Flatten back to sequence
556
+ new_H = H_patches * self.expand_size
557
+ new_W = W_patches * self.expand_size
558
+ expanded = expanded.reshape(B, new_H * new_W, self.out_features)
559
+
560
+ return expanded, new_H, new_W
561
+
562
+
563
+ # --- Hierarchical MM-DiT ---
564
+ class HierarchicalMMDiT(nn.Module):
565
+ """
566
+ A Hierarchical Multi-Modal Diffusion Transformer (MM-DiT) implementation
567
+ based on the PixArt-α architecture. Processes images at multiple resolutions
568
+ with skip connections between encoder and decoder paths.
569
+ Follows a U-Net like structure: Fine -> Coarse (Encoder) -> Coarse -> Fine (Decoder).
570
+ """
571
+ output_channels: int = 3
572
+ base_patch_size: int = 8 # Patch size at the *finest* resolution level (stage 0)
573
+ emb_features: Sequence[int] = (512, 768, 1024) # Feature dimensions for stages 0, 1, 2 (fine to coarse)
574
+ num_layers: Sequence[int] = (4, 4, 14) # Layers per stage (can be asymmetric encoder/decoder if needed)
575
+ num_heads: Sequence[int] = (8, 12, 16) # Heads per stage (fine to coarse)
576
+ mlp_ratio: int = 4
577
+ dropout_rate: float = 0.0
578
+ dtype: Optional[Dtype] = None
579
+ precision: PrecisionLike = None
580
+ use_flash_attention: bool = False
581
+ force_fp32_for_softmax: bool = True
582
+ norm_epsilon: float = 1e-5
583
+ learn_sigma: bool = False
584
+ use_hilbert: bool = False
585
+ norm_groups: int = 0 # Not used in this structure, maybe remove later
586
+ activation: Callable = jax.nn.swish # Not used directly here, used in MLP inside MMDiTBlock
587
+
588
+ def setup(self):
589
+ assert len(self.emb_features) == len(self.num_layers) == len(self.num_heads), \
590
+ "Feature dimensions, layers, and heads must have the same number of stages"
591
+
592
+ num_stages = len(self.emb_features)
593
+
594
+ # 1. Initial Patch Embedding (FINEST level - stage 0)
595
+ self.patch_embed = PatchEmbedding(
596
+ patch_size=self.base_patch_size,
597
+ embedding_dim=self.emb_features[0], # Finest embedding dim
598
+ dtype=self.dtype,
599
+ precision=self.precision
600
+ )
601
+
602
+ # 2. Time/Text Embeddings (Projected for each stage)
603
+ # Base projection to largest dimension first for stability/capacity
604
+ base_t_emb_dim = self.emb_features[-1]
605
+ self.time_embed_base = nn.Sequential([
606
+ FourierEmbedding(features=base_t_emb_dim),
607
+ TimeProjection(features=base_t_emb_dim * self.mlp_ratio),
608
+ nn.Dense(features=base_t_emb_dim, dtype=self.dtype, precision=self.precision)
609
+ ], name="time_embed_base")
610
+ self.text_proj_base = nn.Dense(
611
+ features=base_t_emb_dim,
612
+ dtype=self.dtype,
613
+ precision=self.precision,
614
+ name="text_context_proj_base"
615
+ )
616
+ # Projections for each stage (0 to N-1)
617
+ self.t_emb_projs = [
618
+ nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"t_emb_proj_stage{i}")
619
+ for i in range(num_stages)
620
+ ]
621
+ self.text_emb_projs = [
622
+ nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"text_emb_proj_stage{i}")
623
+ for i in range(num_stages)
624
+ ]
625
+
626
+ # 3. Hilbert projection (if used, applied after initial patch embedding)
627
+ if self.use_hilbert:
628
+ self.hilbert_proj = nn.Dense(
629
+ features=self.emb_features[0], # Match finest embedding dim
630
+ dtype=self.dtype,
631
+ precision=self.precision,
632
+ name="hilbert_projection"
633
+ )
634
+
635
+ # 4. RoPE embeddings for each stage (0 to N-1)
636
+ self.ropes = [
637
+ RotaryEmbedding(
638
+ dim=self.emb_features[i] // self.num_heads[i],
639
+ max_seq_len=4096, # Adjust if needed based on max patch count per stage
640
+ dtype=self.dtype,
641
+ name=f"rope_stage_{i}"
642
+ )
643
+ for i in range(num_stages)
644
+ ]
645
+
646
+ # 5. --- Encoder Path (Fine to Coarse) ---
647
+ encoder_blocks = []
648
+ patch_mergers = []
649
+ for stage in range(num_stages):
650
+ # Blocks for this stage
651
+ stage_blocks = [
652
+ MMDiTBlock(
653
+ features=self.emb_features[stage],
654
+ num_heads=self.num_heads[stage],
655
+ mlp_ratio=self.mlp_ratio,
656
+ dropout_rate=self.dropout_rate,
657
+ dtype=self.dtype,
658
+ precision=self.precision,
659
+ use_flash_attention=self.use_flash_attention,
660
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
661
+ norm_epsilon=self.norm_epsilon,
662
+ rope_emb=self.ropes[stage],
663
+ name=f"encoder_block_stage{stage}_{i}"
664
+ )
665
+ # Assuming symmetric layers for now, adjust if needed (e.g., self.num_encoder_layers[stage])
666
+ for i in range(self.num_layers[stage])
667
+ ]
668
+ encoder_blocks.append(stage_blocks)
669
+
670
+ # Patch Merging layer (except for the last/coarsest stage)
671
+ if stage < num_stages - 1:
672
+ patch_mergers.append(
673
+ PatchMerging(
674
+ out_features=self.emb_features[stage + 1], # Target next stage dim
675
+ dtype=self.dtype,
676
+ precision=self.precision,
677
+ norm_epsilon=self.norm_epsilon,
678
+ name=f"patch_merger_{stage}"
679
+ )
680
+ )
681
+ self.encoder_blocks = encoder_blocks
682
+ self.patch_mergers = patch_mergers
683
+
684
+ # 6. --- Decoder Path (Coarse to Fine) ---
685
+ decoder_blocks = []
686
+ patch_expanders = []
687
+ fusion_layers = []
688
+ # Iterate from second coarsest stage (N-2) down to finest (0)
689
+ for stage in range(num_stages - 2, -1, -1):
690
+ # Patch Expanding layer (Expands from stage+1 to stage)
691
+ patch_expanders.append(
692
+ PatchExpanding(
693
+ out_features=self.emb_features[stage], # Target current stage dim
694
+ dtype=self.dtype,
695
+ precision=self.precision,
696
+ norm_epsilon=self.norm_epsilon,
697
+ name=f"patch_expander_{stage}" # Naming indicates target stage
698
+ )
699
+ )
700
+ # Fusion layer (Combines skip[stage] and expanded[stage+1]->[stage])
701
+ fusion_layers.append(
702
+ nn.Sequential([ # Use Sequential for Norm + Dense
703
+ nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name=f"fusion_norm_{stage}"),
704
+ nn.Dense(
705
+ features=self.emb_features[stage], # Output current stage dim
706
+ dtype=self.dtype,
707
+ precision=self.precision,
708
+ name=f"fusion_dense_{stage}"
709
+ )
710
+ ])
711
+ )
712
+
713
+ # Blocks for this stage (stage N-2 down to 0)
714
+ # Assuming symmetric layers for now
715
+ stage_blocks = [
716
+ MMDiTBlock(
717
+ features=self.emb_features[stage],
718
+ num_heads=self.num_heads[stage],
719
+ mlp_ratio=self.mlp_ratio,
720
+ dropout_rate=self.dropout_rate,
721
+ dtype=self.dtype,
722
+ precision=self.precision,
723
+ use_flash_attention=self.use_flash_attention,
724
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
725
+ norm_epsilon=self.norm_epsilon,
726
+ rope_emb=self.ropes[stage],
727
+ name=f"decoder_block_stage{stage}_{i}"
728
+ )
729
+ for i in range(self.num_layers[stage])
730
+ ]
731
+ # Append blocks in order: stage N-2, N-3, ..., 0
732
+ decoder_blocks.append(stage_blocks)
733
+
734
+ self.patch_expanders = patch_expanders
735
+ self.fusion_layers = fusion_layers
736
+ self.decoder_blocks = decoder_blocks
737
+
738
+ # Note: The lists expanders, fusion_layers, decoder_blocks are now ordered
739
+ # corresponding to stages N-2, N-3, ..., 0.
740
+
741
+ # 7. Final Layer
742
+ self.final_norm = nn.LayerNorm(
743
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
744
+
745
+ # Output projection to pixels (at finest resolution)
746
+ output_dim = self.base_patch_size * self.base_patch_size * self.output_channels
747
+ if self.learn_sigma:
748
+ output_dim *= 2 # Predict both mean and variance
749
+
750
+ self.final_proj = nn.Dense(
751
+ features=output_dim,
752
+ dtype=self.dtype,
753
+ precision=self.precision,
754
+ kernel_init=nn.initializers.zeros, # Zero init
755
+ name="final_proj"
756
+ )
757
+
758
+ def __call__(self, x, temb, textcontext):
759
+ B, H, W, C = x.shape
760
+ num_stages = len(self.emb_features)
761
+ finest_patch_size = self.base_patch_size
762
+
763
+ # Assertions
764
+ assert H % (finest_patch_size * (2**(num_stages - 1))) == 0 and \
765
+ W % (finest_patch_size * (2**(num_stages - 1))) == 0, \
766
+ f"Image dimensions ({H},{W}) must be divisible by effective coarsest patch size {finest_patch_size * (2**(num_stages - 1))}"
767
+ assert textcontext is not None, "textcontext must be provided"
768
+
769
+ # 1. Initial Patch Embedding (Finest Level - stage 0)
770
+ H_patches = H // finest_patch_size
771
+ W_patches = W // finest_patch_size
772
+ total_patches = H_patches * W_patches # Calculate total patches
773
+ hilbert_inv_idx = None
774
+ if self.use_hilbert:
775
+ # Calculate Hilbert indices and inverse permutation for the *finest* grid
776
+ fine_idx = hilbert_indices(H_patches, W_patches)
777
+ # Pass the total number of patches as total_size
778
+ hilbert_inv_idx = inverse_permutation(fine_idx, total_size=total_patches) # Store for unpatchify
779
+
780
+ # Apply Hilbert patchify at the finest level
781
+ patches_raw, _ = hilbert_patchify(x, finest_patch_size) # We already have inv_idx
782
+ x_seq = self.hilbert_proj(patches_raw) # Shape [B, S_fine, emb[0]]
783
+ else:
784
+ x_seq = self.patch_embed(x) # Shape [B, S_fine, emb[0]]
785
+
786
+ # 2. Prepare Conditioning Signals for each stage
787
+ t_emb_base = self.time_embed_base(temb)
788
+ text_emb_base = self.text_proj_base(textcontext)
789
+ t_embs = [proj(t_emb_base) for proj in self.t_emb_projs] # List for stages 0 to N-1
790
+ text_embs = [proj(text_emb_base) for proj in self.text_emb_projs] # List for stages 0 to N-1
791
+
792
+ # --- Encoder Path (Fine to Coarse: stages 0 to N-1) ---
793
+ skip_features = {}
794
+ current_H_patches, current_W_patches = H_patches, W_patches
795
+ for stage in range(num_stages):
796
+ # Apply RoPE for current stage
797
+ seq_len = x_seq.shape[1]
798
+ freqs_cos, freqs_sin = self.ropes[stage](seq_len)
799
+
800
+ # Apply blocks for this stage
801
+ for block in self.encoder_blocks[stage]:
802
+ x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
803
+
804
+ # Store skip features (before merging)
805
+ skip_features[stage] = x_seq
806
+
807
+ # Apply Patch Merging (if not the last/coarsest stage)
808
+ if stage < num_stages - 1:
809
+ x_seq, current_H_patches, current_W_patches = self.patch_mergers[stage](
810
+ x_seq, current_H_patches, current_W_patches
811
+ )
812
+
813
+ # --- Bottleneck ---
814
+ # x_seq now holds the output of the coarsest stage (stage N-1)
815
+
816
+ # --- Decoder Path (Coarse to Fine: stages N-2 down to 0) ---
817
+ # Decoder lists (expanders, fusion, blocks) are ordered for stages N-2, ..., 0
818
+ for i, stage in enumerate(range(num_stages - 2, -1, -1)): # stage = N-2, N-3, ..., 0; i = 0, 1, ..., N-2
819
+ # Apply Patch Expanding (Expand from stage+1 feature map to stage feature map)
820
+ x_seq, current_H_patches, current_W_patches = self.patch_expanders[i](
821
+ x_seq, current_H_patches, current_W_patches
822
+ )
823
+
824
+ # Fusion with skip connection from corresponding encoder stage
825
+ skip = skip_features[stage]
826
+ x_seq = jnp.concatenate([x_seq, skip], axis=-1) # Concatenate along feature dim
827
+ x_seq = self.fusion_layers[i](x_seq) # Apply fusion (Norm + Dense)
828
+
829
+ # Apply RoPE for current stage
830
+ seq_len = x_seq.shape[1]
831
+ freqs_cos, freqs_sin = self.ropes[stage](seq_len)
832
+
833
+ # Apply blocks for this stage
834
+ for block in self.decoder_blocks[i]: # Use index i for the decoder block list
835
+ x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
836
+
837
+ # --- Final Layer ---
838
+ # x_seq should now be at the finest resolution (stage 0 features)
839
+ x_seq = self.final_norm(x_seq)
840
+ x_seq = self.final_proj(x_seq) # Project to patch pixel values
841
+
842
+ # --- Unpatchify ---
843
+ if self.use_hilbert:
844
+ # Use the inverse Hilbert index calculated for the *finest* grid
845
+ assert hilbert_inv_idx is not None, "Hilbert inverse index should exist if use_hilbert is True"
846
+ if self.learn_sigma:
847
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
848
+ out = hilbert_unpatchify(x_mean, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
849
+ # Optionally return logvar: logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
850
+ else:
851
+ out = hilbert_unpatchify(x_seq, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
852
+ else:
853
+ # Standard unpatchify
854
+ if self.learn_sigma:
855
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
856
+ out = unpatchify(x_mean, channels=self.output_channels)
857
+ # Optionally return logvar: logvar = unpatchify(x_logvar, channels=self.output_channels)
858
+ else:
859
+ out = unpatchify(x_seq, channels=self.output_channels)
860
+
861
+ return out