diffusers 0.32.2__py3-none-any.whl → 0.33.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +186 -3
- diffusers/configuration_utils.py +40 -12
- diffusers/dependency_versions_table.py +9 -2
- diffusers/hooks/__init__.py +9 -0
- diffusers/hooks/faster_cache.py +653 -0
- diffusers/hooks/group_offloading.py +793 -0
- diffusers/hooks/hooks.py +236 -0
- diffusers/hooks/layerwise_casting.py +245 -0
- diffusers/hooks/pyramid_attention_broadcast.py +311 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +38 -30
- diffusers/loaders/lora_base.py +121 -86
- diffusers/loaders/lora_conversion_utils.py +504 -44
- diffusers/loaders/lora_pipeline.py +1769 -181
- diffusers/loaders/peft.py +167 -57
- diffusers/loaders/single_file.py +17 -2
- diffusers/loaders/single_file_model.py +53 -5
- diffusers/loaders/single_file_utils.py +646 -72
- diffusers/loaders/textual_inversion.py +9 -9
- diffusers/loaders/transformer_flux.py +8 -9
- diffusers/loaders/transformer_sd3.py +120 -39
- diffusers/loaders/unet.py +20 -7
- diffusers/models/__init__.py +22 -0
- diffusers/models/activations.py +9 -9
- diffusers/models/attention.py +0 -1
- diffusers/models/attention_processor.py +163 -25
- diffusers/models/auto_model.py +169 -0
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +2 -0
- diffusers/models/autoencoders/autoencoder_dc.py +106 -4
- diffusers/models/autoencoders/autoencoder_kl.py +0 -4
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +5 -23
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +17 -55
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +17 -97
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +326 -107
- diffusers/models/autoencoders/autoencoder_kl_magvit.py +1094 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +21 -56
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -42
- diffusers/models/autoencoders/autoencoder_kl_wan.py +855 -0
- diffusers/models/autoencoders/autoencoder_oobleck.py +1 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +0 -4
- diffusers/models/autoencoders/consistency_decoder_vae.py +3 -1
- diffusers/models/autoencoders/vae.py +31 -141
- diffusers/models/autoencoders/vq_model.py +3 -0
- diffusers/models/cache_utils.py +108 -0
- diffusers/models/controlnets/__init__.py +1 -0
- diffusers/models/controlnets/controlnet.py +3 -8
- diffusers/models/controlnets/controlnet_flux.py +14 -42
- diffusers/models/controlnets/controlnet_sd3.py +58 -34
- diffusers/models/controlnets/controlnet_sparsectrl.py +4 -7
- diffusers/models/controlnets/controlnet_union.py +27 -18
- diffusers/models/controlnets/controlnet_xs.py +7 -46
- diffusers/models/controlnets/multicontrolnet_union.py +196 -0
- diffusers/models/embeddings.py +18 -7
- diffusers/models/model_loading_utils.py +122 -80
- diffusers/models/modeling_flax_pytorch_utils.py +1 -1
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +617 -272
- diffusers/models/normalization.py +67 -14
- diffusers/models/resnet.py +1 -1
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +9 -35
- diffusers/models/transformers/cogvideox_transformer_3d.py +13 -24
- diffusers/models/transformers/consisid_transformer_3d.py +789 -0
- diffusers/models/transformers/dit_transformer_2d.py +5 -19
- diffusers/models/transformers/hunyuan_transformer_2d.py +4 -3
- diffusers/models/transformers/latte_transformer_3d.py +20 -15
- diffusers/models/transformers/lumina_nextdit2d.py +3 -1
- diffusers/models/transformers/pixart_transformer_2d.py +4 -19
- diffusers/models/transformers/prior_transformer.py +5 -1
- diffusers/models/transformers/sana_transformer.py +144 -40
- diffusers/models/transformers/stable_audio_transformer.py +5 -20
- diffusers/models/transformers/transformer_2d.py +7 -22
- diffusers/models/transformers/transformer_allegro.py +9 -17
- diffusers/models/transformers/transformer_cogview3plus.py +6 -17
- diffusers/models/transformers/transformer_cogview4.py +462 -0
- diffusers/models/transformers/transformer_easyanimate.py +527 -0
- diffusers/models/transformers/transformer_flux.py +68 -110
- diffusers/models/transformers/transformer_hunyuan_video.py +404 -46
- diffusers/models/transformers/transformer_ltx.py +53 -35
- diffusers/models/transformers/transformer_lumina2.py +548 -0
- diffusers/models/transformers/transformer_mochi.py +6 -17
- diffusers/models/transformers/transformer_omnigen.py +469 -0
- diffusers/models/transformers/transformer_sd3.py +56 -86
- diffusers/models/transformers/transformer_temporal.py +5 -11
- diffusers/models/transformers/transformer_wan.py +469 -0
- diffusers/models/unets/unet_1d.py +3 -1
- diffusers/models/unets/unet_2d.py +21 -20
- diffusers/models/unets/unet_2d_blocks.py +19 -243
- diffusers/models/unets/unet_2d_condition.py +4 -6
- diffusers/models/unets/unet_3d_blocks.py +14 -127
- diffusers/models/unets/unet_3d_condition.py +8 -12
- diffusers/models/unets/unet_i2vgen_xl.py +5 -13
- diffusers/models/unets/unet_kandinsky3.py +0 -4
- diffusers/models/unets/unet_motion_model.py +20 -114
- diffusers/models/unets/unet_spatio_temporal_condition.py +7 -8
- diffusers/models/unets/unet_stable_cascade.py +8 -35
- diffusers/models/unets/uvit_2d.py +1 -4
- diffusers/optimization.py +2 -2
- diffusers/pipelines/__init__.py +57 -8
- diffusers/pipelines/allegro/pipeline_allegro.py +22 -2
- diffusers/pipelines/amused/pipeline_amused.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_img2img.py +15 -2
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +15 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +24 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +15 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +16 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +16 -4
- diffusers/pipelines/audioldm/pipeline_audioldm.py +13 -2
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +13 -68
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +39 -9
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +63 -7
- diffusers/pipelines/auto_pipeline.py +35 -14
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -8
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +12 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +22 -6
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +22 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +22 -6
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +12 -4
- diffusers/pipelines/cogview4/__init__.py +49 -0
- diffusers/pipelines/cogview4/pipeline_cogview4.py +684 -0
- diffusers/pipelines/cogview4/pipeline_cogview4_control.py +732 -0
- diffusers/pipelines/cogview4/pipeline_output.py +21 -0
- diffusers/pipelines/consisid/__init__.py +49 -0
- diffusers/pipelines/consisid/consisid_utils.py +357 -0
- diffusers/pipelines/consisid/pipeline_consisid.py +974 -0
- diffusers/pipelines/consisid/pipeline_output.py +20 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +11 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +6 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +13 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +17 -5
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +31 -12
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +26 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +22 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +26 -25
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +224 -109
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +25 -29
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +7 -4
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +3 -5
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +121 -10
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +122 -11
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +12 -1
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +20 -3
- diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +14 -2
- diffusers/pipelines/ddim/pipeline_ddim.py +14 -1
- diffusers/pipelines/ddpm/pipeline_ddpm.py +15 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +12 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +14 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +14 -1
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -7
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -7
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_onnx_stable_diffusion_inpaint_legacy.py +2 -2
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +11 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +10 -105
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +15 -2
- diffusers/pipelines/easyanimate/__init__.py +52 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate.py +770 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_control.py +994 -0
- diffusers/pipelines/easyanimate/pipeline_easyanimate_inpaint.py +1234 -0
- diffusers/pipelines/easyanimate/pipeline_output.py +20 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -21
- diffusers/pipelines/flux/pipeline_flux_control.py +9 -12
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +6 -10
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +185 -13
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +8 -10
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +16 -16
- diffusers/pipelines/flux/pipeline_flux_fill.py +107 -39
- diffusers/pipelines/flux/pipeline_flux_img2img.py +193 -15
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +199 -19
- diffusers/pipelines/free_noise_utils.py +3 -3
- diffusers/pipelines/hunyuan_video/__init__.py +4 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_skyreels_image2video.py +804 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +90 -23
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video_image2video.py +924 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +3 -5
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +12 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +13 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +13 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +12 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +12 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +12 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +12 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +10 -8
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +6 -4
- diffusers/pipelines/kolors/text_encoder.py +7 -34
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +12 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +13 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +14 -13
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion_superresolution.py +12 -1
- diffusers/pipelines/latte/pipeline_latte.py +36 -7
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +67 -13
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +60 -15
- diffusers/pipelines/ltx/__init__.py +2 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +25 -13
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +1194 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +31 -17
- diffusers/pipelines/lumina/__init__.py +2 -2
- diffusers/pipelines/lumina/pipeline_lumina.py +83 -20
- diffusers/pipelines/lumina2/__init__.py +48 -0
- diffusers/pipelines/lumina2/pipeline_lumina2.py +790 -0
- diffusers/pipelines/marigold/__init__.py +2 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +127 -14
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +31 -16
- diffusers/pipelines/marigold/pipeline_marigold_intrinsics.py +721 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +31 -16
- diffusers/pipelines/mochi/pipeline_mochi.py +14 -18
- diffusers/pipelines/musicldm/pipeline_musicldm.py +16 -1
- diffusers/pipelines/omnigen/__init__.py +50 -0
- diffusers/pipelines/omnigen/pipeline_omnigen.py +512 -0
- diffusers/pipelines/omnigen/processor_omnigen.py +327 -0
- diffusers/pipelines/onnx_utils.py +5 -3
- diffusers/pipelines/pag/pag_utils.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +12 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +15 -4
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +20 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_kolors.py +6 -4
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +16 -3
- diffusers/pipelines/pag/pipeline_pag_sana.py +65 -8
- diffusers/pipelines/pag/pipeline_pag_sd.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +13 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +23 -7
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +26 -10
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +12 -4
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +7 -3
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +10 -6
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +13 -3
- diffusers/pipelines/pia/pipeline_pia.py +13 -1
- diffusers/pipelines/pipeline_flax_utils.py +7 -7
- diffusers/pipelines/pipeline_loading_utils.py +193 -83
- diffusers/pipelines/pipeline_utils.py +221 -106
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +17 -5
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +17 -4
- diffusers/pipelines/sana/__init__.py +2 -0
- diffusers/pipelines/sana/pipeline_sana.py +183 -58
- diffusers/pipelines/sana/pipeline_sana_sprint.py +889 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +12 -2
- diffusers/pipelines/shap_e/pipeline_shap_e.py +12 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +12 -0
- diffusers/pipelines/shap_e/renderer.py +6 -6
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +15 -4
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +12 -8
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +12 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +3 -2
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +14 -10
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -13
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +30 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +24 -10
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +28 -12
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +39 -18
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +13 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +20 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +14 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +13 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +16 -17
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +136 -18
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +150 -21
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +15 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +26 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +15 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +22 -4
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -13
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +12 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +15 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +15 -3
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +26 -12
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +12 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +7 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +10 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +11 -4
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +13 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +18 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +26 -5
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +13 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +26 -4
- diffusers/pipelines/transformers_loading_utils.py +121 -0
- diffusers/pipelines/unclip/pipeline_unclip.py +11 -1
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +11 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +19 -2
- diffusers/pipelines/wan/__init__.py +51 -0
- diffusers/pipelines/wan/pipeline_output.py +20 -0
- diffusers/pipelines/wan/pipeline_wan.py +593 -0
- diffusers/pipelines/wan/pipeline_wan_i2v.py +722 -0
- diffusers/pipelines/wan/pipeline_wan_video2video.py +725 -0
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +7 -31
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +12 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +12 -1
- diffusers/quantizers/auto.py +5 -1
- diffusers/quantizers/base.py +5 -9
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +41 -29
- diffusers/quantizers/bitsandbytes/utils.py +30 -20
- diffusers/quantizers/gguf/gguf_quantizer.py +1 -0
- diffusers/quantizers/gguf/utils.py +4 -2
- diffusers/quantizers/quantization_config.py +59 -4
- diffusers/quantizers/quanto/__init__.py +1 -0
- diffusers/quantizers/quanto/quanto_quantizer.py +177 -0
- diffusers/quantizers/quanto/utils.py +60 -0
- diffusers/quantizers/torchao/__init__.py +1 -1
- diffusers/quantizers/torchao/torchao_quantizer.py +47 -2
- diffusers/schedulers/__init__.py +2 -1
- diffusers/schedulers/scheduling_consistency_models.py +1 -2
- diffusers/schedulers/scheduling_ddim_inverse.py +1 -1
- diffusers/schedulers/scheduling_ddpm.py +2 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +12 -4
- diffusers/schedulers/scheduling_edm_euler.py +45 -10
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +116 -28
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +7 -6
- diffusers/schedulers/scheduling_heun_discrete.py +1 -1
- diffusers/schedulers/scheduling_lcm.py +1 -2
- diffusers/schedulers/scheduling_lms_discrete.py +1 -1
- diffusers/schedulers/scheduling_repaint.py +5 -1
- diffusers/schedulers/scheduling_scm.py +265 -0
- diffusers/schedulers/scheduling_tcd.py +1 -2
- diffusers/schedulers/scheduling_utils.py +2 -1
- diffusers/training_utils.py +14 -7
- diffusers/utils/__init__.py +9 -1
- diffusers/utils/constants.py +13 -1
- diffusers/utils/deprecation_utils.py +1 -1
- diffusers/utils/dummy_bitsandbytes_objects.py +17 -0
- diffusers/utils/dummy_gguf_objects.py +17 -0
- diffusers/utils/dummy_optimum_quanto_objects.py +17 -0
- diffusers/utils/dummy_pt_objects.py +233 -0
- diffusers/utils/dummy_torch_and_transformers_and_opencv_objects.py +17 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dummy_torchao_objects.py +17 -0
- diffusers/utils/dynamic_modules_utils.py +1 -1
- diffusers/utils/export_utils.py +28 -3
- diffusers/utils/hub_utils.py +52 -102
- diffusers/utils/import_utils.py +121 -221
- diffusers/utils/loading_utils.py +2 -1
- diffusers/utils/logging.py +1 -2
- diffusers/utils/peft_utils.py +6 -14
- diffusers/utils/remote_utils.py +425 -0
- diffusers/utils/source_code_parsing_utils.py +52 -0
- diffusers/utils/state_dict_utils.py +15 -1
- diffusers/utils/testing_utils.py +243 -13
- diffusers/utils/torch_utils.py +10 -0
- diffusers/utils/typing_utils.py +91 -0
- diffusers/video_processor.py +1 -1
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/METADATA +76 -44
- diffusers-0.33.0.dist-info/RECORD +608 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/WHEEL +1 -1
- diffusers-0.32.2.dist-info/RECORD +0 -550
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/LICENSE +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.32.2.dist-info → diffusers-0.33.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,327 @@
|
|
1
|
+
# Copyright 2024 OmniGen team and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
import re
|
16
|
+
from typing import Dict, List
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
from PIL import Image
|
21
|
+
from torchvision import transforms
|
22
|
+
|
23
|
+
|
24
|
+
def crop_image(pil_image, max_image_size):
|
25
|
+
"""
|
26
|
+
Crop the image so that its height and width does not exceed `max_image_size`, while ensuring both the height and
|
27
|
+
width are multiples of 16.
|
28
|
+
"""
|
29
|
+
while min(*pil_image.size) >= 2 * max_image_size:
|
30
|
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
31
|
+
|
32
|
+
if max(*pil_image.size) > max_image_size:
|
33
|
+
scale = max_image_size / max(*pil_image.size)
|
34
|
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
35
|
+
|
36
|
+
if min(*pil_image.size) < 16:
|
37
|
+
scale = 16 / min(*pil_image.size)
|
38
|
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
39
|
+
|
40
|
+
arr = np.array(pil_image)
|
41
|
+
crop_y1 = (arr.shape[0] % 16) // 2
|
42
|
+
crop_y2 = arr.shape[0] % 16 - crop_y1
|
43
|
+
|
44
|
+
crop_x1 = (arr.shape[1] % 16) // 2
|
45
|
+
crop_x2 = arr.shape[1] % 16 - crop_x1
|
46
|
+
|
47
|
+
arr = arr[crop_y1 : arr.shape[0] - crop_y2, crop_x1 : arr.shape[1] - crop_x2]
|
48
|
+
return Image.fromarray(arr)
|
49
|
+
|
50
|
+
|
51
|
+
class OmniGenMultiModalProcessor:
|
52
|
+
def __init__(self, text_tokenizer, max_image_size: int = 1024):
|
53
|
+
self.text_tokenizer = text_tokenizer
|
54
|
+
self.max_image_size = max_image_size
|
55
|
+
|
56
|
+
self.image_transform = transforms.Compose(
|
57
|
+
[
|
58
|
+
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
|
59
|
+
transforms.ToTensor(),
|
60
|
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
61
|
+
]
|
62
|
+
)
|
63
|
+
|
64
|
+
self.collator = OmniGenCollator()
|
65
|
+
|
66
|
+
def reset_max_image_size(self, max_image_size):
|
67
|
+
self.max_image_size = max_image_size
|
68
|
+
self.image_transform = transforms.Compose(
|
69
|
+
[
|
70
|
+
transforms.Lambda(lambda pil_image: crop_image(pil_image, max_image_size)),
|
71
|
+
transforms.ToTensor(),
|
72
|
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
73
|
+
]
|
74
|
+
)
|
75
|
+
|
76
|
+
def process_image(self, image):
|
77
|
+
if isinstance(image, str):
|
78
|
+
image = Image.open(image).convert("RGB")
|
79
|
+
return self.image_transform(image)
|
80
|
+
|
81
|
+
def process_multi_modal_prompt(self, text, input_images):
|
82
|
+
text = self.add_prefix_instruction(text)
|
83
|
+
if input_images is None or len(input_images) == 0:
|
84
|
+
model_inputs = self.text_tokenizer(text)
|
85
|
+
return {"input_ids": model_inputs.input_ids, "pixel_values": None, "image_sizes": None}
|
86
|
+
|
87
|
+
pattern = r"<\|image_\d+\|>"
|
88
|
+
prompt_chunks = [self.text_tokenizer(chunk).input_ids for chunk in re.split(pattern, text)]
|
89
|
+
|
90
|
+
for i in range(1, len(prompt_chunks)):
|
91
|
+
if prompt_chunks[i][0] == 1:
|
92
|
+
prompt_chunks[i] = prompt_chunks[i][1:]
|
93
|
+
|
94
|
+
image_tags = re.findall(pattern, text)
|
95
|
+
image_ids = [int(s.split("|")[1].split("_")[-1]) for s in image_tags]
|
96
|
+
|
97
|
+
unique_image_ids = sorted(set(image_ids))
|
98
|
+
assert unique_image_ids == list(range(1, len(unique_image_ids) + 1)), (
|
99
|
+
f"image_ids must start from 1, and must be continuous int, e.g. [1, 2, 3], cannot be {unique_image_ids}"
|
100
|
+
)
|
101
|
+
# total images must be the same as the number of image tags
|
102
|
+
assert len(unique_image_ids) == len(input_images), (
|
103
|
+
f"total images must be the same as the number of image tags, got {len(unique_image_ids)} image tags and {len(input_images)} images"
|
104
|
+
)
|
105
|
+
|
106
|
+
input_images = [input_images[x - 1] for x in image_ids]
|
107
|
+
|
108
|
+
all_input_ids = []
|
109
|
+
img_inx = []
|
110
|
+
for i in range(len(prompt_chunks)):
|
111
|
+
all_input_ids.extend(prompt_chunks[i])
|
112
|
+
if i != len(prompt_chunks) - 1:
|
113
|
+
start_inx = len(all_input_ids)
|
114
|
+
size = input_images[i].size(-2) * input_images[i].size(-1) // 16 // 16
|
115
|
+
img_inx.append([start_inx, start_inx + size])
|
116
|
+
all_input_ids.extend([0] * size)
|
117
|
+
|
118
|
+
return {"input_ids": all_input_ids, "pixel_values": input_images, "image_sizes": img_inx}
|
119
|
+
|
120
|
+
def add_prefix_instruction(self, prompt):
|
121
|
+
user_prompt = "<|user|>\n"
|
122
|
+
generation_prompt = "Generate an image according to the following instructions\n"
|
123
|
+
assistant_prompt = "<|assistant|>\n<|diffusion|>"
|
124
|
+
prompt_suffix = "<|end|>\n"
|
125
|
+
prompt = f"{user_prompt}{generation_prompt}{prompt}{prompt_suffix}{assistant_prompt}"
|
126
|
+
return prompt
|
127
|
+
|
128
|
+
def __call__(
|
129
|
+
self,
|
130
|
+
instructions: List[str],
|
131
|
+
input_images: List[List[str]] = None,
|
132
|
+
height: int = 1024,
|
133
|
+
width: int = 1024,
|
134
|
+
negative_prompt: str = "low quality, jpeg artifacts, ugly, duplicate, morbid, mutilated, extra fingers, mutated hands, poorly drawn hands, poorly drawn face, mutation, deformed, blurry, dehydrated, bad anatomy, bad proportions, extra limbs, cloned face, disfigured, gross proportions, malformed limbs, missing arms, missing legs, extra arms, extra legs, fused fingers, too many fingers.",
|
135
|
+
use_img_cfg: bool = True,
|
136
|
+
separate_cfg_input: bool = False,
|
137
|
+
use_input_image_size_as_output: bool = False,
|
138
|
+
num_images_per_prompt: int = 1,
|
139
|
+
) -> Dict:
|
140
|
+
if isinstance(instructions, str):
|
141
|
+
instructions = [instructions]
|
142
|
+
input_images = [input_images]
|
143
|
+
|
144
|
+
input_data = []
|
145
|
+
for i in range(len(instructions)):
|
146
|
+
cur_instruction = instructions[i]
|
147
|
+
cur_input_images = None if input_images is None else input_images[i]
|
148
|
+
if cur_input_images is not None and len(cur_input_images) > 0:
|
149
|
+
cur_input_images = [self.process_image(x) for x in cur_input_images]
|
150
|
+
else:
|
151
|
+
cur_input_images = None
|
152
|
+
assert "<img><|image_1|></img>" not in cur_instruction
|
153
|
+
|
154
|
+
mllm_input = self.process_multi_modal_prompt(cur_instruction, cur_input_images)
|
155
|
+
|
156
|
+
neg_mllm_input, img_cfg_mllm_input = None, None
|
157
|
+
neg_mllm_input = self.process_multi_modal_prompt(negative_prompt, None)
|
158
|
+
if use_img_cfg:
|
159
|
+
if cur_input_images is not None and len(cur_input_images) >= 1:
|
160
|
+
img_cfg_prompt = [f"<img><|image_{i + 1}|></img>" for i in range(len(cur_input_images))]
|
161
|
+
img_cfg_mllm_input = self.process_multi_modal_prompt(" ".join(img_cfg_prompt), cur_input_images)
|
162
|
+
else:
|
163
|
+
img_cfg_mllm_input = neg_mllm_input
|
164
|
+
|
165
|
+
for _ in range(num_images_per_prompt):
|
166
|
+
if use_input_image_size_as_output:
|
167
|
+
input_data.append(
|
168
|
+
(
|
169
|
+
mllm_input,
|
170
|
+
neg_mllm_input,
|
171
|
+
img_cfg_mllm_input,
|
172
|
+
[mllm_input["pixel_values"][0].size(-2), mllm_input["pixel_values"][0].size(-1)],
|
173
|
+
)
|
174
|
+
)
|
175
|
+
else:
|
176
|
+
input_data.append((mllm_input, neg_mllm_input, img_cfg_mllm_input, [height, width]))
|
177
|
+
|
178
|
+
return self.collator(input_data)
|
179
|
+
|
180
|
+
|
181
|
+
class OmniGenCollator:
|
182
|
+
def __init__(self, pad_token_id=2, hidden_size=3072):
|
183
|
+
self.pad_token_id = pad_token_id
|
184
|
+
self.hidden_size = hidden_size
|
185
|
+
|
186
|
+
def create_position(self, attention_mask, num_tokens_for_output_images):
|
187
|
+
position_ids = []
|
188
|
+
text_length = attention_mask.size(-1)
|
189
|
+
img_length = max(num_tokens_for_output_images)
|
190
|
+
for mask in attention_mask:
|
191
|
+
temp_l = torch.sum(mask)
|
192
|
+
temp_position = [0] * (text_length - temp_l) + list(
|
193
|
+
range(temp_l + img_length + 1)
|
194
|
+
) # we add a time embedding into the sequence, so add one more token
|
195
|
+
position_ids.append(temp_position)
|
196
|
+
return torch.LongTensor(position_ids)
|
197
|
+
|
198
|
+
def create_mask(self, attention_mask, num_tokens_for_output_images):
|
199
|
+
"""
|
200
|
+
OmniGen applies causal attention to each element in the sequence, but applies bidirectional attention within
|
201
|
+
each image sequence References: [OmniGen](https://arxiv.org/pdf/2409.11340)
|
202
|
+
"""
|
203
|
+
extended_mask = []
|
204
|
+
padding_images = []
|
205
|
+
text_length = attention_mask.size(-1)
|
206
|
+
img_length = max(num_tokens_for_output_images)
|
207
|
+
seq_len = text_length + img_length + 1 # we add a time embedding into the sequence, so add one more token
|
208
|
+
inx = 0
|
209
|
+
for mask in attention_mask:
|
210
|
+
temp_l = torch.sum(mask)
|
211
|
+
pad_l = text_length - temp_l
|
212
|
+
|
213
|
+
temp_mask = torch.tril(torch.ones(size=(temp_l + 1, temp_l + 1)))
|
214
|
+
|
215
|
+
image_mask = torch.zeros(size=(temp_l + 1, img_length))
|
216
|
+
temp_mask = torch.cat([temp_mask, image_mask], dim=-1)
|
217
|
+
|
218
|
+
image_mask = torch.ones(size=(img_length, temp_l + img_length + 1))
|
219
|
+
temp_mask = torch.cat([temp_mask, image_mask], dim=0)
|
220
|
+
|
221
|
+
if pad_l > 0:
|
222
|
+
pad_mask = torch.zeros(size=(temp_l + 1 + img_length, pad_l))
|
223
|
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=-1)
|
224
|
+
|
225
|
+
pad_mask = torch.ones(size=(pad_l, seq_len))
|
226
|
+
temp_mask = torch.cat([pad_mask, temp_mask], dim=0)
|
227
|
+
|
228
|
+
true_img_length = num_tokens_for_output_images[inx]
|
229
|
+
pad_img_length = img_length - true_img_length
|
230
|
+
if pad_img_length > 0:
|
231
|
+
temp_mask[:, -pad_img_length:] = 0
|
232
|
+
temp_padding_imgs = torch.zeros(size=(1, pad_img_length, self.hidden_size))
|
233
|
+
else:
|
234
|
+
temp_padding_imgs = None
|
235
|
+
|
236
|
+
extended_mask.append(temp_mask.unsqueeze(0))
|
237
|
+
padding_images.append(temp_padding_imgs)
|
238
|
+
inx += 1
|
239
|
+
return torch.cat(extended_mask, dim=0), padding_images
|
240
|
+
|
241
|
+
def adjust_attention_for_input_images(self, attention_mask, image_sizes):
|
242
|
+
for b_inx in image_sizes.keys():
|
243
|
+
for start_inx, end_inx in image_sizes[b_inx]:
|
244
|
+
attention_mask[b_inx][start_inx:end_inx, start_inx:end_inx] = 1
|
245
|
+
|
246
|
+
return attention_mask
|
247
|
+
|
248
|
+
def pad_input_ids(self, input_ids, image_sizes):
|
249
|
+
max_l = max([len(x) for x in input_ids])
|
250
|
+
padded_ids = []
|
251
|
+
attention_mask = []
|
252
|
+
|
253
|
+
for i in range(len(input_ids)):
|
254
|
+
temp_ids = input_ids[i]
|
255
|
+
temp_l = len(temp_ids)
|
256
|
+
pad_l = max_l - temp_l
|
257
|
+
if pad_l == 0:
|
258
|
+
attention_mask.append([1] * max_l)
|
259
|
+
padded_ids.append(temp_ids)
|
260
|
+
else:
|
261
|
+
attention_mask.append([0] * pad_l + [1] * temp_l)
|
262
|
+
padded_ids.append([self.pad_token_id] * pad_l + temp_ids)
|
263
|
+
|
264
|
+
if i in image_sizes:
|
265
|
+
new_inx = []
|
266
|
+
for old_inx in image_sizes[i]:
|
267
|
+
new_inx.append([x + pad_l for x in old_inx])
|
268
|
+
image_sizes[i] = new_inx
|
269
|
+
|
270
|
+
return torch.LongTensor(padded_ids), torch.LongTensor(attention_mask), image_sizes
|
271
|
+
|
272
|
+
def process_mllm_input(self, mllm_inputs, target_img_size):
|
273
|
+
num_tokens_for_output_images = []
|
274
|
+
for img_size in target_img_size:
|
275
|
+
num_tokens_for_output_images.append(img_size[0] * img_size[1] // 16 // 16)
|
276
|
+
|
277
|
+
pixel_values, image_sizes = [], {}
|
278
|
+
b_inx = 0
|
279
|
+
for x in mllm_inputs:
|
280
|
+
if x["pixel_values"] is not None:
|
281
|
+
pixel_values.extend(x["pixel_values"])
|
282
|
+
for size in x["image_sizes"]:
|
283
|
+
if b_inx not in image_sizes:
|
284
|
+
image_sizes[b_inx] = [size]
|
285
|
+
else:
|
286
|
+
image_sizes[b_inx].append(size)
|
287
|
+
b_inx += 1
|
288
|
+
pixel_values = [x.unsqueeze(0) for x in pixel_values]
|
289
|
+
|
290
|
+
input_ids = [x["input_ids"] for x in mllm_inputs]
|
291
|
+
padded_input_ids, attention_mask, image_sizes = self.pad_input_ids(input_ids, image_sizes)
|
292
|
+
position_ids = self.create_position(attention_mask, num_tokens_for_output_images)
|
293
|
+
attention_mask, padding_images = self.create_mask(attention_mask, num_tokens_for_output_images)
|
294
|
+
attention_mask = self.adjust_attention_for_input_images(attention_mask, image_sizes)
|
295
|
+
|
296
|
+
return padded_input_ids, position_ids, attention_mask, padding_images, pixel_values, image_sizes
|
297
|
+
|
298
|
+
def __call__(self, features):
|
299
|
+
mllm_inputs = [f[0] for f in features]
|
300
|
+
cfg_mllm_inputs = [f[1] for f in features]
|
301
|
+
img_cfg_mllm_input = [f[2] for f in features]
|
302
|
+
target_img_size = [f[3] for f in features]
|
303
|
+
|
304
|
+
if img_cfg_mllm_input[0] is not None:
|
305
|
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs + img_cfg_mllm_input
|
306
|
+
target_img_size = target_img_size + target_img_size + target_img_size
|
307
|
+
else:
|
308
|
+
mllm_inputs = mllm_inputs + cfg_mllm_inputs
|
309
|
+
target_img_size = target_img_size + target_img_size
|
310
|
+
|
311
|
+
(
|
312
|
+
all_padded_input_ids,
|
313
|
+
all_position_ids,
|
314
|
+
all_attention_mask,
|
315
|
+
all_padding_images,
|
316
|
+
all_pixel_values,
|
317
|
+
all_image_sizes,
|
318
|
+
) = self.process_mllm_input(mllm_inputs, target_img_size)
|
319
|
+
|
320
|
+
data = {
|
321
|
+
"input_ids": all_padded_input_ids,
|
322
|
+
"attention_mask": all_attention_mask,
|
323
|
+
"position_ids": all_position_ids,
|
324
|
+
"input_pixel_values": all_pixel_values,
|
325
|
+
"input_image_sizes": all_image_sizes,
|
326
|
+
}
|
327
|
+
return data
|
@@ -1,5 +1,5 @@
|
|
1
1
|
# coding=utf-8
|
2
|
-
# Copyright
|
2
|
+
# Copyright 2025 The HuggingFace Inc. team.
|
3
3
|
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
4
|
#
|
5
5
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -61,7 +61,7 @@ class OnnxRuntimeModel:
|
|
61
61
|
return self.model.run(None, inputs)
|
62
62
|
|
63
63
|
@staticmethod
|
64
|
-
def load_model(path: Union[str, Path], provider=None, sess_options=None):
|
64
|
+
def load_model(path: Union[str, Path], provider=None, sess_options=None, provider_options=None):
|
65
65
|
"""
|
66
66
|
Loads an ONNX Inference session with an ExecutionProvider. Default provider is `CPUExecutionProvider`
|
67
67
|
|
@@ -75,7 +75,9 @@ class OnnxRuntimeModel:
|
|
75
75
|
logger.info("No onnxruntime provider specified, using CPUExecutionProvider")
|
76
76
|
provider = "CPUExecutionProvider"
|
77
77
|
|
78
|
-
return ort.InferenceSession(
|
78
|
+
return ort.InferenceSession(
|
79
|
+
path, providers=[provider], sess_options=sess_options, provider_options=provider_options
|
80
|
+
)
|
79
81
|
|
80
82
|
def _save_pretrained(self, save_directory: Union[str, Path], file_name: Optional[str] = None, **kwargs):
|
81
83
|
"""
|
@@ -158,7 +158,7 @@ class PAGMixin:
|
|
158
158
|
),
|
159
159
|
):
|
160
160
|
r"""
|
161
|
-
Set the
|
161
|
+
Set the self-attention layers to apply PAG. Raise ValueError if the input is invalid.
|
162
162
|
|
163
163
|
Args:
|
164
164
|
pag_applied_layers (`str` or `List[str]`):
|
@@ -30,6 +30,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
|
30
30
|
from ...schedulers import KarrasDiffusionSchedulers
|
31
31
|
from ...utils import (
|
32
32
|
USE_PEFT_BACKEND,
|
33
|
+
is_torch_xla_available,
|
33
34
|
logging,
|
34
35
|
replace_example_docstring,
|
35
36
|
scale_lora_layers,
|
@@ -42,6 +43,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
42
43
|
from .pag_utils import PAGMixin
|
43
44
|
|
44
45
|
|
46
|
+
if is_torch_xla_available():
|
47
|
+
import torch_xla.core.xla_model as xm
|
48
|
+
|
49
|
+
XLA_AVAILABLE = True
|
50
|
+
else:
|
51
|
+
XLA_AVAILABLE = False
|
52
|
+
|
45
53
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
46
54
|
|
47
55
|
|
@@ -251,7 +259,7 @@ class StableDiffusionControlNetPAGPipeline(
|
|
251
259
|
feature_extractor=feature_extractor,
|
252
260
|
image_encoder=image_encoder,
|
253
261
|
)
|
254
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
262
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
255
263
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
256
264
|
self.control_image_processor = VaeImageProcessor(
|
257
265
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
@@ -1293,6 +1301,9 @@ class StableDiffusionControlNetPAGPipeline(
|
|
1293
1301
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1294
1302
|
progress_bar.update()
|
1295
1303
|
|
1304
|
+
if XLA_AVAILABLE:
|
1305
|
+
xm.mark_step()
|
1306
|
+
|
1296
1307
|
# If we do sequential model offloading, let's offload unet and controlnet
|
1297
1308
|
# manually for max memory savings
|
1298
1309
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
@@ -31,6 +31,7 @@ from ...models.lora import adjust_lora_scale_text_encoder
|
|
31
31
|
from ...schedulers import KarrasDiffusionSchedulers
|
32
32
|
from ...utils import (
|
33
33
|
USE_PEFT_BACKEND,
|
34
|
+
is_torch_xla_available,
|
34
35
|
logging,
|
35
36
|
replace_example_docstring,
|
36
37
|
scale_lora_layers,
|
@@ -43,6 +44,13 @@ from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
|
|
43
44
|
from .pag_utils import PAGMixin
|
44
45
|
|
45
46
|
|
47
|
+
if is_torch_xla_available():
|
48
|
+
import torch_xla.core.xla_model as xm
|
49
|
+
|
50
|
+
XLA_AVAILABLE = True
|
51
|
+
else:
|
52
|
+
XLA_AVAILABLE = False
|
53
|
+
|
46
54
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47
55
|
|
48
56
|
|
@@ -228,7 +236,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
|
228
236
|
feature_extractor=feature_extractor,
|
229
237
|
image_encoder=image_encoder,
|
230
238
|
)
|
231
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
239
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
232
240
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
233
241
|
self.mask_processor = VaeImageProcessor(
|
234
242
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True
|
@@ -596,7 +604,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
|
596
604
|
if padding_mask_crop is not None:
|
597
605
|
if not isinstance(image, PIL.Image.Image):
|
598
606
|
raise ValueError(
|
599
|
-
f"The image should be a PIL image when inpainting mask crop, but is of type
|
607
|
+
f"The image should be a PIL image when inpainting mask crop, but is of type {type(image)}."
|
600
608
|
)
|
601
609
|
if not isinstance(mask_image, PIL.Image.Image):
|
602
610
|
raise ValueError(
|
@@ -604,7 +612,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
|
604
612
|
f" {type(mask_image)}."
|
605
613
|
)
|
606
614
|
if output_type != "pil":
|
607
|
-
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is
|
615
|
+
raise ValueError(f"The output type should be PIL when inpainting mask crop, but is {output_type}.")
|
608
616
|
|
609
617
|
# `prompt` needs more sophisticated handling when there are multiple
|
610
618
|
# conditionings.
|
@@ -1332,7 +1340,7 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
|
1332
1340
|
f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
|
1333
1341
|
f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
|
1334
1342
|
f" `num_channels_mask`: {num_channels_mask} + `num_channels_masked_image`: {num_channels_masked_image}"
|
1335
|
-
f" = {num_channels_latents+num_channels_masked_image+num_channels_mask}. Please verify the config of"
|
1343
|
+
f" = {num_channels_latents + num_channels_masked_image + num_channels_mask}. Please verify the config of"
|
1336
1344
|
" `pipeline.unet` or your `mask_image` or `image` input."
|
1337
1345
|
)
|
1338
1346
|
elif num_channels_unet != 4:
|
@@ -1505,6 +1513,9 @@ class StableDiffusionControlNetPAGInpaintPipeline(
|
|
1505
1513
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1506
1514
|
progress_bar.update()
|
1507
1515
|
|
1516
|
+
if XLA_AVAILABLE:
|
1517
|
+
xm.mark_step()
|
1518
|
+
|
1508
1519
|
# If we do sequential model offloading, let's offload unet and controlnet
|
1509
1520
|
# manually for max memory savings
|
1510
1521
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
|
|
62
62
|
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
63
63
|
|
64
64
|
|
65
|
+
from ...utils import is_torch_xla_available
|
66
|
+
|
67
|
+
|
68
|
+
if is_torch_xla_available():
|
69
|
+
import torch_xla.core.xla_model as xm
|
70
|
+
|
71
|
+
XLA_AVAILABLE = True
|
72
|
+
else:
|
73
|
+
XLA_AVAILABLE = False
|
74
|
+
|
65
75
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
66
76
|
|
67
77
|
|
@@ -280,7 +290,7 @@ class StableDiffusionXLControlNetPAGPipeline(
|
|
280
290
|
feature_extractor=feature_extractor,
|
281
291
|
image_encoder=image_encoder,
|
282
292
|
)
|
283
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
293
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
284
294
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
285
295
|
self.control_image_processor = VaeImageProcessor(
|
286
296
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
@@ -421,7 +431,9 @@ class StableDiffusionXLControlNetPAGPipeline(
|
|
421
431
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
422
432
|
|
423
433
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
424
|
-
pooled_prompt_embeds
|
434
|
+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
435
|
+
pooled_prompt_embeds = prompt_embeds[0]
|
436
|
+
|
425
437
|
if clip_skip is None:
|
426
438
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
427
439
|
else:
|
@@ -480,8 +492,10 @@ class StableDiffusionXLControlNetPAGPipeline(
|
|
480
492
|
uncond_input.input_ids.to(device),
|
481
493
|
output_hidden_states=True,
|
482
494
|
)
|
495
|
+
|
483
496
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
484
|
-
negative_pooled_prompt_embeds
|
497
|
+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
498
|
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
485
499
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
486
500
|
|
487
501
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
@@ -1560,6 +1574,9 @@ class StableDiffusionXLControlNetPAGPipeline(
|
|
1560
1574
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1561
1575
|
progress_bar.update()
|
1562
1576
|
|
1577
|
+
if XLA_AVAILABLE:
|
1578
|
+
xm.mark_step()
|
1579
|
+
|
1563
1580
|
if not output_type == "latent":
|
1564
1581
|
# make sure the VAE is in float32 mode, as it overflows in float16
|
1565
1582
|
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
@@ -62,6 +62,16 @@ if is_invisible_watermark_available():
|
|
62
62
|
from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
|
63
63
|
|
64
64
|
|
65
|
+
from ...utils import is_torch_xla_available
|
66
|
+
|
67
|
+
|
68
|
+
if is_torch_xla_available():
|
69
|
+
import torch_xla.core.xla_model as xm
|
70
|
+
|
71
|
+
XLA_AVAILABLE = True
|
72
|
+
else:
|
73
|
+
XLA_AVAILABLE = False
|
74
|
+
|
65
75
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
66
76
|
|
67
77
|
|
@@ -270,7 +280,7 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
|
270
280
|
feature_extractor=feature_extractor,
|
271
281
|
image_encoder=image_encoder,
|
272
282
|
)
|
273
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
283
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
274
284
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True)
|
275
285
|
self.control_image_processor = VaeImageProcessor(
|
276
286
|
vae_scale_factor=self.vae_scale_factor, do_convert_rgb=True, do_normalize=False
|
@@ -413,7 +423,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
|
413
423
|
prompt_embeds = text_encoder(text_input_ids.to(device), output_hidden_states=True)
|
414
424
|
|
415
425
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
416
|
-
pooled_prompt_embeds
|
426
|
+
if pooled_prompt_embeds is None and prompt_embeds[0].ndim == 2:
|
427
|
+
pooled_prompt_embeds = prompt_embeds[0]
|
428
|
+
|
417
429
|
if clip_skip is None:
|
418
430
|
prompt_embeds = prompt_embeds.hidden_states[-2]
|
419
431
|
else:
|
@@ -472,8 +484,10 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
|
472
484
|
uncond_input.input_ids.to(device),
|
473
485
|
output_hidden_states=True,
|
474
486
|
)
|
487
|
+
|
475
488
|
# We are only ALWAYS interested in the pooled output of the final text encoder
|
476
|
-
negative_pooled_prompt_embeds
|
489
|
+
if negative_pooled_prompt_embeds is None and negative_prompt_embeds[0].ndim == 2:
|
490
|
+
negative_pooled_prompt_embeds = negative_prompt_embeds[0]
|
477
491
|
negative_prompt_embeds = negative_prompt_embeds.hidden_states[-2]
|
478
492
|
|
479
493
|
negative_prompt_embeds_list.append(negative_prompt_embeds)
|
@@ -1626,6 +1640,9 @@ class StableDiffusionXLControlNetPAGImg2ImgPipeline(
|
|
1626
1640
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
1627
1641
|
progress_bar.update()
|
1628
1642
|
|
1643
|
+
if XLA_AVAILABLE:
|
1644
|
+
xm.mark_step()
|
1645
|
+
|
1629
1646
|
# If we do sequential model offloading, let's offload unet and controlnet
|
1630
1647
|
# manually for max memory savings
|
1631
1648
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
@@ -245,9 +245,7 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
|
|
245
245
|
" checker. If you do not want to use the safety checker, you can pass `'safety_checker=None'` instead."
|
246
246
|
)
|
247
247
|
|
248
|
-
self.vae_scale_factor = (
|
249
|
-
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
250
|
-
)
|
248
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
251
249
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
252
250
|
self.register_to_config(requires_safety_checker=requires_safety_checker)
|
253
251
|
self.default_sample_size = (
|
@@ -202,12 +202,14 @@ class KolorsPAGPipeline(
|
|
202
202
|
feature_extractor=feature_extractor,
|
203
203
|
)
|
204
204
|
self.register_to_config(force_zeros_for_empty_prompt=force_zeros_for_empty_prompt)
|
205
|
-
self.vae_scale_factor = (
|
206
|
-
2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
|
207
|
-
)
|
205
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
208
206
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
209
207
|
|
210
|
-
self.default_sample_size =
|
208
|
+
self.default_sample_size = (
|
209
|
+
self.unet.config.sample_size
|
210
|
+
if hasattr(self, "unet") and self.unet is not None and hasattr(self.unet.config, "sample_size")
|
211
|
+
else 128
|
212
|
+
)
|
211
213
|
|
212
214
|
self.set_pag_applied_layers(pag_applied_layers)
|
213
215
|
|
@@ -29,6 +29,7 @@ from ...utils import (
|
|
29
29
|
deprecate,
|
30
30
|
is_bs4_available,
|
31
31
|
is_ftfy_available,
|
32
|
+
is_torch_xla_available,
|
32
33
|
logging,
|
33
34
|
replace_example_docstring,
|
34
35
|
)
|
@@ -43,8 +44,16 @@ from ..pixart_alpha.pipeline_pixart_sigma import ASPECT_RATIO_2048_BIN
|
|
43
44
|
from .pag_utils import PAGMixin
|
44
45
|
|
45
46
|
|
47
|
+
if is_torch_xla_available():
|
48
|
+
import torch_xla.core.xla_model as xm
|
49
|
+
|
50
|
+
XLA_AVAILABLE = True
|
51
|
+
else:
|
52
|
+
XLA_AVAILABLE = False
|
53
|
+
|
46
54
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
47
55
|
|
56
|
+
|
48
57
|
if is_bs4_available():
|
49
58
|
from bs4 import BeautifulSoup
|
50
59
|
|
@@ -172,7 +181,7 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
|
|
172
181
|
tokenizer=tokenizer, text_encoder=text_encoder, vae=vae, transformer=transformer, scheduler=scheduler
|
173
182
|
)
|
174
183
|
|
175
|
-
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
|
184
|
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
176
185
|
self.image_processor = PixArtImageProcessor(vae_scale_factor=self.vae_scale_factor)
|
177
186
|
|
178
187
|
self.set_pag_applied_layers(pag_applied_layers)
|
@@ -798,10 +807,11 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
|
|
798
807
|
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
799
808
|
# This would be a good case for the `match` statement (Python 3.10+)
|
800
809
|
is_mps = latent_model_input.device.type == "mps"
|
810
|
+
is_npu = latent_model_input.device.type == "npu"
|
801
811
|
if isinstance(current_timestep, float):
|
802
|
-
dtype = torch.float32 if is_mps else torch.float64
|
812
|
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
803
813
|
else:
|
804
|
-
dtype = torch.int32 if is_mps else torch.int64
|
814
|
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
805
815
|
current_timestep = torch.tensor([current_timestep], dtype=dtype, device=latent_model_input.device)
|
806
816
|
elif len(current_timestep.shape) == 0:
|
807
817
|
current_timestep = current_timestep[None].to(latent_model_input.device)
|
@@ -843,6 +853,9 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
|
|
843
853
|
step_idx = i // getattr(self.scheduler, "order", 1)
|
844
854
|
callback(step_idx, t, latents)
|
845
855
|
|
856
|
+
if XLA_AVAILABLE:
|
857
|
+
xm.mark_step()
|
858
|
+
|
846
859
|
if not output_type == "latent":
|
847
860
|
image = self.vae.decode(latents / self.vae.config.scaling_factor, return_dict=False)[0]
|
848
861
|
if use_resolution_binning:
|