flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,730 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Callable, Any, Optional, Tuple, Sequence, Union, List
5
+ import einops
6
+ from functools import partial
7
+ from flax.typing import Dtype, PrecisionLike
8
+
9
+ # Imports from local modules
10
+ from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention
11
+ from .common import kernel_init, FourierEmbedding, TimeProjection
12
+ from .attention import NormalAttention # Base for RoPEAttention
13
+ # Replace common.hilbert_indices with improved implementation from hilbert.py
14
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
15
+
16
+ # --- MM-DiT AdaLN-Zero ---
17
+ class MMAdaLNZero(nn.Module):
18
+ """
19
+ Adaptive Layer Normalization Zero (AdaLN-Zero) tailored for MM-DiT.
20
+ Projects time and text embeddings separately, combines them, and then
21
+ generates modulation parameters (scale, shift, gate) for attention and MLP paths.
22
+ """
23
+ features: int
24
+ dtype: Optional[Dtype] = None
25
+ precision: PrecisionLike = None
26
+ norm_epsilon: float = 1e-5
27
+ use_mean_pooling: bool = True # Whether to use mean pooling for sequence inputs
28
+
29
+ @nn.compact
30
+ def __call__(self, x, t_emb, text_emb):
31
+ # x shape: [B, S, F]
32
+ # t_emb shape: [B, D_t]
33
+ # text_emb shape: [B, S_text, D_text] or [B, D_text]
34
+
35
+ # First normalize the input features
36
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
37
+ use_scale=False, use_bias=False, dtype=self.dtype)
38
+ norm_x = norm(x) # Shape: [B, S, F]
39
+
40
+ # Process time embedding: ensure it has a sequence dimension for later broadcasting
41
+ if t_emb.ndim == 2: # [B, D_t]
42
+ t_emb = jnp.expand_dims(t_emb, axis=1) # [B, 1, D_t]
43
+
44
+ # Process text embedding: if it has a sequence dimension different from x
45
+ if text_emb.ndim == 2: # [B, D_text]
46
+ text_emb = jnp.expand_dims(text_emb, axis=1) # [B, 1, D_text]
47
+ elif text_emb.ndim == 3 and self.use_mean_pooling and text_emb.shape[1] != x.shape[1]:
48
+ # Mean pooling is standard in MM-DiT for handling different sequence lengths
49
+ text_emb = jnp.mean(
50
+ text_emb, axis=1, keepdims=True) # [B, 1, D_text]
51
+
52
+ # Project time embedding
53
+ t_params = nn.Dense(
54
+ features=6 * self.features,
55
+ dtype=self.dtype,
56
+ precision=self.precision,
57
+ kernel_init=nn.initializers.zeros, # Zero init is standard in AdaLN-Zero
58
+ name="ada_t_proj"
59
+ )(t_emb) # Shape: [B, 1, 6*F]
60
+
61
+ # Project text embedding
62
+ text_params = nn.Dense(
63
+ features=6 * self.features,
64
+ dtype=self.dtype,
65
+ precision=self.precision,
66
+ kernel_init=nn.initializers.zeros, # Zero init
67
+ name="ada_text_proj"
68
+ )(text_emb) # Shape: [B, 1, 6*F] or [B, S_text, 6*F]
69
+
70
+ # If text_params still has a sequence dim different from t_params, mean pool it
71
+ if t_params.shape[1] != text_params.shape[1]:
72
+ text_params = jnp.mean(text_params, axis=1, keepdims=True)
73
+
74
+ # Combine parameters (summing is standard in MM-DiT)
75
+ ada_params = t_params + text_params # Shape: [B, 1, 6*F]
76
+
77
+ # Split into scale, shift, gate for MLP and Attention
78
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
79
+ ada_params, 6, axis=-1) # Each shape: [B, 1, F]
80
+
81
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
82
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
83
+ # Apply modulation for Attention path (broadcasting handled by JAX)
84
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
85
+
86
+ # Apply modulation for MLP path
87
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
88
+
89
+ # Return modulated outputs and gates
90
+ return x_attn, gate_attn, x_mlp, gate_mlp
91
+
92
+
93
+ # --- MM-DiT Block ---
94
+ class MMDiTBlock(nn.Module):
95
+ """
96
+ A Transformer block adapted for MM-DiT, using MMAdaLNZero for conditioning.
97
+ """
98
+ features: int
99
+ num_heads: int
100
+ rope_emb: RotaryEmbedding # Pass RoPE module
101
+ mlp_ratio: int = 4
102
+ dropout_rate: float = 0.0
103
+ dtype: Optional[Dtype] = None
104
+ precision: PrecisionLike = None
105
+ # Keep option, though RoPEAttention doesn't use it
106
+ use_flash_attention: bool = False
107
+ force_fp32_for_softmax: bool = True
108
+ norm_epsilon: float = 1e-5
109
+
110
+ def setup(self):
111
+ hidden_features = int(self.features * self.mlp_ratio)
112
+ # Use the new MMAdaLNZero block
113
+ self.ada_ln_zero = MMAdaLNZero(
114
+ self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
115
+
116
+ # RoPEAttention remains the same
117
+ self.attention = RoPEAttention(
118
+ query_dim=self.features,
119
+ heads=self.num_heads,
120
+ dim_head=self.features // self.num_heads,
121
+ dtype=self.dtype,
122
+ precision=self.precision,
123
+ use_bias=True, # Bias is common in DiT attention proj
124
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
125
+ rope_emb=self.rope_emb # Pass RoPE module instance
126
+ )
127
+
128
+ # Standard MLP block remains the same
129
+ self.mlp = nn.Sequential([
130
+ nn.Dense(features=hidden_features, dtype=self.dtype,
131
+ precision=self.precision),
132
+ nn.gelu, # Consider swish/silu if preferred
133
+ nn.Dense(features=self.features, dtype=self.dtype,
134
+ precision=self.precision)
135
+ ])
136
+
137
+ @nn.compact
138
+ def __call__(self, x, t_emb, text_emb, freqs_cis):
139
+ # x shape: [B, S, F]
140
+ # t_emb shape: [B, D_t] or [B, 1, D_t]
141
+ # text_emb shape: [B, D_text] or [B, 1, D_text]
142
+
143
+ residual = x
144
+
145
+ # Apply MMAdaLNZero with separate time and text embeddings
146
+ x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(
147
+ x, t_emb, text_emb)
148
+
149
+ # Attention block (remains the same)
150
+ attn_output = self.attention(
151
+ x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
152
+ x = residual + gate_attn * attn_output
153
+
154
+ # MLP block (remains the same)
155
+ mlp_output = self.mlp(x_mlp)
156
+ x = x + gate_mlp * mlp_output
157
+
158
+ return x
159
+
160
+
161
+ # --- SimpleMMDiT ---
162
+ class SimpleMMDiT(nn.Module):
163
+ """
164
+ A Simple Multi-Modal Diffusion Transformer (MM-DiT) implementation.
165
+ Integrates time and text conditioning using separate projections within
166
+ each transformer block, following the MM-DiT approach. Uses RoPE for
167
+ patch positional encoding.
168
+ """
169
+ output_channels: int = 3
170
+ patch_size: int = 16
171
+ emb_features: int = 768
172
+ num_layers: int = 12
173
+ num_heads: int = 12
174
+ mlp_ratio: int = 4
175
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
176
+ dtype: Optional[Dtype] = None
177
+ precision: PrecisionLike = None
178
+ # Passed down, but RoPEAttention uses NormalAttention
179
+ use_flash_attention: bool = False
180
+ force_fp32_for_softmax: bool = True
181
+ norm_epsilon: float = 1e-5
182
+ learn_sigma: bool = False # Option to predict sigma like in DiT paper
183
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
184
+ norm_groups: int = 0
185
+ activation: Callable = jax.nn.swish
186
+
187
+ def setup(self):
188
+ self.patch_embed = PatchEmbedding(
189
+ patch_size=self.patch_size,
190
+ embedding_dim=self.emb_features,
191
+ dtype=self.dtype,
192
+ precision=self.precision
193
+ )
194
+
195
+ # Time embedding projection (output dim: emb_features)
196
+ self.time_embed = nn.Sequential([
197
+ FourierEmbedding(features=self.emb_features),
198
+ TimeProjection(features=self.emb_features *
199
+ self.mlp_ratio), # Intermediate projection
200
+ nn.Dense(features=self.emb_features, dtype=self.dtype,
201
+ precision=self.precision) # Final projection
202
+ ], name="time_embed")
203
+
204
+ # Add projection layer for Hilbert patches
205
+ if self.use_hilbert:
206
+ self.hilbert_proj = nn.Dense(
207
+ features=self.emb_features,
208
+ dtype=self.dtype,
209
+ precision=self.precision,
210
+ name="hilbert_projection"
211
+ )
212
+ # Text context projection (output dim: emb_features)
213
+ # Input dim depends on the text encoder output, assumed to be handled externally
214
+ self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
215
+ precision=self.precision, name="text_context_proj")
216
+
217
+ # Rotary Positional Embedding (for patches)
218
+ # Dim per head, max_len should cover max number of patches
219
+ self.rope = RotaryEmbedding(
220
+ dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype)
221
+
222
+ # Transformer Blocks (use MMDiTBlock)
223
+ self.blocks = [
224
+ MMDiTBlock(
225
+ features=self.emb_features,
226
+ num_heads=self.num_heads,
227
+ mlp_ratio=self.mlp_ratio,
228
+ dropout_rate=self.dropout_rate,
229
+ dtype=self.dtype,
230
+ precision=self.precision,
231
+ use_flash_attention=self.use_flash_attention,
232
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
233
+ norm_epsilon=self.norm_epsilon,
234
+ rope_emb=self.rope, # Pass RoPE instance
235
+ name=f"mmdit_block_{i}"
236
+ ) for i in range(self.num_layers)
237
+ ]
238
+
239
+ # Final Layer (Normalization + Linear Projection)
240
+ self.final_norm = nn.LayerNorm(
241
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
242
+ # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm") # Alternative
243
+
244
+ # Predict patch pixels + potentially sigma
245
+ output_dim = self.patch_size * self.patch_size * self.output_channels
246
+ if self.learn_sigma:
247
+ output_dim *= 2 # Predict both mean and variance (or log_variance)
248
+
249
+ self.final_proj = nn.Dense(
250
+ features=output_dim,
251
+ dtype=self.dtype,
252
+ precision=self.precision,
253
+ kernel_init=nn.initializers.zeros, # Initialize final layer to zero
254
+ name="final_proj"
255
+ )
256
+
257
+ @nn.compact
258
+ def __call__(self, x, temb, textcontext): # textcontext is required
259
+ B, H, W, C = x.shape
260
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
261
+ assert textcontext is not None, "textcontext must be provided for SimpleMMDiT"
262
+
263
+ # 1. Patch Embedding
264
+ if self.use_hilbert:
265
+ # Use hilbert_patchify which handles both patchification and reordering
266
+ patches_raw, hilbert_inv_idx = hilbert_patchify(
267
+ x, self.patch_size) # Shape [B, S, P*P*C]
268
+ # Apply projection
269
+ # Shape [B, S, emb_features]
270
+ patches = self.hilbert_proj(patches_raw)
271
+ else:
272
+ # Shape: [B, num_patches, emb_features]
273
+ patches = self.patch_embed(x)
274
+ hilbert_inv_idx = None
275
+
276
+ num_patches = patches.shape[1]
277
+ x_seq = patches
278
+
279
+ # 2. Prepare Conditioning Signals
280
+ t_emb = self.time_embed(temb) # Shape: [B, emb_features]
281
+ # Assuming textcontext is [B, context_seq_len, context_dim] or [B, context_dim]
282
+ # If [B, context_seq_len, context_dim], usually mean/pool or take CLS token first.
283
+ # Assuming textcontext is already pooled/CLS token: [B, context_dim]
284
+ text_emb = self.text_proj(textcontext) # Shape: [B, emb_features]
285
+
286
+ # 3. Apply RoPE Frequencies (only to patch tokens)
287
+ seq_len = x_seq.shape[1]
288
+ freqs_cos, freqs_sin = self.rope(seq_len) # Shapes: [S, D_head/2]
289
+
290
+ # 4. Apply Transformer Blocks
291
+ for block in self.blocks:
292
+ # Pass t_emb and text_emb separately to the block
293
+ x_seq = block(x_seq, t_emb, text_emb,
294
+ freqs_cis=(freqs_cos, freqs_sin))
295
+
296
+ # 5. Final Layer
297
+ x_seq = self.final_norm(x_seq)
298
+ # Shape: [B, num_patches, P*P*C (*2 if learn_sigma)]
299
+ x_seq = self.final_proj(x_seq)
300
+
301
+ # 6. Unpatchify
302
+ if self.use_hilbert:
303
+ # For Hilbert mode, we need to use the specialized unpatchify function
304
+ if self.learn_sigma:
305
+ # Split into mean and variance predictions
306
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
307
+ x_image = hilbert_unpatchify(
308
+ x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
309
+ # If needed, also unpack the logvar
310
+ # logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
311
+ # return x_image, logvar_image
312
+ return x_image
313
+ else:
314
+ x_image = hilbert_unpatchify(
315
+ x_seq, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
316
+ return x_image
317
+ else:
318
+ # Standard patch ordering - use the existing unpatchify function
319
+ if self.learn_sigma:
320
+ # Split into mean and variance predictions
321
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
322
+ x = unpatchify(x_mean, channels=self.output_channels)
323
+ # Return both mean and logvar if needed by the loss function
324
+ # For now, just returning the mean prediction like standard diffusion models
325
+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
326
+ # return x, logvar
327
+ return x
328
+ else:
329
+ # Shape: [B, H, W, C]
330
+ x = unpatchify(x_seq, channels=self.output_channels)
331
+ return x
332
+
333
+
334
+ # --- Hierarchical MM-DiT components ---
335
+
336
+ class PatchMerging(nn.Module):
337
+ """
338
+ Merges a group of patches into a single patch with increased feature dimensions.
339
+ Used in the hierarchical structure to reduce spatial resolution and increase channels.
340
+ """
341
+ out_features: int
342
+ merge_size: int = 2 # Default 2x2 patch merging
343
+ dtype: Optional[Dtype] = None
344
+ precision: PrecisionLike = None
345
+ norm_epsilon: float = 1e-5 # Add norm for stability like in Swin Transformer
346
+
347
+ @nn.compact
348
+ def __call__(self, x, H_patches, W_patches):
349
+ # x shape: [B, H*W, C]
350
+ B, L, C = x.shape
351
+ assert L == H_patches * \
352
+ W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
353
+ assert H_patches % self.merge_size == 0 and W_patches % self.merge_size == 0, f"Patch dimensions ({H_patches}, {W_patches}) not divisible by merge size {self.merge_size}"
354
+
355
+ # Reshape to [B, H, W, C]
356
+ x = x.reshape(B, H_patches, W_patches, C)
357
+
358
+ # Merge patches - rearrange to group nearby patches
359
+ merged = einops.rearrange(
360
+ x,
361
+ 'b (h p1) (w p2) c -> b h w (p1 p2 c)',
362
+ p1=self.merge_size, p2=self.merge_size
363
+ )
364
+
365
+ # Apply LayerNorm before projection (common practice)
366
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
367
+ merged = norm(merged) # Apply norm on [B, H/p, W/p, p*p*C]
368
+
369
+ # Project to new dimension
370
+ merged = nn.Dense(
371
+ features=self.out_features,
372
+ dtype=self.dtype,
373
+ precision=self.precision,
374
+ name="projection"
375
+ )(merged)
376
+
377
+ # Flatten back to sequence
378
+ new_H = H_patches // self.merge_size
379
+ new_W = W_patches // self.merge_size
380
+ merged = merged.reshape(B, new_H * new_W, self.out_features)
381
+
382
+ return merged, new_H, new_W
383
+
384
+ class PatchExpanding(nn.Module):
385
+ """
386
+ Expands patches to increase spatial resolution.
387
+ Used in the hierarchical structure decoder path.
388
+ """
389
+ out_features: int
390
+ expand_size: int = 2 # Default 2x2 patch expansion
391
+ dtype: Optional[Dtype] = None
392
+ precision: PrecisionLike = None
393
+ norm_epsilon: float = 1e-5 # Add norm for stability
394
+
395
+ @nn.compact
396
+ def __call__(self, x, H_patches, W_patches):
397
+ # x shape: [B, H*W, C]
398
+ B, L, C = x.shape
399
+ assert L == H_patches * W_patches, f"Input length {L} doesn't match {H_patches}*{W_patches}"
400
+
401
+ # Project to expanded dimension first
402
+ expanded_features = self.expand_size * self.expand_size * self.out_features
403
+ x = nn.Dense(
404
+ features=expanded_features,
405
+ dtype=self.dtype,
406
+ precision=self.precision,
407
+ name="projection"
408
+ )(x) # Shape [B, L, P*P*C_out]
409
+
410
+ # Apply LayerNorm after projection
411
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="norm")
412
+ x = norm(x)
413
+
414
+ # Reshape to spatial grid before rearranging
415
+ x = x.reshape(B, H_patches, W_patches, expanded_features)
416
+
417
+ # Rearrange to expand spatial dims
418
+ expanded = einops.rearrange(
419
+ x,
420
+ 'b h w (p1 p2 c) -> b (h p1) (w p2) c',
421
+ p1=self.expand_size, p2=self.expand_size, c=self.out_features
422
+ )
423
+
424
+ # Flatten back to sequence
425
+ new_H = H_patches * self.expand_size
426
+ new_W = W_patches * self.expand_size
427
+ expanded = expanded.reshape(B, new_H * new_W, self.out_features)
428
+
429
+ return expanded, new_H, new_W
430
+
431
+
432
+ # --- Hierarchical MM-DiT ---
433
+ class HierarchicalMMDiT(nn.Module):
434
+ """
435
+ A Hierarchical Multi-Modal Diffusion Transformer (MM-DiT) implementation
436
+ based on the PixArt-α architecture. Processes images at multiple resolutions
437
+ with skip connections between encoder and decoder paths.
438
+ Follows a U-Net like structure: Fine -> Coarse (Encoder) -> Coarse -> Fine (Decoder).
439
+ """
440
+ output_channels: int = 3
441
+ base_patch_size: int = 8 # Patch size at the *finest* resolution level (stage 0)
442
+ emb_features: Sequence[int] = (512, 768, 1024) # Feature dimensions for stages 0, 1, 2 (fine to coarse)
443
+ num_layers: Sequence[int] = (4, 4, 14) # Layers per stage (can be asymmetric encoder/decoder if needed)
444
+ num_heads: Sequence[int] = (8, 12, 16) # Heads per stage (fine to coarse)
445
+ mlp_ratio: int = 4
446
+ dropout_rate: float = 0.0
447
+ dtype: Optional[Dtype] = None
448
+ precision: PrecisionLike = None
449
+ use_flash_attention: bool = False
450
+ force_fp32_for_softmax: bool = True
451
+ norm_epsilon: float = 1e-5
452
+ learn_sigma: bool = False
453
+ use_hilbert: bool = False
454
+ norm_groups: int = 0 # Not used in this structure, maybe remove later
455
+ activation: Callable = jax.nn.swish # Not used directly here, used in MLP inside MMDiTBlock
456
+
457
+ def setup(self):
458
+ assert len(self.emb_features) == len(self.num_layers) == len(self.num_heads), \
459
+ "Feature dimensions, layers, and heads must have the same number of stages"
460
+
461
+ num_stages = len(self.emb_features)
462
+
463
+ # 1. Initial Patch Embedding (FINEST level - stage 0)
464
+ self.patch_embed = PatchEmbedding(
465
+ patch_size=self.base_patch_size,
466
+ embedding_dim=self.emb_features[0], # Finest embedding dim
467
+ dtype=self.dtype,
468
+ precision=self.precision
469
+ )
470
+
471
+ # 2. Time/Text Embeddings (Projected for each stage)
472
+ # Base projection to largest dimension first for stability/capacity
473
+ base_t_emb_dim = self.emb_features[-1]
474
+ self.time_embed_base = nn.Sequential([
475
+ FourierEmbedding(features=base_t_emb_dim),
476
+ TimeProjection(features=base_t_emb_dim * self.mlp_ratio),
477
+ nn.Dense(features=base_t_emb_dim, dtype=self.dtype, precision=self.precision)
478
+ ], name="time_embed_base")
479
+ self.text_proj_base = nn.Dense(
480
+ features=base_t_emb_dim,
481
+ dtype=self.dtype,
482
+ precision=self.precision,
483
+ name="text_context_proj_base"
484
+ )
485
+ # Projections for each stage (0 to N-1)
486
+ self.t_emb_projs = [
487
+ nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"t_emb_proj_stage{i}")
488
+ for i in range(num_stages)
489
+ ]
490
+ self.text_emb_projs = [
491
+ nn.Dense(features=self.emb_features[i], dtype=self.dtype, precision=self.precision, name=f"text_emb_proj_stage{i}")
492
+ for i in range(num_stages)
493
+ ]
494
+
495
+ # 3. Hilbert projection (if used, applied after initial patch embedding)
496
+ if self.use_hilbert:
497
+ self.hilbert_proj = nn.Dense(
498
+ features=self.emb_features[0], # Match finest embedding dim
499
+ dtype=self.dtype,
500
+ precision=self.precision,
501
+ name="hilbert_projection"
502
+ )
503
+
504
+ # 4. RoPE embeddings for each stage (0 to N-1)
505
+ self.ropes = [
506
+ RotaryEmbedding(
507
+ dim=self.emb_features[i] // self.num_heads[i],
508
+ max_seq_len=4096, # Adjust if needed based on max patch count per stage
509
+ dtype=self.dtype,
510
+ name=f"rope_stage_{i}"
511
+ )
512
+ for i in range(num_stages)
513
+ ]
514
+
515
+ # 5. --- Encoder Path (Fine to Coarse) ---
516
+ encoder_blocks = []
517
+ patch_mergers = []
518
+ for stage in range(num_stages):
519
+ # Blocks for this stage
520
+ stage_blocks = [
521
+ MMDiTBlock(
522
+ features=self.emb_features[stage],
523
+ num_heads=self.num_heads[stage],
524
+ mlp_ratio=self.mlp_ratio,
525
+ dropout_rate=self.dropout_rate,
526
+ dtype=self.dtype,
527
+ precision=self.precision,
528
+ use_flash_attention=self.use_flash_attention,
529
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
530
+ norm_epsilon=self.norm_epsilon,
531
+ rope_emb=self.ropes[stage],
532
+ name=f"encoder_block_stage{stage}_{i}"
533
+ )
534
+ # Assuming symmetric layers for now, adjust if needed (e.g., self.num_encoder_layers[stage])
535
+ for i in range(self.num_layers[stage])
536
+ ]
537
+ encoder_blocks.append(stage_blocks)
538
+
539
+ # Patch Merging layer (except for the last/coarsest stage)
540
+ if stage < num_stages - 1:
541
+ patch_mergers.append(
542
+ PatchMerging(
543
+ out_features=self.emb_features[stage + 1], # Target next stage dim
544
+ dtype=self.dtype,
545
+ precision=self.precision,
546
+ norm_epsilon=self.norm_epsilon,
547
+ name=f"patch_merger_{stage}"
548
+ )
549
+ )
550
+ self.encoder_blocks = encoder_blocks
551
+ self.patch_mergers = patch_mergers
552
+
553
+ # 6. --- Decoder Path (Coarse to Fine) ---
554
+ decoder_blocks = []
555
+ patch_expanders = []
556
+ fusion_layers = []
557
+ # Iterate from second coarsest stage (N-2) down to finest (0)
558
+ for stage in range(num_stages - 2, -1, -1):
559
+ # Patch Expanding layer (Expands from stage+1 to stage)
560
+ patch_expanders.append(
561
+ PatchExpanding(
562
+ out_features=self.emb_features[stage], # Target current stage dim
563
+ dtype=self.dtype,
564
+ precision=self.precision,
565
+ norm_epsilon=self.norm_epsilon,
566
+ name=f"patch_expander_{stage}" # Naming indicates target stage
567
+ )
568
+ )
569
+ # Fusion layer (Combines skip[stage] and expanded[stage+1]->[stage])
570
+ fusion_layers.append(
571
+ nn.Sequential([ # Use Sequential for Norm + Dense
572
+ nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name=f"fusion_norm_{stage}"),
573
+ nn.Dense(
574
+ features=self.emb_features[stage], # Output current stage dim
575
+ dtype=self.dtype,
576
+ precision=self.precision,
577
+ name=f"fusion_dense_{stage}"
578
+ )
579
+ ])
580
+ )
581
+
582
+ # Blocks for this stage (stage N-2 down to 0)
583
+ # Assuming symmetric layers for now
584
+ stage_blocks = [
585
+ MMDiTBlock(
586
+ features=self.emb_features[stage],
587
+ num_heads=self.num_heads[stage],
588
+ mlp_ratio=self.mlp_ratio,
589
+ dropout_rate=self.dropout_rate,
590
+ dtype=self.dtype,
591
+ precision=self.precision,
592
+ use_flash_attention=self.use_flash_attention,
593
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
594
+ norm_epsilon=self.norm_epsilon,
595
+ rope_emb=self.ropes[stage],
596
+ name=f"decoder_block_stage{stage}_{i}"
597
+ )
598
+ for i in range(self.num_layers[stage])
599
+ ]
600
+ # Append blocks in order: stage N-2, N-3, ..., 0
601
+ decoder_blocks.append(stage_blocks)
602
+
603
+ self.patch_expanders = patch_expanders
604
+ self.fusion_layers = fusion_layers
605
+ self.decoder_blocks = decoder_blocks
606
+
607
+ # Note: The lists expanders, fusion_layers, decoder_blocks are now ordered
608
+ # corresponding to stages N-2, N-3, ..., 0.
609
+
610
+ # 7. Final Layer
611
+ self.final_norm = nn.LayerNorm(
612
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
613
+
614
+ # Output projection to pixels (at finest resolution)
615
+ output_dim = self.base_patch_size * self.base_patch_size * self.output_channels
616
+ if self.learn_sigma:
617
+ output_dim *= 2 # Predict both mean and variance
618
+
619
+ self.final_proj = nn.Dense(
620
+ features=output_dim,
621
+ dtype=self.dtype,
622
+ precision=self.precision,
623
+ kernel_init=nn.initializers.zeros, # Zero init
624
+ name="final_proj"
625
+ )
626
+
627
+ def __call__(self, x, temb, textcontext):
628
+ B, H, W, C = x.shape
629
+ num_stages = len(self.emb_features)
630
+ finest_patch_size = self.base_patch_size
631
+
632
+ # Assertions
633
+ assert H % (finest_patch_size * (2**(num_stages - 1))) == 0 and \
634
+ W % (finest_patch_size * (2**(num_stages - 1))) == 0, \
635
+ f"Image dimensions ({H},{W}) must be divisible by effective coarsest patch size {finest_patch_size * (2**(num_stages - 1))}"
636
+ assert textcontext is not None, "textcontext must be provided"
637
+
638
+ # 1. Initial Patch Embedding (Finest Level - stage 0)
639
+ H_patches = H // finest_patch_size
640
+ W_patches = W // finest_patch_size
641
+ total_patches = H_patches * W_patches # Calculate total patches
642
+ hilbert_inv_idx = None
643
+ if self.use_hilbert:
644
+ # Calculate Hilbert indices and inverse permutation for the *finest* grid
645
+ fine_idx = hilbert_indices(H_patches, W_patches)
646
+ # Pass the total number of patches as total_size
647
+ hilbert_inv_idx = inverse_permutation(fine_idx, total_size=total_patches) # Store for unpatchify
648
+
649
+ # Apply Hilbert patchify at the finest level
650
+ patches_raw, _ = hilbert_patchify(x, finest_patch_size) # We already have inv_idx
651
+ x_seq = self.hilbert_proj(patches_raw) # Shape [B, S_fine, emb[0]]
652
+ else:
653
+ x_seq = self.patch_embed(x) # Shape [B, S_fine, emb[0]]
654
+
655
+ # 2. Prepare Conditioning Signals for each stage
656
+ t_emb_base = self.time_embed_base(temb)
657
+ text_emb_base = self.text_proj_base(textcontext)
658
+ t_embs = [proj(t_emb_base) for proj in self.t_emb_projs] # List for stages 0 to N-1
659
+ text_embs = [proj(text_emb_base) for proj in self.text_emb_projs] # List for stages 0 to N-1
660
+
661
+ # --- Encoder Path (Fine to Coarse: stages 0 to N-1) ---
662
+ skip_features = {}
663
+ current_H_patches, current_W_patches = H_patches, W_patches
664
+ for stage in range(num_stages):
665
+ # Apply RoPE for current stage
666
+ seq_len = x_seq.shape[1]
667
+ freqs_cos, freqs_sin = self.ropes[stage](seq_len)
668
+
669
+ # Apply blocks for this stage
670
+ for block in self.encoder_blocks[stage]:
671
+ x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
672
+
673
+ # Store skip features (before merging)
674
+ skip_features[stage] = x_seq
675
+
676
+ # Apply Patch Merging (if not the last/coarsest stage)
677
+ if stage < num_stages - 1:
678
+ x_seq, current_H_patches, current_W_patches = self.patch_mergers[stage](
679
+ x_seq, current_H_patches, current_W_patches
680
+ )
681
+
682
+ # --- Bottleneck ---
683
+ # x_seq now holds the output of the coarsest stage (stage N-1)
684
+
685
+ # --- Decoder Path (Coarse to Fine: stages N-2 down to 0) ---
686
+ # Decoder lists (expanders, fusion, blocks) are ordered for stages N-2, ..., 0
687
+ for i, stage in enumerate(range(num_stages - 2, -1, -1)): # stage = N-2, N-3, ..., 0; i = 0, 1, ..., N-2
688
+ # Apply Patch Expanding (Expand from stage+1 feature map to stage feature map)
689
+ x_seq, current_H_patches, current_W_patches = self.patch_expanders[i](
690
+ x_seq, current_H_patches, current_W_patches
691
+ )
692
+
693
+ # Fusion with skip connection from corresponding encoder stage
694
+ skip = skip_features[stage]
695
+ x_seq = jnp.concatenate([x_seq, skip], axis=-1) # Concatenate along feature dim
696
+ x_seq = self.fusion_layers[i](x_seq) # Apply fusion (Norm + Dense)
697
+
698
+ # Apply RoPE for current stage
699
+ seq_len = x_seq.shape[1]
700
+ freqs_cos, freqs_sin = self.ropes[stage](seq_len)
701
+
702
+ # Apply blocks for this stage
703
+ for block in self.decoder_blocks[i]: # Use index i for the decoder block list
704
+ x_seq = block(x_seq, t_embs[stage], text_embs[stage], freqs_cis=(freqs_cos, freqs_sin))
705
+
706
+ # --- Final Layer ---
707
+ # x_seq should now be at the finest resolution (stage 0 features)
708
+ x_seq = self.final_norm(x_seq)
709
+ x_seq = self.final_proj(x_seq) # Project to patch pixel values
710
+
711
+ # --- Unpatchify ---
712
+ if self.use_hilbert:
713
+ # Use the inverse Hilbert index calculated for the *finest* grid
714
+ assert hilbert_inv_idx is not None, "Hilbert inverse index should exist if use_hilbert is True"
715
+ if self.learn_sigma:
716
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
717
+ out = hilbert_unpatchify(x_mean, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
718
+ # Optionally return logvar: logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
719
+ else:
720
+ out = hilbert_unpatchify(x_seq, hilbert_inv_idx, finest_patch_size, H, W, self.output_channels)
721
+ else:
722
+ # Standard unpatchify
723
+ if self.learn_sigma:
724
+ x_mean, x_logvar = jnp.split(x_seq, 2, axis=-1)
725
+ out = unpatchify(x_mean, channels=self.output_channels)
726
+ # Optionally return logvar: logvar = unpatchify(x_logvar, channels=self.output_channels)
727
+ else:
728
+ out = unpatchify(x_seq, channels=self.output_channels)
729
+
730
+ return out