diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
diffusers/utils/export_utils.py
CHANGED
@@ -155,7 +155,7 @@ def export_to_video(
|
|
155
155
|
bitrate:
|
156
156
|
Set a constant bitrate for the video encoding. Default is None causing `quality` parameter to be used instead.
|
157
157
|
Better quality videos with smaller file sizes will result from using the `quality` variable bitrate parameter
|
158
|
-
rather than
|
158
|
+
rather than specifying a fixed bitrate with this parameter.
|
159
159
|
|
160
160
|
macro_block_size:
|
161
161
|
Size constraint for video. Width and height, must be divisible by this number. If not divisible by this number
|
diffusers/utils/hub_utils.py
CHANGED
@@ -304,8 +304,7 @@ def _get_model_file(
|
|
304
304
|
raise EnvironmentError(
|
305
305
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
306
306
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
307
|
-
"token having permission to this repo with `token` or log in with `
|
308
|
-
"login`."
|
307
|
+
"token having permission to this repo with `token` or log in with `hf auth login`."
|
309
308
|
) from e
|
310
309
|
except RevisionNotFoundError as e:
|
311
310
|
raise EnvironmentError(
|
@@ -403,15 +402,17 @@ def _get_checkpoint_shard_files(
|
|
403
402
|
allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
|
404
403
|
|
405
404
|
ignore_patterns = ["*.json", "*.md"]
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
405
|
+
|
406
|
+
# If the repo doesn't have the required shards, error out early even before downloading anything.
|
407
|
+
if not local_files_only:
|
408
|
+
model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
|
409
|
+
for shard_file in original_shard_filenames:
|
410
|
+
shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
|
411
|
+
if not shard_file_present:
|
412
|
+
raise EnvironmentError(
|
413
|
+
f"{shards_path} does not appear to have a file named {shard_file} which is "
|
414
|
+
"required according to the checkpoint index."
|
415
|
+
)
|
415
416
|
|
416
417
|
try:
|
417
418
|
# Load from URL
|
@@ -438,6 +439,11 @@ def _get_checkpoint_shard_files(
|
|
438
439
|
) from e
|
439
440
|
|
440
441
|
cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
|
442
|
+
for cached_file in cached_filenames:
|
443
|
+
if not os.path.isfile(cached_file):
|
444
|
+
raise EnvironmentError(
|
445
|
+
f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
|
446
|
+
)
|
441
447
|
|
442
448
|
return cached_filenames, sharded_metadata
|
443
449
|
|
@@ -467,6 +473,7 @@ class PushToHubMixin:
|
|
467
473
|
token: Optional[str] = None,
|
468
474
|
commit_message: Optional[str] = None,
|
469
475
|
create_pr: bool = False,
|
476
|
+
subfolder: Optional[str] = None,
|
470
477
|
):
|
471
478
|
"""
|
472
479
|
Uploads all files in `working_dir` to `repo_id`.
|
@@ -481,7 +488,12 @@ class PushToHubMixin:
|
|
481
488
|
|
482
489
|
logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
|
483
490
|
return upload_folder(
|
484
|
-
repo_id=repo_id,
|
491
|
+
repo_id=repo_id,
|
492
|
+
folder_path=working_dir,
|
493
|
+
token=token,
|
494
|
+
commit_message=commit_message,
|
495
|
+
create_pr=create_pr,
|
496
|
+
path_in_repo=subfolder,
|
485
497
|
)
|
486
498
|
|
487
499
|
def push_to_hub(
|
@@ -493,6 +505,7 @@ class PushToHubMixin:
|
|
493
505
|
create_pr: bool = False,
|
494
506
|
safe_serialization: bool = True,
|
495
507
|
variant: Optional[str] = None,
|
508
|
+
subfolder: Optional[str] = None,
|
496
509
|
) -> str:
|
497
510
|
"""
|
498
511
|
Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.
|
@@ -508,8 +521,8 @@ class PushToHubMixin:
|
|
508
521
|
Whether to make the repo private. If `None` (default), the repo will be public unless the
|
509
522
|
organization's default is private. This value is ignored if the repo already exists.
|
510
523
|
token (`str`, *optional*):
|
511
|
-
The token to use as HTTP bearer authorization for remote files. The token generated when running
|
512
|
-
|
524
|
+
The token to use as HTTP bearer authorization for remote files. The token generated when running `hf
|
525
|
+
auth login` (stored in `~/.huggingface`).
|
513
526
|
create_pr (`bool`, *optional*, defaults to `False`):
|
514
527
|
Whether or not to create a PR with the uploaded files or directly commit.
|
515
528
|
safe_serialization (`bool`, *optional*, defaults to `True`):
|
@@ -534,8 +547,9 @@ class PushToHubMixin:
|
|
534
547
|
repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
|
535
548
|
|
536
549
|
# Create a new empty model card and eventually tag it
|
537
|
-
|
538
|
-
|
550
|
+
if not subfolder:
|
551
|
+
model_card = load_or_create_model_card(repo_id, token=token)
|
552
|
+
model_card = populate_model_card(model_card)
|
539
553
|
|
540
554
|
# Save all files.
|
541
555
|
save_kwargs = {"safe_serialization": safe_serialization}
|
@@ -546,7 +560,8 @@ class PushToHubMixin:
|
|
546
560
|
self.save_pretrained(tmpdir, **save_kwargs)
|
547
561
|
|
548
562
|
# Update model card if needed:
|
549
|
-
|
563
|
+
if not subfolder:
|
564
|
+
model_card.save(os.path.join(tmpdir, "README.md"))
|
550
565
|
|
551
566
|
return self._upload_folder(
|
552
567
|
tmpdir,
|
@@ -554,4 +569,5 @@ class PushToHubMixin:
|
|
554
569
|
token=token,
|
555
570
|
commit_message=commit_message,
|
556
571
|
create_pr=create_pr,
|
572
|
+
subfolder=subfolder,
|
557
573
|
)
|
diffusers/utils/import_utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,13 +16,14 @@ Import utilities: Utilities related to imports and our lazy inits.
|
|
16
16
|
"""
|
17
17
|
|
18
18
|
import importlib.util
|
19
|
+
import inspect
|
19
20
|
import operator as op
|
20
21
|
import os
|
21
22
|
import sys
|
22
|
-
from collections import OrderedDict
|
23
|
+
from collections import OrderedDict, defaultdict
|
23
24
|
from itertools import chain
|
24
25
|
from types import ModuleType
|
25
|
-
from typing import Any, Union
|
26
|
+
from typing import Any, Tuple, Union
|
26
27
|
|
27
28
|
from huggingface_hub.utils import is_jinja_available # noqa: F401
|
28
29
|
from packaging.version import Version, parse
|
@@ -35,7 +36,10 @@ if sys.version_info < (3, 8):
|
|
35
36
|
import importlib_metadata
|
36
37
|
else:
|
37
38
|
import importlib.metadata as importlib_metadata
|
38
|
-
|
39
|
+
try:
|
40
|
+
_package_map = importlib_metadata.packages_distributions() # load-once to avoid expensive calls
|
41
|
+
except Exception:
|
42
|
+
_package_map = None
|
39
43
|
|
40
44
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
41
45
|
|
@@ -54,12 +58,33 @@ STR_OPERATION_TO_FUNC = {">": op.gt, ">=": op.ge, "==": op.eq, "!=": op.ne, "<="
|
|
54
58
|
_is_google_colab = "google.colab" in sys.modules or any(k.startswith("COLAB_") for k in os.environ)
|
55
59
|
|
56
60
|
|
57
|
-
def _is_package_available(pkg_name: str):
|
61
|
+
def _is_package_available(pkg_name: str, get_dist_name: bool = False) -> Tuple[bool, str]:
|
62
|
+
global _package_map
|
58
63
|
pkg_exists = importlib.util.find_spec(pkg_name) is not None
|
59
64
|
pkg_version = "N/A"
|
60
65
|
|
61
66
|
if pkg_exists:
|
67
|
+
if _package_map is None:
|
68
|
+
_package_map = defaultdict(list)
|
69
|
+
try:
|
70
|
+
# Fallback for Python < 3.10
|
71
|
+
for dist in importlib_metadata.distributions():
|
72
|
+
_top_level_declared = (dist.read_text("top_level.txt") or "").split()
|
73
|
+
_infered_opt_names = {
|
74
|
+
f.parts[0] if len(f.parts) > 1 else inspect.getmodulename(f) for f in (dist.files or [])
|
75
|
+
} - {None}
|
76
|
+
_top_level_inferred = filter(lambda name: "." not in name, _infered_opt_names)
|
77
|
+
for pkg in _top_level_declared or _top_level_inferred:
|
78
|
+
_package_map[pkg].append(dist.metadata["Name"])
|
79
|
+
except Exception as _:
|
80
|
+
pass
|
62
81
|
try:
|
82
|
+
if get_dist_name and pkg_name in _package_map and _package_map[pkg_name]:
|
83
|
+
if len(_package_map[pkg_name]) > 1:
|
84
|
+
logger.warning(
|
85
|
+
f"Multiple distributions found for package {pkg_name}. Picked distribution: {_package_map[pkg_name][0]}"
|
86
|
+
)
|
87
|
+
pkg_name = _package_map[pkg_name][0]
|
63
88
|
pkg_version = importlib_metadata.version(pkg_name)
|
64
89
|
logger.debug(f"Successfully imported {pkg_name} version {pkg_version}")
|
65
90
|
except (ImportError, importlib_metadata.PackageNotFoundError):
|
@@ -74,6 +99,7 @@ if USE_TORCH in ENV_VARS_TRUE_AND_AUTO_VALUES and USE_TF not in ENV_VARS_TRUE_VA
|
|
74
99
|
else:
|
75
100
|
logger.info("Disabling PyTorch because USE_TORCH is set")
|
76
101
|
_torch_available = False
|
102
|
+
_torch_version = "N/A"
|
77
103
|
|
78
104
|
_jax_version = "N/A"
|
79
105
|
_flax_version = "N/A"
|
@@ -101,18 +127,20 @@ _onnx_available = importlib.util.find_spec("onnxruntime") is not None
|
|
101
127
|
if _onnx_available:
|
102
128
|
candidates = (
|
103
129
|
"onnxruntime",
|
130
|
+
"onnxruntime-cann",
|
131
|
+
"onnxruntime-directml",
|
132
|
+
"ort_nightly_directml",
|
104
133
|
"onnxruntime-gpu",
|
105
134
|
"ort_nightly_gpu",
|
106
|
-
"onnxruntime-
|
135
|
+
"onnxruntime-migraphx",
|
107
136
|
"onnxruntime-openvino",
|
108
|
-
"
|
137
|
+
"onnxruntime-qnn",
|
109
138
|
"onnxruntime-rocm",
|
110
|
-
"onnxruntime-migraphx",
|
111
139
|
"onnxruntime-training",
|
112
140
|
"onnxruntime-vitisai",
|
113
141
|
)
|
114
142
|
_onnxruntime_version = None
|
115
|
-
# For the metadata, we have to look for both onnxruntime and onnxruntime-
|
143
|
+
# For the metadata, we have to look for both onnxruntime and onnxruntime-x
|
116
144
|
for pkg in candidates:
|
117
145
|
try:
|
118
146
|
_onnxruntime_version = importlib_metadata.version(pkg)
|
@@ -164,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
|
|
164
192
|
_torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
|
165
193
|
_transformers_available, _transformers_version = _is_package_available("transformers")
|
166
194
|
_hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
|
195
|
+
_kernels_available, _kernels_version = _is_package_available("kernels")
|
167
196
|
_inflect_available, _inflect_version = _is_package_available("inflect")
|
168
197
|
_unidecode_available, _unidecode_version = _is_package_available("unidecode")
|
169
198
|
_k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
|
@@ -187,15 +216,15 @@ _xformers_available, _xformers_version = _is_package_available("xformers")
|
|
187
216
|
_gguf_available, _gguf_version = _is_package_available("gguf")
|
188
217
|
_torchao_available, _torchao_version = _is_package_available("torchao")
|
189
218
|
_bitsandbytes_available, _bitsandbytes_version = _is_package_available("bitsandbytes")
|
190
|
-
|
191
|
-
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
|
198
|
-
|
219
|
+
_optimum_quanto_available, _optimum_quanto_version = _is_package_available("optimum", get_dist_name=True)
|
220
|
+
_pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_available("pytorch_retinaface")
|
221
|
+
_better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
|
222
|
+
_nltk_available, _nltk_version = _is_package_available("nltk")
|
223
|
+
_cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
|
224
|
+
_sageattention_available, _sageattention_version = _is_package_available("sageattention")
|
225
|
+
_flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
|
226
|
+
_flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
|
227
|
+
_kornia_available, _kornia_version = _is_package_available("kornia")
|
199
228
|
|
200
229
|
|
201
230
|
def is_torch_available():
|
@@ -250,6 +279,10 @@ def is_accelerate_available():
|
|
250
279
|
return _accelerate_available
|
251
280
|
|
252
281
|
|
282
|
+
def is_kernels_available():
|
283
|
+
return _kernels_available
|
284
|
+
|
285
|
+
|
253
286
|
def is_k_diffusion_available():
|
254
287
|
return _k_diffusion_available
|
255
288
|
|
@@ -334,6 +367,42 @@ def is_timm_available():
|
|
334
367
|
return _timm_available
|
335
368
|
|
336
369
|
|
370
|
+
def is_pytorch_retinaface_available():
|
371
|
+
return _pytorch_retinaface_available
|
372
|
+
|
373
|
+
|
374
|
+
def is_better_profanity_available():
|
375
|
+
return _better_profanity_available
|
376
|
+
|
377
|
+
|
378
|
+
def is_nltk_available():
|
379
|
+
return _nltk_available
|
380
|
+
|
381
|
+
|
382
|
+
def is_cosmos_guardrail_available():
|
383
|
+
return _cosmos_guardrail_available
|
384
|
+
|
385
|
+
|
386
|
+
def is_hpu_available():
|
387
|
+
return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
|
388
|
+
|
389
|
+
|
390
|
+
def is_sageattention_available():
|
391
|
+
return _sageattention_available
|
392
|
+
|
393
|
+
|
394
|
+
def is_flash_attn_available():
|
395
|
+
return _flash_attn_available
|
396
|
+
|
397
|
+
|
398
|
+
def is_flash_attn_3_available():
|
399
|
+
return _flash_attn_3_available
|
400
|
+
|
401
|
+
|
402
|
+
def is_kornia_available():
|
403
|
+
return _kornia_available
|
404
|
+
|
405
|
+
|
337
406
|
# docstyle-ignore
|
338
407
|
FLAX_IMPORT_ERROR = """
|
339
408
|
{0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
|
@@ -482,6 +551,22 @@ QUANTO_IMPORT_ERROR = """
|
|
482
551
|
install optimum-quanto`
|
483
552
|
"""
|
484
553
|
|
554
|
+
# docstyle-ignore
|
555
|
+
PYTORCH_RETINAFACE_IMPORT_ERROR = """
|
556
|
+
{0} requires the pytorch_retinaface library but it was not found in your environment. You can install it with pip: `pip install pytorch_retinaface`
|
557
|
+
"""
|
558
|
+
|
559
|
+
# docstyle-ignore
|
560
|
+
BETTER_PROFANITY_IMPORT_ERROR = """
|
561
|
+
{0} requires the better_profanity library but it was not found in your environment. You can install it with pip: `pip install better_profanity`
|
562
|
+
"""
|
563
|
+
|
564
|
+
# docstyle-ignore
|
565
|
+
NLTK_IMPORT_ERROR = """
|
566
|
+
{0} requires the nltk library but it was not found in your environment. You can install it with pip: `pip install nltk`
|
567
|
+
"""
|
568
|
+
|
569
|
+
|
485
570
|
BACKENDS_MAPPING = OrderedDict(
|
486
571
|
[
|
487
572
|
("bs4", (is_bs4_available, BS4_IMPORT_ERROR)),
|
@@ -510,6 +595,9 @@ BACKENDS_MAPPING = OrderedDict(
|
|
510
595
|
("gguf", (is_gguf_available, GGUF_IMPORT_ERROR)),
|
511
596
|
("torchao", (is_torchao_available, TORCHAO_IMPORT_ERROR)),
|
512
597
|
("quanto", (is_optimum_quanto_available, QUANTO_IMPORT_ERROR)),
|
598
|
+
("pytorch_retinaface", (is_pytorch_retinaface_available, PYTORCH_RETINAFACE_IMPORT_ERROR)),
|
599
|
+
("better_profanity", (is_better_profanity_available, BETTER_PROFANITY_IMPORT_ERROR)),
|
600
|
+
("nltk", (is_nltk_available, NLTK_IMPORT_ERROR)),
|
513
601
|
]
|
514
602
|
)
|
515
603
|
|
@@ -741,6 +829,51 @@ def is_optimum_quanto_version(operation: str, version: str):
|
|
741
829
|
return compare_versions(parse(_optimum_quanto_version), operation, version)
|
742
830
|
|
743
831
|
|
832
|
+
def is_xformers_version(operation: str, version: str):
|
833
|
+
"""
|
834
|
+
Compares the current xformers version to a given reference with an operation.
|
835
|
+
|
836
|
+
Args:
|
837
|
+
operation (`str`):
|
838
|
+
A string representation of an operator, such as `">"` or `"<="`
|
839
|
+
version (`str`):
|
840
|
+
A version string
|
841
|
+
"""
|
842
|
+
if not _xformers_available:
|
843
|
+
return False
|
844
|
+
return compare_versions(parse(_xformers_version), operation, version)
|
845
|
+
|
846
|
+
|
847
|
+
def is_sageattention_version(operation: str, version: str):
|
848
|
+
"""
|
849
|
+
Compares the current sageattention version to a given reference with an operation.
|
850
|
+
|
851
|
+
Args:
|
852
|
+
operation (`str`):
|
853
|
+
A string representation of an operator, such as `">"` or `"<="`
|
854
|
+
version (`str`):
|
855
|
+
A version string
|
856
|
+
"""
|
857
|
+
if not _sageattention_available:
|
858
|
+
return False
|
859
|
+
return compare_versions(parse(_sageattention_version), operation, version)
|
860
|
+
|
861
|
+
|
862
|
+
def is_flash_attn_version(operation: str, version: str):
|
863
|
+
"""
|
864
|
+
Compares the current flash-attention version to a given reference with an operation.
|
865
|
+
|
866
|
+
Args:
|
867
|
+
operation (`str`):
|
868
|
+
A string representation of an operator, such as `">"` or `"<="`
|
869
|
+
version (`str`):
|
870
|
+
A version string
|
871
|
+
"""
|
872
|
+
if not _flash_attn_available:
|
873
|
+
return False
|
874
|
+
return compare_versions(parse(_flash_attn_version), operation, version)
|
875
|
+
|
876
|
+
|
744
877
|
def get_objects_from_module(module):
|
745
878
|
"""
|
746
879
|
Returns a dict of object names and values in a module, while skipping private/internal objects
|
diffusers/utils/logging.py
CHANGED
diffusers/utils/outputs.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -71,6 +71,7 @@ class BaseOutput(OrderedDict):
|
|
71
71
|
cls,
|
72
72
|
torch.utils._pytree._dict_flatten,
|
73
73
|
lambda values, context: cls(**torch.utils._pytree._dict_unflatten(values, context)),
|
74
|
+
serialized_type_name=f"{cls.__module__}.{cls.__name__}",
|
74
75
|
)
|
75
76
|
|
76
77
|
def __post_init__(self) -> None:
|
diffusers/utils/peft_utils.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -21,9 +21,13 @@ from typing import Optional
|
|
21
21
|
|
22
22
|
from packaging import version
|
23
23
|
|
24
|
-
from .
|
24
|
+
from . import logging
|
25
|
+
from .import_utils import is_peft_available, is_peft_version, is_torch_available
|
26
|
+
from .torch_utils import empty_device_cache
|
25
27
|
|
26
28
|
|
29
|
+
logger = logging.get_logger(__name__)
|
30
|
+
|
27
31
|
if is_torch_available():
|
28
32
|
import torch
|
29
33
|
|
@@ -95,8 +99,7 @@ def recurse_remove_peft_layers(model):
|
|
95
99
|
setattr(model, name, new_module)
|
96
100
|
del module
|
97
101
|
|
98
|
-
|
99
|
-
torch.cuda.empty_cache()
|
102
|
+
empty_device_cache()
|
100
103
|
return model
|
101
104
|
|
102
105
|
|
@@ -147,25 +150,27 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
|
|
147
150
|
module.set_scale(adapter_name, 1.0)
|
148
151
|
|
149
152
|
|
150
|
-
def get_peft_kwargs(
|
153
|
+
def get_peft_kwargs(
|
154
|
+
rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
155
|
+
):
|
151
156
|
rank_pattern = {}
|
152
157
|
alpha_pattern = {}
|
153
158
|
r = lora_alpha = list(rank_dict.values())[0]
|
154
159
|
|
155
160
|
if len(set(rank_dict.values())) > 1:
|
156
|
-
# get the rank
|
161
|
+
# get the rank occurring the most number of times
|
157
162
|
r = collections.Counter(rank_dict.values()).most_common()[0][0]
|
158
163
|
|
159
|
-
# for modules with rank different from the most
|
164
|
+
# for modules with rank different from the most occurring rank, add it to the `rank_pattern`
|
160
165
|
rank_pattern = dict(filter(lambda x: x[1] != r, rank_dict.items()))
|
161
166
|
rank_pattern = {k.split(".lora_B.")[0]: v for k, v in rank_pattern.items()}
|
162
167
|
|
163
168
|
if network_alpha_dict is not None and len(network_alpha_dict) > 0:
|
164
169
|
if len(set(network_alpha_dict.values())) > 1:
|
165
|
-
# get the alpha
|
170
|
+
# get the alpha occurring the most number of times
|
166
171
|
lora_alpha = collections.Counter(network_alpha_dict.values()).most_common()[0][0]
|
167
172
|
|
168
|
-
# for modules with alpha different from the most
|
173
|
+
# for modules with alpha different from the most occurring alpha, add it to the `alpha_pattern`
|
169
174
|
alpha_pattern = dict(filter(lambda x: x[1] != lora_alpha, network_alpha_dict.items()))
|
170
175
|
if is_unet:
|
171
176
|
alpha_pattern = {
|
@@ -177,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
177
182
|
else:
|
178
183
|
lora_alpha = set(network_alpha_dict.values()).pop()
|
179
184
|
|
180
|
-
# layer names without the Diffusers specific
|
181
185
|
target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
|
182
186
|
use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
|
183
187
|
# for now we know that the "bias" keys are only associated with `lora_B`.
|
@@ -192,6 +196,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
|
|
192
196
|
"use_dora": use_dora,
|
193
197
|
"lora_bias": lora_bias,
|
194
198
|
}
|
199
|
+
|
195
200
|
return lora_config_kwargs
|
196
201
|
|
197
202
|
|
@@ -288,3 +293,84 @@ def check_peft_version(min_version: str) -> None:
|
|
288
293
|
f"The version of PEFT you are using is not compatible, please use a version that is greater"
|
289
294
|
f" than {min_version}"
|
290
295
|
)
|
296
|
+
|
297
|
+
|
298
|
+
def _create_lora_config(
|
299
|
+
state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
|
300
|
+
):
|
301
|
+
from peft import LoraConfig
|
302
|
+
|
303
|
+
if metadata is not None:
|
304
|
+
lora_config_kwargs = metadata
|
305
|
+
else:
|
306
|
+
lora_config_kwargs = get_peft_kwargs(
|
307
|
+
rank_pattern_dict,
|
308
|
+
network_alpha_dict=network_alphas,
|
309
|
+
peft_state_dict=state_dict,
|
310
|
+
is_unet=is_unet,
|
311
|
+
model_state_dict=model_state_dict,
|
312
|
+
adapter_name=adapter_name,
|
313
|
+
)
|
314
|
+
|
315
|
+
_maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
|
316
|
+
|
317
|
+
# Version checks for DoRA and lora_bias
|
318
|
+
if "use_dora" in lora_config_kwargs and lora_config_kwargs["use_dora"]:
|
319
|
+
if is_peft_version("<", "0.9.0"):
|
320
|
+
raise ValueError("DoRA requires PEFT >= 0.9.0. Please upgrade.")
|
321
|
+
|
322
|
+
if "lora_bias" in lora_config_kwargs and lora_config_kwargs["lora_bias"]:
|
323
|
+
if is_peft_version("<=", "0.13.2"):
|
324
|
+
raise ValueError("lora_bias requires PEFT >= 0.14.0. Please upgrade.")
|
325
|
+
|
326
|
+
try:
|
327
|
+
return LoraConfig(**lora_config_kwargs)
|
328
|
+
except TypeError as e:
|
329
|
+
raise TypeError("`LoraConfig` class could not be instantiated.") from e
|
330
|
+
|
331
|
+
|
332
|
+
def _maybe_raise_error_for_ambiguous_keys(config):
|
333
|
+
rank_pattern = config["rank_pattern"].copy()
|
334
|
+
target_modules = config["target_modules"]
|
335
|
+
|
336
|
+
for key in list(rank_pattern.keys()):
|
337
|
+
# try to detect ambiguity
|
338
|
+
# `target_modules` can also be a str, in which case this loop would loop
|
339
|
+
# over the chars of the str. The technically correct way to match LoRA keys
|
340
|
+
# in PEFT is to use LoraModel._check_target_module_exists (lora_config, key).
|
341
|
+
# But this cuts it for now.
|
342
|
+
exact_matches = [mod for mod in target_modules if mod == key]
|
343
|
+
substring_matches = [mod for mod in target_modules if key in mod and mod != key]
|
344
|
+
|
345
|
+
if exact_matches and substring_matches:
|
346
|
+
if is_peft_version("<", "0.14.1"):
|
347
|
+
raise ValueError(
|
348
|
+
"There are ambiguous keys present in this LoRA. To load it, please update your `peft` installation - `pip install -U peft`."
|
349
|
+
)
|
350
|
+
|
351
|
+
|
352
|
+
def _maybe_warn_for_unhandled_keys(incompatible_keys, adapter_name):
|
353
|
+
warn_msg = ""
|
354
|
+
if incompatible_keys is not None:
|
355
|
+
# Check only for unexpected keys.
|
356
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
357
|
+
if unexpected_keys:
|
358
|
+
lora_unexpected_keys = [k for k in unexpected_keys if "lora_" in k and adapter_name in k]
|
359
|
+
if lora_unexpected_keys:
|
360
|
+
warn_msg = (
|
361
|
+
f"Loading adapter weights from state_dict led to unexpected keys found in the model:"
|
362
|
+
f" {', '.join(lora_unexpected_keys)}. "
|
363
|
+
)
|
364
|
+
|
365
|
+
# Filter missing keys specific to the current adapter.
|
366
|
+
missing_keys = getattr(incompatible_keys, "missing_keys", None)
|
367
|
+
if missing_keys:
|
368
|
+
lora_missing_keys = [k for k in missing_keys if "lora_" in k and adapter_name in k]
|
369
|
+
if lora_missing_keys:
|
370
|
+
warn_msg += (
|
371
|
+
f"Loading adapter weights from state_dict led to missing keys in the model:"
|
372
|
+
f" {', '.join(lora_missing_keys)}."
|
373
|
+
)
|
374
|
+
|
375
|
+
if warn_msg:
|
376
|
+
logger.warning(warn_msg)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,6 +16,7 @@ State dict utilities: utility methods for converting state dicts easily
|
|
16
16
|
"""
|
17
17
|
|
18
18
|
import enum
|
19
|
+
import json
|
19
20
|
|
20
21
|
from .import_utils import is_torch_available
|
21
22
|
from .logging import get_logger
|
@@ -219,7 +220,7 @@ def convert_state_dict_to_diffusers(state_dict, original_type=None, **kwargs):
|
|
219
220
|
kwargs (`dict`, *args*):
|
220
221
|
Additional arguments to pass to the method.
|
221
222
|
|
222
|
-
- **adapter_name**: For example, in case of PEFT, some keys will be
|
223
|
+
- **adapter_name**: For example, in case of PEFT, some keys will be prepended
|
223
224
|
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
224
225
|
`get_peft_model_state_dict` method:
|
225
226
|
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
@@ -290,7 +291,7 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
|
|
290
291
|
kwargs (`dict`, *args*):
|
291
292
|
Additional arguments to pass to the method.
|
292
293
|
|
293
|
-
- **adapter_name**: For example, in case of PEFT, some keys will be
|
294
|
+
- **adapter_name**: For example, in case of PEFT, some keys will be prepended
|
294
295
|
with the adapter name, therefore needs a special handling. By default PEFT also takes care of that in
|
295
296
|
`get_peft_model_state_dict` method:
|
296
297
|
https://github.com/huggingface/peft/blob/ba0477f2985b1ba311b83459d29895c809404e99/src/peft/utils/save_and_load.py#L92
|
@@ -347,3 +348,19 @@ def state_dict_all_zero(state_dict, filter_str=None):
|
|
347
348
|
state_dict = {k: v for k, v in state_dict.items() if any(f in k for f in filter_str)}
|
348
349
|
|
349
350
|
return all(torch.all(param == 0).item() for param in state_dict.values())
|
351
|
+
|
352
|
+
|
353
|
+
def _load_sft_state_dict_metadata(model_file: str):
|
354
|
+
import safetensors.torch
|
355
|
+
|
356
|
+
from ..loaders.lora_base import LORA_ADAPTER_METADATA_KEY
|
357
|
+
|
358
|
+
with safetensors.torch.safe_open(model_file, framework="pt", device="cpu") as f:
|
359
|
+
metadata = f.metadata() or {}
|
360
|
+
|
361
|
+
metadata.pop("format", None)
|
362
|
+
if metadata:
|
363
|
+
raw = metadata.get(LORA_ADAPTER_METADATA_KEY)
|
364
|
+
return json.loads(raw) if raw else None
|
365
|
+
else:
|
366
|
+
return None
|