diffusers 0.27.0__py3-none-any.whl → 0.32.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +233 -6
- diffusers/callbacks.py +209 -0
- diffusers/commands/env.py +102 -6
- diffusers/configuration_utils.py +45 -16
- diffusers/dependency_versions_table.py +4 -3
- diffusers/image_processor.py +434 -110
- diffusers/loaders/__init__.py +42 -9
- diffusers/loaders/ip_adapter.py +626 -36
- diffusers/loaders/lora_base.py +900 -0
- diffusers/loaders/lora_conversion_utils.py +991 -125
- diffusers/loaders/lora_pipeline.py +3812 -0
- diffusers/loaders/peft.py +571 -7
- diffusers/loaders/single_file.py +405 -173
- diffusers/loaders/single_file_model.py +385 -0
- diffusers/loaders/single_file_utils.py +1783 -713
- diffusers/loaders/textual_inversion.py +41 -23
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +464 -540
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +76 -7
- diffusers/models/activations.py +65 -10
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +605 -18
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +4304 -687
- diffusers/models/autoencoders/__init__.py +8 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +15 -17
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +110 -28
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1482 -0
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +19 -24
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +21 -18
- diffusers/models/autoencoders/consistency_decoder_vae.py +45 -20
- diffusers/models/autoencoders/vae.py +41 -29
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet.py +47 -800
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +68 -0
- diffusers/models/controlnet_sparsectrl.py +116 -0
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +9 -9
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/controlnets/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/controlnets/controlnet_xs.py +1946 -0
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/downsampling.py +85 -18
- diffusers/models/embeddings.py +1856 -158
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +480 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +2 -7
- diffusers/models/modeling_outputs.py +14 -0
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +611 -146
- diffusers/models/normalization.py +361 -20
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformers/__init__.py +16 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +544 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +542 -0
- diffusers/models/transformers/dit_transformer_2d.py +240 -0
- diffusers/models/transformers/dual_transformer_2d.py +9 -8
- diffusers/models/transformers/hunyuan_transformer_2d.py +578 -0
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +445 -0
- diffusers/models/transformers/prior_transformer.py +13 -13
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +297 -187
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +593 -0
- diffusers/models/transformers/transformer_hunyuan_video.py +791 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +461 -0
- diffusers/models/transformers/transformer_temporal.py +21 -19
- diffusers/models/unets/unet_1d.py +8 -8
- diffusers/models/unets/unet_1d_blocks.py +31 -31
- diffusers/models/unets/unet_2d.py +17 -10
- diffusers/models/unets/unet_2d_blocks.py +225 -149
- diffusers/models/unets/unet_2d_condition.py +50 -53
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +192 -1057
- diffusers/models/unets/unet_3d_condition.py +22 -27
- diffusers/models/unets/unet_i2vgen_xl.py +22 -18
- diffusers/models/unets/unet_kandinsky3.py +2 -2
- diffusers/models/unets/unet_motion_model.py +1413 -89
- diffusers/models/unets/unet_spatio_temporal_condition.py +40 -16
- diffusers/models/unets/unet_stable_cascade.py +19 -18
- diffusers/models/unets/uvit_2d.py +2 -2
- diffusers/models/upsampling.py +95 -26
- diffusers/models/vq_model.py +12 -164
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +202 -3
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +8 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +122 -109
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1106 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1288 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1010 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +236 -180
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +58 -39
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +121 -36
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +584 -0
- diffusers/pipelines/auto_pipeline.py +196 -28
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +6 -6
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/cogvideo/__init__.py +54 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +772 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +885 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +851 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +6 -6
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -182
- diffusers/pipelines/controlnet/pipeline_controlnet.py +134 -87
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +93 -77
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +88 -197
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +136 -90
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +176 -80
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +125 -89
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1060 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +57 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1133 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +916 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1111 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +16 -30
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +20 -35
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +23 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +22 -38
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +25 -41
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +19 -34
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +70 -30
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +48 -25
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +21 -20
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +27 -29
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +33 -27
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +33 -23
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +36 -30
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +102 -69
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +7 -4
- diffusers/pipelines/flux/__init__.py +69 -0
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +957 -0
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +37 -0
- diffusers/pipelines/free_init_utils.py +41 -38
- diffusers/pipelines/free_noise_utils.py +596 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/__init__.py +48 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +916 -0
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +32 -29
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +34 -31
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +22 -35
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +26 -37
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1250 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +338 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +82 -62
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +77 -60
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +12 -12
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +80 -74
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +85 -76
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +890 -0
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +576 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/pag/__init__.py +80 -0
- diffusers/pipelines/pag/pag_utils.py +243 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1328 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1610 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +969 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +865 -0
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1062 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +994 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +866 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1345 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1776 -0
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +74 -164
- diffusers/pipelines/pipeline_flax_utils.py +5 -10
- diffusers/pipelines/pipeline_loading_utils.py +515 -53
- diffusers/pipelines/pipeline_utils.py +411 -222
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +76 -93
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +873 -0
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +27 -23
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +756 -0
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +71 -25
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +35 -34
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +20 -11
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +145 -79
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +43 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +100 -68
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +109 -201
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +131 -32
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +247 -87
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +35 -27
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +49 -42
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_3/__init__.py +54 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +1140 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +1036 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1250 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +29 -20
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +59 -58
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +31 -25
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +38 -22
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +30 -24
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -23
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +107 -67
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +316 -69
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +98 -30
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +121 -83
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +161 -105
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +142 -218
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -29
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +69 -39
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +105 -74
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +29 -49
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +32 -93
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +37 -25
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +54 -40
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +12 -12
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +29 -28
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +6 -8
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +15 -14
- diffusers/{models/dual_transformer_2d.py → quantizers/__init__.py} +2 -6
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
- diffusers/schedulers/__init__.py +12 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +27 -26
- diffusers/schedulers/scheduling_ddim_cogvideox.py +452 -0
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +32 -31
- diffusers/schedulers/scheduling_ddpm.py +27 -30
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +33 -36
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +150 -50
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +221 -84
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +158 -52
- diffusers/schedulers/scheduling_dpmsolver_sde.py +153 -34
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +275 -86
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +81 -57
- diffusers/schedulers/scheduling_edm_euler.py +62 -39
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +30 -29
- diffusers/schedulers/scheduling_euler_discrete.py +255 -74
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +458 -0
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +320 -0
- diffusers/schedulers/scheduling_heun_discrete.py +174 -46
- diffusers/schedulers/scheduling_ipndm.py +9 -9
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +138 -29
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +132 -26
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +23 -29
- diffusers/schedulers/scheduling_lms_discrete.py +105 -28
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +21 -21
- diffusers/schedulers/scheduling_sasolver.py +157 -60
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +41 -36
- diffusers/schedulers/scheduling_unclip.py +19 -16
- diffusers/schedulers/scheduling_unipc_multistep.py +243 -47
- diffusers/schedulers/scheduling_utils.py +12 -5
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +214 -30
- diffusers/utils/__init__.py +17 -1
- diffusers/utils/constants.py +3 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +592 -7
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +1001 -71
- diffusers/utils/dynamic_modules_utils.py +34 -29
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +131 -17
- diffusers/utils/import_utils.py +210 -8
- diffusers/utils/loading_utils.py +118 -5
- diffusers/utils/logging.py +4 -2
- diffusers/utils/peft_utils.py +37 -7
- diffusers/utils/state_dict_utils.py +13 -2
- diffusers/utils/testing_utils.py +193 -11
- diffusers/utils/torch_utils.py +4 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/METADATA +82 -91
- diffusers-0.32.2.dist-info/RECORD +550 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1349
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- diffusers-0.27.0.dist-info/RECORD +0 -399
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/LICENSE +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.0.dist-info → diffusers-0.32.2.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
|
+
import enum
|
16
17
|
import fnmatch
|
17
18
|
import importlib
|
18
19
|
import inspect
|
@@ -21,7 +22,7 @@ import re
|
|
21
22
|
import sys
|
22
23
|
from dataclasses import dataclass
|
23
24
|
from pathlib import Path
|
24
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
25
|
+
from typing import Any, Callable, Dict, List, Optional, Union, get_args, get_origin
|
25
26
|
|
26
27
|
import numpy as np
|
27
28
|
import PIL.Image
|
@@ -43,38 +44,45 @@ from .. import __version__
|
|
43
44
|
from ..configuration_utils import ConfigMixin
|
44
45
|
from ..models import AutoencoderKL
|
45
46
|
from ..models.attention_processor import FusedAttnProcessor2_0
|
46
|
-
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT
|
47
|
+
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
48
|
+
from ..quantizers.bitsandbytes.utils import _check_bnb_status
|
47
49
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
48
50
|
from ..utils import (
|
49
51
|
CONFIG_NAME,
|
50
52
|
DEPRECATED_REVISION_ARGS,
|
51
53
|
BaseOutput,
|
52
54
|
PushToHubMixin,
|
53
|
-
deprecate,
|
54
55
|
is_accelerate_available,
|
55
56
|
is_accelerate_version,
|
56
57
|
is_torch_npu_available,
|
57
58
|
is_torch_version,
|
59
|
+
is_transformers_version,
|
58
60
|
logging,
|
59
61
|
numpy_to_pil,
|
60
62
|
)
|
61
|
-
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
63
|
+
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
62
64
|
from ..utils.torch_utils import is_compiled_module
|
63
65
|
|
64
66
|
|
65
67
|
if is_torch_npu_available():
|
66
68
|
import torch_npu # noqa: F401
|
67
69
|
|
68
|
-
|
69
70
|
from .pipeline_loading_utils import (
|
70
71
|
ALL_IMPORTABLE_CLASSES,
|
71
72
|
CONNECTED_PIPES_KEYS,
|
72
73
|
CUSTOM_PIPELINE_FILE_NAME,
|
73
74
|
LOADABLE_CLASSES,
|
74
75
|
_fetch_class_library_tuple,
|
76
|
+
_get_custom_components_and_folders,
|
77
|
+
_get_custom_pipeline_class,
|
78
|
+
_get_final_device_map,
|
79
|
+
_get_ignore_patterns,
|
75
80
|
_get_pipeline_class,
|
81
|
+
_identify_model_variants,
|
82
|
+
_maybe_raise_warning_for_inpainting,
|
83
|
+
_resolve_custom_pipeline_and_cls,
|
76
84
|
_unwrap_model,
|
77
|
-
|
85
|
+
_update_init_kwargs_with_connected_pipeline,
|
78
86
|
load_sub_model,
|
79
87
|
maybe_raise_or_warn,
|
80
88
|
variant_compatible_siblings,
|
@@ -90,6 +98,8 @@ LIBRARIES = []
|
|
90
98
|
for library in LOADABLE_CLASSES:
|
91
99
|
LIBRARIES.append(library)
|
92
100
|
|
101
|
+
SUPPORTED_DEVICE_MAP = ["balanced"]
|
102
|
+
|
93
103
|
logger = logging.get_logger(__name__)
|
94
104
|
|
95
105
|
|
@@ -140,6 +150,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
140
150
|
|
141
151
|
config_name = "model_index.json"
|
142
152
|
model_cpu_offload_seq = None
|
153
|
+
hf_device_map = None
|
143
154
|
_optional_components = []
|
144
155
|
_exclude_from_cpu_offload = []
|
145
156
|
_load_connected_pipes = False
|
@@ -180,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
180
191
|
save_directory: Union[str, os.PathLike],
|
181
192
|
safe_serialization: bool = True,
|
182
193
|
variant: Optional[str] = None,
|
194
|
+
max_shard_size: Optional[Union[int, str]] = None,
|
183
195
|
push_to_hub: bool = False,
|
184
196
|
**kwargs,
|
185
197
|
):
|
@@ -195,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
195
207
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
196
208
|
variant (`str`, *optional*):
|
197
209
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
210
|
+
max_shard_size (`int` or `str`, defaults to `None`):
|
211
|
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
212
|
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
213
|
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
214
|
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
215
|
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
216
|
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
198
217
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
199
218
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
200
219
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
@@ -210,7 +229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
210
229
|
|
211
230
|
if push_to_hub:
|
212
231
|
commit_message = kwargs.pop("commit_message", None)
|
213
|
-
private = kwargs.pop("private",
|
232
|
+
private = kwargs.pop("private", None)
|
214
233
|
create_pr = kwargs.pop("create_pr", False)
|
215
234
|
token = kwargs.pop("token", None)
|
216
235
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
@@ -269,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
269
288
|
save_method_signature = inspect.signature(save_method)
|
270
289
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
271
290
|
save_method_accept_variant = "variant" in save_method_signature.parameters
|
291
|
+
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
272
292
|
|
273
293
|
save_kwargs = {}
|
274
294
|
if save_method_accept_safe:
|
275
295
|
save_kwargs["safe_serialization"] = safe_serialization
|
276
296
|
if save_method_accept_variant:
|
277
297
|
save_kwargs["variant"] = variant
|
298
|
+
if save_method_accept_max_shard_size and max_shard_size is not None:
|
299
|
+
# max_shard_size is expected to not be None in ModelMixin
|
300
|
+
save_kwargs["max_shard_size"] = max_shard_size
|
278
301
|
|
279
302
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
280
303
|
|
@@ -365,14 +388,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
365
388
|
)
|
366
389
|
|
367
390
|
device = device or device_arg
|
391
|
+
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
|
368
392
|
|
369
393
|
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
370
394
|
def module_is_sequentially_offloaded(module):
|
371
395
|
if not is_accelerate_available() or is_accelerate_version("<", "0.14.0"):
|
372
396
|
return False
|
373
397
|
|
374
|
-
return hasattr(module, "_hf_hook") and
|
375
|
-
module._hf_hook,
|
398
|
+
return hasattr(module, "_hf_hook") and (
|
399
|
+
isinstance(module._hf_hook, accelerate.hooks.AlignDevicesHook)
|
400
|
+
or hasattr(module._hf_hook, "hooks")
|
401
|
+
and isinstance(module._hf_hook.hooks[0], accelerate.hooks.AlignDevicesHook)
|
376
402
|
)
|
377
403
|
|
378
404
|
def module_is_offloaded(module):
|
@@ -385,9 +411,21 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
385
411
|
pipeline_is_sequentially_offloaded = any(
|
386
412
|
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
387
413
|
)
|
388
|
-
if
|
414
|
+
if device and torch.device(device).type == "cuda":
|
415
|
+
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
416
|
+
raise ValueError(
|
417
|
+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
418
|
+
)
|
419
|
+
# PR: https://github.com/huggingface/accelerate/pull/3223/
|
420
|
+
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
|
421
|
+
raise ValueError(
|
422
|
+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
423
|
+
)
|
424
|
+
|
425
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
426
|
+
if is_pipeline_device_mapped:
|
389
427
|
raise ValueError(
|
390
|
-
"It seems like you have activated
|
428
|
+
"It seems like you have activated a device mapping strategy on the pipeline which doesn't allow explicit device placement using `to()`. You can call `reset_device_map()` first and then call `to()`."
|
391
429
|
)
|
392
430
|
|
393
431
|
# Display a warning in this case (the operation succeeds but the benefits are lost)
|
@@ -403,18 +441,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
403
441
|
|
404
442
|
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
405
443
|
for module in modules:
|
406
|
-
|
444
|
+
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
407
445
|
|
408
|
-
if
|
446
|
+
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
|
409
447
|
logger.warning(
|
410
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not
|
448
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
411
449
|
)
|
412
450
|
|
413
|
-
if
|
451
|
+
if is_loaded_in_8bit_bnb and device is not None:
|
414
452
|
logger.warning(
|
415
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {
|
453
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
416
454
|
)
|
417
|
-
|
455
|
+
|
456
|
+
# This can happen for `transformer` models. CPU placement was added in
|
457
|
+
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
458
|
+
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
459
|
+
module.to(device=device)
|
460
|
+
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
|
418
461
|
module.to(device, dtype)
|
419
462
|
|
420
463
|
if (
|
@@ -520,9 +563,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
520
563
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
521
564
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
522
565
|
is not used.
|
523
|
-
|
524
|
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
525
|
-
incompletely downloaded files are deleted.
|
566
|
+
|
526
567
|
proxies (`Dict[str, str]`, *optional*):
|
527
568
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
528
569
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -539,7 +580,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
539
580
|
allowed by Git.
|
540
581
|
custom_revision (`str`, *optional*):
|
541
582
|
The specific model version to use. It can be a branch name, a tag name, or a commit id similar to
|
542
|
-
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
|
583
|
+
`revision` when loading a custom pipeline from the Hub. Defaults to the latest stable 🤗 Diffusers
|
584
|
+
version.
|
543
585
|
mirror (`str`, *optional*):
|
544
586
|
Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
|
545
587
|
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
@@ -610,8 +652,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
610
652
|
>>> pipeline.scheduler = scheduler
|
611
653
|
```
|
612
654
|
"""
|
655
|
+
# Copy the kwargs to re-use during loading connected pipeline.
|
656
|
+
kwargs_copied = kwargs.copy()
|
657
|
+
|
613
658
|
cache_dir = kwargs.pop("cache_dir", None)
|
614
|
-
resume_download = kwargs.pop("resume_download", False)
|
615
659
|
force_download = kwargs.pop("force_download", False)
|
616
660
|
proxies = kwargs.pop("proxies", None)
|
617
661
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -642,18 +686,35 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
642
686
|
" install accelerate\n```\n."
|
643
687
|
)
|
644
688
|
|
689
|
+
if low_cpu_mem_usage is True and not is_torch_version(">=", "1.9.0"):
|
690
|
+
raise NotImplementedError(
|
691
|
+
"Low memory initialization requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
692
|
+
" `low_cpu_mem_usage=False`."
|
693
|
+
)
|
694
|
+
|
645
695
|
if device_map is not None and not is_torch_version(">=", "1.9.0"):
|
646
696
|
raise NotImplementedError(
|
647
697
|
"Loading and dispatching requires torch >= 1.9.0. Please either update your PyTorch version or set"
|
648
698
|
" `device_map=None`."
|
649
699
|
)
|
650
700
|
|
651
|
-
if
|
701
|
+
if device_map is not None and not is_accelerate_available():
|
652
702
|
raise NotImplementedError(
|
653
|
-
"
|
654
|
-
" `low_cpu_mem_usage=False`."
|
703
|
+
"Using `device_map` requires the `accelerate` library. Please install it using: `pip install accelerate`."
|
655
704
|
)
|
656
705
|
|
706
|
+
if device_map is not None and not isinstance(device_map, str):
|
707
|
+
raise ValueError("`device_map` must be a string.")
|
708
|
+
|
709
|
+
if device_map is not None and device_map not in SUPPORTED_DEVICE_MAP:
|
710
|
+
raise NotImplementedError(
|
711
|
+
f"{device_map} not supported. Supported strategies are: {', '.join(SUPPORTED_DEVICE_MAP)}"
|
712
|
+
)
|
713
|
+
|
714
|
+
if device_map is not None and device_map in SUPPORTED_DEVICE_MAP:
|
715
|
+
if is_accelerate_version("<", "0.28.0"):
|
716
|
+
raise NotImplementedError("Device placement requires `accelerate` version `0.28.0` or later.")
|
717
|
+
|
657
718
|
if low_cpu_mem_usage is False and device_map is not None:
|
658
719
|
raise ValueError(
|
659
720
|
f"You cannot set `low_cpu_mem_usage` to False while using device_map={device_map} for loading and"
|
@@ -671,7 +732,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
671
732
|
cached_folder = cls.download(
|
672
733
|
pretrained_model_name_or_path,
|
673
734
|
cache_dir=cache_dir,
|
674
|
-
resume_download=resume_download,
|
675
735
|
force_download=force_download,
|
676
736
|
proxies=proxies,
|
677
737
|
local_files_only=local_files_only,
|
@@ -689,39 +749,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
689
749
|
else:
|
690
750
|
cached_folder = pretrained_model_name_or_path
|
691
751
|
|
752
|
+
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
|
753
|
+
# a warning if detected.
|
754
|
+
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
|
755
|
+
warn_msg = (
|
756
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
757
|
+
"Please check your files carefully:\n\n"
|
758
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
759
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
760
|
+
"If you find any files in the deprecated format:\n"
|
761
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
762
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
763
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
764
|
+
)
|
765
|
+
logger.warning(warn_msg)
|
766
|
+
|
692
767
|
config_dict = cls.load_config(cached_folder)
|
693
768
|
|
694
769
|
# pop out "_ignore_files" as it is only needed for download
|
695
770
|
config_dict.pop("_ignore_files", None)
|
696
771
|
|
697
772
|
# 2. Define which model components should load variants
|
698
|
-
# We retrieve the information by matching whether variant
|
699
|
-
#
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
705
|
-
variant_exists = is_folder and any(
|
706
|
-
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
707
|
-
)
|
708
|
-
if variant_exists:
|
709
|
-
model_variants[folder] = variant
|
773
|
+
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
|
774
|
+
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
|
775
|
+
# with variant being `"fp16"`.
|
776
|
+
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
|
777
|
+
if len(model_variants) == 0 and variant is not None:
|
778
|
+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
779
|
+
raise ValueError(error_message)
|
710
780
|
|
711
781
|
# 3. Load the pipeline class, if using custom module then load it from the hub
|
712
782
|
# if we load from explicit class, let's use it
|
713
|
-
custom_class_name =
|
714
|
-
|
715
|
-
|
716
|
-
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
717
|
-
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
718
|
-
):
|
719
|
-
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
720
|
-
custom_class_name = config_dict["_class_name"][1]
|
721
|
-
|
783
|
+
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
|
784
|
+
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
|
785
|
+
)
|
722
786
|
pipeline_class = _get_pipeline_class(
|
723
787
|
cls,
|
724
|
-
config_dict,
|
788
|
+
config=config_dict,
|
725
789
|
load_connected_pipeline=load_connected_pipeline,
|
726
790
|
custom_pipeline=custom_pipeline,
|
727
791
|
class_name=custom_class_name,
|
@@ -729,24 +793,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
729
793
|
revision=custom_revision,
|
730
794
|
)
|
731
795
|
|
796
|
+
if device_map is not None and pipeline_class._load_connected_pipes:
|
797
|
+
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
798
|
+
|
732
799
|
# DEPRECATED: To be removed in 1.0.0
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
|
739
|
-
|
740
|
-
deprecation_message = (
|
741
|
-
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
742
|
-
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
743
|
-
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
744
|
-
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
745
|
-
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
746
|
-
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
747
|
-
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
748
|
-
)
|
749
|
-
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
800
|
+
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
|
801
|
+
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
|
802
|
+
_maybe_raise_warning_for_inpainting(
|
803
|
+
pipeline_class=pipeline_class,
|
804
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
805
|
+
config=config_dict,
|
806
|
+
)
|
750
807
|
|
751
808
|
# 4. Define expected modules given pipeline signature
|
752
809
|
# and define non-None initialized modules (=`init_kwargs`)
|
@@ -755,9 +812,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
755
812
|
# in this case they are already instantiated in `kwargs`
|
756
813
|
# extract them here
|
757
814
|
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
815
|
+
expected_types = pipeline_class._get_signature_types()
|
758
816
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
759
817
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
760
|
-
|
761
818
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
762
819
|
|
763
820
|
# define init kwargs and make sure that optional component modules are filtered out
|
@@ -778,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
778
835
|
|
779
836
|
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
780
837
|
|
838
|
+
for key in init_dict.keys():
|
839
|
+
if key not in passed_class_obj:
|
840
|
+
continue
|
841
|
+
if "scheduler" in key:
|
842
|
+
continue
|
843
|
+
|
844
|
+
class_obj = passed_class_obj[key]
|
845
|
+
_expected_class_types = []
|
846
|
+
for expected_type in expected_types[key]:
|
847
|
+
if isinstance(expected_type, enum.EnumMeta):
|
848
|
+
_expected_class_types.extend(expected_type.__members__.keys())
|
849
|
+
else:
|
850
|
+
_expected_class_types.append(expected_type.__name__)
|
851
|
+
|
852
|
+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
853
|
+
if not _is_valid_type:
|
854
|
+
logger.warning(
|
855
|
+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
856
|
+
)
|
857
|
+
|
781
858
|
# Special case: safety_checker must be loaded separately when using `from_flax`
|
782
859
|
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
783
860
|
raise NotImplementedError(
|
@@ -795,17 +872,45 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
795
872
|
# import it here to avoid circular import
|
796
873
|
from diffusers import pipelines
|
797
874
|
|
798
|
-
# 6.
|
875
|
+
# 6. device map delegation
|
876
|
+
final_device_map = None
|
877
|
+
if device_map is not None:
|
878
|
+
final_device_map = _get_final_device_map(
|
879
|
+
device_map=device_map,
|
880
|
+
pipeline_class=pipeline_class,
|
881
|
+
passed_class_obj=passed_class_obj,
|
882
|
+
init_dict=init_dict,
|
883
|
+
library=library,
|
884
|
+
max_memory=max_memory,
|
885
|
+
torch_dtype=torch_dtype,
|
886
|
+
cached_folder=cached_folder,
|
887
|
+
force_download=force_download,
|
888
|
+
proxies=proxies,
|
889
|
+
local_files_only=local_files_only,
|
890
|
+
token=token,
|
891
|
+
revision=revision,
|
892
|
+
)
|
893
|
+
|
894
|
+
# 7. Load each module in the pipeline
|
895
|
+
current_device_map = None
|
799
896
|
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
800
|
-
#
|
897
|
+
# 7.1 device_map shenanigans
|
898
|
+
if final_device_map is not None and len(final_device_map) > 0:
|
899
|
+
component_device = final_device_map.get(name, None)
|
900
|
+
if component_device is not None:
|
901
|
+
current_device_map = {"": component_device}
|
902
|
+
else:
|
903
|
+
current_device_map = None
|
904
|
+
|
905
|
+
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
801
906
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
802
907
|
|
803
|
-
#
|
908
|
+
# 7.3 Define all importable classes
|
804
909
|
is_pipeline_module = hasattr(pipelines, library_name)
|
805
910
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
806
911
|
loaded_sub_model = None
|
807
912
|
|
808
|
-
#
|
913
|
+
# 7.4 Use passed sub model or load class_name from library_name
|
809
914
|
if name in passed_class_obj:
|
810
915
|
# if the model is in a pipeline module, then we load it from the pipeline
|
811
916
|
# check that passed_class_obj has correct parent class
|
@@ -826,7 +931,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
826
931
|
torch_dtype=torch_dtype,
|
827
932
|
provider=provider,
|
828
933
|
sess_options=sess_options,
|
829
|
-
device_map=
|
934
|
+
device_map=current_device_map,
|
830
935
|
max_memory=max_memory,
|
831
936
|
offload_folder=offload_folder,
|
832
937
|
offload_state_dict=offload_state_dict,
|
@@ -836,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
836
941
|
variant=variant,
|
837
942
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
838
943
|
cached_folder=cached_folder,
|
944
|
+
use_safetensors=use_safetensors,
|
839
945
|
)
|
840
946
|
logger.info(
|
841
947
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
@@ -843,57 +949,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
843
949
|
|
844
950
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
845
951
|
|
952
|
+
# 8. Handle connected pipelines.
|
846
953
|
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
847
|
-
|
848
|
-
|
849
|
-
|
850
|
-
|
851
|
-
|
852
|
-
|
853
|
-
|
854
|
-
"local_files_only": local_files_only,
|
855
|
-
"token": token,
|
856
|
-
"revision": revision,
|
857
|
-
"torch_dtype": torch_dtype,
|
858
|
-
"custom_pipeline": custom_pipeline,
|
859
|
-
"custom_revision": custom_revision,
|
860
|
-
"provider": provider,
|
861
|
-
"sess_options": sess_options,
|
862
|
-
"device_map": device_map,
|
863
|
-
"max_memory": max_memory,
|
864
|
-
"offload_folder": offload_folder,
|
865
|
-
"offload_state_dict": offload_state_dict,
|
866
|
-
"low_cpu_mem_usage": low_cpu_mem_usage,
|
867
|
-
"variant": variant,
|
868
|
-
"use_safetensors": use_safetensors,
|
869
|
-
}
|
870
|
-
|
871
|
-
def get_connected_passed_kwargs(prefix):
|
872
|
-
connected_passed_class_obj = {
|
873
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
874
|
-
}
|
875
|
-
connected_passed_pipe_kwargs = {
|
876
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
877
|
-
}
|
878
|
-
|
879
|
-
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
880
|
-
return connected_passed_kwargs
|
881
|
-
|
882
|
-
connected_pipes = {
|
883
|
-
prefix: DiffusionPipeline.from_pretrained(
|
884
|
-
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
885
|
-
)
|
886
|
-
for prefix, repo_id in connected_pipes.items()
|
887
|
-
if repo_id is not None
|
888
|
-
}
|
889
|
-
|
890
|
-
for prefix, connected_pipe in connected_pipes.items():
|
891
|
-
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
892
|
-
init_kwargs.update(
|
893
|
-
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
894
|
-
)
|
954
|
+
init_kwargs = _update_init_kwargs_with_connected_pipeline(
|
955
|
+
init_kwargs=init_kwargs,
|
956
|
+
passed_pipe_kwargs=passed_pipe_kwargs,
|
957
|
+
passed_class_objs=passed_class_obj,
|
958
|
+
folder=cached_folder,
|
959
|
+
**kwargs_copied,
|
960
|
+
)
|
895
961
|
|
896
|
-
#
|
962
|
+
# 9. Potentially add passed objects if expected
|
897
963
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
898
964
|
passed_modules = list(passed_class_obj.keys())
|
899
965
|
optional_modules = pipeline_class._optional_components
|
@@ -906,11 +972,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
906
972
|
f"Pipeline {pipeline_class} expected {expected_modules}, but only {passed_modules} were passed."
|
907
973
|
)
|
908
974
|
|
909
|
-
#
|
975
|
+
# 10. Instantiate the pipeline
|
910
976
|
model = pipeline_class(**init_kwargs)
|
911
977
|
|
912
|
-
#
|
978
|
+
# 11. Save where the model was instantiated from
|
913
979
|
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
980
|
+
if device_map is not None:
|
981
|
+
setattr(model, "hf_device_map", final_device_map)
|
914
982
|
return model
|
915
983
|
|
916
984
|
@property
|
@@ -939,6 +1007,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
939
1007
|
return torch.device(module._hf_hook.execution_device)
|
940
1008
|
return self.device
|
941
1009
|
|
1010
|
+
def remove_all_hooks(self):
|
1011
|
+
r"""
|
1012
|
+
Removes all hooks that were added when using `enable_sequential_cpu_offload` or `enable_model_cpu_offload`.
|
1013
|
+
"""
|
1014
|
+
for _, model in self.components.items():
|
1015
|
+
if isinstance(model, torch.nn.Module) and hasattr(model, "_hf_hook"):
|
1016
|
+
accelerate.hooks.remove_hook_from_module(model, recurse=True)
|
1017
|
+
self._all_hooks = []
|
1018
|
+
|
942
1019
|
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
943
1020
|
r"""
|
944
1021
|
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
@@ -953,6 +1030,12 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
953
1030
|
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
954
1031
|
default to "cuda".
|
955
1032
|
"""
|
1033
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
1034
|
+
if is_pipeline_device_mapped:
|
1035
|
+
raise ValueError(
|
1036
|
+
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_model_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_model_cpu_offload()`."
|
1037
|
+
)
|
1038
|
+
|
956
1039
|
if self.model_cpu_offload_seq is None:
|
957
1040
|
raise ValueError(
|
958
1041
|
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
@@ -963,6 +1046,8 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
963
1046
|
else:
|
964
1047
|
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
965
1048
|
|
1049
|
+
self.remove_all_hooks()
|
1050
|
+
|
966
1051
|
torch_device = torch.device(device)
|
967
1052
|
device_index = torch_device.index
|
968
1053
|
|
@@ -979,11 +1064,10 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
979
1064
|
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
980
1065
|
self._offload_device = device
|
981
1066
|
|
982
|
-
|
983
|
-
|
984
|
-
|
985
|
-
|
986
|
-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
1067
|
+
self.to("cpu", silence_dtype_warnings=True)
|
1068
|
+
device_mod = getattr(torch, device.type, None)
|
1069
|
+
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
1070
|
+
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
987
1071
|
|
988
1072
|
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
989
1073
|
|
@@ -991,9 +1075,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
991
1075
|
hook = None
|
992
1076
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
993
1077
|
model = all_model_components.pop(model_str, None)
|
1078
|
+
|
994
1079
|
if not isinstance(model, torch.nn.Module):
|
995
1080
|
continue
|
996
1081
|
|
1082
|
+
# This is because the model would already be placed on a CUDA device.
|
1083
|
+
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
1084
|
+
if is_loaded_in_8bit_bnb:
|
1085
|
+
logger.info(
|
1086
|
+
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
1087
|
+
)
|
1088
|
+
continue
|
1089
|
+
|
997
1090
|
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
998
1091
|
self._all_hooks.append(hook)
|
999
1092
|
|
@@ -1021,11 +1114,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1021
1114
|
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
1022
1115
|
return
|
1023
1116
|
|
1024
|
-
for hook in self._all_hooks:
|
1025
|
-
# offload model and remove hook from model
|
1026
|
-
hook.offload()
|
1027
|
-
hook.remove()
|
1028
|
-
|
1029
1117
|
# make sure the model is in the same state as before calling it
|
1030
1118
|
self.enable_model_cpu_offload(device=getattr(self, "_offload_device", "cuda"))
|
1031
1119
|
|
@@ -1048,6 +1136,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1048
1136
|
from accelerate import cpu_offload
|
1049
1137
|
else:
|
1050
1138
|
raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
|
1139
|
+
self.remove_all_hooks()
|
1140
|
+
|
1141
|
+
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
1142
|
+
if is_pipeline_device_mapped:
|
1143
|
+
raise ValueError(
|
1144
|
+
"It seems like you have activated a device mapping strategy on the pipeline so calling `enable_sequential_cpu_offload() isn't allowed. You can call `reset_device_map()` first and then call `enable_sequential_cpu_offload()`."
|
1145
|
+
)
|
1051
1146
|
|
1052
1147
|
torch_device = torch.device(device)
|
1053
1148
|
device_index = torch_device.index
|
@@ -1083,6 +1178,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1083
1178
|
offload_buffers = len(model._parameters) > 0
|
1084
1179
|
cpu_offload(model, device, offload_buffers=offload_buffers)
|
1085
1180
|
|
1181
|
+
def reset_device_map(self):
|
1182
|
+
r"""
|
1183
|
+
Resets the device maps (if any) to None.
|
1184
|
+
"""
|
1185
|
+
if self.hf_device_map is None:
|
1186
|
+
return
|
1187
|
+
else:
|
1188
|
+
self.remove_all_hooks()
|
1189
|
+
for name, component in self.components.items():
|
1190
|
+
if isinstance(component, torch.nn.Module):
|
1191
|
+
component.to("cpu")
|
1192
|
+
self.hf_device_map = None
|
1193
|
+
|
1086
1194
|
@classmethod
|
1087
1195
|
@validate_hf_hub_args
|
1088
1196
|
def download(cls, pretrained_model_name, **kwargs) -> Union[str, os.PathLike]:
|
@@ -1121,9 +1229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1121
1229
|
force_download (`bool`, *optional*, defaults to `False`):
|
1122
1230
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1123
1231
|
cached versions if they exist.
|
1124
|
-
|
1125
|
-
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
1126
|
-
incompletely downloaded files are deleted.
|
1232
|
+
|
1127
1233
|
proxies (`Dict[str, str]`, *optional*):
|
1128
1234
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1129
1235
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -1176,7 +1282,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1176
1282
|
|
1177
1283
|
"""
|
1178
1284
|
cache_dir = kwargs.pop("cache_dir", None)
|
1179
|
-
resume_download = kwargs.pop("resume_download", False)
|
1180
1285
|
force_download = kwargs.pop("force_download", False)
|
1181
1286
|
proxies = kwargs.pop("proxies", None)
|
1182
1287
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -1209,6 +1314,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1209
1314
|
model_info_call_error = e # save error to reraise it if model is not cached locally
|
1210
1315
|
|
1211
1316
|
if not local_files_only:
|
1317
|
+
filenames = {sibling.rfilename for sibling in info.siblings}
|
1318
|
+
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
1319
|
+
warn_msg = (
|
1320
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
1321
|
+
"Please check your files carefully:\n\n"
|
1322
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
1323
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
1324
|
+
"If you find any files in the deprecated format:\n"
|
1325
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
1326
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
1327
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
1328
|
+
)
|
1329
|
+
logger.warning(warn_msg)
|
1330
|
+
|
1331
|
+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1332
|
+
|
1212
1333
|
config_file = hf_hub_download(
|
1213
1334
|
pretrained_model_name,
|
1214
1335
|
cls.config_name,
|
@@ -1216,59 +1337,24 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1216
1337
|
revision=revision,
|
1217
1338
|
proxies=proxies,
|
1218
1339
|
force_download=force_download,
|
1219
|
-
resume_download=resume_download,
|
1220
1340
|
token=token,
|
1221
1341
|
)
|
1222
1342
|
|
1223
1343
|
config_dict = cls._dict_from_json_file(config_file)
|
1224
1344
|
ignore_filenames = config_dict.pop("_ignore_files", [])
|
1225
1345
|
|
1226
|
-
# retrieve all folder_names that contain relevant files
|
1227
|
-
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
1228
|
-
|
1229
|
-
filenames = {sibling.rfilename for sibling in info.siblings}
|
1230
|
-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1231
|
-
|
1232
|
-
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
1233
|
-
pipelines = getattr(diffusers_module, "pipelines")
|
1234
|
-
|
1235
|
-
# optionally create a custom component <> custom file mapping
|
1236
|
-
custom_components = {}
|
1237
|
-
for component in folder_names:
|
1238
|
-
module_candidate = config_dict[component][0]
|
1239
|
-
|
1240
|
-
if module_candidate is None or not isinstance(module_candidate, str):
|
1241
|
-
continue
|
1242
|
-
|
1243
|
-
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
1244
|
-
candidate_file = f"{component}/{module_candidate}.py"
|
1245
|
-
|
1246
|
-
if candidate_file in filenames:
|
1247
|
-
custom_components[component] = module_candidate
|
1248
|
-
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
1249
|
-
raise ValueError(
|
1250
|
-
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
1251
|
-
)
|
1252
|
-
|
1253
|
-
if len(variant_filenames) == 0 and variant is not None:
|
1254
|
-
deprecation_message = (
|
1255
|
-
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
1256
|
-
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
1257
|
-
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
|
1258
|
-
"modeling files is deprecated."
|
1259
|
-
)
|
1260
|
-
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
|
1261
|
-
|
1262
1346
|
# remove ignored filenames
|
1263
1347
|
model_filenames = set(model_filenames) - set(ignore_filenames)
|
1264
1348
|
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
1265
1349
|
|
1266
|
-
# if the whole pipeline is cached we don't have to ping the Hub
|
1267
1350
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
1268
1351
|
version.parse(__version__).base_version
|
1269
1352
|
) >= version.parse("0.22.0"):
|
1270
1353
|
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
1271
1354
|
|
1355
|
+
custom_components, folder_names = _get_custom_components_and_folders(
|
1356
|
+
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
1357
|
+
)
|
1272
1358
|
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
1273
1359
|
|
1274
1360
|
custom_class_name = None
|
@@ -1328,49 +1414,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1328
1414
|
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
1329
1415
|
passed_components = [k for k in expected_components if k in kwargs]
|
1330
1416
|
|
1331
|
-
|
1332
|
-
|
1333
|
-
|
1334
|
-
|
1335
|
-
|
1336
|
-
|
1337
|
-
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
model_filenames, variant=variant, passed_components=passed_components
|
1345
|
-
):
|
1346
|
-
ignore_patterns = ["*.bin", "*.msgpack"]
|
1347
|
-
|
1348
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1349
|
-
if not use_onnx:
|
1350
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1351
|
-
|
1352
|
-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
1353
|
-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
1354
|
-
if (
|
1355
|
-
len(safetensors_variant_filenames) > 0
|
1356
|
-
and safetensors_model_filenames != safetensors_variant_filenames
|
1357
|
-
):
|
1358
|
-
logger.warning(
|
1359
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1360
|
-
)
|
1361
|
-
else:
|
1362
|
-
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
1363
|
-
|
1364
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1365
|
-
if not use_onnx:
|
1366
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1367
|
-
|
1368
|
-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
1369
|
-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
1370
|
-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
1371
|
-
logger.warning(
|
1372
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1373
|
-
)
|
1417
|
+
# retrieve all patterns that should not be downloaded and error out when needed
|
1418
|
+
ignore_patterns = _get_ignore_patterns(
|
1419
|
+
passed_components,
|
1420
|
+
model_folder_names,
|
1421
|
+
model_filenames,
|
1422
|
+
variant_filenames,
|
1423
|
+
use_safetensors,
|
1424
|
+
from_flax,
|
1425
|
+
allow_pickle,
|
1426
|
+
use_onnx,
|
1427
|
+
pipeline_class._is_onnx,
|
1428
|
+
variant,
|
1429
|
+
)
|
1374
1430
|
|
1375
1431
|
# Don't download any objects that are passed
|
1376
1432
|
allow_patterns = [
|
@@ -1382,7 +1438,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1382
1438
|
|
1383
1439
|
# Don't download index files of forbidden patterns either
|
1384
1440
|
ignore_patterns = ignore_patterns + [f"{i}.index.*json" for i in ignore_patterns]
|
1385
|
-
|
1386
1441
|
re_ignore_pattern = [re.compile(fnmatch.translate(p)) for p in ignore_patterns]
|
1387
1442
|
re_allow_pattern = [re.compile(fnmatch.translate(p)) for p in allow_patterns]
|
1388
1443
|
|
@@ -1406,7 +1461,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1406
1461
|
cached_folder = snapshot_download(
|
1407
1462
|
pretrained_model_name,
|
1408
1463
|
cache_dir=cache_dir,
|
1409
|
-
resume_download=resume_download,
|
1410
1464
|
proxies=proxies,
|
1411
1465
|
local_files_only=local_files_only,
|
1412
1466
|
token=token,
|
@@ -1429,7 +1483,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1429
1483
|
for connected_pipe_repo_id in connected_pipes:
|
1430
1484
|
download_kwargs = {
|
1431
1485
|
"cache_dir": cache_dir,
|
1432
|
-
"resume_download": resume_download,
|
1433
1486
|
"force_download": force_download,
|
1434
1487
|
"proxies": proxies,
|
1435
1488
|
"local_files_only": local_files_only,
|
@@ -1472,6 +1525,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1472
1525
|
|
1473
1526
|
return expected_modules, optional_parameters
|
1474
1527
|
|
1528
|
+
@classmethod
|
1529
|
+
def _get_signature_types(cls):
|
1530
|
+
signature_types = {}
|
1531
|
+
for k, v in inspect.signature(cls.__init__).parameters.items():
|
1532
|
+
if inspect.isclass(v.annotation):
|
1533
|
+
signature_types[k] = (v.annotation,)
|
1534
|
+
elif get_origin(v.annotation) == Union:
|
1535
|
+
signature_types[k] = get_args(v.annotation)
|
1536
|
+
else:
|
1537
|
+
logger.warning(f"cannot get type annotation for Parameter {k} of {cls}.")
|
1538
|
+
return signature_types
|
1539
|
+
|
1475
1540
|
@property
|
1476
1541
|
def components(self) -> Dict[str, Any]:
|
1477
1542
|
r"""
|
@@ -1515,6 +1580,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1515
1580
|
"""
|
1516
1581
|
return numpy_to_pil(images)
|
1517
1582
|
|
1583
|
+
@torch.compiler.disable
|
1518
1584
|
def progress_bar(self, iterable=None, total=None):
|
1519
1585
|
if not hasattr(self, "_progress_bar_config"):
|
1520
1586
|
self._progress_bar_config = {}
|
@@ -1650,6 +1716,129 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1650
1716
|
for module in modules:
|
1651
1717
|
module.set_attention_slice(slice_size)
|
1652
1718
|
|
1719
|
+
@classmethod
|
1720
|
+
def from_pipe(cls, pipeline, **kwargs):
|
1721
|
+
r"""
|
1722
|
+
Create a new pipeline from a given pipeline. This method is useful to create a new pipeline from the existing
|
1723
|
+
pipeline components without reallocating additional memory.
|
1724
|
+
|
1725
|
+
Arguments:
|
1726
|
+
pipeline (`DiffusionPipeline`):
|
1727
|
+
The pipeline from which to create a new pipeline.
|
1728
|
+
|
1729
|
+
Returns:
|
1730
|
+
`DiffusionPipeline`:
|
1731
|
+
A new pipeline with the same weights and configurations as `pipeline`.
|
1732
|
+
|
1733
|
+
Examples:
|
1734
|
+
|
1735
|
+
```py
|
1736
|
+
>>> from diffusers import StableDiffusionPipeline, StableDiffusionSAGPipeline
|
1737
|
+
|
1738
|
+
>>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5")
|
1739
|
+
>>> new_pipe = StableDiffusionSAGPipeline.from_pipe(pipe)
|
1740
|
+
```
|
1741
|
+
"""
|
1742
|
+
|
1743
|
+
original_config = dict(pipeline.config)
|
1744
|
+
torch_dtype = kwargs.pop("torch_dtype", None)
|
1745
|
+
|
1746
|
+
# derive the pipeline class to instantiate
|
1747
|
+
custom_pipeline = kwargs.pop("custom_pipeline", None)
|
1748
|
+
custom_revision = kwargs.pop("custom_revision", None)
|
1749
|
+
|
1750
|
+
if custom_pipeline is not None:
|
1751
|
+
pipeline_class = _get_custom_pipeline_class(custom_pipeline, revision=custom_revision)
|
1752
|
+
else:
|
1753
|
+
pipeline_class = cls
|
1754
|
+
|
1755
|
+
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
1756
|
+
# true_optional_modules are optional components with default value in signature so it is ok not to pass them to `__init__`
|
1757
|
+
# e.g. `image_encoder` for StableDiffusionPipeline
|
1758
|
+
parameters = inspect.signature(cls.__init__).parameters
|
1759
|
+
true_optional_modules = set(
|
1760
|
+
{k for k, v in parameters.items() if v.default != inspect._empty and k in expected_modules}
|
1761
|
+
)
|
1762
|
+
|
1763
|
+
# get the class of each component based on its type hint
|
1764
|
+
# e.g. {"unet": UNet2DConditionModel, "text_encoder": CLIPTextMode}
|
1765
|
+
component_types = pipeline_class._get_signature_types()
|
1766
|
+
|
1767
|
+
pretrained_model_name_or_path = original_config.pop("_name_or_path", None)
|
1768
|
+
# allow users pass modules in `kwargs` to override the original pipeline's components
|
1769
|
+
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
1770
|
+
|
1771
|
+
original_class_obj = {}
|
1772
|
+
for name, component in pipeline.components.items():
|
1773
|
+
if name in expected_modules and name not in passed_class_obj:
|
1774
|
+
# for model components, we will not switch over if the class does not matches the type hint in the new pipeline's signature
|
1775
|
+
if (
|
1776
|
+
not isinstance(component, ModelMixin)
|
1777
|
+
or type(component) in component_types[name]
|
1778
|
+
or (component is None and name in cls._optional_components)
|
1779
|
+
):
|
1780
|
+
original_class_obj[name] = component
|
1781
|
+
else:
|
1782
|
+
logger.warning(
|
1783
|
+
f"component {name} is not switched over to new pipeline because type does not match the expected."
|
1784
|
+
f" {name} is {type(component)} while the new pipeline expect {component_types[name]}."
|
1785
|
+
f" please pass the component of the correct type to the new pipeline. `from_pipe(..., {name}={name})`"
|
1786
|
+
)
|
1787
|
+
|
1788
|
+
# allow users pass optional kwargs to override the original pipelines config attribute
|
1789
|
+
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
1790
|
+
original_pipe_kwargs = {
|
1791
|
+
k: original_config[k]
|
1792
|
+
for k in original_config.keys()
|
1793
|
+
if k in optional_kwargs and k not in passed_pipe_kwargs
|
1794
|
+
}
|
1795
|
+
|
1796
|
+
# config attribute that were not expected by pipeline is stored as its private attribute
|
1797
|
+
# (i.e. when the original pipeline was also instantiated with `from_pipe` from another pipeline that has this config)
|
1798
|
+
# in this case, we will pass them as optional arguments if they can be accepted by the new pipeline
|
1799
|
+
additional_pipe_kwargs = [
|
1800
|
+
k[1:]
|
1801
|
+
for k in original_config.keys()
|
1802
|
+
if k.startswith("_") and k[1:] in optional_kwargs and k[1:] not in passed_pipe_kwargs
|
1803
|
+
]
|
1804
|
+
for k in additional_pipe_kwargs:
|
1805
|
+
original_pipe_kwargs[k] = original_config.pop(f"_{k}")
|
1806
|
+
|
1807
|
+
pipeline_kwargs = {
|
1808
|
+
**passed_class_obj,
|
1809
|
+
**original_class_obj,
|
1810
|
+
**passed_pipe_kwargs,
|
1811
|
+
**original_pipe_kwargs,
|
1812
|
+
**kwargs,
|
1813
|
+
}
|
1814
|
+
|
1815
|
+
# store unused config as private attribute in the new pipeline
|
1816
|
+
unused_original_config = {
|
1817
|
+
f"{'' if k.startswith('_') else '_'}{k}": v for k, v in original_config.items() if k not in pipeline_kwargs
|
1818
|
+
}
|
1819
|
+
|
1820
|
+
missing_modules = (
|
1821
|
+
set(expected_modules)
|
1822
|
+
- set(pipeline._optional_components)
|
1823
|
+
- set(pipeline_kwargs.keys())
|
1824
|
+
- set(true_optional_modules)
|
1825
|
+
)
|
1826
|
+
|
1827
|
+
if len(missing_modules) > 0:
|
1828
|
+
raise ValueError(
|
1829
|
+
f"Pipeline {pipeline_class} expected {expected_modules}, but only {set(list(passed_class_obj.keys()) + list(original_class_obj.keys()))} were passed"
|
1830
|
+
)
|
1831
|
+
|
1832
|
+
new_pipeline = pipeline_class(**pipeline_kwargs)
|
1833
|
+
if pretrained_model_name_or_path is not None:
|
1834
|
+
new_pipeline.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
1835
|
+
new_pipeline.register_to_config(**unused_original_config)
|
1836
|
+
|
1837
|
+
if torch_dtype is not None:
|
1838
|
+
new_pipeline.to(dtype=torch_dtype)
|
1839
|
+
|
1840
|
+
return new_pipeline
|
1841
|
+
|
1653
1842
|
|
1654
1843
|
class StableDiffusionMixin:
|
1655
1844
|
r"""
|
@@ -1713,8 +1902,8 @@ class StableDiffusionMixin:
|
|
1713
1902
|
|
1714
1903
|
def fuse_qkv_projections(self, unet: bool = True, vae: bool = True):
|
1715
1904
|
"""
|
1716
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query,
|
1717
|
-
|
1905
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
1906
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
1718
1907
|
|
1719
1908
|
<Tip warning={true}>
|
1720
1909
|
|