diffusers 0.33.1__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +145 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +3 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +2 -2
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +3 -3
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +9 -8
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +332 -227
- diffusers/hooks/hooks.py +58 -3
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +5 -10
- diffusers/hooks/pyramid_attention_broadcast.py +15 -12
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +10 -0
- diffusers/loaders/ip_adapter.py +260 -18
- diffusers/loaders/lora_base.py +261 -127
- diffusers/loaders/lora_conversion_utils.py +657 -35
- diffusers/loaders/lora_pipeline.py +2778 -1246
- diffusers/loaders/peft.py +78 -112
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +64 -15
- diffusers/loaders/single_file_utils.py +395 -7
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +10 -11
- diffusers/loaders/transformer_sd3.py +8 -3
- diffusers/loaders/unet.py +24 -21
- diffusers/loaders/unet_loader_utils.py +6 -3
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +23 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +488 -7
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +113 -667
- diffusers/models/auto_model.py +49 -12
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +17 -4
- diffusers/models/autoencoders/autoencoder_kl.py +5 -5
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1110 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +626 -62
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +32 -10
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +21 -20
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +5 -5
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +36 -46
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +203 -108
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +7 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +641 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +353 -27
- diffusers/models/transformers/transformer_cosmos.py +586 -0
- diffusers/models/transformers/transformer_flux.py +376 -138
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +12 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +105 -24
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +316 -87
- diffusers/models/transformers/transformer_wan_vace.py +387 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +4 -3
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +68 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +22 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +23 -20
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +17 -16
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +18 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +17 -16
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +23 -22
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +11 -10
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +226 -107
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +12 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +207 -105
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +8 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +4 -2
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +37 -36
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -9
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +31 -23
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +3 -2
- diffusers/pipelines/flux/pipeline_flux_fill.py +7 -7
- diffusers/pipelines/flux/pipeline_flux_img2img.py +40 -7
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +12 -7
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +2 -2
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +26 -25
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +64 -18
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +117 -38
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +63 -18
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +15 -14
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +5 -6
- diffusers/pipelines/pipeline_loading_utils.py +113 -15
- diffusers/pipelines/pipeline_utils.py +127 -48
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +14 -12
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +31 -11
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +12 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +13 -12
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +91 -30
- diffusers/pipelines/wan/pipeline_wan_i2v.py +145 -45
- diffusers/pipelines/wan/pipeline_wan_vace.py +975 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +14 -16
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +3 -1
- diffusers/quantizers/base.py +17 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +108 -16
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +31 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +16 -9
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +18 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +19 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -2
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +21 -12
- diffusers/schedulers/scheduling_utils.py +3 -3
- diffusers/schedulers/scheduling_utils_flax.py +2 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +91 -5
- diffusers/utils/__init__.py +15 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/constants.py +4 -0
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +432 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +480 -0
- diffusers/utils/dynamic_modules_utils.py +85 -8
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +151 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +96 -10
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +195 -17
- diffusers/utils/torch_utils.py +43 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/METADATA +72 -57
- diffusers-0.35.0.dist-info/RECORD +703 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/WHEEL +1 -1
- diffusers-0.33.1.dist-info/RECORD +0 -608
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.1.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 OmniGen team and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -283,7 +283,7 @@ class OmniGenBlock(nn.Module):
|
|
283
283
|
|
284
284
|
class OmniGenTransformer2DModel(ModelMixin, ConfigMixin):
|
285
285
|
"""
|
286
|
-
The Transformer model introduced in OmniGen (https://
|
286
|
+
The Transformer model introduced in OmniGen (https://huggingface.co/papers/2409.11340).
|
287
287
|
|
288
288
|
Parameters:
|
289
289
|
in_channels (`int`, defaults to `4`):
|
@@ -0,0 +1,645 @@
|
|
1
|
+
# Copyright 2025 Qwen-Image Team, The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import functools
|
16
|
+
import math
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
import torch.nn.functional as F
|
22
|
+
|
23
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
25
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
26
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
+
from ..attention import FeedForward
|
28
|
+
from ..attention_dispatch import dispatch_attention_fn
|
29
|
+
from ..attention_processor import Attention
|
30
|
+
from ..cache_utils import CacheMixin
|
31
|
+
from ..embeddings import TimestepEmbedding, Timesteps
|
32
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
33
|
+
from ..modeling_utils import ModelMixin
|
34
|
+
from ..normalization import AdaLayerNormContinuous, RMSNorm
|
35
|
+
|
36
|
+
|
37
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
38
|
+
|
39
|
+
|
40
|
+
def get_timestep_embedding(
|
41
|
+
timesteps: torch.Tensor,
|
42
|
+
embedding_dim: int,
|
43
|
+
flip_sin_to_cos: bool = False,
|
44
|
+
downscale_freq_shift: float = 1,
|
45
|
+
scale: float = 1,
|
46
|
+
max_period: int = 10000,
|
47
|
+
) -> torch.Tensor:
|
48
|
+
"""
|
49
|
+
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
50
|
+
|
51
|
+
Args
|
52
|
+
timesteps (torch.Tensor):
|
53
|
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
54
|
+
embedding_dim (int):
|
55
|
+
the dimension of the output.
|
56
|
+
flip_sin_to_cos (bool):
|
57
|
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
58
|
+
downscale_freq_shift (float):
|
59
|
+
Controls the delta between frequencies between dimensions
|
60
|
+
scale (float):
|
61
|
+
Scaling factor applied to the embeddings.
|
62
|
+
max_period (int):
|
63
|
+
Controls the maximum frequency of the embeddings
|
64
|
+
Returns
|
65
|
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
66
|
+
"""
|
67
|
+
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
68
|
+
|
69
|
+
half_dim = embedding_dim // 2
|
70
|
+
exponent = -math.log(max_period) * torch.arange(
|
71
|
+
start=0, end=half_dim, dtype=torch.float32, device=timesteps.device
|
72
|
+
)
|
73
|
+
exponent = exponent / (half_dim - downscale_freq_shift)
|
74
|
+
|
75
|
+
emb = torch.exp(exponent).to(timesteps.dtype)
|
76
|
+
emb = timesteps[:, None].float() * emb[None, :]
|
77
|
+
|
78
|
+
# scale embeddings
|
79
|
+
emb = scale * emb
|
80
|
+
|
81
|
+
# concat sine and cosine embeddings
|
82
|
+
emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=-1)
|
83
|
+
|
84
|
+
# flip sine and cosine embeddings
|
85
|
+
if flip_sin_to_cos:
|
86
|
+
emb = torch.cat([emb[:, half_dim:], emb[:, :half_dim]], dim=-1)
|
87
|
+
|
88
|
+
# zero pad
|
89
|
+
if embedding_dim % 2 == 1:
|
90
|
+
emb = torch.nn.functional.pad(emb, (0, 1, 0, 0))
|
91
|
+
return emb
|
92
|
+
|
93
|
+
|
94
|
+
def apply_rotary_emb_qwen(
|
95
|
+
x: torch.Tensor,
|
96
|
+
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
97
|
+
use_real: bool = True,
|
98
|
+
use_real_unbind_dim: int = -1,
|
99
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
100
|
+
"""
|
101
|
+
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
102
|
+
to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
|
103
|
+
reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
|
104
|
+
tensors contain rotary embeddings and are returned as real tensors.
|
105
|
+
|
106
|
+
Args:
|
107
|
+
x (`torch.Tensor`):
|
108
|
+
Query or key tensor to apply rotary embeddings. [B, S, H, D] xk (torch.Tensor): Key tensor to apply
|
109
|
+
freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
|
110
|
+
|
111
|
+
Returns:
|
112
|
+
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
113
|
+
"""
|
114
|
+
if use_real:
|
115
|
+
cos, sin = freqs_cis # [S, D]
|
116
|
+
cos = cos[None, None]
|
117
|
+
sin = sin[None, None]
|
118
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
119
|
+
|
120
|
+
if use_real_unbind_dim == -1:
|
121
|
+
# Used for flux, cogvideox, hunyuan-dit
|
122
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
123
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
124
|
+
elif use_real_unbind_dim == -2:
|
125
|
+
# Used for Stable Audio, OmniGen, CogView4 and Cosmos
|
126
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
127
|
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
128
|
+
else:
|
129
|
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
130
|
+
|
131
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
132
|
+
|
133
|
+
return out
|
134
|
+
else:
|
135
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
136
|
+
freqs_cis = freqs_cis.unsqueeze(1)
|
137
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
138
|
+
|
139
|
+
return x_out.type_as(x)
|
140
|
+
|
141
|
+
|
142
|
+
class QwenTimestepProjEmbeddings(nn.Module):
|
143
|
+
def __init__(self, embedding_dim):
|
144
|
+
super().__init__()
|
145
|
+
|
146
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0, scale=1000)
|
147
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
148
|
+
|
149
|
+
def forward(self, timestep, hidden_states):
|
150
|
+
timesteps_proj = self.time_proj(timestep)
|
151
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=hidden_states.dtype)) # (N, D)
|
152
|
+
|
153
|
+
conditioning = timesteps_emb
|
154
|
+
|
155
|
+
return conditioning
|
156
|
+
|
157
|
+
|
158
|
+
class QwenEmbedRope(nn.Module):
|
159
|
+
def __init__(self, theta: int, axes_dim: List[int], scale_rope=False):
|
160
|
+
super().__init__()
|
161
|
+
self.theta = theta
|
162
|
+
self.axes_dim = axes_dim
|
163
|
+
pos_index = torch.arange(4096)
|
164
|
+
neg_index = torch.arange(4096).flip(0) * -1 - 1
|
165
|
+
self.pos_freqs = torch.cat(
|
166
|
+
[
|
167
|
+
self.rope_params(pos_index, self.axes_dim[0], self.theta),
|
168
|
+
self.rope_params(pos_index, self.axes_dim[1], self.theta),
|
169
|
+
self.rope_params(pos_index, self.axes_dim[2], self.theta),
|
170
|
+
],
|
171
|
+
dim=1,
|
172
|
+
)
|
173
|
+
self.neg_freqs = torch.cat(
|
174
|
+
[
|
175
|
+
self.rope_params(neg_index, self.axes_dim[0], self.theta),
|
176
|
+
self.rope_params(neg_index, self.axes_dim[1], self.theta),
|
177
|
+
self.rope_params(neg_index, self.axes_dim[2], self.theta),
|
178
|
+
],
|
179
|
+
dim=1,
|
180
|
+
)
|
181
|
+
self.rope_cache = {}
|
182
|
+
|
183
|
+
# DO NOT USING REGISTER BUFFER HERE, IT WILL CAUSE COMPLEX NUMBERS LOSE ITS IMAGINARY PART
|
184
|
+
self.scale_rope = scale_rope
|
185
|
+
|
186
|
+
def rope_params(self, index, dim, theta=10000):
|
187
|
+
"""
|
188
|
+
Args:
|
189
|
+
index: [0, 1, 2, 3] 1D Tensor representing the position index of the token
|
190
|
+
"""
|
191
|
+
assert dim % 2 == 0
|
192
|
+
freqs = torch.outer(index, 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float32).div(dim)))
|
193
|
+
freqs = torch.polar(torch.ones_like(freqs), freqs)
|
194
|
+
return freqs
|
195
|
+
|
196
|
+
def forward(self, video_fhw, txt_seq_lens, device):
|
197
|
+
"""
|
198
|
+
Args: video_fhw: [frame, height, width] a list of 3 integers representing the shape of the video Args:
|
199
|
+
txt_length: [bs] a list of 1 integers representing the length of the text
|
200
|
+
"""
|
201
|
+
if self.pos_freqs.device != device:
|
202
|
+
self.pos_freqs = self.pos_freqs.to(device)
|
203
|
+
self.neg_freqs = self.neg_freqs.to(device)
|
204
|
+
|
205
|
+
if isinstance(video_fhw, list):
|
206
|
+
video_fhw = video_fhw[0]
|
207
|
+
if not isinstance(video_fhw, list):
|
208
|
+
video_fhw = [video_fhw]
|
209
|
+
|
210
|
+
vid_freqs = []
|
211
|
+
max_vid_index = 0
|
212
|
+
for idx, fhw in enumerate(video_fhw):
|
213
|
+
frame, height, width = fhw
|
214
|
+
rope_key = f"{idx}_{height}_{width}"
|
215
|
+
|
216
|
+
if not torch.compiler.is_compiling():
|
217
|
+
if rope_key not in self.rope_cache:
|
218
|
+
self.rope_cache[rope_key] = self._compute_video_freqs(frame, height, width, idx)
|
219
|
+
video_freq = self.rope_cache[rope_key]
|
220
|
+
else:
|
221
|
+
video_freq = self._compute_video_freqs(frame, height, width, idx)
|
222
|
+
video_freq = video_freq.to(device)
|
223
|
+
vid_freqs.append(video_freq)
|
224
|
+
|
225
|
+
if self.scale_rope:
|
226
|
+
max_vid_index = max(height // 2, width // 2, max_vid_index)
|
227
|
+
else:
|
228
|
+
max_vid_index = max(height, width, max_vid_index)
|
229
|
+
|
230
|
+
max_len = max(txt_seq_lens)
|
231
|
+
txt_freqs = self.pos_freqs[max_vid_index : max_vid_index + max_len, ...]
|
232
|
+
vid_freqs = torch.cat(vid_freqs, dim=0)
|
233
|
+
|
234
|
+
return vid_freqs, txt_freqs
|
235
|
+
|
236
|
+
@functools.lru_cache(maxsize=None)
|
237
|
+
def _compute_video_freqs(self, frame, height, width, idx=0):
|
238
|
+
seq_lens = frame * height * width
|
239
|
+
freqs_pos = self.pos_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
240
|
+
freqs_neg = self.neg_freqs.split([x // 2 for x in self.axes_dim], dim=1)
|
241
|
+
|
242
|
+
freqs_frame = freqs_pos[0][idx : idx + frame].view(frame, 1, 1, -1).expand(frame, height, width, -1)
|
243
|
+
if self.scale_rope:
|
244
|
+
freqs_height = torch.cat([freqs_neg[1][-(height - height // 2) :], freqs_pos[1][: height // 2]], dim=0)
|
245
|
+
freqs_height = freqs_height.view(1, height, 1, -1).expand(frame, height, width, -1)
|
246
|
+
freqs_width = torch.cat([freqs_neg[2][-(width - width // 2) :], freqs_pos[2][: width // 2]], dim=0)
|
247
|
+
freqs_width = freqs_width.view(1, 1, width, -1).expand(frame, height, width, -1)
|
248
|
+
else:
|
249
|
+
freqs_height = freqs_pos[1][:height].view(1, height, 1, -1).expand(frame, height, width, -1)
|
250
|
+
freqs_width = freqs_pos[2][:width].view(1, 1, width, -1).expand(frame, height, width, -1)
|
251
|
+
|
252
|
+
freqs = torch.cat([freqs_frame, freqs_height, freqs_width], dim=-1).reshape(seq_lens, -1)
|
253
|
+
return freqs.clone().contiguous()
|
254
|
+
|
255
|
+
|
256
|
+
class QwenDoubleStreamAttnProcessor2_0:
|
257
|
+
"""
|
258
|
+
Attention processor for Qwen double-stream architecture, matching DoubleStreamLayerMegatron logic. This processor
|
259
|
+
implements joint attention computation where text and image streams are processed together.
|
260
|
+
"""
|
261
|
+
|
262
|
+
_attention_backend = None
|
263
|
+
|
264
|
+
def __init__(self):
|
265
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
266
|
+
raise ImportError(
|
267
|
+
"QwenDoubleStreamAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
268
|
+
)
|
269
|
+
|
270
|
+
def __call__(
|
271
|
+
self,
|
272
|
+
attn: Attention,
|
273
|
+
hidden_states: torch.FloatTensor, # Image stream
|
274
|
+
encoder_hidden_states: torch.FloatTensor = None, # Text stream
|
275
|
+
encoder_hidden_states_mask: torch.FloatTensor = None,
|
276
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
277
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
278
|
+
) -> torch.FloatTensor:
|
279
|
+
if encoder_hidden_states is None:
|
280
|
+
raise ValueError("QwenDoubleStreamAttnProcessor2_0 requires encoder_hidden_states (text stream)")
|
281
|
+
|
282
|
+
seq_txt = encoder_hidden_states.shape[1]
|
283
|
+
|
284
|
+
# Compute QKV for image stream (sample projections)
|
285
|
+
img_query = attn.to_q(hidden_states)
|
286
|
+
img_key = attn.to_k(hidden_states)
|
287
|
+
img_value = attn.to_v(hidden_states)
|
288
|
+
|
289
|
+
# Compute QKV for text stream (context projections)
|
290
|
+
txt_query = attn.add_q_proj(encoder_hidden_states)
|
291
|
+
txt_key = attn.add_k_proj(encoder_hidden_states)
|
292
|
+
txt_value = attn.add_v_proj(encoder_hidden_states)
|
293
|
+
|
294
|
+
# Reshape for multi-head attention
|
295
|
+
img_query = img_query.unflatten(-1, (attn.heads, -1))
|
296
|
+
img_key = img_key.unflatten(-1, (attn.heads, -1))
|
297
|
+
img_value = img_value.unflatten(-1, (attn.heads, -1))
|
298
|
+
|
299
|
+
txt_query = txt_query.unflatten(-1, (attn.heads, -1))
|
300
|
+
txt_key = txt_key.unflatten(-1, (attn.heads, -1))
|
301
|
+
txt_value = txt_value.unflatten(-1, (attn.heads, -1))
|
302
|
+
|
303
|
+
# Apply QK normalization
|
304
|
+
if attn.norm_q is not None:
|
305
|
+
img_query = attn.norm_q(img_query)
|
306
|
+
if attn.norm_k is not None:
|
307
|
+
img_key = attn.norm_k(img_key)
|
308
|
+
if attn.norm_added_q is not None:
|
309
|
+
txt_query = attn.norm_added_q(txt_query)
|
310
|
+
if attn.norm_added_k is not None:
|
311
|
+
txt_key = attn.norm_added_k(txt_key)
|
312
|
+
|
313
|
+
# Apply RoPE
|
314
|
+
if image_rotary_emb is not None:
|
315
|
+
img_freqs, txt_freqs = image_rotary_emb
|
316
|
+
img_query = apply_rotary_emb_qwen(img_query, img_freqs, use_real=False)
|
317
|
+
img_key = apply_rotary_emb_qwen(img_key, img_freqs, use_real=False)
|
318
|
+
txt_query = apply_rotary_emb_qwen(txt_query, txt_freqs, use_real=False)
|
319
|
+
txt_key = apply_rotary_emb_qwen(txt_key, txt_freqs, use_real=False)
|
320
|
+
|
321
|
+
# Concatenate for joint attention
|
322
|
+
# Order: [text, image]
|
323
|
+
joint_query = torch.cat([txt_query, img_query], dim=1)
|
324
|
+
joint_key = torch.cat([txt_key, img_key], dim=1)
|
325
|
+
joint_value = torch.cat([txt_value, img_value], dim=1)
|
326
|
+
|
327
|
+
# Compute joint attention
|
328
|
+
joint_hidden_states = dispatch_attention_fn(
|
329
|
+
joint_query,
|
330
|
+
joint_key,
|
331
|
+
joint_value,
|
332
|
+
attn_mask=attention_mask,
|
333
|
+
dropout_p=0.0,
|
334
|
+
is_causal=False,
|
335
|
+
backend=self._attention_backend,
|
336
|
+
)
|
337
|
+
|
338
|
+
# Reshape back
|
339
|
+
joint_hidden_states = joint_hidden_states.flatten(2, 3)
|
340
|
+
joint_hidden_states = joint_hidden_states.to(joint_query.dtype)
|
341
|
+
|
342
|
+
# Split attention outputs back
|
343
|
+
txt_attn_output = joint_hidden_states[:, :seq_txt, :] # Text part
|
344
|
+
img_attn_output = joint_hidden_states[:, seq_txt:, :] # Image part
|
345
|
+
|
346
|
+
# Apply output projections
|
347
|
+
img_attn_output = attn.to_out[0](img_attn_output)
|
348
|
+
if len(attn.to_out) > 1:
|
349
|
+
img_attn_output = attn.to_out[1](img_attn_output) # dropout
|
350
|
+
|
351
|
+
txt_attn_output = attn.to_add_out(txt_attn_output)
|
352
|
+
|
353
|
+
return img_attn_output, txt_attn_output
|
354
|
+
|
355
|
+
|
356
|
+
@maybe_allow_in_graph
|
357
|
+
class QwenImageTransformerBlock(nn.Module):
|
358
|
+
def __init__(
|
359
|
+
self, dim: int, num_attention_heads: int, attention_head_dim: int, qk_norm: str = "rms_norm", eps: float = 1e-6
|
360
|
+
):
|
361
|
+
super().__init__()
|
362
|
+
|
363
|
+
self.dim = dim
|
364
|
+
self.num_attention_heads = num_attention_heads
|
365
|
+
self.attention_head_dim = attention_head_dim
|
366
|
+
|
367
|
+
# Image processing modules
|
368
|
+
self.img_mod = nn.Sequential(
|
369
|
+
nn.SiLU(),
|
370
|
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
371
|
+
)
|
372
|
+
self.img_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
373
|
+
self.attn = Attention(
|
374
|
+
query_dim=dim,
|
375
|
+
cross_attention_dim=None, # Enable cross attention for joint computation
|
376
|
+
added_kv_proj_dim=dim, # Enable added KV projections for text stream
|
377
|
+
dim_head=attention_head_dim,
|
378
|
+
heads=num_attention_heads,
|
379
|
+
out_dim=dim,
|
380
|
+
context_pre_only=False,
|
381
|
+
bias=True,
|
382
|
+
processor=QwenDoubleStreamAttnProcessor2_0(),
|
383
|
+
qk_norm=qk_norm,
|
384
|
+
eps=eps,
|
385
|
+
)
|
386
|
+
self.img_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
387
|
+
self.img_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
388
|
+
|
389
|
+
# Text processing modules
|
390
|
+
self.txt_mod = nn.Sequential(
|
391
|
+
nn.SiLU(),
|
392
|
+
nn.Linear(dim, 6 * dim, bias=True), # For scale, shift, gate for norm1 and norm2
|
393
|
+
)
|
394
|
+
self.txt_norm1 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
395
|
+
# Text doesn't need separate attention - it's handled by img_attn joint computation
|
396
|
+
self.txt_norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=eps)
|
397
|
+
self.txt_mlp = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
398
|
+
|
399
|
+
def _modulate(self, x, mod_params):
|
400
|
+
"""Apply modulation to input tensor"""
|
401
|
+
shift, scale, gate = mod_params.chunk(3, dim=-1)
|
402
|
+
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1), gate.unsqueeze(1)
|
403
|
+
|
404
|
+
def forward(
|
405
|
+
self,
|
406
|
+
hidden_states: torch.Tensor,
|
407
|
+
encoder_hidden_states: torch.Tensor,
|
408
|
+
encoder_hidden_states_mask: torch.Tensor,
|
409
|
+
temb: torch.Tensor,
|
410
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
411
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
412
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
413
|
+
# Get modulation parameters for both streams
|
414
|
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
415
|
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
416
|
+
|
417
|
+
# Split modulation parameters for norm1 and norm2
|
418
|
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
419
|
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
420
|
+
|
421
|
+
# Process image stream - norm1 + modulation
|
422
|
+
img_normed = self.img_norm1(hidden_states)
|
423
|
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
424
|
+
|
425
|
+
# Process text stream - norm1 + modulation
|
426
|
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
427
|
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
428
|
+
|
429
|
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
430
|
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
431
|
+
# 1. Computes QKV for both streams
|
432
|
+
# 2. Applies QK normalization and RoPE
|
433
|
+
# 3. Concatenates and runs joint attention
|
434
|
+
# 4. Splits results back to separate streams
|
435
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
436
|
+
attn_output = self.attn(
|
437
|
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
438
|
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
439
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
440
|
+
image_rotary_emb=image_rotary_emb,
|
441
|
+
**joint_attention_kwargs,
|
442
|
+
)
|
443
|
+
|
444
|
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
445
|
+
img_attn_output, txt_attn_output = attn_output
|
446
|
+
|
447
|
+
# Apply attention gates and add residual (like in Megatron)
|
448
|
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
449
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
450
|
+
|
451
|
+
# Process image stream - norm2 + MLP
|
452
|
+
img_normed2 = self.img_norm2(hidden_states)
|
453
|
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
454
|
+
img_mlp_output = self.img_mlp(img_modulated2)
|
455
|
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
456
|
+
|
457
|
+
# Process text stream - norm2 + MLP
|
458
|
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
459
|
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
460
|
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
461
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
462
|
+
|
463
|
+
# Clip to prevent overflow for fp16
|
464
|
+
if encoder_hidden_states.dtype == torch.float16:
|
465
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
466
|
+
if hidden_states.dtype == torch.float16:
|
467
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
468
|
+
|
469
|
+
return encoder_hidden_states, hidden_states
|
470
|
+
|
471
|
+
|
472
|
+
class QwenImageTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
473
|
+
"""
|
474
|
+
The Transformer model introduced in Qwen.
|
475
|
+
|
476
|
+
Args:
|
477
|
+
patch_size (`int`, defaults to `2`):
|
478
|
+
Patch size to turn the input data into small patches.
|
479
|
+
in_channels (`int`, defaults to `64`):
|
480
|
+
The number of channels in the input.
|
481
|
+
out_channels (`int`, *optional*, defaults to `None`):
|
482
|
+
The number of channels in the output. If not specified, it defaults to `in_channels`.
|
483
|
+
num_layers (`int`, defaults to `60`):
|
484
|
+
The number of layers of dual stream DiT blocks to use.
|
485
|
+
attention_head_dim (`int`, defaults to `128`):
|
486
|
+
The number of dimensions to use for each attention head.
|
487
|
+
num_attention_heads (`int`, defaults to `24`):
|
488
|
+
The number of attention heads to use.
|
489
|
+
joint_attention_dim (`int`, defaults to `3584`):
|
490
|
+
The number of dimensions to use for the joint attention (embedding/channel dimension of
|
491
|
+
`encoder_hidden_states`).
|
492
|
+
guidance_embeds (`bool`, defaults to `False`):
|
493
|
+
Whether to use guidance embeddings for guidance-distilled variant of the model.
|
494
|
+
axes_dims_rope (`Tuple[int]`, defaults to `(16, 56, 56)`):
|
495
|
+
The dimensions to use for the rotary positional embeddings.
|
496
|
+
"""
|
497
|
+
|
498
|
+
_supports_gradient_checkpointing = True
|
499
|
+
_no_split_modules = ["QwenImageTransformerBlock"]
|
500
|
+
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
501
|
+
_repeated_blocks = ["QwenImageTransformerBlock"]
|
502
|
+
|
503
|
+
@register_to_config
|
504
|
+
def __init__(
|
505
|
+
self,
|
506
|
+
patch_size: int = 2,
|
507
|
+
in_channels: int = 64,
|
508
|
+
out_channels: Optional[int] = 16,
|
509
|
+
num_layers: int = 60,
|
510
|
+
attention_head_dim: int = 128,
|
511
|
+
num_attention_heads: int = 24,
|
512
|
+
joint_attention_dim: int = 3584,
|
513
|
+
guidance_embeds: bool = False, # TODO: this should probably be removed
|
514
|
+
axes_dims_rope: Tuple[int, int, int] = (16, 56, 56),
|
515
|
+
):
|
516
|
+
super().__init__()
|
517
|
+
self.out_channels = out_channels or in_channels
|
518
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
519
|
+
|
520
|
+
self.pos_embed = QwenEmbedRope(theta=10000, axes_dim=list(axes_dims_rope), scale_rope=True)
|
521
|
+
|
522
|
+
self.time_text_embed = QwenTimestepProjEmbeddings(embedding_dim=self.inner_dim)
|
523
|
+
|
524
|
+
self.txt_norm = RMSNorm(joint_attention_dim, eps=1e-6)
|
525
|
+
|
526
|
+
self.img_in = nn.Linear(in_channels, self.inner_dim)
|
527
|
+
self.txt_in = nn.Linear(joint_attention_dim, self.inner_dim)
|
528
|
+
|
529
|
+
self.transformer_blocks = nn.ModuleList(
|
530
|
+
[
|
531
|
+
QwenImageTransformerBlock(
|
532
|
+
dim=self.inner_dim,
|
533
|
+
num_attention_heads=num_attention_heads,
|
534
|
+
attention_head_dim=attention_head_dim,
|
535
|
+
)
|
536
|
+
for _ in range(num_layers)
|
537
|
+
]
|
538
|
+
)
|
539
|
+
|
540
|
+
self.norm_out = AdaLayerNormContinuous(self.inner_dim, self.inner_dim, elementwise_affine=False, eps=1e-6)
|
541
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
542
|
+
|
543
|
+
self.gradient_checkpointing = False
|
544
|
+
|
545
|
+
def forward(
|
546
|
+
self,
|
547
|
+
hidden_states: torch.Tensor,
|
548
|
+
encoder_hidden_states: torch.Tensor = None,
|
549
|
+
encoder_hidden_states_mask: torch.Tensor = None,
|
550
|
+
timestep: torch.LongTensor = None,
|
551
|
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
552
|
+
txt_seq_lens: Optional[List[int]] = None,
|
553
|
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
554
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
555
|
+
return_dict: bool = True,
|
556
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
557
|
+
"""
|
558
|
+
The [`QwenTransformer2DModel`] forward method.
|
559
|
+
|
560
|
+
Args:
|
561
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
562
|
+
Input `hidden_states`.
|
563
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
564
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
565
|
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
566
|
+
Mask of the input conditions.
|
567
|
+
timestep ( `torch.LongTensor`):
|
568
|
+
Used to indicate denoising step.
|
569
|
+
attention_kwargs (`dict`, *optional*):
|
570
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
571
|
+
`self.processor` in
|
572
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
573
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
574
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
575
|
+
tuple.
|
576
|
+
|
577
|
+
Returns:
|
578
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
579
|
+
`tuple` where the first element is the sample tensor.
|
580
|
+
"""
|
581
|
+
if attention_kwargs is not None:
|
582
|
+
attention_kwargs = attention_kwargs.copy()
|
583
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
584
|
+
else:
|
585
|
+
lora_scale = 1.0
|
586
|
+
|
587
|
+
if USE_PEFT_BACKEND:
|
588
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
589
|
+
scale_lora_layers(self, lora_scale)
|
590
|
+
else:
|
591
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
592
|
+
logger.warning(
|
593
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
594
|
+
)
|
595
|
+
|
596
|
+
hidden_states = self.img_in(hidden_states)
|
597
|
+
|
598
|
+
timestep = timestep.to(hidden_states.dtype)
|
599
|
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
600
|
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
601
|
+
|
602
|
+
if guidance is not None:
|
603
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
604
|
+
|
605
|
+
temb = (
|
606
|
+
self.time_text_embed(timestep, hidden_states)
|
607
|
+
if guidance is None
|
608
|
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
609
|
+
)
|
610
|
+
|
611
|
+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
|
612
|
+
|
613
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
614
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
615
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
616
|
+
block,
|
617
|
+
hidden_states,
|
618
|
+
encoder_hidden_states,
|
619
|
+
encoder_hidden_states_mask,
|
620
|
+
temb,
|
621
|
+
image_rotary_emb,
|
622
|
+
)
|
623
|
+
|
624
|
+
else:
|
625
|
+
encoder_hidden_states, hidden_states = block(
|
626
|
+
hidden_states=hidden_states,
|
627
|
+
encoder_hidden_states=encoder_hidden_states,
|
628
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
629
|
+
temb=temb,
|
630
|
+
image_rotary_emb=image_rotary_emb,
|
631
|
+
joint_attention_kwargs=attention_kwargs,
|
632
|
+
)
|
633
|
+
|
634
|
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
635
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
636
|
+
output = self.proj_out(hidden_states)
|
637
|
+
|
638
|
+
if USE_PEFT_BACKEND:
|
639
|
+
# remove `lora_scale` from each PEFT layer
|
640
|
+
unscale_lora_layers(self, lora_scale)
|
641
|
+
|
642
|
+
if not return_dict:
|
643
|
+
return (output,)
|
644
|
+
|
645
|
+
return Transformer2DModelOutput(sample=output)
|