diffusers 0.26.2__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +28 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +278 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +52 -21
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/METADATA +5 -5
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. diffusers-0.26.2.dist-info/RECORD +0 -384
  296. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  297. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/WHEEL +0 -0
  298. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.2.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -75,6 +75,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
75
75
  The tuple of downsample blocks to use.
76
76
  up_block_types (`Tuple[str]`, *optional*, defaults to `("FlaxUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D", "FlaxCrossAttnUpBlock2D")`):
77
77
  The tuple of upsample blocks to use.
78
+ mid_block_type (`str`, *optional*, defaults to `"UNetMidBlock2DCrossAttn"`):
79
+ Block type for middle of UNet, it can be one of `UNetMidBlock2DCrossAttn`. If `None`, the mid block layer is skipped.
78
80
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
79
81
  The tuple of output channels for each block.
80
82
  layers_per_block (`int`, *optional*, defaults to 2):
@@ -107,6 +109,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
107
109
  "DownBlock2D",
108
110
  )
109
111
  up_block_types: Tuple[str, ...] = ("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")
112
+ mid_block_type: Optional[str] = "UNetMidBlock2DCrossAttn"
110
113
  only_cross_attention: Union[bool, Tuple[bool]] = False
111
114
  block_out_channels: Tuple[int, ...] = (320, 640, 1280, 1280)
112
115
  layers_per_block: int = 2
@@ -252,16 +255,21 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
252
255
  self.down_blocks = down_blocks
253
256
 
254
257
  # mid
255
- self.mid_block = FlaxUNetMidBlock2DCrossAttn(
256
- in_channels=block_out_channels[-1],
257
- dropout=self.dropout,
258
- num_attention_heads=num_attention_heads[-1],
259
- transformer_layers_per_block=transformer_layers_per_block[-1],
260
- use_linear_projection=self.use_linear_projection,
261
- use_memory_efficient_attention=self.use_memory_efficient_attention,
262
- split_head_dim=self.split_head_dim,
263
- dtype=self.dtype,
264
- )
258
+ if self.config.mid_block_type == "UNetMidBlock2DCrossAttn":
259
+ self.mid_block = FlaxUNetMidBlock2DCrossAttn(
260
+ in_channels=block_out_channels[-1],
261
+ dropout=self.dropout,
262
+ num_attention_heads=num_attention_heads[-1],
263
+ transformer_layers_per_block=transformer_layers_per_block[-1],
264
+ use_linear_projection=self.use_linear_projection,
265
+ use_memory_efficient_attention=self.use_memory_efficient_attention,
266
+ split_head_dim=self.split_head_dim,
267
+ dtype=self.dtype,
268
+ )
269
+ elif self.config.mid_block_type is None:
270
+ self.mid_block = None
271
+ else:
272
+ raise ValueError(f"Unexpected mid_block_type {self.config.mid_block_type}")
265
273
 
266
274
  # up
267
275
  up_blocks = []
@@ -412,7 +420,8 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
412
420
  down_block_res_samples = new_down_block_res_samples
413
421
 
414
422
  # 4. mid
415
- sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
423
+ if self.mid_block is not None:
424
+ sample = self.mid_block(sample, t_emb, encoder_hidden_states, deterministic=not train)
416
425
 
417
426
  if mid_block_additional_residual is not None:
418
427
  sample += mid_block_additional_residual
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -17,7 +17,7 @@ from typing import Any, Dict, Optional, Tuple, Union
17
17
  import torch
18
18
  from torch import nn
19
19
 
20
- from ...utils import is_torch_version
20
+ from ...utils import deprecate, is_torch_version, logging
21
21
  from ...utils.torch_utils import apply_freeu
22
22
  from ..attention import Attention
23
23
  from ..resnet import (
@@ -35,6 +35,9 @@ from ..transformers.transformer_temporal import (
35
35
  )
36
36
 
37
37
 
38
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
39
+
40
+
38
41
  def get_down_block(
39
42
  down_block_type: str,
40
43
  num_layers: int,
@@ -1005,9 +1008,14 @@ class DownBlockMotion(nn.Module):
1005
1008
  self,
1006
1009
  hidden_states: torch.FloatTensor,
1007
1010
  temb: Optional[torch.FloatTensor] = None,
1008
- scale: float = 1.0,
1009
1011
  num_frames: int = 1,
1012
+ *args,
1013
+ **kwargs,
1010
1014
  ) -> Union[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1015
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1016
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1017
+ deprecate("scale", "1.0.0", deprecation_message)
1018
+
1011
1019
  output_states = ()
1012
1020
 
1013
1021
  blocks = zip(self.resnets, self.motion_modules)
@@ -1029,24 +1037,18 @@ class DownBlockMotion(nn.Module):
1029
1037
  )
1030
1038
  else:
1031
1039
  hidden_states = torch.utils.checkpoint.checkpoint(
1032
- create_custom_forward(resnet), hidden_states, temb, scale
1040
+ create_custom_forward(resnet), hidden_states, temb
1033
1041
  )
1034
- hidden_states = torch.utils.checkpoint.checkpoint(
1035
- create_custom_forward(motion_module),
1036
- hidden_states.requires_grad_(),
1037
- temb,
1038
- num_frames,
1039
- )
1040
1042
 
1041
1043
  else:
1042
- hidden_states = resnet(hidden_states, temb, scale=scale)
1043
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1044
+ hidden_states = resnet(hidden_states, temb)
1045
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1044
1046
 
1045
1047
  output_states = output_states + (hidden_states,)
1046
1048
 
1047
1049
  if self.downsamplers is not None:
1048
1050
  for downsampler in self.downsamplers:
1049
- hidden_states = downsampler(hidden_states, scale=scale)
1051
+ hidden_states = downsampler(hidden_states)
1050
1052
 
1051
1053
  output_states = output_states + (hidden_states,)
1052
1054
 
@@ -1179,9 +1181,11 @@ class CrossAttnDownBlockMotion(nn.Module):
1179
1181
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1180
1182
  additional_residuals: Optional[torch.FloatTensor] = None,
1181
1183
  ):
1182
- output_states = ()
1184
+ if cross_attention_kwargs is not None:
1185
+ if cross_attention_kwargs.get("scale", None) is not None:
1186
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1183
1187
 
1184
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1188
+ output_states = ()
1185
1189
 
1186
1190
  blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
1187
1191
  for i, (resnet, attn, motion_module) in enumerate(blocks):
@@ -1212,7 +1216,7 @@ class CrossAttnDownBlockMotion(nn.Module):
1212
1216
  return_dict=False,
1213
1217
  )[0]
1214
1218
  else:
1215
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1219
+ hidden_states = resnet(hidden_states, temb)
1216
1220
  hidden_states = attn(
1217
1221
  hidden_states,
1218
1222
  encoder_hidden_states=encoder_hidden_states,
@@ -1221,10 +1225,10 @@ class CrossAttnDownBlockMotion(nn.Module):
1221
1225
  encoder_attention_mask=encoder_attention_mask,
1222
1226
  return_dict=False,
1223
1227
  )[0]
1224
- hidden_states = motion_module(
1225
- hidden_states,
1226
- num_frames=num_frames,
1227
- )[0]
1228
+ hidden_states = motion_module(
1229
+ hidden_states,
1230
+ num_frames=num_frames,
1231
+ )[0]
1228
1232
 
1229
1233
  # apply additional residuals to the output of the last pair of resnet and attention blocks
1230
1234
  if i == len(blocks) - 1 and additional_residuals is not None:
@@ -1234,7 +1238,7 @@ class CrossAttnDownBlockMotion(nn.Module):
1234
1238
 
1235
1239
  if self.downsamplers is not None:
1236
1240
  for downsampler in self.downsamplers:
1237
- hidden_states = downsampler(hidden_states, scale=lora_scale)
1241
+ hidden_states = downsampler(hidden_states)
1238
1242
 
1239
1243
  output_states = output_states + (hidden_states,)
1240
1244
 
@@ -1361,7 +1365,10 @@ class CrossAttnUpBlockMotion(nn.Module):
1361
1365
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1362
1366
  num_frames: int = 1,
1363
1367
  ) -> torch.FloatTensor:
1364
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1368
+ if cross_attention_kwargs is not None:
1369
+ if cross_attention_kwargs.get("scale", None) is not None:
1370
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1371
+
1365
1372
  is_freeu_enabled = (
1366
1373
  getattr(self, "s1", None)
1367
1374
  and getattr(self, "s2", None)
@@ -1416,7 +1423,7 @@ class CrossAttnUpBlockMotion(nn.Module):
1416
1423
  return_dict=False,
1417
1424
  )[0]
1418
1425
  else:
1419
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1426
+ hidden_states = resnet(hidden_states, temb)
1420
1427
  hidden_states = attn(
1421
1428
  hidden_states,
1422
1429
  encoder_hidden_states=encoder_hidden_states,
@@ -1425,14 +1432,14 @@ class CrossAttnUpBlockMotion(nn.Module):
1425
1432
  encoder_attention_mask=encoder_attention_mask,
1426
1433
  return_dict=False,
1427
1434
  )[0]
1428
- hidden_states = motion_module(
1429
- hidden_states,
1430
- num_frames=num_frames,
1431
- )[0]
1435
+ hidden_states = motion_module(
1436
+ hidden_states,
1437
+ num_frames=num_frames,
1438
+ )[0]
1432
1439
 
1433
1440
  if self.upsamplers is not None:
1434
1441
  for upsampler in self.upsamplers:
1435
- hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
1442
+ hidden_states = upsampler(hidden_states, upsample_size)
1436
1443
 
1437
1444
  return hidden_states
1438
1445
 
@@ -1513,9 +1520,14 @@ class UpBlockMotion(nn.Module):
1513
1520
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1514
1521
  temb: Optional[torch.FloatTensor] = None,
1515
1522
  upsample_size=None,
1516
- scale: float = 1.0,
1517
1523
  num_frames: int = 1,
1524
+ *args,
1525
+ **kwargs,
1518
1526
  ) -> torch.FloatTensor:
1527
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1528
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1529
+ deprecate("scale", "1.0.0", deprecation_message)
1530
+
1519
1531
  is_freeu_enabled = (
1520
1532
  getattr(self, "s1", None)
1521
1533
  and getattr(self, "s2", None)
@@ -1563,19 +1575,14 @@ class UpBlockMotion(nn.Module):
1563
1575
  hidden_states = torch.utils.checkpoint.checkpoint(
1564
1576
  create_custom_forward(resnet), hidden_states, temb
1565
1577
  )
1566
- hidden_states = torch.utils.checkpoint.checkpoint(
1567
- create_custom_forward(resnet),
1568
- hidden_states,
1569
- temb,
1570
- )
1571
1578
 
1572
1579
  else:
1573
- hidden_states = resnet(hidden_states, temb, scale=scale)
1574
- hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1580
+ hidden_states = resnet(hidden_states, temb)
1581
+ hidden_states = motion_module(hidden_states, num_frames=num_frames)[0]
1575
1582
 
1576
1583
  if self.upsamplers is not None:
1577
1584
  for upsampler in self.upsamplers:
1578
- hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1585
+ hidden_states = upsampler(hidden_states, upsample_size)
1579
1586
 
1580
1587
  return hidden_states
1581
1588
 
@@ -1698,8 +1705,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1698
1705
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
1699
1706
  num_frames: int = 1,
1700
1707
  ) -> torch.FloatTensor:
1701
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1702
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
1708
+ if cross_attention_kwargs is not None:
1709
+ if cross_attention_kwargs.get("scale", None) is not None:
1710
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1711
+
1712
+ hidden_states = self.resnets[0](hidden_states, temb)
1703
1713
 
1704
1714
  blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
1705
1715
  for attn, resnet, motion_module in blocks:
@@ -1748,7 +1758,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1748
1758
  hidden_states,
1749
1759
  num_frames=num_frames,
1750
1760
  )[0]
1751
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1761
+ hidden_states = resnet(hidden_states, temb)
1752
1762
 
1753
1763
  return hidden_states
1754
1764
 
@@ -1,5 +1,5 @@
1
- # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
- # Copyright 2023 The ModelScope Team.
1
+ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
+ # Copyright 2024 The ModelScope Team.
3
3
  #
4
4
  # Licensed under the Apache License, Version 2.0 (the "License");
5
5
  # you may not use this file except in compliance with the License.
@@ -27,6 +27,7 @@ from ..activations import get_activation
27
27
  from ..attention_processor import (
28
28
  ADDED_KV_ATTENTION_PROCESSORS,
29
29
  CROSS_ATTENTION_PROCESSORS,
30
+ Attention,
30
31
  AttentionProcessor,
31
32
  AttnAddedKVProcessor,
32
33
  AttnProcessor,
@@ -54,7 +55,7 @@ class UNet3DConditionOutput(BaseOutput):
54
55
  The output of [`UNet3DConditionModel`].
55
56
 
56
57
  Args:
57
- sample (`torch.FloatTensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
58
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, num_frames, height, width)`):
58
59
  The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
59
60
  """
60
61
 
@@ -74,9 +75,9 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
74
75
  Height and width of input/output sample.
75
76
  in_channels (`int`, *optional*, defaults to 4): The number of channels in the input sample.
76
77
  out_channels (`int`, *optional*, defaults to 4): The number of channels in the output.
77
- down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "CrossAttnDownBlock2D", "DownBlock2D")`):
78
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "CrossAttnDownBlock3D", "DownBlock3D")`):
78
79
  The tuple of downsample blocks to use.
79
- up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D", "CrossAttnUpBlock2D")`):
80
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D", "CrossAttnUpBlock3D")`):
80
81
  The tuple of upsample blocks to use.
81
82
  block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
82
83
  The tuple of output channels for each block.
@@ -87,8 +88,8 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
87
88
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
88
89
  If `None`, normalization and activation layers is skipped in post-processing.
89
90
  norm_eps (`float`, *optional*, defaults to 1e-5): The epsilon to use for the normalization.
90
- cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
91
- attention_head_dim (`int`, *optional*, defaults to 8): The dimension of the attention heads.
91
+ cross_attention_dim (`int`, *optional*, defaults to 1024): The dimension of the cross attention features.
92
+ attention_head_dim (`int`, *optional*, defaults to 64): The dimension of the attention heads.
92
93
  num_attention_heads (`int`, *optional*): The number of attention heads.
93
94
  """
94
95
 
@@ -503,6 +504,44 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
503
504
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
504
505
  setattr(upsample_block, k, None)
505
506
 
507
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
508
+ def fuse_qkv_projections(self):
509
+ """
510
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
511
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
512
+
513
+ <Tip warning={true}>
514
+
515
+ This API is 🧪 experimental.
516
+
517
+ </Tip>
518
+ """
519
+ self.original_attn_processors = None
520
+
521
+ for _, attn_processor in self.attn_processors.items():
522
+ if "Added" in str(attn_processor.__class__.__name__):
523
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
524
+
525
+ self.original_attn_processors = self.attn_processors
526
+
527
+ for module in self.modules():
528
+ if isinstance(module, Attention):
529
+ module.fuse_projections(fuse=True)
530
+
531
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
532
+ def unfuse_qkv_projections(self):
533
+ """Disables the fused QKV projection if enabled.
534
+
535
+ <Tip warning={true}>
536
+
537
+ This API is 🧪 experimental.
538
+
539
+ </Tip>
540
+
541
+ """
542
+ if self.original_attn_processors is not None:
543
+ self.set_attn_processor(self.original_attn_processors)
544
+
506
545
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unload_lora
507
546
  def unload_lora(self):
508
547
  """Unloads LoRA weights."""
@@ -533,7 +572,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
533
572
 
534
573
  Args:
535
574
  sample (`torch.FloatTensor`):
536
- The noisy input tensor with the following shape `(batch, num_frames, channel, height, width`.
575
+ The noisy input tensor with the following shape `(batch, num_channels, num_frames, height, width`.
537
576
  timestep (`torch.FloatTensor` or `float` or `int`): The number of timesteps to denoise an input.
538
577
  encoder_hidden_states (`torch.FloatTensor`):
539
578
  The encoder hidden states with shape `(batch, sequence_length, feature_dim)`.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -48,29 +48,6 @@ from .unet_3d_condition import UNet3DConditionOutput
48
48
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
49
49
 
50
50
 
51
- def _to_tensor(inputs, device):
52
- if not torch.is_tensor(inputs):
53
- # TODO: this requires sync between CPU and GPU. So try to pass `inputs` as tensors if you can
54
- # This would be a good case for the `match` statement (Python 3.10+)
55
- is_mps = device.type == "mps"
56
- if isinstance(inputs, float):
57
- dtype = torch.float32 if is_mps else torch.float64
58
- else:
59
- dtype = torch.int32 if is_mps else torch.int64
60
- inputs = torch.tensor([inputs], dtype=dtype, device=device)
61
- elif len(inputs.shape) == 0:
62
- inputs = inputs[None].to(device)
63
-
64
- return inputs
65
-
66
-
67
- def _collapse_frames_into_batch(sample: torch.Tensor) -> torch.Tensor:
68
- batch_size, channels, num_frames, height, width = sample.shape
69
- sample = sample.permute(0, 2, 1, 3, 4).reshape(batch_size * num_frames, channels, height, width)
70
-
71
- return sample
72
-
73
-
74
51
  class I2VGenXLTransformerTemporalEncoder(nn.Module):
75
52
  def __init__(
76
53
  self,
@@ -112,7 +89,7 @@ class I2VGenXLTransformerTemporalEncoder(nn.Module):
112
89
  if hidden_states.ndim == 4:
113
90
  hidden_states = hidden_states.squeeze(1)
114
91
 
115
- ff_output = self.ff(hidden_states, scale=1.0)
92
+ ff_output = self.ff(hidden_states)
116
93
  hidden_states = ff_output + hidden_states
117
94
  if hidden_states.ndim == 4:
118
95
  hidden_states = hidden_states.squeeze(1)
@@ -143,6 +120,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
143
120
  norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for the normalization.
144
121
  If `None`, normalization and activation layers is skipped in post-processing.
145
122
  cross_attention_dim (`int`, *optional*, defaults to 1280): The dimension of the cross attention features.
123
+ attention_head_dim (`int`, *optional*, defaults to 64): Attention head dim.
146
124
  num_attention_heads (`int`, *optional*): The number of attention heads.
147
125
  """
148
126
 
@@ -170,11 +148,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
170
148
  layers_per_block: int = 2,
171
149
  norm_num_groups: Optional[int] = 32,
172
150
  cross_attention_dim: int = 1024,
173
- num_attention_heads: Optional[Union[int, Tuple[int]]] = 64,
151
+ attention_head_dim: Union[int, Tuple[int]] = 64,
152
+ num_attention_heads: Optional[Union[int, Tuple[int]]] = None,
174
153
  ):
175
154
  super().__init__()
176
155
 
177
- self.sample_size = sample_size
156
+ # When we first integrated the UNet into the library, we didn't have `attention_head_dim`. As a consequence
157
+ # of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
158
+ # is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
159
+ # This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
160
+ # without running proper depcrecation cycles for the {down,mid,up} blocks which are a
161
+ # part of the public API.
162
+ num_attention_heads = attention_head_dim
178
163
 
179
164
  # Check inputs
180
165
  if len(down_block_types) != len(up_block_types):
@@ -489,6 +474,44 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
489
474
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
490
475
  setattr(upsample_block, k, None)
491
476
 
477
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
478
+ def fuse_qkv_projections(self):
479
+ """
480
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
481
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
482
+
483
+ <Tip warning={true}>
484
+
485
+ This API is 🧪 experimental.
486
+
487
+ </Tip>
488
+ """
489
+ self.original_attn_processors = None
490
+
491
+ for _, attn_processor in self.attn_processors.items():
492
+ if "Added" in str(attn_processor.__class__.__name__):
493
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
494
+
495
+ self.original_attn_processors = self.attn_processors
496
+
497
+ for module in self.modules():
498
+ if isinstance(module, Attention):
499
+ module.fuse_projections(fuse=True)
500
+
501
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
502
+ def unfuse_qkv_projections(self):
503
+ """Disables the fused QKV projection if enabled.
504
+
505
+ <Tip warning={true}>
506
+
507
+ This API is 🧪 experimental.
508
+
509
+ </Tip>
510
+
511
+ """
512
+ if self.original_attn_processors is not None:
513
+ self.set_attn_processor(self.original_attn_processors)
514
+
492
515
  def forward(
493
516
  self,
494
517
  sample: torch.FloatTensor,
@@ -543,7 +566,18 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
543
566
  forward_upsample_size = True
544
567
 
545
568
  # 1. time
546
- timesteps = _to_tensor(timestep, sample.device)
569
+ timesteps = timestep
570
+ if not torch.is_tensor(timesteps):
571
+ # TODO: this requires sync between CPU and GPU. So try to pass `timesteps` as tensors if you can
572
+ # This would be a good case for the `match` statement (Python 3.10+)
573
+ is_mps = sample.device.type == "mps"
574
+ if isinstance(timesteps, float):
575
+ dtype = torch.float32 if is_mps else torch.float64
576
+ else:
577
+ dtype = torch.int32 if is_mps else torch.int64
578
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
579
+ elif len(timesteps.shape) == 0:
580
+ timesteps = timesteps[None].to(sample.device)
547
581
 
548
582
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
549
583
  timesteps = timesteps.expand(sample.shape[0])
@@ -572,7 +606,13 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
572
606
  context_emb = sample.new_zeros(batch_size, 0, self.config.cross_attention_dim)
573
607
  context_emb = torch.cat([context_emb, encoder_hidden_states], dim=1)
574
608
 
575
- image_latents_context_embs = _collapse_frames_into_batch(image_latents[:, :, :1, :])
609
+ image_latents_for_context_embds = image_latents[:, :, :1, :]
610
+ image_latents_context_embs = image_latents_for_context_embds.permute(0, 2, 1, 3, 4).reshape(
611
+ image_latents_for_context_embds.shape[0] * image_latents_for_context_embds.shape[2],
612
+ image_latents_for_context_embds.shape[1],
613
+ image_latents_for_context_embds.shape[3],
614
+ image_latents_for_context_embds.shape[4],
615
+ )
576
616
  image_latents_context_embs = self.image_latents_context_embedding(image_latents_context_embs)
577
617
 
578
618
  _batch_size, _channels, _height, _width = image_latents_context_embs.shape
@@ -586,7 +626,12 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
586
626
  context_emb = torch.cat([context_emb, image_emb], dim=1)
587
627
  context_emb = context_emb.repeat_interleave(repeats=num_frames, dim=0)
588
628
 
589
- image_latents = _collapse_frames_into_batch(image_latents)
629
+ image_latents = image_latents.permute(0, 2, 1, 3, 4).reshape(
630
+ image_latents.shape[0] * image_latents.shape[2],
631
+ image_latents.shape[1],
632
+ image_latents.shape[3],
633
+ image_latents.shape[4],
634
+ )
590
635
  image_latents = self.image_latents_proj_in(image_latents)
591
636
  image_latents = (
592
637
  image_latents[None, :]
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -23,6 +23,7 @@ from ...utils import logging
23
23
  from ..attention_processor import (
24
24
  ADDED_KV_ATTENTION_PROCESSORS,
25
25
  CROSS_ATTENTION_PROCESSORS,
26
+ Attention,
26
27
  AttentionProcessor,
27
28
  AttnAddedKVProcessor,
28
29
  AttnProcessor,
@@ -217,6 +218,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
217
218
  use_motion_mid_block: int = True,
218
219
  encoder_hid_dim: Optional[int] = None,
219
220
  encoder_hid_dim_type: Optional[str] = None,
221
+ time_cond_proj_dim: Optional[int] = None,
220
222
  ):
221
223
  super().__init__()
222
224
 
@@ -252,9 +254,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
252
254
  timestep_input_dim = block_out_channels[0]
253
255
 
254
256
  self.time_embedding = TimestepEmbedding(
255
- timestep_input_dim,
256
- time_embed_dim,
257
- act_fn=act_fn,
257
+ timestep_input_dim, time_embed_dim, act_fn=act_fn, cond_proj_dim=time_cond_proj_dim
258
258
  )
259
259
 
260
260
  if encoder_hid_dim_type is None:
@@ -306,6 +306,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
306
306
  num_attention_heads=num_attention_heads[-1],
307
307
  resnet_groups=norm_num_groups,
308
308
  dual_cross_attention=False,
309
+ use_linear_projection=use_linear_projection,
309
310
  temporal_num_attention_heads=motion_num_attention_heads,
310
311
  temporal_max_seq_length=motion_max_seq_length,
311
312
  )
@@ -321,6 +322,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
321
322
  num_attention_heads=num_attention_heads[-1],
322
323
  resnet_groups=norm_num_groups,
323
324
  dual_cross_attention=False,
325
+ use_linear_projection=use_linear_projection,
324
326
  )
325
327
 
326
328
  # count how many layers upsample the images
@@ -700,6 +702,44 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
700
702
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
701
703
  setattr(upsample_block, k, None)
702
704
 
705
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
706
+ def fuse_qkv_projections(self):
707
+ """
708
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
709
+ key, value) are fused. For cross-attention modules, key and value projection matrices are fused.
710
+
711
+ <Tip warning={true}>
712
+
713
+ This API is 🧪 experimental.
714
+
715
+ </Tip>
716
+ """
717
+ self.original_attn_processors = None
718
+
719
+ for _, attn_processor in self.attn_processors.items():
720
+ if "Added" in str(attn_processor.__class__.__name__):
721
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
722
+
723
+ self.original_attn_processors = self.attn_processors
724
+
725
+ for module in self.modules():
726
+ if isinstance(module, Attention):
727
+ module.fuse_projections(fuse=True)
728
+
729
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
730
+ def unfuse_qkv_projections(self):
731
+ """Disables the fused QKV projection if enabled.
732
+
733
+ <Tip warning={true}>
734
+
735
+ This API is 🧪 experimental.
736
+
737
+ </Tip>
738
+
739
+ """
740
+ if self.original_attn_processors is not None:
741
+ self.set_attn_processor(self.original_attn_processors)
742
+
703
743
  def forward(
704
744
  self,
705
745
  sample: torch.FloatTensor,
@@ -792,6 +832,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
792
832
 
793
833
  emb = self.time_embedding(t_emb, timestep_cond)
794
834
  emb = emb.repeat_interleave(repeats=num_frames, dim=0)
835
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
795
836
 
796
837
  if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
797
838
  if "image_embeds" not in added_cond_kwargs:
@@ -799,10 +840,9 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
799
840
  f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_conditions`"
800
841
  )
801
842
  image_embeds = added_cond_kwargs.get("image_embeds")
802
- image_embeds = self.encoder_hid_proj(image_embeds).to(encoder_hidden_states.dtype)
803
- encoder_hidden_states = torch.cat([encoder_hidden_states, image_embeds], dim=1)
804
-
805
- encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
843
+ image_embeds = self.encoder_hid_proj(image_embeds)
844
+ image_embeds = [image_embed.repeat_interleave(repeats=num_frames, dim=0) for image_embed in image_embeds]
845
+ encoder_hidden_states = (encoder_hidden_states, image_embeds)
806
846
 
807
847
  # 2. pre-process
808
848
  sample = sample.permute(0, 2, 1, 3, 4).reshape((sample.shape[0] * num_frames, -1) + sample.shape[3:])