diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -752,7 +752,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
752
752
|
condition = self.controlnet_cond_embedding(cond)
|
753
753
|
feat_seq = torch.mean(condition, dim=(2, 3))
|
754
754
|
feat_seq = feat_seq + self.task_embedding[control_idx]
|
755
|
-
if from_multi:
|
755
|
+
if from_multi or len(control_type_idx) == 1:
|
756
756
|
inputs.append(feat_seq.unsqueeze(1))
|
757
757
|
condition_list.append(condition)
|
758
758
|
else:
|
@@ -772,7 +772,7 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
772
772
|
for (idx, condition), scale in zip(enumerate(condition_list[:-1]), conditioning_scale):
|
773
773
|
alpha = self.spatial_ch_projs(x[:, idx])
|
774
774
|
alpha = alpha.unsqueeze(-1).unsqueeze(-1)
|
775
|
-
if from_multi:
|
775
|
+
if from_multi or len(control_type_idx) == 1:
|
776
776
|
controlnet_cond_fuser += condition + alpha
|
777
777
|
else:
|
778
778
|
controlnet_cond_fuser += condition + alpha * scale
|
@@ -819,11 +819,11 @@ class ControlNetUnionModel(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
819
819
|
# 6. scaling
|
820
820
|
if guess_mode and not self.config.global_pool_conditions:
|
821
821
|
scales = torch.logspace(-1, 0, len(down_block_res_samples) + 1, device=sample.device) # 0.1 to 1.0
|
822
|
-
if from_multi:
|
822
|
+
if from_multi or len(control_type_idx) == 1:
|
823
823
|
scales = scales * conditioning_scale[0]
|
824
824
|
down_block_res_samples = [sample * scale for sample, scale in zip(down_block_res_samples, scales)]
|
825
825
|
mid_block_res_sample = mid_block_res_sample * scales[-1] # last one
|
826
|
-
elif from_multi:
|
826
|
+
elif from_multi or len(control_type_idx) == 1:
|
827
827
|
down_block_res_samples = [sample * conditioning_scale[0] for sample in down_block_res_samples]
|
828
828
|
mid_block_res_sample = mid_block_res_sample * conditioning_scale[0]
|
829
829
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -734,17 +734,17 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
734
734
|
unet (`UNet2DConditionModel`):
|
735
735
|
The UNet model we want to control.
|
736
736
|
controlnet (`ControlNetXSAdapter`):
|
737
|
-
The
|
737
|
+
The ControlNet-XS adapter with which the UNet will be fused. If none is given, a new ControlNet-XS
|
738
738
|
adapter will be created.
|
739
739
|
size_ratio (float, *optional*, defaults to `None`):
|
740
|
-
Used to
|
740
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
741
741
|
ctrl_block_out_channels (`List[int]`, *optional*, defaults to `None`):
|
742
|
-
Used to
|
742
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details,
|
743
743
|
where this parameter is called `block_out_channels`.
|
744
744
|
time_embedding_mix (`float`, *optional*, defaults to None):
|
745
|
-
Used to
|
745
|
+
Used to construct the controlnet if none is given. See [`ControlNetXSAdapter.from_unet`] for details.
|
746
746
|
ctrl_optional_kwargs (`Dict`, *optional*, defaults to `None`):
|
747
|
-
Passed to the `init` of the new
|
747
|
+
Passed to the `init` of the new controlnet if no controlnet was given.
|
748
748
|
"""
|
749
749
|
if controlnet is None:
|
750
750
|
controlnet = ControlNetXSAdapter.from_unet(
|
@@ -942,7 +942,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
|
|
942
942
|
|
943
943
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
944
944
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
945
|
-
r"""Enables the FreeU mechanism from https://
|
945
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
946
946
|
|
947
947
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
948
948
|
|
@@ -4,9 +4,9 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
|
-
from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
|
8
|
-
from ...models.modeling_utils import ModelMixin
|
9
7
|
from ...utils import logging
|
8
|
+
from ..controlnets.controlnet import ControlNetModel, ControlNetOutput
|
9
|
+
from ..modeling_utils import ModelMixin
|
10
10
|
|
11
11
|
|
12
12
|
logger = logging.get_logger(__name__)
|
@@ -130,9 +130,8 @@ class MultiControlNetModel(ModelMixin):
|
|
130
130
|
A path to a *directory* containing model weights saved using
|
131
131
|
[`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
|
132
132
|
`./my_model_directory/controlnet`.
|
133
|
-
torch_dtype (`
|
134
|
-
Override the default `torch.dtype` and load the model under this dtype.
|
135
|
-
will be automatically derived from the model's weights.
|
133
|
+
torch_dtype (`torch.dtype`, *optional*):
|
134
|
+
Override the default `torch.dtype` and load the model under this dtype.
|
136
135
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
137
136
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
138
137
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
@@ -4,10 +4,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
|
4
4
|
import torch
|
5
5
|
from torch import nn
|
6
6
|
|
7
|
-
from ...models.controlnets.controlnet import ControlNetOutput
|
8
|
-
from ...models.controlnets.controlnet_union import ControlNetUnionModel
|
9
|
-
from ...models.modeling_utils import ModelMixin
|
10
7
|
from ...utils import logging
|
8
|
+
from ..controlnets.controlnet import ControlNetOutput
|
9
|
+
from ..controlnets.controlnet_union import ControlNetUnionModel
|
10
|
+
from ..modeling_utils import ModelMixin
|
11
11
|
|
12
12
|
|
13
13
|
logger = logging.get_logger(__name__)
|
@@ -143,9 +143,8 @@ class MultiControlNetUnionModel(ModelMixin):
|
|
143
143
|
A path to a *directory* containing model weights saved using
|
144
144
|
[`~models.controlnets.multicontrolnet.MultiControlNetUnionModel.save_pretrained`], e.g.,
|
145
145
|
`./my_model_directory/controlnet`.
|
146
|
-
torch_dtype (`
|
147
|
-
Override the default `torch.dtype` and load the model under this dtype.
|
148
|
-
will be automatically derived from the model's weights.
|
146
|
+
torch_dtype (`torch.dtype`, *optional*):
|
147
|
+
Override the default `torch.dtype` and load the model under this dtype.
|
149
148
|
output_loading_info(`bool`, *optional*, defaults to `False`):
|
150
149
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
151
150
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
diffusers/models/downsampling.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -286,7 +286,7 @@ class KDownsample2D(nn.Module):
|
|
286
286
|
|
287
287
|
|
288
288
|
class CogVideoXDownsample3D(nn.Module):
|
289
|
-
# Todo: Wait for paper
|
289
|
+
# Todo: Wait for paper release.
|
290
290
|
r"""
|
291
291
|
A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
|
292
292
|
|
diffusers/models/embeddings.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -31,7 +31,7 @@ def get_timestep_embedding(
|
|
31
31
|
downscale_freq_shift: float = 1,
|
32
32
|
scale: float = 1,
|
33
33
|
max_period: int = 10000,
|
34
|
-
):
|
34
|
+
) -> torch.Tensor:
|
35
35
|
"""
|
36
36
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
37
37
|
|
@@ -97,7 +97,7 @@ def get_3d_sincos_pos_embed(
|
|
97
97
|
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
98
98
|
spatial dimensions (height and width).
|
99
99
|
temporal_size (`int`):
|
100
|
-
The temporal dimension of
|
100
|
+
The temporal dimension of positional embeddings (number of frames).
|
101
101
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
102
102
|
Scale factor for spatial grid interpolation.
|
103
103
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
@@ -169,7 +169,7 @@ def _get_3d_sincos_pos_embed_np(
|
|
169
169
|
The spatial dimension of positional embeddings. If an integer is provided, the same size is applied to both
|
170
170
|
spatial dimensions (height and width).
|
171
171
|
temporal_size (`int`):
|
172
|
-
The temporal dimension of
|
172
|
+
The temporal dimension of positional embeddings (number of frames).
|
173
173
|
spatial_interpolation_scale (`float`, defaults to 1.0):
|
174
174
|
Scale factor for spatial grid interpolation.
|
175
175
|
temporal_interpolation_scale (`float`, defaults to 1.0):
|
@@ -319,7 +319,7 @@ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid, output_type="np"):
|
|
319
319
|
return emb
|
320
320
|
|
321
321
|
|
322
|
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
322
|
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np", flip_sin_to_cos=False):
|
323
323
|
"""
|
324
324
|
This function generates 1D positional embeddings from a grid.
|
325
325
|
|
@@ -352,6 +352,11 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos, output_type="np"):
|
|
352
352
|
emb_cos = torch.cos(out) # (M, D/2)
|
353
353
|
|
354
354
|
emb = torch.concat([emb_sin, emb_cos], dim=1) # (M, D)
|
355
|
+
|
356
|
+
# flip sine and cosine embeddings
|
357
|
+
if flip_sin_to_cos:
|
358
|
+
emb = torch.cat([emb[:, embed_dim // 2 :], emb[:, : embed_dim // 2]], dim=1)
|
359
|
+
|
355
360
|
return emb
|
356
361
|
|
357
362
|
|
@@ -1149,9 +1154,7 @@ def get_1d_rotary_pos_embed(
|
|
1149
1154
|
|
1150
1155
|
theta = theta * ntk_factor
|
1151
1156
|
freqs = (
|
1152
|
-
1.0
|
1153
|
-
/ (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device)[: (dim // 2)] / dim))
|
1154
|
-
/ linear_factor
|
1157
|
+
1.0 / (theta ** (torch.arange(0, dim, 2, dtype=freqs_dtype, device=pos.device) / dim)) / linear_factor
|
1155
1158
|
) # [D/2]
|
1156
1159
|
freqs = torch.outer(pos, freqs) # type: ignore # [S, D/2]
|
1157
1160
|
is_npu = freqs.device.type == "npu"
|
@@ -1178,6 +1181,7 @@ def apply_rotary_emb(
|
|
1178
1181
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
1179
1182
|
use_real: bool = True,
|
1180
1183
|
use_real_unbind_dim: int = -1,
|
1184
|
+
sequence_dim: int = 2,
|
1181
1185
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
1182
1186
|
"""
|
1183
1187
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
@@ -1195,17 +1199,24 @@ def apply_rotary_emb(
|
|
1195
1199
|
"""
|
1196
1200
|
if use_real:
|
1197
1201
|
cos, sin = freqs_cis # [S, D]
|
1198
|
-
|
1199
|
-
|
1202
|
+
if sequence_dim == 2:
|
1203
|
+
cos = cos[None, None, :, :]
|
1204
|
+
sin = sin[None, None, :, :]
|
1205
|
+
elif sequence_dim == 1:
|
1206
|
+
cos = cos[None, :, None, :]
|
1207
|
+
sin = sin[None, :, None, :]
|
1208
|
+
else:
|
1209
|
+
raise ValueError(f"`sequence_dim={sequence_dim}` but should be 1 or 2.")
|
1210
|
+
|
1200
1211
|
cos, sin = cos.to(x.device), sin.to(x.device)
|
1201
1212
|
|
1202
1213
|
if use_real_unbind_dim == -1:
|
1203
1214
|
# Used for flux, cogvideox, hunyuan-dit
|
1204
|
-
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B,
|
1215
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, H, S, D//2]
|
1205
1216
|
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
1206
1217
|
elif use_real_unbind_dim == -2:
|
1207
|
-
# Used for Stable Audio, OmniGen and
|
1208
|
-
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B,
|
1218
|
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
1219
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, H, S, D//2]
|
1209
1220
|
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
1210
1221
|
else:
|
1211
1222
|
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
@@ -1240,37 +1251,6 @@ def apply_rotary_emb_allegro(x: torch.Tensor, freqs_cis, positions):
|
|
1240
1251
|
return x
|
1241
1252
|
|
1242
1253
|
|
1243
|
-
class FluxPosEmbed(nn.Module):
|
1244
|
-
# modified from https://github.com/black-forest-labs/flux/blob/c00d7c60b085fce8058b9df845e036090873f2ce/src/flux/modules/layers.py#L11
|
1245
|
-
def __init__(self, theta: int, axes_dim: List[int]):
|
1246
|
-
super().__init__()
|
1247
|
-
self.theta = theta
|
1248
|
-
self.axes_dim = axes_dim
|
1249
|
-
|
1250
|
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
1251
|
-
n_axes = ids.shape[-1]
|
1252
|
-
cos_out = []
|
1253
|
-
sin_out = []
|
1254
|
-
pos = ids.float()
|
1255
|
-
is_mps = ids.device.type == "mps"
|
1256
|
-
is_npu = ids.device.type == "npu"
|
1257
|
-
freqs_dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
1258
|
-
for i in range(n_axes):
|
1259
|
-
cos, sin = get_1d_rotary_pos_embed(
|
1260
|
-
self.axes_dim[i],
|
1261
|
-
pos[:, i],
|
1262
|
-
theta=self.theta,
|
1263
|
-
repeat_interleave_real=True,
|
1264
|
-
use_real=True,
|
1265
|
-
freqs_dtype=freqs_dtype,
|
1266
|
-
)
|
1267
|
-
cos_out.append(cos)
|
1268
|
-
sin_out.append(sin)
|
1269
|
-
freqs_cos = torch.cat(cos_out, dim=-1).to(ids.device)
|
1270
|
-
freqs_sin = torch.cat(sin_out, dim=-1).to(ids.device)
|
1271
|
-
return freqs_cos, freqs_sin
|
1272
|
-
|
1273
|
-
|
1274
1254
|
class TimestepEmbedding(nn.Module):
|
1275
1255
|
def __init__(
|
1276
1256
|
self,
|
@@ -1327,7 +1307,7 @@ class Timesteps(nn.Module):
|
|
1327
1307
|
self.downscale_freq_shift = downscale_freq_shift
|
1328
1308
|
self.scale = scale
|
1329
1309
|
|
1330
|
-
def forward(self, timesteps):
|
1310
|
+
def forward(self, timesteps: torch.Tensor) -> torch.Tensor:
|
1331
1311
|
t_emb = get_timestep_embedding(
|
1332
1312
|
timesteps,
|
1333
1313
|
self.num_channels,
|
@@ -1401,7 +1381,7 @@ class ImagePositionalEmbeddings(nn.Module):
|
|
1401
1381
|
Converts latent image classes into vector embeddings. Sums the vector embeddings with positional embeddings for the
|
1402
1382
|
height and width of the latent space.
|
1403
1383
|
|
1404
|
-
For more details, see figure 10 of the dall-e paper: https://
|
1384
|
+
For more details, see figure 10 of the dall-e paper: https://huggingface.co/papers/2102.12092
|
1405
1385
|
|
1406
1386
|
For VQ-diffusion:
|
1407
1387
|
|
@@ -2621,3 +2601,13 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
2621
2601
|
projected_image_embeds.append(image_embed)
|
2622
2602
|
|
2623
2603
|
return projected_image_embeds
|
2604
|
+
|
2605
|
+
|
2606
|
+
class FluxPosEmbed(nn.Module):
|
2607
|
+
def __new__(cls, *args, **kwargs):
|
2608
|
+
deprecation_message = "Importing and using `FluxPosEmbed` from `diffusers.models.embeddings` is deprecated. Please import it from `diffusers.models.transformers.transformer_flux`."
|
2609
|
+
deprecate("FluxPosEmbed", "1.0.0", deprecation_message)
|
2610
|
+
|
2611
|
+
from .transformers.transformer_flux import FluxPosEmbed
|
2612
|
+
|
2613
|
+
return FluxPosEmbed(*args, **kwargs)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -89,7 +89,7 @@ class FlaxTimestepEmbedding(nn.Module):
|
|
89
89
|
|
90
90
|
class FlaxTimesteps(nn.Module):
|
91
91
|
r"""
|
92
|
-
Wrapper Module for sinusoidal Time step Embeddings as described in https://
|
92
|
+
Wrapper Module for sinusoidal Time step Embeddings as described in https://huggingface.co/papers/2006.11239
|
93
93
|
|
94
94
|
Args:
|
95
95
|
dim (`int`, *optional*, defaults to `32`):
|
diffusers/models/lora.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -38,7 +38,7 @@ if is_transformers_available():
|
|
38
38
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
39
39
|
|
40
40
|
|
41
|
-
def text_encoder_attn_modules(text_encoder):
|
41
|
+
def text_encoder_attn_modules(text_encoder: nn.Module):
|
42
42
|
attn_modules = []
|
43
43
|
|
44
44
|
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
@@ -52,7 +52,7 @@ def text_encoder_attn_modules(text_encoder):
|
|
52
52
|
return attn_modules
|
53
53
|
|
54
54
|
|
55
|
-
def text_encoder_mlp_modules(text_encoder):
|
55
|
+
def text_encoder_mlp_modules(text_encoder: nn.Module):
|
56
56
|
mlp_modules = []
|
57
57
|
|
58
58
|
if isinstance(text_encoder, (CLIPTextModel, CLIPTextModelWithProjection)):
|
@@ -14,11 +14,13 @@
|
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
|
+
import functools
|
17
18
|
import importlib
|
18
19
|
import inspect
|
19
20
|
import os
|
20
21
|
from array import array
|
21
|
-
from collections import OrderedDict
|
22
|
+
from collections import OrderedDict, defaultdict
|
23
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
22
24
|
from pathlib import Path
|
23
25
|
from typing import Dict, List, Optional, Union
|
24
26
|
from zipfile import is_zipfile
|
@@ -30,6 +32,7 @@ from huggingface_hub.utils import EntryNotFoundError
|
|
30
32
|
|
31
33
|
from ..quantizers import DiffusersQuantizer
|
32
34
|
from ..utils import (
|
35
|
+
DEFAULT_HF_PARALLEL_LOADING_WORKERS,
|
33
36
|
GGUF_FILE_EXTENSION,
|
34
37
|
SAFE_WEIGHTS_INDEX_NAME,
|
35
38
|
SAFETENSORS_FILE_EXTENSION,
|
@@ -38,6 +41,7 @@ from ..utils import (
|
|
38
41
|
_get_model_file,
|
39
42
|
deprecate,
|
40
43
|
is_accelerate_available,
|
44
|
+
is_accelerate_version,
|
41
45
|
is_gguf_available,
|
42
46
|
is_torch_available,
|
43
47
|
is_torch_version,
|
@@ -252,6 +256,10 @@ def load_model_dict_into_meta(
|
|
252
256
|
param = param.to(dtype)
|
253
257
|
set_module_kwargs["dtype"] = dtype
|
254
258
|
|
259
|
+
if is_accelerate_version(">", "1.8.1"):
|
260
|
+
set_module_kwargs["non_blocking"] = True
|
261
|
+
set_module_kwargs["clear_cache"] = False
|
262
|
+
|
255
263
|
# For compatibility with PyTorch load_state_dict which converts state dict dtype to existing dtype in model, and which
|
256
264
|
# uses `param.copy_(input_param)` that preserves the contiguity of the parameter in the model.
|
257
265
|
# Reference: https://github.com/pytorch/pytorch/blob/db79ceb110f6646523019a59bbd7b838f43d4a86/torch/nn/modules/module.py#L2040C29-L2040C29
|
@@ -304,6 +312,161 @@ def load_model_dict_into_meta(
|
|
304
312
|
return offload_index, state_dict_index
|
305
313
|
|
306
314
|
|
315
|
+
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
|
316
|
+
"""
|
317
|
+
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
|
318
|
+
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
|
319
|
+
parameters.
|
320
|
+
|
321
|
+
"""
|
322
|
+
if model_to_load.device.type == "meta":
|
323
|
+
return False
|
324
|
+
|
325
|
+
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
326
|
+
return False
|
327
|
+
|
328
|
+
# Some models explicitly do not support param buffer assignment
|
329
|
+
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
330
|
+
logger.debug(
|
331
|
+
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
332
|
+
)
|
333
|
+
return False
|
334
|
+
|
335
|
+
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
336
|
+
first_key = next(iter(model_to_load.state_dict().keys()))
|
337
|
+
if start_prefix + first_key in state_dict:
|
338
|
+
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
339
|
+
|
340
|
+
return False
|
341
|
+
|
342
|
+
|
343
|
+
def _load_shard_file(
|
344
|
+
shard_file,
|
345
|
+
model,
|
346
|
+
model_state_dict,
|
347
|
+
device_map=None,
|
348
|
+
dtype=None,
|
349
|
+
hf_quantizer=None,
|
350
|
+
keep_in_fp32_modules=None,
|
351
|
+
dduf_entries=None,
|
352
|
+
loaded_keys=None,
|
353
|
+
unexpected_keys=None,
|
354
|
+
offload_index=None,
|
355
|
+
offload_folder=None,
|
356
|
+
state_dict_index=None,
|
357
|
+
state_dict_folder=None,
|
358
|
+
ignore_mismatched_sizes=False,
|
359
|
+
low_cpu_mem_usage=False,
|
360
|
+
):
|
361
|
+
state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
|
362
|
+
mismatched_keys = _find_mismatched_keys(
|
363
|
+
state_dict,
|
364
|
+
model_state_dict,
|
365
|
+
loaded_keys,
|
366
|
+
ignore_mismatched_sizes,
|
367
|
+
)
|
368
|
+
error_msgs = []
|
369
|
+
if low_cpu_mem_usage:
|
370
|
+
offload_index, state_dict_index = load_model_dict_into_meta(
|
371
|
+
model,
|
372
|
+
state_dict,
|
373
|
+
device_map=device_map,
|
374
|
+
dtype=dtype,
|
375
|
+
hf_quantizer=hf_quantizer,
|
376
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
377
|
+
unexpected_keys=unexpected_keys,
|
378
|
+
offload_folder=offload_folder,
|
379
|
+
offload_index=offload_index,
|
380
|
+
state_dict_index=state_dict_index,
|
381
|
+
state_dict_folder=state_dict_folder,
|
382
|
+
)
|
383
|
+
else:
|
384
|
+
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
385
|
+
|
386
|
+
error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
|
387
|
+
return offload_index, state_dict_index, mismatched_keys, error_msgs
|
388
|
+
|
389
|
+
|
390
|
+
def _load_shard_files_with_threadpool(
|
391
|
+
shard_files,
|
392
|
+
model,
|
393
|
+
model_state_dict,
|
394
|
+
device_map=None,
|
395
|
+
dtype=None,
|
396
|
+
hf_quantizer=None,
|
397
|
+
keep_in_fp32_modules=None,
|
398
|
+
dduf_entries=None,
|
399
|
+
loaded_keys=None,
|
400
|
+
unexpected_keys=None,
|
401
|
+
offload_index=None,
|
402
|
+
offload_folder=None,
|
403
|
+
state_dict_index=None,
|
404
|
+
state_dict_folder=None,
|
405
|
+
ignore_mismatched_sizes=False,
|
406
|
+
low_cpu_mem_usage=False,
|
407
|
+
):
|
408
|
+
# Do not spawn anymore workers than you need
|
409
|
+
num_workers = min(len(shard_files), DEFAULT_HF_PARALLEL_LOADING_WORKERS)
|
410
|
+
|
411
|
+
logger.info(f"Loading model weights in parallel with {num_workers} workers...")
|
412
|
+
|
413
|
+
error_msgs = []
|
414
|
+
mismatched_keys = []
|
415
|
+
|
416
|
+
load_one = functools.partial(
|
417
|
+
_load_shard_file,
|
418
|
+
model=model,
|
419
|
+
model_state_dict=model_state_dict,
|
420
|
+
device_map=device_map,
|
421
|
+
dtype=dtype,
|
422
|
+
hf_quantizer=hf_quantizer,
|
423
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
424
|
+
dduf_entries=dduf_entries,
|
425
|
+
loaded_keys=loaded_keys,
|
426
|
+
unexpected_keys=unexpected_keys,
|
427
|
+
offload_index=offload_index,
|
428
|
+
offload_folder=offload_folder,
|
429
|
+
state_dict_index=state_dict_index,
|
430
|
+
state_dict_folder=state_dict_folder,
|
431
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
432
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
433
|
+
)
|
434
|
+
|
435
|
+
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
436
|
+
with logging.tqdm(total=len(shard_files), desc="Loading checkpoint shards") as pbar:
|
437
|
+
futures = [executor.submit(load_one, shard_file) for shard_file in shard_files]
|
438
|
+
for future in as_completed(futures):
|
439
|
+
result = future.result()
|
440
|
+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = result
|
441
|
+
error_msgs += _error_msgs
|
442
|
+
mismatched_keys += _mismatched_keys
|
443
|
+
pbar.update(1)
|
444
|
+
|
445
|
+
return offload_index, state_dict_index, mismatched_keys, error_msgs
|
446
|
+
|
447
|
+
|
448
|
+
def _find_mismatched_keys(
|
449
|
+
state_dict,
|
450
|
+
model_state_dict,
|
451
|
+
loaded_keys,
|
452
|
+
ignore_mismatched_sizes,
|
453
|
+
):
|
454
|
+
mismatched_keys = []
|
455
|
+
if ignore_mismatched_sizes:
|
456
|
+
for checkpoint_key in loaded_keys:
|
457
|
+
model_key = checkpoint_key
|
458
|
+
# If the checkpoint is sharded, we may not have the key here.
|
459
|
+
if checkpoint_key not in state_dict:
|
460
|
+
continue
|
461
|
+
|
462
|
+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
463
|
+
mismatched_keys.append(
|
464
|
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
465
|
+
)
|
466
|
+
del state_dict[checkpoint_key]
|
467
|
+
return mismatched_keys
|
468
|
+
|
469
|
+
|
307
470
|
def _load_state_dict_into_model(
|
308
471
|
model_to_load, state_dict: OrderedDict, assign_to_params_buffers: bool = False
|
309
472
|
) -> List[str]:
|
@@ -520,3 +683,72 @@ def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
|
520
683
|
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
521
684
|
|
522
685
|
return parsed_parameters
|
686
|
+
|
687
|
+
|
688
|
+
def _find_mismatched_keys(state_dict, model_state_dict, loaded_keys, ignore_mismatched_sizes):
|
689
|
+
mismatched_keys = []
|
690
|
+
if not ignore_mismatched_sizes:
|
691
|
+
return mismatched_keys
|
692
|
+
for checkpoint_key in loaded_keys:
|
693
|
+
model_key = checkpoint_key
|
694
|
+
# If the checkpoint is sharded, we may not have the key here.
|
695
|
+
if checkpoint_key not in state_dict:
|
696
|
+
continue
|
697
|
+
|
698
|
+
if model_key in model_state_dict and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape:
|
699
|
+
mismatched_keys.append(
|
700
|
+
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
701
|
+
)
|
702
|
+
del state_dict[checkpoint_key]
|
703
|
+
return mismatched_keys
|
704
|
+
|
705
|
+
|
706
|
+
def _expand_device_map(device_map, param_names):
|
707
|
+
"""
|
708
|
+
Expand a device map to return the correspondence parameter name to device.
|
709
|
+
"""
|
710
|
+
new_device_map = {}
|
711
|
+
for module, device in device_map.items():
|
712
|
+
new_device_map.update(
|
713
|
+
{p: device for p in param_names if p == module or p.startswith(f"{module}.") or module == ""}
|
714
|
+
)
|
715
|
+
return new_device_map
|
716
|
+
|
717
|
+
|
718
|
+
# Adapted from: https://github.com/huggingface/transformers/blob/0687d481e2c71544501ef9cb3eef795a6e79b1de/src/transformers/modeling_utils.py#L5859
|
719
|
+
def _caching_allocator_warmup(
|
720
|
+
model, expanded_device_map: Dict[str, torch.device], dtype: torch.dtype, hf_quantizer: Optional[DiffusersQuantizer]
|
721
|
+
) -> None:
|
722
|
+
"""
|
723
|
+
This function warm-ups the caching allocator based on the size of the model tensors that will reside on each
|
724
|
+
device. It allows to have one large call to Malloc, instead of recursively calling it later when loading the model,
|
725
|
+
which is actually the loading speed bottleneck. Calling this function allows to cut the model loading time by a
|
726
|
+
very large margin.
|
727
|
+
"""
|
728
|
+
factor = 2 if hf_quantizer is None else hf_quantizer.get_cuda_warm_up_factor()
|
729
|
+
|
730
|
+
# Keep only accelerator devices
|
731
|
+
accelerator_device_map = {
|
732
|
+
param: torch.device(device)
|
733
|
+
for param, device in expanded_device_map.items()
|
734
|
+
if str(device) not in ["cpu", "disk"]
|
735
|
+
}
|
736
|
+
if not accelerator_device_map:
|
737
|
+
return
|
738
|
+
|
739
|
+
elements_per_device = defaultdict(int)
|
740
|
+
for param_name, device in accelerator_device_map.items():
|
741
|
+
try:
|
742
|
+
p = model.get_parameter(param_name)
|
743
|
+
except AttributeError:
|
744
|
+
try:
|
745
|
+
p = model.get_buffer(param_name)
|
746
|
+
except AttributeError:
|
747
|
+
raise AttributeError(f"Parameter or buffer with name={param_name} not found in model")
|
748
|
+
# TODO: account for TP when needed.
|
749
|
+
elements_per_device[device] += p.numel()
|
750
|
+
|
751
|
+
# This will kick off the caching allocator to avoid having to Malloc afterwards
|
752
|
+
for device, elem_count in elements_per_device.items():
|
753
|
+
warmup_elems = max(1, elem_count // factor)
|
754
|
+
_ = torch.empty(warmup_elems, dtype=dtype, device=device, requires_grad=False)
|
@@ -369,8 +369,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
369
369
|
raise EnvironmentError(
|
370
370
|
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
|
371
371
|
"listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
|
372
|
-
"token having permission to this repo with `token` or log in with `
|
373
|
-
"login`."
|
372
|
+
"token having permission to this repo with `token` or log in with `hf auth login`."
|
374
373
|
)
|
375
374
|
except RevisionNotFoundError:
|
376
375
|
raise EnvironmentError(
|