diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Any, Dict, Optional, Tuple, Union
|
15
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
18
|
import torch.nn as nn
|
@@ -21,13 +21,14 @@ import torch.nn.functional as F
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
23
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
24
25
|
from ..attention import FeedForward
|
25
26
|
from ..attention_processor import Attention
|
26
27
|
from ..cache_utils import CacheMixin
|
27
28
|
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
28
29
|
from ..modeling_outputs import Transformer2DModelOutput
|
29
30
|
from ..modeling_utils import ModelMixin
|
30
|
-
from ..normalization import
|
31
|
+
from ..normalization import LayerNorm, RMSNorm
|
31
32
|
|
32
33
|
|
33
34
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -73,8 +74,9 @@ class CogView4AdaLayerNormZero(nn.Module):
|
|
73
74
|
def forward(
|
74
75
|
self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
|
75
76
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
76
|
-
|
77
|
-
|
77
|
+
dtype = hidden_states.dtype
|
78
|
+
norm_hidden_states = self.norm(hidden_states).to(dtype=dtype)
|
79
|
+
norm_encoder_hidden_states = self.norm_context(encoder_hidden_states).to(dtype=dtype)
|
78
80
|
|
79
81
|
emb = self.linear(temb)
|
80
82
|
(
|
@@ -111,8 +113,11 @@ class CogView4AdaLayerNormZero(nn.Module):
|
|
111
113
|
|
112
114
|
class CogView4AttnProcessor:
|
113
115
|
"""
|
114
|
-
Processor for implementing scaled dot-product attention for the
|
116
|
+
Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary embedding on
|
115
117
|
query and key vectors, but does not include spatial normalization.
|
118
|
+
|
119
|
+
The processor supports passing an attention mask for text tokens. The attention mask should have shape (batch_size,
|
120
|
+
text_seq_length) where 1 indicates a non-padded token and 0 indicates a padded token.
|
116
121
|
"""
|
117
122
|
|
118
123
|
def __init__(self):
|
@@ -125,8 +130,10 @@ class CogView4AttnProcessor:
|
|
125
130
|
hidden_states: torch.Tensor,
|
126
131
|
encoder_hidden_states: torch.Tensor,
|
127
132
|
attention_mask: Optional[torch.Tensor] = None,
|
128
|
-
image_rotary_emb: Optional[torch.Tensor] = None,
|
129
|
-
) -> torch.Tensor:
|
133
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
134
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
135
|
+
dtype = encoder_hidden_states.dtype
|
136
|
+
|
130
137
|
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
131
138
|
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
132
139
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
@@ -142,9 +149,9 @@ class CogView4AttnProcessor:
|
|
142
149
|
|
143
150
|
# 2. QK normalization
|
144
151
|
if attn.norm_q is not None:
|
145
|
-
query = attn.norm_q(query)
|
152
|
+
query = attn.norm_q(query).to(dtype=dtype)
|
146
153
|
if attn.norm_k is not None:
|
147
|
-
key = attn.norm_k(key)
|
154
|
+
key = attn.norm_k(key).to(dtype=dtype)
|
148
155
|
|
149
156
|
# 3. Rotational positional embeddings applied to latent stream
|
150
157
|
if image_rotary_emb is not None:
|
@@ -159,13 +166,14 @@ class CogView4AttnProcessor:
|
|
159
166
|
|
160
167
|
# 4. Attention
|
161
168
|
if attention_mask is not None:
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
+
text_attn_mask = attention_mask
|
170
|
+
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
171
|
+
text_attn_mask = text_attn_mask.float().to(query.device)
|
172
|
+
mix_attn_mask = torch.ones((batch_size, text_seq_length + image_seq_length), device=query.device)
|
173
|
+
mix_attn_mask[:, :text_seq_length] = text_attn_mask
|
174
|
+
mix_attn_mask = mix_attn_mask.unsqueeze(2)
|
175
|
+
attn_mask_matrix = mix_attn_mask @ mix_attn_mask.transpose(1, 2)
|
176
|
+
attention_mask = (attn_mask_matrix > 0).unsqueeze(1).to(query.dtype)
|
169
177
|
|
170
178
|
hidden_states = F.scaled_dot_product_attention(
|
171
179
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
@@ -183,9 +191,277 @@ class CogView4AttnProcessor:
|
|
183
191
|
return hidden_states, encoder_hidden_states
|
184
192
|
|
185
193
|
|
194
|
+
class CogView4TrainingAttnProcessor:
|
195
|
+
"""
|
196
|
+
Training Processor for implementing scaled dot-product attention for the CogView4 model. It applies a rotary
|
197
|
+
embedding on query and key vectors, but does not include spatial normalization.
|
198
|
+
|
199
|
+
This processor differs from CogView4AttnProcessor in several important ways:
|
200
|
+
1. It supports attention masking with variable sequence lengths for multi-resolution training
|
201
|
+
2. It unpacks and repacks sequences for efficient training with variable sequence lengths when batch_flag is
|
202
|
+
provided
|
203
|
+
"""
|
204
|
+
|
205
|
+
def __init__(self):
|
206
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
207
|
+
raise ImportError("CogView4AttnProcessor requires PyTorch 2.0. To use it, please upgrade PyTorch to 2.0.")
|
208
|
+
|
209
|
+
def __call__(
|
210
|
+
self,
|
211
|
+
attn: Attention,
|
212
|
+
hidden_states: torch.Tensor,
|
213
|
+
encoder_hidden_states: torch.Tensor,
|
214
|
+
latent_attn_mask: Optional[torch.Tensor] = None,
|
215
|
+
text_attn_mask: Optional[torch.Tensor] = None,
|
216
|
+
batch_flag: Optional[torch.Tensor] = None,
|
217
|
+
image_rotary_emb: Optional[
|
218
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
219
|
+
] = None,
|
220
|
+
**kwargs,
|
221
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
222
|
+
"""
|
223
|
+
Args:
|
224
|
+
attn (`Attention`):
|
225
|
+
The attention module.
|
226
|
+
hidden_states (`torch.Tensor`):
|
227
|
+
The input hidden states.
|
228
|
+
encoder_hidden_states (`torch.Tensor`):
|
229
|
+
The encoder hidden states for cross-attention.
|
230
|
+
latent_attn_mask (`torch.Tensor`, *optional*):
|
231
|
+
Mask for latent tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full
|
232
|
+
attention is used for all latent tokens. Note: the shape of latent_attn_mask is (batch_size,
|
233
|
+
num_latent_tokens).
|
234
|
+
text_attn_mask (`torch.Tensor`, *optional*):
|
235
|
+
Mask for text tokens where 0 indicates pad token and 1 indicates non-pad token. If None, full attention
|
236
|
+
is used for all text tokens.
|
237
|
+
batch_flag (`torch.Tensor`, *optional*):
|
238
|
+
Values from 0 to n-1 indicating which samples belong to the same batch. Samples with the same
|
239
|
+
batch_flag are packed together. Example: [0, 1, 1, 2, 2] means sample 0 forms batch0, samples 1-2 form
|
240
|
+
batch1, and samples 3-4 form batch2. If None, no packing is used.
|
241
|
+
image_rotary_emb (`Tuple[torch.Tensor, torch.Tensor]` or `list[Tuple[torch.Tensor, torch.Tensor]]`, *optional*):
|
242
|
+
The rotary embedding for the image part of the input.
|
243
|
+
Returns:
|
244
|
+
`Tuple[torch.Tensor, torch.Tensor]`: The processed hidden states for both image and text streams.
|
245
|
+
"""
|
246
|
+
|
247
|
+
# Get dimensions and device info
|
248
|
+
batch_size, text_seq_length, embed_dim = encoder_hidden_states.shape
|
249
|
+
batch_size, image_seq_length, embed_dim = hidden_states.shape
|
250
|
+
dtype = encoder_hidden_states.dtype
|
251
|
+
device = encoder_hidden_states.device
|
252
|
+
latent_hidden_states = hidden_states
|
253
|
+
# Combine text and image streams for joint processing
|
254
|
+
mixed_hidden_states = torch.cat([encoder_hidden_states, latent_hidden_states], dim=1)
|
255
|
+
|
256
|
+
# 1. Construct attention mask and maybe packing input
|
257
|
+
# Create default masks if not provided
|
258
|
+
if text_attn_mask is None:
|
259
|
+
text_attn_mask = torch.ones((batch_size, text_seq_length), dtype=torch.int32, device=device)
|
260
|
+
if latent_attn_mask is None:
|
261
|
+
latent_attn_mask = torch.ones((batch_size, image_seq_length), dtype=torch.int32, device=device)
|
262
|
+
|
263
|
+
# Validate mask shapes and types
|
264
|
+
assert text_attn_mask.dim() == 2, "the shape of text_attn_mask should be (batch_size, text_seq_length)"
|
265
|
+
assert text_attn_mask.dtype == torch.int32, "the dtype of text_attn_mask should be torch.int32"
|
266
|
+
assert latent_attn_mask.dim() == 2, "the shape of latent_attn_mask should be (batch_size, num_latent_tokens)"
|
267
|
+
assert latent_attn_mask.dtype == torch.int32, "the dtype of latent_attn_mask should be torch.int32"
|
268
|
+
|
269
|
+
# Create combined mask for text and image tokens
|
270
|
+
mixed_attn_mask = torch.ones(
|
271
|
+
(batch_size, text_seq_length + image_seq_length), dtype=torch.int32, device=device
|
272
|
+
)
|
273
|
+
mixed_attn_mask[:, :text_seq_length] = text_attn_mask
|
274
|
+
mixed_attn_mask[:, text_seq_length:] = latent_attn_mask
|
275
|
+
|
276
|
+
# Convert mask to attention matrix format (where 1 means attend, 0 means don't attend)
|
277
|
+
mixed_attn_mask_input = mixed_attn_mask.unsqueeze(2).to(dtype=dtype)
|
278
|
+
attn_mask_matrix = mixed_attn_mask_input @ mixed_attn_mask_input.transpose(1, 2)
|
279
|
+
|
280
|
+
# Handle batch packing if enabled
|
281
|
+
if batch_flag is not None:
|
282
|
+
assert batch_flag.dim() == 1
|
283
|
+
# Determine packed batch size based on batch_flag
|
284
|
+
packing_batch_size = torch.max(batch_flag).item() + 1
|
285
|
+
|
286
|
+
# Calculate actual sequence lengths for each sample based on masks
|
287
|
+
text_seq_length = torch.sum(text_attn_mask, dim=1)
|
288
|
+
latent_seq_length = torch.sum(latent_attn_mask, dim=1)
|
289
|
+
mixed_seq_length = text_seq_length + latent_seq_length
|
290
|
+
|
291
|
+
# Calculate packed sequence lengths for each packed batch
|
292
|
+
mixed_seq_length_packed = [
|
293
|
+
torch.sum(mixed_attn_mask[batch_flag == batch_idx]).item() for batch_idx in range(packing_batch_size)
|
294
|
+
]
|
295
|
+
|
296
|
+
assert len(mixed_seq_length_packed) == packing_batch_size
|
297
|
+
|
298
|
+
# Pack sequences by removing padding tokens
|
299
|
+
mixed_attn_mask_flatten = mixed_attn_mask.flatten(0, 1)
|
300
|
+
mixed_hidden_states_flatten = mixed_hidden_states.flatten(0, 1)
|
301
|
+
mixed_hidden_states_unpad = mixed_hidden_states_flatten[mixed_attn_mask_flatten == 1]
|
302
|
+
assert torch.sum(mixed_seq_length) == mixed_hidden_states_unpad.shape[0]
|
303
|
+
|
304
|
+
# Split the unpadded sequence into packed batches
|
305
|
+
mixed_hidden_states_packed = torch.split(mixed_hidden_states_unpad, mixed_seq_length_packed)
|
306
|
+
|
307
|
+
# Re-pad to create packed batches with right-side padding
|
308
|
+
mixed_hidden_states_packed_padded = torch.nn.utils.rnn.pad_sequence(
|
309
|
+
mixed_hidden_states_packed,
|
310
|
+
batch_first=True,
|
311
|
+
padding_value=0.0,
|
312
|
+
padding_side="right",
|
313
|
+
)
|
314
|
+
|
315
|
+
# Create attention mask for packed batches
|
316
|
+
l = mixed_hidden_states_packed_padded.shape[1]
|
317
|
+
attn_mask_matrix = torch.zeros(
|
318
|
+
(packing_batch_size, l, l),
|
319
|
+
dtype=dtype,
|
320
|
+
device=device,
|
321
|
+
)
|
322
|
+
|
323
|
+
# Fill attention mask with block diagonal matrices
|
324
|
+
# This ensures that tokens can only attend to other tokens within the same original sample
|
325
|
+
for idx, mask in enumerate(attn_mask_matrix):
|
326
|
+
seq_lengths = mixed_seq_length[batch_flag == idx]
|
327
|
+
offset = 0
|
328
|
+
for length in seq_lengths:
|
329
|
+
# Create a block of 1s for each sample in the packed batch
|
330
|
+
mask[offset : offset + length, offset : offset + length] = 1
|
331
|
+
offset += length
|
332
|
+
|
333
|
+
attn_mask_matrix = attn_mask_matrix.to(dtype=torch.bool)
|
334
|
+
attn_mask_matrix = attn_mask_matrix.unsqueeze(1) # Add attention head dim
|
335
|
+
attention_mask = attn_mask_matrix
|
336
|
+
|
337
|
+
# Prepare hidden states for attention computation
|
338
|
+
if batch_flag is None:
|
339
|
+
# If no packing, just combine text and image tokens
|
340
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
341
|
+
else:
|
342
|
+
# If packing, use the packed sequence
|
343
|
+
hidden_states = mixed_hidden_states_packed_padded
|
344
|
+
|
345
|
+
# 2. QKV projections - convert hidden states to query, key, value
|
346
|
+
query = attn.to_q(hidden_states)
|
347
|
+
key = attn.to_k(hidden_states)
|
348
|
+
value = attn.to_v(hidden_states)
|
349
|
+
|
350
|
+
# Reshape for multi-head attention: [batch, seq_len, heads*dim] -> [batch, heads, seq_len, dim]
|
351
|
+
query = query.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
352
|
+
key = key.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
353
|
+
value = value.unflatten(2, (attn.heads, -1)).transpose(1, 2)
|
354
|
+
|
355
|
+
# 3. QK normalization - apply layer norm to queries and keys if configured
|
356
|
+
if attn.norm_q is not None:
|
357
|
+
query = attn.norm_q(query).to(dtype=dtype)
|
358
|
+
if attn.norm_k is not None:
|
359
|
+
key = attn.norm_k(key).to(dtype=dtype)
|
360
|
+
|
361
|
+
# 4. Apply rotary positional embeddings to image tokens only
|
362
|
+
if image_rotary_emb is not None:
|
363
|
+
from ..embeddings import apply_rotary_emb
|
364
|
+
|
365
|
+
if batch_flag is None:
|
366
|
+
# Apply RoPE only to image tokens (after text tokens)
|
367
|
+
query[:, :, text_seq_length:, :] = apply_rotary_emb(
|
368
|
+
query[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
369
|
+
)
|
370
|
+
key[:, :, text_seq_length:, :] = apply_rotary_emb(
|
371
|
+
key[:, :, text_seq_length:, :], image_rotary_emb, use_real_unbind_dim=-2
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
# For packed batches, need to carefully apply RoPE to appropriate tokens
|
375
|
+
assert query.shape[0] == packing_batch_size
|
376
|
+
assert key.shape[0] == packing_batch_size
|
377
|
+
assert len(image_rotary_emb) == batch_size
|
378
|
+
|
379
|
+
rope_idx = 0
|
380
|
+
for idx in range(packing_batch_size):
|
381
|
+
offset = 0
|
382
|
+
# Get text and image sequence lengths for samples in this packed batch
|
383
|
+
text_seq_length_bi = text_seq_length[batch_flag == idx]
|
384
|
+
latent_seq_length_bi = latent_seq_length[batch_flag == idx]
|
385
|
+
|
386
|
+
# Apply RoPE to each image segment in the packed sequence
|
387
|
+
for tlen, llen in zip(text_seq_length_bi, latent_seq_length_bi):
|
388
|
+
mlen = tlen + llen
|
389
|
+
# Apply RoPE only to image tokens (after text tokens)
|
390
|
+
query[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
391
|
+
query[idx, :, offset + tlen : offset + mlen, :],
|
392
|
+
image_rotary_emb[rope_idx],
|
393
|
+
use_real_unbind_dim=-2,
|
394
|
+
)
|
395
|
+
key[idx, :, offset + tlen : offset + mlen, :] = apply_rotary_emb(
|
396
|
+
key[idx, :, offset + tlen : offset + mlen, :],
|
397
|
+
image_rotary_emb[rope_idx],
|
398
|
+
use_real_unbind_dim=-2,
|
399
|
+
)
|
400
|
+
offset += mlen
|
401
|
+
rope_idx += 1
|
402
|
+
|
403
|
+
hidden_states = F.scaled_dot_product_attention(
|
404
|
+
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
405
|
+
)
|
406
|
+
|
407
|
+
# Reshape back: [batch, heads, seq_len, dim] -> [batch, seq_len, heads*dim]
|
408
|
+
hidden_states = hidden_states.transpose(1, 2).flatten(2, 3)
|
409
|
+
hidden_states = hidden_states.type_as(query)
|
410
|
+
|
411
|
+
# 5. Output projection - project attention output to model dimension
|
412
|
+
hidden_states = attn.to_out[0](hidden_states)
|
413
|
+
hidden_states = attn.to_out[1](hidden_states)
|
414
|
+
|
415
|
+
# Split the output back into text and image streams
|
416
|
+
if batch_flag is None:
|
417
|
+
# Simple split for non-packed case
|
418
|
+
encoder_hidden_states, hidden_states = hidden_states.split(
|
419
|
+
[text_seq_length, hidden_states.size(1) - text_seq_length], dim=1
|
420
|
+
)
|
421
|
+
else:
|
422
|
+
# For packed case: need to unpack, split text/image, then restore to original shapes
|
423
|
+
# First, unpad the sequence based on the packed sequence lengths
|
424
|
+
hidden_states_unpad = torch.nn.utils.rnn.unpad_sequence(
|
425
|
+
hidden_states,
|
426
|
+
lengths=torch.tensor(mixed_seq_length_packed),
|
427
|
+
batch_first=True,
|
428
|
+
)
|
429
|
+
# Concatenate all unpadded sequences
|
430
|
+
hidden_states_flatten = torch.cat(hidden_states_unpad, dim=0)
|
431
|
+
# Split by original sample sequence lengths
|
432
|
+
hidden_states_unpack = torch.split(hidden_states_flatten, mixed_seq_length.tolist())
|
433
|
+
assert len(hidden_states_unpack) == batch_size
|
434
|
+
|
435
|
+
# Further split each sample's sequence into text and image parts
|
436
|
+
hidden_states_unpack = [
|
437
|
+
torch.split(h, [tlen, llen])
|
438
|
+
for h, tlen, llen in zip(hidden_states_unpack, text_seq_length, latent_seq_length)
|
439
|
+
]
|
440
|
+
# Separate text and image sequences
|
441
|
+
encoder_hidden_states_unpad = [h[0] for h in hidden_states_unpack]
|
442
|
+
hidden_states_unpad = [h[1] for h in hidden_states_unpack]
|
443
|
+
|
444
|
+
# Update the original tensors with the processed values, respecting the attention masks
|
445
|
+
for idx in range(batch_size):
|
446
|
+
# Place unpacked text tokens back in the encoder_hidden_states tensor
|
447
|
+
encoder_hidden_states[idx][text_attn_mask[idx] == 1] = encoder_hidden_states_unpad[idx]
|
448
|
+
# Place unpacked image tokens back in the latent_hidden_states tensor
|
449
|
+
latent_hidden_states[idx][latent_attn_mask[idx] == 1] = hidden_states_unpad[idx]
|
450
|
+
|
451
|
+
# Update the output hidden states
|
452
|
+
hidden_states = latent_hidden_states
|
453
|
+
|
454
|
+
return hidden_states, encoder_hidden_states
|
455
|
+
|
456
|
+
|
457
|
+
@maybe_allow_in_graph
|
186
458
|
class CogView4TransformerBlock(nn.Module):
|
187
459
|
def __init__(
|
188
|
-
self,
|
460
|
+
self,
|
461
|
+
dim: int = 2560,
|
462
|
+
num_attention_heads: int = 64,
|
463
|
+
attention_head_dim: int = 40,
|
464
|
+
time_embed_dim: int = 512,
|
189
465
|
) -> None:
|
190
466
|
super().__init__()
|
191
467
|
|
@@ -213,9 +489,11 @@ class CogView4TransformerBlock(nn.Module):
|
|
213
489
|
hidden_states: torch.Tensor,
|
214
490
|
encoder_hidden_states: torch.Tensor,
|
215
491
|
temb: Optional[torch.Tensor] = None,
|
216
|
-
image_rotary_emb: Optional[
|
217
|
-
|
218
|
-
|
492
|
+
image_rotary_emb: Optional[
|
493
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
494
|
+
] = None,
|
495
|
+
attention_mask: Optional[Dict[str, torch.Tensor]] = None,
|
496
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
219
497
|
) -> torch.Tensor:
|
220
498
|
# 1. Timestep conditioning
|
221
499
|
(
|
@@ -232,12 +510,14 @@ class CogView4TransformerBlock(nn.Module):
|
|
232
510
|
) = self.norm1(hidden_states, encoder_hidden_states, temb)
|
233
511
|
|
234
512
|
# 2. Attention
|
513
|
+
if attention_kwargs is None:
|
514
|
+
attention_kwargs = {}
|
235
515
|
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
236
516
|
hidden_states=norm_hidden_states,
|
237
517
|
encoder_hidden_states=norm_encoder_hidden_states,
|
238
518
|
image_rotary_emb=image_rotary_emb,
|
239
519
|
attention_mask=attention_mask,
|
240
|
-
**
|
520
|
+
**attention_kwargs,
|
241
521
|
)
|
242
522
|
hidden_states = hidden_states + attn_hidden_states * gate_msa.unsqueeze(1)
|
243
523
|
encoder_hidden_states = encoder_hidden_states + attn_encoder_hidden_states * c_gate_msa.unsqueeze(1)
|
@@ -304,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
|
|
304
584
|
return (freqs.cos(), freqs.sin())
|
305
585
|
|
306
586
|
|
587
|
+
class CogView4AdaLayerNormContinuous(nn.Module):
|
588
|
+
"""
|
589
|
+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
590
|
+
Linear on conditioning embedding.
|
591
|
+
"""
|
592
|
+
|
593
|
+
def __init__(
|
594
|
+
self,
|
595
|
+
embedding_dim: int,
|
596
|
+
conditioning_embedding_dim: int,
|
597
|
+
elementwise_affine: bool = True,
|
598
|
+
eps: float = 1e-5,
|
599
|
+
bias: bool = True,
|
600
|
+
norm_type: str = "layer_norm",
|
601
|
+
):
|
602
|
+
super().__init__()
|
603
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
604
|
+
if norm_type == "layer_norm":
|
605
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
606
|
+
elif norm_type == "rms_norm":
|
607
|
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
608
|
+
else:
|
609
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
610
|
+
|
611
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
612
|
+
# *** NO SiLU here ***
|
613
|
+
emb = self.linear(conditioning_embedding.to(x.dtype))
|
614
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
615
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
616
|
+
return x
|
617
|
+
|
618
|
+
|
307
619
|
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
308
620
|
r"""
|
309
621
|
Args:
|
@@ -386,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
386
698
|
)
|
387
699
|
|
388
700
|
# 4. Output projection
|
389
|
-
self.norm_out =
|
701
|
+
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
390
702
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
391
703
|
|
392
704
|
self.gradient_checkpointing = False
|
@@ -402,7 +714,9 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
402
714
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
403
715
|
return_dict: bool = True,
|
404
716
|
attention_mask: Optional[torch.Tensor] = None,
|
405
|
-
|
717
|
+
image_rotary_emb: Optional[
|
718
|
+
Union[Tuple[torch.Tensor, torch.Tensor], List[Tuple[torch.Tensor, torch.Tensor]]]
|
719
|
+
] = None,
|
406
720
|
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
407
721
|
if attention_kwargs is not None:
|
408
722
|
attention_kwargs = attention_kwargs.copy()
|
@@ -422,7 +736,8 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
422
736
|
batch_size, num_channels, height, width = hidden_states.shape
|
423
737
|
|
424
738
|
# 1. RoPE
|
425
|
-
image_rotary_emb
|
739
|
+
if image_rotary_emb is None:
|
740
|
+
image_rotary_emb = self.rope(hidden_states)
|
426
741
|
|
427
742
|
# 2. Patch & Timestep embeddings
|
428
743
|
p = self.config.patch_size
|
@@ -438,11 +753,22 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
438
753
|
for block in self.transformer_blocks:
|
439
754
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
440
755
|
hidden_states, encoder_hidden_states = self._gradient_checkpointing_func(
|
441
|
-
block,
|
756
|
+
block,
|
757
|
+
hidden_states,
|
758
|
+
encoder_hidden_states,
|
759
|
+
temb,
|
760
|
+
image_rotary_emb,
|
761
|
+
attention_mask,
|
762
|
+
attention_kwargs,
|
442
763
|
)
|
443
764
|
else:
|
444
765
|
hidden_states, encoder_hidden_states = block(
|
445
|
-
hidden_states,
|
766
|
+
hidden_states,
|
767
|
+
encoder_hidden_states,
|
768
|
+
temb,
|
769
|
+
image_rotary_emb,
|
770
|
+
attention_mask,
|
771
|
+
attention_kwargs,
|
446
772
|
)
|
447
773
|
|
448
774
|
# 4. Output norm & projection
|