flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,446 @@
|
|
1
|
+
from typing import Dict, Optional, Tuple, Union
|
2
|
+
|
3
|
+
import flax
|
4
|
+
import flax.linen as nn
|
5
|
+
import jax
|
6
|
+
import jax.numpy as jnp
|
7
|
+
from flax.core.frozen_dict import FrozenDict
|
8
|
+
|
9
|
+
from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
|
10
|
+
from diffusers.utils import BaseOutput
|
11
|
+
from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
|
12
|
+
from diffusers.models.modeling_flax_utils import FlaxModelMixin
|
13
|
+
|
14
|
+
from .unet_3d_blocks import (
|
15
|
+
FlaxCrossAttnDownBlock3D,
|
16
|
+
FlaxCrossAttnUpBlock3D,
|
17
|
+
FlaxDownBlock3D,
|
18
|
+
FlaxUNetMidBlock3DCrossAttn,
|
19
|
+
FlaxUpBlock3D,
|
20
|
+
)
|
21
|
+
|
22
|
+
|
23
|
+
@flax_register_to_config
|
24
|
+
class FlaxUNet3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
25
|
+
r"""
|
26
|
+
A conditional 3D UNet model for video diffusion.
|
27
|
+
|
28
|
+
Parameters:
|
29
|
+
sample_size (`int` or `Tuple[int, int, int]`, *optional*, defaults to (16, 32, 32)):
|
30
|
+
The spatial and temporal size of the input sample. Can be provided as a single integer for square spatial size and fixed temporal size.
|
31
|
+
in_channels (`int`, *optional*, defaults to 4):
|
32
|
+
The number of channels in the input sample.
|
33
|
+
out_channels (`int`, *optional*, defaults to 4):
|
34
|
+
The number of channels in the output.
|
35
|
+
down_block_types (`Tuple[str]`, *optional*, defaults to ("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")):
|
36
|
+
The tuple of downsample blocks to use.
|
37
|
+
up_block_types (`Tuple[str]`, *optional*, defaults to ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")):
|
38
|
+
The tuple of upsample blocks to use.
|
39
|
+
block_out_channels (`Tuple[int]`, *optional*, defaults to (320, 640, 1280, 1280)):
|
40
|
+
The tuple of output channels for each block.
|
41
|
+
layers_per_block (`int`, *optional*, defaults to 2):
|
42
|
+
The number of layers per block.
|
43
|
+
attention_head_dim (`int`, *optional*, defaults to 8):
|
44
|
+
The dimension of the attention heads.
|
45
|
+
cross_attention_dim (`int`, *optional*, defaults to 1280):
|
46
|
+
The dimension of the cross attention features.
|
47
|
+
dropout (`float`, *optional*, defaults to 0):
|
48
|
+
Dropout probability for down, up and bottleneck blocks.
|
49
|
+
use_linear_projection (`bool`, *optional*, defaults to False):
|
50
|
+
Whether to use linear projection in attention blocks.
|
51
|
+
dtype (`jnp.dtype`, *optional*, defaults to jnp.float32):
|
52
|
+
The dtype of the model weights.
|
53
|
+
flip_sin_to_cos (`bool`, *optional*, defaults to True):
|
54
|
+
Whether to flip the sin to cos in the time embedding.
|
55
|
+
freq_shift (`int`, *optional*, defaults to 0):
|
56
|
+
The frequency shift to apply to the time embedding.
|
57
|
+
use_memory_efficient_attention (`bool`, *optional*, defaults to False):
|
58
|
+
Whether to use memory-efficient attention.
|
59
|
+
split_head_dim (`bool`, *optional*, defaults to False):
|
60
|
+
Whether to split the head dimension into a new axis for the self-attention computation.
|
61
|
+
"""
|
62
|
+
|
63
|
+
sample_size: Union[int, Tuple[int, int, int]] = (16, 32, 32)
|
64
|
+
in_channels: int = 4
|
65
|
+
out_channels: int = 4
|
66
|
+
down_block_types: Tuple[str, ...] = (
|
67
|
+
"CrossAttnDownBlock3D",
|
68
|
+
"CrossAttnDownBlock3D",
|
69
|
+
"CrossAttnDownBlock3D",
|
70
|
+
"DownBlock3D",
|
71
|
+
)
|
72
|
+
up_block_types: Tuple[str, ...] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")
|
73
|
+
block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
|
74
|
+
layers_per_block: int = 2
|
75
|
+
attention_head_dim: Union[int, Tuple[int, ...]] = 8
|
76
|
+
num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
|
77
|
+
cross_attention_dim: int = 1280
|
78
|
+
dropout: float = 0.0
|
79
|
+
use_linear_projection: bool = False
|
80
|
+
dtype: jnp.dtype = jnp.float32
|
81
|
+
flip_sin_to_cos: bool = True
|
82
|
+
freq_shift: int = 0
|
83
|
+
use_memory_efficient_attention: bool = False
|
84
|
+
split_head_dim: bool = False
|
85
|
+
transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
|
86
|
+
addition_embed_type: Optional[str] = None
|
87
|
+
addition_time_embed_dim: Optional[int] = None
|
88
|
+
|
89
|
+
def init_weights(self, rng: jax.Array) -> FrozenDict:
|
90
|
+
# init input tensors
|
91
|
+
if isinstance(self.sample_size, int):
|
92
|
+
sample_size = (self.sample_size, self.sample_size, self.sample_size)
|
93
|
+
else:
|
94
|
+
sample_size = self.sample_size
|
95
|
+
|
96
|
+
# Shape: [batch, frames, height, width, channels]
|
97
|
+
sample_shape = (1, sample_size[0], sample_size[1], sample_size[2], self.in_channels)
|
98
|
+
sample = jnp.zeros(sample_shape, dtype=jnp.float32)
|
99
|
+
timesteps = jnp.ones((1,), dtype=jnp.int32)
|
100
|
+
encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
|
101
|
+
|
102
|
+
params_rng, dropout_rng = jax.random.split(rng)
|
103
|
+
rngs = {"params": params_rng, "dropout": dropout_rng}
|
104
|
+
|
105
|
+
added_cond_kwargs = None
|
106
|
+
if self.addition_embed_type == "text_time":
|
107
|
+
# For text-time conditioning for video diffusion
|
108
|
+
text_embeds_dim = self.cross_attention_dim
|
109
|
+
time_ids_dims = 6 # Default value for video models
|
110
|
+
added_cond_kwargs = {
|
111
|
+
"text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
|
112
|
+
"time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
|
113
|
+
}
|
114
|
+
|
115
|
+
return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
|
116
|
+
|
117
|
+
def setup(self) -> None:
|
118
|
+
block_out_channels = self.block_out_channels
|
119
|
+
time_embed_dim = block_out_channels[0] * 4
|
120
|
+
|
121
|
+
if self.num_attention_heads is not None:
|
122
|
+
raise ValueError(
|
123
|
+
"At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue. "
|
124
|
+
"Use `attention_head_dim` instead."
|
125
|
+
)
|
126
|
+
|
127
|
+
# Default behavior: if num_attention_heads is not set, use attention_head_dim
|
128
|
+
num_attention_heads = self.num_attention_heads or self.attention_head_dim
|
129
|
+
|
130
|
+
# input
|
131
|
+
self.conv_in = nn.Conv(
|
132
|
+
block_out_channels[0],
|
133
|
+
kernel_size=(3, 3, 3),
|
134
|
+
strides=(1, 1, 1),
|
135
|
+
padding=((1, 1), (1, 1), (1, 1)),
|
136
|
+
dtype=self.dtype,
|
137
|
+
)
|
138
|
+
|
139
|
+
# time
|
140
|
+
self.time_proj = FlaxTimesteps(
|
141
|
+
block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
|
142
|
+
)
|
143
|
+
self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
144
|
+
|
145
|
+
# Handle attention head configurations
|
146
|
+
if isinstance(num_attention_heads, int):
|
147
|
+
num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
|
148
|
+
|
149
|
+
# transformer layers per block
|
150
|
+
transformer_layers_per_block = self.transformer_layers_per_block
|
151
|
+
if isinstance(transformer_layers_per_block, int):
|
152
|
+
transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
|
153
|
+
|
154
|
+
# addition embed types
|
155
|
+
if self.addition_embed_type == "text_time":
|
156
|
+
if self.addition_time_embed_dim is None:
|
157
|
+
raise ValueError(
|
158
|
+
f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
|
159
|
+
)
|
160
|
+
self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
|
161
|
+
self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
|
162
|
+
else:
|
163
|
+
self.add_embedding = None
|
164
|
+
|
165
|
+
# down blocks
|
166
|
+
down_blocks = []
|
167
|
+
output_channel = block_out_channels[0]
|
168
|
+
for i, down_block_type in enumerate(self.down_block_types):
|
169
|
+
input_channel = output_channel
|
170
|
+
output_channel = block_out_channels[i]
|
171
|
+
is_final_block = i == len(block_out_channels) - 1
|
172
|
+
|
173
|
+
if down_block_type == "CrossAttnDownBlock3D":
|
174
|
+
down_block = FlaxCrossAttnDownBlock3D(
|
175
|
+
in_channels=input_channel,
|
176
|
+
out_channels=output_channel,
|
177
|
+
dropout=self.dropout,
|
178
|
+
num_layers=self.layers_per_block,
|
179
|
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
180
|
+
num_attention_heads=num_attention_heads[i],
|
181
|
+
add_downsample=not is_final_block,
|
182
|
+
use_linear_projection=self.use_linear_projection,
|
183
|
+
only_cross_attention=False, # We don't use only cross attention in 3D UNet
|
184
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
185
|
+
split_head_dim=self.split_head_dim,
|
186
|
+
dtype=self.dtype,
|
187
|
+
)
|
188
|
+
elif down_block_type == "DownBlock3D":
|
189
|
+
down_block = FlaxDownBlock3D(
|
190
|
+
in_channels=input_channel,
|
191
|
+
out_channels=output_channel,
|
192
|
+
dropout=self.dropout,
|
193
|
+
num_layers=self.layers_per_block,
|
194
|
+
add_downsample=not is_final_block,
|
195
|
+
dtype=self.dtype,
|
196
|
+
)
|
197
|
+
else:
|
198
|
+
raise ValueError(f"Unknown down block type: {down_block_type}")
|
199
|
+
|
200
|
+
down_blocks.append(down_block)
|
201
|
+
self.down_blocks = down_blocks
|
202
|
+
|
203
|
+
# mid block
|
204
|
+
self.mid_block = FlaxUNetMidBlock3DCrossAttn(
|
205
|
+
in_channels=block_out_channels[-1],
|
206
|
+
dropout=self.dropout,
|
207
|
+
num_attention_heads=num_attention_heads[-1],
|
208
|
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
209
|
+
use_linear_projection=self.use_linear_projection,
|
210
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
211
|
+
split_head_dim=self.split_head_dim,
|
212
|
+
dtype=self.dtype,
|
213
|
+
)
|
214
|
+
|
215
|
+
# up blocks
|
216
|
+
up_blocks = []
|
217
|
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
218
|
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
219
|
+
output_channel = reversed_block_out_channels[0]
|
220
|
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
221
|
+
|
222
|
+
for i, up_block_type in enumerate(self.up_block_types):
|
223
|
+
prev_output_channel = output_channel
|
224
|
+
output_channel = reversed_block_out_channels[i]
|
225
|
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
226
|
+
|
227
|
+
is_final_block = i == len(block_out_channels) - 1
|
228
|
+
|
229
|
+
if up_block_type == "CrossAttnUpBlock3D":
|
230
|
+
up_block = FlaxCrossAttnUpBlock3D(
|
231
|
+
in_channels=input_channel,
|
232
|
+
out_channels=output_channel,
|
233
|
+
prev_output_channel=prev_output_channel,
|
234
|
+
num_layers=self.layers_per_block + 1,
|
235
|
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
236
|
+
num_attention_heads=reversed_num_attention_heads[i],
|
237
|
+
add_upsample=not is_final_block,
|
238
|
+
dropout=self.dropout,
|
239
|
+
use_linear_projection=self.use_linear_projection,
|
240
|
+
only_cross_attention=False, # We don't use only cross attention in 3D UNet
|
241
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
242
|
+
split_head_dim=self.split_head_dim,
|
243
|
+
dtype=self.dtype,
|
244
|
+
)
|
245
|
+
elif up_block_type == "UpBlock3D":
|
246
|
+
up_block = FlaxUpBlock3D(
|
247
|
+
in_channels=input_channel,
|
248
|
+
out_channels=output_channel,
|
249
|
+
prev_output_channel=prev_output_channel,
|
250
|
+
num_layers=self.layers_per_block + 1,
|
251
|
+
add_upsample=not is_final_block,
|
252
|
+
dropout=self.dropout,
|
253
|
+
dtype=self.dtype,
|
254
|
+
)
|
255
|
+
else:
|
256
|
+
raise ValueError(f"Unknown up block type: {up_block_type}")
|
257
|
+
|
258
|
+
up_blocks.append(up_block)
|
259
|
+
prev_output_channel = output_channel
|
260
|
+
self.up_blocks = up_blocks
|
261
|
+
|
262
|
+
# out
|
263
|
+
self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
264
|
+
self.conv_out = nn.Conv(
|
265
|
+
self.out_channels,
|
266
|
+
kernel_size=(3, 3, 3),
|
267
|
+
strides=(1, 1, 1),
|
268
|
+
padding=((1, 1), (1, 1), (1, 1)),
|
269
|
+
dtype=self.dtype,
|
270
|
+
)
|
271
|
+
|
272
|
+
def __call__(
|
273
|
+
self,
|
274
|
+
sample: jnp.ndarray,
|
275
|
+
timesteps: Union[jnp.ndarray, float, int],
|
276
|
+
encoder_hidden_states: jnp.ndarray,
|
277
|
+
frame_encoder_hidden_states: Optional[jnp.ndarray] = None,
|
278
|
+
added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
|
279
|
+
down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
|
280
|
+
mid_block_additional_residual: Optional[jnp.ndarray] = None,
|
281
|
+
return_dict: bool = True,
|
282
|
+
train: bool = False,
|
283
|
+
) -> Union[jnp.ndarray]:
|
284
|
+
r"""
|
285
|
+
Args:
|
286
|
+
sample (`jnp.ndarray`): (batch, frames, height, width, channels) noisy inputs tensor
|
287
|
+
timestep (`jnp.ndarray` or `float` or `int`): timesteps
|
288
|
+
encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
|
289
|
+
frame_encoder_hidden_states (`jnp.ndarray`, *optional*):
|
290
|
+
(batch_size, frames, sequence_length, hidden_size) per-frame encoder hidden states
|
291
|
+
added_cond_kwargs: (`dict`, *optional*):
|
292
|
+
Additional embeddings to add to the time embeddings
|
293
|
+
down_block_additional_residuals: (`tuple` of `jnp.ndarray`, *optional*):
|
294
|
+
Additional residual connections for down blocks
|
295
|
+
mid_block_additional_residual: (`jnp.ndarray`, *optional*):
|
296
|
+
Additional residual connection for mid block
|
297
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
298
|
+
Whether to return a dict or tuple
|
299
|
+
train (`bool`, *optional*, defaults to `False`):
|
300
|
+
Training mode flag for dropout
|
301
|
+
"""
|
302
|
+
# Extract the number of frames from the input
|
303
|
+
batch, num_frames, height, width, channels = sample.shape
|
304
|
+
|
305
|
+
# 1. Time embedding
|
306
|
+
if not isinstance(timesteps, jnp.ndarray):
|
307
|
+
timesteps = jnp.array([timesteps], dtype=jnp.int32)
|
308
|
+
elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
|
309
|
+
timesteps = timesteps.astype(dtype=jnp.float32)
|
310
|
+
timesteps = jnp.expand_dims(timesteps, 0)
|
311
|
+
|
312
|
+
t_emb = self.time_proj(timesteps)
|
313
|
+
t_emb = self.time_embedding(t_emb)
|
314
|
+
|
315
|
+
# Repeat time embedding for each frame
|
316
|
+
t_emb = jnp.repeat(t_emb, repeats=num_frames, axis=0)
|
317
|
+
|
318
|
+
|
319
|
+
# additional embeddings
|
320
|
+
if self.add_embedding is not None and added_cond_kwargs is not None:
|
321
|
+
if "text_embeds" not in added_cond_kwargs:
|
322
|
+
raise ValueError(
|
323
|
+
"text_embeds must be provided for text_time addition_embed_type"
|
324
|
+
)
|
325
|
+
if "time_ids" not in added_cond_kwargs:
|
326
|
+
raise ValueError(
|
327
|
+
"time_ids must be provided for text_time addition_embed_type"
|
328
|
+
)
|
329
|
+
|
330
|
+
text_embeds = added_cond_kwargs["text_embeds"]
|
331
|
+
time_ids = added_cond_kwargs["time_ids"]
|
332
|
+
|
333
|
+
# Compute time embeds
|
334
|
+
time_embeds = self.add_time_proj(jnp.ravel(time_ids))
|
335
|
+
time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
|
336
|
+
|
337
|
+
# Concatenate text and time embeds
|
338
|
+
add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
|
339
|
+
|
340
|
+
# Project to time embedding dimension
|
341
|
+
aug_emb = self.add_embedding(add_embeds)
|
342
|
+
t_emb = t_emb + aug_emb
|
343
|
+
|
344
|
+
# 2. Pre-process input - reshape from [B, F, H, W, C] to [B*F, H, W, C] for 2D operations
|
345
|
+
sample = sample.reshape(batch * num_frames, height, width, channels)
|
346
|
+
sample = self.conv_in(sample)
|
347
|
+
|
348
|
+
# Process encoder hidden states - repeat for each frame and combine with frame-specific conditioning if provided
|
349
|
+
if encoder_hidden_states is not None:
|
350
|
+
# Repeat video-wide conditioning for each frame: (B, S, X) -> (B*F, S, X)
|
351
|
+
encoder_hidden_states_expanded = jnp.repeat(
|
352
|
+
encoder_hidden_states, repeats=num_frames, axis=0
|
353
|
+
)
|
354
|
+
|
355
|
+
# If we have frame-specific conditioning, reshape and concatenate with video conditioning
|
356
|
+
if frame_encoder_hidden_states is not None:
|
357
|
+
# Reshape from (B, F, S, X) to (B*F, S, X)
|
358
|
+
frame_encoder_hidden_states = frame_encoder_hidden_states.reshape(
|
359
|
+
batch * num_frames, *frame_encoder_hidden_states.shape[2:]
|
360
|
+
)
|
361
|
+
|
362
|
+
# Concatenate along the sequence dimension
|
363
|
+
encoder_hidden_states_combined = jnp.concatenate(
|
364
|
+
[encoder_hidden_states_expanded, frame_encoder_hidden_states],
|
365
|
+
axis=1
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
encoder_hidden_states_combined = encoder_hidden_states_expanded
|
369
|
+
else:
|
370
|
+
encoder_hidden_states_combined = None
|
371
|
+
|
372
|
+
# 3. Down blocks
|
373
|
+
down_block_res_samples = (sample,)
|
374
|
+
for down_block in self.down_blocks:
|
375
|
+
if isinstance(down_block, FlaxCrossAttnDownBlock3D):
|
376
|
+
sample, res_samples = down_block(
|
377
|
+
sample,
|
378
|
+
t_emb,
|
379
|
+
encoder_hidden_states_combined,
|
380
|
+
num_frames=num_frames,
|
381
|
+
deterministic=not train
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
sample, res_samples = down_block(
|
385
|
+
sample,
|
386
|
+
t_emb,
|
387
|
+
num_frames=num_frames,
|
388
|
+
deterministic=not train
|
389
|
+
)
|
390
|
+
down_block_res_samples += res_samples
|
391
|
+
|
392
|
+
# Add additional residuals if provided
|
393
|
+
if down_block_additional_residuals is not None:
|
394
|
+
new_down_block_res_samples = ()
|
395
|
+
|
396
|
+
for down_block_res_sample, down_block_additional_residual in zip(
|
397
|
+
down_block_res_samples, down_block_additional_residuals
|
398
|
+
):
|
399
|
+
down_block_res_sample += down_block_additional_residual
|
400
|
+
new_down_block_res_samples += (down_block_res_sample,)
|
401
|
+
|
402
|
+
down_block_res_samples = new_down_block_res_samples
|
403
|
+
|
404
|
+
# 4. Mid block
|
405
|
+
sample = self.mid_block(
|
406
|
+
sample,
|
407
|
+
t_emb,
|
408
|
+
encoder_hidden_states_combined,
|
409
|
+
num_frames=num_frames,
|
410
|
+
deterministic=not train
|
411
|
+
)
|
412
|
+
|
413
|
+
# Add mid block residual if provided
|
414
|
+
if mid_block_additional_residual is not None:
|
415
|
+
sample += mid_block_additional_residual
|
416
|
+
|
417
|
+
# 5. Up blocks
|
418
|
+
for up_block in self.up_blocks:
|
419
|
+
res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
|
420
|
+
down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
|
421
|
+
if isinstance(up_block, FlaxCrossAttnUpBlock3D):
|
422
|
+
sample = up_block(
|
423
|
+
sample,
|
424
|
+
res_hidden_states_tuple=res_samples,
|
425
|
+
temb=t_emb,
|
426
|
+
encoder_hidden_states=encoder_hidden_states_combined,
|
427
|
+
num_frames=num_frames,
|
428
|
+
deterministic=not train,
|
429
|
+
)
|
430
|
+
else:
|
431
|
+
sample = up_block(
|
432
|
+
sample,
|
433
|
+
res_hidden_states_tuple=res_samples,
|
434
|
+
temb=t_emb,
|
435
|
+
num_frames=num_frames,
|
436
|
+
deterministic=not train
|
437
|
+
)
|
438
|
+
|
439
|
+
# 6. Post-process
|
440
|
+
sample = self.conv_norm_out(sample)
|
441
|
+
sample = nn.silu(sample)
|
442
|
+
sample = self.conv_out(sample)
|
443
|
+
|
444
|
+
# Reshape back to [B, F, H, W, C]
|
445
|
+
sample = sample.reshape(batch, num_frames, height, width, self.out_channels)
|
446
|
+
return sample
|