flaxdiff 0.1.38.1__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (44) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  21. flaxdiff/models/autoencoder/diffusers.py +88 -25
  22. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  23. flaxdiff/models/simple_unet.py +5 -5
  24. flaxdiff/models/simple_vit.py +1 -1
  25. flaxdiff/models/unet_3d.py +446 -0
  26. flaxdiff/models/unet_3d_blocks.py +505 -0
  27. flaxdiff/samplers/common.py +358 -96
  28. flaxdiff/samplers/ddim.py +44 -5
  29. flaxdiff/schedulers/karras.py +20 -12
  30. flaxdiff/trainer/__init__.py +2 -1
  31. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  32. flaxdiff/trainer/diffusion_trainer.py +33 -27
  33. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  34. flaxdiff/trainer/simple_trainer.py +48 -31
  35. flaxdiff/utils.py +128 -57
  36. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  37. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  38. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  39. flaxdiff/data/datasets.py +0 -169
  40. flaxdiff/data/sources/gcs.py +0 -81
  41. flaxdiff/data/sources/tfds.py +0 -79
  42. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  43. flaxdiff-0.1.38.1.dist-info/RECORD +0 -50
  44. {flaxdiff-0.1.38.1.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,505 @@
1
+ from typing import Tuple, Optional
2
+
3
+ import flax.linen as nn
4
+ import jax
5
+ import jax.numpy as jnp
6
+
7
+ from diffusers.models.attention_flax import (
8
+ FlaxBasicTransformerBlock,
9
+ FlaxTransformer2DModel,
10
+ )
11
+
12
+ from diffusers.models.resnet_flax import (
13
+ FlaxResnetBlock2D,
14
+ FlaxUpsample2D,
15
+ FlaxDownsample2D,
16
+ )
17
+
18
+ from diffusers.models.unets.unet_2d_blocks_flax import (
19
+ FlaxCrossAttnDownBlock2D,
20
+ FlaxDownBlock2D,
21
+ FlaxUNetMidBlock2DCrossAttn,
22
+ FlaxUpBlock2D,
23
+ FlaxCrossAttnUpBlock2D,
24
+ )
25
+
26
+ class FlaxTransformerTemporalModel(nn.Module):
27
+ """
28
+ Transformer for temporal attention in 3D UNet.
29
+ """
30
+ in_channels: int
31
+ n_heads: int
32
+ d_head: int
33
+ depth: int = 1
34
+ dropout: float = 0.0
35
+ only_cross_attention: bool = False
36
+ dtype: jnp.dtype = jnp.float32
37
+ use_memory_efficient_attention: bool = False
38
+ split_head_dim: bool = False
39
+
40
+ def setup(self):
41
+ self.norm = nn.GroupNorm(num_groups=32, epsilon=1e-5)
42
+
43
+ inner_dim = self.n_heads * self.d_head
44
+ self.proj_in = nn.Dense(inner_dim, dtype=self.dtype)
45
+ # Use existing FlaxBasicTransformerBlock from diffusers
46
+ self.transformer_blocks = [
47
+ FlaxBasicTransformerBlock(
48
+ inner_dim,
49
+ self.n_heads,
50
+ self.d_head,
51
+ dropout=self.dropout,
52
+ only_cross_attention=self.only_cross_attention,
53
+ dtype=self.dtype,
54
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
55
+ split_head_dim=self.split_head_dim,
56
+ )
57
+ for _ in range(self.depth)
58
+ ]
59
+
60
+ self.proj_out = nn.Dense(inner_dim, dtype=self.dtype)
61
+
62
+ self.dropout_layer = nn.Dropout(rate=self.dropout)
63
+
64
+ def __call__(self, hidden_states: jnp.ndarray, context: jnp.ndarray, num_frames: int, deterministic=True):
65
+ # Save original shape for later reshaping
66
+ batch_depth, height, width, channels = hidden_states.shape
67
+ batch = batch_depth // num_frames
68
+
69
+ # Reshape to (batch, depth, height, width, channels)
70
+ hidden_states = hidden_states.reshape(batch, num_frames, height, width, channels)
71
+ residual = hidden_states
72
+
73
+ # Apply normalization
74
+ hidden_states = self.norm(hidden_states)
75
+
76
+ # Reshape for temporal attention: (batch, depth, height, width, channels) ->
77
+ # (batch*height*width, depth, channels)
78
+ hidden_states = hidden_states.transpose(0, 2, 3, 1, 4)
79
+ hidden_states = hidden_states.reshape(batch * height * width, num_frames, channels)
80
+
81
+ # Project input
82
+ hidden_states = self.proj_in(hidden_states)
83
+
84
+ # Apply transformer blocks
85
+ for block in self.transformer_blocks:
86
+ hidden_states = block(hidden_states, context=context, deterministic=deterministic)
87
+
88
+ # Project output
89
+ hidden_states = self.proj_out(hidden_states)
90
+
91
+ # Reshape back to original shape
92
+ hidden_states = hidden_states.reshape(batch, height, width, num_frames, channels)
93
+ hidden_states = hidden_states.transpose(0, 3, 1, 2, 4)
94
+
95
+ # Add residual connection
96
+ hidden_states = hidden_states + residual
97
+
98
+ # Reshape back to (batch*depth, height, width, channels)
99
+ hidden_states = hidden_states.reshape(batch_depth, height, width, channels)
100
+
101
+ return hidden_states
102
+
103
+ class TemporalConvLayer(nn.Module):
104
+ in_channels: int
105
+ out_channels: Optional[int] = None
106
+ dropout: float = 0.0
107
+ norm_num_groups: int = 32
108
+ dtype: jnp.dtype = jnp.float32
109
+
110
+ @nn.compact
111
+ def __call__(self, x: jnp.ndarray, num_frames: int, deterministic=True) -> jnp.ndarray:
112
+ """
113
+ Args:
114
+ x: shape (B*F, H, W, C)
115
+ num_frames: number of frames F per batch element
116
+
117
+ Returns:
118
+ A jnp.ndarray of shape (B*F, H, W, C)
119
+ """
120
+ out_channels = self.out_channels or self.in_channels
121
+ bf, h, w, c = x.shape
122
+ b = bf // num_frames
123
+
124
+ # Reshape to [B, F, H, W, C], interpret F as "depth" for 3D conv
125
+ x = x.reshape(b, num_frames, h, w, c)
126
+ identity = x
127
+
128
+ # conv1: in_channels -> out_channels
129
+ x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
130
+ x = nn.silu(x)
131
+ x = nn.Conv(features=out_channels, kernel_size=(3, 1, 1),
132
+ dtype=self.dtype,
133
+ padding=((1,1), (0,0), (0,0)))(x)
134
+
135
+ # conv2: out_channels -> in_channels
136
+ x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
137
+ x = nn.silu(x)
138
+ x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
139
+ x = nn.Conv(features=self.in_channels, kernel_size=(3, 1, 1),
140
+ dtype=self.dtype,
141
+ padding=((1,1), (0,0), (0,0)))(x)
142
+
143
+ # conv3: in_channels -> in_channels
144
+ x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
145
+ x = nn.silu(x)
146
+ x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
147
+ x = nn.Conv(features=self.in_channels, kernel_size=(3, 1, 1),
148
+ dtype=self.dtype,
149
+ padding=((1,1), (0,0), (0,0)))(x)
150
+
151
+ # conv4 (zero-init): in_channels -> in_channels
152
+ x = nn.GroupNorm(num_groups=self.norm_num_groups)(x)
153
+ x = nn.silu(x)
154
+ x = nn.Dropout(rate=self.dropout)(x, deterministic=deterministic)
155
+ x = nn.Conv(
156
+ features=self.in_channels,
157
+ kernel_size=(3, 1, 1),
158
+ padding=((1,1), (0,0), (0,0)),
159
+ kernel_init=nn.initializers.zeros,
160
+ bias_init=nn.initializers.zeros,
161
+ dtype=self.dtype,
162
+ )(x)
163
+
164
+ # Residual connection and reshape back to (B*F, H, W, C)
165
+ x = identity + x
166
+ x = x.reshape(bf, h, w, c)
167
+ return x
168
+
169
+
170
+ class FlaxCrossAttnDownBlock3D(FlaxCrossAttnDownBlock2D):
171
+ """
172
+ Cross attention 3D downsampling block.
173
+ """
174
+
175
+ def setup(self):
176
+ resnets = []
177
+ temp_convs = []
178
+ attentions = []
179
+ temp_attentions = []
180
+
181
+ for i in range(self.num_layers):
182
+ in_channels = self.in_channels if i == 0 else self.out_channels
183
+
184
+ res_block = FlaxResnetBlock2D(
185
+ in_channels=in_channels,
186
+ out_channels=self.out_channels,
187
+ dropout_prob=self.dropout,
188
+ dtype=self.dtype,
189
+ )
190
+ resnets.append(res_block)
191
+ temp_conv = TemporalConvLayer(
192
+ in_channels=self.out_channels,
193
+ out_channels=self.out_channels,
194
+ dropout=self.dropout,
195
+ dtype=self.dtype,
196
+ )
197
+ temp_convs.append(temp_conv)
198
+ attn_block = FlaxTransformer2DModel(
199
+ in_channels=self.out_channels,
200
+ n_heads=self.num_attention_heads,
201
+ d_head=self.out_channels // self.num_attention_heads,
202
+ depth=self.transformer_layers_per_block,
203
+ use_linear_projection=self.use_linear_projection,
204
+ only_cross_attention=self.only_cross_attention,
205
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
206
+ split_head_dim=self.split_head_dim,
207
+ dtype=self.dtype,
208
+ )
209
+ attentions.append(attn_block)
210
+ temp_attn_block = FlaxTransformerTemporalModel(
211
+ in_channels=self.out_channels,
212
+ n_heads=self.num_attention_heads,
213
+ d_head=self.out_channels // self.num_attention_heads,
214
+ depth=self.transformer_layers_per_block,
215
+ dropout=self.dropout,
216
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
217
+ split_head_dim=self.split_head_dim,
218
+ dtype=self.dtype,
219
+ )
220
+ temp_attentions.append(temp_attn_block)
221
+
222
+ self.temp_convs = temp_convs
223
+ self.temp_attentions = temp_attentions
224
+ self.resnets = resnets
225
+ self.attentions = attentions
226
+
227
+ if self.add_downsample:
228
+ # self.downsamplers_0 = FlaxDownsample3D(self.out_channels, dtype=self.dtype)
229
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
230
+
231
+ def __call__(self, hidden_states, temb, encoder_hidden_states, num_frames, deterministic=True):
232
+ output_states = ()
233
+
234
+ for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions):
235
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
236
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
237
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
238
+ hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
239
+ output_states += (hidden_states,)
240
+
241
+ if self.add_downsample:
242
+ hidden_states = self.downsamplers_0(hidden_states)
243
+ output_states += (hidden_states,)
244
+
245
+ return hidden_states, output_states
246
+
247
+
248
+ class FlaxDownBlock3D(FlaxDownBlock2D):
249
+ """
250
+ Basic downsampling block without attention.
251
+ """
252
+ def setup(self):
253
+ resnets = []
254
+ temp_convs = []
255
+
256
+ for i in range(self.num_layers):
257
+ in_channels = self.in_channels if i == 0 else self.out_channels
258
+
259
+ res_block = FlaxResnetBlock2D(
260
+ in_channels=in_channels,
261
+ out_channels=self.out_channels,
262
+ dropout_prob=self.dropout,
263
+ dtype=self.dtype,
264
+ )
265
+ resnets.append(res_block)
266
+ temp_conv = TemporalConvLayer(
267
+ in_channels=self.out_channels,
268
+ out_channels=self.out_channels,
269
+ dropout=self.dropout,
270
+ dtype=self.dtype,
271
+ )
272
+ temp_convs.append(temp_conv)
273
+ self.temp_convs = temp_convs
274
+ self.resnets = resnets
275
+
276
+ if self.add_downsample:
277
+ self.downsamplers_0 = FlaxDownsample2D(self.out_channels, dtype=self.dtype)
278
+
279
+ def __call__(self, hidden_states, temb, num_frames, deterministic=True):
280
+ output_states = ()
281
+
282
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
283
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
284
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
285
+ output_states += (hidden_states,)
286
+
287
+ if self.add_downsample:
288
+ hidden_states = self.downsamplers_0(hidden_states)
289
+ output_states += (hidden_states,)
290
+
291
+ return hidden_states, output_states
292
+
293
+
294
+ class FlaxCrossAttnUpBlock3D(FlaxCrossAttnUpBlock2D):
295
+ """
296
+ Cross attention 3D upsampling block.
297
+ """
298
+
299
+ def setup(self):
300
+ resnets = []
301
+ temp_convs = []
302
+ attentions = []
303
+ temp_attentions = []
304
+
305
+ for i in range(self.num_layers):
306
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
307
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
308
+
309
+ res_block = FlaxResnetBlock2D(
310
+ in_channels=resnet_in_channels + res_skip_channels,
311
+ out_channels=self.out_channels,
312
+ dropout_prob=self.dropout,
313
+ dtype=self.dtype,
314
+ )
315
+ resnets.append(res_block)
316
+ temp_conv = TemporalConvLayer(
317
+ in_channels=self.out_channels,
318
+ out_channels=self.out_channels,
319
+ dropout=self.dropout,
320
+ dtype=self.dtype,
321
+ )
322
+ temp_convs.append(temp_conv)
323
+ attn_block = FlaxTransformer2DModel(
324
+ in_channels=self.out_channels,
325
+ n_heads=self.num_attention_heads,
326
+ d_head=self.out_channels // self.num_attention_heads,
327
+ depth=self.transformer_layers_per_block,
328
+ use_linear_projection=self.use_linear_projection,
329
+ only_cross_attention=self.only_cross_attention,
330
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
331
+ split_head_dim=self.split_head_dim,
332
+ dtype=self.dtype,
333
+ )
334
+ attentions.append(attn_block)
335
+ temp_attn_block = FlaxTransformerTemporalModel(
336
+ in_channels=self.out_channels,
337
+ n_heads=self.num_attention_heads,
338
+ d_head=self.out_channels // self.num_attention_heads,
339
+ depth=self.transformer_layers_per_block,
340
+ dropout=self.dropout,
341
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
342
+ split_head_dim=self.split_head_dim,
343
+ dtype=self.dtype,
344
+ )
345
+ temp_attentions.append(temp_attn_block)
346
+
347
+ self.resnets = resnets
348
+ self.attentions = attentions
349
+ self.temp_convs = temp_convs
350
+ self.temp_attentions = temp_attentions
351
+
352
+ if self.add_upsample:
353
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
354
+
355
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, encoder_hidden_states, num_frames, deterministic=True):
356
+ for resnet, temp_conv, attn, temp_attn in zip(self.resnets, self.temp_convs, self.attentions, self.temp_attentions):
357
+ # pop res hidden states
358
+ res_hidden_states = res_hidden_states_tuple[-1]
359
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
360
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
361
+
362
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
363
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
364
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
365
+ hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
366
+
367
+ if self.add_upsample:
368
+ hidden_states = self.upsamplers_0(hidden_states)
369
+
370
+ return hidden_states
371
+
372
+
373
+ class FlaxUpBlock3D(FlaxUpBlock2D):
374
+ """
375
+ Basic upsampling block without attention.
376
+ """
377
+ def setup(self):
378
+ resnets = []
379
+ temp_convs = []
380
+
381
+ for i in range(self.num_layers):
382
+ res_skip_channels = self.in_channels if (i == self.num_layers - 1) else self.out_channels
383
+ resnet_in_channels = self.prev_output_channel if i == 0 else self.out_channels
384
+
385
+ res_block = FlaxResnetBlock2D(
386
+ in_channels=resnet_in_channels + res_skip_channels,
387
+ out_channels=self.out_channels,
388
+ dropout_prob=self.dropout,
389
+ dtype=self.dtype,
390
+ )
391
+ resnets.append(res_block)
392
+ temp_conv = TemporalConvLayer(
393
+ in_channels=self.out_channels,
394
+ out_channels=self.out_channels,
395
+ dropout=self.dropout,
396
+ dtype=self.dtype,
397
+ )
398
+ temp_convs.append(temp_conv)
399
+
400
+ self.resnets = resnets
401
+ self.temp_convs = temp_convs
402
+
403
+ if self.add_upsample:
404
+ self.upsamplers_0 = FlaxUpsample2D(self.out_channels, dtype=self.dtype)
405
+
406
+ def __call__(self, hidden_states, res_hidden_states_tuple, temb, num_frames, deterministic=True):
407
+ for resnet, temp_conv in zip(self.resnets, self.temp_convs):
408
+ # pop res hidden states
409
+ res_hidden_states = res_hidden_states_tuple[-1]
410
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
411
+ hidden_states = jnp.concatenate((hidden_states, res_hidden_states), axis=-1)
412
+
413
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
414
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
415
+
416
+ if self.add_upsample:
417
+ hidden_states = self.upsamplers_0(hidden_states)
418
+
419
+ return hidden_states
420
+
421
+
422
+ class FlaxUNetMidBlock3DCrossAttn(FlaxUNetMidBlock2DCrossAttn):
423
+ """
424
+ Middle block with cross-attention for 3D UNet.
425
+ """
426
+ def setup(self):
427
+ # there is always at least one resnet
428
+ resnets = [
429
+ FlaxResnetBlock2D(
430
+ in_channels=self.in_channels,
431
+ out_channels=self.in_channels,
432
+ dropout_prob=self.dropout,
433
+ dtype=self.dtype,
434
+ )
435
+ ]
436
+ temp_convs = [
437
+ TemporalConvLayer(
438
+ in_channels=self.in_channels,
439
+ out_channels=self.in_channels,
440
+ dropout=self.dropout,
441
+ dtype=self.dtype,
442
+ )
443
+ ]
444
+
445
+ attentions = []
446
+ temp_attentions = []
447
+
448
+ for _ in range(self.num_layers):
449
+ attn_block = FlaxTransformer2DModel(
450
+ in_channels=self.in_channels,
451
+ n_heads=self.num_attention_heads,
452
+ d_head=self.in_channels // self.num_attention_heads,
453
+ depth=self.transformer_layers_per_block,
454
+ use_linear_projection=self.use_linear_projection,
455
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
456
+ split_head_dim=self.split_head_dim,
457
+ dtype=self.dtype,
458
+ )
459
+ attentions.append(attn_block)
460
+
461
+ temp_block = FlaxTransformerTemporalModel(
462
+ in_channels=self.in_channels,
463
+ n_heads=self.num_attention_heads,
464
+ d_head=self.in_channels // self.num_attention_heads,
465
+ depth=self.transformer_layers_per_block,
466
+ dropout=self.dropout,
467
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
468
+ split_head_dim=self.split_head_dim,
469
+ dtype=self.dtype,
470
+ )
471
+ temp_attentions.append(temp_block)
472
+
473
+ res_block = FlaxResnetBlock2D(
474
+ in_channels=self.in_channels,
475
+ out_channels=self.in_channels,
476
+ dropout_prob=self.dropout,
477
+ dtype=self.dtype,
478
+ )
479
+ resnets.append(res_block)
480
+ temp_conv = TemporalConvLayer(
481
+ in_channels=self.in_channels,
482
+ out_channels=self.in_channels,
483
+ dropout=self.dropout,
484
+ dtype=self.dtype,
485
+ )
486
+ temp_convs.append(temp_conv)
487
+
488
+ self.temp_convs = temp_convs
489
+ self.temp_attentions = temp_attentions
490
+ self.resnets = resnets
491
+ self.attentions = attentions
492
+
493
+ def __call__(self, hidden_states, temb, encoder_hidden_states, num_frames, deterministic=True):
494
+ hidden_states = self.resnets[0](hidden_states, temb, deterministic=deterministic)
495
+ hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames, deterministic=deterministic)
496
+
497
+ for attn, temp_attn, resnet, temp_conv in zip(
498
+ self.attentions, self.temp_attentions, self.resnets[1:], self.temp_convs[1:]
499
+ ):
500
+ hidden_states = attn(hidden_states, encoder_hidden_states, deterministic=deterministic)
501
+ hidden_states = temp_attn(hidden_states, None, num_frames=num_frames, deterministic=deterministic)
502
+ hidden_states = resnet(hidden_states, temb, deterministic=deterministic)
503
+ hidden_states = temp_conv(hidden_states, num_frames=num_frames, deterministic=deterministic)
504
+
505
+ return hidden_states