flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -10,177 +10,437 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
10
10
  import einops
11
11
  from flax.typing import Dtype, PrecisionLike
12
12
  from functools import partial
13
- from .common import hilbert_indices, inverse_permutation
13
+ from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
14
+ from .vit_common import _rotate_half, unpatchify, PatchEmbedding, apply_rotary_embedding, RotaryEmbedding, RoPEAttention, AdaLNZero, AdaLNParams
15
+ from .simple_dit import DiTBlock
14
16
 
15
- def unpatchify(x, channels=3):
16
- patch_size = int((x.shape[2] // channels) ** 0.5)
17
- h = w = int(x.shape[1] ** .5)
18
- assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
19
- x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
20
- return x
21
-
22
- class PatchEmbedding(nn.Module):
23
- patch_size: int
24
- embedding_dim: int
25
- dtype: Any = jnp.float32
26
- precision: Any = jax.lax.Precision.HIGH
27
-
28
- @nn.compact
29
- def __call__(self, x):
30
- batch, height, width, channels = x.shape
31
- assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
32
-
33
- x = nn.Conv(features=self.embedding_dim,
34
- kernel_size=(self.patch_size, self.patch_size),
35
- strides=(self.patch_size, self.patch_size),
36
- dtype=self.dtype,
37
- precision=self.precision)(x)
38
- x = jnp.reshape(x, (batch, -1, self.embedding_dim))
39
- return x
40
-
41
- class PositionalEncoding(nn.Module):
42
- max_len: int
43
- embedding_dim: int
44
-
45
- @nn.compact
46
- def __call__(self, x):
47
- pe = self.param('pos_encoding',
48
- jax.nn.initializers.zeros,
49
- (1, self.max_len, self.embedding_dim))
50
- return x + pe[:, :x.shape[1], :]
51
17
 
52
18
  class UViT(nn.Module):
53
- output_channels:int=3
19
+ output_channels: int = 3
54
20
  patch_size: int = 16
55
- emb_features:int=768
56
- num_layers: int = 12
21
+ emb_features: int = 768
22
+ num_layers: int = 12 # Should be even for U-Net structure
57
23
  num_heads: int = 12
58
- dropout_rate: float = 0.1
59
- use_projection: bool = False
60
- use_flash_attention: bool = False
24
+ dropout_rate: float = 0.1 # Dropout is often 0 in diffusion models
25
+ use_projection: bool = False # In TransformerBlock MLP
26
+ use_flash_attention: bool = False # Passed to TransformerBlock
27
+ # Passed to TransformerBlock (likely False for UViT)
61
28
  use_self_and_cross: bool = False
62
- force_fp32_for_softmax: bool = True
63
- activation:Callable = jax.nn.swish
64
- norm_groups:int=8
65
- dtype: Optional[Dtype] = None
29
+ force_fp32_for_softmax: bool = True # Passed to TransformerBlock
30
+ # Used in final convs if add_residualblock_output
31
+ activation: Callable = jax.nn.swish
32
+ norm_groups: int = 8
33
+ dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
66
34
  precision: PrecisionLike = None
67
35
  add_residualblock_output: bool = False
68
- norm_inputs: bool = False
69
- explicitly_add_residual: bool = True
70
- norm_epsilon: float = 1e-4 # Added epsilon parameter, increased default
71
- use_hilbert: bool = False # Toggle Hilbert patch reorder
36
+ norm_inputs: bool = False # Passed to TransformerBlock
37
+ explicitly_add_residual: bool = True # Passed to TransformerBlock
38
+ norm_epsilon: float = 1e-5 # Adjusted default
39
+ use_hilbert: bool = False # Toggle Hilbert patch reorder
40
+ use_remat: bool = False # Add flag to use remat
72
41
 
73
42
  def setup(self):
43
+ assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
44
+ half_layers = self.num_layers // 2
45
+
46
+ # --- Norm Layer ---
74
47
  if self.norm_groups > 0:
75
- self.norm = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon)
48
+ print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
49
+ self.norm_factory = partial(
50
+ nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
76
51
  else:
77
- self.norm = partial(nn.RMSNorm, epsilon=self.norm_epsilon)
78
-
52
+ self.norm_factory = partial(
53
+ nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
54
+
55
+ # --- Input Path ---
56
+ self.patch_embed = PatchEmbedding(
57
+ patch_size=self.patch_size,
58
+ embedding_dim=self.emb_features,
59
+ dtype=self.dtype,
60
+ precision=self.precision,
61
+ name="patch_embed"
62
+ )
63
+ if self.use_hilbert:
64
+ self.hilbert_proj = nn.Dense(
65
+ features=self.emb_features,
66
+ dtype=self.dtype,
67
+ precision=self.precision,
68
+ name="hilbert_projection"
69
+ )
70
+
71
+ max_patches = (512 // self.patch_size)**2
72
+ self.pos_encoding = self.param('pos_encoding',
73
+ jax.nn.initializers.normal(stddev=0.02),
74
+ (1, max_patches, self.emb_features))
75
+
76
+ # --- Conditioning ---
77
+ self.time_embed = nn.Sequential([
78
+ FourierEmbedding(features=self.emb_features),
79
+ TimeProjection(features=self.emb_features)
80
+ ], name="time_embed")
81
+
82
+ self.text_proj = nn.DenseGeneral(
83
+ features=self.emb_features,
84
+ dtype=self.dtype,
85
+ precision=self.precision,
86
+ name="text_proj"
87
+ )
88
+
89
+ # --- Transformer Blocks ---
90
+ BlockClass = TransformerBlock
91
+
92
+ self.down_blocks = [
93
+ BlockClass(
94
+ heads=self.num_heads,
95
+ dim_head=self.emb_features // self.num_heads,
96
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
97
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
98
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
99
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
100
+ explicitly_add_residual=self.explicitly_add_residual,
101
+ norm_epsilon=self.norm_epsilon,
102
+ name=f"down_block_{i}"
103
+ ) for i in range(half_layers)
104
+ ]
105
+
106
+ self.mid_block = BlockClass(
107
+ heads=self.num_heads,
108
+ dim_head=self.emb_features // self.num_heads,
109
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
110
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
111
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
112
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
113
+ explicitly_add_residual=self.explicitly_add_residual,
114
+ norm_epsilon=self.norm_epsilon,
115
+ name="mid_block"
116
+ )
117
+
118
+ self.up_dense = [
119
+ nn.DenseGeneral(
120
+ features=self.emb_features,
121
+ dtype=self.dtype,
122
+ precision=self.precision,
123
+ name=f"up_dense_{i}"
124
+ ) for i in range(half_layers)
125
+ ]
126
+ self.up_blocks = [
127
+ BlockClass(
128
+ heads=self.num_heads,
129
+ dim_head=self.emb_features // self.num_heads,
130
+ dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
131
+ use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
132
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
133
+ only_pure_attention=False, norm_inputs=self.norm_inputs,
134
+ explicitly_add_residual=self.explicitly_add_residual,
135
+ norm_epsilon=self.norm_epsilon,
136
+ name=f"up_block_{i}"
137
+ ) for i in range(half_layers)
138
+ ]
139
+
140
+ # --- Output Path ---
141
+ self.final_norm = self.norm_factory(name="final_norm")
142
+
143
+ patch_dim = self.patch_size ** 2 * self.output_channels
144
+ self.final_proj = nn.Dense(
145
+ features=patch_dim,
146
+ dtype=self.dtype,
147
+ precision=self.precision,
148
+ kernel_init=nn.initializers.zeros,
149
+ name="final_proj"
150
+ )
151
+
152
+ if self.add_residualblock_output:
153
+ self.final_conv1 = ConvLayer(
154
+ "conv",
155
+ features=64, kernel_size=(3, 3), strides=(1, 1),
156
+ dtype=self.dtype, precision=self.precision, name="final_conv1"
157
+ )
158
+ self.final_norm_conv = self.norm_factory(
159
+ name="final_norm_conv")
160
+ self.final_conv2 = ConvLayer(
161
+ "conv",
162
+ features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
163
+ dtype=jnp.float32,
164
+ precision=self.precision, name="final_conv2"
165
+ )
166
+ else:
167
+ self.final_conv_direct = ConvLayer(
168
+ "conv",
169
+ features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
170
+ dtype=jnp.float32,
171
+ precision=self.precision, name="final_conv_direct"
172
+ )
173
+
79
174
  @nn.compact
80
175
  def __call__(self, x, temb, textcontext=None):
81
- # Time embedding
82
- temb = FourierEmbedding(features=self.emb_features)(temb)
83
- temb = TimeProjection(features=self.emb_features)(temb)
84
-
85
176
  original_img = x
86
177
  B, H, W, C = original_img.shape
87
178
  H_P = H // self.patch_size
88
179
  W_P = W // self.patch_size
180
+ num_patches = H_P * W_P
181
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
89
182
 
90
- # Patch embedding
91
- x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
92
- dtype=self.dtype, precision=self.precision)(x)
93
- num_patches = x.shape[1]
94
-
95
- # Optional Hilbert reorder
183
+ hilbert_inv_idx = None
96
184
  if self.use_hilbert:
185
+ patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
186
+ x, self.patch_size)
187
+ x_patches = self.hilbert_proj(patches_raw)
97
188
  idx = hilbert_indices(H_P, W_P)
98
- inv_idx = inverse_permutation(idx)
99
- x = x[:, idx, :]
100
-
101
- context_emb = nn.DenseGeneral(features=self.emb_features,
102
- dtype=self.dtype, precision=self.precision)(textcontext)
103
- num_text_tokens = textcontext.shape[1]
104
-
105
- # Add time embedding
106
- temb = jnp.expand_dims(temb, axis=1)
107
- x = jnp.concatenate([x, temb, context_emb], axis=1)
108
-
109
- # Add positional encoding
110
- x = PositionalEncoding(max_len=x.shape[1], embedding_dim=self.emb_features)(x)
111
-
189
+ hilbert_inv_idx = inverse_permutation(
190
+ idx, total_size=num_patches)
191
+ x_patches = x_patches[:, idx, :]
192
+ else:
193
+ x_patches = self.patch_embed(x)
194
+
195
+ assert num_patches <= self.pos_encoding.shape[
196
+ 1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
197
+ x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
198
+
199
+ time_token = self.time_embed(temb.astype(
200
+ jnp.float32))
201
+ time_token = jnp.expand_dims(time_token.astype(
202
+ self.dtype), axis=1)
203
+
204
+ if textcontext is not None:
205
+ text_tokens = self.text_proj(
206
+ textcontext.astype(self.dtype))
207
+ num_text_tokens = text_tokens.shape[1]
208
+ x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
209
+ else:
210
+ num_text_tokens = 0
211
+ x = jnp.concatenate([x_patches, time_token], axis=1)
212
+
112
213
  skips = []
113
- # In blocks
114
214
  for i in range(self.num_layers // 2):
115
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
116
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
117
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
118
- only_pure_attention=False,
119
- norm_inputs=self.norm_inputs,
120
- explicitly_add_residual=self.explicitly_add_residual,
121
- norm_epsilon=self.norm_epsilon, # Pass epsilon
122
- )(x)
215
+ x = self.down_blocks[i](x)
123
216
  skips.append(x)
124
-
125
- # Middle block
126
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
127
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
128
- use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
129
- only_pure_attention=False,
130
- norm_inputs=self.norm_inputs,
131
- explicitly_add_residual=self.explicitly_add_residual,
132
- norm_epsilon=self.norm_epsilon, # Pass epsilon
133
- )(x)
134
-
135
- # Out blocks
217
+
218
+ x = self.mid_block(x)
219
+
136
220
  for i in range(self.num_layers // 2):
137
- x = jnp.concatenate([x, skips.pop()], axis=-1)
138
- x = nn.DenseGeneral(features=self.emb_features,
139
- dtype=self.dtype, precision=self.precision)(x)
140
- x = TransformerBlock(heads=self.num_heads, dim_head=self.emb_features // self.num_heads,
141
- dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
142
- use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross, force_fp32_for_softmax=self.force_fp32_for_softmax,
143
- only_pure_attention=False,
144
- norm_inputs=self.norm_inputs,
145
- explicitly_add_residual=self.explicitly_add_residual,
146
- norm_epsilon=self.norm_epsilon, # Pass epsilon
147
- )(x)
148
-
149
- x = self.norm()(x) # Uses norm_epsilon defined in setup
150
-
151
- patch_dim = self.patch_size ** 2 * self.output_channels
152
- x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
153
- # If Hilbert, restore original patch order
221
+ skip_conn = skips.pop()
222
+ x = jnp.concatenate([x, skip_conn], axis=-1)
223
+ x = self.up_dense[i](x)
224
+ x = self.up_blocks[i](x)
225
+
226
+ x = self.final_norm(x)
227
+
228
+ x_patches_out = x[:, :num_patches, :]
229
+
230
+ x_patches_out = self.final_proj(x_patches_out)
231
+
154
232
  if self.use_hilbert:
155
- x = x[:, inv_idx, :]
156
- # Extract only the image patch tokens (first num_patches tokens)
157
- x = x[:, :num_patches, :]
158
- x = unpatchify(x, channels=self.output_channels)
159
-
233
+ assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
234
+ x_image = hilbert_unpatchify(
235
+ x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
236
+ else:
237
+ x_image = unpatchify(x_patches_out, channels=self.output_channels)
238
+
160
239
  if self.add_residualblock_output:
161
- # Concatenate the original image
162
- x = jnp.concatenate([original_img, x], axis=-1)
163
-
164
- x = ConvLayer(
165
- "conv",
166
- features=64,
167
- kernel_size=(3, 3),
168
- strides=(1, 1),
169
- # activation=jax.nn.mish
240
+ x_image = jnp.concatenate(
241
+ [original_img.astype(self.dtype), x_image], axis=-1)
242
+
243
+ x_image = self.final_conv1(x_image)
244
+ x_image = self.final_norm_conv(x_image)
245
+ x_image = self.activation(x_image)
246
+ x_image = self.final_conv2(x_image)
247
+ else:
248
+ pass
249
+
250
+ return x_image
251
+
252
+
253
+ # --- Simple U-DiT ---
254
+
255
+ class SimpleUDiT(nn.Module):
256
+ """
257
+ A Simple U-Net Diffusion Transformer (U-DiT) implementation.
258
+ Combines the U-Net structure with DiT blocks using RoPE and AdaLN-Zero conditioning.
259
+ Based on SimpleDiT and standard U-Net principles.
260
+ """
261
+ output_channels: int = 3
262
+ patch_size: int = 16
263
+ emb_features: int = 768
264
+ num_layers: int = 12 # Should be even for U-Net structure
265
+ num_heads: int = 12
266
+ mlp_ratio: int = 4
267
+ dropout_rate: float = 0.0 # Typically 0 for diffusion
268
+ dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
269
+ precision: PrecisionLike = None
270
+ use_flash_attention: bool = False # Passed to DiTBlock -> RoPEAttention
271
+ force_fp32_for_softmax: bool = True # Passed to DiTBlock -> RoPEAttention
272
+ norm_epsilon: float = 1e-5
273
+ learn_sigma: bool = False
274
+ use_hilbert: bool = False
275
+ norm_groups: int = 0
276
+ activation: Callable = jax.nn.swish
277
+
278
+ def setup(self):
279
+ assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
280
+ half_layers = self.num_layers // 2
281
+
282
+ self.patch_embed = PatchEmbedding(
283
+ patch_size=self.patch_size,
284
+ embedding_dim=self.emb_features,
285
+ dtype=self.dtype,
286
+ precision=self.precision,
287
+ name="patch_embed"
288
+ )
289
+ if self.use_hilbert:
290
+ self.hilbert_proj = nn.Dense(
291
+ features=self.emb_features,
170
292
  dtype=self.dtype,
171
- precision=self.precision
172
- )(x)
173
-
174
- x = self.norm()(x)
175
- x = self.activation(x)
176
-
177
- x = ConvLayer(
178
- "conv",
179
- features=self.output_channels,
180
- kernel_size=(3, 3),
181
- strides=(1, 1),
182
- # activation=jax.nn.mish
293
+ precision=self.precision,
294
+ name="hilbert_projection"
295
+ )
296
+
297
+ self.time_embed = nn.Sequential([
298
+ FourierEmbedding(features=self.emb_features),
299
+ TimeProjection(features=self.emb_features * self.mlp_ratio),
300
+ nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision)
301
+ ], name="time_embed")
302
+
303
+ self.text_proj = nn.Dense(
304
+ features=self.emb_features,
305
+ dtype=self.dtype,
306
+ precision=self.precision,
307
+ name="text_proj"
308
+ )
309
+
310
+ max_patches = (512 // self.patch_size)**2
311
+ self.rope = RotaryEmbedding(
312
+ dim=self.emb_features // self.num_heads,
313
+ max_seq_len=max_patches,
314
+ dtype=self.dtype,
315
+ name="rope_emb"
316
+ )
317
+
318
+ self.down_blocks = [
319
+ DiTBlock(
320
+ features=self.emb_features,
321
+ num_heads=self.num_heads,
322
+ mlp_ratio=self.mlp_ratio,
323
+ dropout_rate=self.dropout_rate,
324
+ dtype=self.dtype,
325
+ precision=self.precision,
326
+ use_flash_attention=self.use_flash_attention,
327
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
328
+ norm_epsilon=self.norm_epsilon,
329
+ rope_emb=self.rope,
330
+ name=f"down_block_{i}"
331
+ ) for i in range(half_layers)
332
+ ]
333
+
334
+ self.mid_block = DiTBlock(
335
+ features=self.emb_features,
336
+ num_heads=self.num_heads,
337
+ mlp_ratio=self.mlp_ratio,
338
+ dropout_rate=self.dropout_rate,
183
339
  dtype=self.dtype,
184
- precision=self.precision
185
- )(x)
186
- return x
340
+ precision=self.precision,
341
+ use_flash_attention=self.use_flash_attention,
342
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
343
+ norm_epsilon=self.norm_epsilon,
344
+ rope_emb=self.rope,
345
+ name="mid_block"
346
+ )
347
+
348
+ self.up_dense = [
349
+ nn.DenseGeneral(
350
+ features=self.emb_features,
351
+ dtype=self.dtype,
352
+ precision=self.precision,
353
+ name=f"up_dense_{i}"
354
+ ) for i in range(half_layers)
355
+ ]
356
+ self.up_blocks = [
357
+ DiTBlock(
358
+ features=self.emb_features,
359
+ num_heads=self.num_heads,
360
+ mlp_ratio=self.mlp_ratio,
361
+ dropout_rate=self.dropout_rate,
362
+ dtype=self.dtype,
363
+ precision=self.precision,
364
+ use_flash_attention=self.use_flash_attention,
365
+ force_fp32_for_softmax=self.force_fp32_for_softmax,
366
+ norm_epsilon=self.norm_epsilon,
367
+ rope_emb=self.rope,
368
+ name=f"up_block_{i}"
369
+ ) for i in range(half_layers)
370
+ ]
371
+
372
+ self.final_norm = nn.LayerNorm(
373
+ epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
374
+
375
+ output_dim = self.patch_size * self.patch_size * self.output_channels
376
+ if self.learn_sigma:
377
+ output_dim *= 2
378
+
379
+ self.final_proj = nn.Dense(
380
+ features=output_dim,
381
+ dtype=jnp.float32,
382
+ precision=self.precision,
383
+ kernel_init=nn.initializers.zeros,
384
+ name="final_proj"
385
+ )
386
+
387
+ @nn.compact
388
+ def __call__(self, x, temb, textcontext=None):
389
+ B, H, W, C = x.shape
390
+ H_P = H // self.patch_size
391
+ W_P = W // self.patch_size
392
+ num_patches = H_P * W_P
393
+ assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
394
+
395
+ x = x.astype(self.dtype)
396
+
397
+ hilbert_inv_idx = None
398
+ if self.use_hilbert:
399
+ patches_raw, _ = hilbert_patchify(x, self.patch_size)
400
+ x_seq = self.hilbert_proj(patches_raw)
401
+ idx = hilbert_indices(H_P, W_P)
402
+ hilbert_inv_idx = inverse_permutation(idx, total_size=num_patches)
403
+ else:
404
+ x_seq = self.patch_embed(x)
405
+
406
+ t_emb = self.time_embed(temb.astype(jnp.float32))
407
+ t_emb = t_emb.astype(self.dtype)
408
+
409
+ cond_emb = t_emb
410
+ if textcontext is not None:
411
+ text_emb = self.text_proj(textcontext.astype(self.dtype))
412
+ if text_emb.ndim == 3:
413
+ text_emb = jnp.mean(text_emb, axis=1)
414
+ cond_emb = cond_emb + text_emb
415
+
416
+ skips = []
417
+ for i in range(self.num_layers // 2):
418
+ x_seq = self.down_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
419
+ skips.append(x_seq)
420
+
421
+ x_seq = self.mid_block(x_seq, conditioning=cond_emb, freqs_cis=None)
422
+
423
+ for i in range(self.num_layers // 2):
424
+ skip_conn = skips.pop()
425
+ x_seq = jnp.concatenate([x_seq, skip_conn], axis=-1)
426
+ x_seq = self.up_dense[i](x_seq)
427
+ x_seq = self.up_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
428
+
429
+ x_out = self.final_norm(x_seq)
430
+ x_out = self.final_proj(x_out)
431
+
432
+ if self.use_hilbert:
433
+ assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
434
+ if self.learn_sigma:
435
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
436
+ x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
437
+ else:
438
+ x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
439
+ else:
440
+ if self.learn_sigma:
441
+ x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
442
+ x_image = unpatchify(x_mean, channels=self.output_channels)
443
+ else:
444
+ x_image = unpatchify(x_out, channels=self.output_channels)
445
+
446
+ return x_image.astype(jnp.float32)