diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +17 -12
- diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
- diffusers-0.34.0.dist-info/RECORD +639 -0
- diffusers-0.33.0.dist-info/RECORD +0 -608
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -19,18 +19,13 @@ import torch
|
|
19
19
|
import torch.nn as nn
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...models.attention import FeedForward
|
23
|
-
from ...models.attention_processor import (
|
24
|
-
Attention,
|
25
|
-
AttentionProcessor,
|
26
|
-
CogVideoXAttnProcessor2_0,
|
27
|
-
)
|
28
|
-
from ...models.modeling_utils import ModelMixin
|
29
|
-
from ...models.normalization import AdaLayerNormContinuous
|
30
22
|
from ...utils import logging
|
23
|
+
from ..attention import FeedForward
|
24
|
+
from ..attention_processor import Attention, AttentionProcessor, CogVideoXAttnProcessor2_0
|
31
25
|
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
32
26
|
from ..modeling_outputs import Transformer2DModelOutput
|
33
|
-
from ..
|
27
|
+
from ..modeling_utils import ModelMixin
|
28
|
+
from ..normalization import AdaLayerNormContinuous, CogView3PlusAdaLayerNormZeroTextImage
|
34
29
|
|
35
30
|
|
36
31
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
import torch.nn as nn
|
@@ -73,8 +73,9 @@ class CogView4AdaLayerNormZero(nn.Module):
|
|
73
73
|
def forward(
|
74
74
|
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
75
75
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
-
|
77
|
-
|
76
|
+
dtype = hidden_states.dtype
|
77
|
+
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
78
|
+
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
78
79
|
|
79
80
|
emb = self.linear(temb)
|
80
81
|
(
|
@@ -111,8 +112,11 @@ class CogView4AdaLayerNormZero(nn.Module):
|
|
111
112
|
|
112
113
|
class CogView4AttnProcessor:
|
113
114
|
"""
|
114
|
-
Processor for implementing scaled dot-product attention for the
|
115
|
+
Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
|
115
116
|
query and key vectors, but does not include spatial normalization.
|
117
|
+
|
118
|
+
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
119
|
+
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
116
120
|
"""
|
117
121
|
|
118
122
|
def __init__(self):
|
@@ -125,8 +129,10 @@ class CogView4AttnProcessor:
|
|
125
129
|
hidden_states: torch.Tensor,
|
126
130
|
encoder_hidden_states: torch.Tensor,
|
127
131
|
attention_mask: Optional[torch.Tensor] = None,
|
128
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
129
|
-
) -> torch.Tensor:
|
132
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
133
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134
|
+
dtype = encoder_hidden_states.dtype
|
135
|
+
|
130
136
|
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
131
137
|
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
132
138
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
@@ -142,9 +148,9 @@ class CogView4AttnProcessor:
|
|
142
148
|
|
143
149
|
# 2. QK normalization
|
144
150
|
if attn.norm_q is not None:
|
145
|
-
query = attn.norm_q(query)
|
151
|
+
query = attn.norm_q(query).to(dtype=dtype)
|
146
152
|
if attn.norm_k is not None:
|
147
|
-
key = attn.norm_k(key)
|
153
|
+
key = attn.norm_k(key).to(dtype=dtype)
|
148
154
|
|
149
155
|
# 3. Rotational positional embeddings applied to latent stream
|
150
156
|
if image_rotary_emb is not None:
|
@@ -159,13 +165,14 @@ class CogView4AttnProcessor:
|
|
159
165
|
|
160
166
|
# 4. Attention
|
161
167
|
if attention_mask is not None:
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
168
|
+
text_attn_mask = attention_mask
|
169
|
+
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
170
|
+
text_attn_mask = text_attn_mask.float().to(query.device)
|
171
|
+
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
172
|
+
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
173
|
+
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
174
|
+
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
175
|
+
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
169
176
|
|
170
177
|
hidden_states = F.scaled_dot_product_attention(
|
171
178
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
@@ -183,9 +190,276 @@ class CogView4AttnProcessor:
|
|
183
190
|
return hidden_states, encoder_hidden_states
|
184
191
|
|
185
192
|
|
193
|
+
class CogView4TrainingAttnProcessor:
|
194
|
+
"""
|
195
|
+
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
|
196
|
+
embedding on query and key vectors, but does not include spatial normalization.
|
197
|
+
|
198
|
+
This processor differs from CogView4AttnProcessor in several important ways:
|
199
|
+
1. It supports attention masking with variable sequence lengths for multi-resolution training
|
200
|
+
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
|
201
|
+
provided
|
202
|
+
"""
|
203
|
+
|
204
|
+
def __init__(self):
|
205
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
206
|
+
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
207
|
+
|
208
|
+
def __call__(
|
209
|
+
self,
|
210
|
+
attn: Attention,
|
211
|
+
hidden_states: torch.Tensor,
|
212
|
+
encoder_hidden_states: torch.Tensor,
|
213
|
+
latent_attn_mask: Optional[torch.Tensor] = None,
|
214
|
+
text_attn_mask: Optional[torch.Tensor] = None,
|
215
|
+
batch_flag: Optional[torch.Tensor] = None,
|
216
|
+
image_rotary_emb: Optional[
|
217
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
218
|
+
] = None,
|
219
|
+
**kwargs,
|
220
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
221
|
+
"""
|
222
|
+
Args:
|
223
|
+
attn (`Attention`):
|
224
|
+
The attention module.
|
225
|
+
hidden_states (`torch.Tensor`):
|
226
|
+
The input hidden states.
|
227
|
+
encoder_hidden_states (`torch.Tensor`):
|
228
|
+
The encoder hidden states for cross-attention.
|
229
|
+
latent_attn_mask (`torch.Tensor`, *optional*):
|
230
|
+
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
|
231
|
+
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
|
232
|
+
num_latent_tokens).
|
233
|
+
text_attn_mask (`torch.Tensor`, *optional*):
|
234
|
+
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
|
235
|
+
is used for all text tokens.
|
236
|
+
batch_flag (`torch.Tensor`, *optional*):
|
237
|
+
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
|
238
|
+
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
|
239
|
+
batch1, and samples 3-4 form batch2. If None, no packing is used.
|
240
|
+
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
|
241
|
+
The rotary embedding for the image part of the input.
|
242
|
+
Returns:
|
243
|
+
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
|
244
|
+
"""
|
245
|
+
|
246
|
+
# Get dimensions and device info
|
247
|
+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
248
|
+
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
249
|
+
dtype = encoder_hidden_states.dtype
|
250
|
+
device = encoder_hidden_states.device
|
251
|
+
latent_hidden_states = hidden_states
|
252
|
+
# Combine text and image streams for joint processing
|
253
|
+
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
|
254
|
+
|
255
|
+
# 1. Construct attention mask and maybe packing input
|
256
|
+
# Create default masks if not provided
|
257
|
+
if text_attn_mask is None:
|
258
|
+
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
|
259
|
+
if latent_attn_mask is None:
|
260
|
+
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
|
261
|
+
|
262
|
+
# Validate mask shapes and types
|
263
|
+
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
264
|
+
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
|
265
|
+
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
|
266
|
+
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
|
267
|
+
|
268
|
+
# Create combined mask for text and image tokens
|
269
|
+
mixed_attn_mask = torch.ones(
|
270
|
+
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
|
271
|
+
)
|
272
|
+
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
|
273
|
+
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
|
274
|
+
|
275
|
+
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
|
276
|
+
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
|
277
|
+
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
|
278
|
+
|
279
|
+
# Handle batch packing if enabled
|
280
|
+
if batch_flag is not None:
|
281
|
+
assert batch_flag.dim() == 1
|
282
|
+
# Determine packed batch size based on batch_flag
|
283
|
+
packing_batch_size = torch.max(batch_flag).item() + 1
|
284
|
+
|
285
|
+
# Calculate actual sequence lengths for each sample based on masks
|
286
|
+
text_seq_length = torch.sum(text_attn_mask, dim=1)
|
287
|
+
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
|
288
|
+
mixed_seq_length = text_seq_length + latent_seq_length
|
289
|
+
|
290
|
+
# Calculate packed sequence lengths for each packed batch
|
291
|
+
mixed_seq_length_packed = [
|
292
|
+
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
|
293
|
+
]
|
294
|
+
|
295
|
+
assert len(mixed_seq_length_packed) == packing_batch_size
|
296
|
+
|
297
|
+
# Pack sequences by removing padding tokens
|
298
|
+
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
|
299
|
+
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
|
300
|
+
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
|
301
|
+
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
|
302
|
+
|
303
|
+
# Split the unpadded sequence into packed batches
|
304
|
+
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
|
305
|
+
|
306
|
+
# Re-pad to create packed batches with right-side padding
|
307
|
+
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
|
308
|
+
mixed_hidden_states_packed,
|
309
|
+
batch_first=True,
|
310
|
+
padding_value=0.0,
|
311
|
+
padding_side="right",
|
312
|
+
)
|
313
|
+
|
314
|
+
# Create attention mask for packed batches
|
315
|
+
l = mixed_hidden_states_packed_padded.shape[1]
|
316
|
+
attn_mask_matrix = torch.zeros(
|
317
|
+
(packing_batch_size, l, l),
|
318
|
+
dtype=dtype,
|
319
|
+
device=device,
|
320
|
+
)
|
321
|
+
|
322
|
+
# Fill attention mask with block diagonal matrices
|
323
|
+
# This ensures that tokens can only attend to other tokens within the same original sample
|
324
|
+
for idx, mask in enumerate(attn_mask_matrix):
|
325
|
+
seq_lengths = mixed_seq_length[batch_flag == idx]
|
326
|
+
offset = 0
|
327
|
+
for length in seq_lengths:
|
328
|
+
# Create a block of 1s for each sample in the packed batch
|
329
|
+
mask[offset : offset + length, offset : offset + length] = 1
|
330
|
+
offset += length
|
331
|
+
|
332
|
+
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
|
333
|
+
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
|
334
|
+
attention_mask = attn_mask_matrix
|
335
|
+
|
336
|
+
# Prepare hidden states for attention computation
|
337
|
+
if batch_flag is None:
|
338
|
+
# If no packing, just combine text and image tokens
|
339
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
340
|
+
else:
|
341
|
+
# If packing, use the packed sequence
|
342
|
+
hidden_states = mixed_hidden_states_packed_padded
|
343
|
+
|
344
|
+
# 2. QKV projections - convert hidden states to query, key, value
|
345
|
+
query = attn.to_q(hidden_states)
|
346
|
+
key = attn.to_k(hidden_states)
|
347
|
+
value = attn.to_v(hidden_states)
|
348
|
+
|
349
|
+
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
|
350
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
351
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
352
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
353
|
+
|
354
|
+
# 3. QK normalization - apply layer norm to queries and keys if configured
|
355
|
+
if attn.norm_q is not None:
|
356
|
+
query = attn.norm_q(query).to(dtype=dtype)
|
357
|
+
if attn.norm_k is not None:
|
358
|
+
key = attn.norm_k(key).to(dtype=dtype)
|
359
|
+
|
360
|
+
# 4. Apply rotary positional embeddings to image tokens only
|
361
|
+
if image_rotary_emb is not None:
|
362
|
+
from ..embeddings import apply_rotary_emb
|
363
|
+
|
364
|
+
if batch_flag is None:
|
365
|
+
# Apply RoPE only to image tokens (after text tokens)
|
366
|
+
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
367
|
+
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
368
|
+
)
|
369
|
+
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
370
|
+
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
371
|
+
)
|
372
|
+
else:
|
373
|
+
# For packed batches, need to carefully apply RoPE to appropriate tokens
|
374
|
+
assert query.shape[0] == packing_batch_size
|
375
|
+
assert key.shape[0] == packing_batch_size
|
376
|
+
assert len(image_rotary_emb) == batch_size
|
377
|
+
|
378
|
+
rope_idx = 0
|
379
|
+
for idx in range(packing_batch_size):
|
380
|
+
offset = 0
|
381
|
+
# Get text and image sequence lengths for samples in this packed batch
|
382
|
+
text_seq_length_bi = text_seq_length[batch_flag == idx]
|
383
|
+
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
|
384
|
+
|
385
|
+
# Apply RoPE to each image segment in the packed sequence
|
386
|
+
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
|
387
|
+
mlen = tlen + llen
|
388
|
+
# Apply RoPE only to image tokens (after text tokens)
|
389
|
+
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
390
|
+
query[idx, :, offset + tlen : offset + mlen, :],
|
391
|
+
image_rotary_emb[rope_idx],
|
392
|
+
use_real_unbind_dim=-2,
|
393
|
+
)
|
394
|
+
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
395
|
+
key[idx, :, offset + tlen : offset + mlen, :],
|
396
|
+
image_rotary_emb[rope_idx],
|
397
|
+
use_real_unbind_dim=-2,
|
398
|
+
)
|
399
|
+
offset += mlen
|
400
|
+
rope_idx += 1
|
401
|
+
|
402
|
+
hidden_states = F.scaled_dot_product_attention(
|
403
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
404
|
+
)
|
405
|
+
|
406
|
+
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
|
407
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
408
|
+
hidden_states = hidden_states.type_as(query)
|
409
|
+
|
410
|
+
# 5. Output projection - project attention output to model dimension
|
411
|
+
hidden_states = attn.to_out[0](hidden_states)
|
412
|
+
hidden_states = attn.to_out[1](hidden_states)
|
413
|
+
|
414
|
+
# Split the output back into text and image streams
|
415
|
+
if batch_flag is None:
|
416
|
+
# Simple split for non-packed case
|
417
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
418
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
419
|
+
)
|
420
|
+
else:
|
421
|
+
# For packed case: need to unpack, split text/image, then restore to original shapes
|
422
|
+
# First, unpad the sequence based on the packed sequence lengths
|
423
|
+
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
|
424
|
+
hidden_states,
|
425
|
+
lengths=torch.tensor(mixed_seq_length_packed),
|
426
|
+
batch_first=True,
|
427
|
+
)
|
428
|
+
# Concatenate all unpadded sequences
|
429
|
+
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
|
430
|
+
# Split by original sample sequence lengths
|
431
|
+
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
|
432
|
+
assert len(hidden_states_unpack) == batch_size
|
433
|
+
|
434
|
+
# Further split each sample's sequence into text and image parts
|
435
|
+
hidden_states_unpack = [
|
436
|
+
torch.split(h, [tlen, llen])
|
437
|
+
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
|
438
|
+
]
|
439
|
+
# Separate text and image sequences
|
440
|
+
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
|
441
|
+
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
|
442
|
+
|
443
|
+
# Update the original tensors with the processed values, respecting the attention masks
|
444
|
+
for idx in range(batch_size):
|
445
|
+
# Place unpacked text tokens back in the encoder_hidden_states tensor
|
446
|
+
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
|
447
|
+
# Place unpacked image tokens back in the latent_hidden_states tensor
|
448
|
+
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
|
449
|
+
|
450
|
+
# Update the output hidden states
|
451
|
+
hidden_states = latent_hidden_states
|
452
|
+
|
453
|
+
return hidden_states, encoder_hidden_states
|
454
|
+
|
455
|
+
|
186
456
|
class CogView4TransformerBlock(nn.Module):
|
187
457
|
def __init__(
|
188
|
-
self,
|
458
|
+
self,
|
459
|
+
dim: int = 2560,
|
460
|
+
num_attention_heads: int = 64,
|
461
|
+
attention_head_dim: int = 40,
|
462
|
+
time_embed_dim: int = 512,
|
189
463
|
) -> None:
|
190
464
|
super().__init__()
|
191
465
|
|
@@ -213,9 +487,11 @@ class CogView4TransformerBlock(nn.Module):
|
|
213
487
|
hidden_states: torch.Tensor,
|
214
488
|
encoder_hidden_states: torch.Tensor,
|
215
489
|
temb: Optional[torch.Tensor] = None,
|
216
|
-
image_rotary_emb: Optional[
|
217
|
-
|
218
|
-
|
490
|
+
image_rotary_emb: Optional[
|
491
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
492
|
+
] = None,
|
493
|
+
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
494
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
219
495
|
) -> torch.Tensor:
|
220
496
|
# 1. Timestep conditioning
|
221
497
|
(
|
@@ -232,12 +508,14 @@ class CogView4TransformerBlock(nn.Module):
|
|
232
508
|
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
233
509
|
|
234
510
|
# 2. Attention
|
511
|
+
if attention_kwargs is None:
|
512
|
+
attention_kwargs = {}
|
235
513
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
236
514
|
hidden_states=norm_hidden_states,
|
237
515
|
encoder_hidden_states=norm_encoder_hidden_states,
|
238
516
|
image_rotary_emb=image_rotary_emb,
|
239
517
|
attention_mask=attention_mask,
|
240
|
-
**
|
518
|
+
**attention_kwargs,
|
241
519
|
)
|
242
520
|
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
243
521
|
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
@@ -402,7 +680,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
402
680
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
403
681
|
return_dict: bool = True,
|
404
682
|
attention_mask: Optional[torch.Tensor] = None,
|
405
|
-
|
683
|
+
image_rotary_emb: Optional[
|
684
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
685
|
+
] = None,
|
406
686
|
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
407
687
|
if attention_kwargs is not None:
|
408
688
|
attention_kwargs = attention_kwargs.copy()
|
@@ -422,7 +702,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
422
702
|
batch_size, num_channels, height, width = hidden_states.shape
|
423
703
|
|
424
704
|
# 1. RoPE
|
425
|
-
image_rotary_emb
|
705
|
+
if image_rotary_emb is None:
|
706
|
+
image_rotary_emb = self.rope(hidden_states)
|
426
707
|
|
427
708
|
# 2. Patch & Timestep embeddings
|
428
709
|
p = self.config.patch_size
|
@@ -438,11 +719,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
438
719
|
for block in self.transformer_blocks:
|
439
720
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
440
721
|
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
441
|
-
block,
|
722
|
+
block,
|
723
|
+
hidden_states,
|
724
|
+
encoder_hidden_states,
|
725
|
+
temb,
|
726
|
+
image_rotary_emb,
|
727
|
+
attention_mask,
|
728
|
+
attention_kwargs,
|
442
729
|
)
|
443
730
|
else:
|
444
731
|
hidden_states, encoder_hidden_states = block(
|
445
|
-
hidden_states,
|
732
|
+
hidden_states,
|
733
|
+
encoder_hidden_states,
|
734
|
+
temb,
|
735
|
+
image_rotary_emb,
|
736
|
+
attention_mask,
|
737
|
+
attention_kwargs,
|
446
738
|
)
|
447
739
|
|
448
740
|
# 4. Output norm & projection
|