diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +1 -14
  4. diffusers/dependency_versions_table.py +5 -4
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +11 -6
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. diffusers/utils/versions.py +117 -0
  171. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
  172. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
  173. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
  174. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
  175. diffusers/loaders.py +0 -3336
  176. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  177. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -164,7 +164,12 @@ class Upsample2D(nn.Module):
164
164
  else:
165
165
  self.Conv2d_0 = conv
166
166
 
167
- def forward(self, hidden_states: torch.Tensor, output_size: Optional[int] = None, scale: float = 1.0):
167
+ def forward(
168
+ self,
169
+ hidden_states: torch.FloatTensor,
170
+ output_size: Optional[int] = None,
171
+ scale: float = 1.0,
172
+ ) -> torch.FloatTensor:
168
173
  assert hidden_states.shape[1] == self.channels
169
174
 
170
175
  if self.use_conv_transpose:
@@ -256,7 +261,7 @@ class Downsample2D(nn.Module):
256
261
  else:
257
262
  self.conv = conv
258
263
 
259
- def forward(self, hidden_states, scale: float = 1.0):
264
+ def forward(self, hidden_states: torch.FloatTensor, scale: float = 1.0) -> torch.FloatTensor:
260
265
  assert hidden_states.shape[1] == self.channels
261
266
 
262
267
  if self.use_conv and self.padding == 0:
@@ -280,7 +285,7 @@ class FirUpsample2D(nn.Module):
280
285
  """A 2D FIR upsampling layer with an optional convolution.
281
286
 
282
287
  Parameters:
283
- channels (`int`):
288
+ channels (`int`, optional):
284
289
  number of channels in the inputs and outputs.
285
290
  use_conv (`bool`, default `False`):
286
291
  option to use a convolution.
@@ -292,7 +297,7 @@ class FirUpsample2D(nn.Module):
292
297
 
293
298
  def __init__(
294
299
  self,
295
- channels: int = None,
300
+ channels: Optional[int] = None,
296
301
  out_channels: Optional[int] = None,
297
302
  use_conv: bool = False,
298
303
  fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -307,12 +312,12 @@ class FirUpsample2D(nn.Module):
307
312
 
308
313
  def _upsample_2d(
309
314
  self,
310
- hidden_states: torch.Tensor,
311
- weight: Optional[torch.Tensor] = None,
315
+ hidden_states: torch.FloatTensor,
316
+ weight: Optional[torch.FloatTensor] = None,
312
317
  kernel: Optional[torch.FloatTensor] = None,
313
318
  factor: int = 2,
314
319
  gain: float = 1,
315
- ) -> torch.Tensor:
320
+ ) -> torch.FloatTensor:
316
321
  """Fused `upsample_2d()` followed by `Conv2d()`.
317
322
 
318
323
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
@@ -320,17 +325,21 @@ class FirUpsample2D(nn.Module):
320
325
  arbitrary order.
321
326
 
322
327
  Args:
323
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
324
- weight: Weight tensor of the shape `[filterH, filterW, inChannels,
325
- outChannels]`. Grouped convolution can be performed by `inChannels = x.shape[0] // numGroups`.
326
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
327
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
328
- factor: Integer upsampling factor (default: 2).
329
- gain: Scaling factor for signal magnitude (default: 1.0).
328
+ hidden_states (`torch.FloatTensor`):
329
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
330
+ weight (`torch.FloatTensor`, *optional*):
331
+ Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
332
+ performed by `inChannels = x.shape[0] // numGroups`.
333
+ kernel (`torch.FloatTensor`, *optional*):
334
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
335
+ corresponds to nearest-neighbor upsampling.
336
+ factor (`int`, *optional*): Integer upsampling factor (default: 2).
337
+ gain (`float`, *optional*): Scaling factor for signal magnitude (default: 1.0).
330
338
 
331
339
  Returns:
332
- output: Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
333
- datatype as `hidden_states`.
340
+ output (`torch.FloatTensor`):
341
+ Tensor of the shape `[N, C, H * factor, W * factor]` or `[N, H * factor, W * factor, C]`, and same
342
+ datatype as `hidden_states`.
334
343
  """
335
344
 
336
345
  assert isinstance(factor, int) and factor >= 1
@@ -373,7 +382,11 @@ class FirUpsample2D(nn.Module):
373
382
  weight = torch.reshape(weight, (num_groups * inC, -1, convH, convW))
374
383
 
375
384
  inverse_conv = F.conv_transpose2d(
376
- hidden_states, weight, stride=stride, output_padding=output_padding, padding=0
385
+ hidden_states,
386
+ weight,
387
+ stride=stride,
388
+ output_padding=output_padding,
389
+ padding=0,
377
390
  )
378
391
 
379
392
  output = upfirdn2d_native(
@@ -392,7 +405,7 @@ class FirUpsample2D(nn.Module):
392
405
 
393
406
  return output
394
407
 
395
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
408
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
396
409
  if self.use_conv:
397
410
  height = self._upsample_2d(hidden_states, self.Conv2d_0.weight, kernel=self.fir_kernel)
398
411
  height = height + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -418,7 +431,7 @@ class FirDownsample2D(nn.Module):
418
431
 
419
432
  def __init__(
420
433
  self,
421
- channels: int = None,
434
+ channels: Optional[int] = None,
422
435
  out_channels: Optional[int] = None,
423
436
  use_conv: bool = False,
424
437
  fir_kernel: Tuple[int, int, int, int] = (1, 3, 3, 1),
@@ -433,30 +446,35 @@ class FirDownsample2D(nn.Module):
433
446
 
434
447
  def _downsample_2d(
435
448
  self,
436
- hidden_states: torch.Tensor,
437
- weight: Optional[torch.Tensor] = None,
449
+ hidden_states: torch.FloatTensor,
450
+ weight: Optional[torch.FloatTensor] = None,
438
451
  kernel: Optional[torch.FloatTensor] = None,
439
452
  factor: int = 2,
440
453
  gain: float = 1,
441
- ) -> torch.Tensor:
454
+ ) -> torch.FloatTensor:
442
455
  """Fused `Conv2d()` followed by `downsample_2d()`.
443
456
  Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
444
457
  efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
445
458
  arbitrary order.
446
459
 
447
460
  Args:
448
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
449
- weight:
461
+ hidden_states (`torch.FloatTensor`):
462
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
463
+ weight (`torch.FloatTensor`, *optional*):
450
464
  Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
451
465
  performed by `inChannels = x.shape[0] // numGroups`.
452
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] *
453
- factor`, which corresponds to average pooling.
454
- factor: Integer downsampling factor (default: 2).
455
- gain: Scaling factor for signal magnitude (default: 1.0).
466
+ kernel (`torch.FloatTensor`, *optional*):
467
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
468
+ corresponds to average pooling.
469
+ factor (`int`, *optional*, default to `2`):
470
+ Integer downsampling factor.
471
+ gain (`float`, *optional*, default to `1.0`):
472
+ Scaling factor for signal magnitude.
456
473
 
457
474
  Returns:
458
- output: Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and
459
- same datatype as `x`.
475
+ output (`torch.FloatTensor`):
476
+ Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
477
+ datatype as `x`.
460
478
  """
461
479
 
462
480
  assert isinstance(factor, int) and factor >= 1
@@ -492,7 +510,7 @@ class FirDownsample2D(nn.Module):
492
510
 
493
511
  return output
494
512
 
495
- def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
513
+ def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
496
514
  if self.use_conv:
497
515
  downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
498
516
  hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
@@ -519,7 +537,14 @@ class KDownsample2D(nn.Module):
519
537
 
520
538
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
521
539
  inputs = F.pad(inputs, (self.pad,) * 4, self.pad_mode)
522
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
540
+ weight = inputs.new_zeros(
541
+ [
542
+ inputs.shape[1],
543
+ inputs.shape[1],
544
+ self.kernel.shape[0],
545
+ self.kernel.shape[1],
546
+ ]
547
+ )
523
548
  indices = torch.arange(inputs.shape[1], device=inputs.device)
524
549
  kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
525
550
  weight[indices, indices] = kernel
@@ -542,7 +567,14 @@ class KUpsample2D(nn.Module):
542
567
 
543
568
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
544
569
  inputs = F.pad(inputs, ((self.pad + 1) // 2,) * 4, self.pad_mode)
545
- weight = inputs.new_zeros([inputs.shape[1], inputs.shape[1], self.kernel.shape[0], self.kernel.shape[1]])
570
+ weight = inputs.new_zeros(
571
+ [
572
+ inputs.shape[1],
573
+ inputs.shape[1],
574
+ self.kernel.shape[0],
575
+ self.kernel.shape[1],
576
+ ]
577
+ )
546
578
  indices = torch.arange(inputs.shape[1], device=inputs.device)
547
579
  kernel = self.kernel.to(weight)[None, :].expand(inputs.shape[1], -1, -1)
548
580
  weight[indices, indices] = kernel
@@ -679,10 +711,20 @@ class ResnetBlock2D(nn.Module):
679
711
  self.conv_shortcut = None
680
712
  if self.use_in_shortcut:
681
713
  self.conv_shortcut = conv_cls(
682
- in_channels, conv_2d_out_channels, kernel_size=1, stride=1, padding=0, bias=conv_shortcut_bias
714
+ in_channels,
715
+ conv_2d_out_channels,
716
+ kernel_size=1,
717
+ stride=1,
718
+ padding=0,
719
+ bias=conv_shortcut_bias,
683
720
  )
684
721
 
685
- def forward(self, input_tensor, temb, scale: float = 1.0):
722
+ def forward(
723
+ self,
724
+ input_tensor: torch.FloatTensor,
725
+ temb: torch.FloatTensor,
726
+ scale: float = 1.0,
727
+ ) -> torch.FloatTensor:
686
728
  hidden_states = input_tensor
687
729
 
688
730
  if self.time_embedding_norm == "ada_group" or self.time_embedding_norm == "spatial":
@@ -778,7 +820,7 @@ class Conv1dBlock(nn.Module):
778
820
  out_channels (`int`): Number of output channels.
779
821
  kernel_size (`int` or `tuple`): Size of the convolving kernel.
780
822
  n_groups (`int`, default `8`): Number of groups to separate the channels into.
781
- activation (`str`, defaults `mish`): Name of the activation function.
823
+ activation (`str`, defaults to `mish`): Name of the activation function.
782
824
  """
783
825
 
784
826
  def __init__(
@@ -853,8 +895,11 @@ class ResidualTemporalBlock1D(nn.Module):
853
895
 
854
896
 
855
897
  def upsample_2d(
856
- hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
857
- ) -> torch.Tensor:
898
+ hidden_states: torch.FloatTensor,
899
+ kernel: Optional[torch.FloatTensor] = None,
900
+ factor: int = 2,
901
+ gain: float = 1,
902
+ ) -> torch.FloatTensor:
858
903
  r"""Upsample2D a batch of 2D images with the given filter.
859
904
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and upsamples each image with the given
860
905
  filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the specified
@@ -862,14 +907,19 @@ def upsample_2d(
862
907
  a: multiple of the upsampling factor.
863
908
 
864
909
  Args:
865
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
866
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
867
- (separable). The default is `[1] * factor`, which corresponds to nearest-neighbor upsampling.
868
- factor: Integer upsampling factor (default: 2).
869
- gain: Scaling factor for signal magnitude (default: 1.0).
910
+ hidden_states (`torch.FloatTensor`):
911
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
912
+ kernel (`torch.FloatTensor`, *optional*):
913
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
914
+ corresponds to nearest-neighbor upsampling.
915
+ factor (`int`, *optional*, default to `2`):
916
+ Integer upsampling factor.
917
+ gain (`float`, *optional*, default to `1.0`):
918
+ Scaling factor for signal magnitude (default: 1.0).
870
919
 
871
920
  Returns:
872
- output: Tensor of the shape `[N, C, H * factor, W * factor]`
921
+ output (`torch.FloatTensor`):
922
+ Tensor of the shape `[N, C, H * factor, W * factor]`
873
923
  """
874
924
  assert isinstance(factor, int) and factor >= 1
875
925
  if kernel is None:
@@ -892,8 +942,11 @@ def upsample_2d(
892
942
 
893
943
 
894
944
  def downsample_2d(
895
- hidden_states: torch.Tensor, kernel: Optional[torch.FloatTensor] = None, factor: int = 2, gain: float = 1
896
- ) -> torch.Tensor:
945
+ hidden_states: torch.FloatTensor,
946
+ kernel: Optional[torch.FloatTensor] = None,
947
+ factor: int = 2,
948
+ gain: float = 1,
949
+ ) -> torch.FloatTensor:
897
950
  r"""Downsample2D a batch of 2D images with the given filter.
898
951
  Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
899
952
  given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
@@ -901,14 +954,19 @@ def downsample_2d(
901
954
  shape is a multiple of the downsampling factor.
902
955
 
903
956
  Args:
904
- hidden_states: Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
905
- kernel: FIR filter of the shape `[firH, firW]` or `[firN]`
906
- (separable). The default is `[1] * factor`, which corresponds to average pooling.
907
- factor: Integer downsampling factor (default: 2).
908
- gain: Scaling factor for signal magnitude (default: 1.0).
957
+ hidden_states (`torch.FloatTensor`)
958
+ Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
959
+ kernel (`torch.FloatTensor`, *optional*):
960
+ FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
961
+ corresponds to average pooling.
962
+ factor (`int`, *optional*, default to `2`):
963
+ Integer downsampling factor.
964
+ gain (`float`, *optional*, default to `1.0`):
965
+ Scaling factor for signal magnitude.
909
966
 
910
967
  Returns:
911
- output: Tensor of the shape `[N, C, H // factor, W // factor]`
968
+ output (`torch.FloatTensor`):
969
+ Tensor of the shape `[N, C, H // factor, W // factor]`
912
970
  """
913
971
 
914
972
  assert isinstance(factor, int) and factor >= 1
@@ -923,13 +981,20 @@ def downsample_2d(
923
981
  kernel = kernel * gain
924
982
  pad_value = kernel.shape[0] - factor
925
983
  output = upfirdn2d_native(
926
- hidden_states, kernel.to(device=hidden_states.device), down=factor, pad=((pad_value + 1) // 2, pad_value // 2)
984
+ hidden_states,
985
+ kernel.to(device=hidden_states.device),
986
+ down=factor,
987
+ pad=((pad_value + 1) // 2, pad_value // 2),
927
988
  )
928
989
  return output
929
990
 
930
991
 
931
992
  def upfirdn2d_native(
932
- tensor: torch.Tensor, kernel: torch.Tensor, up: int = 1, down: int = 1, pad: Tuple[int, int] = (0, 0)
993
+ tensor: torch.Tensor,
994
+ kernel: torch.Tensor,
995
+ up: int = 1,
996
+ down: int = 1,
997
+ pad: Tuple[int, int] = (0, 0),
933
998
  ) -> torch.Tensor:
934
999
  up_x = up_y = up
935
1000
  down_x = down_y = down
@@ -985,7 +1050,13 @@ class TemporalConvLayer(nn.Module):
985
1050
  dropout (`float`, *optional*, defaults to `0.0`): The dropout probability to use.
986
1051
  """
987
1052
 
988
- def __init__(self, in_dim: int, out_dim: Optional[int] = None, dropout: float = 0.0):
1053
+ def __init__(
1054
+ self,
1055
+ in_dim: int,
1056
+ out_dim: Optional[int] = None,
1057
+ dropout: float = 0.0,
1058
+ norm_num_groups: int = 32,
1059
+ ):
989
1060
  super().__init__()
990
1061
  out_dim = out_dim or in_dim
991
1062
  self.in_dim = in_dim
@@ -993,22 +1064,24 @@ class TemporalConvLayer(nn.Module):
993
1064
 
994
1065
  # conv layers
995
1066
  self.conv1 = nn.Sequential(
996
- nn.GroupNorm(32, in_dim), nn.SiLU(), nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0))
1067
+ nn.GroupNorm(norm_num_groups, in_dim),
1068
+ nn.SiLU(),
1069
+ nn.Conv3d(in_dim, out_dim, (3, 1, 1), padding=(1, 0, 0)),
997
1070
  )
998
1071
  self.conv2 = nn.Sequential(
999
- nn.GroupNorm(32, out_dim),
1072
+ nn.GroupNorm(norm_num_groups, out_dim),
1000
1073
  nn.SiLU(),
1001
1074
  nn.Dropout(dropout),
1002
1075
  nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
1003
1076
  )
1004
1077
  self.conv3 = nn.Sequential(
1005
- nn.GroupNorm(32, out_dim),
1078
+ nn.GroupNorm(norm_num_groups, out_dim),
1006
1079
  nn.SiLU(),
1007
1080
  nn.Dropout(dropout),
1008
1081
  nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
1009
1082
  )
1010
1083
  self.conv4 = nn.Sequential(
1011
- nn.GroupNorm(32, out_dim),
1084
+ nn.GroupNorm(norm_num_groups, out_dim),
1012
1085
  nn.SiLU(),
1013
1086
  nn.Dropout(dropout),
1014
1087
  nn.Conv3d(out_dim, in_dim, (3, 1, 1), padding=(1, 0, 0)),
@@ -1035,3 +1108,261 @@ class TemporalConvLayer(nn.Module):
1035
1108
  (hidden_states.shape[0] * hidden_states.shape[2], -1) + hidden_states.shape[3:]
1036
1109
  )
1037
1110
  return hidden_states
1111
+
1112
+
1113
+ class TemporalResnetBlock(nn.Module):
1114
+ r"""
1115
+ A Resnet block.
1116
+
1117
+ Parameters:
1118
+ in_channels (`int`): The number of channels in the input.
1119
+ out_channels (`int`, *optional*, default to be `None`):
1120
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
1121
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
1122
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
1123
+ """
1124
+
1125
+ def __init__(
1126
+ self,
1127
+ in_channels: int,
1128
+ out_channels: Optional[int] = None,
1129
+ temb_channels: int = 512,
1130
+ eps: float = 1e-6,
1131
+ ):
1132
+ super().__init__()
1133
+ self.in_channels = in_channels
1134
+ out_channels = in_channels if out_channels is None else out_channels
1135
+ self.out_channels = out_channels
1136
+
1137
+ kernel_size = (3, 1, 1)
1138
+ padding = [k // 2 for k in kernel_size]
1139
+
1140
+ self.norm1 = torch.nn.GroupNorm(num_groups=32, num_channels=in_channels, eps=eps, affine=True)
1141
+ self.conv1 = nn.Conv3d(
1142
+ in_channels,
1143
+ out_channels,
1144
+ kernel_size=kernel_size,
1145
+ stride=1,
1146
+ padding=padding,
1147
+ )
1148
+
1149
+ if temb_channels is not None:
1150
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
1151
+ else:
1152
+ self.time_emb_proj = None
1153
+
1154
+ self.norm2 = torch.nn.GroupNorm(num_groups=32, num_channels=out_channels, eps=eps, affine=True)
1155
+
1156
+ self.dropout = torch.nn.Dropout(0.0)
1157
+ self.conv2 = nn.Conv3d(
1158
+ out_channels,
1159
+ out_channels,
1160
+ kernel_size=kernel_size,
1161
+ stride=1,
1162
+ padding=padding,
1163
+ )
1164
+
1165
+ self.nonlinearity = get_activation("silu")
1166
+
1167
+ self.use_in_shortcut = self.in_channels != out_channels
1168
+
1169
+ self.conv_shortcut = None
1170
+ if self.use_in_shortcut:
1171
+ self.conv_shortcut = nn.Conv3d(
1172
+ in_channels,
1173
+ out_channels,
1174
+ kernel_size=1,
1175
+ stride=1,
1176
+ padding=0,
1177
+ )
1178
+
1179
+ def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
1180
+ hidden_states = input_tensor
1181
+
1182
+ hidden_states = self.norm1(hidden_states)
1183
+ hidden_states = self.nonlinearity(hidden_states)
1184
+ hidden_states = self.conv1(hidden_states)
1185
+
1186
+ if self.time_emb_proj is not None:
1187
+ temb = self.nonlinearity(temb)
1188
+ temb = self.time_emb_proj(temb)[:, :, :, None, None]
1189
+ temb = temb.permute(0, 2, 1, 3, 4)
1190
+ hidden_states = hidden_states + temb
1191
+
1192
+ hidden_states = self.norm2(hidden_states)
1193
+ hidden_states = self.nonlinearity(hidden_states)
1194
+ hidden_states = self.dropout(hidden_states)
1195
+ hidden_states = self.conv2(hidden_states)
1196
+
1197
+ if self.conv_shortcut is not None:
1198
+ input_tensor = self.conv_shortcut(input_tensor)
1199
+
1200
+ output_tensor = input_tensor + hidden_states
1201
+
1202
+ return output_tensor
1203
+
1204
+
1205
+ # VideoResBlock
1206
+ class SpatioTemporalResBlock(nn.Module):
1207
+ r"""
1208
+ A SpatioTemporal Resnet block.
1209
+
1210
+ Parameters:
1211
+ in_channels (`int`): The number of channels in the input.
1212
+ out_channels (`int`, *optional*, default to be `None`):
1213
+ The number of output channels for the first conv2d layer. If None, same as `in_channels`.
1214
+ temb_channels (`int`, *optional*, default to `512`): the number of channels in timestep embedding.
1215
+ eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the spatial resenet.
1216
+ temporal_eps (`float`, *optional*, defaults to `eps`): The epsilon to use for the temporal resnet.
1217
+ merge_factor (`float`, *optional*, defaults to `0.5`): The merge factor to use for the temporal mixing.
1218
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
1219
+ The merge strategy to use for the temporal mixing.
1220
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
1221
+ If `True`, switch the spatial and temporal mixing.
1222
+ """
1223
+
1224
+ def __init__(
1225
+ self,
1226
+ in_channels: int,
1227
+ out_channels: Optional[int] = None,
1228
+ temb_channels: int = 512,
1229
+ eps: float = 1e-6,
1230
+ temporal_eps: Optional[float] = None,
1231
+ merge_factor: float = 0.5,
1232
+ merge_strategy="learned_with_images",
1233
+ switch_spatial_to_temporal_mix: bool = False,
1234
+ ):
1235
+ super().__init__()
1236
+
1237
+ self.spatial_res_block = ResnetBlock2D(
1238
+ in_channels=in_channels,
1239
+ out_channels=out_channels,
1240
+ temb_channels=temb_channels,
1241
+ eps=eps,
1242
+ )
1243
+
1244
+ self.temporal_res_block = TemporalResnetBlock(
1245
+ in_channels=out_channels if out_channels is not None else in_channels,
1246
+ out_channels=out_channels if out_channels is not None else in_channels,
1247
+ temb_channels=temb_channels,
1248
+ eps=temporal_eps if temporal_eps is not None else eps,
1249
+ )
1250
+
1251
+ self.time_mixer = AlphaBlender(
1252
+ alpha=merge_factor,
1253
+ merge_strategy=merge_strategy,
1254
+ switch_spatial_to_temporal_mix=switch_spatial_to_temporal_mix,
1255
+ )
1256
+
1257
+ def forward(
1258
+ self,
1259
+ hidden_states: torch.FloatTensor,
1260
+ temb: Optional[torch.FloatTensor] = None,
1261
+ image_only_indicator: Optional[torch.Tensor] = None,
1262
+ ):
1263
+ num_frames = image_only_indicator.shape[-1]
1264
+ hidden_states = self.spatial_res_block(hidden_states, temb)
1265
+
1266
+ batch_frames, channels, height, width = hidden_states.shape
1267
+ batch_size = batch_frames // num_frames
1268
+
1269
+ hidden_states_mix = (
1270
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
1271
+ )
1272
+ hidden_states = (
1273
+ hidden_states[None, :].reshape(batch_size, num_frames, channels, height, width).permute(0, 2, 1, 3, 4)
1274
+ )
1275
+
1276
+ if temb is not None:
1277
+ temb = temb.reshape(batch_size, num_frames, -1)
1278
+
1279
+ hidden_states = self.temporal_res_block(hidden_states, temb)
1280
+ hidden_states = self.time_mixer(
1281
+ x_spatial=hidden_states_mix,
1282
+ x_temporal=hidden_states,
1283
+ image_only_indicator=image_only_indicator,
1284
+ )
1285
+
1286
+ hidden_states = hidden_states.permute(0, 2, 1, 3, 4).reshape(batch_frames, channels, height, width)
1287
+ return hidden_states
1288
+
1289
+
1290
+ class AlphaBlender(nn.Module):
1291
+ r"""
1292
+ A module to blend spatial and temporal features.
1293
+
1294
+ Parameters:
1295
+ alpha (`float`): The initial value of the blending factor.
1296
+ merge_strategy (`str`, *optional*, defaults to `learned_with_images`):
1297
+ The merge strategy to use for the temporal mixing.
1298
+ switch_spatial_to_temporal_mix (`bool`, *optional*, defaults to `False`):
1299
+ If `True`, switch the spatial and temporal mixing.
1300
+ """
1301
+
1302
+ strategies = ["learned", "fixed", "learned_with_images"]
1303
+
1304
+ def __init__(
1305
+ self,
1306
+ alpha: float,
1307
+ merge_strategy: str = "learned_with_images",
1308
+ switch_spatial_to_temporal_mix: bool = False,
1309
+ ):
1310
+ super().__init__()
1311
+ self.merge_strategy = merge_strategy
1312
+ self.switch_spatial_to_temporal_mix = switch_spatial_to_temporal_mix # For TemporalVAE
1313
+
1314
+ if merge_strategy not in self.strategies:
1315
+ raise ValueError(f"merge_strategy needs to be in {self.strategies}")
1316
+
1317
+ if self.merge_strategy == "fixed":
1318
+ self.register_buffer("mix_factor", torch.Tensor([alpha]))
1319
+ elif self.merge_strategy == "learned" or self.merge_strategy == "learned_with_images":
1320
+ self.register_parameter("mix_factor", torch.nn.Parameter(torch.Tensor([alpha])))
1321
+ else:
1322
+ raise ValueError(f"Unknown merge strategy {self.merge_strategy}")
1323
+
1324
+ def get_alpha(self, image_only_indicator: torch.Tensor, ndims: int) -> torch.Tensor:
1325
+ if self.merge_strategy == "fixed":
1326
+ alpha = self.mix_factor
1327
+
1328
+ elif self.merge_strategy == "learned":
1329
+ alpha = torch.sigmoid(self.mix_factor)
1330
+
1331
+ elif self.merge_strategy == "learned_with_images":
1332
+ if image_only_indicator is None:
1333
+ raise ValueError("Please provide image_only_indicator to use learned_with_images merge strategy")
1334
+
1335
+ alpha = torch.where(
1336
+ image_only_indicator.bool(),
1337
+ torch.ones(1, 1, device=image_only_indicator.device),
1338
+ torch.sigmoid(self.mix_factor)[..., None],
1339
+ )
1340
+
1341
+ # (batch, channel, frames, height, width)
1342
+ if ndims == 5:
1343
+ alpha = alpha[:, None, :, None, None]
1344
+ # (batch*frames, height*width, channels)
1345
+ elif ndims == 3:
1346
+ alpha = alpha.reshape(-1)[:, None, None]
1347
+ else:
1348
+ raise ValueError(f"Unexpected ndims {ndims}. Dimensions should be 3 or 5")
1349
+
1350
+ else:
1351
+ raise NotImplementedError
1352
+
1353
+ return alpha
1354
+
1355
+ def forward(
1356
+ self,
1357
+ x_spatial: torch.Tensor,
1358
+ x_temporal: torch.Tensor,
1359
+ image_only_indicator: Optional[torch.Tensor] = None,
1360
+ ) -> torch.Tensor:
1361
+ alpha = self.get_alpha(image_only_indicator, x_spatial.ndim)
1362
+ alpha = alpha.to(x_spatial.dtype)
1363
+
1364
+ if self.switch_spatial_to_temporal_mix:
1365
+ alpha = 1.0 - alpha
1366
+
1367
+ x = alpha * x_spatial + (1.0 - alpha) * x_temporal
1368
+ return x