diffusers 0.32.1__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +198 -28
- diffusers/loaders/lora_conversion_utils.py +679 -44
- diffusers/loaders/lora_pipeline.py +1963 -801
- diffusers/loaders/peft.py +169 -84
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +653 -75
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +22 -32
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +409 -49
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +10 -2
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +14 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.1.dist-info/RECORD +0 -550
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.1.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -213,7 +213,9 @@ class Attention(nn.Module):
|
|
213
213
|
self.norm_q = LpNorm(p=2, dim=-1, eps=eps)
|
214
214
|
self.norm_k = LpNorm(p=2, dim=-1, eps=eps)
|
215
215
|
else:
|
216
|
-
raise ValueError(
|
216
|
+
raise ValueError(
|
217
|
+
f"unknown qk_norm: {qk_norm}. Should be one of None, 'layer_norm', 'fp32_layer_norm', 'layer_norm_across_heads', 'rms_norm', 'rms_norm_across_heads', 'l2'."
|
218
|
+
)
|
217
219
|
|
218
220
|
if cross_attention_norm is None:
|
219
221
|
self.norm_cross = None
|
@@ -272,12 +274,20 @@ class Attention(nn.Module):
|
|
272
274
|
self.to_add_out = None
|
273
275
|
|
274
276
|
if qk_norm is not None and added_kv_proj_dim is not None:
|
275
|
-
if qk_norm == "
|
277
|
+
if qk_norm == "layer_norm":
|
278
|
+
self.norm_added_q = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
279
|
+
self.norm_added_k = nn.LayerNorm(dim_head, eps=eps, elementwise_affine=elementwise_affine)
|
280
|
+
elif qk_norm == "fp32_layer_norm":
|
276
281
|
self.norm_added_q = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
277
282
|
self.norm_added_k = FP32LayerNorm(dim_head, elementwise_affine=False, bias=False, eps=eps)
|
278
283
|
elif qk_norm == "rms_norm":
|
279
284
|
self.norm_added_q = RMSNorm(dim_head, eps=eps)
|
280
285
|
self.norm_added_k = RMSNorm(dim_head, eps=eps)
|
286
|
+
elif qk_norm == "rms_norm_across_heads":
|
287
|
+
# Wan applies qk norm across all heads
|
288
|
+
# Wan also doesn't apply a q norm
|
289
|
+
self.norm_added_q = None
|
290
|
+
self.norm_added_k = RMSNorm(dim_head * kv_heads, eps=eps)
|
281
291
|
else:
|
282
292
|
raise ValueError(
|
283
293
|
f"unknown qk_norm: {qk_norm}. Should be one of `None,'layer_norm','fp32_layer_norm','rms_norm'`"
|
@@ -297,7 +307,10 @@ class Attention(nn.Module):
|
|
297
307
|
self.set_processor(processor)
|
298
308
|
|
299
309
|
def set_use_xla_flash_attention(
|
300
|
-
self,
|
310
|
+
self,
|
311
|
+
use_xla_flash_attention: bool,
|
312
|
+
partition_spec: Optional[Tuple[Optional[str], ...]] = None,
|
313
|
+
is_flux=False,
|
301
314
|
) -> None:
|
302
315
|
r"""
|
303
316
|
Set whether to use xla flash attention from `torch_xla` or not.
|
@@ -316,7 +329,10 @@ class Attention(nn.Module):
|
|
316
329
|
elif is_spmd() and is_torch_xla_version("<", "2.4"):
|
317
330
|
raise "flash attention pallas kernel using SPMD is supported from torch_xla version 2.4"
|
318
331
|
else:
|
319
|
-
|
332
|
+
if is_flux:
|
333
|
+
processor = XLAFluxFlashAttnProcessor2_0(partition_spec)
|
334
|
+
else:
|
335
|
+
processor = XLAFlashAttnProcessor2_0(partition_spec)
|
320
336
|
else:
|
321
337
|
processor = (
|
322
338
|
AttnProcessor2_0() if hasattr(F, "scaled_dot_product_attention") and self.scale_qk else AttnProcessor()
|
@@ -399,11 +415,12 @@ class Attention(nn.Module):
|
|
399
415
|
else:
|
400
416
|
try:
|
401
417
|
# Make sure we can run the memory efficient attention
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
)
|
418
|
+
dtype = None
|
419
|
+
if attention_op is not None:
|
420
|
+
op_fw, op_bw = attention_op
|
421
|
+
dtype, *_ = op_fw.SUPPORTED_DTYPES
|
422
|
+
q = torch.randn((1, 2, 40), device="cuda", dtype=dtype)
|
423
|
+
_ = xformers.ops.memory_efficient_attention(q, q, q)
|
407
424
|
except Exception as e:
|
408
425
|
raise e
|
409
426
|
|
@@ -724,10 +741,14 @@ class Attention(nn.Module):
|
|
724
741
|
|
725
742
|
if out_dim == 3:
|
726
743
|
if attention_mask.shape[0] < batch_size * head_size:
|
727
|
-
attention_mask = attention_mask.repeat_interleave(
|
744
|
+
attention_mask = attention_mask.repeat_interleave(
|
745
|
+
head_size, dim=0, output_size=attention_mask.shape[0] * head_size
|
746
|
+
)
|
728
747
|
elif out_dim == 4:
|
729
748
|
attention_mask = attention_mask.unsqueeze(1)
|
730
|
-
attention_mask = attention_mask.repeat_interleave(
|
749
|
+
attention_mask = attention_mask.repeat_interleave(
|
750
|
+
head_size, dim=1, output_size=attention_mask.shape[1] * head_size
|
751
|
+
)
|
731
752
|
|
732
753
|
return attention_mask
|
733
754
|
|
@@ -899,7 +920,7 @@ class SanaMultiscaleLinearAttention(nn.Module):
|
|
899
920
|
scores = torch.matmul(key.transpose(-1, -2), query)
|
900
921
|
scores = scores.to(dtype=torch.float32)
|
901
922
|
scores = scores / (torch.sum(scores, dim=2, keepdim=True) + self.eps)
|
902
|
-
hidden_states = torch.matmul(value, scores)
|
923
|
+
hidden_states = torch.matmul(value, scores.to(value.dtype))
|
903
924
|
return hidden_states
|
904
925
|
|
905
926
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
@@ -1401,7 +1422,7 @@ class JointAttnProcessor2_0:
|
|
1401
1422
|
|
1402
1423
|
def __init__(self):
|
1403
1424
|
if not hasattr(F, "scaled_dot_product_attention"):
|
1404
|
-
raise ImportError("
|
1425
|
+
raise ImportError("JointAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
|
1405
1426
|
|
1406
1427
|
def __call__(
|
1407
1428
|
self,
|
@@ -2321,6 +2342,7 @@ class FluxAttnProcessor2_0:
|
|
2321
2342
|
hidden_states = F.scaled_dot_product_attention(
|
2322
2343
|
query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
|
2323
2344
|
)
|
2345
|
+
|
2324
2346
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2325
2347
|
hidden_states = hidden_states.to(query.dtype)
|
2326
2348
|
|
@@ -2522,6 +2544,7 @@ class FusedFluxAttnProcessor2_0:
|
|
2522
2544
|
key = apply_rotary_emb(key, image_rotary_emb)
|
2523
2545
|
|
2524
2546
|
hidden_states = F.scaled_dot_product_attention(query, key, value, dropout_p=0.0, is_causal=False)
|
2547
|
+
|
2525
2548
|
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
2526
2549
|
hidden_states = hidden_states.to(query.dtype)
|
2527
2550
|
|
@@ -2771,9 +2794,8 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
|
2771
2794
|
|
2772
2795
|
# IP-adapter
|
2773
2796
|
ip_query = hidden_states_query_proj
|
2774
|
-
ip_attn_output =
|
2775
|
-
|
2776
|
-
# TODO: support for multiple adapters
|
2797
|
+
ip_attn_output = torch.zeros_like(hidden_states)
|
2798
|
+
|
2777
2799
|
for current_ip_hidden_states, scale, to_k_ip, to_v_ip in zip(
|
2778
2800
|
ip_hidden_states, self.scale, self.to_k_ip, self.to_v_ip
|
2779
2801
|
):
|
@@ -2784,12 +2806,14 @@ class FluxIPAdapterJointAttnProcessor2_0(torch.nn.Module):
|
|
2784
2806
|
ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
2785
2807
|
# the output of sdp = (batch, num_heads, seq_len, head_dim)
|
2786
2808
|
# TODO: add support for attn.scale when we move to Torch 2.1
|
2787
|
-
|
2809
|
+
current_ip_hidden_states = F.scaled_dot_product_attention(
|
2788
2810
|
ip_query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
|
2789
2811
|
)
|
2790
|
-
|
2791
|
-
|
2792
|
-
|
2812
|
+
current_ip_hidden_states = current_ip_hidden_states.transpose(1, 2).reshape(
|
2813
|
+
batch_size, -1, attn.heads * head_dim
|
2814
|
+
)
|
2815
|
+
current_ip_hidden_states = current_ip_hidden_states.to(ip_query.dtype)
|
2816
|
+
ip_attn_output += scale * current_ip_hidden_states
|
2793
2817
|
|
2794
2818
|
return hidden_states, encoder_hidden_states, ip_attn_output
|
2795
2819
|
else:
|
@@ -2818,9 +2842,7 @@ class CogVideoXAttnProcessor2_0:
|
|
2818
2842
|
|
2819
2843
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
2820
2844
|
|
2821
|
-
batch_size, sequence_length, _ =
|
2822
|
-
hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
2823
|
-
)
|
2845
|
+
batch_size, sequence_length, _ = hidden_states.shape
|
2824
2846
|
|
2825
2847
|
if attention_mask is not None:
|
2826
2848
|
attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
|
@@ -3148,6 +3170,11 @@ class AttnProcessorNPU:
|
|
3148
3170
|
# scaled_dot_product_attention expects attention_mask shape to be
|
3149
3171
|
# (batch, heads, source_length, target_length)
|
3150
3172
|
attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
|
3173
|
+
attention_mask = attention_mask.repeat(1, 1, hidden_states.shape[1], 1)
|
3174
|
+
if attention_mask.dtype == torch.bool:
|
3175
|
+
attention_mask = torch.logical_not(attention_mask.bool())
|
3176
|
+
else:
|
3177
|
+
attention_mask = attention_mask.bool()
|
3151
3178
|
|
3152
3179
|
if attn.group_norm is not None:
|
3153
3180
|
hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
|
@@ -3422,6 +3449,106 @@ class XLAFlashAttnProcessor2_0:
|
|
3422
3449
|
return hidden_states
|
3423
3450
|
|
3424
3451
|
|
3452
|
+
class XLAFluxFlashAttnProcessor2_0:
|
3453
|
+
r"""
|
3454
|
+
Processor for implementing scaled dot-product attention with pallas flash attention kernel if using `torch_xla`.
|
3455
|
+
"""
|
3456
|
+
|
3457
|
+
def __init__(self, partition_spec: Optional[Tuple[Optional[str], ...]] = None):
|
3458
|
+
if not hasattr(F, "scaled_dot_product_attention"):
|
3459
|
+
raise ImportError(
|
3460
|
+
"XLAFlashAttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
|
3461
|
+
)
|
3462
|
+
if is_torch_xla_version("<", "2.3"):
|
3463
|
+
raise ImportError("XLA flash attention requires torch_xla version >= 2.3.")
|
3464
|
+
if is_spmd() and is_torch_xla_version("<", "2.4"):
|
3465
|
+
raise ImportError("SPMD support for XLA flash attention needs torch_xla version >= 2.4.")
|
3466
|
+
self.partition_spec = partition_spec
|
3467
|
+
|
3468
|
+
def __call__(
|
3469
|
+
self,
|
3470
|
+
attn: Attention,
|
3471
|
+
hidden_states: torch.FloatTensor,
|
3472
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
3473
|
+
attention_mask: Optional[torch.FloatTensor] = None,
|
3474
|
+
image_rotary_emb: Optional[torch.Tensor] = None,
|
3475
|
+
) -> torch.FloatTensor:
|
3476
|
+
batch_size, _, _ = hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
|
3477
|
+
|
3478
|
+
# `sample` projections.
|
3479
|
+
query = attn.to_q(hidden_states)
|
3480
|
+
key = attn.to_k(hidden_states)
|
3481
|
+
value = attn.to_v(hidden_states)
|
3482
|
+
|
3483
|
+
inner_dim = key.shape[-1]
|
3484
|
+
head_dim = inner_dim // attn.heads
|
3485
|
+
|
3486
|
+
query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3487
|
+
key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3488
|
+
value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
|
3489
|
+
|
3490
|
+
if attn.norm_q is not None:
|
3491
|
+
query = attn.norm_q(query)
|
3492
|
+
if attn.norm_k is not None:
|
3493
|
+
key = attn.norm_k(key)
|
3494
|
+
|
3495
|
+
# the attention in FluxSingleTransformerBlock does not use `encoder_hidden_states`
|
3496
|
+
if encoder_hidden_states is not None:
|
3497
|
+
# `context` projections.
|
3498
|
+
encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
|
3499
|
+
encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
|
3500
|
+
encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
|
3501
|
+
|
3502
|
+
encoder_hidden_states_query_proj = encoder_hidden_states_query_proj.view(
|
3503
|
+
batch_size, -1, attn.heads, head_dim
|
3504
|
+
).transpose(1, 2)
|
3505
|
+
encoder_hidden_states_key_proj = encoder_hidden_states_key_proj.view(
|
3506
|
+
batch_size, -1, attn.heads, head_dim
|
3507
|
+
).transpose(1, 2)
|
3508
|
+
encoder_hidden_states_value_proj = encoder_hidden_states_value_proj.view(
|
3509
|
+
batch_size, -1, attn.heads, head_dim
|
3510
|
+
).transpose(1, 2)
|
3511
|
+
|
3512
|
+
if attn.norm_added_q is not None:
|
3513
|
+
encoder_hidden_states_query_proj = attn.norm_added_q(encoder_hidden_states_query_proj)
|
3514
|
+
if attn.norm_added_k is not None:
|
3515
|
+
encoder_hidden_states_key_proj = attn.norm_added_k(encoder_hidden_states_key_proj)
|
3516
|
+
|
3517
|
+
# attention
|
3518
|
+
query = torch.cat([encoder_hidden_states_query_proj, query], dim=2)
|
3519
|
+
key = torch.cat([encoder_hidden_states_key_proj, key], dim=2)
|
3520
|
+
value = torch.cat([encoder_hidden_states_value_proj, value], dim=2)
|
3521
|
+
|
3522
|
+
if image_rotary_emb is not None:
|
3523
|
+
from .embeddings import apply_rotary_emb
|
3524
|
+
|
3525
|
+
query = apply_rotary_emb(query, image_rotary_emb)
|
3526
|
+
key = apply_rotary_emb(key, image_rotary_emb)
|
3527
|
+
|
3528
|
+
query /= math.sqrt(head_dim)
|
3529
|
+
hidden_states = flash_attention(query, key, value, causal=False)
|
3530
|
+
|
3531
|
+
hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
|
3532
|
+
hidden_states = hidden_states.to(query.dtype)
|
3533
|
+
|
3534
|
+
if encoder_hidden_states is not None:
|
3535
|
+
encoder_hidden_states, hidden_states = (
|
3536
|
+
hidden_states[:, : encoder_hidden_states.shape[1]],
|
3537
|
+
hidden_states[:, encoder_hidden_states.shape[1] :],
|
3538
|
+
)
|
3539
|
+
|
3540
|
+
# linear proj
|
3541
|
+
hidden_states = attn.to_out[0](hidden_states)
|
3542
|
+
# dropout
|
3543
|
+
hidden_states = attn.to_out[1](hidden_states)
|
3544
|
+
|
3545
|
+
encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
|
3546
|
+
|
3547
|
+
return hidden_states, encoder_hidden_states
|
3548
|
+
else:
|
3549
|
+
return hidden_states
|
3550
|
+
|
3551
|
+
|
3425
3552
|
class MochiVaeAttnProcessor2_0:
|
3426
3553
|
r"""
|
3427
3554
|
Attention processor used in Mochi VAE.
|
@@ -3583,8 +3710,10 @@ class StableAudioAttnProcessor2_0:
|
|
3583
3710
|
if kv_heads != attn.heads:
|
3584
3711
|
# if GQA or MQA, repeat the key/value heads to reach the number of query heads.
|
3585
3712
|
heads_per_kv_head = attn.heads // kv_heads
|
3586
|
-
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1)
|
3587
|
-
value = torch.repeat_interleave(
|
3713
|
+
key = torch.repeat_interleave(key, heads_per_kv_head, dim=1, output_size=key.shape[1] * heads_per_kv_head)
|
3714
|
+
value = torch.repeat_interleave(
|
3715
|
+
value, heads_per_kv_head, dim=1, output_size=value.shape[1] * heads_per_kv_head
|
3716
|
+
)
|
3588
3717
|
|
3589
3718
|
if attn.norm_q is not None:
|
3590
3719
|
query = attn.norm_q(query)
|
@@ -4839,6 +4968,8 @@ class IPAdapterAttnProcessor(nn.Module):
|
|
4839
4968
|
)
|
4840
4969
|
else:
|
4841
4970
|
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
4971
|
+
if mask is None:
|
4972
|
+
continue
|
4842
4973
|
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
4843
4974
|
raise ValueError(
|
4844
4975
|
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
@@ -5056,6 +5187,8 @@ class IPAdapterAttnProcessor2_0(torch.nn.Module):
|
|
5056
5187
|
)
|
5057
5188
|
else:
|
5058
5189
|
for index, (mask, scale, ip_state) in enumerate(zip(ip_adapter_masks, self.scale, ip_hidden_states)):
|
5190
|
+
if mask is None:
|
5191
|
+
continue
|
5059
5192
|
if not isinstance(mask, torch.Tensor) or mask.ndim != 4:
|
5060
5193
|
raise ValueError(
|
5061
5194
|
"Each element of the ip_adapter_masks array should be a tensor with shape "
|
@@ -5887,6 +6020,11 @@ class SanaLinearAttnProcessor2_0:
|
|
5887
6020
|
key = attn.to_k(encoder_hidden_states)
|
5888
6021
|
value = attn.to_v(encoder_hidden_states)
|
5889
6022
|
|
6023
|
+
if attn.norm_q is not None:
|
6024
|
+
query = attn.norm_q(query)
|
6025
|
+
if attn.norm_k is not None:
|
6026
|
+
key = attn.norm_k(key)
|
6027
|
+
|
5890
6028
|
query = query.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
5891
6029
|
key = key.transpose(1, 2).unflatten(1, (attn.heads, -1)).transpose(2, 3)
|
5892
6030
|
value = value.transpose(1, 2).unflatten(1, (attn.heads, -1))
|
@@ -0,0 +1,169 @@
|
|
1
|
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import importlib
|
16
|
+
import os
|
17
|
+
from typing import Optional, Union
|
18
|
+
|
19
|
+
from huggingface_hub.utils import validate_hf_hub_args
|
20
|
+
|
21
|
+
from ..configuration_utils import ConfigMixin
|
22
|
+
|
23
|
+
|
24
|
+
class AutoModel(ConfigMixin):
|
25
|
+
config_name = "config.json"
|
26
|
+
|
27
|
+
def __init__(self, *args, **kwargs):
|
28
|
+
raise EnvironmentError(
|
29
|
+
f"{self.__class__.__name__} is designed to be instantiated "
|
30
|
+
f"using the `{self.__class__.__name__}.from_pretrained(pretrained_model_name_or_path)` or "
|
31
|
+
f"`{self.__class__.__name__}.from_pipe(pipeline)` methods."
|
32
|
+
)
|
33
|
+
|
34
|
+
@classmethod
|
35
|
+
@validate_hf_hub_args
|
36
|
+
def from_pretrained(cls, pretrained_model_or_path: Optional[Union[str, os.PathLike]] = None, **kwargs):
|
37
|
+
r"""
|
38
|
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
39
|
+
|
40
|
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
41
|
+
train the model, set it back in training mode with `model.train()`.
|
42
|
+
|
43
|
+
Parameters:
|
44
|
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
45
|
+
Can be either:
|
46
|
+
|
47
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
48
|
+
the Hub.
|
49
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
50
|
+
with [`~ModelMixin.save_pretrained`].
|
51
|
+
|
52
|
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
53
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
54
|
+
is not used.
|
55
|
+
torch_dtype (`str` or `torch.dtype`, *optional*):
|
56
|
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
57
|
+
dtype is automatically derived from the model's weights.
|
58
|
+
force_download (`bool`, *optional*, defaults to `False`):
|
59
|
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
60
|
+
cached versions if they exist.
|
61
|
+
proxies (`Dict[str, str]`, *optional*):
|
62
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
63
|
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
64
|
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
65
|
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
66
|
+
local_files_only(`bool`, *optional*, defaults to `False`):
|
67
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
68
|
+
won't be downloaded from the Hub.
|
69
|
+
token (`str` or *bool*, *optional*):
|
70
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
71
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
72
|
+
revision (`str`, *optional*, defaults to `"main"`):
|
73
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
74
|
+
allowed by Git.
|
75
|
+
from_flax (`bool`, *optional*, defaults to `False`):
|
76
|
+
Load the model weights from a Flax checkpoint save file.
|
77
|
+
subfolder (`str`, *optional*, defaults to `""`):
|
78
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
79
|
+
mirror (`str`, *optional*):
|
80
|
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
81
|
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
82
|
+
information.
|
83
|
+
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
84
|
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
85
|
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
86
|
+
same device. Defaults to `None`, meaning that the model will be loaded on CPU.
|
87
|
+
|
88
|
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
89
|
+
more information about each option see [designing a device
|
90
|
+
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
91
|
+
max_memory (`Dict`, *optional*):
|
92
|
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
93
|
+
each GPU and the available CPU RAM if unset.
|
94
|
+
offload_folder (`str` or `os.PathLike`, *optional*):
|
95
|
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
96
|
+
offload_state_dict (`bool`, *optional*):
|
97
|
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
98
|
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
99
|
+
when there is some disk offload.
|
100
|
+
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
101
|
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
102
|
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
103
|
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
104
|
+
argument to `True` will raise an error.
|
105
|
+
variant (`str`, *optional*):
|
106
|
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
107
|
+
loading `from_flax`.
|
108
|
+
use_safetensors (`bool`, *optional*, defaults to `None`):
|
109
|
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
110
|
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
111
|
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
112
|
+
disable_mmap ('bool', *optional*, defaults to 'False'):
|
113
|
+
Whether to disable mmap when loading a Safetensors model. This option can perform better when the model
|
114
|
+
is on a network mount or hard drive, which may not handle the seeky-ness of mmap very well.
|
115
|
+
|
116
|
+
<Tip>
|
117
|
+
|
118
|
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
119
|
+
`huggingface-cli login`. You can also activate the special
|
120
|
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
121
|
+
firewalled environment.
|
122
|
+
|
123
|
+
</Tip>
|
124
|
+
|
125
|
+
Example:
|
126
|
+
|
127
|
+
```py
|
128
|
+
from diffusers import AutoModel
|
129
|
+
|
130
|
+
unet = AutoModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
131
|
+
```
|
132
|
+
|
133
|
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
134
|
+
|
135
|
+
```bash
|
136
|
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
137
|
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
138
|
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
139
|
+
```
|
140
|
+
"""
|
141
|
+
cache_dir = kwargs.pop("cache_dir", None)
|
142
|
+
force_download = kwargs.pop("force_download", False)
|
143
|
+
proxies = kwargs.pop("proxies", None)
|
144
|
+
token = kwargs.pop("token", None)
|
145
|
+
local_files_only = kwargs.pop("local_files_only", False)
|
146
|
+
revision = kwargs.pop("revision", None)
|
147
|
+
subfolder = kwargs.pop("subfolder", None)
|
148
|
+
|
149
|
+
load_config_kwargs = {
|
150
|
+
"cache_dir": cache_dir,
|
151
|
+
"force_download": force_download,
|
152
|
+
"proxies": proxies,
|
153
|
+
"token": token,
|
154
|
+
"local_files_only": local_files_only,
|
155
|
+
"revision": revision,
|
156
|
+
"subfolder": subfolder,
|
157
|
+
}
|
158
|
+
|
159
|
+
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
160
|
+
orig_class_name = config["_class_name"]
|
161
|
+
|
162
|
+
library = importlib.import_module("diffusers")
|
163
|
+
|
164
|
+
model_cls = getattr(library, orig_class_name, None)
|
165
|
+
if model_cls is None:
|
166
|
+
raise ValueError(f"AutoModel can't find a model linked to {orig_class_name}.")
|
167
|
+
|
168
|
+
kwargs = {**load_config_kwargs, **kwargs}
|
169
|
+
return model_cls.from_pretrained(pretrained_model_or_path, **kwargs)
|
@@ -5,8 +5,10 @@ from .autoencoder_kl_allegro import AutoencoderKLAllegro
|
|
5
5
|
from .autoencoder_kl_cogvideox import AutoencoderKLCogVideoX
|
6
6
|
from .autoencoder_kl_hunyuan_video import AutoencoderKLHunyuanVideo
|
7
7
|
from .autoencoder_kl_ltx import AutoencoderKLLTXVideo
|
8
|
+
from .autoencoder_kl_magvit import AutoencoderKLMagvit
|
8
9
|
from .autoencoder_kl_mochi import AutoencoderKLMochi
|
9
10
|
from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
|
11
|
+
from .autoencoder_kl_wan import AutoencoderKLWan
|
10
12
|
from .autoencoder_oobleck import AutoencoderOobleck
|
11
13
|
from .autoencoder_tiny import AutoencoderTiny
|
12
14
|
from .consistency_decoder_vae import ConsistencyDecoderVAE
|
@@ -190,7 +190,7 @@ class DCUpBlock2d(nn.Module):
|
|
190
190
|
x = F.pixel_shuffle(x, self.factor)
|
191
191
|
|
192
192
|
if self.shortcut:
|
193
|
-
y = hidden_states.repeat_interleave(self.repeats, dim=1)
|
193
|
+
y = hidden_states.repeat_interleave(self.repeats, dim=1, output_size=hidden_states.shape[1] * self.repeats)
|
194
194
|
y = F.pixel_shuffle(y, self.factor)
|
195
195
|
hidden_states = x + y
|
196
196
|
else:
|
@@ -361,7 +361,9 @@ class Decoder(nn.Module):
|
|
361
361
|
|
362
362
|
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
363
363
|
if self.in_shortcut:
|
364
|
-
x = hidden_states.repeat_interleave(
|
364
|
+
x = hidden_states.repeat_interleave(
|
365
|
+
self.in_shortcut_repeats, dim=1, output_size=hidden_states.shape[1] * self.in_shortcut_repeats
|
366
|
+
)
|
365
367
|
hidden_states = self.conv_in(hidden_states) + x
|
366
368
|
else:
|
367
369
|
hidden_states = self.conv_in(hidden_states)
|
@@ -486,6 +488,9 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
486
488
|
self.tile_sample_stride_height = 448
|
487
489
|
self.tile_sample_stride_width = 448
|
488
490
|
|
491
|
+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
492
|
+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
493
|
+
|
489
494
|
def enable_tiling(
|
490
495
|
self,
|
491
496
|
tile_sample_min_height: Optional[int] = None,
|
@@ -515,6 +520,8 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
515
520
|
self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
|
516
521
|
self.tile_sample_stride_height = tile_sample_stride_height or self.tile_sample_stride_height
|
517
522
|
self.tile_sample_stride_width = tile_sample_stride_width or self.tile_sample_stride_width
|
523
|
+
self.tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
524
|
+
self.tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
518
525
|
|
519
526
|
def disable_tiling(self) -> None:
|
520
527
|
r"""
|
@@ -606,11 +613,106 @@ class AutoencoderDC(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
606
613
|
return (decoded,)
|
607
614
|
return DecoderOutput(sample=decoded)
|
608
615
|
|
616
|
+
def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
617
|
+
blend_extent = min(a.shape[2], b.shape[2], blend_extent)
|
618
|
+
for y in range(blend_extent):
|
619
|
+
b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
|
620
|
+
return b
|
621
|
+
|
622
|
+
def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
|
623
|
+
blend_extent = min(a.shape[3], b.shape[3], blend_extent)
|
624
|
+
for x in range(blend_extent):
|
625
|
+
b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
|
626
|
+
return b
|
627
|
+
|
609
628
|
def tiled_encode(self, x: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
610
|
-
|
629
|
+
batch_size, num_channels, height, width = x.shape
|
630
|
+
latent_height = height // self.spatial_compression_ratio
|
631
|
+
latent_width = width // self.spatial_compression_ratio
|
632
|
+
|
633
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
634
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
635
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
636
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
637
|
+
blend_height = tile_latent_min_height - tile_latent_stride_height
|
638
|
+
blend_width = tile_latent_min_width - tile_latent_stride_width
|
639
|
+
|
640
|
+
# Split x into overlapping tiles and encode them separately.
|
641
|
+
# The tiles have an overlap to avoid seams between tiles.
|
642
|
+
rows = []
|
643
|
+
for i in range(0, x.shape[2], self.tile_sample_stride_height):
|
644
|
+
row = []
|
645
|
+
for j in range(0, x.shape[3], self.tile_sample_stride_width):
|
646
|
+
tile = x[:, :, i : i + self.tile_sample_min_height, j : j + self.tile_sample_min_width]
|
647
|
+
if (
|
648
|
+
tile.shape[2] % self.spatial_compression_ratio != 0
|
649
|
+
or tile.shape[3] % self.spatial_compression_ratio != 0
|
650
|
+
):
|
651
|
+
pad_h = (self.spatial_compression_ratio - tile.shape[2]) % self.spatial_compression_ratio
|
652
|
+
pad_w = (self.spatial_compression_ratio - tile.shape[3]) % self.spatial_compression_ratio
|
653
|
+
tile = F.pad(tile, (0, pad_w, 0, pad_h))
|
654
|
+
tile = self.encoder(tile)
|
655
|
+
row.append(tile)
|
656
|
+
rows.append(row)
|
657
|
+
result_rows = []
|
658
|
+
for i, row in enumerate(rows):
|
659
|
+
result_row = []
|
660
|
+
for j, tile in enumerate(row):
|
661
|
+
# blend the above tile and the left tile
|
662
|
+
# to the current tile and add the current tile to the result row
|
663
|
+
if i > 0:
|
664
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
665
|
+
if j > 0:
|
666
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
667
|
+
result_row.append(tile[:, :, :tile_latent_stride_height, :tile_latent_stride_width])
|
668
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
669
|
+
|
670
|
+
encoded = torch.cat(result_rows, dim=2)[:, :, :latent_height, :latent_width]
|
671
|
+
|
672
|
+
if not return_dict:
|
673
|
+
return (encoded,)
|
674
|
+
return EncoderOutput(latent=encoded)
|
611
675
|
|
612
676
|
def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
|
613
|
-
|
677
|
+
batch_size, num_channels, height, width = z.shape
|
678
|
+
|
679
|
+
tile_latent_min_height = self.tile_sample_min_height // self.spatial_compression_ratio
|
680
|
+
tile_latent_min_width = self.tile_sample_min_width // self.spatial_compression_ratio
|
681
|
+
tile_latent_stride_height = self.tile_sample_stride_height // self.spatial_compression_ratio
|
682
|
+
tile_latent_stride_width = self.tile_sample_stride_width // self.spatial_compression_ratio
|
683
|
+
|
684
|
+
blend_height = self.tile_sample_min_height - self.tile_sample_stride_height
|
685
|
+
blend_width = self.tile_sample_min_width - self.tile_sample_stride_width
|
686
|
+
|
687
|
+
# Split z into overlapping tiles and decode them separately.
|
688
|
+
# The tiles have an overlap to avoid seams between tiles.
|
689
|
+
rows = []
|
690
|
+
for i in range(0, height, tile_latent_stride_height):
|
691
|
+
row = []
|
692
|
+
for j in range(0, width, tile_latent_stride_width):
|
693
|
+
tile = z[:, :, i : i + tile_latent_min_height, j : j + tile_latent_min_width]
|
694
|
+
decoded = self.decoder(tile)
|
695
|
+
row.append(decoded)
|
696
|
+
rows.append(row)
|
697
|
+
|
698
|
+
result_rows = []
|
699
|
+
for i, row in enumerate(rows):
|
700
|
+
result_row = []
|
701
|
+
for j, tile in enumerate(row):
|
702
|
+
# blend the above tile and the left tile
|
703
|
+
# to the current tile and add the current tile to the result row
|
704
|
+
if i > 0:
|
705
|
+
tile = self.blend_v(rows[i - 1][j], tile, blend_height)
|
706
|
+
if j > 0:
|
707
|
+
tile = self.blend_h(row[j - 1], tile, blend_width)
|
708
|
+
result_row.append(tile[:, :, : self.tile_sample_stride_height, : self.tile_sample_stride_width])
|
709
|
+
result_rows.append(torch.cat(result_row, dim=3))
|
710
|
+
|
711
|
+
decoded = torch.cat(result_rows, dim=2)
|
712
|
+
|
713
|
+
if not return_dict:
|
714
|
+
return (decoded,)
|
715
|
+
return DecoderOutput(sample=decoded)
|
614
716
|
|
615
717
|
def forward(self, sample: torch.Tensor, return_dict: bool = True) -> torch.Tensor:
|
616
718
|
encoded = self.encode(sample, return_dict=False)[0]
|