diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -29,7 +29,6 @@ from ...models.attention_processor import (
|
|
29
29
|
AttnProcessor,
|
30
30
|
)
|
31
31
|
from ...models.modeling_utils import ModelMixin
|
32
|
-
from ...utils import is_torch_version
|
33
32
|
from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
|
34
33
|
|
35
34
|
|
@@ -138,9 +137,6 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
138
137
|
|
139
138
|
self.set_attn_processor(processor)
|
140
139
|
|
141
|
-
def _set_gradient_checkpointing(self, module, value=False):
|
142
|
-
self.gradient_checkpointing = value
|
143
|
-
|
144
140
|
def gen_r_embedding(self, r, max_positions=10000):
|
145
141
|
r = r * max_positions
|
146
142
|
half_dim = self.c_r // 2
|
@@ -159,33 +155,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
159
155
|
r_embed = self.gen_r_embedding(r)
|
160
156
|
|
161
157
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
if is_torch_version(">=", "1.11.0"):
|
170
|
-
for block in self.blocks:
|
171
|
-
if isinstance(block, AttnBlock):
|
172
|
-
x = torch.utils.checkpoint.checkpoint(
|
173
|
-
create_custom_forward(block), x, c_embed, use_reentrant=False
|
174
|
-
)
|
175
|
-
elif isinstance(block, TimestepBlock):
|
176
|
-
x = torch.utils.checkpoint.checkpoint(
|
177
|
-
create_custom_forward(block), x, r_embed, use_reentrant=False
|
178
|
-
)
|
179
|
-
else:
|
180
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, use_reentrant=False)
|
181
|
-
else:
|
182
|
-
for block in self.blocks:
|
183
|
-
if isinstance(block, AttnBlock):
|
184
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, c_embed)
|
185
|
-
elif isinstance(block, TimestepBlock):
|
186
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x, r_embed)
|
187
|
-
else:
|
188
|
-
x = torch.utils.checkpoint.checkpoint(create_custom_forward(block), x)
|
158
|
+
for block in self.blocks:
|
159
|
+
if isinstance(block, AttnBlock):
|
160
|
+
x = self._gradient_checkpointing_func(block, x, c_embed)
|
161
|
+
elif isinstance(block, TimestepBlock):
|
162
|
+
x = self._gradient_checkpointing_func(block, x, r_embed)
|
163
|
+
else:
|
164
|
+
x = self._gradient_checkpointing_func(block, x)
|
189
165
|
else:
|
190
166
|
for block in self.blocks:
|
191
167
|
if isinstance(block, AttnBlock):
|
@@ -19,15 +19,23 @@ import torch
|
|
19
19
|
from transformers import CLIPTextModel, CLIPTokenizer
|
20
20
|
|
21
21
|
from ...schedulers import DDPMWuerstchenScheduler
|
22
|
-
from ...utils import deprecate, logging, replace_example_docstring
|
22
|
+
from ...utils import deprecate, is_torch_xla_available, logging, replace_example_docstring
|
23
23
|
from ...utils.torch_utils import randn_tensor
|
24
24
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
25
25
|
from .modeling_paella_vq_model import PaellaVQModel
|
26
26
|
from .modeling_wuerstchen_diffnext import WuerstchenDiffNeXt
|
27
27
|
|
28
28
|
|
29
|
+
if is_torch_xla_available():
|
30
|
+
import torch_xla.core.xla_model as xm
|
31
|
+
|
32
|
+
XLA_AVAILABLE = True
|
33
|
+
else:
|
34
|
+
XLA_AVAILABLE = False
|
35
|
+
|
29
36
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
37
|
|
38
|
+
|
31
39
|
EXAMPLE_DOC_STRING = """
|
32
40
|
Examples:
|
33
41
|
```py
|
@@ -413,6 +421,9 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
|
|
413
421
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
414
422
|
callback(step_idx, t, latents)
|
415
423
|
|
424
|
+
if XLA_AVAILABLE:
|
425
|
+
xm.mark_step()
|
426
|
+
|
416
427
|
if output_type not in ["pt", "np", "pil", "latent"]:
|
417
428
|
raise ValueError(
|
418
429
|
f"Only the output types `pt`, `np`, `pil` and `latent` are supported not output_type={output_type}"
|
@@ -22,14 +22,22 @@ from transformers import CLIPTextModel, CLIPTokenizer
|
|
22
22
|
|
23
23
|
from ...loaders import StableDiffusionLoraLoaderMixin
|
24
24
|
from ...schedulers import DDPMWuerstchenScheduler
|
25
|
-
from ...utils import BaseOutput, deprecate, logging, replace_example_docstring
|
25
|
+
from ...utils import BaseOutput, deprecate, is_torch_xla_available, logging, replace_example_docstring
|
26
26
|
from ...utils.torch_utils import randn_tensor
|
27
27
|
from ..pipeline_utils import DiffusionPipeline
|
28
28
|
from .modeling_wuerstchen_prior import WuerstchenPrior
|
29
29
|
|
30
30
|
|
31
|
+
if is_torch_xla_available():
|
32
|
+
import torch_xla.core.xla_model as xm
|
33
|
+
|
34
|
+
XLA_AVAILABLE = True
|
35
|
+
else:
|
36
|
+
XLA_AVAILABLE = False
|
37
|
+
|
31
38
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
32
39
|
|
40
|
+
|
33
41
|
DEFAULT_STAGE_C_TIMESTEPS = list(np.linspace(1.0, 2 / 3, 20)) + list(np.linspace(2 / 3, 0.0, 11))[1:]
|
34
42
|
|
35
43
|
EXAMPLE_DOC_STRING = """
|
@@ -502,6 +510,9 @@ class WuerstchenPriorPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin)
|
|
502
510
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
503
511
|
callback(step_idx, t, latents)
|
504
512
|
|
513
|
+
if XLA_AVAILABLE:
|
514
|
+
xm.mark_step()
|
515
|
+
|
505
516
|
# 10. Denormalize the latents
|
506
517
|
latents = latents * self.config.latent_mean - self.config.latent_std
|
507
518
|
|
diffusers/quantizers/auto.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -26,8 +26,10 @@ from .quantization_config import (
|
|
26
26
|
GGUFQuantizationConfig,
|
27
27
|
QuantizationConfigMixin,
|
28
28
|
QuantizationMethod,
|
29
|
+
QuantoConfig,
|
29
30
|
TorchAoConfig,
|
30
31
|
)
|
32
|
+
from .quanto import QuantoQuantizer
|
31
33
|
from .torchao import TorchAoHfQuantizer
|
32
34
|
|
33
35
|
|
@@ -35,6 +37,7 @@ AUTO_QUANTIZER_MAPPING = {
|
|
35
37
|
"bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
|
36
38
|
"bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
|
37
39
|
"gguf": GGUFQuantizer,
|
40
|
+
"quanto": QuantoQuantizer,
|
38
41
|
"torchao": TorchAoHfQuantizer,
|
39
42
|
}
|
40
43
|
|
@@ -42,6 +45,7 @@ AUTO_QUANTIZATION_CONFIG_MAPPING = {
|
|
42
45
|
"bitsandbytes_4bit": BitsAndBytesConfig,
|
43
46
|
"bitsandbytes_8bit": BitsAndBytesConfig,
|
44
47
|
"gguf": GGUFQuantizationConfig,
|
48
|
+
"quanto": QuantoConfig,
|
45
49
|
"torchao": TorchAoConfig,
|
46
50
|
}
|
47
51
|
|
diffusers/quantizers/base.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -215,19 +215,15 @@ class DiffusersQuantizer(ABC):
|
|
215
215
|
)
|
216
216
|
|
217
217
|
@abstractmethod
|
218
|
-
def _process_model_before_weight_loading(self, model, **kwargs):
|
219
|
-
...
|
218
|
+
def _process_model_before_weight_loading(self, model, **kwargs): ...
|
220
219
|
|
221
220
|
@abstractmethod
|
222
|
-
def _process_model_after_weight_loading(self, model, **kwargs):
|
223
|
-
...
|
221
|
+
def _process_model_after_weight_loading(self, model, **kwargs): ...
|
224
222
|
|
225
223
|
@property
|
226
224
|
@abstractmethod
|
227
|
-
def is_serializable(self):
|
228
|
-
...
|
225
|
+
def is_serializable(self): ...
|
229
226
|
|
230
227
|
@property
|
231
228
|
@abstractmethod
|
232
|
-
def is_trainable(self):
|
233
|
-
...
|
229
|
+
def is_trainable(self): ...
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -61,7 +61,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
61
61
|
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
62
62
|
|
63
63
|
def validate_environment(self, *args, **kwargs):
|
64
|
-
if not torch.cuda.is_available():
|
64
|
+
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
65
65
|
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
66
66
|
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
67
67
|
raise ImportError(
|
@@ -135,6 +135,7 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
135
135
|
target_device: "torch.device",
|
136
136
|
state_dict: Dict[str, Any],
|
137
137
|
unexpected_keys: Optional[List[str]] = None,
|
138
|
+
**kwargs,
|
138
139
|
):
|
139
140
|
import bitsandbytes as bnb
|
140
141
|
|
@@ -235,18 +236,20 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
235
236
|
torch_dtype = torch.float16
|
236
237
|
return torch_dtype
|
237
238
|
|
238
|
-
|
239
|
-
|
240
|
-
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
|
249
|
-
|
239
|
+
def update_device_map(self, device_map):
|
240
|
+
if device_map is None:
|
241
|
+
if torch.xpu.is_available():
|
242
|
+
current_device = f"xpu:{torch.xpu.current_device()}"
|
243
|
+
else:
|
244
|
+
current_device = f"cuda:{torch.cuda.current_device()}"
|
245
|
+
device_map = {"": current_device}
|
246
|
+
logger.info(
|
247
|
+
"The device_map was not initialized. "
|
248
|
+
"Setting device_map to {"
|
249
|
+
": {current_device}}. "
|
250
|
+
"If you want to use the model for inference, please set device_map ='auto' "
|
251
|
+
)
|
252
|
+
return device_map
|
250
253
|
|
251
254
|
def _process_model_before_weight_loading(
|
252
255
|
self,
|
@@ -289,9 +292,9 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
289
292
|
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
290
293
|
)
|
291
294
|
model.config.quantization_config = self.quantization_config
|
295
|
+
model.is_loaded_in_4bit = True
|
292
296
|
|
293
297
|
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
294
|
-
model.is_loaded_in_4bit = True
|
295
298
|
model.is_4bit_serializable = self.is_serializable
|
296
299
|
return model
|
297
300
|
|
@@ -313,7 +316,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
|
|
313
316
|
logger.info(
|
314
317
|
"Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
|
315
318
|
)
|
316
|
-
|
319
|
+
if torch.xpu.is_available():
|
320
|
+
model.to(torch.xpu.current_device())
|
321
|
+
else:
|
322
|
+
model.to(torch.cuda.current_device())
|
317
323
|
|
318
324
|
model = dequantize_and_replace(
|
319
325
|
model, self.modules_to_not_convert, quantization_config=self.quantization_config
|
@@ -344,7 +350,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
|
344
350
|
self.modules_to_not_convert = self.quantization_config.llm_int8_skip_modules
|
345
351
|
|
346
352
|
def validate_environment(self, *args, **kwargs):
|
347
|
-
if not torch.cuda.is_available():
|
353
|
+
if not (torch.cuda.is_available() or torch.xpu.is_available()):
|
348
354
|
raise RuntimeError("No GPU found. A GPU is needed for quantization.")
|
349
355
|
if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
|
350
356
|
raise ImportError(
|
@@ -400,16 +406,21 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
|
400
406
|
torch_dtype = torch.float16
|
401
407
|
return torch_dtype
|
402
408
|
|
403
|
-
#
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
409
|
+
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.update_device_map
|
410
|
+
def update_device_map(self, device_map):
|
411
|
+
if device_map is None:
|
412
|
+
if torch.xpu.is_available():
|
413
|
+
current_device = f"xpu:{torch.xpu.current_device()}"
|
414
|
+
else:
|
415
|
+
current_device = f"cuda:{torch.cuda.current_device()}"
|
416
|
+
device_map = {"": current_device}
|
417
|
+
logger.info(
|
418
|
+
"The device_map was not initialized. "
|
419
|
+
"Setting device_map to {"
|
420
|
+
": {current_device}}. "
|
421
|
+
"If you want to use the model for inference, please set device_map ='auto' "
|
422
|
+
)
|
423
|
+
return device_map
|
413
424
|
|
414
425
|
def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
|
415
426
|
if target_dtype != torch.int8:
|
@@ -446,6 +457,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
|
446
457
|
target_device: "torch.device",
|
447
458
|
state_dict: Dict[str, Any],
|
448
459
|
unexpected_keys: Optional[List[str]] = None,
|
460
|
+
**kwargs,
|
449
461
|
):
|
450
462
|
import bitsandbytes as bnb
|
451
463
|
|
@@ -493,11 +505,10 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
|
493
505
|
|
494
506
|
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_after_weight_loading with 4bit->8bit
|
495
507
|
def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
|
496
|
-
model.is_loaded_in_8bit = True
|
497
508
|
model.is_8bit_serializable = self.is_serializable
|
498
509
|
return model
|
499
510
|
|
500
|
-
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading
|
511
|
+
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer._process_model_before_weight_loading with 4bit->8bit
|
501
512
|
def _process_model_before_weight_loading(
|
502
513
|
self,
|
503
514
|
model: "ModelMixin",
|
@@ -539,6 +550,7 @@ class BnB8BitDiffusersQuantizer(DiffusersQuantizer):
|
|
539
550
|
model, modules_to_not_convert=self.modules_to_not_convert, quantization_config=self.quantization_config
|
540
551
|
)
|
541
552
|
model.config.quantization_config = self.quantization_config
|
553
|
+
model.is_loaded_in_8bit = True
|
542
554
|
|
543
555
|
@property
|
544
556
|
# Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.is_serializable
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -139,10 +139,12 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
|
139
139
|
models by reducing the precision of the weights and activations, thus making models more efficient in terms
|
140
140
|
of both storage and computation.
|
141
141
|
"""
|
142
|
-
model,
|
143
|
-
model, modules_to_not_convert, current_key_name, quantization_config
|
144
|
-
)
|
142
|
+
model, _ = _replace_with_bnb_linear(model, modules_to_not_convert, current_key_name, quantization_config)
|
145
143
|
|
144
|
+
has_been_replaced = any(
|
145
|
+
isinstance(replaced_module, (bnb.nn.Linear4bit, bnb.nn.Linear8bitLt))
|
146
|
+
for _, replaced_module in model.named_modules()
|
147
|
+
)
|
146
148
|
if not has_been_replaced:
|
147
149
|
logger.warning(
|
148
150
|
"You are loading your model in 8bit or 4bit but no linear modules were found in your model."
|
@@ -153,8 +155,8 @@ def replace_with_bnb_linear(model, modules_to_not_convert=None, current_key_name
|
|
153
155
|
return model
|
154
156
|
|
155
157
|
|
156
|
-
#
|
157
|
-
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
|
158
|
+
# Adapted from PEFT: https://github.com/huggingface/peft/blob/6d458b300fc2ed82e19f796b53af4c97d03ea604/src/peft/utils/integrations.py#L81
|
159
|
+
def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None, dtype: "torch.dtype" = None):
|
158
160
|
"""
|
159
161
|
Helper function to dequantize 4bit or 8bit bnb weights.
|
160
162
|
|
@@ -177,13 +179,16 @@ def dequantize_bnb_weight(weight: "torch.nn.Parameter", state=None):
|
|
177
179
|
if state.SCB is None:
|
178
180
|
state.SCB = weight.SCB
|
179
181
|
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
185
|
-
|
186
|
-
|
182
|
+
if hasattr(bnb.functional, "int8_vectorwise_dequant"):
|
183
|
+
# Use bitsandbytes API if available (requires v0.45.0+)
|
184
|
+
dequantized = bnb.functional.int8_vectorwise_dequant(weight.data, state.SCB)
|
185
|
+
else:
|
186
|
+
# Multiply by (scale/127) to dequantize.
|
187
|
+
dequantized = weight.data * state.SCB.view(-1, 1) * 7.874015718698502e-3
|
188
|
+
|
189
|
+
if dtype:
|
190
|
+
dequantized = dequantized.to(dtype)
|
191
|
+
return dequantized
|
187
192
|
|
188
193
|
|
189
194
|
def _create_accelerate_new_hook(old_hook):
|
@@ -205,6 +210,7 @@ def _create_accelerate_new_hook(old_hook):
|
|
205
210
|
|
206
211
|
def _dequantize_and_replace(
|
207
212
|
model,
|
213
|
+
dtype,
|
208
214
|
modules_to_not_convert=None,
|
209
215
|
current_key_name=None,
|
210
216
|
quantization_config=None,
|
@@ -244,7 +250,7 @@ def _dequantize_and_replace(
|
|
244
250
|
else:
|
245
251
|
state = None
|
246
252
|
|
247
|
-
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state))
|
253
|
+
new_module.weight = torch.nn.Parameter(dequantize_bnb_weight(module.weight, state, dtype))
|
248
254
|
|
249
255
|
if bias is not None:
|
250
256
|
new_module.bias = bias
|
@@ -263,9 +269,10 @@ def _dequantize_and_replace(
|
|
263
269
|
if len(list(module.children())) > 0:
|
264
270
|
_, has_been_replaced = _dequantize_and_replace(
|
265
271
|
module,
|
266
|
-
|
267
|
-
|
268
|
-
|
272
|
+
dtype=dtype,
|
273
|
+
modules_to_not_convert=modules_to_not_convert,
|
274
|
+
current_key_name=current_key_name,
|
275
|
+
quantization_config=quantization_config,
|
269
276
|
has_been_replaced=has_been_replaced,
|
270
277
|
)
|
271
278
|
# Remove the last key for recursion
|
@@ -278,15 +285,18 @@ def dequantize_and_replace(
|
|
278
285
|
modules_to_not_convert=None,
|
279
286
|
quantization_config=None,
|
280
287
|
):
|
281
|
-
model,
|
288
|
+
model, _ = _dequantize_and_replace(
|
282
289
|
model,
|
290
|
+
dtype=model.dtype,
|
283
291
|
modules_to_not_convert=modules_to_not_convert,
|
284
292
|
quantization_config=quantization_config,
|
285
293
|
)
|
286
|
-
|
294
|
+
has_been_replaced = any(
|
295
|
+
isinstance(replaced_module, torch.nn.Linear) for _, replaced_module in model.named_modules()
|
296
|
+
)
|
287
297
|
if not has_been_replaced:
|
288
298
|
logger.warning(
|
289
|
-
"
|
299
|
+
"Some linear modules were not dequantized. This could lead to unexpected behaviour. Please check your model."
|
290
300
|
)
|
291
301
|
|
292
302
|
return model
|
@@ -108,6 +108,7 @@ class GGUFQuantizer(DiffusersQuantizer):
|
|
108
108
|
target_device: "torch.device",
|
109
109
|
state_dict: Optional[Dict[str, Any]] = None,
|
110
110
|
unexpected_keys: Optional[List[str]] = None,
|
111
|
+
**kwargs,
|
111
112
|
):
|
112
113
|
module, tensor_name = get_module_from_name(model, param_name)
|
113
114
|
if tensor_name not in module._parameters and tensor_name not in module._buffers:
|
@@ -400,6 +400,8 @@ class GGUFParameter(torch.nn.Parameter):
|
|
400
400
|
data = data if data is not None else torch.empty(0)
|
401
401
|
self = torch.Tensor._make_subclass(cls, data, requires_grad)
|
402
402
|
self.quant_type = quant_type
|
403
|
+
block_size, type_size = GGML_QUANT_SIZES[quant_type]
|
404
|
+
self.quant_shape = _quant_shape_from_byte_shape(self.shape, type_size, block_size)
|
403
405
|
|
404
406
|
return self
|
405
407
|
|
@@ -418,7 +420,7 @@ class GGUFParameter(torch.nn.Parameter):
|
|
418
420
|
# so that we preserve quant_type information
|
419
421
|
quant_type = None
|
420
422
|
for arg in args:
|
421
|
-
if isinstance(arg, list) and (arg[0], GGUFParameter):
|
423
|
+
if isinstance(arg, list) and isinstance(arg[0], GGUFParameter):
|
422
424
|
quant_type = arg[0].quant_type
|
423
425
|
break
|
424
426
|
if isinstance(arg, GGUFParameter):
|
@@ -450,7 +452,7 @@ class GGUFLinear(nn.Linear):
|
|
450
452
|
def forward(self, inputs):
|
451
453
|
weight = dequantize_gguf_tensor(self.weight)
|
452
454
|
weight = weight.to(self.compute_dtype)
|
453
|
-
bias = self.bias.to(self.compute_dtype)
|
455
|
+
bias = self.bias.to(self.compute_dtype) if self.bias is not None else None
|
454
456
|
|
455
457
|
output = torch.nn.functional.linear(inputs, weight, bias)
|
456
458
|
return output
|
@@ -45,6 +45,17 @@ class QuantizationMethod(str, Enum):
|
|
45
45
|
BITS_AND_BYTES = "bitsandbytes"
|
46
46
|
GGUF = "gguf"
|
47
47
|
TORCHAO = "torchao"
|
48
|
+
QUANTO = "quanto"
|
49
|
+
|
50
|
+
|
51
|
+
if is_torchao_available():
|
52
|
+
from torchao.quantization.quant_primitives import MappingType
|
53
|
+
|
54
|
+
class TorchAoJSONEncoder(json.JSONEncoder):
|
55
|
+
def default(self, obj):
|
56
|
+
if isinstance(obj, MappingType):
|
57
|
+
return obj.name
|
58
|
+
return super().default(obj)
|
48
59
|
|
49
60
|
|
50
61
|
@dataclass
|
@@ -481,8 +492,15 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|
481
492
|
|
482
493
|
TORCHAO_QUANT_TYPE_METHODS = self._get_torchao_quant_type_to_method()
|
483
494
|
if self.quant_type not in TORCHAO_QUANT_TYPE_METHODS.keys():
|
495
|
+
is_floating_quant_type = self.quant_type.startswith("float") or self.quant_type.startswith("fp")
|
496
|
+
if is_floating_quant_type and not self._is_cuda_capability_atleast_8_9():
|
497
|
+
raise ValueError(
|
498
|
+
f"Requested quantization type: {self.quant_type} is not supported on GPUs with CUDA capability <= 8.9. You "
|
499
|
+
f"can check the CUDA capability of your GPU using `torch.cuda.get_device_capability()`."
|
500
|
+
)
|
501
|
+
|
484
502
|
raise ValueError(
|
485
|
-
f"Requested quantization type: {self.quant_type} is not supported
|
503
|
+
f"Requested quantization type: {self.quant_type} is not supported or is an incorrect `quant_type` name. If you think the "
|
486
504
|
f"provided quantization type should be supported, please open an issue at https://github.com/huggingface/diffusers/issues."
|
487
505
|
)
|
488
506
|
|
@@ -652,13 +670,13 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|
652
670
|
|
653
671
|
def __repr__(self):
|
654
672
|
r"""
|
655
|
-
Example of how this looks for `TorchAoConfig("
|
673
|
+
Example of how this looks for `TorchAoConfig("uint4wo", group_size=32)`:
|
656
674
|
|
657
675
|
```
|
658
676
|
TorchAoConfig {
|
659
677
|
"modules_to_not_convert": null,
|
660
678
|
"quant_method": "torchao",
|
661
|
-
"quant_type": "
|
679
|
+
"quant_type": "uint4wo",
|
662
680
|
"quant_type_kwargs": {
|
663
681
|
"group_size": 32
|
664
682
|
}
|
@@ -666,4 +684,41 @@ class TorchAoConfig(QuantizationConfigMixin):
|
|
666
684
|
```
|
667
685
|
"""
|
668
686
|
config_dict = self.to_dict()
|
669
|
-
return
|
687
|
+
return (
|
688
|
+
f"{self.__class__.__name__} {json.dumps(config_dict, indent=2, sort_keys=True, cls=TorchAoJSONEncoder)}\n"
|
689
|
+
)
|
690
|
+
|
691
|
+
|
692
|
+
@dataclass
|
693
|
+
class QuantoConfig(QuantizationConfigMixin):
|
694
|
+
"""
|
695
|
+
This is a wrapper class about all possible attributes and features that you can play with a model that has been
|
696
|
+
loaded using `quanto`.
|
697
|
+
|
698
|
+
Args:
|
699
|
+
weights_dtype (`str`, *optional*, defaults to `"int8"`):
|
700
|
+
The target dtype for the weights after quantization. Supported values are ("float8","int8","int4","int2")
|
701
|
+
modules_to_not_convert (`list`, *optional*, default to `None`):
|
702
|
+
The list of modules to not quantize, useful for quantizing models that explicitly require to have some
|
703
|
+
modules left in their original precision (e.g. Whisper encoder, Llava encoder, Mixtral gate layers).
|
704
|
+
"""
|
705
|
+
|
706
|
+
def __init__(
|
707
|
+
self,
|
708
|
+
weights_dtype: str = "int8",
|
709
|
+
modules_to_not_convert: Optional[List[str]] = None,
|
710
|
+
**kwargs,
|
711
|
+
):
|
712
|
+
self.quant_method = QuantizationMethod.QUANTO
|
713
|
+
self.weights_dtype = weights_dtype
|
714
|
+
self.modules_to_not_convert = modules_to_not_convert
|
715
|
+
|
716
|
+
self.post_init()
|
717
|
+
|
718
|
+
def post_init(self):
|
719
|
+
r"""
|
720
|
+
Safety checker that arguments are correct
|
721
|
+
"""
|
722
|
+
accepted_weights = ["float8", "int8", "int4", "int2"]
|
723
|
+
if self.weights_dtype not in accepted_weights:
|
724
|
+
raise ValueError(f"Only support weights in {accepted_weights} but found {self.weights_dtype}")
|
@@ -0,0 +1 @@
|
|
1
|
+
from .quanto_quantizer import QuantoQuantizer
|