flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,275 @@
1
+ import jax
2
+ import jax.numpy as jnp
3
+ from flax import linen as nn
4
+ from typing import Callable, Any, Optional, Tuple, Sequence, Union
5
+ import einops
6
+ from functools import partial
7
+
8
+ # Re-use existing components if they are suitable
9
+ from .vit_common import PatchEmbedding, unpatchify, RotaryEmbedding, RoPEAttention, AdaLNParams
10
+ from .common import kernel_init, FourierEmbedding, TimeProjection
11
+ # Using NormalAttention for RoPE integration
12
+ from .attention import NormalAttention
13
+ from flax.typing import Dtype, PrecisionLike
14
+
15
+ # Use our improved Hilbert implementation
16
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
17
+
18
+ # --- DiT Block ---
19
+ class DiTBlock(nn.Module):
20
+ features: int
21
+ num_heads: int
22
+ rope_emb: RotaryEmbedding
23
+ mlp_ratio: int = 4
24
+ dropout_rate: float = 0.0
25
+ dtype: Optional[Dtype] = None
26
+ precision: PrecisionLike = None
27
+ use_flash_attention: bool = False # Keep placeholder
28
+ force_fp32_for_softmax: bool = True
29
+ norm_epsilon: float = 1e-5
30
+ use_gating: bool = True # Add flag to easily disable gating
31
+
32
+ def setup(self):
33
+ hidden_features = int(self.features * self.mlp_ratio)
34
+ # Get modulation parameters (scale, shift, gates)
35
+ self.ada_params_module = AdaLNParams( # Use the modified module
36
+ self.features, dtype=self.dtype, precision=self.precision)
37
+
38
+ # Layer Norms - one before Attn, one before MLP
39
+ self.norm1 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm1")
40
+ self.norm2 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm2")
41
+
42
+ self.attention = RoPEAttention(
43
+ query_dim=self.features,
44
+ heads=self.num_heads,
45
+ dim_head=self.features // self.num_heads,
46
+ dtype=self.dtype,
47
+ precision=self.precision,
48
+ use_bias=True,
49
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
50
+ rope_emb=self.rope_emb
51
+ )
52
+
53
+ self.mlp = nn.Sequential([
54
+ nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
55
+ nn.gelu, # Or swish as specified in SimpleDiT? Consider consistency.
56
+ nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
57
+ ])
58
+
59
+ @nn.compact
60
+ def __call__(self, x, conditioning, freqs_cis):
61
+ # Get scale/shift/gate parameters
62
+ # Shape: [B, 1, 6*F] -> split into 6 of [B, 1, F]
63
+ scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
64
+ self.ada_params_module(conditioning), 6, axis=-1
65
+ )
66
+
67
+ # --- Attention Path ---
68
+ residual = x
69
+ norm_x_attn = self.norm1(x)
70
+ # Modulate after norm
71
+ x_attn_modulated = norm_x_attn * (1 + scale_attn) + shift_attn
72
+ attn_output = self.attention(x_attn_modulated, context=None, freqs_cis=freqs_cis)
73
+
74
+ if self.use_gating:
75
+ x = residual + gate_attn * attn_output
76
+ else:
77
+ x = residual + attn_output # Original DiT style without gate
78
+
79
+ # --- MLP Path ---
80
+ residual = x
81
+ norm_x_mlp = self.norm2(x) # Apply second LayerNorm
82
+ # Modulate after norm
83
+ x_mlp_modulated = norm_x_mlp * (1 + scale_mlp) + shift_mlp
84
+ mlp_output = self.mlp(x_mlp_modulated)
85
+
86
+ if self.use_gating:
87
+ x = residual + gate_mlp * mlp_output
88
+ else:
89
+ x = residual + mlp_output # Original DiT style without gate
90
+
91
+ return x
92
+
93
+
94
+ # --- Patch Embedding (reuse or define if needed) ---
95
+ # Assuming PatchEmbedding exists in simple_vit.py and is suitable
96
+
97
+ # --- DiT ---
98
+
99
+ class SimpleDiT(nn.Module):
100
+ output_channels: int = 3
101
+ patch_size: int = 16
102
+ emb_features: int = 768
103
+ num_layers: int = 12
104
+ num_heads: int = 12
105
+ mlp_ratio: int = 4
106
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
107
+ dtype: Optional[Dtype] = None
108
+ precision: PrecisionLike = None
109
+ # Passed down, but RoPEAttention uses NormalAttention
110
+ use_flash_attention: bool = False
111
+ force_fp32_for_softmax: bool = True
112
+ norm_epsilon: float = 1e-5
113
+ learn_sigma: bool = False # Option to predict sigma like in DiT paper
114
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
115
+ norm_groups: int = 0
116
+ activation: Callable = jax.nn.swish
117
+
118
+ def setup(self):
119
+ self.patch_embed = PatchEmbedding(
120
+ patch_size=self.patch_size,
121
+ embedding_dim=self.emb_features,
122
+ dtype=self.dtype,
123
+ precision=self.precision
124
+ )
125
+
126
+ # Add projection layer for Hilbert patches
127
+ if self.use_hilbert:
128
+ self.hilbert_proj = nn.Dense(
129
+ features=self.emb_features,
130
+ dtype=self.dtype,
131
+ precision=self.precision,
132
+ name="hilbert_projection"
133
+ )
134
+
135
+ # Time embedding projection
136
+ self.time_embed = nn.Sequential([
137
+ FourierEmbedding(features=self.emb_features),
138
+ TimeProjection(features=self.emb_features *
139
+ self.mlp_ratio), # Project to MLP dim
140
+ nn.Dense(features=self.emb_features, dtype=self.dtype,
141
+ precision=self.precision) # Final projection
142
+ ])
143
+
144
+ # Text context projection (if used)
145
+ # Assuming textcontext is already projected to some dimension, project it to match emb_features
146
+ # This might need adjustment based on how text context is provided
147
+ self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
148
+ precision=self.precision, name="text_context_proj")
149
+
150
+ # Rotary Positional Embedding
151
+ # Max length needs to be estimated or set large enough.
152
+ # For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
153
+ # Add 1 if a class token is used, or more for text tokens if concatenated.
154
+ # Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
155
+ # If only patches use RoPE, max_len = max_image_tokens
156
+ # If time/text are concatenated *before* blocks, max_len needs to include them.
157
+ # DiT typically applies PE only to patch tokens. Let's follow that.
158
+ # max_len should be max number of patches.
159
+ # Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
160
+ self.rope = RotaryEmbedding(
161
+ dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
162
+
163
+ # Transformer Blocks
164
+ self.blocks = [
165
+ DiTBlock(
166
+ features=self.emb_features,
167
+ num_heads=self.num_heads,
168
+ mlp_ratio=self.mlp_ratio,
169
+ dropout_rate=self.dropout_rate,
170
+ dtype=self.dtype,
171
+ precision=self.precision,
172
+ use_flash_attention=self.use_flash_attention,
173
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
174
+ norm_epsilon=self.norm_epsilon,
175
+ rope_emb=self.rope, # Pass RoPE instance
176
+ name=f"dit_block_{i}"
177
+ ) for i in range(self.num_layers)
178
+ ]
179
+
180
+ # Final Layer (Normalization + Linear Projection)
181
+ self.final_norm = nn.LayerNorm(
182
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
183
+ # self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
184
+
185
+ # Predict patch pixels + potentially sigma
186
+ output_dim = self.patch_size * self.patch_size * self.output_channels
187
+ if self.learn_sigma:
188
+ output_dim *= 2 # Predict both mean and variance (or log_variance)
189
+
190
+ self.final_proj = nn.Dense(
191
+ features=output_dim,
192
+ dtype=self.dtype,
193
+ precision=self.precision,
194
+ kernel_init=nn.initializers.zeros, # Initialize final layer to zero
195
+ name="final_proj"
196
+ )
197
+
198
+ @nn.compact
199
+ def __call__(self, x, temb, textcontext=None):
200
+ B, H, W, C = x.shape
201
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
202
+
203
+ # Compute dimensions in terms of patches
204
+ H_P = H // self.patch_size
205
+ W_P = W // self.patch_size
206
+
207
+ # 1. Patch Embedding
208
+ if self.use_hilbert:
209
+ # Use hilbert_patchify which handles both patchification and reordering
210
+ patches_raw, hilbert_inv_idx = hilbert_patchify(x, self.patch_size) # Shape [B, S, P*P*C]
211
+ # Apply projection
212
+ patches = self.hilbert_proj(patches_raw) # Shape [B, S, emb_features]
213
+ else:
214
+ patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
215
+ hilbert_inv_idx = None
216
+
217
+ num_patches = patches.shape[1]
218
+ x_seq = patches
219
+
220
+ # 2. Prepare Conditioning Signal (Time + Text Context)
221
+ t_emb = self.time_embed(temb) # Shape: [B, emb_features]
222
+
223
+ cond_emb = t_emb
224
+ if textcontext is not None:
225
+ # Shape: [B, num_text_tokens, emb_features]
226
+ text_emb = self.text_proj(textcontext)
227
+ # Pool or select text embedding (e.g., mean pool or use CLS token)
228
+ # Assuming mean pooling for simplicity
229
+ # Shape: [B, emb_features]
230
+ text_emb_pooled = jnp.mean(text_emb, axis=1)
231
+ cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
232
+
233
+ # 3. Apply RoPE
234
+ # Get RoPE frequencies for the sequence length (number of patches)
235
+ # Shape [num_patches, D_head/2]
236
+ freqs_cos, freqs_sin = self.rope(seq_len=num_patches)
237
+
238
+ # 4. Apply Transformer Blocks with adaLN-Zero conditioning
239
+ for block in self.blocks:
240
+ x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=(freqs_cos, freqs_sin))
241
+
242
+ # 5. Final Layer
243
+ x_out = self.final_norm(x_seq)
244
+ # Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
245
+ x_out = self.final_proj(x_out)
246
+
247
+ # 6. Unpatchify
248
+ if self.use_hilbert:
249
+ # For Hilbert mode, we need to use the specialized unpatchify function
250
+ if self.learn_sigma:
251
+ # Split into mean and variance predictions
252
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
253
+ x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
254
+ # If needed, also unpack the logvar
255
+ # logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
256
+ # return x_image, logvar_image
257
+ return x_image
258
+ else:
259
+ x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
260
+ return x_image
261
+ else:
262
+ # Standard patch ordering - use the existing unpatchify function
263
+ if self.learn_sigma:
264
+ # Split into mean and variance predictions
265
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
266
+ x = unpatchify(x_mean, channels=self.output_channels)
267
+ # Return both mean and logvar if needed by the loss function
268
+ # For now, just returning the mean prediction like standard diffusion models
269
+ # logvar = unpatchify(x_logvar, channels=self.output_channels)
270
+ # return x, logvar
271
+ return x
272
+ else:
273
+ # Shape: [B, H, W, C]
274
+ x = unpatchify(x_out, channels=self.output_channels)
275
+ return x