flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- flaxdiff/data/__init__.py +5 -1
- flaxdiff/data/benchmark_decord.py +443 -0
- flaxdiff/data/dataloaders.py +608 -0
- flaxdiff/data/dataset_map.py +61 -6
- flaxdiff/data/online_loader.py +779 -150
- flaxdiff/data/sources/audio_utils.py +142 -0
- flaxdiff/data/sources/av_example.py +125 -0
- flaxdiff/data/sources/av_utils.py +590 -0
- flaxdiff/data/sources/base.py +129 -0
- flaxdiff/data/sources/images.py +309 -0
- flaxdiff/data/sources/utils.py +158 -0
- flaxdiff/data/sources/videos.py +250 -0
- flaxdiff/data/sources/voxceleb2.py +412 -0
- flaxdiff/inference/__init__.py +0 -0
- flaxdiff/inference/pipeline.py +260 -0
- flaxdiff/inference/utils.py +320 -0
- flaxdiff/inputs/__init__.py +173 -0
- flaxdiff/inputs/encoders.py +98 -0
- flaxdiff/models/__init__.py +2 -1
- flaxdiff/models/autoencoder/autoencoder.py +141 -9
- flaxdiff/models/autoencoder/diffusers.py +88 -25
- flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
- flaxdiff/models/simple_unet.py +5 -5
- flaxdiff/models/simple_vit.py +1 -1
- flaxdiff/models/unet_3d.py +446 -0
- flaxdiff/models/unet_3d_blocks.py +505 -0
- flaxdiff/samplers/common.py +358 -96
- flaxdiff/samplers/ddim.py +44 -5
- flaxdiff/schedulers/karras.py +20 -12
- flaxdiff/trainer/__init__.py +2 -1
- flaxdiff/trainer/autoencoder_trainer.py +1 -2
- flaxdiff/trainer/diffusion_trainer.py +33 -27
- flaxdiff/trainer/general_diffusion_trainer.py +583 -0
- flaxdiff/trainer/simple_trainer.py +48 -31
- flaxdiff/utils.py +128 -57
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
- flaxdiff-0.2.0.dist-info/RECORD +64 -0
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
- flaxdiff/data/datasets.py +0 -169
- flaxdiff/data/sources/gcs.py +0 -81
- flaxdiff/data/sources/tfds.py +0 -79
- flaxdiff/trainer/video_diffusion_trainer.py +0 -62
- flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
- {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,505 @@
|
|
1
|
+
from typing import Tuple, Optional
|
2
|
+
|
3
|
+
import flax.linen as nn
|
4
|
+
import jax
|
5
|
+
import jax.numpy as jnp
|
6
|
+
|
7
|
+
from diffusers.models.attention_flax import (
|
8
|
+
FlaxBasicTransformerBlock,
|
9
|
+
FlaxTransformer2DModel,
|
10
|
+
)
|
11
|
+
|
12
|
+
from diffusers.models.resnet_flax import (
|
13
|
+
FlaxResnetBlock2D,
|
14
|
+
FlaxUpsample2D,
|
15
|
+
FlaxDownsample2D,
|
16
|
+
)
|
17
|
+
|
18
|
+
from diffusers.models.unets.unet_2d_blocks_flax import (
|
19
|
+
FlaxCrossAttnDownBlock2D,
|
20
|
+
FlaxDownBlock2D,
|
21
|
+
FlaxUNetMidBlock2DCrossAttn,
|
22
|
+
FlaxUpBlock2D,
|
23
|
+
FlaxCrossAttnUpBlock2D,
|
24
|
+
)
|
25
|
+
|
26
|
+
class FlaxTransformerTemporalModel(nn.Module):
|
27
|
+
"""
|
28
|
+
Transformer for temporal attention in 3D UNet.
|
29
|
+
"""
|
30
|
+
in_channels: int
|
31
|
+
n_heads: int
|
32
|
+
d_head: int
|
33
|
+
depth: int = 1
|
34
|
+
dropout: float = 0.0
|
35
|
+
only_cross_attention: bool = False
|
36
|
+
dtype: jnp.dtype = jnp.float32
|
37
|
+
use_memory_efficient_attention: bool = False
|
38
|
+
split_head_dim: bool = False
|
39
|
+
|
40
|
+
def setup(self):
|
41
|
+
self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
|
42
|
+
|
43
|
+
inner_dim = self.n_heads * self.d_head
|
44
|
+
self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
|
45
|
+
# Use existing FlaxBasicTransformerBlock from diffusers
|
46
|
+
self.transformer_blocks = [
|
47
|
+
FlaxBasicTransformerBlock(
|
48
|
+
inner_dim,
|
49
|
+
self.n_heads,
|
50
|
+
self.d_head,
|
51
|
+
dropout=self.dropout,
|
52
|
+
only_cross_attention=self.only_cross_attention,
|
53
|
+
dtype=self.dtype,
|
54
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
55
|
+
split_head_dim=self.split_head_dim,
|
56
|
+
)
|
57
|
+
for _ in range(self.depth)
|
58
|
+
]
|
59
|
+
|
60
|
+
self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
|
61
|
+
|
62
|
+
self.dropout_layer = nn.Dropout(rate=self.dropout)
|
63
|
+
|
64
|
+
def __call__(self, hidden_states: jnp.ndarray, context: jnp.ndarray, num_frames: int, deterministic=True):
|
65
|
+
# Save original shape for later reshaping
|
66
|
+
batch_depth, height, width, channels = hidden_states.shape
|
67
|
+
batch = batch_depth // num_frames
|
68
|
+
|
69
|
+
# Reshape to (batch, depth, height, width, channels)
|
70
|
+
hidden_states = hidden_states.reshape(batch, num_frames, height, width, channels)
|
71
|
+
residual = hidden_states
|
72
|
+
|
73
|
+
# Apply normalization
|
74
|
+
hidden_states = self.norm(hidden_states)
|
75
|
+
|
76
|
+
# Reshape for temporal attention: (batch, depth, height, width, channels) ->
|
77
|
+
# (batch*height*width, depth, channels)
|
78
|
+
hidden_states = hidden_states.transpose(0, 2, 3, 1, 4)
|
79
|
+
hidden_states = hidden_states.reshape(batch * height * width, num_frames, channels)
|
80
|
+
|
81
|
+
# Project input
|
82
|
+
hidden_states = self.proj_in(hidden_states)
|
83
|
+
|
84
|
+
# Apply transformer blocks
|
85
|
+
for block in self.transformer_blocks:
|
86
|
+
hidden_states = block(hidden_states, context=context, deterministic=deterministic)
|
87
|
+
|
88
|
+
# Project output
|
89
|
+
hidden_states = self.proj_out(hidden_states)
|
90
|
+
|
91
|
+
# Reshape back to original shape
|
92
|
+
hidden_states = hidden_states.reshape(batch, height, width, num_frames, channels)
|
93
|
+
hidden_states = hidden_states.transpose(0, 3, 1, 2, 4)
|
94
|
+
|
95
|
+
# Add residual connection
|
96
|
+
hidden_states = hidden_states + residual
|
97
|
+
|
98
|
+
# Reshape back to (batch*depth, height, width, channels)
|
99
|
+
hidden_states = hidden_states.reshape(batch_depth, height, width, channels)
|
100
|
+
|
101
|
+
return hidden_states
|
102
|
+
|
103
|
+
class TemporalConvLayer(nn.Module):
|
104
|
+
in_channels: int
|
105
|
+
out_channels: Optional[int] = None
|
106
|
+
dropout: float = 0.0
|
107
|
+
norm_num_groups: int = 32
|
108
|
+
dtype: jnp.dtype = jnp.float32
|
109
|
+
|
110
|
+
@nn.compact
|
111
|
+
def __call__(self, x: jnp.ndarray, num_frames: int, deterministic=True) -> jnp.ndarray:
|
112
|
+
"""
|
113
|
+
Args:
|
114
|
+
x: shape (B*F, H, W, C)
|
115
|
+
num_frames: number of frames F per batch element
|
116
|
+
|
117
|
+
Returns:
|
118
|
+
A jnp.ndarray of shape (B*F, H, W, C)
|
119
|
+
"""
|
120
|
+
out_channels = self.out_channels or self.in_channels
|
121
|
+
bf, h, w, c = x.shape
|
122
|
+
b = bf // num_frames
|
123
|
+
|
124
|
+
# Reshape to [B, F, H, W, C], interpret F as "depth" for 3D conv
|
125
|
+
x = x.reshape(b, num_frames, h, w, c)
|
126
|
+
identity = x
|
127
|
+
|
128
|
+
# conv1: in_channels -> out_channels
|
129
|
+
x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
|
130
|
+
x = nn.silu(x)
|
131
|
+
x = nn.Conv(features=out_channels, kernel_size=(3, 1, 1),
|
132
|
+
dtype=self.dtype,
|
133
|
+
padding=((1,1), (0,0), (0,0)))(x)
|
134
|
+
|
135
|
+
# conv2: out_channels -> in_channels
|
136
|
+
x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
|
137
|
+
x = nn.silu(x)
|
138
|
+
x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
|
139
|
+
x = nn.Conv(features=self.in_channels, kernel_size=(3, 1, 1),
|
140
|
+
dtype=self.dtype,
|
141
|
+
padding=((1,1), (0,0), (0,0)))(x)
|
142
|
+
|
143
|
+
# conv3: in_channels -> in_channels
|
144
|
+
x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
|
145
|
+
x = nn.silu(x)
|
146
|
+
x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
|
147
|
+
x = nn.Conv(features=self.in_channels, kernel_size=(3, 1, 1),
|
148
|
+
dtype=self.dtype,
|
149
|
+
padding=((1,1), (0,0), (0,0)))(x)
|
150
|
+
|
151
|
+
# conv4 (zero-init): in_channels -> in_channels
|
152
|
+
x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
|
153
|
+
x = nn.silu(x)
|
154
|
+
x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
|
155
|
+
x = nn.Conv(
|
156
|
+
features=self.in_channels,
|
157
|
+
kernel_size=(3, 1, 1),
|
158
|
+
padding=((1,1), (0,0), (0,0)),
|
159
|
+
kernel_init=nn.initializers.zeros,
|
160
|
+
bias_init=nn.initializers.zeros,
|
161
|
+
dtype=self.dtype,
|
162
|
+
)(x)
|
163
|
+
|
164
|
+
# Residual connection and reshape back to (B*F, H, W, C)
|
165
|
+
x = identity + x
|
166
|
+
x = x.reshape(bf, h, w, c)
|
167
|
+
return x
|
168
|
+
|
169
|
+
|
170
|
+
class FlaxCrossAttnDownBlock3D(FlaxCrossAttnDownBlock2D):
|
171
|
+
"""
|
172
|
+
Cross attention 3D downsampling block.
|
173
|
+
"""
|
174
|
+
|
175
|
+
def setup(self):
|
176
|
+
resnets = []
|
177
|
+
temp_convs = []
|
178
|
+
attentions = []
|
179
|
+
temp_attentions = []
|
180
|
+
|
181
|
+
for i in range(self.num_layers):
|
182
|
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
183
|
+
|
184
|
+
res_block = FlaxResnetBlock2D(
|
185
|
+
in_channels=in_channels,
|
186
|
+
out_channels=self.out_channels,
|
187
|
+
dropout_prob=self.dropout,
|
188
|
+
dtype=self.dtype,
|
189
|
+
)
|
190
|
+
resnets.append(res_block)
|
191
|
+
temp_conv = TemporalConvLayer(
|
192
|
+
in_channels=self.out_channels,
|
193
|
+
out_channels=self.out_channels,
|
194
|
+
dropout=self.dropout,
|
195
|
+
dtype=self.dtype,
|
196
|
+
)
|
197
|
+
temp_convs.append(temp_conv)
|
198
|
+
attn_block = FlaxTransformer2DModel(
|
199
|
+
in_channels=self.out_channels,
|
200
|
+
n_heads=self.num_attention_heads,
|
201
|
+
d_head=self.out_channels // self.num_attention_heads,
|
202
|
+
depth=self.transformer_layers_per_block,
|
203
|
+
use_linear_projection=self.use_linear_projection,
|
204
|
+
only_cross_attention=self.only_cross_attention,
|
205
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
206
|
+
split_head_dim=self.split_head_dim,
|
207
|
+
dtype=self.dtype,
|
208
|
+
)
|
209
|
+
attentions.append(attn_block)
|
210
|
+
temp_attn_block = FlaxTransformerTemporalModel(
|
211
|
+
in_channels=self.out_channels,
|
212
|
+
n_heads=self.num_attention_heads,
|
213
|
+
d_head=self.out_channels // self.num_attention_heads,
|
214
|
+
depth=self.transformer_layers_per_block,
|
215
|
+
dropout=self.dropout,
|
216
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
217
|
+
split_head_dim=self.split_head_dim,
|
218
|
+
dtype=self.dtype,
|
219
|
+
)
|
220
|
+
temp_attentions.append(temp_attn_block)
|
221
|
+
|
222
|
+
self.temp_convs = temp_convs
|
223
|
+
self.temp_attentions = temp_attentions
|
224
|
+
self.resnets = resnets
|
225
|
+
self.attentions = attentions
|
226
|
+
|
227
|
+
if self.add_downsample:
|
228
|
+
# self.downsamplers_0 = FlaxDownsample3D(self.out_channels, dtype=self.dtype)
|
229
|
+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
230
|
+
|
231
|
+
def __call__(self, hidden_states, temb, encoder_hidden_states, num_frames, deterministic=True):
|
232
|
+
output_states = ()
|
233
|
+
|
234
|
+
for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions):
|
235
|
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
236
|
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
|
237
|
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
238
|
+
hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
|
239
|
+
output_states += (hidden_states,)
|
240
|
+
|
241
|
+
if self.add_downsample:
|
242
|
+
hidden_states = self.downsamplers_0(hidden_states)
|
243
|
+
output_states += (hidden_states,)
|
244
|
+
|
245
|
+
return hidden_states, output_states
|
246
|
+
|
247
|
+
|
248
|
+
class FlaxDownBlock3D(FlaxDownBlock2D):
|
249
|
+
"""
|
250
|
+
Basic downsampling block without attention.
|
251
|
+
"""
|
252
|
+
def setup(self):
|
253
|
+
resnets = []
|
254
|
+
temp_convs = []
|
255
|
+
|
256
|
+
for i in range(self.num_layers):
|
257
|
+
in_channels = self.in_channels if i == 0 else self.out_channels
|
258
|
+
|
259
|
+
res_block = FlaxResnetBlock2D(
|
260
|
+
in_channels=in_channels,
|
261
|
+
out_channels=self.out_channels,
|
262
|
+
dropout_prob=self.dropout,
|
263
|
+
dtype=self.dtype,
|
264
|
+
)
|
265
|
+
resnets.append(res_block)
|
266
|
+
temp_conv = TemporalConvLayer(
|
267
|
+
in_channels=self.out_channels,
|
268
|
+
out_channels=self.out_channels,
|
269
|
+
dropout=self.dropout,
|
270
|
+
dtype=self.dtype,
|
271
|
+
)
|
272
|
+
temp_convs.append(temp_conv)
|
273
|
+
self.temp_convs = temp_convs
|
274
|
+
self.resnets = resnets
|
275
|
+
|
276
|
+
if self.add_downsample:
|
277
|
+
self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
|
278
|
+
|
279
|
+
def __call__(self, hidden_states, temb, num_frames, deterministic=True):
|
280
|
+
output_states = ()
|
281
|
+
|
282
|
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
283
|
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
284
|
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
|
285
|
+
output_states += (hidden_states,)
|
286
|
+
|
287
|
+
if self.add_downsample:
|
288
|
+
hidden_states = self.downsamplers_0(hidden_states)
|
289
|
+
output_states += (hidden_states,)
|
290
|
+
|
291
|
+
return hidden_states, output_states
|
292
|
+
|
293
|
+
|
294
|
+
class FlaxCrossAttnUpBlock3D(FlaxCrossAttnUpBlock2D):
|
295
|
+
"""
|
296
|
+
Cross attention 3D upsampling block.
|
297
|
+
"""
|
298
|
+
|
299
|
+
def setup(self):
|
300
|
+
resnets = []
|
301
|
+
temp_convs = []
|
302
|
+
attentions = []
|
303
|
+
temp_attentions = []
|
304
|
+
|
305
|
+
for i in range(self.num_layers):
|
306
|
+
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
307
|
+
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
308
|
+
|
309
|
+
res_block = FlaxResnetBlock2D(
|
310
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
311
|
+
out_channels=self.out_channels,
|
312
|
+
dropout_prob=self.dropout,
|
313
|
+
dtype=self.dtype,
|
314
|
+
)
|
315
|
+
resnets.append(res_block)
|
316
|
+
temp_conv = TemporalConvLayer(
|
317
|
+
in_channels=self.out_channels,
|
318
|
+
out_channels=self.out_channels,
|
319
|
+
dropout=self.dropout,
|
320
|
+
dtype=self.dtype,
|
321
|
+
)
|
322
|
+
temp_convs.append(temp_conv)
|
323
|
+
attn_block = FlaxTransformer2DModel(
|
324
|
+
in_channels=self.out_channels,
|
325
|
+
n_heads=self.num_attention_heads,
|
326
|
+
d_head=self.out_channels // self.num_attention_heads,
|
327
|
+
depth=self.transformer_layers_per_block,
|
328
|
+
use_linear_projection=self.use_linear_projection,
|
329
|
+
only_cross_attention=self.only_cross_attention,
|
330
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
331
|
+
split_head_dim=self.split_head_dim,
|
332
|
+
dtype=self.dtype,
|
333
|
+
)
|
334
|
+
attentions.append(attn_block)
|
335
|
+
temp_attn_block = FlaxTransformerTemporalModel(
|
336
|
+
in_channels=self.out_channels,
|
337
|
+
n_heads=self.num_attention_heads,
|
338
|
+
d_head=self.out_channels // self.num_attention_heads,
|
339
|
+
depth=self.transformer_layers_per_block,
|
340
|
+
dropout=self.dropout,
|
341
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
342
|
+
split_head_dim=self.split_head_dim,
|
343
|
+
dtype=self.dtype,
|
344
|
+
)
|
345
|
+
temp_attentions.append(temp_attn_block)
|
346
|
+
|
347
|
+
self.resnets = resnets
|
348
|
+
self.attentions = attentions
|
349
|
+
self.temp_convs = temp_convs
|
350
|
+
self.temp_attentions = temp_attentions
|
351
|
+
|
352
|
+
if self.add_upsample:
|
353
|
+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
354
|
+
|
355
|
+
def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, num_frames, deterministic=True):
|
356
|
+
for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions):
|
357
|
+
# pop res hidden states
|
358
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
359
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
360
|
+
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
361
|
+
|
362
|
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
363
|
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
|
364
|
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
365
|
+
hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
|
366
|
+
|
367
|
+
if self.add_upsample:
|
368
|
+
hidden_states = self.upsamplers_0(hidden_states)
|
369
|
+
|
370
|
+
return hidden_states
|
371
|
+
|
372
|
+
|
373
|
+
class FlaxUpBlock3D(FlaxUpBlock2D):
|
374
|
+
"""
|
375
|
+
Basic upsampling block without attention.
|
376
|
+
"""
|
377
|
+
def setup(self):
|
378
|
+
resnets = []
|
379
|
+
temp_convs = []
|
380
|
+
|
381
|
+
for i in range(self.num_layers):
|
382
|
+
res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
|
383
|
+
resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
|
384
|
+
|
385
|
+
res_block = FlaxResnetBlock2D(
|
386
|
+
in_channels=resnet_in_channels + res_skip_channels,
|
387
|
+
out_channels=self.out_channels,
|
388
|
+
dropout_prob=self.dropout,
|
389
|
+
dtype=self.dtype,
|
390
|
+
)
|
391
|
+
resnets.append(res_block)
|
392
|
+
temp_conv = TemporalConvLayer(
|
393
|
+
in_channels=self.out_channels,
|
394
|
+
out_channels=self.out_channels,
|
395
|
+
dropout=self.dropout,
|
396
|
+
dtype=self.dtype,
|
397
|
+
)
|
398
|
+
temp_convs.append(temp_conv)
|
399
|
+
|
400
|
+
self.resnets = resnets
|
401
|
+
self.temp_convs = temp_convs
|
402
|
+
|
403
|
+
if self.add_upsample:
|
404
|
+
self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
|
405
|
+
|
406
|
+
def __call__(self, hidden_states, res_hidden_states_tuple, temb, num_frames, deterministic=True):
|
407
|
+
for resnet, temp_conv in zip(self.resnets, self.temp_convs):
|
408
|
+
# pop res hidden states
|
409
|
+
res_hidden_states = res_hidden_states_tuple[-1]
|
410
|
+
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
411
|
+
hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
|
412
|
+
|
413
|
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
414
|
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
|
415
|
+
|
416
|
+
if self.add_upsample:
|
417
|
+
hidden_states = self.upsamplers_0(hidden_states)
|
418
|
+
|
419
|
+
return hidden_states
|
420
|
+
|
421
|
+
|
422
|
+
class FlaxUNetMidBlock3DCrossAttn(FlaxUNetMidBlock2DCrossAttn):
|
423
|
+
"""
|
424
|
+
Middle block with cross-attention for 3D UNet.
|
425
|
+
"""
|
426
|
+
def setup(self):
|
427
|
+
# there is always at least one resnet
|
428
|
+
resnets = [
|
429
|
+
FlaxResnetBlock2D(
|
430
|
+
in_channels=self.in_channels,
|
431
|
+
out_channels=self.in_channels,
|
432
|
+
dropout_prob=self.dropout,
|
433
|
+
dtype=self.dtype,
|
434
|
+
)
|
435
|
+
]
|
436
|
+
temp_convs = [
|
437
|
+
TemporalConvLayer(
|
438
|
+
in_channels=self.in_channels,
|
439
|
+
out_channels=self.in_channels,
|
440
|
+
dropout=self.dropout,
|
441
|
+
dtype=self.dtype,
|
442
|
+
)
|
443
|
+
]
|
444
|
+
|
445
|
+
attentions = []
|
446
|
+
temp_attentions = []
|
447
|
+
|
448
|
+
for _ in range(self.num_layers):
|
449
|
+
attn_block = FlaxTransformer2DModel(
|
450
|
+
in_channels=self.in_channels,
|
451
|
+
n_heads=self.num_attention_heads,
|
452
|
+
d_head=self.in_channels // self.num_attention_heads,
|
453
|
+
depth=self.transformer_layers_per_block,
|
454
|
+
use_linear_projection=self.use_linear_projection,
|
455
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
456
|
+
split_head_dim=self.split_head_dim,
|
457
|
+
dtype=self.dtype,
|
458
|
+
)
|
459
|
+
attentions.append(attn_block)
|
460
|
+
|
461
|
+
temp_block = FlaxTransformerTemporalModel(
|
462
|
+
in_channels=self.in_channels,
|
463
|
+
n_heads=self.num_attention_heads,
|
464
|
+
d_head=self.in_channels // self.num_attention_heads,
|
465
|
+
depth=self.transformer_layers_per_block,
|
466
|
+
dropout=self.dropout,
|
467
|
+
use_memory_efficient_attention=self.use_memory_efficient_attention,
|
468
|
+
split_head_dim=self.split_head_dim,
|
469
|
+
dtype=self.dtype,
|
470
|
+
)
|
471
|
+
temp_attentions.append(temp_block)
|
472
|
+
|
473
|
+
res_block = FlaxResnetBlock2D(
|
474
|
+
in_channels=self.in_channels,
|
475
|
+
out_channels=self.in_channels,
|
476
|
+
dropout_prob=self.dropout,
|
477
|
+
dtype=self.dtype,
|
478
|
+
)
|
479
|
+
resnets.append(res_block)
|
480
|
+
temp_conv = TemporalConvLayer(
|
481
|
+
in_channels=self.in_channels,
|
482
|
+
out_channels=self.in_channels,
|
483
|
+
dropout=self.dropout,
|
484
|
+
dtype=self.dtype,
|
485
|
+
)
|
486
|
+
temp_convs.append(temp_conv)
|
487
|
+
|
488
|
+
self.temp_convs = temp_convs
|
489
|
+
self.temp_attentions = temp_attentions
|
490
|
+
self.resnets = resnets
|
491
|
+
self.attentions = attentions
|
492
|
+
|
493
|
+
def __call__(self, hidden_states, temb, encoder_hidden_states, num_frames, deterministic=True):
|
494
|
+
hidden_states = self.resnets[0](hidden_states, temb, deterministic=deterministic)
|
495
|
+
hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames, deterministic=deterministic)
|
496
|
+
|
497
|
+
for attn, temp_attn, resnet, temp_conv in zip(
|
498
|
+
self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
|
499
|
+
):
|
500
|
+
hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
|
501
|
+
hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
|
502
|
+
hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
|
503
|
+
hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
|
504
|
+
|
505
|
+
return hidden_states
|