flaxdiff 0.2.8__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +11 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/images.py +29 -14
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/simple_dit.py +1 -202
- flaxdiff/models/simple_mmdit.py +1 -132
- flaxdiff/models/simple_vit.py +217 -118
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +2 -1
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.9.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.9.dist-info}/RECORD +13 -12
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.9.dist-info}/WHEEL +0 -0
- {flaxdiff-0.2.8.dist-info → flaxdiff-0.2.9.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_vit.py
CHANGED
@@ -11,48 +11,8 @@ import einops
|
|
11
11
|
from flax.typing import Dtype, PrecisionLike
|
12
12
|
from functools import partial
|
13
13
|
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
14
|
-
|
15
|
-
|
16
|
-
def unpatchify(x, channels=3):
|
17
|
-
patch_size = int((x.shape[2] // channels) ** 0.5)
|
18
|
-
h = w = int(x.shape[1] ** .5)
|
19
|
-
assert h * w == x.shape[1] and patch_size ** 2 * \
|
20
|
-
channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
21
|
-
x = einops.rearrange(
|
22
|
-
x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
23
|
-
return x
|
24
|
-
|
25
|
-
|
26
|
-
class PatchEmbedding(nn.Module):
|
27
|
-
patch_size: int
|
28
|
-
embedding_dim: int
|
29
|
-
dtype: Any = jnp.float32
|
30
|
-
precision: Any = jax.lax.Precision.HIGH
|
31
|
-
|
32
|
-
@nn.compact
|
33
|
-
def __call__(self, x):
|
34
|
-
batch, height, width, channels = x.shape
|
35
|
-
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
36
|
-
|
37
|
-
x = nn.Conv(features=self.embedding_dim,
|
38
|
-
kernel_size=(self.patch_size, self.patch_size),
|
39
|
-
strides=(self.patch_size, self.patch_size),
|
40
|
-
dtype=self.dtype,
|
41
|
-
precision=self.precision)(x)
|
42
|
-
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
43
|
-
return x
|
44
|
-
|
45
|
-
|
46
|
-
class PositionalEncoding(nn.Module):
|
47
|
-
max_len: int
|
48
|
-
embedding_dim: int
|
49
|
-
|
50
|
-
@nn.compact
|
51
|
-
def __call__(self, x):
|
52
|
-
pe = self.param('pos_encoding',
|
53
|
-
jax.nn.initializers.zeros,
|
54
|
-
(1, self.max_len, self.embedding_dim))
|
55
|
-
return x + pe[:, :x.shape[1], :]
|
14
|
+
from .vit_common import _rotate_half, unpatchify, PatchEmbedding, apply_rotary_embedding, RotaryEmbedding, RoPEAttention, AdaLNZero, AdaLNParams
|
15
|
+
from .simple_dit import DiTBlock
|
56
16
|
|
57
17
|
|
58
18
|
class UViT(nn.Module):
|
@@ -85,16 +45,10 @@ class UViT(nn.Module):
|
|
85
45
|
|
86
46
|
# --- Norm Layer ---
|
87
47
|
if self.norm_groups > 0:
|
88
|
-
# GroupNorm needs features arg, which varies. Define partial here, apply in __call__?
|
89
|
-
# Or maybe use LayerNorm/RMSNorm consistently? Let's use LayerNorm for simplicity here.
|
90
|
-
# If GroupNorm is essential, it needs careful handling with changing feature sizes.
|
91
|
-
# self.norm_factory = partial(nn.GroupNorm, self.norm_groups, epsilon=self.norm_epsilon, dtype=self.dtype)
|
92
48
|
print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
|
93
49
|
self.norm_factory = partial(
|
94
50
|
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
95
51
|
else:
|
96
|
-
# Use LayerNorm or RMSNorm for sequence normalization
|
97
|
-
# self.norm_factory = partial(nn.RMSNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
98
52
|
self.norm_factory = partial(
|
99
53
|
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
100
54
|
|
@@ -107,7 +61,6 @@ class UViT(nn.Module):
|
|
107
61
|
name="patch_embed"
|
108
62
|
)
|
109
63
|
if self.use_hilbert:
|
110
|
-
# Projection layer needed after raw Hilbert patches
|
111
64
|
self.hilbert_proj = nn.Dense(
|
112
65
|
features=self.emb_features,
|
113
66
|
dtype=self.dtype,
|
@@ -115,13 +68,8 @@ class UViT(nn.Module):
|
|
115
68
|
name="hilbert_projection"
|
116
69
|
)
|
117
70
|
|
118
|
-
# Positional encoding (learned) - applied only to patch tokens
|
119
|
-
# Max length needs to accommodate max possible patches
|
120
|
-
# Example: 512x512 image, patch 16 -> (512/16)^2 = 1024 patches
|
121
|
-
# Estimate max patches, adjust if needed
|
122
71
|
max_patches = (512 // self.patch_size)**2
|
123
72
|
self.pos_encoding = self.param('pos_encoding',
|
124
|
-
# Standard init for ViT pos embeds
|
125
73
|
jax.nn.initializers.normal(stddev=0.02),
|
126
74
|
(1, max_patches, self.emb_features))
|
127
75
|
|
@@ -131,7 +79,6 @@ class UViT(nn.Module):
|
|
131
79
|
TimeProjection(features=self.emb_features)
|
132
80
|
], name="time_embed")
|
133
81
|
|
134
|
-
# Text projection
|
135
82
|
self.text_proj = nn.DenseGeneral(
|
136
83
|
features=self.emb_features,
|
137
84
|
dtype=self.dtype,
|
@@ -169,7 +116,7 @@ class UViT(nn.Module):
|
|
169
116
|
)
|
170
117
|
|
171
118
|
self.up_dense = [
|
172
|
-
nn.DenseGeneral(
|
119
|
+
nn.DenseGeneral(
|
173
120
|
features=self.emb_features,
|
174
121
|
dtype=self.dtype,
|
175
122
|
precision=self.precision,
|
@@ -191,157 +138,309 @@ class UViT(nn.Module):
|
|
191
138
|
]
|
192
139
|
|
193
140
|
# --- Output Path ---
|
194
|
-
self.final_norm = self.norm_factory(name="final_norm")
|
141
|
+
self.final_norm = self.norm_factory(name="final_norm")
|
195
142
|
|
196
143
|
patch_dim = self.patch_size ** 2 * self.output_channels
|
197
144
|
self.final_proj = nn.Dense(
|
198
145
|
features=patch_dim,
|
199
|
-
dtype=self.dtype,
|
146
|
+
dtype=self.dtype,
|
200
147
|
precision=self.precision,
|
201
|
-
kernel_init=nn.initializers.zeros,
|
148
|
+
kernel_init=nn.initializers.zeros,
|
202
149
|
name="final_proj"
|
203
150
|
)
|
204
151
|
|
205
152
|
if self.add_residualblock_output:
|
206
|
-
# Define these layers only if needed
|
207
153
|
self.final_conv1 = ConvLayer(
|
208
154
|
"conv",
|
209
155
|
features=64, kernel_size=(3, 3), strides=(1, 1),
|
210
156
|
dtype=self.dtype, precision=self.precision, name="final_conv1"
|
211
157
|
)
|
212
158
|
self.final_norm_conv = self.norm_factory(
|
213
|
-
name="final_norm_conv")
|
159
|
+
name="final_norm_conv")
|
214
160
|
self.final_conv2 = ConvLayer(
|
215
161
|
"conv",
|
216
162
|
features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
|
217
|
-
dtype=jnp.float32,
|
163
|
+
dtype=jnp.float32,
|
218
164
|
precision=self.precision, name="final_conv2"
|
219
165
|
)
|
220
166
|
else:
|
221
|
-
# Final conv to map features to output channels directly after unpatchify
|
222
167
|
self.final_conv_direct = ConvLayer(
|
223
168
|
"conv",
|
224
|
-
# Use 1x1 conv
|
225
169
|
features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
|
226
|
-
dtype=jnp.float32,
|
170
|
+
dtype=jnp.float32,
|
227
171
|
precision=self.precision, name="final_conv_direct"
|
228
172
|
)
|
229
173
|
|
230
174
|
@nn.compact
|
231
175
|
def __call__(self, x, temb, textcontext=None):
|
232
|
-
original_img = x
|
176
|
+
original_img = x
|
233
177
|
B, H, W, C = original_img.shape
|
234
178
|
H_P = H // self.patch_size
|
235
179
|
W_P = W // self.patch_size
|
236
180
|
num_patches = H_P * W_P
|
237
181
|
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
238
182
|
|
239
|
-
# --- Patch Embedding ---
|
240
183
|
hilbert_inv_idx = None
|
241
184
|
if self.use_hilbert:
|
242
|
-
# Use hilbert_patchify to get raw patches and inverse index
|
243
185
|
patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
|
244
|
-
x, self.patch_size)
|
245
|
-
# Project raw patches
|
246
|
-
# Shape [B, S, emb_features]
|
186
|
+
x, self.patch_size)
|
247
187
|
x_patches = self.hilbert_proj(patches_raw)
|
248
|
-
# Calculate inverse permutation (needs total_size)
|
249
188
|
idx = hilbert_indices(H_P, W_P)
|
250
189
|
hilbert_inv_idx = inverse_permutation(
|
251
|
-
idx, total_size=num_patches)
|
252
|
-
# Apply Hilbert reordering *after* projection
|
190
|
+
idx, total_size=num_patches)
|
253
191
|
x_patches = x_patches[:, idx, :]
|
254
192
|
else:
|
255
|
-
# Standard patch embedding
|
256
|
-
# Shape: [B, num_patches, emb_features]
|
257
193
|
x_patches = self.patch_embed(x)
|
258
194
|
|
259
|
-
# --- Positional Encoding ---
|
260
|
-
# Add positional encoding only to patch tokens
|
261
195
|
assert num_patches <= self.pos_encoding.shape[
|
262
196
|
1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
|
263
197
|
x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
|
264
198
|
|
265
|
-
# --- Conditioning Tokens ---
|
266
|
-
# Time embedding: [B, D] -> [B, 1, D]
|
267
199
|
time_token = self.time_embed(temb.astype(
|
268
|
-
jnp.float32))
|
200
|
+
jnp.float32))
|
269
201
|
time_token = jnp.expand_dims(time_token.astype(
|
270
|
-
self.dtype), axis=1)
|
202
|
+
self.dtype), axis=1)
|
271
203
|
|
272
|
-
# Text embedding: [B, S_text, D_in] -> [B, S_text, D]
|
273
204
|
if textcontext is not None:
|
274
205
|
text_tokens = self.text_proj(
|
275
|
-
textcontext.astype(self.dtype))
|
206
|
+
textcontext.astype(self.dtype))
|
276
207
|
num_text_tokens = text_tokens.shape[1]
|
277
|
-
# Concatenate: [Patches+Pos, Time, Text]
|
278
208
|
x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
|
279
209
|
else:
|
280
|
-
# Concatenate: [Patches+Pos, Time]
|
281
210
|
num_text_tokens = 0
|
282
211
|
x = jnp.concatenate([x_patches, time_token], axis=1)
|
283
212
|
|
284
|
-
# --- U-Net Transformer ---
|
285
213
|
skips = []
|
286
|
-
# Down blocks (Encoder)
|
287
214
|
for i in range(self.num_layers // 2):
|
288
|
-
x = self.down_blocks[i](x)
|
289
|
-
skips.append(x)
|
215
|
+
x = self.down_blocks[i](x)
|
216
|
+
skips.append(x)
|
290
217
|
|
291
|
-
# Middle block
|
292
218
|
x = self.mid_block(x)
|
293
219
|
|
294
|
-
# Up blocks (Decoder)
|
295
220
|
for i in range(self.num_layers // 2):
|
296
221
|
skip_conn = skips.pop()
|
297
|
-
# Concatenate along feature dimension
|
298
222
|
x = jnp.concatenate([x, skip_conn], axis=-1)
|
299
|
-
# Project back to emb_features
|
300
223
|
x = self.up_dense[i](x)
|
301
|
-
# Apply transformer block
|
302
224
|
x = self.up_blocks[i](x)
|
303
225
|
|
304
|
-
|
305
|
-
# Normalize before final projection
|
306
|
-
x = self.final_norm(x) # Apply norm factory instance
|
226
|
+
x = self.final_norm(x)
|
307
227
|
|
308
|
-
# Extract only the image patch tokens (first num_patches tokens)
|
309
|
-
# Conditioning tokens (time, text) are discarded here
|
310
228
|
x_patches_out = x[:, :num_patches, :]
|
311
229
|
|
312
|
-
# Project to patch pixel dimensions
|
313
|
-
# Shape: [B, num_patches, patch_dim]
|
314
230
|
x_patches_out = self.final_proj(x_patches_out)
|
315
231
|
|
316
|
-
# --- Unpatchify ---
|
317
232
|
if self.use_hilbert:
|
318
|
-
# Restore Hilbert order to row-major order and then to image
|
319
233
|
assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
|
320
234
|
x_image = hilbert_unpatchify(
|
321
235
|
x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
322
236
|
else:
|
323
|
-
# Standard unpatchify
|
324
|
-
# Shape: [B, H, W, C_out]
|
325
237
|
x_image = unpatchify(x_patches_out, channels=self.output_channels)
|
326
238
|
|
327
|
-
# --- Final Convolutions ---
|
328
239
|
if self.add_residualblock_output:
|
329
|
-
# Concatenate the original image (ensure dtype matches)
|
330
240
|
x_image = jnp.concatenate(
|
331
241
|
[original_img.astype(self.dtype), x_image], axis=-1)
|
332
242
|
|
333
243
|
x_image = self.final_conv1(x_image)
|
334
|
-
# Apply norm factory instance
|
335
244
|
x_image = self.final_norm_conv(x_image)
|
336
245
|
x_image = self.activation(x_image)
|
337
|
-
x_image = self.final_conv2(x_image)
|
246
|
+
x_image = self.final_conv2(x_image)
|
338
247
|
else:
|
339
|
-
|
340
|
-
|
341
|
-
# If unpatchify output channels == self.output_channels, this might be redundant
|
342
|
-
# Let's assume unpatchify gives correct channels, but ensure float32
|
343
|
-
# x_image = self.final_conv_direct(x_image) # Use 1x1 conv if needed
|
344
|
-
pass # Assuming unpatchify output is correct
|
345
|
-
|
346
|
-
# Ensure final output is float32
|
248
|
+
pass
|
249
|
+
|
347
250
|
return x_image
|
251
|
+
|
252
|
+
|
253
|
+
# --- Simple U-DiT ---
|
254
|
+
|
255
|
+
class SimpleUDiT(nn.Module):
|
256
|
+
"""
|
257
|
+
A Simple U-Net Diffusion Transformer (U-DiT) implementation.
|
258
|
+
Combines the U-Net structure with DiT blocks using RoPE and AdaLN-Zero conditioning.
|
259
|
+
Based on SimpleDiT and standard U-Net principles.
|
260
|
+
"""
|
261
|
+
output_channels: int = 3
|
262
|
+
patch_size: int = 16
|
263
|
+
emb_features: int = 768
|
264
|
+
num_layers: int = 12 # Should be even for U-Net structure
|
265
|
+
num_heads: int = 12
|
266
|
+
mlp_ratio: int = 4
|
267
|
+
dropout_rate: float = 0.0 # Typically 0 for diffusion
|
268
|
+
dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
|
269
|
+
precision: PrecisionLike = None
|
270
|
+
use_flash_attention: bool = False # Passed to DiTBlock -> RoPEAttention
|
271
|
+
force_fp32_for_softmax: bool = True # Passed to DiTBlock -> RoPEAttention
|
272
|
+
norm_epsilon: float = 1e-5
|
273
|
+
learn_sigma: bool = False
|
274
|
+
use_hilbert: bool = False
|
275
|
+
norm_groups: int = 0
|
276
|
+
activation: Callable = jax.nn.swish
|
277
|
+
|
278
|
+
def setup(self):
|
279
|
+
assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
|
280
|
+
half_layers = self.num_layers // 2
|
281
|
+
|
282
|
+
self.patch_embed = PatchEmbedding(
|
283
|
+
patch_size=self.patch_size,
|
284
|
+
embedding_dim=self.emb_features,
|
285
|
+
dtype=self.dtype,
|
286
|
+
precision=self.precision,
|
287
|
+
name="patch_embed"
|
288
|
+
)
|
289
|
+
if self.use_hilbert:
|
290
|
+
self.hilbert_proj = nn.Dense(
|
291
|
+
features=self.emb_features,
|
292
|
+
dtype=self.dtype,
|
293
|
+
precision=self.precision,
|
294
|
+
name="hilbert_projection"
|
295
|
+
)
|
296
|
+
|
297
|
+
self.time_embed = nn.Sequential([
|
298
|
+
FourierEmbedding(features=self.emb_features),
|
299
|
+
TimeProjection(features=self.emb_features * self.mlp_ratio),
|
300
|
+
nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision)
|
301
|
+
], name="time_embed")
|
302
|
+
|
303
|
+
self.text_proj = nn.Dense(
|
304
|
+
features=self.emb_features,
|
305
|
+
dtype=self.dtype,
|
306
|
+
precision=self.precision,
|
307
|
+
name="text_proj"
|
308
|
+
)
|
309
|
+
|
310
|
+
max_patches = (512 // self.patch_size)**2
|
311
|
+
self.rope = RotaryEmbedding(
|
312
|
+
dim=self.emb_features // self.num_heads,
|
313
|
+
max_seq_len=max_patches,
|
314
|
+
dtype=self.dtype,
|
315
|
+
name="rope_emb"
|
316
|
+
)
|
317
|
+
|
318
|
+
self.down_blocks = [
|
319
|
+
DiTBlock(
|
320
|
+
features=self.emb_features,
|
321
|
+
num_heads=self.num_heads,
|
322
|
+
mlp_ratio=self.mlp_ratio,
|
323
|
+
dropout_rate=self.dropout_rate,
|
324
|
+
dtype=self.dtype,
|
325
|
+
precision=self.precision,
|
326
|
+
use_flash_attention=self.use_flash_attention,
|
327
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
328
|
+
norm_epsilon=self.norm_epsilon,
|
329
|
+
rope_emb=self.rope,
|
330
|
+
name=f"down_block_{i}"
|
331
|
+
) for i in range(half_layers)
|
332
|
+
]
|
333
|
+
|
334
|
+
self.mid_block = DiTBlock(
|
335
|
+
features=self.emb_features,
|
336
|
+
num_heads=self.num_heads,
|
337
|
+
mlp_ratio=self.mlp_ratio,
|
338
|
+
dropout_rate=self.dropout_rate,
|
339
|
+
dtype=self.dtype,
|
340
|
+
precision=self.precision,
|
341
|
+
use_flash_attention=self.use_flash_attention,
|
342
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
343
|
+
norm_epsilon=self.norm_epsilon,
|
344
|
+
rope_emb=self.rope,
|
345
|
+
name="mid_block"
|
346
|
+
)
|
347
|
+
|
348
|
+
self.up_dense = [
|
349
|
+
nn.DenseGeneral(
|
350
|
+
features=self.emb_features,
|
351
|
+
dtype=self.dtype,
|
352
|
+
precision=self.precision,
|
353
|
+
name=f"up_dense_{i}"
|
354
|
+
) for i in range(half_layers)
|
355
|
+
]
|
356
|
+
self.up_blocks = [
|
357
|
+
DiTBlock(
|
358
|
+
features=self.emb_features,
|
359
|
+
num_heads=self.num_heads,
|
360
|
+
mlp_ratio=self.mlp_ratio,
|
361
|
+
dropout_rate=self.dropout_rate,
|
362
|
+
dtype=self.dtype,
|
363
|
+
precision=self.precision,
|
364
|
+
use_flash_attention=self.use_flash_attention,
|
365
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
366
|
+
norm_epsilon=self.norm_epsilon,
|
367
|
+
rope_emb=self.rope,
|
368
|
+
name=f"up_block_{i}"
|
369
|
+
) for i in range(half_layers)
|
370
|
+
]
|
371
|
+
|
372
|
+
self.final_norm = nn.LayerNorm(
|
373
|
+
epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
374
|
+
|
375
|
+
output_dim = self.patch_size * self.patch_size * self.output_channels
|
376
|
+
if self.learn_sigma:
|
377
|
+
output_dim *= 2
|
378
|
+
|
379
|
+
self.final_proj = nn.Dense(
|
380
|
+
features=output_dim,
|
381
|
+
dtype=jnp.float32,
|
382
|
+
precision=self.precision,
|
383
|
+
kernel_init=nn.initializers.zeros,
|
384
|
+
name="final_proj"
|
385
|
+
)
|
386
|
+
|
387
|
+
@nn.compact
|
388
|
+
def __call__(self, x, temb, textcontext=None):
|
389
|
+
B, H, W, C = x.shape
|
390
|
+
H_P = H // self.patch_size
|
391
|
+
W_P = W // self.patch_size
|
392
|
+
num_patches = H_P * W_P
|
393
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
394
|
+
|
395
|
+
x = x.astype(self.dtype)
|
396
|
+
|
397
|
+
hilbert_inv_idx = None
|
398
|
+
if self.use_hilbert:
|
399
|
+
patches_raw, _ = hilbert_patchify(x, self.patch_size)
|
400
|
+
x_seq = self.hilbert_proj(patches_raw)
|
401
|
+
idx = hilbert_indices(H_P, W_P)
|
402
|
+
hilbert_inv_idx = inverse_permutation(idx, total_size=num_patches)
|
403
|
+
else:
|
404
|
+
x_seq = self.patch_embed(x)
|
405
|
+
|
406
|
+
t_emb = self.time_embed(temb.astype(jnp.float32))
|
407
|
+
t_emb = t_emb.astype(self.dtype)
|
408
|
+
|
409
|
+
cond_emb = t_emb
|
410
|
+
if textcontext is not None:
|
411
|
+
text_emb = self.text_proj(textcontext.astype(self.dtype))
|
412
|
+
if text_emb.ndim == 3:
|
413
|
+
text_emb = jnp.mean(text_emb, axis=1)
|
414
|
+
cond_emb = cond_emb + text_emb
|
415
|
+
|
416
|
+
skips = []
|
417
|
+
for i in range(self.num_layers // 2):
|
418
|
+
x_seq = self.down_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
|
419
|
+
skips.append(x_seq)
|
420
|
+
|
421
|
+
x_seq = self.mid_block(x_seq, conditioning=cond_emb, freqs_cis=None)
|
422
|
+
|
423
|
+
for i in range(self.num_layers // 2):
|
424
|
+
skip_conn = skips.pop()
|
425
|
+
x_seq = jnp.concatenate([x_seq, skip_conn], axis=-1)
|
426
|
+
x_seq = self.up_dense[i](x_seq)
|
427
|
+
x_seq = self.up_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
|
428
|
+
|
429
|
+
x_out = self.final_norm(x_seq)
|
430
|
+
x_out = self.final_proj(x_out)
|
431
|
+
|
432
|
+
if self.use_hilbert:
|
433
|
+
assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
|
434
|
+
if self.learn_sigma:
|
435
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
436
|
+
x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
437
|
+
else:
|
438
|
+
x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
439
|
+
else:
|
440
|
+
if self.learn_sigma:
|
441
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
442
|
+
x_image = unpatchify(x_mean, channels=self.output_channels)
|
443
|
+
else:
|
444
|
+
x_image = unpatchify(x_out, channels=self.output_channels)
|
445
|
+
|
446
|
+
return x_image.astype(jnp.float32)
|