flaxdiff 0.1.38__py3-none-any.whl → 0.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (46) hide show
  1. flaxdiff/data/__init__.py +5 -1
  2. flaxdiff/data/benchmark_decord.py +443 -0
  3. flaxdiff/data/dataloaders.py +608 -0
  4. flaxdiff/data/dataset_map.py +61 -6
  5. flaxdiff/data/online_loader.py +779 -150
  6. flaxdiff/data/sources/audio_utils.py +142 -0
  7. flaxdiff/data/sources/av_example.py +125 -0
  8. flaxdiff/data/sources/av_utils.py +590 -0
  9. flaxdiff/data/sources/base.py +129 -0
  10. flaxdiff/data/sources/images.py +309 -0
  11. flaxdiff/data/sources/utils.py +158 -0
  12. flaxdiff/data/sources/videos.py +250 -0
  13. flaxdiff/data/sources/voxceleb2.py +412 -0
  14. flaxdiff/inference/__init__.py +0 -0
  15. flaxdiff/inference/pipeline.py +260 -0
  16. flaxdiff/inference/utils.py +320 -0
  17. flaxdiff/inputs/__init__.py +173 -0
  18. flaxdiff/inputs/encoders.py +98 -0
  19. flaxdiff/models/__init__.py +2 -1
  20. flaxdiff/models/attention.py +22 -16
  21. flaxdiff/models/autoencoder/autoencoder.py +141 -9
  22. flaxdiff/models/autoencoder/diffusers.py +88 -25
  23. flaxdiff/models/autoencoder/simple_autoenc.py +40 -8
  24. flaxdiff/models/common.py +8 -18
  25. flaxdiff/models/simple_unet.py +6 -17
  26. flaxdiff/models/simple_vit.py +9 -13
  27. flaxdiff/models/unet_3d.py +446 -0
  28. flaxdiff/models/unet_3d_blocks.py +505 -0
  29. flaxdiff/samplers/common.py +358 -96
  30. flaxdiff/samplers/ddim.py +44 -5
  31. flaxdiff/schedulers/karras.py +20 -12
  32. flaxdiff/trainer/__init__.py +2 -1
  33. flaxdiff/trainer/autoencoder_trainer.py +1 -2
  34. flaxdiff/trainer/diffusion_trainer.py +35 -29
  35. flaxdiff/trainer/general_diffusion_trainer.py +583 -0
  36. flaxdiff/trainer/simple_trainer.py +51 -16
  37. flaxdiff/utils.py +128 -57
  38. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/METADATA +1 -1
  39. flaxdiff-0.2.0.dist-info/RECORD +64 -0
  40. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/WHEEL +1 -1
  41. flaxdiff/data/datasets.py +0 -169
  42. flaxdiff/data/sources/gcs.py +0 -81
  43. flaxdiff/data/sources/tfds.py +0 -79
  44. flaxdiff/trainer/video_diffusion_trainer.py +0 -62
  45. flaxdiff-0.1.38.dist-info/RECORD +0 -50
  46. {flaxdiff-0.1.38.dist-info → flaxdiff-0.2.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,446 @@
1
+ from typing import Dict, Optional, Tuple, Union
2
+
3
+ import flax
4
+ import flax.linen as nn
5
+ import jax
6
+ import jax.numpy as jnp
7
+ from flax.core.frozen_dict import FrozenDict
8
+
9
+ from diffusers.configuration_utils import ConfigMixin, flax_register_to_config
10
+ from diffusers.utils import BaseOutput
11
+ from diffusers.models.embeddings_flax import FlaxTimestepEmbedding, FlaxTimesteps
12
+ from diffusers.models.modeling_flax_utils import FlaxModelMixin
13
+
14
+ from .unet_3d_blocks import (
15
+ FlaxCrossAttnDownBlock3D,
16
+ FlaxCrossAttnUpBlock3D,
17
+ FlaxDownBlock3D,
18
+ FlaxUNetMidBlock3DCrossAttn,
19
+ FlaxUpBlock3D,
20
+ )
21
+
22
+
23
+ @flax_register_to_config
24
+ class FlaxUNet3DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
25
+ r"""
26
+ A conditional 3D UNet model for video diffusion.
27
+
28
+ Parameters:
29
+ sample_size (`int` or `Tuple[int, int, int]`, *optional*, defaults to (16, 32, 32)):
30
+ The spatial and temporal size of the input sample. Can be provided as a single integer for square spatial size and fixed temporal size.
31
+ in_channels (`int`, *optional*, defaults to 4):
32
+ The number of channels in the input sample.
33
+ out_channels (`int`, *optional*, defaults to 4):
34
+ The number of channels in the output.
35
+ down_block_types (`Tuple[str]`, *optional*, defaults to ("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")):
36
+ The tuple of downsample blocks to use.
37
+ up_block_types (`Tuple[str]`, *optional*, defaults to ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")):
38
+ The tuple of upsample blocks to use.
39
+ block_out_channels (`Tuple[int]`, *optional*, defaults to (320, 640, 1280, 1280)):
40
+ The tuple of output channels for each block.
41
+ layers_per_block (`int`, *optional*, defaults to 2):
42
+ The number of layers per block.
43
+ attention_head_dim (`int`, *optional*, defaults to 8):
44
+ The dimension of the attention heads.
45
+ cross_attention_dim (`int`, *optional*, defaults to 1280):
46
+ The dimension of the cross attention features.
47
+ dropout (`float`, *optional*, defaults to 0):
48
+ Dropout probability for down, up and bottleneck blocks.
49
+ use_linear_projection (`bool`, *optional*, defaults to False):
50
+ Whether to use linear projection in attention blocks.
51
+ dtype (`jnp.dtype`, *optional*, defaults to jnp.float32):
52
+ The dtype of the model weights.
53
+ flip_sin_to_cos (`bool`, *optional*, defaults to True):
54
+ Whether to flip the sin to cos in the time embedding.
55
+ freq_shift (`int`, *optional*, defaults to 0):
56
+ The frequency shift to apply to the time embedding.
57
+ use_memory_efficient_attention (`bool`, *optional*, defaults to False):
58
+ Whether to use memory-efficient attention.
59
+ split_head_dim (`bool`, *optional*, defaults to False):
60
+ Whether to split the head dimension into a new axis for the self-attention computation.
61
+ """
62
+
63
+ sample_size: Union[int, Tuple[int, int, int]] = (16, 32, 32)
64
+ in_channels: int = 4
65
+ out_channels: int = 4
66
+ down_block_types: Tuple[str, ...] = (
67
+ "CrossAttnDownBlock3D",
68
+ "CrossAttnDownBlock3D",
69
+ "CrossAttnDownBlock3D",
70
+ "DownBlock3D",
71
+ )
72
+ up_block_types: Tuple[str, ...] = ("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")
73
+ block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
74
+ layers_per_block: int = 2
75
+ attention_head_dim: Union[int, Tuple[int, ...]] = 8
76
+ num_attention_heads: Optional[Union[int, Tuple[int, ...]]] = None
77
+ cross_attention_dim: int = 1280
78
+ dropout: float = 0.0
79
+ use_linear_projection: bool = False
80
+ dtype: jnp.dtype = jnp.float32
81
+ flip_sin_to_cos: bool = True
82
+ freq_shift: int = 0
83
+ use_memory_efficient_attention: bool = False
84
+ split_head_dim: bool = False
85
+ transformer_layers_per_block: Union[int, Tuple[int, ...]] = 1
86
+ addition_embed_type: Optional[str] = None
87
+ addition_time_embed_dim: Optional[int] = None
88
+
89
+ def init_weights(self, rng: jax.Array) -> FrozenDict:
90
+ # init input tensors
91
+ if isinstance(self.sample_size, int):
92
+ sample_size = (self.sample_size, self.sample_size, self.sample_size)
93
+ else:
94
+ sample_size = self.sample_size
95
+
96
+ # Shape: [batch, frames, height, width, channels]
97
+ sample_shape = (1, sample_size[0], sample_size[1], sample_size[2], self.in_channels)
98
+ sample = jnp.zeros(sample_shape, dtype=jnp.float32)
99
+ timesteps = jnp.ones((1,), dtype=jnp.int32)
100
+ encoder_hidden_states = jnp.zeros((1, 1, self.cross_attention_dim), dtype=jnp.float32)
101
+
102
+ params_rng, dropout_rng = jax.random.split(rng)
103
+ rngs = {"params": params_rng, "dropout": dropout_rng}
104
+
105
+ added_cond_kwargs = None
106
+ if self.addition_embed_type == "text_time":
107
+ # For text-time conditioning for video diffusion
108
+ text_embeds_dim = self.cross_attention_dim
109
+ time_ids_dims = 6 # Default value for video models
110
+ added_cond_kwargs = {
111
+ "text_embeds": jnp.zeros((1, text_embeds_dim), dtype=jnp.float32),
112
+ "time_ids": jnp.zeros((1, time_ids_dims), dtype=jnp.float32),
113
+ }
114
+
115
+ return self.init(rngs, sample, timesteps, encoder_hidden_states, added_cond_kwargs)["params"]
116
+
117
+ def setup(self) -> None:
118
+ block_out_channels = self.block_out_channels
119
+ time_embed_dim = block_out_channels[0] * 4
120
+
121
+ if self.num_attention_heads is not None:
122
+ raise ValueError(
123
+ "At the moment it is not possible to define the number of attention heads via `num_attention_heads` because of a naming issue. "
124
+ "Use `attention_head_dim` instead."
125
+ )
126
+
127
+ # Default behavior: if num_attention_heads is not set, use attention_head_dim
128
+ num_attention_heads = self.num_attention_heads or self.attention_head_dim
129
+
130
+ # input
131
+ self.conv_in = nn.Conv(
132
+ block_out_channels[0],
133
+ kernel_size=(3, 3, 3),
134
+ strides=(1, 1, 1),
135
+ padding=((1, 1), (1, 1), (1, 1)),
136
+ dtype=self.dtype,
137
+ )
138
+
139
+ # time
140
+ self.time_proj = FlaxTimesteps(
141
+ block_out_channels[0], flip_sin_to_cos=self.flip_sin_to_cos, freq_shift=self.freq_shift
142
+ )
143
+ self.time_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
144
+
145
+ # Handle attention head configurations
146
+ if isinstance(num_attention_heads, int):
147
+ num_attention_heads = (num_attention_heads,) * len(self.down_block_types)
148
+
149
+ # transformer layers per block
150
+ transformer_layers_per_block = self.transformer_layers_per_block
151
+ if isinstance(transformer_layers_per_block, int):
152
+ transformer_layers_per_block = [transformer_layers_per_block] * len(self.down_block_types)
153
+
154
+ # addition embed types
155
+ if self.addition_embed_type == "text_time":
156
+ if self.addition_time_embed_dim is None:
157
+ raise ValueError(
158
+ f"addition_embed_type {self.addition_embed_type} requires `addition_time_embed_dim` to not be None"
159
+ )
160
+ self.add_time_proj = FlaxTimesteps(self.addition_time_embed_dim, self.flip_sin_to_cos, self.freq_shift)
161
+ self.add_embedding = FlaxTimestepEmbedding(time_embed_dim, dtype=self.dtype)
162
+ else:
163
+ self.add_embedding = None
164
+
165
+ # down blocks
166
+ down_blocks = []
167
+ output_channel = block_out_channels[0]
168
+ for i, down_block_type in enumerate(self.down_block_types):
169
+ input_channel = output_channel
170
+ output_channel = block_out_channels[i]
171
+ is_final_block = i == len(block_out_channels) - 1
172
+
173
+ if down_block_type == "CrossAttnDownBlock3D":
174
+ down_block = FlaxCrossAttnDownBlock3D(
175
+ in_channels=input_channel,
176
+ out_channels=output_channel,
177
+ dropout=self.dropout,
178
+ num_layers=self.layers_per_block,
179
+ transformer_layers_per_block=transformer_layers_per_block[i],
180
+ num_attention_heads=num_attention_heads[i],
181
+ add_downsample=not is_final_block,
182
+ use_linear_projection=self.use_linear_projection,
183
+ only_cross_attention=False, # We don't use only cross attention in 3D UNet
184
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
185
+ split_head_dim=self.split_head_dim,
186
+ dtype=self.dtype,
187
+ )
188
+ elif down_block_type == "DownBlock3D":
189
+ down_block = FlaxDownBlock3D(
190
+ in_channels=input_channel,
191
+ out_channels=output_channel,
192
+ dropout=self.dropout,
193
+ num_layers=self.layers_per_block,
194
+ add_downsample=not is_final_block,
195
+ dtype=self.dtype,
196
+ )
197
+ else:
198
+ raise ValueError(f"Unknown down block type: {down_block_type}")
199
+
200
+ down_blocks.append(down_block)
201
+ self.down_blocks = down_blocks
202
+
203
+ # mid block
204
+ self.mid_block = FlaxUNetMidBlock3DCrossAttn(
205
+ in_channels=block_out_channels[-1],
206
+ dropout=self.dropout,
207
+ num_attention_heads=num_attention_heads[-1],
208
+ transformer_layers_per_block=transformer_layers_per_block[-1],
209
+ use_linear_projection=self.use_linear_projection,
210
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
211
+ split_head_dim=self.split_head_dim,
212
+ dtype=self.dtype,
213
+ )
214
+
215
+ # up blocks
216
+ up_blocks = []
217
+ reversed_block_out_channels = list(reversed(block_out_channels))
218
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
219
+ output_channel = reversed_block_out_channels[0]
220
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
221
+
222
+ for i, up_block_type in enumerate(self.up_block_types):
223
+ prev_output_channel = output_channel
224
+ output_channel = reversed_block_out_channels[i]
225
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
226
+
227
+ is_final_block = i == len(block_out_channels) - 1
228
+
229
+ if up_block_type == "CrossAttnUpBlock3D":
230
+ up_block = FlaxCrossAttnUpBlock3D(
231
+ in_channels=input_channel,
232
+ out_channels=output_channel,
233
+ prev_output_channel=prev_output_channel,
234
+ num_layers=self.layers_per_block + 1,
235
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
236
+ num_attention_heads=reversed_num_attention_heads[i],
237
+ add_upsample=not is_final_block,
238
+ dropout=self.dropout,
239
+ use_linear_projection=self.use_linear_projection,
240
+ only_cross_attention=False, # We don't use only cross attention in 3D UNet
241
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
242
+ split_head_dim=self.split_head_dim,
243
+ dtype=self.dtype,
244
+ )
245
+ elif up_block_type == "UpBlock3D":
246
+ up_block = FlaxUpBlock3D(
247
+ in_channels=input_channel,
248
+ out_channels=output_channel,
249
+ prev_output_channel=prev_output_channel,
250
+ num_layers=self.layers_per_block + 1,
251
+ add_upsample=not is_final_block,
252
+ dropout=self.dropout,
253
+ dtype=self.dtype,
254
+ )
255
+ else:
256
+ raise ValueError(f"Unknown up block type: {up_block_type}")
257
+
258
+ up_blocks.append(up_block)
259
+ prev_output_channel = output_channel
260
+ self.up_blocks = up_blocks
261
+
262
+ # out
263
+ self.conv_norm_out = nn.GroupNorm(num_groups=32, epsilon=1e-5)
264
+ self.conv_out = nn.Conv(
265
+ self.out_channels,
266
+ kernel_size=(3, 3, 3),
267
+ strides=(1, 1, 1),
268
+ padding=((1, 1), (1, 1), (1, 1)),
269
+ dtype=self.dtype,
270
+ )
271
+
272
+ def __call__(
273
+ self,
274
+ sample: jnp.ndarray,
275
+ timesteps: Union[jnp.ndarray, float, int],
276
+ encoder_hidden_states: jnp.ndarray,
277
+ frame_encoder_hidden_states: Optional[jnp.ndarray] = None,
278
+ added_cond_kwargs: Optional[Union[Dict, FrozenDict]] = None,
279
+ down_block_additional_residuals: Optional[Tuple[jnp.ndarray, ...]] = None,
280
+ mid_block_additional_residual: Optional[jnp.ndarray] = None,
281
+ return_dict: bool = True,
282
+ train: bool = False,
283
+ ) -> Union[jnp.ndarray]:
284
+ r"""
285
+ Args:
286
+ sample (`jnp.ndarray`): (batch, frames, height, width, channels) noisy inputs tensor
287
+ timestep (`jnp.ndarray` or `float` or `int`): timesteps
288
+ encoder_hidden_states (`jnp.ndarray`): (batch_size, sequence_length, hidden_size) encoder hidden states
289
+ frame_encoder_hidden_states (`jnp.ndarray`, *optional*):
290
+ (batch_size, frames, sequence_length, hidden_size) per-frame encoder hidden states
291
+ added_cond_kwargs: (`dict`, *optional*):
292
+ Additional embeddings to add to the time embeddings
293
+ down_block_additional_residuals: (`tuple` of `jnp.ndarray`, *optional*):
294
+ Additional residual connections for down blocks
295
+ mid_block_additional_residual: (`jnp.ndarray`, *optional*):
296
+ Additional residual connection for mid block
297
+ return_dict (`bool`, *optional*, defaults to `True`):
298
+ Whether to return a dict or tuple
299
+ train (`bool`, *optional*, defaults to `False`):
300
+ Training mode flag for dropout
301
+ """
302
+ # Extract the number of frames from the input
303
+ batch, num_frames, height, width, channels = sample.shape
304
+
305
+ # 1. Time embedding
306
+ if not isinstance(timesteps, jnp.ndarray):
307
+ timesteps = jnp.array([timesteps], dtype=jnp.int32)
308
+ elif isinstance(timesteps, jnp.ndarray) and len(timesteps.shape) == 0:
309
+ timesteps = timesteps.astype(dtype=jnp.float32)
310
+ timesteps = jnp.expand_dims(timesteps, 0)
311
+
312
+ t_emb = self.time_proj(timesteps)
313
+ t_emb = self.time_embedding(t_emb)
314
+
315
+ # Repeat time embedding for each frame
316
+ t_emb = jnp.repeat(t_emb, repeats=num_frames, axis=0)
317
+
318
+
319
+ # additional embeddings
320
+ if self.add_embedding is not None and added_cond_kwargs is not None:
321
+ if "text_embeds" not in added_cond_kwargs:
322
+ raise ValueError(
323
+ "text_embeds must be provided for text_time addition_embed_type"
324
+ )
325
+ if "time_ids" not in added_cond_kwargs:
326
+ raise ValueError(
327
+ "time_ids must be provided for text_time addition_embed_type"
328
+ )
329
+
330
+ text_embeds = added_cond_kwargs["text_embeds"]
331
+ time_ids = added_cond_kwargs["time_ids"]
332
+
333
+ # Compute time embeds
334
+ time_embeds = self.add_time_proj(jnp.ravel(time_ids))
335
+ time_embeds = jnp.reshape(time_embeds, (text_embeds.shape[0], -1))
336
+
337
+ # Concatenate text and time embeds
338
+ add_embeds = jnp.concatenate([text_embeds, time_embeds], axis=-1)
339
+
340
+ # Project to time embedding dimension
341
+ aug_emb = self.add_embedding(add_embeds)
342
+ t_emb = t_emb + aug_emb
343
+
344
+ # 2. Pre-process input - reshape from [B, F, H, W, C] to [B*F, H, W, C] for 2D operations
345
+ sample = sample.reshape(batch * num_frames, height, width, channels)
346
+ sample = self.conv_in(sample)
347
+
348
+ # Process encoder hidden states - repeat for each frame and combine with frame-specific conditioning if provided
349
+ if encoder_hidden_states is not None:
350
+ # Repeat video-wide conditioning for each frame: (B, S, X) -> (B*F, S, X)
351
+ encoder_hidden_states_expanded = jnp.repeat(
352
+ encoder_hidden_states, repeats=num_frames, axis=0
353
+ )
354
+
355
+ # If we have frame-specific conditioning, reshape and concatenate with video conditioning
356
+ if frame_encoder_hidden_states is not None:
357
+ # Reshape from (B, F, S, X) to (B*F, S, X)
358
+ frame_encoder_hidden_states = frame_encoder_hidden_states.reshape(
359
+ batch * num_frames, *frame_encoder_hidden_states.shape[2:]
360
+ )
361
+
362
+ # Concatenate along the sequence dimension
363
+ encoder_hidden_states_combined = jnp.concatenate(
364
+ [encoder_hidden_states_expanded, frame_encoder_hidden_states],
365
+ axis=1
366
+ )
367
+ else:
368
+ encoder_hidden_states_combined = encoder_hidden_states_expanded
369
+ else:
370
+ encoder_hidden_states_combined = None
371
+
372
+ # 3. Down blocks
373
+ down_block_res_samples = (sample,)
374
+ for down_block in self.down_blocks:
375
+ if isinstance(down_block, FlaxCrossAttnDownBlock3D):
376
+ sample, res_samples = down_block(
377
+ sample,
378
+ t_emb,
379
+ encoder_hidden_states_combined,
380
+ num_frames=num_frames,
381
+ deterministic=not train
382
+ )
383
+ else:
384
+ sample, res_samples = down_block(
385
+ sample,
386
+ t_emb,
387
+ num_frames=num_frames,
388
+ deterministic=not train
389
+ )
390
+ down_block_res_samples += res_samples
391
+
392
+ # Add additional residuals if provided
393
+ if down_block_additional_residuals is not None:
394
+ new_down_block_res_samples = ()
395
+
396
+ for down_block_res_sample, down_block_additional_residual in zip(
397
+ down_block_res_samples, down_block_additional_residuals
398
+ ):
399
+ down_block_res_sample += down_block_additional_residual
400
+ new_down_block_res_samples += (down_block_res_sample,)
401
+
402
+ down_block_res_samples = new_down_block_res_samples
403
+
404
+ # 4. Mid block
405
+ sample = self.mid_block(
406
+ sample,
407
+ t_emb,
408
+ encoder_hidden_states_combined,
409
+ num_frames=num_frames,
410
+ deterministic=not train
411
+ )
412
+
413
+ # Add mid block residual if provided
414
+ if mid_block_additional_residual is not None:
415
+ sample += mid_block_additional_residual
416
+
417
+ # 5. Up blocks
418
+ for up_block in self.up_blocks:
419
+ res_samples = down_block_res_samples[-(self.layers_per_block + 1) :]
420
+ down_block_res_samples = down_block_res_samples[: -(self.layers_per_block + 1)]
421
+ if isinstance(up_block, FlaxCrossAttnUpBlock3D):
422
+ sample = up_block(
423
+ sample,
424
+ res_hidden_states_tuple=res_samples,
425
+ temb=t_emb,
426
+ encoder_hidden_states=encoder_hidden_states_combined,
427
+ num_frames=num_frames,
428
+ deterministic=not train,
429
+ )
430
+ else:
431
+ sample = up_block(
432
+ sample,
433
+ res_hidden_states_tuple=res_samples,
434
+ temb=t_emb,
435
+ num_frames=num_frames,
436
+ deterministic=not train
437
+ )
438
+
439
+ # 6. Post-process
440
+ sample = self.conv_norm_out(sample)
441
+ sample = nn.silu(sample)
442
+ sample = self.conv_out(sample)
443
+
444
+ # Reshape back to [B, F, H, W, C]
445
+ sample = sample.reshape(batch, num_frames, height, width, self.out_channels)
446
+ return sample