diffusers 0.27.1__py3-none-any.whl → 0.32.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (445) hide show
  1. diffusers/__init__.py +233 -6
  2. diffusers/callbacks.py +209 -0
  3. diffusers/commands/env.py +102 -6
  4. diffusers/configuration_utils.py +45 -16
  5. diffusers/dependency_versions_table.py +4 -3
  6. diffusers/image_processor.py +434 -110
  7. diffusers/loaders/__init__.py +42 -9
  8. diffusers/loaders/ip_adapter.py +626 -36
  9. diffusers/loaders/lora_base.py +900 -0
  10. diffusers/loaders/lora_conversion_utils.py +991 -125
  11. diffusers/loaders/lora_pipeline.py +3812 -0
  12. diffusers/loaders/peft.py +571 -7
  13. diffusers/loaders/single_file.py +405 -173
  14. diffusers/loaders/single_file_model.py +385 -0
  15. diffusers/loaders/single_file_utils.py +1783 -713
  16. diffusers/loaders/textual_inversion.py +41 -23
  17. diffusers/loaders/transformer_flux.py +181 -0
  18. diffusers/loaders/transformer_sd3.py +89 -0
  19. diffusers/loaders/unet.py +464 -540
  20. diffusers/loaders/unet_loader_utils.py +163 -0
  21. diffusers/models/__init__.py +76 -7
  22. diffusers/models/activations.py +65 -10
  23. diffusers/models/adapter.py +53 -53
  24. diffusers/models/attention.py +605 -18
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +4304 -687
  27. diffusers/models/autoencoders/__init__.py +8 -0
  28. diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
  29. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  30. diffusers/models/autoencoders/autoencoder_kl.py +110 -28
  31. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  32. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
  33. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  34. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  35. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  36. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
  37. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  38. diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
  39. diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
  40. diffusers/models/autoencoders/vae.py +41 -29
  41. diffusers/models/autoencoders/vq_model.py +182 -0
  42. diffusers/models/controlnet.py +47 -800
  43. diffusers/models/controlnet_flux.py +70 -0
  44. diffusers/models/controlnet_sd3.py +68 -0
  45. diffusers/models/controlnet_sparsectrl.py +116 -0
  46. diffusers/models/controlnets/__init__.py +23 -0
  47. diffusers/models/controlnets/controlnet.py +872 -0
  48. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
  49. diffusers/models/controlnets/controlnet_flux.py +536 -0
  50. diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
  51. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  52. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  53. diffusers/models/controlnets/controlnet_union.py +832 -0
  54. diffusers/models/controlnets/controlnet_xs.py +1946 -0
  55. diffusers/models/controlnets/multicontrolnet.py +183 -0
  56. diffusers/models/downsampling.py +85 -18
  57. diffusers/models/embeddings.py +1856 -158
  58. diffusers/models/embeddings_flax.py +23 -9
  59. diffusers/models/model_loading_utils.py +480 -0
  60. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  61. diffusers/models/modeling_flax_utils.py +2 -7
  62. diffusers/models/modeling_outputs.py +14 -0
  63. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  64. diffusers/models/modeling_utils.py +611 -146
  65. diffusers/models/normalization.py +361 -20
  66. diffusers/models/resnet.py +18 -23
  67. diffusers/models/transformers/__init__.py +16 -0
  68. diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
  69. diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
  70. diffusers/models/transformers/dit_transformer_2d.py +240 -0
  71. diffusers/models/transformers/dual_transformer_2d.py +9 -8
  72. diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
  73. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  74. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  75. diffusers/models/transformers/pixart_transformer_2d.py +445 -0
  76. diffusers/models/transformers/prior_transformer.py +13 -13
  77. diffusers/models/transformers/sana_transformer.py +488 -0
  78. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  79. diffusers/models/transformers/t5_film_transformer.py +17 -19
  80. diffusers/models/transformers/transformer_2d.py +297 -187
  81. diffusers/models/transformers/transformer_allegro.py +422 -0
  82. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  83. diffusers/models/transformers/transformer_flux.py +593 -0
  84. diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
  85. diffusers/models/transformers/transformer_ltx.py +469 -0
  86. diffusers/models/transformers/transformer_mochi.py +499 -0
  87. diffusers/models/transformers/transformer_sd3.py +461 -0
  88. diffusers/models/transformers/transformer_temporal.py +21 -19
  89. diffusers/models/unets/unet_1d.py +8 -8
  90. diffusers/models/unets/unet_1d_blocks.py +31 -31
  91. diffusers/models/unets/unet_2d.py +17 -10
  92. diffusers/models/unets/unet_2d_blocks.py +225 -149
  93. diffusers/models/unets/unet_2d_condition.py +41 -40
  94. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  95. diffusers/models/unets/unet_3d_blocks.py +192 -1057
  96. diffusers/models/unets/unet_3d_condition.py +22 -27
  97. diffusers/models/unets/unet_i2vgen_xl.py +22 -18
  98. diffusers/models/unets/unet_kandinsky3.py +2 -2
  99. diffusers/models/unets/unet_motion_model.py +1413 -89
  100. diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
  101. diffusers/models/unets/unet_stable_cascade.py +19 -18
  102. diffusers/models/unets/uvit_2d.py +2 -2
  103. diffusers/models/upsampling.py +95 -26
  104. diffusers/models/vq_model.py +12 -164
  105. diffusers/optimization.py +1 -1
  106. diffusers/pipelines/__init__.py +202 -3
  107. diffusers/pipelines/allegro/__init__.py +48 -0
  108. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  109. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  110. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  111. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  112. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  113. diffusers/pipelines/animatediff/__init__.py +8 -0
  114. diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
  115. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
  116. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
  117. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
  118. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
  119. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  120. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  121. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  122. diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
  123. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
  124. diffusers/pipelines/aura_flow/__init__.py +48 -0
  125. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
  126. diffusers/pipelines/auto_pipeline.py +196 -28
  127. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  128. diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
  129. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  130. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  131. diffusers/pipelines/cogvideo/__init__.py +54 -0
  132. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
  133. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  134. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
  135. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
  136. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  137. diffusers/pipelines/cogview3/__init__.py +47 -0
  138. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  139. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  140. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
  141. diffusers/pipelines/controlnet/__init__.py +86 -80
  142. diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
  143. diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
  144. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  145. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
  146. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
  147. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
  148. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
  149. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
  150. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  151. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  152. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  153. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  154. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  155. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
  156. diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
  157. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
  158. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  159. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  160. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
  161. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
  162. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  163. diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
  164. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
  165. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
  166. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
  167. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
  168. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
  169. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  170. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  171. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  172. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
  173. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
  174. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  175. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  176. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
  177. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
  178. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
  179. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
  180. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
  181. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
  182. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  183. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  184. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  185. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  186. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  187. diffusers/pipelines/dit/pipeline_dit.py +7 -4
  188. diffusers/pipelines/flux/__init__.py +69 -0
  189. diffusers/pipelines/flux/modeling_flux.py +47 -0
  190. diffusers/pipelines/flux/pipeline_flux.py +957 -0
  191. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  192. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  193. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  194. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  195. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  196. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  197. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  198. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  199. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  200. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  201. diffusers/pipelines/flux/pipeline_output.py +37 -0
  202. diffusers/pipelines/free_init_utils.py +41 -38
  203. diffusers/pipelines/free_noise_utils.py +596 -0
  204. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  205. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  206. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  207. diffusers/pipelines/hunyuandit/__init__.py +48 -0
  208. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
  209. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  210. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  211. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
  212. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  213. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  214. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  215. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  216. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
  217. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  218. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  219. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  220. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  221. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  222. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  223. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  224. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
  225. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
  226. diffusers/pipelines/kolors/__init__.py +54 -0
  227. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  228. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
  229. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  230. diffusers/pipelines/kolors/text_encoder.py +889 -0
  231. diffusers/pipelines/kolors/tokenizer.py +338 -0
  232. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
  233. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
  234. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
  235. diffusers/pipelines/latte/__init__.py +48 -0
  236. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  237. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
  238. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
  239. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  240. diffusers/pipelines/ltx/__init__.py +50 -0
  241. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  242. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  243. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  244. diffusers/pipelines/lumina/__init__.py +48 -0
  245. diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
  246. diffusers/pipelines/marigold/__init__.py +50 -0
  247. diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
  248. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  249. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  250. diffusers/pipelines/mochi/__init__.py +48 -0
  251. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  252. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  253. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  254. diffusers/pipelines/pag/__init__.py +80 -0
  255. diffusers/pipelines/pag/pag_utils.py +243 -0
  256. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
  257. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  258. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
  259. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  260. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
  261. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  262. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
  263. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  264. diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
  265. diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
  266. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  267. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
  268. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  269. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  270. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
  271. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
  272. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
  273. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  274. diffusers/pipelines/pia/pipeline_pia.py +74 -164
  275. diffusers/pipelines/pipeline_flax_utils.py +5 -10
  276. diffusers/pipelines/pipeline_loading_utils.py +515 -53
  277. diffusers/pipelines/pipeline_utils.py +411 -222
  278. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  279. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
  280. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
  281. diffusers/pipelines/sana/__init__.py +47 -0
  282. diffusers/pipelines/sana/pipeline_output.py +21 -0
  283. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  284. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
  285. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  286. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  287. diffusers/pipelines/shap_e/renderer.py +1 -1
  288. diffusers/pipelines/stable_audio/__init__.py +50 -0
  289. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  290. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
  291. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
  292. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  293. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
  294. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  295. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
  296. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  297. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  298. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  299. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
  300. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
  301. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  302. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
  303. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
  304. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
  305. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
  306. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
  307. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
  308. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
  309. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  310. diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
  311. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  312. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
  313. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
  314. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
  315. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
  316. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
  317. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
  318. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
  319. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
  320. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
  321. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
  322. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
  323. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  324. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  325. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
  326. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
  327. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
  328. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
  329. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
  330. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  331. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  332. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
  333. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
  334. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  335. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
  336. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
  337. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
  338. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
  339. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  340. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  341. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  342. diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
  343. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
  344. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  345. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  346. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
  347. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  348. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  349. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
  350. diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
  351. diffusers/quantizers/auto.py +139 -0
  352. diffusers/quantizers/base.py +233 -0
  353. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  354. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  355. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  356. diffusers/quantizers/gguf/__init__.py +1 -0
  357. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  358. diffusers/quantizers/gguf/utils.py +456 -0
  359. diffusers/quantizers/quantization_config.py +669 -0
  360. diffusers/quantizers/torchao/__init__.py +15 -0
  361. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  362. diffusers/schedulers/__init__.py +12 -2
  363. diffusers/schedulers/deprecated/__init__.py +1 -1
  364. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  365. diffusers/schedulers/scheduling_amused.py +5 -5
  366. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  367. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  368. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  369. diffusers/schedulers/scheduling_ddim.py +27 -26
  370. diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
  371. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  372. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  373. diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
  374. diffusers/schedulers/scheduling_ddpm.py +27 -30
  375. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  376. diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
  377. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  378. diffusers/schedulers/scheduling_deis_multistep.py +150 -50
  379. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  380. diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
  381. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  382. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
  383. diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
  384. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
  385. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
  386. diffusers/schedulers/scheduling_edm_euler.py +62 -39
  387. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
  388. diffusers/schedulers/scheduling_euler_discrete.py +255 -74
  389. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
  390. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
  391. diffusers/schedulers/scheduling_heun_discrete.py +174 -46
  392. diffusers/schedulers/scheduling_ipndm.py +9 -9
  393. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
  394. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
  395. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  396. diffusers/schedulers/scheduling_lcm.py +23 -29
  397. diffusers/schedulers/scheduling_lms_discrete.py +105 -28
  398. diffusers/schedulers/scheduling_pndm.py +20 -20
  399. diffusers/schedulers/scheduling_repaint.py +21 -21
  400. diffusers/schedulers/scheduling_sasolver.py +157 -60
  401. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  402. diffusers/schedulers/scheduling_tcd.py +41 -36
  403. diffusers/schedulers/scheduling_unclip.py +19 -16
  404. diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
  405. diffusers/schedulers/scheduling_utils.py +12 -5
  406. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  407. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  408. diffusers/training_utils.py +214 -30
  409. diffusers/utils/__init__.py +17 -1
  410. diffusers/utils/constants.py +3 -0
  411. diffusers/utils/doc_utils.py +1 -0
  412. diffusers/utils/dummy_pt_objects.py +592 -7
  413. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  414. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  415. diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
  416. diffusers/utils/dynamic_modules_utils.py +34 -29
  417. diffusers/utils/export_utils.py +50 -6
  418. diffusers/utils/hub_utils.py +131 -17
  419. diffusers/utils/import_utils.py +210 -8
  420. diffusers/utils/loading_utils.py +118 -5
  421. diffusers/utils/logging.py +4 -2
  422. diffusers/utils/peft_utils.py +37 -7
  423. diffusers/utils/state_dict_utils.py +13 -2
  424. diffusers/utils/testing_utils.py +193 -11
  425. diffusers/utils/torch_utils.py +4 -0
  426. diffusers/video_processor.py +113 -0
  427. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
  428. diffusers-0.32.2.dist-info/RECORD +550 -0
  429. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
  430. diffusers/loaders/autoencoder.py +0 -146
  431. diffusers/loaders/controlnet.py +0 -136
  432. diffusers/loaders/lora.py +0 -1349
  433. diffusers/models/prior_transformer.py +0 -12
  434. diffusers/models/t5_film_transformer.py +0 -70
  435. diffusers/models/transformer_2d.py +0 -25
  436. diffusers/models/transformer_temporal.py +0 -34
  437. diffusers/models/unet_1d.py +0 -26
  438. diffusers/models/unet_1d_blocks.py +0 -203
  439. diffusers/models/unet_2d.py +0 -27
  440. diffusers/models/unet_2d_blocks.py +0 -375
  441. diffusers/models/unet_2d_condition.py +0 -25
  442. diffusers-0.27.1.dist-info/RECORD +0 -399
  443. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
  444. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
  445. {diffusers-0.27.1.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -561,7 +561,7 @@ class AutoencoderTinyBlock(nn.Module):
561
561
  ` The activation function to use. Supported values are `"swish"`, `"mish"`, `"gelu"`, and `"relu"`.
562
562
 
563
563
  Returns:
564
- `torch.FloatTensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
564
+ `torch.Tensor`: A tensor with the same shape as the input tensor, but with the number of channels equal to
565
565
  `out_channels`.
566
566
  """
567
567
 
@@ -582,7 +582,7 @@ class AutoencoderTinyBlock(nn.Module):
582
582
  )
583
583
  self.fuse = nn.ReLU()
584
584
 
585
- def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
585
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
586
586
  return self.fuse(self.conv(x) + self.skip(x))
587
587
 
588
588
 
@@ -612,8 +612,8 @@ class UNetMidBlock2D(nn.Module):
612
612
  output_scale_factor (`float`, *optional*, defaults to 1.0): The output scale factor.
613
613
 
614
614
  Returns:
615
- `torch.FloatTensor`: The output of the last residual block, which is a tensor of shape `(batch_size,
616
- in_channels, height, width)`.
615
+ `torch.Tensor`: The output of the last residual block, which is a tensor of shape `(batch_size, in_channels,
616
+ height, width)`.
617
617
 
618
618
  """
619
619
 
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
731
731
  self.attentions = nn.ModuleList(attentions)
732
732
  self.resnets = nn.ModuleList(resnets)
733
733
 
734
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
734
+ self.gradient_checkpointing = False
735
+
736
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
735
737
  hidden_states = self.resnets[0](hidden_states, temb)
736
738
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
737
- if attn is not None:
738
- hidden_states = attn(hidden_states, temb=temb)
739
- hidden_states = resnet(hidden_states, temb)
739
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
740
+
741
+ def create_custom_forward(module, return_dict=None):
742
+ def custom_forward(*inputs):
743
+ if return_dict is not None:
744
+ return module(*inputs, return_dict=return_dict)
745
+ else:
746
+ return module(*inputs)
747
+
748
+ return custom_forward
749
+
750
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
751
+ if attn is not None:
752
+ hidden_states = attn(hidden_states, temb=temb)
753
+ hidden_states = torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(resnet),
755
+ hidden_states,
756
+ temb,
757
+ **ckpt_kwargs,
758
+ )
759
+ else:
760
+ if attn is not None:
761
+ hidden_states = attn(hidden_states, temb=temb)
762
+ hidden_states = resnet(hidden_states, temb)
740
763
 
741
764
  return hidden_states
742
765
 
@@ -746,6 +769,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
746
769
  self,
747
770
  in_channels: int,
748
771
  temb_channels: int,
772
+ out_channels: Optional[int] = None,
749
773
  dropout: float = 0.0,
750
774
  num_layers: int = 1,
751
775
  transformer_layers_per_block: Union[int, Tuple[int]] = 1,
@@ -753,6 +777,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
753
777
  resnet_time_scale_shift: str = "default",
754
778
  resnet_act_fn: str = "swish",
755
779
  resnet_groups: int = 32,
780
+ resnet_groups_out: Optional[int] = None,
756
781
  resnet_pre_norm: bool = True,
757
782
  num_attention_heads: int = 1,
758
783
  output_scale_factor: float = 1.0,
@@ -764,6 +789,10 @@ class UNetMidBlock2DCrossAttn(nn.Module):
764
789
  ):
765
790
  super().__init__()
766
791
 
792
+ out_channels = out_channels or in_channels
793
+ self.in_channels = in_channels
794
+ self.out_channels = out_channels
795
+
767
796
  self.has_cross_attention = True
768
797
  self.num_attention_heads = num_attention_heads
769
798
  resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
@@ -772,14 +801,17 @@ class UNetMidBlock2DCrossAttn(nn.Module):
772
801
  if isinstance(transformer_layers_per_block, int):
773
802
  transformer_layers_per_block = [transformer_layers_per_block] * num_layers
774
803
 
804
+ resnet_groups_out = resnet_groups_out or resnet_groups
805
+
775
806
  # there is always at least one resnet
776
807
  resnets = [
777
808
  ResnetBlock2D(
778
809
  in_channels=in_channels,
779
- out_channels=in_channels,
810
+ out_channels=out_channels,
780
811
  temb_channels=temb_channels,
781
812
  eps=resnet_eps,
782
813
  groups=resnet_groups,
814
+ groups_out=resnet_groups_out,
783
815
  dropout=dropout,
784
816
  time_embedding_norm=resnet_time_scale_shift,
785
817
  non_linearity=resnet_act_fn,
@@ -794,11 +826,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
794
826
  attentions.append(
795
827
  Transformer2DModel(
796
828
  num_attention_heads,
797
- in_channels // num_attention_heads,
798
- in_channels=in_channels,
829
+ out_channels // num_attention_heads,
830
+ in_channels=out_channels,
799
831
  num_layers=transformer_layers_per_block[i],
800
832
  cross_attention_dim=cross_attention_dim,
801
- norm_num_groups=resnet_groups,
833
+ norm_num_groups=resnet_groups_out,
802
834
  use_linear_projection=use_linear_projection,
803
835
  upcast_attention=upcast_attention,
804
836
  attention_type=attention_type,
@@ -808,8 +840,8 @@ class UNetMidBlock2DCrossAttn(nn.Module):
808
840
  attentions.append(
809
841
  DualTransformer2DModel(
810
842
  num_attention_heads,
811
- in_channels // num_attention_heads,
812
- in_channels=in_channels,
843
+ out_channels // num_attention_heads,
844
+ in_channels=out_channels,
813
845
  num_layers=1,
814
846
  cross_attention_dim=cross_attention_dim,
815
847
  norm_num_groups=resnet_groups,
@@ -817,11 +849,11 @@ class UNetMidBlock2DCrossAttn(nn.Module):
817
849
  )
818
850
  resnets.append(
819
851
  ResnetBlock2D(
820
- in_channels=in_channels,
821
- out_channels=in_channels,
852
+ in_channels=out_channels,
853
+ out_channels=out_channels,
822
854
  temb_channels=temb_channels,
823
855
  eps=resnet_eps,
824
- groups=resnet_groups,
856
+ groups=resnet_groups_out,
825
857
  dropout=dropout,
826
858
  time_embedding_norm=resnet_time_scale_shift,
827
859
  non_linearity=resnet_act_fn,
@@ -837,20 +869,20 @@ class UNetMidBlock2DCrossAttn(nn.Module):
837
869
 
838
870
  def forward(
839
871
  self,
840
- hidden_states: torch.FloatTensor,
841
- temb: Optional[torch.FloatTensor] = None,
842
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
843
- attention_mask: Optional[torch.FloatTensor] = None,
872
+ hidden_states: torch.Tensor,
873
+ temb: Optional[torch.Tensor] = None,
874
+ encoder_hidden_states: Optional[torch.Tensor] = None,
875
+ attention_mask: Optional[torch.Tensor] = None,
844
876
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
845
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
846
- ) -> torch.FloatTensor:
877
+ encoder_attention_mask: Optional[torch.Tensor] = None,
878
+ ) -> torch.Tensor:
847
879
  if cross_attention_kwargs is not None:
848
880
  if cross_attention_kwargs.get("scale", None) is not None:
849
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
881
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
850
882
 
851
883
  hidden_states = self.resnets[0](hidden_states, temb)
852
884
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
853
- if self.training and self.gradient_checkpointing:
885
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
854
886
 
855
887
  def create_custom_forward(module, return_dict=None):
856
888
  def custom_forward(*inputs):
@@ -977,16 +1009,16 @@ class UNetMidBlock2DSimpleCrossAttn(nn.Module):
977
1009
 
978
1010
  def forward(
979
1011
  self,
980
- hidden_states: torch.FloatTensor,
981
- temb: Optional[torch.FloatTensor] = None,
982
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
983
- attention_mask: Optional[torch.FloatTensor] = None,
1012
+ hidden_states: torch.Tensor,
1013
+ temb: Optional[torch.Tensor] = None,
1014
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1015
+ attention_mask: Optional[torch.Tensor] = None,
984
1016
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
985
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
986
- ) -> torch.FloatTensor:
1017
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1018
+ ) -> torch.Tensor:
987
1019
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
988
1020
  if cross_attention_kwargs.get("scale", None) is not None:
989
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1021
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
990
1022
 
991
1023
  if attention_mask is None:
992
1024
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -1107,23 +1139,46 @@ class AttnDownBlock2D(nn.Module):
1107
1139
  else:
1108
1140
  self.downsamplers = None
1109
1141
 
1142
+ self.gradient_checkpointing = False
1143
+
1110
1144
  def forward(
1111
1145
  self,
1112
- hidden_states: torch.FloatTensor,
1113
- temb: Optional[torch.FloatTensor] = None,
1146
+ hidden_states: torch.Tensor,
1147
+ temb: Optional[torch.Tensor] = None,
1114
1148
  upsample_size: Optional[int] = None,
1115
1149
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1116
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1150
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1117
1151
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1118
1152
  if cross_attention_kwargs.get("scale", None) is not None:
1119
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1153
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1120
1154
 
1121
1155
  output_states = ()
1122
1156
 
1123
1157
  for resnet, attn in zip(self.resnets, self.attentions):
1124
- hidden_states = resnet(hidden_states, temb)
1125
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
1126
- output_states = output_states + (hidden_states,)
1158
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1159
+
1160
+ def create_custom_forward(module, return_dict=None):
1161
+ def custom_forward(*inputs):
1162
+ if return_dict is not None:
1163
+ return module(*inputs, return_dict=return_dict)
1164
+ else:
1165
+ return module(*inputs)
1166
+
1167
+ return custom_forward
1168
+
1169
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170
+ hidden_states = torch.utils.checkpoint.checkpoint(
1171
+ create_custom_forward(resnet),
1172
+ hidden_states,
1173
+ temb,
1174
+ **ckpt_kwargs,
1175
+ )
1176
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1177
+ output_states = output_states + (hidden_states,)
1178
+ else:
1179
+ hidden_states = resnet(hidden_states, temb)
1180
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1181
+ output_states = output_states + (hidden_states,)
1127
1182
 
1128
1183
  if self.downsamplers is not None:
1129
1184
  for downsampler in self.downsamplers:
@@ -1231,24 +1286,24 @@ class CrossAttnDownBlock2D(nn.Module):
1231
1286
 
1232
1287
  def forward(
1233
1288
  self,
1234
- hidden_states: torch.FloatTensor,
1235
- temb: Optional[torch.FloatTensor] = None,
1236
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1237
- attention_mask: Optional[torch.FloatTensor] = None,
1289
+ hidden_states: torch.Tensor,
1290
+ temb: Optional[torch.Tensor] = None,
1291
+ encoder_hidden_states: Optional[torch.Tensor] = None,
1292
+ attention_mask: Optional[torch.Tensor] = None,
1238
1293
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1239
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1240
- additional_residuals: Optional[torch.FloatTensor] = None,
1241
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1294
+ encoder_attention_mask: Optional[torch.Tensor] = None,
1295
+ additional_residuals: Optional[torch.Tensor] = None,
1296
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1242
1297
  if cross_attention_kwargs is not None:
1243
1298
  if cross_attention_kwargs.get("scale", None) is not None:
1244
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
1299
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1245
1300
 
1246
1301
  output_states = ()
1247
1302
 
1248
1303
  blocks = list(zip(self.resnets, self.attentions))
1249
1304
 
1250
1305
  for i, (resnet, attn) in enumerate(blocks):
1251
- if self.training and self.gradient_checkpointing:
1306
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1252
1307
 
1253
1308
  def create_custom_forward(module, return_dict=None):
1254
1309
  def custom_forward(*inputs):
@@ -1353,8 +1408,8 @@ class DownBlock2D(nn.Module):
1353
1408
  self.gradient_checkpointing = False
1354
1409
 
1355
1410
  def forward(
1356
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1357
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1411
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
1412
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1358
1413
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1359
1414
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1360
1415
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1362,7 +1417,7 @@ class DownBlock2D(nn.Module):
1362
1417
  output_states = ()
1363
1418
 
1364
1419
  for resnet in self.resnets:
1365
- if self.training and self.gradient_checkpointing:
1420
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1366
1421
 
1367
1422
  def create_custom_forward(module):
1368
1423
  def custom_forward(*inputs):
@@ -1456,7 +1511,7 @@ class DownEncoderBlock2D(nn.Module):
1456
1511
  else:
1457
1512
  self.downsamplers = None
1458
1513
 
1459
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1514
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1460
1515
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1461
1516
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1462
1517
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1558,7 +1613,7 @@ class AttnDownEncoderBlock2D(nn.Module):
1558
1613
  else:
1559
1614
  self.downsamplers = None
1560
1615
 
1561
- def forward(self, hidden_states: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1616
+ def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
1562
1617
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1563
1618
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1564
1619
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1657,12 +1712,12 @@ class AttnSkipDownBlock2D(nn.Module):
1657
1712
 
1658
1713
  def forward(
1659
1714
  self,
1660
- hidden_states: torch.FloatTensor,
1661
- temb: Optional[torch.FloatTensor] = None,
1662
- skip_sample: Optional[torch.FloatTensor] = None,
1715
+ hidden_states: torch.Tensor,
1716
+ temb: Optional[torch.Tensor] = None,
1717
+ skip_sample: Optional[torch.Tensor] = None,
1663
1718
  *args,
1664
1719
  **kwargs,
1665
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1720
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
1666
1721
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1667
1722
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1668
1723
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1748,12 +1803,12 @@ class SkipDownBlock2D(nn.Module):
1748
1803
 
1749
1804
  def forward(
1750
1805
  self,
1751
- hidden_states: torch.FloatTensor,
1752
- temb: Optional[torch.FloatTensor] = None,
1753
- skip_sample: Optional[torch.FloatTensor] = None,
1806
+ hidden_states: torch.Tensor,
1807
+ temb: Optional[torch.Tensor] = None,
1808
+ skip_sample: Optional[torch.Tensor] = None,
1754
1809
  *args,
1755
1810
  **kwargs,
1756
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...], torch.FloatTensor]:
1811
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...], torch.Tensor]:
1757
1812
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1758
1813
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1759
1814
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1841,8 +1896,8 @@ class ResnetDownsampleBlock2D(nn.Module):
1841
1896
  self.gradient_checkpointing = False
1842
1897
 
1843
1898
  def forward(
1844
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
1845
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1899
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
1900
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1846
1901
  if len(args) > 0 or kwargs.get("scale", None) is not None:
1847
1902
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1848
1903
  deprecate("scale", "1.0.0", deprecation_message)
@@ -1850,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
1850
1905
  output_states = ()
1851
1906
 
1852
1907
  for resnet in self.resnets:
1853
- if self.training and self.gradient_checkpointing:
1908
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1854
1909
 
1855
1910
  def create_custom_forward(module):
1856
1911
  def custom_forward(*inputs):
@@ -1977,16 +2032,16 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
1977
2032
 
1978
2033
  def forward(
1979
2034
  self,
1980
- hidden_states: torch.FloatTensor,
1981
- temb: Optional[torch.FloatTensor] = None,
1982
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
1983
- attention_mask: Optional[torch.FloatTensor] = None,
2035
+ hidden_states: torch.Tensor,
2036
+ temb: Optional[torch.Tensor] = None,
2037
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2038
+ attention_mask: Optional[torch.Tensor] = None,
1984
2039
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
1985
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
1986
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2040
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2041
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1987
2042
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
1988
2043
  if cross_attention_kwargs.get("scale", None) is not None:
1989
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2044
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
1990
2045
 
1991
2046
  output_states = ()
1992
2047
 
@@ -2002,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
2002
2057
  mask = attention_mask
2003
2058
 
2004
2059
  for resnet, attn in zip(self.resnets, self.attentions):
2005
- if self.training and self.gradient_checkpointing:
2060
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2006
2061
 
2007
2062
  def create_custom_forward(module, return_dict=None):
2008
2063
  def custom_forward(*inputs):
@@ -2088,8 +2143,8 @@ class KDownBlock2D(nn.Module):
2088
2143
  self.gradient_checkpointing = False
2089
2144
 
2090
2145
  def forward(
2091
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, *args, **kwargs
2092
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2146
+ self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None, *args, **kwargs
2147
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2093
2148
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2094
2149
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2095
2150
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2097,7 +2152,7 @@ class KDownBlock2D(nn.Module):
2097
2152
  output_states = ()
2098
2153
 
2099
2154
  for resnet in self.resnets:
2100
- if self.training and self.gradient_checkpointing:
2155
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2101
2156
 
2102
2157
  def create_custom_forward(module):
2103
2158
  def custom_forward(*inputs):
@@ -2192,21 +2247,21 @@ class KCrossAttnDownBlock2D(nn.Module):
2192
2247
 
2193
2248
  def forward(
2194
2249
  self,
2195
- hidden_states: torch.FloatTensor,
2196
- temb: Optional[torch.FloatTensor] = None,
2197
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2198
- attention_mask: Optional[torch.FloatTensor] = None,
2250
+ hidden_states: torch.Tensor,
2251
+ temb: Optional[torch.Tensor] = None,
2252
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2253
+ attention_mask: Optional[torch.Tensor] = None,
2199
2254
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2200
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2201
- ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
2255
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2256
+ ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
2202
2257
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
2203
2258
  if cross_attention_kwargs.get("scale", None) is not None:
2204
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2259
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2205
2260
 
2206
2261
  output_states = ()
2207
2262
 
2208
2263
  for resnet, attn in zip(self.resnets, self.attentions):
2209
- if self.training and self.gradient_checkpointing:
2264
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2210
2265
 
2211
2266
  def create_custom_forward(module, return_dict=None):
2212
2267
  def custom_forward(*inputs):
@@ -2345,17 +2400,18 @@ class AttnUpBlock2D(nn.Module):
2345
2400
  else:
2346
2401
  self.upsamplers = None
2347
2402
 
2403
+ self.gradient_checkpointing = False
2348
2404
  self.resolution_idx = resolution_idx
2349
2405
 
2350
2406
  def forward(
2351
2407
  self,
2352
- hidden_states: torch.FloatTensor,
2353
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2354
- temb: Optional[torch.FloatTensor] = None,
2408
+ hidden_states: torch.Tensor,
2409
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2410
+ temb: Optional[torch.Tensor] = None,
2355
2411
  upsample_size: Optional[int] = None,
2356
2412
  *args,
2357
2413
  **kwargs,
2358
- ) -> torch.FloatTensor:
2414
+ ) -> torch.Tensor:
2359
2415
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2360
2416
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2361
2417
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2366,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
2366
2422
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2367
2423
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2368
2424
 
2369
- hidden_states = resnet(hidden_states, temb)
2370
- hidden_states = attn(hidden_states)
2425
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2426
+
2427
+ def create_custom_forward(module, return_dict=None):
2428
+ def custom_forward(*inputs):
2429
+ if return_dict is not None:
2430
+ return module(*inputs, return_dict=return_dict)
2431
+ else:
2432
+ return module(*inputs)
2433
+
2434
+ return custom_forward
2435
+
2436
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2437
+ hidden_states = torch.utils.checkpoint.checkpoint(
2438
+ create_custom_forward(resnet),
2439
+ hidden_states,
2440
+ temb,
2441
+ **ckpt_kwargs,
2442
+ )
2443
+ hidden_states = attn(hidden_states)
2444
+ else:
2445
+ hidden_states = resnet(hidden_states, temb)
2446
+ hidden_states = attn(hidden_states)
2371
2447
 
2372
2448
  if self.upsamplers is not None:
2373
2449
  for upsampler in self.upsamplers:
@@ -2472,18 +2548,18 @@ class CrossAttnUpBlock2D(nn.Module):
2472
2548
 
2473
2549
  def forward(
2474
2550
  self,
2475
- hidden_states: torch.FloatTensor,
2476
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2477
- temb: Optional[torch.FloatTensor] = None,
2478
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
2551
+ hidden_states: torch.Tensor,
2552
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2553
+ temb: Optional[torch.Tensor] = None,
2554
+ encoder_hidden_states: Optional[torch.Tensor] = None,
2479
2555
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2480
2556
  upsample_size: Optional[int] = None,
2481
- attention_mask: Optional[torch.FloatTensor] = None,
2482
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
2483
- ) -> torch.FloatTensor:
2557
+ attention_mask: Optional[torch.Tensor] = None,
2558
+ encoder_attention_mask: Optional[torch.Tensor] = None,
2559
+ ) -> torch.Tensor:
2484
2560
  if cross_attention_kwargs is not None:
2485
2561
  if cross_attention_kwargs.get("scale", None) is not None:
2486
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2562
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
2487
2563
 
2488
2564
  is_freeu_enabled = (
2489
2565
  getattr(self, "s1", None)
@@ -2511,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
2511
2587
 
2512
2588
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2513
2589
 
2514
- if self.training and self.gradient_checkpointing:
2590
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2515
2591
 
2516
2592
  def create_custom_forward(module, return_dict=None):
2517
2593
  def custom_forward(*inputs):
@@ -2607,13 +2683,13 @@ class UpBlock2D(nn.Module):
2607
2683
 
2608
2684
  def forward(
2609
2685
  self,
2610
- hidden_states: torch.FloatTensor,
2611
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2612
- temb: Optional[torch.FloatTensor] = None,
2686
+ hidden_states: torch.Tensor,
2687
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
2688
+ temb: Optional[torch.Tensor] = None,
2613
2689
  upsample_size: Optional[int] = None,
2614
2690
  *args,
2615
2691
  **kwargs,
2616
- ) -> torch.FloatTensor:
2692
+ ) -> torch.Tensor:
2617
2693
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2618
2694
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2619
2695
  deprecate("scale", "1.0.0", deprecation_message)
@@ -2644,7 +2720,7 @@ class UpBlock2D(nn.Module):
2644
2720
 
2645
2721
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2646
2722
 
2647
- if self.training and self.gradient_checkpointing:
2723
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2648
2724
 
2649
2725
  def create_custom_forward(module):
2650
2726
  def custom_forward(*inputs):
@@ -2732,7 +2808,7 @@ class UpDecoderBlock2D(nn.Module):
2732
2808
 
2733
2809
  self.resolution_idx = resolution_idx
2734
2810
 
2735
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2811
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2736
2812
  for resnet in self.resnets:
2737
2813
  hidden_states = resnet(hidden_states, temb=temb)
2738
2814
 
@@ -2830,7 +2906,7 @@ class AttnUpDecoderBlock2D(nn.Module):
2830
2906
 
2831
2907
  self.resolution_idx = resolution_idx
2832
2908
 
2833
- def forward(self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None) -> torch.FloatTensor:
2909
+ def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2834
2910
  for resnet, attn in zip(self.resnets, self.attentions):
2835
2911
  hidden_states = resnet(hidden_states, temb=temb)
2836
2912
  hidden_states = attn(hidden_states, temb=temb)
@@ -2938,13 +3014,13 @@ class AttnSkipUpBlock2D(nn.Module):
2938
3014
 
2939
3015
  def forward(
2940
3016
  self,
2941
- hidden_states: torch.FloatTensor,
2942
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
2943
- temb: Optional[torch.FloatTensor] = None,
3017
+ hidden_states: torch.Tensor,
3018
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3019
+ temb: Optional[torch.Tensor] = None,
2944
3020
  skip_sample=None,
2945
3021
  *args,
2946
3022
  **kwargs,
2947
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
3023
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
2948
3024
  if len(args) > 0 or kwargs.get("scale", None) is not None:
2949
3025
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
2950
3026
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3050,13 +3126,13 @@ class SkipUpBlock2D(nn.Module):
3050
3126
 
3051
3127
  def forward(
3052
3128
  self,
3053
- hidden_states: torch.FloatTensor,
3054
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3055
- temb: Optional[torch.FloatTensor] = None,
3129
+ hidden_states: torch.Tensor,
3130
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3131
+ temb: Optional[torch.Tensor] = None,
3056
3132
  skip_sample=None,
3057
3133
  *args,
3058
3134
  **kwargs,
3059
- ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
3135
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
3060
3136
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3061
3137
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3062
3138
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3157,13 +3233,13 @@ class ResnetUpsampleBlock2D(nn.Module):
3157
3233
 
3158
3234
  def forward(
3159
3235
  self,
3160
- hidden_states: torch.FloatTensor,
3161
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3162
- temb: Optional[torch.FloatTensor] = None,
3236
+ hidden_states: torch.Tensor,
3237
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3238
+ temb: Optional[torch.Tensor] = None,
3163
3239
  upsample_size: Optional[int] = None,
3164
3240
  *args,
3165
3241
  **kwargs,
3166
- ) -> torch.FloatTensor:
3242
+ ) -> torch.Tensor:
3167
3243
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3168
3244
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3169
3245
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3174,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
3174
3250
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3175
3251
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3176
3252
 
3177
- if self.training and self.gradient_checkpointing:
3253
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3178
3254
 
3179
3255
  def create_custom_forward(module):
3180
3256
  def custom_forward(*inputs):
@@ -3301,18 +3377,18 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3301
3377
 
3302
3378
  def forward(
3303
3379
  self,
3304
- hidden_states: torch.FloatTensor,
3305
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3306
- temb: Optional[torch.FloatTensor] = None,
3307
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3380
+ hidden_states: torch.Tensor,
3381
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3382
+ temb: Optional[torch.Tensor] = None,
3383
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3308
3384
  upsample_size: Optional[int] = None,
3309
- attention_mask: Optional[torch.FloatTensor] = None,
3385
+ attention_mask: Optional[torch.Tensor] = None,
3310
3386
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3311
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3312
- ) -> torch.FloatTensor:
3387
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3388
+ ) -> torch.Tensor:
3313
3389
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3314
3390
  if cross_attention_kwargs.get("scale", None) is not None:
3315
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3391
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
3316
3392
 
3317
3393
  if attention_mask is None:
3318
3394
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -3332,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3332
3408
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3333
3409
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3334
3410
 
3335
- if self.training and self.gradient_checkpointing:
3411
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3336
3412
 
3337
3413
  def create_custom_forward(module, return_dict=None):
3338
3414
  def custom_forward(*inputs):
@@ -3419,13 +3495,13 @@ class KUpBlock2D(nn.Module):
3419
3495
 
3420
3496
  def forward(
3421
3497
  self,
3422
- hidden_states: torch.FloatTensor,
3423
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3424
- temb: Optional[torch.FloatTensor] = None,
3498
+ hidden_states: torch.Tensor,
3499
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3500
+ temb: Optional[torch.Tensor] = None,
3425
3501
  upsample_size: Optional[int] = None,
3426
3502
  *args,
3427
3503
  **kwargs,
3428
- ) -> torch.FloatTensor:
3504
+ ) -> torch.Tensor:
3429
3505
  if len(args) > 0 or kwargs.get("scale", None) is not None:
3430
3506
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
3431
3507
  deprecate("scale", "1.0.0", deprecation_message)
@@ -3435,7 +3511,7 @@ class KUpBlock2D(nn.Module):
3435
3511
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3436
3512
 
3437
3513
  for resnet in self.resnets:
3438
- if self.training and self.gradient_checkpointing:
3514
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3439
3515
 
3440
3516
  def create_custom_forward(module):
3441
3517
  def custom_forward(*inputs):
@@ -3549,21 +3625,21 @@ class KCrossAttnUpBlock2D(nn.Module):
3549
3625
 
3550
3626
  def forward(
3551
3627
  self,
3552
- hidden_states: torch.FloatTensor,
3553
- res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
3554
- temb: Optional[torch.FloatTensor] = None,
3555
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3628
+ hidden_states: torch.Tensor,
3629
+ res_hidden_states_tuple: Tuple[torch.Tensor, ...],
3630
+ temb: Optional[torch.Tensor] = None,
3631
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3556
3632
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3557
3633
  upsample_size: Optional[int] = None,
3558
- attention_mask: Optional[torch.FloatTensor] = None,
3559
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3560
- ) -> torch.FloatTensor:
3634
+ attention_mask: Optional[torch.Tensor] = None,
3635
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3636
+ ) -> torch.Tensor:
3561
3637
  res_hidden_states_tuple = res_hidden_states_tuple[-1]
3562
3638
  if res_hidden_states_tuple is not None:
3563
3639
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3564
3640
 
3565
3641
  for resnet, attn in zip(self.resnets, self.attentions):
3566
- if self.training and self.gradient_checkpointing:
3642
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3567
3643
 
3568
3644
  def create_custom_forward(module, return_dict=None):
3569
3645
  def custom_forward(*inputs):
@@ -3675,26 +3751,26 @@ class KAttentionBlock(nn.Module):
3675
3751
  cross_attention_norm=cross_attention_norm,
3676
3752
  )
3677
3753
 
3678
- def _to_3d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
3754
+ def _to_3d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
3679
3755
  return hidden_states.permute(0, 2, 3, 1).reshape(hidden_states.shape[0], height * weight, -1)
3680
3756
 
3681
- def _to_4d(self, hidden_states: torch.FloatTensor, height: int, weight: int) -> torch.FloatTensor:
3757
+ def _to_4d(self, hidden_states: torch.Tensor, height: int, weight: int) -> torch.Tensor:
3682
3758
  return hidden_states.permute(0, 2, 1).reshape(hidden_states.shape[0], -1, height, weight)
3683
3759
 
3684
3760
  def forward(
3685
3761
  self,
3686
- hidden_states: torch.FloatTensor,
3687
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
3762
+ hidden_states: torch.Tensor,
3763
+ encoder_hidden_states: Optional[torch.Tensor] = None,
3688
3764
  # TODO: mark emb as non-optional (self.norm2 requires it).
3689
3765
  # requires assessing impact of change to positional param interface.
3690
- emb: Optional[torch.FloatTensor] = None,
3691
- attention_mask: Optional[torch.FloatTensor] = None,
3766
+ emb: Optional[torch.Tensor] = None,
3767
+ attention_mask: Optional[torch.Tensor] = None,
3692
3768
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
3693
- encoder_attention_mask: Optional[torch.FloatTensor] = None,
3694
- ) -> torch.FloatTensor:
3769
+ encoder_attention_mask: Optional[torch.Tensor] = None,
3770
+ ) -> torch.Tensor:
3695
3771
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
3696
3772
  if cross_attention_kwargs.get("scale", None) is not None:
3697
- logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
3773
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
3698
3774
 
3699
3775
  # 1. Self-Attention
3700
3776
  if self.add_self_attention: