flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,476 @@
1
+ from .simple_vit import PatchEmbedding, unpatchify
2
+ import jax
3
+ import jax.numpy as jnp
4
+ from flax import linen as nn
5
+ from typing import Callable, Any, Optional, Tuple, Sequence, Union
6
+ import einops
7
+ from functools import partial
8
+
9
+ # Re-use existing components if they are suitable
10
+ from .common import kernel_init, FourierEmbedding, TimeProjection
11
+ # Using NormalAttention for RoPE integration
12
+ from .attention import NormalAttention
13
+ from flax.typing import Dtype, PrecisionLike
14
+
15
+ # Use our improved Hilbert implementation
16
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
17
+
18
+ # --- Rotary Positional Embedding (RoPE) ---
19
+ # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
20
+
21
+
22
+ def _rotate_half(x: jax.Array) -> jax.Array:
23
+ """Rotates half the hidden dims of the input."""
24
+ x1 = x[..., : x.shape[-1] // 2]
25
+ x2 = x[..., x.shape[-1] // 2:]
26
+ return jnp.concatenate((-x2, x1), axis=-1)
27
+
28
+ def apply_rotary_embedding(
29
+ x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
30
+ ) -> jax.Array:
31
+ """Applies rotary embedding to the input tensor using rotate_half method."""
32
+ # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
33
+ # freqs_cos/sin shape: [Sequence, Dimension / 2]
34
+
35
+ # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
36
+ if x.ndim == 4: # [B, H, S, D]
37
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
38
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
39
+ elif x.ndim == 3: # [B, S, D]
40
+ cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
41
+ sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
42
+
43
+ # Duplicate cos and sin for the full dimension D
44
+ # Shape becomes [..., S, D]
45
+ cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
46
+ sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
47
+
48
+ # Apply rotation: x * cos + rotate_half(x) * sin
49
+ x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
50
+ return x_rotated.astype(x.dtype)
51
+
52
+ class RotaryEmbedding(nn.Module):
53
+ dim: int # Dimension of the head
54
+ max_seq_len: int = 2048
55
+ base: int = 10000
56
+ dtype: Dtype = jnp.float32
57
+
58
+ def setup(self):
59
+ inv_freq = 1.0 / (
60
+ self.base ** (jnp.arange(0, self.dim, 2,
61
+ dtype=jnp.float32) / self.dim)
62
+ )
63
+ t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
64
+ freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
65
+
66
+ # Store cosine and sine separately instead of as complex numbers
67
+ self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2]
68
+ self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2]
69
+
70
+ def __call__(self, seq_len: int):
71
+ if seq_len > self.max_seq_len:
72
+ raise ValueError(
73
+ f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
74
+ # Return separate cos and sin components
75
+ return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
76
+ # --- Attention with RoPE ---
77
+
78
+
79
+ class RoPEAttention(NormalAttention):
80
+ rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding
81
+
82
+ @nn.compact
83
+ def __call__(self, x, context=None, freqs_cis=None):
84
+ # x has shape [B, H, W, C] or [B, S, C]
85
+ orig_x_shape = x.shape
86
+ is_4d = len(orig_x_shape) == 4
87
+ if is_4d:
88
+ B, H, W, C = x.shape
89
+ seq_len = H * W
90
+ x = x.reshape((B, seq_len, C))
91
+ else:
92
+ B, seq_len, C = x.shape
93
+
94
+ context = x if context is None else context
95
+ if len(context.shape) == 4:
96
+ _B, _H, _W, _C = context.shape
97
+ context_seq_len = _H * _W
98
+ context = context.reshape((B, context_seq_len, _C))
99
+ # else: context is already [B, S_ctx, C]
100
+
101
+ query = self.query(x) # [B, S, H, D]
102
+ key = self.key(context) # [B, S_ctx, H, D]
103
+ value = self.value(context) # [B, S_ctx, H, D]
104
+
105
+ # Apply RoPE to query and key
106
+ if freqs_cis is None:
107
+ # Generate frequencies using the rope_emb instance
108
+ seq_len_q = query.shape[1] # Use query's sequence length
109
+ freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
110
+ else:
111
+ # If freqs_cis is passed in as a tuple
112
+ freqs_cos, freqs_sin = freqs_cis
113
+
114
+ # Apply RoPE to query and key
115
+ # Permute to [B, H, S, D] for RoPE application
116
+ query = einops.rearrange(query, 'b s h d -> b h s d')
117
+ key = einops.rearrange(key, 'b s h d -> b h s d')
118
+
119
+ # Apply RoPE only up to the context sequence length for keys if different
120
+ # Assuming self-attention or context has same seq len for simplicity here
121
+ query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
122
+ key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
123
+
124
+ # Permute back to [B, S, H, D] for dot_product_attention
125
+ query = einops.rearrange(query, 'b h s d -> b s h d')
126
+ key = einops.rearrange(key, 'b h s d -> b s h d')
127
+
128
+ hidden_states = nn.dot_product_attention(
129
+ query, key, value, dtype=self.dtype, broadcast_dropout=False,
130
+ dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
131
+ deterministic=True
132
+ ) # Output shape [B, S, H, D]
133
+
134
+ # Use the proj_attn from NormalAttention which expects [B, S, H, D]
135
+ proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
136
+
137
+ if is_4d:
138
+ proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
139
+
140
+ return proj
141
+
142
+ # --- adaLN-Zero ---
143
+
144
+
145
+ class AdaLNZero(nn.Module):
146
+ features: int
147
+ dtype: Optional[Dtype] = None
148
+ precision: PrecisionLike = None
149
+ norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
150
+
151
+ @nn.compact
152
+ def __call__(self, x, conditioning):
153
+ # Project conditioning signal to get scale and shift parameters
154
+ # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
155
+ # Or [B, 1, 6*features] if x is [B, S, F]
156
+
157
+ # Ensure conditioning has seq dim if x does
158
+ # x=[B,S,F], cond=[B,D_cond]
159
+ if x.ndim == 3 and conditioning.ndim == 2:
160
+ conditioning = jnp.expand_dims(
161
+ conditioning, axis=1) # cond=[B,1,D_cond]
162
+
163
+ # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
164
+ # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
165
+ ada_params = nn.Dense(
166
+ features=6 * self.features,
167
+ dtype=self.dtype,
168
+ precision=self.precision,
169
+ # Initialize projection to zero (Zero init)
170
+ kernel_init=nn.initializers.zeros,
171
+ name="ada_proj"
172
+ )(conditioning)
173
+
174
+ # Split into scale, shift, gate for MLP and Attention
175
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
176
+ ada_params, 6, axis=-1)
177
+
178
+ scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
179
+ shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
180
+ # Apply Layer Normalization
181
+ norm = nn.LayerNorm(epsilon=self.norm_epsilon,
182
+ use_scale=False, use_bias=False, dtype=self.dtype)
183
+ # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
184
+
185
+ norm_x = norm(x)
186
+
187
+ # Modulate for Attention path
188
+ x_attn = norm_x * (1 + scale_attn) + shift_attn
189
+
190
+ # Modulate for MLP path
191
+ x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
192
+
193
+ # Return modulated outputs and gates
194
+ return x_attn, gate_attn, x_mlp, gate_mlp
195
+
196
+ class AdaLNParams(nn.Module): # Renamed for clarity
197
+ features: int
198
+ dtype: Optional[Dtype] = None
199
+ precision: PrecisionLike = None
200
+
201
+ @nn.compact
202
+ def __call__(self, conditioning):
203
+ # Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
204
+ if conditioning.ndim == 2:
205
+ conditioning = jnp.expand_dims(conditioning, axis=1)
206
+
207
+ # Project conditioning to get 6 params per feature
208
+ ada_params = nn.Dense(
209
+ features=6 * self.features,
210
+ dtype=self.dtype,
211
+ precision=self.precision,
212
+ kernel_init=nn.initializers.zeros,
213
+ name="ada_proj"
214
+ )(conditioning)
215
+ # Return all params (or split if preferred, but maybe return tuple/dict)
216
+ # Shape: [B, 1, 6*F]
217
+ return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
218
+
219
+ # --- DiT Block ---
220
+ class DiTBlock(nn.Module):
221
+ features: int
222
+ num_heads: int
223
+ rope_emb: RotaryEmbedding
224
+ mlp_ratio: int = 4
225
+ dropout_rate: float = 0.0
226
+ dtype: Optional[Dtype] = None
227
+ precision: PrecisionLike = None
228
+ use_flash_attention: bool = False # Keep placeholder
229
+ force_fp32_for_softmax: bool = True
230
+ norm_epsilon: float = 1e-5
231
+ use_gating: bool = True # Add flag to easily disable gating
232
+
233
+ def setup(self):
234
+ hidden_features = int(self.features * self.mlp_ratio)
235
+ # Get modulation parameters (scale, shift, gates)
236
+ self.ada_params_module = AdaLNParams( # Use the modified module
237
+ self.features, dtype=self.dtype, precision=self.precision)
238
+
239
+ # Layer Norms - one before Attn, one before MLP
240
+ self.norm1 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm1")
241
+ self.norm2 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm2")
242
+
243
+ self.attention = RoPEAttention(
244
+ query_dim=self.features,
245
+ heads=self.num_heads,
246
+ dim_head=self.features // self.num_heads,
247
+ dtype=self.dtype,
248
+ precision=self.precision,
249
+ use_bias=True,
250
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
251
+ rope_emb=self.rope_emb
252
+ )
253
+
254
+ self.mlp = nn.Sequential([
255
+ nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
256
+ nn.gelu, # Or swish as specified in SimpleDiT? Consider consistency.
257
+ nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
258
+ ])
259
+
260
+ @nn.compact
261
+ def __call__(self, x, conditioning, freqs_cis):
262
+ # Get scale/shift/gate parameters
263
+ # Shape: [B, 1, 6*F] -> split into 6 of [B, 1, F]
264
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
265
+ self.ada_params_module(conditioning), 6, axis=-1
266
+ )
267
+
268
+ # --- Attention Path ---
269
+ residual = x
270
+ norm_x_attn = self.norm1(x)
271
+ # Modulate after norm
272
+ x_attn_modulated = norm_x_attn * (1 + scale_attn) + shift_attn
273
+ attn_output = self.attention(x_attn_modulated, context=None, freqs_cis=freqs_cis)
274
+
275
+ if self.use_gating:
276
+ x = residual + gate_attn * attn_output
277
+ else:
278
+ x = residual + attn_output # Original DiT style without gate
279
+
280
+ # --- MLP Path ---
281
+ residual = x
282
+ norm_x_mlp = self.norm2(x) # Apply second LayerNorm
283
+ # Modulate after norm
284
+ x_mlp_modulated = norm_x_mlp * (1 + scale_mlp) + shift_mlp
285
+ mlp_output = self.mlp(x_mlp_modulated)
286
+
287
+ if self.use_gating:
288
+ x = residual + gate_mlp * mlp_output
289
+ else:
290
+ x = residual + mlp_output # Original DiT style without gate
291
+
292
+ return x
293
+
294
+
295
+ # --- Patch Embedding (reuse or define if needed) ---
296
+ # Assuming PatchEmbedding exists in simple_vit.py and is suitable
297
+
298
+ # --- DiT ---
299
+
300
+ class SimpleDiT(nn.Module):
301
+ output_channels: int = 3
302
+ patch_size: int = 16
303
+ emb_features: int = 768
304
+ num_layers: int = 12
305
+ num_heads: int = 12
306
+ mlp_ratio: int = 4
307
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
308
+ dtype: Optional[Dtype] = None
309
+ precision: PrecisionLike = None
310
+ # Passed down, but RoPEAttention uses NormalAttention
311
+ use_flash_attention: bool = False
312
+ force_fp32_for_softmax: bool = True
313
+ norm_epsilon: float = 1e-5
314
+ learn_sigma: bool = False # Option to predict sigma like in DiT paper
315
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
316
+ norm_groups: int = 0
317
+ activation: Callable = jax.nn.swish
318
+
319
+ def setup(self):
320
+ self.patch_embed = PatchEmbedding(
321
+ patch_size=self.patch_size,
322
+ embedding_dim=self.emb_features,
323
+ dtype=self.dtype,
324
+ precision=self.precision
325
+ )
326
+
327
+ # Add projection layer for Hilbert patches
328
+ if self.use_hilbert:
329
+ self.hilbert_proj = nn.Dense(
330
+ features=self.emb_features,
331
+ dtype=self.dtype,
332
+ precision=self.precision,
333
+ name="hilbert_projection"
334
+ )
335
+
336
+ # Time embedding projection
337
+ self.time_embed = nn.Sequential([
338
+ FourierEmbedding(features=self.emb_features),
339
+ TimeProjection(features=self.emb_features *
340
+ self.mlp_ratio), # Project to MLP dim
341
+ nn.Dense(features=self.emb_features, dtype=self.dtype,
342
+ precision=self.precision) # Final projection
343
+ ])
344
+
345
+ # Text context projection (if used)
346
+ # Assuming textcontext is already projected to some dimension, project it to match emb_features
347
+ # This might need adjustment based on how text context is provided
348
+ self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
349
+ precision=self.precision, name="text_context_proj")
350
+
351
+ # Rotary Positional Embedding
352
+ # Max length needs to be estimated or set large enough.
353
+ # For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
354
+ # Add 1 if a class token is used, or more for text tokens if concatenated.
355
+ # Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
356
+ # If only patches use RoPE, max_len = max_image_tokens
357
+ # If time/text are concatenated *before* blocks, max_len needs to include them.
358
+ # DiT typically applies PE only to patch tokens. Let's follow that.
359
+ # max_len should be max number of patches.
360
+ # Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
361
+ self.rope = RotaryEmbedding(
362
+ dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
363
+
364
+ # Transformer Blocks
365
+ self.blocks = [
366
+ DiTBlock(
367
+ features=self.emb_features,
368
+ num_heads=self.num_heads,
369
+ mlp_ratio=self.mlp_ratio,
370
+ dropout_rate=self.dropout_rate,
371
+ dtype=self.dtype,
372
+ precision=self.precision,
373
+ use_flash_attention=self.use_flash_attention,
374
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
375
+ norm_epsilon=self.norm_epsilon,
376
+ rope_emb=self.rope, # Pass RoPE instance
377
+ name=f"dit_block_{i}"
378
+ ) for i in range(self.num_layers)
379
+ ]
380
+
381
+ # Final Layer (Normalization + Linear Projection)
382
+ self.final_norm = nn.LayerNorm(
383
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
384
+ # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
385
+
386
+ # Predict patch pixels + potentially sigma
387
+ output_dim = self.patch_size * self.patch_size * self.output_channels
388
+ if self.learn_sigma:
389
+ output_dim *= 2 # Predict both mean and variance (or log_variance)
390
+
391
+ self.final_proj = nn.Dense(
392
+ features=output_dim,
393
+ dtype=self.dtype,
394
+ precision=self.precision,
395
+ kernel_init=nn.initializers.zeros, # Initialize final layer to zero
396
+ name="final_proj"
397
+ )
398
+
399
+ @nn.compact
400
+ def __call__(self, x, temb, textcontext=None):
401
+ B, H, W, C = x.shape
402
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
403
+
404
+ # Compute dimensions in terms of patches
405
+ H_P = H // self.patch_size
406
+ W_P = W // self.patch_size
407
+
408
+ # 1. Patch Embedding
409
+ if self.use_hilbert:
410
+ # Use hilbert_patchify which handles both patchification and reordering
411
+ patches_raw, hilbert_inv_idx = hilbert_patchify(x, self.patch_size) # Shape [B, S, P*P*C]
412
+ # Apply projection
413
+ patches = self.hilbert_proj(patches_raw) # Shape [B, S, emb_features]
414
+ else:
415
+ patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
416
+ hilbert_inv_idx = None
417
+
418
+ num_patches = patches.shape[1]
419
+ x_seq = patches
420
+
421
+ # 2. Prepare Conditioning Signal (Time + Text Context)
422
+ t_emb = self.time_embed(temb) # Shape: [B, emb_features]
423
+
424
+ cond_emb = t_emb
425
+ if textcontext is not None:
426
+ # Shape: [B, num_text_tokens, emb_features]
427
+ text_emb = self.text_proj(textcontext)
428
+ # Pool or select text embedding (e.g., mean pool or use CLS token)
429
+ # Assuming mean pooling for simplicity
430
+ # Shape: [B, emb_features]
431
+ text_emb_pooled = jnp.mean(text_emb, axis=1)
432
+ cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
433
+
434
+ # 3. Apply RoPE
435
+ # Get RoPE frequencies for the sequence length (number of patches)
436
+ # Shape [num_patches, D_head/2]
437
+ freqs_cos, freqs_sin = self.rope(seq_len=num_patches)
438
+
439
+ # 4. Apply Transformer Blocks with adaLN-Zero conditioning
440
+ for block in self.blocks:
441
+ x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=(freqs_cos, freqs_sin))
442
+
443
+ # 5. Final Layer
444
+ x_out = self.final_norm(x_seq)
445
+ # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
446
+ x_out = self.final_proj(x_out)
447
+
448
+ # 6. Unpatchify
449
+ if self.use_hilbert:
450
+ # For Hilbert mode, we need to use the specialized unpatchify function
451
+ if self.learn_sigma:
452
+ # Split into mean and variance predictions
453
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
454
+ x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
455
+ # If needed, also unpack the logvar
456
+ # logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
457
+ # return x_image, logvar_image
458
+ return x_image
459
+ else:
460
+ x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
461
+ return x_image
462
+ else:
463
+ # Standard patch ordering - use the existing unpatchify function
464
+ if self.learn_sigma:
465
+ # Split into mean and variance predictions
466
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
467
+ x = unpatchify(x_mean, channels=self.output_channels)
468
+ # Return both mean and logvar if needed by the loss function
469
+ # For now, just returning the mean prediction like standard diffusion models
470
+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
471
+ # return x, logvar
472
+ return x
473
+ else:
474
+ # Shape: [B, H, W, C]
475
+ x = unpatchify(x_out, channels=self.output_channels)
476
+ return x