flaxdiff 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +23 -19
- flaxdiff/data/dataset_map.py +2 -1
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +75 -3
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/inference/utils.py +7 -1
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +275 -0
- flaxdiff/models/simple_mmdit.py +730 -0
- flaxdiff/models/simple_vit.py +405 -145
- flaxdiff/models/vit_common.py +262 -0
- flaxdiff/trainer/general_diffusion_trainer.py +30 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/RECORD +18 -15
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.9.dist-info}/top_level.txt +0 -0
flaxdiff/models/simple_vit.py
CHANGED
@@ -10,177 +10,437 @@ from flaxdiff.models.simple_unet import FourierEmbedding, TimeProjection, ConvLa
|
|
10
10
|
import einops
|
11
11
|
from flax.typing import Dtype, PrecisionLike
|
12
12
|
from functools import partial
|
13
|
-
from .
|
13
|
+
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
14
|
+
from .vit_common import _rotate_half, unpatchify, PatchEmbedding, apply_rotary_embedding, RotaryEmbedding, RoPEAttention, AdaLNZero, AdaLNParams
|
15
|
+
from .simple_dit import DiTBlock
|
14
16
|
|
15
|
-
def unpatchify(x, channels=3):
|
16
|
-
patch_size = int((x.shape[2] // channels) ** 0.5)
|
17
|
-
h = w = int(x.shape[1] ** .5)
|
18
|
-
assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2], f"Invalid shape: {x.shape}, should be {h*w}, {patch_size**2*channels}"
|
19
|
-
x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B (h p1) (w p2) C', h=h, p1=patch_size, p2=patch_size)
|
20
|
-
return x
|
21
|
-
|
22
|
-
class PatchEmbedding(nn.Module):
|
23
|
-
patch_size: int
|
24
|
-
embedding_dim: int
|
25
|
-
dtype: Any = jnp.float32
|
26
|
-
precision: Any = jax.lax.Precision.HIGH
|
27
|
-
|
28
|
-
@nn.compact
|
29
|
-
def __call__(self, x):
|
30
|
-
batch, height, width, channels = x.shape
|
31
|
-
assert height % self.patch_size == 0 and width % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
32
|
-
|
33
|
-
x = nn.Conv(features=self.embedding_dim,
|
34
|
-
kernel_size=(self.patch_size, self.patch_size),
|
35
|
-
strides=(self.patch_size, self.patch_size),
|
36
|
-
dtype=self.dtype,
|
37
|
-
precision=self.precision)(x)
|
38
|
-
x = jnp.reshape(x, (batch, -1, self.embedding_dim))
|
39
|
-
return x
|
40
|
-
|
41
|
-
class PositionalEncoding(nn.Module):
|
42
|
-
max_len: int
|
43
|
-
embedding_dim: int
|
44
|
-
|
45
|
-
@nn.compact
|
46
|
-
def __call__(self, x):
|
47
|
-
pe = self.param('pos_encoding',
|
48
|
-
jax.nn.initializers.zeros,
|
49
|
-
(1, self.max_len, self.embedding_dim))
|
50
|
-
return x + pe[:, :x.shape[1], :]
|
51
17
|
|
52
18
|
class UViT(nn.Module):
|
53
|
-
output_channels:int=3
|
19
|
+
output_channels: int = 3
|
54
20
|
patch_size: int = 16
|
55
|
-
emb_features:int=768
|
56
|
-
num_layers: int = 12
|
21
|
+
emb_features: int = 768
|
22
|
+
num_layers: int = 12 # Should be even for U-Net structure
|
57
23
|
num_heads: int = 12
|
58
|
-
dropout_rate: float = 0.1
|
59
|
-
use_projection: bool = False
|
60
|
-
use_flash_attention: bool = False
|
24
|
+
dropout_rate: float = 0.1 # Dropout is often 0 in diffusion models
|
25
|
+
use_projection: bool = False # In TransformerBlock MLP
|
26
|
+
use_flash_attention: bool = False # Passed to TransformerBlock
|
27
|
+
# Passed to TransformerBlock (likely False for UViT)
|
61
28
|
use_self_and_cross: bool = False
|
62
|
-
force_fp32_for_softmax: bool = True
|
63
|
-
|
64
|
-
|
65
|
-
|
29
|
+
force_fp32_for_softmax: bool = True # Passed to TransformerBlock
|
30
|
+
# Used in final convs if add_residualblock_output
|
31
|
+
activation: Callable = jax.nn.swish
|
32
|
+
norm_groups: int = 8
|
33
|
+
dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
|
66
34
|
precision: PrecisionLike = None
|
67
35
|
add_residualblock_output: bool = False
|
68
|
-
norm_inputs: bool = False
|
69
|
-
explicitly_add_residual: bool = True
|
70
|
-
norm_epsilon: float = 1e-
|
71
|
-
use_hilbert: bool = False
|
36
|
+
norm_inputs: bool = False # Passed to TransformerBlock
|
37
|
+
explicitly_add_residual: bool = True # Passed to TransformerBlock
|
38
|
+
norm_epsilon: float = 1e-5 # Adjusted default
|
39
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
40
|
+
use_remat: bool = False # Add flag to use remat
|
72
41
|
|
73
42
|
def setup(self):
|
43
|
+
assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
|
44
|
+
half_layers = self.num_layers // 2
|
45
|
+
|
46
|
+
# --- Norm Layer ---
|
74
47
|
if self.norm_groups > 0:
|
75
|
-
|
48
|
+
print(f"Warning: norm_groups > 0 not fully supported with standard LayerNorm fallback in UViT setup. Using LayerNorm.")
|
49
|
+
self.norm_factory = partial(
|
50
|
+
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
76
51
|
else:
|
77
|
-
self.
|
78
|
-
|
52
|
+
self.norm_factory = partial(
|
53
|
+
nn.LayerNorm, epsilon=self.norm_epsilon, dtype=self.dtype)
|
54
|
+
|
55
|
+
# --- Input Path ---
|
56
|
+
self.patch_embed = PatchEmbedding(
|
57
|
+
patch_size=self.patch_size,
|
58
|
+
embedding_dim=self.emb_features,
|
59
|
+
dtype=self.dtype,
|
60
|
+
precision=self.precision,
|
61
|
+
name="patch_embed"
|
62
|
+
)
|
63
|
+
if self.use_hilbert:
|
64
|
+
self.hilbert_proj = nn.Dense(
|
65
|
+
features=self.emb_features,
|
66
|
+
dtype=self.dtype,
|
67
|
+
precision=self.precision,
|
68
|
+
name="hilbert_projection"
|
69
|
+
)
|
70
|
+
|
71
|
+
max_patches = (512 // self.patch_size)**2
|
72
|
+
self.pos_encoding = self.param('pos_encoding',
|
73
|
+
jax.nn.initializers.normal(stddev=0.02),
|
74
|
+
(1, max_patches, self.emb_features))
|
75
|
+
|
76
|
+
# --- Conditioning ---
|
77
|
+
self.time_embed = nn.Sequential([
|
78
|
+
FourierEmbedding(features=self.emb_features),
|
79
|
+
TimeProjection(features=self.emb_features)
|
80
|
+
], name="time_embed")
|
81
|
+
|
82
|
+
self.text_proj = nn.DenseGeneral(
|
83
|
+
features=self.emb_features,
|
84
|
+
dtype=self.dtype,
|
85
|
+
precision=self.precision,
|
86
|
+
name="text_proj"
|
87
|
+
)
|
88
|
+
|
89
|
+
# --- Transformer Blocks ---
|
90
|
+
BlockClass = TransformerBlock
|
91
|
+
|
92
|
+
self.down_blocks = [
|
93
|
+
BlockClass(
|
94
|
+
heads=self.num_heads,
|
95
|
+
dim_head=self.emb_features // self.num_heads,
|
96
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
97
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
98
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
99
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
100
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
101
|
+
norm_epsilon=self.norm_epsilon,
|
102
|
+
name=f"down_block_{i}"
|
103
|
+
) for i in range(half_layers)
|
104
|
+
]
|
105
|
+
|
106
|
+
self.mid_block = BlockClass(
|
107
|
+
heads=self.num_heads,
|
108
|
+
dim_head=self.emb_features // self.num_heads,
|
109
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
110
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
111
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
112
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
113
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
114
|
+
norm_epsilon=self.norm_epsilon,
|
115
|
+
name="mid_block"
|
116
|
+
)
|
117
|
+
|
118
|
+
self.up_dense = [
|
119
|
+
nn.DenseGeneral(
|
120
|
+
features=self.emb_features,
|
121
|
+
dtype=self.dtype,
|
122
|
+
precision=self.precision,
|
123
|
+
name=f"up_dense_{i}"
|
124
|
+
) for i in range(half_layers)
|
125
|
+
]
|
126
|
+
self.up_blocks = [
|
127
|
+
BlockClass(
|
128
|
+
heads=self.num_heads,
|
129
|
+
dim_head=self.emb_features // self.num_heads,
|
130
|
+
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
131
|
+
use_flash_attention=self.use_flash_attention, use_self_and_cross=self.use_self_and_cross,
|
132
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
133
|
+
only_pure_attention=False, norm_inputs=self.norm_inputs,
|
134
|
+
explicitly_add_residual=self.explicitly_add_residual,
|
135
|
+
norm_epsilon=self.norm_epsilon,
|
136
|
+
name=f"up_block_{i}"
|
137
|
+
) for i in range(half_layers)
|
138
|
+
]
|
139
|
+
|
140
|
+
# --- Output Path ---
|
141
|
+
self.final_norm = self.norm_factory(name="final_norm")
|
142
|
+
|
143
|
+
patch_dim = self.patch_size ** 2 * self.output_channels
|
144
|
+
self.final_proj = nn.Dense(
|
145
|
+
features=patch_dim,
|
146
|
+
dtype=self.dtype,
|
147
|
+
precision=self.precision,
|
148
|
+
kernel_init=nn.initializers.zeros,
|
149
|
+
name="final_proj"
|
150
|
+
)
|
151
|
+
|
152
|
+
if self.add_residualblock_output:
|
153
|
+
self.final_conv1 = ConvLayer(
|
154
|
+
"conv",
|
155
|
+
features=64, kernel_size=(3, 3), strides=(1, 1),
|
156
|
+
dtype=self.dtype, precision=self.precision, name="final_conv1"
|
157
|
+
)
|
158
|
+
self.final_norm_conv = self.norm_factory(
|
159
|
+
name="final_norm_conv")
|
160
|
+
self.final_conv2 = ConvLayer(
|
161
|
+
"conv",
|
162
|
+
features=self.output_channels, kernel_size=(3, 3), strides=(1, 1),
|
163
|
+
dtype=jnp.float32,
|
164
|
+
precision=self.precision, name="final_conv2"
|
165
|
+
)
|
166
|
+
else:
|
167
|
+
self.final_conv_direct = ConvLayer(
|
168
|
+
"conv",
|
169
|
+
features=self.output_channels, kernel_size=(1, 1), strides=(1, 1),
|
170
|
+
dtype=jnp.float32,
|
171
|
+
precision=self.precision, name="final_conv_direct"
|
172
|
+
)
|
173
|
+
|
79
174
|
@nn.compact
|
80
175
|
def __call__(self, x, temb, textcontext=None):
|
81
|
-
# Time embedding
|
82
|
-
temb = FourierEmbedding(features=self.emb_features)(temb)
|
83
|
-
temb = TimeProjection(features=self.emb_features)(temb)
|
84
|
-
|
85
176
|
original_img = x
|
86
177
|
B, H, W, C = original_img.shape
|
87
178
|
H_P = H // self.patch_size
|
88
179
|
W_P = W // self.patch_size
|
180
|
+
num_patches = H_P * W_P
|
181
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
89
182
|
|
90
|
-
|
91
|
-
x = PatchEmbedding(patch_size=self.patch_size, embedding_dim=self.emb_features,
|
92
|
-
dtype=self.dtype, precision=self.precision)(x)
|
93
|
-
num_patches = x.shape[1]
|
94
|
-
|
95
|
-
# Optional Hilbert reorder
|
183
|
+
hilbert_inv_idx = None
|
96
184
|
if self.use_hilbert:
|
185
|
+
patches_raw, hilbert_inv_idx_calc = hilbert_patchify(
|
186
|
+
x, self.patch_size)
|
187
|
+
x_patches = self.hilbert_proj(patches_raw)
|
97
188
|
idx = hilbert_indices(H_P, W_P)
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
189
|
+
hilbert_inv_idx = inverse_permutation(
|
190
|
+
idx, total_size=num_patches)
|
191
|
+
x_patches = x_patches[:, idx, :]
|
192
|
+
else:
|
193
|
+
x_patches = self.patch_embed(x)
|
194
|
+
|
195
|
+
assert num_patches <= self.pos_encoding.shape[
|
196
|
+
1], f"Number of patches {num_patches} exceeds max_len {self.pos_encoding.shape[1]} in positional encoding"
|
197
|
+
x_patches = x_patches + self.pos_encoding[:, :num_patches, :]
|
198
|
+
|
199
|
+
time_token = self.time_embed(temb.astype(
|
200
|
+
jnp.float32))
|
201
|
+
time_token = jnp.expand_dims(time_token.astype(
|
202
|
+
self.dtype), axis=1)
|
203
|
+
|
204
|
+
if textcontext is not None:
|
205
|
+
text_tokens = self.text_proj(
|
206
|
+
textcontext.astype(self.dtype))
|
207
|
+
num_text_tokens = text_tokens.shape[1]
|
208
|
+
x = jnp.concatenate([x_patches, time_token, text_tokens], axis=1)
|
209
|
+
else:
|
210
|
+
num_text_tokens = 0
|
211
|
+
x = jnp.concatenate([x_patches, time_token], axis=1)
|
212
|
+
|
112
213
|
skips = []
|
113
|
-
# In blocks
|
114
214
|
for i in range(self.num_layers // 2):
|
115
|
-
x =
|
116
|
-
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
117
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
118
|
-
only_pure_attention=False,
|
119
|
-
norm_inputs=self.norm_inputs,
|
120
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
121
|
-
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
122
|
-
)(x)
|
215
|
+
x = self.down_blocks[i](x)
|
123
216
|
skips.append(x)
|
124
|
-
|
125
|
-
|
126
|
-
|
127
|
-
dtype=self.dtype, precision=self.precision, use_projection=self.use_projection,
|
128
|
-
use_flash_attention=self.use_flash_attention, use_self_and_cross=False, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
129
|
-
only_pure_attention=False,
|
130
|
-
norm_inputs=self.norm_inputs,
|
131
|
-
explicitly_add_residual=self.explicitly_add_residual,
|
132
|
-
norm_epsilon=self.norm_epsilon, # Pass epsilon
|
133
|
-
)(x)
|
134
|
-
|
135
|
-
# Out blocks
|
217
|
+
|
218
|
+
x = self.mid_block(x)
|
219
|
+
|
136
220
|
for i in range(self.num_layers // 2):
|
137
|
-
|
138
|
-
x =
|
139
|
-
|
140
|
-
x =
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
|
148
|
-
|
149
|
-
x = self.norm()(x) # Uses norm_epsilon defined in setup
|
150
|
-
|
151
|
-
patch_dim = self.patch_size ** 2 * self.output_channels
|
152
|
-
x = nn.Dense(features=patch_dim, dtype=self.dtype, precision=self.precision)(x)
|
153
|
-
# If Hilbert, restore original patch order
|
221
|
+
skip_conn = skips.pop()
|
222
|
+
x = jnp.concatenate([x, skip_conn], axis=-1)
|
223
|
+
x = self.up_dense[i](x)
|
224
|
+
x = self.up_blocks[i](x)
|
225
|
+
|
226
|
+
x = self.final_norm(x)
|
227
|
+
|
228
|
+
x_patches_out = x[:, :num_patches, :]
|
229
|
+
|
230
|
+
x_patches_out = self.final_proj(x_patches_out)
|
231
|
+
|
154
232
|
if self.use_hilbert:
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
233
|
+
assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
|
234
|
+
x_image = hilbert_unpatchify(
|
235
|
+
x_patches_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
236
|
+
else:
|
237
|
+
x_image = unpatchify(x_patches_out, channels=self.output_channels)
|
238
|
+
|
160
239
|
if self.add_residualblock_output:
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
240
|
+
x_image = jnp.concatenate(
|
241
|
+
[original_img.astype(self.dtype), x_image], axis=-1)
|
242
|
+
|
243
|
+
x_image = self.final_conv1(x_image)
|
244
|
+
x_image = self.final_norm_conv(x_image)
|
245
|
+
x_image = self.activation(x_image)
|
246
|
+
x_image = self.final_conv2(x_image)
|
247
|
+
else:
|
248
|
+
pass
|
249
|
+
|
250
|
+
return x_image
|
251
|
+
|
252
|
+
|
253
|
+
# --- Simple U-DiT ---
|
254
|
+
|
255
|
+
class SimpleUDiT(nn.Module):
|
256
|
+
"""
|
257
|
+
A Simple U-Net Diffusion Transformer (U-DiT) implementation.
|
258
|
+
Combines the U-Net structure with DiT blocks using RoPE and AdaLN-Zero conditioning.
|
259
|
+
Based on SimpleDiT and standard U-Net principles.
|
260
|
+
"""
|
261
|
+
output_channels: int = 3
|
262
|
+
patch_size: int = 16
|
263
|
+
emb_features: int = 768
|
264
|
+
num_layers: int = 12 # Should be even for U-Net structure
|
265
|
+
num_heads: int = 12
|
266
|
+
mlp_ratio: int = 4
|
267
|
+
dropout_rate: float = 0.0 # Typically 0 for diffusion
|
268
|
+
dtype: Optional[Dtype] = None # e.g., jnp.float32 or jnp.bfloat16
|
269
|
+
precision: PrecisionLike = None
|
270
|
+
use_flash_attention: bool = False # Passed to DiTBlock -> RoPEAttention
|
271
|
+
force_fp32_for_softmax: bool = True # Passed to DiTBlock -> RoPEAttention
|
272
|
+
norm_epsilon: float = 1e-5
|
273
|
+
learn_sigma: bool = False
|
274
|
+
use_hilbert: bool = False
|
275
|
+
norm_groups: int = 0
|
276
|
+
activation: Callable = jax.nn.swish
|
277
|
+
|
278
|
+
def setup(self):
|
279
|
+
assert self.num_layers % 2 == 0, "num_layers must be even for U-Net structure"
|
280
|
+
half_layers = self.num_layers // 2
|
281
|
+
|
282
|
+
self.patch_embed = PatchEmbedding(
|
283
|
+
patch_size=self.patch_size,
|
284
|
+
embedding_dim=self.emb_features,
|
285
|
+
dtype=self.dtype,
|
286
|
+
precision=self.precision,
|
287
|
+
name="patch_embed"
|
288
|
+
)
|
289
|
+
if self.use_hilbert:
|
290
|
+
self.hilbert_proj = nn.Dense(
|
291
|
+
features=self.emb_features,
|
170
292
|
dtype=self.dtype,
|
171
|
-
precision=self.precision
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
293
|
+
precision=self.precision,
|
294
|
+
name="hilbert_projection"
|
295
|
+
)
|
296
|
+
|
297
|
+
self.time_embed = nn.Sequential([
|
298
|
+
FourierEmbedding(features=self.emb_features),
|
299
|
+
TimeProjection(features=self.emb_features * self.mlp_ratio),
|
300
|
+
nn.Dense(features=self.emb_features, dtype=self.dtype, precision=self.precision)
|
301
|
+
], name="time_embed")
|
302
|
+
|
303
|
+
self.text_proj = nn.Dense(
|
304
|
+
features=self.emb_features,
|
305
|
+
dtype=self.dtype,
|
306
|
+
precision=self.precision,
|
307
|
+
name="text_proj"
|
308
|
+
)
|
309
|
+
|
310
|
+
max_patches = (512 // self.patch_size)**2
|
311
|
+
self.rope = RotaryEmbedding(
|
312
|
+
dim=self.emb_features // self.num_heads,
|
313
|
+
max_seq_len=max_patches,
|
314
|
+
dtype=self.dtype,
|
315
|
+
name="rope_emb"
|
316
|
+
)
|
317
|
+
|
318
|
+
self.down_blocks = [
|
319
|
+
DiTBlock(
|
320
|
+
features=self.emb_features,
|
321
|
+
num_heads=self.num_heads,
|
322
|
+
mlp_ratio=self.mlp_ratio,
|
323
|
+
dropout_rate=self.dropout_rate,
|
324
|
+
dtype=self.dtype,
|
325
|
+
precision=self.precision,
|
326
|
+
use_flash_attention=self.use_flash_attention,
|
327
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
328
|
+
norm_epsilon=self.norm_epsilon,
|
329
|
+
rope_emb=self.rope,
|
330
|
+
name=f"down_block_{i}"
|
331
|
+
) for i in range(half_layers)
|
332
|
+
]
|
333
|
+
|
334
|
+
self.mid_block = DiTBlock(
|
335
|
+
features=self.emb_features,
|
336
|
+
num_heads=self.num_heads,
|
337
|
+
mlp_ratio=self.mlp_ratio,
|
338
|
+
dropout_rate=self.dropout_rate,
|
183
339
|
dtype=self.dtype,
|
184
|
-
precision=self.precision
|
185
|
-
|
186
|
-
|
340
|
+
precision=self.precision,
|
341
|
+
use_flash_attention=self.use_flash_attention,
|
342
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
343
|
+
norm_epsilon=self.norm_epsilon,
|
344
|
+
rope_emb=self.rope,
|
345
|
+
name="mid_block"
|
346
|
+
)
|
347
|
+
|
348
|
+
self.up_dense = [
|
349
|
+
nn.DenseGeneral(
|
350
|
+
features=self.emb_features,
|
351
|
+
dtype=self.dtype,
|
352
|
+
precision=self.precision,
|
353
|
+
name=f"up_dense_{i}"
|
354
|
+
) for i in range(half_layers)
|
355
|
+
]
|
356
|
+
self.up_blocks = [
|
357
|
+
DiTBlock(
|
358
|
+
features=self.emb_features,
|
359
|
+
num_heads=self.num_heads,
|
360
|
+
mlp_ratio=self.mlp_ratio,
|
361
|
+
dropout_rate=self.dropout_rate,
|
362
|
+
dtype=self.dtype,
|
363
|
+
precision=self.precision,
|
364
|
+
use_flash_attention=self.use_flash_attention,
|
365
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
366
|
+
norm_epsilon=self.norm_epsilon,
|
367
|
+
rope_emb=self.rope,
|
368
|
+
name=f"up_block_{i}"
|
369
|
+
) for i in range(half_layers)
|
370
|
+
]
|
371
|
+
|
372
|
+
self.final_norm = nn.LayerNorm(
|
373
|
+
epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
374
|
+
|
375
|
+
output_dim = self.patch_size * self.patch_size * self.output_channels
|
376
|
+
if self.learn_sigma:
|
377
|
+
output_dim *= 2
|
378
|
+
|
379
|
+
self.final_proj = nn.Dense(
|
380
|
+
features=output_dim,
|
381
|
+
dtype=jnp.float32,
|
382
|
+
precision=self.precision,
|
383
|
+
kernel_init=nn.initializers.zeros,
|
384
|
+
name="final_proj"
|
385
|
+
)
|
386
|
+
|
387
|
+
@nn.compact
|
388
|
+
def __call__(self, x, temb, textcontext=None):
|
389
|
+
B, H, W, C = x.shape
|
390
|
+
H_P = H // self.patch_size
|
391
|
+
W_P = W // self.patch_size
|
392
|
+
num_patches = H_P * W_P
|
393
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
394
|
+
|
395
|
+
x = x.astype(self.dtype)
|
396
|
+
|
397
|
+
hilbert_inv_idx = None
|
398
|
+
if self.use_hilbert:
|
399
|
+
patches_raw, _ = hilbert_patchify(x, self.patch_size)
|
400
|
+
x_seq = self.hilbert_proj(patches_raw)
|
401
|
+
idx = hilbert_indices(H_P, W_P)
|
402
|
+
hilbert_inv_idx = inverse_permutation(idx, total_size=num_patches)
|
403
|
+
else:
|
404
|
+
x_seq = self.patch_embed(x)
|
405
|
+
|
406
|
+
t_emb = self.time_embed(temb.astype(jnp.float32))
|
407
|
+
t_emb = t_emb.astype(self.dtype)
|
408
|
+
|
409
|
+
cond_emb = t_emb
|
410
|
+
if textcontext is not None:
|
411
|
+
text_emb = self.text_proj(textcontext.astype(self.dtype))
|
412
|
+
if text_emb.ndim == 3:
|
413
|
+
text_emb = jnp.mean(text_emb, axis=1)
|
414
|
+
cond_emb = cond_emb + text_emb
|
415
|
+
|
416
|
+
skips = []
|
417
|
+
for i in range(self.num_layers // 2):
|
418
|
+
x_seq = self.down_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
|
419
|
+
skips.append(x_seq)
|
420
|
+
|
421
|
+
x_seq = self.mid_block(x_seq, conditioning=cond_emb, freqs_cis=None)
|
422
|
+
|
423
|
+
for i in range(self.num_layers // 2):
|
424
|
+
skip_conn = skips.pop()
|
425
|
+
x_seq = jnp.concatenate([x_seq, skip_conn], axis=-1)
|
426
|
+
x_seq = self.up_dense[i](x_seq)
|
427
|
+
x_seq = self.up_blocks[i](x_seq, conditioning=cond_emb, freqs_cis=None)
|
428
|
+
|
429
|
+
x_out = self.final_norm(x_seq)
|
430
|
+
x_out = self.final_proj(x_out)
|
431
|
+
|
432
|
+
if self.use_hilbert:
|
433
|
+
assert hilbert_inv_idx is not None, "Hilbert inverse index missing"
|
434
|
+
if self.learn_sigma:
|
435
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
436
|
+
x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
437
|
+
else:
|
438
|
+
x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
439
|
+
else:
|
440
|
+
if self.learn_sigma:
|
441
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
442
|
+
x_image = unpatchify(x_mean, channels=self.output_channels)
|
443
|
+
else:
|
444
|
+
x_image = unpatchify(x_out, channels=self.output_channels)
|
445
|
+
|
446
|
+
return x_image.astype(jnp.float32)
|