flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,380 +0,0 @@
1
- # flaxdiff/models/better_uvit.py
2
- import jax
3
- import jax.numpy as jnp
4
- from flax import linen as nn
5
- from typing import Callable, Any, Optional, Tuple, Sequence, Union
6
- import einops
7
- from functools import partial
8
-
9
- # Re-use existing components if they are suitable
10
- from .common import kernel_init, FourierEmbedding, TimeProjection, hilbert_indices, inverse_permutation
11
- from .attention import NormalAttention # Using NormalAttention for RoPE integration
12
- from flax.typing import Dtype, PrecisionLike
13
-
14
- # --- Rotary Positional Embedding (RoPE) ---
15
- # Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
16
-
17
- def _rotate_half(x: jax.Array) -> jax.Array:
18
- """Rotates half the hidden dims of the input."""
19
- x1 = x[..., : x.shape[-1] // 2]
20
- x2 = x[..., x.shape[-1] // 2 :]
21
- return jnp.concatenate((-x2, x1), axis=-1)
22
-
23
- def apply_rotary_embedding(
24
- x: jax.Array, freqs_cis: jax.Array
25
- ) -> jax.Array:
26
- """Applies rotary embedding to the input tensor using rotate_half method."""
27
- # x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
28
- # freqs_cis shape: complex [Sequence, Dimension / 2]
29
-
30
- # Extract cos and sin from the complex freqs_cis
31
- cos_freqs = jnp.real(freqs_cis) # Shape [S, D/2]
32
- sin_freqs = jnp.imag(freqs_cis) # Shape [S, D/2]
33
-
34
- # Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
35
- if x.ndim == 4: # [B, H, S, D]
36
- cos_freqs = jnp.expand_dims(cos_freqs, axis=(0, 1))
37
- sin_freqs = jnp.expand_dims(sin_freqs, axis=(0, 1))
38
- elif x.ndim == 3: # [B, S, D]
39
- cos_freqs = jnp.expand_dims(cos_freqs, axis=0)
40
- sin_freqs = jnp.expand_dims(sin_freqs, axis=0)
41
-
42
- # Duplicate cos and sin for the full dimension D
43
- # Shape becomes [..., S, D]
44
- cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
45
- sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
46
-
47
- # Apply rotation: x * cos + rotate_half(x) * sin
48
- x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
49
- return x_rotated.astype(x.dtype)
50
-
51
-
52
- class RotaryEmbedding(nn.Module):
53
- dim: int # Dimension of the head
54
- max_seq_len: int = 2048
55
- base: int = 10000
56
- dtype: Dtype = jnp.float32
57
-
58
- def setup(self):
59
- inv_freq = 1.0 / (
60
- self.base ** (jnp.arange(0, self.dim, 2, dtype=jnp.float32) / self.dim)
61
- )
62
- t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
63
- freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
64
-
65
- # Precompute the complex form: cos(theta) + i * sin(theta)
66
- self.freqs_cis_complex = jnp.cos(freqs) + 1j * jnp.sin(freqs)
67
- # Shape: [max_seq_len, dim / 2]
68
-
69
- def __call__(self, seq_len: int):
70
- if seq_len > self.max_seq_len:
71
- raise ValueError(f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
72
- # Return complex shape [seq_len, dim / 2]
73
- return self.freqs_cis_complex[:seq_len, :]
74
-
75
- # --- Attention with RoPE ---
76
-
77
- class RoPEAttention(NormalAttention):
78
- rope_emb: RotaryEmbedding
79
-
80
- @nn.compact
81
- def __call__(self, x, context=None, freqs_cis=None):
82
- # x has shape [B, H, W, C] or [B, S, C]
83
- orig_x_shape = x.shape
84
- is_4d = len(orig_x_shape) == 4
85
- if is_4d:
86
- B, H, W, C = x.shape
87
- seq_len = H * W
88
- x = x.reshape((B, seq_len, C))
89
- else:
90
- B, seq_len, C = x.shape
91
-
92
- context = x if context is None else context
93
- if len(context.shape) == 4:
94
- _B, _H, _W, _C = context.shape
95
- context_seq_len = _H * _W
96
- context = context.reshape((B, context_seq_len, _C))
97
- else:
98
- _B, context_seq_len, _C = context.shape
99
-
100
- query = self.query(x) # [B, S, H, D]
101
- key = self.key(context) # [B, S_ctx, H, D]
102
- value = self.value(context) # [B, S_ctx, H, D]
103
-
104
- # Apply RoPE to query and key
105
- if freqs_cis is not None:
106
- # Permute to [B, H, S, D] for RoPE application if needed by apply_rotary_embedding
107
- query = einops.rearrange(query, 'b s h d -> b h s d')
108
- key = einops.rearrange(key, 'b s h d -> b h s d')
109
-
110
- query = apply_rotary_embedding(query, freqs_cis)
111
- key = apply_rotary_embedding(key, freqs_cis) # Apply to key as well
112
-
113
- # Permute back to [B, S, H, D] for dot_product_attention
114
- query = einops.rearrange(query, 'b h s d -> b s h d')
115
- key = einops.rearrange(key, 'b h s d -> b s h d')
116
-
117
- hidden_states = nn.dot_product_attention(
118
- query, key, value, dtype=self.dtype, broadcast_dropout=False,
119
- dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
120
- deterministic=True
121
- ) # Output shape [B, S, H, D]
122
-
123
- proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
124
-
125
- if is_4d:
126
- proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
127
-
128
- return proj
129
-
130
- # --- adaLN-Zero ---
131
-
132
- class AdaLNZero(nn.Module):
133
- features: int
134
- dtype: Optional[Dtype] = None
135
- precision: PrecisionLike = None
136
- norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
137
-
138
- @nn.compact
139
- def __call__(self, x, conditioning):
140
- # Project conditioning signal to get scale and shift parameters
141
- # Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
142
- # Or [B, 1, 6*features] if x is [B, S, F]
143
-
144
- # Ensure conditioning has seq dim if x does
145
- if x.ndim == 3 and conditioning.ndim == 2: # x=[B,S,F], cond=[B,D_cond]
146
- conditioning = jnp.expand_dims(conditioning, axis=1) # cond=[B,1,D_cond]
147
-
148
- # Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
149
- # Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
150
- ada_params = nn.Dense(
151
- features=6 * self.features,
152
- dtype=self.dtype,
153
- precision=self.precision,
154
- kernel_init=nn.initializers.zeros, # Initialize projection to zero (Zero init)
155
- name="ada_proj"
156
- )(conditioning)
157
-
158
- # Split into scale, shift, gate for MLP and Attention
159
- scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(ada_params, 6, axis=-1)
160
-
161
- # Apply Layer Normalization
162
- norm = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype)
163
- # norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
164
-
165
- norm_x = norm(x)
166
-
167
- # Modulate for Attention path
168
- x_attn = norm_x * (1 + scale_attn) + shift_attn
169
-
170
- # Modulate for MLP path
171
- x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
172
-
173
- # Return modulated outputs and gates
174
- return x_attn, gate_attn, x_mlp, gate_mlp
175
-
176
-
177
- # --- DiT Block ---
178
-
179
- class DiTBlock(nn.Module):
180
- features: int
181
- num_heads: int
182
- mlp_ratio: int = 4
183
- dropout_rate: float = 0.0 # Typically dropout is not used in diffusion models
184
- dtype: Optional[Dtype] = None
185
- precision: PrecisionLike = None
186
- use_flash_attention: bool = False # Keep option, but RoPEAttention uses NormalAttention base
187
- force_fp32_for_softmax: bool = True
188
- norm_epsilon: float = 1e-5
189
- rope_emb: RotaryEmbedding # Pass RoPE module
190
-
191
- def setup(self):
192
- hidden_features = int(self.features * self.mlp_ratio)
193
- self.ada_ln_zero = AdaLNZero(self.features, dtype=self.dtype, precision=self.precision, norm_epsilon=self.norm_epsilon)
194
-
195
- # Use RoPEAttention
196
- self.attention = RoPEAttention(
197
- query_dim=self.features,
198
- heads=self.num_heads,
199
- dim_head=self.features // self.num_heads,
200
- dtype=self.dtype,
201
- precision=self.precision,
202
- use_bias=True, # Bias is common in DiT attention proj
203
- force_fp32_for_softmax=self.force_fp32_for_softmax,
204
- rope_emb=self.rope_emb # Pass RoPE module instance
205
- )
206
-
207
- # Standard MLP block
208
- self.mlp = nn.Sequential([
209
- nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
210
- nn.gelu,
211
- nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
212
- ])
213
-
214
- @nn.compact
215
- def __call__(self, x, conditioning, freqs_cis):
216
- # x shape: [B, S, F]
217
- # conditioning shape: [B, D_cond]
218
-
219
- residual = x
220
-
221
- # Apply adaLN-Zero to get modulated inputs and gates
222
- x_attn, gate_attn, x_mlp, gate_mlp = self.ada_ln_zero(x, conditioning)
223
-
224
- # Attention block
225
- attn_output = self.attention(x_attn, context=None, freqs_cis=freqs_cis) # Self-attention only
226
- x = residual + gate_attn * attn_output
227
-
228
- # MLP block
229
- mlp_output = self.mlp(x_mlp)
230
- x = x + gate_mlp * mlp_output
231
-
232
- return x
233
-
234
- # --- Patch Embedding (reuse or define if needed) ---
235
- # Assuming PatchEmbedding exists in simple_vit.py and is suitable
236
- from .simple_vit import PatchEmbedding, unpatchify
237
-
238
- # --- Better UViT (DiT Style) ---
239
-
240
- class BetterUViT(nn.Module):
241
- output_channels: int = 3
242
- patch_size: int = 16
243
- emb_features: int = 768
244
- num_layers: int = 12
245
- num_heads: int = 12
246
- mlp_ratio: int = 4
247
- dropout_rate: float = 0.0 # Typically 0 for diffusion
248
- dtype: Optional[Dtype] = None
249
- precision: PrecisionLike = None
250
- use_flash_attention: bool = False # Passed down, but RoPEAttention uses NormalAttention
251
- force_fp32_for_softmax: bool = True
252
- norm_epsilon: float = 1e-5
253
- learn_sigma: bool = False # Option to predict sigma like in DiT paper
254
- use_hilbert: bool = False # Toggle Hilbert patch reorder
255
-
256
- def setup(self):
257
- self.patch_embed = PatchEmbedding(
258
- patch_size=self.patch_size,
259
- embedding_dim=self.emb_features,
260
- dtype=self.dtype,
261
- precision=self.precision
262
- )
263
-
264
- # Time embedding projection
265
- self.time_embed = nn.Sequential([
266
- FourierEmbedding(features=self.emb_features),
267
- TimeProjection(features=self.emb_features * self.mlp_ratio), # Project to MLP dim
268
- nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision) # Final projection
269
- ])
270
-
271
- # Text context projection (if used)
272
- # Assuming textcontext is already projected to some dimension, project it to match emb_features
273
- # This might need adjustment based on how text context is provided
274
- self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision, name="text_context_proj")
275
-
276
- # Rotary Positional Embedding
277
- # Max length needs to be estimated or set large enough.
278
- # For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
279
- # Add 1 if a class token is used, or more for text tokens if concatenated.
280
- # Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
281
- # If only patches use RoPE, max_len = max_image_tokens
282
- # If time/text are concatenated *before* blocks, max_len needs to include them.
283
- # DiT typically applies PE only to patch tokens. Let's follow that.
284
- # max_len should be max number of patches.
285
- # Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
286
- self.rope = RotaryEmbedding(dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
287
-
288
- # Transformer Blocks
289
- self.blocks = [
290
- DiTBlock(
291
- features=self.emb_features,
292
- num_heads=self.num_heads,
293
- mlp_ratio=self.mlp_ratio,
294
- dropout_rate=self.dropout_rate,
295
- dtype=self.dtype,
296
- precision=self.precision,
297
- use_flash_attention=self.use_flash_attention,
298
- force_fp32_for_softmax=self.force_fp32_for_softmax,
299
- norm_epsilon=self.norm_epsilon,
300
- rope_emb=self.rope, # Pass RoPE instance
301
- name=f"dit_block_{i}"
302
- ) for i in range(self.num_layers)
303
- ]
304
-
305
- # Final Layer (Normalization + Linear Projection)
306
- self.final_norm = nn.LayerNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
307
- # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
308
-
309
- # Predict patch pixels + potentially sigma
310
- output_dim = self.patch_size * self.patch_size * self.output_channels
311
- if self.learn_sigma:
312
- output_dim *= 2 # Predict both mean and variance (or log_variance)
313
-
314
- self.final_proj = nn.Dense(
315
- features=output_dim,
316
- dtype=self.dtype,
317
- precision=self.precision,
318
- kernel_init=nn.initializers.zeros, # Initialize final layer to zero
319
- name="final_proj"
320
- )
321
-
322
- @nn.compact
323
- def __call__(self, x, temb, textcontext=None):
324
- B, H, W, C = x.shape
325
- assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
326
-
327
- # 1. Patch Embedding
328
- patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
329
- num_patches = patches.shape[1]
330
-
331
- # Optional Hilbert reorder
332
- if self.use_hilbert:
333
- idx = hilbert_indices(H // self.patch_size, W // self.patch_size)
334
- inv_idx = inverse_permutation(idx)
335
- patches = patches[:, idx, :]
336
-
337
- # replace x with patches
338
- x_seq = patches
339
-
340
- # 2. Prepare Conditioning Signal (Time + Text Context)
341
- t_emb = self.time_embed(temb) # Shape: [B, emb_features]
342
-
343
- cond_emb = t_emb
344
- if textcontext is not None:
345
- text_emb = self.text_proj(textcontext) # Shape: [B, num_text_tokens, emb_features]
346
- # Pool or select text embedding (e.g., mean pool or use CLS token)
347
- # Assuming mean pooling for simplicity
348
- text_emb_pooled = jnp.mean(text_emb, axis=1) # Shape: [B, emb_features]
349
- cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
350
-
351
- # 3. Apply RoPE
352
- # Get RoPE frequencies for the sequence length (number of patches)
353
- freqs_cis = self.rope(seq_len=num_patches) # Shape [num_patches, D_head/2]
354
-
355
- # 4. Apply Transformer Blocks with adaLN-Zero conditioning
356
- for block in self.blocks:
357
- x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=freqs_cis)
358
-
359
- # 5. Final Layer
360
- x_out = self.final_norm(x_seq)
361
- x_out = self.final_proj(x_out) # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
362
-
363
- # Optional Hilbert inverse reorder
364
- if self.use_hilbert:
365
- x_out = x_out[:, inv_idx, :]
366
-
367
- # 6. Unpatchify
368
- if self.learn_sigma:
369
- # Split into mean and variance predictions
370
- x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
371
- x = unpatchify(x_mean, channels=self.output_channels)
372
- # Return both mean and logvar if needed by the loss function
373
- # For now, just returning the mean prediction like standard diffusion models
374
- # logvar = unpatchify(x_logvar, channels=self.output_channels)
375
- # return x, logvar
376
- return x
377
- else:
378
- x = unpatchify(x_out, channels=self.output_channels) # Shape: [B, H, W, C]
379
- return x
380
-