diffusers 0.32.2__py3-none-any.whl → 0.33.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +595 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +724 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +727 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/METADATA +21 -4
  384. diffusers-0.33.1.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.1.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import torch.nn.functional as F
18
18
  from torch import nn
19
19
 
20
20
  from ...configuration_utils import ConfigMixin, register_to_config
21
- from ...utils import is_torch_version, logging
21
+ from ...utils import logging
22
22
  from ..attention import BasicTransformerBlock
23
23
  from ..embeddings import PatchEmbed
24
24
  from ..modeling_outputs import Transformer2DModelOutput
@@ -64,7 +64,9 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
64
64
  A small constant added to the denominator in normalization layers to prevent division by zero.
65
65
  """
66
66
 
67
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
67
68
  _supports_gradient_checkpointing = True
69
+ _supports_group_offloading = False
68
70
 
69
71
  @register_to_config
70
72
  def __init__(
@@ -143,10 +145,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
143
145
  self.inner_dim, self.config.patch_size * self.config.patch_size * self.out_channels
144
146
  )
145
147
 
146
- def _set_gradient_checkpointing(self, module, value=False):
147
- if hasattr(module, "gradient_checkpointing"):
148
- module.gradient_checkpointing = value
149
-
150
148
  def forward(
151
149
  self,
152
150
  hidden_states: torch.Tensor,
@@ -185,19 +183,8 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
185
183
  # 2. Blocks
186
184
  for block in self.transformer_blocks:
187
185
  if torch.is_grad_enabled() and self.gradient_checkpointing:
188
-
189
- def create_custom_forward(module, return_dict=None):
190
- def custom_forward(*inputs):
191
- if return_dict is not None:
192
- return module(*inputs, return_dict=return_dict)
193
- else:
194
- return module(*inputs)
195
-
196
- return custom_forward
197
-
198
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
199
- hidden_states = torch.utils.checkpoint.checkpoint(
200
- create_custom_forward(block),
186
+ hidden_states = self._gradient_checkpointing_func(
187
+ block,
201
188
  hidden_states,
202
189
  None,
203
190
  None,
@@ -205,7 +192,6 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
205
192
  timestep,
206
193
  cross_attention_kwargs,
207
194
  class_labels,
208
- **ckpt_kwargs,
209
195
  )
210
196
  else:
211
197
  hidden_states = block(
@@ -244,6 +244,9 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
244
244
  Whether or not to use style condition and image meta size. True for version <=1.1, False for version >= 1.2
245
245
  """
246
246
 
247
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "pooler"]
248
+ _supports_group_offloading = False
249
+
247
250
  @register_to_config
248
251
  def __init__(
249
252
  self,
@@ -277,9 +280,7 @@ class HunyuanDiT2DModel(ModelMixin, ConfigMixin):
277
280
  act_fn="silu_fp32",
278
281
  )
279
282
 
280
- self.text_embedding_padding = nn.Parameter(
281
- torch.randn(text_len + text_len_t5, cross_attention_dim, dtype=torch.float32)
282
- )
283
+ self.text_embedding_padding = nn.Parameter(torch.randn(text_len + text_len_t5, cross_attention_dim))
283
284
 
284
285
  self.pos_embed = PatchEmbed(
285
286
  height=sample_size,
@@ -11,6 +11,7 @@
11
11
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
+
14
15
  from typing import Optional
15
16
 
16
17
  import torch
@@ -19,13 +20,14 @@ from torch import nn
19
20
  from ...configuration_utils import ConfigMixin, register_to_config
20
21
  from ...models.embeddings import PixArtAlphaTextProjection, get_1d_sincos_pos_embed_from_grid
21
22
  from ..attention import BasicTransformerBlock
23
+ from ..cache_utils import CacheMixin
22
24
  from ..embeddings import PatchEmbed
23
25
  from ..modeling_outputs import Transformer2DModelOutput
24
26
  from ..modeling_utils import ModelMixin
25
27
  from ..normalization import AdaLayerNormSingle
26
28
 
27
29
 
28
- class LatteTransformer3DModel(ModelMixin, ConfigMixin):
30
+ class LatteTransformer3DModel(ModelMixin, ConfigMixin, CacheMixin):
29
31
  _supports_gradient_checkpointing = True
30
32
 
31
33
  """
@@ -65,6 +67,8 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
65
67
  The number of frames in the video-like data.
66
68
  """
67
69
 
70
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
71
+
68
72
  @register_to_config
69
73
  def __init__(
70
74
  self,
@@ -162,9 +166,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
162
166
 
163
167
  self.gradient_checkpointing = False
164
168
 
165
- def _set_gradient_checkpointing(self, module, value=False):
166
- self.gradient_checkpointing = value
167
-
168
169
  def forward(
169
170
  self,
170
171
  hidden_states: torch.Tensor,
@@ -226,20 +227,24 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
226
227
  # Prepare text embeddings for spatial block
227
228
  # batch_size num_tokens hidden_size -> (batch_size * num_frame) num_tokens hidden_size
228
229
  encoder_hidden_states = self.caption_projection(encoder_hidden_states) # 3 120 1152
229
- encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(num_frame, dim=0).view(
230
- -1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1]
231
- )
230
+ encoder_hidden_states_spatial = encoder_hidden_states.repeat_interleave(
231
+ num_frame, dim=0, output_size=encoder_hidden_states.shape[0] * num_frame
232
+ ).view(-1, encoder_hidden_states.shape[-2], encoder_hidden_states.shape[-1])
232
233
 
233
234
  # Prepare timesteps for spatial and temporal block
234
- timestep_spatial = timestep.repeat_interleave(num_frame, dim=0).view(-1, timestep.shape[-1])
235
- timestep_temp = timestep.repeat_interleave(num_patches, dim=0).view(-1, timestep.shape[-1])
235
+ timestep_spatial = timestep.repeat_interleave(
236
+ num_frame, dim=0, output_size=timestep.shape[0] * num_frame
237
+ ).view(-1, timestep.shape[-1])
238
+ timestep_temp = timestep.repeat_interleave(
239
+ num_patches, dim=0, output_size=timestep.shape[0] * num_patches
240
+ ).view(-1, timestep.shape[-1])
236
241
 
237
242
  # Spatial and temporal transformer blocks
238
243
  for i, (spatial_block, temp_block) in enumerate(
239
244
  zip(self.transformer_blocks, self.temporal_transformer_blocks)
240
245
  ):
241
246
  if torch.is_grad_enabled() and self.gradient_checkpointing:
242
- hidden_states = torch.utils.checkpoint.checkpoint(
247
+ hidden_states = self._gradient_checkpointing_func(
243
248
  spatial_block,
244
249
  hidden_states,
245
250
  None, # attention_mask
@@ -248,7 +253,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
248
253
  timestep_spatial,
249
254
  None, # cross_attention_kwargs
250
255
  None, # class_labels
251
- use_reentrant=False,
252
256
  )
253
257
  else:
254
258
  hidden_states = spatial_block(
@@ -269,10 +273,10 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
269
273
  hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
270
274
 
271
275
  if i == 0 and num_frame > 1:
272
- hidden_states = hidden_states + self.temp_pos_embed
276
+ hidden_states = hidden_states + self.temp_pos_embed.to(hidden_states.dtype)
273
277
 
274
278
  if torch.is_grad_enabled() and self.gradient_checkpointing:
275
- hidden_states = torch.utils.checkpoint.checkpoint(
279
+ hidden_states = self._gradient_checkpointing_func(
276
280
  temp_block,
277
281
  hidden_states,
278
282
  None, # attention_mask
@@ -281,7 +285,6 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
281
285
  timestep_temp,
282
286
  None, # cross_attention_kwargs
283
287
  None, # class_labels
284
- use_reentrant=False,
285
288
  )
286
289
  else:
287
290
  hidden_states = temp_block(
@@ -300,7 +303,9 @@ class LatteTransformer3DModel(ModelMixin, ConfigMixin):
300
303
  ).permute(0, 2, 1, 3)
301
304
  hidden_states = hidden_states.reshape(-1, hidden_states.shape[-2], hidden_states.shape[-1])
302
305
 
303
- embedded_timestep = embedded_timestep.repeat_interleave(num_frame, dim=0).view(-1, embedded_timestep.shape[-1])
306
+ embedded_timestep = embedded_timestep.repeat_interleave(
307
+ num_frame, dim=0, output_size=embedded_timestep.shape[0] * num_frame
308
+ ).view(-1, embedded_timestep.shape[-1])
304
309
  shift, scale = (self.scale_shift_table[None] + embedded_timestep[:, None]).chunk(2, dim=1)
305
310
  hidden_states = self.norm_out(hidden_states)
306
311
  # Modulation
@@ -98,7 +98,7 @@ class LuminaNextDiTBlock(nn.Module):
98
98
 
99
99
  self.feed_forward = LuminaFeedForward(
100
100
  dim=dim,
101
- inner_dim=4 * dim,
101
+ inner_dim=int(4 * 2 * dim / 3),
102
102
  multiple_of=multiple_of,
103
103
  ffn_dim_multiplier=ffn_dim_multiplier,
104
104
  )
@@ -221,6 +221,8 @@ class LuminaNextDiT2DModel(ModelMixin, ConfigMixin):
221
221
  overall scale of the model's operations.
222
222
  """
223
223
 
224
+ _skip_layerwise_casting_patterns = ["patch_embedder", "norm", "ffn_norm"]
225
+
224
226
  @register_to_config
225
227
  def __init__(
226
228
  self,
@@ -17,7 +17,7 @@ import torch
17
17
  from torch import nn
18
18
 
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
- from ...utils import is_torch_version, logging
20
+ from ...utils import logging
21
21
  from ..attention import BasicTransformerBlock
22
22
  from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
23
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
@@ -79,6 +79,7 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
79
79
 
80
80
  _supports_gradient_checkpointing = True
81
81
  _no_split_modules = ["BasicTransformerBlock", "PatchEmbed"]
82
+ _skip_layerwise_casting_patterns = ["pos_embed", "norm", "adaln_single"]
82
83
 
83
84
  @register_to_config
84
85
  def __init__(
@@ -183,10 +184,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
183
184
  in_features=self.config.caption_channels, hidden_size=self.inner_dim
184
185
  )
185
186
 
186
- def _set_gradient_checkpointing(self, module, value=False):
187
- if hasattr(module, "gradient_checkpointing"):
188
- module.gradient_checkpointing = value
189
-
190
187
  @property
191
188
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
192
189
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -387,19 +384,8 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
387
384
  # 2. Blocks
388
385
  for block in self.transformer_blocks:
389
386
  if torch.is_grad_enabled() and self.gradient_checkpointing:
390
-
391
- def create_custom_forward(module, return_dict=None):
392
- def custom_forward(*inputs):
393
- if return_dict is not None:
394
- return module(*inputs, return_dict=return_dict)
395
- else:
396
- return module(*inputs)
397
-
398
- return custom_forward
399
-
400
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
401
- hidden_states = torch.utils.checkpoint.checkpoint(
402
- create_custom_forward(block),
387
+ hidden_states = self._gradient_checkpointing_func(
388
+ block,
403
389
  hidden_states,
404
390
  attention_mask,
405
391
  encoder_hidden_states,
@@ -407,7 +393,6 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
407
393
  timestep,
408
394
  cross_attention_kwargs,
409
395
  None,
410
- **ckpt_kwargs,
411
396
  )
412
397
  else:
413
398
  hidden_states = block(
@@ -353,7 +353,11 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
353
353
  attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
354
354
  attention_mask = F.pad(attention_mask, (0, self.additional_embeddings), value=0.0)
355
355
  attention_mask = (attention_mask[:, None, :] + self.causal_attention_mask).to(hidden_states.dtype)
356
- attention_mask = attention_mask.repeat_interleave(self.config.num_attention_heads, dim=0)
356
+ attention_mask = attention_mask.repeat_interleave(
357
+ self.config.num_attention_heads,
358
+ dim=0,
359
+ output_size=attention_mask.shape[0] * self.config.num_attention_heads,
360
+ )
357
361
 
358
362
  if self.norm_in is not None:
359
363
  hidden_states = self.norm_in(hidden_states)
@@ -15,18 +15,18 @@
15
15
  from typing import Any, Dict, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
+ import torch.nn.functional as F
18
19
  from torch import nn
19
20
 
20
21
  from ...configuration_utils import ConfigMixin, register_to_config
21
- from ...loaders import PeftAdapterMixin
22
- from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
22
+ from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
23
24
  from ..attention_processor import (
24
25
  Attention,
25
26
  AttentionProcessor,
26
- AttnProcessor2_0,
27
27
  SanaLinearAttnProcessor2_0,
28
28
  )
29
- from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
29
+ from ..embeddings import PatchEmbed, PixArtAlphaTextProjection, TimestepEmbedding, Timesteps
30
30
  from ..modeling_outputs import Transformer2DModelOutput
31
31
  from ..modeling_utils import ModelMixin
32
32
  from ..normalization import AdaLayerNormSingle, RMSNorm
@@ -82,6 +82,109 @@ class GLUMBConv(nn.Module):
82
82
  return hidden_states
83
83
 
84
84
 
85
+ class SanaModulatedNorm(nn.Module):
86
+ def __init__(self, dim: int, elementwise_affine: bool = False, eps: float = 1e-6):
87
+ super().__init__()
88
+ self.norm = nn.LayerNorm(dim, elementwise_affine=elementwise_affine, eps=eps)
89
+
90
+ def forward(
91
+ self, hidden_states: torch.Tensor, temb: torch.Tensor, scale_shift_table: torch.Tensor
92
+ ) -> torch.Tensor:
93
+ hidden_states = self.norm(hidden_states)
94
+ shift, scale = (scale_shift_table[None] + temb[:, None].to(scale_shift_table.device)).chunk(2, dim=1)
95
+ hidden_states = hidden_states * (1 + scale) + shift
96
+ return hidden_states
97
+
98
+
99
+ class SanaCombinedTimestepGuidanceEmbeddings(nn.Module):
100
+ def __init__(self, embedding_dim):
101
+ super().__init__()
102
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
103
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
104
+
105
+ self.guidance_condition_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
106
+ self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
107
+
108
+ self.silu = nn.SiLU()
109
+ self.linear = nn.Linear(embedding_dim, 6 * embedding_dim, bias=True)
110
+
111
+ def forward(self, timestep: torch.Tensor, guidance: torch.Tensor = None, hidden_dtype: torch.dtype = None):
112
+ timesteps_proj = self.time_proj(timestep)
113
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_dtype)) # (N, D)
114
+
115
+ guidance_proj = self.guidance_condition_proj(guidance)
116
+ guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=hidden_dtype))
117
+ conditioning = timesteps_emb + guidance_emb
118
+
119
+ return self.linear(self.silu(conditioning)), conditioning
120
+
121
+
122
+ class SanaAttnProcessor2_0:
123
+ r"""
124
+ Processor for implementing scaled dot-product attention (enabled by default if you're using PyTorch 2.0).
125
+ """
126
+
127
+ def __init__(self):
128
+ if not hasattr(F, "scaled_dot_product_attention"):
129
+ raise ImportError("SanaAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
130
+
131
+ def __call__(
132
+ self,
133
+ attn: Attention,
134
+ hidden_states: torch.Tensor,
135
+ encoder_hidden_states: Optional[torch.Tensor] = None,
136
+ attention_mask: Optional[torch.Tensor] = None,
137
+ ) -> torch.Tensor:
138
+ batch_size, sequence_length, _ = (
139
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
140
+ )
141
+
142
+ if attention_mask is not None:
143
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
144
+ # scaled_dot_product_attention expects attention_mask shape to be
145
+ # (batch, heads, source_length, target_length)
146
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
147
+
148
+ query = attn.to_q(hidden_states)
149
+
150
+ if encoder_hidden_states is None:
151
+ encoder_hidden_states = hidden_states
152
+
153
+ key = attn.to_k(encoder_hidden_states)
154
+ value = attn.to_v(encoder_hidden_states)
155
+
156
+ if attn.norm_q is not None:
157
+ query = attn.norm_q(query)
158
+ if attn.norm_k is not None:
159
+ key = attn.norm_k(key)
160
+
161
+ inner_dim = key.shape[-1]
162
+ head_dim = inner_dim // attn.heads
163
+
164
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
165
+
166
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
167
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
168
+
169
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
170
+ # TODO: add support for attn.scale when we move to Torch 2.1
171
+ hidden_states = F.scaled_dot_product_attention(
172
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
173
+ )
174
+
175
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
176
+ hidden_states = hidden_states.to(query.dtype)
177
+
178
+ # linear proj
179
+ hidden_states = attn.to_out[0](hidden_states)
180
+ # dropout
181
+ hidden_states = attn.to_out[1](hidden_states)
182
+
183
+ hidden_states = hidden_states / attn.rescale_output_factor
184
+
185
+ return hidden_states
186
+
187
+
85
188
  class SanaTransformerBlock(nn.Module):
86
189
  r"""
87
190
  Transformer block introduced in [Sana](https://huggingface.co/papers/2410.10629).
@@ -101,6 +204,7 @@ class SanaTransformerBlock(nn.Module):
101
204
  norm_eps: float = 1e-6,
102
205
  attention_out_bias: bool = True,
103
206
  mlp_ratio: float = 2.5,
207
+ qk_norm: Optional[str] = None,
104
208
  ) -> None:
105
209
  super().__init__()
106
210
 
@@ -110,6 +214,8 @@ class SanaTransformerBlock(nn.Module):
110
214
  query_dim=dim,
111
215
  heads=num_attention_heads,
112
216
  dim_head=attention_head_dim,
217
+ kv_heads=num_attention_heads if qk_norm is not None else None,
218
+ qk_norm=qk_norm,
113
219
  dropout=dropout,
114
220
  bias=attention_bias,
115
221
  cross_attention_dim=None,
@@ -121,13 +227,15 @@ class SanaTransformerBlock(nn.Module):
121
227
  self.norm2 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
122
228
  self.attn2 = Attention(
123
229
  query_dim=dim,
230
+ qk_norm=qk_norm,
231
+ kv_heads=num_cross_attention_heads if qk_norm is not None else None,
124
232
  cross_attention_dim=cross_attention_dim,
125
233
  heads=num_cross_attention_heads,
126
234
  dim_head=cross_attention_head_dim,
127
235
  dropout=dropout,
128
236
  bias=True,
129
237
  out_bias=attention_out_bias,
130
- processor=AttnProcessor2_0(),
238
+ processor=SanaAttnProcessor2_0(),
131
239
  )
132
240
 
133
241
  # 3. Feed-forward
@@ -181,7 +289,7 @@ class SanaTransformerBlock(nn.Module):
181
289
  return hidden_states
182
290
 
183
291
 
184
- class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
292
+ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
185
293
  r"""
186
294
  A 2D Transformer model introduced in [Sana](https://huggingface.co/papers/2410.10629) family of models.
187
295
 
@@ -218,10 +326,15 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
218
326
  Whether to use elementwise affinity in the normalization layer.
219
327
  norm_eps (`float`, defaults to `1e-6`):
220
328
  The epsilon value for the normalization layer.
329
+ qk_norm (`str`, *optional*, defaults to `None`):
330
+ The normalization to use for the query and key.
331
+ timestep_scale (`float`, defaults to `1.0`):
332
+ The scale to use for the timesteps.
221
333
  """
222
334
 
223
335
  _supports_gradient_checkpointing = True
224
- _no_split_modules = ["SanaTransformerBlock", "PatchEmbed"]
336
+ _no_split_modules = ["SanaTransformerBlock", "PatchEmbed", "SanaModulatedNorm"]
337
+ _skip_layerwise_casting_patterns = ["patch_embed", "norm"]
225
338
 
226
339
  @register_to_config
227
340
  def __init__(
@@ -243,6 +356,10 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
243
356
  norm_elementwise_affine: bool = False,
244
357
  norm_eps: float = 1e-6,
245
358
  interpolation_scale: Optional[int] = None,
359
+ guidance_embeds: bool = False,
360
+ guidance_embeds_scale: float = 0.1,
361
+ qk_norm: Optional[str] = None,
362
+ timestep_scale: float = 1.0,
246
363
  ) -> None:
247
364
  super().__init__()
248
365
 
@@ -250,7 +367,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
250
367
  inner_dim = num_attention_heads * attention_head_dim
251
368
 
252
369
  # 1. Patch Embedding
253
- interpolation_scale = interpolation_scale if interpolation_scale is not None else max(sample_size // 64, 1)
254
370
  self.patch_embed = PatchEmbed(
255
371
  height=sample_size,
256
372
  width=sample_size,
@@ -258,10 +374,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
258
374
  in_channels=in_channels,
259
375
  embed_dim=inner_dim,
260
376
  interpolation_scale=interpolation_scale,
377
+ pos_embed_type="sincos" if interpolation_scale is not None else None,
261
378
  )
262
379
 
263
380
  # 2. Additional condition embeddings
264
- self.time_embed = AdaLayerNormSingle(inner_dim)
381
+ if guidance_embeds:
382
+ self.time_embed = SanaCombinedTimestepGuidanceEmbeddings(inner_dim)
383
+ else:
384
+ self.time_embed = AdaLayerNormSingle(inner_dim)
265
385
 
266
386
  self.caption_projection = PixArtAlphaTextProjection(in_features=caption_channels, hidden_size=inner_dim)
267
387
  self.caption_norm = RMSNorm(inner_dim, eps=1e-5, elementwise_affine=True)
@@ -281,6 +401,7 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
281
401
  norm_elementwise_affine=norm_elementwise_affine,
282
402
  norm_eps=norm_eps,
283
403
  mlp_ratio=mlp_ratio,
404
+ qk_norm=qk_norm,
284
405
  )
285
406
  for _ in range(num_layers)
286
407
  ]
@@ -288,16 +409,11 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
288
409
 
289
410
  # 4. Output blocks
290
411
  self.scale_shift_table = nn.Parameter(torch.randn(2, inner_dim) / inner_dim**0.5)
291
-
292
- self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
412
+ self.norm_out = SanaModulatedNorm(inner_dim, elementwise_affine=False, eps=1e-6)
293
413
  self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels)
294
414
 
295
415
  self.gradient_checkpointing = False
296
416
 
297
- def _set_gradient_checkpointing(self, module, value=False):
298
- if hasattr(module, "gradient_checkpointing"):
299
- module.gradient_checkpointing = value
300
-
301
417
  @property
302
418
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
303
419
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
@@ -362,7 +478,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
362
478
  self,
363
479
  hidden_states: torch.Tensor,
364
480
  encoder_hidden_states: torch.Tensor,
365
- timestep: torch.LongTensor,
481
+ timestep: torch.Tensor,
482
+ guidance: Optional[torch.Tensor] = None,
366
483
  encoder_attention_mask: Optional[torch.Tensor] = None,
367
484
  attention_mask: Optional[torch.Tensor] = None,
368
485
  attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -413,9 +530,14 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
413
530
 
414
531
  hidden_states = self.patch_embed(hidden_states)
415
532
 
416
- timestep, embedded_timestep = self.time_embed(
417
- timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
418
- )
533
+ if guidance is not None:
534
+ timestep, embedded_timestep = self.time_embed(
535
+ timestep, guidance=guidance, hidden_dtype=hidden_states.dtype
536
+ )
537
+ else:
538
+ timestep, embedded_timestep = self.time_embed(
539
+ timestep, batch_size=batch_size, hidden_dtype=hidden_states.dtype
540
+ )
419
541
 
420
542
  encoder_hidden_states = self.caption_projection(encoder_hidden_states)
421
543
  encoder_hidden_states = encoder_hidden_states.view(batch_size, -1, hidden_states.shape[-1])
@@ -424,21 +546,9 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
424
546
 
425
547
  # 2. Transformer blocks
426
548
  if torch.is_grad_enabled() and self.gradient_checkpointing:
427
-
428
- def create_custom_forward(module, return_dict=None):
429
- def custom_forward(*inputs):
430
- if return_dict is not None:
431
- return module(*inputs, return_dict=return_dict)
432
- else:
433
- return module(*inputs)
434
-
435
- return custom_forward
436
-
437
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
438
-
439
549
  for block in self.transformer_blocks:
440
- hidden_states = torch.utils.checkpoint.checkpoint(
441
- create_custom_forward(block),
550
+ hidden_states = self._gradient_checkpointing_func(
551
+ block,
442
552
  hidden_states,
443
553
  attention_mask,
444
554
  encoder_hidden_states,
@@ -446,7 +556,6 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
446
556
  timestep,
447
557
  post_patch_height,
448
558
  post_patch_width,
449
- **ckpt_kwargs,
450
559
  )
451
560
 
452
561
  else:
@@ -462,13 +571,8 @@ class SanaTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
462
571
  )
463
572
 
464
573
  # 3. Normalization
465
- shift, scale = (
466
- self.scale_shift_table[None] + embedded_timestep[:, None].to(self.scale_shift_table.device)
467
- ).chunk(2, dim=1)
468
- hidden_states = self.norm_out(hidden_states)
574
+ hidden_states = self.norm_out(hidden_states, embedded_timestep, self.scale_shift_table)
469
575
 
470
- # 4. Modulation
471
- hidden_states = hidden_states * (1 + scale) + shift
472
576
  hidden_states = self.proj_out(hidden_states)
473
577
 
474
578
  # 5. Unpatchify
@@ -13,7 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
 
16
- from typing import Any, Dict, Optional, Union
16
+ from typing import Dict, Optional, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
@@ -29,7 +29,7 @@ from ...models.attention_processor import (
29
29
  )
30
30
  from ...models.modeling_utils import ModelMixin
31
31
  from ...models.transformers.transformer_2d import Transformer2DModelOutput
32
- from ...utils import is_torch_version, logging
32
+ from ...utils import logging
33
33
  from ...utils.torch_utils import maybe_allow_in_graph
34
34
 
35
35
 
@@ -211,6 +211,7 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
211
211
  """
212
212
 
213
213
  _supports_gradient_checkpointing = True
214
+ _skip_layerwise_casting_patterns = ["preprocess_conv", "postprocess_conv", "^proj_in$", "^proj_out$", "norm"]
214
215
 
215
216
  @register_to_config
216
217
  def __init__(
@@ -345,10 +346,6 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
345
346
  """
346
347
  self.set_attn_processor(StableAudioAttnProcessor2_0())
347
348
 
348
- def _set_gradient_checkpointing(self, module, value=False):
349
- if hasattr(module, "gradient_checkpointing"):
350
- module.gradient_checkpointing = value
351
-
352
349
  def forward(
353
350
  self,
354
351
  hidden_states: torch.FloatTensor,
@@ -415,25 +412,13 @@ class StableAudioDiTModel(ModelMixin, ConfigMixin):
415
412
 
416
413
  for block in self.transformer_blocks:
417
414
  if torch.is_grad_enabled() and self.gradient_checkpointing:
418
-
419
- def create_custom_forward(module, return_dict=None):
420
- def custom_forward(*inputs):
421
- if return_dict is not None:
422
- return module(*inputs, return_dict=return_dict)
423
- else:
424
- return module(*inputs)
425
-
426
- return custom_forward
427
-
428
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
429
- hidden_states = torch.utils.checkpoint.checkpoint(
430
- create_custom_forward(block),
415
+ hidden_states = self._gradient_checkpointing_func(
416
+ block,
431
417
  hidden_states,
432
418
  attention_mask,
433
419
  cross_attention_hidden_states,
434
420
  encoder_attention_mask,
435
421
  rotary_embedding,
436
- **ckpt_kwargs,
437
422
  )
438
423
 
439
424
  else: