diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -139,7 +139,9 @@ def get_3d_sincos_pos_embed(
|
|
139
139
|
|
140
140
|
# 3. Concat
|
141
141
|
pos_embed_spatial = pos_embed_spatial[None, :, :]
|
142
|
-
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
|
142
|
+
pos_embed_spatial = pos_embed_spatial.repeat_interleave(
|
143
|
+
temporal_size, dim=0, output_size=pos_embed_spatial.shape[0] * temporal_size
|
144
|
+
) # [T, H*W, D // 4 * 3]
|
143
145
|
|
144
146
|
pos_embed_temporal = pos_embed_temporal[:, None, :]
|
145
147
|
pos_embed_temporal = pos_embed_temporal.repeat_interleave(
|
@@ -334,7 +336,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
|
334
336
|
" `from_numpy` is no longer required."
|
335
337
|
" Pass `output_type='pt' to use the new version now."
|
336
338
|
)
|
337
|
-
deprecate("output_type=='np'", "0.
|
339
|
+
deprecate("output_type=='np'", "0.34.0", deprecation_message, standard_warn=False)
|
338
340
|
return get_1d_sincos_pos_embed_from_grid_np(embed_dim=embed_dim, pos=pos)
|
339
341
|
if embed_dim % 2 != 0:
|
340
342
|
raise ValueError("embed_dim must be divisible by 2")
|
@@ -1152,10 +1154,13 @@ def get_1d_rotary_pos_embed(
|
|
1152
1154
|
/ linear_factor
|
1153
1155
|
) # [D/2]
|
1154
1156
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
1157
|
+
is_npu = freqs.device.type == "npu"
|
1158
|
+
if is_npu:
|
1159
|
+
freqs = freqs.float()
|
1155
1160
|
if use_real and repeat_interleave_real:
|
1156
1161
|
# flux, hunyuan-dit, cogvideox
|
1157
|
-
freqs_cos = freqs.cos().repeat_interleave(2, dim=1).float() # [S, D]
|
1158
|
-
freqs_sin = freqs.sin().repeat_interleave(2, dim=1).float() # [S, D]
|
1162
|
+
freqs_cos = freqs.cos().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
1163
|
+
freqs_sin = freqs.sin().repeat_interleave(2, dim=1, output_size=freqs.shape[1] * 2).float() # [S, D]
|
1159
1164
|
return freqs_cos, freqs_sin
|
1160
1165
|
elif use_real:
|
1161
1166
|
# stable audio, allegro
|
@@ -1199,7 +1204,7 @@ def apply_rotary_emb(
|
|
1199
1204
|
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
1200
1205
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
1201
1206
|
elif use_real_unbind_dim == -2:
|
1202
|
-
# Used for Stable Audio
|
1207
|
+
# Used for Stable Audio, OmniGen and CogView4
|
1203
1208
|
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
1204
1209
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
1205
1210
|
else:
|
@@ -1248,7 +1253,8 @@ class FluxPosEmbed(nn.Module):
|
|
1248
1253
|
sin_out = []
|
1249
1254
|
pos = ids.float()
|
1250
1255
|
is_mps = ids.device.type == "mps"
|
1251
|
-
|
1256
|
+
is_npu = ids.device.type == "npu"
|
1257
|
+
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
1252
1258
|
for i in range(n_axes):
|
1253
1259
|
cos, sin = get_1d_rotary_pos_embed(
|
1254
1260
|
self.axes_dim[i],
|
@@ -1786,7 +1792,7 @@ class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
|
1786
1792
|
def forward(self, timestep, caption_feat, caption_mask):
|
1787
1793
|
# timestep embedding:
|
1788
1794
|
time_freq = self.time_proj(timestep)
|
1789
|
-
time_embed = self.timestep_embedder(time_freq.to(dtype=
|
1795
|
+
time_embed = self.timestep_embedder(time_freq.to(dtype=caption_feat.dtype))
|
1790
1796
|
|
1791
1797
|
# caption condition embedding:
|
1792
1798
|
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
@@ -2582,6 +2588,11 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
2582
2588
|
super().__init__()
|
2583
2589
|
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
2584
2590
|
|
2591
|
+
@property
|
2592
|
+
def num_ip_adapters(self) -> int:
|
2593
|
+
"""Number of IP-Adapters loaded."""
|
2594
|
+
return len(self.image_projection_layers)
|
2595
|
+
|
2585
2596
|
def forward(self, image_embeds: List[torch.Tensor]):
|
2586
2597
|
projected_image_embeds = []
|
2587
2598
|
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
4
|
#
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -20,12 +20,15 @@ import os
|
|
20
20
|
from array import array
|
21
21
|
from collections import OrderedDict
|
22
22
|
from pathlib import Path
|
23
|
-
from typing import List, Optional, Union
|
23
|
+
from typing import Dict, List, Optional, Union
|
24
|
+
from zipfile import is_zipfile
|
24
25
|
|
25
26
|
import safetensors
|
26
27
|
import torch
|
28
|
+
from huggingface_hub import DDUFEntry
|
27
29
|
from huggingface_hub.utils import EntryNotFoundError
|
28
30
|
|
31
|
+
from ..quantizers import DiffusersQuantizer
|
29
32
|
from ..utils import (
|
30
33
|
GGUF_FILE_EXTENSION,
|
31
34
|
SAFE_WEIGHTS_INDEX_NAME,
|
@@ -54,7 +57,7 @@ _CLASS_REMAPPING_DICT = {
|
|
54
57
|
|
55
58
|
if is_accelerate_available():
|
56
59
|
from accelerate import infer_auto_device_map
|
57
|
-
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
60
|
+
from accelerate.utils import get_balanced_memory, get_max_memory, offload_weight, set_module_tensor_to_device
|
58
61
|
|
59
62
|
|
60
63
|
# Adapted from `transformers` (see modeling_utils.py)
|
@@ -131,27 +134,61 @@ def _fetch_remapped_cls_from_config(config, old_class):
|
|
131
134
|
return old_class
|
132
135
|
|
133
136
|
|
134
|
-
def
|
137
|
+
def _determine_param_device(param_name: str, device_map: Optional[Dict[str, Union[int, str, torch.device]]]):
|
138
|
+
"""
|
139
|
+
Find the device of param_name from the device_map.
|
140
|
+
"""
|
141
|
+
if device_map is None:
|
142
|
+
return "cpu"
|
143
|
+
else:
|
144
|
+
module_name = param_name
|
145
|
+
# find next higher level module that is defined in device_map:
|
146
|
+
# bert.lm_head.weight -> bert.lm_head -> bert -> ''
|
147
|
+
while len(module_name) > 0 and module_name not in device_map:
|
148
|
+
module_name = ".".join(module_name.split(".")[:-1])
|
149
|
+
if module_name == "" and "" not in device_map:
|
150
|
+
raise ValueError(f"{param_name} doesn't have any device set.")
|
151
|
+
return device_map[module_name]
|
152
|
+
|
153
|
+
|
154
|
+
def load_state_dict(
|
155
|
+
checkpoint_file: Union[str, os.PathLike],
|
156
|
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
157
|
+
disable_mmap: bool = False,
|
158
|
+
map_location: Union[str, torch.device] = "cpu",
|
159
|
+
):
|
135
160
|
"""
|
136
161
|
Reads a checkpoint file, returning properly formatted errors if they arise.
|
137
162
|
"""
|
138
|
-
# TODO:
|
139
|
-
# when refactoring the _merge_sharded_checkpoints() method later.
|
163
|
+
# TODO: maybe refactor a bit this part where we pass a dict here
|
140
164
|
if isinstance(checkpoint_file, dict):
|
141
165
|
return checkpoint_file
|
142
166
|
try:
|
143
167
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
144
168
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
145
|
-
|
169
|
+
if dduf_entries:
|
170
|
+
# tensors are loaded on cpu
|
171
|
+
with dduf_entries[checkpoint_file].as_mmap() as mm:
|
172
|
+
return safetensors.torch.load(mm)
|
173
|
+
if disable_mmap:
|
174
|
+
return safetensors.torch.load(open(checkpoint_file, "rb").read())
|
175
|
+
else:
|
176
|
+
return safetensors.torch.load_file(checkpoint_file, device=map_location)
|
146
177
|
elif file_extension == GGUF_FILE_EXTENSION:
|
147
178
|
return load_gguf_checkpoint(checkpoint_file)
|
148
179
|
else:
|
180
|
+
extra_args = {}
|
149
181
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
182
|
+
# mmap can only be used with files serialized with zipfile-based format.
|
183
|
+
if (
|
184
|
+
isinstance(checkpoint_file, str)
|
185
|
+
and map_location != "meta"
|
186
|
+
and is_torch_version(">=", "2.1.0")
|
187
|
+
and is_zipfile(checkpoint_file)
|
188
|
+
and not disable_mmap
|
189
|
+
):
|
190
|
+
extra_args = {"mmap": True}
|
191
|
+
return torch.load(checkpoint_file, map_location=map_location, **weights_only_kwarg, **extra_args)
|
155
192
|
except Exception as e:
|
156
193
|
try:
|
157
194
|
with open(checkpoint_file) as f:
|
@@ -168,29 +205,31 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
168
205
|
) from e
|
169
206
|
except (UnicodeDecodeError, ValueError):
|
170
207
|
raise OSError(
|
171
|
-
f"Unable to load weights from checkpoint file for '{checkpoint_file}'
|
208
|
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' at '{checkpoint_file}'. "
|
172
209
|
)
|
173
210
|
|
174
211
|
|
175
212
|
def load_model_dict_into_meta(
|
176
213
|
model,
|
177
214
|
state_dict: OrderedDict,
|
178
|
-
device: Optional[Union[str, torch.device]] = None,
|
179
215
|
dtype: Optional[Union[str, torch.dtype]] = None,
|
180
216
|
model_name_or_path: Optional[str] = None,
|
181
|
-
hf_quantizer=None,
|
182
|
-
keep_in_fp32_modules=None,
|
217
|
+
hf_quantizer: Optional[DiffusersQuantizer] = None,
|
218
|
+
keep_in_fp32_modules: Optional[List] = None,
|
219
|
+
device_map: Optional[Dict[str, Union[int, str, torch.device]]] = None,
|
220
|
+
unexpected_keys: Optional[List[str]] = None,
|
221
|
+
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
222
|
+
offload_index: Optional[Dict] = None,
|
223
|
+
state_dict_index: Optional[Dict] = None,
|
224
|
+
state_dict_folder: Optional[Union[str, os.PathLike]] = None,
|
183
225
|
) -> List[str]:
|
184
|
-
|
185
|
-
|
186
|
-
|
187
|
-
|
188
|
-
dtype = dtype or torch.float32
|
189
|
-
is_quantized = hf_quantizer is not None
|
226
|
+
"""
|
227
|
+
This is somewhat similar to `_load_state_dict_into_model`, but deals with a model that has some or all of its
|
228
|
+
params on a `meta` device. It replaces the model params with the data from the `state_dict`
|
229
|
+
"""
|
190
230
|
|
191
|
-
|
231
|
+
is_quantized = hf_quantizer is not None
|
192
232
|
empty_state_dict = model.state_dict()
|
193
|
-
unexpected_keys = [param_name for param_name in state_dict if param_name not in empty_state_dict]
|
194
233
|
|
195
234
|
for param_name, param in state_dict.items():
|
196
235
|
if param_name not in empty_state_dict:
|
@@ -200,21 +239,38 @@ def load_model_dict_into_meta(
|
|
200
239
|
# We convert floating dtypes to the `dtype` passed. We also want to keep the buffers/params
|
201
240
|
# in int/uint/bool and not cast them.
|
202
241
|
# TODO: revisit cases when param.dtype == torch.float8_e4m3fn
|
203
|
-
if torch.is_floating_point(param):
|
204
|
-
if (
|
205
|
-
|
206
|
-
and any(
|
207
|
-
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
208
|
-
)
|
209
|
-
and dtype == torch.float16
|
242
|
+
if dtype is not None and torch.is_floating_point(param):
|
243
|
+
if keep_in_fp32_modules is not None and any(
|
244
|
+
module_to_keep_in_fp32 in param_name.split(".") for module_to_keep_in_fp32 in keep_in_fp32_modules
|
210
245
|
):
|
211
246
|
param = param.to(torch.float32)
|
212
|
-
|
213
|
-
|
247
|
+
set_module_kwargs["dtype"] = torch.float32
|
248
|
+
# For quantizers have save weights using torch.float8_e4m3fn
|
249
|
+
elif hf_quantizer is not None and param.dtype == getattr(torch, "float8_e4m3fn", None):
|
250
|
+
pass
|
214
251
|
else:
|
215
252
|
param = param.to(dtype)
|
216
|
-
|
217
|
-
|
253
|
+
set_module_kwargs["dtype"] = dtype
|
254
|
+
|
255
|
+
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
256
|
+
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
257
|
+
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
258
|
+
old_param = model
|
259
|
+
splits = param_name.split(".")
|
260
|
+
for split in splits:
|
261
|
+
old_param = getattr(old_param, split)
|
262
|
+
|
263
|
+
if not isinstance(old_param, (torch.nn.Parameter, torch.Tensor)):
|
264
|
+
old_param = None
|
265
|
+
|
266
|
+
if old_param is not None:
|
267
|
+
if dtype is None:
|
268
|
+
param = param.to(old_param.dtype)
|
269
|
+
|
270
|
+
if old_param.is_contiguous():
|
271
|
+
param = param.contiguous()
|
272
|
+
|
273
|
+
param_device = _determine_param_device(param_name, device_map)
|
218
274
|
|
219
275
|
# bnb params are flattened.
|
220
276
|
# gguf quants have a different shape based on the type of quantization applied
|
@@ -222,7 +278,9 @@ def load_model_dict_into_meta(
|
|
222
278
|
if (
|
223
279
|
is_quantized
|
224
280
|
and hf_quantizer.pre_quantized
|
225
|
-
and hf_quantizer.check_if_quantized_param(
|
281
|
+
and hf_quantizer.check_if_quantized_param(
|
282
|
+
model, param, param_name, state_dict, param_device=param_device
|
283
|
+
)
|
226
284
|
):
|
227
285
|
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
228
286
|
else:
|
@@ -230,21 +288,25 @@ def load_model_dict_into_meta(
|
|
230
288
|
raise ValueError(
|
231
289
|
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
232
290
|
)
|
233
|
-
|
234
|
-
|
235
|
-
|
291
|
+
if param_device == "disk":
|
292
|
+
offload_index = offload_weight(param, param_name, offload_folder, offload_index)
|
293
|
+
elif param_device == "cpu" and state_dict_index is not None:
|
294
|
+
state_dict_index = offload_weight(param, param_name, state_dict_folder, state_dict_index)
|
295
|
+
elif is_quantized and (
|
296
|
+
hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=param_device)
|
236
297
|
):
|
237
|
-
hf_quantizer.create_quantized_param(
|
298
|
+
hf_quantizer.create_quantized_param(
|
299
|
+
model, param, param_name, param_device, state_dict, unexpected_keys, dtype=dtype
|
300
|
+
)
|
238
301
|
else:
|
239
|
-
|
240
|
-
set_module_tensor_to_device(model, param_name, device, value=param, **set_module_kwargs)
|
241
|
-
else:
|
242
|
-
set_module_tensor_to_device(model, param_name, device, value=param)
|
302
|
+
set_module_tensor_to_device(model, param_name, param_device, value=param, **set_module_kwargs)
|
243
303
|
|
244
|
-
return
|
304
|
+
return offload_index, state_dict_index
|
245
305
|
|
246
306
|
|
247
|
-
def _load_state_dict_into_model(
|
307
|
+
def _load_state_dict_into_model(
|
308
|
+
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
309
|
+
) -> List[str]:
|
248
310
|
# Convert old format to new format if needed from a PyTorch state_dict
|
249
311
|
# copy state_dict so _load_from_state_dict can modify it
|
250
312
|
state_dict = state_dict.copy()
|
@@ -252,15 +314,19 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
|
|
252
314
|
|
253
315
|
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
254
316
|
# so we need to apply the function recursively.
|
255
|
-
def load(module: torch.nn.Module, prefix: str = ""):
|
256
|
-
|
317
|
+
def load(module: torch.nn.Module, prefix: str = "", assign_to_params_buffers: bool = False):
|
318
|
+
local_metadata = {}
|
319
|
+
local_metadata["assign_to_params_buffers"] = assign_to_params_buffers
|
320
|
+
if assign_to_params_buffers and not is_torch_version(">=", "2.1"):
|
321
|
+
logger.info("You need to have torch>=2.1 in order to load the model with assign_to_params_buffers=True")
|
322
|
+
args = (state_dict, prefix, local_metadata, True, [], [], error_msgs)
|
257
323
|
module._load_from_state_dict(*args)
|
258
324
|
|
259
325
|
for name, child in module._modules.items():
|
260
326
|
if child is not None:
|
261
|
-
load(child, prefix + name + ".")
|
327
|
+
load(child, prefix + name + ".", assign_to_params_buffers)
|
262
328
|
|
263
|
-
load(model_to_load)
|
329
|
+
load(model_to_load, assign_to_params_buffers=assign_to_params_buffers)
|
264
330
|
|
265
331
|
return error_msgs
|
266
332
|
|
@@ -279,6 +345,7 @@ def _fetch_index_file(
|
|
279
345
|
revision,
|
280
346
|
user_agent,
|
281
347
|
commit_hash,
|
348
|
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
282
349
|
):
|
283
350
|
if is_local:
|
284
351
|
index_file = Path(
|
@@ -304,43 +371,16 @@ def _fetch_index_file(
|
|
304
371
|
subfolder=None,
|
305
372
|
user_agent=user_agent,
|
306
373
|
commit_hash=commit_hash,
|
374
|
+
dduf_entries=dduf_entries,
|
307
375
|
)
|
308
|
-
|
376
|
+
if not dduf_entries:
|
377
|
+
index_file = Path(index_file)
|
309
378
|
except (EntryNotFoundError, EnvironmentError):
|
310
379
|
index_file = None
|
311
380
|
|
312
381
|
return index_file
|
313
382
|
|
314
383
|
|
315
|
-
# Adapted from
|
316
|
-
# https://github.com/bghira/SimpleTuner/blob/cea2457ab063f6dedb9e697830ae68a96be90641/helpers/training/save_hooks.py#L64
|
317
|
-
def _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata):
|
318
|
-
weight_map = sharded_metadata.get("weight_map", None)
|
319
|
-
if weight_map is None:
|
320
|
-
raise KeyError("'weight_map' key not found in the shard index file.")
|
321
|
-
|
322
|
-
# Collect all unique safetensors files from weight_map
|
323
|
-
files_to_load = set(weight_map.values())
|
324
|
-
is_safetensors = all(f.endswith(".safetensors") for f in files_to_load)
|
325
|
-
merged_state_dict = {}
|
326
|
-
|
327
|
-
# Load tensors from each unique file
|
328
|
-
for file_name in files_to_load:
|
329
|
-
part_file_path = os.path.join(sharded_ckpt_cached_folder, file_name)
|
330
|
-
if not os.path.exists(part_file_path):
|
331
|
-
raise FileNotFoundError(f"Part file {file_name} not found.")
|
332
|
-
|
333
|
-
if is_safetensors:
|
334
|
-
with safetensors.safe_open(part_file_path, framework="pt", device="cpu") as f:
|
335
|
-
for tensor_key in f.keys():
|
336
|
-
if tensor_key in weight_map:
|
337
|
-
merged_state_dict[tensor_key] = f.get_tensor(tensor_key)
|
338
|
-
else:
|
339
|
-
merged_state_dict.update(torch.load(part_file_path, weights_only=True, map_location="cpu"))
|
340
|
-
|
341
|
-
return merged_state_dict
|
342
|
-
|
343
|
-
|
344
384
|
def _fetch_index_file_legacy(
|
345
385
|
is_local,
|
346
386
|
pretrained_model_name_or_path,
|
@@ -355,6 +395,7 @@ def _fetch_index_file_legacy(
|
|
355
395
|
revision,
|
356
396
|
user_agent,
|
357
397
|
commit_hash,
|
398
|
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
358
399
|
):
|
359
400
|
if is_local:
|
360
401
|
index_file = Path(
|
@@ -395,6 +436,7 @@ def _fetch_index_file_legacy(
|
|
395
436
|
subfolder=None,
|
396
437
|
user_agent=user_agent,
|
397
438
|
commit_hash=commit_hash,
|
439
|
+
dduf_entries=dduf_entries,
|
398
440
|
)
|
399
441
|
index_file = Path(index_file)
|
400
442
|
deprecation_message = f"This serialization format is now deprecated to standardize the serialization format between `transformers` and `diffusers`. We recommend you to remove the existing files associated with the current variant ({variant}) and re-obtain them by running a `save_pretrained()`."
|