diffusers 0.33.0__py3-none-any.whl → 0.34.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +48 -1
- diffusers/commands/__init__.py +1 -1
- diffusers/commands/diffusers_cli.py +1 -1
- diffusers/commands/env.py +1 -1
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/dependency_versions_check.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/experimental/rl/value_guided_sampling.py +1 -1
- diffusers/hooks/faster_cache.py +2 -2
- diffusers/hooks/group_offloading.py +128 -29
- diffusers/hooks/hooks.py +2 -2
- diffusers/hooks/layerwise_casting.py +3 -3
- diffusers/hooks/pyramid_attention_broadcast.py +1 -1
- diffusers/image_processor.py +7 -2
- diffusers/loaders/__init__.py +4 -0
- diffusers/loaders/ip_adapter.py +5 -14
- diffusers/loaders/lora_base.py +212 -111
- diffusers/loaders/lora_conversion_utils.py +275 -34
- diffusers/loaders/lora_pipeline.py +1554 -819
- diffusers/loaders/peft.py +52 -109
- diffusers/loaders/single_file.py +2 -2
- diffusers/loaders/single_file_model.py +20 -4
- diffusers/loaders/single_file_utils.py +225 -5
- diffusers/loaders/textual_inversion.py +3 -2
- diffusers/loaders/transformer_flux.py +1 -1
- diffusers/loaders/transformer_sd3.py +2 -2
- diffusers/loaders/unet.py +2 -16
- diffusers/loaders/unet_loader_utils.py +1 -1
- diffusers/loaders/utils.py +1 -1
- diffusers/models/__init__.py +15 -1
- diffusers/models/activations.py +5 -5
- diffusers/models/adapter.py +2 -3
- diffusers/models/attention.py +4 -4
- diffusers/models/attention_flax.py +10 -10
- diffusers/models/attention_processor.py +14 -10
- diffusers/models/auto_model.py +47 -10
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_dc.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +6 -6
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +1108 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +2 -2
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +3 -3
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -4
- diffusers/models/autoencoders/autoencoder_kl_wan.py +256 -22
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -1
- diffusers/models/autoencoders/autoencoder_tiny.py +3 -3
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vae.py +13 -2
- diffusers/models/autoencoders/vq_model.py +2 -2
- diffusers/models/cache_utils.py +1 -1
- diffusers/models/controlnet.py +1 -1
- diffusers/models/controlnet_flux.py +1 -1
- diffusers/models/controlnet_sd3.py +1 -1
- diffusers/models/controlnet_sparsectrl.py +1 -1
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -3
- diffusers/models/controlnets/controlnet_flax.py +1 -1
- diffusers/models/controlnets/controlnet_flux.py +16 -15
- diffusers/models/controlnets/controlnet_hunyuan.py +2 -2
- diffusers/models/controlnets/controlnet_sana.py +290 -0
- diffusers/models/controlnets/controlnet_sd3.py +1 -1
- diffusers/models/controlnets/controlnet_sparsectrl.py +2 -2
- diffusers/models/controlnets/controlnet_union.py +1 -1
- diffusers/models/controlnets/controlnet_xs.py +7 -7
- diffusers/models/controlnets/multicontrolnet.py +4 -5
- diffusers/models/controlnets/multicontrolnet_union.py +5 -6
- diffusers/models/downsampling.py +2 -2
- diffusers/models/embeddings.py +10 -12
- diffusers/models/embeddings_flax.py +2 -2
- diffusers/models/lora.py +3 -3
- diffusers/models/modeling_utils.py +44 -14
- diffusers/models/normalization.py +4 -4
- diffusers/models/resnet.py +2 -2
- diffusers/models/resnet_flax.py +1 -1
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +70 -24
- diffusers/models/transformers/cogvideox_transformer_3d.py +1 -1
- diffusers/models/transformers/consisid_transformer_3d.py +1 -1
- diffusers/models/transformers/dit_transformer_2d.py +2 -2
- diffusers/models/transformers/dual_transformer_2d.py +1 -1
- diffusers/models/transformers/hunyuan_transformer_2d.py +2 -2
- diffusers/models/transformers/latte_transformer_3d.py +4 -5
- diffusers/models/transformers/lumina_nextdit2d.py +2 -2
- diffusers/models/transformers/pixart_transformer_2d.py +3 -3
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/sana_transformer.py +8 -3
- diffusers/models/transformers/stable_audio_transformer.py +5 -9
- diffusers/models/transformers/t5_film_transformer.py +3 -3
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +1 -1
- diffusers/models/transformers/transformer_chroma.py +742 -0
- diffusers/models/transformers/transformer_cogview3plus.py +5 -10
- diffusers/models/transformers/transformer_cogview4.py +317 -25
- diffusers/models/transformers/transformer_cosmos.py +579 -0
- diffusers/models/transformers/transformer_flux.py +9 -11
- diffusers/models/transformers/transformer_hidream_image.py +942 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -8
- diffusers/models/transformers/transformer_hunyuan_video_framepack.py +416 -0
- diffusers/models/transformers/transformer_ltx.py +2 -2
- diffusers/models/transformers/transformer_lumina2.py +1 -1
- diffusers/models/transformers/transformer_mochi.py +1 -1
- diffusers/models/transformers/transformer_omnigen.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +7 -7
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/transformers/transformer_wan.py +24 -8
- diffusers/models/transformers/transformer_wan_vace.py +393 -0
- diffusers/models/unets/unet_1d.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +1 -1
- diffusers/models/unets/unet_2d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_blocks_flax.py +8 -7
- diffusers/models/unets/unet_2d_condition.py +2 -2
- diffusers/models/unets/unet_2d_condition_flax.py +2 -2
- diffusers/models/unets/unet_3d_blocks.py +1 -1
- diffusers/models/unets/unet_3d_condition.py +3 -3
- diffusers/models/unets/unet_i2vgen_xl.py +3 -3
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +2 -2
- diffusers/models/unets/unet_stable_cascade.py +1 -1
- diffusers/models/upsampling.py +2 -2
- diffusers/models/vae_flax.py +2 -2
- diffusers/models/vq_model.py +1 -1
- diffusers/pipelines/__init__.py +37 -6
- diffusers/pipelines/allegro/pipeline_allegro.py +11 -11
- diffusers/pipelines/amused/pipeline_amused.py +7 -6
- diffusers/pipelines/amused/pipeline_amused_img2img.py +6 -5
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +6 -5
- diffusers/pipelines/animatediff/pipeline_animatediff.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +16 -15
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +6 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +5 -5
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +5 -5
- diffusers/pipelines/audioldm/pipeline_audioldm.py +8 -7
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +23 -13
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +48 -11
- diffusers/pipelines/auto_pipeline.py +6 -7
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +11 -10
- diffusers/pipelines/chroma/__init__.py +49 -0
- diffusers/pipelines/chroma/pipeline_chroma.py +949 -0
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +1034 -0
- diffusers/pipelines/chroma/pipeline_output.py +21 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +8 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +8 -8
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +9 -9
- diffusers/pipelines/cogview4/pipeline_cogview4.py +7 -7
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +7 -7
- diffusers/pipelines/consisid/consisid_utils.py +2 -2
- diffusers/pipelines/consisid/pipeline_consisid.py +8 -8
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +8 -8
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +7 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +14 -14
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +5 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +13 -13
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +8 -8
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +7 -7
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -10
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +9 -7
- diffusers/pipelines/cosmos/__init__.py +54 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_text2image.py +673 -0
- diffusers/pipelines/cosmos/pipeline_cosmos2_video2world.py +792 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_text2world.py +664 -0
- diffusers/pipelines/cosmos/pipeline_cosmos_video2world.py +826 -0
- diffusers/pipelines/cosmos/pipeline_output.py +40 -0
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +5 -4
- diffusers/pipelines/ddim/pipeline_ddim.py +4 -4
- diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +10 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +10 -10
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +8 -8
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -5
- diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
- diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +3 -3
- diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
- diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +2 -2
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +4 -3
- diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +9 -9
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -8
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -5
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +18 -18
- diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +2 -2
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +6 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +5 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +5 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +1 -1
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +4 -4
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +7 -6
- diffusers/pipelines/flux/modeling_flux.py +1 -1
- diffusers/pipelines/flux/pipeline_flux.py +10 -17
- diffusers/pipelines/flux/pipeline_flux_control.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +30 -22
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +2 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +6 -6
- diffusers/pipelines/flux/pipeline_flux_img2img.py +39 -6
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +11 -6
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/free_init_utils.py +2 -2
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hidream_image/__init__.py +47 -0
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +1026 -0
- diffusers/pipelines/hidream_image/pipeline_output.py +35 -0
- diffusers/pipelines/hunyuan_video/__init__.py +2 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +8 -8
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_framepack.py +1114 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +71 -15
- diffusers/pipelines/hunyuan_video/pipeline_output.py +19 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +8 -8
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +10 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +6 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +34 -34
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +19 -26
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +7 -7
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +11 -11
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +35 -35
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +17 -39
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +17 -45
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +10 -10
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +7 -7
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +17 -38
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -10
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +12 -12
- diffusers/pipelines/kolors/text_encoder.py +3 -3
- diffusers/pipelines/kolors/tokenizer.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +2 -2
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +3 -3
- diffusers/pipelines/latte/pipeline_latte.py +12 -12
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +13 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +17 -16
- diffusers/pipelines/ltx/__init__.py +4 -0
- diffusers/pipelines/ltx/modeling_latent_upsampler.py +188 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +51 -6
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +107 -29
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +50 -6
- diffusers/pipelines/ltx/pipeline_ltx_latent_upsample.py +277 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +13 -13
- diffusers/pipelines/lumina2/pipeline_lumina2.py +10 -10
- diffusers/pipelines/marigold/marigold_image_processing.py +2 -2
- diffusers/pipelines/mochi/pipeline_mochi.py +6 -6
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -13
- diffusers/pipelines/omnigen/pipeline_omnigen.py +13 -11
- diffusers/pipelines/omnigen/processor_omnigen.py +8 -3
- diffusers/pipelines/onnx_utils.py +15 -2
- diffusers/pipelines/pag/pag_utils.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -8
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +10 -6
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +14 -14
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_kolors.py +10 -10
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +11 -11
- diffusers/pipelines/pag/pipeline_pag_sana.py +18 -12
- diffusers/pipelines/pag/pipeline_pag_sd.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +7 -7
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +5 -5
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +8 -8
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +16 -15
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +18 -17
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +12 -12
- diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +8 -7
- diffusers/pipelines/pia/pipeline_pia.py +8 -6
- diffusers/pipelines/pipeline_flax_utils.py +3 -4
- diffusers/pipelines/pipeline_loading_utils.py +89 -13
- diffusers/pipelines/pipeline_utils.py +105 -33
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +11 -11
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +11 -11
- diffusers/pipelines/sana/__init__.py +4 -0
- diffusers/pipelines/sana/pipeline_sana.py +23 -21
- diffusers/pipelines/sana/pipeline_sana_controlnet.py +1106 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +23 -19
- diffusers/pipelines/sana/pipeline_sana_sprint_img2img.py +981 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +7 -6
- diffusers/pipelines/shap_e/camera.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
- diffusers/pipelines/shap_e/renderer.py +3 -3
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +1 -1
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +5 -5
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +8 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +13 -13
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +9 -9
- diffusers/pipelines/stable_diffusion/__init__.py +0 -7
- diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +11 -4
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +10 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +9 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +8 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +7 -7
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -5
- diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
- diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +7 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +7 -7
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +12 -8
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +15 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +11 -9
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -9
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +18 -12
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +11 -8
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +11 -8
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -12
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +8 -6
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +15 -11
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -15
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +18 -17
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +12 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -15
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +3 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +12 -12
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -17
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +12 -7
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +15 -13
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +24 -21
- diffusers/pipelines/unclip/pipeline_unclip.py +4 -3
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +4 -3
- diffusers/pipelines/unclip/text_proj.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +2 -2
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +8 -7
- diffusers/pipelines/visualcloze/__init__.py +52 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_combined.py +444 -0
- diffusers/pipelines/visualcloze/pipeline_visualcloze_generation.py +952 -0
- diffusers/pipelines/visualcloze/visualcloze_utils.py +251 -0
- diffusers/pipelines/wan/__init__.py +2 -0
- diffusers/pipelines/wan/pipeline_wan.py +17 -12
- diffusers/pipelines/wan/pipeline_wan_i2v.py +42 -20
- diffusers/pipelines/wan/pipeline_wan_vace.py +976 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +18 -18
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +1 -1
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +16 -15
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +6 -6
- diffusers/quantizers/__init__.py +179 -1
- diffusers/quantizers/base.py +6 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -0
- diffusers/quantizers/bitsandbytes/utils.py +10 -7
- diffusers/quantizers/gguf/gguf_quantizer.py +13 -4
- diffusers/quantizers/gguf/utils.py +16 -13
- diffusers/quantizers/quantization_config.py +18 -16
- diffusers/quantizers/quanto/quanto_quantizer.py +4 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +5 -1
- diffusers/schedulers/__init__.py +3 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +4 -3
- diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -1
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +10 -5
- diffusers/schedulers/scheduling_ddim.py +8 -8
- diffusers/schedulers/scheduling_ddim_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_ddim_flax.py +6 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +6 -6
- diffusers/schedulers/scheduling_ddim_parallel.py +22 -22
- diffusers/schedulers/scheduling_ddpm.py +9 -9
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -18
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +2 -2
- diffusers/schedulers/scheduling_deis_multistep.py +8 -8
- diffusers/schedulers/scheduling_dpm_cogvideox.py +5 -5
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -12
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +22 -20
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +11 -11
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +13 -13
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +13 -8
- diffusers/schedulers/scheduling_edm_euler.py +20 -11
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete.py +3 -3
- diffusers/schedulers/scheduling_euler_discrete_flax.py +3 -3
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +20 -5
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_flow_match_lcm.py +561 -0
- diffusers/schedulers/scheduling_heun_discrete.py +2 -2
- diffusers/schedulers/scheduling_ipndm.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +2 -2
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +2 -2
- diffusers/schedulers/scheduling_karras_ve_flax.py +5 -5
- diffusers/schedulers/scheduling_lcm.py +3 -3
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
- diffusers/schedulers/scheduling_pndm.py +4 -4
- diffusers/schedulers/scheduling_pndm_flax.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +9 -9
- diffusers/schedulers/scheduling_sasolver.py +15 -15
- diffusers/schedulers/scheduling_scm.py +1 -1
- diffusers/schedulers/scheduling_sde_ve.py +1 -1
- diffusers/schedulers/scheduling_sde_ve_flax.py +2 -2
- diffusers/schedulers/scheduling_tcd.py +3 -3
- diffusers/schedulers/scheduling_unclip.py +5 -5
- diffusers/schedulers/scheduling_unipc_multistep.py +11 -11
- diffusers/schedulers/scheduling_utils.py +1 -1
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
- diffusers/training_utils.py +13 -5
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/accelerate_utils.py +1 -1
- diffusers/utils/doc_utils.py +1 -1
- diffusers/utils/dummy_pt_objects.py +120 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +225 -0
- diffusers/utils/dynamic_modules_utils.py +21 -3
- diffusers/utils/export_utils.py +1 -1
- diffusers/utils/import_utils.py +81 -18
- diffusers/utils/logging.py +1 -1
- diffusers/utils/outputs.py +2 -1
- diffusers/utils/peft_utils.py +91 -8
- diffusers/utils/state_dict_utils.py +20 -3
- diffusers/utils/testing_utils.py +59 -7
- diffusers/utils/torch_utils.py +25 -5
- diffusers/video_processor.py +2 -2
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/METADATA +3 -3
- diffusers-0.34.0.dist-info/RECORD +639 -0
- diffusers-0.33.0.dist-info/RECORD +0 -608
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/LICENSE +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/WHEEL +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.33.0.dist-info → diffusers-0.34.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,393 @@
|
|
1
|
+
# Copyright 2025 The Wan Team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import math
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ..attention import FeedForward
|
25
|
+
from ..attention_processor import Attention
|
26
|
+
from ..cache_utils import CacheMixin
|
27
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
28
|
+
from ..modeling_utils import ModelMixin
|
29
|
+
from ..normalization import FP32LayerNorm
|
30
|
+
from .transformer_wan import WanAttnProcessor2_0, WanRotaryPosEmbed, WanTimeTextImageEmbedding, WanTransformerBlock
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
class WanVACETransformerBlock(nn.Module):
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
dim: int,
|
40
|
+
ffn_dim: int,
|
41
|
+
num_heads: int,
|
42
|
+
qk_norm: str = "rms_norm_across_heads",
|
43
|
+
cross_attn_norm: bool = False,
|
44
|
+
eps: float = 1e-6,
|
45
|
+
added_kv_proj_dim: Optional[int] = None,
|
46
|
+
apply_input_projection: bool = False,
|
47
|
+
apply_output_projection: bool = False,
|
48
|
+
):
|
49
|
+
super().__init__()
|
50
|
+
|
51
|
+
# 1. Input projection
|
52
|
+
self.proj_in = None
|
53
|
+
if apply_input_projection:
|
54
|
+
self.proj_in = nn.Linear(dim, dim)
|
55
|
+
|
56
|
+
# 2. Self-attention
|
57
|
+
self.norm1 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
58
|
+
self.attn1 = Attention(
|
59
|
+
query_dim=dim,
|
60
|
+
heads=num_heads,
|
61
|
+
kv_heads=num_heads,
|
62
|
+
dim_head=dim // num_heads,
|
63
|
+
qk_norm=qk_norm,
|
64
|
+
eps=eps,
|
65
|
+
bias=True,
|
66
|
+
cross_attention_dim=None,
|
67
|
+
out_bias=True,
|
68
|
+
processor=WanAttnProcessor2_0(),
|
69
|
+
)
|
70
|
+
|
71
|
+
# 3. Cross-attention
|
72
|
+
self.attn2 = Attention(
|
73
|
+
query_dim=dim,
|
74
|
+
heads=num_heads,
|
75
|
+
kv_heads=num_heads,
|
76
|
+
dim_head=dim // num_heads,
|
77
|
+
qk_norm=qk_norm,
|
78
|
+
eps=eps,
|
79
|
+
bias=True,
|
80
|
+
cross_attention_dim=None,
|
81
|
+
out_bias=True,
|
82
|
+
added_kv_proj_dim=added_kv_proj_dim,
|
83
|
+
added_proj_bias=True,
|
84
|
+
processor=WanAttnProcessor2_0(),
|
85
|
+
)
|
86
|
+
self.norm2 = FP32LayerNorm(dim, eps, elementwise_affine=True) if cross_attn_norm else nn.Identity()
|
87
|
+
|
88
|
+
# 4. Feed-forward
|
89
|
+
self.ffn = FeedForward(dim, inner_dim=ffn_dim, activation_fn="gelu-approximate")
|
90
|
+
self.norm3 = FP32LayerNorm(dim, eps, elementwise_affine=False)
|
91
|
+
|
92
|
+
# 5. Output projection
|
93
|
+
self.proj_out = None
|
94
|
+
if apply_output_projection:
|
95
|
+
self.proj_out = nn.Linear(dim, dim)
|
96
|
+
|
97
|
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
|
98
|
+
|
99
|
+
def forward(
|
100
|
+
self,
|
101
|
+
hidden_states: torch.Tensor,
|
102
|
+
encoder_hidden_states: torch.Tensor,
|
103
|
+
control_hidden_states: torch.Tensor,
|
104
|
+
temb: torch.Tensor,
|
105
|
+
rotary_emb: torch.Tensor,
|
106
|
+
) -> torch.Tensor:
|
107
|
+
if self.proj_in is not None:
|
108
|
+
control_hidden_states = self.proj_in(control_hidden_states)
|
109
|
+
control_hidden_states = control_hidden_states + hidden_states
|
110
|
+
|
111
|
+
shift_msa, scale_msa, gate_msa, c_shift_msa, c_scale_msa, c_gate_msa = (
|
112
|
+
self.scale_shift_table + temb.float()
|
113
|
+
).chunk(6, dim=1)
|
114
|
+
|
115
|
+
# 1. Self-attention
|
116
|
+
norm_hidden_states = (self.norm1(control_hidden_states.float()) * (1 + scale_msa) + shift_msa).type_as(
|
117
|
+
control_hidden_states
|
118
|
+
)
|
119
|
+
attn_output = self.attn1(hidden_states=norm_hidden_states, rotary_emb=rotary_emb)
|
120
|
+
control_hidden_states = (control_hidden_states.float() + attn_output * gate_msa).type_as(control_hidden_states)
|
121
|
+
|
122
|
+
# 2. Cross-attention
|
123
|
+
norm_hidden_states = self.norm2(control_hidden_states.float()).type_as(control_hidden_states)
|
124
|
+
attn_output = self.attn2(hidden_states=norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
|
125
|
+
control_hidden_states = control_hidden_states + attn_output
|
126
|
+
|
127
|
+
# 3. Feed-forward
|
128
|
+
norm_hidden_states = (self.norm3(control_hidden_states.float()) * (1 + c_scale_msa) + c_shift_msa).type_as(
|
129
|
+
control_hidden_states
|
130
|
+
)
|
131
|
+
ff_output = self.ffn(norm_hidden_states)
|
132
|
+
control_hidden_states = (control_hidden_states.float() + ff_output.float() * c_gate_msa).type_as(
|
133
|
+
control_hidden_states
|
134
|
+
)
|
135
|
+
|
136
|
+
conditioning_states = None
|
137
|
+
if self.proj_out is not None:
|
138
|
+
conditioning_states = self.proj_out(control_hidden_states)
|
139
|
+
|
140
|
+
return conditioning_states, control_hidden_states
|
141
|
+
|
142
|
+
|
143
|
+
class WanVACETransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, CacheMixin):
|
144
|
+
r"""
|
145
|
+
A Transformer model for video-like data used in the Wan model.
|
146
|
+
|
147
|
+
Args:
|
148
|
+
patch_size (`Tuple[int]`, defaults to `(1, 2, 2)`):
|
149
|
+
3D patch dimensions for video embedding (t_patch, h_patch, w_patch).
|
150
|
+
num_attention_heads (`int`, defaults to `40`):
|
151
|
+
Fixed length for text embeddings.
|
152
|
+
attention_head_dim (`int`, defaults to `128`):
|
153
|
+
The number of channels in each head.
|
154
|
+
in_channels (`int`, defaults to `16`):
|
155
|
+
The number of channels in the input.
|
156
|
+
out_channels (`int`, defaults to `16`):
|
157
|
+
The number of channels in the output.
|
158
|
+
text_dim (`int`, defaults to `512`):
|
159
|
+
Input dimension for text embeddings.
|
160
|
+
freq_dim (`int`, defaults to `256`):
|
161
|
+
Dimension for sinusoidal time embeddings.
|
162
|
+
ffn_dim (`int`, defaults to `13824`):
|
163
|
+
Intermediate dimension in feed-forward network.
|
164
|
+
num_layers (`int`, defaults to `40`):
|
165
|
+
The number of layers of transformer blocks to use.
|
166
|
+
window_size (`Tuple[int]`, defaults to `(-1, -1)`):
|
167
|
+
Window size for local attention (-1 indicates global attention).
|
168
|
+
cross_attn_norm (`bool`, defaults to `True`):
|
169
|
+
Enable cross-attention normalization.
|
170
|
+
qk_norm (`bool`, defaults to `True`):
|
171
|
+
Enable query/key normalization.
|
172
|
+
eps (`float`, defaults to `1e-6`):
|
173
|
+
Epsilon value for normalization layers.
|
174
|
+
add_img_emb (`bool`, defaults to `False`):
|
175
|
+
Whether to use img_emb.
|
176
|
+
added_kv_proj_dim (`int`, *optional*, defaults to `None`):
|
177
|
+
The number of channels to use for the added key and value projections. If `None`, no projection is used.
|
178
|
+
"""
|
179
|
+
|
180
|
+
_supports_gradient_checkpointing = True
|
181
|
+
_skip_layerwise_casting_patterns = ["patch_embedding", "vace_patch_embedding", "condition_embedder", "norm"]
|
182
|
+
_no_split_modules = ["WanTransformerBlock", "WanVACETransformerBlock"]
|
183
|
+
_keep_in_fp32_modules = ["time_embedder", "scale_shift_table", "norm1", "norm2", "norm3"]
|
184
|
+
_keys_to_ignore_on_load_unexpected = ["norm_added_q"]
|
185
|
+
|
186
|
+
@register_to_config
|
187
|
+
def __init__(
|
188
|
+
self,
|
189
|
+
patch_size: Tuple[int] = (1, 2, 2),
|
190
|
+
num_attention_heads: int = 40,
|
191
|
+
attention_head_dim: int = 128,
|
192
|
+
in_channels: int = 16,
|
193
|
+
out_channels: int = 16,
|
194
|
+
text_dim: int = 4096,
|
195
|
+
freq_dim: int = 256,
|
196
|
+
ffn_dim: int = 13824,
|
197
|
+
num_layers: int = 40,
|
198
|
+
cross_attn_norm: bool = True,
|
199
|
+
qk_norm: Optional[str] = "rms_norm_across_heads",
|
200
|
+
eps: float = 1e-6,
|
201
|
+
image_dim: Optional[int] = None,
|
202
|
+
added_kv_proj_dim: Optional[int] = None,
|
203
|
+
rope_max_seq_len: int = 1024,
|
204
|
+
pos_embed_seq_len: Optional[int] = None,
|
205
|
+
vace_layers: List[int] = [0, 5, 10, 15, 20, 25, 30, 35],
|
206
|
+
vace_in_channels: int = 96,
|
207
|
+
) -> None:
|
208
|
+
super().__init__()
|
209
|
+
|
210
|
+
inner_dim = num_attention_heads * attention_head_dim
|
211
|
+
out_channels = out_channels or in_channels
|
212
|
+
|
213
|
+
if max(vace_layers) >= num_layers:
|
214
|
+
raise ValueError(f"VACE layers {vace_layers} exceed the number of transformer layers {num_layers}.")
|
215
|
+
if 0 not in vace_layers:
|
216
|
+
raise ValueError("VACE layers must include layer 0.")
|
217
|
+
|
218
|
+
# 1. Patch & position embedding
|
219
|
+
self.rope = WanRotaryPosEmbed(attention_head_dim, patch_size, rope_max_seq_len)
|
220
|
+
self.patch_embedding = nn.Conv3d(in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
221
|
+
self.vace_patch_embedding = nn.Conv3d(vace_in_channels, inner_dim, kernel_size=patch_size, stride=patch_size)
|
222
|
+
|
223
|
+
# 2. Condition embeddings
|
224
|
+
# image_embedding_dim=1280 for I2V model
|
225
|
+
self.condition_embedder = WanTimeTextImageEmbedding(
|
226
|
+
dim=inner_dim,
|
227
|
+
time_freq_dim=freq_dim,
|
228
|
+
time_proj_dim=inner_dim * 6,
|
229
|
+
text_embed_dim=text_dim,
|
230
|
+
image_embed_dim=image_dim,
|
231
|
+
pos_embed_seq_len=pos_embed_seq_len,
|
232
|
+
)
|
233
|
+
|
234
|
+
# 3. Transformer blocks
|
235
|
+
self.blocks = nn.ModuleList(
|
236
|
+
[
|
237
|
+
WanTransformerBlock(
|
238
|
+
inner_dim, ffn_dim, num_attention_heads, qk_norm, cross_attn_norm, eps, added_kv_proj_dim
|
239
|
+
)
|
240
|
+
for _ in range(num_layers)
|
241
|
+
]
|
242
|
+
)
|
243
|
+
|
244
|
+
self.vace_blocks = nn.ModuleList(
|
245
|
+
[
|
246
|
+
WanVACETransformerBlock(
|
247
|
+
inner_dim,
|
248
|
+
ffn_dim,
|
249
|
+
num_attention_heads,
|
250
|
+
qk_norm,
|
251
|
+
cross_attn_norm,
|
252
|
+
eps,
|
253
|
+
added_kv_proj_dim,
|
254
|
+
apply_input_projection=i == 0, # Layer 0 always has input projection and is in vace_layers
|
255
|
+
apply_output_projection=True,
|
256
|
+
)
|
257
|
+
for i in range(len(vace_layers))
|
258
|
+
]
|
259
|
+
)
|
260
|
+
|
261
|
+
# 4. Output norm & projection
|
262
|
+
self.norm_out = FP32LayerNorm(inner_dim, eps, elementwise_affine=False)
|
263
|
+
self.proj_out = nn.Linear(inner_dim, out_channels * math.prod(patch_size))
|
264
|
+
self.scale_shift_table = nn.Parameter(torch.randn(1, 2, inner_dim) / inner_dim**0.5)
|
265
|
+
|
266
|
+
self.gradient_checkpointing = False
|
267
|
+
|
268
|
+
def forward(
|
269
|
+
self,
|
270
|
+
hidden_states: torch.Tensor,
|
271
|
+
timestep: torch.LongTensor,
|
272
|
+
encoder_hidden_states: torch.Tensor,
|
273
|
+
encoder_hidden_states_image: Optional[torch.Tensor] = None,
|
274
|
+
control_hidden_states: torch.Tensor = None,
|
275
|
+
control_hidden_states_scale: torch.Tensor = None,
|
276
|
+
return_dict: bool = True,
|
277
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
278
|
+
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
279
|
+
if attention_kwargs is not None:
|
280
|
+
attention_kwargs = attention_kwargs.copy()
|
281
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
282
|
+
else:
|
283
|
+
lora_scale = 1.0
|
284
|
+
|
285
|
+
if USE_PEFT_BACKEND:
|
286
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
287
|
+
scale_lora_layers(self, lora_scale)
|
288
|
+
else:
|
289
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
290
|
+
logger.warning(
|
291
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
292
|
+
)
|
293
|
+
|
294
|
+
batch_size, num_channels, num_frames, height, width = hidden_states.shape
|
295
|
+
p_t, p_h, p_w = self.config.patch_size
|
296
|
+
post_patch_num_frames = num_frames // p_t
|
297
|
+
post_patch_height = height // p_h
|
298
|
+
post_patch_width = width // p_w
|
299
|
+
|
300
|
+
if control_hidden_states_scale is None:
|
301
|
+
control_hidden_states_scale = control_hidden_states.new_ones(len(self.config.vace_layers))
|
302
|
+
control_hidden_states_scale = torch.unbind(control_hidden_states_scale)
|
303
|
+
if len(control_hidden_states_scale) != len(self.config.vace_layers):
|
304
|
+
raise ValueError(
|
305
|
+
f"Length of `control_hidden_states_scale` {len(control_hidden_states_scale)} should be "
|
306
|
+
f"equal to {len(self.config.vace_layers)}."
|
307
|
+
)
|
308
|
+
|
309
|
+
# 1. Rotary position embedding
|
310
|
+
rotary_emb = self.rope(hidden_states)
|
311
|
+
|
312
|
+
# 2. Patch embedding
|
313
|
+
hidden_states = self.patch_embedding(hidden_states)
|
314
|
+
hidden_states = hidden_states.flatten(2).transpose(1, 2)
|
315
|
+
|
316
|
+
control_hidden_states = self.vace_patch_embedding(control_hidden_states)
|
317
|
+
control_hidden_states = control_hidden_states.flatten(2).transpose(1, 2)
|
318
|
+
control_hidden_states_padding = control_hidden_states.new_zeros(
|
319
|
+
batch_size, hidden_states.size(1) - control_hidden_states.size(1), control_hidden_states.size(2)
|
320
|
+
)
|
321
|
+
control_hidden_states = torch.cat([control_hidden_states, control_hidden_states_padding], dim=1)
|
322
|
+
|
323
|
+
# 3. Time embedding
|
324
|
+
temb, timestep_proj, encoder_hidden_states, encoder_hidden_states_image = self.condition_embedder(
|
325
|
+
timestep, encoder_hidden_states, encoder_hidden_states_image
|
326
|
+
)
|
327
|
+
timestep_proj = timestep_proj.unflatten(1, (6, -1))
|
328
|
+
|
329
|
+
# 4. Image embedding
|
330
|
+
if encoder_hidden_states_image is not None:
|
331
|
+
encoder_hidden_states = torch.concat([encoder_hidden_states_image, encoder_hidden_states], dim=1)
|
332
|
+
|
333
|
+
# 5. Transformer blocks
|
334
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
335
|
+
# Prepare VACE hints
|
336
|
+
control_hidden_states_list = []
|
337
|
+
for i, block in enumerate(self.vace_blocks):
|
338
|
+
conditioning_states, control_hidden_states = self._gradient_checkpointing_func(
|
339
|
+
block, hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
340
|
+
)
|
341
|
+
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
342
|
+
control_hidden_states_list = control_hidden_states_list[::-1]
|
343
|
+
|
344
|
+
for i, block in enumerate(self.blocks):
|
345
|
+
hidden_states = self._gradient_checkpointing_func(
|
346
|
+
block, hidden_states, encoder_hidden_states, timestep_proj, rotary_emb
|
347
|
+
)
|
348
|
+
if i in self.config.vace_layers:
|
349
|
+
control_hint, scale = control_hidden_states_list.pop()
|
350
|
+
hidden_states = hidden_states + control_hint * scale
|
351
|
+
else:
|
352
|
+
# Prepare VACE hints
|
353
|
+
control_hidden_states_list = []
|
354
|
+
for i, block in enumerate(self.vace_blocks):
|
355
|
+
conditioning_states, control_hidden_states = block(
|
356
|
+
hidden_states, encoder_hidden_states, control_hidden_states, timestep_proj, rotary_emb
|
357
|
+
)
|
358
|
+
control_hidden_states_list.append((conditioning_states, control_hidden_states_scale[i]))
|
359
|
+
control_hidden_states_list = control_hidden_states_list[::-1]
|
360
|
+
|
361
|
+
for i, block in enumerate(self.blocks):
|
362
|
+
hidden_states = block(hidden_states, encoder_hidden_states, timestep_proj, rotary_emb)
|
363
|
+
if i in self.config.vace_layers:
|
364
|
+
control_hint, scale = control_hidden_states_list.pop()
|
365
|
+
hidden_states = hidden_states + control_hint * scale
|
366
|
+
|
367
|
+
# 6. Output norm, projection & unpatchify
|
368
|
+
shift, scale = (self.scale_shift_table + temb.unsqueeze(1)).chunk(2, dim=1)
|
369
|
+
|
370
|
+
# Move the shift and scale tensors to the same device as hidden_states.
|
371
|
+
# When using multi-GPU inference via accelerate these will be on the
|
372
|
+
# first device rather than the last device, which hidden_states ends up
|
373
|
+
# on.
|
374
|
+
shift = shift.to(hidden_states.device)
|
375
|
+
scale = scale.to(hidden_states.device)
|
376
|
+
|
377
|
+
hidden_states = (self.norm_out(hidden_states.float()) * (1 + scale) + shift).type_as(hidden_states)
|
378
|
+
hidden_states = self.proj_out(hidden_states)
|
379
|
+
|
380
|
+
hidden_states = hidden_states.reshape(
|
381
|
+
batch_size, post_patch_num_frames, post_patch_height, post_patch_width, p_t, p_h, p_w, -1
|
382
|
+
)
|
383
|
+
hidden_states = hidden_states.permute(0, 7, 1, 4, 2, 5, 3, 6)
|
384
|
+
output = hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
385
|
+
|
386
|
+
if USE_PEFT_BACKEND:
|
387
|
+
# remove `lora_scale` from each PEFT layer
|
388
|
+
unscale_lora_layers(self, lora_scale)
|
389
|
+
|
390
|
+
if not return_dict:
|
391
|
+
return (output,)
|
392
|
+
|
393
|
+
return Transformer2DModelOutput(sample=output)
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -22,7 +22,7 @@ from ..resnet_flax import FlaxDownsample2D, FlaxResnetBlock2D, FlaxUpsample2D
|
|
22
22
|
class FlaxCrossAttnDownBlock2D(nn.Module):
|
23
23
|
r"""
|
24
24
|
Cross Attention 2D Downsizing block - original architecture from Unet transformers:
|
25
|
-
https://
|
25
|
+
https://huggingface.co/papers/2103.06104
|
26
26
|
|
27
27
|
Parameters:
|
28
28
|
in_channels (:obj:`int`):
|
@@ -38,7 +38,7 @@ class FlaxCrossAttnDownBlock2D(nn.Module):
|
|
38
38
|
add_downsample (:obj:`bool`, *optional*, defaults to `True`):
|
39
39
|
Whether to add downsampling layer before each final output
|
40
40
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
41
|
-
enable memory efficient attention https://
|
41
|
+
enable memory efficient attention https://huggingface.co/papers/2112.05682
|
42
42
|
split_head_dim (`bool`, *optional*, defaults to `False`):
|
43
43
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
44
44
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
@@ -169,7 +169,7 @@ class FlaxDownBlock2D(nn.Module):
|
|
169
169
|
class FlaxCrossAttnUpBlock2D(nn.Module):
|
170
170
|
r"""
|
171
171
|
Cross Attention 2D Upsampling block - original architecture from Unet transformers:
|
172
|
-
https://
|
172
|
+
https://huggingface.co/papers/2103.06104
|
173
173
|
|
174
174
|
Parameters:
|
175
175
|
in_channels (:obj:`int`):
|
@@ -185,7 +185,7 @@ class FlaxCrossAttnUpBlock2D(nn.Module):
|
|
185
185
|
add_upsample (:obj:`bool`, *optional*, defaults to `True`):
|
186
186
|
Whether to add upsampling layer before each final output
|
187
187
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
188
|
-
enable memory efficient attention https://
|
188
|
+
enable memory efficient attention https://huggingface.co/papers/2112.05682
|
189
189
|
split_head_dim (`bool`, *optional*, defaults to `False`):
|
190
190
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
191
191
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
@@ -324,7 +324,8 @@ class FlaxUpBlock2D(nn.Module):
|
|
324
324
|
|
325
325
|
class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
326
326
|
r"""
|
327
|
-
Cross Attention 2D Mid-level block - original architecture from Unet transformers:
|
327
|
+
Cross Attention 2D Mid-level block - original architecture from Unet transformers:
|
328
|
+
https://huggingface.co/papers/2103.06104
|
328
329
|
|
329
330
|
Parameters:
|
330
331
|
in_channels (:obj:`int`):
|
@@ -336,7 +337,7 @@ class FlaxUNetMidBlock2DCrossAttn(nn.Module):
|
|
336
337
|
num_attention_heads (:obj:`int`, *optional*, defaults to 1):
|
337
338
|
Number of attention heads of each spatial transformer block
|
338
339
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
339
|
-
enable memory efficient attention https://
|
340
|
+
enable memory efficient attention https://huggingface.co/papers/2112.05682
|
340
341
|
split_head_dim (`bool`, *optional*, defaults to `False`):
|
341
342
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
342
343
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -835,7 +835,7 @@ class UNet2DConditionModel(
|
|
835
835
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
836
836
|
|
837
837
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
|
838
|
-
r"""Enables the FreeU mechanism from https://
|
838
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
839
839
|
|
840
840
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
841
841
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -94,7 +94,7 @@ class FlaxUNet2DConditionModel(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
94
94
|
Whether to flip the sin to cos in the time embedding.
|
95
95
|
freq_shift (`int`, *optional*, defaults to 0): The frequency shift to apply to the time embedding.
|
96
96
|
use_memory_efficient_attention (`bool`, *optional*, defaults to `False`):
|
97
|
-
Enable memory efficient attention as described [here](https://
|
97
|
+
Enable memory efficient attention as described [here](https://huggingface.co/papers/2112.05682).
|
98
98
|
split_head_dim (`bool`, *optional*, defaults to `False`):
|
99
99
|
Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
|
100
100
|
enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
|
@@ -1,5 +1,5 @@
|
|
1
|
-
# Copyright
|
2
|
-
# Copyright
|
1
|
+
# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2
|
+
# Copyright 2025 The ModelScope Team.
|
3
3
|
#
|
4
4
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
5
5
|
# you may not use this file except in compliance with the License.
|
@@ -470,7 +470,7 @@ class UNet3DConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin)
|
|
470
470
|
|
471
471
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
472
472
|
def enable_freeu(self, s1, s2, b1, b2):
|
473
|
-
r"""Enables the FreeU mechanism from https://
|
473
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
474
474
|
|
475
475
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
476
476
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 Alibaba DAMO-VILAB and The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -154,7 +154,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
154
154
|
# of that, we used `num_attention_heads` for arguments that actually denote attention head dimension. This
|
155
155
|
# is why we ignore `num_attention_heads` and calculate it from `attention_head_dims` below.
|
156
156
|
# This is still an incorrect way of calculating `num_attention_heads` but we need to stick to it
|
157
|
-
# without running proper
|
157
|
+
# without running proper deprecation cycles for the {down,mid,up} blocks which are a
|
158
158
|
# part of the public API.
|
159
159
|
num_attention_heads = attention_head_dim
|
160
160
|
|
@@ -434,7 +434,7 @@ class I2VGenXLUNet(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
|
434
434
|
|
435
435
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
436
436
|
def enable_freeu(self, s1, s2, b1, b2):
|
437
|
-
r"""Enables the FreeU mechanism from https://
|
437
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
438
438
|
|
439
439
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
440
440
|
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -1873,7 +1873,7 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
1873
1873
|
|
1874
1874
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
|
1875
1875
|
def enable_freeu(self, s1: float, s2: float, b1: float, b2: float) -> None:
|
1876
|
-
r"""Enables the FreeU mechanism from https://
|
1876
|
+
r"""Enables the FreeU mechanism from https://huggingface.co/papers/2309.11497.
|
1877
1877
|
|
1878
1878
|
The suffixes after the scaling factors represent the stage blocks where they are being applied.
|
1879
1879
|
|
diffusers/models/upsampling.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -358,7 +358,7 @@ class KUpsample2D(nn.Module):
|
|
358
358
|
|
359
359
|
class CogVideoXUpsample3D(nn.Module):
|
360
360
|
r"""
|
361
|
-
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper
|
361
|
+
A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper release.
|
362
362
|
|
363
363
|
Args:
|
364
364
|
in_channels (`int`):
|
diffusers/models/vae_flax.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -769,7 +769,7 @@ class FlaxAutoencoderKL(nn.Module, FlaxModelMixin, ConfigMixin):
|
|
769
769
|
model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
|
770
770
|
diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
|
771
771
|
/ scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
|
772
|
-
Synthesis with Latent Diffusion Models](https://
|
772
|
+
Synthesis with Latent Diffusion Models](https://huggingface.co/papers/2112.10752) paper.
|
773
773
|
dtype (`jnp.dtype`, *optional*, defaults to `jnp.float32`):
|
774
774
|
The `dtype` of the parameters.
|
775
775
|
"""
|
diffusers/models/vq_model.py
CHANGED