diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -46,6 +46,7 @@ from ..utils import (
|
|
46
46
|
)
|
47
47
|
from ..utils.constants import DIFFUSERS_REQUEST_TIMEOUT
|
48
48
|
from ..utils.hub_utils import _get_model_file
|
49
|
+
from ..utils.torch_utils import empty_device_cache
|
49
50
|
|
50
51
|
|
51
52
|
if is_transformers_available():
|
@@ -54,11 +55,12 @@ if is_transformers_available():
|
|
54
55
|
if is_accelerate_available():
|
55
56
|
from accelerate import init_empty_weights
|
56
57
|
|
57
|
-
from ..models.
|
58
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
58
59
|
|
59
60
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
60
61
|
|
61
62
|
CHECKPOINT_KEY_NAMES = {
|
63
|
+
"v1": "model.diffusion_model.output_blocks.11.0.skip_connection.weight",
|
62
64
|
"v2": "model.diffusion_model.input_blocks.2.1.transformer_blocks.0.attn2.to_k.weight",
|
63
65
|
"xl_base": "conditioner.embedders.1.model.transformer.resblocks.9.mlp.c_proj.bias",
|
64
66
|
"xl_refiner": "conditioner.embedders.0.model.transformer.resblocks.9.mlp.c_proj.bias",
|
@@ -126,6 +128,18 @@ CHECKPOINT_KEY_NAMES = {
|
|
126
128
|
],
|
127
129
|
"wan": ["model.diffusion_model.head.modulation", "head.modulation"],
|
128
130
|
"wan_vae": "decoder.middle.0.residual.0.gamma",
|
131
|
+
"wan_vace": "vace_blocks.0.after_proj.bias",
|
132
|
+
"hidream": "double_stream_blocks.0.block.adaLN_modulation.1.bias",
|
133
|
+
"cosmos-1.0": [
|
134
|
+
"net.x_embedder.proj.1.weight",
|
135
|
+
"net.blocks.block1.blocks.0.block.attn.to_q.0.weight",
|
136
|
+
"net.extra_pos_embedder.pos_emb_h",
|
137
|
+
],
|
138
|
+
"cosmos-2.0": [
|
139
|
+
"net.x_embedder.proj.1.weight",
|
140
|
+
"net.blocks.0.self_attn.q_proj.weight",
|
141
|
+
"net.pos_embedder.dim_spatial_range",
|
142
|
+
],
|
129
143
|
}
|
130
144
|
|
131
145
|
DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
@@ -177,6 +191,8 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
177
191
|
"flux-schnell": {"pretrained_model_name_or_path": "black-forest-labs/FLUX.1-schnell"},
|
178
192
|
"ltx-video": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.0"},
|
179
193
|
"ltx-video-0.9.1": {"pretrained_model_name_or_path": "diffusers/LTX-Video-0.9.1"},
|
194
|
+
"ltx-video-0.9.5": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.5"},
|
195
|
+
"ltx-video-0.9.7": {"pretrained_model_name_or_path": "Lightricks/LTX-Video-0.9.7-dev"},
|
180
196
|
"autoencoder-dc-f128c512": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f128c512-mix-1.0-diffusers"},
|
181
197
|
"autoencoder-dc-f64c128": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f64c128-mix-1.0-diffusers"},
|
182
198
|
"autoencoder-dc-f32c32": {"pretrained_model_name_or_path": "mit-han-lab/dc-ae-f32c32-mix-1.0-diffusers"},
|
@@ -189,6 +205,17 @@ DIFFUSERS_DEFAULT_PIPELINE_PATHS = {
|
|
189
205
|
"wan-t2v-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers"},
|
190
206
|
"wan-t2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-T2V-14B-Diffusers"},
|
191
207
|
"wan-i2v-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-I2V-14B-480P-Diffusers"},
|
208
|
+
"wan-vace-1.3B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-1.3B-diffusers"},
|
209
|
+
"wan-vace-14B": {"pretrained_model_name_or_path": "Wan-AI/Wan2.1-VACE-14B-diffusers"},
|
210
|
+
"hidream": {"pretrained_model_name_or_path": "HiDream-ai/HiDream-I1-Dev"},
|
211
|
+
"cosmos-1.0-t2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Text2World"},
|
212
|
+
"cosmos-1.0-t2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Text2World"},
|
213
|
+
"cosmos-1.0-v2w-7B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-7B-Video2World"},
|
214
|
+
"cosmos-1.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-1.0-Diffusion-14B-Video2World"},
|
215
|
+
"cosmos-2.0-t2i-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Text2Image"},
|
216
|
+
"cosmos-2.0-t2i-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Text2Image"},
|
217
|
+
"cosmos-2.0-v2w-2B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-2B-Video2World"},
|
218
|
+
"cosmos-2.0-v2w-14B": {"pretrained_model_name_or_path": "nvidia/Cosmos-Predict2-14B-Video2World"},
|
192
219
|
}
|
193
220
|
|
194
221
|
# Use to configure model sample size when original config is provided
|
@@ -404,13 +431,16 @@ def load_single_file_checkpoint(
|
|
404
431
|
local_files_only=None,
|
405
432
|
revision=None,
|
406
433
|
disable_mmap=False,
|
434
|
+
user_agent=None,
|
407
435
|
):
|
436
|
+
if user_agent is None:
|
437
|
+
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
438
|
+
|
408
439
|
if os.path.isfile(pretrained_model_link_or_path):
|
409
440
|
pretrained_model_link_or_path = pretrained_model_link_or_path
|
410
441
|
|
411
442
|
else:
|
412
443
|
repo_id, weights_name = _extract_repo_id_and_weights_name(pretrained_model_link_or_path)
|
413
|
-
user_agent = {"file_type": "single_file", "framework": "pytorch"}
|
414
444
|
pretrained_model_link_or_path = _get_model_file(
|
415
445
|
repo_id,
|
416
446
|
weights_name=weights_name,
|
@@ -638,7 +668,12 @@ def infer_diffusers_model_type(checkpoint):
|
|
638
668
|
model_type = "flux-schnell"
|
639
669
|
|
640
670
|
elif any(key in checkpoint for key in CHECKPOINT_KEY_NAMES["ltx-video"]):
|
641
|
-
|
671
|
+
has_vae = "vae.encoder.conv_in.conv.bias" in checkpoint
|
672
|
+
if any(key.endswith("transformer_blocks.47.scale_shift_table") for key in checkpoint):
|
673
|
+
model_type = "ltx-video-0.9.7"
|
674
|
+
elif has_vae and checkpoint["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
675
|
+
model_type = "ltx-video-0.9.5"
|
676
|
+
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in checkpoint:
|
642
677
|
model_type = "ltx-video-0.9.1"
|
643
678
|
else:
|
644
679
|
model_type = "ltx-video"
|
@@ -686,15 +721,44 @@ def infer_diffusers_model_type(checkpoint):
|
|
686
721
|
else:
|
687
722
|
target_key = "patch_embedding.weight"
|
688
723
|
|
689
|
-
if
|
724
|
+
if CHECKPOINT_KEY_NAMES["wan_vace"] in checkpoint:
|
725
|
+
if checkpoint[target_key].shape[0] == 1536:
|
726
|
+
model_type = "wan-vace-1.3B"
|
727
|
+
elif checkpoint[target_key].shape[0] == 5120:
|
728
|
+
model_type = "wan-vace-14B"
|
729
|
+
|
730
|
+
elif checkpoint[target_key].shape[0] == 1536:
|
690
731
|
model_type = "wan-t2v-1.3B"
|
691
732
|
elif checkpoint[target_key].shape[0] == 5120 and checkpoint[target_key].shape[1] == 16:
|
692
733
|
model_type = "wan-t2v-14B"
|
693
734
|
else:
|
694
735
|
model_type = "wan-i2v-14B"
|
736
|
+
|
695
737
|
elif CHECKPOINT_KEY_NAMES["wan_vae"] in checkpoint:
|
696
738
|
# All Wan models use the same VAE so we can use the same default model repo to fetch the config
|
697
739
|
model_type = "wan-t2v-14B"
|
740
|
+
|
741
|
+
elif CHECKPOINT_KEY_NAMES["hidream"] in checkpoint:
|
742
|
+
model_type = "hidream"
|
743
|
+
|
744
|
+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-1.0"]):
|
745
|
+
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-1.0"][0]].shape
|
746
|
+
if x_embedder_shape[1] == 68:
|
747
|
+
model_type = "cosmos-1.0-t2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-t2w-14B"
|
748
|
+
elif x_embedder_shape[1] == 72:
|
749
|
+
model_type = "cosmos-1.0-v2w-7B" if x_embedder_shape[0] == 4096 else "cosmos-1.0-v2w-14B"
|
750
|
+
else:
|
751
|
+
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 1.0 model.")
|
752
|
+
|
753
|
+
elif all(key in checkpoint for key in CHECKPOINT_KEY_NAMES["cosmos-2.0"]):
|
754
|
+
x_embedder_shape = checkpoint[CHECKPOINT_KEY_NAMES["cosmos-2.0"][0]].shape
|
755
|
+
if x_embedder_shape[1] == 68:
|
756
|
+
model_type = "cosmos-2.0-t2i-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-t2i-14B"
|
757
|
+
elif x_embedder_shape[1] == 72:
|
758
|
+
model_type = "cosmos-2.0-v2w-2B" if x_embedder_shape[0] == 2048 else "cosmos-2.0-v2w-14B"
|
759
|
+
else:
|
760
|
+
raise ValueError(f"Unexpected x_embedder shape: {x_embedder_shape} when loading Cosmos 2.0 model.")
|
761
|
+
|
698
762
|
else:
|
699
763
|
model_type = "v1"
|
700
764
|
|
@@ -1627,6 +1691,7 @@ def create_diffusers_clip_model_from_ldm(
|
|
1627
1691
|
|
1628
1692
|
if is_accelerate_available():
|
1629
1693
|
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1694
|
+
empty_device_cache()
|
1630
1695
|
else:
|
1631
1696
|
model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
1632
1697
|
|
@@ -2086,6 +2151,7 @@ def create_diffusers_t5_model_from_checkpoint(
|
|
2086
2151
|
|
2087
2152
|
if is_accelerate_available():
|
2088
2153
|
load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
2154
|
+
empty_device_cache()
|
2089
2155
|
else:
|
2090
2156
|
model.load_state_dict(diffusers_format_checkpoint)
|
2091
2157
|
|
@@ -2272,7 +2338,7 @@ def convert_flux_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2272
2338
|
f"double_blocks.{i}.txt_attn.proj.bias"
|
2273
2339
|
)
|
2274
2340
|
|
2275
|
-
# single
|
2341
|
+
# single transformer blocks
|
2276
2342
|
for i in range(num_single_layers):
|
2277
2343
|
block_prefix = f"single_transformer_blocks.{i}."
|
2278
2344
|
# norm.linear <- single_blocks.0.modulation.lin
|
@@ -2403,13 +2469,41 @@ def convert_ltx_vae_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2403
2469
|
"last_scale_shift_table": "scale_shift_table",
|
2404
2470
|
}
|
2405
2471
|
|
2472
|
+
VAE_095_RENAME_DICT = {
|
2473
|
+
# decoder
|
2474
|
+
"up_blocks.0": "mid_block",
|
2475
|
+
"up_blocks.1": "up_blocks.0.upsamplers.0",
|
2476
|
+
"up_blocks.2": "up_blocks.0",
|
2477
|
+
"up_blocks.3": "up_blocks.1.upsamplers.0",
|
2478
|
+
"up_blocks.4": "up_blocks.1",
|
2479
|
+
"up_blocks.5": "up_blocks.2.upsamplers.0",
|
2480
|
+
"up_blocks.6": "up_blocks.2",
|
2481
|
+
"up_blocks.7": "up_blocks.3.upsamplers.0",
|
2482
|
+
"up_blocks.8": "up_blocks.3",
|
2483
|
+
# encoder
|
2484
|
+
"down_blocks.0": "down_blocks.0",
|
2485
|
+
"down_blocks.1": "down_blocks.0.downsamplers.0",
|
2486
|
+
"down_blocks.2": "down_blocks.1",
|
2487
|
+
"down_blocks.3": "down_blocks.1.downsamplers.0",
|
2488
|
+
"down_blocks.4": "down_blocks.2",
|
2489
|
+
"down_blocks.5": "down_blocks.2.downsamplers.0",
|
2490
|
+
"down_blocks.6": "down_blocks.3",
|
2491
|
+
"down_blocks.7": "down_blocks.3.downsamplers.0",
|
2492
|
+
"down_blocks.8": "mid_block",
|
2493
|
+
# common
|
2494
|
+
"last_time_embedder": "time_embedder",
|
2495
|
+
"last_scale_shift_table": "scale_shift_table",
|
2496
|
+
}
|
2497
|
+
|
2406
2498
|
VAE_SPECIAL_KEYS_REMAP = {
|
2407
2499
|
"per_channel_statistics.channel": remove_keys_,
|
2408
2500
|
"per_channel_statistics.mean-of-means": remove_keys_,
|
2409
2501
|
"per_channel_statistics.mean-of-stds": remove_keys_,
|
2410
2502
|
}
|
2411
2503
|
|
2412
|
-
if "vae.
|
2504
|
+
if converted_state_dict["vae.encoder.conv_out.conv.weight"].shape[1] == 2048:
|
2505
|
+
VAE_KEYS_RENAME_DICT.update(VAE_095_RENAME_DICT)
|
2506
|
+
elif "vae.decoder.last_time_embedder.timestep_embedder.linear_1.weight" in converted_state_dict:
|
2413
2507
|
VAE_KEYS_RENAME_DICT.update(VAE_091_RENAME_DICT)
|
2414
2508
|
|
2415
2509
|
for key in list(converted_state_dict.keys()):
|
@@ -2838,7 +2932,7 @@ def convert_auraflow_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
|
2838
2932
|
def convert_lumina2_to_diffusers(checkpoint, **kwargs):
|
2839
2933
|
converted_state_dict = {}
|
2840
2934
|
|
2841
|
-
# Original Lumina-Image-2 has an extra norm
|
2935
|
+
# Original Lumina-Image-2 has an extra norm parameter that is unused
|
2842
2936
|
# We just remove it here
|
2843
2937
|
checkpoint.pop("norm_final.weight", None)
|
2844
2938
|
|
@@ -3051,6 +3145,9 @@ def convert_wan_transformer_to_diffusers(checkpoint, **kwargs):
|
|
3051
3145
|
"img_emb.proj.1": "condition_embedder.image_embedder.ff.net.0.proj",
|
3052
3146
|
"img_emb.proj.3": "condition_embedder.image_embedder.ff.net.2",
|
3053
3147
|
"img_emb.proj.4": "condition_embedder.image_embedder.norm2",
|
3148
|
+
# For the VACE model
|
3149
|
+
"before_proj": "proj_in",
|
3150
|
+
"after_proj": "proj_out",
|
3054
3151
|
}
|
3055
3152
|
|
3056
3153
|
for key in list(checkpoint.keys()):
|
@@ -3259,3 +3356,294 @@ def convert_wan_vae_to_diffusers(checkpoint, **kwargs):
|
|
3259
3356
|
converted_state_dict[key] = value
|
3260
3357
|
|
3261
3358
|
return converted_state_dict
|
3359
|
+
|
3360
|
+
|
3361
|
+
def convert_hidream_transformer_to_diffusers(checkpoint, **kwargs):
|
3362
|
+
keys = list(checkpoint.keys())
|
3363
|
+
for k in keys:
|
3364
|
+
if "model.diffusion_model." in k:
|
3365
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
3366
|
+
|
3367
|
+
return checkpoint
|
3368
|
+
|
3369
|
+
|
3370
|
+
def convert_chroma_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
3371
|
+
converted_state_dict = {}
|
3372
|
+
keys = list(checkpoint.keys())
|
3373
|
+
|
3374
|
+
for k in keys:
|
3375
|
+
if "model.diffusion_model." in k:
|
3376
|
+
checkpoint[k.replace("model.diffusion_model.", "")] = checkpoint.pop(k)
|
3377
|
+
|
3378
|
+
num_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "double_blocks." in k))[-1] + 1 # noqa: C401
|
3379
|
+
num_single_layers = list(set(int(k.split(".", 2)[1]) for k in checkpoint if "single_blocks." in k))[-1] + 1 # noqa: C401
|
3380
|
+
num_guidance_layers = (
|
3381
|
+
list(set(int(k.split(".", 3)[2]) for k in checkpoint if "distilled_guidance_layer.layers." in k))[-1] + 1 # noqa: C401
|
3382
|
+
)
|
3383
|
+
mlp_ratio = 4.0
|
3384
|
+
inner_dim = 3072
|
3385
|
+
|
3386
|
+
# in SD3 original implementation of AdaLayerNormContinuous, it split linear projection output into shift, scale;
|
3387
|
+
# while in diffusers it split into scale, shift. Here we swap the linear projection weights in order to be able to use diffusers implementation
|
3388
|
+
def swap_scale_shift(weight):
|
3389
|
+
shift, scale = weight.chunk(2, dim=0)
|
3390
|
+
new_weight = torch.cat([scale, shift], dim=0)
|
3391
|
+
return new_weight
|
3392
|
+
|
3393
|
+
# guidance
|
3394
|
+
converted_state_dict["distilled_guidance_layer.in_proj.bias"] = checkpoint.pop(
|
3395
|
+
"distilled_guidance_layer.in_proj.bias"
|
3396
|
+
)
|
3397
|
+
converted_state_dict["distilled_guidance_layer.in_proj.weight"] = checkpoint.pop(
|
3398
|
+
"distilled_guidance_layer.in_proj.weight"
|
3399
|
+
)
|
3400
|
+
converted_state_dict["distilled_guidance_layer.out_proj.bias"] = checkpoint.pop(
|
3401
|
+
"distilled_guidance_layer.out_proj.bias"
|
3402
|
+
)
|
3403
|
+
converted_state_dict["distilled_guidance_layer.out_proj.weight"] = checkpoint.pop(
|
3404
|
+
"distilled_guidance_layer.out_proj.weight"
|
3405
|
+
)
|
3406
|
+
for i in range(num_guidance_layers):
|
3407
|
+
block_prefix = f"distilled_guidance_layer.layers.{i}."
|
3408
|
+
converted_state_dict[f"{block_prefix}linear_1.bias"] = checkpoint.pop(
|
3409
|
+
f"distilled_guidance_layer.layers.{i}.in_layer.bias"
|
3410
|
+
)
|
3411
|
+
converted_state_dict[f"{block_prefix}linear_1.weight"] = checkpoint.pop(
|
3412
|
+
f"distilled_guidance_layer.layers.{i}.in_layer.weight"
|
3413
|
+
)
|
3414
|
+
converted_state_dict[f"{block_prefix}linear_2.bias"] = checkpoint.pop(
|
3415
|
+
f"distilled_guidance_layer.layers.{i}.out_layer.bias"
|
3416
|
+
)
|
3417
|
+
converted_state_dict[f"{block_prefix}linear_2.weight"] = checkpoint.pop(
|
3418
|
+
f"distilled_guidance_layer.layers.{i}.out_layer.weight"
|
3419
|
+
)
|
3420
|
+
converted_state_dict[f"distilled_guidance_layer.norms.{i}.weight"] = checkpoint.pop(
|
3421
|
+
f"distilled_guidance_layer.norms.{i}.scale"
|
3422
|
+
)
|
3423
|
+
|
3424
|
+
# context_embedder
|
3425
|
+
converted_state_dict["context_embedder.weight"] = checkpoint.pop("txt_in.weight")
|
3426
|
+
converted_state_dict["context_embedder.bias"] = checkpoint.pop("txt_in.bias")
|
3427
|
+
|
3428
|
+
# x_embedder
|
3429
|
+
converted_state_dict["x_embedder.weight"] = checkpoint.pop("img_in.weight")
|
3430
|
+
converted_state_dict["x_embedder.bias"] = checkpoint.pop("img_in.bias")
|
3431
|
+
|
3432
|
+
# double transformer blocks
|
3433
|
+
for i in range(num_layers):
|
3434
|
+
block_prefix = f"transformer_blocks.{i}."
|
3435
|
+
# Q, K, V
|
3436
|
+
sample_q, sample_k, sample_v = torch.chunk(checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.weight"), 3, dim=0)
|
3437
|
+
context_q, context_k, context_v = torch.chunk(
|
3438
|
+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.weight"), 3, dim=0
|
3439
|
+
)
|
3440
|
+
sample_q_bias, sample_k_bias, sample_v_bias = torch.chunk(
|
3441
|
+
checkpoint.pop(f"double_blocks.{i}.img_attn.qkv.bias"), 3, dim=0
|
3442
|
+
)
|
3443
|
+
context_q_bias, context_k_bias, context_v_bias = torch.chunk(
|
3444
|
+
checkpoint.pop(f"double_blocks.{i}.txt_attn.qkv.bias"), 3, dim=0
|
3445
|
+
)
|
3446
|
+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([sample_q])
|
3447
|
+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([sample_q_bias])
|
3448
|
+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([sample_k])
|
3449
|
+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([sample_k_bias])
|
3450
|
+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([sample_v])
|
3451
|
+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([sample_v_bias])
|
3452
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.weight"] = torch.cat([context_q])
|
3453
|
+
converted_state_dict[f"{block_prefix}attn.add_q_proj.bias"] = torch.cat([context_q_bias])
|
3454
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.weight"] = torch.cat([context_k])
|
3455
|
+
converted_state_dict[f"{block_prefix}attn.add_k_proj.bias"] = torch.cat([context_k_bias])
|
3456
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.weight"] = torch.cat([context_v])
|
3457
|
+
converted_state_dict[f"{block_prefix}attn.add_v_proj.bias"] = torch.cat([context_v_bias])
|
3458
|
+
# qk_norm
|
3459
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
3460
|
+
f"double_blocks.{i}.img_attn.norm.query_norm.scale"
|
3461
|
+
)
|
3462
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
3463
|
+
f"double_blocks.{i}.img_attn.norm.key_norm.scale"
|
3464
|
+
)
|
3465
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_q.weight"] = checkpoint.pop(
|
3466
|
+
f"double_blocks.{i}.txt_attn.norm.query_norm.scale"
|
3467
|
+
)
|
3468
|
+
converted_state_dict[f"{block_prefix}attn.norm_added_k.weight"] = checkpoint.pop(
|
3469
|
+
f"double_blocks.{i}.txt_attn.norm.key_norm.scale"
|
3470
|
+
)
|
3471
|
+
# ff img_mlp
|
3472
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.weight"] = checkpoint.pop(
|
3473
|
+
f"double_blocks.{i}.img_mlp.0.weight"
|
3474
|
+
)
|
3475
|
+
converted_state_dict[f"{block_prefix}ff.net.0.proj.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.0.bias")
|
3476
|
+
converted_state_dict[f"{block_prefix}ff.net.2.weight"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.weight")
|
3477
|
+
converted_state_dict[f"{block_prefix}ff.net.2.bias"] = checkpoint.pop(f"double_blocks.{i}.img_mlp.2.bias")
|
3478
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.weight"] = checkpoint.pop(
|
3479
|
+
f"double_blocks.{i}.txt_mlp.0.weight"
|
3480
|
+
)
|
3481
|
+
converted_state_dict[f"{block_prefix}ff_context.net.0.proj.bias"] = checkpoint.pop(
|
3482
|
+
f"double_blocks.{i}.txt_mlp.0.bias"
|
3483
|
+
)
|
3484
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.weight"] = checkpoint.pop(
|
3485
|
+
f"double_blocks.{i}.txt_mlp.2.weight"
|
3486
|
+
)
|
3487
|
+
converted_state_dict[f"{block_prefix}ff_context.net.2.bias"] = checkpoint.pop(
|
3488
|
+
f"double_blocks.{i}.txt_mlp.2.bias"
|
3489
|
+
)
|
3490
|
+
# output projections.
|
3491
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.weight"] = checkpoint.pop(
|
3492
|
+
f"double_blocks.{i}.img_attn.proj.weight"
|
3493
|
+
)
|
3494
|
+
converted_state_dict[f"{block_prefix}attn.to_out.0.bias"] = checkpoint.pop(
|
3495
|
+
f"double_blocks.{i}.img_attn.proj.bias"
|
3496
|
+
)
|
3497
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.weight"] = checkpoint.pop(
|
3498
|
+
f"double_blocks.{i}.txt_attn.proj.weight"
|
3499
|
+
)
|
3500
|
+
converted_state_dict[f"{block_prefix}attn.to_add_out.bias"] = checkpoint.pop(
|
3501
|
+
f"double_blocks.{i}.txt_attn.proj.bias"
|
3502
|
+
)
|
3503
|
+
|
3504
|
+
# single transformer blocks
|
3505
|
+
for i in range(num_single_layers):
|
3506
|
+
block_prefix = f"single_transformer_blocks.{i}."
|
3507
|
+
# Q, K, V, mlp
|
3508
|
+
mlp_hidden_dim = int(inner_dim * mlp_ratio)
|
3509
|
+
split_size = (inner_dim, inner_dim, inner_dim, mlp_hidden_dim)
|
3510
|
+
q, k, v, mlp = torch.split(checkpoint.pop(f"single_blocks.{i}.linear1.weight"), split_size, dim=0)
|
3511
|
+
q_bias, k_bias, v_bias, mlp_bias = torch.split(
|
3512
|
+
checkpoint.pop(f"single_blocks.{i}.linear1.bias"), split_size, dim=0
|
3513
|
+
)
|
3514
|
+
converted_state_dict[f"{block_prefix}attn.to_q.weight"] = torch.cat([q])
|
3515
|
+
converted_state_dict[f"{block_prefix}attn.to_q.bias"] = torch.cat([q_bias])
|
3516
|
+
converted_state_dict[f"{block_prefix}attn.to_k.weight"] = torch.cat([k])
|
3517
|
+
converted_state_dict[f"{block_prefix}attn.to_k.bias"] = torch.cat([k_bias])
|
3518
|
+
converted_state_dict[f"{block_prefix}attn.to_v.weight"] = torch.cat([v])
|
3519
|
+
converted_state_dict[f"{block_prefix}attn.to_v.bias"] = torch.cat([v_bias])
|
3520
|
+
converted_state_dict[f"{block_prefix}proj_mlp.weight"] = torch.cat([mlp])
|
3521
|
+
converted_state_dict[f"{block_prefix}proj_mlp.bias"] = torch.cat([mlp_bias])
|
3522
|
+
# qk norm
|
3523
|
+
converted_state_dict[f"{block_prefix}attn.norm_q.weight"] = checkpoint.pop(
|
3524
|
+
f"single_blocks.{i}.norm.query_norm.scale"
|
3525
|
+
)
|
3526
|
+
converted_state_dict[f"{block_prefix}attn.norm_k.weight"] = checkpoint.pop(
|
3527
|
+
f"single_blocks.{i}.norm.key_norm.scale"
|
3528
|
+
)
|
3529
|
+
# output projections.
|
3530
|
+
converted_state_dict[f"{block_prefix}proj_out.weight"] = checkpoint.pop(f"single_blocks.{i}.linear2.weight")
|
3531
|
+
converted_state_dict[f"{block_prefix}proj_out.bias"] = checkpoint.pop(f"single_blocks.{i}.linear2.bias")
|
3532
|
+
|
3533
|
+
converted_state_dict["proj_out.weight"] = checkpoint.pop("final_layer.linear.weight")
|
3534
|
+
converted_state_dict["proj_out.bias"] = checkpoint.pop("final_layer.linear.bias")
|
3535
|
+
|
3536
|
+
return converted_state_dict
|
3537
|
+
|
3538
|
+
|
3539
|
+
def convert_cosmos_transformer_checkpoint_to_diffusers(checkpoint, **kwargs):
|
3540
|
+
converted_state_dict = {key: checkpoint.pop(key) for key in list(checkpoint.keys())}
|
3541
|
+
|
3542
|
+
def remove_keys_(key: str, state_dict):
|
3543
|
+
state_dict.pop(key)
|
3544
|
+
|
3545
|
+
def rename_transformer_blocks_(key: str, state_dict):
|
3546
|
+
block_index = int(key.split(".")[1].removeprefix("block"))
|
3547
|
+
new_key = key
|
3548
|
+
old_prefix = f"blocks.block{block_index}"
|
3549
|
+
new_prefix = f"transformer_blocks.{block_index}"
|
3550
|
+
new_key = new_prefix + new_key.removeprefix(old_prefix)
|
3551
|
+
state_dict[new_key] = state_dict.pop(key)
|
3552
|
+
|
3553
|
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0 = {
|
3554
|
+
"t_embedder.1": "time_embed.t_embedder",
|
3555
|
+
"affline_norm": "time_embed.norm",
|
3556
|
+
".blocks.0.block.attn": ".attn1",
|
3557
|
+
".blocks.1.block.attn": ".attn2",
|
3558
|
+
".blocks.2.block": ".ff",
|
3559
|
+
".blocks.0.adaLN_modulation.1": ".norm1.linear_1",
|
3560
|
+
".blocks.0.adaLN_modulation.2": ".norm1.linear_2",
|
3561
|
+
".blocks.1.adaLN_modulation.1": ".norm2.linear_1",
|
3562
|
+
".blocks.1.adaLN_modulation.2": ".norm2.linear_2",
|
3563
|
+
".blocks.2.adaLN_modulation.1": ".norm3.linear_1",
|
3564
|
+
".blocks.2.adaLN_modulation.2": ".norm3.linear_2",
|
3565
|
+
"to_q.0": "to_q",
|
3566
|
+
"to_q.1": "norm_q",
|
3567
|
+
"to_k.0": "to_k",
|
3568
|
+
"to_k.1": "norm_k",
|
3569
|
+
"to_v.0": "to_v",
|
3570
|
+
"layer1": "net.0.proj",
|
3571
|
+
"layer2": "net.2",
|
3572
|
+
"proj.1": "proj",
|
3573
|
+
"x_embedder": "patch_embed",
|
3574
|
+
"extra_pos_embedder": "learnable_pos_embed",
|
3575
|
+
"final_layer.adaLN_modulation.1": "norm_out.linear_1",
|
3576
|
+
"final_layer.adaLN_modulation.2": "norm_out.linear_2",
|
3577
|
+
"final_layer.linear": "proj_out",
|
3578
|
+
}
|
3579
|
+
|
3580
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0 = {
|
3581
|
+
"blocks.block": rename_transformer_blocks_,
|
3582
|
+
"logvar.0.freqs": remove_keys_,
|
3583
|
+
"logvar.0.phases": remove_keys_,
|
3584
|
+
"logvar.1.weight": remove_keys_,
|
3585
|
+
"pos_embedder.seq": remove_keys_,
|
3586
|
+
}
|
3587
|
+
|
3588
|
+
TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0 = {
|
3589
|
+
"t_embedder.1": "time_embed.t_embedder",
|
3590
|
+
"t_embedding_norm": "time_embed.norm",
|
3591
|
+
"blocks": "transformer_blocks",
|
3592
|
+
"adaln_modulation_self_attn.1": "norm1.linear_1",
|
3593
|
+
"adaln_modulation_self_attn.2": "norm1.linear_2",
|
3594
|
+
"adaln_modulation_cross_attn.1": "norm2.linear_1",
|
3595
|
+
"adaln_modulation_cross_attn.2": "norm2.linear_2",
|
3596
|
+
"adaln_modulation_mlp.1": "norm3.linear_1",
|
3597
|
+
"adaln_modulation_mlp.2": "norm3.linear_2",
|
3598
|
+
"self_attn": "attn1",
|
3599
|
+
"cross_attn": "attn2",
|
3600
|
+
"q_proj": "to_q",
|
3601
|
+
"k_proj": "to_k",
|
3602
|
+
"v_proj": "to_v",
|
3603
|
+
"output_proj": "to_out.0",
|
3604
|
+
"q_norm": "norm_q",
|
3605
|
+
"k_norm": "norm_k",
|
3606
|
+
"mlp.layer1": "ff.net.0.proj",
|
3607
|
+
"mlp.layer2": "ff.net.2",
|
3608
|
+
"x_embedder.proj.1": "patch_embed.proj",
|
3609
|
+
"final_layer.adaln_modulation.1": "norm_out.linear_1",
|
3610
|
+
"final_layer.adaln_modulation.2": "norm_out.linear_2",
|
3611
|
+
"final_layer.linear": "proj_out",
|
3612
|
+
}
|
3613
|
+
|
3614
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0 = {
|
3615
|
+
"accum_video_sample_counter": remove_keys_,
|
3616
|
+
"accum_image_sample_counter": remove_keys_,
|
3617
|
+
"accum_iteration": remove_keys_,
|
3618
|
+
"accum_train_in_hours": remove_keys_,
|
3619
|
+
"pos_embedder.seq": remove_keys_,
|
3620
|
+
"pos_embedder.dim_spatial_range": remove_keys_,
|
3621
|
+
"pos_embedder.dim_temporal_range": remove_keys_,
|
3622
|
+
"_extra_state": remove_keys_,
|
3623
|
+
}
|
3624
|
+
|
3625
|
+
PREFIX_KEY = "net."
|
3626
|
+
if "net.blocks.block1.blocks.0.block.attn.to_q.0.weight" in checkpoint:
|
3627
|
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_1_0
|
3628
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_1_0
|
3629
|
+
else:
|
3630
|
+
TRANSFORMER_KEYS_RENAME_DICT = TRANSFORMER_KEYS_RENAME_DICT_COSMOS_2_0
|
3631
|
+
TRANSFORMER_SPECIAL_KEYS_REMAP = TRANSFORMER_SPECIAL_KEYS_REMAP_COSMOS_2_0
|
3632
|
+
|
3633
|
+
state_dict_keys = list(converted_state_dict.keys())
|
3634
|
+
for key in state_dict_keys:
|
3635
|
+
new_key = key[:]
|
3636
|
+
if new_key.startswith(PREFIX_KEY):
|
3637
|
+
new_key = new_key.removeprefix(PREFIX_KEY)
|
3638
|
+
for replace_key, rename_key in TRANSFORMER_KEYS_RENAME_DICT.items():
|
3639
|
+
new_key = new_key.replace(replace_key, rename_key)
|
3640
|
+
converted_state_dict[new_key] = converted_state_dict.pop(key)
|
3641
|
+
|
3642
|
+
state_dict_keys = list(converted_state_dict.keys())
|
3643
|
+
for key in state_dict_keys:
|
3644
|
+
for special_key, handler_fn_inplace in TRANSFORMER_SPECIAL_KEYS_REMAP.items():
|
3645
|
+
if special_key not in key:
|
3646
|
+
continue
|
3647
|
+
handler_fn_inplace(key, converted_state_dict)
|
3648
|
+
|
3649
|
+
return converted_state_dict
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -427,7 +427,8 @@ class TextualInversionLoaderMixin:
|
|
427
427
|
logger.info(
|
428
428
|
"Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
|
429
429
|
)
|
430
|
-
|
430
|
+
if is_sequential_cpu_offload or is_model_cpu_offload:
|
431
|
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
431
432
|
|
432
433
|
# 7.2 save expected device and dtype
|
433
434
|
device = text_encoder.device
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -17,12 +17,10 @@ from ..models.embeddings import (
|
|
17
17
|
ImageProjection,
|
18
18
|
MultiIPAdapterImageProjection,
|
19
19
|
)
|
20
|
-
from ..models.
|
21
|
-
from ..
|
22
|
-
|
23
|
-
|
24
|
-
logging,
|
25
|
-
)
|
20
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
21
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
22
|
+
from ..utils import is_accelerate_available, is_torch_version, logging
|
23
|
+
from ..utils.torch_utils import empty_device_cache
|
26
24
|
|
27
25
|
|
28
26
|
if is_accelerate_available():
|
@@ -84,13 +82,12 @@ class FluxTransformer2DLoadersMixin:
|
|
84
82
|
else:
|
85
83
|
device_map = {"": self.device}
|
86
84
|
load_model_dict_into_meta(image_projection, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
85
|
+
empty_device_cache()
|
87
86
|
|
88
87
|
return image_projection
|
89
88
|
|
90
89
|
def _convert_ip_adapter_attn_to_diffusers(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
91
|
-
from ..models.
|
92
|
-
FluxIPAdapterJointAttnProcessor2_0,
|
93
|
-
)
|
90
|
+
from ..models.transformers.transformer_flux import FluxIPAdapterAttnProcessor
|
94
91
|
|
95
92
|
if low_cpu_mem_usage:
|
96
93
|
if is_accelerate_available():
|
@@ -122,7 +119,7 @@ class FluxTransformer2DLoadersMixin:
|
|
122
119
|
else:
|
123
120
|
cross_attention_dim = self.config.joint_attention_dim
|
124
121
|
hidden_size = self.inner_dim
|
125
|
-
attn_processor_class =
|
122
|
+
attn_processor_class = FluxIPAdapterAttnProcessor
|
126
123
|
num_image_text_embeds = []
|
127
124
|
for state_dict in state_dicts:
|
128
125
|
if "proj.weight" in state_dict["image_proj"]:
|
@@ -158,6 +155,8 @@ class FluxTransformer2DLoadersMixin:
|
|
158
155
|
|
159
156
|
key_id += 1
|
160
157
|
|
158
|
+
empty_device_cache()
|
159
|
+
|
161
160
|
return attn_procs
|
162
161
|
|
163
162
|
def _load_ip_adapter_weights(self, state_dicts, low_cpu_mem_usage=_LOW_CPU_MEM_USAGE_DEFAULT):
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -16,8 +16,10 @@ from typing import Dict
|
|
16
16
|
|
17
17
|
from ..models.attention_processor import SD3IPAdapterJointAttnProcessor2_0
|
18
18
|
from ..models.embeddings import IPAdapterTimeImageProjection
|
19
|
-
from ..models.
|
19
|
+
from ..models.model_loading_utils import load_model_dict_into_meta
|
20
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
20
21
|
from ..utils import is_accelerate_available, is_torch_version, logging
|
22
|
+
from ..utils.torch_utils import empty_device_cache
|
21
23
|
|
22
24
|
|
23
25
|
logger = logging.get_logger(__name__)
|
@@ -80,6 +82,8 @@ class SD3Transformer2DLoadersMixin:
|
|
80
82
|
attn_procs[name], layer_state_dict[idx], device_map=device_map, dtype=self.dtype
|
81
83
|
)
|
82
84
|
|
85
|
+
empty_device_cache()
|
86
|
+
|
83
87
|
return attn_procs
|
84
88
|
|
85
89
|
def _convert_ip_adapter_image_proj_to_diffusers(
|
@@ -123,7 +127,7 @@ class SD3Transformer2DLoadersMixin:
|
|
123
127
|
key = key.replace(f"layers.{idx}.2.1", f"layers.{idx}.adaln_proj")
|
124
128
|
updated_state_dict[key] = value
|
125
129
|
|
126
|
-
# Image
|
130
|
+
# Image projection parameters
|
127
131
|
embed_dim = updated_state_dict["proj_in.weight"].shape[1]
|
128
132
|
output_dim = updated_state_dict["proj_out.weight"].shape[0]
|
129
133
|
hidden_dim = updated_state_dict["proj_in.weight"].shape[0]
|
@@ -147,6 +151,7 @@ class SD3Transformer2DLoadersMixin:
|
|
147
151
|
else:
|
148
152
|
device_map = {"": self.device}
|
149
153
|
load_model_dict_into_meta(image_proj, updated_state_dict, device_map=device_map, dtype=self.dtype)
|
154
|
+
empty_device_cache()
|
150
155
|
|
151
156
|
return image_proj
|
152
157
|
|