flaxdiff 0.2.7__py3-none-any.whl → 0.2.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/dataloaders.py +36 -24
- flaxdiff/data/dataset_map.py +2 -2
- flaxdiff/data/sources/base.py +12 -0
- flaxdiff/data/sources/images.py +68 -11
- flaxdiff/data/sources/videos.py +5 -0
- flaxdiff/models/common.py +1 -70
- flaxdiff/models/hilbert.py +617 -0
- flaxdiff/models/simple_dit.py +476 -0
- flaxdiff/models/simple_mmdit.py +861 -0
- flaxdiff/models/simple_vit.py +278 -117
- flaxdiff/trainer/general_diffusion_trainer.py +29 -10
- flaxdiff/trainer/simple_trainer.py +113 -19
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/METADATA +1 -1
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/RECORD +16 -14
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/WHEEL +1 -1
- flaxdiff/models/better_uvit.py +0 -380
- {flaxdiff-0.2.7.dist-info → flaxdiff-0.2.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,476 @@
|
|
1
|
+
from .simple_vit import PatchEmbedding, unpatchify
|
2
|
+
import jax
|
3
|
+
import jax.numpy as jnp
|
4
|
+
from flax import linen as nn
|
5
|
+
from typing import Callable, Any, Optional, Tuple, Sequence, Union
|
6
|
+
import einops
|
7
|
+
from functools import partial
|
8
|
+
|
9
|
+
# Re-use existing components if they are suitable
|
10
|
+
from .common import kernel_init, FourierEmbedding, TimeProjection
|
11
|
+
# Using NormalAttention for RoPE integration
|
12
|
+
from .attention import NormalAttention
|
13
|
+
from flax.typing import Dtype, PrecisionLike
|
14
|
+
|
15
|
+
# Use our improved Hilbert implementation
|
16
|
+
from .hilbert import hilbert_indices, inverse_permutation, hilbert_patchify, hilbert_unpatchify
|
17
|
+
|
18
|
+
# --- Rotary Positional Embedding (RoPE) ---
|
19
|
+
# Adapted from https://github.com/google-deepmind/ring_attention/blob/main/ring_attention/layers/rotary.py
|
20
|
+
|
21
|
+
|
22
|
+
def _rotate_half(x: jax.Array) -> jax.Array:
|
23
|
+
"""Rotates half the hidden dims of the input."""
|
24
|
+
x1 = x[..., : x.shape[-1] // 2]
|
25
|
+
x2 = x[..., x.shape[-1] // 2:]
|
26
|
+
return jnp.concatenate((-x2, x1), axis=-1)
|
27
|
+
|
28
|
+
def apply_rotary_embedding(
|
29
|
+
x: jax.Array, freqs_cos: jax.Array, freqs_sin: jax.Array
|
30
|
+
) -> jax.Array:
|
31
|
+
"""Applies rotary embedding to the input tensor using rotate_half method."""
|
32
|
+
# x shape: [..., Sequence, Dimension] e.g. [B, H, S, D] or [B, S, D]
|
33
|
+
# freqs_cos/sin shape: [Sequence, Dimension / 2]
|
34
|
+
|
35
|
+
# Expand dims for broadcasting: [1, 1, S, D/2] or [1, S, D/2]
|
36
|
+
if x.ndim == 4: # [B, H, S, D]
|
37
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=(0, 1))
|
38
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=(0, 1))
|
39
|
+
elif x.ndim == 3: # [B, S, D]
|
40
|
+
cos_freqs = jnp.expand_dims(freqs_cos, axis=0)
|
41
|
+
sin_freqs = jnp.expand_dims(freqs_sin, axis=0)
|
42
|
+
|
43
|
+
# Duplicate cos and sin for the full dimension D
|
44
|
+
# Shape becomes [..., S, D]
|
45
|
+
cos_freqs = jnp.concatenate([cos_freqs, cos_freqs], axis=-1)
|
46
|
+
sin_freqs = jnp.concatenate([sin_freqs, sin_freqs], axis=-1)
|
47
|
+
|
48
|
+
# Apply rotation: x * cos + rotate_half(x) * sin
|
49
|
+
x_rotated = x * cos_freqs + _rotate_half(x) * sin_freqs
|
50
|
+
return x_rotated.astype(x.dtype)
|
51
|
+
|
52
|
+
class RotaryEmbedding(nn.Module):
|
53
|
+
dim: int # Dimension of the head
|
54
|
+
max_seq_len: int = 2048
|
55
|
+
base: int = 10000
|
56
|
+
dtype: Dtype = jnp.float32
|
57
|
+
|
58
|
+
def setup(self):
|
59
|
+
inv_freq = 1.0 / (
|
60
|
+
self.base ** (jnp.arange(0, self.dim, 2,
|
61
|
+
dtype=jnp.float32) / self.dim)
|
62
|
+
)
|
63
|
+
t = jnp.arange(self.max_seq_len, dtype=jnp.float32)
|
64
|
+
freqs = jnp.outer(t, inv_freq) # Shape: [max_seq_len, dim / 2]
|
65
|
+
|
66
|
+
# Store cosine and sine separately instead of as complex numbers
|
67
|
+
self.freqs_cos = jnp.cos(freqs) # Shape: [max_seq_len, dim / 2]
|
68
|
+
self.freqs_sin = jnp.sin(freqs) # Shape: [max_seq_len, dim / 2]
|
69
|
+
|
70
|
+
def __call__(self, seq_len: int):
|
71
|
+
if seq_len > self.max_seq_len:
|
72
|
+
raise ValueError(
|
73
|
+
f"Sequence length {seq_len} exceeds max_seq_len {self.max_seq_len}")
|
74
|
+
# Return separate cos and sin components
|
75
|
+
return self.freqs_cos[:seq_len, :], self.freqs_sin[:seq_len, :]
|
76
|
+
# --- Attention with RoPE ---
|
77
|
+
|
78
|
+
|
79
|
+
class RoPEAttention(NormalAttention):
|
80
|
+
rope_emb: RotaryEmbedding = None # Instance of RotaryEmbedding
|
81
|
+
|
82
|
+
@nn.compact
|
83
|
+
def __call__(self, x, context=None, freqs_cis=None):
|
84
|
+
# x has shape [B, H, W, C] or [B, S, C]
|
85
|
+
orig_x_shape = x.shape
|
86
|
+
is_4d = len(orig_x_shape) == 4
|
87
|
+
if is_4d:
|
88
|
+
B, H, W, C = x.shape
|
89
|
+
seq_len = H * W
|
90
|
+
x = x.reshape((B, seq_len, C))
|
91
|
+
else:
|
92
|
+
B, seq_len, C = x.shape
|
93
|
+
|
94
|
+
context = x if context is None else context
|
95
|
+
if len(context.shape) == 4:
|
96
|
+
_B, _H, _W, _C = context.shape
|
97
|
+
context_seq_len = _H * _W
|
98
|
+
context = context.reshape((B, context_seq_len, _C))
|
99
|
+
# else: context is already [B, S_ctx, C]
|
100
|
+
|
101
|
+
query = self.query(x) # [B, S, H, D]
|
102
|
+
key = self.key(context) # [B, S_ctx, H, D]
|
103
|
+
value = self.value(context) # [B, S_ctx, H, D]
|
104
|
+
|
105
|
+
# Apply RoPE to query and key
|
106
|
+
if freqs_cis is None:
|
107
|
+
# Generate frequencies using the rope_emb instance
|
108
|
+
seq_len_q = query.shape[1] # Use query's sequence length
|
109
|
+
freqs_cos, freqs_sin = self.rope_emb(seq_len_q)
|
110
|
+
else:
|
111
|
+
# If freqs_cis is passed in as a tuple
|
112
|
+
freqs_cos, freqs_sin = freqs_cis
|
113
|
+
|
114
|
+
# Apply RoPE to query and key
|
115
|
+
# Permute to [B, H, S, D] for RoPE application
|
116
|
+
query = einops.rearrange(query, 'b s h d -> b h s d')
|
117
|
+
key = einops.rearrange(key, 'b s h d -> b h s d')
|
118
|
+
|
119
|
+
# Apply RoPE only up to the context sequence length for keys if different
|
120
|
+
# Assuming self-attention or context has same seq len for simplicity here
|
121
|
+
query = apply_rotary_embedding(query, freqs_cos, freqs_sin)
|
122
|
+
key = apply_rotary_embedding(key, freqs_cos, freqs_sin) # Apply same freqs to key
|
123
|
+
|
124
|
+
# Permute back to [B, S, H, D] for dot_product_attention
|
125
|
+
query = einops.rearrange(query, 'b h s d -> b s h d')
|
126
|
+
key = einops.rearrange(key, 'b h s d -> b s h d')
|
127
|
+
|
128
|
+
hidden_states = nn.dot_product_attention(
|
129
|
+
query, key, value, dtype=self.dtype, broadcast_dropout=False,
|
130
|
+
dropout_rng=None, precision=self.precision, force_fp32_for_softmax=self.force_fp32_for_softmax,
|
131
|
+
deterministic=True
|
132
|
+
) # Output shape [B, S, H, D]
|
133
|
+
|
134
|
+
# Use the proj_attn from NormalAttention which expects [B, S, H, D]
|
135
|
+
proj = self.proj_attn(hidden_states) # Output shape [B, S, C]
|
136
|
+
|
137
|
+
if is_4d:
|
138
|
+
proj = proj.reshape(orig_x_shape) # Reshape back if input was 4D
|
139
|
+
|
140
|
+
return proj
|
141
|
+
|
142
|
+
# --- adaLN-Zero ---
|
143
|
+
|
144
|
+
|
145
|
+
class AdaLNZero(nn.Module):
|
146
|
+
features: int
|
147
|
+
dtype: Optional[Dtype] = None
|
148
|
+
precision: PrecisionLike = None
|
149
|
+
norm_epsilon: float = 1e-5 # Standard LayerNorm epsilon
|
150
|
+
|
151
|
+
@nn.compact
|
152
|
+
def __call__(self, x, conditioning):
|
153
|
+
# Project conditioning signal to get scale and shift parameters
|
154
|
+
# Conditioning shape: [B, D_cond] -> [B, 1, ..., 1, 6 * features] for broadcasting
|
155
|
+
# Or [B, 1, 6*features] if x is [B, S, F]
|
156
|
+
|
157
|
+
# Ensure conditioning has seq dim if x does
|
158
|
+
# x=[B,S,F], cond=[B,D_cond]
|
159
|
+
if x.ndim == 3 and conditioning.ndim == 2:
|
160
|
+
conditioning = jnp.expand_dims(
|
161
|
+
conditioning, axis=1) # cond=[B,1,D_cond]
|
162
|
+
|
163
|
+
# Project conditioning to get 6 params per feature (scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn)
|
164
|
+
# Using nn.DenseGeneral for flexibility if needed, but nn.Dense is fine if cond is [B, D_cond] or [B, 1, D_cond]
|
165
|
+
ada_params = nn.Dense(
|
166
|
+
features=6 * self.features,
|
167
|
+
dtype=self.dtype,
|
168
|
+
precision=self.precision,
|
169
|
+
# Initialize projection to zero (Zero init)
|
170
|
+
kernel_init=nn.initializers.zeros,
|
171
|
+
name="ada_proj"
|
172
|
+
)(conditioning)
|
173
|
+
|
174
|
+
# Split into scale, shift, gate for MLP and Attention
|
175
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
176
|
+
ada_params, 6, axis=-1)
|
177
|
+
|
178
|
+
scale_mlp = jnp.clip(scale_mlp, -10.0, 10.0)
|
179
|
+
shift_mlp = jnp.clip(shift_mlp, -10.0, 10.0)
|
180
|
+
# Apply Layer Normalization
|
181
|
+
norm = nn.LayerNorm(epsilon=self.norm_epsilon,
|
182
|
+
use_scale=False, use_bias=False, dtype=self.dtype)
|
183
|
+
# norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype) # Alternative: RMSNorm
|
184
|
+
|
185
|
+
norm_x = norm(x)
|
186
|
+
|
187
|
+
# Modulate for Attention path
|
188
|
+
x_attn = norm_x * (1 + scale_attn) + shift_attn
|
189
|
+
|
190
|
+
# Modulate for MLP path
|
191
|
+
x_mlp = norm_x * (1 + scale_mlp) + shift_mlp
|
192
|
+
|
193
|
+
# Return modulated outputs and gates
|
194
|
+
return x_attn, gate_attn, x_mlp, gate_mlp
|
195
|
+
|
196
|
+
class AdaLNParams(nn.Module): # Renamed for clarity
|
197
|
+
features: int
|
198
|
+
dtype: Optional[Dtype] = None
|
199
|
+
precision: PrecisionLike = None
|
200
|
+
|
201
|
+
@nn.compact
|
202
|
+
def __call__(self, conditioning):
|
203
|
+
# Ensure conditioning is broadcastable if needed (e.g., [B, 1, D_cond])
|
204
|
+
if conditioning.ndim == 2:
|
205
|
+
conditioning = jnp.expand_dims(conditioning, axis=1)
|
206
|
+
|
207
|
+
# Project conditioning to get 6 params per feature
|
208
|
+
ada_params = nn.Dense(
|
209
|
+
features=6 * self.features,
|
210
|
+
dtype=self.dtype,
|
211
|
+
precision=self.precision,
|
212
|
+
kernel_init=nn.initializers.zeros,
|
213
|
+
name="ada_proj"
|
214
|
+
)(conditioning)
|
215
|
+
# Return all params (or split if preferred, but maybe return tuple/dict)
|
216
|
+
# Shape: [B, 1, 6*F]
|
217
|
+
return ada_params # Or split and return tuple: jnp.split(ada_params, 6, axis=-1)
|
218
|
+
|
219
|
+
# --- DiT Block ---
|
220
|
+
class DiTBlock(nn.Module):
|
221
|
+
features: int
|
222
|
+
num_heads: int
|
223
|
+
rope_emb: RotaryEmbedding
|
224
|
+
mlp_ratio: int = 4
|
225
|
+
dropout_rate: float = 0.0
|
226
|
+
dtype: Optional[Dtype] = None
|
227
|
+
precision: PrecisionLike = None
|
228
|
+
use_flash_attention: bool = False # Keep placeholder
|
229
|
+
force_fp32_for_softmax: bool = True
|
230
|
+
norm_epsilon: float = 1e-5
|
231
|
+
use_gating: bool = True # Add flag to easily disable gating
|
232
|
+
|
233
|
+
def setup(self):
|
234
|
+
hidden_features = int(self.features * self.mlp_ratio)
|
235
|
+
# Get modulation parameters (scale, shift, gates)
|
236
|
+
self.ada_params_module = AdaLNParams( # Use the modified module
|
237
|
+
self.features, dtype=self.dtype, precision=self.precision)
|
238
|
+
|
239
|
+
# Layer Norms - one before Attn, one before MLP
|
240
|
+
self.norm1 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm1")
|
241
|
+
self.norm2 = nn.LayerNorm(epsilon=self.norm_epsilon, use_scale=False, use_bias=False, dtype=self.dtype, name="norm2")
|
242
|
+
|
243
|
+
self.attention = RoPEAttention(
|
244
|
+
query_dim=self.features,
|
245
|
+
heads=self.num_heads,
|
246
|
+
dim_head=self.features // self.num_heads,
|
247
|
+
dtype=self.dtype,
|
248
|
+
precision=self.precision,
|
249
|
+
use_bias=True,
|
250
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
251
|
+
rope_emb=self.rope_emb
|
252
|
+
)
|
253
|
+
|
254
|
+
self.mlp = nn.Sequential([
|
255
|
+
nn.Dense(features=hidden_features, dtype=self.dtype, precision=self.precision),
|
256
|
+
nn.gelu, # Or swish as specified in SimpleDiT? Consider consistency.
|
257
|
+
nn.Dense(features=self.features, dtype=self.dtype, precision=self.precision)
|
258
|
+
])
|
259
|
+
|
260
|
+
@nn.compact
|
261
|
+
def __call__(self, x, conditioning, freqs_cis):
|
262
|
+
# Get scale/shift/gate parameters
|
263
|
+
# Shape: [B, 1, 6*F] -> split into 6 of [B, 1, F]
|
264
|
+
scale_mlp, shift_mlp, gate_mlp, scale_attn, shift_attn, gate_attn = jnp.split(
|
265
|
+
self.ada_params_module(conditioning), 6, axis=-1
|
266
|
+
)
|
267
|
+
|
268
|
+
# --- Attention Path ---
|
269
|
+
residual = x
|
270
|
+
norm_x_attn = self.norm1(x)
|
271
|
+
# Modulate after norm
|
272
|
+
x_attn_modulated = norm_x_attn * (1 + scale_attn) + shift_attn
|
273
|
+
attn_output = self.attention(x_attn_modulated, context=None, freqs_cis=freqs_cis)
|
274
|
+
|
275
|
+
if self.use_gating:
|
276
|
+
x = residual + gate_attn * attn_output
|
277
|
+
else:
|
278
|
+
x = residual + attn_output # Original DiT style without gate
|
279
|
+
|
280
|
+
# --- MLP Path ---
|
281
|
+
residual = x
|
282
|
+
norm_x_mlp = self.norm2(x) # Apply second LayerNorm
|
283
|
+
# Modulate after norm
|
284
|
+
x_mlp_modulated = norm_x_mlp * (1 + scale_mlp) + shift_mlp
|
285
|
+
mlp_output = self.mlp(x_mlp_modulated)
|
286
|
+
|
287
|
+
if self.use_gating:
|
288
|
+
x = residual + gate_mlp * mlp_output
|
289
|
+
else:
|
290
|
+
x = residual + mlp_output # Original DiT style without gate
|
291
|
+
|
292
|
+
return x
|
293
|
+
|
294
|
+
|
295
|
+
# --- Patch Embedding (reuse or define if needed) ---
|
296
|
+
# Assuming PatchEmbedding exists in simple_vit.py and is suitable
|
297
|
+
|
298
|
+
# --- DiT ---
|
299
|
+
|
300
|
+
class SimpleDiT(nn.Module):
|
301
|
+
output_channels: int = 3
|
302
|
+
patch_size: int = 16
|
303
|
+
emb_features: int = 768
|
304
|
+
num_layers: int = 12
|
305
|
+
num_heads: int = 12
|
306
|
+
mlp_ratio: int = 4
|
307
|
+
dropout_rate: float = 0.0 # Typically 0 for diffusion
|
308
|
+
dtype: Optional[Dtype] = None
|
309
|
+
precision: PrecisionLike = None
|
310
|
+
# Passed down, but RoPEAttention uses NormalAttention
|
311
|
+
use_flash_attention: bool = False
|
312
|
+
force_fp32_for_softmax: bool = True
|
313
|
+
norm_epsilon: float = 1e-5
|
314
|
+
learn_sigma: bool = False # Option to predict sigma like in DiT paper
|
315
|
+
use_hilbert: bool = False # Toggle Hilbert patch reorder
|
316
|
+
norm_groups: int = 0
|
317
|
+
activation: Callable = jax.nn.swish
|
318
|
+
|
319
|
+
def setup(self):
|
320
|
+
self.patch_embed = PatchEmbedding(
|
321
|
+
patch_size=self.patch_size,
|
322
|
+
embedding_dim=self.emb_features,
|
323
|
+
dtype=self.dtype,
|
324
|
+
precision=self.precision
|
325
|
+
)
|
326
|
+
|
327
|
+
# Add projection layer for Hilbert patches
|
328
|
+
if self.use_hilbert:
|
329
|
+
self.hilbert_proj = nn.Dense(
|
330
|
+
features=self.emb_features,
|
331
|
+
dtype=self.dtype,
|
332
|
+
precision=self.precision,
|
333
|
+
name="hilbert_projection"
|
334
|
+
)
|
335
|
+
|
336
|
+
# Time embedding projection
|
337
|
+
self.time_embed = nn.Sequential([
|
338
|
+
FourierEmbedding(features=self.emb_features),
|
339
|
+
TimeProjection(features=self.emb_features *
|
340
|
+
self.mlp_ratio), # Project to MLP dim
|
341
|
+
nn.Dense(features=self.emb_features, dtype=self.dtype,
|
342
|
+
precision=self.precision) # Final projection
|
343
|
+
])
|
344
|
+
|
345
|
+
# Text context projection (if used)
|
346
|
+
# Assuming textcontext is already projected to some dimension, project it to match emb_features
|
347
|
+
# This might need adjustment based on how text context is provided
|
348
|
+
self.text_proj = nn.Dense(features=self.emb_features, dtype=self.dtype,
|
349
|
+
precision=self.precision, name="text_context_proj")
|
350
|
+
|
351
|
+
# Rotary Positional Embedding
|
352
|
+
# Max length needs to be estimated or set large enough.
|
353
|
+
# For images, seq len = (H/P) * (W/P). Example: 256/16 * 256/16 = 16*16 = 256
|
354
|
+
# Add 1 if a class token is used, or more for text tokens if concatenated.
|
355
|
+
# Let's assume max seq len accommodates patches + time + text tokens if needed, or just patches.
|
356
|
+
# If only patches use RoPE, max_len = max_image_tokens
|
357
|
+
# If time/text are concatenated *before* blocks, max_len needs to include them.
|
358
|
+
# DiT typically applies PE only to patch tokens. Let's follow that.
|
359
|
+
# max_len should be max number of patches.
|
360
|
+
# Example: max image size 512x512, patch 16 -> (512/16)^2 = 32^2 = 1024 patches
|
361
|
+
self.rope = RotaryEmbedding(
|
362
|
+
dim=self.emb_features // self.num_heads, max_seq_len=4096, dtype=self.dtype) # Dim per head
|
363
|
+
|
364
|
+
# Transformer Blocks
|
365
|
+
self.blocks = [
|
366
|
+
DiTBlock(
|
367
|
+
features=self.emb_features,
|
368
|
+
num_heads=self.num_heads,
|
369
|
+
mlp_ratio=self.mlp_ratio,
|
370
|
+
dropout_rate=self.dropout_rate,
|
371
|
+
dtype=self.dtype,
|
372
|
+
precision=self.precision,
|
373
|
+
use_flash_attention=self.use_flash_attention,
|
374
|
+
force_fp32_for_softmax=self.force_fp32_for_softmax,
|
375
|
+
norm_epsilon=self.norm_epsilon,
|
376
|
+
rope_emb=self.rope, # Pass RoPE instance
|
377
|
+
name=f"dit_block_{i}"
|
378
|
+
) for i in range(self.num_layers)
|
379
|
+
]
|
380
|
+
|
381
|
+
# Final Layer (Normalization + Linear Projection)
|
382
|
+
self.final_norm = nn.LayerNorm(
|
383
|
+
epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
384
|
+
# self.final_norm = nn.RMSNorm(epsilon=self.norm_epsilon, dtype=self.dtype, name="final_norm")
|
385
|
+
|
386
|
+
# Predict patch pixels + potentially sigma
|
387
|
+
output_dim = self.patch_size * self.patch_size * self.output_channels
|
388
|
+
if self.learn_sigma:
|
389
|
+
output_dim *= 2 # Predict both mean and variance (or log_variance)
|
390
|
+
|
391
|
+
self.final_proj = nn.Dense(
|
392
|
+
features=output_dim,
|
393
|
+
dtype=self.dtype,
|
394
|
+
precision=self.precision,
|
395
|
+
kernel_init=nn.initializers.zeros, # Initialize final layer to zero
|
396
|
+
name="final_proj"
|
397
|
+
)
|
398
|
+
|
399
|
+
@nn.compact
|
400
|
+
def __call__(self, x, temb, textcontext=None):
|
401
|
+
B, H, W, C = x.shape
|
402
|
+
assert H % self.patch_size == 0 and W % self.patch_size == 0, "Image dimensions must be divisible by patch size"
|
403
|
+
|
404
|
+
# Compute dimensions in terms of patches
|
405
|
+
H_P = H // self.patch_size
|
406
|
+
W_P = W // self.patch_size
|
407
|
+
|
408
|
+
# 1. Patch Embedding
|
409
|
+
if self.use_hilbert:
|
410
|
+
# Use hilbert_patchify which handles both patchification and reordering
|
411
|
+
patches_raw, hilbert_inv_idx = hilbert_patchify(x, self.patch_size) # Shape [B, S, P*P*C]
|
412
|
+
# Apply projection
|
413
|
+
patches = self.hilbert_proj(patches_raw) # Shape [B, S, emb_features]
|
414
|
+
else:
|
415
|
+
patches = self.patch_embed(x) # Shape: [B, num_patches, emb_features]
|
416
|
+
hilbert_inv_idx = None
|
417
|
+
|
418
|
+
num_patches = patches.shape[1]
|
419
|
+
x_seq = patches
|
420
|
+
|
421
|
+
# 2. Prepare Conditioning Signal (Time + Text Context)
|
422
|
+
t_emb = self.time_embed(temb) # Shape: [B, emb_features]
|
423
|
+
|
424
|
+
cond_emb = t_emb
|
425
|
+
if textcontext is not None:
|
426
|
+
# Shape: [B, num_text_tokens, emb_features]
|
427
|
+
text_emb = self.text_proj(textcontext)
|
428
|
+
# Pool or select text embedding (e.g., mean pool or use CLS token)
|
429
|
+
# Assuming mean pooling for simplicity
|
430
|
+
# Shape: [B, emb_features]
|
431
|
+
text_emb_pooled = jnp.mean(text_emb, axis=1)
|
432
|
+
cond_emb = cond_emb + text_emb_pooled # Combine time and text embeddings
|
433
|
+
|
434
|
+
# 3. Apply RoPE
|
435
|
+
# Get RoPE frequencies for the sequence length (number of patches)
|
436
|
+
# Shape [num_patches, D_head/2]
|
437
|
+
freqs_cos, freqs_sin = self.rope(seq_len=num_patches)
|
438
|
+
|
439
|
+
# 4. Apply Transformer Blocks with adaLN-Zero conditioning
|
440
|
+
for block in self.blocks:
|
441
|
+
x_seq = block(x_seq, conditioning=cond_emb, freqs_cis=(freqs_cos, freqs_sin))
|
442
|
+
|
443
|
+
# 5. Final Layer
|
444
|
+
x_out = self.final_norm(x_seq)
|
445
|
+
# Shape: [B, num_patches, patch_pixels (*2 if learn_sigma)]
|
446
|
+
x_out = self.final_proj(x_out)
|
447
|
+
|
448
|
+
# 6. Unpatchify
|
449
|
+
if self.use_hilbert:
|
450
|
+
# For Hilbert mode, we need to use the specialized unpatchify function
|
451
|
+
if self.learn_sigma:
|
452
|
+
# Split into mean and variance predictions
|
453
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
454
|
+
x_image = hilbert_unpatchify(x_mean, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
455
|
+
# If needed, also unpack the logvar
|
456
|
+
# logvar_image = hilbert_unpatchify(x_logvar, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
457
|
+
# return x_image, logvar_image
|
458
|
+
return x_image
|
459
|
+
else:
|
460
|
+
x_image = hilbert_unpatchify(x_out, hilbert_inv_idx, self.patch_size, H, W, self.output_channels)
|
461
|
+
return x_image
|
462
|
+
else:
|
463
|
+
# Standard patch ordering - use the existing unpatchify function
|
464
|
+
if self.learn_sigma:
|
465
|
+
# Split into mean and variance predictions
|
466
|
+
x_mean, x_logvar = jnp.split(x_out, 2, axis=-1)
|
467
|
+
x = unpatchify(x_mean, channels=self.output_channels)
|
468
|
+
# Return both mean and logvar if needed by the loss function
|
469
|
+
# For now, just returning the mean prediction like standard diffusion models
|
470
|
+
# logvar = unpatchify(x_logvar, channels=self.output_channels)
|
471
|
+
# return x, logvar
|
472
|
+
return x
|
473
|
+
else:
|
474
|
+
# Shape: [B, H, W, C]
|
475
|
+
x = unpatchify(x_out, channels=self.output_channels)
|
476
|
+
return x
|