diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import torch.utils.checkpoint
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin, UNet2DConditionLoadersMixin
23
+ from ...loaders.single_file_model import FromOriginalModelMixin
23
24
  from ...utils import USE_PEFT_BACKEND, BaseOutput, deprecate, logging, scale_lora_layers, unscale_lora_layers
24
25
  from ..activations import get_activation
25
26
  from ..attention_processor import (
@@ -59,14 +60,16 @@ class UNet2DConditionOutput(BaseOutput):
59
60
  The output of [`UNet2DConditionModel`].
60
61
 
61
62
  Args:
62
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`):
63
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
63
64
  The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
64
65
  """
65
66
 
66
- sample: torch.FloatTensor = None
67
+ sample: torch.Tensor = None
67
68
 
68
69
 
69
- class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
70
+ class UNet2DConditionModel(
71
+ ModelMixin, ConfigMixin, FromOriginalModelMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin
72
+ ):
70
73
  r"""
71
74
  A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
72
75
  shaped output.
@@ -161,6 +164,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
161
164
  """
162
165
 
163
166
  _supports_gradient_checkpointing = True
167
+ _no_split_modules = ["BasicTransformerBlock", "ResnetBlock2D", "CrossAttnUpBlock2D"]
164
168
 
165
169
  @register_to_config
166
170
  def __init__(
@@ -580,7 +584,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
580
584
  elif encoder_hid_dim_type == "text_image_proj":
581
585
  # image_embed_dim DOESN'T have to be `cross_attention_dim`. To not clutter the __init__ too much
582
586
  # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
583
- # case when `addition_embed_type == "text_image_proj"` (Kadinsky 2.1)`
587
+ # case when `addition_embed_type == "text_image_proj"` (Kandinsky 2.1)`
584
588
  self.encoder_hid_proj = TextImageProjection(
585
589
  text_embed_dim=encoder_hid_dim,
586
590
  image_embed_dim=cross_attention_dim,
@@ -660,7 +664,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
660
664
  elif addition_embed_type == "text_image":
661
665
  # text_embed_dim and image_embed_dim DON'T have to be `cross_attention_dim`. To not clutter the __init__ too much
662
666
  # they are set to `cross_attention_dim` here as this is exactly the required dimension for the currently only use
663
- # case when `addition_embed_type == "text_image"` (Kadinsky 2.1)`
667
+ # case when `addition_embed_type == "text_image"` (Kandinsky 2.1)`
664
668
  self.add_embedding = TextImageTimeEmbedding(
665
669
  text_embed_dim=cross_attention_dim, image_embed_dim=cross_attention_dim, time_embed_dim=time_embed_dim
666
670
  )
@@ -681,7 +685,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
681
685
  positive_len = 768
682
686
  if isinstance(cross_attention_dim, int):
683
687
  positive_len = cross_attention_dim
684
- elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
688
+ elif isinstance(cross_attention_dim, (list, tuple)):
685
689
  positive_len = cross_attention_dim[0]
686
690
 
687
691
  feature_type = "text-only" if attention_type == "gated" else "text-image"
@@ -865,8 +869,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
865
869
 
866
870
  def fuse_qkv_projections(self):
867
871
  """
868
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
869
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
872
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
873
+ are fused. For cross-attention modules, key and value projection matrices are fused.
870
874
 
871
875
  <Tip warning={true}>
872
876
 
@@ -1010,7 +1014,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
1010
1014
  if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_proj":
1011
1015
  encoder_hidden_states = self.encoder_hid_proj(encoder_hidden_states)
1012
1016
  elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "text_image_proj":
1013
- # Kadinsky 2.1 - style
1017
+ # Kandinsky 2.1 - style
1014
1018
  if "image_embeds" not in added_cond_kwargs:
1015
1019
  raise ValueError(
1016
1020
  f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
@@ -1038,7 +1042,7 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
1038
1042
 
1039
1043
  def forward(
1040
1044
  self,
1041
- sample: torch.FloatTensor,
1045
+ sample: torch.Tensor,
1042
1046
  timestep: Union[torch.Tensor, float, int],
1043
1047
  encoder_hidden_states: torch.Tensor,
1044
1048
  class_labels: Optional[torch.Tensor] = None,
@@ -1056,10 +1060,10 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
1056
1060
  The [`UNet2DConditionModel`] forward method.
1057
1061
 
1058
1062
  Args:
1059
- sample (`torch.FloatTensor`):
1063
+ sample (`torch.Tensor`):
1060
1064
  The noisy input tensor with the following shape `(batch, channel, height, width)`.
1061
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
1062
- encoder_hidden_states (`torch.FloatTensor`):
1065
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
1066
+ encoder_hidden_states (`torch.Tensor`):
1063
1067
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
1064
1068
  class_labels (`torch.Tensor`, *optional*, defaults to `None`):
1065
1069
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
@@ -1093,8 +1097,8 @@ class UNet2DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin,
1093
1097
 
1094
1098
  Returns:
1095
1099
  [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
1096
- If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned, otherwise
1097
- a `tuple` is returned where the first element is the sample tensor.
1100
+ If `return_dict` is True, an [`~models.unets.unet_2d_condition.UNet2DConditionOutput`] is returned,
1101
+ otherwise a `tuple` is returned where the first element is the sample tensor.
1098
1102
  """
1099
1103
  # By default samples have to be AT least a multiple of the overall upsampling factor.
1100
1104
  # The overall upsampling factor is equal to 2 ** (# num of upsampling layers).
@@ -76,7 +76,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
76
76
  up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
77
  The tuple of upsample blocks to use.
78
78
  mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
79
- Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
79
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer
80
+ is skipped.
80
81
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
81
82
  The tuple of output channels for each block.
82
83
  layers_per_block (`int`, *optional*, defaults to 2):
@@ -350,15 +351,15 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
350
351
  mid_block_additional_residual: (`torch.Tensor`, *optional*):
351
352
  A tensor that if specified is added to the residual of the middle unet block.
352
353
  return_dict (`bool`, *optional*, defaults to `True`):
353
- Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of a
354
- plain tuple.
354
+ Whether or not to return a [`models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] instead of
355
+ a plain tuple.
355
356
  train (`bool`, *optional*, defaults to `False`):
356
357
  Use deterministic functions and disable dropout when not training.
357
358
 
358
359
  Returns:
359
360
  [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] or `tuple`:
360
- [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`.
361
- When returning a tuple, the first element is the sample tensor.
361
+ [`~models.unets.unet_2d_condition_flax.FlaxUNet2DConditionOutput`] if `return_dict` is True, otherwise a
362
+ `tuple`. When returning a tuple, the first element is the sample tensor.
362
363
  """
363
364
  # 1. time
364
365
  if not isinstance(timesteps, jnp.ndarray):
@@ -121,6 +121,7 @@ def get_down_block(
121
121
  raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlockMotion")
122
122
  return CrossAttnDownBlockMotion(
123
123
  num_layers=num_layers,
124
+ transformer_layers_per_block=transformer_layers_per_block,
124
125
  in_channels=in_channels,
125
126
  out_channels=out_channels,
126
127
  temb_channels=temb_channels,
@@ -255,6 +256,7 @@ def get_up_block(
255
256
  raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlockMotion")
256
257
  return CrossAttnUpBlockMotion(
257
258
  num_layers=num_layers,
259
+ transformer_layers_per_block=transformer_layers_per_block,
258
260
  in_channels=in_channels,
259
261
  out_channels=out_channels,
260
262
  prev_output_channel=prev_output_channel,
@@ -409,13 +411,13 @@ class UNetMidBlock3DCrossAttn(nn.Module):
409
411
 
410
412
  def forward(
411
413
  self,
412
- hidden_states: torch.FloatTensor,
413
- temb: Optional[torch.FloatTensor] = None,
414
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
415
- attention_mask: Optional[torch.FloatTensor] = None,
414
+ hidden_states: torch.Tensor,
415
+ temb: Optional[torch.Tensor] = None,
416
+ encoder_hidden_states: Optional[torch.Tensor] = None,
417
+ attention_mask: Optional[torch.Tensor] = None,
416
418
  num_frames: int = 1,
417
419
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
418
- ) -> torch.FloatTensor:
420
+ ) -> torch.Tensor:
419
421
  hidden_states = self.resnets[0](hidden_states, temb)
420
422
  hidden_states = self.temp_convs[0](hidden_states, num_frames=num_frames)
421
423
  for attn, temp_attn, resnet, temp_conv in zip(
@@ -542,13 +544,13 @@ class CrossAttnDownBlock3D(nn.Module):
542
544
 
543
545
  def forward(
544
546
  self,
545
- hidden_states: torch.FloatTensor,
546
- temb: Optional[torch.FloatTensor] = None,
547
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
548
- attention_mask: Optional[torch.FloatTensor] = None,
547
+ hidden_states: torch.Tensor,
548
+ temb: Optional[torch.Tensor] = None,
549
+ encoder_hidden_states: Optional[torch.Tensor] = None,
550
+ attention_mask: Optional[torch.Tensor] = None,
549
551
  num_frames: int = 1,
550
552
  cross_attention_kwargs: Dict[str, Any] = None,
551
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
553
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
552
554
  # TODO(Patrick, William) - attention mask is not used
553
555
  output_states = ()
554
556
 
@@ -649,10 +651,10 @@ class DownBlock3D(nn.Module):
649
651
 
650
652
  def forward(
651
653
  self,
652
- hidden_states: torch.FloatTensor,
653
- temb: Optional[torch.FloatTensor] = None,
654
+ hidden_states: torch.Tensor,
655
+ temb: Optional[torch.Tensor] = None,
654
656
  num_frames: int = 1,
655
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
657
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
656
658
  output_states = ()
657
659
 
658
660
  for resnet, temp_conv in zip(self.resnets, self.temp_convs):
@@ -767,15 +769,15 @@ class CrossAttnUpBlock3D(nn.Module):
767
769
 
768
770
  def forward(
769
771
  self,
770
- hidden_states: torch.FloatTensor,
771
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
772
- temb: Optional[torch.FloatTensor] = None,
773
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
772
+ hidden_states: torch.Tensor,
773
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
774
+ temb: Optional[torch.Tensor] = None,
775
+ encoder_hidden_states: Optional[torch.Tensor] = None,
774
776
  upsample_size: Optional[int] = None,
775
- attention_mask: Optional[torch.FloatTensor] = None,
777
+ attention_mask: Optional[torch.Tensor] = None,
776
778
  num_frames: int = 1,
777
779
  cross_attention_kwargs: Dict[str, Any] = None,
778
- ) -> torch.FloatTensor:
780
+ ) -> torch.Tensor:
779
781
  is_freeu_enabled = (
780
782
  getattr(self, "s1", None)
781
783
  and getattr(self, "s2", None)
@@ -889,12 +891,12 @@ class UpBlock3D(nn.Module):
889
891
 
890
892
  def forward(
891
893
  self,
892
- hidden_states: torch.FloatTensor,
893
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
894
- temb: Optional[torch.FloatTensor] = None,
894
+ hidden_states: torch.Tensor,
895
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
896
+ temb: Optional[torch.Tensor] = None,
895
897
  upsample_size: Optional[int] = None,
896
898
  num_frames: int = 1,
897
- ) -> torch.FloatTensor:
899
+ ) -> torch.Tensor:
898
900
  is_freeu_enabled = (
899
901
  getattr(self, "s1", None)
900
902
  and getattr(self, "s2", None)
@@ -1006,12 +1008,12 @@ class DownBlockMotion(nn.Module):
1006
1008
 
1007
1009
  def forward(
1008
1010
  self,
1009
- hidden_states: torch.FloatTensor,
1010
- temb: Optional[torch.FloatTensor] = None,
1011
+ hidden_states: torch.Tensor,
1012
+ temb: Optional[torch.Tensor] = None,
1011
1013
  num_frames: int = 1,
1012
1014
  *args,
1013
1015
  **kwargs,
1014
- ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1016
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, ...]]:
1015
1017
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1016
1018
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1017
1019
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1172,18 +1174,18 @@ class CrossAttnDownBlockMotion(nn.Module):
1172
1174
 
1173
1175
  def forward(
1174
1176
  self,
1175
- hidden_states: torch.FloatTensor,
1176
- temb: Optional[torch.FloatTensor] = None,
1177
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1178
- attention_mask: Optional[torch.FloatTensor] = None,
1177
+ hidden_states: torch.Tensor,
1178
+ temb: Optional[torch.Tensor] = None,
1179
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1180
+ attention_mask: Optional[torch.Tensor] = None,
1179
1181
  num_frames: int = 1,
1180
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1182
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1181
1183
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1182
- additional_residuals: Optional[torch.FloatTensor] = None,
1184
+ additional_residuals: Optional[torch.Tensor] = None,
1183
1185
  ):
1184
1186
  if cross_attention_kwargs is not None:
1185
1187
  if cross_attention_kwargs.get("scale", None) is not None:
1186
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1188
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1187
1189
 
1188
1190
  output_states = ()
1189
1191
 
@@ -1355,19 +1357,19 @@ class CrossAttnUpBlockMotion(nn.Module):
1355
1357
 
1356
1358
  def forward(
1357
1359
  self,
1358
- hidden_states: torch.FloatTensor,
1359
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1360
- temb: Optional[torch.FloatTensor] = None,
1361
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1360
+ hidden_states: torch.Tensor,
1361
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1362
+ temb: Optional[torch.Tensor] = None,
1363
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1362
1364
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1363
1365
  upsample_size: Optional[int] = None,
1364
- attention_mask: Optional[torch.FloatTensor] = None,
1365
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1366
+ attention_mask: Optional[torch.Tensor] = None,
1367
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1366
1368
  num_frames: int = 1,
1367
- ) -> torch.FloatTensor:
1369
+ ) -> torch.Tensor:
1368
1370
  if cross_attention_kwargs is not None:
1369
1371
  if cross_attention_kwargs.get("scale", None) is not None:
1370
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1372
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1371
1373
 
1372
1374
  is_freeu_enabled = (
1373
1375
  getattr(self, "s1", None)
@@ -1516,14 +1518,14 @@ class UpBlockMotion(nn.Module):
1516
1518
 
1517
1519
  def forward(
1518
1520
  self,
1519
- hidden_states: torch.FloatTensor,
1520
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1521
- temb: Optional[torch.FloatTensor] = None,
1521
+ hidden_states: torch.Tensor,
1522
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1523
+ temb: Optional[torch.Tensor] = None,
1522
1524
  upsample_size=None,
1523
1525
  num_frames: int = 1,
1524
1526
  *args,
1525
1527
  **kwargs,
1526
- ) -> torch.FloatTensor:
1528
+ ) -> torch.Tensor:
1527
1529
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1528
1530
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1529
1531
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1697,17 +1699,17 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1697
1699
 
1698
1700
  def forward(
1699
1701
  self,
1700
- hidden_states: torch.FloatTensor,
1701
- temb: Optional[torch.FloatTensor] = None,
1702
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1703
- attention_mask: Optional[torch.FloatTensor] = None,
1702
+ hidden_states: torch.Tensor,
1703
+ temb: Optional[torch.Tensor] = None,
1704
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1705
+ attention_mask: Optional[torch.Tensor] = None,
1704
1706
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1705
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1707
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1706
1708
  num_frames: int = 1,
1707
- ) -> torch.FloatTensor:
1709
+ ) -> torch.Tensor:
1708
1710
  if cross_attention_kwargs is not None:
1709
1711
  if cross_attention_kwargs.get("scale", None) is not None:
1710
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1712
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1711
1713
 
1712
1714
  hidden_states = self.resnets[0](hidden_states, temb)
1713
1715
 
@@ -1809,8 +1811,8 @@ class MidBlockTemporalDecoder(nn.Module):
1809
1811
 
1810
1812
  def forward(
1811
1813
  self,
1812
- hidden_states: torch.FloatTensor,
1813
- image_only_indicator: torch.FloatTensor,
1814
+ hidden_states: torch.Tensor,
1815
+ image_only_indicator: torch.Tensor,
1814
1816
  ):
1815
1817
  hidden_states = self.resnets[0](
1816
1818
  hidden_states,
@@ -1860,9 +1862,9 @@ class UpBlockTemporalDecoder(nn.Module):
1860
1862
 
1861
1863
  def forward(
1862
1864
  self,
1863
- hidden_states: torch.FloatTensor,
1864
- image_only_indicator: torch.FloatTensor,
1865
- ) -> torch.FloatTensor:
1865
+ hidden_states: torch.Tensor,
1866
+ image_only_indicator: torch.Tensor,
1867
+ ) -> torch.Tensor:
1866
1868
  for resnet in self.resnets:
1867
1869
  hidden_states = resnet(
1868
1870
  hidden_states,
@@ -1933,11 +1935,11 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1933
1935
 
1934
1936
  def forward(
1935
1937
  self,
1936
- hidden_states: torch.FloatTensor,
1937
- temb: Optional[torch.FloatTensor] = None,
1938
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1938
+ hidden_states: torch.Tensor,
1939
+ temb: Optional[torch.Tensor] = None,
1940
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1939
1941
  image_only_indicator: Optional[torch.Tensor] = None,
1940
- ) -> torch.FloatTensor:
1942
+ ) -> torch.Tensor:
1941
1943
  hidden_states = self.resnets[0](
1942
1944
  hidden_states,
1943
1945
  temb,
@@ -2029,10 +2031,10 @@ class DownBlockSpatioTemporal(nn.Module):
2029
2031
 
2030
2032
  def forward(
2031
2033
  self,
2032
- hidden_states: torch.FloatTensor,
2033
- temb: Optional[torch.FloatTensor] = None,
2034
+ hidden_states: torch.Tensor,
2035
+ temb: Optional[torch.Tensor] = None,
2034
2036
  image_only_indicator: Optional[torch.Tensor] = None,
2035
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2037
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2036
2038
  output_states = ()
2037
2039
  for resnet in self.resnets:
2038
2040
  if self.training and self.gradient_checkpointing:
@@ -2139,11 +2141,11 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
2139
2141
 
2140
2142
  def forward(
2141
2143
  self,
2142
- hidden_states: torch.FloatTensor,
2143
- temb: Optional[torch.FloatTensor] = None,
2144
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2144
+ hidden_states: torch.Tensor,
2145
+ temb: Optional[torch.Tensor] = None,
2146
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2145
2147
  image_only_indicator: Optional[torch.Tensor] = None,
2146
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2148
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2147
2149
  output_states = ()
2148
2150
 
2149
2151
  blocks = list(zip(self.resnets, self.attentions))
@@ -2238,11 +2240,11 @@ class UpBlockSpatioTemporal(nn.Module):
2238
2240
 
2239
2241
  def forward(
2240
2242
  self,
2241
- hidden_states: torch.FloatTensor,
2242
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2243
- temb: Optional[torch.FloatTensor] = None,
2243
+ hidden_states: torch.Tensor,
2244
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2245
+ temb: Optional[torch.Tensor] = None,
2244
2246
  image_only_indicator: Optional[torch.Tensor] = None,
2245
- ) -> torch.FloatTensor:
2247
+ ) -> torch.Tensor:
2246
2248
  for resnet in self.resnets:
2247
2249
  # pop res hidden states
2248
2250
  res_hidden_states = res_hidden_states_tuple[-1]
@@ -2347,12 +2349,12 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
2347
2349
 
2348
2350
  def forward(
2349
2351
  self,
2350
- hidden_states: torch.FloatTensor,
2351
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2352
- temb: Optional[torch.FloatTensor] = None,
2353
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2352
+ hidden_states: torch.Tensor,
2353
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2354
+ temb: Optional[torch.Tensor] = None,
2355
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2354
2356
  image_only_indicator: Optional[torch.Tensor] = None,
2355
- ) -> torch.FloatTensor:
2357
+ ) -> torch.Tensor:
2356
2358
  for resnet, attn in zip(self.resnets, self.attentions):
2357
2359
  # pop res hidden states
2358
2360
  res_hidden_states = res_hidden_states_tuple[-1]
@@ -55,11 +55,11 @@ class UNet3DConditionOutput(BaseOutput):
55
55
  The output of [`UNet3DConditionModel`].
56
56
 
57
57
  Args:
58
- sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
58
+ sample (`torch.Tensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
59
59
  The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
60
60
  """
61
61
 
62
- sample: torch.FloatTensor
62
+ sample: torch.Tensor
63
63
 
64
64
 
65
65
  class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
@@ -91,6 +91,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
91
91
  cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
92
92
  attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
93
93
  num_attention_heads (`int`, *optional*): The number of attention heads.
94
+ time_cond_proj_dim (`int`, *optional*, defaults to `None`):
95
+ The dimension of `cond_proj` layer in the timestep embedding.
94
96
  """
95
97
 
96
98
  _supports_gradient_checkpointing = False
@@ -123,6 +125,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
123
125
  cross_attention_dim: int = 1024,
124
126
  attention_head_dim: Union[int, Tuple[int]] = 64,
125
127
  num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
128
+ time_cond_proj_dim: Optional[int] = None,
126
129
  ):
127
130
  super().__init__()
128
131
 
@@ -174,6 +177,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
174
177
  timestep_input_dim,
175
178
  time_embed_dim,
176
179
  act_fn=act_fn,
180
+ cond_proj_dim=time_cond_proj_dim,
177
181
  )
178
182
 
179
183
  self.transformer_in = TransformerTemporalModel(
@@ -507,8 +511,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
507
511
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
508
512
  def fuse_qkv_projections(self):
509
513
  """
510
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
511
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
514
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
515
+ are fused. For cross-attention modules, key and value projection matrices are fused.
512
516
 
513
517
  <Tip warning={true}>
514
518
 
@@ -556,7 +560,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
556
560
 
557
561
  def forward(
558
562
  self,
559
- sample: torch.FloatTensor,
563
+ sample: torch.Tensor,
560
564
  timestep: Union[torch.Tensor, float, int],
561
565
  encoder_hidden_states: torch.Tensor,
562
566
  class_labels: Optional[torch.Tensor] = None,
@@ -566,15 +570,15 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
566
570
  down_block_additional_residuals: Optional[Tuple[torch.Tensor]] = None,
567
571
  mid_block_additional_residual: Optional[torch.Tensor] = None,
568
572
  return_dict: bool = True,
569
- ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
573
+ ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
570
574
  r"""
571
575
  The [`UNet3DConditionModel`] forward method.
572
576
 
573
577
  Args:
574
- sample (`torch.FloatTensor`):
578
+ sample (`torch.Tensor`):
575
579
  The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`.
576
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
577
- encoder_hidden_states (`torch.FloatTensor`):
580
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
581
+ encoder_hidden_states (`torch.Tensor`):
578
582
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
579
583
  class_labels (`torch.Tensor`, *optional*, defaults to `None`):
580
584
  Optional class labels for conditioning. Their embeddings will be summed with the timestep embeddings.
@@ -81,8 +81,8 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
81
81
 
82
82
  def forward(
83
83
  self,
84
- hidden_states: torch.FloatTensor,
85
- ) -> torch.FloatTensor:
84
+ hidden_states: torch.Tensor,
85
+ ) -> torch.Tensor:
86
86
  norm_hidden_states = self.norm1(hidden_states)
87
87
  attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
88
88
  hidden_states = attn_output + hidden_states
@@ -99,8 +99,8 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
99
99
 
100
100
  class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
101
101
  r"""
102
- I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep
103
- and returns a sample-shaped output.
102
+ I2VGenXL UNet. It is a conditional 3D UNet model that takes a noisy sample, conditional state, and a timestep and
103
+ returns a sample-shaped output.
104
104
 
105
105
  This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
106
106
  for all models (such as downloading or saving).
@@ -477,8 +477,8 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
477
477
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
478
478
  def fuse_qkv_projections(self):
479
479
  """
480
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
481
- key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
480
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
481
+ are fused. For cross-attention modules, key and value projection matrices are fused.
482
482
 
483
483
  <Tip warning={true}>
484
484
 
@@ -514,7 +514,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
514
514
 
515
515
  def forward(
516
516
  self,
517
- sample: torch.FloatTensor,
517
+ sample: torch.Tensor,
518
518
  timestep: Union[torch.Tensor, float, int],
519
519
  fps: torch.Tensor,
520
520
  image_latents: torch.Tensor,
@@ -523,18 +523,19 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
523
523
  timestep_cond: Optional[torch.Tensor] = None,
524
524
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
525
525
  return_dict: bool = True,
526
- ) -> Union[UNet3DConditionOutput, Tuple[torch.FloatTensor]]:
526
+ ) -> Union[UNet3DConditionOutput, Tuple[torch.Tensor]]:
527
527
  r"""
528
528
  The [`I2VGenXLUNet`] forward method.
529
529
 
530
530
  Args:
531
- sample (`torch.FloatTensor`):
531
+ sample (`torch.Tensor`):
532
532
  The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
533
- timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
533
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
534
534
  fps (`torch.Tensor`): Frames per second for the video being generated. Used as a "micro-condition".
535
- image_latents (`torch.FloatTensor`): Image encodings from the VAE.
536
- image_embeddings (`torch.FloatTensor`): Projection embeddings of the conditioning image computed with a vision encoder.
537
- encoder_hidden_states (`torch.FloatTensor`):
535
+ image_latents (`torch.Tensor`): Image encodings from the VAE.
536
+ image_embeddings (`torch.Tensor`):
537
+ Projection embeddings of the conditioning image computed with a vision encoder.
538
+ encoder_hidden_states (`torch.Tensor`):
538
539
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
539
540
  cross_attention_kwargs (`dict`, *optional*):
540
541
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
@@ -31,7 +31,7 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
31
 
32
32
  @dataclass
33
33
  class Kandinsky3UNetOutput(BaseOutput):
34
- sample: torch.FloatTensor = None
34
+ sample: torch.Tensor = None
35
35
 
36
36
 
37
37
  class Kandinsky3EncoderProj(nn.Module):