diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -29,6 +29,7 @@ from ...utils import (
|
|
29
29
|
deprecate,
|
30
30
|
is_bs4_available,
|
31
31
|
is_ftfy_available,
|
32
|
+
is_torch_xla_available,
|
32
33
|
logging,
|
33
34
|
replace_example_docstring,
|
34
35
|
)
|
@@ -36,8 +37,16 @@ from ...utils.torch_utils import randn_tensor
|
|
36
37
|
from ..pipeline_utils import DiffusionPipeline, ImagePipelineOutput
|
37
38
|
|
38
39
|
|
40
|
+
if is_torch_xla_available():
|
41
|
+
import torch_xla.core.xla_model as xm
|
42
|
+
|
43
|
+
XLA_AVAILABLE = True
|
44
|
+
else:
|
45
|
+
XLA_AVAILABLE = False
|
46
|
+
|
39
47
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
40
48
|
|
49
|
+
|
41
50
|
if is_bs4_available():
|
42
51
|
from bs4 import BeautifulSoup
|
43
52
|
|
@@ -285,7 +294,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
285
294
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
286
295
|
)
|
287
296
|
|
288
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
297
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
289
298
|
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
290
299
|
|
291
300
|
# Adapted from diffusers.pipelines.deepfloyd_if.pipeline_if.encode_prompt
|
@@ -898,10 +907,11 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
898
907
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
899
908
|
# This would be a good case for the `match` statement (Python 3.10+)
|
900
909
|
is_mps = latent_model_input.device.type == "mps"
|
910
|
+
is_npu = latent_model_input.device.type == "npu"
|
901
911
|
if isinstance(current_timestep, float):
|
902
|
-
dtype = torch.float32 if is_mps else torch.float64
|
912
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
903
913
|
else:
|
904
|
-
dtype = torch.int32 if is_mps else torch.int64
|
914
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
905
915
|
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
906
916
|
elif len(current_timestep.shape) == 0:
|
907
917
|
current_timestep = current_timestep[None].to(latent_model_input.device)
|
@@ -931,8 +941,7 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
931
941
|
|
932
942
|
# compute previous image: x_t -> x_t-1
|
933
943
|
if num_inference_steps == 1:
|
934
|
-
|
935
|
-
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).pred_original_sample
|
944
|
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[1]
|
936
945
|
else:
|
937
946
|
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0]
|
938
947
|
|
@@ -943,6 +952,9 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
943
952
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
944
953
|
callback(step_idx, t, latents)
|
945
954
|
|
955
|
+
if XLA_AVAILABLE:
|
956
|
+
xm.mark_step()
|
957
|
+
|
946
958
|
if not output_type == "latent":
|
947
959
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
948
960
|
if use_resolution_binning:
|
@@ -29,6 +29,7 @@ from ...utils import (
|
|
29
29
|
deprecate,
|
30
30
|
is_bs4_available,
|
31
31
|
is_ftfy_available,
|
32
|
+
is_torch_xla_available,
|
32
33
|
logging,
|
33
34
|
replace_example_docstring,
|
34
35
|
)
|
@@ -41,8 +42,16 @@ from .pipeline_pixart_alpha import (
|
|
41
42
|
)
|
42
43
|
|
43
44
|
|
45
|
+
if is_torch_xla_available():
|
46
|
+
import torch_xla.core.xla_model as xm
|
47
|
+
|
48
|
+
XLA_AVAILABLE = True
|
49
|
+
else:
|
50
|
+
XLA_AVAILABLE = False
|
51
|
+
|
44
52
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
45
53
|
|
54
|
+
|
46
55
|
if is_bs4_available():
|
47
56
|
from bs4 import BeautifulSoup
|
48
57
|
|
@@ -211,7 +220,7 @@ class PixArtSigmaPipeline(DiffusionPipeline):
|
|
211
220
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
212
221
|
)
|
213
222
|
|
214
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
223
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
215
224
|
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
216
225
|
|
217
226
|
# Copied from diffusers.pipelines.pixart_alpha.pipeline_pixart_alpha.PixArtAlphaPipeline.encode_prompt with 120->300
|
@@ -813,10 +822,11 @@ class PixArtSigmaPipeline(DiffusionPipeline):
|
|
813
822
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
814
823
|
# This would be a good case for the `match` statement (Python 3.10+)
|
815
824
|
is_mps = latent_model_input.device.type == "mps"
|
825
|
+
is_npu = latent_model_input.device.type == "npu"
|
816
826
|
if isinstance(current_timestep, float):
|
817
|
-
dtype = torch.float32 if is_mps else torch.float64
|
827
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
818
828
|
else:
|
819
|
-
dtype = torch.int32 if is_mps else torch.int64
|
829
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
820
830
|
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
821
831
|
elif len(current_timestep.shape) == 0:
|
822
832
|
current_timestep = current_timestep[None].to(latent_model_input.device)
|
@@ -854,8 +864,11 @@ class PixArtSigmaPipeline(DiffusionPipeline):
|
|
854
864
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
855
865
|
callback(step_idx, t, latents)
|
856
866
|
|
867
|
+
if XLA_AVAILABLE:
|
868
|
+
xm.mark_step()
|
869
|
+
|
857
870
|
if not output_type == "latent":
|
858
|
-
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
871
|
+
image = self.vae.decode(latents.to(self.vae.dtype) / self.vae.config.scaling_factor, return_dict=False)[0]
|
859
872
|
if use_resolution_binning:
|
860
873
|
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
861
874
|
else:
|
@@ -23,6 +23,7 @@ except OptionalDependencyNotAvailable:
|
|
23
23
|
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
24
|
else:
|
25
25
|
_import_structure["pipeline_sana"] = ["SanaPipeline"]
|
26
|
+
_import_structure["pipeline_sana_sprint"] = ["SanaSprintPipeline"]
|
26
27
|
|
27
28
|
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
28
29
|
try:
|
@@ -33,6 +34,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
33
34
|
from ...utils.dummy_torch_and_transformers_objects import *
|
34
35
|
else:
|
35
36
|
from .pipeline_sana import SanaPipeline
|
37
|
+
from .pipeline_sana_sprint import SanaSprintPipeline
|
36
38
|
else:
|
37
39
|
import sys
|
38
40
|
|
@@ -16,10 +16,11 @@ import html
|
|
16
16
|
import inspect
|
17
17
|
import re
|
18
18
|
import urllib.parse as ul
|
19
|
+
import warnings
|
19
20
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
20
21
|
|
21
22
|
import torch
|
22
|
-
from transformers import
|
23
|
+
from transformers import Gemma2PreTrainedModel, GemmaTokenizer, GemmaTokenizerFast
|
23
24
|
|
24
25
|
from ...callbacks import MultiPipelineCallbacks, PipelineCallback
|
25
26
|
from ...image_processor import PixArtImageProcessor
|
@@ -31,6 +32,7 @@ from ...utils import (
|
|
31
32
|
USE_PEFT_BACKEND,
|
32
33
|
is_bs4_available,
|
33
34
|
is_ftfy_available,
|
35
|
+
is_torch_xla_available,
|
34
36
|
logging,
|
35
37
|
replace_example_docstring,
|
36
38
|
scale_lora_layers,
|
@@ -46,6 +48,13 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
|
|
46
48
|
from .pipeline_output import SanaPipelineOutput
|
47
49
|
|
48
50
|
|
51
|
+
if is_torch_xla_available():
|
52
|
+
import torch_xla.core.xla_model as xm
|
53
|
+
|
54
|
+
XLA_AVAILABLE = True
|
55
|
+
else:
|
56
|
+
XLA_AVAILABLE = False
|
57
|
+
|
49
58
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
50
59
|
|
51
60
|
if is_bs4_available():
|
@@ -55,6 +64,49 @@ if is_ftfy_available():
|
|
55
64
|
import ftfy
|
56
65
|
|
57
66
|
|
67
|
+
ASPECT_RATIO_4096_BIN = {
|
68
|
+
"0.25": [2048.0, 8192.0],
|
69
|
+
"0.26": [2048.0, 7936.0],
|
70
|
+
"0.27": [2048.0, 7680.0],
|
71
|
+
"0.28": [2048.0, 7424.0],
|
72
|
+
"0.32": [2304.0, 7168.0],
|
73
|
+
"0.33": [2304.0, 6912.0],
|
74
|
+
"0.35": [2304.0, 6656.0],
|
75
|
+
"0.4": [2560.0, 6400.0],
|
76
|
+
"0.42": [2560.0, 6144.0],
|
77
|
+
"0.48": [2816.0, 5888.0],
|
78
|
+
"0.5": [2816.0, 5632.0],
|
79
|
+
"0.52": [2816.0, 5376.0],
|
80
|
+
"0.57": [3072.0, 5376.0],
|
81
|
+
"0.6": [3072.0, 5120.0],
|
82
|
+
"0.68": [3328.0, 4864.0],
|
83
|
+
"0.72": [3328.0, 4608.0],
|
84
|
+
"0.78": [3584.0, 4608.0],
|
85
|
+
"0.82": [3584.0, 4352.0],
|
86
|
+
"0.88": [3840.0, 4352.0],
|
87
|
+
"0.94": [3840.0, 4096.0],
|
88
|
+
"1.0": [4096.0, 4096.0],
|
89
|
+
"1.07": [4096.0, 3840.0],
|
90
|
+
"1.13": [4352.0, 3840.0],
|
91
|
+
"1.21": [4352.0, 3584.0],
|
92
|
+
"1.29": [4608.0, 3584.0],
|
93
|
+
"1.38": [4608.0, 3328.0],
|
94
|
+
"1.46": [4864.0, 3328.0],
|
95
|
+
"1.67": [5120.0, 3072.0],
|
96
|
+
"1.75": [5376.0, 3072.0],
|
97
|
+
"2.0": [5632.0, 2816.0],
|
98
|
+
"2.09": [5888.0, 2816.0],
|
99
|
+
"2.4": [6144.0, 2560.0],
|
100
|
+
"2.5": [6400.0, 2560.0],
|
101
|
+
"2.89": [6656.0, 2304.0],
|
102
|
+
"3.0": [6912.0, 2304.0],
|
103
|
+
"3.11": [7168.0, 2304.0],
|
104
|
+
"3.62": [7424.0, 2048.0],
|
105
|
+
"3.75": [7680.0, 2048.0],
|
106
|
+
"3.88": [7936.0, 2048.0],
|
107
|
+
"4.0": [8192.0, 2048.0],
|
108
|
+
}
|
109
|
+
|
58
110
|
EXAMPLE_DOC_STRING = """
|
59
111
|
Examples:
|
60
112
|
```py
|
@@ -148,8 +200,8 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
148
200
|
|
149
201
|
def __init__(
|
150
202
|
self,
|
151
|
-
tokenizer:
|
152
|
-
text_encoder:
|
203
|
+
tokenizer: Union[GemmaTokenizer, GemmaTokenizerFast],
|
204
|
+
text_encoder: Gemma2PreTrainedModel,
|
153
205
|
vae: AutoencoderDC,
|
154
206
|
transformer: SanaTransformer2DModel,
|
155
207
|
scheduler: DPMSolverMultistepScheduler,
|
@@ -167,6 +219,93 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
167
219
|
)
|
168
220
|
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
169
221
|
|
222
|
+
def enable_vae_slicing(self):
|
223
|
+
r"""
|
224
|
+
Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
|
225
|
+
compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
|
226
|
+
"""
|
227
|
+
self.vae.enable_slicing()
|
228
|
+
|
229
|
+
def disable_vae_slicing(self):
|
230
|
+
r"""
|
231
|
+
Disable sliced VAE decoding. If `enable_vae_slicing` was previously enabled, this method will go back to
|
232
|
+
computing decoding in one step.
|
233
|
+
"""
|
234
|
+
self.vae.disable_slicing()
|
235
|
+
|
236
|
+
def enable_vae_tiling(self):
|
237
|
+
r"""
|
238
|
+
Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
|
239
|
+
compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
|
240
|
+
processing larger images.
|
241
|
+
"""
|
242
|
+
self.vae.enable_tiling()
|
243
|
+
|
244
|
+
def disable_vae_tiling(self):
|
245
|
+
r"""
|
246
|
+
Disable tiled VAE decoding. If `enable_vae_tiling` was previously enabled, this method will go back to
|
247
|
+
computing decoding in one step.
|
248
|
+
"""
|
249
|
+
self.vae.disable_tiling()
|
250
|
+
|
251
|
+
def _get_gemma_prompt_embeds(
|
252
|
+
self,
|
253
|
+
prompt: Union[str, List[str]],
|
254
|
+
device: torch.device,
|
255
|
+
dtype: torch.dtype,
|
256
|
+
clean_caption: bool = False,
|
257
|
+
max_sequence_length: int = 300,
|
258
|
+
complex_human_instruction: Optional[List[str]] = None,
|
259
|
+
):
|
260
|
+
r"""
|
261
|
+
Encodes the prompt into text encoder hidden states.
|
262
|
+
|
263
|
+
Args:
|
264
|
+
prompt (`str` or `List[str]`, *optional*):
|
265
|
+
prompt to be encoded
|
266
|
+
device: (`torch.device`, *optional*):
|
267
|
+
torch device to place the resulting embeddings on
|
268
|
+
clean_caption (`bool`, defaults to `False`):
|
269
|
+
If `True`, the function will preprocess and clean the provided caption before encoding.
|
270
|
+
max_sequence_length (`int`, defaults to 300): Maximum sequence length to use for the prompt.
|
271
|
+
complex_human_instruction (`list[str]`, defaults to `complex_human_instruction`):
|
272
|
+
If `complex_human_instruction` is not empty, the function will use the complex Human instruction for
|
273
|
+
the prompt.
|
274
|
+
"""
|
275
|
+
prompt = [prompt] if isinstance(prompt, str) else prompt
|
276
|
+
|
277
|
+
if getattr(self, "tokenizer", None) is not None:
|
278
|
+
self.tokenizer.padding_side = "right"
|
279
|
+
|
280
|
+
prompt = self._text_preprocessing(prompt, clean_caption=clean_caption)
|
281
|
+
|
282
|
+
# prepare complex human instruction
|
283
|
+
if not complex_human_instruction:
|
284
|
+
max_length_all = max_sequence_length
|
285
|
+
else:
|
286
|
+
chi_prompt = "\n".join(complex_human_instruction)
|
287
|
+
prompt = [chi_prompt + p for p in prompt]
|
288
|
+
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
289
|
+
max_length_all = num_chi_prompt_tokens + max_sequence_length - 2
|
290
|
+
|
291
|
+
text_inputs = self.tokenizer(
|
292
|
+
prompt,
|
293
|
+
padding="max_length",
|
294
|
+
max_length=max_length_all,
|
295
|
+
truncation=True,
|
296
|
+
add_special_tokens=True,
|
297
|
+
return_tensors="pt",
|
298
|
+
)
|
299
|
+
text_input_ids = text_inputs.input_ids
|
300
|
+
|
301
|
+
prompt_attention_mask = text_inputs.attention_mask
|
302
|
+
prompt_attention_mask = prompt_attention_mask.to(device)
|
303
|
+
|
304
|
+
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
305
|
+
prompt_embeds = prompt_embeds[0].to(dtype=dtype, device=device)
|
306
|
+
|
307
|
+
return prompt_embeds, prompt_attention_mask
|
308
|
+
|
170
309
|
def encode_prompt(
|
171
310
|
self,
|
172
311
|
prompt: Union[str, List[str]],
|
@@ -215,6 +354,13 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
215
354
|
if device is None:
|
216
355
|
device = self._execution_device
|
217
356
|
|
357
|
+
if self.transformer is not None:
|
358
|
+
dtype = self.transformer.dtype
|
359
|
+
elif self.text_encoder is not None:
|
360
|
+
dtype = self.text_encoder.dtype
|
361
|
+
else:
|
362
|
+
dtype = None
|
363
|
+
|
218
364
|
# set lora scale so that monkey patched LoRA
|
219
365
|
# function of text encoder can correctly access it
|
220
366
|
if lora_scale is not None and isinstance(self, SanaLoraLoaderMixin):
|
@@ -231,50 +377,26 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
231
377
|
else:
|
232
378
|
batch_size = prompt_embeds.shape[0]
|
233
379
|
|
234
|
-
self
|
380
|
+
if getattr(self, "tokenizer", None) is not None:
|
381
|
+
self.tokenizer.padding_side = "right"
|
235
382
|
|
236
383
|
# See Section 3.1. of the paper.
|
237
384
|
max_length = max_sequence_length
|
238
385
|
select_index = [0] + list(range(-max_length + 1, 0))
|
239
386
|
|
240
387
|
if prompt_embeds is None:
|
241
|
-
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
246
|
-
|
247
|
-
|
248
|
-
prompt = [chi_prompt + p for p in prompt]
|
249
|
-
num_chi_prompt_tokens = len(self.tokenizer.encode(chi_prompt))
|
250
|
-
max_length_all = num_chi_prompt_tokens + max_length - 2
|
251
|
-
|
252
|
-
text_inputs = self.tokenizer(
|
253
|
-
prompt,
|
254
|
-
padding="max_length",
|
255
|
-
max_length=max_length_all,
|
256
|
-
truncation=True,
|
257
|
-
add_special_tokens=True,
|
258
|
-
return_tensors="pt",
|
388
|
+
prompt_embeds, prompt_attention_mask = self._get_gemma_prompt_embeds(
|
389
|
+
prompt=prompt,
|
390
|
+
device=device,
|
391
|
+
dtype=dtype,
|
392
|
+
clean_caption=clean_caption,
|
393
|
+
max_sequence_length=max_sequence_length,
|
394
|
+
complex_human_instruction=complex_human_instruction,
|
259
395
|
)
|
260
|
-
text_input_ids = text_inputs.input_ids
|
261
396
|
|
262
|
-
|
263
|
-
prompt_attention_mask = prompt_attention_mask.to(device)
|
264
|
-
|
265
|
-
prompt_embeds = self.text_encoder(text_input_ids.to(device), attention_mask=prompt_attention_mask)
|
266
|
-
prompt_embeds = prompt_embeds[0][:, select_index]
|
397
|
+
prompt_embeds = prompt_embeds[:, select_index]
|
267
398
|
prompt_attention_mask = prompt_attention_mask[:, select_index]
|
268
399
|
|
269
|
-
if self.transformer is not None:
|
270
|
-
dtype = self.transformer.dtype
|
271
|
-
elif self.text_encoder is not None:
|
272
|
-
dtype = self.text_encoder.dtype
|
273
|
-
else:
|
274
|
-
dtype = None
|
275
|
-
|
276
|
-
prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
|
277
|
-
|
278
400
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
279
401
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
280
402
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
@@ -284,25 +406,15 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
284
406
|
|
285
407
|
# get unconditional embeddings for classifier free guidance
|
286
408
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
295
|
-
return_attention_mask=True,
|
296
|
-
add_special_tokens=True,
|
297
|
-
return_tensors="pt",
|
298
|
-
)
|
299
|
-
negative_prompt_attention_mask = uncond_input.attention_mask
|
300
|
-
negative_prompt_attention_mask = negative_prompt_attention_mask.to(device)
|
301
|
-
|
302
|
-
negative_prompt_embeds = self.text_encoder(
|
303
|
-
uncond_input.input_ids.to(device), attention_mask=negative_prompt_attention_mask
|
409
|
+
negative_prompt = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
|
410
|
+
negative_prompt_embeds, negative_prompt_attention_mask = self._get_gemma_prompt_embeds(
|
411
|
+
prompt=negative_prompt,
|
412
|
+
device=device,
|
413
|
+
dtype=dtype,
|
414
|
+
clean_caption=clean_caption,
|
415
|
+
max_sequence_length=max_sequence_length,
|
416
|
+
complex_human_instruction=False,
|
304
417
|
)
|
305
|
-
negative_prompt_embeds = negative_prompt_embeds[0]
|
306
418
|
|
307
419
|
if do_classifier_free_guidance:
|
308
420
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
@@ -611,7 +723,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
611
723
|
negative_prompt_attention_mask: Optional[torch.Tensor] = None,
|
612
724
|
output_type: Optional[str] = "pil",
|
613
725
|
return_dict: bool = True,
|
614
|
-
clean_caption: bool =
|
726
|
+
clean_caption: bool = False,
|
615
727
|
use_resolution_binning: bool = True,
|
616
728
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
617
729
|
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
@@ -726,7 +838,9 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
726
838
|
|
727
839
|
# 1. Check inputs. Raise error if not correct
|
728
840
|
if use_resolution_binning:
|
729
|
-
if self.transformer.config.sample_size ==
|
841
|
+
if self.transformer.config.sample_size == 128:
|
842
|
+
aspect_ratio_bin = ASPECT_RATIO_4096_BIN
|
843
|
+
elif self.transformer.config.sample_size == 64:
|
730
844
|
aspect_ratio_bin = ASPECT_RATIO_2048_BIN
|
731
845
|
elif self.transformer.config.sample_size == 32:
|
732
846
|
aspect_ratio_bin = ASPECT_RATIO_1024_BIN
|
@@ -824,6 +938,7 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
824
938
|
|
825
939
|
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
826
940
|
timestep = t.expand(latent_model_input.shape[0]).to(latents.dtype)
|
941
|
+
timestep = timestep * self.transformer.config.timestep_scale
|
827
942
|
|
828
943
|
# predict noise model_output
|
829
944
|
noise_pred = self.transformer(
|
@@ -864,11 +979,21 @@ class SanaPipeline(DiffusionPipeline, SanaLoraLoaderMixin):
|
|
864
979
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
865
980
|
progress_bar.update()
|
866
981
|
|
982
|
+
if XLA_AVAILABLE:
|
983
|
+
xm.mark_step()
|
984
|
+
|
867
985
|
if output_type == "latent":
|
868
986
|
image = latents
|
869
987
|
else:
|
870
988
|
latents = latents.to(self.vae.dtype)
|
871
|
-
|
989
|
+
try:
|
990
|
+
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
991
|
+
except torch.cuda.OutOfMemoryError as e:
|
992
|
+
warnings.warn(
|
993
|
+
f"{e}. \n"
|
994
|
+
f"Try to use VAE tiling for large images. For example: \n"
|
995
|
+
f"pipe.vae.enable_tiling(tile_sample_min_width=512, tile_sample_min_height=512)"
|
996
|
+
)
|
872
997
|
if use_resolution_binning:
|
873
998
|
image = self.image_processor.resize_and_crop_tensor(image, orig_width, orig_height)
|
874
999
|
|