diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (389) hide show
  1. diffusers/__init__.py +186 -3
  2. diffusers/configuration_utils.py +40 -12
  3. diffusers/dependency_versions_table.py +9 -2
  4. diffusers/hooks/__init__.py +9 -0
  5. diffusers/hooks/faster_cache.py +653 -0
  6. diffusers/hooks/group_offloading.py +793 -0
  7. diffusers/hooks/hooks.py +236 -0
  8. diffusers/hooks/layerwise_casting.py +245 -0
  9. diffusers/hooks/pyramid_attention_broadcast.py +311 -0
  10. diffusers/loaders/__init__.py +6 -0
  11. diffusers/loaders/ip_adapter.py +38 -30
  12. diffusers/loaders/lora_base.py +121 -86
  13. diffusers/loaders/lora_conversion_utils.py +504 -44
  14. diffusers/loaders/lora_pipeline.py +1769 -181
  15. diffusers/loaders/peft.py +167 -57
  16. diffusers/loaders/single_file.py +17 -2
  17. diffusers/loaders/single_file_model.py +53 -5
  18. diffusers/loaders/single_file_utils.py +646 -72
  19. diffusers/loaders/textual_inversion.py +9 -9
  20. diffusers/loaders/transformer_flux.py +8 -9
  21. diffusers/loaders/transformer_sd3.py +120 -39
  22. diffusers/loaders/unet.py +20 -7
  23. diffusers/models/__init__.py +22 -0
  24. diffusers/models/activations.py +9 -9
  25. diffusers/models/attention.py +0 -1
  26. diffusers/models/attention_processor.py +163 -25
  27. diffusers/models/auto_model.py +169 -0
  28. diffusers/models/autoencoders/__init__.py +2 -0
  29. diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
  30. diffusers/models/autoencoders/autoencoder_dc.py +106 -4
  31. diffusers/models/autoencoders/autoencoder_kl.py +0 -4
  32. diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
  33. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
  34. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
  35. diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
  36. diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
  37. diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
  38. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
  39. diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
  40. diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
  41. diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
  42. diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
  43. diffusers/models/autoencoders/vae.py +31 -141
  44. diffusers/models/autoencoders/vq_model.py +3 -0
  45. diffusers/models/cache_utils.py +108 -0
  46. diffusers/models/controlnets/__init__.py +1 -0
  47. diffusers/models/controlnets/controlnet.py +3 -8
  48. diffusers/models/controlnets/controlnet_flux.py +14 -42
  49. diffusers/models/controlnets/controlnet_sd3.py +58 -34
  50. diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
  51. diffusers/models/controlnets/controlnet_union.py +27 -18
  52. diffusers/models/controlnets/controlnet_xs.py +7 -46
  53. diffusers/models/controlnets/multicontrolnet_union.py +196 -0
  54. diffusers/models/embeddings.py +18 -7
  55. diffusers/models/model_loading_utils.py +122 -80
  56. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  57. diffusers/models/modeling_flax_utils.py +1 -1
  58. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  59. diffusers/models/modeling_utils.py +617 -272
  60. diffusers/models/normalization.py +67 -14
  61. diffusers/models/resnet.py +1 -1
  62. diffusers/models/transformers/__init__.py +6 -0
  63. diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
  64. diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
  65. diffusers/models/transformers/consisid_transformer_3d.py +789 -0
  66. diffusers/models/transformers/dit_transformer_2d.py +5 -19
  67. diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
  68. diffusers/models/transformers/latte_transformer_3d.py +20 -15
  69. diffusers/models/transformers/lumina_nextdit2d.py +3 -1
  70. diffusers/models/transformers/pixart_transformer_2d.py +4 -19
  71. diffusers/models/transformers/prior_transformer.py +5 -1
  72. diffusers/models/transformers/sana_transformer.py +144 -40
  73. diffusers/models/transformers/stable_audio_transformer.py +5 -20
  74. diffusers/models/transformers/transformer_2d.py +7 -22
  75. diffusers/models/transformers/transformer_allegro.py +9 -17
  76. diffusers/models/transformers/transformer_cogview3plus.py +6 -17
  77. diffusers/models/transformers/transformer_cogview4.py +462 -0
  78. diffusers/models/transformers/transformer_easyanimate.py +527 -0
  79. diffusers/models/transformers/transformer_flux.py +68 -110
  80. diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
  81. diffusers/models/transformers/transformer_ltx.py +53 -35
  82. diffusers/models/transformers/transformer_lumina2.py +548 -0
  83. diffusers/models/transformers/transformer_mochi.py +6 -17
  84. diffusers/models/transformers/transformer_omnigen.py +469 -0
  85. diffusers/models/transformers/transformer_sd3.py +56 -86
  86. diffusers/models/transformers/transformer_temporal.py +5 -11
  87. diffusers/models/transformers/transformer_wan.py +469 -0
  88. diffusers/models/unets/unet_1d.py +3 -1
  89. diffusers/models/unets/unet_2d.py +21 -20
  90. diffusers/models/unets/unet_2d_blocks.py +19 -243
  91. diffusers/models/unets/unet_2d_condition.py +4 -6
  92. diffusers/models/unets/unet_3d_blocks.py +14 -127
  93. diffusers/models/unets/unet_3d_condition.py +8 -12
  94. diffusers/models/unets/unet_i2vgen_xl.py +5 -13
  95. diffusers/models/unets/unet_kandinsky3.py +0 -4
  96. diffusers/models/unets/unet_motion_model.py +20 -114
  97. diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
  98. diffusers/models/unets/unet_stable_cascade.py +8 -35
  99. diffusers/models/unets/uvit_2d.py +1 -4
  100. diffusers/optimization.py +2 -2
  101. diffusers/pipelines/__init__.py +57 -8
  102. diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
  103. diffusers/pipelines/amused/pipeline_amused.py +15 -2
  104. diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
  105. diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
  106. diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
  107. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
  108. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
  109. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
  110. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
  111. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
  112. diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
  113. diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
  114. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
  115. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
  116. diffusers/pipelines/auto_pipeline.py +35 -14
  117. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  118. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
  119. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
  120. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
  121. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
  122. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
  123. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
  124. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
  125. diffusers/pipelines/cogview4/__init__.py +49 -0
  126. diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
  127. diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
  128. diffusers/pipelines/cogview4/pipeline_output.py +21 -0
  129. diffusers/pipelines/consisid/__init__.py +49 -0
  130. diffusers/pipelines/consisid/consisid_utils.py +357 -0
  131. diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
  132. diffusers/pipelines/consisid/pipeline_output.py +20 -0
  133. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
  134. diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
  135. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
  136. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
  137. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
  138. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
  139. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
  140. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
  141. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
  142. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
  143. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
  144. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
  145. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
  146. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
  147. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
  148. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
  149. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
  150. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
  151. diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
  152. diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
  153. diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
  154. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
  155. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
  156. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
  157. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
  158. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
  159. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
  160. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
  161. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  162. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
  163. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
  164. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
  165. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
  166. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
  167. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
  168. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
  169. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
  170. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  171. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  172. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
  173. diffusers/pipelines/dit/pipeline_dit.py +15 -2
  174. diffusers/pipelines/easyanimate/__init__.py +52 -0
  175. diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
  176. diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
  177. diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
  178. diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
  179. diffusers/pipelines/flux/pipeline_flux.py +53 -21
  180. diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
  181. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
  182. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
  183. diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
  184. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
  185. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
  186. diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
  187. diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
  188. diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
  189. diffusers/pipelines/free_noise_utils.py +3 -3
  190. diffusers/pipelines/hunyuan_video/__init__.py +4 -0
  191. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
  192. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
  193. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
  194. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
  195. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
  196. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
  197. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  198. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
  199. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
  200. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
  201. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
  202. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
  203. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
  204. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
  205. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
  206. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
  207. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
  208. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
  209. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
  210. diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
  211. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
  212. diffusers/pipelines/kolors/text_encoder.py +7 -34
  213. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
  214. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
  215. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
  216. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
  217. diffusers/pipelines/latte/pipeline_latte.py +36 -7
  218. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
  219. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
  220. diffusers/pipelines/ltx/__init__.py +2 -0
  221. diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
  222. diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
  223. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
  224. diffusers/pipelines/lumina/__init__.py +2 -2
  225. diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
  226. diffusers/pipelines/lumina2/__init__.py +48 -0
  227. diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
  228. diffusers/pipelines/marigold/__init__.py +2 -0
  229. diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
  230. diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
  231. diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
  232. diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
  233. diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
  234. diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
  235. diffusers/pipelines/omnigen/__init__.py +50 -0
  236. diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
  237. diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
  238. diffusers/pipelines/onnx_utils.py +5 -3
  239. diffusers/pipelines/pag/pag_utils.py +1 -1
  240. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
  241. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
  242. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
  243. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
  244. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
  245. diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
  246. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
  247. diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
  248. diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
  249. diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
  250. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
  251. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
  252. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
  253. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
  254. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
  255. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
  256. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
  257. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
  258. diffusers/pipelines/pia/pipeline_pia.py +13 -1
  259. diffusers/pipelines/pipeline_flax_utils.py +7 -7
  260. diffusers/pipelines/pipeline_loading_utils.py +193 -83
  261. diffusers/pipelines/pipeline_utils.py +221 -106
  262. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
  263. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
  264. diffusers/pipelines/sana/__init__.py +2 -0
  265. diffusers/pipelines/sana/pipeline_sana.py +183 -58
  266. diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
  267. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
  268. diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
  269. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
  270. diffusers/pipelines/shap_e/renderer.py +6 -6
  271. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
  272. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
  273. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
  274. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
  275. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
  276. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
  277. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
  278. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
  279. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  280. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
  281. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
  282. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
  283. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
  284. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
  285. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
  286. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
  287. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
  288. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
  289. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
  290. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
  291. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
  292. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
  293. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
  294. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
  295. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
  296. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
  297. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
  298. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
  299. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
  300. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
  301. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
  302. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
  303. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
  304. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
  305. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
  306. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  307. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
  308. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
  309. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
  310. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
  311. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
  312. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
  313. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
  314. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
  315. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
  316. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
  317. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
  318. diffusers/pipelines/transformers_loading_utils.py +121 -0
  319. diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
  320. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
  321. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
  322. diffusers/pipelines/wan/__init__.py +51 -0
  323. diffusers/pipelines/wan/pipeline_output.py +20 -0
  324. diffusers/pipelines/wan/pipeline_wan.py +593 -0
  325. diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
  326. diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
  327. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
  328. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
  329. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
  330. diffusers/quantizers/auto.py +5 -1
  331. diffusers/quantizers/base.py +5 -9
  332. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
  333. diffusers/quantizers/bitsandbytes/utils.py +30 -20
  334. diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
  335. diffusers/quantizers/gguf/utils.py +4 -2
  336. diffusers/quantizers/quantization_config.py +59 -4
  337. diffusers/quantizers/quanto/__init__.py +1 -0
  338. diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
  339. diffusers/quantizers/quanto/utils.py +60 -0
  340. diffusers/quantizers/torchao/__init__.py +1 -1
  341. diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
  342. diffusers/schedulers/__init__.py +2 -1
  343. diffusers/schedulers/scheduling_consistency_models.py +1 -2
  344. diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
  345. diffusers/schedulers/scheduling_ddpm.py +2 -3
  346. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
  347. diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
  348. diffusers/schedulers/scheduling_edm_euler.py +45 -10
  349. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
  350. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
  351. diffusers/schedulers/scheduling_heun_discrete.py +1 -1
  352. diffusers/schedulers/scheduling_lcm.py +1 -2
  353. diffusers/schedulers/scheduling_lms_discrete.py +1 -1
  354. diffusers/schedulers/scheduling_repaint.py +5 -1
  355. diffusers/schedulers/scheduling_scm.py +265 -0
  356. diffusers/schedulers/scheduling_tcd.py +1 -2
  357. diffusers/schedulers/scheduling_utils.py +2 -1
  358. diffusers/training_utils.py +14 -7
  359. diffusers/utils/__init__.py +9 -1
  360. diffusers/utils/constants.py +13 -1
  361. diffusers/utils/deprecation_utils.py +1 -1
  362. diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
  363. diffusers/utils/dummy_gguf_objects.py +17 -0
  364. diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
  365. diffusers/utils/dummy_pt_objects.py +233 -0
  366. diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
  367. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  368. diffusers/utils/dummy_torchao_objects.py +17 -0
  369. diffusers/utils/dynamic_modules_utils.py +1 -1
  370. diffusers/utils/export_utils.py +28 -3
  371. diffusers/utils/hub_utils.py +52 -102
  372. diffusers/utils/import_utils.py +121 -221
  373. diffusers/utils/loading_utils.py +2 -1
  374. diffusers/utils/logging.py +1 -2
  375. diffusers/utils/peft_utils.py +6 -14
  376. diffusers/utils/remote_utils.py +425 -0
  377. diffusers/utils/source_code_parsing_utils.py +52 -0
  378. diffusers/utils/state_dict_utils.py +15 -1
  379. diffusers/utils/testing_utils.py +243 -13
  380. diffusers/utils/torch_utils.py +10 -0
  381. diffusers/utils/typing_utils.py +91 -0
  382. diffusers/video_processor.py +1 -1
  383. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
  384. diffusers-0.33.0.dist-info/RECORD +608 -0
  385. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
  386. diffusers-0.32.2.dist-info/RECORD +0 -550
  387. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
  388. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
  389. {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ from ...models.attention_processor import (
29
29
  AttnProcessor,
30
30
  )
31
31
  from ...models.modeling_utils import ModelMixin
32
- from ...utils import is_torch_version
33
32
  from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
34
33
 
35
34
 
@@ -138,9 +137,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
138
137
 
139
138
  self.set_attn_processor(processor)
140
139
 
141
- def _set_gradient_checkpointing(self, module, value=False):
142
- self.gradient_checkpointing = value
143
-
144
140
  def gen_r_embedding(self, r, max_positions=10000):
145
141
  r = r * max_positions
146
142
  half_dim = self.c_r // 2
@@ -159,33 +155,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
159
155
  r_embed = self.gen_r_embedding(r)
160
156
 
161
157
  if torch.is_grad_enabled() and self.gradient_checkpointing:
162
-
163
- def create_custom_forward(module):
164
- def custom_forward(*inputs):
165
- return module(*inputs)
166
-
167
- return custom_forward
168
-
169
- if is_torch_version(">=", "1.11.0"):
170
- for block in self.blocks:
171
- if isinstance(block, AttnBlock):
172
- x = torch.utils.checkpoint.checkpoint(
173
- create_custom_forward(block), x, c_embed, use_reentrant=False
174
- )
175
- elif isinstance(block, TimestepBlock):
176
- x = torch.utils.checkpoint.checkpoint(
177
- create_custom_forward(block), x, r_embed, use_reentrant=False
178
- )
179
- else:
180
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
181
- else:
182
- for block in self.blocks:
183
- if isinstance(block, AttnBlock):
184
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
185
- elif isinstance(block, TimestepBlock):
186
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
187
- else:
188
- x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
158
+ for block in self.blocks:
159
+ if isinstance(block, AttnBlock):
160
+ x = self._gradient_checkpointing_func(block, x, c_embed)
161
+ elif isinstance(block, TimestepBlock):
162
+ x = self._gradient_checkpointing_func(block, x, r_embed)
163
+ else:
164
+ x = self._gradient_checkpointing_func(block, x)
189
165
  else:
190
166
  for block in self.blocks:
191
167
  if isinstance(block, AttnBlock):
@@ -19,15 +19,23 @@ import torch
19
19
  from transformers import CLIPTextModel, CLIPTokenizer
20
20
 
21
21
  from ...schedulers import DDPMWuerstchenScheduler
22
- from ...utils import deprecate, logging, replace_example_docstring
22
+ from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
23
23
  from ...utils.torch_utils import randn_tensor
24
24
  from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
25
25
  from .modeling_paella_vq_model import PaellaVQModel
26
26
  from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
27
27
 
28
28
 
29
+ if is_torch_xla_available():
30
+ import torch_xla.core.xla_model as xm
31
+
32
+ XLA_AVAILABLE = True
33
+ else:
34
+ XLA_AVAILABLE = False
35
+
29
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
37
 
38
+
31
39
  EXAMPLE_DOC_STRING = """
32
40
  Examples:
33
41
  ```py
@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
413
421
  step_idx = i // getattr(self.scheduler, "order", 1)
414
422
  callback(step_idx, t, latents)
415
423
 
424
+ if XLA_AVAILABLE:
425
+ xm.mark_step()
426
+
416
427
  if output_type not in ["pt", "np", "pil", "latent"]:
417
428
  raise ValueError(
418
429
  f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
22
22
 
23
23
  from ...loaders import StableDiffusionLoraLoaderMixin
24
24
  from ...schedulers import DDPMWuerstchenScheduler
25
- from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
25
+ from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
26
26
  from ...utils.torch_utils import randn_tensor
27
27
  from ..pipeline_utils import DiffusionPipeline
28
28
  from .modeling_wuerstchen_prior import WuerstchenPrior
29
29
 
30
30
 
31
+ if is_torch_xla_available():
32
+ import torch_xla.core.xla_model as xm
33
+
34
+ XLA_AVAILABLE = True
35
+ else:
36
+ XLA_AVAILABLE = False
37
+
31
38
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
32
39
 
40
+
33
41
  DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
34
42
 
35
43
  EXAMPLE_DOC_STRING = """
@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
502
510
  step_idx = i // getattr(self.scheduler, "order", 1)
503
511
  callback(step_idx, t, latents)
504
512
 
513
+ if XLA_AVAILABLE:
514
+ xm.mark_step()
515
+
505
516
  # 10. Denormalize the latents
506
517
  latents = latents * self.config.latent_mean - self.config.latent_std
507
518
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -26,8 +26,10 @@ from .quantization_config import (
26
26
  GGUFQuantizationConfig,
27
27
  QuantizationConfigMixin,
28
28
  QuantizationMethod,
29
+ QuantoConfig,
29
30
  TorchAoConfig,
30
31
  )
32
+ from .quanto import QuantoQuantizer
31
33
  from .torchao import TorchAoHfQuantizer
32
34
 
33
35
 
@@ -35,6 +37,7 @@ AUTO_QUANTIZER_MAPPING = {
35
37
  "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
36
38
  "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
37
39
  "gguf": GGUFQuantizer,
40
+ "quanto": QuantoQuantizer,
38
41
  "torchao": TorchAoHfQuantizer,
39
42
  }
40
43
 
@@ -42,6 +45,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
42
45
  "bitsandbytes_4bit": BitsAndBytesConfig,
43
46
  "bitsandbytes_8bit": BitsAndBytesConfig,
44
47
  "gguf": GGUFQuantizationConfig,
48
+ "quanto": QuantoConfig,
45
49
  "torchao": TorchAoConfig,
46
50
  }
47
51
 
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -215,19 +215,15 @@ class DiffusersQuantizer(ABC):
215
215
  )
216
216
 
217
217
  @abstractmethod
218
- def _process_model_before_weight_loading(self, model, **kwargs):
219
- ...
218
+ def _process_model_before_weight_loading(self, model, **kwargs): ...
220
219
 
221
220
  @abstractmethod
222
- def _process_model_after_weight_loading(self, model, **kwargs):
223
- ...
221
+ def _process_model_after_weight_loading(self, model, **kwargs): ...
224
222
 
225
223
  @property
226
224
  @abstractmethod
227
- def is_serializable(self):
228
- ...
225
+ def is_serializable(self): ...
229
226
 
230
227
  @property
231
228
  @abstractmethod
232
- def is_trainable(self):
233
- ...
229
+ def is_trainable(self): ...
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
61
61
  self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
62
62
 
63
63
  def validate_environment(self, *args, **kwargs):
64
- if not torch.cuda.is_available():
64
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
65
65
  raise RuntimeError("No GPU found. A GPU is needed for quantization.")
66
66
  if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
67
67
  raise ImportError(
@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
135
135
  target_device: "torch.device",
136
136
  state_dict: Dict[str, Any],
137
137
  unexpected_keys: Optional[List[str]] = None,
138
+ **kwargs,
138
139
  ):
139
140
  import bitsandbytes as bnb
140
141
 
@@ -235,18 +236,20 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
235
236
  torch_dtype = torch.float16
236
237
  return torch_dtype
237
238
 
238
- # (sayakpaul): I think it could be better to disable custom `device_map`s
239
- # for the first phase of the integration in the interest of simplicity.
240
- # Commenting this for discussions on the PR.
241
- # def update_device_map(self, device_map):
242
- # if device_map is None:
243
- # device_map = {"": torch.cuda.current_device()}
244
- # logger.info(
245
- # "The device_map was not initialized. "
246
- # "Setting device_map to {'':torch.cuda.current_device()}. "
247
- # "If you want to use the model for inference, please set device_map ='auto' "
248
- # )
249
- # return device_map
239
+ def update_device_map(self, device_map):
240
+ if device_map is None:
241
+ if torch.xpu.is_available():
242
+ current_device = f"xpu:{torch.xpu.current_device()}"
243
+ else:
244
+ current_device = f"cuda:{torch.cuda.current_device()}"
245
+ device_map = {"": current_device}
246
+ logger.info(
247
+ "The device_map was not initialized. "
248
+ "Setting device_map to {"
249
+ ": {current_device}}. "
250
+ "If you want to use the model for inference, please set device_map ='auto' "
251
+ )
252
+ return device_map
250
253
 
251
254
  def _process_model_before_weight_loading(
252
255
  self,
@@ -289,9 +292,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
289
292
  model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
290
293
  )
291
294
  model.config.quantization_config = self.quantization_config
295
+ model.is_loaded_in_4bit = True
292
296
 
293
297
  def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
294
- model.is_loaded_in_4bit = True
295
298
  model.is_4bit_serializable = self.is_serializable
296
299
  return model
297
300
 
@@ -313,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
313
316
  logger.info(
314
317
  "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
315
318
  )
316
- model.to(torch.cuda.current_device())
319
+ if torch.xpu.is_available():
320
+ model.to(torch.xpu.current_device())
321
+ else:
322
+ model.to(torch.cuda.current_device())
317
323
 
318
324
  model = dequantize_and_replace(
319
325
  model, self.modules_to_not_convert, quantization_config=self.quantization_config
@@ -344,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
344
350
  self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
345
351
 
346
352
  def validate_environment(self, *args, **kwargs):
347
- if not torch.cuda.is_available():
353
+ if not (torch.cuda.is_available() or torch.xpu.is_available()):
348
354
  raise RuntimeError("No GPU found. A GPU is needed for quantization.")
349
355
  if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
350
356
  raise ImportError(
@@ -400,16 +406,21 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
400
406
  torch_dtype = torch.float16
401
407
  return torch_dtype
402
408
 
403
- # # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
404
- # def update_device_map(self, device_map):
405
- # if device_map is None:
406
- # device_map = {"": torch.cuda.current_device()}
407
- # logger.info(
408
- # "The device_map was not initialized. "
409
- # "Setting device_map to {'':torch.cuda.current_device()}. "
410
- # "If you want to use the model for inference, please set device_map ='auto' "
411
- # )
412
- # return device_map
409
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
410
+ def update_device_map(self, device_map):
411
+ if device_map is None:
412
+ if torch.xpu.is_available():
413
+ current_device = f"xpu:{torch.xpu.current_device()}"
414
+ else:
415
+ current_device = f"cuda:{torch.cuda.current_device()}"
416
+ device_map = {"": current_device}
417
+ logger.info(
418
+ "The device_map was not initialized. "
419
+ "Setting device_map to {"
420
+ ": {current_device}}. "
421
+ "If you want to use the model for inference, please set device_map ='auto' "
422
+ )
423
+ return device_map
413
424
 
414
425
  def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
415
426
  if target_dtype != torch.int8:
@@ -446,6 +457,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
446
457
  target_device: "torch.device",
447
458
  state_dict: Dict[str, Any],
448
459
  unexpected_keys: Optional[List[str]] = None,
460
+ **kwargs,
449
461
  ):
450
462
  import bitsandbytes as bnb
451
463
 
@@ -493,11 +505,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
493
505
 
494
506
  # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
495
507
  def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
496
- model.is_loaded_in_8bit = True
497
508
  model.is_8bit_serializable = self.is_serializable
498
509
  return model
499
510
 
500
- # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
511
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
501
512
  def _process_model_before_weight_loading(
502
513
  self,
503
514
  model: "ModelMixin",
@@ -539,6 +550,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
539
550
  model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
540
551
  )
541
552
  model.config.quantization_config = self.quantization_config
553
+ model.is_loaded_in_8bit = True
542
554
 
543
555
  @property
544
556
  # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
@@ -1,4 +1,4 @@
1
- # Copyright 2024 The HuggingFace Inc. team. All rights reserved.
1
+ # Copyright 2025 The HuggingFace Inc. team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
139
139
  models by reducing the precision of the weights and activations, thus making models more efficient in terms
140
140
  of both storage and computation.
141
141
  """
142
- model, has_been_replaced = _replace_with_bnb_linear(
143
- model, modules_to_not_convert, current_key_name, quantization_config
144
- )
142
+ model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
145
143
 
144
+ has_been_replaced = any(
145
+ isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
146
+ for _, replaced_module in model.named_modules()
147
+ )
146
148
  if not has_been_replaced:
147
149
  logger.warning(
148
150
  "You are loading your model in 8bit or 4bit but no linear modules were found in your model."
@@ -153,8 +155,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
153
155
  return model
154
156
 
155
157
 
156
- # Copied from PEFT: https://github.com/huggingface/peft/blob/47b3712898539569c02ec5b3ed4a6c36811331a1/src/peft/utils/integrations.py#L41
157
- def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
158
+ # Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
159
+ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
158
160
  """
159
161
  Helper function to dequantize 4bit or 8bit bnb weights.
160
162
 
@@ -177,13 +179,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
177
179
  if state.SCB is None:
178
180
  state.SCB = weight.SCB
179
181
 
180
- im = torch.eye(weight.data.shape[-1]).contiguous().half().to(weight.device)
181
- im, imt, SCim, SCimt, coo_tensorim = bnb.functional.double_quant(im)
182
- im, Sim = bnb.functional.transform(im, "col32")
183
- if state.CxB is None:
184
- state.CxB, state.SB = bnb.functional.transform(weight.data, to_order=state.formatB)
185
- out32, Sout32 = bnb.functional.igemmlt(im, state.CxB, Sim, state.SB)
186
- return bnb.functional.mm_dequant(out32, Sout32, SCim, state.SCB, bias=None).t()
182
+ if hasattr(bnb.functional, "int8_vectorwise_dequant"):
183
+ # Use bitsandbytes API if available (requires v0.45.0+)
184
+ dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
185
+ else:
186
+ # Multiply by (scale/127) to dequantize.
187
+ dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
188
+
189
+ if dtype:
190
+ dequantized = dequantized.to(dtype)
191
+ return dequantized
187
192
 
188
193
 
189
194
  def _create_accelerate_new_hook(old_hook):
@@ -205,6 +210,7 @@ def _create_accelerate_new_hook(old_hook):
205
210
 
206
211
  def _dequantize_and_replace(
207
212
  model,
213
+ dtype,
208
214
  modules_to_not_convert=None,
209
215
  current_key_name=None,
210
216
  quantization_config=None,
@@ -244,7 +250,7 @@ def _dequantize_and_replace(
244
250
  else:
245
251
  state = None
246
252
 
247
- new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
253
+ new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
248
254
 
249
255
  if bias is not None:
250
256
  new_module.bias = bias
@@ -263,9 +269,10 @@ def _dequantize_and_replace(
263
269
  if len(list(module.children())) > 0:
264
270
  _, has_been_replaced = _dequantize_and_replace(
265
271
  module,
266
- modules_to_not_convert,
267
- current_key_name,
268
- quantization_config,
272
+ dtype=dtype,
273
+ modules_to_not_convert=modules_to_not_convert,
274
+ current_key_name=current_key_name,
275
+ quantization_config=quantization_config,
269
276
  has_been_replaced=has_been_replaced,
270
277
  )
271
278
  # Remove the last key for recursion
@@ -278,15 +285,18 @@ def dequantize_and_replace(
278
285
  modules_to_not_convert=None,
279
286
  quantization_config=None,
280
287
  ):
281
- model, has_been_replaced = _dequantize_and_replace(
288
+ model, _ = _dequantize_and_replace(
282
289
  model,
290
+ dtype=model.dtype,
283
291
  modules_to_not_convert=modules_to_not_convert,
284
292
  quantization_config=quantization_config,
285
293
  )
286
-
294
+ has_been_replaced = any(
295
+ isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
296
+ )
287
297
  if not has_been_replaced:
288
298
  logger.warning(
289
- "For some reason the model has not been properly dequantized. You might see unexpected behavior."
299
+ "Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
290
300
  )
291
301
 
292
302
  return model
@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer):
108
108
  target_device: "torch.device",
109
109
  state_dict: Optional[Dict[str, Any]] = None,
110
110
  unexpected_keys: Optional[List[str]] = None,
111
+ **kwargs,
111
112
  ):
112
113
  module, tensor_name = get_module_from_name(model, param_name)
113
114
  if tensor_name not in module._parameters and tensor_name not in module._buffers:
@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
400
400
  data = data if data is not None else torch.empty(0)
401
401
  self = torch.Tensor._make_subclass(cls, data, requires_grad)
402
402
  self.quant_type = quant_type
403
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
404
+ self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
403
405
 
404
406
  return self
405
407
 
@@ -418,7 +420,7 @@ class GGUFParameter(torch.nn.Parameter):
418
420
  # so that we preserve quant_type information
419
421
  quant_type = None
420
422
  for arg in args:
421
- if isinstance(arg, list) and (arg[0], GGUFParameter):
423
+ if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
422
424
  quant_type = arg[0].quant_type
423
425
  break
424
426
  if isinstance(arg, GGUFParameter):
@@ -450,7 +452,7 @@ class GGUFLinear(nn.Linear):
450
452
  def forward(self, inputs):
451
453
  weight = dequantize_gguf_tensor(self.weight)
452
454
  weight = weight.to(self.compute_dtype)
453
- bias = self.bias.to(self.compute_dtype)
455
+ bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
454
456
 
455
457
  output = torch.nn.functional.linear(inputs, weight, bias)
456
458
  return output
@@ -45,6 +45,17 @@ class QuantizationMethod(str, Enum):
45
45
  BITS_AND_BYTES = "bitsandbytes"
46
46
  GGUF = "gguf"
47
47
  TORCHAO = "torchao"
48
+ QUANTO = "quanto"
49
+
50
+
51
+ if is_torchao_available():
52
+ from torchao.quantization.quant_primitives import MappingType
53
+
54
+ class TorchAoJSONEncoder(json.JSONEncoder):
55
+ def default(self, obj):
56
+ if isinstance(obj, MappingType):
57
+ return obj.name
58
+ return super().default(obj)
48
59
 
49
60
 
50
61
  @dataclass
@@ -481,8 +492,15 @@ class TorchAoConfig(QuantizationConfigMixin):
481
492
 
482
493
  TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
483
494
  if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
495
+ is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
496
+ if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
497
+ raise ValueError(
498
+ f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
499
+ f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
500
+ )
501
+
484
502
  raise ValueError(
485
- f"Requested quantization type: {self.quant_type} is not supported yet or is incorrect. If you think the "
503
+ f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
486
504
  f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
487
505
  )
488
506
 
@@ -652,13 +670,13 @@ class TorchAoConfig(QuantizationConfigMixin):
652
670
 
653
671
  def __repr__(self):
654
672
  r"""
655
- Example of how this looks for `TorchAoConfig("uint_a16w4", group_size=32)`:
673
+ Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
656
674
 
657
675
  ```
658
676
  TorchAoConfig {
659
677
  "modules_to_not_convert": null,
660
678
  "quant_method": "torchao",
661
- "quant_type": "uint_a16w4",
679
+ "quant_type": "uint4wo",
662
680
  "quant_type_kwargs": {
663
681
  "group_size": 32
664
682
  }
@@ -666,4 +684,41 @@ class TorchAoConfig(QuantizationConfigMixin):
666
684
  ```
667
685
  """
668
686
  config_dict = self.to_dict()
669
- return f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True)}\n"
687
+ return (
688
+ f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
689
+ )
690
+
691
+
692
+ @dataclass
693
+ class QuantoConfig(QuantizationConfigMixin):
694
+ """
695
+ This is a wrapper class about all possible attributes and features that you can play with a model that has been
696
+ loaded using `quanto`.
697
+
698
+ Args:
699
+ weights_dtype (`str`, *optional*, defaults to `"int8"`):
700
+ The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
701
+ modules_to_not_convert (`list`, *optional*, default to `None`):
702
+ The list of modules to not quantize, useful for quantizing models that explicitly require to have some
703
+ modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
704
+ """
705
+
706
+ def __init__(
707
+ self,
708
+ weights_dtype: str = "int8",
709
+ modules_to_not_convert: Optional[List[str]] = None,
710
+ **kwargs,
711
+ ):
712
+ self.quant_method = QuantizationMethod.QUANTO
713
+ self.weights_dtype = weights_dtype
714
+ self.modules_to_not_convert = modules_to_not_convert
715
+
716
+ self.post_init()
717
+
718
+ def post_init(self):
719
+ r"""
720
+ Safety checker that arguments are correct
721
+ """
722
+ accepted_weights = ["float8", "int8", "int4", "int2"]
723
+ if self.weights_dtype not in accepted_weights:
724
+ raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
@@ -0,0 +1 @@
1
+ from .quanto_quantizer import QuantoQuantizer