flaxdiff 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -11,48 +11,8 @@ import einops
11
11
  from flax.typing import Dtype, PrecisionLike
12
12
  from functools import partial
13
13
  from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
14
-
15
-
16
- def unpatchify(x, channels=3):
17
- patch_size = int((x.shape[2] // channels) ** 0.5)
18
- h = w = int(x.shape[1] ** .5)
19
- assert h * w == x.shape[1] and patch_size ** 2 * \
20
- channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
21
- x = einops.rearrange(
22
- x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
23
- return x
24
-
25
-
26
- class PatchEmbedding(nn.Module):
27
- patch_size: int
28
- embedding_dim: int
29
- dtype: Any = jnp.float32
30
- precision: Any = jax.lax.Precision.HIGH
31
-
32
- @nn.compact
33
- def __call__(self, x):
34
- batch, height, width, channels = x.shape
35
- assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
36
-
37
- x = nn.Conv(features=self.embedding_dim,
38
- kernel_size=(self.patch_size, self.patch_size),
39
- strides=(self.patch_size, self.patch_size),
40
- dtype=self.dtype,
41
- precision=self.precision)(x)
42
- x = jnp.reshape(x, (batch, -1, self.embedding_dim))
43
- return x
44
-
45
-
46
- class PositionalEncoding(nn.Module):
47
- max_len: int
48
- embedding_dim: int
49
-
50
- @nn.compact
51
- def __call__(self, x):
52
- pe = self.param('pos_encoding',
53
- jax.nn.initializers.zeros,
54
- (1, self.max_len, self.embedding_dim))
55
- return x + pe[:, :x.shape[1], :]
14
+ from .vit_common import _rotate_half, unpatchify, PatchEmbedding, apply_rotary_embedding, RotaryEmbedding, RoPEAttention, AdaLNZero, AdaLNParams
15
+ from .simple_dit import DiTBlock
56
16
 
57
17
 
58
18
  class UViT(nn.Module):
@@ -85,16 +45,10 @@ class UViT(nn.Module):
85
45
 
86
46
  # --- Norm Layer ---
87
47
  if self.norm_groups > 0:
88
- # GroupNorm needs features arg, which varies. Define partial here, apply in __call__?
89
- # Or maybe use LayerNorm/RMSNorm consistently? Let's use LayerNorm for simplicity here.
90
- # If GroupNorm is essential, it needs careful handling with changing feature sizes.
91
- # self.norm_factory = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon, dtype=self.dtype)
92
48
  print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
93
49
  self.norm_factory = partial(
94
50
  nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
95
51
  else:
96
- # Use LayerNorm or RMSNorm for sequence normalization
97
- # self.norm_factory = partial(nn.RMSNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
98
52
  self.norm_factory = partial(
99
53
  nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
100
54
 
@@ -107,7 +61,6 @@ class UViT(nn.Module):
107
61
  name="patch_embed"
108
62
  )
109
63
  if self.use_hilbert:
110
- # Projection layer needed after raw Hilbert patches
111
64
  self.hilbert_proj = nn.Dense(
112
65
  features=self.emb_features,
113
66
  dtype=self.dtype,
@@ -115,13 +68,8 @@ class UViT(nn.Module):
115
68
  name="hilbert_projection"
116
69
  )
117
70
 
118
- # Positional encoding (learned) - applied only to patch tokens
119
- # Max length needs to accommodate max possible patches
120
- # Example: 512x512 image, patch 16 -> (512/16)^2 = 1024 patches
121
- # Estimate max patches, adjust if needed
122
71
  max_patches = (512 // self.patch_size)**2
123
72
  self.pos_encoding = self.param('pos_encoding',
124
- # Standard init for ViT pos embeds
125
73
  jax.nn.initializers.normal(stddev=0.02),
126
74
  (1, max_patches, self.emb_features))
127
75
 
@@ -131,7 +79,6 @@ class UViT(nn.Module):
131
79
  TimeProjection(features=self.emb_features)
132
80
  ], name="time_embed")
133
81
 
134
- # Text projection
135
82
  self.text_proj = nn.DenseGeneral(
136
83
  features=self.emb_features,
137
84
  dtype=self.dtype,
@@ -169,7 +116,7 @@ class UViT(nn.Module):
169
116
  )
170
117
 
171
118
  self.up_dense = [
172
- nn.DenseGeneral( # Project concatenated skip + up_path features back to emb_features
119
+ nn.DenseGeneral(
173
120
  features=self.emb_features,
174
121
  dtype=self.dtype,
175
122
  precision=self.precision,
@@ -191,157 +138,309 @@ class UViT(nn.Module):
191
138
  ]
192
139
 
193
140
  # --- Output Path ---
194
- self.final_norm = self.norm_factory(name="final_norm") # Use factory
141
+ self.final_norm = self.norm_factory(name="final_norm")
195
142
 
196
143
  patch_dim = self.patch_size ** 2 * self.output_channels
197
144
  self.final_proj = nn.Dense(
198
145
  features=patch_dim,
199
- dtype=self.dtype, # Keep model dtype for projection
146
+ dtype=self.dtype,
200
147
  precision=self.precision,
201
- kernel_init=nn.initializers.zeros, # Zero init final layer
148
+ kernel_init=nn.initializers.zeros,
202
149
  name="final_proj"
203
150
  )
204
151
 
205
152
  if self.add_residualblock_output:
206
- # Define these layers only if needed
207
153
  self.final_conv1 = ConvLayer(
208
154
  "conv",
209
155
  features=64, kernel_size=(3, 3), strides=(1, 1),
210
156
  dtype=self.dtype, precision=self.precision, name="final_conv1"
211
157
  )
212
158
  self.final_norm_conv = self.norm_factory(
213
- name="final_norm_conv") # Use factory
159
+ name="final_norm_conv")
214
160
  self.final_conv2 = ConvLayer(
215
161
  "conv",
216
162
  features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
217
- dtype=jnp.float32, # Often good to have final conv output float32
163
+ dtype=jnp.float32,
218
164
  precision=self.precision, name="final_conv2"
219
165
  )
220
166
  else:
221
- # Final conv to map features to output channels directly after unpatchify
222
167
  self.final_conv_direct = ConvLayer(
223
168
  "conv",
224
- # Use 1x1 conv
225
169
  features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
226
- dtype=jnp.float32, # Output float32
170
+ dtype=jnp.float32,
227
171
  precision=self.precision, name="final_conv_direct"
228
172
  )
229
173
 
230
174
  @nn.compact
231
175
  def __call__(self, x, temb, textcontext=None):
232
- original_img = x # Keep original for potential residual connection
176
+ original_img = x
233
177
  B, H, W, C = original_img.shape
234
178
  H_P = H // self.patch_size
235
179
  W_P = W // self.patch_size
236
180
  num_patches = H_P * W_P
237
181
  assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
238
182
 
239
- # --- Patch Embedding ---
240
183
  hilbert_inv_idx = None
241
184
  if self.use_hilbert:
242
- # Use hilbert_patchify to get raw patches and inverse index
243
185
  patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
244
- x, self.patch_size) # Shape [B, S, P*P*C]
245
- # Project raw patches
246
- # Shape [B, S, emb_features]
186
+ x, self.patch_size)
247
187
  x_patches = self.hilbert_proj(patches_raw)
248
- # Calculate inverse permutation (needs total_size)
249
188
  idx = hilbert_indices(H_P, W_P)
250
189
  hilbert_inv_idx = inverse_permutation(
251
- idx, total_size=num_patches) # Corrected call
252
- # Apply Hilbert reordering *after* projection
190
+ idx, total_size=num_patches)
253
191
  x_patches = x_patches[:, idx, :]
254
192
  else:
255
- # Standard patch embedding
256
- # Shape: [B, num_patches, emb_features]
257
193
  x_patches = self.patch_embed(x)
258
194
 
259
- # --- Positional Encoding ---
260
- # Add positional encoding only to patch tokens
261
195
  assert num_patches <= self.pos_encoding.shape[
262
196
  1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
263
197
  x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
264
198
 
265
- # --- Conditioning Tokens ---
266
- # Time embedding: [B, D] -> [B, 1, D]
267
199
  time_token = self.time_embed(temb.astype(
268
- jnp.float32)) # Ensure input is float32
200
+ jnp.float32))
269
201
  time_token = jnp.expand_dims(time_token.astype(
270
- self.dtype), axis=1) # Cast back and add seq dim
202
+ self.dtype), axis=1)
271
203
 
272
- # Text embedding: [B, S_text, D_in] -> [B, S_text, D]
273
204
  if textcontext is not None:
274
205
  text_tokens = self.text_proj(
275
- textcontext.astype(self.dtype)) # Cast context
206
+ textcontext.astype(self.dtype))
276
207
  num_text_tokens = text_tokens.shape[1]
277
- # Concatenate: [Patches+Pos, Time, Text]
278
208
  x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
279
209
  else:
280
- # Concatenate: [Patches+Pos, Time]
281
210
  num_text_tokens = 0
282
211
  x = jnp.concatenate([x_patches, time_token], axis=1)
283
212
 
284
- # --- U-Net Transformer ---
285
213
  skips = []
286
- # Down blocks (Encoder)
287
214
  for i in range(self.num_layers // 2):
288
- x = self.down_blocks[i](x) # Pass full sequence (patches+cond)
289
- skips.append(x) # Store output for skip connection
215
+ x = self.down_blocks[i](x)
216
+ skips.append(x)
290
217
 
291
- # Middle block
292
218
  x = self.mid_block(x)
293
219
 
294
- # Up blocks (Decoder)
295
220
  for i in range(self.num_layers // 2):
296
221
  skip_conn = skips.pop()
297
- # Concatenate along feature dimension
298
222
  x = jnp.concatenate([x, skip_conn], axis=-1)
299
- # Project back to emb_features
300
223
  x = self.up_dense[i](x)
301
- # Apply transformer block
302
224
  x = self.up_blocks[i](x)
303
225
 
304
- # --- Output Processing ---
305
- # Normalize before final projection
306
- x = self.final_norm(x) # Apply norm factory instance
226
+ x = self.final_norm(x)
307
227
 
308
- # Extract only the image patch tokens (first num_patches tokens)
309
- # Conditioning tokens (time, text) are discarded here
310
228
  x_patches_out = x[:, :num_patches, :]
311
229
 
312
- # Project to patch pixel dimensions
313
- # Shape: [B, num_patches, patch_dim]
314
230
  x_patches_out = self.final_proj(x_patches_out)
315
231
 
316
- # --- Unpatchify ---
317
232
  if self.use_hilbert:
318
- # Restore Hilbert order to row-major order and then to image
319
233
  assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
320
234
  x_image = hilbert_unpatchify(
321
235
  x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
322
236
  else:
323
- # Standard unpatchify
324
- # Shape: [B, H, W, C_out]
325
237
  x_image = unpatchify(x_patches_out, channels=self.output_channels)
326
238
 
327
- # --- Final Convolutions ---
328
239
  if self.add_residualblock_output:
329
- # Concatenate the original image (ensure dtype matches)
330
240
  x_image = jnp.concatenate(
331
241
  [original_img.astype(self.dtype), x_image], axis=-1)
332
242
 
333
243
  x_image = self.final_conv1(x_image)
334
- # Apply norm factory instance
335
244
  x_image = self.final_norm_conv(x_image)
336
245
  x_image = self.activation(x_image)
337
- x_image = self.final_conv2(x_image) # Outputs float32
246
+ x_image = self.final_conv2(x_image)
338
247
  else:
339
- # Apply a simple 1x1 conv to map features if needed (unpatchify already gives C_out channels)
340
- # Or just return x_image if channels match output_channels
341
- # If unpatchify output channels == self.output_channels, this might be redundant
342
- # Let's assume unpatchify gives correct channels, but ensure float32
343
- # x_image = self.final_conv_direct(x_image) # Use 1x1 conv if needed
344
- pass # Assuming unpatchify output is correct
345
-
346
- # Ensure final output is float32
248
+ pass
249
+
347
250
  return x_image
251
+
252
+
253
+ # --- Simple U-DiT ---
254
+
255
+ class SimpleUDiT(nn.Module):
256
+ """
257
+ A Simple U-Net Diffusion Transformer (U-DiT) implementation.
258
+ Combines the U-Net structure with DiT blocks using RoPE and AdaLN-Zero conditioning.
259
+ Based on SimpleDiT and standard U-Net principles.
260
+ """
261
+ output_channels: int = 3
262
+ patch_size: int = 16
263
+ emb_features: int = 768
264
+ num_layers: int = 12 # Should be even for U-Net structure
265
+ num_heads: int = 12
266
+ mlp_ratio: int = 4
267
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
268
+ dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
269
+ precision: PrecisionLike = None
270
+ use_flash_attention: bool = False # Passed to DiTBlock -> RoPEAttention
271
+ force_fp32_for_softmax: bool = True # Passed to DiTBlock -> RoPEAttention
272
+ norm_epsilon: float = 1e-5
273
+ learn_sigma: bool = False
274
+ use_hilbert: bool = False
275
+ norm_groups: int = 0
276
+ activation: Callable = jax.nn.swish
277
+
278
+ def setup(self):
279
+ assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
280
+ half_layers = self.num_layers // 2
281
+
282
+ self.patch_embed = PatchEmbedding(
283
+ patch_size=self.patch_size,
284
+ embedding_dim=self.emb_features,
285
+ dtype=self.dtype,
286
+ precision=self.precision,
287
+ name="patch_embed"
288
+ )
289
+ if self.use_hilbert:
290
+ self.hilbert_proj = nn.Dense(
291
+ features=self.emb_features,
292
+ dtype=self.dtype,
293
+ precision=self.precision,
294
+ name="hilbert_projection"
295
+ )
296
+
297
+ self.time_embed = nn.Sequential([
298
+ FourierEmbedding(features=self.emb_features),
299
+ TimeProjection(features=self.emb_features * self.mlp_ratio),
300
+ nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision)
301
+ ], name="time_embed")
302
+
303
+ self.text_proj = nn.Dense(
304
+ features=self.emb_features,
305
+ dtype=self.dtype,
306
+ precision=self.precision,
307
+ name="text_proj"
308
+ )
309
+
310
+ max_patches = (512 // self.patch_size)**2
311
+ self.rope = RotaryEmbedding(
312
+ dim=self.emb_features // self.num_heads,
313
+ max_seq_len=max_patches,
314
+ dtype=self.dtype,
315
+ name="rope_emb"
316
+ )
317
+
318
+ self.down_blocks = [
319
+ DiTBlock(
320
+ features=self.emb_features,
321
+ num_heads=self.num_heads,
322
+ mlp_ratio=self.mlp_ratio,
323
+ dropout_rate=self.dropout_rate,
324
+ dtype=self.dtype,
325
+ precision=self.precision,
326
+ use_flash_attention=self.use_flash_attention,
327
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
328
+ norm_epsilon=self.norm_epsilon,
329
+ rope_emb=self.rope,
330
+ name=f"down_block_{i}"
331
+ ) for i in range(half_layers)
332
+ ]
333
+
334
+ self.mid_block = DiTBlock(
335
+ features=self.emb_features,
336
+ num_heads=self.num_heads,
337
+ mlp_ratio=self.mlp_ratio,
338
+ dropout_rate=self.dropout_rate,
339
+ dtype=self.dtype,
340
+ precision=self.precision,
341
+ use_flash_attention=self.use_flash_attention,
342
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
343
+ norm_epsilon=self.norm_epsilon,
344
+ rope_emb=self.rope,
345
+ name="mid_block"
346
+ )
347
+
348
+ self.up_dense = [
349
+ nn.DenseGeneral(
350
+ features=self.emb_features,
351
+ dtype=self.dtype,
352
+ precision=self.precision,
353
+ name=f"up_dense_{i}"
354
+ ) for i in range(half_layers)
355
+ ]
356
+ self.up_blocks = [
357
+ DiTBlock(
358
+ features=self.emb_features,
359
+ num_heads=self.num_heads,
360
+ mlp_ratio=self.mlp_ratio,
361
+ dropout_rate=self.dropout_rate,
362
+ dtype=self.dtype,
363
+ precision=self.precision,
364
+ use_flash_attention=self.use_flash_attention,
365
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
366
+ norm_epsilon=self.norm_epsilon,
367
+ rope_emb=self.rope,
368
+ name=f"up_block_{i}"
369
+ ) for i in range(half_layers)
370
+ ]
371
+
372
+ self.final_norm = nn.LayerNorm(
373
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
374
+
375
+ output_dim = self.patch_size * self.patch_size * self.output_channels
376
+ if self.learn_sigma:
377
+ output_dim *= 2
378
+
379
+ self.final_proj = nn.Dense(
380
+ features=output_dim,
381
+ dtype=jnp.float32,
382
+ precision=self.precision,
383
+ kernel_init=nn.initializers.zeros,
384
+ name="final_proj"
385
+ )
386
+
387
+ @nn.compact
388
+ def __call__(self, x, temb, textcontext=None):
389
+ B, H, W, C = x.shape
390
+ H_P = H // self.patch_size
391
+ W_P = W // self.patch_size
392
+ num_patches = H_P * W_P
393
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
394
+
395
+ x = x.astype(self.dtype)
396
+
397
+ hilbert_inv_idx = None
398
+ if self.use_hilbert:
399
+ patches_raw, _ = hilbert_patchify(x, self.patch_size)
400
+ x_seq = self.hilbert_proj(patches_raw)
401
+ idx = hilbert_indices(H_P, W_P)
402
+ hilbert_inv_idx = inverse_permutation(idx, total_size=num_patches)
403
+ else:
404
+ x_seq = self.patch_embed(x)
405
+
406
+ t_emb = self.time_embed(temb.astype(jnp.float32))
407
+ t_emb = t_emb.astype(self.dtype)
408
+
409
+ cond_emb = t_emb
410
+ if textcontext is not None:
411
+ text_emb = self.text_proj(textcontext.astype(self.dtype))
412
+ if text_emb.ndim == 3:
413
+ text_emb = jnp.mean(text_emb, axis=1)
414
+ cond_emb = cond_emb + text_emb
415
+
416
+ skips = []
417
+ for i in range(self.num_layers // 2):
418
+ x_seq = self.down_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
419
+ skips.append(x_seq)
420
+
421
+ x_seq = self.mid_block(x_seq, conditioning=cond_emb, freqs_cis=None)
422
+
423
+ for i in range(self.num_layers // 2):
424
+ skip_conn = skips.pop()
425
+ x_seq = jnp.concatenate([x_seq, skip_conn], axis=-1)
426
+ x_seq = self.up_dense[i](x_seq)
427
+ x_seq = self.up_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
428
+
429
+ x_out = self.final_norm(x_seq)
430
+ x_out = self.final_proj(x_out)
431
+
432
+ if self.use_hilbert:
433
+ assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
434
+ if self.learn_sigma:
435
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
436
+ x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
437
+ else:
438
+ x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
439
+ else:
440
+ if self.learn_sigma:
441
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
442
+ x_image = unpatchify(x_mean, channels=self.output_channels)
443
+ else:
444
+ x_image = unpatchify(x_out, channels=self.output_channels)
445
+
446
+ return x_image.astype(jnp.float32)