diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,236 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import functools
16
+ from typing import Any, Dict, Optional, Tuple
17
+
18
+ import torch
19
+
20
+ from ..utils.logging import get_logger
21
+
22
+
23
+ logger = get_logger(__name__) # pylint: disable=invalid-name
24
+
25
+
26
+ class ModelHook:
27
+ r"""
28
+ A hook that contains callbacks to be executed just before and after the forward method of a model.
29
+ """
30
+
31
+ _is_stateful = False
32
+
33
+ def __init__(self):
34
+ self.fn_ref: "HookFunctionReference" = None
35
+
36
+ def initialize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
37
+ r"""
38
+ Hook that is executed when a model is initialized.
39
+
40
+ Args:
41
+ module (`torch.nn.Module`):
42
+ The module attached to this hook.
43
+ """
44
+ return module
45
+
46
+ def deinitalize_hook(self, module: torch.nn.Module) -> torch.nn.Module:
47
+ r"""
48
+ Hook that is executed when a model is deinitalized.
49
+
50
+ Args:
51
+ module (`torch.nn.Module`):
52
+ The module attached to this hook.
53
+ """
54
+ return module
55
+
56
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs) -> Tuple[Tuple[Any], Dict[str, Any]]:
57
+ r"""
58
+ Hook that is executed just before the forward method of the model.
59
+
60
+ Args:
61
+ module (`torch.nn.Module`):
62
+ The module whose forward pass will be executed just after this event.
63
+ args (`Tuple[Any]`):
64
+ The positional arguments passed to the module.
65
+ kwargs (`Dict[Str, Any]`):
66
+ The keyword arguments passed to the module.
67
+ Returns:
68
+ `Tuple[Tuple[Any], Dict[Str, Any]]`:
69
+ A tuple with the treated `args` and `kwargs`.
70
+ """
71
+ return args, kwargs
72
+
73
+ def post_forward(self, module: torch.nn.Module, output: Any) -> Any:
74
+ r"""
75
+ Hook that is executed just after the forward method of the model.
76
+
77
+ Args:
78
+ module (`torch.nn.Module`):
79
+ The module whose forward pass been executed just before this event.
80
+ output (`Any`):
81
+ The output of the module.
82
+ Returns:
83
+ `Any`: The processed `output`.
84
+ """
85
+ return output
86
+
87
+ def detach_hook(self, module: torch.nn.Module) -> torch.nn.Module:
88
+ r"""
89
+ Hook that is executed when the hook is detached from a module.
90
+
91
+ Args:
92
+ module (`torch.nn.Module`):
93
+ The module detached from this hook.
94
+ """
95
+ return module
96
+
97
+ def reset_state(self, module: torch.nn.Module):
98
+ if self._is_stateful:
99
+ raise NotImplementedError("This hook is stateful and needs to implement the `reset_state` method.")
100
+ return module
101
+
102
+
103
+ class HookFunctionReference:
104
+ def __init__(self) -> None:
105
+ """A container class that maintains mutable references to forward pass functions in a hook chain.
106
+
107
+ Its mutable nature allows the hook system to modify the execution chain dynamically without rebuilding the
108
+ entire forward pass structure.
109
+
110
+ Attributes:
111
+ pre_forward: A callable that processes inputs before the main forward pass.
112
+ post_forward: A callable that processes outputs after the main forward pass.
113
+ forward: The current forward function in the hook chain.
114
+ original_forward: The original forward function, stored when a hook provides a custom new_forward.
115
+
116
+ The class enables hook removal by allowing updates to the forward chain through reference modification rather
117
+ than requiring reconstruction of the entire chain. When a hook is removed, only the relevant references need to
118
+ be updated, preserving the execution order of the remaining hooks.
119
+ """
120
+ self.pre_forward = None
121
+ self.post_forward = None
122
+ self.forward = None
123
+ self.original_forward = None
124
+
125
+
126
+ class HookRegistry:
127
+ def __init__(self, module_ref: torch.nn.Module) -> None:
128
+ super().__init__()
129
+
130
+ self.hooks: Dict[str, ModelHook] = {}
131
+
132
+ self._module_ref = module_ref
133
+ self._hook_order = []
134
+ self._fn_refs = []
135
+
136
+ def register_hook(self, hook: ModelHook, name: str) -> None:
137
+ if name in self.hooks.keys():
138
+ raise ValueError(
139
+ f"Hook with name {name} already exists in the registry. Please use a different name or "
140
+ f"first remove the existing hook and then add a new one."
141
+ )
142
+
143
+ self._module_ref = hook.initialize_hook(self._module_ref)
144
+
145
+ def create_new_forward(function_reference: HookFunctionReference):
146
+ def new_forward(module, *args, **kwargs):
147
+ args, kwargs = function_reference.pre_forward(module, *args, **kwargs)
148
+ output = function_reference.forward(*args, **kwargs)
149
+ return function_reference.post_forward(module, output)
150
+
151
+ return new_forward
152
+
153
+ forward = self._module_ref.forward
154
+
155
+ fn_ref = HookFunctionReference()
156
+ fn_ref.pre_forward = hook.pre_forward
157
+ fn_ref.post_forward = hook.post_forward
158
+ fn_ref.forward = forward
159
+
160
+ if hasattr(hook, "new_forward"):
161
+ fn_ref.original_forward = forward
162
+ fn_ref.forward = functools.update_wrapper(
163
+ functools.partial(hook.new_forward, self._module_ref), hook.new_forward
164
+ )
165
+
166
+ rewritten_forward = create_new_forward(fn_ref)
167
+ self._module_ref.forward = functools.update_wrapper(
168
+ functools.partial(rewritten_forward, self._module_ref), rewritten_forward
169
+ )
170
+
171
+ hook.fn_ref = fn_ref
172
+ self.hooks[name] = hook
173
+ self._hook_order.append(name)
174
+ self._fn_refs.append(fn_ref)
175
+
176
+ def get_hook(self, name: str) -> Optional[ModelHook]:
177
+ return self.hooks.get(name, None)
178
+
179
+ def remove_hook(self, name: str, recurse: bool = True) -> None:
180
+ if name in self.hooks.keys():
181
+ num_hooks = len(self._hook_order)
182
+ hook = self.hooks[name]
183
+ index = self._hook_order.index(name)
184
+ fn_ref = self._fn_refs[index]
185
+
186
+ old_forward = fn_ref.forward
187
+ if fn_ref.original_forward is not None:
188
+ old_forward = fn_ref.original_forward
189
+
190
+ if index == num_hooks - 1:
191
+ self._module_ref.forward = old_forward
192
+ else:
193
+ self._fn_refs[index + 1].forward = old_forward
194
+
195
+ self._module_ref = hook.deinitalize_hook(self._module_ref)
196
+ del self.hooks[name]
197
+ self._hook_order.pop(index)
198
+ self._fn_refs.pop(index)
199
+
200
+ if recurse:
201
+ for module_name, module in self._module_ref.named_modules():
202
+ if module_name == "":
203
+ continue
204
+ if hasattr(module, "_diffusers_hook"):
205
+ module._diffusers_hook.remove_hook(name, recurse=False)
206
+
207
+ def reset_stateful_hooks(self, recurse: bool = True) -> None:
208
+ for hook_name in reversed(self._hook_order):
209
+ hook = self.hooks[hook_name]
210
+ if hook._is_stateful:
211
+ hook.reset_state(self._module_ref)
212
+
213
+ if recurse:
214
+ for module_name, module in self._module_ref.named_modules():
215
+ if module_name == "":
216
+ continue
217
+ if hasattr(module, "_diffusers_hook"):
218
+ module._diffusers_hook.reset_stateful_hooks(recurse=False)
219
+
220
+ @classmethod
221
+ def check_if_exists_or_initialize(cls, module: torch.nn.Module) -> "HookRegistry":
222
+ if not hasattr(module, "_diffusers_hook"):
223
+ module._diffusers_hook = cls(module)
224
+ return module._diffusers_hook
225
+
226
+ def __repr__(self) -> str:
227
+ registry_repr = ""
228
+ for i, hook_name in enumerate(self._hook_order):
229
+ if self.hooks[hook_name].__class__.__repr__ is not object.__repr__:
230
+ hook_repr = self.hooks[hook_name].__repr__()
231
+ else:
232
+ hook_repr = self.hooks[hook_name].__class__.__name__
233
+ registry_repr += f" ({i}) {hook_name} - {hook_repr}"
234
+ if i < len(self._hook_order) - 1:
235
+ registry_repr += "\n"
236
+ return f"HookRegistry(\n{registry_repr}\n)"
@@ -0,0 +1,245 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import re
16
+ from typing import Optional, Tuple, Type, Union
17
+
18
+ import torch
19
+
20
+ from ..utils import get_logger, is_peft_available, is_peft_version
21
+ from .hooks import HookRegistry, ModelHook
22
+
23
+
24
+ logger = get_logger(__name__) # pylint: disable=invalid-name
25
+
26
+
27
+ # fmt: off
28
+ _LAYERWISE_CASTING_HOOK = "layerwise_casting"
29
+ _PEFT_AUTOCAST_DISABLE_HOOK = "peft_autocast_disable"
30
+ SUPPORTED_PYTORCH_LAYERS = (
31
+ torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d,
32
+ torch.nn.ConvTranspose1d, torch.nn.ConvTranspose2d, torch.nn.ConvTranspose3d,
33
+ torch.nn.Linear,
34
+ )
35
+
36
+ DEFAULT_SKIP_MODULES_PATTERN = ("pos_embed", "patch_embed", "norm", "^proj_in$", "^proj_out$")
37
+ # fmt: on
38
+
39
+ _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST = is_peft_available() and is_peft_version(">", "0.14.0")
40
+ if _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
41
+ from peft.helpers import disable_input_dtype_casting
42
+ from peft.tuners.tuners_utils import BaseTunerLayer
43
+
44
+
45
+ class LayerwiseCastingHook(ModelHook):
46
+ r"""
47
+ A hook that casts the weights of a module to a high precision dtype for computation, and to a low precision dtype
48
+ for storage. This process may lead to quality loss in the output, but can significantly reduce the memory
49
+ footprint.
50
+ """
51
+
52
+ _is_stateful = False
53
+
54
+ def __init__(self, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool) -> None:
55
+ self.storage_dtype = storage_dtype
56
+ self.compute_dtype = compute_dtype
57
+ self.non_blocking = non_blocking
58
+
59
+ def initialize_hook(self, module: torch.nn.Module):
60
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
61
+ return module
62
+
63
+ def deinitalize_hook(self, module: torch.nn.Module):
64
+ raise NotImplementedError(
65
+ "LayerwiseCastingHook does not support deinitalization. A model once enabled with layerwise casting will "
66
+ "have casted its weights to a lower precision dtype for storage. Casting this back to the original dtype "
67
+ "will lead to precision loss, which might have an impact on the model's generation quality. The model should "
68
+ "be re-initialized and loaded in the original dtype."
69
+ )
70
+
71
+ def pre_forward(self, module: torch.nn.Module, *args, **kwargs):
72
+ module.to(dtype=self.compute_dtype, non_blocking=self.non_blocking)
73
+ return args, kwargs
74
+
75
+ def post_forward(self, module: torch.nn.Module, output):
76
+ module.to(dtype=self.storage_dtype, non_blocking=self.non_blocking)
77
+ return output
78
+
79
+
80
+ class PeftInputAutocastDisableHook(ModelHook):
81
+ r"""
82
+ A hook that disables the casting of inputs to the module weight dtype during the forward pass. By default, PEFT
83
+ casts the inputs to the weight dtype of the module, which can lead to precision loss.
84
+
85
+ The reasons for needing this are:
86
+ - If we don't add PEFT layers' weight names to `skip_modules_pattern` when applying layerwise casting, the
87
+ inputs will be casted to the, possibly lower precision, storage dtype. Reference:
88
+ https://github.com/huggingface/peft/blob/0facdebf6208139cbd8f3586875acb378813dd97/src/peft/tuners/lora/layer.py#L706
89
+ - We can, on our end, use something like accelerate's `send_to_device` but for dtypes. This way, we can ensure
90
+ that the inputs are casted to the computation dtype correctly always. However, there are two goals we are
91
+ hoping to achieve:
92
+ 1. Making forward implementations independent of device/dtype casting operations as much as possible.
93
+ 2. Peforming inference without losing information from casting to different precisions. With the current
94
+ PEFT implementation (as linked in the reference above), and assuming running layerwise casting inference
95
+ with storage_dtype=torch.float8_e4m3fn and compute_dtype=torch.bfloat16, inputs are cast to
96
+ torch.float8_e4m3fn in the lora layer. We will then upcast back to torch.bfloat16 when we continue the
97
+ forward pass in PEFT linear forward or Diffusers layer forward, with a `send_to_dtype` operation from
98
+ LayerwiseCastingHook. This will be a lossy operation and result in poorer generation quality.
99
+ """
100
+
101
+ def new_forward(self, module: torch.nn.Module, *args, **kwargs):
102
+ with disable_input_dtype_casting(module):
103
+ return self.fn_ref.original_forward(*args, **kwargs)
104
+
105
+
106
+ def apply_layerwise_casting(
107
+ module: torch.nn.Module,
108
+ storage_dtype: torch.dtype,
109
+ compute_dtype: torch.dtype,
110
+ skip_modules_pattern: Union[str, Tuple[str, ...]] = "auto",
111
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
112
+ non_blocking: bool = False,
113
+ ) -> None:
114
+ r"""
115
+ Applies layerwise casting to a given module. The module expected here is a Diffusers ModelMixin but it can be any
116
+ nn.Module using diffusers layers or pytorch primitives.
117
+
118
+ Example:
119
+
120
+ ```python
121
+ >>> import torch
122
+ >>> from diffusers import CogVideoXTransformer3DModel
123
+
124
+ >>> transformer = CogVideoXTransformer3DModel.from_pretrained(
125
+ ... model_id, subfolder="transformer", torch_dtype=torch.bfloat16
126
+ ... )
127
+
128
+ >>> apply_layerwise_casting(
129
+ ... transformer,
130
+ ... storage_dtype=torch.float8_e4m3fn,
131
+ ... compute_dtype=torch.bfloat16,
132
+ ... skip_modules_pattern=["patch_embed", "norm", "proj_out"],
133
+ ... non_blocking=True,
134
+ ... )
135
+ ```
136
+
137
+ Args:
138
+ module (`torch.nn.Module`):
139
+ The module whose leaf modules will be cast to a high precision dtype for computation, and to a low
140
+ precision dtype for storage.
141
+ storage_dtype (`torch.dtype`):
142
+ The dtype to cast the module to before/after the forward pass for storage.
143
+ compute_dtype (`torch.dtype`):
144
+ The dtype to cast the module to during the forward pass for computation.
145
+ skip_modules_pattern (`Tuple[str, ...]`, defaults to `"auto"`):
146
+ A list of patterns to match the names of the modules to skip during the layerwise casting process. If set
147
+ to `"auto"`, the default patterns are used. If set to `None`, no modules are skipped. If set to `None`
148
+ alongside `skip_modules_classes` being `None`, the layerwise casting is applied directly to the module
149
+ instead of its internal submodules.
150
+ skip_modules_classes (`Tuple[Type[torch.nn.Module], ...]`, defaults to `None`):
151
+ A list of module classes to skip during the layerwise casting process.
152
+ non_blocking (`bool`, defaults to `False`):
153
+ If `True`, the weight casting operations are non-blocking.
154
+ """
155
+ if skip_modules_pattern == "auto":
156
+ skip_modules_pattern = DEFAULT_SKIP_MODULES_PATTERN
157
+
158
+ if skip_modules_classes is None and skip_modules_pattern is None:
159
+ apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
160
+ return
161
+
162
+ _apply_layerwise_casting(
163
+ module,
164
+ storage_dtype,
165
+ compute_dtype,
166
+ skip_modules_pattern,
167
+ skip_modules_classes,
168
+ non_blocking,
169
+ )
170
+ _disable_peft_input_autocast(module)
171
+
172
+
173
+ def _apply_layerwise_casting(
174
+ module: torch.nn.Module,
175
+ storage_dtype: torch.dtype,
176
+ compute_dtype: torch.dtype,
177
+ skip_modules_pattern: Optional[Tuple[str, ...]] = None,
178
+ skip_modules_classes: Optional[Tuple[Type[torch.nn.Module], ...]] = None,
179
+ non_blocking: bool = False,
180
+ _prefix: str = "",
181
+ ) -> None:
182
+ should_skip = (skip_modules_classes is not None and isinstance(module, skip_modules_classes)) or (
183
+ skip_modules_pattern is not None and any(re.search(pattern, _prefix) for pattern in skip_modules_pattern)
184
+ )
185
+ if should_skip:
186
+ logger.debug(f'Skipping layerwise casting for layer "{_prefix}"')
187
+ return
188
+
189
+ if isinstance(module, SUPPORTED_PYTORCH_LAYERS):
190
+ logger.debug(f'Applying layerwise casting to layer "{_prefix}"')
191
+ apply_layerwise_casting_hook(module, storage_dtype, compute_dtype, non_blocking)
192
+ return
193
+
194
+ for name, submodule in module.named_children():
195
+ layer_name = f"{_prefix}.{name}" if _prefix else name
196
+ _apply_layerwise_casting(
197
+ submodule,
198
+ storage_dtype,
199
+ compute_dtype,
200
+ skip_modules_pattern,
201
+ skip_modules_classes,
202
+ non_blocking,
203
+ _prefix=layer_name,
204
+ )
205
+
206
+
207
+ def apply_layerwise_casting_hook(
208
+ module: torch.nn.Module, storage_dtype: torch.dtype, compute_dtype: torch.dtype, non_blocking: bool
209
+ ) -> None:
210
+ r"""
211
+ Applies a `LayerwiseCastingHook` to a given module.
212
+
213
+ Args:
214
+ module (`torch.nn.Module`):
215
+ The module to attach the hook to.
216
+ storage_dtype (`torch.dtype`):
217
+ The dtype to cast the module to before the forward pass.
218
+ compute_dtype (`torch.dtype`):
219
+ The dtype to cast the module to during the forward pass.
220
+ non_blocking (`bool`):
221
+ If `True`, the weight casting operations are non-blocking.
222
+ """
223
+ registry = HookRegistry.check_if_exists_or_initialize(module)
224
+ hook = LayerwiseCastingHook(storage_dtype, compute_dtype, non_blocking)
225
+ registry.register_hook(hook, _LAYERWISE_CASTING_HOOK)
226
+
227
+
228
+ def _is_layerwise_casting_active(module: torch.nn.Module) -> bool:
229
+ for submodule in module.modules():
230
+ if (
231
+ hasattr(submodule, "_diffusers_hook")
232
+ and submodule._diffusers_hook.get_hook(_LAYERWISE_CASTING_HOOK) is not None
233
+ ):
234
+ return True
235
+ return False
236
+
237
+
238
+ def _disable_peft_input_autocast(module: torch.nn.Module) -> None:
239
+ if not _SHOULD_DISABLE_PEFT_INPUT_AUTOCAST:
240
+ return
241
+ for submodule in module.modules():
242
+ if isinstance(submodule, BaseTunerLayer) and _is_layerwise_casting_active(submodule):
243
+ registry = HookRegistry.check_if_exists_or_initialize(submodule)
244
+ hook = PeftInputAutocastDisableHook()
245
+ registry.register_hook(hook, _PEFT_AUTOCAST_DISABLE_HOOK)